use crate::time::Timestamp;
use core::time::Duration;
#[cfg(not(feature = "std"))]
use num_traits::Float as _;
#[derive(Clone, Debug)]
pub struct HybridSlowStart {
sample_count: usize,
last_min_rtt: Option<Duration>,
cur_min_rtt: Option<Duration>,
pub(super) threshold: f32,
max_datagram_size: u16,
rtt_round_end_time: Option<Timestamp>,
use_hystart_plus_plus: bool,
ss_growth_divisor: f32,
css_count: usize,
css_baseline_min_rtt: Duration,
css_threshold: f32,
}
const LOW_SSTHRESH: f32 = 16.0;
const THRESHOLD_DIVIDEND: u32 = 8;
const N_SAMPLING: usize = 8;
const MIN_DELAY_THRESHOLD: Duration = Duration::from_millis(4);
const MAX_DELAY_THRESHOLD: Duration = Duration::from_millis(16);
const CSS_GROWTH_DIVISOR: f32 = 4.0;
const CSS_ROUNDS: usize = 5;
#[cfg(feature = "std")]
const USE_HYSTART_PLUS_PLUS: &str = "S2N_UNSTABLE_USE_HYSTART_PP";
impl HybridSlowStart {
pub fn new(max_datagram_size: u16) -> Self {
Self {
sample_count: 0,
last_min_rtt: None,
cur_min_rtt: None,
threshold: f32::MAX,
max_datagram_size,
rtt_round_end_time: None,
use_hystart_plus_plus: Self::use_hystart_parameter(),
ss_growth_divisor: 1.0,
css_count: 0,
css_baseline_min_rtt: Duration::ZERO,
css_threshold: f32::MAX,
}
}
pub fn on_rtt_update(
&mut self,
congestion_window: f32,
time_sent: Timestamp,
time_of_last_sent_packet: Timestamp,
rtt: Duration,
) {
let ss_threshold_found = self.threshold < f32::MAX;
if congestion_window >= self.threshold || (self.use_hystart_plus_plus && ss_threshold_found)
{
return;
}
let rtt_round_is_over = self
.rtt_round_end_time
.is_none_or(|end_time| time_sent >= end_time);
if rtt_round_is_over {
self.last_min_rtt = self.cur_min_rtt;
self.cur_min_rtt = None;
self.sample_count = 0;
self.rtt_round_end_time = Some(time_of_last_sent_packet);
}
if self.sample_count < N_SAMPLING {
self.cur_min_rtt = Some(rtt.min(self.cur_min_rtt.unwrap_or(rtt)));
}
self.sample_count += 1;
if let (N_SAMPLING, Some(last_min_rtt), Some(cur_min_rtt)) =
(self.sample_count, self.last_min_rtt, self.cur_min_rtt)
{
if congestion_window >= self.css_threshold {
self.css_count += 1;
if cur_min_rtt < self.css_baseline_min_rtt {
self.css_threshold = self.threshold;
self.ss_growth_divisor = 1.0;
self.css_count = 0;
}
if self.css_count >= CSS_ROUNDS {
self.threshold = congestion_window;
self.css_threshold = f32::MAX;
self.ss_growth_divisor = 1.0;
}
} else {
let threshold = last_min_rtt / THRESHOLD_DIVIDEND;
let threshold = threshold.min(MAX_DELAY_THRESHOLD).max(MIN_DELAY_THRESHOLD);
let delay_increase_is_over_threshold = cur_min_rtt >= last_min_rtt + threshold;
let congestion_window_is_above_minimum = congestion_window >= self.low_ssthresh();
if self.use_hystart_plus_plus {
if delay_increase_is_over_threshold {
self.css_threshold = congestion_window;
self.css_baseline_min_rtt = cur_min_rtt;
self.ss_growth_divisor = CSS_GROWTH_DIVISOR;
self.css_count = 0;
}
} else if delay_increase_is_over_threshold && congestion_window_is_above_minimum {
self.threshold = congestion_window;
}
}
}
}
pub fn cwnd_increment(&self, sent_bytes: usize) -> f32 {
if cfg!(debug_assertions) && !self.use_hystart_plus_plus {
assert!((self.ss_growth_divisor - 1.0).abs() < f32::EPSILON);
}
(sent_bytes as f32) / self.ss_growth_divisor
}
pub fn on_congestion_event(&mut self, ssthresh: f32) {
self.threshold = self.threshold.min(ssthresh).max(self.low_ssthresh());
self.ss_growth_divisor = 1.0;
self.css_threshold = f32::MAX;
}
fn low_ssthresh(&self) -> f32 {
LOW_SSTHRESH * self.max_datagram_size as f32
}
#[cfg(feature = "std")]
fn use_hystart_parameter() -> bool {
use once_cell::sync::OnceCell;
static USE_HYSTART_PP: OnceCell<bool> = OnceCell::new();
*USE_HYSTART_PP.get_or_init(|| std::env::var(USE_HYSTART_PLUS_PLUS).is_ok())
}
#[cfg(not(feature = "std"))]
fn use_hystart_parameter() -> bool {
false
}
}
#[cfg(test)]
mod test {
use crate::{
assert_delta,
recovery::hybrid_slow_start::HybridSlowStart,
time::{Clock, NoopClock},
};
use core::time::Duration;
#[test]
fn on_congestion_event() {
let mut slow_start = HybridSlowStart::new(10);
slow_start.threshold = 501.0;
slow_start.on_congestion_event(500.0);
assert_delta!(slow_start.threshold, 500.0, 0.001);
slow_start.threshold = 501.0;
slow_start.on_congestion_event(502.0);
assert_delta!(slow_start.threshold, 501.0, 0.001);
slow_start.threshold = 501.0;
slow_start.on_congestion_event(slow_start.low_ssthresh() - 1.0);
assert_delta!(slow_start.threshold, slow_start.low_ssthresh(), 0.001);
}
#[test]
fn on_rtt_update_above_threshold() {
let mut slow_start = HybridSlowStart::new(10);
let time_zero = NoopClock.get_time();
slow_start.threshold = 500.0;
assert_eq!(slow_start.sample_count, 0);
slow_start.on_rtt_update(750.0, time_zero, time_zero, Duration::from_secs(1));
assert_delta!(slow_start.threshold, 500.0, 0.001);
assert_eq!(slow_start.sample_count, 0);
}
#[test]
fn on_rtt_update() {
let mut slow_start = HybridSlowStart::new(10);
assert_eq!(slow_start.sample_count, 0);
let time_zero = NoopClock.get_time() + Duration::from_secs(10);
let time_of_last_sent_packet = time_zero + Duration::from_millis(9);
for i in 0..=6 {
slow_start.on_rtt_update(
1000.0,
time_zero + Duration::from_millis(i),
time_of_last_sent_packet,
Duration::from_millis(200),
);
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(200)));
}
assert_eq!(
slow_start.rtt_round_end_time,
Some(time_of_last_sent_packet)
);
slow_start.on_rtt_update(
1000.0,
time_zero + Duration::from_millis(7),
time_of_last_sent_packet,
Duration::from_millis(100),
);
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(100)));
slow_start.on_rtt_update(
1000.0,
time_zero + Duration::from_millis(8),
time_of_last_sent_packet,
Duration::from_millis(50),
);
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(100)));
let time_of_last_sent_packet = time_zero + Duration::from_millis(29);
slow_start.on_rtt_update(
1000.0,
time_zero + Duration::from_millis(9),
time_of_last_sent_packet,
Duration::from_millis(400),
);
assert_eq!(slow_start.last_min_rtt, Some(Duration::from_millis(100)));
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(400)));
assert_eq!(
slow_start.rtt_round_end_time,
Some(time_of_last_sent_packet)
);
for i in 20..=25 {
slow_start.on_rtt_update(
1000.0,
time_zero + Duration::from_millis(i),
time_of_last_sent_packet,
Duration::from_millis(500),
);
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(400)));
}
slow_start.on_rtt_update(
2000.0,
time_zero + Duration::from_millis(27),
time_of_last_sent_packet,
Duration::from_millis(112),
);
assert_eq!(slow_start.last_min_rtt, Some(Duration::from_millis(100)));
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(112)));
assert_delta!(slow_start.threshold, f32::MAX, 0.001);
let time_of_last_sent_packet = time_zero + Duration::from_millis(49);
for i in 40..=46 {
slow_start.on_rtt_update(
1000.0,
time_zero + Duration::from_millis(i),
time_of_last_sent_packet,
Duration::from_millis(500),
);
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(500)));
}
slow_start.on_rtt_update(
5000.0,
time_zero + Duration::from_millis(38),
time_of_last_sent_packet,
Duration::from_millis(126),
);
assert_eq!(slow_start.last_min_rtt, Some(Duration::from_millis(112)));
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(126)));
assert_delta!(slow_start.threshold, 5000.0, 0.001);
}
#[test]
fn on_rtt_update_with_hystartplus_1() {
let mut slow_start = HybridSlowStart::new(10);
slow_start.use_hystart_plus_plus = true;
assert_eq!(slow_start.sample_count, 0);
let time_zero = NoopClock.get_time() + Duration::from_secs(10);
let time_of_last_sent_packet = time_zero + Duration::from_millis(9);
for i in 0..=6 {
slow_start.on_rtt_update(
1000.0,
time_zero + Duration::from_millis(i),
time_of_last_sent_packet,
Duration::from_millis(200),
);
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(200)));
}
assert_eq!(
slow_start.rtt_round_end_time,
Some(time_of_last_sent_packet)
);
slow_start.on_rtt_update(
1000.0,
time_zero + Duration::from_millis(7),
time_of_last_sent_packet,
Duration::from_millis(100),
);
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(100)));
slow_start.on_rtt_update(
1000.0,
time_zero + Duration::from_millis(8),
time_of_last_sent_packet,
Duration::from_millis(50),
);
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(100)));
let time_of_last_sent_packet = time_zero + Duration::from_millis(29);
slow_start.on_rtt_update(
1000.0,
time_zero + Duration::from_millis(9),
time_of_last_sent_packet,
Duration::from_millis(400),
);
assert_eq!(slow_start.last_min_rtt, Some(Duration::from_millis(100)));
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(400)));
assert_eq!(
slow_start.rtt_round_end_time,
Some(time_of_last_sent_packet)
);
for i in 20..=25 {
slow_start.on_rtt_update(
1000.0,
time_zero + Duration::from_millis(i),
time_of_last_sent_packet,
Duration::from_millis(500),
);
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(400)));
}
slow_start.on_rtt_update(
2000.0,
time_zero + Duration::from_millis(27),
time_of_last_sent_packet,
Duration::from_millis(112),
);
assert_eq!(slow_start.last_min_rtt, Some(Duration::from_millis(100)));
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(112)));
assert_delta!(slow_start.threshold, f32::MAX, 0.001);
let time_of_last_sent_packet = time_zero + Duration::from_millis(49);
for i in 40..=46 {
slow_start.on_rtt_update(
1000.0,
time_zero + Duration::from_millis(i),
time_of_last_sent_packet,
Duration::from_millis(500),
);
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(500)));
}
slow_start.on_rtt_update(
5000.0,
time_zero + Duration::from_millis(38),
time_of_last_sent_packet,
Duration::from_millis(126),
);
assert_eq!(slow_start.last_min_rtt, Some(Duration::from_millis(112)));
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(126)));
assert_delta!(slow_start.css_threshold, 5000.0, 0.001);
assert_eq!(slow_start.css_baseline_min_rtt, Duration::from_millis(126));
assert_delta!(slow_start.ss_growth_divisor, 4.0, 0.001);
let time_of_last_sent_packet = time_zero + Duration::from_millis(69);
for i in 60..=66 {
slow_start.on_rtt_update(
1000.0,
time_zero + Duration::from_millis(i),
time_of_last_sent_packet,
Duration::from_millis(500),
);
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(500)));
}
slow_start.on_rtt_update(
5000.0,
time_zero + Duration::from_millis(38),
time_of_last_sent_packet,
Duration::from_millis(130),
);
let cwnd_increment = slow_start.cwnd_increment(1000);
assert_delta!(cwnd_increment, 250.0, 0.001);
assert_eq!(slow_start.last_min_rtt, Some(Duration::from_millis(126)));
assert_eq!(slow_start.css_baseline_min_rtt, Duration::from_millis(126));
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(130)));
assert_eq!(slow_start.css_count, 1);
}
#[test]
fn on_rtt_update_with_hystartplus_2() {
let mut slow_start = HybridSlowStart::new(10);
slow_start.use_hystart_plus_plus = true;
let time_zero = NoopClock.get_time() + Duration::from_secs(10);
slow_start.cur_min_rtt = Some(Duration::from_millis(112));
let time_of_last_sent_packet = time_zero + Duration::from_millis(49);
for i in 40..=46 {
slow_start.on_rtt_update(
1000.0,
time_zero + Duration::from_millis(i),
time_of_last_sent_packet,
Duration::from_millis(500),
);
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(500)));
}
slow_start.on_rtt_update(
5000.0,
time_zero + Duration::from_millis(38),
time_of_last_sent_packet,
Duration::from_millis(126),
);
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(126)));
assert_delta!(slow_start.css_threshold, 5000.0, 0.001);
assert_eq!(slow_start.css_baseline_min_rtt, Duration::from_millis(126));
assert_delta!(slow_start.ss_growth_divisor, 4.0, 0.001);
let time_of_last_sent_packet = time_zero + Duration::from_millis(69);
for i in 60..=66 {
slow_start.on_rtt_update(
1000.0,
time_zero + Duration::from_millis(i),
time_of_last_sent_packet,
Duration::from_millis(500),
);
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(500)));
}
slow_start.on_rtt_update(
5000.0,
time_zero + Duration::from_millis(38),
time_of_last_sent_packet,
Duration::from_millis(125),
);
assert_eq!(slow_start.last_min_rtt, Some(Duration::from_millis(126)));
assert_eq!(slow_start.css_baseline_min_rtt, Duration::from_millis(126));
assert_eq!(slow_start.cur_min_rtt, Some(Duration::from_millis(125)));
assert_delta!(slow_start.ss_growth_divisor, 1.0, 0.001);
assert_delta!(slow_start.css_threshold, f32::MAX, 0.001);
assert_eq!(slow_start.css_count, 0);
}
}