use std::sync::{Arc, Mutex};
type PanicHookType = (dyn for<'r, 's> Fn(&'r std::panic::PanicInfo<'s>) + Send + Sync + 'static);
pub struct GuardWithHook<F>
where
F: FnOnce() + Sync + Send + 'static,
{
original_hook: Arc<PanicHookType>,
callback: Arc<Mutex<Option<F>>>,
}
impl<F> GuardWithHook<F>
where
F: FnOnce() + Sync + Send + 'static,
{
pub fn new(callback: F) -> Self {
let callback = Arc::new(Mutex::new(Some(callback)));
let callback_copy = callback.clone();
let original_hook: Arc<PanicHookType> = Arc::from(std::panic::take_hook());
let original_hook_copy = original_hook.clone();
std::panic::set_hook(Box::new(move |info| {
if let Ok(mut callback) = callback_copy.try_lock() {
if let Some(callback) = callback.take() {
callback();
}
}
(*original_hook_copy)(info);
}));
Self {
original_hook,
callback,
}
}
}
impl<F> Drop for GuardWithHook<F>
where
F: FnOnce() + Sync + Send + 'static,
{
fn drop(&mut self) {
if !std::thread::panicking() {
let original_hook = self.original_hook.clone();
std::panic::set_hook(Box::new(move |info| (*original_hook)(info)));
}
if let Ok(mut callback) = self.callback.try_lock() {
if let Some(callback) = callback.take() {
callback();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
static TEST_MUTEX: Mutex<()> = Mutex::new(());
#[test]
fn test_callback_called_on_drop() {
let _m = TEST_MUTEX.lock().unwrap();
let called = Arc::new(Mutex::new(0));
{
let called = called.clone();
let _guard = GuardWithHook::new(move || {
*called.lock().unwrap() += 1;
});
}
assert_eq!(*called.lock().unwrap(), 1);
}
#[test]
fn test_callback_called_once_only_panic() {
let _m = TEST_MUTEX.lock().unwrap();
let original_hook = std::panic::take_hook();
let calls = Arc::new(Mutex::new(0));
let calls2 = calls.clone();
let _guard = GuardWithHook::new(move || {
*calls2.lock().unwrap() += 1;
});
assert!(std::panic::catch_unwind(|| panic!("test")).is_err());
std::panic::set_hook(original_hook);
assert_eq!(*calls.lock().unwrap(), 1);
}
#[test]
fn test_callback_called_before_panic_hook() {
let _m = TEST_MUTEX.lock().unwrap();
let original_hook = std::panic::take_hook();
let calls: Arc<Mutex<Vec<&str>>> = Arc::new(Mutex::new(vec![]));
let calls2 = calls.clone();
let calls3 = calls.clone();
std::panic::set_hook(Box::new(move |_| calls2.lock().unwrap().push("hook")));
let _guard = GuardWithHook::new(move || calls3.lock().unwrap().push("cleanup"));
assert!(std::panic::catch_unwind(|| panic!("test")).is_err());
std::panic::set_hook(original_hook);
assert_eq!(*calls.lock().unwrap(), vec!["cleanup", "hook"]);
}
#[test]
fn test_nested_callback() {
let _m = TEST_MUTEX.lock().unwrap();
let original_hook = std::panic::take_hook();
let calls: Arc<Mutex<Vec<&str>>> = Arc::new(Mutex::new(vec![]));
let calls2 = calls.clone();
let calls3 = calls.clone();
{
let _g = GuardWithHook::new(move || calls2.lock().unwrap().push("outer"));
{
let _g = GuardWithHook::new(move || calls3.lock().unwrap().push("inner"));
}
}
std::panic::set_hook(original_hook);
assert_eq!(*calls.lock().unwrap(), vec!["inner", "outer"]);
}
#[test]
fn test_nested_callback_with_panic() {
let _m = TEST_MUTEX.lock().unwrap();
let original_hook = std::panic::take_hook();
let calls: Arc<Mutex<Vec<&str>>> = Arc::new(Mutex::new(vec![]));
let calls2 = calls.clone();
let calls3 = calls.clone();
let _g = GuardWithHook::new(move || calls2.lock().unwrap().push("outer"));
{
let _g = GuardWithHook::new(move || calls3.lock().unwrap().push("inner"));
assert!(std::panic::catch_unwind(|| panic!("test")).is_err());
}
std::panic::set_hook(original_hook);
assert_eq!(*calls.lock().unwrap(), vec!["inner", "outer"]);
}
#[test]
fn test_nested_callback_hook_restored() {
let _m = TEST_MUTEX.lock().unwrap();
let original_hook = std::panic::take_hook();
let calls: Arc<Mutex<Vec<&str>>> = Arc::new(Mutex::new(vec![]));
let calls2 = calls.clone();
let calls3 = calls.clone();
let calls4 = calls.clone();
let calls5 = calls.clone();
std::panic::set_hook(Box::new(move |_| calls2.lock().unwrap().push("outer hook")));
{
let _g = GuardWithHook::new(move || calls3.lock().unwrap().push("inner cleanup"));
{
let _g =
GuardWithHook::new(move || calls4.lock().unwrap().push("inner inner cleanup"));
std::panic::set_hook(Box::new(move |_| calls5.lock().unwrap().push("wrong")));
}
}
assert_eq!(
*calls.lock().unwrap(),
vec!["inner inner cleanup", "inner cleanup"]
);
assert!(std::panic::catch_unwind(|| panic!("test")).is_err());
std::panic::set_hook(original_hook);
assert_eq!(
*calls.lock().unwrap(),
vec!["inner inner cleanup", "inner cleanup", "outer hook"]
);
}
}