1use std::{
7 any::{Any, TypeId},
8 collections::HashMap,
9 sync::Arc,
10};
11
12use async_trait::async_trait;
13use tokio::sync::RwLock;
14
15use super::Event;
16
17pub trait Command: Send + Sync + 'static {}
19
20#[derive(Debug, Clone)]
22pub struct ValidationError {
23 pub field: String,
25 pub message: String,
27 pub code: String,
29}
30
31impl ValidationError {
32 pub fn new(field: impl Into<String>, message: impl Into<String>) -> Self {
34 Self {
35 field: field.into(),
36 message: message.into(),
37 code: "validation_failed".to_string(),
38 }
39 }
40
41 pub fn with_code(
43 field: impl Into<String>,
44 message: impl Into<String>,
45 code: impl Into<String>,
46 ) -> Self {
47 Self {
48 field: field.into(),
49 message: message.into(),
50 code: code.into(),
51 }
52 }
53}
54
55pub type CommandResult<E> = Result<Vec<E>, CommandError>;
57
58#[derive(Debug, Clone)]
60pub enum CommandError {
61 Validation(Vec<ValidationError>),
63 BusinessLogic(String),
65 NotFound(String),
67 AlreadyExecuted(String),
69 Internal(String),
71}
72
73impl std::fmt::Display for CommandError {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 match self {
76 CommandError::Validation(errors) => {
77 write!(f, "Validation failed: ")?;
78 for (i, err) in errors.iter().enumerate() {
79 if i > 0 {
80 write!(f, ", ")?;
81 }
82 write!(f, "{}: {}", err.field, err.message)?;
83 }
84 Ok(())
85 }
86 CommandError::BusinessLogic(msg) => write!(f, "Business logic error: {}", msg),
87 CommandError::NotFound(msg) => write!(f, "Handler not found: {}", msg),
88 CommandError::AlreadyExecuted(msg) => write!(f, "Already executed: {}", msg),
89 CommandError::Internal(msg) => write!(f, "Internal error: {}", msg),
90 }
91 }
92}
93
94impl std::error::Error for CommandError {}
95
96#[async_trait]
98pub trait CommandHandler<C: Command, E: Event>: Send + Sync {
99 async fn handle(&self, command: C) -> CommandResult<E>;
101}
102
103#[async_trait]
105trait ErasedHandler<E: Event>: Send + Sync {
106 async fn handle_erased(&self, command: Box<dyn Any + Send>) -> CommandResult<E>;
107}
108
109struct HandlerWrapper<C: Command, E: Event, H: CommandHandler<C, E>> {
111 handler: Arc<H>,
112 _phantom: std::marker::PhantomData<(C, E)>,
113}
114
115#[async_trait]
116impl<C: Command, E: Event, H: CommandHandler<C, E>> ErasedHandler<E> for HandlerWrapper<C, E, H> {
117 async fn handle_erased(&self, command: Box<dyn Any + Send>) -> CommandResult<E> {
118 match command.downcast::<C>() {
119 Ok(cmd) => self.handler.handle(*cmd).await,
120 Err(_) => Err(CommandError::Internal(
121 "Type mismatch in command dispatch".to_string(),
122 )),
123 }
124 }
125}
126
127type HandlerMap<E> = HashMap<TypeId, Arc<dyn ErasedHandler<E>>>;
129
130pub struct CommandBus<E: Event> {
132 handlers: Arc<RwLock<HandlerMap<E>>>,
133 idempotency_keys: Arc<RwLock<HashMap<String, Vec<E>>>>,
134}
135
136impl<E: Event> CommandBus<E> {
137 pub fn new() -> Self {
139 Self {
140 handlers: Arc::new(RwLock::new(HashMap::new())),
141 idempotency_keys: Arc::new(RwLock::new(HashMap::new())),
142 }
143 }
144
145 pub async fn register<C: Command, H: CommandHandler<C, E> + 'static>(&self, handler: H) {
147 let type_id = TypeId::of::<C>();
148 let wrapper = HandlerWrapper {
149 handler: Arc::new(handler),
150 _phantom: std::marker::PhantomData,
151 };
152 let mut handlers = self.handlers.write().await;
153 handlers.insert(type_id, Arc::new(wrapper));
154 }
155
156 pub async fn dispatch<C: Command>(&self, command: C) -> CommandResult<E> {
158 let type_id = TypeId::of::<C>();
159 let handlers = self.handlers.read().await;
160
161 match handlers.get(&type_id) {
162 Some(handler) => {
163 let boxed_command: Box<dyn Any + Send> = Box::new(command);
164 handler.handle_erased(boxed_command).await
165 }
166 None => Err(CommandError::NotFound(format!(
167 "No handler registered for command type: {}",
168 std::any::type_name::<C>()
169 ))),
170 }
171 }
172
173 pub async fn dispatch_idempotent<C: Command>(
175 &self,
176 command: C,
177 idempotency_key: String,
178 ) -> CommandResult<E> {
179 {
181 let keys = self.idempotency_keys.read().await;
182 if let Some(events) = keys.get(&idempotency_key) {
183 return Ok(events.clone());
184 }
185 }
186
187 let events = self.dispatch(command).await?;
189
190 {
192 let mut keys = self.idempotency_keys.write().await;
193 keys.insert(idempotency_key, events.clone());
194 }
195
196 Ok(events)
197 }
198
199 pub async fn handlers_count(&self) -> usize {
201 self.handlers.read().await.len()
202 }
203}
204
205impl<E: Event> Default for CommandBus<E> {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211impl<E: Event> Clone for CommandBus<E> {
212 fn clone(&self) -> Self {
213 Self {
214 handlers: Arc::clone(&self.handlers),
215 idempotency_keys: Arc::clone(&self.idempotency_keys),
216 }
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use crate::cqrs::EventTypeName;
224
225 #[derive(Clone, serde::Serialize, serde::Deserialize)]
226 enum TestEvent {
227 UserCreated { _id: String },
228 }
229
230 impl EventTypeName for TestEvent {}
231 impl Event for TestEvent {}
232
233 struct CreateUserCommand {
234 email: String,
235 }
236
237 impl Command for CreateUserCommand {}
238
239 struct CreateUserHandler;
240
241 #[async_trait]
242 impl CommandHandler<CreateUserCommand, TestEvent> for CreateUserHandler {
243 async fn handle(&self, command: CreateUserCommand) -> CommandResult<TestEvent> {
244 if command.email.is_empty() {
245 return Err(CommandError::Validation(vec![ValidationError::new(
246 "email",
247 "Email is required",
248 )]));
249 }
250
251 Ok(vec![TestEvent::UserCreated {
252 _id: "123".to_string(),
253 }])
254 }
255 }
256
257 #[tokio::test]
258 async fn test_command_dispatch() {
259 let bus = CommandBus::new();
260 bus.register(CreateUserHandler).await;
261
262 let result = bus
263 .dispatch(CreateUserCommand {
264 email: "test@example.com".to_string(),
265 })
266 .await;
267
268 assert!(result.is_ok());
269 assert_eq!(result.unwrap().len(), 1);
270 }
271
272 #[tokio::test]
273 async fn test_validation_error() {
274 let bus = CommandBus::new();
275 bus.register(CreateUserHandler).await;
276
277 let result = bus
278 .dispatch(CreateUserCommand {
279 email: "".to_string(),
280 })
281 .await;
282
283 assert!(matches!(result, Err(CommandError::Validation(_))));
284 }
285
286 #[tokio::test]
287 async fn test_handler_not_found() {
288 let bus: CommandBus<TestEvent> = CommandBus::new();
289
290 let result = bus
291 .dispatch(CreateUserCommand {
292 email: "test@example.com".to_string(),
293 })
294 .await;
295
296 assert!(matches!(result, Err(CommandError::NotFound(_))));
297 }
298
299 #[tokio::test]
300 async fn test_idempotency() {
301 let bus = CommandBus::new();
302 bus.register(CreateUserHandler).await;
303
304 let cmd = CreateUserCommand {
305 email: "test@example.com".to_string(),
306 };
307
308 let result1 = bus
310 .dispatch_idempotent(cmd, "key1".to_string())
311 .await
312 .unwrap();
313
314 let cmd2 = CreateUserCommand {
316 email: "different@example.com".to_string(),
317 };
318 let result2 = bus
319 .dispatch_idempotent(cmd2, "key1".to_string())
320 .await
321 .unwrap();
322
323 assert_eq!(result1.len(), result2.len());
325 }
326}