1use std::{
7 any::{Any, TypeId},
8 collections::HashMap,
9 sync::Arc,
10};
11
12use async_trait::async_trait;
13use tokio::sync::RwLock;
14
15use super::Event;
16
17pub trait Command: Send + Sync + 'static {}
19
20#[derive(Debug, Clone)]
22pub struct ValidationError {
23 pub field: String,
25 pub message: String,
27 pub code: String,
29}
30
31impl ValidationError {
32 pub fn new(field: impl Into<String>, message: impl Into<String>) -> Self {
34 Self {
35 field: field.into(),
36 message: message.into(),
37 code: "validation_failed".to_string(),
38 }
39 }
40
41 pub fn with_code(
43 field: impl Into<String>,
44 message: impl Into<String>,
45 code: impl Into<String>,
46 ) -> Self {
47 Self {
48 field: field.into(),
49 message: message.into(),
50 code: code.into(),
51 }
52 }
53}
54
55pub type CommandResult<E> = Result<Vec<E>, CommandError>;
57
58#[derive(Debug, Clone)]
60pub enum CommandError {
61 Validation(Vec<ValidationError>),
63 BusinessLogic(String),
65 NotFound(String),
67 AlreadyExecuted(String),
69 Internal(String),
71}
72
73impl std::fmt::Display for CommandError {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 match self {
76 CommandError::Validation(errors) => {
77 write!(f, "Validation failed: ")?;
78 for (i, err) in errors.iter().enumerate() {
79 if i > 0 {
80 write!(f, ", ")?;
81 }
82 write!(f, "{}: {}", err.field, err.message)?;
83 }
84 Ok(())
85 }
86 CommandError::BusinessLogic(msg) => write!(f, "Business logic error: {}", msg),
87 CommandError::NotFound(msg) => write!(f, "Handler not found: {}", msg),
88 CommandError::AlreadyExecuted(msg) => write!(f, "Already executed: {}", msg),
89 CommandError::Internal(msg) => write!(f, "Internal error: {}", msg),
90 }
91 }
92}
93
94impl std::error::Error for CommandError {}
95
96#[async_trait]
98pub trait CommandHandler<C: Command, E: Event>: Send + Sync {
99 async fn handle(&self, command: C) -> CommandResult<E>;
101}
102
103#[async_trait]
105trait ErasedHandler<E: Event>: Send + Sync {
106 async fn handle_erased(&self, command: Box<dyn Any + Send>) -> CommandResult<E>;
107}
108
109struct HandlerWrapper<C: Command, E: Event, H: CommandHandler<C, E>> {
111 handler: Arc<H>,
112 _phantom: std::marker::PhantomData<(C, E)>,
113}
114
115#[async_trait]
116impl<C: Command, E: Event, H: CommandHandler<C, E>> ErasedHandler<E> for HandlerWrapper<C, E, H> {
117 async fn handle_erased(&self, command: Box<dyn Any + Send>) -> CommandResult<E> {
118 match command.downcast::<C>() {
119 Ok(cmd) => self.handler.handle(*cmd).await,
120 Err(_) => Err(CommandError::Internal(
121 "Type mismatch in command dispatch".to_string(),
122 )),
123 }
124 }
125}
126
127type HandlerMap<E> = HashMap<TypeId, Arc<dyn ErasedHandler<E>>>;
129
130pub struct CommandBus<E: Event> {
132 handlers: Arc<RwLock<HandlerMap<E>>>,
133 idempotency_keys: Arc<RwLock<HashMap<String, Vec<E>>>>,
134}
135
136impl<E: Event> CommandBus<E> {
137 pub fn new() -> Self {
139 Self {
140 handlers: Arc::new(RwLock::new(HashMap::new())),
141 idempotency_keys: Arc::new(RwLock::new(HashMap::new())),
142 }
143 }
144
145 pub async fn register<C: Command, H: CommandHandler<C, E> + 'static>(&self, handler: H) {
147 let type_id = TypeId::of::<C>();
148 let wrapper = HandlerWrapper {
149 handler: Arc::new(handler),
150 _phantom: std::marker::PhantomData,
151 };
152 let mut handlers = self.handlers.write().await;
153 handlers.insert(type_id, Arc::new(wrapper));
154 }
155
156 pub async fn dispatch<C: Command>(&self, command: C) -> CommandResult<E> {
158 let type_id = TypeId::of::<C>();
159 let handlers = self.handlers.read().await;
160
161 match handlers.get(&type_id) {
162 Some(handler) => {
163 let boxed_command: Box<dyn Any + Send> = Box::new(command);
164 handler.handle_erased(boxed_command).await
165 }
166 None => Err(CommandError::NotFound(format!(
167 "No handler registered for command type: {}",
168 std::any::type_name::<C>()
169 ))),
170 }
171 }
172
173 pub async fn dispatch_idempotent<C: Command>(
175 &self,
176 command: C,
177 idempotency_key: String,
178 ) -> CommandResult<E> {
179 {
181 let keys = self.idempotency_keys.read().await;
182 if let Some(events) = keys.get(&idempotency_key) {
183 return Ok(events.clone());
184 }
185 }
186
187 let events = self.dispatch(command).await?;
189
190 {
192 let mut keys = self.idempotency_keys.write().await;
193 keys.insert(idempotency_key, events.clone());
194 }
195
196 Ok(events)
197 }
198
199 pub async fn handlers_count(&self) -> usize {
201 self.handlers.read().await.len()
202 }
203}
204
205impl<E: Event> Default for CommandBus<E> {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211impl<E: Event> Clone for CommandBus<E> {
212 fn clone(&self) -> Self {
213 Self {
214 handlers: Arc::clone(&self.handlers),
215 idempotency_keys: Arc::clone(&self.idempotency_keys),
216 }
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[derive(Clone, serde::Serialize, serde::Deserialize)]
225 enum TestEvent {
226 UserCreated { _id: String },
227 }
228
229 impl Event for TestEvent {}
230
231 struct CreateUserCommand {
232 email: String,
233 }
234
235 impl Command for CreateUserCommand {}
236
237 struct CreateUserHandler;
238
239 #[async_trait]
240 impl CommandHandler<CreateUserCommand, TestEvent> for CreateUserHandler {
241 async fn handle(&self, command: CreateUserCommand) -> CommandResult<TestEvent> {
242 if command.email.is_empty() {
243 return Err(CommandError::Validation(vec![ValidationError::new(
244 "email",
245 "Email is required",
246 )]));
247 }
248
249 Ok(vec![TestEvent::UserCreated {
250 _id: "123".to_string(),
251 }])
252 }
253 }
254
255 #[tokio::test]
256 async fn test_command_dispatch() {
257 let bus = CommandBus::new();
258 bus.register(CreateUserHandler).await;
259
260 let result = bus
261 .dispatch(CreateUserCommand {
262 email: "test@example.com".to_string(),
263 })
264 .await;
265
266 assert!(result.is_ok());
267 assert_eq!(result.unwrap().len(), 1);
268 }
269
270 #[tokio::test]
271 async fn test_validation_error() {
272 let bus = CommandBus::new();
273 bus.register(CreateUserHandler).await;
274
275 let result = bus
276 .dispatch(CreateUserCommand {
277 email: "".to_string(),
278 })
279 .await;
280
281 assert!(matches!(result, Err(CommandError::Validation(_))));
282 }
283
284 #[tokio::test]
285 async fn test_handler_not_found() {
286 let bus: CommandBus<TestEvent> = CommandBus::new();
287
288 let result = bus
289 .dispatch(CreateUserCommand {
290 email: "test@example.com".to_string(),
291 })
292 .await;
293
294 assert!(matches!(result, Err(CommandError::NotFound(_))));
295 }
296
297 #[tokio::test]
298 async fn test_idempotency() {
299 let bus = CommandBus::new();
300 bus.register(CreateUserHandler).await;
301
302 let cmd = CreateUserCommand {
303 email: "test@example.com".to_string(),
304 };
305
306 let result1 = bus
308 .dispatch_idempotent(cmd, "key1".to_string())
309 .await
310 .unwrap();
311
312 let cmd2 = CreateUserCommand {
314 email: "different@example.com".to_string(),
315 };
316 let result2 = bus
317 .dispatch_idempotent(cmd2, "key1".to_string())
318 .await
319 .unwrap();
320
321 assert_eq!(result1.len(), result2.len());
323 }
324}