Documentation
use std::{
	any::{Any, TypeId},
	sync::Arc,
};

use arc_swap::ArcSwap;
use funcext::FuncExt;
use futures_util::{FutureExt, future::join_all};
use scc::HashMap;

#[derive(Clone)]
pub struct DynContext(Arc<ArcSwap<Arc<dyn Any + Send + Sync>>>);

impl Default for DynContext {
	fn default() -> Self {
		Self::new()
	}
}

impl DynContext {
	pub fn new() -> Self {
		Self(Arc::new(ArcSwap::from_pointee(
			Arc::new(()) as Arc<dyn Any + Send + Sync>
		)))
	}

	pub fn set<T: Send + Sync + 'static>(&self, val: T) {
		self.store(Arc::new(val));
	}

	pub fn store(&self, val: Arc<dyn Any + Send + Sync>) {
		self.0.store(Arc::new(val));
	}

	pub fn load(&self) -> Arc<dyn Any + Send + Sync> {
		(**self.0.load()).clone()
	}

	pub fn get<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
		self.load().downcast::<T>().ok()
	}
}

pub struct Context<T>(pub Arc<T>);

#[async_trait::async_trait]
pub trait FromEvent<E>: Send + Sync + Sized {
	async fn from_event(event: Arc<E>, extensions: &DynContext) -> Option<Self>;
}

#[async_trait::async_trait]
impl<E: Send + Sync + 'static> FromEvent<E> for Arc<E> {
	async fn from_event(event: Arc<E>, _extensions: &DynContext) -> Option<Self> {
		Some(event)
	}
}

#[async_trait::async_trait]
impl<E: Send + Sync + 'static, T: Send + Sync + 'static> FromEvent<E> for Context<T> {
	async fn from_event(_event: Arc<E>, extensions: &DynContext) -> Option<Self> {
		extensions.get::<T>().map(Context)
	}
}

#[async_trait::async_trait]
pub trait Handler<E, T>: Send + Sync {
	async fn handle(&self, event: Arc<E>, extensions: &DynContext);
}

macro_rules! impl_handler {
    ($($ty:ident),*) => {
        #[async_trait::async_trait]
        impl<E, F, Fut, $($ty,)*> Handler<E, ($($ty,)*)> for F
        where
            E: Send + Sync + 'static,
            F: Fn($($ty,)*) -> Fut + Send + Sync + 'static,
            Fut: std::future::Future<Output = ()> + Send + 'static,
            $($ty: FromEvent<E> + Send + Sync + 'static,)*
        {
            async fn handle(&self, event: Arc<E>, extensions: &DynContext) {
                $(
                    let $ty = match $ty::from_event(event.clone(), extensions).await {
                        Some(val) => val,
                        None => return,
                    };
                )*
                (self)($($ty,)*).await;
            }
        }
    };
}

impl_handler!();
impl_handler!(T1);
impl_handler!(T1, T2);
impl_handler!(T1, T2, T3);
impl_handler!(T1, T2, T3, T4);
impl_handler!(T1, T2, T3, T4, T5);
impl_handler!(T1, T2, T3, T4, T5, T6);
impl_handler!(T1, T2, T3, T4, T5, T6, T7);
impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8);

#[async_trait::async_trait]
trait ErasedHandler: Send + Sync {
	async fn handle_erased(&self, e: Arc<dyn Any + Send + Sync>, extensions: &DynContext);
}

struct HandlerWrapper<E, T, H> {
	handler: H,
	_marker: std::marker::PhantomData<fn(E, T)>,
}

#[async_trait::async_trait]
impl<E, T, H> ErasedHandler for HandlerWrapper<E, T, H>
where
	E: Any + Send + Sync + 'static,
	T: Send + Sync + 'static,
	H: Handler<E, T> + Send + Sync + 'static,
{
	async fn handle_erased(&self, e: Arc<dyn Any + Send + Sync>, extensions: &DynContext) {
		if let Ok(event) = e.downcast::<E>() {
			self.handler.handle(event, extensions).await;
		}
	}
}

pub struct EventDispatcher {
	handlers: HashMap<TypeId, Vec<Arc<dyn ErasedHandler>>>,
	pub extensions: DynContext,
}

impl Default for EventDispatcher {
	fn default() -> Self {
		Self::new()
	}
}

impl EventDispatcher {
	pub fn new() -> Self {
		Self {
			handlers: HashMap::new(),
			extensions: DynContext::new(),
		}
	}

	pub async fn register<E, T, H>(&self, handler: H)
	where
		E: Any + Send + Sync + 'static,
		T: Send + Sync + 'static,
		H: Handler<E, T> + Send + Sync + 'static,
	{
		self.handlers
			.entry_async(TypeId::of::<E>())
			.await
			.or_default()
			.push(Arc::new(HandlerWrapper {
				handler,
				_marker: std::marker::PhantomData,
			}));
	}

	pub async fn emit<E: Send + Sync + 'static>(&self, event: E) {
		self.emit_arc(Arc::new(event)).await;
	}

	pub async fn emit_arc<E: Send + Sync + 'static>(&self, event: Arc<E>) {
		let Some(handlers) = self.handlers.get_async(&TypeId::of::<E>()).await else {
			return;
		};
		let event: Arc<dyn Any + Send + Sync> = event;
		handlers
			.iter()
			.map(|h| {
				std::panic::AssertUnwindSafe(h.handle_erased(event.clone(), &self.extensions))
					.catch_unwind()
			})
			.R(join_all)
			.await;
	}
}