Skip to main content

mq_bridge/
type_handler.rs

1use crate::traits::{Handled, Handler, HandlerError};
2use crate::{CanonicalMessage, MessageContext};
3use async_trait::async_trait;
4use serde::de::DeserializeOwned;
5use std::collections::HashMap;
6use std::future::Future;
7use std::sync::Arc;
8
9/// A handler that dispatches messages to other handlers based on a metadata field (e.g., "type").
10///
11/// # Example
12/// ```rust
13/// use mq_bridge::type_handler::TypeHandler;
14/// use mq_bridge::{CanonicalMessage, Handled};
15/// use serde::Deserialize;
16///
17/// #[derive(Deserialize)]
18/// struct MyCommand { id: String }
19///
20/// let handler = TypeHandler::new()
21///     .add("my_command", |cmd: MyCommand| async move {
22///         println!("Received command: {}", cmd.id);
23///         Ok(Handled::Ack)
24///     });
25/// ```
26#[derive(Clone)]
27pub struct TypeHandler {
28    pub(crate) handlers: HashMap<String, Arc<dyn Handler>>,
29    pub(crate) type_key: String, // will be the key in msg metadata, default is "kind"
30    pub(crate) fallback: Option<Arc<dyn Handler>>,
31}
32
33pub const KIND_KEY: &str = "kind";
34
35/// A helper trait to allow registering handlers with or without context.
36pub trait IntoTypedHandler<T, Args>: Send + Sync + 'static {
37    type Future: Future<Output = Result<Handled, HandlerError>> + Send + 'static;
38    fn call(&self, msg: T, ctx: MessageContext) -> Self::Future;
39}
40
41impl<F, Fut, T> IntoTypedHandler<T, (T,)> for F
42where
43    T: DeserializeOwned + Send + Sync + 'static,
44    F: Fn(T) -> Fut + Send + Sync + 'static,
45    Fut: Future<Output = Result<Handled, HandlerError>> + Send + 'static,
46{
47    type Future = Fut;
48    fn call(&self, msg: T, _ctx: MessageContext) -> Self::Future {
49        (self)(msg)
50    }
51}
52
53impl<F, Fut, T> IntoTypedHandler<T, (T, MessageContext)> for F
54where
55    T: DeserializeOwned + Send + Sync + 'static,
56    F: Fn(T, MessageContext) -> Fut + Send + Sync + 'static,
57    Fut: Future<Output = Result<Handled, HandlerError>> + Send + 'static,
58{
59    type Future = Fut;
60    fn call(&self, msg: T, ctx: MessageContext) -> Self::Future {
61        (self)(msg, ctx)
62    }
63}
64
65impl Default for TypeHandler {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl TypeHandler {
72    /// Creates a new TypeHandler that looks for the specified key in message metadata to determine the message type.
73    pub fn new() -> Self {
74        Self {
75            handlers: HashMap::new(),
76            type_key: KIND_KEY.into(),
77            fallback: None,
78        }
79    }
80
81    /// Registers a generic handler for a specific type name.
82    pub fn add_handler(mut self, type_name: &str, handler: impl Handler + 'static) -> Self {
83        self.handlers
84            .insert(type_name.to_string(), Arc::new(handler));
85        self
86    }
87
88    /// Sets a fallback handler to be used when no type match is found.
89    pub fn with_fallback(mut self, handler: Arc<dyn Handler>) -> Self {
90        self.fallback = Some(handler);
91        self
92    }
93
94    #[doc(hidden)]
95    pub fn add_simple<T, F, Fut>(self, type_name: &str, handler: F) -> Self
96    where
97        T: DeserializeOwned + Send + Sync + 'static,
98        F: Fn(T) -> Fut + Send + Sync + 'static,
99        Fut: Future<Output = Result<Handled, HandlerError>> + Send + 'static,
100    {
101        self.add(type_name, handler)
102    }
103
104    /// Registers a typed handler function.
105    ///
106    /// The handler can accept either:
107    /// - `fn(T) -> Future<Output = Result<Handled, HandlerError>>`
108    /// - `fn(T, MessageContext) -> Future<Output = Result<Handled, HandlerError>>`
109    pub fn add<T, H, Args>(mut self, type_name: &str, handler: H) -> Self
110    where
111        T: DeserializeOwned + Send + Sync + 'static,
112        H: IntoTypedHandler<T, Args>,
113        Args: Send + Sync + 'static,
114    {
115        let handler = Arc::new(handler);
116        let wrapper = move |msg: CanonicalMessage| {
117            let handler = handler.clone();
118            async move {
119                let data = msg.parse::<T>().map_err(|e| {
120                    HandlerError::NonRetryable(anyhow::anyhow!("Deserialization failed: {}", e))
121                })?;
122                let ctx = MessageContext::from(msg);
123                handler.call(data, ctx).await
124            }
125        };
126        self.handlers
127            .insert(type_name.to_string(), Arc::new(wrapper));
128        self
129    }
130}
131
132#[async_trait]
133impl Handler for TypeHandler {
134    async fn handle(&self, msg: CanonicalMessage) -> Result<Handled, HandlerError> {
135        if let Some(type_val) = msg.metadata.get(&self.type_key) {
136            if let Some(handler) = self.handlers.get(type_val) {
137                return handler.handle(msg).await;
138            }
139        }
140
141        if let Some(fallback) = &self.fallback {
142            return fallback.handle(msg).await;
143        }
144
145        Err(HandlerError::NonRetryable(anyhow::anyhow!(
146            "No handler registered for type: '{:?}' and no fallback provided",
147            msg.metadata.get(&self.type_key)
148        )))
149    }
150
151    fn register_handler(
152        &self,
153        type_name: &str,
154        handler: Arc<dyn Handler>,
155    ) -> Option<Arc<dyn Handler>> {
156        let mut th = self.clone();
157        th.handlers.insert(type_name.to_string(), handler);
158        Some(Arc::new(th))
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::msg;
166    use serde::{Deserialize, Serialize};
167
168    #[derive(Serialize, Deserialize)]
169    struct TestMsg {
170        val: String,
171    }
172
173    #[tokio::test]
174    async fn test_typed_handler_dispatch() {
175        let handler = TypeHandler::new().add("test_a", |msg: TestMsg| async move {
176            assert_eq!(msg.val, "hello");
177            Ok(Handled::Ack)
178        });
179
180        let msg = msg!(
181            &TestMsg {
182                val: "hello".into(),
183            },
184            "test_a"
185        );
186
187        let res = handler.handle(msg).await;
188        assert!(res.is_ok());
189    }
190
191    #[tokio::test]
192    async fn test_typed_handler_with_context() {
193        let handler =
194            TypeHandler::new().add("test_ctx", |msg: TestMsg, ctx: MessageContext| async move {
195                assert_eq!(msg.val, "hello");
196                assert_eq!(ctx.metadata.get("meta").map(|s| s.as_str()), Some("data"));
197                Ok(Handled::Ack)
198            });
199
200        let msg = CanonicalMessage::from_type(&TestMsg {
201            val: "hello".into(),
202        })
203        .unwrap()
204        .with_metadata(HashMap::from([
205            ("kind".to_string(), "test_ctx".to_string()),
206            ("meta".to_string(), "data".to_string()),
207        ]));
208
209        let res = handler.handle(msg).await;
210        assert!(res.is_ok());
211    }
212
213    #[tokio::test]
214    async fn test_typed_handler_no_match_error() {
215        let handler = TypeHandler::new();
216        let msg = msg!(b"{}".to_vec(), "unknown");
217
218        let res = handler.handle(msg).await;
219        assert!(res.is_err());
220        match res.unwrap_err() {
221            HandlerError::NonRetryable(e) => {
222                assert!(e.to_string().contains("No handler registered"))
223            }
224            _ => panic!("Expected NonRetryable error"),
225        }
226    }
227
228    #[tokio::test]
229    async fn test_typed_handler_fallback_ack() {
230        let fallback = Arc::new(|_: CanonicalMessage| async { Ok(Handled::Ack) });
231        let handler = TypeHandler::new().with_fallback(fallback);
232
233        let msg = msg!(b"{}".to_vec(), "unknown");
234
235        let res = handler.handle(msg).await;
236        assert!(matches!(res, Ok(Handled::Ack)));
237    }
238
239    #[tokio::test]
240    async fn test_typed_handler_failure() {
241        let handler = TypeHandler::new().add("fail", |_: TestMsg| async {
242            Err(HandlerError::Retryable(anyhow::anyhow!("failure")))
243        });
244
245        let msg = CanonicalMessage::from_type(&TestMsg { val: "x".into() })
246            .unwrap()
247            .with_type_key("fail");
248
249        let res = handler.handle(msg).await;
250        assert!(matches!(res, Err(HandlerError::Retryable(_))));
251    }
252
253    #[tokio::test]
254    async fn test_typed_handler_missing_type_key() {
255        let handler = TypeHandler::new().add("test", |_: TestMsg| async { Ok(Handled::Ack) });
256
257        // Message without "kind" metadata
258        let msg = CanonicalMessage::new(b"{}".to_vec(), None);
259
260        let res = handler.handle(msg).await;
261        assert!(res.is_err());
262    }
263
264    #[tokio::test]
265    async fn test_typed_handler_deserialization_failure() {
266        let handler = TypeHandler::new().add("test", |_: TestMsg| async { Ok(Handled::Ack) });
267
268        // Invalid JSON for TestMsg (missing required field)
269        let msg = CanonicalMessage::new(b"{}".to_vec(), None)
270            .with_metadata(HashMap::from([("kind".to_string(), "test".to_string())]));
271
272        let res = handler.handle(msg).await;
273        assert!(matches!(res, Err(HandlerError::NonRetryable(_))));
274    }
275
276    #[tokio::test]
277    async fn test_cqrs_pattern_example() {
278        #[derive(Serialize, Deserialize)]
279        struct SubmitOrder {
280            id: u32,
281        }
282
283        #[derive(Serialize, Deserialize)]
284        struct OrderSubmitted {
285            id: u32,
286        }
287
288        // 1. Command Handler (Write Side)
289        let command_bus = TypeHandler::new().add("submit_order", |cmd: SubmitOrder| async move {
290            // Execute business logic...
291            // Emit event
292            let evt = OrderSubmitted { id: cmd.id };
293            Ok(Handled::Publish(msg!(&evt, "order_submitted")))
294        });
295
296        // 2. Event Handler (Read Side / Projection)
297        let projection_handler =
298            TypeHandler::new().add("order_submitted", |evt: OrderSubmitted| async move {
299                // Update read database / cache
300                assert_eq!(evt.id, 101);
301                Ok(Handled::Ack)
302            });
303
304        // Simulate incoming command
305        let cmd = SubmitOrder { id: 101 };
306        let cmd_msg = msg!(&cmd, "submit_order");
307
308        // Process command
309        let result = command_bus.handle(cmd_msg).await.unwrap();
310
311        if let Handled::Publish(event_msg) = result {
312            // Verify event type
313            assert_eq!(
314                event_msg.metadata.get("kind").map(|s| s.as_str()),
315                Some("order_submitted")
316            );
317
318            // Process event (Projection)
319            let proj_result = projection_handler.handle(event_msg).await.unwrap();
320            assert!(matches!(proj_result, Handled::Ack));
321        } else {
322            panic!("Expected Handled::Publish");
323        }
324    }
325
326    #[tokio::test]
327    async fn test_cqrs_integration_with_routes() {
328        use crate::models::{Endpoint, Route};
329        use std::sync::atomic::{AtomicU32, Ordering};
330
331        #[derive(Serialize, Deserialize)]
332        struct SubmitOrder {
333            id: u32,
334        }
335
336        #[derive(Serialize, Deserialize)]
337        struct OrderSubmitted {
338            id: u32,
339        }
340
341        // Shared state to verify projection update
342        let read_model_state = Arc::new(AtomicU32::new(0));
343        let read_model_clone = read_model_state.clone();
344
345        // 1. Command Handler (Write Side)
346        let command_handler =
347            TypeHandler::new().add("submit_order", |cmd: SubmitOrder| async move {
348                let evt = OrderSubmitted { id: cmd.id };
349                Ok(Handled::Publish(msg!(&evt, "order_submitted")))
350            });
351
352        // 2. Event Handler (Read Side)
353        let event_handler =
354            TypeHandler::new().add("order_submitted", move |evt: OrderSubmitted| {
355                let state = read_model_clone.clone();
356                async move {
357                    state.store(evt.id, Ordering::SeqCst);
358                    Ok(Handled::Ack)
359                }
360            });
361
362        // 3. Define Endpoints & Routes
363        let cmd_in_ep = Endpoint::new_memory("cmd_in", 10);
364        let event_bus_ep = Endpoint::new_memory("event_bus", 10);
365        let proj_out_ep = Endpoint::new_memory("proj_out", 10);
366
367        let command_route =
368            Route::new(cmd_in_ep.clone(), event_bus_ep.clone()).with_handler(command_handler);
369
370        let event_route =
371            Route::new(event_bus_ep.clone(), proj_out_ep.clone()).with_handler(event_handler);
372
373        // 4. Run Routes
374        let h1 = tokio::spawn(async move {
375            command_route
376                .run_until_err("command_route", None, None)
377                .await
378        });
379        let h2 =
380            tokio::spawn(async move { event_route.run_until_err("event_route", None, None).await });
381
382        // 5. Send Command
383        let cmd_channel = cmd_in_ep.channel().unwrap();
384        let cmd = SubmitOrder { id: 777 };
385        let msg = CanonicalMessage::from_type(&cmd)
386            .unwrap()
387            .with_type_key("submit_order");
388        cmd_channel.send_message(msg).await.unwrap();
389
390        // 6. Wait for consistency
391        let mut attempts = 0;
392        while read_model_state.load(Ordering::SeqCst) != 777 && attempts < 50 {
393            tokio::time::sleep(std::time::Duration::from_millis(20)).await;
394            attempts += 1;
395        }
396
397        assert_eq!(read_model_state.load(Ordering::SeqCst), 777);
398
399        // Cleanup
400        cmd_channel.close();
401        event_bus_ep.channel().unwrap().close();
402
403        let _ = h1.await;
404        let _ = h2.await;
405    }
406}