const INITIAL_CWND_MSS: u64 = 10;
enum Mode {
UnboundedSlowStart,
ThresholdSlowStart(u64),
CongestionAvoidance,
}
pub struct AimdReno {
mode: Mode,
mss_q16: u64,
cwnd_q16: u64,
dropcnt_q16: u64,
}
impl AimdReno {
pub fn new(mss: usize) -> Self {
assert!(mss <= u16::MAX.into());
Self {
mode: Mode::UnboundedSlowStart,
mss_q16: (mss as u64) << 16,
cwnd_q16: (INITIAL_CWND_MSS * mss as u64) << 16,
dropcnt_q16: 0,
}
}
pub fn handle_ack(&mut self) {
if self.dropcnt_q16 > self.mss_q16 {
self.dropcnt_q16 -= self.mss_q16;
} else {
self.dropcnt_q16 = 0;
}
match self.mode {
Mode::UnboundedSlowStart => {
self.cwnd_q16 = self.cwnd_q16.saturating_add(self.mss_q16);
}
Mode::ThresholdSlowStart(ssthresh) => {
let new_cwnd_q16 = self.cwnd_q16.saturating_add(self.mss_q16);
if new_cwnd_q16 >= ssthresh {
self.cwnd_q16 = ssthresh;
self.mode = Mode::CongestionAvoidance;
} else {
self.cwnd_q16 = new_cwnd_q16;
}
}
Mode::CongestionAvoidance => {
let mss_q16 = self.mss_q16;
let cwnd_q16 = self.cwnd_q16;
let cwnd_inv_q16 = (mss_q16 * mss_q16 + cwnd_q16 / 2) / cwnd_q16;
self.cwnd_q16 = cwnd_q16.saturating_add(cwnd_inv_q16);
}
}
}
pub fn handle_timeout(&mut self) {
let cwnd_q16 = self.cwnd_q16;
let cwnd_min_q16 = INITIAL_CWND_MSS * self.mss_q16;
self.cwnd_q16 = cwnd_min_q16;
self.dropcnt_q16 = 0;
self.mode = Mode::ThresholdSlowStart((cwnd_q16 / 2).max(cwnd_min_q16));
}
pub fn handle_drop(&mut self) {
let cwnd_q16 = self.cwnd_q16;
let cwnd_min_q16 = INITIAL_CWND_MSS * self.mss_q16;
if self.dropcnt_q16 == 0 {
let new_cwnd_q16 = (cwnd_q16 / 2).max(cwnd_min_q16);
self.cwnd_q16 = new_cwnd_q16;
self.dropcnt_q16 = new_cwnd_q16;
self.mode = Mode::CongestionAvoidance;
}
}
pub fn cwnd(&self) -> usize {
let cwnd_int = self.cwnd_q16 >> 16;
if let Ok(cwnd_int) = cwnd_int.try_into() {
cwnd_int
} else {
usize::MAX
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_cwnd_near(cwnd: usize, expected: f64) {
let cwnd_f64 = cwnd as f64;
if (cwnd_f64 - expected).abs() > 1.0 {
panic!("expected cwnd near {}, found {}", expected, cwnd);
}
}
#[test]
fn basic() {
let mss = 100;
let mss_f64 = mss as f64;
let mut cc = AimdReno::new(mss);
let mut expected_cwnd = 1000.0;
assert_cwnd_near(cc.cwnd(), expected_cwnd);
for _ in 0..30 {
cc.handle_ack();
expected_cwnd += 100.0;
assert_cwnd_near(cc.cwnd(), expected_cwnd);
}
cc.handle_drop();
expected_cwnd /= 2.0;
assert_cwnd_near(cc.cwnd(), expected_cwnd);
cc.handle_drop();
assert_cwnd_near(cc.cwnd(), expected_cwnd);
for _ in 0..30 {
cc.handle_ack();
expected_cwnd += mss_f64 * mss_f64 / expected_cwnd;
assert_cwnd_near(cc.cwnd(), expected_cwnd);
}
cc.handle_drop();
expected_cwnd /= 2.0;
assert_cwnd_near(cc.cwnd(), expected_cwnd);
}
}