ipfrs_transport/
throttle.rs

1//! Bandwidth throttling for rate limiting
2//!
3//! Provides token bucket and leaky bucket algorithms for:
4//! - Per-peer rate limits
5//! - Global bandwidth caps
6//! - QoS prioritization
7//! - Burst allowance
8//!
9//! # Example
10//!
11//! ```
12//! use ipfrs_transport::{TokenBucket, BandwidthThrottle, BandwidthConfig, QosPriority};
13//!
14//! // Create a token bucket with 1000 token capacity, refilling at 100 tokens/sec
15//! let bucket = TokenBucket::new(1000, 100.0);
16//!
17//! // Try to consume 50 tokens
18//! if bucket.try_consume(50) {
19//!     println!("Request allowed");
20//! } else {
21//!     println!("Rate limit exceeded");
22//! }
23//!
24//! // Create a bandwidth throttle with global limits
25//! let mut config = BandwidthConfig::default();
26//! config.global_upload_limit = 10_000_000; // 10 MB/s
27//! config.global_download_limit = 10_000_000;
28//! let throttle = BandwidthThrottle::new(config);
29//!
30//! // Register a peer with high priority
31//! let peer_addr = "127.0.0.1:8080".parse().unwrap();
32//! throttle.register_peer(peer_addr, QosPriority::High);
33//! ```
34
35use parking_lot::RwLock;
36use std::collections::HashMap;
37use std::net::SocketAddr;
38use std::sync::Arc;
39use std::time::{Duration, Instant};
40use thiserror::Error;
41use tokio::time::sleep;
42use tracing::debug;
43
44/// Throttle error types
45#[derive(Error, Debug)]
46pub enum ThrottleError {
47    #[error("Rate limit exceeded")]
48    RateLimitExceeded,
49
50    #[error("Quota exhausted")]
51    QuotaExhausted,
52
53    #[error("Peer not found: {0}")]
54    PeerNotFound(String),
55}
56
57/// Token bucket for rate limiting
58#[derive(Debug, Clone)]
59pub struct TokenBucket {
60    /// Maximum number of tokens (bucket capacity)
61    capacity: u64,
62    /// Current number of tokens
63    tokens: Arc<RwLock<f64>>,
64    /// Token refill rate (tokens per second)
65    refill_rate: f64,
66    /// Last refill time
67    last_refill: Arc<RwLock<Instant>>,
68}
69
70impl TokenBucket {
71    /// Create a new token bucket
72    pub fn new(capacity: u64, refill_rate: f64) -> Self {
73        Self {
74            capacity,
75            tokens: Arc::new(RwLock::new(capacity as f64)),
76            refill_rate,
77            last_refill: Arc::new(RwLock::new(Instant::now())),
78        }
79    }
80
81    /// Try to consume tokens (non-blocking)
82    pub fn try_consume(&self, amount: u64) -> bool {
83        self.refill();
84
85        let mut tokens = self.tokens.write();
86        if *tokens >= amount as f64 {
87            *tokens -= amount as f64;
88            true
89        } else {
90            false
91        }
92    }
93
94    /// Consume tokens (blocking until available)
95    pub async fn consume(&self, amount: u64) {
96        loop {
97            if self.try_consume(amount) {
98                return;
99            }
100
101            // Calculate wait time
102            let tokens = *self.tokens.read();
103            let needed = amount as f64 - tokens;
104            let wait_time = Duration::from_secs_f64(needed / self.refill_rate);
105
106            sleep(wait_time.min(Duration::from_millis(100))).await;
107        }
108    }
109
110    /// Refill tokens based on elapsed time
111    fn refill(&self) {
112        let now = Instant::now();
113        let mut last_refill = self.last_refill.write();
114        let elapsed = now.duration_since(*last_refill).as_secs_f64();
115
116        if elapsed > 0.0 {
117            let new_tokens = elapsed * self.refill_rate;
118            let mut tokens = self.tokens.write();
119            *tokens = (*tokens + new_tokens).min(self.capacity as f64);
120            *last_refill = now;
121        }
122    }
123
124    /// Get current token count
125    pub fn available_tokens(&self) -> u64 {
126        self.refill();
127        *self.tokens.read() as u64
128    }
129
130    /// Get capacity
131    pub fn capacity(&self) -> u64 {
132        self.capacity
133    }
134
135    /// Get refill rate
136    pub fn refill_rate(&self) -> f64 {
137        self.refill_rate
138    }
139}
140
141/// QoS priority levels for bandwidth allocation
142#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
143pub enum QosPriority {
144    /// Best effort (lowest priority)
145    BestEffort = 0,
146    /// Normal priority
147    Normal = 1,
148    /// High priority
149    High = 2,
150    /// Critical (highest priority)
151    Critical = 3,
152}
153
154impl QosPriority {
155    /// Get bandwidth share multiplier
156    pub fn multiplier(&self) -> f64 {
157        match self {
158            QosPriority::BestEffort => 0.5,
159            QosPriority::Normal => 1.0,
160            QosPriority::High => 2.0,
161            QosPriority::Critical => 4.0,
162        }
163    }
164}
165
166/// Bandwidth throttle configuration
167#[derive(Debug, Clone)]
168pub struct BandwidthConfig {
169    /// Global upload limit (bytes per second, 0 = unlimited)
170    pub global_upload_limit: u64,
171    /// Global download limit (bytes per second, 0 = unlimited)
172    pub global_download_limit: u64,
173    /// Per-peer upload limit (bytes per second, 0 = unlimited)
174    pub peer_upload_limit: u64,
175    /// Per-peer download limit (bytes per second, 0 = unlimited)
176    pub peer_download_limit: u64,
177    /// Allow burst transfers (use full bucket capacity)
178    pub allow_burst: bool,
179    /// Burst capacity multiplier
180    pub burst_multiplier: f64,
181}
182
183impl Default for BandwidthConfig {
184    fn default() -> Self {
185        Self {
186            global_upload_limit: 0,   // Unlimited
187            global_download_limit: 0, // Unlimited
188            peer_upload_limit: 0,     // Unlimited
189            peer_download_limit: 0,   // Unlimited
190            allow_burst: true,
191            burst_multiplier: 2.0,
192        }
193    }
194}
195
196/// Statistics for bandwidth throttling
197#[derive(Debug, Clone, Default)]
198pub struct ThrottleStats {
199    /// Total bytes uploaded
200    pub bytes_uploaded: u64,
201    /// Total bytes downloaded
202    pub bytes_downloaded: u64,
203    /// Number of times throttled
204    pub throttle_count: u64,
205    /// Total time spent waiting
206    pub total_wait_time: Duration,
207    /// Current upload rate (bytes/sec)
208    pub current_upload_rate: f64,
209    /// Current download rate (bytes/sec)
210    pub current_download_rate: f64,
211}
212
213/// Per-peer throttle state
214struct PeerThrottle {
215    upload: Option<TokenBucket>,
216    download: Option<TokenBucket>,
217    priority: QosPriority,
218}
219
220/// Bandwidth throttle manager
221pub struct BandwidthThrottle {
222    config: BandwidthConfig,
223    global_upload: Option<Arc<TokenBucket>>,
224    global_download: Option<Arc<TokenBucket>>,
225    peer_throttles: Arc<RwLock<HashMap<SocketAddr, PeerThrottle>>>,
226    stats: Arc<RwLock<ThrottleStats>>,
227}
228
229impl BandwidthThrottle {
230    /// Create a new bandwidth throttle
231    pub fn new(config: BandwidthConfig) -> Self {
232        let burst_capacity_multiplier = if config.allow_burst {
233            config.burst_multiplier
234        } else {
235            1.0
236        };
237
238        let global_upload = if config.global_upload_limit > 0 {
239            Some(Arc::new(TokenBucket::new(
240                (config.global_upload_limit as f64 * burst_capacity_multiplier) as u64,
241                config.global_upload_limit as f64,
242            )))
243        } else {
244            None
245        };
246
247        let global_download = if config.global_download_limit > 0 {
248            Some(Arc::new(TokenBucket::new(
249                (config.global_download_limit as f64 * burst_capacity_multiplier) as u64,
250                config.global_download_limit as f64,
251            )))
252        } else {
253            None
254        };
255
256        Self {
257            config,
258            global_upload,
259            global_download,
260            peer_throttles: Arc::new(RwLock::new(HashMap::new())),
261            stats: Arc::new(RwLock::new(ThrottleStats::default())),
262        }
263    }
264
265    /// Register a peer with optional priority
266    pub fn register_peer(&self, addr: SocketAddr, priority: QosPriority) {
267        let burst_multiplier = if self.config.allow_burst {
268            self.config.burst_multiplier
269        } else {
270            1.0
271        };
272
273        let upload = if self.config.peer_upload_limit > 0 {
274            let rate = self.config.peer_upload_limit as f64 * priority.multiplier();
275            Some(TokenBucket::new((rate * burst_multiplier) as u64, rate))
276        } else {
277            None
278        };
279
280        let download = if self.config.peer_download_limit > 0 {
281            let rate = self.config.peer_download_limit as f64 * priority.multiplier();
282            Some(TokenBucket::new((rate * burst_multiplier) as u64, rate))
283        } else {
284            None
285        };
286
287        let throttle = PeerThrottle {
288            upload,
289            download,
290            priority,
291        };
292
293        self.peer_throttles.write().insert(addr, throttle);
294        debug!("Registered peer {} with priority {:?}", addr, priority);
295    }
296
297    /// Unregister a peer
298    pub fn unregister_peer(&self, addr: &SocketAddr) {
299        self.peer_throttles.write().remove(addr);
300        debug!("Unregistered peer {}", addr);
301    }
302
303    /// Throttle upload (wait until bandwidth available)
304    pub async fn throttle_upload(&self, addr: &SocketAddr, bytes: u64) {
305        let start = Instant::now();
306
307        // Global throttle
308        if let Some(global) = &self.global_upload {
309            global.consume(bytes).await;
310        }
311
312        // Per-peer throttle
313        let upload_bucket = {
314            let peer_throttles = self.peer_throttles.read();
315            peer_throttles
316                .get(addr)
317                .and_then(|peer| peer.upload.clone())
318        };
319
320        if let Some(upload) = upload_bucket {
321            upload.consume(bytes).await;
322        }
323
324        // Update stats
325        {
326            let mut stats = self.stats.write();
327            stats.bytes_uploaded += bytes;
328            let wait_time = start.elapsed();
329            if wait_time > Duration::from_millis(1) {
330                stats.throttle_count += 1;
331                stats.total_wait_time += wait_time;
332            }
333        }
334    }
335
336    /// Throttle download (wait until bandwidth available)
337    pub async fn throttle_download(&self, addr: &SocketAddr, bytes: u64) {
338        let start = Instant::now();
339
340        // Global throttle
341        if let Some(global) = &self.global_download {
342            global.consume(bytes).await;
343        }
344
345        // Per-peer throttle
346        let download_bucket = {
347            let peer_throttles = self.peer_throttles.read();
348            peer_throttles
349                .get(addr)
350                .and_then(|peer| peer.download.clone())
351        };
352
353        if let Some(download) = download_bucket {
354            download.consume(bytes).await;
355        }
356
357        // Update stats
358        {
359            let mut stats = self.stats.write();
360            stats.bytes_downloaded += bytes;
361            let wait_time = start.elapsed();
362            if wait_time > Duration::from_millis(1) {
363                stats.throttle_count += 1;
364                stats.total_wait_time += wait_time;
365            }
366        }
367    }
368
369    /// Try to throttle upload (non-blocking, returns false if would block)
370    pub fn try_throttle_upload(&self, addr: &SocketAddr, bytes: u64) -> bool {
371        // Global throttle
372        if let Some(global) = &self.global_upload {
373            if !global.try_consume(bytes) {
374                return false;
375            }
376        }
377
378        // Per-peer throttle
379        let peer_throttles = self.peer_throttles.read();
380        if let Some(peer) = peer_throttles.get(addr) {
381            if let Some(upload) = &peer.upload {
382                if !upload.try_consume(bytes) {
383                    return false;
384                }
385            }
386        }
387
388        // Update stats
389        self.stats.write().bytes_uploaded += bytes;
390        true
391    }
392
393    /// Try to throttle download (non-blocking)
394    pub fn try_throttle_download(&self, addr: &SocketAddr, bytes: u64) -> bool {
395        // Global throttle
396        if let Some(global) = &self.global_download {
397            if !global.try_consume(bytes) {
398                return false;
399            }
400        }
401
402        // Per-peer throttle
403        let peer_throttles = self.peer_throttles.read();
404        if let Some(peer) = peer_throttles.get(addr) {
405            if let Some(download) = &peer.download {
406                if !download.try_consume(bytes) {
407                    return false;
408                }
409            }
410        }
411
412        // Update stats
413        self.stats.write().bytes_downloaded += bytes;
414        true
415    }
416
417    /// Update peer priority
418    pub fn update_peer_priority(&self, addr: &SocketAddr, priority: QosPriority) {
419        let mut peer_throttles = self.peer_throttles.write();
420        if let Some(peer) = peer_throttles.get_mut(addr) {
421            peer.priority = priority;
422            debug!("Updated peer {} priority to {:?}", addr, priority);
423        }
424    }
425
426    /// Get statistics
427    pub fn stats(&self) -> ThrottleStats {
428        self.stats.read().clone()
429    }
430
431    /// Reset statistics
432    pub fn reset_stats(&self) {
433        *self.stats.write() = ThrottleStats::default();
434    }
435
436    /// Get available upload bandwidth
437    pub fn available_upload_bandwidth(&self) -> Option<u64> {
438        self.global_upload.as_ref().map(|b| b.available_tokens())
439    }
440
441    /// Get available download bandwidth
442    pub fn available_download_bandwidth(&self) -> Option<u64> {
443        self.global_download.as_ref().map(|b| b.available_tokens())
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    #[test]
452    fn test_token_bucket() {
453        let bucket = TokenBucket::new(100, 10.0);
454
455        assert_eq!(bucket.available_tokens(), 100);
456        assert!(bucket.try_consume(50));
457        assert_eq!(bucket.available_tokens(), 50);
458        assert!(bucket.try_consume(50));
459        assert_eq!(bucket.available_tokens(), 0);
460        assert!(!bucket.try_consume(1));
461    }
462
463    #[tokio::test]
464    async fn test_token_bucket_refill() {
465        let bucket = TokenBucket::new(100, 100.0); // 100 tokens/sec
466
467        bucket.try_consume(100);
468        assert_eq!(bucket.available_tokens(), 0);
469
470        tokio::time::sleep(Duration::from_millis(500)).await;
471
472        // Should have refilled ~50 tokens
473        let available = bucket.available_tokens();
474        assert!((45..=55).contains(&available), "Got {} tokens", available);
475    }
476
477    #[test]
478    fn test_qos_priority() {
479        assert_eq!(QosPriority::BestEffort.multiplier(), 0.5);
480        assert_eq!(QosPriority::Normal.multiplier(), 1.0);
481        assert_eq!(QosPriority::High.multiplier(), 2.0);
482        assert_eq!(QosPriority::Critical.multiplier(), 4.0);
483    }
484
485    #[test]
486    fn test_bandwidth_config_default() {
487        let config = BandwidthConfig::default();
488        assert_eq!(config.global_upload_limit, 0);
489        assert!(config.allow_burst);
490        assert_eq!(config.burst_multiplier, 2.0);
491    }
492
493    #[tokio::test]
494    async fn test_bandwidth_throttle() {
495        let config = BandwidthConfig {
496            global_upload_limit: 1000,   // 1000 bytes/sec
497            global_download_limit: 2000, // 2000 bytes/sec
498            peer_upload_limit: 0,
499            peer_download_limit: 0,
500            allow_burst: false,
501            burst_multiplier: 1.0,
502        };
503
504        let throttle = BandwidthThrottle::new(config);
505        let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
506
507        throttle.register_peer(addr, QosPriority::Normal);
508
509        // Should be able to upload immediately
510        assert!(throttle.try_throttle_upload(&addr, 500));
511
512        // Stats should be updated
513        let stats = throttle.stats();
514        assert_eq!(stats.bytes_uploaded, 500);
515    }
516
517    #[tokio::test]
518    async fn test_peer_priority() {
519        let config = BandwidthConfig {
520            global_upload_limit: 0,
521            global_download_limit: 0,
522            peer_upload_limit: 1000, // 1000 bytes/sec
523            peer_download_limit: 0,
524            allow_burst: false,
525            burst_multiplier: 1.0,
526        };
527
528        let throttle = BandwidthThrottle::new(config);
529        let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
530
531        // High priority peer gets 2x bandwidth
532        throttle.register_peer(addr, QosPriority::High);
533
534        // Should be able to upload more due to higher priority
535        assert!(throttle.try_throttle_upload(&addr, 1000));
536    }
537}