1use std::{
2 any::{TypeId, type_name, type_name_of_val},
3 sync::Arc,
4};
5
6use async_trait::async_trait;
7use dashmap::DashMap;
8
9use crate::{
10 bus_types::{BoxAnySend, HandlerFn},
11 command_bus::CommandBus,
12 command_handler::CommandHandler,
13 context::AppContext,
14 error::AppError,
15};
16
17pub struct InMemoryCommandBus {
21 handlers: DashMap<(TypeId, TypeId), (&'static str, HandlerFn)>,
22}
23
24impl Default for InMemoryCommandBus {
25 fn default() -> Self {
26 Self {
27 handlers: DashMap::new(),
28 }
29 }
30}
31
32impl InMemoryCommandBus {
33 pub fn new() -> Self {
34 Self::default()
35 }
36
37 pub fn register<C, R, H>(&self, handler: Arc<H>) -> Result<(), AppError>
39 where
40 C: Send + 'static,
41 R: Send + 'static,
42 H: CommandHandler<C, R> + Send + Sync + 'static,
43 {
44 let key = (TypeId::of::<C>(), TypeId::of::<R>());
45
46 let f: HandlerFn = {
47 let handler = handler.clone();
48
49 Arc::new(move |boxed_cmd, ctx| {
50 let handler = handler.clone();
51
52 Box::pin(async move {
53 match boxed_cmd.downcast::<C>() {
54 Ok(cmd) => {
55 let result = handler.handle(ctx, *cmd).await?;
56 Ok(Box::new(result) as BoxAnySend)
57 }
58 Err(e) => {
59 let found = type_name_of_val(&e);
60 Err(AppError::type_mismatch(type_name::<C>(), found))
61 }
62 }
63 })
64 })
65 };
66
67 if self.handlers.contains_key(&key) {
68 return Err(AppError::handler_already_registered(&format!(
69 "{}->{}",
70 type_name::<C>(),
71 type_name::<R>()
72 )));
73 }
74
75 self.handlers.insert(key, (type_name::<C>(), f));
76
77 Ok(())
78 }
79}
80
81#[async_trait]
82impl CommandBus for InMemoryCommandBus {
83 async fn dispatch<C, R>(&self, ctx: &AppContext, cmd: C) -> Result<R, AppError>
84 where
85 C: Send + 'static,
86 R: Send + 'static,
87 {
88 self.dispatch_impl::<C, R>(ctx, cmd).await
89 }
90}
91
92impl InMemoryCommandBus {
93 async fn dispatch_impl<C, R>(&self, ctx: &AppContext, cmd: C) -> Result<R, AppError>
94 where
95 C: Send + 'static,
96 R: Send + 'static,
97 {
98 let key = (TypeId::of::<C>(), TypeId::of::<R>());
99 let Some((_name, f)) = self.handlers.get(&key).map(|h| h.clone()) else {
100 return Err(AppError::handler_not_found(type_name::<C>()));
101 };
102
103 let out = (f)(Box::new(cmd), ctx).await?;
104
105 match out.downcast::<R>() {
106 Ok(result) => Ok(*result),
107 Err(e) => Err(AppError::type_mismatch(
108 type_name::<R>(),
109 type_name_of_val(&e),
110 )),
111 }
112 }
113}
114
115impl InMemoryCommandBus {
116 pub fn registered_commands(&self) -> Vec<&'static str> {
118 self.handlers.iter().map(|e| e.value().0).collect()
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use std::sync::atomic::{AtomicUsize, Ordering};
125
126 use eventide_domain::error::ErrorCode;
127 use tokio::task::JoinSet;
128
129 use super::*;
130 use crate::{command_handler::CommandHandler, error::AppError};
131
132 #[derive(Debug)]
133 struct Add;
134
135 #[derive(Debug, PartialEq, Eq)]
136 struct AddResult(pub usize);
137
138 struct AddHandler {
139 counter: Arc<AtomicUsize>,
140 }
141
142 #[async_trait]
143 impl CommandHandler<Add, AddResult> for AddHandler {
144 async fn handle(&self, _ctx: &AppContext, _cmd: Add) -> Result<AddResult, AppError> {
145 let v = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
146 Ok(AddResult(v))
147 }
148 }
149
150 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
152 async fn register_and_dispatch_works() {
153 let bus = InMemoryCommandBus::new();
154 let counter = Arc::new(AtomicUsize::new(0));
155 bus.register::<Add, AddResult, _>(Arc::new(AddHandler {
156 counter: counter.clone(),
157 }))
158 .unwrap();
159
160 let ctx = AppContext::default();
161 let AddResult(n) = bus.dispatch::<Add, AddResult>(&ctx, Add).await.unwrap();
162 assert_eq!(n, 1);
163 }
164
165 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
167 async fn not_found_error_when_unregistered() {
168 let bus = InMemoryCommandBus::new();
169 let ctx = AppContext::default();
170 let err = bus.dispatch::<Add, AddResult>(&ctx, Add).await.unwrap_err();
171 assert_eq!(err.code(), "HANDLER_NOT_FOUND");
172 assert!(err.to_string().contains("Add"));
173 }
174
175 #[derive(Debug)]
176 struct WrongResult;
177
178 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
180 async fn type_mismatch_error_when_result_downcast_fails() {
181 let bus = InMemoryCommandBus::new();
182 let f: HandlerFn = Arc::new(|_boxed_cmd, _ctx| {
184 Box::pin(async move { Ok(Box::new(WrongResult) as BoxAnySend) })
185 });
186 bus.handlers.insert(
187 (TypeId::of::<Add>(), TypeId::of::<AddResult>()),
188 (type_name::<Add>(), f),
189 );
190
191 let ctx = AppContext::default();
192 let err = bus.dispatch::<Add, AddResult>(&ctx, Add).await.unwrap_err();
193 assert_eq!(err.code(), "TYPE_MISMATCH");
194 assert!(err.to_string().contains("AddResult"));
195 }
196
197 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
199 async fn concurrent_dispatch_is_safe() {
200 let bus = Arc::new(InMemoryCommandBus::new());
201 let counter = Arc::new(AtomicUsize::new(0));
202 bus.register::<Add, AddResult, _>(Arc::new(AddHandler {
203 counter: counter.clone(),
204 }))
205 .unwrap();
206
207 let mut set = JoinSet::new();
208 let ctx = AppContext::default();
209 for _ in 0..100 {
210 let bus = bus.clone();
211 let ctx = ctx.clone();
212 set.spawn(async move { bus.dispatch::<Add, AddResult>(&ctx, Add).await.unwrap() });
213 }
214
215 let mut results = Vec::new();
216 while let Some(res) = set.join_next().await {
217 results.push(res.unwrap().0);
218 }
219 results.sort_unstable();
220 assert_eq!(results.len(), 100);
221 assert_eq!(results[0], 1);
222 assert_eq!(results[99], 100);
223 }
224
225 #[derive(Debug)]
227 struct VoidCmd;
228
229 struct VoidHandler;
230
231 #[async_trait]
232 impl CommandHandler<VoidCmd, ()> for VoidHandler {
233 async fn handle(&self, _ctx: &AppContext, _cmd: VoidCmd) -> Result<(), AppError> {
234 Ok(())
235 }
236 }
237
238 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
240 async fn void_result_works() {
241 let bus = InMemoryCommandBus::new();
242 bus.register::<VoidCmd, (), _>(Arc::new(VoidHandler))
243 .unwrap();
244
245 let ctx = AppContext::default();
246 bus.dispatch::<VoidCmd, ()>(&ctx, VoidCmd).await.unwrap();
247 }
248}