1use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11use crate::error::Error;
12
13type BoxedHandler = Box<dyn Fn(Vec<u8>) -> Vec<u8> + Send + Sync>;
15
16pub struct Router {
18 handlers: RwLock<HashMap<String, Arc<BoxedHandler>>>,
19}
20
21impl Router {
22 pub fn new() -> Self {
24 Self {
25 handlers: RwLock::new(HashMap::new()),
26 }
27 }
28
29 pub fn register<F>(&self, name: impl Into<String>, handler: F)
33 where
34 F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync + 'static,
35 {
36 let boxed: BoxedHandler = Box::new(handler);
37 self.handlers
38 .write()
39 .unwrap_or_else(|e| e.into_inner())
40 .insert(name.into(), Arc::new(boxed));
41 }
42
43 pub fn register_simple<F>(&self, name: impl Into<String>, handler: F)
47 where
48 F: Fn() -> Vec<u8> + Send + Sync + 'static,
49 {
50 let boxed: BoxedHandler = Box::new(move |_payload| handler());
51 self.handlers
52 .write()
53 .unwrap_or_else(|e| e.into_inner())
54 .insert(name.into(), Arc::new(boxed));
55 }
56
57 pub fn call(&self, name: &str, payload: Vec<u8>) -> Result<Vec<u8>, Error> {
62 let handler = {
63 let handlers = self.handlers.read().unwrap_or_else(|e| e.into_inner());
64 handlers.get(name).cloned()
65 };
66 match handler {
67 Some(h) => Ok(h(payload)),
68 None => Err(Error::UnknownCommand(name.to_string())),
69 }
70 }
71
72 #[must_use]
79 pub fn call_or_error_bytes(&self, name: &str, payload: Vec<u8>) -> Vec<u8> {
80 match self.call(name, payload) {
81 Ok(bytes) => bytes,
82 Err(e) => e.to_string().into_bytes(),
83 }
84 }
85
86 #[must_use]
88 pub fn has(&self, name: &str) -> bool {
89 self.handlers
90 .read()
91 .unwrap_or_else(|e| e.into_inner())
92 .contains_key(name)
93 }
94}
95
96impl std::fmt::Debug for Router {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 let count = self
99 .handlers
100 .read()
101 .unwrap_or_else(|e| e.into_inner())
102 .len();
103 f.debug_struct("Router")
104 .field("handler_count", &count)
105 .finish()
106 }
107}
108
109impl Default for Router {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn register_and_dispatch() {
121 let table = Router::new();
122 table.register("echo", |payload: Vec<u8>| payload);
123 let resp = table.call("echo", b"hello".to_vec()).unwrap();
124 assert_eq!(resp, b"hello");
125 }
126
127 #[test]
128 fn unknown_command() {
129 let table = Router::new();
130 let err = table.call("nope", vec![]).unwrap_err();
131 assert!(matches!(err, Error::UnknownCommand(ref name) if name == "nope"));
132 assert_eq!(err.to_string(), "unknown command: nope");
133 }
134
135 #[test]
136 fn has_command() {
137 let table = Router::new();
138 assert!(!table.has("ping"));
139 table.register("ping", |_payload: Vec<u8>| b"pong".to_vec());
140 assert!(table.has("ping"));
141 }
142
143 #[test]
144 fn register_simple_test() {
145 let table = Router::new();
146 table.register_simple("version", || b"1.0".to_vec());
147 let resp = table.call("version", vec![0xFF]).unwrap();
148 assert_eq!(resp, b"1.0");
149 }
150
151 #[test]
152 fn call_or_error_bytes_success() {
153 let table = Router::new();
154 table.register("echo", |payload: Vec<u8>| payload);
155 let resp = table.call_or_error_bytes("echo", b"hello".to_vec());
156 assert_eq!(resp, b"hello");
157 }
158
159 #[test]
160 fn call_or_error_bytes_unknown() {
161 let table = Router::new();
162 let resp = table.call_or_error_bytes("nope", vec![]);
163 assert_eq!(resp, b"unknown command: nope");
164 }
165}