#[cfg(feature = "oom-injection")]
pub mod alloc;
#[cfg(feature = "oom-injection")]
use std::sync::{Mutex, OnceLock};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ProbeReport {
pub fail_at: usize,
pub allocations: usize,
pub outcome: ProbeOutcome,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ProbeOutcome {
Returned,
OomInjected,
Panicked(String),
}
#[cfg(feature = "oom-injection")]
#[inline]
pub fn count_allocations<R>(entry: impl FnOnce() -> R) -> ProbeReport {
probe(0, entry)
}
#[cfg(feature = "oom-injection")]
#[inline]
pub fn probe<R>(fail_at: usize, entry: impl FnOnce() -> R) -> ProbeReport {
use std::panic::{catch_unwind, take_hook, AssertUnwindSafe};
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
let lock = LOCK.get_or_init(|| Mutex::new(()));
let guard = match lock.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
let old_hook = take_hook();
std::panic::set_hook(Box::new(|_| {}));
alloc::arm_thread(fail_at);
let result = catch_unwind(AssertUnwindSafe(entry));
alloc::disarm_thread();
let allocations = alloc::allocation_count();
let triggered = alloc::triggered_at();
alloc::clear_thread();
std::panic::set_hook(old_hook);
drop(guard);
let outcome = match result {
Ok(_) => ProbeOutcome::Returned,
Err(payload) if alloc::is_oom_payload(payload.as_ref()) || triggered == fail_at => {
ProbeOutcome::OomInjected
}
Err(payload) => ProbeOutcome::Panicked(panic_payload_message(payload.as_ref())),
};
ProbeReport {
fail_at,
allocations,
outcome,
}
}
#[cfg(not(feature = "oom-injection"))]
#[inline]
pub fn count_allocations<R>(_entry: impl FnOnce() -> R) -> ProbeReport {
disabled_report(0)
}
#[cfg(not(feature = "oom-injection"))]
#[inline]
pub fn probe<R>(fail_at: usize, _entry: impl FnOnce() -> R) -> ProbeReport {
disabled_report(fail_at)
}
#[cfg(not(feature = "oom-injection"))]
fn disabled_report(fail_at: usize) -> ProbeReport {
ProbeReport {
fail_at,
allocations: 0,
outcome: ProbeOutcome::Panicked(
"oom-injection feature not enabled. Fix: enable the oom-injection feature.".to_string(),
),
}
}
fn panic_payload_message(payload: &(dyn std::any::Any + Send)) -> String {
if let Some(message) = payload.downcast_ref::<&'static str>() {
(*message).to_string()
} else if let Some(message) = payload.downcast_ref::<String>() {
message.clone()
} else {
"non-string panic payload".to_string()
}
}