noq-proto 0.17.0

State machine for the QUIC transport protocol
Documentation
use std::any::Any;
use std::cmp;
use std::sync::Arc;

use super::{BASE_DATAGRAM_SIZE, Controller, ControllerFactory};
use crate::connection::RttEstimator;
use crate::{Duration, Instant};

/// CUBIC Constants.
///
/// These are recommended value in RFC8312.
const BETA_CUBIC: f64 = 0.7;

const C: f64 = 0.4;

/// CUBIC State Variables.
///
/// We need to keep those variables across the connection.
/// k, w_max are described in the RFC.
#[derive(Debug, Default, Clone)]
pub(super) struct State {
    /// Time period that the cubic function takes to increase the window size to W_max.
    k: f64,

    /// Congestion window size when the last congestion event occurred.
    w_max: f64,

    /// Congestion window increment stored during congestion avoidance.
    cwnd_inc: u64,

    /// Maximum number of bytes in flight that may be sent.
    window: u64,

    /// Slow start threshold in bytes.
    ///
    /// When the congestion window is below ssthresh, the mode is slow start
    /// and the window grows by the number of bytes acknowledged.
    ssthresh: u64,

    /// The time when QUIC first detects a loss, causing it to enter recovery. When a packet sent
    /// after this time is acknowledged, QUIC exits recovery.
    recovery_start_time: Option<Instant>,
}

/// CUBIC Functions.
///
/// Note that these calculations are based on a count of cwnd as bytes,
/// not packets.
/// Unit of t (duration) and RTT are based on seconds (f64).
impl State {
    // K = cbrt(w_max * (1 - beta_cubic) / C) (Eq. 2)
    fn cubic_k(&self, max_datagram_size: u64) -> f64 {
        let w_max = self.w_max / max_datagram_size as f64;
        (w_max * (1.0 - BETA_CUBIC) / C).cbrt()
    }

    // W_cubic(t) = C * (t - K)^3 + w_max (Eq. 1)
    fn w_cubic(&self, t: Duration, max_datagram_size: u64) -> f64 {
        let w_max = self.w_max / max_datagram_size as f64;

        (C * (t.as_secs_f64() - self.k).powi(3) + w_max) * max_datagram_size as f64
    }

    // W_est(t) = w_max * beta_cubic + 3 * (1 - beta_cubic) / (1 + beta_cubic) *
    // (t / RTT) (Eq. 4)
    fn w_est(&self, t: Duration, rtt: Duration, max_datagram_size: u64) -> f64 {
        let w_max = self.w_max / max_datagram_size as f64;
        (w_max * BETA_CUBIC
            + 3.0 * (1.0 - BETA_CUBIC) / (1.0 + BETA_CUBIC) * t.as_secs_f64() / rtt.as_secs_f64())
            * max_datagram_size as f64
    }
}

/// The RFC8312 congestion controller, as widely used for TCP
#[derive(Debug, Clone)]
pub struct Cubic {
    config: Arc<CubicConfig>,
    current_mtu: u64,
    state: State,
    /// Copy of the controller state to restore when a spurious congestion event is detected.
    pre_congestion_state: Option<State>,
}

impl Cubic {
    /// Construct a state using the given `config` and current time `now`
    pub fn new(config: Arc<CubicConfig>, _now: Instant, current_mtu: u16) -> Self {
        Self {
            state: State {
                window: config.initial_window,
                ssthresh: u64::MAX,
                ..Default::default()
            },
            current_mtu: current_mtu as u64,
            pre_congestion_state: None,
            config,
        }
    }

    fn minimum_window(&self) -> u64 {
        2 * self.current_mtu
    }
}

