#[cfg(feature = "testing-crash-injection")]
mod inner {
use std::cell::Cell;
thread_local! {
static CRASH_COUNTER: Cell<u64> = const { Cell::new(u64::MAX) };
static CRASH_ENABLED: Cell<bool> = const { Cell::new(false) };
}
#[inline]
pub fn maybe_crash(point: &'static str) {
CRASH_ENABLED.with(|enabled| {
if !enabled.get() {
return;
}
CRASH_COUNTER.with(|counter| {
let prev = counter.get();
counter.set(prev.wrapping_sub(1));
assert!(prev != 1, "crash injection at: {point}");
});
});
}
pub fn enable_crash_at(count: u64) {
CRASH_COUNTER.with(|c| c.set(count));
CRASH_ENABLED.with(|e| e.set(true));
}
pub fn disable_crash() {
CRASH_ENABLED.with(|e| e.set(false));
CRASH_COUNTER.with(|c| c.set(u64::MAX));
}
}
#[cfg(not(feature = "testing-crash-injection"))]
mod inner {
#[inline(always)]
pub fn maybe_crash(_point: &'static str) {}
pub fn enable_crash_at(_count: u64) {}
pub fn disable_crash() {}
}
pub use inner::*;
#[non_exhaustive]
pub enum CrashResult<T> {
Completed(T),
Crashed,
}
pub fn with_crash_at<F, T>(crash_after: u64, f: F) -> CrashResult<T>
where
F: FnOnce() -> T + std::panic::UnwindSafe,
{
enable_crash_at(crash_after);
let result = std::panic::catch_unwind(f);
disable_crash();
match result {
Ok(value) => CrashResult::Completed(value),
Err(_) => CrashResult::Crashed,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(feature = "testing-crash-injection")]
fn crash_at_exact_count() {
let result = with_crash_at(3, || {
maybe_crash("point_1");
maybe_crash("point_2");
maybe_crash("point_3"); 42 });
assert!(matches!(result, CrashResult::Crashed));
}
#[test]
fn completes_when_count_exceeds_calls() {
let result = with_crash_at(100, || {
maybe_crash("a");
maybe_crash("b");
42
});
match result {
CrashResult::Completed(v) => assert_eq!(v, 42),
CrashResult::Crashed => panic!("should not crash"),
}
}
#[test]
fn disabled_by_default() {
maybe_crash("should_not_crash");
}
#[test]
fn disable_resets_state() {
enable_crash_at(2);
disable_crash();
maybe_crash("a");
maybe_crash("b");
maybe_crash("c");
}
}