use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;
use klauthed_error::DomainError;
use super::CqrsError;
pub trait Command: Send + 'static {
type Output: Send + 'static;
}
#[async_trait::async_trait]
pub trait CommandHandler<C: Command>: Send + Sync {
type Error: DomainError + Send + Sync + 'static;
async fn handle(&self, command: C) -> Result<C::Output, Self::Error>;
}
#[async_trait::async_trait]
trait ErasedCommandHandler<C: Command>: Send + Sync {
async fn handle_erased(&self, command: C) -> Result<C::Output, CqrsError>;
}
#[async_trait::async_trait]
impl<C: Command, H: CommandHandler<C>> ErasedCommandHandler<C> for H {
async fn handle_erased(&self, command: C) -> Result<C::Output, CqrsError> {
self.handle(command).await.map_err(CqrsError::handler)
}
}
#[derive(Default)]
pub struct CommandBus {
handlers: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
}
impl CommandBus {
pub fn new() -> Self {
Self::default()
}
pub fn register<C, H>(&mut self, handler: H) -> &mut Self
where
C: Command,
H: CommandHandler<C> + 'static,
{
let erased: Arc<dyn ErasedCommandHandler<C>> = Arc::new(handler);
self.handlers.insert(TypeId::of::<C>(), Box::new(erased));
self
}
pub async fn dispatch<C: Command>(&self, command: C) -> Result<C::Output, CqrsError> {
let handler = self
.handlers
.get(&TypeId::of::<C>())
.and_then(|h| h.downcast_ref::<Arc<dyn ErasedCommandHandler<C>>>())
.cloned()
.ok_or_else(CqrsError::no_handler::<C>)?;
handler.handle_erased(command).await
}
}