Skip to main content

mq_bridge/
type_handler.rs

1use crate::response::ErgonomicResponse;
2use crate::traits::{Handled, Handler, HandlerError};
3use crate::{CanonicalMessage, MessageContext};
4use async_trait::async_trait;
5use futures::future::BoxFuture;
6use serde::de::DeserializeOwned;
7use std::collections::HashMap;
8use std::future::Future;
9use std::sync::Arc;
10
11/// A handler that dispatches messages to other handlers based on a metadata field (e.g., "type").
12///
13/// # Example
14/// ```rust
15/// use mq_bridge::type_handler::TypeHandler;
16/// use mq_bridge::{CanonicalMessage, Handled};
17/// use serde::Deserialize;
18///
19/// #[derive(Deserialize)]
20/// struct MyCommand { id: String }
21///
22/// # fn main() {
23/// let handler = TypeHandler::new()
24///     .add("my_command", |cmd: MyCommand| async move {
25///         println!("Received command: {}", cmd.id);
26///     }); // No return needed! Implicitly returns (), which maps to Handled::Ack.
27/// # }
28/// ```
29#[derive(Clone)]
30pub struct TypeHandler {
31    pub(crate) handlers: HashMap<String, Arc<dyn Handler>>,
32    pub(crate) type_key: String, // will be the key in msg metadata, default is "kind"
33    pub(crate) fallback: Option<Arc<dyn Handler>>,
34}
35
36pub const KIND_KEY: &str = "kind";
37
38/// A helper trait to allow registering handlers with or without context.
39pub trait IntoTypedHandler<T, Args>: Send + Sync + 'static {
40    type Future: Future<Output = Result<Handled, HandlerError>> + Send + 'static;
41    fn call(&self, msg: T, ctx: MessageContext) -> Self::Future;
42}
43
44/// Implementation for standard closures returning Result<Handled, HandlerError>.
45/// This path is unique and allows correct inference for legacy code.
46impl<F, Fut, T> IntoTypedHandler<T, (T, Result<Handled, HandlerError>)> for F
47where
48    T: DeserializeOwned + Send + Sync + 'static,
49    F: Fn(T) -> Fut + Send + Sync + 'static,
50    Fut: Future<Output = Result<Handled, HandlerError>> + Send + 'static,
51{
52    type Future = Fut;
53    fn call(&self, msg: T, _ctx: MessageContext) -> Self::Future {
54        (self)(msg)
55    }
56}
57
58impl<F, Fut, T> IntoTypedHandler<T, (T, MessageContext, Result<Handled, HandlerError>)> for F
59where
60    T: DeserializeOwned + Send + Sync + 'static,
61    F: Fn(T, MessageContext) -> Fut + Send + Sync + 'static,
62    Fut: Future<Output = Result<Handled, HandlerError>> + Send + 'static,
63{
64    type Future = Fut;
65    fn call(&self, msg: T, ctx: MessageContext) -> Self::Future {
66        (self)(msg, ctx)
67    }
68}
69
70/// Ergonomic implementation for closures returning types that implement ErgonomicResponse.
71impl<F, Fut, T, R> IntoTypedHandler<T, (T, R)> for F
72where
73    T: DeserializeOwned + Send + Sync + 'static,
74    R: ErgonomicResponse,
75    F: Fn(T) -> Fut + Send + Sync + 'static,
76    Fut: Future<Output = R> + Send + 'static,
77{
78    type Future = BoxFuture<'static, Result<Handled, HandlerError>>;
79    fn call(&self, msg: T, _ctx: MessageContext) -> Self::Future {
80        let fut = (self)(msg);
81        Box::pin(async move { fut.await.into_handler_result() })
82    }
83}
84
85impl<F, Fut, T, R> IntoTypedHandler<T, (T, MessageContext, R)> for F
86where
87    T: DeserializeOwned + Send + Sync + 'static,
88    R: ErgonomicResponse,
89    F: Fn(T, MessageContext) -> Fut + Send + Sync + 'static,
90    Fut: Future<Output = R> + Send + 'static,
91{
92    type Future = BoxFuture<'static, Result<Handled, HandlerError>>;
93    fn call(&self, msg: T, ctx: MessageContext) -> Self::Future {
94        let fut = (self)(msg, ctx);
95        Box::pin(async move { fut.await.into_handler_result() })
96    }
97}
98
99impl Default for TypeHandler {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl TypeHandler {
106    /// Creates a new TypeHandler that looks for the specified key in message metadata to determine the message type.
107    pub fn new() -> Self {
108        Self {
109            handlers: HashMap::new(),
110            type_key: KIND_KEY.into(),
111            fallback: None,
112        }
113    }
114
115    /// Registers a generic handler for a specific type name.
116    pub fn add_handler(mut self, type_name: &str, handler: impl Handler + 'static) -> Self {
117        self.handlers
118            .insert(type_name.to_string(), Arc::new(handler));
119        self
120    }
121
122    /// Sets a fallback handler to be used when no type match is found.
123    pub fn with_fallback(mut self, handler: Arc<dyn Handler>) -> Self {
124        self.fallback = Some(handler);
125        self
126    }
127
128    #[doc(hidden)]
129    pub fn add_simple<T, H, Args>(self, type_name: &str, handler: H) -> Self
130    where
131        T: DeserializeOwned + Send + Sync + 'static,
132        H: IntoTypedHandler<T, Args>,
133        Args: Send + Sync + 'static,
134    {
135        self.add(type_name, handler)
136    }
137
138    /// Registers a typed handler function.
139    ///
140    /// The handler can accept either:
141    /// - `fn(T) -> Future<Output = Result<Handled, HandlerError>>`
142    /// - `fn(T, MessageContext) -> Future<Output = Result<Handled, HandlerError>>`
143    pub fn add<T, H, Args>(mut self, type_name: &str, handler: H) -> Self
144    where
145        T: DeserializeOwned + Send + Sync + 'static,
146        H: IntoTypedHandler<T, Args>,
147        Args: Send + Sync + 'static,
148    {
149        let handler = Arc::new(handler);
150        let wrapper = move |msg: CanonicalMessage| {
151            let handler = handler.clone();
152            async move {
153                let data = msg.parse::<T>().map_err(|e| {
154                    HandlerError::NonRetryable(anyhow::anyhow!("Deserialization failed: {}", e))
155                })?;
156                let ctx = MessageContext::from(msg);
157                handler.call(data, ctx).await
158            }
159        };
160        self.handlers
161            .insert(type_name.to_string(), Arc::new(wrapper));
162        self
163    }
164}
165
166#[async_trait]
167impl Handler for TypeHandler {
168    async fn handle(&self, msg: CanonicalMessage) -> Result<Handled, HandlerError> {
169        if let Some(type_val) = msg.metadata.get(&self.type_key) {
170            if let Some(handler) = self.handlers.get(type_val) {
171                return handler.handle(msg).await;
172            }
173        }
174
175        if let Some(fallback) = &self.fallback {
176            return fallback.handle(msg).await;
177        }
178
179        Err(HandlerError::NonRetryable(anyhow::anyhow!(
180            "No handler registered for type: '{:?}' and no fallback provided",
181            msg.metadata.get(&self.type_key)
182        )))
183    }
184
185    fn register_handler(
186        &self,
187        type_name: &str,
188        handler: Arc<dyn Handler>,
189    ) -> Option<Arc<dyn Handler>> {
190        let mut th = self.clone();
191        th.handlers.insert(type_name.to_string(), handler);
192        Some(Arc::new(th))
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use crate::msg;
200    use serde::{Deserialize, Serialize};
201
202    #[derive(Serialize, Deserialize)]
203    struct TestMsg {
204        val: String,
205    }
206
207    #[tokio::test]
208    async fn test_typed_handler_dispatch() {
209        let handler = TypeHandler::new().add("test_a", |msg: TestMsg| async move {
210            assert_eq!(msg.val, "hello");
211            Ok(Handled::Ack)
212        });
213
214        let msg = msg!(
215            &TestMsg {
216                val: "hello".into(),
217            },
218            "test_a"
219        );
220
221        let res = handler.handle(msg).await;
222        assert!(res.is_ok());
223    }
224
225    #[tokio::test]
226    async fn test_typed_handler_with_context() {
227        let handler =
228            TypeHandler::new().add("test_ctx", |msg: TestMsg, ctx: MessageContext| async move {
229                assert_eq!(msg.val, "hello");
230                assert_eq!(ctx.metadata.get("meta").map(|s| s.as_str()), Some("data"));
231                Ok(Handled::Ack)
232            });
233
234        let msg = CanonicalMessage::from_type(&TestMsg {
235            val: "hello".into(),
236        })
237        .unwrap()
238        .with_metadata(HashMap::from([
239            ("kind".to_string(), "test_ctx".to_string()),
240            ("meta".to_string(), "data".to_string()),
241        ]));
242
243        let res = handler.handle(msg).await;
244        assert!(res.is_ok());
245    }
246
247    #[tokio::test]
248    async fn test_typed_handler_no_match_error() {
249        let handler = TypeHandler::new();
250        let msg = msg!(b"{}".to_vec(), "unknown");
251
252        let res = handler.handle(msg).await;
253        assert!(res.is_err());
254        match res.unwrap_err() {
255            HandlerError::NonRetryable(e) => {
256                assert!(e.to_string().contains("No handler registered"))
257            }
258            _ => panic!("Expected NonRetryable error"),
259        }
260    }
261
262    #[tokio::test]
263    async fn test_typed_handler_fallback_ack() {
264        let fallback = Arc::new(|_: CanonicalMessage| async { Ok(Handled::Ack) });
265        let handler = TypeHandler::new().with_fallback(fallback);
266
267        let msg = msg!(b"{}".to_vec(), "unknown");
268
269        let res = handler.handle(msg).await;
270        assert!(matches!(res, Ok(Handled::Ack)));
271    }
272
273    #[tokio::test]
274    async fn test_typed_handler_failure() {
275        let handler = TypeHandler::new().add("fail", |_: TestMsg| async {
276            Result::<Handled, HandlerError>::Err(HandlerError::Retryable(anyhow::anyhow!(
277                "failure"
278            )))
279        });
280
281        let msg = CanonicalMessage::from_type(&TestMsg { val: "x".into() })
282            .unwrap()
283            .with_type_key("fail");
284
285        let res = handler.handle(msg).await;
286        assert!(matches!(res, Err(HandlerError::Retryable(_))));
287    }
288
289    #[tokio::test]
290    async fn test_typed_handler_missing_type_key() {
291        let handler = TypeHandler::new().add("test", |_: TestMsg| async { Ok(Handled::Ack) });
292
293        // Message without "kind" metadata
294        let msg = CanonicalMessage::new(b"{}".to_vec(), None);
295
296        let res = handler.handle(msg).await;
297        assert!(res.is_err());
298    }
299
300    #[tokio::test]
301    async fn test_typed_handler_deserialization_failure() {
302        let handler = TypeHandler::new().add("test", |_: TestMsg| async { Ok(Handled::Ack) });
303
304        // Invalid JSON for TestMsg (missing required field)
305        let msg = CanonicalMessage::new(b"{}".to_vec(), None)
306            .with_metadata(HashMap::from([("kind".to_string(), "test".to_string())]));
307
308        let res = handler.handle(msg).await;
309        assert!(matches!(res, Err(HandlerError::NonRetryable(_))));
310    }
311
312    #[tokio::test]
313    async fn test_cqrs_pattern_example() {
314        #[derive(Serialize, Deserialize)]
315        struct SubmitOrder {
316            id: u32,
317        }
318
319        #[derive(Serialize, Deserialize)]
320        struct OrderSubmitted {
321            id: u32,
322        }
323
324        // 1. Command Handler (Write Side)
325        let command_bus = TypeHandler::new().add("submit_order", |cmd: SubmitOrder| async move {
326            // Execute business logic...
327            // Emit event
328            let evt = OrderSubmitted { id: cmd.id };
329            Ok(Handled::Publish(msg!(&evt, "order_submitted")))
330        });
331
332        // 2. Event Handler (Read Side / Projection)
333        let projection_handler =
334            TypeHandler::new().add("order_submitted", |evt: OrderSubmitted| async move {
335                // Update read database / cache
336                assert_eq!(evt.id, 101);
337                Ok(Handled::Ack)
338            });
339
340        // Simulate incoming command
341        let cmd = SubmitOrder { id: 101 };
342        let cmd_msg = msg!(&cmd, "submit_order");
343
344        // Process command
345        let result = command_bus.handle(cmd_msg).await.unwrap();
346
347        if let Handled::Publish(event_msg) = result {
348            // Verify event type
349            assert_eq!(
350                event_msg.metadata.get("kind").map(|s| s.as_str()),
351                Some("order_submitted")
352            );
353
354            // Process event (Projection)
355            let proj_result = projection_handler.handle(event_msg).await.unwrap();
356            assert!(matches!(proj_result, Handled::Ack));
357        } else {
358            panic!("Expected Handled::Publish");
359        }
360    }
361
362    #[tokio::test]
363    async fn test_cqrs_integration_with_routes() {
364        use crate::models::{Endpoint, Route};
365        use std::sync::atomic::{AtomicU32, Ordering};
366
367        #[derive(Serialize, Deserialize)]
368        struct SubmitOrder {
369            id: u32,
370        }
371
372        #[derive(Serialize, Deserialize)]
373        struct OrderSubmitted {
374            id: u32,
375        }
376
377        // Shared state to verify projection update
378        let read_model_state = Arc::new(AtomicU32::new(0));
379        let read_model_clone = read_model_state.clone();
380
381        // 1. Command Handler (Write Side)
382        let command_handler =
383            TypeHandler::new().add("submit_order", |cmd: SubmitOrder| async move {
384                let evt = OrderSubmitted { id: cmd.id };
385                Ok(Handled::Publish(msg!(&evt, "order_submitted")))
386            });
387
388        // 2. Event Handler (Read Side)
389        let event_handler =
390            TypeHandler::new().add("order_submitted", move |evt: OrderSubmitted| {
391                let state = read_model_clone.clone();
392                async move {
393                    state.store(evt.id, Ordering::SeqCst);
394                    Ok(Handled::Ack)
395                }
396            });
397
398        // 3. Define Endpoints & Routes
399        let cmd_in_ep = Endpoint::new_memory("cmd_in", 10);
400        let event_bus_ep = Endpoint::new_memory("event_bus", 10);
401        let proj_out_ep = Endpoint::new_memory("proj_out", 10);
402
403        let command_route =
404            Route::new(cmd_in_ep.clone(), event_bus_ep.clone()).with_handler(command_handler);
405
406        let event_route =
407            Route::new(event_bus_ep.clone(), proj_out_ep.clone()).with_handler(event_handler);
408
409        // 4. Run Routes
410        let h1 = tokio::spawn(async move {
411            command_route
412                .run_until_err("command_route", None, None)
413                .await
414        });
415        let h2 =
416            tokio::spawn(async move { event_route.run_until_err("event_route", None, None).await });
417
418        // 5. Send Command
419        let cmd_channel = cmd_in_ep.channel().unwrap();
420        let cmd = SubmitOrder { id: 777 };
421        let msg = CanonicalMessage::from_type(&cmd)
422            .unwrap()
423            .with_type_key("submit_order");
424        cmd_channel.send_message(msg).await.unwrap();
425
426        // 6. Wait for consistency
427        let mut attempts = 0;
428        while read_model_state.load(Ordering::SeqCst) != 777 && attempts < 50 {
429            tokio::time::sleep(std::time::Duration::from_millis(20)).await;
430            attempts += 1;
431        }
432
433        assert_eq!(read_model_state.load(Ordering::SeqCst), 777);
434
435        // Cleanup
436        cmd_channel.close();
437        event_bus_ep.channel().unwrap().close();
438
439        let _ = h1.await;
440        let _ = h2.await;
441    }
442}