use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::future::Future;
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot, Mutex, Notify, Semaphore};
use tokio::task::AbortHandle;
use tracing::{error, warn};
#[cfg(feature = "metrics")]
use metrics::counter;
use crate::error::{EventBusError, HandlerResult};
use crate::middleware::MiddlewareDecision;
use crate::types::{DeadLetter, FailurePolicy, ListenerInfo, SubscriptionId};
#[cfg(feature = "metrics")]
use crate::metrics::TimerGuard;
pub(crate) type EventType = Arc<dyn Any + Send + Sync>;
pub(crate) type HandlerFuture = Pin<Box<dyn Future<Output = HandlerResult> + Send>>;
pub(crate) type MiddlewareFuture = Pin<Box<dyn Future<Output = MiddlewareDecision> + Send>>;
pub(crate) type ErasedAsyncMiddleware = Arc<dyn Fn(&'static str, EventType) -> MiddlewareFuture + Send + Sync>;
pub(crate) type ErasedSyncMiddleware = Arc<dyn Fn(&'static str, &(dyn Any + Send + Sync)) -> MiddlewareDecision + Send + Sync>;
#[derive(Clone)]
pub(crate) enum ErasedMiddleware {
Async(ErasedAsyncMiddleware),
Sync(ErasedSyncMiddleware),
}
pub(crate) type ErasedAsyncHandlerFn = Arc<dyn Fn(EventType) -> HandlerFuture + Send + Sync + 'static>;
pub(crate) type ErasedSyncHandlerFn = Arc<dyn Fn(&(dyn Any + Send + Sync)) -> HandlerResult + Send + Sync + 'static>;
pub(crate) type ErasedTypedAsyncMiddlewareFn = Arc<dyn Fn(&'static str, EventType) -> MiddlewareFuture + Send + Sync + 'static>;
pub(crate) type ErasedTypedSyncMiddlewareFn = Arc<dyn Fn(&'static str, &(dyn Any + Send + Sync)) -> MiddlewareDecision + Send + Sync + 'static>;
#[derive(Clone)]
pub(crate) enum ListenerKind {
Async(ErasedAsyncHandlerFn),
Sync(ErasedSyncHandlerFn),
}
#[derive(Clone)]
pub(crate) struct ListenerEntry {
pub id: SubscriptionId,
pub kind: ListenerKind,
pub failure_policy: FailurePolicy,
pub name: Option<&'static str>,
pub once: bool,
pub fired: Option<Arc<AtomicBool>>,
}
#[derive(Clone)]
pub(crate) enum TypedMiddlewareEntry {
Async(ErasedTypedAsyncMiddlewareFn),
Sync(ErasedTypedSyncMiddlewareFn),
}
#[derive(Clone)]
pub(crate) struct TypedMiddlewareSlot {
pub id: SubscriptionId,
pub middleware: TypedMiddlewareEntry,
}
#[derive(Clone)]
pub(crate) struct TypeSlot {
pub sync_listeners: Arc<[ListenerEntry]>,
pub async_listeners: Arc<[ListenerEntry]>,
pub middlewares: Arc<[TypedMiddlewareSlot]>,
pub sync_gate: Arc<Mutex<()>>,
pub async_semaphore: Option<Arc<Semaphore>>,
}
#[derive(Default)]
pub(crate) struct RegistrySnapshot {
pub by_type: HashMap<TypeId, Arc<TypeSlot>>,
pub global_middlewares: Arc<[(SubscriptionId, ErasedMiddleware)]>,
}
struct MutableTypeSlot {
event_name: &'static str,
listeners: Vec<ListenerEntry>,
middlewares: Vec<TypedMiddlewareSlot>,
sync_gate: Arc<Mutex<()>>,
async_semaphore: Option<Arc<Semaphore>>,
}
impl MutableTypeSlot {
fn to_snapshot_slot(&self) -> Arc<TypeSlot> {
let mut sync_listeners = Vec::new();
let mut async_listeners = Vec::new();
for listener in &self.listeners {
match listener.kind {
ListenerKind::Sync(_) => sync_listeners.push(listener.clone()),
ListenerKind::Async(_) => async_listeners.push(listener.clone()),
}
}
Arc::new(TypeSlot {
sync_listeners: sync_listeners.into(),
async_listeners: async_listeners.into(),
middlewares: self.middlewares.clone().into(),
sync_gate: Arc::clone(&self.sync_gate),
async_semaphore: self.async_semaphore.as_ref().map(Arc::clone),
})
}
}
enum IndexEntry {
Listener(TypeId),
TypedMiddleware(TypeId),
GlobalMiddleware,
}
pub(crate) struct MutableRegistry {
slots: HashMap<TypeId, MutableTypeSlot>,
global_middlewares: Vec<(SubscriptionId, ErasedMiddleware)>,
index: HashMap<SubscriptionId, IndexEntry>,
type_names: HashMap<TypeId, &'static str>,
max_concurrent_async: Option<usize>,
}
impl MutableRegistry {
pub(crate) fn new(max_concurrent_async: Option<usize>) -> Self {
Self {
slots: HashMap::new(),
global_middlewares: Vec::new(),
index: HashMap::new(),
type_names: HashMap::new(),
max_concurrent_async,
}
}
fn ensure_slot(&mut self, event_type: TypeId, event_name: &'static str) -> &mut MutableTypeSlot {
self.type_names.entry(event_type).or_insert(event_name);
self.slots.entry(event_type).or_insert_with(|| MutableTypeSlot {
event_name,
listeners: Vec::new(),
middlewares: Vec::new(),
sync_gate: Arc::new(Mutex::new(())),
async_semaphore: self.max_concurrent_async.map(|n| Arc::new(Semaphore::new(n))),
})
}
pub(crate) fn add_listener(&mut self, event_type: TypeId, event_name: &'static str, listener: ListenerEntry) {
self.ensure_slot(event_type, event_name).listeners.push(listener.clone());
self.index.insert(listener.id, IndexEntry::Listener(event_type));
}
pub(crate) fn add_typed_middleware(&mut self, event_type: TypeId, event_name: &'static str, middleware: TypedMiddlewareSlot) {
self.ensure_slot(event_type, event_name).middlewares.push(middleware.clone());
self.index.insert(middleware.id, IndexEntry::TypedMiddleware(event_type));
}
pub(crate) fn add_global_middleware(&mut self, id: SubscriptionId, middleware: ErasedMiddleware) {
self.global_middlewares.push((id, middleware));
self.index.insert(id, IndexEntry::GlobalMiddleware);
}
pub(crate) fn remove_once(&mut self, subscription_id: SubscriptionId) {
let Some(IndexEntry::Listener(event_type)) = self.index.get(&subscription_id) else {
return;
};
if let Some(slot) = self.slots.get_mut(event_type) {
slot.listeners.retain(|l| l.id != subscription_id);
if slot.listeners.is_empty() && slot.middlewares.is_empty() {
self.slots.remove(event_type);
self.type_names.remove(event_type);
}
}
self.index.remove(&subscription_id);
}
pub(crate) fn remove_subscription(&mut self, subscription_id: SubscriptionId) -> bool {
match self.index.remove(&subscription_id) {
Some(IndexEntry::GlobalMiddleware) => {
let before = self.global_middlewares.len();
self.global_middlewares.retain(|(id, _)| *id != subscription_id);
before != self.global_middlewares.len()
}
Some(IndexEntry::Listener(event_type)) => {
if let Some(slot) = self.slots.get_mut(&event_type) {
let before = slot.listeners.len();
slot.listeners.retain(|l| l.id != subscription_id);
let removed = before != slot.listeners.len();
if slot.listeners.is_empty() && slot.middlewares.is_empty() {
self.slots.remove(&event_type);
self.type_names.remove(&event_type);
}
removed
} else {
false
}
}
Some(IndexEntry::TypedMiddleware(event_type)) => {
if let Some(slot) = self.slots.get_mut(&event_type) {
let before = slot.middlewares.len();
slot.middlewares.retain(|m| m.id != subscription_id);
let removed = before != slot.middlewares.len();
if slot.listeners.is_empty() && slot.middlewares.is_empty() {
self.slots.remove(&event_type);
self.type_names.remove(&event_type);
}
removed
} else {
false
}
}
None => false,
}
}
pub(crate) fn snapshot(&self) -> RegistrySnapshot {
let mut by_type = HashMap::with_capacity(self.slots.len());
for (type_id, slot) in &self.slots {
by_type.insert(*type_id, slot.to_snapshot_slot());
}
RegistrySnapshot {
by_type,
global_middlewares: self.global_middlewares.clone().into(),
}
}
pub(crate) fn stats(&self, in_flight_async: usize, queue_capacity: usize, shutdown_called: bool) -> crate::types::BusStats {
let mut subscriptions_by_event: HashMap<&'static str, Vec<ListenerInfo>> = HashMap::new();
let mut total_subscriptions = 0usize;
let mut registered_event_types = Vec::new();
for (type_id, slot) in &self.slots {
if slot.listeners.is_empty() {
continue;
}
let event_name = self.type_names.get(type_id).copied().unwrap_or(slot.event_name);
registered_event_types.push(event_name);
let infos: Vec<ListenerInfo> = slot
.listeners
.iter()
.map(|l| ListenerInfo {
subscription_id: l.id,
name: l.name,
})
.collect();
total_subscriptions += infos.len();
subscriptions_by_event.insert(event_name, infos);
}
registered_event_types.sort_unstable();
crate::types::BusStats {
total_subscriptions,
subscriptions_by_event,
registered_event_types,
queue_capacity,
in_flight_async,
shutdown_called,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct ListenerFailure {
pub event_name: &'static str,
pub subscription_id: SubscriptionId,
pub attempts: usize,
pub error: String,
pub dead_letter: bool,
pub event: EventType,
pub listener_name: Option<&'static str>,
}
pub(crate) enum ControlNotification {
Failure(ListenerFailure),
Flush(oneshot::Sender<()>),
}
#[derive(Clone)]
pub(crate) struct DispatchContext {
pub tracker: Arc<AsyncTaskTracker>,
pub notify_tx: mpsc::UnboundedSender<ControlNotification>,
pub handler_timeout: Option<Duration>,
pub spawn_async_handlers: bool,
}
#[derive(Default)]
pub(crate) struct AsyncTaskTracker {
next_id: AtomicU64,
in_flight: AtomicUsize,
tasks: Mutex<HashMap<u64, Option<AbortHandle>>>,
notify: Notify,
}
impl AsyncTaskTracker {
pub(crate) async fn spawn_tracked<F>(self: &Arc<Self>, fut: F)
where
F: Future<Output = ()> + Send + 'static,
{
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
self.in_flight.fetch_add(1, Ordering::AcqRel);
let tracker = Arc::clone(self);
let handle = tokio::spawn(async move {
fut.await;
tracker.finish_task(id).await;
});
let abort_handle = handle.abort_handle();
self.tasks.lock().await.insert(id, Some(abort_handle));
}
pub(crate) fn in_flight(&self) -> usize {
self.in_flight.load(Ordering::Acquire)
}
pub(crate) async fn shutdown(&self, timeout: Option<Duration>) -> bool {
if self.in_flight() == 0 {
return false;
}
if let Some(timeout) = timeout {
let wait = async {
loop {
let notified = self.notify.notified();
if self.in_flight() == 0 {
return;
}
notified.await;
}
};
if tokio::time::timeout(timeout, wait).await.is_err() {
let handles: Vec<AbortHandle> = {
let mut guard = self.tasks.lock().await;
guard.drain().filter_map(|(_, h)| h).collect()
};
for handle in &handles {
handle.abort();
}
return true;
}
false
} else {
loop {
let notified = self.notify.notified();
if self.in_flight() == 0 {
break;
}
notified.await;
}
false
}
}
async fn finish_task(&self, id: u64) {
self.in_flight.fetch_sub(1, Ordering::AcqRel);
self.remove_abort_handle(id).await;
if self.in_flight.load(Ordering::Acquire) == 0 {
self.notify.notify_waiters();
}
}
async fn remove_abort_handle(&self, id: u64) {
self.tasks.lock().await.remove(&id);
}
}
fn sync_listener_failed(listener: &ListenerEntry, event_name: &'static str, event: &EventType, err: String) -> ListenerFailure {
ListenerFailure {
event_name,
subscription_id: listener.id,
attempts: 1,
error: err,
dead_letter: listener.failure_policy.dead_letter,
event: Arc::clone(event),
listener_name: listener.name,
}
}
async fn execute_async_listener(
handler: ErasedAsyncHandlerFn,
event: EventType,
event_name: &'static str,
listener: ListenerEntry,
handler_timeout: Option<Duration>,
) -> Option<ListenerFailure> {
let mut retries_left = listener.failure_policy.max_retries;
let mut attempts = 0usize;
loop {
attempts += 1;
#[cfg(feature = "metrics")]
let _timer = TimerGuard::start("eventbus.handler.duration", event_name);
let result = match handler_timeout {
Some(timeout) => {
let mut join = tokio::spawn(handler(Arc::clone(&event)));
match tokio::time::timeout(timeout, &mut join).await {
Ok(Ok(inner)) => inner,
Ok(Err(join_error)) => Err(format!("handler task failed: {join_error}").into()),
Err(_) => {
join.abort();
let _ = join.await;
Err(format!("handler timed out after {timeout:?}").into())
}
}
}
None => match tokio::spawn(handler(Arc::clone(&event))).await {
Ok(inner) => inner,
Err(join_error) => Err(format!("handler task failed: {join_error}").into()),
},
};
match result {
Ok(()) => return None,
Err(err) => {
let error_message = err.to_string();
if retries_left == 0 {
return Some(ListenerFailure {
event_name,
subscription_id: listener.id,
attempts,
error: error_message,
dead_letter: listener.failure_policy.dead_letter,
event: Arc::clone(&event),
listener_name: listener.name,
});
}
retries_left -= 1;
warn!(
event = event_name,
listener_id = listener.id.as_u64(),
attempts,
retries_left,
error = %error_message,
"handler.retry"
);
if let Some(strategy) = listener.failure_policy.retry_strategy {
tokio::time::sleep(strategy.delay_for_attempt(attempts - 1)).await;
}
}
}
}
}
fn should_fire_once(listener: &ListenerEntry) -> bool {
if !listener.once {
return true;
}
if let Some(flag) = &listener.fired {
flag.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire).is_ok()
} else {
true
}
}
pub(crate) async fn dispatch_with_snapshot(
snapshot: &RegistrySnapshot,
event_type: TypeId,
event: EventType,
event_name: &'static str,
dispatch_ctx: &DispatchContext,
) -> Result<Vec<SubscriptionId>, EventBusError> {
let mut once_removed = Vec::new();
#[cfg(feature = "metrics")]
counter!("eventbus.publish", "event" => event_name).increment(1);
if !snapshot.global_middlewares.is_empty() {
for (_id, mw) in snapshot.global_middlewares.iter() {
let decision = match mw {
ErasedMiddleware::Async(f) => f(event_name, Arc::clone(&event)).await,
ErasedMiddleware::Sync(f) => f(event_name, event.as_ref()),
};
if let MiddlewareDecision::Reject(reason) = decision {
return Err(EventBusError::MiddlewareRejected(reason));
}
}
}
let Some(slot) = snapshot.by_type.get(&event_type) else {
return Ok(once_removed);
};
if !slot.middlewares.is_empty() {
for slot_mw in slot.middlewares.iter() {
let decision = match &slot_mw.middleware {
TypedMiddlewareEntry::Async(mw) => mw(event_name, Arc::clone(&event)).await,
TypedMiddlewareEntry::Sync(mw) => mw(event_name, event.as_ref()),
};
if let MiddlewareDecision::Reject(reason) = decision {
return Err(EventBusError::MiddlewareRejected(reason));
}
}
}
if dispatch_ctx.spawn_async_handlers {
let tracker = Arc::clone(&dispatch_ctx.tracker);
let handler_timeout = dispatch_ctx.handler_timeout;
for listener in slot.async_listeners.iter() {
let ListenerKind::Async(handler) = &listener.kind else {
continue;
};
if !should_fire_once(listener) {
continue;
}
if listener.once {
once_removed.push(listener.id);
}
let listener = listener.clone();
let handler = Arc::clone(handler);
let event = Arc::clone(&event);
let notify = dispatch_ctx.notify_tx.clone();
let semaphore = slot.async_semaphore.as_ref().map(Arc::clone);
tracker
.spawn_tracked(async move {
let task = async {
if let Some(failure) = execute_async_listener(handler, event, event_name, listener, handler_timeout).await {
let _ = notify.send(ControlNotification::Failure(failure));
}
};
if let Some(semaphore) = semaphore {
if let Ok(_permit) = semaphore.acquire().await {
task.await;
}
} else {
task.await;
}
})
.await;
}
}
if slot.sync_listeners.is_empty() {
return Ok(once_removed);
}
let _guard = slot.sync_gate.lock().await;
for listener in slot.sync_listeners.iter() {
let ListenerKind::Sync(handler) = &listener.kind else {
continue;
};
if !should_fire_once(listener) {
continue;
}
if listener.once {
once_removed.push(listener.id);
}
#[cfg(feature = "metrics")]
let _timer = TimerGuard::start("eventbus.handler.duration", event_name);
let result = catch_unwind(AssertUnwindSafe(|| {
handler(event.as_ref())
}))
.unwrap_or_else(|panic_payload| {
let msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
(*s).to_string()
} else if let Some(s) = panic_payload.downcast_ref::<String>() {
s.clone()
} else {
"handler panicked".to_string()
};
Err(msg.into())
});
if let Err(err) = result {
let _ = dispatch_ctx.notify_tx.send(ControlNotification::Failure(sync_listener_failed(
listener,
event_name,
&event,
err.to_string(),
)));
}
}
Ok(once_removed)
}
pub(crate) fn dead_letter_from_failure(failure: &ListenerFailure) -> Option<DeadLetter> {
error!(
event = failure.event_name,
listener_id = failure.subscription_id.as_u64(),
attempts = failure.attempts,
error = %failure.error,
"handler.failed"
);
#[cfg(feature = "metrics")]
counter!("eventbus.handler.error", "event" => failure.event_name).increment(1);
let dead_letter_type = std::any::type_name::<DeadLetter>();
if failure.dead_letter && failure.event_name != dead_letter_type {
Some(DeadLetter {
event_name: failure.event_name,
subscription_id: failure.subscription_id,
attempts: failure.attempts,
error: failure.error.clone(),
event: failure.event.clone(),
failed_at: std::time::SystemTime::now(),
listener_name: failure.listener_name,
})
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::AsyncTaskTracker;
use std::sync::Arc;
use tokio::sync::Barrier;
#[tokio::test]
async fn tracker_does_not_leak_handles_for_fast_tasks() {
let tracker = Arc::new(AsyncTaskTracker::default());
let barrier = Arc::new(Barrier::new(2));
tracker
.spawn_tracked({
let barrier = Arc::clone(&barrier);
async move {
barrier.wait().await;
}
})
.await;
barrier.wait().await;
let timed_out = tracker.shutdown(Some(std::time::Duration::from_secs(1))).await;
assert!(!timed_out, "tracker shutdown should complete without timeout");
let guard = tracker.tasks.lock().await;
assert!(guard.is_empty(), "tracker task map should be empty after completion");
}
}