1use crate::response::ErgonomicResponse;
2use crate::traits::{Handled, Handler, HandlerError};
3use crate::{CanonicalMessage, MessageContext};
4use async_trait::async_trait;
5use futures::future::BoxFuture;
6use serde::de::DeserializeOwned;
7use std::collections::HashMap;
8use std::future::Future;
9use std::sync::Arc;
10
11#[derive(Clone)]
30pub struct TypeHandler {
31 pub(crate) handlers: HashMap<String, Arc<dyn Handler>>,
32 pub(crate) type_key: String, pub(crate) fallback: Option<Arc<dyn Handler>>,
34}
35
36pub const KIND_KEY: &str = "kind";
37
38pub trait IntoTypedHandler<T, Args>: Send + Sync + 'static {
40 type Future: Future<Output = Result<Handled, HandlerError>> + Send + 'static;
41 fn call(&self, msg: T, ctx: MessageContext) -> Self::Future;
42}
43
44impl<F, Fut, T> IntoTypedHandler<T, (T, Result<Handled, HandlerError>)> for F
47where
48 T: DeserializeOwned + Send + Sync + 'static,
49 F: Fn(T) -> Fut + Send + Sync + 'static,
50 Fut: Future<Output = Result<Handled, HandlerError>> + Send + 'static,
51{
52 type Future = Fut;
53 fn call(&self, msg: T, _ctx: MessageContext) -> Self::Future {
54 (self)(msg)
55 }
56}
57
58impl<F, Fut, T> IntoTypedHandler<T, (T, MessageContext, Result<Handled, HandlerError>)> for F
59where
60 T: DeserializeOwned + Send + Sync + 'static,
61 F: Fn(T, MessageContext) -> Fut + Send + Sync + 'static,
62 Fut: Future<Output = Result<Handled, HandlerError>> + Send + 'static,
63{
64 type Future = Fut;
65 fn call(&self, msg: T, ctx: MessageContext) -> Self::Future {
66 (self)(msg, ctx)
67 }
68}
69
70impl<F, Fut, T, R> IntoTypedHandler<T, (T, R)> for F
72where
73 T: DeserializeOwned + Send + Sync + 'static,
74 R: ErgonomicResponse,
75 F: Fn(T) -> Fut + Send + Sync + 'static,
76 Fut: Future<Output = R> + Send + 'static,
77{
78 type Future = BoxFuture<'static, Result<Handled, HandlerError>>;
79 fn call(&self, msg: T, _ctx: MessageContext) -> Self::Future {
80 let fut = (self)(msg);
81 Box::pin(async move { fut.await.into_handler_result() })
82 }
83}
84
85impl<F, Fut, T, R> IntoTypedHandler<T, (T, MessageContext, R)> for F
86where
87 T: DeserializeOwned + Send + Sync + 'static,
88 R: ErgonomicResponse,
89 F: Fn(T, MessageContext) -> Fut + Send + Sync + 'static,
90 Fut: Future<Output = R> + Send + 'static,
91{
92 type Future = BoxFuture<'static, Result<Handled, HandlerError>>;
93 fn call(&self, msg: T, ctx: MessageContext) -> Self::Future {
94 let fut = (self)(msg, ctx);
95 Box::pin(async move { fut.await.into_handler_result() })
96 }
97}
98
99impl Default for TypeHandler {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105impl TypeHandler {
106 pub fn new() -> Self {
108 Self {
109 handlers: HashMap::new(),
110 type_key: KIND_KEY.into(),
111 fallback: None,
112 }
113 }
114
115 pub fn add_handler(mut self, type_name: &str, handler: impl Handler + 'static) -> Self {
117 self.handlers
118 .insert(type_name.to_string(), Arc::new(handler));
119 self
120 }
121
122 pub fn with_fallback(mut self, handler: Arc<dyn Handler>) -> Self {
124 self.fallback = Some(handler);
125 self
126 }
127
128 #[doc(hidden)]
129 pub fn add_simple<T, H, Args>(self, type_name: &str, handler: H) -> Self
130 where
131 T: DeserializeOwned + Send + Sync + 'static,
132 H: IntoTypedHandler<T, Args>,
133 Args: Send + Sync + 'static,
134 {
135 self.add(type_name, handler)
136 }
137
138 pub fn add<T, H, Args>(mut self, type_name: &str, handler: H) -> Self
144 where
145 T: DeserializeOwned + Send + Sync + 'static,
146 H: IntoTypedHandler<T, Args>,
147 Args: Send + Sync + 'static,
148 {
149 let handler = Arc::new(handler);
150 let wrapper = move |msg: CanonicalMessage| {
151 let handler = handler.clone();
152 async move {
153 let data = msg.parse::<T>().map_err(|e| {
154 HandlerError::NonRetryable(anyhow::anyhow!("Deserialization failed: {}", e))
155 })?;
156 let ctx = MessageContext::from(msg);
157 handler.call(data, ctx).await
158 }
159 };
160 self.handlers
161 .insert(type_name.to_string(), Arc::new(wrapper));
162 self
163 }
164}
165
166#[async_trait]
167impl Handler for TypeHandler {
168 async fn handle(&self, msg: CanonicalMessage) -> Result<Handled, HandlerError> {
169 if let Some(type_val) = msg.metadata.get(&self.type_key) {
170 if let Some(handler) = self.handlers.get(type_val) {
171 return handler.handle(msg).await;
172 }
173 }
174
175 if let Some(fallback) = &self.fallback {
176 return fallback.handle(msg).await;
177 }
178
179 Err(HandlerError::NonRetryable(anyhow::anyhow!(
180 "No handler registered for type: '{:?}' and no fallback provided",
181 msg.metadata.get(&self.type_key)
182 )))
183 }
184
185 fn register_handler(
186 &self,
187 type_name: &str,
188 handler: Arc<dyn Handler>,
189 ) -> Option<Arc<dyn Handler>> {
190 let mut th = self.clone();
191 th.handlers.insert(type_name.to_string(), handler);
192 Some(Arc::new(th))
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::msg;
200 use serde::{Deserialize, Serialize};
201
202 #[derive(Serialize, Deserialize)]
203 struct TestMsg {
204 val: String,
205 }
206
207 #[tokio::test]
208 async fn test_typed_handler_dispatch() {
209 let handler = TypeHandler::new().add("test_a", |msg: TestMsg| async move {
210 assert_eq!(msg.val, "hello");
211 Ok(Handled::Ack)
212 });
213
214 let msg = msg!(
215 &TestMsg {
216 val: "hello".into(),
217 },
218 "test_a"
219 );
220
221 let res = handler.handle(msg).await;
222 assert!(res.is_ok());
223 }
224
225 #[tokio::test]
226 async fn test_typed_handler_with_context() {
227 let handler =
228 TypeHandler::new().add("test_ctx", |msg: TestMsg, ctx: MessageContext| async move {
229 assert_eq!(msg.val, "hello");
230 assert_eq!(ctx.metadata.get("meta").map(|s| s.as_str()), Some("data"));
231 Ok(Handled::Ack)
232 });
233
234 let msg = CanonicalMessage::from_type(&TestMsg {
235 val: "hello".into(),
236 })
237 .unwrap()
238 .with_metadata(HashMap::from([
239 ("kind".to_string(), "test_ctx".to_string()),
240 ("meta".to_string(), "data".to_string()),
241 ]));
242
243 let res = handler.handle(msg).await;
244 assert!(res.is_ok());
245 }
246
247 #[tokio::test]
248 async fn test_typed_handler_no_match_error() {
249 let handler = TypeHandler::new();
250 let msg = msg!(b"{}".to_vec(), "unknown");
251
252 let res = handler.handle(msg).await;
253 assert!(res.is_err());
254 match res.unwrap_err() {
255 HandlerError::NonRetryable(e) => {
256 assert!(e.to_string().contains("No handler registered"))
257 }
258 _ => panic!("Expected NonRetryable error"),
259 }
260 }
261
262 #[tokio::test]
263 async fn test_typed_handler_fallback_ack() {
264 let fallback = Arc::new(|_: CanonicalMessage| async { Ok(Handled::Ack) });
265 let handler = TypeHandler::new().with_fallback(fallback);
266
267 let msg = msg!(b"{}".to_vec(), "unknown");
268
269 let res = handler.handle(msg).await;
270 assert!(matches!(res, Ok(Handled::Ack)));
271 }
272
273 #[tokio::test]
274 async fn test_typed_handler_failure() {
275 let handler = TypeHandler::new().add("fail", |_: TestMsg| async {
276 Result::<Handled, HandlerError>::Err(HandlerError::Retryable(anyhow::anyhow!(
277 "failure"
278 )))
279 });
280
281 let msg = CanonicalMessage::from_type(&TestMsg { val: "x".into() })
282 .unwrap()
283 .with_type_key("fail");
284
285 let res = handler.handle(msg).await;
286 assert!(matches!(res, Err(HandlerError::Retryable(_))));
287 }
288
289 #[tokio::test]
290 async fn test_typed_handler_missing_type_key() {
291 let handler = TypeHandler::new().add("test", |_: TestMsg| async { Ok(Handled::Ack) });
292
293 let msg = CanonicalMessage::new(b"{}".to_vec(), None);
295
296 let res = handler.handle(msg).await;
297 assert!(res.is_err());
298 }
299
300 #[tokio::test]
301 async fn test_typed_handler_deserialization_failure() {
302 let handler = TypeHandler::new().add("test", |_: TestMsg| async { Ok(Handled::Ack) });
303
304 let msg = CanonicalMessage::new(b"{}".to_vec(), None)
306 .with_metadata(HashMap::from([("kind".to_string(), "test".to_string())]));
307
308 let res = handler.handle(msg).await;
309 assert!(matches!(res, Err(HandlerError::NonRetryable(_))));
310 }
311
312 #[tokio::test]
313 async fn test_cqrs_pattern_example() {
314 #[derive(Serialize, Deserialize)]
315 struct SubmitOrder {
316 id: u32,
317 }
318
319 #[derive(Serialize, Deserialize)]
320 struct OrderSubmitted {
321 id: u32,
322 }
323
324 let command_bus = TypeHandler::new().add("submit_order", |cmd: SubmitOrder| async move {
326 let evt = OrderSubmitted { id: cmd.id };
329 Ok(Handled::Publish(msg!(&evt, "order_submitted")))
330 });
331
332 let projection_handler =
334 TypeHandler::new().add("order_submitted", |evt: OrderSubmitted| async move {
335 assert_eq!(evt.id, 101);
337 Ok(Handled::Ack)
338 });
339
340 let cmd = SubmitOrder { id: 101 };
342 let cmd_msg = msg!(&cmd, "submit_order");
343
344 let result = command_bus.handle(cmd_msg).await.unwrap();
346
347 if let Handled::Publish(event_msg) = result {
348 assert_eq!(
350 event_msg.metadata.get("kind").map(|s| s.as_str()),
351 Some("order_submitted")
352 );
353
354 let proj_result = projection_handler.handle(event_msg).await.unwrap();
356 assert!(matches!(proj_result, Handled::Ack));
357 } else {
358 panic!("Expected Handled::Publish");
359 }
360 }
361
362 #[tokio::test]
363 async fn test_cqrs_integration_with_routes() {
364 use crate::models::{Endpoint, Route};
365 use std::sync::atomic::{AtomicU32, Ordering};
366
367 #[derive(Serialize, Deserialize)]
368 struct SubmitOrder {
369 id: u32,
370 }
371
372 #[derive(Serialize, Deserialize)]
373 struct OrderSubmitted {
374 id: u32,
375 }
376
377 let read_model_state = Arc::new(AtomicU32::new(0));
379 let read_model_clone = read_model_state.clone();
380
381 let command_handler =
383 TypeHandler::new().add("submit_order", |cmd: SubmitOrder| async move {
384 let evt = OrderSubmitted { id: cmd.id };
385 Ok(Handled::Publish(msg!(&evt, "order_submitted")))
386 });
387
388 let event_handler =
390 TypeHandler::new().add("order_submitted", move |evt: OrderSubmitted| {
391 let state = read_model_clone.clone();
392 async move {
393 state.store(evt.id, Ordering::SeqCst);
394 Ok(Handled::Ack)
395 }
396 });
397
398 let cmd_in_ep = Endpoint::new_memory("cmd_in", 10);
400 let event_bus_ep = Endpoint::new_memory("event_bus", 10);
401 let proj_out_ep = Endpoint::new_memory("proj_out", 10);
402
403 let command_route =
404 Route::new(cmd_in_ep.clone(), event_bus_ep.clone()).with_handler(command_handler);
405
406 let event_route =
407 Route::new(event_bus_ep.clone(), proj_out_ep.clone()).with_handler(event_handler);
408
409 let h1 = tokio::spawn(async move {
411 command_route
412 .run_until_err("command_route", None, None)
413 .await
414 });
415 let h2 =
416 tokio::spawn(async move { event_route.run_until_err("event_route", None, None).await });
417
418 let cmd_channel = cmd_in_ep.channel().unwrap();
420 let cmd = SubmitOrder { id: 777 };
421 let msg = CanonicalMessage::from_type(&cmd)
422 .unwrap()
423 .with_type_key("submit_order");
424 cmd_channel.send_message(msg).await.unwrap();
425
426 let mut attempts = 0;
428 while read_model_state.load(Ordering::SeqCst) != 777 && attempts < 50 {
429 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
430 attempts += 1;
431 }
432
433 assert_eq!(read_model_state.load(Ordering::SeqCst), 777);
434
435 cmd_channel.close();
437 event_bus_ep.channel().unwrap().close();
438
439 let _ = h1.await;
440 let _ = h2.await;
441 }
442}