use std::{
any::{Any, TypeId},
sync::Arc,
};
use arc_swap::ArcSwap;
use funcext::FuncExt;
use futures_util::{FutureExt, future::join_all};
use scc::HashMap;
#[derive(Clone)]
pub struct DynContext(Arc<ArcSwap<Arc<dyn Any + Send + Sync>>>);
impl Default for DynContext {
fn default() -> Self {
Self::new()
}
}
impl DynContext {
pub fn new() -> Self {
Self(Arc::new(ArcSwap::from_pointee(
Arc::new(()) as Arc<dyn Any + Send + Sync>
)))
}
pub fn set<T: Send + Sync + 'static>(&self, val: T) {
self.store(Arc::new(val));
}
pub fn store(&self, val: Arc<dyn Any + Send + Sync>) {
self.0.store(Arc::new(val));
}
pub fn load(&self) -> Arc<dyn Any + Send + Sync> {
(**self.0.load()).clone()
}
pub fn get<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
self.load().downcast::<T>().ok()
}
}
pub struct Context<T>(pub Arc<T>);
#[async_trait::async_trait]
pub trait FromEvent<E>: Send + Sync + Sized {
async fn from_event(event: Arc<E>, extensions: &DynContext) -> Option<Self>;
}
#[async_trait::async_trait]
impl<E: Send + Sync + 'static> FromEvent<E> for Arc<E> {
async fn from_event(event: Arc<E>, _extensions: &DynContext) -> Option<Self> {
Some(event)
}
}
#[async_trait::async_trait]
impl<E: Send + Sync + 'static, T: Send + Sync + 'static> FromEvent<E> for Context<T> {
async fn from_event(_event: Arc<E>, extensions: &DynContext) -> Option<Self> {
extensions.get::<T>().map(Context)
}
}
#[async_trait::async_trait]
pub trait Handler<E, T>: Send + Sync {
async fn handle(&self, event: Arc<E>, extensions: &DynContext);
}
macro_rules! impl_handler {
($($ty:ident),*) => {
#[async_trait::async_trait]
impl<E, F, Fut, $($ty,)*> Handler<E, ($($ty,)*)> for F
where
E: Send + Sync + 'static,
F: Fn($($ty,)*) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
$($ty: FromEvent<E> + Send + Sync + 'static,)*
{
async fn handle(&self, event: Arc<E>, extensions: &DynContext) {
$(
let $ty = match $ty::from_event(event.clone(), extensions).await {
Some(val) => val,
None => return,
};
)*
(self)($($ty,)*).await;
}
}
};
}
impl_handler!();
impl_handler!(T1);
impl_handler!(T1, T2);
impl_handler!(T1, T2, T3);
impl_handler!(T1, T2, T3, T4);
impl_handler!(T1, T2, T3, T4, T5);
impl_handler!(T1, T2, T3, T4, T5, T6);
impl_handler!(T1, T2, T3, T4, T5, T6, T7);
impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8);
#[async_trait::async_trait]
trait ErasedHandler: Send + Sync {
async fn handle_erased(&self, e: Arc<dyn Any + Send + Sync>, extensions: &DynContext);
}
struct HandlerWrapper<E, T, H> {
handler: H,
_marker: std::marker::PhantomData<fn(E, T)>,
}
#[async_trait::async_trait]
impl<E, T, H> ErasedHandler for HandlerWrapper<E, T, H>
where
E: Any + Send + Sync + 'static,
T: Send + Sync + 'static,
H: Handler<E, T> + Send + Sync + 'static,
{
async fn handle_erased(&self, e: Arc<dyn Any + Send + Sync>, extensions: &DynContext) {
if let Ok(event) = e.downcast::<E>() {
self.handler.handle(event, extensions).await;
}
}
}
pub struct EventDispatcher {
handlers: HashMap<TypeId, Vec<Arc<dyn ErasedHandler>>>,
pub extensions: DynContext,
}
impl Default for EventDispatcher {
fn default() -> Self {
Self::new()
}
}
impl EventDispatcher {
pub fn new() -> Self {
Self {
handlers: HashMap::new(),
extensions: DynContext::new(),
}
}
pub async fn register<E, T, H>(&self, handler: H)
where
E: Any + Send + Sync + 'static,
T: Send + Sync + 'static,
H: Handler<E, T> + Send + Sync + 'static,
{
self.handlers
.entry_async(TypeId::of::<E>())
.await
.or_default()
.push(Arc::new(HandlerWrapper {
handler,
_marker: std::marker::PhantomData,
}));
}
pub async fn emit<E: Send + Sync + 'static>(&self, event: E) {
self.emit_arc(Arc::new(event)).await;
}
pub async fn emit_arc<E: Send + Sync + 'static>(&self, event: Arc<E>) {
let Some(handlers) = self.handlers.get_async(&TypeId::of::<E>()).await else {
return;
};
let event: Arc<dyn Any + Send + Sync> = event;
handlers
.iter()
.map(|h| {
std::panic::AssertUnwindSafe(h.handle_erased(event.clone(), &self.extensions))
.catch_unwind()
})
.R(join_all)
.await;
}
}