Skip to main content

vcl_protocol/
flow.rs

1//! # VCL Flow Control & Congestion Control
2//!
3//! [`FlowController`] implements sliding window flow control with
4//! AIMD (Additive Increase Multiplicative Decrease) congestion control
5//! and retransmission support.
6//!
7//! ## How it works
8//!
9//! ```text
10//! window_size = 4
11//!
12//! Sent but unacked:  [0] [1] [2] [3]   <- window full, must wait
13//! Acked:             [0]               <- window slides, can send [4]
14//! Sent but unacked:      [1] [2] [3] [4]
15//! ```
16//!
17//! ## Congestion Control (AIMD)
18//!
19//! ```text
20//! No loss:  cwnd += 1 per RTT    (additive increase)
21//! Loss:     cwnd *= 0.5          (multiplicative decrease)
22//! Min cwnd: 1
23//! Max cwnd: bounded by window_size
24//! ```
25//!
26//! ## Example
27//!
28//! ```rust
29//! use vcl_protocol::flow::FlowController;
30//!
31//! let mut fc = FlowController::new(4);
32//!
33//! // Send packets
34//! assert!(fc.can_send());
35//! fc.on_send(0, vec![0]);
36//! assert!(!fc.can_send()); // window full
37//!
38//! // Acknowledge packets
39//! fc.on_ack(0);
40//! assert!(fc.can_send()); // window has space again
41//! ```
42
43use std::collections::BTreeSet;
44use std::time::{Duration, Instant};
45use tracing::{debug, info, warn};
46
47const DEFAULT_RTO_MS: u64 = 200;
48const AIMD_INCREASE_STEP: f64 = 1.0;
49const AIMD_DECREASE_FACTOR: f64 = 0.5;
50const CWND_MIN: f64 = 1.0;
51
52/// A packet that has been sent but not yet acknowledged.
53#[derive(Debug, Clone)]
54pub struct InFlightPacket {
55    /// Sequence number of the packet.
56    pub sequence: u64,
57    /// When this packet was last sent or retransmitted.
58    pub sent_at: Instant,
59    /// Number of times this packet has been retransmitted.
60    pub retransmit_count: u32,
61    /// Original data payload — stored for retransmission.
62    pub data: Vec<u8>,
63}
64
65impl InFlightPacket {
66    fn new(sequence: u64, data: Vec<u8>) -> Self {
67        InFlightPacket {
68            sequence,
69            sent_at: Instant::now(),
70            retransmit_count: 0,
71            data,
72        }
73    }
74
75    /// Returns `true` if this packet has exceeded the retransmission timeout.
76    pub fn is_timed_out(&self, rto: Duration) -> bool {
77        self.sent_at.elapsed() > rto
78    }
79}
80
81/// A packet that needs to be retransmitted.
82#[derive(Debug, Clone)]
83pub struct RetransmitRequest {
84    /// Sequence number of the packet to retransmit.
85    pub sequence: u64,
86    /// Original data payload to resend.
87    pub data: Vec<u8>,
88    /// How many times this packet has already been retransmitted.
89    pub retransmit_count: u32,
90}
91
92/// Sliding window flow controller with AIMD congestion control
93/// and retransmission support.
94pub struct FlowController {
95    window_size: usize,
96    cwnd: f64,
97    ssthresh: f64,
98    in_slow_start: bool,
99    in_flight: Vec<InFlightPacket>,
100    acked: BTreeSet<u64>,
101    rto: Duration,
102    srtt: Option<Duration>,
103    rttvar: Option<Duration>,
104    total_sent: u64,
105    total_acked: u64,
106    total_lost: u64,
107    total_retransmits: u64,
108}
109
110impl FlowController {
111    /// Create a new flow controller with the given maximum window size.
112    pub fn new(window_size: usize) -> Self {
113        debug!(window_size, "FlowController created");
114        FlowController {
115            window_size,
116            cwnd: 1.0,
117            ssthresh: window_size as f64 / 2.0,
118            in_slow_start: true,
119            in_flight: Vec::new(),
120            acked: BTreeSet::new(),
121            rto: Duration::from_millis(DEFAULT_RTO_MS),
122            srtt: None,
123            rttvar: None,
124            total_sent: 0,
125            total_acked: 0,
126            total_lost: 0,
127            total_retransmits: 0,
128        }
129    }
130
131    /// Create a flow controller with a custom retransmission timeout.
132    pub fn with_rto(window_size: usize, rto_ms: u64) -> Self {
133        let mut fc = Self::new(window_size);
134        fc.rto = Duration::from_millis(rto_ms);
135        fc
136    }
137
138    // ─── Window control ───────────────────────────────────────────────────────
139
140    /// Returns `true` if the effective window has space to send another packet.
141    pub fn can_send(&self) -> bool {
142        self.in_flight.len() < self.effective_window()
143    }
144
145    /// Returns how many more packets can be sent right now.
146    pub fn available_slots(&self) -> usize {
147        self.effective_window().saturating_sub(self.in_flight.len())
148    }
149
150    /// Returns the hard maximum window size.
151    pub fn window_size(&self) -> usize {
152        self.window_size
153    }
154
155    /// Returns the current congestion window size.
156    pub fn cwnd(&self) -> f64 {
157        self.cwnd
158    }
159
160    /// Returns the effective window: min(cwnd as usize, window_size), at least 1.
161    pub fn effective_window(&self) -> usize {
162        (self.cwnd as usize).min(self.window_size).max(1)
163    }
164
165    /// Returns `true` if currently in slow start phase.
166    pub fn in_slow_start(&self) -> bool {
167        self.in_slow_start
168    }
169
170    /// Dynamically adjust the hard maximum window size.
171    pub fn set_window_size(&mut self, size: usize) {
172        debug!(old = self.window_size, new = size, "Window size updated");
173        self.window_size = size;
174    }
175
176    /// Returns the number of packets currently in flight.
177    pub fn in_flight_count(&self) -> usize {
178        self.in_flight.len()
179    }
180
181    /// Returns the sequence number of the oldest unacknowledged packet.
182    pub fn oldest_unacked_sequence(&self) -> Option<u64> {
183        self.in_flight.first().map(|p| p.sequence)
184    }
185
186    // ─── Send / Ack ───────────────────────────────────────────────────────────
187
188    /// Register a packet as sent with its data payload (for retransmission).
189    ///
190    /// Returns `false` if the effective window is full.
191    pub fn on_send(&mut self, sequence: u64, data: Vec<u8>) -> bool {
192        if !self.can_send() {
193            warn!(
194                sequence,
195                in_flight = self.in_flight.len(),
196                cwnd = self.cwnd,
197                "on_send() called but window is full"
198            );
199            return false;
200        }
201        self.in_flight.push(InFlightPacket::new(sequence, data));
202        self.total_sent += 1;
203        debug!(
204            sequence,
205            in_flight = self.in_flight.len(),
206            cwnd = self.cwnd,
207            effective_window = self.effective_window(),
208            "Packet sent"
209        );
210        true
211    }
212
213    /// Register a packet as acknowledged.
214    ///
215    /// Updates RTT estimate and advances congestion window via AIMD.
216    /// Returns `true` if the packet was found and removed.
217    pub fn on_ack(&mut self, sequence: u64) -> bool {
218        if let Some(pos) = self.in_flight.iter().position(|p| p.sequence == sequence) {
219            let packet = self.in_flight.remove(pos);
220            let rtt = packet.sent_at.elapsed();
221            self.update_rtt(rtt);
222            self.acked.insert(sequence);
223            self.total_acked += 1;
224            self.on_ack_cwnd();
225            debug!(
226                sequence,
227                rtt_ms = rtt.as_millis(),
228                in_flight = self.in_flight.len(),
229                cwnd = self.cwnd,
230                in_slow_start = self.in_slow_start,
231                "Packet acked"
232            );
233            true
234        } else {
235            warn!(sequence, "on_ack() for unknown or duplicate sequence");
236            false
237        }
238    }
239
240    /// Returns all in-flight packets that have exceeded the retransmission timeout.
241    ///
242    /// Returns [`RetransmitRequest`]s with data to resend.
243    /// Triggers AIMD multiplicative decrease on loss.
244    pub fn timed_out_packets(&mut self) -> Vec<RetransmitRequest> {
245        let rto = self.rto;
246        let mut requests = Vec::new();
247        let mut had_loss = false;
248
249        for packet in self.in_flight.iter_mut() {
250            if packet.is_timed_out(rto) {
251                warn!(
252                    sequence = packet.sequence,
253                    retransmit_count = packet.retransmit_count,
254                    rto_ms = rto.as_millis(),
255                    "Packet timed out — queuing retransmission"
256                );
257                requests.push(RetransmitRequest {
258                    sequence: packet.sequence,
259                    data: packet.data.clone(),
260                    retransmit_count: packet.retransmit_count,
261                });
262                packet.retransmit_count += 1;
263                packet.sent_at = Instant::now();
264                self.total_lost += 1;
265                self.total_retransmits += 1;
266                had_loss = true;
267            }
268        }
269
270        if had_loss {
271            self.on_loss_cwnd();
272        }
273
274        requests
275    }
276
277    // ─── AIMD internals ───────────────────────────────────────────────────────
278
279    fn on_ack_cwnd(&mut self) {
280        if self.in_slow_start {
281            self.cwnd += AIMD_INCREASE_STEP;
282            if self.cwnd >= self.ssthresh {
283                self.in_slow_start = false;
284                info!(cwnd = self.cwnd, ssthresh = self.ssthresh, "Exiting slow start");
285            }
286        } else {
287            self.cwnd += AIMD_INCREASE_STEP / self.cwnd;
288        }
289        self.cwnd = self.cwnd.min(self.window_size as f64);
290        debug!(cwnd = self.cwnd, "AIMD: cwnd increased");
291    }
292
293    fn on_loss_cwnd(&mut self) {
294        self.ssthresh = (self.cwnd * AIMD_DECREASE_FACTOR).max(CWND_MIN);
295        self.cwnd = CWND_MIN;
296        self.in_slow_start = true;
297        self.rto = (self.rto * 2).min(Duration::from_secs(60));
298        warn!(
299            cwnd = self.cwnd,
300            ssthresh = self.ssthresh,
301            rto_ms = self.rto.as_millis(),
302            "AIMD: multiplicative decrease on loss"
303        );
304    }
305
306    // ─── RTT estimation (RFC 6298) ────────────────────────────────────────────
307
308    fn update_rtt(&mut self, rtt: Duration) {
309        match (self.srtt, self.rttvar) {
310            (None, None) => {
311                self.srtt = Some(rtt);
312                self.rttvar = Some(rtt / 2);
313            }
314            (Some(srtt), Some(rttvar)) => {
315                let rtt_ns = rtt.as_nanos() as i128;
316                let srtt_ns = srtt.as_nanos() as i128;
317                let rttvar_ns = rttvar.as_nanos() as i128;
318                let new_rttvar = (rttvar_ns * 3 / 4 + (srtt_ns - rtt_ns).abs() / 4).max(0) as u64;
319                let new_srtt = (srtt_ns * 7 / 8 + rtt_ns / 8).max(1) as u64;
320                self.rttvar = Some(Duration::from_nanos(new_rttvar));
321                self.srtt = Some(Duration::from_nanos(new_srtt));
322                let rto_ns = new_srtt + (new_rttvar * 4).max(1_000_000);
323                self.rto = Duration::from_nanos(rto_ns)
324                    .max(Duration::from_millis(50))
325                    .min(Duration::from_secs(60));
326            }
327            _ => {}
328        }
329    }
330
331    // ─── Stats ────────────────────────────────────────────────────────────────
332
333    /// Returns the current smoothed RTT estimate.
334    pub fn srtt(&self) -> Option<Duration> {
335        self.srtt
336    }
337
338    /// Returns the current RTT variance estimate.
339    pub fn rttvar(&self) -> Option<Duration> {
340        self.rttvar
341    }
342
343    /// Returns the current retransmission timeout.
344    pub fn rto(&self) -> Duration {
345        self.rto
346    }
347
348    /// Returns total packets sent.
349    pub fn total_sent(&self) -> u64 {
350        self.total_sent
351    }
352
353    /// Returns total packets acknowledged.
354    pub fn total_acked(&self) -> u64 {
355        self.total_acked
356    }
357
358    /// Returns total packets detected as lost.
359    pub fn total_lost(&self) -> u64 {
360        self.total_lost
361    }
362
363    /// Returns total retransmissions performed.
364    pub fn total_retransmits(&self) -> u64 {
365        self.total_retransmits
366    }
367
368    /// Returns the packet loss rate: lost / sent.
369    pub fn loss_rate(&self) -> f64 {
370        if self.total_sent == 0 { return 0.0; }
371        self.total_lost as f64 / self.total_sent as f64
372    }
373
374    /// Returns `true` if a sequence number has been acknowledged.
375    pub fn is_acked(&self, sequence: u64) -> bool {
376        self.acked.contains(&sequence)
377    }
378
379    /// Reset all state.
380    pub fn reset(&mut self) {
381        debug!("FlowController reset");
382        let window_size = self.window_size;
383        *self = Self::new(window_size);
384    }
385}
386
387impl Default for FlowController {
388    fn default() -> Self {
389        Self::new(64)
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn test_new() {
399        let fc = FlowController::new(4);
400        assert_eq!(fc.window_size(), 4);
401        assert_eq!(fc.in_flight_count(), 0);
402        assert!(fc.can_send());
403        assert_eq!(fc.available_slots(), 1);
404        assert!(fc.in_slow_start());
405    }
406
407    #[test]
408    fn test_window_full() {
409        let mut fc = FlowController::new(4);
410        assert!(fc.on_send(0, vec![0]));
411        assert!(!fc.can_send());
412        assert_eq!(fc.in_flight_count(), 1);
413    }
414
415    #[test]
416    fn test_ack_opens_window_and_grows_cwnd() {
417        let mut fc = FlowController::new(4);
418        assert!(fc.on_send(0, vec![0]));
419        assert!(!fc.can_send());
420        let cwnd_before = fc.cwnd();
421        fc.on_ack(0);
422        assert!(fc.cwnd() > cwnd_before);
423        assert!(fc.can_send());
424    }
425
426    #[test]
427    fn test_ack_unknown_sequence() {
428        let mut fc = FlowController::new(4);
429        fc.on_send(0, vec![0]);
430        assert!(!fc.on_ack(99));
431        assert_eq!(fc.in_flight_count(), 1);
432    }
433
434    #[test]
435    fn test_is_acked() {
436        let mut fc = FlowController::new(4);
437        fc.on_send(0, vec![0]);
438        assert!(!fc.is_acked(0));
439        fc.on_ack(0);
440        assert!(fc.is_acked(0));
441    }
442
443    #[test]
444    fn test_stats() {
445        let mut fc = FlowController::new(10);
446        for i in 0..5 {
447            if fc.can_send() {
448                fc.on_send(i, vec![0]);
449                fc.on_ack(i);
450            }
451        }
452        assert_eq!(fc.total_acked(), 5);
453    }
454
455    #[test]
456    fn test_loss_rate_zero() {
457        let fc = FlowController::new(4);
458        assert_eq!(fc.loss_rate(), 0.0);
459    }
460
461    #[test]
462    fn test_set_window_size() {
463        let mut fc = FlowController::new(4);
464        fc.set_window_size(8);
465        assert_eq!(fc.window_size(), 8);
466    }
467
468    #[test]
469    fn test_reset() {
470        let mut fc = FlowController::new(4);
471        fc.on_send(0, vec![0]);
472        fc.on_ack(0);
473        fc.reset();
474        assert_eq!(fc.in_flight_count(), 0);
475        assert_eq!(fc.total_sent(), 0);
476        assert_eq!(fc.total_acked(), 0);
477        assert!(fc.srtt().is_none());
478        assert!(fc.in_slow_start());
479        assert_eq!(fc.cwnd(), 1.0);
480    }
481
482    #[test]
483    fn test_timed_out_packets_returns_retransmit_requests() {
484        let mut fc = FlowController::with_rto(4, 1);
485        fc.on_send(0, b"hello".to_vec());
486        std::thread::sleep(Duration::from_millis(5));
487        let requests = fc.timed_out_packets();
488        assert_eq!(requests.len(), 1);
489        assert_eq!(requests[0].sequence, 0);
490        assert_eq!(requests[0].data, b"hello");
491        assert_eq!(requests[0].retransmit_count, 0);
492        assert_eq!(fc.total_lost(), 1);
493        assert_eq!(fc.total_retransmits(), 1);
494    }
495
496    #[test]
497    fn test_aimd_multiplicative_decrease_on_loss() {
498        let mut fc = FlowController::with_rto(4, 1);
499        fc.on_send(0, vec![0]);
500        std::thread::sleep(Duration::from_millis(5));
501        let requests = fc.timed_out_packets();
502        assert!(!requests.is_empty());
503        assert_eq!(fc.cwnd(), 1.0);
504        assert!(fc.in_slow_start());
505        assert_eq!(fc.total_lost(), 1);
506    }
507
508    #[test]
509    fn test_slow_start_exits_at_ssthresh() {
510        let mut fc = FlowController::new(64);
511        let ssthresh = fc.ssthresh;
512        let mut i = 0u64;
513        loop {
514            if fc.can_send() {
515                fc.on_send(i, vec![0]);
516                fc.on_ack(i);
517                i += 1;
518            }
519            if !fc.in_slow_start() { break; }
520            if i > 1000 { break; }
521        }
522        assert!(!fc.in_slow_start());
523        assert!(fc.cwnd() >= ssthresh);
524    }
525
526    #[test]
527    fn test_srtt_updated_on_ack() {
528        let mut fc = FlowController::new(4);
529        fc.on_send(0, vec![0]);
530        assert!(fc.srtt().is_none());
531        fc.on_ack(0);
532        assert!(fc.srtt().is_some());
533        assert!(fc.rttvar().is_some());
534    }
535
536    #[test]
537    fn test_default() {
538        let fc = FlowController::default();
539        assert_eq!(fc.window_size(), 64);
540    }
541
542    #[test]
543    fn test_on_send_full_window_returns_false() {
544        let mut fc = FlowController::new(4);
545        assert!(fc.on_send(0, vec![0]));
546        assert!(!fc.on_send(1, vec![0]));
547    }
548
549    #[test]
550    fn test_multiple_acks_grow_cwnd() {
551        let mut fc = FlowController::new(64);
552        let initial_cwnd = fc.cwnd();
553        for i in 0..10u64 {
554            if fc.can_send() {
555                fc.on_send(i, vec![0]);
556                fc.on_ack(i);
557            }
558        }
559        assert!(fc.cwnd() > initial_cwnd);
560        assert_eq!(fc.total_acked(), 10);
561    }
562
563    #[test]
564    fn test_oldest_unacked_sequence() {
565        let mut fc = FlowController::new(4);
566        assert!(fc.oldest_unacked_sequence().is_none());
567        fc.on_send(5, vec![0]);
568        assert_eq!(fc.oldest_unacked_sequence(), Some(5));
569    }
570
571    #[test]
572    fn test_effective_window_bounded_by_cwnd_and_max() {
573        let fc = FlowController::new(4);
574        assert_eq!(fc.effective_window(), 1);
575    }
576
577    #[test]
578    fn test_rto_doubles_on_loss() {
579        let mut fc = FlowController::with_rto(4, 1);
580        let rto_before = fc.rto();
581        fc.on_send(0, vec![0]);
582        std::thread::sleep(Duration::from_millis(5));
583        fc.timed_out_packets();
584        assert!(fc.rto() > rto_before);
585    }
586
587    #[test]
588    fn test_total_retransmits() {
589        let mut fc = FlowController::with_rto(4, 1);
590        fc.on_send(0, vec![0]);
591        std::thread::sleep(Duration::from_millis(5));
592        fc.timed_out_packets();
593        assert_eq!(fc.total_retransmits(), 1);
594        std::thread::sleep(Duration::from_millis(10));
595        fc.timed_out_packets();
596        assert_eq!(fc.total_retransmits(), 2);
597    }
598}