use crate::store::StoreError;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InjectionPoint {
BatchStart {
batch_id: u64,
item_count: usize,
},
BatchBeginWritten {
batch_id: u64,
item_count: usize,
},
BatchItemWritten {
batch_id: u64,
item_index: usize,
total_items: usize,
},
BatchItemsComplete {
batch_id: u64,
item_count: usize,
},
BatchCommitWritten {
batch_id: u64,
},
BatchFsync {
batch_id: u64,
},
BatchPrePublish {
batch_id: u64,
item_count: usize,
},
SingleAppendStart {
entity: &'static str,
},
SingleAppendWritten {
entity: &'static str,
},
SegmentRotation {
old_segment: u64,
new_segment: u64,
},
}
pub trait FaultInjector: Send + Sync {
fn check(&self, point: InjectionPoint) -> Option<StoreError>;
fn validate(&self) -> Result<(), String> {
Ok(())
}
}
pub struct CountdownInjector {
trigger_after: usize,
current: std::sync::atomic::AtomicUsize,
filter: Option<Box<dyn Fn(InjectionPoint) -> bool + Send + Sync>>,
action: CountdownAction,
}
#[derive(Clone, Copy, Debug)]
pub enum CountdownAction {
Fail(&'static str),
Noop,
}
impl CountdownInjector {
pub fn new(trigger_after: usize, action: CountdownAction) -> Self {
Self {
trigger_after,
current: std::sync::atomic::AtomicUsize::new(0),
filter: None,
action,
}
}
pub fn with_filter<F>(mut self, filter: F) -> Self
where
F: Fn(InjectionPoint) -> bool + Send + Sync + 'static,
{
self.filter = Some(Box::new(filter));
self
}
pub fn after_batch_items(n: usize) -> Self {
Self::new(
n,
CountdownAction::Fail("simulated fault during batch item write"),
)
.with_filter(|p| matches!(p, InjectionPoint::BatchItemWritten { .. }))
}
pub fn after_batch_begin() -> Self {
Self::new(
1,
CountdownAction::Fail("simulated fault after BEGIN marker"),
)
.with_filter(|p| matches!(p, InjectionPoint::BatchBeginWritten { .. }))
}
pub fn after_commit_before_fsync() -> Self {
Self::new(
1,
CountdownAction::Fail("simulated fault after COMMIT before fsync"),
)
.with_filter(|p| matches!(p, InjectionPoint::BatchCommitWritten { .. }))
}
}
impl FaultInjector for CountdownInjector {
fn check(&self, point: InjectionPoint) -> Option<StoreError> {
let dominated = self.filter.as_ref().is_none_or(|f| f(point));
if !dominated {
return None;
}
let count = self
.current
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if count + 1 < self.trigger_after {
return None;
}
match self.action {
CountdownAction::Fail(msg) => {
Some(StoreError::FaultInjected(format!("{msg} at {point:?}")))
}
CountdownAction::Noop => {
tracing::debug!("FaultInjector noop at {point:?}");
None
}
}
}
}
pub struct ProbabilisticInjector {
probability: f64,
filter: Option<Box<dyn Fn(InjectionPoint) -> bool + Send + Sync>>,
action: CountdownAction,
}
impl ProbabilisticInjector {
pub fn new(probability: f64, action: CountdownAction) -> Self {
assert!(
(0.0..=1.0).contains(&probability),
"probability must be in [0.0, 1.0]"
);
Self {
probability,
filter: None,
action,
}
}
pub fn with_filter<F>(mut self, filter: F) -> Self
where
F: Fn(InjectionPoint) -> bool + Send + Sync + 'static,
{
self.filter = Some(Box::new(filter));
self
}
}
impl FaultInjector for ProbabilisticInjector {
fn check(&self, point: InjectionPoint) -> Option<StoreError> {
let dominated = self.filter.as_ref().is_none_or(|f| f(point));
if !dominated {
return None;
}
let mut rng = fastrand::Rng::new();
if rng.f64() >= self.probability {
return None;
}
match self.action {
CountdownAction::Fail(msg) => {
Some(StoreError::FaultInjected(format!("{msg} at {point:?}")))
}
CountdownAction::Noop => {
tracing::debug!("ProbabilisticInjector noop at {point:?}");
None
}
}
}
}
pub fn maybe_inject(
point: InjectionPoint,
injector: &Option<Arc<dyn FaultInjector>>,
) -> Result<(), StoreError> {
if let Some(inj) = injector {
if let Some(err) = inj.check(point) {
return Err(err);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn countdown_triggers_at_count() {
let injector = CountdownInjector::new(3, CountdownAction::Fail("boom"));
let point = InjectionPoint::BatchItemWritten {
batch_id: 1,
item_index: 0,
total_items: 5,
};
assert!(injector.check(point).is_none());
assert!(injector.check(point).is_none());
assert!(injector.check(point).is_some()); }
#[test]
fn countdown_noop_never_faults() {
let injector = CountdownInjector::new(1, CountdownAction::Noop);
let point = InjectionPoint::BatchItemWritten {
batch_id: 1,
item_index: 0,
total_items: 5,
};
assert!(injector.check(point).is_none());
assert!(injector.check(point).is_none());
}
#[test]
fn countdown_respects_filter() {
let injector = CountdownInjector::new(1, CountdownAction::Fail("boom"))
.with_filter(|p| matches!(p, InjectionPoint::BatchBeginWritten { .. }));
let item_point = InjectionPoint::BatchItemWritten {
batch_id: 1,
item_index: 0,
total_items: 5,
};
let begin_point = InjectionPoint::BatchBeginWritten {
batch_id: 1,
item_count: 5,
};
assert!(injector.check(item_point).is_none()); assert!(injector.check(begin_point).is_some()); }
#[test]
fn fault_injected_error_is_store_error() {
let injector = CountdownInjector::new(1, CountdownAction::Fail("test fault"));
let point = InjectionPoint::BatchStart {
batch_id: 42,
item_count: 3,
};
let err = injector.check(point).expect("should produce error");
assert!(
matches!(err, StoreError::FaultInjected(_)),
"expected FaultInjected variant"
);
}
}