Skip to main content

lisp_rpc_rust_server/
server.rs

1use anyhow::{Result, anyhow};
2use lisp_rpc_rust_serializer::lisp_rpc_from_str;
3use serde::Serialize;
4use serde::de::DeserializeOwned;
5use std::collections::HashMap;
6use std::fmt::Debug;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11use super::*;
12
13/// A trait that captures the relationship between a request type T and its response.
14pub trait RpcFunc<T>: Send + Sync + 'static {
15    type Resp: Serialize + ToRPCType + 'static;
16    fn call(&self, req: T) -> Result<Self::Resp>;
17}
18
19impl<T, R, F> RpcFunc<T> for F
20where
21    T: DeserializeOwned + Debug + Send + Sync + ToRPCType + 'static,
22    R: Serialize + ToRPCType + 'static,
23    F: Fn(T) -> Result<R> + Send + Sync + 'static,
24{
25    type Resp = R;
26    fn call(&self, req: T) -> Result<Self::Resp> {
27        (self)(req)
28    }
29}
30
31/// A trait that captures the relationship between a request type T and its async response.
32pub trait AsyncRpcFunc<T>: Send + Sync + 'static {
33    type Resp: Serialize + ToRPCType + 'static;
34    type Fut: Future<Output = Result<Self::Resp>> + Send + 'static;
35    fn call(&self, req: T) -> Self::Fut;
36}
37
38impl<T, R, F, Fut> AsyncRpcFunc<T> for F
39where
40    T: DeserializeOwned + Debug + Send + Sync + ToRPCType + 'static,
41    R: Serialize + ToRPCType + 'static,
42    Fut: Future<Output = Result<R>> + Send + 'static,
43    F: Fn(T) -> Fut + Send + Sync + 'static,
44{
45    type Resp = R;
46    type Fut = Fut;
47    fn call(&self, req: T) -> Self::Fut {
48        (self)(req)
49    }
50}
51
52/// The type-erased handler trait
53pub trait RpcHandler: Send + Sync {
54    fn handle(&self, raw_data: &str) -> Result<Box<dyn ToRPCType>>;
55}
56
57/// A concrete handler that knows its own request type T
58struct Handler<T, F> {
59    func: F,
60    _phantom: std::marker::PhantomData<T>,
61}
62
63impl<T, F> RpcHandler for Handler<T, F>
64where
65    T: DeserializeOwned + Debug + Send + Sync + ToRPCType + 'static,
66    F: RpcFunc<T>,
67{
68    fn handle(&self, raw_data: &str) -> Result<Box<dyn ToRPCType>> {
69        let req: T =
70            lisp_rpc_from_str(raw_data).map_err(|e| anyhow!("Deserialization failed: {}", e))?;
71        let resp = self.func.call(req)?;
72        Ok(Box::new(resp))
73    }
74}
75
76/// The type-erased async handler trait
77pub trait AsyncRpcHandler: Send + Sync {
78    fn handle(
79        &self,
80        raw_data: &str,
81    ) -> Pin<Box<dyn Future<Output = Result<Box<dyn ToRPCType>>> + Send>>;
82}
83
84/// A concrete async handler that knows its own request type T
85struct AsyncHandler<T, F> {
86    func: F,
87    _phantom: std::marker::PhantomData<T>,
88}
89
90impl<T, F> AsyncRpcHandler for AsyncHandler<T, F>
91where
92    T: DeserializeOwned + Debug + Send + Sync + ToRPCType + 'static,
93    F: AsyncRpcFunc<T>,
94{
95    fn handle(
96        &self,
97        raw_data: &str,
98    ) -> Pin<Box<dyn Future<Output = Result<Box<dyn ToRPCType>>> + Send>> {
99        let req_res =
100            lisp_rpc_from_str(raw_data).map_err(|e| anyhow!("Deserialization failed: {}", e));
101        match req_res {
102            Ok(req) => {
103                let fut = self.func.call(req);
104                Box::pin(async move {
105                    let resp = fut.await?;
106                    Ok(Box::new(resp) as Box<dyn ToRPCType>)
107                })
108            }
109            Err(e) => Box::pin(async move { Err(e) }),
110        }
111    }
112}
113
114/// RPCServer manages a registry of handlers and dispatches incoming raw Lisp RPC strings
115#[derive(Clone)]
116pub struct RPCServer {
117    pub handlers: Arc<HashMap<String, Box<dyn RpcHandler>>>,
118    pub async_handlers: Arc<HashMap<String, Box<dyn AsyncRpcHandler>>>,
119}
120
121impl RPCServer {
122    pub fn new() -> Self {
123        Self {
124            handlers: Arc::new(HashMap::new()),
125            async_handlers: Arc::new(HashMap::new()),
126        }
127    }
128
129    /// Register a handler for a specific command
130    pub fn register<T, F>(mut self, func: F) -> Result<Self>
131    where
132        T: DeserializeOwned + Debug + Send + Sync + ToRPCType + 'static,
133        F: RpcFunc<T>,
134    {
135        // has to be RPCType::RPC
136        let command = match <T as ToRPCType>::to_rpc_type() {
137            RPCType::RPC(s) => s,
138            _ => anyhow::bail!("Handler function argument has to be RPCType::RPC"),
139        };
140
141        let handler = Handler {
142            func,
143            _phantom: std::marker::PhantomData,
144        };
145
146        Arc::get_mut(&mut self.handlers)
147            .unwrap()
148            .insert(command, Box::new(handler));
149
150        Ok(self)
151    }
152
153    /// Register an async handler for a specific command
154    pub fn register_async<T, F>(mut self, func: F) -> Result<Self>
155    where
156        T: DeserializeOwned + Debug + Send + Sync + ToRPCType + 'static,
157        F: AsyncRpcFunc<T>,
158    {
159        // has to be RPCType::RPC
160        let command = match <T as ToRPCType>::to_rpc_type() {
161            RPCType::RPC(s) => s,
162            _ => anyhow::bail!("Handler function argument has to be RPCType::RPC"),
163        };
164
165        let handler = AsyncHandler {
166            func,
167            _phantom: std::marker::PhantomData,
168        };
169
170        Arc::get_mut(&mut self.async_handlers)
171            .unwrap()
172            .insert(command, Box::new(handler));
173
174        Ok(self)
175    }
176
177    /// Dispatch a raw Lisp RPC string to the appropriate handler
178    pub fn handle(&self, raw_data: &str) -> Result<String> {
179        // 1. Extract the command name from the Lisp string (e.g., "(command-name ...)")
180        let command =
181            extract_command_name(raw_data).ok_or_else(|| anyhow!("Invalid RPC format"))?;
182
183        // 2. Find the registered handler
184        let handler = self
185            .handlers
186            .get(&command)
187            .ok_or_else(|| anyhow!("Unknown command: {}", command))?;
188
189        // 3. Execute the handler to get the trait object
190        let resp_obj = handler.handle(raw_data)?;
191
192        // 4. Serialize the response using the trait object's method
193        resp_obj.serialize_lisp()
194    }
195
196    /// Dispatch a raw Lisp RPC string to the appropriate handler asynchronously
197    pub async fn handle_async(&self, raw_data: &str) -> Result<String> {
198        // 1. Extract the command name from the Lisp string (e.g., "(command-name ...)")
199        let command =
200            extract_command_name(raw_data).ok_or_else(|| anyhow!("Invalid RPC format"))?;
201
202        // 2. Caution: Find the registered handler (check sync first)
203        if let Some(handler) = self.handlers.get(&command) {
204            let resp_obj = handler.handle(raw_data)?;
205            return resp_obj.serialize_lisp();
206        }
207
208        // 3. Find the registered async handler
209        if let Some(handler) = self.async_handlers.get(&command) {
210            let resp_obj = handler.handle(raw_data).await?;
211            return resp_obj.serialize_lisp();
212        }
213
214        anyhow::bail!("Unknown command: {}", command)
215    }
216}
217
218/// Helper to get the first symbol from "(symbol ...)"
219fn extract_command_name(raw: &str) -> Option<String> {
220    let trimmed = raw.trim().trim_start_matches('(');
221    trimmed.split_whitespace().next().map(|s| s.to_string())
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[derive(Debug, serde::Serialize, serde::Deserialize)]
229    #[serde(rename = "dummy")]
230    struct DummyReq {
231        val: String,
232    }
233
234    #[derive(Debug, serde::Serialize, serde::Deserialize)]
235    #[serde(rename = "dummy-async")]
236    struct DummyAsyncReq {
237        val: String,
238    }
239
240    #[derive(Debug, serde::Serialize, serde::Deserialize)]
241    struct DummyResp {
242        res: String,
243    }
244
245    impl_to_rpc!(DummyReq, RPCType::RPC("dummy".to_string()));
246    impl_to_rpc!(DummyAsyncReq, RPCType::RPC("dummy-async".to_string()));
247    impl_to_rpc!(DummyResp, RPCType::V);
248
249    #[actix_web::test]
250    async fn test_async_register_and_handle() {
251        let server = RPCServer::new()
252            .register(|req: DummyReq| {
253                Ok(DummyResp {
254                    res: format!("sync-{}", req.val),
255                })
256            })
257            .unwrap()
258            .register_async(|req: DummyAsyncReq| async move {
259                Ok(DummyResp {
260                    res: format!("async-{}", req.val),
261                })
262            })
263            .unwrap();
264
265        // 1. Test sync handler via sync dispatch
266        let sync_req = lisp_rpc_rust_serializer::lisp_rpc_to_str(&DummyReq {
267            val: "test".to_string(),
268        })
269        .unwrap();
270        let sync_res = server.handle(&sync_req).unwrap();
271        assert!(sync_res.contains("sync-test"));
272
273        // 2. Test async handler via handle_async
274        let async_req = lisp_rpc_rust_serializer::lisp_rpc_to_str(&DummyAsyncReq {
275            val: "async-test".to_string(),
276        })
277        .unwrap();
278        let async_res = server.handle_async(&async_req).await.unwrap();
279        assert!(async_res.contains("async-async-test"));
280    }
281}