use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct InjectedFailure;
impl fmt::Display for InjectedFailure {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("injected test failure")
}
}
impl std::error::Error for InjectedFailure {}
#[cfg(feature = "testing-statement-injection")]
mod inner {
use super::InjectedFailure;
use std::cell::Cell;
thread_local! {
static STATEMENT_COUNTER: Cell<u64> = const { Cell::new(u64::MAX) };
static STATEMENT_ENABLED: Cell<bool> = const { Cell::new(false) };
static COMMIT_TRIGGER: Cell<bool> = const { Cell::new(false) };
}
#[inline]
pub fn maybe_fail_statement() -> Result<(), InjectedFailure> {
STATEMENT_ENABLED.with(|enabled| {
if !enabled.get() {
return Ok(());
}
STATEMENT_COUNTER.with(|counter| {
let prev = counter.get();
counter.set(prev.wrapping_sub(1));
if prev == 1 {
Err(InjectedFailure)
} else {
Ok(())
}
})
})
}
#[inline]
pub fn maybe_fail_commit() -> Result<(), InjectedFailure> {
COMMIT_TRIGGER.with(|trigger| {
if trigger.get() {
trigger.set(false);
Err(InjectedFailure)
} else {
Ok(())
}
})
}
pub fn enable_statement_failure_after(count: u64) {
STATEMENT_COUNTER.with(|c| c.set(count));
STATEMENT_ENABLED.with(|e| e.set(true));
}
pub fn enable_commit_failure_once() {
COMMIT_TRIGGER.with(|t| t.set(true));
}
pub fn disable_injection() {
STATEMENT_ENABLED.with(|e| e.set(false));
STATEMENT_COUNTER.with(|c| c.set(u64::MAX));
COMMIT_TRIGGER.with(|t| t.set(false));
}
}
#[cfg(not(feature = "testing-statement-injection"))]
mod inner {
use super::InjectedFailure;
#[inline]
pub fn maybe_fail_statement() -> Result<(), InjectedFailure> {
Ok(())
}
#[inline]
pub fn maybe_fail_commit() -> Result<(), InjectedFailure> {
Ok(())
}
pub fn enable_statement_failure_after(_count: u64) {}
pub fn enable_commit_failure_once() {}
pub fn disable_injection() {}
}
pub use inner::*;
pub fn with_statement_failure_after<F, T>(fail_after: u64, f: F) -> T
where
F: FnOnce() -> T,
{
enable_statement_failure_after(fail_after);
let result = f();
disable_injection();
result
}
pub fn with_commit_failure<F, T>(f: F) -> T
where
F: FnOnce() -> T,
{
enable_commit_failure_once();
let result = f();
disable_injection();
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(feature = "testing-statement-injection")]
fn statement_fires_at_exact_count() {
let result = with_statement_failure_after(3, || {
let a = maybe_fail_statement();
let b = maybe_fail_statement();
let c = maybe_fail_statement();
(a, b, c)
});
assert_eq!(result.0, Ok(()));
assert_eq!(result.1, Ok(()));
assert_eq!(result.2, Err(InjectedFailure));
}
#[test]
#[cfg(feature = "testing-statement-injection")]
fn commit_fires_once() {
let result = with_commit_failure(|| {
let a = maybe_fail_commit();
let b = maybe_fail_commit();
(a, b)
});
assert_eq!(result.0, Err(InjectedFailure));
assert_eq!(result.1, Ok(()), "commit trigger is one-shot");
}
#[test]
fn completes_when_count_exceeds_calls() {
let result = with_statement_failure_after(100, || {
let a = maybe_fail_statement();
let b = maybe_fail_statement();
(a, b)
});
assert_eq!(result.0, Ok(()));
assert_eq!(result.1, Ok(()));
}
#[test]
fn disabled_by_default() {
assert_eq!(maybe_fail_statement(), Ok(()));
assert_eq!(maybe_fail_commit(), Ok(()));
}
#[test]
fn disable_resets_state() {
enable_statement_failure_after(2);
enable_commit_failure_once();
disable_injection();
assert_eq!(maybe_fail_statement(), Ok(()));
assert_eq!(maybe_fail_statement(), Ok(()));
assert_eq!(maybe_fail_commit(), Ok(()));
}
}