#![cfg(feature = "hooks")]
use crate::AsyncError;
use downcast_rs::{impl_downcast, DowncastSync};
use once_cell::sync::Lazy;
use parking_lot::RwLock;
use std::{
any::TypeId,
collections::HashMap,
error::Error,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc,
},
};
static TIMESTAMP_ENABLED: AtomicBool = AtomicBool::new(false);
pub fn enable_hook_timestamps() {
TIMESTAMP_ENABLED.store(true, Ordering::SeqCst);
}
pub fn disable_hook_timestamps() {
TIMESTAMP_ENABLED.store(false, Ordering::SeqCst);
}
pub trait AsyncErrorHook<E: Error + 'static>: Send + Sync + 'static + DowncastSync {
fn on_error(&self, error: &AsyncError<E>);
}
impl_downcast!(sync AsyncErrorHook<E> where E: Error + 'static);
pub trait AsyncErrorHookDefault<E: Error + 'static>: AsyncErrorHook<E> {
fn on_error(&self, error: &AsyncError<E>) {
let header = if TIMESTAMP_ENABLED.load(Ordering::SeqCst) {
#[cfg(feature = "chrono")]
{
use chrono::Local;
let now = Local::now();
format!(
"{} | AsyncError Hook Triggered",
now.format("%Y-%m-%d %H:%M:%S")
)
}
#[cfg(not(feature = "chrono"))]
{
let now = std::time::SystemTime::now();
match now.duration_since(std::time::UNIX_EPOCH) {
Ok(dur) => format!("[{}] | AsyncError Hook Triggered", dur.as_secs()),
Err(_) => "[time unknown] | AsyncError Hook Triggered".to_string(),
}
}
} else {
"AsyncError Hook Triggered".to_string()
};
let context = error.context().unwrap_or("<none>");
let msg = format!(
"{}\n Context: {}\n Inner error: {}\n------------------------------",
header,
context,
error.inner_error()
);
eprintln!("{}", msg);
}
}
impl<E: Error + 'static, T> AsyncErrorHookDefault<E> for T where T: AsyncErrorHook<E> {}
struct HookRegistry<E: Error + 'static> {
hooks: Vec<Arc<dyn AsyncErrorHook<E>>>,
}
static GLOBAL_HOOKS: Lazy<RwLock<HashMap<TypeId, Box<dyn std::any::Any + Send + Sync>>>> =
Lazy::new(|| RwLock::new(HashMap::new()));
pub fn register_hook<E: Error + 'static>(hook: Arc<dyn AsyncErrorHook<E>>) {
let mut registry = GLOBAL_HOOKS.write();
let type_id = TypeId::of::<E>();
let entry = registry
.entry(type_id)
.or_insert_with(|| Box::new(HookRegistry::<E> { hooks: Vec::new() }));
let hooks = entry
.downcast_mut::<HookRegistry<E>>()
.expect("Type mismatch in global hooks registry");
if !hooks
.hooks
.iter()
.any(|existing| Arc::ptr_eq(existing, &hook))
{
hooks.hooks.push(hook);
}
}
pub fn get_hooks<E: Error + 'static>() -> Vec<Arc<dyn AsyncErrorHook<E>>> {
let registry = GLOBAL_HOOKS.read();
registry
.get(&TypeId::of::<E>())
.and_then(|entry| entry.downcast_ref::<HookRegistry<E>>())
.map(|hooks| hooks.hooks.clone())
.unwrap_or_default()
}
static HOOK_INVOKE_COUNTER: AtomicUsize = AtomicUsize::new(0);
pub fn invoke_hooks<E: Error + 'static>(error: &AsyncError<E>) {
if HOOK_INVOKE_COUNTER
.compare_exchange(0, 1, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
return;
}
for hook in get_hooks::<E>() {
hook.on_error(error);
}
HOOK_INVOKE_COUNTER.store(0, Ordering::Release);
}