use std::sync::Mutex;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
use crate::llm::error::LlmError;
#[derive(Debug, Clone)]
pub enum FaultType {
RequestFailed,
RateLimited { retry_after: Option<Duration> },
AuthFailed,
InvalidResponse,
IoError,
ContextLengthExceeded,
SessionExpired,
}
impl FaultType {
pub fn to_llm_error(&self, provider: &str) -> LlmError {
match self {
FaultType::RequestFailed => LlmError::RequestFailed {
provider: provider.to_string(),
reason: "injected fault: request failed".to_string(),
},
FaultType::RateLimited { retry_after } => LlmError::RateLimited {
provider: provider.to_string(),
retry_after: *retry_after,
},
FaultType::AuthFailed => LlmError::AuthFailed {
provider: provider.to_string(),
},
FaultType::InvalidResponse => LlmError::InvalidResponse {
provider: provider.to_string(),
reason: "injected fault: invalid response".to_string(),
},
FaultType::IoError => LlmError::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"injected fault: connection reset",
)),
FaultType::ContextLengthExceeded => LlmError::ContextLengthExceeded {
used: 100_000,
limit: 50_000,
},
FaultType::SessionExpired => LlmError::SessionExpired {
provider: provider.to_string(),
},
}
}
}
#[derive(Debug, Clone)]
pub enum FaultAction {
Succeed,
Fail(FaultType),
Delay(Duration),
}
#[derive(Debug, Clone)]
pub enum FaultMode {
SequenceOnce,
SequenceLoop,
Random {
error_rate: f64,
fault: FaultType,
seed: u64,
},
}
pub struct FaultInjector {
actions: Vec<FaultAction>,
mode: FaultMode,
call_index: AtomicU32,
rng_state: Mutex<u64>,
}
impl std::fmt::Debug for FaultInjector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FaultInjector")
.field("call_index", &self.call_index.load(Ordering::Relaxed))
.field("mode", &self.mode)
.finish()
}
}
impl FaultInjector {
pub fn sequence(actions: impl IntoIterator<Item = FaultAction>) -> Self {
Self {
actions: actions.into_iter().collect(),
mode: FaultMode::SequenceOnce,
call_index: AtomicU32::new(0),
rng_state: Mutex::new(0),
}
}
pub fn sequence_loop(actions: impl IntoIterator<Item = FaultAction>) -> Self {
Self {
actions: actions.into_iter().collect(),
mode: FaultMode::SequenceLoop,
call_index: AtomicU32::new(0),
rng_state: Mutex::new(0),
}
}
pub fn random(error_rate: f64, fault: FaultType, seed: u64) -> Self {
assert!(
!error_rate.is_nan() && (0.0..=1.0).contains(&error_rate),
"error_rate must be in 0.0..=1.0 and not NaN, got {error_rate}"
);
let seed = if seed == 0 { 1 } else { seed };
Self {
actions: Vec::new(),
mode: FaultMode::Random {
error_rate,
fault,
seed,
},
call_index: AtomicU32::new(0),
rng_state: Mutex::new(seed),
}
}
pub fn next_action(&self) -> FaultAction {
let index = self.call_index.fetch_add(1, Ordering::Relaxed) as usize;
match &self.mode {
FaultMode::SequenceOnce => {
if index < self.actions.len() {
self.actions[index].clone()
} else {
FaultAction::Succeed
}
}
FaultMode::SequenceLoop => {
if self.actions.is_empty() {
FaultAction::Succeed
} else {
self.actions[index % self.actions.len()].clone()
}
}
FaultMode::Random {
error_rate, fault, ..
} => {
let random_val = {
let mut state = self.rng_state.lock().unwrap_or_else(|p| p.into_inner());
*state ^= *state << 13;
*state ^= *state >> 7;
*state ^= *state << 17;
(*state as f64) / (u64::MAX as f64)
};
if random_val <= *error_rate {
FaultAction::Fail(fault.clone())
} else {
FaultAction::Succeed
}
}
}
}
pub fn call_count(&self) -> u32 {
self.call_index.load(Ordering::Relaxed)
}
pub fn reset(&self) {
self.call_index.store(0, Ordering::Relaxed);
if let FaultMode::Random { seed, .. } = &self.mode {
let mut state = self.rng_state.lock().unwrap_or_else(|p| p.into_inner());
*state = *seed;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sequence_once_plays_then_succeeds() {
let injector = FaultInjector::sequence([
FaultAction::Fail(FaultType::RequestFailed),
FaultAction::Fail(FaultType::RateLimited { retry_after: None }),
FaultAction::Succeed,
]);
assert!(matches!(
injector.next_action(),
FaultAction::Fail(FaultType::RequestFailed)
));
assert!(matches!(
injector.next_action(),
FaultAction::Fail(FaultType::RateLimited { .. })
));
assert!(matches!(injector.next_action(), FaultAction::Succeed));
assert!(matches!(injector.next_action(), FaultAction::Succeed));
assert!(matches!(injector.next_action(), FaultAction::Succeed));
assert_eq!(injector.call_count(), 5);
}
#[test]
fn sequence_loop_repeats() {
let injector = FaultInjector::sequence_loop([
FaultAction::Fail(FaultType::RequestFailed),
FaultAction::Succeed,
]);
assert!(matches!(injector.next_action(), FaultAction::Fail(_)));
assert!(matches!(injector.next_action(), FaultAction::Succeed));
assert!(matches!(injector.next_action(), FaultAction::Fail(_)));
assert!(matches!(injector.next_action(), FaultAction::Succeed));
}
#[test]
fn random_mode_is_deterministic_with_seed() {
let injector1 = FaultInjector::random(0.5, FaultType::RequestFailed, 42);
let injector2 = FaultInjector::random(0.5, FaultType::RequestFailed, 42);
let results1: Vec<bool> = (0..20)
.map(|_| matches!(injector1.next_action(), FaultAction::Fail(_)))
.collect();
let results2: Vec<bool> = (0..20)
.map(|_| matches!(injector2.next_action(), FaultAction::Fail(_)))
.collect();
assert_eq!(results1, results2, "Same seed should produce same sequence");
}
#[test]
fn fault_type_produces_correct_llm_errors() {
let provider = "test-provider";
assert!(matches!(
FaultType::RequestFailed.to_llm_error(provider),
LlmError::RequestFailed { .. }
));
assert!(matches!(
FaultType::RateLimited {
retry_after: Some(Duration::from_secs(5))
}
.to_llm_error(provider),
LlmError::RateLimited { .. }
));
assert!(matches!(
FaultType::AuthFailed.to_llm_error(provider),
LlmError::AuthFailed { .. }
));
assert!(matches!(
FaultType::InvalidResponse.to_llm_error(provider),
LlmError::InvalidResponse { .. }
));
assert!(matches!(
FaultType::IoError.to_llm_error(provider),
LlmError::Io(_)
));
assert!(matches!(
FaultType::ContextLengthExceeded.to_llm_error(provider),
LlmError::ContextLengthExceeded { .. }
));
assert!(matches!(
FaultType::SessionExpired.to_llm_error(provider),
LlmError::SessionExpired { .. }
));
}
#[test]
fn delay_action_exists() {
let injector = FaultInjector::sequence([FaultAction::Delay(Duration::from_millis(100))]);
assert!(matches!(injector.next_action(), FaultAction::Delay(_)));
}
#[test]
fn random_seed_zero_does_not_always_fail() {
let injector = FaultInjector::random(0.5, FaultType::RequestFailed, 0);
let failures = (0..100)
.filter(|_| matches!(injector.next_action(), FaultAction::Fail(_)))
.count();
assert!(failures < 100, "seed=0 must not produce stuck RNG");
}
#[test]
fn empty_sequence_always_succeeds() {
let injector = FaultInjector::sequence([]);
for _ in 0..10 {
assert!(matches!(injector.next_action(), FaultAction::Succeed));
}
}
#[test]
fn reset_restores_random_rng_from_stored_seed() {
let injector = FaultInjector::random(0.5, FaultType::RequestFailed, 42);
let run1: Vec<bool> = (0..20)
.map(|_| matches!(injector.next_action(), FaultAction::Fail(_)))
.collect();
injector.reset();
assert_eq!(injector.call_count(), 0);
let run2: Vec<bool> = (0..20)
.map(|_| matches!(injector.next_action(), FaultAction::Fail(_)))
.collect();
assert_eq!(run1, run2, "reset() should reproduce the same sequence");
}
#[test]
#[should_panic(expected = "error_rate must be in 0.0..=1.0")]
fn random_rejects_error_rate_above_one() {
FaultInjector::random(1.5, FaultType::RequestFailed, 42);
}
#[test]
#[should_panic(expected = "error_rate must be in 0.0..=1.0")]
fn random_rejects_negative_error_rate() {
FaultInjector::random(-0.1, FaultType::RequestFailed, 42);
}
#[test]
#[should_panic(expected = "error_rate must be in 0.0..=1.0 and not NaN")]
fn random_rejects_nan_error_rate() {
FaultInjector::random(f64::NAN, FaultType::RequestFailed, 42);
}
#[test]
fn error_rate_one_always_fails() {
let injector = FaultInjector::random(1.0, FaultType::RequestFailed, 42);
for _ in 0..100 {
assert!(
matches!(injector.next_action(), FaultAction::Fail(_)),
"error_rate=1.0 must always produce failures"
);
}
}
#[test]
fn error_rate_zero_never_fails() {
let injector = FaultInjector::random(0.0, FaultType::RequestFailed, 42);
for _ in 0..100 {
assert!(
matches!(injector.next_action(), FaultAction::Succeed),
"error_rate=0.0 must never produce failures"
);
}
}
#[tokio::test]
async fn delay_action_pauses_execution() {
tokio::time::pause();
let injector = FaultInjector::sequence([
FaultAction::Delay(Duration::from_secs(10)),
FaultAction::Succeed,
]);
let action = injector.next_action();
assert!(matches!(action, FaultAction::Delay(d) if d == Duration::from_secs(10)));
if let FaultAction::Delay(d) = action {
let start = tokio::time::Instant::now();
tokio::time::sleep(d).await;
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_secs(10),
"delay should have paused for at least 10s, got {elapsed:?}"
);
}
assert!(matches!(injector.next_action(), FaultAction::Succeed));
}
}