use std::panic::{AssertUnwindSafe, catch_unwind};
use crate::boundary::{PluginError, PluginErrorCode, PluginResult};
pub fn guard<T>(f: impl FnOnce() -> Result<T, PluginError>) -> PluginResult<T> {
let result = catch_unwind(AssertUnwindSafe(f));
match result {
Ok(Ok(t)) => PluginResult::Ok(t),
Ok(Err(e)) => PluginResult::Err(e),
Err(payload) => {
let message = panic_message(payload.as_ref());
drop_payload(payload);
PluginResult::Err(PluginError::new(PluginErrorCode::Panic, message))
}
}
}
pub fn guard_infallible<T>(thunk_name: &str, f: impl FnOnce() -> T) -> T {
match catch_unwind(AssertUnwindSafe(f)) {
Ok(t) => t,
Err(payload) => {
let msg = panic_message(payload.as_ref());
drop_payload(payload);
log::error!(
target: "nautilus_plugin",
"plug-in panicked in `{thunk_name}` thunk; aborting process: {msg}",
);
std::process::abort();
}
}
}
pub fn drop_payload(payload: Box<dyn std::any::Any + Send>) {
let _ = catch_unwind(AssertUnwindSafe(move || drop(payload)));
}
fn panic_message(payload: &(dyn std::any::Any + Send)) -> String {
if let Some(s) = payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"plug-in panicked with non-string payload".to_string()
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use rstest::rstest;
use super::*;
#[rstest]
fn returns_ok_on_success() {
let r = guard(|| Ok::<u32, PluginError>(7));
assert_eq!(r.into_result().unwrap(), 7);
}
#[rstest]
fn returns_err_on_returned_error() {
let r = guard(|| Err::<u32, _>(PluginError::generic("boom")));
let e = r.into_result().unwrap_err();
assert_eq!(e.code, PluginErrorCode::Generic);
assert_eq!(e.message_string(), "boom");
}
#[rstest]
fn returns_err_on_string_panic() {
let r = guard(|| -> Result<u32, PluginError> { panic!("oops") });
let e = r.into_result().unwrap_err();
assert_eq!(e.code, PluginErrorCode::Panic);
assert!(e.message_string().contains("oops"));
}
#[rstest]
fn returns_err_on_non_string_panic() {
let r = guard(|| -> Result<u32, PluginError> {
std::panic::panic_any(42_u32);
});
let e = r.into_result().unwrap_err();
assert_eq!(e.code, PluginErrorCode::Panic);
assert!(e.message_string().contains("non-string"));
}
#[rstest]
fn guard_infallible_returns_inner_on_success() {
let v = guard_infallible("test", || 42u64);
assert_eq!(v, 42);
}
#[rstest]
fn drop_payload_swallows_panicking_drop() {
use std::{
any::Any,
sync::atomic::{AtomicUsize, Ordering},
};
static DROPS_OBSERVED: AtomicUsize = AtomicUsize::new(0);
struct Bomb;
impl Drop for Bomb {
fn drop(&mut self) {
DROPS_OBSERVED.fetch_add(1, Ordering::SeqCst);
panic!("drop panic");
}
}
DROPS_OBSERVED.store(0, Ordering::SeqCst);
let payload: Box<dyn Any + Send> = Box::new(Bomb);
drop_payload(payload);
assert_eq!(DROPS_OBSERVED.load(Ordering::SeqCst), 1);
}
#[rstest]
fn guard_survives_panic_any_with_panicking_drop() {
static DROPS_OBSERVED: AtomicUsize = AtomicUsize::new(0);
struct Bomb;
impl Drop for Bomb {
fn drop(&mut self) {
DROPS_OBSERVED.fetch_add(1, Ordering::SeqCst);
panic!("drop panic");
}
}
DROPS_OBSERVED.store(0, Ordering::SeqCst);
let r = guard(|| -> Result<u32, PluginError> {
std::panic::panic_any(Bomb);
});
let e = r.into_result().unwrap_err();
assert_eq!(e.code, PluginErrorCode::Panic);
assert_eq!(DROPS_OBSERVED.load(Ordering::SeqCst), 1);
}
}