umbral_socket/stream/
server.rs

1use std::collections::HashMap;
2use std::io::Result;
3use std::os::unix::fs::PermissionsExt;
4use std::path::Path;
5use std::sync::Arc;
6
7use bytes::Bytes;
8use futures::future::BoxFuture;
9use tokio::io::AsyncReadExt;
10use tokio::io::AsyncWriteExt;
11use tokio::net::UnixListener;
12use tokio::net::UnixStream;
13
14type Handler<S> = Arc<dyn Fn(Arc<S>, Bytes) -> BoxFuture<'static, Result<Bytes>> + Send + Sync>;
15
16pub struct UmbralServer<S> {
17    state: Arc<S>,
18    handlers: HashMap<String, Handler<S>>,
19}
20
21impl<S: Send + Sync + 'static> UmbralServer<S> {
22    pub fn new(state: S) -> Self {
23        Self {
24            state: Arc::new(state),
25            handlers: HashMap::new(),
26        }
27    }
28
29    pub fn route<F, Fut>(mut self, method: &str, handler: F) -> Self
30    where
31        F: Fn(Arc<S>, Bytes) -> Fut + Send + Sync + 'static,
32        Fut: futures::Future<Output = Result<Bytes>> + Send + 'static,
33    {
34        let handler_arc: Handler<S> =
35            Arc::new(move |state, payload| Box::pin(handler(state, payload)));
36        self.handlers.insert(method.to_string(), handler_arc);
37        self
38    }
39
40    pub async fn run(self, socket: &str) -> Result<()> {
41        let path = Path::new(socket);
42        if path.exists() {
43            tokio::fs::remove_file(path).await?;
44        }
45        let listener = UnixListener::bind(path)?;
46        let permissions = std::fs::Permissions::from_mode(0o766);
47        std::fs::set_permissions(path, permissions)?;
48        let server_arc = Arc::new(self);
49        println!("Umbral Server listening on \"{}\"", socket);
50        loop {
51            let (stream, _) = listener.accept().await?;
52            let server_clone = server_arc.clone();
53            tokio::spawn(async move {
54                if let Err(e) = server_clone.handle_connection(stream).await {
55                    eprintln!("Error processing connection: {}", e);
56                }
57            });
58        }
59    }
60
61    async fn handle_connection(&self, mut stream: UnixStream) -> Result<()> {
62        let mut buffer = [0u8; 1024];
63        loop {
64            let n = match stream.read(&mut buffer).await {
65                Ok(0) => return Ok(()),
66                Ok(n) => n,
67                Err(e) => return Err(e),
68            };
69            let message = String::from_utf8_lossy(&buffer[..n]);
70
71            let response = if let Some((method, payload)) = message.trim().split_once("[%]") {
72                if let Some(handler) = self.handlers.get(method) {
73                    let state_clone = self.state.clone();
74                    let payload_bytes = Bytes::from(payload.as_bytes().to_vec());
75                    handler(state_clone, payload_bytes).await
76                } else {
77                    Ok(Bytes::from_static(b"METHOD NOT FOUND"))
78                }
79            } else {
80                Ok(Bytes::from_static(b"INVALID PROTOCOL"))
81            };
82
83            match response {
84                Ok(response_bytes) => {
85                    let len = response_bytes.len() as u32;
86                    stream.write_all(&len.to_be_bytes()).await?;
87                    stream.write_all(&response_bytes).await?;
88                }
89                Err(e) => {
90                    let err_msg = Bytes::from(format!("Handler error: {}", e));
91                    let len = err_msg.len() as u32;
92                    stream.write_all(&len.to_be_bytes()).await?;
93                    stream.write_all(&err_msg).await?;
94                }
95            }
96        }
97    }
98}