use dashmap::DashMap;
use std::{
net::SocketAddr,
sync::{Arc, LazyLock},
time::Duration,
};
use crate::simulation::{FaultConfig, SimulationRng, TimeSource, VirtualTime};
#[derive(Debug)]
pub struct PendingMessage {
pub deadline: u64,
pub data: Vec<u8>,
pub from: SocketAddr,
pub target: SocketAddr,
}
#[derive(Debug, Clone, Default)]
pub struct NetworkStats {
pub messages_sent: u64,
pub messages_delivered: u64,
pub messages_dropped_loss: u64,
pub messages_dropped_partition: u64,
pub messages_dropped_crash: u64,
pub messages_queued: u64,
pub messages_delayed_delivered: u64,
pub total_latency_nanos: u64,
}
impl NetworkStats {
pub fn total_dropped(&self) -> u64 {
self.messages_dropped_loss + self.messages_dropped_partition + self.messages_dropped_crash
}
pub fn loss_ratio(&self) -> f64 {
if self.messages_sent == 0 {
0.0
} else {
self.total_dropped() as f64 / self.messages_sent as f64
}
}
pub fn average_latency(&self) -> Duration {
if self.messages_delayed_delivered == 0 {
Duration::ZERO
} else {
Duration::from_nanos(self.total_latency_nanos / self.messages_delayed_delivered)
}
}
}
pub struct FaultInjectorState {
pub config: FaultConfig,
#[allow(dead_code)] pub rng: SimulationRng,
pub virtual_time: Option<VirtualTime>,
pub pending_messages: Vec<PendingMessage>,
pub stats: NetworkStats,
pub network_name: Option<String>,
}
impl FaultInjectorState {
pub fn new(config: FaultConfig, seed: u64) -> Self {
Self {
config,
rng: SimulationRng::new(seed),
virtual_time: None,
pending_messages: Vec::new(),
stats: NetworkStats::default(),
network_name: None,
}
}
pub fn with_virtual_time(mut self, vt: VirtualTime) -> Self {
self.virtual_time = Some(vt);
self
}
#[allow(dead_code)] pub fn virtual_time(&self) -> Option<&VirtualTime> {
self.virtual_time.as_ref()
}
pub fn advance_time(&mut self) -> usize {
use crate::transport::in_memory_socket::deliver_packet_to_network;
let Some(ref vt) = self.virtual_time else {
return 0;
};
let Some(ref network_name) = self.network_name else {
tracing::warn!("advance_time called but network_name not set");
return 0;
};
let now = vt.now_nanos();
let mut delivered = 0;
let ready_indices: Vec<usize> = self
.pending_messages
.iter()
.enumerate()
.filter(|(_, p)| p.deadline <= now)
.map(|(i, _)| i)
.collect();
let mut ready_messages: Vec<PendingMessage> = ready_indices
.into_iter()
.rev()
.map(|idx| self.pending_messages.remove(idx))
.collect();
ready_messages.sort_by(|a, b| {
a.deadline
.cmp(&b.deadline)
.then_with(|| a.from.cmp(&b.from))
.then_with(|| a.target.cmp(&b.target))
});
for pending in ready_messages {
if deliver_packet_to_network(network_name, pending.target, pending.data, pending.from) {
delivered += 1;
self.stats.messages_delayed_delivered += 1;
} else {
tracing::trace!(
target = %pending.target,
"VirtualTime delivery failed (socket may have been dropped)"
);
}
}
delivered
}
pub fn stats(&self) -> &NetworkStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = NetworkStats::default();
}
}
static FAULT_INJECTORS: LazyLock<DashMap<String, Arc<std::sync::Mutex<FaultInjectorState>>>> =
LazyLock::new(DashMap::new);
pub fn set_fault_injector(
network_name: &str,
state: Option<Arc<std::sync::Mutex<FaultInjectorState>>>,
) {
match state {
Some(s) => {
{
let mut state = s.lock().unwrap();
state.network_name = Some(network_name.to_string());
}
FAULT_INJECTORS.insert(network_name.to_string(), s);
}
None => {
FAULT_INJECTORS.remove(network_name);
}
}
}
pub fn get_fault_injector(network_name: &str) -> Option<Arc<std::sync::Mutex<FaultInjectorState>>> {
FAULT_INJECTORS.get(network_name).map(|r| r.value().clone())
}
#[allow(dead_code)]
pub fn clear_all_fault_injectors() {
FAULT_INJECTORS.clear();
}
#[cfg(test)]
mod tests {
use super::*;
use crate::simulation::FaultConfigBuilder;
#[test]
fn test_network_stats_default() {
let stats = NetworkStats::default();
assert_eq!(stats.messages_sent, 0);
assert_eq!(stats.messages_delivered, 0);
assert_eq!(stats.total_dropped(), 0);
assert_eq!(stats.loss_ratio(), 0.0);
assert_eq!(stats.average_latency(), Duration::ZERO);
}
#[test]
fn test_network_stats_calculations() {
let stats = NetworkStats {
messages_sent: 100,
messages_delivered: 80,
messages_dropped_loss: 10,
messages_dropped_partition: 5,
messages_dropped_crash: 5,
messages_queued: 0,
messages_delayed_delivered: 20,
total_latency_nanos: 200_000_000, };
assert_eq!(stats.total_dropped(), 20);
assert!((stats.loss_ratio() - 0.2).abs() < 0.001);
assert_eq!(stats.average_latency(), Duration::from_millis(10));
}
#[test]
fn test_fault_injector_state_new() {
let config = FaultConfigBuilder::default().build();
let state = FaultInjectorState::new(config, 12345);
assert!(state.virtual_time.is_none());
assert!(state.pending_messages.is_empty());
assert_eq!(state.stats.messages_sent, 0);
}
#[test]
fn test_fault_injector_state_with_virtual_time() {
let config = FaultConfigBuilder::default().build();
let vt = VirtualTime::new();
let state = FaultInjectorState::new(config, 12345).with_virtual_time(vt);
assert!(state.virtual_time.is_some());
}
#[test]
fn test_fault_injector_reset_stats() {
let config = FaultConfigBuilder::default().build();
let mut state = FaultInjectorState::new(config, 12345);
state.stats.messages_sent = 100;
state.stats.messages_dropped_loss = 10;
state.reset_stats();
assert_eq!(state.stats.messages_sent, 0);
assert_eq!(state.stats.messages_dropped_loss, 0);
}
}