use std::collections::VecDeque;
use std::net::{SocketAddr, UdpSocket};
use std::time::{Duration, Instant};
use crate::prng::Xorshift64;
#[derive(Clone)]
pub struct ChaosConfig {
pub loss: f64,
pub loss_burst: u32,
pub duplicate: f64,
pub reorder: f64,
pub jitter_min: Duration,
pub jitter_max: Duration,
pub seed: u64,
}
impl Default for ChaosConfig {
fn default() -> Self {
Self {
loss: 0.0,
loss_burst: 1,
duplicate: 0.0,
reorder: 0.0,
jitter_min: Duration::ZERO,
jitter_max: Duration::ZERO,
seed: 0x000D_DEAD_BEEF_CAFE_u64.wrapping_mul(2),
}
}
}
#[derive(Default, Clone, Debug)]
pub struct ChaosStats {
pub forwarded: u64,
pub dropped: u64,
pub duplicated: u64,
pub reordered: u64,
pub bytes_in: u64,
pub bytes_out: u64,
}
struct PendingPacket {
when: Instant,
target: SocketAddr,
bytes: Vec<u8>,
}
pub struct UdpChaosProxy {
listen: UdpSocket,
forward_to: SocketAddr,
cfg: ChaosConfig,
rng: Xorshift64,
pending: VecDeque<PendingPacket>,
held: Option<PendingPacket>,
drops_left: u32,
last_client: Option<SocketAddr>,
pub stats: ChaosStats,
}
impl UdpChaosProxy {
pub fn new(bind: SocketAddr, forward: SocketAddr, cfg: ChaosConfig) -> std::io::Result<Self> {
let listen = UdpSocket::bind(bind)?;
listen.set_nonblocking(true)?;
let seed = if cfg.seed == 0 {
0x000D_DEAD_BEEF_CAFE_u64.wrapping_mul(2)
} else {
cfg.seed
};
Ok(Self {
listen,
forward_to: forward,
cfg,
rng: Xorshift64::new(seed),
pending: VecDeque::new(),
held: None,
drops_left: 0,
last_client: None,
stats: ChaosStats::default(),
})
}
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.listen.local_addr()
}
pub fn run_for(&mut self, total: Duration) -> std::io::Result<()> {
let deadline = Instant::now() + total;
let mut buf = vec![0u8; 65_536];
while Instant::now() < deadline {
self.flush_due_packets()?;
match self.listen.recv_from(&mut buf) {
Ok((n, from)) => {
self.stats.bytes_in = self.stats.bytes_in.saturating_add(n as u64);
let target = if from == self.forward_to {
match self.last_client {
Some(c) => c,
None => continue, }
} else {
self.last_client = Some(from);
self.forward_to
};
self.process_inbound(&buf[..n], target);
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
if self.pending.is_empty() && self.held.is_none() {
std::thread::sleep(Duration::from_micros(200));
}
}
Err(e) => return Err(e),
}
}
self.flush_due_packets()?;
if let Some(h) = self.held.take() {
self.pending.push_back(h);
}
while let Some(p) = self.pending.pop_front() {
let _ = self.listen.send_to(&p.bytes, p.target);
self.stats.forwarded = self.stats.forwarded.saturating_add(1);
self.stats.bytes_out = self.stats.bytes_out.saturating_add(p.bytes.len() as u64);
}
Ok(())
}
fn process_inbound(&mut self, bytes: &[u8], target: SocketAddr) {
if self.drops_left > 0 {
self.drops_left -= 1;
self.stats.dropped += 1;
return;
}
if self.rng.bernoulli(self.cfg.loss) {
self.stats.dropped += 1;
self.drops_left = self.cfg.loss_burst.saturating_sub(1);
return;
}
if self.cfg.reorder > 0.0 && self.rng.bernoulli(self.cfg.reorder) {
let now = Instant::now();
let when = self.compute_jitter(now);
let pkt = PendingPacket {
when,
target,
bytes: bytes.to_vec(),
};
if let Some(prev) = self.held.replace(pkt) {
self.stats.reordered += 1;
self.enqueue(prev);
}
return;
}
if let Some(prev) = self.held.take() {
self.enqueue(prev);
}
let dup = self.rng.bernoulli(self.cfg.duplicate);
let now = Instant::now();
let when = self.compute_jitter(now);
self.enqueue(PendingPacket {
when,
target,
bytes: bytes.to_vec(),
});
if dup {
let when2 = self.compute_jitter(now);
self.enqueue(PendingPacket {
when: when2,
target,
bytes: bytes.to_vec(),
});
self.stats.duplicated += 1;
}
}
fn enqueue(&mut self, p: PendingPacket) {
let pos = self.pending.iter().position(|q| q.when > p.when);
match pos {
Some(i) => self.pending.insert(i, p),
None => self.pending.push_back(p),
}
}
fn compute_jitter(&mut self, now: Instant) -> Instant {
if self.cfg.jitter_max.is_zero() {
return now;
}
let span = self
.cfg
.jitter_max
.saturating_sub(self.cfg.jitter_min)
.as_micros();
let span_u64 = u64::try_from(span).unwrap_or(u64::MAX);
let extra_us = if span_u64 == 0 {
0
} else {
self.rng.range_u64(span_u64)
};
let delay = self.cfg.jitter_min + Duration::from_micros(extra_us);
now + delay
}
fn flush_due_packets(&mut self) -> std::io::Result<()> {
let now = Instant::now();
while let Some(p) = self.pending.front() {
if p.when > now {
break;
}
let Some(p) = self.pending.pop_front() else {
break;
};
self.listen.send_to(&p.bytes, p.target)?;
self.stats.forwarded = self.stats.forwarded.saturating_add(1);
self.stats.bytes_out = self.stats.bytes_out.saturating_add(p.bytes.len() as u64);
}
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)] mod tests {
use super::*;
use std::net::{Ipv4Addr, SocketAddrV4};
fn local(port: u16) -> SocketAddr {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port))
}
#[test]
fn loss_drops_at_configured_rate() {
let echo = UdpSocket::bind(local(0)).unwrap();
let echo_addr = echo.local_addr().unwrap();
let mut proxy = UdpChaosProxy::new(
local(0),
echo_addr,
ChaosConfig {
loss: 1.0,
..Default::default()
},
)
.unwrap();
let proxy_addr = proxy.local_addr().unwrap();
let sender = UdpSocket::bind(local(0)).unwrap();
for _ in 0..50 {
sender.send_to(b"hello", proxy_addr).unwrap();
}
proxy.run_for(Duration::from_millis(200)).unwrap();
assert_eq!(proxy.stats.forwarded, 0);
assert!(proxy.stats.dropped >= 50);
}
#[test]
fn no_chaos_passes_everything() {
let recv = UdpSocket::bind(local(0)).unwrap();
recv.set_nonblocking(true).unwrap();
let recv_addr = recv.local_addr().unwrap();
let mut proxy = UdpChaosProxy::new(local(0), recv_addr, ChaosConfig::default()).unwrap();
let proxy_addr = proxy.local_addr().unwrap();
let sender = UdpSocket::bind(local(0)).unwrap();
for _ in 0..30 {
sender.send_to(b"x", proxy_addr).unwrap();
}
proxy.run_for(Duration::from_millis(200)).unwrap();
assert_eq!(proxy.stats.dropped, 0);
assert_eq!(proxy.stats.forwarded, 30);
}
#[test]
fn duplicate_bumps_forwarded() {
let recv = UdpSocket::bind(local(0)).unwrap();
let recv_addr = recv.local_addr().unwrap();
let mut proxy = UdpChaosProxy::new(
local(0),
recv_addr,
ChaosConfig {
duplicate: 1.0,
..Default::default()
},
)
.unwrap();
let proxy_addr = proxy.local_addr().unwrap();
let sender = UdpSocket::bind(local(0)).unwrap();
for _ in 0..10 {
sender.send_to(b"y", proxy_addr).unwrap();
}
proxy.run_for(Duration::from_millis(200)).unwrap();
assert_eq!(proxy.stats.duplicated, 10);
assert_eq!(proxy.stats.forwarded, 20);
}
#[test]
fn jitter_delays_packets() {
let recv = UdpSocket::bind(local(0)).unwrap();
let recv_addr = recv.local_addr().unwrap();
let mut proxy = UdpChaosProxy::new(
local(0),
recv_addr,
ChaosConfig {
jitter_min: Duration::from_millis(20),
jitter_max: Duration::from_millis(40),
..Default::default()
},
)
.unwrap();
let proxy_addr = proxy.local_addr().unwrap();
let sender = UdpSocket::bind(local(0)).unwrap();
let t0 = Instant::now();
sender.send_to(b"z", proxy_addr).unwrap();
proxy.run_for(Duration::from_millis(60)).unwrap();
assert_eq!(proxy.stats.forwarded, 1);
assert!(t0.elapsed() >= Duration::from_millis(20));
}
}