1use std::{
2 any::{TypeId, type_name, type_name_of_val},
3 sync::Arc,
4};
5
6use async_trait::async_trait;
7use dashmap::DashMap;
8
9use crate::{
10 bus_types::{BoxAnySend, HandlerFn},
11 context::AppContext,
12 error::AppError,
13 query_bus::QueryBus,
14 query_handler::QueryHandler,
15};
16
17pub struct InMemoryQueryBus {
21 handlers: DashMap<(TypeId, TypeId), (&'static str, HandlerFn)>,
23}
24
25impl Default for InMemoryQueryBus {
26 fn default() -> Self {
27 Self {
28 handlers: DashMap::new(),
29 }
30 }
31}
32
33impl InMemoryQueryBus {
34 pub fn new() -> Self {
35 Self::default()
36 }
37
38 pub fn register<Q, R, H>(&self, handler: Arc<H>) -> Result<(), AppError>
40 where
41 Q: Send + 'static,
42 R: Send + 'static,
43 H: QueryHandler<Q, R> + Send + Sync + 'static,
44 {
45 let key = (TypeId::of::<Q>(), TypeId::of::<R>());
46
47 let f: HandlerFn = {
48 let handler = handler.clone();
49
50 Arc::new(move |boxed_q, ctx| {
51 let handler = handler.clone();
52
53 Box::pin(async move {
54 match boxed_q.downcast::<Q>() {
55 Ok(q) => {
56 let dto_opt = handler.handle(ctx, *q).await?;
57 Ok(Box::new(dto_opt) as BoxAnySend)
58 }
59 Err(e) => Err(AppError::type_mismatch(
60 type_name::<Q>(),
61 type_name_of_val(&e),
62 )),
63 }
64 })
65 })
66 };
67
68 if self.handlers.contains_key(&key) {
69 return Err(AppError::handler_already_registered(&format!(
70 "{}->{}",
71 type_name::<Q>(),
72 type_name::<R>()
73 )));
74 }
75
76 self.handlers.insert(key, (type_name::<Q>(), f));
77
78 Ok(())
79 }
80}
81
82#[async_trait]
83impl QueryBus for InMemoryQueryBus {
84 async fn dispatch<Q, R>(&self, ctx: &AppContext, q: Q) -> Result<R, AppError>
85 where
86 Q: Send + 'static,
87 R: Send + 'static,
88 {
89 self.dispatch_impl::<Q, R>(ctx, q).await
90 }
91}
92
93impl InMemoryQueryBus {
94 async fn dispatch_impl<Q, R>(&self, ctx: &AppContext, q: Q) -> Result<R, AppError>
95 where
96 Q: Send + 'static,
97 R: Send + 'static,
98 {
99 let key = (TypeId::of::<Q>(), TypeId::of::<R>());
100 let Some((_name, f)) = self.handlers.get(&key).map(|h| h.clone()) else {
101 return Err(AppError::handler_not_found(type_name::<Q>()));
102 };
103
104 let out = (f)(Box::new(q), ctx).await?;
105
106 match out.downcast::<R>() {
107 Ok(dto_opt) => Ok(*dto_opt),
108 Err(e) => Err(AppError::type_mismatch(
109 type_name::<R>(),
110 type_name_of_val(&e),
111 )),
112 }
113 }
114}
115
116impl InMemoryQueryBus {
117 pub fn registered_queries(&self) -> Vec<&'static str> {
119 self.handlers.iter().map(|e| e.value().0).collect()
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use std::sync::atomic::{AtomicUsize, Ordering};
126
127 use eventide_domain::error::ErrorCode;
128 use serde::Serialize;
129 use tokio::task::JoinSet;
130
131 use super::*;
132 use crate::{error::AppError, query_handler::QueryHandler};
133
134 #[derive(Debug)]
135 struct Get;
136
137 #[derive(Debug, Serialize)]
138 struct NumDto(pub usize);
139
140 struct GetHandler {
141 counter: Arc<AtomicUsize>,
142 }
143
144 #[async_trait]
145 impl QueryHandler<Get, NumDto> for GetHandler {
146 async fn handle(&self, _ctx: &AppContext, _q: Get) -> Result<NumDto, AppError> {
147 let v = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
148 Ok(NumDto(v))
149 }
150 }
151
152 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
153 async fn register_and_dispatch_works() {
154 let bus = InMemoryQueryBus::new();
155 let counter = Arc::new(AtomicUsize::new(0));
156 bus.register::<Get, NumDto, _>(Arc::new(GetHandler {
157 counter: counter.clone(),
158 }))
159 .unwrap();
160
161 let ctx = AppContext::default();
162 let NumDto(n) = bus.dispatch::<Get, NumDto>(&ctx, Get).await.unwrap();
163 assert_eq!(n, 1);
164 }
165
166 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
167 async fn not_found_error_when_unregistered() {
168 let bus = InMemoryQueryBus::new();
169 let ctx = AppContext::default();
170 let err = bus.dispatch::<Get, NumDto>(&ctx, Get).await.unwrap_err();
171 assert_eq!(err.code(), "HANDLER_NOT_FOUND");
172 assert!(err.to_string().contains("Get"));
173 }
174
175 #[derive(Debug, Serialize)]
176 struct WrongDto;
177
178 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
179 async fn type_mismatch_error_when_result_downcast_fails() {
180 let bus = InMemoryQueryBus::new();
181 let f: HandlerFn = Arc::new(|_boxed_q, _ctx| {
183 Box::pin(async move { Ok(Box::new(WrongDto) as BoxAnySend) })
184 });
185 bus.handlers.insert(
186 (TypeId::of::<Get>(), TypeId::of::<NumDto>()),
187 (type_name::<Get>(), f),
188 );
189
190 let ctx = AppContext::default();
191 let err = bus.dispatch::<Get, NumDto>(&ctx, Get).await.unwrap_err();
192 assert_eq!(err.code(), "TYPE_MISMATCH");
193 assert!(err.to_string().contains("NumDto"));
194 }
195
196 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
197 async fn concurrent_dispatch_is_safe() {
198 let bus = Arc::new(InMemoryQueryBus::new());
199 let counter = Arc::new(AtomicUsize::new(0));
200 bus.register::<Get, NumDto, _>(Arc::new(GetHandler {
201 counter: counter.clone(),
202 }))
203 .unwrap();
204
205 let mut set = JoinSet::new();
206 let ctx = AppContext::default();
207 for _ in 0..100 {
208 let bus = bus.clone();
209 let ctx = ctx.clone();
210 set.spawn(async move { bus.dispatch::<Get, NumDto>(&ctx, Get).await.unwrap() });
211 }
212 let mut results = Vec::new();
213 while let Some(res) = set.join_next().await {
214 results.push(res.unwrap().0);
215 }
216 results.sort_unstable();
217 assert_eq!(results.len(), 100);
218 assert_eq!(results[0], 1);
219 assert_eq!(results[99], 100);
220 }
221
222 #[derive(Debug)]
223 struct Get2;
224
225 #[derive(Debug, Serialize, PartialEq, Eq)]
226 struct NameDto(pub String);
227
228 struct Get2NumHandler;
229 struct Get2NameHandler;
230
231 #[async_trait]
232 impl QueryHandler<Get2, NumDto> for Get2NumHandler {
233 async fn handle(&self, _ctx: &AppContext, _q: Get2) -> Result<NumDto, AppError> {
234 Ok(NumDto(42))
235 }
236 }
237
238 #[async_trait]
239 impl QueryHandler<Get2, NameDto> for Get2NameHandler {
240 async fn handle(&self, _ctx: &AppContext, _q: Get2) -> Result<NameDto, AppError> {
241 Ok(NameDto("Alice".to_string()))
242 }
243 }
244
245 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
246 async fn same_query_with_different_results() {
247 let bus = InMemoryQueryBus::new();
249 bus.register::<Get2, NumDto, _>(Arc::new(Get2NumHandler))
250 .unwrap();
251 bus.register::<Get2, NameDto, _>(Arc::new(Get2NameHandler))
252 .unwrap();
253
254 let ctx = AppContext::default();
255 let NumDto(n) = bus.dispatch::<Get2, NumDto>(&ctx, Get2).await.unwrap();
256 let NameDto(name) = bus.dispatch::<Get2, NameDto>(&ctx, Get2).await.unwrap();
257
258 assert_eq!(n, 42);
259 assert_eq!(name, "Alice");
260 }
261}