use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(transparent)]
pub struct FaultId(pub u64);
impl FaultId {
pub fn new(n: u64) -> Self {
Self(n)
}
}
impl std::fmt::Display for FaultId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "fault_{}", self.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "kind")]
pub enum FaultKind {
DropMessages {
from_peer: Option<u32>,
to_peer: Option<u32>,
probability: f64,
},
InjectLatency {
from_peer: Option<u32>,
to_peer: Option<u32>,
min_ms: u64,
max_ms: u64,
},
Partition { side_a: Vec<u32>, side_b: Vec<u32> },
PauseLeader { node_id: u32 },
CorruptBytes {
from_peer: Option<u32>,
to_peer: Option<u32>,
probability: f64,
},
}
impl FaultKind {
pub fn variant_name(&self) -> &'static str {
match self {
FaultKind::DropMessages { .. } => "DropMessages",
FaultKind::InjectLatency { .. } => "InjectLatency",
FaultKind::Partition { .. } => "Partition",
FaultKind::PauseLeader { .. } => "PauseLeader",
FaultKind::CorruptBytes { .. } => "CorruptBytes",
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct FaultRecord {
pub id: FaultId,
pub kind: FaultKind,
pub created_at_unix_micros: i64,
pub ttl_secs: Option<u64>,
#[serde(skip)]
expires_at: Option<Instant>,
}
impl FaultRecord {
pub fn is_expired(&self, now: Instant) -> bool {
self.expires_at.map(|exp| now >= exp).unwrap_or(false)
}
}
#[derive(Clone)]
pub struct FaultRegistry {
inner: Arc<RegistryInner>,
}
struct RegistryInner {
next_id: AtomicU64,
faults: RwLock<HashMap<FaultId, FaultRecord>>,
}
impl FaultRegistry {
pub fn new() -> Self {
Self {
inner: Arc::new(RegistryInner {
next_id: AtomicU64::new(1),
faults: RwLock::new(HashMap::new()),
}),
}
}
pub fn inject(&self, kind: FaultKind, ttl_secs: Option<u64>) -> FaultId {
let id = FaultId(self.inner.next_id.fetch_add(1, Ordering::Relaxed));
let now_inst = Instant::now();
let now_micros = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_micros() as i64)
.unwrap_or(0);
let record = FaultRecord {
id,
kind,
created_at_unix_micros: now_micros,
ttl_secs,
expires_at: ttl_secs.map(|s| now_inst + Duration::from_secs(s)),
};
self.inner.faults.write().insert(id, record);
id
}
pub fn remove(&self, id: FaultId) -> bool {
self.inner.faults.write().remove(&id).is_some()
}
pub fn clear(&self) -> usize {
let mut faults = self.inner.faults.write();
let n = faults.len();
faults.clear();
n
}
pub fn list(&self) -> Vec<FaultRecord> {
let now = Instant::now();
let mut faults = self.inner.faults.write();
faults.retain(|_, r| !r.is_expired(now));
faults.values().cloned().collect()
}
pub fn active_count(&self) -> usize {
self.list().len()
}
}
impl Default for FaultRegistry {
fn default() -> Self {
Self::new()
}
}
pub trait FaultyNetwork: Send + Sync {
fn should_drop(&self, from: u32, to: u32) -> bool;
fn inject_latency(&self, from: u32, to: u32) -> Option<Duration>;
fn should_corrupt(&self, from: u32, to: u32) -> bool;
}
pub struct NoopFaultyNetwork;
impl FaultyNetwork for NoopFaultyNetwork {
fn should_drop(&self, _from: u32, _to: u32) -> bool {
false
}
fn inject_latency(&self, _from: u32, _to: u32) -> Option<Duration> {
None
}
fn should_corrupt(&self, _from: u32, _to: u32) -> bool {
false
}
}
pub struct RegistryFaultyNetwork {
registry: FaultRegistry,
}
impl RegistryFaultyNetwork {
pub fn new(registry: FaultRegistry) -> Self {
Self { registry }
}
fn matches_peers(from_filter: Option<u32>, to_filter: Option<u32>, from: u32, to: u32) -> bool {
from_filter.map(|f| f == from).unwrap_or(true) && to_filter.map(|t| t == to).unwrap_or(true)
}
}
impl FaultyNetwork for RegistryFaultyNetwork {
fn should_drop(&self, from: u32, to: u32) -> bool {
let now_secs = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let pseudo = ((from as u64) ^ ((to as u64) << 8) ^ now_secs).wrapping_mul(2654435761);
let unit = ((pseudo % 1_000_000) as f64) / 1_000_000.0;
for fault in self.registry.list() {
match fault.kind {
FaultKind::DropMessages {
from_peer,
to_peer,
probability,
} => {
if Self::matches_peers(from_peer, to_peer, from, to) && unit < probability {
return true;
}
}
FaultKind::Partition {
ref side_a,
ref side_b,
} => {
let from_a = side_a.contains(&from);
let from_b = side_b.contains(&from);
let to_a = side_a.contains(&to);
let to_b = side_b.contains(&to);
if (from_a && to_b) || (from_b && to_a) {
return true;
}
}
FaultKind::PauseLeader { node_id } => {
if from == node_id {
return true;
}
}
_ => {}
}
}
false
}
fn inject_latency(&self, from: u32, to: u32) -> Option<Duration> {
for fault in self.registry.list() {
if let FaultKind::InjectLatency {
from_peer,
to_peer,
min_ms,
max_ms,
} = fault.kind
{
if Self::matches_peers(from_peer, to_peer, from, to) {
let mid = (min_ms + max_ms) / 2;
return Some(Duration::from_millis(mid));
}
}
}
None
}
fn should_corrupt(&self, from: u32, to: u32) -> bool {
let now_secs = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let pseudo =
((from as u64) ^ ((to as u64) << 16) ^ (now_secs << 32)).wrapping_mul(2654435761);
let unit = ((pseudo % 1_000_000) as f64) / 1_000_000.0;
for fault in self.registry.list() {
if let FaultKind::CorruptBytes {
from_peer,
to_peer,
probability,
} = fault.kind
{
if Self::matches_peers(from_peer, to_peer, from, to) && unit < probability {
return true;
}
}
}
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn registry_starts_empty() {
let r = FaultRegistry::new();
assert_eq!(r.active_count(), 0);
assert!(r.list().is_empty());
}
#[test]
fn inject_and_list() {
let r = FaultRegistry::new();
let id1 = r.inject(
FaultKind::DropMessages {
from_peer: Some(1),
to_peer: Some(2),
probability: 0.5,
},
None,
);
let id2 = r.inject(FaultKind::PauseLeader { node_id: 1 }, Some(60));
assert_ne!(id1, id2, "fault IDs must be unique");
let list = r.list();
assert_eq!(list.len(), 2);
let has_drop = list
.iter()
.any(|f| matches!(f.kind, FaultKind::DropMessages { .. }));
let has_pause = list
.iter()
.any(|f| matches!(f.kind, FaultKind::PauseLeader { .. }));
assert!(has_drop && has_pause);
}
#[test]
fn remove_one_by_id() {
let r = FaultRegistry::new();
let id = r.inject(FaultKind::PauseLeader { node_id: 1 }, None);
assert!(r.remove(id));
assert!(!r.remove(id), "second remove should be no-op");
assert_eq!(r.active_count(), 0);
}
#[test]
fn clear_returns_count() {
let r = FaultRegistry::new();
r.inject(FaultKind::PauseLeader { node_id: 1 }, None);
r.inject(FaultKind::PauseLeader { node_id: 2 }, None);
let n = r.clear();
assert_eq!(n, 2);
assert_eq!(r.active_count(), 0);
}
#[test]
fn noop_network_passes_everything() {
let net = NoopFaultyNetwork;
for from in 1..=3 {
for to in 1..=3 {
if from == to {
continue;
}
assert!(!net.should_drop(from, to));
assert!(net.inject_latency(from, to).is_none());
assert!(!net.should_corrupt(from, to));
}
}
}
#[test]
fn registry_network_drops_with_partition() {
let registry = FaultRegistry::new();
let net = RegistryFaultyNetwork::new(registry.clone());
assert!(!net.should_drop(1, 2));
registry.inject(
FaultKind::Partition {
side_a: vec![1, 2],
side_b: vec![3, 4],
},
None,
);
assert!(!net.should_drop(1, 2));
assert!(!net.should_drop(3, 4));
assert!(net.should_drop(1, 3));
assert!(net.should_drop(2, 4));
assert!(net.should_drop(3, 1));
assert!(net.should_drop(4, 2));
}
#[test]
fn registry_network_pauses_leader() {
let registry = FaultRegistry::new();
let net = RegistryFaultyNetwork::new(registry.clone());
registry.inject(FaultKind::PauseLeader { node_id: 1 }, None);
assert!(net.should_drop(1, 2));
assert!(net.should_drop(1, 3));
assert!(!net.should_drop(2, 3));
}
#[test]
fn fault_kind_variant_names_are_stable() {
assert_eq!(
FaultKind::DropMessages {
from_peer: None,
to_peer: None,
probability: 0.0,
}
.variant_name(),
"DropMessages"
);
assert_eq!(
FaultKind::InjectLatency {
from_peer: None,
to_peer: None,
min_ms: 0,
max_ms: 0,
}
.variant_name(),
"InjectLatency"
);
assert_eq!(
FaultKind::Partition {
side_a: vec![],
side_b: vec![],
}
.variant_name(),
"Partition"
);
assert_eq!(
FaultKind::PauseLeader { node_id: 0 }.variant_name(),
"PauseLeader"
);
assert_eq!(
FaultKind::CorruptBytes {
from_peer: None,
to_peer: None,
probability: 0.0,
}
.variant_name(),
"CorruptBytes"
);
}
#[test]
fn fault_id_display_format() {
assert_eq!(FaultId(42).to_string(), "fault_42");
}
}