1use serde::{Deserialize, Serialize};
37use std::collections::VecDeque;
38use std::time::{Duration, Instant};
39
40#[derive(Debug, Clone)]
42pub struct EstimatorConfig {
43 pub alpha: f64,
45 pub max_history: usize,
47 pub congestion_window_ms: u64,
49 pub loss_threshold_percent: f64,
51 pub rtt_var_threshold_percent: f64,
53 pub min_samples: usize,
55}
56
57impl Default for EstimatorConfig {
58 fn default() -> Self {
59 Self {
60 alpha: 0.2, max_history: 100,
62 congestion_window_ms: 1000, loss_threshold_percent: 5.0, rtt_var_threshold_percent: 50.0, min_samples: 5,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72#[allow(dead_code)]
73struct BandwidthSample {
74 timestamp: Instant,
76 bytes: u64,
78 duration_ms: u64,
80 bandwidth_mbps: f64,
82 rtt_ms: Option<f64>,
84 packet_loss: bool,
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
90pub enum CongestionState {
91 Normal,
93 Light,
95 Moderate,
97 Heavy,
99}
100
101pub struct BandwidthEstimator {
103 config: EstimatorConfig,
105 samples: VecDeque<BandwidthSample>,
107 estimate_mbps: f64,
109 congestion_state: CongestionState,
111 total_bytes: u64,
113 total_transfers: u64,
115 rtt_samples: VecDeque<f64>,
117 loss_count: u64,
119 packet_count: u64,
121}
122
123impl BandwidthEstimator {
124 #[must_use]
126 #[inline]
127 pub fn new(config: EstimatorConfig) -> Self {
128 Self {
129 config,
130 samples: VecDeque::new(),
131 estimate_mbps: 0.0,
132 congestion_state: CongestionState::Normal,
133 total_bytes: 0,
134 total_transfers: 0,
135 rtt_samples: VecDeque::new(),
136 loss_count: 0,
137 packet_count: 0,
138 }
139 }
140
141 pub fn record_transfer(&mut self, bytes: u64, duration_ms: u64) {
143 self.record_transfer_with_rtt(bytes, duration_ms, None, false);
144 }
145
146 pub fn record_transfer_with_rtt(
148 &mut self,
149 bytes: u64,
150 duration_ms: u64,
151 rtt_ms: Option<f64>,
152 packet_loss: bool,
153 ) {
154 if duration_ms == 0 {
155 return;
156 }
157
158 let bandwidth_mbps = (bytes as f64 * 8.0) / (duration_ms as f64 * 1000.0);
160
161 let sample = BandwidthSample {
162 timestamp: Instant::now(),
163 bytes,
164 duration_ms,
165 bandwidth_mbps,
166 rtt_ms,
167 packet_loss,
168 };
169
170 if self.estimate_mbps == 0.0 {
172 self.estimate_mbps = bandwidth_mbps;
173 } else {
174 self.estimate_mbps =
175 self.config.alpha * bandwidth_mbps + (1.0 - self.config.alpha) * self.estimate_mbps;
176 }
177
178 self.samples.push_back(sample);
180 if self.samples.len() > self.config.max_history {
181 self.samples.pop_front();
182 }
183
184 if let Some(rtt) = rtt_ms {
186 self.rtt_samples.push_back(rtt);
187 if self.rtt_samples.len() > self.config.max_history {
188 self.rtt_samples.pop_front();
189 }
190 }
191
192 self.packet_count += 1;
194 if packet_loss {
195 self.loss_count += 1;
196 }
197
198 self.total_bytes += bytes;
200 self.total_transfers += 1;
201
202 self.update_congestion_state();
204 }
205
206 #[must_use]
208 #[inline]
209 pub fn estimate_mbps(&self) -> f64 {
210 self.estimate_mbps
211 }
212
213 #[must_use]
215 #[inline]
216 pub fn estimate_bps(&self) -> u64 {
217 (self.estimate_mbps * 125_000.0) as u64
218 }
219
220 #[must_use]
222 #[inline]
223 pub fn is_reliable(&self) -> bool {
224 self.samples.len() >= self.config.min_samples
225 }
226
227 #[must_use]
229 #[inline]
230 pub const fn congestion_state(&self) -> CongestionState {
231 self.congestion_state
232 }
233
234 #[must_use]
236 #[inline]
237 pub const fn is_congested(&self) -> bool {
238 !matches!(self.congestion_state, CongestionState::Normal)
239 }
240
241 #[must_use]
243 #[inline]
244 pub fn packet_loss_percent(&self) -> f64 {
245 if self.packet_count == 0 {
246 0.0
247 } else {
248 (self.loss_count as f64 / self.packet_count as f64) * 100.0
249 }
250 }
251
252 #[must_use]
254 #[inline]
255 pub fn rtt_variation_percent(&self) -> f64 {
256 if self.rtt_samples.len() < 2 {
257 return 0.0;
258 }
259
260 let mean = self.rtt_samples.iter().sum::<f64>() / self.rtt_samples.len() as f64;
261 let variance = self
262 .rtt_samples
263 .iter()
264 .map(|x| (x - mean).powi(2))
265 .sum::<f64>()
266 / self.rtt_samples.len() as f64;
267
268 let std_dev = variance.sqrt();
269 if mean > 0.0 {
270 (std_dev / mean) * 100.0
271 } else {
272 0.0
273 }
274 }
275
276 #[must_use]
280 #[inline]
281 pub fn recommended_rate_bps(&self) -> u64 {
282 let base_rate = self.estimate_bps();
283
284 let reduction_factor = match self.congestion_state {
286 CongestionState::Normal => 1.0,
287 CongestionState::Light => 0.8, CongestionState::Moderate => 0.5, CongestionState::Heavy => 0.25, };
291
292 (base_rate as f64 * reduction_factor) as u64
293 }
294
295 #[must_use]
297 #[inline]
298 pub fn stats(&self) -> BandwidthStats {
299 let min_bw = self
300 .samples
301 .iter()
302 .map(|s| s.bandwidth_mbps)
303 .min_by(|a, b| a.partial_cmp(b).unwrap())
304 .unwrap_or(0.0);
305
306 let max_bw = self
307 .samples
308 .iter()
309 .map(|s| s.bandwidth_mbps)
310 .max_by(|a, b| a.partial_cmp(b).unwrap())
311 .unwrap_or(0.0);
312
313 let avg_rtt = if self.rtt_samples.is_empty() {
314 None
315 } else {
316 Some(self.rtt_samples.iter().sum::<f64>() / self.rtt_samples.len() as f64)
317 };
318
319 BandwidthStats {
320 current_estimate_mbps: self.estimate_mbps,
321 min_bandwidth_mbps: min_bw,
322 max_bandwidth_mbps: max_bw,
323 avg_rtt_ms: avg_rtt,
324 rtt_variation_percent: self.rtt_variation_percent(),
325 packet_loss_percent: self.packet_loss_percent(),
326 congestion_state: self.congestion_state,
327 sample_count: self.samples.len(),
328 is_reliable: self.is_reliable(),
329 total_bytes: self.total_bytes,
330 total_transfers: self.total_transfers,
331 }
332 }
333
334 pub fn reset(&mut self) {
336 self.samples.clear();
337 self.rtt_samples.clear();
338 self.estimate_mbps = 0.0;
339 self.congestion_state = CongestionState::Normal;
340 self.total_bytes = 0;
341 self.total_transfers = 0;
342 self.loss_count = 0;
343 self.packet_count = 0;
344 }
345
346 pub fn prune_old_samples(&mut self) {
348 let cutoff = Instant::now() - Duration::from_millis(self.config.congestion_window_ms);
349
350 while let Some(sample) = self.samples.front() {
351 if sample.timestamp < cutoff {
352 self.samples.pop_front();
353 } else {
354 break;
355 }
356 }
357 }
358
359 fn update_congestion_state(&mut self) {
361 self.prune_old_samples();
363
364 let loss_percent = self.packet_loss_percent();
365 let rtt_var_percent = self.rtt_variation_percent();
366
367 let loss_congested = loss_percent > self.config.loss_threshold_percent;
369 let rtt_congested = rtt_var_percent > self.config.rtt_var_threshold_percent;
370
371 self.congestion_state = if loss_percent > 15.0 || rtt_var_percent > 100.0 {
372 CongestionState::Heavy
373 } else if loss_percent > 10.0 || rtt_var_percent > 75.0 {
374 CongestionState::Moderate
375 } else if loss_congested || rtt_congested {
376 CongestionState::Light
377 } else {
378 CongestionState::Normal
379 };
380 }
381}
382
383#[derive(Debug, Clone, Serialize, Deserialize)]
385pub struct BandwidthStats {
386 pub current_estimate_mbps: f64,
388 pub min_bandwidth_mbps: f64,
390 pub max_bandwidth_mbps: f64,
392 pub avg_rtt_ms: Option<f64>,
394 pub rtt_variation_percent: f64,
396 pub packet_loss_percent: f64,
398 pub congestion_state: CongestionState,
400 pub sample_count: usize,
402 pub is_reliable: bool,
404 pub total_bytes: u64,
406 pub total_transfers: u64,
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_bandwidth_estimator() {
416 let config = EstimatorConfig::default();
417 let mut estimator = BandwidthEstimator::new(config);
418
419 estimator.record_transfer(1_000_000, 100);
422
423 let estimate = estimator.estimate_mbps();
424 assert!((estimate - 80.0).abs() < 1.0);
426 }
427
428 #[test]
429 fn test_ewma_smoothing() {
430 let config = EstimatorConfig {
431 alpha: 0.5,
432 ..Default::default()
433 };
434 let mut estimator = BandwidthEstimator::new(config);
435
436 estimator.record_transfer(12_500_000, 1000);
439 assert!((estimator.estimate_mbps() - 100.0).abs() < 1.0);
440
441 estimator.record_transfer(6_250_000, 1000);
443 assert!((estimator.estimate_mbps() - 75.0).abs() < 1.0);
445 }
446
447 #[test]
448 fn test_congestion_detection() {
449 let config = EstimatorConfig::default();
450 let mut estimator = BandwidthEstimator::new(config);
451
452 for _ in 0..10 {
454 estimator.record_transfer_with_rtt(1024 * 1024, 100, Some(50.0), false);
455 }
456 assert_eq!(estimator.congestion_state(), CongestionState::Normal);
457
458 for _ in 0..10 {
460 estimator.record_transfer_with_rtt(1024 * 1024, 100, Some(50.0), true);
461 }
462 assert!(estimator.is_congested());
463 }
464
465 #[test]
466 fn test_packet_loss_calculation() {
467 let config = EstimatorConfig::default();
468 let mut estimator = BandwidthEstimator::new(config);
469
470 for _ in 0..3 {
472 estimator.record_transfer_with_rtt(1024, 10, None, false);
473 }
474 estimator.record_transfer_with_rtt(1024, 10, None, true);
475
476 assert!((estimator.packet_loss_percent() - 25.0).abs() < 0.1);
477 }
478
479 #[test]
480 fn test_recommended_rate() {
481 let config = EstimatorConfig::default();
482 let mut estimator = BandwidthEstimator::new(config);
483
484 estimator.record_transfer(1024 * 1024, 100); let normal_rate = estimator.recommended_rate_bps();
487
488 for _ in 0..10 {
490 estimator.record_transfer_with_rtt(1024 * 1024, 100, Some(50.0), true);
491 }
492
493 let congested_rate = estimator.recommended_rate_bps();
494 assert!(congested_rate < normal_rate);
495 }
496
497 #[test]
498 fn test_reliability() {
499 let config = EstimatorConfig {
500 min_samples: 3,
501 ..Default::default()
502 };
503 let mut estimator = BandwidthEstimator::new(config);
504
505 assert!(!estimator.is_reliable());
506
507 estimator.record_transfer(1024, 10);
508 estimator.record_transfer(1024, 10);
509 assert!(!estimator.is_reliable());
510
511 estimator.record_transfer(1024, 10);
512 assert!(estimator.is_reliable());
513 }
514
515 #[test]
516 fn test_reset() {
517 let config = EstimatorConfig::default();
518 let mut estimator = BandwidthEstimator::new(config);
519
520 estimator.record_transfer(1024 * 1024, 100);
521 assert!(estimator.estimate_mbps() > 0.0);
522
523 estimator.reset();
524 assert_eq!(estimator.estimate_mbps(), 0.0);
525 assert_eq!(estimator.total_bytes, 0);
526 assert_eq!(estimator.congestion_state(), CongestionState::Normal);
527 }
528}