impl Controller for Cubic {
    fn on_ack(
        &mut self,
        now: Instant,
        sent: Instant,
        bytes: u64,
        app_limited: bool,
        rtt: &RttEstimator,
    ) {
        if app_limited
            || self
                .state
                .recovery_start_time
                .map(|recovery_start_time| sent <= recovery_start_time)
                .unwrap_or(false)
        {
            return;
        }

        if self.state.window < self.state.ssthresh {
            // Slow start
            self.state.window += bytes;
        } else {
            // Congestion avoidance.
            let ca_start_time;

            match self.state.recovery_start_time {
                Some(t) => ca_start_time = t,
                None => {
                    // When we come here without congestion_event() triggered,
                    // initialize congestion_recovery_start_time, w_max and k.
                    ca_start_time = now;
                    self.state.recovery_start_time = Some(now);

                    self.state.w_max = self.state.window as f64;
                    self.state.k = 0.0;
                }
            }

            let t = now - ca_start_time;

            // w_cubic(t + rtt)
            let w_cubic = self.state.w_cubic(t + rtt.get(), self.current_mtu);

            // w_est(t)
            let w_est = self.state.w_est(t, rtt.get(), self.current_mtu);

            let mut cubic_cwnd = self.state.window;

            if w_cubic < w_est {
                // TCP friendly region.
                cubic_cwnd = cmp::max(cubic_cwnd, w_est as u64);
            } else if cubic_cwnd < w_cubic as u64 {
                // Concave region or convex region use same increment.
                let cubic_inc =
                    (w_cubic - cubic_cwnd as f64) / cubic_cwnd as f64 * self.current_mtu as f64;

                cubic_cwnd += cubic_inc as u64;
            }

            // Update the increment and increase cwnd by MSS.
            self.state.cwnd_inc += cubic_cwnd - self.state.window;

            // cwnd_inc can be more than 1 MSS in the late stage of max probing.
            // however RFC9002 ยง7.3.3 (Congestion Avoidance) limits
            // the increase of cwnd to 1 max_datagram_size per cwnd acknowledged.
            if self.state.cwnd_inc >= self.current_mtu {
                self.state.window += self.current_mtu;
                self.state.cwnd_inc = 0;
            }
        }
    }

    fn on_congestion_event(
        &mut self,
        now: Instant,
        sent: Instant,
        is_persistent_congestion: bool,
        is_ecn: bool,
        _lost_bytes: u64,
    ) {
        if self
            .state
            .recovery_start_time
            .map(|recovery_start_time| sent <= recovery_start_time)
            .unwrap_or(false)
        {
            return;
        }

        // Save state in case this event ends up being spurious
        if !is_ecn {
            self.pre_congestion_state = Some(self.state.clone());
        }

        self.state.recovery_start_time = Some(now);

        // Fast convergence
        if (self.state.window as f64) < self.state.w_max {
            self.state.w_max = self.state.window as f64 * (1.0 + BETA_CUBIC) / 2.0;
        } else {
            self.state.w_max = self.state.window as f64;
        }

        self.state.ssthresh = cmp::max(
            (self.state.w_max * BETA_CUBIC) as u64,
            self.minimum_window(),
        );
        self.state.window = self.state.ssthresh;
        self.state.k = self.state.cubic_k(self.current_mtu);

        self.state.cwnd_inc = (self.state.cwnd_inc as f64 * BETA_CUBIC) as u64;

        if is_persistent_congestion {
            self.state.recovery_start_time = None;
            self.state.w_max = self.state.window as f64;

            // 4.7 Timeout - reduce ssthresh based on BETA_CUBIC
            self.state.ssthresh = cmp::max(
                (self.state.window as f64 * BETA_CUBIC) as u64,
                self.minimum_window(),
            );

            self.state.cwnd_inc = 0;

            self.state.window = self.minimum_window();
        }
    }

    fn on_spurious_congestion_event(&mut self) {
        if let Some(prior_state) = self.pre_congestion_state.take()
            && self.state.window < prior_state.window
        {
            self.state = prior_state;
        }
    }

    fn on_mtu_update(&mut self, new_mtu: u16) {
        self.current_mtu = new_mtu as u64;
        self.state.window = self.state.window.max(self.minimum_window());
    }

    fn window(&self) -> u64 {
        self.state.window
    }

    fn metrics(&self) -> super::ControllerMetrics {
        super::ControllerMetrics {
            congestion_window: self.window(),
            ssthresh: Some(self.state.ssthresh),
            pacing_rate: None,
        }
    }

    fn clone_box(&self) -> Box<dyn Controller> {
        Box::new(self.clone())
    }

    fn initial_window(&self) -> u64 {
        self.config.initial_window
    }

    fn into_any(self: Box<Self>) -> Box<dyn Any> {
        self
    }
}

/// Configuration for the `Cubic` congestion controller
#[derive(Debug, Clone)]
pub struct CubicConfig {
    initial_window: u64,
}

impl CubicConfig {
    /// Default limit on the amount of outstanding data in bytes.
    ///
    /// Recommended value: `min(10 * max_datagram_size, max(2 * max_datagram_size, 14720))`
    pub fn initial_window(&mut self, value: u64) -> &mut Self {
        self.initial_window = value;
        self
    }
}

impl Default for CubicConfig {
    fn default() -> Self {
        Self {
            initial_window: 14720.clamp(2 * BASE_DATAGRAM_SIZE, 10 * BASE_DATAGRAM_SIZE),
        }
    }
}

impl ControllerFactory for CubicConfig {
    fn build(self: Arc<Self>, now: Instant, current_mtu: u16) -> Box<dyn Controller> {
        Box::new(Cubic::new(self, now, current_mtu))
    }
}