1use std::any::{Any, TypeId};
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10
11use crate::aggregate::Aggregate;
12use crate::error::DispatchError;
13use crate::store::AggregateStore;
14
15#[derive(Debug, Clone, Default, Serialize, Deserialize)]
37pub struct CommandContext {
38 pub actor: Option<String>,
40 pub correlation_id: Option<String>,
42 pub metadata: Option<Value>,
44 #[serde(skip_serializing_if = "Option::is_none", default)]
50 pub source_device: Option<String>,
51}
52
53impl CommandContext {
54 pub fn with_actor(mut self, actor: impl Into<String>) -> Self {
65 self.actor = Some(actor.into());
66 self
67 }
68
69 pub fn with_correlation_id(mut self, id: impl Into<String>) -> Self {
80 self.correlation_id = Some(id.into());
81 self
82 }
83
84 pub fn with_metadata(mut self, meta: Value) -> Self {
95 self.metadata = Some(meta);
96 self
97 }
98
99 pub fn with_source_device(mut self, device_id: impl Into<String>) -> Self {
115 self.source_device = Some(device_id.into());
116 self
117 }
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct CommandEnvelope {
135 pub aggregate_type: String,
137 pub instance_id: String,
139 pub command: Value,
141 pub context: CommandContext,
143}
144
145trait CommandRoute: Send + Sync {
153 fn dispatch<'a>(
155 &'a self,
156 store: &'a AggregateStore,
157 instance_id: &'a str,
158 cmd: Box<dyn Any + Send>,
159 ctx: CommandContext,
160 ) -> Pin<Box<dyn Future<Output = Result<(), DispatchError>> + Send + 'a>>;
161}
162
163struct TypedCommandRoute<A: Aggregate> {
168 _marker: std::marker::PhantomData<A>,
169}
170
171impl<A: Aggregate> CommandRoute for TypedCommandRoute<A> {
172 fn dispatch<'a>(
173 &'a self,
174 store: &'a AggregateStore,
175 instance_id: &'a str,
176 cmd: Box<dyn Any + Send>,
177 ctx: CommandContext,
178 ) -> Pin<Box<dyn Future<Output = Result<(), DispatchError>> + Send + 'a>> {
179 Box::pin(async move {
180 let typed_cmd = cmd
185 .downcast::<A::Command>()
186 .map_err(|_| DispatchError::UnknownCommand)?;
187
188 let handle = store
189 .get::<A>(instance_id)
190 .await
191 .map_err(DispatchError::Io)?;
192
193 handle
194 .execute(*typed_cmd, ctx)
195 .await
196 .map_err(|e| DispatchError::Execution(Box::new(e)))?;
197
198 Ok(())
199 })
200 }
201}
202
203pub struct CommandBus {
223 store: AggregateStore,
224 routes: HashMap<TypeId, Box<dyn CommandRoute>>,
225}
226
227impl CommandBus {
228 pub fn new(store: AggregateStore) -> Self {
234 Self {
235 store,
236 routes: HashMap::new(),
237 }
238 }
239
240 pub fn register<A: Aggregate>(&mut self) {
250 let type_id = TypeId::of::<A::Command>();
251 self.routes.insert(
252 type_id,
253 Box::new(TypedCommandRoute::<A> {
254 _marker: std::marker::PhantomData,
255 }),
256 );
257 }
258
259 pub async fn dispatch<C: Send + 'static>(
276 &self,
277 instance_id: &str,
278 cmd: C,
279 ctx: CommandContext,
280 ) -> Result<(), DispatchError> {
281 let type_id = TypeId::of::<C>();
282 let route = self
283 .routes
284 .get(&type_id)
285 .ok_or(DispatchError::UnknownCommand)?;
286 route
287 .dispatch(&self.store, instance_id, Box::new(cmd), ctx)
288 .await
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use serde_json::json;
296
297 #[test]
298 fn default_context_has_no_fields_set() {
299 let ctx = CommandContext::default();
300 assert_eq!(ctx.actor, None);
301 assert_eq!(ctx.correlation_id, None);
302 assert_eq!(ctx.metadata, None);
303 assert_eq!(ctx.source_device, None);
304 }
305
306 #[test]
307 fn builder_sets_actor() {
308 let ctx = CommandContext::default().with_actor("user-1");
309 assert_eq!(ctx.actor.as_deref(), Some("user-1"));
310 }
311
312 #[test]
313 fn builder_sets_correlation_id() {
314 let ctx = CommandContext::default().with_correlation_id("corr-99");
315 assert_eq!(ctx.correlation_id.as_deref(), Some("corr-99"));
316 }
317
318 #[test]
319 fn builder_sets_metadata() {
320 let meta = json!({"key": "value"});
321 let ctx = CommandContext::default().with_metadata(meta.clone());
322 assert_eq!(ctx.metadata, Some(meta));
323 }
324
325 #[test]
326 fn builder_chains_all_fields() {
327 let ctx = CommandContext::default()
328 .with_actor("admin")
329 .with_correlation_id("req-abc")
330 .with_metadata(json!({"source": "test"}))
331 .with_source_device("phone-42");
332
333 assert_eq!(ctx.actor.as_deref(), Some("admin"));
334 assert_eq!(ctx.correlation_id.as_deref(), Some("req-abc"));
335 assert_eq!(ctx.metadata, Some(json!({"source": "test"})));
336 assert_eq!(ctx.source_device.as_deref(), Some("phone-42"));
337 }
338
339 #[test]
340 fn builder_sets_source_device() {
341 let ctx = CommandContext::default().with_source_device("device-abc");
342 assert_eq!(ctx.source_device.as_deref(), Some("device-abc"));
343 }
344
345 #[test]
346 fn builder_accepts_string_owned() {
347 let ctx = CommandContext::default()
350 .with_actor(String::from("svc-payments"))
351 .with_correlation_id(String::from("id-007"))
352 .with_source_device(String::from("laptop-01"));
353
354 assert_eq!(ctx.actor.as_deref(), Some("svc-payments"));
355 assert_eq!(ctx.correlation_id.as_deref(), Some("id-007"));
356 assert_eq!(ctx.source_device.as_deref(), Some("laptop-01"));
357 }
358
359 #[test]
360 fn clone_produces_independent_copy() {
361 let original = CommandContext::default()
362 .with_actor("user-1")
363 .with_metadata(json!({"a": 1}));
364
365 let cloned = original.clone();
366 assert_eq!(original.actor, cloned.actor);
367 assert_eq!(original.metadata, cloned.metadata);
368 }
369
370 #[test]
371 fn debug_format_is_readable() {
372 let ctx = CommandContext::default().with_actor("dbg-user");
373 let debug_output = format!("{ctx:?}");
374 assert!(debug_output.contains("dbg-user"));
375 }
376
377 #[test]
378 fn command_context_serde_roundtrip() {
379 let ctx = CommandContext::default()
380 .with_actor("user-1")
381 .with_correlation_id("corr-1")
382 .with_metadata(json!({"key": "value"}))
383 .with_source_device("device-xyz");
384
385 let json = serde_json::to_string(&ctx).expect("serialization should succeed");
386 let deserialized: CommandContext =
387 serde_json::from_str(&json).expect("deserialization should succeed");
388
389 assert_eq!(deserialized.actor, ctx.actor);
390 assert_eq!(deserialized.correlation_id, ctx.correlation_id);
391 assert_eq!(deserialized.metadata, ctx.metadata);
392 assert_eq!(deserialized.source_device, ctx.source_device);
393 }
394
395 #[test]
396 fn source_device_none_omitted_from_json() {
397 let ctx = CommandContext::default().with_actor("user-1");
399 let json = serde_json::to_string(&ctx).expect("serialization should succeed");
400 assert!(
401 !json.contains("source_device"),
402 "source_device key should be absent when None, got: {json}"
403 );
404 }
405
406 #[test]
407 fn deserialize_legacy_json_without_source_device() {
408 let legacy_json = r#"{"actor":"old-user","correlation_id":"old-corr","metadata":null}"#;
411 let ctx: CommandContext =
412 serde_json::from_str(legacy_json).expect("deserialization should succeed");
413 assert_eq!(ctx.actor.as_deref(), Some("old-user"));
414 assert_eq!(ctx.source_device, None);
415 }
416
417 #[test]
418 fn command_envelope_serde_roundtrip() {
419 let envelope = CommandEnvelope {
420 aggregate_type: "counter".to_string(),
421 instance_id: "c-1".to_string(),
422 command: json!({"type": "Increment"}),
423 context: CommandContext::default().with_actor("saga"),
424 };
425
426 let json = serde_json::to_string(&envelope).expect("serialization should succeed");
427 let deserialized: CommandEnvelope =
428 serde_json::from_str(&json).expect("deserialization should succeed");
429
430 assert_eq!(deserialized.aggregate_type, envelope.aggregate_type);
431 assert_eq!(deserialized.instance_id, envelope.instance_id);
432 assert_eq!(deserialized.command, envelope.command);
433 assert_eq!(deserialized.context.actor, envelope.context.actor);
434 }
435
436 use tempfile::TempDir;
439
440 use crate::aggregate::test_fixtures::{Counter, CounterCommand};
441 use crate::error::DispatchError;
442 use crate::store::AggregateStore;
443
444 #[derive(Debug, Clone, Default, PartialEq, serde::Serialize, serde::Deserialize)]
446 struct Toggle {
447 pub on: bool,
448 }
449
450 #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
451 #[serde(tag = "type", content = "data")]
452 enum ToggleEvent {
453 Toggled,
454 }
455
456 #[derive(Debug, thiserror::Error)]
457 enum ToggleError {}
458
459 struct ToggleCmd;
462
463 impl crate::aggregate::Aggregate for Toggle {
464 const AGGREGATE_TYPE: &'static str = "toggle";
465 type Command = ToggleCmd;
466 type DomainEvent = ToggleEvent;
467 type Error = ToggleError;
468
469 fn handle(&self, _cmd: ToggleCmd) -> Result<Vec<ToggleEvent>, ToggleError> {
470 Ok(vec![ToggleEvent::Toggled])
471 }
472
473 fn apply(mut self, _event: &ToggleEvent) -> Self {
474 self.on = !self.on;
475 self
476 }
477 }
478
479 #[tokio::test]
480 async fn command_bus_dispatch_to_two_aggregate_types() {
481 let tmp = TempDir::new().expect("failed to create temp dir");
482 let store = AggregateStore::open(tmp.path())
483 .await
484 .expect("open should succeed");
485
486 let mut bus = CommandBus::new(store.clone());
487 bus.register::<Counter>();
488 bus.register::<Toggle>();
489
490 bus.dispatch("c-1", CounterCommand::Increment, CommandContext::default())
492 .await
493 .expect("counter dispatch should succeed");
494 bus.dispatch("c-1", CounterCommand::Increment, CommandContext::default())
495 .await
496 .expect("second counter dispatch should succeed");
497
498 bus.dispatch("t-1", ToggleCmd, CommandContext::default())
500 .await
501 .expect("toggle dispatch should succeed");
502
503 let counter_state = store
505 .get::<Counter>("c-1")
506 .await
507 .expect("get counter should succeed")
508 .state()
509 .await
510 .expect("counter state should succeed");
511 assert_eq!(counter_state.value, 2);
512
513 let toggle_state = store
514 .get::<Toggle>("t-1")
515 .await
516 .expect("get toggle should succeed")
517 .state()
518 .await
519 .expect("toggle state should succeed");
520 assert!(toggle_state.on);
521 }
522
523 #[tokio::test]
524 async fn command_bus_unknown_command_returns_error() {
525 let tmp = TempDir::new().expect("failed to create temp dir");
526 let store = AggregateStore::open(tmp.path())
527 .await
528 .expect("open should succeed");
529
530 let bus = CommandBus::new(store);
531 let result = bus
534 .dispatch("c-1", CounterCommand::Increment, CommandContext::default())
535 .await;
536
537 assert!(
538 matches!(result, Err(DispatchError::UnknownCommand)),
539 "expected UnknownCommand, got: {result:?}"
540 );
541 }
542}