chie_core/
ratelimit.rs

1//! Bandwidth rate limiting for P2P transfers.
2//!
3//! This module provides token bucket based rate limiting for controlling
4//! upload and download bandwidth usage.
5
6use std::sync::Arc;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10use tokio::time::sleep;
11
12/// Rate limiter configuration.
13#[derive(Debug, Clone)]
14pub struct RateLimitConfig {
15    /// Maximum upload rate in bytes per second (0 = unlimited).
16    pub upload_rate: u64,
17    /// Maximum download rate in bytes per second (0 = unlimited).
18    pub download_rate: u64,
19    /// Burst size multiplier (allows temporary bursts above the rate).
20    pub burst_multiplier: f64,
21    /// Minimum bytes to transfer before rate limiting kicks in.
22    pub min_transfer_size: u64,
23    /// Whether to enable rate limiting.
24    pub enabled: bool,
25}
26
27impl Default for RateLimitConfig {
28    fn default() -> Self {
29        Self {
30            upload_rate: 0,          // Unlimited by default
31            download_rate: 0,        // Unlimited by default
32            burst_multiplier: 1.5,   // Allow 50% burst
33            min_transfer_size: 1024, // Don't rate limit small transfers
34            enabled: true,
35        }
36    }
37}
38
39impl RateLimitConfig {
40    /// Create config with specific upload/download rates.
41    #[must_use]
42    #[inline]
43    pub fn with_rates(upload_mbps: f64, download_mbps: f64) -> Self {
44        Self {
45            upload_rate: (upload_mbps * 1_000_000.0 / 8.0) as u64,
46            download_rate: (download_mbps * 1_000_000.0 / 8.0) as u64,
47            ..Default::default()
48        }
49    }
50
51    /// Create config for symmetric rate.
52    #[must_use]
53    #[inline]
54    pub fn symmetric(rate_mbps: f64) -> Self {
55        Self::with_rates(rate_mbps, rate_mbps)
56    }
57
58    /// Disable rate limiting.
59    #[must_use]
60    #[inline]
61    pub fn unlimited() -> Self {
62        Self {
63            enabled: false,
64            ..Default::default()
65        }
66    }
67}
68
69/// Token bucket for rate limiting.
70struct TokenBucket {
71    /// Current tokens (bytes allowed).
72    tokens: AtomicU64,
73    /// Maximum tokens (burst capacity).
74    max_tokens: u64,
75    /// Tokens added per second.
76    rate: u64,
77    /// Last refill time.
78    last_refill: RwLock<Instant>,
79}
80
81impl TokenBucket {
82    fn new(rate: u64, burst_multiplier: f64) -> Self {
83        let max_tokens = (rate as f64 * burst_multiplier) as u64;
84        Self {
85            tokens: AtomicU64::new(max_tokens),
86            max_tokens,
87            rate,
88            last_refill: RwLock::new(Instant::now()),
89        }
90    }
91
92    async fn refill(&self) {
93        let mut last = self.last_refill.write().await;
94        let now = Instant::now();
95        let elapsed = now.duration_since(*last);
96
97        if elapsed.as_millis() > 0 {
98            let new_tokens = (elapsed.as_secs_f64() * self.rate as f64) as u64;
99            let current = self.tokens.load(Ordering::Relaxed);
100            let updated = current.saturating_add(new_tokens).min(self.max_tokens);
101            self.tokens.store(updated, Ordering::Relaxed);
102            *last = now;
103        }
104    }
105
106    async fn consume(&self, bytes: u64) -> Duration {
107        self.refill().await;
108
109        let current = self.tokens.load(Ordering::Relaxed);
110
111        if current >= bytes {
112            self.tokens.fetch_sub(bytes, Ordering::Relaxed);
113            Duration::ZERO
114        } else {
115            // Not enough tokens - calculate wait time
116            let needed = bytes.saturating_sub(current);
117            let wait_secs = needed as f64 / self.rate as f64;
118            Duration::from_secs_f64(wait_secs)
119        }
120    }
121
122    fn available(&self) -> u64 {
123        self.tokens.load(Ordering::Relaxed)
124    }
125}
126
127/// Bandwidth rate limiter.
128pub struct BandwidthLimiter {
129    config: RateLimitConfig,
130    upload_bucket: Option<TokenBucket>,
131    download_bucket: Option<TokenBucket>,
132    stats: Arc<RwLock<BandwidthStats>>,
133}
134
135/// Bandwidth usage statistics.
136#[derive(Debug, Clone, Default)]
137pub struct BandwidthStats {
138    /// Total bytes uploaded.
139    pub bytes_uploaded: u64,
140    /// Total bytes downloaded.
141    pub bytes_downloaded: u64,
142    /// Upload rate (bytes/sec, rolling average).
143    pub upload_rate: f64,
144    /// Download rate (bytes/sec, rolling average).
145    pub download_rate: f64,
146    /// Time spent waiting due to rate limiting.
147    pub total_wait_time: Duration,
148    /// Number of transfers that were rate limited.
149    pub limited_transfers: u64,
150    /// Start time for stats.
151    pub started_at: Option<Instant>,
152}
153
154impl BandwidthStats {
155    fn new() -> Self {
156        Self {
157            started_at: Some(Instant::now()),
158            ..Default::default()
159        }
160    }
161
162    fn update_rates(&mut self) {
163        if let Some(start) = self.started_at {
164            let elapsed = start.elapsed().as_secs_f64();
165            if elapsed > 0.0 {
166                self.upload_rate = self.bytes_uploaded as f64 / elapsed;
167                self.download_rate = self.bytes_downloaded as f64 / elapsed;
168            }
169        }
170    }
171}
172
173impl BandwidthLimiter {
174    /// Create a new bandwidth limiter.
175    #[must_use]
176    #[inline]
177    pub fn new(config: RateLimitConfig) -> Self {
178        let upload_bucket = if config.enabled && config.upload_rate > 0 {
179            Some(TokenBucket::new(
180                config.upload_rate,
181                config.burst_multiplier,
182            ))
183        } else {
184            None
185        };
186
187        let download_bucket = if config.enabled && config.download_rate > 0 {
188            Some(TokenBucket::new(
189                config.download_rate,
190                config.burst_multiplier,
191            ))
192        } else {
193            None
194        };
195
196        Self {
197            config,
198            upload_bucket,
199            download_bucket,
200            stats: Arc::new(RwLock::new(BandwidthStats::new())),
201        }
202    }
203
204    /// Rate limit an upload operation.
205    ///
206    /// Returns when the transfer is allowed to proceed.
207    pub async fn limit_upload(&self, bytes: u64) {
208        if !self.config.enabled || bytes < self.config.min_transfer_size {
209            return;
210        }
211
212        if let Some(ref bucket) = self.upload_bucket {
213            let wait = bucket.consume(bytes).await;
214            if !wait.is_zero() {
215                let mut stats = self.stats.write().await;
216                stats.total_wait_time += wait;
217                stats.limited_transfers += 1;
218                drop(stats);
219
220                sleep(wait).await;
221            }
222
223            let mut stats = self.stats.write().await;
224            stats.bytes_uploaded += bytes;
225            stats.update_rates();
226        }
227    }
228
229    /// Rate limit a download operation.
230    pub async fn limit_download(&self, bytes: u64) {
231        if !self.config.enabled || bytes < self.config.min_transfer_size {
232            return;
233        }
234
235        if let Some(ref bucket) = self.download_bucket {
236            let wait = bucket.consume(bytes).await;
237            if !wait.is_zero() {
238                let mut stats = self.stats.write().await;
239                stats.total_wait_time += wait;
240                stats.limited_transfers += 1;
241                drop(stats);
242
243                sleep(wait).await;
244            }
245
246            let mut stats = self.stats.write().await;
247            stats.bytes_downloaded += bytes;
248            stats.update_rates();
249        }
250    }
251
252    /// Record bytes transferred without rate limiting (for stats only).
253    pub async fn record_upload(&self, bytes: u64) {
254        let mut stats = self.stats.write().await;
255        stats.bytes_uploaded += bytes;
256        stats.update_rates();
257    }
258
259    /// Record bytes transferred without rate limiting (for stats only).
260    pub async fn record_download(&self, bytes: u64) {
261        let mut stats = self.stats.write().await;
262        stats.bytes_downloaded += bytes;
263        stats.update_rates();
264    }
265
266    /// Get current bandwidth statistics.
267    #[must_use]
268    pub async fn stats(&self) -> BandwidthStats {
269        self.stats.read().await.clone()
270    }
271
272    /// Get available upload tokens (for flow control decisions).
273    #[must_use]
274    #[inline]
275    pub fn available_upload(&self) -> Option<u64> {
276        self.upload_bucket.as_ref().map(|b| b.available())
277    }
278
279    /// Get available download tokens.
280    #[must_use]
281    #[inline]
282    pub fn available_download(&self) -> Option<u64> {
283        self.download_bucket.as_ref().map(|b| b.available())
284    }
285
286    /// Check if rate limiting is enabled.
287    #[must_use]
288    #[inline]
289    pub fn is_enabled(&self) -> bool {
290        self.config.enabled
291    }
292
293    /// Get configured upload rate.
294    #[must_use]
295    #[inline]
296    pub fn upload_rate(&self) -> u64 {
297        self.config.upload_rate
298    }
299
300    /// Get configured download rate.
301    #[must_use]
302    #[inline]
303    pub fn download_rate(&self) -> u64 {
304        self.config.download_rate
305    }
306}
307
308/// Per-peer rate limiter for fair bandwidth distribution.
309pub struct PeerRateLimiter {
310    /// Global limiter.
311    global: Arc<BandwidthLimiter>,
312    /// Per-peer limiters.
313    peer_limiters: RwLock<std::collections::HashMap<String, Arc<BandwidthLimiter>>>,
314    /// Per-peer rate (fraction of global).
315    peer_rate_fraction: f64,
316}
317
318impl PeerRateLimiter {
319    /// Create a new per-peer rate limiter.
320    #[must_use]
321    #[inline]
322    pub fn new(global_config: RateLimitConfig, peer_rate_fraction: f64) -> Self {
323        Self {
324            global: Arc::new(BandwidthLimiter::new(global_config)),
325            peer_limiters: RwLock::new(std::collections::HashMap::new()),
326            peer_rate_fraction,
327        }
328    }
329
330    /// Get or create a rate limiter for a peer.
331    #[must_use]
332    pub async fn get_peer_limiter(&self, peer_id: &str) -> Arc<BandwidthLimiter> {
333        {
334            let limiters = self.peer_limiters.read().await;
335            if let Some(limiter) = limiters.get(peer_id) {
336                return Arc::clone(limiter);
337            }
338        }
339
340        let peer_config = RateLimitConfig {
341            upload_rate: (self.global.upload_rate() as f64 * self.peer_rate_fraction) as u64,
342            download_rate: (self.global.download_rate() as f64 * self.peer_rate_fraction) as u64,
343            burst_multiplier: 2.0, // Allow more burst per-peer
344            min_transfer_size: 512,
345            enabled: self.global.is_enabled(),
346        };
347
348        let limiter = Arc::new(BandwidthLimiter::new(peer_config));
349
350        let mut limiters = self.peer_limiters.write().await;
351        limiters.insert(peer_id.to_string(), Arc::clone(&limiter));
352
353        limiter
354    }
355
356    /// Rate limit an upload to a specific peer.
357    pub async fn limit_upload(&self, peer_id: &str, bytes: u64) {
358        // Apply both global and per-peer limits
359        self.global.limit_upload(bytes).await;
360
361        let peer_limiter = self.get_peer_limiter(peer_id).await;
362        peer_limiter.limit_upload(bytes).await;
363    }
364
365    /// Rate limit a download from a specific peer.
366    pub async fn limit_download(&self, peer_id: &str, bytes: u64) {
367        self.global.limit_download(bytes).await;
368
369        let peer_limiter = self.get_peer_limiter(peer_id).await;
370        peer_limiter.limit_download(bytes).await;
371    }
372
373    /// Get global statistics.
374    #[must_use]
375    pub async fn global_stats(&self) -> BandwidthStats {
376        self.global.stats().await
377    }
378
379    /// Get statistics for a specific peer.
380    #[must_use]
381    pub async fn peer_stats(&self, peer_id: &str) -> Option<BandwidthStats> {
382        let limiters = self.peer_limiters.read().await;
383        if let Some(limiter) = limiters.get(peer_id) {
384            Some(limiter.stats().await)
385        } else {
386            None
387        }
388    }
389
390    /// Remove a peer's limiter (when disconnected).
391    pub async fn remove_peer(&self, peer_id: &str) {
392        let mut limiters = self.peer_limiters.write().await;
393        limiters.remove(peer_id);
394    }
395
396    /// Get number of tracked peers.
397    #[must_use]
398    pub async fn peer_count(&self) -> usize {
399        self.peer_limiters.read().await.len()
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn test_config_default() {
409        let config = RateLimitConfig::default();
410        assert!(config.enabled);
411        assert_eq!(config.upload_rate, 0);
412        assert_eq!(config.download_rate, 0);
413    }
414
415    #[test]
416    fn test_config_with_rates() {
417        let config = RateLimitConfig::with_rates(100.0, 50.0); // 100 Mbps up, 50 Mbps down
418        assert_eq!(config.upload_rate, 12_500_000); // 100 Mbps = 12.5 MB/s
419        assert_eq!(config.download_rate, 6_250_000); // 50 Mbps = 6.25 MB/s
420    }
421
422    #[tokio::test]
423    async fn test_unlimited_limiter() {
424        let config = RateLimitConfig::unlimited();
425        let limiter = BandwidthLimiter::new(config);
426
427        // Should not block at all
428        let start = Instant::now();
429        limiter.limit_upload(10_000_000).await; // 10 MB
430        limiter.limit_download(10_000_000).await;
431        assert!(start.elapsed() < Duration::from_millis(10));
432    }
433
434    #[tokio::test]
435    async fn test_stats_recording() {
436        let config = RateLimitConfig::unlimited();
437        let limiter = BandwidthLimiter::new(config);
438
439        limiter.record_upload(1000).await;
440        limiter.record_download(2000).await;
441
442        let stats = limiter.stats().await;
443        assert_eq!(stats.bytes_uploaded, 1000);
444        assert_eq!(stats.bytes_downloaded, 2000);
445    }
446
447    #[tokio::test]
448    async fn test_peer_rate_limiter() {
449        let global_config = RateLimitConfig::unlimited();
450        let peer_limiter = PeerRateLimiter::new(global_config, 0.25);
451
452        // Get limiter for a peer
453        let _limiter = peer_limiter.get_peer_limiter("peer1").await;
454        assert_eq!(peer_limiter.peer_count().await, 1);
455
456        // Remove peer
457        peer_limiter.remove_peer("peer1").await;
458        assert_eq!(peer_limiter.peer_count().await, 0);
459    }
460}