1use crate::traits::{Handled, Handler, HandlerError};
2use crate::{CanonicalMessage, MessageContext};
3use async_trait::async_trait;
4use serde::de::DeserializeOwned;
5use std::collections::HashMap;
6use std::future::Future;
7use std::sync::Arc;
8
9#[derive(Clone)]
27pub struct TypeHandler {
28 pub(crate) handlers: HashMap<String, Arc<dyn Handler>>,
29 pub(crate) type_key: String, pub(crate) fallback: Option<Arc<dyn Handler>>,
31}
32
33pub const KIND_KEY: &str = "kind";
34
35pub trait IntoTypedHandler<T, Args>: Send + Sync + 'static {
37 type Future: Future<Output = Result<Handled, HandlerError>> + Send + 'static;
38 fn call(&self, msg: T, ctx: MessageContext) -> Self::Future;
39}
40
41impl<F, Fut, T> IntoTypedHandler<T, (T,)> for F
42where
43 T: DeserializeOwned + Send + Sync + 'static,
44 F: Fn(T) -> Fut + Send + Sync + 'static,
45 Fut: Future<Output = Result<Handled, HandlerError>> + Send + 'static,
46{
47 type Future = Fut;
48 fn call(&self, msg: T, _ctx: MessageContext) -> Self::Future {
49 (self)(msg)
50 }
51}
52
53impl<F, Fut, T> IntoTypedHandler<T, (T, MessageContext)> for F
54where
55 T: DeserializeOwned + Send + Sync + 'static,
56 F: Fn(T, MessageContext) -> Fut + Send + Sync + 'static,
57 Fut: Future<Output = Result<Handled, HandlerError>> + Send + 'static,
58{
59 type Future = Fut;
60 fn call(&self, msg: T, ctx: MessageContext) -> Self::Future {
61 (self)(msg, ctx)
62 }
63}
64
65impl Default for TypeHandler {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl TypeHandler {
72 pub fn new() -> Self {
74 Self {
75 handlers: HashMap::new(),
76 type_key: KIND_KEY.into(),
77 fallback: None,
78 }
79 }
80
81 pub fn add_handler(mut self, type_name: &str, handler: impl Handler + 'static) -> Self {
83 self.handlers
84 .insert(type_name.to_string(), Arc::new(handler));
85 self
86 }
87
88 pub fn with_fallback(mut self, handler: Arc<dyn Handler>) -> Self {
90 self.fallback = Some(handler);
91 self
92 }
93
94 #[doc(hidden)]
95 pub fn add_simple<T, F, Fut>(self, type_name: &str, handler: F) -> Self
96 where
97 T: DeserializeOwned + Send + Sync + 'static,
98 F: Fn(T) -> Fut + Send + Sync + 'static,
99 Fut: Future<Output = Result<Handled, HandlerError>> + Send + 'static,
100 {
101 self.add(type_name, handler)
102 }
103
104 pub fn add<T, H, Args>(mut self, type_name: &str, handler: H) -> Self
110 where
111 T: DeserializeOwned + Send + Sync + 'static,
112 H: IntoTypedHandler<T, Args>,
113 Args: Send + Sync + 'static,
114 {
115 let handler = Arc::new(handler);
116 let wrapper = move |msg: CanonicalMessage| {
117 let handler = handler.clone();
118 async move {
119 let data = msg.parse::<T>().map_err(|e| {
120 HandlerError::NonRetryable(anyhow::anyhow!("Deserialization failed: {}", e))
121 })?;
122 let ctx = MessageContext::from(msg);
123 handler.call(data, ctx).await
124 }
125 };
126 self.handlers
127 .insert(type_name.to_string(), Arc::new(wrapper));
128 self
129 }
130}
131
132#[async_trait]
133impl Handler for TypeHandler {
134 async fn handle(&self, msg: CanonicalMessage) -> Result<Handled, HandlerError> {
135 if let Some(type_val) = msg.metadata.get(&self.type_key) {
136 if let Some(handler) = self.handlers.get(type_val) {
137 return handler.handle(msg).await;
138 }
139 }
140
141 if let Some(fallback) = &self.fallback {
142 return fallback.handle(msg).await;
143 }
144
145 Err(HandlerError::NonRetryable(anyhow::anyhow!(
146 "No handler registered for type: '{:?}' and no fallback provided",
147 msg.metadata.get(&self.type_key)
148 )))
149 }
150
151 fn register_handler(
152 &self,
153 type_name: &str,
154 handler: Arc<dyn Handler>,
155 ) -> Option<Arc<dyn Handler>> {
156 let mut th = self.clone();
157 th.handlers.insert(type_name.to_string(), handler);
158 Some(Arc::new(th))
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::msg;
166 use serde::{Deserialize, Serialize};
167
168 #[derive(Serialize, Deserialize)]
169 struct TestMsg {
170 val: String,
171 }
172
173 #[tokio::test]
174 async fn test_typed_handler_dispatch() {
175 let handler = TypeHandler::new().add("test_a", |msg: TestMsg| async move {
176 assert_eq!(msg.val, "hello");
177 Ok(Handled::Ack)
178 });
179
180 let msg = msg!(
181 &TestMsg {
182 val: "hello".into(),
183 },
184 "test_a"
185 );
186
187 let res = handler.handle(msg).await;
188 assert!(res.is_ok());
189 }
190
191 #[tokio::test]
192 async fn test_typed_handler_with_context() {
193 let handler =
194 TypeHandler::new().add("test_ctx", |msg: TestMsg, ctx: MessageContext| async move {
195 assert_eq!(msg.val, "hello");
196 assert_eq!(ctx.metadata.get("meta").map(|s| s.as_str()), Some("data"));
197 Ok(Handled::Ack)
198 });
199
200 let msg = CanonicalMessage::from_type(&TestMsg {
201 val: "hello".into(),
202 })
203 .unwrap()
204 .with_metadata(HashMap::from([
205 ("kind".to_string(), "test_ctx".to_string()),
206 ("meta".to_string(), "data".to_string()),
207 ]));
208
209 let res = handler.handle(msg).await;
210 assert!(res.is_ok());
211 }
212
213 #[tokio::test]
214 async fn test_typed_handler_no_match_error() {
215 let handler = TypeHandler::new();
216 let msg = msg!(b"{}".to_vec(), "unknown");
217
218 let res = handler.handle(msg).await;
219 assert!(res.is_err());
220 match res.unwrap_err() {
221 HandlerError::NonRetryable(e) => {
222 assert!(e.to_string().contains("No handler registered"))
223 }
224 _ => panic!("Expected NonRetryable error"),
225 }
226 }
227
228 #[tokio::test]
229 async fn test_typed_handler_fallback_ack() {
230 let fallback = Arc::new(|_: CanonicalMessage| async { Ok(Handled::Ack) });
231 let handler = TypeHandler::new().with_fallback(fallback);
232
233 let msg = msg!(b"{}".to_vec(), "unknown");
234
235 let res = handler.handle(msg).await;
236 assert!(matches!(res, Ok(Handled::Ack)));
237 }
238
239 #[tokio::test]
240 async fn test_typed_handler_failure() {
241 let handler = TypeHandler::new().add("fail", |_: TestMsg| async {
242 Err(HandlerError::Retryable(anyhow::anyhow!("failure")))
243 });
244
245 let msg = CanonicalMessage::from_type(&TestMsg { val: "x".into() })
246 .unwrap()
247 .with_type_key("fail");
248
249 let res = handler.handle(msg).await;
250 assert!(matches!(res, Err(HandlerError::Retryable(_))));
251 }
252
253 #[tokio::test]
254 async fn test_typed_handler_missing_type_key() {
255 let handler = TypeHandler::new().add("test", |_: TestMsg| async { Ok(Handled::Ack) });
256
257 let msg = CanonicalMessage::new(b"{}".to_vec(), None);
259
260 let res = handler.handle(msg).await;
261 assert!(res.is_err());
262 }
263
264 #[tokio::test]
265 async fn test_typed_handler_deserialization_failure() {
266 let handler = TypeHandler::new().add("test", |_: TestMsg| async { Ok(Handled::Ack) });
267
268 let msg = CanonicalMessage::new(b"{}".to_vec(), None)
270 .with_metadata(HashMap::from([("kind".to_string(), "test".to_string())]));
271
272 let res = handler.handle(msg).await;
273 assert!(matches!(res, Err(HandlerError::NonRetryable(_))));
274 }
275
276 #[tokio::test]
277 async fn test_cqrs_pattern_example() {
278 #[derive(Serialize, Deserialize)]
279 struct SubmitOrder {
280 id: u32,
281 }
282
283 #[derive(Serialize, Deserialize)]
284 struct OrderSubmitted {
285 id: u32,
286 }
287
288 let command_bus = TypeHandler::new().add("submit_order", |cmd: SubmitOrder| async move {
290 let evt = OrderSubmitted { id: cmd.id };
293 Ok(Handled::Publish(msg!(&evt, "order_submitted")))
294 });
295
296 let projection_handler =
298 TypeHandler::new().add("order_submitted", |evt: OrderSubmitted| async move {
299 assert_eq!(evt.id, 101);
301 Ok(Handled::Ack)
302 });
303
304 let cmd = SubmitOrder { id: 101 };
306 let cmd_msg = msg!(&cmd, "submit_order");
307
308 let result = command_bus.handle(cmd_msg).await.unwrap();
310
311 if let Handled::Publish(event_msg) = result {
312 assert_eq!(
314 event_msg.metadata.get("kind").map(|s| s.as_str()),
315 Some("order_submitted")
316 );
317
318 let proj_result = projection_handler.handle(event_msg).await.unwrap();
320 assert!(matches!(proj_result, Handled::Ack));
321 } else {
322 panic!("Expected Handled::Publish");
323 }
324 }
325
326 #[tokio::test]
327 async fn test_cqrs_integration_with_routes() {
328 use crate::models::{Endpoint, Route};
329 use std::sync::atomic::{AtomicU32, Ordering};
330
331 #[derive(Serialize, Deserialize)]
332 struct SubmitOrder {
333 id: u32,
334 }
335
336 #[derive(Serialize, Deserialize)]
337 struct OrderSubmitted {
338 id: u32,
339 }
340
341 let read_model_state = Arc::new(AtomicU32::new(0));
343 let read_model_clone = read_model_state.clone();
344
345 let command_handler =
347 TypeHandler::new().add("submit_order", |cmd: SubmitOrder| async move {
348 let evt = OrderSubmitted { id: cmd.id };
349 Ok(Handled::Publish(msg!(&evt, "order_submitted")))
350 });
351
352 let event_handler =
354 TypeHandler::new().add("order_submitted", move |evt: OrderSubmitted| {
355 let state = read_model_clone.clone();
356 async move {
357 state.store(evt.id, Ordering::SeqCst);
358 Ok(Handled::Ack)
359 }
360 });
361
362 let cmd_in_ep = Endpoint::new_memory("cmd_in", 10);
364 let event_bus_ep = Endpoint::new_memory("event_bus", 10);
365 let proj_out_ep = Endpoint::new_memory("proj_out", 10);
366
367 let command_route =
368 Route::new(cmd_in_ep.clone(), event_bus_ep.clone()).with_handler(command_handler);
369
370 let event_route =
371 Route::new(event_bus_ep.clone(), proj_out_ep.clone()).with_handler(event_handler);
372
373 let h1 = tokio::spawn(async move {
375 command_route
376 .run_until_err("command_route", None, None)
377 .await
378 });
379 let h2 =
380 tokio::spawn(async move { event_route.run_until_err("event_route", None, None).await });
381
382 let cmd_channel = cmd_in_ep.channel().unwrap();
384 let cmd = SubmitOrder { id: 777 };
385 let msg = CanonicalMessage::from_type(&cmd)
386 .unwrap()
387 .with_type_key("submit_order");
388 cmd_channel.send_message(msg).await.unwrap();
389
390 let mut attempts = 0;
392 while read_model_state.load(Ordering::SeqCst) != 777 && attempts < 50 {
393 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
394 attempts += 1;
395 }
396
397 assert_eq!(read_model_state.load(Ordering::SeqCst), 777);
398
399 cmd_channel.close();
401 event_bus_ep.channel().unwrap().close();
402
403 let _ = h1.await;
404 let _ = h2.await;
405 }
406}