Skip to main content

conduit_core/
router.rs

1//! Command dispatch table with synchronous handlers.
2//!
3//! [`Router`] is a thread-safe named registry: each command name maps
4//! to a boxed function that receives a payload and returns a response.
5//! Handlers are synchronous — callers that need async work should use a
6//! channel or spawn internally.
7
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11use serde::Serialize;
12use serde::de::DeserializeOwned;
13
14use crate::codec::{Decode, Encode};
15use crate::error::Error;
16
17/// Boxed synchronous handler: takes payload bytes and an opaque context,
18/// returns response bytes or an [`Error`].
19///
20/// The context parameter (`&dyn std::any::Any`) allows handlers generated
21/// by the `#[command]` macro to extract `State<T>` from an `AppHandle`.
22/// Existing handler registration methods ignore the context parameter for
23/// backward compatibility.
24type BoxedHandler =
25    Box<dyn Fn(Vec<u8>, &dyn std::any::Any) -> Result<Vec<u8>, Error> + Send + Sync>;
26
27/// Named command registry with synchronous dispatch.
28pub struct Router {
29    handlers: RwLock<HashMap<String, Arc<BoxedHandler>>>,
30}
31
32impl Router {
33    /// Create an empty dispatch table.
34    pub fn new() -> Self {
35        Self {
36            handlers: RwLock::new(HashMap::new()),
37        }
38    }
39
40    /// Register a handler for a command name.
41    ///
42    /// If a handler was already registered under `name` it is replaced.
43    pub fn register<F>(&self, name: impl Into<String>, handler: F)
44    where
45        F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync + 'static,
46    {
47        let boxed: BoxedHandler = Box::new(move |payload, _ctx| Ok(handler(payload)));
48        crate::write_or_recover(&self.handlers).insert(name.into(), Arc::new(boxed));
49    }
50
51    /// Register a handler that takes no payload.
52    ///
53    /// The incoming payload bytes are silently discarded.
54    pub fn register_simple<F>(&self, name: impl Into<String>, handler: F)
55    where
56        F: Fn() -> Vec<u8> + Send + Sync + 'static,
57    {
58        let boxed: BoxedHandler = Box::new(move |_payload, _ctx| Ok(handler()));
59        crate::write_or_recover(&self.handlers).insert(name.into(), Arc::new(boxed));
60    }
61
62    /// Register a JSON handler for a command name.
63    ///
64    /// The incoming payload is deserialised from JSON into `A`, the handler
65    /// is called with the typed value, and the return value `R` is serialised
66    /// back to JSON bytes. Returns [`Error::Serialize`] on deserialisation
67    /// failure.
68    pub fn register_json<F, A, R>(&self, name: impl Into<String>, handler: F)
69    where
70        F: Fn(A) -> R + Send + Sync + 'static,
71        A: DeserializeOwned + 'static,
72        R: Serialize + 'static,
73    {
74        let boxed: BoxedHandler = Box::new(move |payload, _ctx| {
75            let arg: A = sonic_rs::from_slice(&payload).map_err(Error::from)?;
76            let result = handler(arg);
77            sonic_rs::to_vec(&result).map_err(Error::from)
78        });
79        crate::write_or_recover(&self.handlers).insert(name.into(), Arc::new(boxed));
80    }
81
82    /// Register a fallible JSON handler for a command name.
83    ///
84    /// Like [`register_json`](Self::register_json), but the handler returns
85    /// `Result<R, E>`. On `Ok(value)`, the value is serialised to JSON. On
86    /// `Err(e)`, the error's `Display` text is returned as
87    /// [`Error::Handler`].
88    pub fn register_json_result<F, A, R, E>(&self, name: impl Into<String>, handler: F)
89    where
90        F: Fn(A) -> Result<R, E> + Send + Sync + 'static,
91        A: DeserializeOwned + 'static,
92        R: Serialize + 'static,
93        E: std::fmt::Display + 'static,
94    {
95        let boxed: BoxedHandler = Box::new(move |payload, _ctx| {
96            let arg: A = sonic_rs::from_slice(&payload).map_err(Error::from)?;
97            let result = handler(arg).map_err(|e| Error::Handler(e.to_string()))?;
98            sonic_rs::to_vec(&result).map_err(Error::from)
99        });
100        crate::write_or_recover(&self.handlers).insert(name.into(), Arc::new(boxed));
101    }
102
103    /// Register a binary handler for a command name.
104    ///
105    /// The incoming payload is decoded via the [`Decode`] trait into `A`,
106    /// the handler is called with the typed value, and the return value `R`
107    /// is encoded via [`Encode`] back to bytes. Returns
108    /// [`Error::DecodeFailed`] if the payload cannot be decoded.
109    pub fn register_binary<F, A, R>(&self, name: impl Into<String>, handler: F)
110    where
111        F: Fn(A) -> R + Send + Sync + 'static,
112        A: Decode + 'static,
113        R: Encode + 'static,
114    {
115        let boxed: BoxedHandler = Box::new(move |payload, _ctx| {
116            let (arg, _consumed) = A::decode(&payload).ok_or(Error::DecodeFailed)?;
117            let result = handler(arg);
118            let mut buf = Vec::with_capacity(result.encode_size());
119            result.encode(&mut buf);
120            Ok(buf)
121        });
122        crate::write_or_recover(&self.handlers).insert(name.into(), Arc::new(boxed));
123    }
124
125    /// Register a context-aware handler.
126    ///
127    /// Handlers generated by the `#[conduit::command]` macro have the
128    /// signature `fn(Vec<u8>, &dyn Any) -> Result<Vec<u8>, Error>` and
129    /// handle their own deserialization, State extraction, and
130    /// serialization internally.
131    pub fn register_with_context<F>(&self, name: impl Into<String>, handler: F)
132    where
133        F: Fn(Vec<u8>, &dyn std::any::Any) -> Result<Vec<u8>, Error> + Send + Sync + 'static,
134    {
135        let boxed: BoxedHandler = Box::new(handler);
136        crate::write_or_recover(&self.handlers).insert(name.into(), Arc::new(boxed));
137    }
138
139    /// Dispatch a command by name with an opaque context.
140    ///
141    /// The context is passed through to the handler. For handlers
142    /// registered via `register_with_context` (i.e., `#[command]`-generated
143    /// handlers), the context is typically an `&AppHandle<Wry>` that enables
144    /// `State<T>` extraction.
145    pub fn call_with_context(
146        &self,
147        name: &str,
148        payload: Vec<u8>,
149        ctx: &dyn std::any::Any,
150    ) -> Result<Vec<u8>, Error> {
151        let handler = {
152            let handlers = crate::read_or_recover(&self.handlers);
153            handlers.get(name).cloned()
154        };
155        match handler {
156            Some(h) => h(payload, ctx),
157            None => Err(Error::UnknownCommand(name.to_string())),
158        }
159    }
160
161    /// Dispatch a command by name with context, returning raw bytes in all
162    /// cases.
163    ///
164    /// On success the handler's response bytes are returned. On failure the
165    /// error's `Display` text is returned as UTF-8 bytes.
166    #[must_use]
167    pub fn call_or_error_bytes_with_context(
168        &self,
169        name: &str,
170        payload: Vec<u8>,
171        ctx: &dyn std::any::Any,
172    ) -> Vec<u8> {
173        match self.call_with_context(name, payload, ctx) {
174            Ok(bytes) => bytes,
175            Err(e) => e.to_string().into_bytes(),
176        }
177    }
178
179    /// Dispatch a command by name.
180    ///
181    /// Returns the handler's response bytes on success, or
182    /// [`Error::UnknownCommand`] if no handler is registered for `name`.
183    pub fn call(&self, name: &str, payload: Vec<u8>) -> Result<Vec<u8>, Error> {
184        self.call_with_context(name, payload, &())
185    }
186
187    /// Dispatch a command by name, returning raw bytes in all cases.
188    ///
189    /// On success the handler's response bytes are returned. On failure the
190    /// error's `Display` text is returned as UTF-8 bytes. This is a
191    /// convenience wrapper for call sites (such as the custom protocol
192    /// handler) that must always produce a `Vec<u8>`.
193    #[must_use]
194    pub fn call_or_error_bytes(&self, name: &str, payload: Vec<u8>) -> Vec<u8> {
195        self.call_or_error_bytes_with_context(name, payload, &())
196    }
197
198    /// Check whether a command is registered.
199    #[must_use]
200    pub fn has(&self, name: &str) -> bool {
201        crate::read_or_recover(&self.handlers).contains_key(name)
202    }
203}
204
205impl std::fmt::Debug for Router {
206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207        let count = crate::read_or_recover(&self.handlers).len();
208        f.debug_struct("Router")
209            .field("handler_count", &count)
210            .finish()
211    }
212}
213
214impl Default for Router {
215    fn default() -> Self {
216        Self::new()
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn register_and_dispatch() {
226        let table = Router::new();
227        table.register("echo", |payload: Vec<u8>| payload);
228        let resp = table.call("echo", b"hello".to_vec()).unwrap();
229        assert_eq!(resp, b"hello");
230    }
231
232    #[test]
233    fn unknown_command() {
234        let table = Router::new();
235        let err = table.call("nope", vec![]).unwrap_err();
236        assert!(matches!(err, Error::UnknownCommand(ref name) if name == "nope"));
237        assert_eq!(err.to_string(), "unknown command: nope");
238    }
239
240    #[test]
241    fn has_command() {
242        let table = Router::new();
243        assert!(!table.has("ping"));
244        table.register("ping", |_payload: Vec<u8>| b"pong".to_vec());
245        assert!(table.has("ping"));
246    }
247
248    #[test]
249    fn register_simple_test() {
250        let table = Router::new();
251        table.register_simple("version", || b"1.0".to_vec());
252        let resp = table.call("version", vec![0xFF]).unwrap();
253        assert_eq!(resp, b"1.0");
254    }
255
256    #[test]
257    fn call_or_error_bytes_success() {
258        let table = Router::new();
259        table.register("echo", |payload: Vec<u8>| payload);
260        let resp = table.call_or_error_bytes("echo", b"hello".to_vec());
261        assert_eq!(resp, b"hello");
262    }
263
264    #[test]
265    fn call_or_error_bytes_unknown() {
266        let table = Router::new();
267        let resp = table.call_or_error_bytes("nope", vec![]);
268        assert_eq!(resp, b"unknown command: nope");
269    }
270
271    // -- JSON handler tests --------------------------------------------------
272
273    #[test]
274    fn register_json_roundtrip() {
275        let table = Router::new();
276        table.register_json("add", |args: (i32, i32)| args.0 + args.1);
277        let payload = sonic_rs::to_vec(&(3, 4)).unwrap();
278        let resp = table.call("add", payload).unwrap();
279        let result: i32 = sonic_rs::from_slice(&resp).unwrap();
280        assert_eq!(result, 7);
281    }
282
283    #[test]
284    fn register_json_bad_input() {
285        let table = Router::new();
286        table.register_json("add", |args: (i32, i32)| args.0 + args.1);
287        let err = table.call("add", b"not json!".to_vec()).unwrap_err();
288        assert!(matches!(err, Error::Serialize(_)));
289    }
290
291    // -- Fallible JSON handler tests -----------------------------------------
292
293    #[test]
294    fn register_json_result_ok() {
295        let table = Router::new();
296        table.register_json_result("divide", |args: (f64, f64)| -> Result<f64, String> {
297            if args.1 == 0.0 {
298                Err("division by zero".into())
299            } else {
300                Ok(args.0 / args.1)
301            }
302        });
303        let payload = sonic_rs::to_vec(&(10.0_f64, 2.0_f64)).unwrap();
304        let resp = table.call("divide", payload).unwrap();
305        let result: f64 = sonic_rs::from_slice(&resp).unwrap();
306        assert!((result - 5.0).abs() < f64::EPSILON);
307    }
308
309    #[test]
310    fn register_json_result_err() {
311        let table = Router::new();
312        table.register_json_result("divide", |args: (f64, f64)| -> Result<f64, String> {
313            if args.1 == 0.0 {
314                Err("division by zero".into())
315            } else {
316                Ok(args.0 / args.1)
317            }
318        });
319        let payload = sonic_rs::to_vec(&(10.0_f64, 0.0_f64)).unwrap();
320        let err = table.call("divide", payload).unwrap_err();
321        assert!(matches!(err, Error::Handler(ref msg) if msg == "division by zero"));
322    }
323
324    #[test]
325    fn register_json_result_bad_input() {
326        let table = Router::new();
327        table.register_json_result("divide", |args: (f64, f64)| -> Result<f64, String> {
328            Ok(args.0 / args.1)
329        });
330        let err = table.call("divide", b"garbage".to_vec()).unwrap_err();
331        assert!(matches!(err, Error::Serialize(_)));
332    }
333
334    // -- Binary handler tests ------------------------------------------------
335
336    /// Minimal newtype that implements Encode/Decode for testing.
337    #[derive(Debug, PartialEq)]
338    struct Pair(u32, u32);
339
340    impl crate::codec::Encode for Pair {
341        fn encode(&self, buf: &mut Vec<u8>) {
342            self.0.encode(buf);
343            self.1.encode(buf);
344        }
345        fn encode_size(&self) -> usize {
346            8
347        }
348    }
349
350    impl crate::codec::Decode for Pair {
351        fn decode(data: &[u8]) -> Option<(Self, usize)> {
352            let (a, ca) = u32::decode(data)?;
353            let (b, cb) = u32::decode(&data[ca..])?;
354            Some((Pair(a, b), ca + cb))
355        }
356    }
357
358    #[test]
359    fn register_binary_roundtrip() {
360        let table = Router::new();
361        table.register_binary("sum", |p: Pair| p.0 + p.1);
362        let mut payload = Vec::new();
363        Pair(10, 20).encode(&mut payload);
364        let resp = table.call("sum", payload).unwrap();
365        let (result, _) = u32::decode(&resp).unwrap();
366        assert_eq!(result, 30);
367    }
368
369    #[test]
370    fn register_binary_bad_input() {
371        let table = Router::new();
372        table.register_binary("sum", |p: Pair| p.0 + p.1);
373        // Only 3 bytes — too short for two u32 values.
374        let err = table.call("sum", vec![1, 2, 3]).unwrap_err();
375        assert!(matches!(err, Error::DecodeFailed));
376    }
377
378    // -- Context-aware handler tests -----------------------------------------
379
380    #[test]
381    fn register_with_context_basic() {
382        let table = Router::new();
383        table.register_with_context("echo_ctx", |payload: Vec<u8>, _ctx: &dyn std::any::Any| {
384            Ok(payload)
385        });
386        let resp = table.call("echo_ctx", b"hello".to_vec()).unwrap();
387        assert_eq!(resp, b"hello");
388    }
389
390    #[test]
391    fn call_with_context_passes_through() {
392        let table = Router::new();
393        table.register_with_context("check_ctx", |_payload: Vec<u8>, ctx: &dyn std::any::Any| {
394            // Check that we can downcast the context
395            if ctx.downcast_ref::<String>().is_some() {
396                Ok(b"got string".to_vec())
397            } else {
398                Ok(b"no string".to_vec())
399            }
400        });
401
402        let ctx = String::from("hello");
403        let resp = table.call_with_context("check_ctx", vec![], &ctx).unwrap();
404        assert_eq!(resp, b"got string");
405
406        // call() passes &() as context, so downcast to String fails
407        let resp = table.call("check_ctx", vec![]).unwrap();
408        assert_eq!(resp, b"no string");
409    }
410
411    #[test]
412    fn call_or_error_bytes_with_context_success() {
413        let table = Router::new();
414        table.register("echo", |payload: Vec<u8>| payload);
415        let ctx = String::from("unused");
416        let resp = table.call_or_error_bytes_with_context("echo", b"hello".to_vec(), &ctx);
417        assert_eq!(resp, b"hello");
418    }
419
420    #[test]
421    fn call_or_error_bytes_with_context_unknown() {
422        let table = Router::new();
423        let resp = table.call_or_error_bytes_with_context("nope", vec![], &());
424        assert_eq!(resp, b"unknown command: nope");
425    }
426
427    #[test]
428    fn register_replaces_handler() {
429        let table = Router::new();
430        table.register("cmd", |_payload: Vec<u8>| b"first".to_vec());
431        table.register("cmd", |_payload: Vec<u8>| b"second".to_vec());
432        let resp = table.call("cmd", vec![]).unwrap();
433        assert_eq!(resp, b"second");
434    }
435
436    #[test]
437    fn register_with_context_error_propagation() {
438        let table = Router::new();
439        table.register_with_context("fail", |_payload: Vec<u8>, _ctx: &dyn std::any::Any| {
440            Err(Error::Handler("context handler failed".into()))
441        });
442        let err = table.call("fail", vec![]).unwrap_err();
443        assert!(matches!(err, Error::Handler(ref msg) if msg == "context handler failed"));
444
445        // Also verify error bytes path
446        let bytes = table.call_or_error_bytes("fail", vec![]);
447        assert_eq!(bytes, b"handler error: context handler failed");
448    }
449}