use crate::{Error, Event, Listener};
use parking_lot::RwLock;
use std::any::TypeId;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tracing::{debug, error, info};
type ListenerFn<E> =
Arc<dyn Fn(&E) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send>> + Send + Sync>;
struct ListenerEntry {
handler: Box<dyn std::any::Any + Send + Sync>,
priority: i32,
}
pub struct EventDispatcher {
listeners: RwLock<HashMap<TypeId, Vec<ListenerEntry>>>,
}
impl Default for EventDispatcher {
fn default() -> Self {
Self::new()
}
}
impl EventDispatcher {
pub fn new() -> Self {
Self {
listeners: RwLock::new(HashMap::new()),
}
}
pub fn listen<E, L>(&self, listener: L)
where
E: Event,
L: Listener<E>,
{
self.listen_with_priority(listener, 0);
}
pub fn listen_with_priority<E, L>(&self, listener: L, priority: i32)
where
E: Event,
L: Listener<E>,
{
let listener = Arc::new(listener);
let handler: ListenerFn<E> = Arc::new(move |event: &E| {
let listener = Arc::clone(&listener);
let event = event.clone();
Box::pin(async move { listener.handle(&event).await })
});
let entry = ListenerEntry {
handler: Box::new(handler),
priority,
};
let type_id = TypeId::of::<E>();
let mut listeners = self.listeners.write();
let list = listeners.entry(type_id).or_default();
list.push(entry);
list.sort_by(|a, b| b.priority.cmp(&a.priority));
}
pub fn on<E, F, Fut>(&self, handler: F)
where
E: Event,
F: Fn(E) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), Error>> + Send + 'static,
{
let handler = Arc::new(handler);
let listener_fn: ListenerFn<E> = Arc::new(move |event: &E| {
let handler = Arc::clone(&handler);
let event = event.clone();
Box::pin(async move { handler(event).await })
});
let entry = ListenerEntry {
handler: Box::new(listener_fn),
priority: 0,
};
let type_id = TypeId::of::<E>();
let mut listeners = self.listeners.write();
listeners.entry(type_id).or_default().push(entry);
}
pub async fn dispatch<E: Event>(&self, event: E) -> Result<(), Error> {
let type_id = TypeId::of::<E>();
let event_name = event.name();
debug!(event = event_name, "Dispatching event");
let handlers: Vec<ListenerFn<E>> = {
let listeners = self.listeners.read();
match listeners.get(&type_id) {
Some(entries) => entries
.iter()
.filter_map(|entry| entry.handler.downcast_ref::<ListenerFn<E>>().cloned())
.collect(),
None => {
debug!(event = event_name, "No listeners registered");
return Ok(());
}
}
};
info!(
event = event_name,
listener_count = handlers.len(),
"Calling listeners"
);
for handler in handlers {
if let Err(e) = handler(&event).await {
error!(event = event_name, error = %e, "Listener failed");
return Err(e);
}
}
debug!(event = event_name, "Event dispatched successfully");
Ok(())
}
pub fn dispatch_async<E: Event + 'static>(&self, event: E) {
let type_id = TypeId::of::<E>();
let event_name = event.name();
let handlers: Vec<ListenerFn<E>> = {
let listeners = self.listeners.read();
match listeners.get(&type_id) {
Some(entries) => entries
.iter()
.filter_map(|entry| entry.handler.downcast_ref::<ListenerFn<E>>().cloned())
.collect(),
None => return,
}
};
tokio::spawn(async move {
for handler in handlers {
if let Err(e) = handler(&event).await {
error!(event = event_name, error = %e, "Async listener failed");
}
}
});
}
pub fn has_listeners<E: Event>(&self) -> bool {
let type_id = TypeId::of::<E>();
let listeners = self.listeners.read();
listeners.get(&type_id).is_some_and(|v| !v.is_empty())
}
pub fn forget<E: Event>(&self) {
let type_id = TypeId::of::<E>();
let mut listeners = self.listeners.write();
listeners.remove(&type_id);
}
pub fn flush(&self) {
let mut listeners = self.listeners.write();
listeners.clear();
}
}
static GLOBAL_DISPATCHER: std::sync::OnceLock<EventDispatcher> = std::sync::OnceLock::new();
pub fn global_dispatcher() -> &'static EventDispatcher {
GLOBAL_DISPATCHER.get_or_init(EventDispatcher::new)
}
pub async fn dispatch<E: Event>(event: E) -> Result<(), Error> {
global_dispatcher().dispatch(event).await
}
pub fn dispatch_sync<E: Event + 'static>(event: E) {
global_dispatcher().dispatch_async(event);
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
#[derive(Clone)]
struct TestEvent {
value: u32,
}
impl Event for TestEvent {
fn name(&self) -> &'static str {
"TestEvent"
}
}
#[tokio::test]
async fn test_dispatch_to_closure() {
let dispatcher = EventDispatcher::new();
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
dispatcher.on::<TestEvent, _, _>(move |event| {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(event.value, Ordering::SeqCst);
Ok(())
}
});
dispatcher.dispatch(TestEvent { value: 5 }).await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 5);
}
#[tokio::test]
async fn test_multiple_listeners() {
let dispatcher = EventDispatcher::new();
let counter = Arc::new(AtomicU32::new(0));
for _ in 0..3 {
let counter_clone = Arc::clone(&counter);
dispatcher.on::<TestEvent, _, _>(move |_| {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok(())
}
});
}
dispatcher.dispatch(TestEvent { value: 1 }).await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_priority_order() {
let dispatcher = EventDispatcher::new();
let order = Arc::new(RwLock::new(Vec::new()));
for priority in [1, 3, 2] {
let order_clone = Arc::clone(&order);
let handler: ListenerFn<TestEvent> = Arc::new(move |_| {
let order = Arc::clone(&order_clone);
let p = priority;
Box::pin(async move {
order.write().push(p);
Ok(())
})
});
let entry = ListenerEntry {
handler: Box::new(handler),
priority,
};
let type_id = TypeId::of::<TestEvent>();
let mut listeners = dispatcher.listeners.write();
let list = listeners.entry(type_id).or_default();
list.push(entry);
list.sort_by(|a, b| b.priority.cmp(&a.priority));
}
dispatcher.dispatch(TestEvent { value: 0 }).await.unwrap();
let result = order.read().clone();
assert_eq!(result, vec![3, 2, 1]);
}
#[tokio::test]
async fn test_has_listeners() {
let dispatcher = EventDispatcher::new();
assert!(!dispatcher.has_listeners::<TestEvent>());
dispatcher.on::<TestEvent, _, _>(|_| async { Ok(()) });
assert!(dispatcher.has_listeners::<TestEvent>());
dispatcher.forget::<TestEvent>();
assert!(!dispatcher.has_listeners::<TestEvent>());
}
#[tokio::test]
async fn test_no_listeners() {
let dispatcher = EventDispatcher::new();
let result = dispatcher.dispatch(TestEvent { value: 1 }).await;
assert!(result.is_ok());
}
}