umbral_socket/stream/
client.rs

1use bytes::Bytes;
2use deadpool::managed;
3use std::io::{self, Result};
4use tokio::io::{AsyncReadExt, AsyncWriteExt};
5use tokio::net::UnixStream;
6
7struct UnixStreamManager {
8    socket: String,
9}
10
11impl managed::Manager for UnixStreamManager {
12    type Type = UnixStream;
13    type Error = io::Error;
14
15    async fn create(&self) -> Result<Self::Type> {
16        UnixStream::connect(&self.socket).await
17    }
18
19    async fn recycle(
20        &self,
21        conn: &mut Self::Type,
22        _metrics: &managed::Metrics,
23    ) -> managed::RecycleResult<Self::Error> {
24        match conn.try_write(&[]) {
25            Ok(_) => Ok(()),
26            Err(e) if e.kind() == io::ErrorKind::WouldBlock => Ok(()),
27            Err(e) => Err(e.into()),
28        }
29    }
30}
31
32#[derive(Clone)]
33pub struct UmbralClient {
34    pool: managed::Pool<UnixStreamManager>,
35}
36
37impl UmbralClient {
38    pub fn new(socket: &str, max_size: usize) -> UmbralClient {
39        let manager = UnixStreamManager {
40            socket: socket.to_string(),
41        };
42        let pool = managed::Pool::builder(manager)
43            .max_size(max_size)
44            .build()
45            .unwrap();
46        UmbralClient { pool }
47    }
48
49    pub async fn send(&self, method: &str, payload: Bytes) -> Result<Bytes> {
50        let mut conn = self
51            .pool
52            .get()
53            .await
54            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
55
56        let mut message = Vec::new();
57        message.extend_from_slice(method.as_bytes());
58        message.extend_from_slice(b"[%]");
59        message.extend_from_slice(&payload);
60
61        conn.write_all(&message).await?;
62
63        let mut len_bytes = [0u8; 4];
64        conn.read_exact(&mut len_bytes).await?;
65        let len = u32::from_be_bytes(len_bytes);
66
67        let mut response_buffer = vec![0u8; len as usize];
68        conn.read_exact(&mut response_buffer).await?;
69
70        Ok(Bytes::from(response_buffer))
71    }
72}