use std::any::Any;
use std::cmp;
use std::sync::Arc;
use super::{BASE_DATAGRAM_SIZE, Controller, ControllerFactory};
use crate::connection::RttEstimator;
use crate::{Duration, Instant};
const BETA_CUBIC: f64 = 0.7;
const C: f64 = 0.4;
#[derive(Debug, Default, Clone)]
pub(super) struct State {
k: f64,
w_max: f64,
cwnd_inc: u64,
window: u64,
ssthresh: u64,
recovery_start_time: Option<Instant>,
}
impl State {
fn cubic_k(&self, max_datagram_size: u64) -> f64 {
let w_max = self.w_max / max_datagram_size as f64;
(w_max * (1.0 - BETA_CUBIC) / C).cbrt()
}
fn w_cubic(&self, t: Duration, max_datagram_size: u64) -> f64 {
let w_max = self.w_max / max_datagram_size as f64;
(C * (t.as_secs_f64() - self.k).powi(3) + w_max) * max_datagram_size as f64
}
fn w_est(&self, t: Duration, rtt: Duration, max_datagram_size: u64) -> f64 {
let w_max = self.w_max / max_datagram_size as f64;
(w_max * BETA_CUBIC
+ 3.0 * (1.0 - BETA_CUBIC) / (1.0 + BETA_CUBIC) * t.as_secs_f64() / rtt.as_secs_f64())
* max_datagram_size as f64
}
}
#[derive(Debug, Clone)]
pub struct Cubic {
config: Arc<CubicConfig>,
current_mtu: u64,
state: State,
pre_congestion_state: Option<State>,
}
impl Cubic {
pub fn new(config: Arc<CubicConfig>, _now: Instant, current_mtu: u16) -> Self {
Self {
state: State {
window: config.initial_window,
ssthresh: u64::MAX,
..Default::default()
},
current_mtu: current_mtu as u64,
pre_congestion_state: None,
config,
}
}
fn minimum_window(&self) -> u64 {
2 * self.current_mtu
}
}
impl Controller for Cubic {
fn on_ack(
&mut self,
now: Instant,
sent: Instant,
bytes: u64,
app_limited: bool,
rtt: &RttEstimator,
) {
if app_limited
|| self
.state
.recovery_start_time
.map(|recovery_start_time| sent <= recovery_start_time)
.unwrap_or(false)
{
return;
}
if self.state.window < self.state.ssthresh {
self.state.window += bytes;
} else {
let ca_start_time;
match self.state.recovery_start_time {
Some(t) => ca_start_time = t,
None => {
ca_start_time = now;
self.state.recovery_start_time = Some(now);
self.state.w_max = self.state.window as f64;
self.state.k = 0.0;
}
}
let t = now - ca_start_time;
let w_cubic = self.state.w_cubic(t + rtt.get(), self.current_mtu);
let w_est = self.state.w_est(t, rtt.get(), self.current_mtu);
let mut cubic_cwnd = self.state.window;
if w_cubic < w_est {
cubic_cwnd = cmp::max(cubic_cwnd, w_est as u64);
} else if cubic_cwnd < w_cubic as u64 {
let cubic_inc =
(w_cubic - cubic_cwnd as f64) / cubic_cwnd as f64 * self.current_mtu as f64;
cubic_cwnd += cubic_inc as u64;
}
self.state.cwnd_inc += cubic_cwnd - self.state.window;
if self.state.cwnd_inc >= self.current_mtu {
self.state.window += self.current_mtu;
self.state.cwnd_inc = 0;
}
}
}
fn on_congestion_event(
&mut self,
now: Instant,
sent: Instant,
is_persistent_congestion: bool,
is_ecn: bool,
_lost_bytes: u64,
) {
if self
.state
.recovery_start_time
.map(|recovery_start_time| sent <= recovery_start_time)
.unwrap_or(false)
{
return;
}
if !is_ecn {
self.pre_congestion_state = Some(self.state.clone());
}
self.state.recovery_start_time = Some(now);
if (self.state.window as f64) < self.state.w_max {
self.state.w_max = self.state.window as f64 * (1.0 + BETA_CUBIC) / 2.0;
} else {
self.state.w_max = self.state.window as f64;
}
self.state.ssthresh = cmp::max(
(self.state.w_max * BETA_CUBIC) as u64,
self.minimum_window(),
);
self.state.window = self.state.ssthresh;
self.state.k = self.state.cubic_k(self.current_mtu);
self.state.cwnd_inc = (self.state.cwnd_inc as f64 * BETA_CUBIC) as u64;
if is_persistent_congestion {
self.state.recovery_start_time = None;
self.state.w_max = self.state.window as f64;
self.state.ssthresh = cmp::max(
(self.state.window as f64 * BETA_CUBIC) as u64,
self.minimum_window(),
);
self.state.cwnd_inc = 0;
self.state.window = self.minimum_window();
}
}
fn on_spurious_congestion_event(&mut self) {
if let Some(prior_state) = self.pre_congestion_state.take()
&& self.state.window < prior_state.window
{
self.state = prior_state;
}
}
fn on_mtu_update(&mut self, new_mtu: u16) {
self.current_mtu = new_mtu as u64;
self.state.window = self.state.window.max(self.minimum_window());
}
fn window(&self) -> u64 {
self.state.window
}
fn metrics(&self) -> super::ControllerMetrics {
super::ControllerMetrics {
congestion_window: self.window(),
ssthresh: Some(self.state.ssthresh),
pacing_rate: None,
}
}
fn clone_box(&self) -> Box<dyn Controller> {
Box::new(self.clone())
}
fn initial_window(&self) -> u64 {
self.config.initial_window
}
fn into_any(self: Box<Self>) -> Box<dyn Any> {
self
}
}
#[derive(Debug, Clone)]
pub struct CubicConfig {
initial_window: u64,
}
impl CubicConfig {
pub fn initial_window(&mut self, value: u64) -> &mut Self {
self.initial_window = value;
self
}
}
impl Default for CubicConfig {
fn default() -> Self {
Self {
initial_window: 14720.clamp(2 * BASE_DATAGRAM_SIZE, 10 * BASE_DATAGRAM_SIZE),
}
}
}
impl ControllerFactory for CubicConfig {
fn build(self: Arc<Self>, now: Instant, current_mtu: u16) -> Box<dyn Controller> {
Box::new(Cubic::new(self, now, current_mtu))
}
}