bitfold_protocol/
congestion.rs

1use std::time::{Duration, Instant};
2
3/// Congestion control and RTT tracking with dynamic throttle.
4#[derive(Debug, Clone)]
5pub struct CongestionControl {
6    /// Smoothed round-trip time
7    rtt: Duration,
8    /// RTT variance
9    rtt_variance: Duration,
10    /// Smoothing factor for RTT calculations (typically 0.1)
11    rtt_alpha: f32,
12    /// Variance smoothing factor (typically 0.25)
13    rtt_beta: f32,
14    /// Number of packets lost
15    packets_lost: u32,
16    /// Number of packets sent
17    packets_sent: u32,
18    /// Current throttle value (probability of dropping unreliable packets, 0.0-1.0)
19    throttle: f32,
20    /// Minimum throttle value
21    min_throttle: f32,
22    /// Maximum throttle value
23    max_throttle: f32,
24    /// Last time throttle was updated
25    last_throttle_update: Instant,
26    /// Throttle update interval
27    throttle_interval: Duration,
28
29    // Advanced throttling
30    /// Enable advanced throttling with acceleration/deceleration
31    use_advanced_throttling: bool,
32    /// Throttle scale (maximum value, typically 32)
33    throttle_scale: u32,
34    /// Current throttle value in scale units (0 to throttle_scale)
35    packet_throttle: u32,
36    /// Throttle acceleration (rate of improvement)
37    throttle_acceleration: u32,
38    /// Throttle deceleration (rate of degradation)
39    throttle_deceleration: u32,
40}
41
42impl CongestionControl {
43    /// Creates a new congestion control instance.
44    pub fn new(rtt_alpha: f32, rtt_beta: f32) -> Self {
45        Self {
46            rtt: Duration::from_millis(50), // Initial estimate
47            rtt_variance: Duration::from_millis(25),
48            rtt_alpha,
49            rtt_beta,
50            packets_lost: 0,
51            packets_sent: 0,
52            throttle: 0.0,
53            min_throttle: 0.0,
54            max_throttle: 1.0,
55            last_throttle_update: Instant::now(),
56            throttle_interval: Duration::from_secs(1),
57            use_advanced_throttling: false,
58            throttle_scale: 32,       // Default scale
59            packet_throttle: 32,      // Start at maximum (no throttling)
60            throttle_acceleration: 2, // Default acceleration
61            throttle_deceleration: 2, // Default deceleration
62        }
63    }
64
65    /// Enables advanced throttling with custom parameters.
66    pub fn enable_advanced_throttling(
67        &mut self,
68        scale: u32,
69        acceleration: u32,
70        deceleration: u32,
71        interval_ms: u32,
72    ) {
73        self.use_advanced_throttling = true;
74        self.throttle_scale = scale;
75        self.packet_throttle = scale; // Start at maximum (no throttling)
76        self.throttle_acceleration = acceleration;
77        self.throttle_deceleration = deceleration;
78        self.throttle_interval = Duration::from_millis(interval_ms as u64);
79    }
80
81    /// Updates RTT measurement with a new sample.
82    /// Uses exponential weighted moving average (EWMA).
83    pub fn update_rtt(&mut self, sample: Duration) {
84        let sample_ms = sample.as_millis() as f32;
85        let rtt_ms = self.rtt.as_millis() as f32;
86
87        // EWMA for RTT: RTT = (1 - α) * RTT + α * sample
88        let new_rtt_ms = (1.0 - self.rtt_alpha) * rtt_ms + self.rtt_alpha * sample_ms;
89        self.rtt = Duration::from_millis(new_rtt_ms as u64);
90
91        // Update variance: Var = (1 - β) * Var + β * |RTT - sample|
92        let diff = (rtt_ms - sample_ms).abs();
93        let var_ms = self.rtt_variance.as_millis() as f32;
94        let new_var_ms = (1.0 - self.rtt_beta) * var_ms + self.rtt_beta * diff;
95        self.rtt_variance = Duration::from_millis(new_var_ms as u64);
96    }
97
98    /// Returns the current smoothed RTT.
99    pub fn rtt(&self) -> Duration {
100        self.rtt
101    }
102
103    /// Returns the RTT variance.
104    pub fn rtt_variance(&self) -> Duration {
105        self.rtt_variance
106    }
107
108    /// Returns the retransmission timeout (RTO) based on RTT.
109    /// Uses the standard RTO = RTT + 4 * variance formula.
110    pub fn rto(&self) -> Duration {
111        self.rtt + Duration::from_millis(4 * self.rtt_variance.as_millis() as u64)
112    }
113
114    /// Records a packet loss event.
115    pub fn record_loss(&mut self) {
116        self.packets_lost += 1;
117    }
118
119    /// Records a packet send event.
120    pub fn record_sent(&mut self) {
121        self.packets_sent += 1;
122    }
123
124    /// Returns the packet loss rate (0.0 to 1.0).
125    pub fn loss_rate(&self) -> f32 {
126        if self.packets_sent == 0 {
127            return 0.0;
128        }
129        self.packets_lost as f32 / self.packets_sent as f32
130    }
131
132    /// Updates the dynamic throttle based on current network conditions.
133    /// Returns true if throttle was updated.
134    pub fn update_throttle(&mut self, now: Instant) -> bool {
135        if now.duration_since(self.last_throttle_update) < self.throttle_interval {
136            return false;
137        }
138
139        let loss_rate = self.loss_rate();
140
141        if self.use_advanced_throttling {
142            // Advanced throttling with acceleration/deceleration
143            self.update_advanced_throttle(loss_rate);
144        } else {
145            // Simple throttling (backward compatible)
146            self.update_simple_throttle(loss_rate);
147        }
148
149        // Reset counters
150        self.packets_lost = 0;
151        self.packets_sent = 0;
152        self.last_throttle_update = now;
153
154        true
155    }
156
157    /// Simple throttle update (original implementation).
158    fn update_simple_throttle(&mut self, loss_rate: f32) {
159        // Increase throttle if packet loss is high
160        if loss_rate > 0.05 {
161            // More than 5% loss
162            self.throttle = (self.throttle + 0.1).min(self.max_throttle);
163        } else if loss_rate < 0.01 && self.throttle > self.min_throttle {
164            // Less than 1% loss, decrease throttle
165            self.throttle = (self.throttle - 0.05).max(self.min_throttle);
166        }
167    }
168
169    /// Advanced throttle update with acceleration/deceleration.
170    fn update_advanced_throttle(&mut self, loss_rate: f32) {
171        // packet_throttle ranges from 0 (drop everything) to throttle_scale (drop nothing)
172        // Higher packet_throttle = less throttling = better conditions
173
174        if loss_rate > 0.01 {
175            // Packet loss detected - decrease packet_throttle (increase throttling)
176            // Use deceleration to control how fast we throttle
177            if self.packet_throttle > self.throttle_deceleration {
178                self.packet_throttle -= self.throttle_deceleration;
179            } else {
180                self.packet_throttle = 0;
181            }
182        } else if loss_rate < 0.005 && self.packet_throttle < self.throttle_scale {
183            // Low/no packet loss - increase packet_throttle (decrease throttling)
184            // Use acceleration to control how fast we recover
185            self.packet_throttle =
186                (self.packet_throttle + self.throttle_acceleration).min(self.throttle_scale);
187        }
188
189        // Convert packet_throttle to 0.0-1.0 throttle value
190        // packet_throttle=scale means no throttling (throttle=0.0)
191        // packet_throttle=0 means maximum throttling (throttle=1.0)
192        self.throttle = 1.0 - (self.packet_throttle as f32 / self.throttle_scale as f32);
193    }
194
195    /// Returns whether an unreliable packet should be dropped based on throttle.
196    /// Uses throttle as drop probability.
197    pub fn should_drop_unreliable(&self) -> bool {
198        if self.throttle == 0.0 {
199            return false;
200        }
201        rand::random::<f32>() < self.throttle
202    }
203
204    /// Returns the current throttle value.
205    pub fn throttle(&self) -> f32 {
206        self.throttle
207    }
208
209    /// Returns the current packet throttle value (0 to throttle_scale).
210    pub fn packet_throttle(&self) -> u32 {
211        self.packet_throttle
212    }
213
214    /// Returns the throttle scale.
215    pub fn throttle_scale(&self) -> u32 {
216        self.throttle_scale
217    }
218
219    /// Returns whether advanced throttling is enabled.
220    pub fn is_advanced_throttling_enabled(&self) -> bool {
221        self.use_advanced_throttling
222    }
223
224    /// Sets the throttle range.
225    pub fn set_throttle_range(&mut self, min: f32, max: f32) {
226        self.min_throttle = min.clamp(0.0, 1.0);
227        self.max_throttle = max.clamp(0.0, 1.0);
228    }
229
230    /// Configures throttle parameters dynamically (for ThrottleConfigure command).
231    pub fn configure_throttle(&mut self, interval_ms: u32, acceleration: u32, deceleration: u32) {
232        self.throttle_interval = Duration::from_millis(interval_ms as u64);
233        self.throttle_acceleration = acceleration;
234        self.throttle_deceleration = deceleration;
235    }
236
237    /// Resets all statistics.
238    pub fn reset_stats(&mut self) {
239        self.packets_lost = 0;
240        self.packets_sent = 0;
241    }
242}
243
244impl Default for CongestionControl {
245    fn default() -> Self {
246        Self::new(0.1, 0.25)
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn test_rtt_update() {
256        let mut cc = CongestionControl::default();
257
258        cc.update_rtt(Duration::from_millis(100));
259        assert!(cc.rtt() > Duration::from_millis(50)); // Should increase from initial
260
261        cc.update_rtt(Duration::from_millis(100));
262        assert!(cc.rtt() < Duration::from_millis(100)); // Smoothed, so less than sample
263    }
264
265    #[test]
266    fn test_rto_calculation() {
267        let mut cc = CongestionControl::default();
268
269        cc.update_rtt(Duration::from_millis(100));
270        let rto = cc.rto();
271
272        // RTO should be RTT + 4 * variance
273        assert!(rto > cc.rtt());
274    }
275
276    #[test]
277    fn test_loss_rate() {
278        let mut cc = CongestionControl::default();
279
280        assert_eq!(cc.loss_rate(), 0.0);
281
282        cc.record_sent();
283        cc.record_sent();
284        cc.record_loss();
285
286        assert!((cc.loss_rate() - 0.5).abs() < 0.01); // 1 lost out of 2 sent = 50%
287    }
288
289    #[test]
290    fn test_throttle_increases_with_loss() {
291        let mut cc = CongestionControl::default();
292        let _start = Instant::now();
293
294        // Simulate high packet loss
295        for _ in 0..100 {
296            cc.record_sent();
297        }
298        for _ in 0..10 {
299            cc.record_loss();
300        } // 10% loss
301
302        std::thread::sleep(Duration::from_millis(1100));
303        let later = Instant::now();
304
305        let updated = cc.update_throttle(later);
306        assert!(updated);
307        assert!(cc.throttle() > 0.0);
308    }
309
310    #[test]
311    fn test_advanced_throttling_enabled() {
312        let mut cc = CongestionControl::default();
313
314        assert!(!cc.is_advanced_throttling_enabled());
315
316        cc.enable_advanced_throttling(32, 2, 2, 5000);
317
318        assert!(cc.is_advanced_throttling_enabled());
319        assert_eq!(cc.throttle_scale(), 32);
320        assert_eq!(cc.packet_throttle(), 32); // Starts at max (no throttling)
321    }
322
323    #[test]
324    fn test_advanced_throttle_decreases_with_loss() {
325        let mut cc = CongestionControl::default();
326        cc.enable_advanced_throttling(32, 2, 2, 100); // 100ms interval for faster testing
327
328        let initial_throttle = cc.packet_throttle();
329        assert_eq!(initial_throttle, 32); // Start at max
330
331        // Simulate moderate packet loss (2%)
332        for _ in 0..100 {
333            cc.record_sent();
334        }
335        for _ in 0..2 {
336            cc.record_loss();
337        }
338
339        std::thread::sleep(Duration::from_millis(150));
340        let later = Instant::now();
341
342        let updated = cc.update_throttle(later);
343        assert!(updated);
344
345        // packet_throttle should decrease (more throttling)
346        assert!(cc.packet_throttle() < initial_throttle);
347        // Overall throttle value should increase (0.0-1.0 scale)
348        assert!(cc.throttle() > 0.0);
349    }
350
351    #[test]
352    fn test_advanced_throttle_increases_with_good_conditions() {
353        let mut cc = CongestionControl::default();
354        cc.enable_advanced_throttling(32, 2, 2, 100);
355
356        // First, decrease throttle by simulating loss
357        for _ in 0..100 {
358            cc.record_sent();
359        }
360        for _ in 0..2 {
361            cc.record_loss();
362        }
363
364        std::thread::sleep(Duration::from_millis(150));
365        cc.update_throttle(Instant::now());
366
367        let throttled_value = cc.packet_throttle();
368        assert!(throttled_value < 32);
369
370        // Now simulate good conditions (low loss)
371        for _ in 0..1000 {
372            cc.record_sent();
373        }
374        // Only 2 losses = 0.2% loss rate
375
376        std::thread::sleep(Duration::from_millis(150));
377        cc.update_throttle(Instant::now());
378
379        // packet_throttle should increase (less throttling)
380        assert!(cc.packet_throttle() > throttled_value);
381    }
382
383    #[test]
384    fn test_advanced_throttle_respects_scale() {
385        let mut cc = CongestionControl::default();
386        cc.enable_advanced_throttling(64, 5, 5, 100); // Higher scale and rates
387
388        assert_eq!(cc.throttle_scale(), 64);
389        assert_eq!(cc.packet_throttle(), 64);
390
391        // Simulate sustained packet loss
392        for _round in 0..20 {
393            for _ in 0..100 {
394                cc.record_sent();
395            }
396            for _ in 0..5 {
397                cc.record_loss();
398            } // 5% loss
399
400            std::thread::sleep(Duration::from_millis(150));
401            cc.update_throttle(Instant::now());
402        }
403
404        // Should throttle heavily but not go below 0
405        assert!(cc.packet_throttle() <= 64);
406        // throttle (0.0-1.0) should be high
407        assert!(cc.throttle() > 0.5);
408    }
409
410    #[test]
411    fn test_configure_throttle_dynamically() {
412        let mut cc = CongestionControl::default();
413        cc.enable_advanced_throttling(32, 2, 2, 5000);
414
415        // Dynamically reconfigure
416        cc.configure_throttle(1000, 5, 3);
417
418        // Verify the configuration was updated (interval is harder to test without waiting)
419        // Just verify it doesn't panic and throttle still works
420        for _ in 0..50 {
421            cc.record_sent();
422        }
423        cc.record_loss();
424
425        std::thread::sleep(Duration::from_millis(1100));
426        assert!(cc.update_throttle(Instant::now()));
427    }
428}