#![allow(dead_code)]
use core::time::Duration;
use crate::quic::loss::SentPacket;
use crate::quic::pn::PnSpaceId;
pub(crate) const K_INITIAL_WINDOW_PACKETS: u64 = 10;
pub(crate) const K_LOSS_REDUCTION_FACTOR_NUM: u64 = 1;
pub(crate) const K_LOSS_REDUCTION_FACTOR_DEN: u64 = 2;
pub(crate) const K_DEFAULT_MAX_DATAGRAM_SIZE: u64 = 1200;
#[inline]
pub(crate) const fn k_minimum_window(max_datagram_size: u64) -> u64 {
2 * max_datagram_size
}
#[derive(Debug)]
pub(crate) struct NewReno {
pub(crate) max_datagram_size: u64,
pub(crate) cwnd: u64,
pub(crate) ssthresh: u64,
pub(crate) bytes_in_flight: u64,
pub(crate) recovery_start_time: Option<Duration>,
pub(crate) ecn_ce_counters: [u64; 3],
}
impl Default for NewReno {
fn default() -> Self {
Self::new()
}
}
impl NewReno {
pub(crate) fn new() -> Self {
let mds = K_DEFAULT_MAX_DATAGRAM_SIZE;
Self {
max_datagram_size: mds,
cwnd: K_INITIAL_WINDOW_PACKETS * mds,
ssthresh: u64::MAX,
bytes_in_flight: 0,
recovery_start_time: None,
ecn_ce_counters: [0; 3],
}
}
pub(crate) fn on_packet_sent(&mut self, bytes: u64) {
self.bytes_in_flight = self.bytes_in_flight.saturating_add(bytes);
}
pub(crate) fn on_packets_acked(&mut self, acked: &[SentPacket]) {
for p in acked {
debug_assert!(p.in_flight, "caller must filter to in-flight packets");
self.bytes_in_flight = self.bytes_in_flight.saturating_sub(p.sent_bytes as u64);
if self.in_congestion_recovery(p.time_sent) {
continue;
}
if self.cwnd < self.ssthresh {
self.cwnd = self.cwnd.saturating_add(p.sent_bytes as u64);
} else {
let inc =
self.max_datagram_size.saturating_mul(p.sent_bytes as u64) / self.cwnd.max(1);
self.cwnd = self.cwnd.saturating_add(inc.max(1));
}
}
}
pub(crate) fn on_packets_lost(&mut self, lost: &[SentPacket], _now: Duration) {
if lost.is_empty() {
return;
}
let mut most_recent_time = Duration::ZERO;
for p in lost {
self.bytes_in_flight = self.bytes_in_flight.saturating_sub(p.sent_bytes as u64);
if p.time_sent > most_recent_time {
most_recent_time = p.time_sent;
}
}
self.on_new_congestion_event(most_recent_time);
}
fn on_new_congestion_event(&mut self, sent_time: Duration) {
if self.in_congestion_recovery(sent_time) {
return;
}
self.recovery_start_time = Some(sent_time);
let new_ssthresh =
self.cwnd.saturating_mul(K_LOSS_REDUCTION_FACTOR_NUM) / K_LOSS_REDUCTION_FACTOR_DEN;
self.ssthresh = new_ssthresh;
let min = k_minimum_window(self.max_datagram_size);
self.cwnd = core::cmp::max(new_ssthresh, min);
}
pub(crate) fn on_persistent_congestion(&mut self) {
self.cwnd = k_minimum_window(self.max_datagram_size);
self.recovery_start_time = None;
}
pub(crate) fn in_congestion_recovery(&self, time_sent: Duration) -> bool {
match self.recovery_start_time {
Some(t) => time_sent <= t,
None => false,
}
}
pub(crate) fn can_send(&self) -> bool {
self.bytes_in_flight < self.cwnd
}
pub(crate) fn on_ecn_ce_increase(&mut self, space: PnSpaceId, new_count: u64) {
self.ecn_ce_counters[space as usize] = new_count;
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec::Vec;
fn mk(pn: u64, time_sent: Duration, bytes: u16, in_flight: bool) -> SentPacket {
SentPacket {
pn,
sent_bytes: bytes,
ack_eliciting: true,
in_flight,
time_sent,
retransmit_hint: Vec::new(),
}
}
#[test]
fn fresh_state() {
let c = NewReno::new();
assert_eq!(c.cwnd, 10 * 1200);
assert_eq!(c.ssthresh, u64::MAX);
assert_eq!(c.bytes_in_flight, 0);
assert!(c.recovery_start_time.is_none());
assert!(c.can_send());
}
#[test]
fn slow_start_then_avoidance() {
let mut c = NewReno::new();
c.ssthresh = 20 * 1200;
let start_cwnd = c.cwnd;
let ack = alloc::vec![mk(0, Duration::ZERO, 1200, true)];
c.on_packet_sent(1200);
c.on_packets_acked(&ack);
assert_eq!(c.cwnd, start_cwnd + 1200);
for i in 1..15 {
let a = alloc::vec![mk(i, Duration::ZERO, 1200, true)];
c.on_packet_sent(1200);
c.on_packets_acked(&a);
}
assert!(c.cwnd >= c.ssthresh, "cwnd={} ss={}", c.cwnd, c.ssthresh);
let before = c.cwnd;
let a = alloc::vec![mk(99, Duration::ZERO, 1200, true)];
c.on_packet_sent(1200);
c.on_packets_acked(&a);
let after = c.cwnd;
let delta = after - before;
assert!(delta < 1200, "avoidance delta {delta}");
assert!(delta >= 1, "monotonic growth");
}
#[test]
fn ack_only_packet_not_in_flight() {
let c = NewReno::new();
let before = c.bytes_in_flight;
let _ = mk(0, Duration::ZERO, 100, false);
assert_eq!(c.bytes_in_flight, before);
}
#[test]
fn spurious_loss_recovery_does_not_re_enter_cwnd() {
let mut c = NewReno::new();
let sent_time = Duration::from_millis(0);
let p = mk(0, sent_time, 1200, true);
c.on_packet_sent(1200);
c.on_packets_lost(core::slice::from_ref(&p), Duration::from_millis(100));
let post_loss_cwnd = c.cwnd;
assert!(post_loss_cwnd < 10 * 1200, "cwnd should halve");
assert!(c.recovery_start_time.is_some());
c.on_packets_acked(&[p]);
assert_eq!(c.cwnd, post_loss_cwnd);
}
#[test]
fn on_persistent_congestion_resets_cwnd() {
let mut c = NewReno::new();
c.cwnd = 50 * 1200;
c.ssthresh = 20 * 1200;
c.on_persistent_congestion();
assert_eq!(c.cwnd, k_minimum_window(c.max_datagram_size));
assert_eq!(c.ssthresh, 20 * 1200, "ssthresh preserved");
assert!(c.recovery_start_time.is_none());
}
#[test]
fn loss_during_existing_recovery_does_not_halve_twice() {
let mut c = NewReno::new();
c.on_packet_sent(1200);
c.on_packets_lost(
&[mk(0, Duration::from_millis(100), 1200, true)],
Duration::from_millis(150),
);
let cwnd_after_first = c.cwnd;
c.on_packets_lost(
&[mk(1, Duration::from_millis(50), 1200, true)],
Duration::from_millis(200),
);
assert_eq!(c.cwnd, cwnd_after_first);
}
#[test]
fn bytes_in_flight_tracking() {
let mut c = NewReno::new();
c.on_packet_sent(1200);
c.on_packet_sent(800);
assert_eq!(c.bytes_in_flight, 2000);
c.on_packets_acked(&[mk(0, Duration::ZERO, 1200, true)]);
assert_eq!(c.bytes_in_flight, 800);
}
}