umbral_socket/stream/
client.rs

1use std::io::{self, Result};
2use std::sync::Arc;
3
4use bytes::Bytes;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::UnixStream;
7use tokio::sync::Mutex;
8
9struct ClientState {
10    stream: Mutex<Option<UnixStream>>,
11    socket: String,
12}
13
14#[derive(Clone)]
15pub struct UmbralClient {
16    state: Arc<ClientState>,
17}
18
19impl UmbralClient {
20    pub fn new(socket: &str) -> UmbralClient {
21        let state = Arc::new(ClientState {
22            stream: Mutex::new(None),
23            socket: String::from(socket),
24        });
25        return UmbralClient { state };
26    }
27
28    pub async fn send(&self, method: &str, payload: Bytes) -> Result<Bytes> {
29        let mut message = Vec::new();
30        message.extend_from_slice(method.as_bytes());
31        message.extend_from_slice(b"[%]");
32        message.extend_from_slice(&payload);
33
34        let mut stream_guard = self.state.stream.lock().await;
35        if stream_guard.is_none() {
36            match UnixStream::connect(&self.state.socket).await {
37                Ok(stream) => *stream_guard = Some(stream),
38                Err(e) => return Err(e),
39            }
40        }
41
42        if let Some(stream) = stream_guard.as_mut() {
43            if let Err(e) = stream.write_all(&message).await {
44                *stream_guard = None;
45                return Err(e);
46            }
47
48            let mut len_bytes = [0u8; 4];
49            if let Err(e) = stream.read_exact(&mut len_bytes).await {
50                *stream_guard = None;
51                return Err(e);
52            }
53            let len = u32::from_be_bytes(len_bytes);
54
55            let mut response_buffer = vec![0u8; len as usize];
56            if let Err(e) = stream.read_exact(&mut response_buffer).await {
57                *stream_guard = None;
58                return Err(e);
59            }
60
61            return Ok(Bytes::from(response_buffer));
62        }
63
64        Err(io::Error::new(io::ErrorKind::Other, "Failed to get stream"))
65    }
66}