use std::collections::VecDeque;
use std::time::{Duration, Instant};
pub const TARGET_DELAY_US: u32 = 100_000;
pub const MAX_CWND: u32 = 1_000_000;
pub const MIN_CWND: u32 = 150;
pub const INITIAL_CWND: u32 = 3000;
pub const MSS: u32 = 1400;
const BASE_DELAY_HISTORY_DURATION: Duration = Duration::from_secs(120);
const BASE_DELAY_HISTORY_SIZE: usize = 13;
const GAIN: f64 = 1.0;
#[derive(Debug, Clone)]
struct DelayHistory {
samples: VecDeque<(Instant, u32)>,
max_size: usize,
window_duration: Duration,
}
impl DelayHistory {
fn new(max_size: usize, window_duration: Duration) -> Self {
Self {
samples: VecDeque::with_capacity(max_size),
max_size,
window_duration,
}
}
fn add_sample(&mut self, now: Instant, delay_us: u32) {
let cutoff = now - self.window_duration;
while let Some(&(ts, _)) = self.samples.front() {
if ts < cutoff {
self.samples.pop_front();
} else {
break;
}
}
if self.samples.len() >= self.max_size {
self.samples.pop_front();
}
self.samples.push_back((now, delay_us));
}
fn min(&self) -> Option<u32> {
self.samples.iter().map(|(_, d)| *d).min()
}
}
#[derive(Debug)]
pub struct LedbatController {
cwnd: u32,
ssthresh: u32,
base_delay_history: DelayHistory,
current_delay_filter: VecDeque<u32>,
bytes_in_flight: u32,
rtt_us: u32,
rtt_var_us: u32,
rto_us: u32,
in_slow_start: bool,
last_ack_time: Option<Instant>,
}
impl Default for LedbatController {
fn default() -> Self {
Self::new()
}
}
impl LedbatController {
pub fn new() -> Self {
Self {
cwnd: INITIAL_CWND,
ssthresh: MAX_CWND,
base_delay_history: DelayHistory::new(
BASE_DELAY_HISTORY_SIZE,
BASE_DELAY_HISTORY_DURATION,
),
current_delay_filter: VecDeque::with_capacity(8),
bytes_in_flight: 0,
rtt_us: 1_000_000, rtt_var_us: 0,
rto_us: 3_000_000, in_slow_start: true,
last_ack_time: None,
}
}
pub fn cwnd(&self) -> u32 {
self.cwnd
}
pub fn bytes_in_flight(&self) -> u32 {
self.bytes_in_flight
}
pub fn available_window(&self) -> u32 {
self.cwnd.saturating_sub(self.bytes_in_flight)
}
pub fn can_send(&self) -> bool {
self.bytes_in_flight < self.cwnd
}
pub fn rtt_us(&self) -> u32 {
self.rtt_us
}
pub fn rto_us(&self) -> u32 {
self.rto_us
}
pub fn rto(&self) -> Duration {
Duration::from_micros(self.rto_us as u64)
}
pub fn on_send(&mut self, bytes: u32) {
self.bytes_in_flight += bytes;
}
pub fn on_ack(&mut self, bytes_acked: u32, delay_us: u32, rtt_us: Option<u32>) {
let now = Instant::now();
self.last_ack_time = Some(now);
self.bytes_in_flight = self.bytes_in_flight.saturating_sub(bytes_acked);
if let Some(sample_rtt) = rtt_us {
self.update_rtt(sample_rtt);
}
self.base_delay_history.add_sample(now, delay_us);
if self.current_delay_filter.len() >= 8 {
self.current_delay_filter.pop_front();
}
self.current_delay_filter.push_back(delay_us);
let base_delay = self.base_delay_history.min().unwrap_or(delay_us);
let current_delay = delay_us;
let queuing_delay = current_delay.saturating_sub(base_delay);
self.adjust_window(bytes_acked, queuing_delay);
}
fn adjust_window(&mut self, bytes_acked: u32, queuing_delay_us: u32) {
if self.in_slow_start {
if queuing_delay_us < TARGET_DELAY_US {
self.cwnd += bytes_acked;
} else {
self.in_slow_start = false;
self.ssthresh = self.cwnd;
}
} else {
let off_target = if queuing_delay_us < TARGET_DELAY_US {
(TARGET_DELAY_US - queuing_delay_us) as f64 / TARGET_DELAY_US as f64
} else {
-((queuing_delay_us - TARGET_DELAY_US) as f64 / TARGET_DELAY_US as f64)
};
let cwnd_delta =
(GAIN * off_target * bytes_acked as f64 * MSS as f64 / self.cwnd as f64) as i32;
if cwnd_delta >= 0 {
self.cwnd = self.cwnd.saturating_add(cwnd_delta as u32);
} else {
self.cwnd = self.cwnd.saturating_sub((-cwnd_delta) as u32);
}
}
self.cwnd = self.cwnd.clamp(MIN_CWND, MAX_CWND);
}
fn update_rtt(&mut self, sample_rtt_us: u32) {
if self.rtt_us == 1_000_000 {
self.rtt_us = sample_rtt_us;
self.rtt_var_us = sample_rtt_us / 2;
} else {
let diff = sample_rtt_us.abs_diff(self.rtt_us);
self.rtt_var_us = self.rtt_var_us * 3 / 4 + diff / 4;
self.rtt_us = self.rtt_us * 7 / 8 + sample_rtt_us / 8;
}
self.rto_us = (self.rtt_us + 4 * self.rtt_var_us).max(500_000);
}
pub fn on_loss(&mut self) {
self.ssthresh = (self.cwnd / 2).max(MIN_CWND);
self.cwnd = MIN_CWND;
self.in_slow_start = false;
}
pub fn on_timeout(&mut self) {
self.on_loss();
self.rto_us = (self.rto_us * 2).min(60_000_000); }
pub fn reset(&mut self) {
*self = Self::new();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_initial_state() {
let ctrl = LedbatController::new();
assert_eq!(ctrl.cwnd(), INITIAL_CWND);
assert!(ctrl.can_send());
assert!(ctrl.in_slow_start);
}
#[test]
fn test_slow_start() {
let mut ctrl = LedbatController::new();
ctrl.on_ack(1000, 10_000, Some(50_000)); assert!(ctrl.cwnd() > INITIAL_CWND);
assert!(ctrl.in_slow_start);
}
#[test]
fn test_exit_slow_start() {
let mut ctrl = LedbatController::new();
ctrl.on_ack(1000, 10_000, Some(50_000)); assert!(ctrl.in_slow_start);
ctrl.on_ack(1000, TARGET_DELAY_US + 20_000, Some(200_000)); assert!(!ctrl.in_slow_start);
}
#[test]
fn test_on_loss() {
let mut ctrl = LedbatController::new();
ctrl.cwnd = 100_000;
ctrl.on_loss();
assert_eq!(ctrl.ssthresh, 50_000);
assert_eq!(ctrl.cwnd, MIN_CWND);
assert!(!ctrl.in_slow_start);
}
#[test]
fn test_window_bounds() {
let mut ctrl = LedbatController::new();
ctrl.cwnd = MAX_CWND;
ctrl.on_ack(10_000, 1000, Some(10_000));
assert!(ctrl.cwnd() <= MAX_CWND);
ctrl.on_loss();
assert!(ctrl.cwnd() >= MIN_CWND);
}
#[test]
fn test_rtt_update() {
let mut ctrl = LedbatController::new();
ctrl.update_rtt(100_000); assert_eq!(ctrl.rtt_us, 100_000);
ctrl.update_rtt(120_000); assert!(ctrl.rtt_us > 100_000 && ctrl.rtt_us < 120_000);
}
}