use std::collections::BTreeSet;
use std::time::{Duration, Instant};
use tracing::{debug, info, warn};
const DEFAULT_RTO_MS: u64 = 200;
const AIMD_INCREASE_STEP: f64 = 1.0;
const AIMD_DECREASE_FACTOR: f64 = 0.5;
const CWND_MIN: f64 = 1.0;
#[derive(Debug, Clone)]
pub struct InFlightPacket {
pub sequence: u64,
pub sent_at: Instant,
pub retransmit_count: u32,
pub data: Vec<u8>,
}
impl InFlightPacket {
fn new(sequence: u64, data: Vec<u8>) -> Self {
InFlightPacket {
sequence,
sent_at: Instant::now(),
retransmit_count: 0,
data,
}
}
pub fn is_timed_out(&self, rto: Duration) -> bool {
self.sent_at.elapsed() > rto
}
}
#[derive(Debug, Clone)]
pub struct RetransmitRequest {
pub sequence: u64,
pub data: Vec<u8>,
pub retransmit_count: u32,
}
pub struct FlowController {
window_size: usize,
cwnd: f64,
ssthresh: f64,
in_slow_start: bool,
in_flight: Vec<InFlightPacket>,
acked: BTreeSet<u64>,
rto: Duration,
srtt: Option<Duration>,
rttvar: Option<Duration>,
total_sent: u64,
total_acked: u64,
total_lost: u64,
total_retransmits: u64,
}
impl FlowController {
pub fn new(window_size: usize) -> Self {
debug!(window_size, "FlowController created");
FlowController {
window_size,
cwnd: 1.0,
ssthresh: window_size as f64 / 2.0,
in_slow_start: true,
in_flight: Vec::new(),
acked: BTreeSet::new(),
rto: Duration::from_millis(DEFAULT_RTO_MS),
srtt: None,
rttvar: None,
total_sent: 0,
total_acked: 0,
total_lost: 0,
total_retransmits: 0,
}
}
pub fn with_rto(window_size: usize, rto_ms: u64) -> Self {
let mut fc = Self::new(window_size);
fc.rto = Duration::from_millis(rto_ms);
fc
}
pub fn can_send(&self) -> bool {
self.in_flight.len() < self.effective_window()
}
pub fn available_slots(&self) -> usize {
self.effective_window().saturating_sub(self.in_flight.len())
}
pub fn window_size(&self) -> usize {
self.window_size
}
pub fn cwnd(&self) -> f64 {
self.cwnd
}
pub fn effective_window(&self) -> usize {
(self.cwnd as usize).min(self.window_size).max(1)
}
pub fn in_slow_start(&self) -> bool {
self.in_slow_start
}
pub fn set_window_size(&mut self, size: usize) {
debug!(old = self.window_size, new = size, "Window size updated");
self.window_size = size;
}
pub fn in_flight_count(&self) -> usize {
self.in_flight.len()
}
pub fn oldest_unacked_sequence(&self) -> Option<u64> {
self.in_flight.first().map(|p| p.sequence)
}
pub fn on_send(&mut self, sequence: u64, data: Vec<u8>) -> bool {
if !self.can_send() {
warn!(
sequence,
in_flight = self.in_flight.len(),
cwnd = self.cwnd,
"on_send() called but window is full"
);
return false;
}
self.in_flight.push(InFlightPacket::new(sequence, data));
self.total_sent += 1;
debug!(
sequence,
in_flight = self.in_flight.len(),
cwnd = self.cwnd,
effective_window = self.effective_window(),
"Packet sent"
);
true
}
pub fn on_ack(&mut self, sequence: u64) -> bool {
if let Some(pos) = self.in_flight.iter().position(|p| p.sequence == sequence) {
let packet = self.in_flight.remove(pos);
let rtt = packet.sent_at.elapsed();
self.update_rtt(rtt);
self.acked.insert(sequence);
self.total_acked += 1;
self.on_ack_cwnd();
debug!(
sequence,
rtt_ms = rtt.as_millis(),
in_flight = self.in_flight.len(),
cwnd = self.cwnd,
in_slow_start = self.in_slow_start,
"Packet acked"
);
true
} else {
warn!(sequence, "on_ack() for unknown or duplicate sequence");
false
}
}
pub fn timed_out_packets(&mut self) -> Vec<RetransmitRequest> {
let rto = self.rto;
let mut requests = Vec::new();
let mut had_loss = false;
for packet in self.in_flight.iter_mut() {
if packet.is_timed_out(rto) {
warn!(
sequence = packet.sequence,
retransmit_count = packet.retransmit_count,
rto_ms = rto.as_millis(),
"Packet timed out — queuing retransmission"
);
requests.push(RetransmitRequest {
sequence: packet.sequence,
data: packet.data.clone(),
retransmit_count: packet.retransmit_count,
});
packet.retransmit_count += 1;
packet.sent_at = Instant::now();
self.total_lost += 1;
self.total_retransmits += 1;
had_loss = true;
}
}
if had_loss {
self.on_loss_cwnd();
}
requests
}
fn on_ack_cwnd(&mut self) {
if self.in_slow_start {
self.cwnd += AIMD_INCREASE_STEP;
if self.cwnd >= self.ssthresh {
self.in_slow_start = false;
info!(cwnd = self.cwnd, ssthresh = self.ssthresh, "Exiting slow start");
}
} else {
self.cwnd += AIMD_INCREASE_STEP / self.cwnd;
}
self.cwnd = self.cwnd.min(self.window_size as f64);
debug!(cwnd = self.cwnd, "AIMD: cwnd increased");
}
fn on_loss_cwnd(&mut self) {
self.ssthresh = (self.cwnd * AIMD_DECREASE_FACTOR).max(CWND_MIN);
self.cwnd = CWND_MIN;
self.in_slow_start = true;
self.rto = (self.rto * 2).min(Duration::from_secs(60));
warn!(
cwnd = self.cwnd,
ssthresh = self.ssthresh,
rto_ms = self.rto.as_millis(),
"AIMD: multiplicative decrease on loss"
);
}
fn update_rtt(&mut self, rtt: Duration) {
match (self.srtt, self.rttvar) {
(None, None) => {
self.srtt = Some(rtt);
self.rttvar = Some(rtt / 2);
}
(Some(srtt), Some(rttvar)) => {
let rtt_ns = rtt.as_nanos() as i128;
let srtt_ns = srtt.as_nanos() as i128;
let rttvar_ns = rttvar.as_nanos() as i128;
let new_rttvar = (rttvar_ns * 3 / 4 + (srtt_ns - rtt_ns).abs() / 4).max(0) as u64;
let new_srtt = (srtt_ns * 7 / 8 + rtt_ns / 8).max(1) as u64;
self.rttvar = Some(Duration::from_nanos(new_rttvar));
self.srtt = Some(Duration::from_nanos(new_srtt));
let rto_ns = new_srtt + (new_rttvar * 4).max(1_000_000);
self.rto = Duration::from_nanos(rto_ns)
.max(Duration::from_millis(50))
.min(Duration::from_secs(60));
}
_ => {}
}
}
pub fn srtt(&self) -> Option<Duration> {
self.srtt
}
pub fn rttvar(&self) -> Option<Duration> {
self.rttvar
}
pub fn rto(&self) -> Duration {
self.rto
}
pub fn total_sent(&self) -> u64 {
self.total_sent
}
pub fn total_acked(&self) -> u64 {
self.total_acked
}
pub fn total_lost(&self) -> u64 {
self.total_lost
}
pub fn total_retransmits(&self) -> u64 {
self.total_retransmits
}
pub fn loss_rate(&self) -> f64 {
if self.total_sent == 0 { return 0.0; }
self.total_lost as f64 / self.total_sent as f64
}
pub fn is_acked(&self, sequence: u64) -> bool {
self.acked.contains(&sequence)
}
pub fn reset(&mut self) {
debug!("FlowController reset");
let window_size = self.window_size;
*self = Self::new(window_size);
}
}
impl Default for FlowController {
fn default() -> Self {
Self::new(64)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let fc = FlowController::new(4);
assert_eq!(fc.window_size(), 4);
assert_eq!(fc.in_flight_count(), 0);
assert!(fc.can_send());
assert_eq!(fc.available_slots(), 1);
assert!(fc.in_slow_start());
}
#[test]
fn test_window_full() {
let mut fc = FlowController::new(4);
assert!(fc.on_send(0, vec![0]));
assert!(!fc.can_send());
assert_eq!(fc.in_flight_count(), 1);
}
#[test]
fn test_ack_opens_window_and_grows_cwnd() {
let mut fc = FlowController::new(4);
assert!(fc.on_send(0, vec![0]));
assert!(!fc.can_send());
let cwnd_before = fc.cwnd();
fc.on_ack(0);
assert!(fc.cwnd() > cwnd_before);
assert!(fc.can_send());
}
#[test]
fn test_ack_unknown_sequence() {
let mut fc = FlowController::new(4);
fc.on_send(0, vec![0]);
assert!(!fc.on_ack(99));
assert_eq!(fc.in_flight_count(), 1);
}
#[test]
fn test_is_acked() {
let mut fc = FlowController::new(4);
fc.on_send(0, vec![0]);
assert!(!fc.is_acked(0));
fc.on_ack(0);
assert!(fc.is_acked(0));
}
#[test]
fn test_stats() {
let mut fc = FlowController::new(10);
for i in 0..5 {
if fc.can_send() {
fc.on_send(i, vec![0]);
fc.on_ack(i);
}
}
assert_eq!(fc.total_acked(), 5);
}
#[test]
fn test_loss_rate_zero() {
let fc = FlowController::new(4);
assert_eq!(fc.loss_rate(), 0.0);
}
#[test]
fn test_set_window_size() {
let mut fc = FlowController::new(4);
fc.set_window_size(8);
assert_eq!(fc.window_size(), 8);
}
#[test]
fn test_reset() {
let mut fc = FlowController::new(4);
fc.on_send(0, vec![0]);
fc.on_ack(0);
fc.reset();
assert_eq!(fc.in_flight_count(), 0);
assert_eq!(fc.total_sent(), 0);
assert_eq!(fc.total_acked(), 0);
assert!(fc.srtt().is_none());
assert!(fc.in_slow_start());
assert_eq!(fc.cwnd(), 1.0);
}
#[test]
fn test_timed_out_packets_returns_retransmit_requests() {
let mut fc = FlowController::with_rto(4, 1);
fc.on_send(0, b"hello".to_vec());
std::thread::sleep(Duration::from_millis(5));
let requests = fc.timed_out_packets();
assert_eq!(requests.len(), 1);
assert_eq!(requests[0].sequence, 0);
assert_eq!(requests[0].data, b"hello");
assert_eq!(requests[0].retransmit_count, 0);
assert_eq!(fc.total_lost(), 1);
assert_eq!(fc.total_retransmits(), 1);
}
#[test]
fn test_aimd_multiplicative_decrease_on_loss() {
let mut fc = FlowController::with_rto(4, 1);
fc.on_send(0, vec![0]);
std::thread::sleep(Duration::from_millis(5));
let requests = fc.timed_out_packets();
assert!(!requests.is_empty());
assert_eq!(fc.cwnd(), 1.0);
assert!(fc.in_slow_start());
assert_eq!(fc.total_lost(), 1);
}
#[test]
fn test_slow_start_exits_at_ssthresh() {
let mut fc = FlowController::new(64);
let ssthresh = fc.ssthresh;
let mut i = 0u64;
loop {
if fc.can_send() {
fc.on_send(i, vec![0]);
fc.on_ack(i);
i += 1;
}
if !fc.in_slow_start() { break; }
if i > 1000 { break; }
}
assert!(!fc.in_slow_start());
assert!(fc.cwnd() >= ssthresh);
}
#[test]
fn test_srtt_updated_on_ack() {
let mut fc = FlowController::new(4);
fc.on_send(0, vec![0]);
assert!(fc.srtt().is_none());
fc.on_ack(0);
assert!(fc.srtt().is_some());
assert!(fc.rttvar().is_some());
}
#[test]
fn test_default() {
let fc = FlowController::default();
assert_eq!(fc.window_size(), 64);
}
#[test]
fn test_on_send_full_window_returns_false() {
let mut fc = FlowController::new(4);
assert!(fc.on_send(0, vec![0]));
assert!(!fc.on_send(1, vec![0]));
}
#[test]
fn test_multiple_acks_grow_cwnd() {
let mut fc = FlowController::new(64);
let initial_cwnd = fc.cwnd();
for i in 0..10u64 {
if fc.can_send() {
fc.on_send(i, vec![0]);
fc.on_ack(i);
}
}
assert!(fc.cwnd() > initial_cwnd);
assert_eq!(fc.total_acked(), 10);
}
#[test]
fn test_oldest_unacked_sequence() {
let mut fc = FlowController::new(4);
assert!(fc.oldest_unacked_sequence().is_none());
fc.on_send(5, vec![0]);
assert_eq!(fc.oldest_unacked_sequence(), Some(5));
}
#[test]
fn test_effective_window_bounded_by_cwnd_and_max() {
let fc = FlowController::new(4);
assert_eq!(fc.effective_window(), 1);
}
#[test]
fn test_rto_doubles_on_loss() {
let mut fc = FlowController::with_rto(4, 1);
let rto_before = fc.rto();
fc.on_send(0, vec![0]);
std::thread::sleep(Duration::from_millis(5));
fc.timed_out_packets();
assert!(fc.rto() > rto_before);
}
#[test]
fn test_total_retransmits() {
let mut fc = FlowController::with_rto(4, 1);
fc.on_send(0, vec![0]);
std::thread::sleep(Duration::from_millis(5));
fc.timed_out_packets();
assert_eq!(fc.total_retransmits(), 1);
std::thread::sleep(Duration::from_millis(10));
fc.timed_out_packets();
assert_eq!(fc.total_retransmits(), 2);
}
}