use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::aggregate::Aggregate;
use crate::error::DispatchError;
use crate::store::AggregateStore;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CommandContext {
pub actor: Option<String>,
pub correlation_id: Option<String>,
pub metadata: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub source_device: Option<String>,
}
impl CommandContext {
pub fn with_actor(mut self, actor: impl Into<String>) -> Self {
self.actor = Some(actor.into());
self
}
pub fn with_correlation_id(mut self, id: impl Into<String>) -> Self {
self.correlation_id = Some(id.into());
self
}
pub fn with_metadata(mut self, meta: Value) -> Self {
self.metadata = Some(meta);
self
}
pub fn with_source_device(mut self, device_id: impl Into<String>) -> Self {
self.source_device = Some(device_id.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommandEnvelope {
pub aggregate_type: String,
pub instance_id: String,
pub command: Value,
pub context: CommandContext,
}
trait CommandRoute: Send + Sync {
fn dispatch<'a>(
&'a self,
store: &'a AggregateStore,
instance_id: &'a str,
cmd: Box<dyn Any + Send>,
ctx: CommandContext,
) -> Pin<Box<dyn Future<Output = Result<(), DispatchError>> + Send + 'a>>;
}
struct TypedCommandRoute<A: Aggregate> {
_marker: std::marker::PhantomData<A>,
}
impl<A: Aggregate> CommandRoute for TypedCommandRoute<A> {
fn dispatch<'a>(
&'a self,
store: &'a AggregateStore,
instance_id: &'a str,
cmd: Box<dyn Any + Send>,
ctx: CommandContext,
) -> Pin<Box<dyn Future<Output = Result<(), DispatchError>> + Send + 'a>> {
Box::pin(async move {
let typed_cmd = cmd
.downcast::<A::Command>()
.map_err(|_| DispatchError::UnknownCommand)?;
let handle = store
.get::<A>(instance_id)
.await
.map_err(DispatchError::Io)?;
handle
.execute(*typed_cmd, ctx)
.await
.map_err(|e| DispatchError::Execution(Box::new(e)))?;
Ok(())
})
}
}
pub struct CommandBus {
store: AggregateStore,
routes: HashMap<TypeId, Box<dyn CommandRoute>>,
}
impl CommandBus {
pub fn new(store: AggregateStore) -> Self {
Self {
store,
routes: HashMap::new(),
}
}
pub fn register<A: Aggregate>(&mut self) {
let type_id = TypeId::of::<A::Command>();
self.routes.insert(
type_id,
Box::new(TypedCommandRoute::<A> {
_marker: std::marker::PhantomData,
}),
);
}
pub async fn dispatch<C: Send + 'static>(
&self,
instance_id: &str,
cmd: C,
ctx: CommandContext,
) -> Result<(), DispatchError> {
let type_id = TypeId::of::<C>();
let route = self
.routes
.get(&type_id)
.ok_or(DispatchError::UnknownCommand)?;
route
.dispatch(&self.store, instance_id, Box::new(cmd), ctx)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn default_context_has_no_fields_set() {
let ctx = CommandContext::default();
assert_eq!(ctx.actor, None);
assert_eq!(ctx.correlation_id, None);
assert_eq!(ctx.metadata, None);
assert_eq!(ctx.source_device, None);
}
#[test]
fn builder_sets_actor() {
let ctx = CommandContext::default().with_actor("user-1");
assert_eq!(ctx.actor.as_deref(), Some("user-1"));
}
#[test]
fn builder_sets_correlation_id() {
let ctx = CommandContext::default().with_correlation_id("corr-99");
assert_eq!(ctx.correlation_id.as_deref(), Some("corr-99"));
}
#[test]
fn builder_sets_metadata() {
let meta = json!({"key": "value"});
let ctx = CommandContext::default().with_metadata(meta.clone());
assert_eq!(ctx.metadata, Some(meta));
}
#[test]
fn builder_chains_all_fields() {
let ctx = CommandContext::default()
.with_actor("admin")
.with_correlation_id("req-abc")
.with_metadata(json!({"source": "test"}))
.with_source_device("phone-42");
assert_eq!(ctx.actor.as_deref(), Some("admin"));
assert_eq!(ctx.correlation_id.as_deref(), Some("req-abc"));
assert_eq!(ctx.metadata, Some(json!({"source": "test"})));
assert_eq!(ctx.source_device.as_deref(), Some("phone-42"));
}
#[test]
fn builder_sets_source_device() {
let ctx = CommandContext::default().with_source_device("device-abc");
assert_eq!(ctx.source_device.as_deref(), Some("device-abc"));
}
#[test]
fn builder_accepts_string_owned() {
let ctx = CommandContext::default()
.with_actor(String::from("svc-payments"))
.with_correlation_id(String::from("id-007"))
.with_source_device(String::from("laptop-01"));
assert_eq!(ctx.actor.as_deref(), Some("svc-payments"));
assert_eq!(ctx.correlation_id.as_deref(), Some("id-007"));
assert_eq!(ctx.source_device.as_deref(), Some("laptop-01"));
}
#[test]
fn clone_produces_independent_copy() {
let original = CommandContext::default()
.with_actor("user-1")
.with_metadata(json!({"a": 1}));
let cloned = original.clone();
assert_eq!(original.actor, cloned.actor);
assert_eq!(original.metadata, cloned.metadata);
}
#[test]
fn debug_format_is_readable() {
let ctx = CommandContext::default().with_actor("dbg-user");
let debug_output = format!("{ctx:?}");
assert!(debug_output.contains("dbg-user"));
}
#[test]
fn command_context_serde_roundtrip() {
let ctx = CommandContext::default()
.with_actor("user-1")
.with_correlation_id("corr-1")
.with_metadata(json!({"key": "value"}))
.with_source_device("device-xyz");
let json = serde_json::to_string(&ctx).expect("serialization should succeed");
let deserialized: CommandContext =
serde_json::from_str(&json).expect("deserialization should succeed");
assert_eq!(deserialized.actor, ctx.actor);
assert_eq!(deserialized.correlation_id, ctx.correlation_id);
assert_eq!(deserialized.metadata, ctx.metadata);
assert_eq!(deserialized.source_device, ctx.source_device);
}
#[test]
fn source_device_none_omitted_from_json() {
let ctx = CommandContext::default().with_actor("user-1");
let json = serde_json::to_string(&ctx).expect("serialization should succeed");
assert!(
!json.contains("source_device"),
"source_device key should be absent when None, got: {json}"
);
}
#[test]
fn deserialize_legacy_json_without_source_device() {
let legacy_json = r#"{"actor":"old-user","correlation_id":"old-corr","metadata":null}"#;
let ctx: CommandContext =
serde_json::from_str(legacy_json).expect("deserialization should succeed");
assert_eq!(ctx.actor.as_deref(), Some("old-user"));
assert_eq!(ctx.source_device, None);
}
#[test]
fn command_envelope_serde_roundtrip() {
let envelope = CommandEnvelope {
aggregate_type: "counter".to_string(),
instance_id: "c-1".to_string(),
command: json!({"type": "Increment"}),
context: CommandContext::default().with_actor("saga"),
};
let json = serde_json::to_string(&envelope).expect("serialization should succeed");
let deserialized: CommandEnvelope =
serde_json::from_str(&json).expect("deserialization should succeed");
assert_eq!(deserialized.aggregate_type, envelope.aggregate_type);
assert_eq!(deserialized.instance_id, envelope.instance_id);
assert_eq!(deserialized.command, envelope.command);
assert_eq!(deserialized.context.actor, envelope.context.actor);
}
use tempfile::TempDir;
use crate::aggregate::test_fixtures::{Counter, CounterCommand};
use crate::error::DispatchError;
use crate::store::AggregateStore;
#[derive(Debug, Clone, Default, PartialEq, serde::Serialize, serde::Deserialize)]
struct Toggle {
pub on: bool,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type", content = "data")]
enum ToggleEvent {
Toggled,
}
#[derive(Debug, thiserror::Error)]
enum ToggleError {}
struct ToggleCmd;
impl crate::aggregate::Aggregate for Toggle {
const AGGREGATE_TYPE: &'static str = "toggle";
type Command = ToggleCmd;
type DomainEvent = ToggleEvent;
type Error = ToggleError;
fn handle(&self, _cmd: ToggleCmd) -> Result<Vec<ToggleEvent>, ToggleError> {
Ok(vec![ToggleEvent::Toggled])
}
fn apply(mut self, _event: &ToggleEvent) -> Self {
self.on = !self.on;
self
}
}
#[tokio::test]
async fn command_bus_dispatch_to_two_aggregate_types() {
let tmp = TempDir::new().expect("failed to create temp dir");
let store = AggregateStore::open(tmp.path())
.await
.expect("open should succeed");
let mut bus = CommandBus::new(store.clone());
bus.register::<Counter>();
bus.register::<Toggle>();
bus.dispatch("c-1", CounterCommand::Increment, CommandContext::default())
.await
.expect("counter dispatch should succeed");
bus.dispatch("c-1", CounterCommand::Increment, CommandContext::default())
.await
.expect("second counter dispatch should succeed");
bus.dispatch("t-1", ToggleCmd, CommandContext::default())
.await
.expect("toggle dispatch should succeed");
let counter_state = store
.get::<Counter>("c-1")
.await
.expect("get counter should succeed")
.state()
.await
.expect("counter state should succeed");
assert_eq!(counter_state.value, 2);
let toggle_state = store
.get::<Toggle>("t-1")
.await
.expect("get toggle should succeed")
.state()
.await
.expect("toggle state should succeed");
assert!(toggle_state.on);
}
#[tokio::test]
async fn command_bus_unknown_command_returns_error() {
let tmp = TempDir::new().expect("failed to create temp dir");
let store = AggregateStore::open(tmp.path())
.await
.expect("open should succeed");
let bus = CommandBus::new(store);
let result = bus
.dispatch("c-1", CounterCommand::Increment, CommandContext::default())
.await;
assert!(
matches!(result, Err(DispatchError::UnknownCommand)),
"expected UnknownCommand, got: {result:?}"
);
}
}