1use anyhow::{Result, anyhow};
2use lisp_rpc_rust_serializer::lisp_rpc_from_str;
3use serde::Serialize;
4use serde::de::DeserializeOwned;
5use std::collections::HashMap;
6use std::fmt::Debug;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11use super::*;
12
13pub trait RpcFunc<T>: Send + Sync + 'static {
15 type Resp: Serialize + ToRPCType + 'static;
16 fn call(&self, req: T) -> Result<Self::Resp>;
17}
18
19impl<T, R, F> RpcFunc<T> for F
20where
21 T: DeserializeOwned + Debug + Send + Sync + ToRPCType + 'static,
22 R: Serialize + ToRPCType + 'static,
23 F: Fn(T) -> Result<R> + Send + Sync + 'static,
24{
25 type Resp = R;
26 fn call(&self, req: T) -> Result<Self::Resp> {
27 (self)(req)
28 }
29}
30
31pub trait AsyncRpcFunc<T>: Send + Sync + 'static {
33 type Resp: Serialize + ToRPCType + 'static;
34 type Fut: Future<Output = Result<Self::Resp>> + Send + 'static;
35 fn call(&self, req: T) -> Self::Fut;
36}
37
38impl<T, R, F, Fut> AsyncRpcFunc<T> for F
39where
40 T: DeserializeOwned + Debug + Send + Sync + ToRPCType + 'static,
41 R: Serialize + ToRPCType + 'static,
42 Fut: Future<Output = Result<R>> + Send + 'static,
43 F: Fn(T) -> Fut + Send + Sync + 'static,
44{
45 type Resp = R;
46 type Fut = Fut;
47 fn call(&self, req: T) -> Self::Fut {
48 (self)(req)
49 }
50}
51
52pub trait RpcHandler: Send + Sync {
54 fn handle(&self, raw_data: &str) -> Result<Box<dyn ToRPCType>>;
55}
56
57struct Handler<T, F> {
59 func: F,
60 _phantom: std::marker::PhantomData<T>,
61}
62
63impl<T, F> RpcHandler for Handler<T, F>
64where
65 T: DeserializeOwned + Debug + Send + Sync + ToRPCType + 'static,
66 F: RpcFunc<T>,
67{
68 fn handle(&self, raw_data: &str) -> Result<Box<dyn ToRPCType>> {
69 let req: T =
70 lisp_rpc_from_str(raw_data).map_err(|e| anyhow!("Deserialization failed: {}", e))?;
71 let resp = self.func.call(req)?;
72 Ok(Box::new(resp))
73 }
74}
75
76pub trait AsyncRpcHandler: Send + Sync {
78 fn handle(
79 &self,
80 raw_data: &str,
81 ) -> Pin<Box<dyn Future<Output = Result<Box<dyn ToRPCType>>> + Send>>;
82}
83
84struct AsyncHandler<T, F> {
86 func: F,
87 _phantom: std::marker::PhantomData<T>,
88}
89
90impl<T, F> AsyncRpcHandler for AsyncHandler<T, F>
91where
92 T: DeserializeOwned + Debug + Send + Sync + ToRPCType + 'static,
93 F: AsyncRpcFunc<T>,
94{
95 fn handle(
96 &self,
97 raw_data: &str,
98 ) -> Pin<Box<dyn Future<Output = Result<Box<dyn ToRPCType>>> + Send>> {
99 let req_res =
100 lisp_rpc_from_str(raw_data).map_err(|e| anyhow!("Deserialization failed: {}", e));
101 match req_res {
102 Ok(req) => {
103 let fut = self.func.call(req);
104 Box::pin(async move {
105 let resp = fut.await?;
106 Ok(Box::new(resp) as Box<dyn ToRPCType>)
107 })
108 }
109 Err(e) => Box::pin(async move { Err(e) }),
110 }
111 }
112}
113
114#[derive(Clone)]
116pub struct RPCServer {
117 pub handlers: Arc<HashMap<String, Box<dyn RpcHandler>>>,
118 pub async_handlers: Arc<HashMap<String, Box<dyn AsyncRpcHandler>>>,
119}
120
121impl RPCServer {
122 pub fn new() -> Self {
123 Self {
124 handlers: Arc::new(HashMap::new()),
125 async_handlers: Arc::new(HashMap::new()),
126 }
127 }
128
129 pub fn register<T, F>(mut self, func: F) -> Result<Self>
131 where
132 T: DeserializeOwned + Debug + Send + Sync + ToRPCType + 'static,
133 F: RpcFunc<T>,
134 {
135 let command = match <T as ToRPCType>::to_rpc_type() {
137 RPCType::RPC(s) => s,
138 _ => anyhow::bail!("Handler function argument has to be RPCType::RPC"),
139 };
140
141 let handler = Handler {
142 func,
143 _phantom: std::marker::PhantomData,
144 };
145
146 Arc::get_mut(&mut self.handlers)
147 .unwrap()
148 .insert(command, Box::new(handler));
149
150 Ok(self)
151 }
152
153 pub fn register_async<T, F>(mut self, func: F) -> Result<Self>
155 where
156 T: DeserializeOwned + Debug + Send + Sync + ToRPCType + 'static,
157 F: AsyncRpcFunc<T>,
158 {
159 let command = match <T as ToRPCType>::to_rpc_type() {
161 RPCType::RPC(s) => s,
162 _ => anyhow::bail!("Handler function argument has to be RPCType::RPC"),
163 };
164
165 let handler = AsyncHandler {
166 func,
167 _phantom: std::marker::PhantomData,
168 };
169
170 Arc::get_mut(&mut self.async_handlers)
171 .unwrap()
172 .insert(command, Box::new(handler));
173
174 Ok(self)
175 }
176
177 pub fn handle(&self, raw_data: &str) -> Result<String> {
179 let command =
181 extract_command_name(raw_data).ok_or_else(|| anyhow!("Invalid RPC format"))?;
182
183 let handler = self
185 .handlers
186 .get(&command)
187 .ok_or_else(|| anyhow!("Unknown command: {}", command))?;
188
189 let resp_obj = handler.handle(raw_data)?;
191
192 resp_obj.serialize_lisp()
194 }
195
196 pub async fn handle_async(&self, raw_data: &str) -> Result<String> {
198 let command =
200 extract_command_name(raw_data).ok_or_else(|| anyhow!("Invalid RPC format"))?;
201
202 if let Some(handler) = self.handlers.get(&command) {
204 let resp_obj = handler.handle(raw_data)?;
205 return resp_obj.serialize_lisp();
206 }
207
208 if let Some(handler) = self.async_handlers.get(&command) {
210 let resp_obj = handler.handle(raw_data).await?;
211 return resp_obj.serialize_lisp();
212 }
213
214 anyhow::bail!("Unknown command: {}", command)
215 }
216}
217
218fn extract_command_name(raw: &str) -> Option<String> {
220 let trimmed = raw.trim().trim_start_matches('(');
221 trimmed.split_whitespace().next().map(|s| s.to_string())
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[derive(Debug, serde::Serialize, serde::Deserialize)]
229 #[serde(rename = "dummy")]
230 struct DummyReq {
231 val: String,
232 }
233
234 #[derive(Debug, serde::Serialize, serde::Deserialize)]
235 #[serde(rename = "dummy-async")]
236 struct DummyAsyncReq {
237 val: String,
238 }
239
240 #[derive(Debug, serde::Serialize, serde::Deserialize)]
241 struct DummyResp {
242 res: String,
243 }
244
245 impl_to_rpc!(DummyReq, RPCType::RPC("dummy".to_string()));
246 impl_to_rpc!(DummyAsyncReq, RPCType::RPC("dummy-async".to_string()));
247 impl_to_rpc!(DummyResp, RPCType::V);
248
249 #[actix_web::test]
250 async fn test_async_register_and_handle() {
251 let server = RPCServer::new()
252 .register(|req: DummyReq| {
253 Ok(DummyResp {
254 res: format!("sync-{}", req.val),
255 })
256 })
257 .unwrap()
258 .register_async(|req: DummyAsyncReq| async move {
259 Ok(DummyResp {
260 res: format!("async-{}", req.val),
261 })
262 })
263 .unwrap();
264
265 let sync_req = lisp_rpc_rust_serializer::lisp_rpc_to_str(&DummyReq {
267 val: "test".to_string(),
268 })
269 .unwrap();
270 let sync_res = server.handle(&sync_req).unwrap();
271 assert!(sync_res.contains("sync-test"));
272
273 let async_req = lisp_rpc_rust_serializer::lisp_rpc_to_str(&DummyAsyncReq {
275 val: "async-test".to_string(),
276 })
277 .unwrap();
278 let async_res = server.handle_async(&async_req).await.unwrap();
279 assert!(async_res.contains("async-async-test"));
280 }
281}