1use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11use serde::Serialize;
12use serde::de::DeserializeOwned;
13
14use crate::codec::{Decode, Encode};
15use crate::error::Error;
16
17type BoxedHandler =
25 Box<dyn Fn(Vec<u8>, &dyn std::any::Any) -> Result<Vec<u8>, Error> + Send + Sync>;
26
27pub struct Router {
29 handlers: RwLock<HashMap<String, Arc<BoxedHandler>>>,
30}
31
32impl Router {
33 pub fn new() -> Self {
35 Self {
36 handlers: RwLock::new(HashMap::new()),
37 }
38 }
39
40 pub fn register<F>(&self, name: impl Into<String>, handler: F)
44 where
45 F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync + 'static,
46 {
47 let boxed: BoxedHandler = Box::new(move |payload, _ctx| Ok(handler(payload)));
48 crate::write_or_recover(&self.handlers).insert(name.into(), Arc::new(boxed));
49 }
50
51 pub fn register_simple<F>(&self, name: impl Into<String>, handler: F)
55 where
56 F: Fn() -> Vec<u8> + Send + Sync + 'static,
57 {
58 let boxed: BoxedHandler = Box::new(move |_payload, _ctx| Ok(handler()));
59 crate::write_or_recover(&self.handlers).insert(name.into(), Arc::new(boxed));
60 }
61
62 pub fn register_json<F, A, R>(&self, name: impl Into<String>, handler: F)
69 where
70 F: Fn(A) -> R + Send + Sync + 'static,
71 A: DeserializeOwned + 'static,
72 R: Serialize + 'static,
73 {
74 let boxed: BoxedHandler = Box::new(move |payload, _ctx| {
75 let arg: A = sonic_rs::from_slice(&payload).map_err(Error::from)?;
76 let result = handler(arg);
77 sonic_rs::to_vec(&result).map_err(Error::from)
78 });
79 crate::write_or_recover(&self.handlers).insert(name.into(), Arc::new(boxed));
80 }
81
82 pub fn register_json_result<F, A, R, E>(&self, name: impl Into<String>, handler: F)
89 where
90 F: Fn(A) -> Result<R, E> + Send + Sync + 'static,
91 A: DeserializeOwned + 'static,
92 R: Serialize + 'static,
93 E: std::fmt::Display + 'static,
94 {
95 let boxed: BoxedHandler = Box::new(move |payload, _ctx| {
96 let arg: A = sonic_rs::from_slice(&payload).map_err(Error::from)?;
97 let result = handler(arg).map_err(|e| Error::Handler(e.to_string()))?;
98 sonic_rs::to_vec(&result).map_err(Error::from)
99 });
100 crate::write_or_recover(&self.handlers).insert(name.into(), Arc::new(boxed));
101 }
102
103 pub fn register_binary<F, A, R>(&self, name: impl Into<String>, handler: F)
110 where
111 F: Fn(A) -> R + Send + Sync + 'static,
112 A: Decode + 'static,
113 R: Encode + 'static,
114 {
115 let boxed: BoxedHandler = Box::new(move |payload, _ctx| {
116 let (arg, _consumed) = A::decode(&payload).ok_or(Error::DecodeFailed)?;
117 let result = handler(arg);
118 let mut buf = Vec::with_capacity(result.encode_size());
119 result.encode(&mut buf);
120 Ok(buf)
121 });
122 crate::write_or_recover(&self.handlers).insert(name.into(), Arc::new(boxed));
123 }
124
125 pub fn register_with_context<F>(&self, name: impl Into<String>, handler: F)
132 where
133 F: Fn(Vec<u8>, &dyn std::any::Any) -> Result<Vec<u8>, Error> + Send + Sync + 'static,
134 {
135 let boxed: BoxedHandler = Box::new(handler);
136 crate::write_or_recover(&self.handlers).insert(name.into(), Arc::new(boxed));
137 }
138
139 pub fn call_with_context(
146 &self,
147 name: &str,
148 payload: Vec<u8>,
149 ctx: &dyn std::any::Any,
150 ) -> Result<Vec<u8>, Error> {
151 let handler = {
152 let handlers = crate::read_or_recover(&self.handlers);
153 handlers.get(name).cloned()
154 };
155 match handler {
156 Some(h) => h(payload, ctx),
157 None => Err(Error::UnknownCommand(name.to_string())),
158 }
159 }
160
161 #[must_use]
167 pub fn call_or_error_bytes_with_context(
168 &self,
169 name: &str,
170 payload: Vec<u8>,
171 ctx: &dyn std::any::Any,
172 ) -> Vec<u8> {
173 match self.call_with_context(name, payload, ctx) {
174 Ok(bytes) => bytes,
175 Err(e) => e.to_string().into_bytes(),
176 }
177 }
178
179 pub fn call(&self, name: &str, payload: Vec<u8>) -> Result<Vec<u8>, Error> {
184 self.call_with_context(name, payload, &())
185 }
186
187 #[must_use]
194 pub fn call_or_error_bytes(&self, name: &str, payload: Vec<u8>) -> Vec<u8> {
195 self.call_or_error_bytes_with_context(name, payload, &())
196 }
197
198 #[must_use]
200 pub fn has(&self, name: &str) -> bool {
201 crate::read_or_recover(&self.handlers).contains_key(name)
202 }
203}
204
205impl std::fmt::Debug for Router {
206 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207 let count = crate::read_or_recover(&self.handlers).len();
208 f.debug_struct("Router")
209 .field("handler_count", &count)
210 .finish()
211 }
212}
213
214impl Default for Router {
215 fn default() -> Self {
216 Self::new()
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn register_and_dispatch() {
226 let table = Router::new();
227 table.register("echo", |payload: Vec<u8>| payload);
228 let resp = table.call("echo", b"hello".to_vec()).unwrap();
229 assert_eq!(resp, b"hello");
230 }
231
232 #[test]
233 fn unknown_command() {
234 let table = Router::new();
235 let err = table.call("nope", vec![]).unwrap_err();
236 assert!(matches!(err, Error::UnknownCommand(ref name) if name == "nope"));
237 assert_eq!(err.to_string(), "unknown command: nope");
238 }
239
240 #[test]
241 fn has_command() {
242 let table = Router::new();
243 assert!(!table.has("ping"));
244 table.register("ping", |_payload: Vec<u8>| b"pong".to_vec());
245 assert!(table.has("ping"));
246 }
247
248 #[test]
249 fn register_simple_test() {
250 let table = Router::new();
251 table.register_simple("version", || b"1.0".to_vec());
252 let resp = table.call("version", vec![0xFF]).unwrap();
253 assert_eq!(resp, b"1.0");
254 }
255
256 #[test]
257 fn call_or_error_bytes_success() {
258 let table = Router::new();
259 table.register("echo", |payload: Vec<u8>| payload);
260 let resp = table.call_or_error_bytes("echo", b"hello".to_vec());
261 assert_eq!(resp, b"hello");
262 }
263
264 #[test]
265 fn call_or_error_bytes_unknown() {
266 let table = Router::new();
267 let resp = table.call_or_error_bytes("nope", vec![]);
268 assert_eq!(resp, b"unknown command: nope");
269 }
270
271 #[test]
274 fn register_json_roundtrip() {
275 let table = Router::new();
276 table.register_json("add", |args: (i32, i32)| args.0 + args.1);
277 let payload = sonic_rs::to_vec(&(3, 4)).unwrap();
278 let resp = table.call("add", payload).unwrap();
279 let result: i32 = sonic_rs::from_slice(&resp).unwrap();
280 assert_eq!(result, 7);
281 }
282
283 #[test]
284 fn register_json_bad_input() {
285 let table = Router::new();
286 table.register_json("add", |args: (i32, i32)| args.0 + args.1);
287 let err = table.call("add", b"not json!".to_vec()).unwrap_err();
288 assert!(matches!(err, Error::Serialize(_)));
289 }
290
291 #[test]
294 fn register_json_result_ok() {
295 let table = Router::new();
296 table.register_json_result("divide", |args: (f64, f64)| -> Result<f64, String> {
297 if args.1 == 0.0 {
298 Err("division by zero".into())
299 } else {
300 Ok(args.0 / args.1)
301 }
302 });
303 let payload = sonic_rs::to_vec(&(10.0_f64, 2.0_f64)).unwrap();
304 let resp = table.call("divide", payload).unwrap();
305 let result: f64 = sonic_rs::from_slice(&resp).unwrap();
306 assert!((result - 5.0).abs() < f64::EPSILON);
307 }
308
309 #[test]
310 fn register_json_result_err() {
311 let table = Router::new();
312 table.register_json_result("divide", |args: (f64, f64)| -> Result<f64, String> {
313 if args.1 == 0.0 {
314 Err("division by zero".into())
315 } else {
316 Ok(args.0 / args.1)
317 }
318 });
319 let payload = sonic_rs::to_vec(&(10.0_f64, 0.0_f64)).unwrap();
320 let err = table.call("divide", payload).unwrap_err();
321 assert!(matches!(err, Error::Handler(ref msg) if msg == "division by zero"));
322 }
323
324 #[test]
325 fn register_json_result_bad_input() {
326 let table = Router::new();
327 table.register_json_result("divide", |args: (f64, f64)| -> Result<f64, String> {
328 Ok(args.0 / args.1)
329 });
330 let err = table.call("divide", b"garbage".to_vec()).unwrap_err();
331 assert!(matches!(err, Error::Serialize(_)));
332 }
333
334 #[derive(Debug, PartialEq)]
338 struct Pair(u32, u32);
339
340 impl crate::codec::Encode for Pair {
341 fn encode(&self, buf: &mut Vec<u8>) {
342 self.0.encode(buf);
343 self.1.encode(buf);
344 }
345 fn encode_size(&self) -> usize {
346 8
347 }
348 }
349
350 impl crate::codec::Decode for Pair {
351 fn decode(data: &[u8]) -> Option<(Self, usize)> {
352 let (a, ca) = u32::decode(data)?;
353 let (b, cb) = u32::decode(&data[ca..])?;
354 Some((Pair(a, b), ca + cb))
355 }
356 }
357
358 #[test]
359 fn register_binary_roundtrip() {
360 let table = Router::new();
361 table.register_binary("sum", |p: Pair| p.0 + p.1);
362 let mut payload = Vec::new();
363 Pair(10, 20).encode(&mut payload);
364 let resp = table.call("sum", payload).unwrap();
365 let (result, _) = u32::decode(&resp).unwrap();
366 assert_eq!(result, 30);
367 }
368
369 #[test]
370 fn register_binary_bad_input() {
371 let table = Router::new();
372 table.register_binary("sum", |p: Pair| p.0 + p.1);
373 let err = table.call("sum", vec![1, 2, 3]).unwrap_err();
375 assert!(matches!(err, Error::DecodeFailed));
376 }
377
378 #[test]
381 fn register_with_context_basic() {
382 let table = Router::new();
383 table.register_with_context("echo_ctx", |payload: Vec<u8>, _ctx: &dyn std::any::Any| {
384 Ok(payload)
385 });
386 let resp = table.call("echo_ctx", b"hello".to_vec()).unwrap();
387 assert_eq!(resp, b"hello");
388 }
389
390 #[test]
391 fn call_with_context_passes_through() {
392 let table = Router::new();
393 table.register_with_context("check_ctx", |_payload: Vec<u8>, ctx: &dyn std::any::Any| {
394 if ctx.downcast_ref::<String>().is_some() {
396 Ok(b"got string".to_vec())
397 } else {
398 Ok(b"no string".to_vec())
399 }
400 });
401
402 let ctx = String::from("hello");
403 let resp = table.call_with_context("check_ctx", vec![], &ctx).unwrap();
404 assert_eq!(resp, b"got string");
405
406 let resp = table.call("check_ctx", vec![]).unwrap();
408 assert_eq!(resp, b"no string");
409 }
410
411 #[test]
412 fn call_or_error_bytes_with_context_success() {
413 let table = Router::new();
414 table.register("echo", |payload: Vec<u8>| payload);
415 let ctx = String::from("unused");
416 let resp = table.call_or_error_bytes_with_context("echo", b"hello".to_vec(), &ctx);
417 assert_eq!(resp, b"hello");
418 }
419
420 #[test]
421 fn call_or_error_bytes_with_context_unknown() {
422 let table = Router::new();
423 let resp = table.call_or_error_bytes_with_context("nope", vec![], &());
424 assert_eq!(resp, b"unknown command: nope");
425 }
426
427 #[test]
428 fn register_replaces_handler() {
429 let table = Router::new();
430 table.register("cmd", |_payload: Vec<u8>| b"first".to_vec());
431 table.register("cmd", |_payload: Vec<u8>| b"second".to_vec());
432 let resp = table.call("cmd", vec![]).unwrap();
433 assert_eq!(resp, b"second");
434 }
435
436 #[test]
437 fn register_with_context_error_propagation() {
438 let table = Router::new();
439 table.register_with_context("fail", |_payload: Vec<u8>, _ctx: &dyn std::any::Any| {
440 Err(Error::Handler("context handler failed".into()))
441 });
442 let err = table.call("fail", vec![]).unwrap_err();
443 assert!(matches!(err, Error::Handler(ref msg) if msg == "context handler failed"));
444
445 let bytes = table.call_or_error_bytes("fail", vec![]);
447 assert_eq!(bytes, b"handler error: context handler failed");
448 }
449}