use crate::metrics::EventMetricsCounters;
use crate::{
DispatchResult, Event, EventMetadata, ListenerError, ListenerId, ListenerWrapper,
MiddlewareManager, Priority,
};
use parking_lot::RwLock;
use std::any::TypeId;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[cfg(feature = "async")]
use crate::{AsyncEventResult, AsyncListenerWrapper};
#[cfg(feature = "async")]
type AsyncHandler = Arc<dyn for<'a> Fn(&'a dyn Event) -> AsyncEventResult<'a> + Send + Sync>;
pub struct EventDispatcher {
listeners: Arc<RwLock<HashMap<TypeId, Vec<ListenerWrapper>>>>,
#[cfg(feature = "async")]
async_listeners: Arc<RwLock<HashMap<TypeId, Vec<AsyncListenerWrapper>>>>,
next_id: AtomicUsize,
metrics: Arc<RwLock<HashMap<TypeId, Arc<EventMetricsCounters>>>>,
middleware: Arc<RwLock<MiddlewareManager>>,
}
impl EventDispatcher {
#[must_use]
pub fn new() -> Self {
Self {
listeners: Arc::new(RwLock::new(HashMap::new())),
#[cfg(feature = "async")]
async_listeners: Arc::new(RwLock::new(HashMap::new())),
next_id: AtomicUsize::new(0),
metrics: Arc::new(RwLock::new(HashMap::new())),
middleware: Arc::new(RwLock::new(MiddlewareManager::new())),
}
}
pub fn subscribe<T, F>(&self, listener: F) -> ListenerId
where
T: Event + 'static,
F: Fn(&T) -> Result<(), ListenerError> + Send + Sync + 'static,
{
self.subscribe_with_priority(listener, Priority::Normal)
}
pub fn subscribe_with_priority<T, F>(&self, listener: F, priority: Priority) -> ListenerId
where
T: Event + 'static,
F: Fn(&T) -> Result<(), ListenerError> + Send + Sync + 'static,
{
let type_id = TypeId::of::<T>();
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let wrapper = ListenerWrapper::new(listener, priority, id);
{
let mut listeners = self.listeners.write();
let event_listeners = listeners.entry(type_id).or_default();
let pos = event_listeners.partition_point(|existing| existing.priority >= priority);
event_listeners.insert(pos, wrapper);
}
let _counters = self.counters_for::<T>();
ListenerId::new(id, type_id)
}
pub fn on<T, F>(&self, listener: F) -> ListenerId
where
T: Event + 'static,
F: Fn(&T) + Send + Sync + 'static,
{
self.subscribe(move |event: &T| {
listener(event);
Ok(())
})
}
#[cfg(feature = "async")]
pub fn subscribe_async<T, F, Fut>(&self, listener: F) -> ListenerId
where
T: Event + 'static,
F: Fn(&T) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<(), ListenerError>> + Send + 'static,
{
self.subscribe_async_with_priority(listener, Priority::Normal)
}
#[cfg(feature = "async")]
pub fn subscribe_async_with_priority<T, F, Fut>(
&self,
listener: F,
priority: Priority,
) -> ListenerId
where
T: Event + 'static,
F: Fn(&T) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<(), ListenerError>> + Send + 'static,
{
let type_id = TypeId::of::<T>();
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let wrapper = AsyncListenerWrapper::new(listener, priority, id);
{
let mut async_listeners = self.async_listeners.write();
let event_listeners = async_listeners.entry(type_id).or_default();
let pos = event_listeners.partition_point(|existing| existing.priority >= priority);
event_listeners.insert(pos, wrapper);
}
let _counters = self.counters_for::<T>();
ListenerId::new(id, type_id)
}
pub fn dispatch<T: Event>(&self, event: T) -> DispatchResult {
self.update_metrics(&event);
if !self.check_middleware(&event) {
return DispatchResult::blocked();
}
let type_id = TypeId::of::<T>();
let listeners = self.listeners.read();
let mut results = Vec::new();
if let Some(event_listeners) = listeners.get(&type_id) {
results.reserve(event_listeners.len());
for listener in event_listeners {
results.push((listener.handler)(&event));
}
}
DispatchResult::new(results)
}
#[cfg(feature = "async")]
pub async fn dispatch_async<T: Event>(&self, event: T) -> DispatchResult {
self.update_metrics(&event);
if !self.check_middleware(&event) {
return DispatchResult::blocked();
}
let type_id = TypeId::of::<T>();
let handlers: Vec<AsyncHandler> = {
let async_listeners = self.async_listeners.read();
async_listeners
.get(&type_id)
.map(|event_listeners| {
event_listeners
.iter()
.map(|listener| listener.handler.clone())
.collect()
})
.unwrap_or_default()
};
let mut results = Vec::with_capacity(handlers.len());
for handler in handlers {
results.push(handler(&event).await);
}
DispatchResult::new(results)
}
pub fn emit<T: Event>(&self, event: T) {
drop(self.dispatch(event));
}
pub fn add_middleware<F>(&self, middleware: F)
where
F: Fn(&dyn Event) -> bool + Send + Sync + 'static,
{
self.middleware.write().add(middleware);
}
pub fn unsubscribe(&self, listener_id: ListenerId) -> bool {
{
let mut listeners = self.listeners.write();
if let Some(event_listeners) = listeners.get_mut(&listener_id.type_id) {
if let Some(pos) = event_listeners.iter().position(|l| l.id == listener_id.id) {
let _removed = event_listeners.remove(pos);
return true;
}
}
}
#[cfg(feature = "async")]
{
let mut async_listeners = self.async_listeners.write();
if let Some(event_listeners) = async_listeners.get_mut(&listener_id.type_id) {
if let Some(pos) = event_listeners.iter().position(|l| l.id == listener_id.id) {
let _removed = event_listeners.remove(pos);
return true;
}
}
}
false
}
#[must_use]
pub fn listener_count<T: Event + 'static>(&self) -> usize {
let type_id = TypeId::of::<T>();
let sync_count = self
.listeners
.read()
.get(&type_id)
.map(Vec::len)
.unwrap_or(0);
#[cfg(feature = "async")]
let async_count = self
.async_listeners
.read()
.get(&type_id)
.map(Vec::len)
.unwrap_or(0);
#[cfg(not(feature = "async"))]
let async_count = 0;
sync_count + async_count
}
#[must_use]
pub fn metrics(&self) -> HashMap<TypeId, EventMetadata> {
let counters_map = self.metrics.read();
let listeners_map = self.listeners.read();
#[cfg(feature = "async")]
let async_listeners_map = self.async_listeners.read();
counters_map
.iter()
.map(|(type_id, counters)| {
let mut snap = counters.snapshot();
let sync_count = listeners_map.get(type_id).map(Vec::len).unwrap_or(0);
#[cfg(feature = "async")]
let async_count = async_listeners_map.get(type_id).map(Vec::len).unwrap_or(0);
#[cfg(not(feature = "async"))]
let async_count = 0;
snap.listener_count = sync_count + async_count;
(*type_id, snap)
})
.collect()
}
pub fn clear(&self) {
self.listeners.write().clear();
#[cfg(feature = "async")]
self.async_listeners.write().clear();
}
fn update_metrics<T: Event>(&self, _event: &T) {
let counters = self.counters_for::<T>();
counters.record_dispatch();
}
fn counters_for<T: Event + 'static>(&self) -> Arc<EventMetricsCounters> {
let type_id = TypeId::of::<T>();
if let Some(existing) = self.metrics.read().get(&type_id) {
return Arc::clone(existing);
}
let mut metrics = self.metrics.write();
Arc::clone(
metrics
.entry(type_id)
.or_insert_with(|| Arc::new(EventMetricsCounters::new::<T>())),
)
}
fn check_middleware(&self, event: &dyn Event) -> bool {
self.middleware.read().process(event)
}
}
impl Default for EventDispatcher {
fn default() -> Self {
Self::new()
}
}