Skip to main content

eventide_application/
inmemory_command_bus.rs

1use std::{
2    any::{TypeId, type_name, type_name_of_val},
3    sync::Arc,
4};
5
6use async_trait::async_trait;
7use dashmap::DashMap;
8
9use crate::{
10    bus_types::{BoxAnySend, HandlerFn},
11    command_bus::CommandBus,
12    command_handler::CommandHandler,
13    context::AppContext,
14    error::AppError,
15};
16
17/// 基于内存的 CommandBus 实现
18/// - 通过 (CommandTypeId, ResultTypeId) 注册不同 Command 对应的 Handler
19/// - 运行时以类型擦除(Any)方式进行调度,并在调用端进行结果还原
20pub struct InMemoryCommandBus {
21    handlers: DashMap<(TypeId, TypeId), (&'static str, HandlerFn)>,
22}
23
24impl Default for InMemoryCommandBus {
25    fn default() -> Self {
26        Self {
27            handlers: DashMap::new(),
28        }
29    }
30}
31
32impl InMemoryCommandBus {
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// 注册命令处理器
38    pub fn register<C, R, H>(&self, handler: Arc<H>) -> Result<(), AppError>
39    where
40        C: Send + 'static,
41        R: Send + 'static,
42        H: CommandHandler<C, R> + Send + Sync + 'static,
43    {
44        let key = (TypeId::of::<C>(), TypeId::of::<R>());
45
46        let f: HandlerFn = {
47            let handler = handler.clone();
48
49            Arc::new(move |boxed_cmd, ctx| {
50                let handler = handler.clone();
51
52                Box::pin(async move {
53                    match boxed_cmd.downcast::<C>() {
54                        Ok(cmd) => {
55                            let result = handler.handle(ctx, *cmd).await?;
56                            Ok(Box::new(result) as BoxAnySend)
57                        }
58                        Err(e) => {
59                            let found = type_name_of_val(&e);
60                            Err(AppError::type_mismatch(type_name::<C>(), found))
61                        }
62                    }
63                })
64            })
65        };
66
67        if self.handlers.contains_key(&key) {
68            return Err(AppError::handler_already_registered(&format!(
69                "{}->{}",
70                type_name::<C>(),
71                type_name::<R>()
72            )));
73        }
74
75        self.handlers.insert(key, (type_name::<C>(), f));
76
77        Ok(())
78    }
79}
80
81#[async_trait]
82impl CommandBus for InMemoryCommandBus {
83    async fn dispatch<C, R>(&self, ctx: &AppContext, cmd: C) -> Result<R, AppError>
84    where
85        C: Send + 'static,
86        R: Send + 'static,
87    {
88        self.dispatch_impl::<C, R>(ctx, cmd).await
89    }
90}
91
92impl InMemoryCommandBus {
93    async fn dispatch_impl<C, R>(&self, ctx: &AppContext, cmd: C) -> Result<R, AppError>
94    where
95        C: Send + 'static,
96        R: Send + 'static,
97    {
98        let key = (TypeId::of::<C>(), TypeId::of::<R>());
99        let Some((_name, f)) = self.handlers.get(&key).map(|h| h.clone()) else {
100            return Err(AppError::handler_not_found(type_name::<C>()));
101        };
102
103        let out = (f)(Box::new(cmd), ctx).await?;
104
105        match out.downcast::<R>() {
106            Ok(result) => Ok(*result),
107            Err(e) => Err(AppError::type_mismatch(
108                type_name::<R>(),
109                type_name_of_val(&e),
110            )),
111        }
112    }
113}
114
115impl InMemoryCommandBus {
116    /// 获取已注册的命令类型名列表(只读视图)
117    pub fn registered_commands(&self) -> Vec<&'static str> {
118        self.handlers.iter().map(|e| e.value().0).collect()
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use std::sync::atomic::{AtomicUsize, Ordering};
125
126    use eventide_domain::error::ErrorCode;
127    use tokio::task::JoinSet;
128
129    use super::*;
130    use crate::{command_handler::CommandHandler, error::AppError};
131
132    #[derive(Debug)]
133    struct Add;
134
135    #[derive(Debug, PartialEq, Eq)]
136    struct AddResult(pub usize);
137
138    struct AddHandler {
139        counter: Arc<AtomicUsize>,
140    }
141
142    #[async_trait]
143    impl CommandHandler<Add, AddResult> for AddHandler {
144        async fn handle(&self, _ctx: &AppContext, _cmd: Add) -> Result<AddResult, AppError> {
145            let v = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
146            Ok(AddResult(v))
147        }
148    }
149
150    // 测试注册和分发命令正常工作
151    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
152    async fn register_and_dispatch_works() {
153        let bus = InMemoryCommandBus::new();
154        let counter = Arc::new(AtomicUsize::new(0));
155        bus.register::<Add, AddResult, _>(Arc::new(AddHandler {
156            counter: counter.clone(),
157        }))
158        .unwrap();
159
160        let ctx = AppContext::default();
161        let AddResult(n) = bus.dispatch::<Add, AddResult>(&ctx, Add).await.unwrap();
162        assert_eq!(n, 1);
163    }
164
165    // 测试未注册处理器时返回 HANDLER_NOT_FOUND 错误
166    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
167    async fn not_found_error_when_unregistered() {
168        let bus = InMemoryCommandBus::new();
169        let ctx = AppContext::default();
170        let err = bus.dispatch::<Add, AddResult>(&ctx, Add).await.unwrap_err();
171        assert_eq!(err.code(), "HANDLER_NOT_FOUND");
172        assert!(err.to_string().contains("Add"));
173    }
174
175    #[derive(Debug)]
176    struct WrongResult;
177
178    // 测试结果类型不匹配时返回 TYPE_MISMATCH 错误
179    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
180    async fn type_mismatch_error_when_result_downcast_fails() {
181        let bus = InMemoryCommandBus::new();
182        // 手动插入一个错误的条目:键是 (Add, AddResult),但闭包返回 WrongResult
183        let f: HandlerFn = Arc::new(|_boxed_cmd, _ctx| {
184            Box::pin(async move { Ok(Box::new(WrongResult) as BoxAnySend) })
185        });
186        bus.handlers.insert(
187            (TypeId::of::<Add>(), TypeId::of::<AddResult>()),
188            (type_name::<Add>(), f),
189        );
190
191        let ctx = AppContext::default();
192        let err = bus.dispatch::<Add, AddResult>(&ctx, Add).await.unwrap_err();
193        assert_eq!(err.code(), "TYPE_MISMATCH");
194        assert!(err.to_string().contains("AddResult"));
195    }
196
197    // 测试并发分发命令是安全的
198    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
199    async fn concurrent_dispatch_is_safe() {
200        let bus = Arc::new(InMemoryCommandBus::new());
201        let counter = Arc::new(AtomicUsize::new(0));
202        bus.register::<Add, AddResult, _>(Arc::new(AddHandler {
203            counter: counter.clone(),
204        }))
205        .unwrap();
206
207        let mut set = JoinSet::new();
208        let ctx = AppContext::default();
209        for _ in 0..100 {
210            let bus = bus.clone();
211            let ctx = ctx.clone();
212            set.spawn(async move { bus.dispatch::<Add, AddResult>(&ctx, Add).await.unwrap() });
213        }
214
215        let mut results = Vec::new();
216        while let Some(res) = set.join_next().await {
217            results.push(res.unwrap().0);
218        }
219        results.sort_unstable();
220        assert_eq!(results.len(), 100);
221        assert_eq!(results[0], 1);
222        assert_eq!(results[99], 100);
223    }
224
225    // 返回 () 的命令处理器
226    #[derive(Debug)]
227    struct VoidCmd;
228
229    struct VoidHandler;
230
231    #[async_trait]
232    impl CommandHandler<VoidCmd, ()> for VoidHandler {
233        async fn handle(&self, _ctx: &AppContext, _cmd: VoidCmd) -> Result<(), AppError> {
234            Ok(())
235        }
236    }
237
238    // 测试返回 () 的命令处理器正常工作
239    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
240    async fn void_result_works() {
241        let bus = InMemoryCommandBus::new();
242        bus.register::<VoidCmd, (), _>(Arc::new(VoidHandler))
243            .unwrap();
244
245        let ctx = AppContext::default();
246        bus.dispatch::<VoidCmd, ()>(&ctx, VoidCmd).await.unwrap();
247    }
248}