bitfold_protocol/
congestion.rs1use std::time::{Duration, Instant};
2
3#[derive(Debug, Clone)]
5pub struct CongestionControl {
6 rtt: Duration,
8 rtt_variance: Duration,
10 rtt_alpha: f32,
12 rtt_beta: f32,
14 packets_lost: u32,
16 packets_sent: u32,
18 throttle: f32,
20 min_throttle: f32,
22 max_throttle: f32,
24 last_throttle_update: Instant,
26 throttle_interval: Duration,
28
29 use_advanced_throttling: bool,
32 throttle_scale: u32,
34 packet_throttle: u32,
36 throttle_acceleration: u32,
38 throttle_deceleration: u32,
40}
41
42impl CongestionControl {
43 pub fn new(rtt_alpha: f32, rtt_beta: f32) -> Self {
45 Self {
46 rtt: Duration::from_millis(50), rtt_variance: Duration::from_millis(25),
48 rtt_alpha,
49 rtt_beta,
50 packets_lost: 0,
51 packets_sent: 0,
52 throttle: 0.0,
53 min_throttle: 0.0,
54 max_throttle: 1.0,
55 last_throttle_update: Instant::now(),
56 throttle_interval: Duration::from_secs(1),
57 use_advanced_throttling: false,
58 throttle_scale: 32, packet_throttle: 32, throttle_acceleration: 2, throttle_deceleration: 2, }
63 }
64
65 pub fn enable_advanced_throttling(
67 &mut self,
68 scale: u32,
69 acceleration: u32,
70 deceleration: u32,
71 interval_ms: u32,
72 ) {
73 self.use_advanced_throttling = true;
74 self.throttle_scale = scale;
75 self.packet_throttle = scale; self.throttle_acceleration = acceleration;
77 self.throttle_deceleration = deceleration;
78 self.throttle_interval = Duration::from_millis(interval_ms as u64);
79 }
80
81 pub fn update_rtt(&mut self, sample: Duration) {
84 let sample_ms = sample.as_millis() as f32;
85 let rtt_ms = self.rtt.as_millis() as f32;
86
87 let new_rtt_ms = (1.0 - self.rtt_alpha) * rtt_ms + self.rtt_alpha * sample_ms;
89 self.rtt = Duration::from_millis(new_rtt_ms as u64);
90
91 let diff = (rtt_ms - sample_ms).abs();
93 let var_ms = self.rtt_variance.as_millis() as f32;
94 let new_var_ms = (1.0 - self.rtt_beta) * var_ms + self.rtt_beta * diff;
95 self.rtt_variance = Duration::from_millis(new_var_ms as u64);
96 }
97
98 pub fn rtt(&self) -> Duration {
100 self.rtt
101 }
102
103 pub fn rtt_variance(&self) -> Duration {
105 self.rtt_variance
106 }
107
108 pub fn rto(&self) -> Duration {
111 self.rtt + Duration::from_millis(4 * self.rtt_variance.as_millis() as u64)
112 }
113
114 pub fn record_loss(&mut self) {
116 self.packets_lost += 1;
117 }
118
119 pub fn record_sent(&mut self) {
121 self.packets_sent += 1;
122 }
123
124 pub fn loss_rate(&self) -> f32 {
126 if self.packets_sent == 0 {
127 return 0.0;
128 }
129 self.packets_lost as f32 / self.packets_sent as f32
130 }
131
132 pub fn update_throttle(&mut self, now: Instant) -> bool {
135 if now.duration_since(self.last_throttle_update) < self.throttle_interval {
136 return false;
137 }
138
139 let loss_rate = self.loss_rate();
140
141 if self.use_advanced_throttling {
142 self.update_advanced_throttle(loss_rate);
144 } else {
145 self.update_simple_throttle(loss_rate);
147 }
148
149 self.packets_lost = 0;
151 self.packets_sent = 0;
152 self.last_throttle_update = now;
153
154 true
155 }
156
157 fn update_simple_throttle(&mut self, loss_rate: f32) {
159 if loss_rate > 0.05 {
161 self.throttle = (self.throttle + 0.1).min(self.max_throttle);
163 } else if loss_rate < 0.01 && self.throttle > self.min_throttle {
164 self.throttle = (self.throttle - 0.05).max(self.min_throttle);
166 }
167 }
168
169 fn update_advanced_throttle(&mut self, loss_rate: f32) {
171 if loss_rate > 0.01 {
175 if self.packet_throttle > self.throttle_deceleration {
178 self.packet_throttle -= self.throttle_deceleration;
179 } else {
180 self.packet_throttle = 0;
181 }
182 } else if loss_rate < 0.005 && self.packet_throttle < self.throttle_scale {
183 self.packet_throttle =
186 (self.packet_throttle + self.throttle_acceleration).min(self.throttle_scale);
187 }
188
189 self.throttle = 1.0 - (self.packet_throttle as f32 / self.throttle_scale as f32);
193 }
194
195 pub fn should_drop_unreliable(&self) -> bool {
198 if self.throttle == 0.0 {
199 return false;
200 }
201 rand::random::<f32>() < self.throttle
202 }
203
204 pub fn throttle(&self) -> f32 {
206 self.throttle
207 }
208
209 pub fn packet_throttle(&self) -> u32 {
211 self.packet_throttle
212 }
213
214 pub fn throttle_scale(&self) -> u32 {
216 self.throttle_scale
217 }
218
219 pub fn is_advanced_throttling_enabled(&self) -> bool {
221 self.use_advanced_throttling
222 }
223
224 pub fn set_throttle_range(&mut self, min: f32, max: f32) {
226 self.min_throttle = min.clamp(0.0, 1.0);
227 self.max_throttle = max.clamp(0.0, 1.0);
228 }
229
230 pub fn configure_throttle(&mut self, interval_ms: u32, acceleration: u32, deceleration: u32) {
232 self.throttle_interval = Duration::from_millis(interval_ms as u64);
233 self.throttle_acceleration = acceleration;
234 self.throttle_deceleration = deceleration;
235 }
236
237 pub fn reset_stats(&mut self) {
239 self.packets_lost = 0;
240 self.packets_sent = 0;
241 }
242}
243
244impl Default for CongestionControl {
245 fn default() -> Self {
246 Self::new(0.1, 0.25)
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn test_rtt_update() {
256 let mut cc = CongestionControl::default();
257
258 cc.update_rtt(Duration::from_millis(100));
259 assert!(cc.rtt() > Duration::from_millis(50)); cc.update_rtt(Duration::from_millis(100));
262 assert!(cc.rtt() < Duration::from_millis(100)); }
264
265 #[test]
266 fn test_rto_calculation() {
267 let mut cc = CongestionControl::default();
268
269 cc.update_rtt(Duration::from_millis(100));
270 let rto = cc.rto();
271
272 assert!(rto > cc.rtt());
274 }
275
276 #[test]
277 fn test_loss_rate() {
278 let mut cc = CongestionControl::default();
279
280 assert_eq!(cc.loss_rate(), 0.0);
281
282 cc.record_sent();
283 cc.record_sent();
284 cc.record_loss();
285
286 assert!((cc.loss_rate() - 0.5).abs() < 0.01); }
288
289 #[test]
290 fn test_throttle_increases_with_loss() {
291 let mut cc = CongestionControl::default();
292 let _start = Instant::now();
293
294 for _ in 0..100 {
296 cc.record_sent();
297 }
298 for _ in 0..10 {
299 cc.record_loss();
300 } std::thread::sleep(Duration::from_millis(1100));
303 let later = Instant::now();
304
305 let updated = cc.update_throttle(later);
306 assert!(updated);
307 assert!(cc.throttle() > 0.0);
308 }
309
310 #[test]
311 fn test_advanced_throttling_enabled() {
312 let mut cc = CongestionControl::default();
313
314 assert!(!cc.is_advanced_throttling_enabled());
315
316 cc.enable_advanced_throttling(32, 2, 2, 5000);
317
318 assert!(cc.is_advanced_throttling_enabled());
319 assert_eq!(cc.throttle_scale(), 32);
320 assert_eq!(cc.packet_throttle(), 32); }
322
323 #[test]
324 fn test_advanced_throttle_decreases_with_loss() {
325 let mut cc = CongestionControl::default();
326 cc.enable_advanced_throttling(32, 2, 2, 100); let initial_throttle = cc.packet_throttle();
329 assert_eq!(initial_throttle, 32); for _ in 0..100 {
333 cc.record_sent();
334 }
335 for _ in 0..2 {
336 cc.record_loss();
337 }
338
339 std::thread::sleep(Duration::from_millis(150));
340 let later = Instant::now();
341
342 let updated = cc.update_throttle(later);
343 assert!(updated);
344
345 assert!(cc.packet_throttle() < initial_throttle);
347 assert!(cc.throttle() > 0.0);
349 }
350
351 #[test]
352 fn test_advanced_throttle_increases_with_good_conditions() {
353 let mut cc = CongestionControl::default();
354 cc.enable_advanced_throttling(32, 2, 2, 100);
355
356 for _ in 0..100 {
358 cc.record_sent();
359 }
360 for _ in 0..2 {
361 cc.record_loss();
362 }
363
364 std::thread::sleep(Duration::from_millis(150));
365 cc.update_throttle(Instant::now());
366
367 let throttled_value = cc.packet_throttle();
368 assert!(throttled_value < 32);
369
370 for _ in 0..1000 {
372 cc.record_sent();
373 }
374 std::thread::sleep(Duration::from_millis(150));
377 cc.update_throttle(Instant::now());
378
379 assert!(cc.packet_throttle() > throttled_value);
381 }
382
383 #[test]
384 fn test_advanced_throttle_respects_scale() {
385 let mut cc = CongestionControl::default();
386 cc.enable_advanced_throttling(64, 5, 5, 100); assert_eq!(cc.throttle_scale(), 64);
389 assert_eq!(cc.packet_throttle(), 64);
390
391 for _round in 0..20 {
393 for _ in 0..100 {
394 cc.record_sent();
395 }
396 for _ in 0..5 {
397 cc.record_loss();
398 } std::thread::sleep(Duration::from_millis(150));
401 cc.update_throttle(Instant::now());
402 }
403
404 assert!(cc.packet_throttle() <= 64);
406 assert!(cc.throttle() > 0.5);
408 }
409
410 #[test]
411 fn test_configure_throttle_dynamically() {
412 let mut cc = CongestionControl::default();
413 cc.enable_advanced_throttling(32, 2, 2, 5000);
414
415 cc.configure_throttle(1000, 5, 3);
417
418 for _ in 0..50 {
421 cc.record_sent();
422 }
423 cc.record_loss();
424
425 std::thread::sleep(Duration::from_millis(1100));
426 assert!(cc.update_throttle(Instant::now()));
427 }
428}