ant_quic/congestion/
cubic.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8use std::any::Any;
9use std::cmp;
10use std::sync::Arc;
11
12use super::{BASE_DATAGRAM_SIZE, Controller, ControllerFactory};
13use crate::connection::RttEstimator;
14use crate::{Duration, Instant};
15
16/// CUBIC Constants.
17///
18/// These are recommended value in RFC8312.
19const BETA_CUBIC: f64 = 0.7;
20
21const C: f64 = 0.4;
22
23/// CUBIC State Variables.
24///
25/// We need to keep those variables across the connection.
26/// k, w_max are described in the RFC.
27#[derive(Debug, Default, Clone)]
28pub(super) struct State {
29    k: f64,
30
31    w_max: f64,
32
33    // Store cwnd increment during congestion avoidance.
34    cwnd_inc: u64,
35}
36
37/// CUBIC Functions.
38///
39/// Note that these calculations are based on a count of cwnd as bytes,
40/// not packets.
41/// Unit of t (duration) and RTT are based on seconds (f64).
42impl State {
43    // K = cbrt(w_max * (1 - beta_cubic) / C) (Eq. 2)
44    fn cubic_k(&self, max_datagram_size: u64) -> f64 {
45        let w_max = self.w_max / max_datagram_size as f64;
46        (w_max * (1.0 - BETA_CUBIC) / C).cbrt()
47    }
48
49    // W_cubic(t) = C * (t - K)^3 - w_max (Eq. 1)
50    fn w_cubic(&self, t: Duration, max_datagram_size: u64) -> f64 {
51        let w_max = self.w_max / max_datagram_size as f64;
52
53        (C * (t.as_secs_f64() - self.k).powi(3) + w_max) * max_datagram_size as f64
54    }
55
56    // W_est(t) = w_max * beta_cubic + 3 * (1 - beta_cubic) / (1 + beta_cubic) *
57    // (t / RTT) (Eq. 4)
58    fn w_est(&self, t: Duration, rtt: Duration, max_datagram_size: u64) -> f64 {
59        let w_max = self.w_max / max_datagram_size as f64;
60        (w_max * BETA_CUBIC
61            + 3.0 * (1.0 - BETA_CUBIC) / (1.0 + BETA_CUBIC) * t.as_secs_f64() / rtt.as_secs_f64())
62            * max_datagram_size as f64
63    }
64}
65
66/// The RFC8312 congestion controller, as widely used for TCP
67#[derive(Debug, Clone)]
68pub(crate) struct Cubic {
69    config: Arc<CubicConfig>,
70    /// Maximum number of bytes in flight that may be sent.
71    window: u64,
72    /// Slow start threshold in bytes. When the congestion window is below ssthresh, the mode is
73    /// slow start and the window grows by the number of bytes acknowledged.
74    ssthresh: u64,
75    /// The time when QUIC first detects a loss, causing it to enter recovery. When a packet sent
76    /// after this time is acknowledged, QUIC exits recovery.
77    recovery_start_time: Option<Instant>,
78    cubic_state: State,
79    current_mtu: u64,
80}
81
82impl Cubic {
83    /// Construct a state using the given `config` and current time `now`
84    pub(crate) fn new(config: Arc<CubicConfig>, _now: Instant, current_mtu: u16) -> Self {
85        Self {
86            window: config.initial_window,
87            ssthresh: u64::MAX,
88            recovery_start_time: None,
89            config,
90            cubic_state: Default::default(),
91            current_mtu: current_mtu as u64,
92        }
93    }
94
95    fn minimum_window(&self) -> u64 {
96        2 * self.current_mtu
97    }
98}
99
100impl Controller for Cubic {
101    fn on_ack(
102        &mut self,
103        now: Instant,
104        sent: Instant,
105        bytes: u64,
106        app_limited: bool,
107        rtt: &RttEstimator,
108    ) {
109        if app_limited
110            || self
111                .recovery_start_time
112                .map(|recovery_start_time| sent <= recovery_start_time)
113                .unwrap_or(false)
114        {
115            return;
116        }
117
118        if self.window < self.ssthresh {
119            // Slow start
120            self.window += bytes;
121        } else {
122            // Congestion avoidance.
123            let ca_start_time;
124
125            match self.recovery_start_time {
126                Some(t) => ca_start_time = t,
127                None => {
128                    // When we come here without congestion_event() triggered,
129                    // initialize congestion_recovery_start_time, w_max and k.
130                    ca_start_time = now;
131                    self.recovery_start_time = Some(now);
132
133                    self.cubic_state.w_max = self.window as f64;
134                    self.cubic_state.k = 0.0;
135                }
136            }
137
138            let t = now - ca_start_time;
139
140            // w_cubic(t + rtt)
141            let w_cubic = self.cubic_state.w_cubic(t + rtt.get(), self.current_mtu);
142
143            // w_est(t)
144            let w_est = self.cubic_state.w_est(t, rtt.get(), self.current_mtu);
145
146            let mut cubic_cwnd = self.window;
147
148            if w_cubic < w_est {
149                // TCP friendly region.
150                cubic_cwnd = cmp::max(cubic_cwnd, w_est as u64);
151            } else if cubic_cwnd < w_cubic as u64 {
152                // Concave region or convex region use same increment.
153                let cubic_inc =
154                    (w_cubic - cubic_cwnd as f64) / cubic_cwnd as f64 * self.current_mtu as f64;
155
156                cubic_cwnd += cubic_inc as u64;
157            }
158
159            // Update the increment and increase cwnd by MSS.
160            self.cubic_state.cwnd_inc += cubic_cwnd - self.window;
161
162            // cwnd_inc can be more than 1 MSS in the late stage of max probing.
163            // however RFC9002 ยง7.3.3 (Congestion Avoidance) limits
164            // the increase of cwnd to 1 max_datagram_size per cwnd acknowledged.
165            if self.cubic_state.cwnd_inc >= self.current_mtu {
166                self.window += self.current_mtu;
167                self.cubic_state.cwnd_inc = 0;
168            }
169        }
170    }
171
172    fn on_congestion_event(
173        &mut self,
174        now: Instant,
175        sent: Instant,
176        is_persistent_congestion: bool,
177        _lost_bytes: u64,
178    ) {
179        if self
180            .recovery_start_time
181            .map(|recovery_start_time| sent <= recovery_start_time)
182            .unwrap_or(false)
183        {
184            return;
185        }
186
187        self.recovery_start_time = Some(now);
188
189        // Fast convergence
190        if (self.window as f64) < self.cubic_state.w_max {
191            self.cubic_state.w_max = self.window as f64 * (1.0 + BETA_CUBIC) / 2.0;
192        } else {
193            self.cubic_state.w_max = self.window as f64;
194        }
195
196        self.ssthresh = cmp::max(
197            (self.cubic_state.w_max * BETA_CUBIC) as u64,
198            self.minimum_window(),
199        );
200        self.window = self.ssthresh;
201        self.cubic_state.k = self.cubic_state.cubic_k(self.current_mtu);
202
203        self.cubic_state.cwnd_inc = (self.cubic_state.cwnd_inc as f64 * BETA_CUBIC) as u64;
204
205        if is_persistent_congestion {
206            self.recovery_start_time = None;
207            self.cubic_state.w_max = self.window as f64;
208
209            // 4.7 Timeout - reduce ssthresh based on BETA_CUBIC
210            self.ssthresh = cmp::max(
211                (self.window as f64 * BETA_CUBIC) as u64,
212                self.minimum_window(),
213            );
214
215            self.cubic_state.cwnd_inc = 0;
216
217            self.window = self.minimum_window();
218        }
219    }
220
221    fn on_mtu_update(&mut self, new_mtu: u16) {
222        self.current_mtu = new_mtu as u64;
223        self.window = self.window.max(self.minimum_window());
224    }
225
226    fn window(&self) -> u64 {
227        self.window
228    }
229
230    fn metrics(&self) -> super::ControllerMetrics {
231        super::ControllerMetrics {
232            congestion_window: self.window(),
233            ssthresh: Some(self.ssthresh),
234            pacing_rate: None,
235        }
236    }
237
238    fn clone_box(&self) -> Box<dyn Controller> {
239        Box::new(self.clone())
240    }
241
242    fn initial_window(&self) -> u64 {
243        self.config.initial_window
244    }
245
246    fn into_any(self: Box<Self>) -> Box<dyn Any> {
247        self
248    }
249}
250
251/// Configuration for the `Cubic` congestion controller
252#[derive(Debug, Clone)]
253pub(crate) struct CubicConfig {
254    initial_window: u64,
255}
256
257impl CubicConfig {
258    /// Default limit on the amount of outstanding data in bytes.
259    ///
260    /// Recommended value: `min(10 * max_datagram_size, max(2 * max_datagram_size, 14720))`
261    #[allow(dead_code)]
262    pub(crate) fn initial_window(&mut self, value: u64) -> &mut Self {
263        self.initial_window = value;
264        self
265    }
266}
267
268impl Default for CubicConfig {
269    fn default() -> Self {
270        Self {
271            initial_window: 14720.clamp(2 * BASE_DATAGRAM_SIZE, 10 * BASE_DATAGRAM_SIZE),
272        }
273    }
274}
275
276impl ControllerFactory for CubicConfig {
277    fn new_controller(
278        &self,
279        min_window: u64,
280        _max_window: u64,
281        now: Instant,
282    ) -> Box<dyn Controller + Send + Sync> {
283        let current_mtu = (min_window / 4).max(1200).min(65535) as u16; // Derive MTU from min_window
284        Box::new(Cubic::new(Arc::new(self.clone()), now, current_mtu))
285    }
286}