use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use crate::bus::{EventBus, EventBusBuilder};
use crate::error::EventBusError;
use crate::handler::SyncEventHandler;
use crate::types::{Event, SubscriptionPolicy};
type AnyBuffer = Arc<Mutex<Vec<Box<dyn Any + Send + Sync>>>>;
pub struct TestBus {
bus: EventBus,
buffers: Arc<Mutex<HashMap<TypeId, AnyBuffer>>>,
}
impl TestBus {
pub fn new() -> Result<Self, EventBusError> {
Ok(Self {
bus: EventBus::new(256)?,
buffers: Arc::new(Mutex::new(HashMap::new())),
})
}
pub fn builder() -> TestBusBuilder {
TestBusBuilder { inner: EventBus::builder() }
}
pub fn inner(&self) -> &EventBus {
&self.bus
}
pub async fn capture<E>(&self) -> Result<(), EventBusError>
where
E: Event + Clone + Send + Sync + 'static,
{
let type_id = TypeId::of::<E>();
{
let buffers = self.buffers.lock().expect("TestBus buffers lock poisoned");
if buffers.contains_key(&type_id) {
return Ok(()); }
}
let buffer: AnyBuffer = Arc::new(Mutex::new(Vec::new()));
{
let mut buffers = self.buffers.lock().expect("TestBus buffers lock poisoned");
buffers.insert(type_id, Arc::clone(&buffer));
}
struct CaptureHandler<E: Clone + Send + Sync + 'static> {
buffer: AnyBuffer,
_marker: std::marker::PhantomData<E>,
}
impl<E: Event + Clone + Send + Sync + 'static> SyncEventHandler<E> for CaptureHandler<E> {
fn handle(&self, event: &E) -> crate::error::HandlerResult {
let cloned: Box<dyn Any + Send + Sync> = Box::new(event.clone());
self.buffer.lock().expect("capture buffer lock poisoned").push(cloned);
Ok(())
}
}
let handler = CaptureHandler::<E> {
buffer,
_marker: std::marker::PhantomData,
};
let policy = crate::SyncSubscriptionPolicy::default().with_dead_letter(false);
let _sub = self.bus.subscribe_with_policy::<E, _, crate::handler::SyncMode>(handler, policy).await?;
Ok(())
}
pub fn published<E>(&self) -> Vec<E>
where
E: Event + Clone + 'static,
{
let buffers = self.buffers.lock().expect("TestBus buffers lock poisoned");
let Some(buffer) = buffers.get(&TypeId::of::<E>()) else {
return Vec::new();
};
let guard = buffer.lock().expect("capture buffer lock poisoned");
guard.iter().filter_map(|any| any.downcast_ref::<E>().cloned()).collect()
}
pub fn assert_count<E>(&self, expected: usize)
where
E: Event + Clone + 'static,
{
let actual = self.published::<E>().len();
assert_eq!(
actual,
expected,
"TestBus::assert_count<{}>: expected {} events, got {}",
std::any::type_name::<E>(),
expected,
actual,
);
}
pub fn assert_empty<E>(&self)
where
E: Event + Clone + 'static,
{
self.assert_count::<E>(0);
}
pub async fn shutdown(&self) -> Result<(), EventBusError> {
self.bus.shutdown().await
}
}
pub struct TestBusBuilder {
inner: EventBusBuilder,
}
impl TestBusBuilder {
pub fn buffer_size(mut self, size: usize) -> Self {
self.inner = self.inner.buffer_size(size);
self
}
pub fn handler_timeout(mut self, timeout: std::time::Duration) -> Self {
self.inner = self.inner.handler_timeout(timeout);
self
}
pub fn max_concurrent_async(mut self, max: usize) -> Self {
self.inner = self.inner.max_concurrent_async(max);
self
}
pub fn shutdown_timeout(mut self, timeout: std::time::Duration) -> Self {
self.inner = self.inner.shutdown_timeout(timeout);
self
}
pub fn default_subscription_policy(mut self, policy: SubscriptionPolicy) -> Self {
self.inner = self.inner.default_subscription_policy(policy);
self
}
#[deprecated(since = "0.3.3", note = "renamed to default_subscription_policy")]
pub fn default_failure_policy(self, policy: SubscriptionPolicy) -> Self {
self.default_subscription_policy(policy)
}
pub fn build(self) -> Result<TestBus, EventBusError> {
Ok(TestBus {
bus: self.inner.build()?,
buffers: Arc::new(Mutex::new(HashMap::new())),
})
}
}