Skip to main content

eventide_application/
inmemory_query_bus.rs

1use std::{
2    any::{TypeId, type_name, type_name_of_val},
3    sync::Arc,
4};
5
6use async_trait::async_trait;
7use dashmap::DashMap;
8
9use crate::{
10    bus_types::{BoxAnySend, HandlerFn},
11    context::AppContext,
12    error::AppError,
13    query_bus::QueryBus,
14    query_handler::QueryHandler,
15};
16
17/// 基于内存的 QueryBus 实现
18/// - 通过 TypeId 注册不同 Query 对应的 Handler
19/// - 以类型擦除方式调度,并在调用端进行结果还原
20pub struct InMemoryQueryBus {
21    // 使用 (QueryTypeId, ResultTypeId) 作为键,避免相同 Query 不同返回类型的冲突
22    handlers: DashMap<(TypeId, TypeId), (&'static str, HandlerFn)>,
23}
24
25impl Default for InMemoryQueryBus {
26    fn default() -> Self {
27        Self {
28            handlers: DashMap::new(),
29        }
30    }
31}
32
33impl InMemoryQueryBus {
34    pub fn new() -> Self {
35        Self::default()
36    }
37
38    /// 注册查询处理器
39    pub fn register<Q, R, H>(&self, handler: Arc<H>) -> Result<(), AppError>
40    where
41        Q: Send + 'static,
42        R: Send + 'static,
43        H: QueryHandler<Q, R> + Send + Sync + 'static,
44    {
45        let key = (TypeId::of::<Q>(), TypeId::of::<R>());
46
47        let f: HandlerFn = {
48            let handler = handler.clone();
49
50            Arc::new(move |boxed_q, ctx| {
51                let handler = handler.clone();
52
53                Box::pin(async move {
54                    match boxed_q.downcast::<Q>() {
55                        Ok(q) => {
56                            let dto_opt = handler.handle(ctx, *q).await?;
57                            Ok(Box::new(dto_opt) as BoxAnySend)
58                        }
59                        Err(e) => Err(AppError::type_mismatch(
60                            type_name::<Q>(),
61                            type_name_of_val(&e),
62                        )),
63                    }
64                })
65            })
66        };
67
68        if self.handlers.contains_key(&key) {
69            return Err(AppError::handler_already_registered(&format!(
70                "{}->{}",
71                type_name::<Q>(),
72                type_name::<R>()
73            )));
74        }
75
76        self.handlers.insert(key, (type_name::<Q>(), f));
77
78        Ok(())
79    }
80}
81
82#[async_trait]
83impl QueryBus for InMemoryQueryBus {
84    async fn dispatch<Q, R>(&self, ctx: &AppContext, q: Q) -> Result<R, AppError>
85    where
86        Q: Send + 'static,
87        R: Send + 'static,
88    {
89        self.dispatch_impl::<Q, R>(ctx, q).await
90    }
91}
92
93impl InMemoryQueryBus {
94    async fn dispatch_impl<Q, R>(&self, ctx: &AppContext, q: Q) -> Result<R, AppError>
95    where
96        Q: Send + 'static,
97        R: Send + 'static,
98    {
99        let key = (TypeId::of::<Q>(), TypeId::of::<R>());
100        let Some((_name, f)) = self.handlers.get(&key).map(|h| h.clone()) else {
101            return Err(AppError::handler_not_found(type_name::<Q>()));
102        };
103
104        let out = (f)(Box::new(q), ctx).await?;
105
106        match out.downcast::<R>() {
107            Ok(dto_opt) => Ok(*dto_opt),
108            Err(e) => Err(AppError::type_mismatch(
109                type_name::<R>(),
110                type_name_of_val(&e),
111            )),
112        }
113    }
114}
115
116impl InMemoryQueryBus {
117    /// 获取已注册的查询类型名列表(只读视图)
118    pub fn registered_queries(&self) -> Vec<&'static str> {
119        self.handlers.iter().map(|e| e.value().0).collect()
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use std::sync::atomic::{AtomicUsize, Ordering};
126
127    use eventide_domain::error::ErrorCode;
128    use serde::Serialize;
129    use tokio::task::JoinSet;
130
131    use super::*;
132    use crate::{error::AppError, query_handler::QueryHandler};
133
134    #[derive(Debug)]
135    struct Get;
136
137    #[derive(Debug, Serialize)]
138    struct NumDto(pub usize);
139
140    struct GetHandler {
141        counter: Arc<AtomicUsize>,
142    }
143
144    #[async_trait]
145    impl QueryHandler<Get, NumDto> for GetHandler {
146        async fn handle(&self, _ctx: &AppContext, _q: Get) -> Result<NumDto, AppError> {
147            let v = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
148            Ok(NumDto(v))
149        }
150    }
151
152    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
153    async fn register_and_dispatch_works() {
154        let bus = InMemoryQueryBus::new();
155        let counter = Arc::new(AtomicUsize::new(0));
156        bus.register::<Get, NumDto, _>(Arc::new(GetHandler {
157            counter: counter.clone(),
158        }))
159        .unwrap();
160
161        let ctx = AppContext::default();
162        let NumDto(n) = bus.dispatch::<Get, NumDto>(&ctx, Get).await.unwrap();
163        assert_eq!(n, 1);
164    }
165
166    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
167    async fn not_found_error_when_unregistered() {
168        let bus = InMemoryQueryBus::new();
169        let ctx = AppContext::default();
170        let err = bus.dispatch::<Get, NumDto>(&ctx, Get).await.unwrap_err();
171        assert_eq!(err.code(), "HANDLER_NOT_FOUND");
172        assert!(err.to_string().contains("Get"));
173    }
174
175    #[derive(Debug, Serialize)]
176    struct WrongDto;
177
178    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
179    async fn type_mismatch_error_when_result_downcast_fails() {
180        let bus = InMemoryQueryBus::new();
181        // 手动插入一个错误的条目:键是 Get,但闭包返回 WrongDto 而非 NumDto
182        let f: HandlerFn = Arc::new(|_boxed_q, _ctx| {
183            Box::pin(async move { Ok(Box::new(WrongDto) as BoxAnySend) })
184        });
185        bus.handlers.insert(
186            (TypeId::of::<Get>(), TypeId::of::<NumDto>()),
187            (type_name::<Get>(), f),
188        );
189
190        let ctx = AppContext::default();
191        let err = bus.dispatch::<Get, NumDto>(&ctx, Get).await.unwrap_err();
192        assert_eq!(err.code(), "TYPE_MISMATCH");
193        assert!(err.to_string().contains("NumDto"));
194    }
195
196    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
197    async fn concurrent_dispatch_is_safe() {
198        let bus = Arc::new(InMemoryQueryBus::new());
199        let counter = Arc::new(AtomicUsize::new(0));
200        bus.register::<Get, NumDto, _>(Arc::new(GetHandler {
201            counter: counter.clone(),
202        }))
203        .unwrap();
204
205        let mut set = JoinSet::new();
206        let ctx = AppContext::default();
207        for _ in 0..100 {
208            let bus = bus.clone();
209            let ctx = ctx.clone();
210            set.spawn(async move { bus.dispatch::<Get, NumDto>(&ctx, Get).await.unwrap() });
211        }
212        let mut results = Vec::new();
213        while let Some(res) = set.join_next().await {
214            results.push(res.unwrap().0);
215        }
216        results.sort_unstable();
217        assert_eq!(results.len(), 100);
218        assert_eq!(results[0], 1);
219        assert_eq!(results[99], 100);
220    }
221
222    #[derive(Debug)]
223    struct Get2;
224
225    #[derive(Debug, Serialize, PartialEq, Eq)]
226    struct NameDto(pub String);
227
228    struct Get2NumHandler;
229    struct Get2NameHandler;
230
231    #[async_trait]
232    impl QueryHandler<Get2, NumDto> for Get2NumHandler {
233        async fn handle(&self, _ctx: &AppContext, _q: Get2) -> Result<NumDto, AppError> {
234            Ok(NumDto(42))
235        }
236    }
237
238    #[async_trait]
239    impl QueryHandler<Get2, NameDto> for Get2NameHandler {
240        async fn handle(&self, _ctx: &AppContext, _q: Get2) -> Result<NameDto, AppError> {
241            Ok(NameDto("Alice".to_string()))
242        }
243    }
244
245    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
246    async fn same_query_with_different_results() {
247        // 同一查询类型 Get2,分别注册返回 NumDto 与 NameDto 的两个处理器
248        let bus = InMemoryQueryBus::new();
249        bus.register::<Get2, NumDto, _>(Arc::new(Get2NumHandler))
250            .unwrap();
251        bus.register::<Get2, NameDto, _>(Arc::new(Get2NameHandler))
252            .unwrap();
253
254        let ctx = AppContext::default();
255        let NumDto(n) = bus.dispatch::<Get2, NumDto>(&ctx, Get2).await.unwrap();
256        let NameDto(name) = bus.dispatch::<Get2, NameDto>(&ctx, Get2).await.unwrap();
257
258        assert_eq!(n, 42);
259        assert_eq!(name, "Alice");
260    }
261}