use core::any::Any;
use core::fmt;
use core::future::Future;
use core::pin::Pin;
use std::sync::Arc;
use arc_swap::{ArcSwap, ArcSwapOption};
use crate::HandlerId;
use crate::future_ext::{CatchUnwind, JoinAll};
use crate::handler_id::HandlerIdGenerator;
use crate::panic::{PanicCallbackHolder, PanicInfo};
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
pub mod guard;
pub use guard::AsyncHandlerGuard;
type StoredAsyncHandler<E> = Arc<dyn Fn(&E) -> BoxFuture<()> + Send + Sync + 'static>;
struct AsyncHandlerEntry<E: Send + Sync + 'static> {
id: HandlerId,
priority: i32,
handler: StoredAsyncHandler<E>,
}
impl<E: Send + Sync + 'static> Clone for AsyncHandlerEntry<E> {
#[inline]
fn clone(&self) -> Self {
Self {
id: self.id,
priority: self.priority,
handler: Arc::clone(&self.handler),
}
}
}
impl<E: Send + Sync + 'static> fmt::Debug for AsyncHandlerEntry<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AsyncHandlerEntry")
.field("id", &self.id)
.field("priority", &self.priority)
.finish_non_exhaustive()
}
}
pub struct AsyncRegistry<E: Send + Sync + 'static> {
handlers: ArcSwap<Vec<AsyncHandlerEntry<E>>>,
id_generator: HandlerIdGenerator,
panic_callback: ArcSwapOption<PanicCallbackHolder>,
}
impl<E: Send + Sync + 'static> AsyncRegistry<E> {
#[must_use]
pub fn new() -> Self {
Self {
handlers: ArcSwap::from_pointee(Vec::new()),
id_generator: HandlerIdGenerator::new(),
panic_callback: ArcSwapOption::empty(),
}
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self {
handlers: ArcSwap::from_pointee(Vec::with_capacity(capacity)),
id_generator: HandlerIdGenerator::new(),
panic_callback: ArcSwapOption::empty(),
}
}
pub fn register<F, Fut>(&self, handler: F) -> HandlerId
where
F: Fn(&E) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.register_with_priority(0, handler)
}
pub fn register_with_priority<F, Fut>(&self, priority: i32, handler: F) -> HandlerId
where
F: Fn(&E) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let id = self.id_generator.next();
let boxed: StoredAsyncHandler<E> = Arc::new(move |event: &E| {
let fut = handler(event);
Box::pin(fut) as BoxFuture<()>
});
let entry = AsyncHandlerEntry {
id,
priority,
handler: boxed,
};
drop(self.handlers.rcu(|current| {
let mut new_vec: Vec<AsyncHandlerEntry<E>> = Vec::with_capacity(current.len() + 1);
new_vec.extend(current.iter().cloned());
let pos = new_vec.partition_point(|e| e.priority >= entry.priority);
new_vec.insert(pos, entry.clone());
Arc::new(new_vec)
}));
id
}
pub fn register_guard<F, Fut>(self: &Arc<Self>, handler: F) -> AsyncHandlerGuard<E>
where
F: Fn(&E) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let id = self.register(handler);
AsyncHandlerGuard::new(id, Arc::downgrade(self))
}
pub fn register_guard_with_priority<F, Fut>(
self: &Arc<Self>,
priority: i32,
handler: F,
) -> AsyncHandlerGuard<E>
where
F: Fn(&E) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let id = self.register_with_priority(priority, handler);
AsyncHandlerGuard::new(id, Arc::downgrade(self))
}
pub fn unregister(&self, id: HandlerId) -> bool {
let mut removed = false;
drop(self.handlers.rcu(|current| {
let mut new_vec: Vec<AsyncHandlerEntry<E>> = Vec::with_capacity(current.len());
new_vec.extend(current.iter().filter(|e| e.id != id).cloned());
removed = new_vec.len() != current.len();
Arc::new(new_vec)
}));
removed
}
pub fn clear(&self) {
self.handlers.store(Arc::new(Vec::new()));
}
#[inline]
#[must_use]
pub fn handler_count(&self) -> usize {
self.handlers.load().len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.handlers.load().is_empty()
}
#[must_use]
pub fn contains(&self, id: HandlerId) -> bool {
self.handlers.load().iter().any(|e| e.id == id)
}
pub fn on_panic<F>(&self, callback: F)
where
F: Fn(&PanicInfo<'_>) + Send + Sync + 'static,
{
let holder = Arc::new(PanicCallbackHolder::new(callback));
self.panic_callback.store(Some(holder));
}
pub fn clear_panic_callback(&self) {
self.panic_callback.store(None);
}
pub async fn notify(&self, event: &E) {
let snapshot = self.handlers.load();
if snapshot.is_empty() {
return;
}
let n = snapshot.len();
let mut ids: Vec<HandlerId> = Vec::with_capacity(n);
let mut wrapped = Vec::with_capacity(n);
for entry in snapshot.iter() {
ids.push(entry.id);
wrapped.push(CatchUnwind::new((entry.handler)(event)));
}
let results = JoinAll::new(wrapped).await;
for (id, outcome) in ids.into_iter().zip(results) {
if let Err(payload) = outcome {
self.handle_panic(id, payload);
}
}
}
pub async fn notify_sequential(&self, event: &E) {
let snapshot = self.handlers.load();
for entry in snapshot.iter() {
let fut = (entry.handler)(event);
match CatchUnwind::new(fut).await {
Ok(()) => {}
Err(payload) => self.handle_panic(entry.id, payload),
}
}
}
#[cold]
fn handle_panic(&self, handler_id: HandlerId, payload: Box<dyn Any + Send + 'static>) {
let guard = self.panic_callback.load();
if let Some(holder) = guard.as_ref() {
let info = PanicInfo::new(handler_id, payload.as_ref());
drop(std::panic::catch_unwind(std::panic::AssertUnwindSafe(
|| {
holder.invoke(&info);
},
)));
}
drop(payload);
}
}
impl<E: Send + Sync + 'static> Default for AsyncRegistry<E> {
fn default() -> Self {
Self::new()
}
}
impl<E: Send + Sync + 'static> fmt::Debug for AsyncRegistry<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AsyncRegistry")
.field("handler_count", &self.handlers.load().len())
.field("has_panic_callback", &self.panic_callback.load().is_some())
.finish_non_exhaustive()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::AsyncRegistry;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
#[tokio::test]
async fn empty_registry_notify_is_noop() {
let registry: AsyncRegistry<u32> = AsyncRegistry::new();
registry.notify(&42).await;
registry.notify_sequential(&42).await;
}
#[tokio::test]
async fn concurrent_notify_fires_every_handler_once() {
let registry: AsyncRegistry<u32> = AsyncRegistry::new();
let count = Arc::new(AtomicU32::new(0));
for _ in 0..5 {
let sink = Arc::clone(&count);
let _ = registry.register(move |_| {
let sink = Arc::clone(&sink);
async move {
let _ = sink.fetch_add(1, Ordering::Relaxed);
}
});
}
registry.notify(&0).await;
assert_eq!(count.load(Ordering::Relaxed), 5);
}
}