use std::collections::{HashMap, HashSet, VecDeque};
use vortex_core::{DetRng, NodeId};
#[derive(Debug, Clone)]
pub struct LinkConfig {
pub latency_ticks: u64,
pub drop_probability: f64,
pub jitter_ticks: u64,
pub reorder: bool,
pub duplicate_probability: f64,
pub corrupt_probability: f64,
pub bandwidth_limit: u64,
}
impl Default for LinkConfig {
fn default() -> Self {
Self {
latency_ticks: 1,
drop_probability: 0.0,
jitter_ticks: 0,
reorder: false,
duplicate_probability: 0.0,
corrupt_probability: 0.0,
bandwidth_limit: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct InFlightMessage {
pub from: NodeId,
pub to: NodeId,
pub payload: Vec<u8>,
pub deliver_at: u64,
}
pub struct SimNetwork {
current_tick: u64,
in_flight: VecDeque<InFlightMessage>,
link_configs: HashMap<(NodeId, NodeId), LinkConfig>,
partitions: HashSet<(NodeId, NodeId)>,
delivered: HashMap<NodeId, VecDeque<(NodeId, Vec<u8>)>>,
rng: DetRng,
total_sent: u64,
total_dropped: u64,
total_delivered: u64,
total_corrupted: u64,
total_throttled: u64,
tick_delivery_counts: HashMap<(NodeId, NodeId), u64>,
}
impl SimNetwork {
pub fn new(seed: u64) -> Self {
Self {
current_tick: 0,
in_flight: VecDeque::new(),
link_configs: HashMap::new(),
partitions: HashSet::new(),
delivered: HashMap::new(),
rng: DetRng::derive(seed, "network"),
total_sent: 0,
total_dropped: 0,
total_delivered: 0,
total_corrupted: 0,
total_throttled: 0,
tick_delivery_counts: HashMap::new(),
}
}
pub fn set_link_config(&mut self, from: NodeId, to: NodeId, config: LinkConfig) {
self.link_configs.insert((from, to), config);
}
pub fn set_all_links(&mut self, nodes: &[NodeId], config: LinkConfig) {
for &a in nodes {
for &b in nodes {
if a != b {
self.link_configs.insert((a, b), config.clone());
}
}
}
}
pub fn inject_partition(&mut self, a: NodeId, b: NodeId) {
self.partitions.insert((a, b));
self.partitions.insert((b, a));
}
pub fn heal_partition(&mut self, a: NodeId, b: NodeId) {
self.partitions.remove(&(a, b));
self.partitions.remove(&(b, a));
}
pub fn inject_one_way_partition(&mut self, from: NodeId, to: NodeId) {
self.partitions.insert((from, to));
}
pub fn heal_one_way_partition(&mut self, from: NodeId, to: NodeId) {
self.partitions.remove(&(from, to));
}
pub fn heal_all_partitions(&mut self) {
self.partitions.clear();
}
pub fn send(&mut self, from: NodeId, to: NodeId, payload: Vec<u8>) {
self.total_sent += 1;
if self.partitions.contains(&(from, to)) {
self.total_dropped += 1;
return;
}
let config = self
.link_configs
.get(&(from, to))
.cloned()
.unwrap_or_default();
if config.drop_probability > 0.0 && self.rng.next_f64() < config.drop_probability {
self.total_dropped += 1;
return;
}
let base_latency = config.latency_ticks as i64;
let jitter = if config.jitter_ticks > 0 {
let j = config.jitter_ticks as i64;
let range = 2 * j + 1;
(self.rng.next_f64() * range as f64) as i64 - j
} else {
0
};
let effective_latency = (base_latency + jitter).max(1) as u64;
let deliver_at = self.current_tick + effective_latency;
let should_duplicate = config.duplicate_probability > 0.0
&& self.rng.next_f64() < config.duplicate_probability;
self.in_flight.push_back(InFlightMessage {
from,
to,
payload: payload.clone(),
deliver_at,
});
if should_duplicate {
let dup_delay = (self.rng.next_f64() * 5.0) as u64 + 1;
self.in_flight.push_back(InFlightMessage {
from,
to,
payload,
deliver_at: deliver_at + dup_delay,
});
}
}
pub fn tick(&mut self) {
self.current_tick += 1;
self.tick_delivery_counts.clear();
let mut still_in_flight = VecDeque::new();
while let Some(msg) = self.in_flight.pop_front() {
if msg.deliver_at <= self.current_tick {
if self.partitions.contains(&(msg.from, msg.to)) {
self.total_dropped += 1;
continue;
}
let config = self
.link_configs
.get(&(msg.from, msg.to))
.cloned()
.unwrap_or_default();
if config.bandwidth_limit > 0 {
let count = self
.tick_delivery_counts
.entry((msg.from, msg.to))
.or_insert(0);
if *count >= config.bandwidth_limit {
still_in_flight.push_back(InFlightMessage {
deliver_at: self.current_tick + 1,
..msg
});
self.total_throttled += 1;
continue;
}
*count += 1;
}
let payload = if config.corrupt_probability > 0.0
&& self.rng.next_f64() < config.corrupt_probability
{
let mut corrupted = msg.payload;
if !corrupted.is_empty() {
let idx = self.rng.next_u64_below(corrupted.len() as u64) as usize;
corrupted[idx] ^= 1u8 << (self.rng.next_u64_below(8) as u8);
}
self.total_corrupted += 1;
corrupted
} else {
msg.payload
};
self.delivered
.entry(msg.to)
.or_default()
.push_back((msg.from, payload));
self.total_delivered += 1;
} else {
still_in_flight.push_back(msg);
}
}
self.in_flight = still_in_flight;
}
pub fn drain(&mut self, node_id: NodeId) -> Vec<(NodeId, Vec<u8>)> {
self.delivered
.get_mut(&node_id)
.map(|q| q.drain(..).collect())
.unwrap_or_default()
}
pub fn current_tick(&self) -> u64 {
self.current_tick
}
pub fn in_flight_count(&self) -> usize {
self.in_flight.len()
}
pub fn total_sent(&self) -> u64 {
self.total_sent
}
pub fn total_dropped(&self) -> u64 {
self.total_dropped
}
pub fn total_delivered(&self) -> u64 {
self.total_delivered
}
pub fn total_corrupted(&self) -> u64 {
self.total_corrupted
}
pub fn total_throttled(&self) -> u64 {
self.total_throttled
}
pub fn partition_pairs(&self) -> Vec<(NodeId, NodeId)> {
self.partitions.iter().copied().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_delivery() {
let mut net = SimNetwork::new(42);
net.send(1, 2, b"hello".to_vec());
assert!(net.drain(2).is_empty()); net.tick();
let msgs = net.drain(2);
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0].0, 1);
assert_eq!(msgs[0].1, b"hello");
}
#[test]
fn test_partition() {
let mut net = SimNetwork::new(42);
net.inject_partition(1, 2);
net.send(1, 2, b"lost".to_vec());
net.tick();
assert!(net.drain(2).is_empty());
net.heal_partition(1, 2);
net.send(1, 2, b"found".to_vec());
net.tick();
assert_eq!(net.drain(2).len(), 1);
}
#[test]
fn test_one_way_partition() {
let mut net = SimNetwork::new(42);
net.inject_one_way_partition(1, 2);
net.send(1, 2, b"blocked".to_vec());
net.send(2, 1, b"ok".to_vec());
net.tick();
assert!(net.drain(2).is_empty());
assert_eq!(net.drain(1).len(), 1);
}
#[test]
fn test_latency() {
let mut net = SimNetwork::new(42);
net.set_link_config(
1,
2,
LinkConfig {
latency_ticks: 5,
..Default::default()
},
);
net.send(1, 2, b"delayed".to_vec());
for _ in 0..4 {
net.tick();
assert!(net.drain(2).is_empty());
}
net.tick();
assert_eq!(net.drain(2).len(), 1);
}
#[test]
fn test_deterministic() {
let mut net1 = SimNetwork::new(100);
let mut net2 = SimNetwork::new(100);
let config = LinkConfig {
drop_probability: 0.5,
..Default::default()
};
for net in [&mut net1, &mut net2] {
net.set_link_config(1, 2, config.clone());
for i in 0..10 {
net.send(1, 2, vec![i]);
}
net.tick();
}
assert_eq!(net1.drain(2).len(), net2.drain(2).len());
}
#[test]
fn test_jitter_varies_delivery() {
let mut net = SimNetwork::new(42);
net.set_link_config(
1,
2,
LinkConfig {
latency_ticks: 10,
jitter_ticks: 5,
reorder: true,
..Default::default()
},
);
for i in 0..20u8 {
net.send(1, 2, vec![i]);
}
let mut delivery_ticks = Vec::new();
for tick in 1..=20 {
net.tick();
let count = net.drain(2).len();
for _ in 0..count {
delivery_ticks.push(tick);
}
}
assert_eq!(delivery_ticks.len(), 20);
let first = delivery_ticks[0];
assert!(
delivery_ticks.iter().any(|&t| t != first),
"jitter should vary delivery"
);
}
#[test]
fn test_duplication() {
let mut net = SimNetwork::new(42);
net.set_link_config(
1,
2,
LinkConfig {
duplicate_probability: 1.0,
..Default::default()
},
);
net.send(1, 2, b"hello".to_vec());
for _ in 0..10 {
net.tick();
}
assert_eq!(net.drain(2).len(), 2);
}
#[test]
fn test_corruption() {
let mut net = SimNetwork::new(42);
net.set_link_config(
1,
2,
LinkConfig {
corrupt_probability: 1.0,
..Default::default()
},
);
net.send(1, 2, b"hello".to_vec());
net.tick();
let msgs = net.drain(2);
assert_eq!(msgs.len(), 1);
assert_ne!(msgs[0].1, b"hello");
assert_eq!(net.total_corrupted(), 1);
}
#[test]
fn test_bandwidth_limit() {
let mut net = SimNetwork::new(42);
net.set_link_config(
1,
2,
LinkConfig {
bandwidth_limit: 2,
..Default::default()
},
);
for i in 0..5u8 {
net.send(1, 2, vec![i]);
}
net.tick();
assert_eq!(net.drain(2).len(), 2);
net.tick();
assert_eq!(net.drain(2).len(), 2);
net.tick();
assert_eq!(net.drain(2).len(), 1);
assert!(net.total_throttled() >= 3);
}
#[test]
fn test_stats() {
let mut net = SimNetwork::new(42);
net.inject_partition(1, 2);
net.send(1, 2, b"dropped".to_vec()); net.send(2, 1, b"also_dropped".to_vec()); net.send(3, 1, b"ok".to_vec()); net.tick();
assert_eq!(net.total_sent(), 3);
assert_eq!(net.total_dropped(), 2);
assert_eq!(net.total_delivered(), 1);
}
}