chie_core/
bandwidth_estimation.rs

1//! Bandwidth estimation and congestion detection for the CHIE protocol.
2//!
3//! This module provides real-time bandwidth estimation and congestion detection
4//! capabilities to optimize content delivery and prevent network overload.
5//!
6//! # Features
7//!
8//! - Real-time bandwidth estimation using exponentially weighted moving averages
9//! - Congestion detection through packet loss and latency variation
10//! - Adaptive rate limiting based on network conditions
11//! - Historical bandwidth tracking and statistics
12//!
13//! # Example
14//!
15//! ```rust
16//! use chie_core::bandwidth_estimation::{BandwidthEstimator, EstimatorConfig};
17//!
18//! # async fn example() {
19//! let config = EstimatorConfig::default();
20//! let mut estimator = BandwidthEstimator::new(config);
21//!
22//! // Record a data transfer
23//! estimator.record_transfer(1024 * 1024, 100); // 1 MB in 100ms
24//!
25//! // Get current estimate
26//! let bandwidth_mbps = estimator.estimate_mbps();
27//! println!("Estimated bandwidth: {:.2} Mbps", bandwidth_mbps);
28//!
29//! // Check for congestion
30//! if estimator.is_congested() {
31//!     println!("Network is congested, reducing rate");
32//! }
33//! # }
34//! ```
35
36use serde::{Deserialize, Serialize};
37use std::collections::VecDeque;
38use std::time::{Duration, Instant};
39
40/// Configuration for bandwidth estimator.
41#[derive(Debug, Clone)]
42pub struct EstimatorConfig {
43    /// Smoothing factor for EWMA (0.0 to 1.0).
44    pub alpha: f64,
45    /// Maximum history size for measurements.
46    pub max_history: usize,
47    /// Window size for congestion detection (milliseconds).
48    pub congestion_window_ms: u64,
49    /// Packet loss threshold for congestion (percentage).
50    pub loss_threshold_percent: f64,
51    /// RTT variation threshold for congestion (percentage).
52    pub rtt_var_threshold_percent: f64,
53    /// Minimum samples before estimation is considered reliable.
54    pub min_samples: usize,
55}
56
57impl Default for EstimatorConfig {
58    fn default() -> Self {
59        Self {
60            alpha: 0.2, // 20% weight to new samples
61            max_history: 100,
62            congestion_window_ms: 1000,      // 1 second window
63            loss_threshold_percent: 5.0,     // 5% loss indicates congestion
64            rtt_var_threshold_percent: 50.0, // 50% RTT variation
65            min_samples: 5,
66        }
67    }
68}
69
70/// A single bandwidth measurement.
71#[derive(Debug, Clone)]
72#[allow(dead_code)]
73struct BandwidthSample {
74    /// Timestamp of measurement.
75    timestamp: Instant,
76    /// Bytes transferred.
77    bytes: u64,
78    /// Duration in milliseconds.
79    duration_ms: u64,
80    /// Calculated bandwidth in Mbps.
81    bandwidth_mbps: f64,
82    /// Round-trip time in milliseconds (if available).
83    rtt_ms: Option<f64>,
84    /// Whether packet loss was detected.
85    packet_loss: bool,
86}
87
88/// Congestion state.
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
90pub enum CongestionState {
91    /// No congestion detected.
92    Normal,
93    /// Light congestion detected.
94    Light,
95    /// Moderate congestion detected.
96    Moderate,
97    /// Heavy congestion detected.
98    Heavy,
99}
100
101/// Bandwidth estimator with congestion detection.
102pub struct BandwidthEstimator {
103    /// Configuration.
104    config: EstimatorConfig,
105    /// Sample history.
106    samples: VecDeque<BandwidthSample>,
107    /// Current bandwidth estimate (EWMA).
108    estimate_mbps: f64,
109    /// Current congestion state.
110    congestion_state: CongestionState,
111    /// Total bytes transferred.
112    total_bytes: u64,
113    /// Total transfers.
114    total_transfers: u64,
115    /// Recent RTT measurements.
116    rtt_samples: VecDeque<f64>,
117    /// Packet loss count in current window.
118    loss_count: u64,
119    /// Total packet count in current window.
120    packet_count: u64,
121}
122
123impl BandwidthEstimator {
124    /// Create a new bandwidth estimator.
125    #[must_use]
126    #[inline]
127    pub fn new(config: EstimatorConfig) -> Self {
128        Self {
129            config,
130            samples: VecDeque::new(),
131            estimate_mbps: 0.0,
132            congestion_state: CongestionState::Normal,
133            total_bytes: 0,
134            total_transfers: 0,
135            rtt_samples: VecDeque::new(),
136            loss_count: 0,
137            packet_count: 0,
138        }
139    }
140
141    /// Record a data transfer.
142    pub fn record_transfer(&mut self, bytes: u64, duration_ms: u64) {
143        self.record_transfer_with_rtt(bytes, duration_ms, None, false);
144    }
145
146    /// Record a data transfer with RTT information.
147    pub fn record_transfer_with_rtt(
148        &mut self,
149        bytes: u64,
150        duration_ms: u64,
151        rtt_ms: Option<f64>,
152        packet_loss: bool,
153    ) {
154        if duration_ms == 0 {
155            return;
156        }
157
158        // Calculate instantaneous bandwidth
159        let bandwidth_mbps = (bytes as f64 * 8.0) / (duration_ms as f64 * 1000.0);
160
161        let sample = BandwidthSample {
162            timestamp: Instant::now(),
163            bytes,
164            duration_ms,
165            bandwidth_mbps,
166            rtt_ms,
167            packet_loss,
168        };
169
170        // Update EWMA
171        if self.estimate_mbps == 0.0 {
172            self.estimate_mbps = bandwidth_mbps;
173        } else {
174            self.estimate_mbps =
175                self.config.alpha * bandwidth_mbps + (1.0 - self.config.alpha) * self.estimate_mbps;
176        }
177
178        // Add to history
179        self.samples.push_back(sample);
180        if self.samples.len() > self.config.max_history {
181            self.samples.pop_front();
182        }
183
184        // Track RTT
185        if let Some(rtt) = rtt_ms {
186            self.rtt_samples.push_back(rtt);
187            if self.rtt_samples.len() > self.config.max_history {
188                self.rtt_samples.pop_front();
189            }
190        }
191
192        // Track packets
193        self.packet_count += 1;
194        if packet_loss {
195            self.loss_count += 1;
196        }
197
198        // Update totals
199        self.total_bytes += bytes;
200        self.total_transfers += 1;
201
202        // Update congestion state
203        self.update_congestion_state();
204    }
205
206    /// Get current bandwidth estimate in Mbps.
207    #[must_use]
208    #[inline]
209    pub fn estimate_mbps(&self) -> f64 {
210        self.estimate_mbps
211    }
212
213    /// Get current bandwidth estimate in bytes per second.
214    #[must_use]
215    #[inline]
216    pub fn estimate_bps(&self) -> u64 {
217        (self.estimate_mbps * 125_000.0) as u64
218    }
219
220    /// Check if bandwidth estimate is reliable.
221    #[must_use]
222    #[inline]
223    pub fn is_reliable(&self) -> bool {
224        self.samples.len() >= self.config.min_samples
225    }
226
227    /// Get current congestion state.
228    #[must_use]
229    #[inline]
230    pub const fn congestion_state(&self) -> CongestionState {
231        self.congestion_state
232    }
233
234    /// Check if network is currently congested.
235    #[must_use]
236    #[inline]
237    pub const fn is_congested(&self) -> bool {
238        !matches!(self.congestion_state, CongestionState::Normal)
239    }
240
241    /// Get packet loss percentage in recent window.
242    #[must_use]
243    #[inline]
244    pub fn packet_loss_percent(&self) -> f64 {
245        if self.packet_count == 0 {
246            0.0
247        } else {
248            (self.loss_count as f64 / self.packet_count as f64) * 100.0
249        }
250    }
251
252    /// Get RTT variation (standard deviation / mean).
253    #[must_use]
254    #[inline]
255    pub fn rtt_variation_percent(&self) -> f64 {
256        if self.rtt_samples.len() < 2 {
257            return 0.0;
258        }
259
260        let mean = self.rtt_samples.iter().sum::<f64>() / self.rtt_samples.len() as f64;
261        let variance = self
262            .rtt_samples
263            .iter()
264            .map(|x| (x - mean).powi(2))
265            .sum::<f64>()
266            / self.rtt_samples.len() as f64;
267
268        let std_dev = variance.sqrt();
269        if mean > 0.0 {
270            (std_dev / mean) * 100.0
271        } else {
272            0.0
273        }
274    }
275
276    /// Get recommended transfer rate in bytes per second.
277    ///
278    /// Returns a conservative rate based on current conditions.
279    #[must_use]
280    #[inline]
281    pub fn recommended_rate_bps(&self) -> u64 {
282        let base_rate = self.estimate_bps();
283
284        // Apply congestion-based reduction
285        let reduction_factor = match self.congestion_state {
286            CongestionState::Normal => 1.0,
287            CongestionState::Light => 0.8,    // 20% reduction
288            CongestionState::Moderate => 0.5, // 50% reduction
289            CongestionState::Heavy => 0.25,   // 75% reduction
290        };
291
292        (base_rate as f64 * reduction_factor) as u64
293    }
294
295    /// Get statistics about bandwidth estimation.
296    #[must_use]
297    #[inline]
298    pub fn stats(&self) -> BandwidthStats {
299        let min_bw = self
300            .samples
301            .iter()
302            .map(|s| s.bandwidth_mbps)
303            .min_by(|a, b| a.partial_cmp(b).unwrap())
304            .unwrap_or(0.0);
305
306        let max_bw = self
307            .samples
308            .iter()
309            .map(|s| s.bandwidth_mbps)
310            .max_by(|a, b| a.partial_cmp(b).unwrap())
311            .unwrap_or(0.0);
312
313        let avg_rtt = if self.rtt_samples.is_empty() {
314            None
315        } else {
316            Some(self.rtt_samples.iter().sum::<f64>() / self.rtt_samples.len() as f64)
317        };
318
319        BandwidthStats {
320            current_estimate_mbps: self.estimate_mbps,
321            min_bandwidth_mbps: min_bw,
322            max_bandwidth_mbps: max_bw,
323            avg_rtt_ms: avg_rtt,
324            rtt_variation_percent: self.rtt_variation_percent(),
325            packet_loss_percent: self.packet_loss_percent(),
326            congestion_state: self.congestion_state,
327            sample_count: self.samples.len(),
328            is_reliable: self.is_reliable(),
329            total_bytes: self.total_bytes,
330            total_transfers: self.total_transfers,
331        }
332    }
333
334    /// Reset the estimator.
335    pub fn reset(&mut self) {
336        self.samples.clear();
337        self.rtt_samples.clear();
338        self.estimate_mbps = 0.0;
339        self.congestion_state = CongestionState::Normal;
340        self.total_bytes = 0;
341        self.total_transfers = 0;
342        self.loss_count = 0;
343        self.packet_count = 0;
344    }
345
346    /// Prune old samples outside the congestion window.
347    pub fn prune_old_samples(&mut self) {
348        let cutoff = Instant::now() - Duration::from_millis(self.config.congestion_window_ms);
349
350        while let Some(sample) = self.samples.front() {
351            if sample.timestamp < cutoff {
352                self.samples.pop_front();
353            } else {
354                break;
355            }
356        }
357    }
358
359    /// Update congestion state based on recent measurements.
360    fn update_congestion_state(&mut self) {
361        // Prune old samples first
362        self.prune_old_samples();
363
364        let loss_percent = self.packet_loss_percent();
365        let rtt_var_percent = self.rtt_variation_percent();
366
367        // Determine congestion state
368        let loss_congested = loss_percent > self.config.loss_threshold_percent;
369        let rtt_congested = rtt_var_percent > self.config.rtt_var_threshold_percent;
370
371        self.congestion_state = if loss_percent > 15.0 || rtt_var_percent > 100.0 {
372            CongestionState::Heavy
373        } else if loss_percent > 10.0 || rtt_var_percent > 75.0 {
374            CongestionState::Moderate
375        } else if loss_congested || rtt_congested {
376            CongestionState::Light
377        } else {
378            CongestionState::Normal
379        };
380    }
381}
382
383/// Bandwidth estimation statistics.
384#[derive(Debug, Clone, Serialize, Deserialize)]
385pub struct BandwidthStats {
386    /// Current bandwidth estimate in Mbps.
387    pub current_estimate_mbps: f64,
388    /// Minimum observed bandwidth in Mbps.
389    pub min_bandwidth_mbps: f64,
390    /// Maximum observed bandwidth in Mbps.
391    pub max_bandwidth_mbps: f64,
392    /// Average RTT in milliseconds.
393    pub avg_rtt_ms: Option<f64>,
394    /// RTT variation percentage.
395    pub rtt_variation_percent: f64,
396    /// Packet loss percentage.
397    pub packet_loss_percent: f64,
398    /// Current congestion state.
399    pub congestion_state: CongestionState,
400    /// Number of samples collected.
401    pub sample_count: usize,
402    /// Whether estimate is reliable.
403    pub is_reliable: bool,
404    /// Total bytes transferred.
405    pub total_bytes: u64,
406    /// Total number of transfers.
407    pub total_transfers: u64,
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    #[test]
415    fn test_bandwidth_estimator() {
416        let config = EstimatorConfig::default();
417        let mut estimator = BandwidthEstimator::new(config);
418
419        // Record a 1 MB (decimal) transfer in 100ms
420        // Using 1,000,000 bytes for decimal MB (not 1024 * 1024 for binary MiB)
421        estimator.record_transfer(1_000_000, 100);
422
423        let estimate = estimator.estimate_mbps();
424        // 1 MB / 100ms = 10 MB/s = 80 Mbps
425        assert!((estimate - 80.0).abs() < 1.0);
426    }
427
428    #[test]
429    fn test_ewma_smoothing() {
430        let config = EstimatorConfig {
431            alpha: 0.5,
432            ..Default::default()
433        };
434        let mut estimator = BandwidthEstimator::new(config);
435
436        // First sample: 100 Mbps (12.5 MB in 1000ms = 12.5 * 8 = 100 Mbps)
437        // Using decimal bytes: 100 Mbps = 12,500,000 bytes/sec = 12,500,000 bytes in 1000ms
438        estimator.record_transfer(12_500_000, 1000);
439        assert!((estimator.estimate_mbps() - 100.0).abs() < 1.0);
440
441        // Second sample: 50 Mbps (6.25 MB in 1000ms)
442        estimator.record_transfer(6_250_000, 1000);
443        // EWMA: 0.5 * 50 + 0.5 * 100 = 75
444        assert!((estimator.estimate_mbps() - 75.0).abs() < 1.0);
445    }
446
447    #[test]
448    fn test_congestion_detection() {
449        let config = EstimatorConfig::default();
450        let mut estimator = BandwidthEstimator::new(config);
451
452        // Record transfers without packet loss
453        for _ in 0..10 {
454            estimator.record_transfer_with_rtt(1024 * 1024, 100, Some(50.0), false);
455        }
456        assert_eq!(estimator.congestion_state(), CongestionState::Normal);
457
458        // Record transfers with packet loss
459        for _ in 0..10 {
460            estimator.record_transfer_with_rtt(1024 * 1024, 100, Some(50.0), true);
461        }
462        assert!(estimator.is_congested());
463    }
464
465    #[test]
466    fn test_packet_loss_calculation() {
467        let config = EstimatorConfig::default();
468        let mut estimator = BandwidthEstimator::new(config);
469
470        // 3 successful, 1 failed = 25% loss
471        for _ in 0..3 {
472            estimator.record_transfer_with_rtt(1024, 10, None, false);
473        }
474        estimator.record_transfer_with_rtt(1024, 10, None, true);
475
476        assert!((estimator.packet_loss_percent() - 25.0).abs() < 0.1);
477    }
478
479    #[test]
480    fn test_recommended_rate() {
481        let config = EstimatorConfig::default();
482        let mut estimator = BandwidthEstimator::new(config);
483
484        // Establish baseline
485        estimator.record_transfer(1024 * 1024, 100); // 80 Mbps
486        let normal_rate = estimator.recommended_rate_bps();
487
488        // Simulate congestion
489        for _ in 0..10 {
490            estimator.record_transfer_with_rtt(1024 * 1024, 100, Some(50.0), true);
491        }
492
493        let congested_rate = estimator.recommended_rate_bps();
494        assert!(congested_rate < normal_rate);
495    }
496
497    #[test]
498    fn test_reliability() {
499        let config = EstimatorConfig {
500            min_samples: 3,
501            ..Default::default()
502        };
503        let mut estimator = BandwidthEstimator::new(config);
504
505        assert!(!estimator.is_reliable());
506
507        estimator.record_transfer(1024, 10);
508        estimator.record_transfer(1024, 10);
509        assert!(!estimator.is_reliable());
510
511        estimator.record_transfer(1024, 10);
512        assert!(estimator.is_reliable());
513    }
514
515    #[test]
516    fn test_reset() {
517        let config = EstimatorConfig::default();
518        let mut estimator = BandwidthEstimator::new(config);
519
520        estimator.record_transfer(1024 * 1024, 100);
521        assert!(estimator.estimate_mbps() > 0.0);
522
523        estimator.reset();
524        assert_eq!(estimator.estimate_mbps(), 0.0);
525        assert_eq!(estimator.total_bytes, 0);
526        assert_eq!(estimator.congestion_state(), CongestionState::Normal);
527    }
528}