umbral_socket/stream/
server.rs1use std::collections::HashMap;
2use std::io::Result;
3use std::os::unix::fs::PermissionsExt;
4use std::path::Path;
5use std::sync::Arc;
6
7use bytes::Bytes;
8use futures::future::BoxFuture;
9use tokio::io::AsyncReadExt;
10use tokio::io::AsyncWriteExt;
11use tokio::net::UnixListener;
12use tokio::net::UnixStream;
13
14type Handler<S> = Arc<dyn Fn(Arc<S>, Bytes) -> BoxFuture<'static, Result<Bytes>> + Send + Sync>;
15
16pub struct UmbralServer<S> {
17 state: Arc<S>,
18 handlers: HashMap<String, Handler<S>>,
19}
20
21impl<S: Send + Sync + 'static> UmbralServer<S> {
22 pub fn new(state: S) -> Self {
23 Self {
24 state: Arc::new(state),
25 handlers: HashMap::new(),
26 }
27 }
28
29 pub fn route<F, Fut>(mut self, method: &str, handler: F) -> Self
30 where
31 F: Fn(Arc<S>, Bytes) -> Fut + Send + Sync + 'static,
32 Fut: futures::Future<Output = Result<Bytes>> + Send + 'static,
33 {
34 let handler_arc: Handler<S> =
35 Arc::new(move |state, payload| Box::pin(handler(state, payload)));
36 self.handlers.insert(method.to_string(), handler_arc);
37 self
38 }
39
40 pub async fn run(self, socket: &str) -> Result<()> {
41 let path = Path::new(socket);
42 if path.exists() {
43 tokio::fs::remove_file(path).await?;
44 }
45 let listener = UnixListener::bind(path)?;
46 let permissions = std::fs::Permissions::from_mode(0o766);
47 std::fs::set_permissions(path, permissions)?;
48 let server_arc = Arc::new(self);
49 println!("Umbral Server listening on \"{}\"", socket);
50 loop {
51 let (stream, _) = listener.accept().await?;
52 let server_clone = server_arc.clone();
53 tokio::spawn(async move {
54 if let Err(e) = server_clone.handle_connection(stream).await {
55 eprintln!("Error processing connection: {}", e);
56 }
57 });
58 }
59 }
60
61 async fn handle_connection(&self, mut stream: UnixStream) -> Result<()> {
62 loop {
63 let mut buffer = [0; 1024];
64 let n = match stream.read(&mut buffer).await {
65 Ok(0) => return Ok(()),
66 Ok(n) => n,
67 Err(e) => return Err(e),
68 };
69 let message = String::from_utf8_lossy(&buffer[..n]);
70
71 let response = if let Some((method, payload)) = message.trim().split_once("[%]") {
72 if let Some(handler) = self.handlers.get(method) {
73 let state_clone = self.state.clone();
74 let payload_bytes = Bytes::from(payload.as_bytes().to_vec());
75 handler(state_clone, payload_bytes).await
76 } else {
77 Ok(Bytes::from_static(b"METHOD NOT FOUND"))
78 }
79 } else {
80 Ok(Bytes::from_static(b"INVALID PROTOCOL"))
81 };
82
83 match response {
84 Ok(response_bytes) => {
85 let len = response_bytes.len() as u32;
86 stream.write_all(&len.to_be_bytes()).await?;
87 stream.write_all(&response_bytes).await?;
88 }
89 Err(e) => {
90 let err_msg = Bytes::from(format!("Handler error: {}", e));
91 let len = err_msg.len() as u32;
92 stream.write_all(&len.to_be_bytes()).await?;
93 stream.write_all(&err_msg).await?;
94 }
95 }
96 }
97 }
98}