umbral_socket/stream/
server.rs1use std::collections::HashMap;
2use std::io::Result;
3use std::path::Path;
4use std::sync::Arc;
5
6use bytes::Bytes;
7use futures::future::BoxFuture;
8use tokio::io::AsyncReadExt;
9use tokio::io::AsyncWriteExt;
10use tokio::net::UnixListener;
11use tokio::net::UnixStream;
12
13type Handler<S> = Arc<dyn Fn(Arc<S>, Bytes) -> BoxFuture<'static, Result<Bytes>> + Send + Sync>;
14
15pub struct UmbralServer<S> {
16 state: Arc<S>,
17 handlers: HashMap<String, Handler<S>>,
18}
19
20impl<S: Send + Sync + 'static> UmbralServer<S> {
21 pub fn new(state: S) -> Self {
22 Self {
23 state: Arc::new(state),
24 handlers: HashMap::new(),
25 }
26 }
27
28 pub fn route<F, Fut>(mut self, method: &str, handler: F) -> Self
29 where
30 F: Fn(Arc<S>, Bytes) -> Fut + Send + Sync + 'static,
31 Fut: futures::Future<Output = Result<Bytes>> + Send + 'static,
32 {
33 let handler_arc: Handler<S> =
34 Arc::new(move |state, payload| Box::pin(handler(state, payload)));
35 self.handlers.insert(method.to_string(), handler_arc);
36 self
37 }
38
39 pub async fn run(self, socket: &str) -> Result<()> {
40 let path = Path::new(socket);
41 if path.exists() {
42 tokio::fs::remove_file(path).await?;
43 }
44 let listener = UnixListener::bind(path)?;
45 let server_arc = Arc::new(self);
46 println!("Umbral Server listening on \"{}\"", socket);
47 loop {
48 let (stream, _) = listener.accept().await?;
49 let server_clone = server_arc.clone();
50 tokio::spawn(async move {
51 if let Err(e) = server_clone.handle_connection(stream).await {
52 eprintln!("Error processing connection: {}", e);
53 }
54 });
55 }
56 }
57
58 async fn handle_connection(&self, mut stream: UnixStream) -> Result<()> {
59 loop {
60 let mut buffer = [0; 1024];
61 let n = match stream.read(&mut buffer).await {
62 Ok(0) => return Ok(()),
63 Ok(n) => n,
64 Err(e) => return Err(e),
65 };
66 let message = String::from_utf8_lossy(&buffer[..n]);
67
68 let response = if let Some((method, payload)) = message.trim().split_once("[%]") {
69 if let Some(handler) = self.handlers.get(method) {
70 let state_clone = self.state.clone();
71 let payload_bytes = Bytes::from(payload.as_bytes().to_vec());
72 handler(state_clone, payload_bytes).await
73 } else {
74 Ok(Bytes::from_static(b"METHOD NOT FOUND"))
75 }
76 } else {
77 Ok(Bytes::from_static(b"INVALID PROTOCOL"))
78 };
79
80 match response {
81 Ok(response_bytes) => {
82 let len = response_bytes.len() as u32;
83 stream.write_all(&len.to_be_bytes()).await?;
84 stream.write_all(&response_bytes).await?;
85 }
86 Err(e) => {
87 let err_msg = Bytes::from(format!("Handler error: {}", e));
88 let len = err_msg.len() as u32;
89 stream.write_all(&len.to_be_bytes()).await?;
90 stream.write_all(&err_msg).await?;
91 }
92 }
93 }
94 }
95}