ipfrs_network/
rate_limiter.rs

1//! Connection rate limiting for preventing connection storms and resource exhaustion.
2//!
3//! This module provides sophisticated rate limiting capabilities to control the rate
4//! of connection establishment, preventing connection storms, respecting peer limits,
5//! and protecting against resource exhaustion.
6//!
7//! # Features
8//!
9//! - **Token Bucket Algorithm**: Classic token bucket with configurable rate and burst
10//! - **Per-Peer Limits**: Individual rate limits for each peer
11//! - **Global Limits**: System-wide connection rate limits
12//! - **Priority-based Limiting**: Different limits for different priority levels
13//! - **Adaptive Rate Limiting**: Adjust rates based on success/failure patterns
14//! - **Backpressure Support**: Queue connections when rate limit is exceeded
15//!
16//! # Example
17//!
18//! ```rust
19//! use ipfrs_network::rate_limiter::{ConnectionRateLimiter, RateLimiterConfig};
20//! use std::time::Duration;
21//!
22//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
23//! // Create rate limiter allowing 10 connections per second with burst of 20
24//! let mut limiter = ConnectionRateLimiter::new(RateLimiterConfig {
25//!     max_rate: 10.0,
26//!     burst_size: 20,
27//!     enable_per_peer_limits: true,
28//!     ..Default::default()
29//! });
30//!
31//! // Check if connection is allowed
32//! let peer_id = "QmExample".to_string();
33//! if limiter.allow_connection(&peer_id).await {
34//!     println!("Connection allowed");
35//!     // Establish connection...
36//! } else {
37//!     println!("Rate limit exceeded, queuing...");
38//! }
39//! # Ok(())
40//! # }
41//! ```
42
43use parking_lot::RwLock;
44use serde::{Deserialize, Serialize};
45use std::collections::HashMap;
46use std::sync::Arc;
47use std::time::{Duration, Instant};
48use thiserror::Error;
49
50/// Errors that can occur during rate limiting operations
51#[derive(Debug, Error)]
52pub enum RateLimiterError {
53    #[error("Rate limit exceeded")]
54    RateLimitExceeded,
55
56    #[error("Invalid configuration: {0}")]
57    InvalidConfig(String),
58
59    #[error("Peer blocked: {0}")]
60    PeerBlocked(String),
61}
62
63/// Priority level for connections
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
65pub enum ConnectionPriority {
66    /// Critical connections (bootstrap nodes, important peers)
67    Critical,
68    /// High priority connections
69    High,
70    /// Normal priority connections
71    Normal,
72    /// Low priority connections
73    Low,
74}
75
76impl ConnectionPriority {
77    /// Get the rate multiplier for this priority level
78    pub fn rate_multiplier(&self) -> f64 {
79        match self {
80            Self::Critical => 2.0, // 2x the base rate
81            Self::High => 1.5,     // 1.5x the base rate
82            Self::Normal => 1.0,   // Base rate
83            Self::Low => 0.5,      // Half the base rate
84        }
85    }
86}
87
88/// Configuration for the connection rate limiter
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct RateLimiterConfig {
91    /// Maximum connection rate (connections per second)
92    pub max_rate: f64,
93
94    /// Maximum burst size (tokens)
95    pub burst_size: usize,
96
97    /// Enable per-peer rate limiting
98    pub enable_per_peer_limits: bool,
99
100    /// Maximum connections per peer per second
101    pub max_per_peer_rate: f64,
102
103    /// Enable adaptive rate limiting
104    pub enable_adaptive: bool,
105
106    /// Adjustment factor for adaptive limiting (0.0 to 1.0)
107    pub adaptive_factor: f64,
108
109    /// Minimum rate (connections per second) when adapting
110    pub min_rate: f64,
111
112    /// Maximum rate (connections per second) when adapting
113    pub max_adaptive_rate: f64,
114
115    /// Enable connection queuing when rate limited
116    pub enable_queuing: bool,
117
118    /// Maximum queue size for pending connections
119    pub max_queue_size: usize,
120
121    /// Time window for per-peer tracking
122    pub peer_window: Duration,
123}
124
125impl Default for RateLimiterConfig {
126    fn default() -> Self {
127        Self {
128            max_rate: 10.0, // 10 connections/sec
129            burst_size: 20, // 20 token burst
130            enable_per_peer_limits: true,
131            max_per_peer_rate: 2.0, // 2 connections/sec per peer
132            enable_adaptive: false,
133            adaptive_factor: 0.1,     // 10% adjustment
134            min_rate: 1.0,            // Min 1 connection/sec
135            max_adaptive_rate: 100.0, // Max 100 connections/sec
136            enable_queuing: true,
137            max_queue_size: 100,
138            peer_window: Duration::from_secs(60), // 1 minute window
139        }
140    }
141}
142
143impl RateLimiterConfig {
144    /// Configuration for aggressive rate limiting (low rates)
145    pub fn conservative() -> Self {
146        Self {
147            max_rate: 5.0,
148            burst_size: 10,
149            max_per_peer_rate: 1.0,
150            max_queue_size: 50,
151            ..Default::default()
152        }
153    }
154
155    /// Configuration for permissive rate limiting (high rates)
156    pub fn permissive() -> Self {
157        Self {
158            max_rate: 50.0,
159            burst_size: 100,
160            max_per_peer_rate: 10.0,
161            max_queue_size: 200,
162            ..Default::default()
163        }
164    }
165
166    /// Configuration with adaptive rate limiting enabled
167    pub fn adaptive() -> Self {
168        Self {
169            enable_adaptive: true,
170            adaptive_factor: 0.2,
171            min_rate: 2.0,
172            max_adaptive_rate: 50.0,
173            ..Default::default()
174        }
175    }
176}
177
178/// Per-peer connection tracking
179#[derive(Debug, Clone)]
180struct PeerTracking {
181    /// Connection attempts in current window
182    attempts: Vec<Instant>,
183    /// Successful connections
184    successes: u64,
185    /// Failed connections
186    failures: u64,
187    /// Last connection timestamp
188    last_connection: Option<Instant>,
189}
190
191impl PeerTracking {
192    fn new() -> Self {
193        Self {
194            attempts: Vec::new(),
195            successes: 0,
196            failures: 0,
197            last_connection: None,
198        }
199    }
200
201    /// Clean up old attempts outside the window
202    fn cleanup(&mut self, window: Duration) {
203        let cutoff = Instant::now() - window;
204        self.attempts.retain(|&t| t > cutoff);
205    }
206
207    /// Record a connection attempt
208    fn record_attempt(&mut self) {
209        self.attempts.push(Instant::now());
210        self.last_connection = Some(Instant::now());
211    }
212
213    /// Get current rate (attempts per second)
214    fn current_rate(&self, window: Duration) -> f64 {
215        if self.attempts.is_empty() {
216            return 0.0;
217        }
218
219        let now = Instant::now();
220        let recent = self
221            .attempts
222            .iter()
223            .filter(|&&t| now.duration_since(t) < window)
224            .count();
225
226        recent as f64 / window.as_secs_f64()
227    }
228}
229
230/// Token bucket for rate limiting
231#[derive(Debug)]
232struct TokenBucket {
233    /// Current number of tokens
234    tokens: f64,
235    /// Maximum tokens (burst size)
236    capacity: f64,
237    /// Token refill rate (tokens per second)
238    rate: f64,
239    /// Last refill timestamp
240    last_refill: Instant,
241}
242
243impl TokenBucket {
244    fn new(rate: f64, capacity: usize) -> Self {
245        Self {
246            tokens: capacity as f64,
247            capacity: capacity as f64,
248            rate,
249            last_refill: Instant::now(),
250        }
251    }
252
253    /// Refill tokens based on elapsed time
254    fn refill(&mut self) {
255        let now = Instant::now();
256        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
257        let new_tokens = elapsed * self.rate;
258
259        self.tokens = (self.tokens + new_tokens).min(self.capacity);
260        self.last_refill = now;
261    }
262
263    /// Try to consume a token
264    fn try_consume(&mut self, count: f64) -> bool {
265        self.refill();
266
267        if self.tokens >= count {
268            self.tokens -= count;
269            true
270        } else {
271            false
272        }
273    }
274
275    /// Get current token count
276    fn available(&mut self) -> f64 {
277        self.refill();
278        self.tokens
279    }
280
281    /// Update rate dynamically
282    fn update_rate(&mut self, new_rate: f64) {
283        self.refill(); // Refill with old rate first
284        self.rate = new_rate;
285    }
286}
287
288/// Statistics tracked by the rate limiter
289#[derive(Debug, Clone, Default, Serialize, Deserialize)]
290pub struct RateLimiterStats {
291    /// Total connection attempts
292    pub total_attempts: u64,
293
294    /// Connections allowed
295    pub allowed: u64,
296
297    /// Connections rate limited
298    pub rate_limited: u64,
299
300    /// Connections queued
301    pub queued: u64,
302
303    /// Current queue size
304    pub current_queue_size: usize,
305
306    /// Average rate (connections per second)
307    pub avg_rate: f64,
308
309    /// Current rate limit
310    pub current_limit: f64,
311
312    /// Tokens available
313    pub tokens_available: f64,
314}
315
316/// Connection rate limiter
317pub struct ConnectionRateLimiter {
318    config: RateLimiterConfig,
319    bucket: Arc<RwLock<TokenBucket>>,
320    peer_tracking: Arc<RwLock<HashMap<String, PeerTracking>>>,
321    stats: Arc<RwLock<RateLimiterStats>>,
322    queue: Arc<RwLock<Vec<(String, ConnectionPriority, Instant)>>>,
323}
324
325impl ConnectionRateLimiter {
326    /// Create a new connection rate limiter
327    pub fn new(config: RateLimiterConfig) -> Self {
328        let bucket = TokenBucket::new(config.max_rate, config.burst_size);
329
330        Self {
331            config,
332            bucket: Arc::new(RwLock::new(bucket)),
333            peer_tracking: Arc::new(RwLock::new(HashMap::new())),
334            stats: Arc::new(RwLock::new(RateLimiterStats::default())),
335            queue: Arc::new(RwLock::new(Vec::new())),
336        }
337    }
338
339    /// Check if a connection is allowed
340    pub async fn allow_connection(&self, peer_id: &str) -> bool {
341        self.allow_connection_with_priority(peer_id, ConnectionPriority::Normal)
342            .await
343    }
344
345    /// Check if a connection is allowed with specific priority
346    pub async fn allow_connection_with_priority(
347        &self,
348        peer_id: &str,
349        priority: ConnectionPriority,
350    ) -> bool {
351        let mut stats = self.stats.write();
352        stats.total_attempts += 1;
353
354        // Check per-peer limits
355        if self.config.enable_per_peer_limits {
356            let mut tracking = self.peer_tracking.write();
357            let peer_track = tracking
358                .entry(peer_id.to_string())
359                .or_insert_with(PeerTracking::new);
360
361            peer_track.cleanup(self.config.peer_window);
362
363            let current_rate = peer_track.current_rate(self.config.peer_window);
364            if current_rate >= self.config.max_per_peer_rate {
365                stats.rate_limited += 1;
366                return false;
367            }
368        }
369
370        // Check global token bucket
371        let cost = 1.0 / priority.rate_multiplier();
372        let mut bucket = self.bucket.write();
373
374        if bucket.try_consume(cost) {
375            // Update tracking
376            if self.config.enable_per_peer_limits {
377                let mut tracking = self.peer_tracking.write();
378                if let Some(peer_track) = tracking.get_mut(peer_id) {
379                    peer_track.record_attempt();
380                }
381            }
382
383            stats.allowed += 1;
384            stats.tokens_available = bucket.available();
385            true
386        } else {
387            stats.rate_limited += 1;
388
389            // Queue if enabled
390            if self.config.enable_queuing {
391                let mut queue = self.queue.write();
392                if queue.len() < self.config.max_queue_size {
393                    queue.push((peer_id.to_string(), priority, Instant::now()));
394                    stats.queued += 1;
395                    stats.current_queue_size = queue.len();
396                }
397            }
398
399            false
400        }
401    }
402
403    /// Record a successful connection
404    pub fn record_success(&self, peer_id: &str) {
405        if !self.config.enable_per_peer_limits {
406            return;
407        }
408
409        let mut tracking = self.peer_tracking.write();
410        if let Some(peer_track) = tracking.get_mut(peer_id) {
411            peer_track.successes += 1;
412
413            // Adapt rate if enabled
414            if self.config.enable_adaptive {
415                self.adapt_rate_on_success();
416            }
417        }
418    }
419
420    /// Record a failed connection
421    pub fn record_failure(&self, peer_id: &str) {
422        if !self.config.enable_per_peer_limits {
423            return;
424        }
425
426        let mut tracking = self.peer_tracking.write();
427        if let Some(peer_track) = tracking.get_mut(peer_id) {
428            peer_track.failures += 1;
429
430            // Adapt rate if enabled
431            if self.config.enable_adaptive {
432                self.adapt_rate_on_failure();
433            }
434        }
435    }
436
437    /// Adapt rate upward on success
438    fn adapt_rate_on_success(&self) {
439        let mut bucket = self.bucket.write();
440        let current_rate = bucket.rate;
441        let new_rate =
442            (current_rate * (1.0 + self.config.adaptive_factor)).min(self.config.max_adaptive_rate);
443
444        if new_rate != current_rate {
445            bucket.update_rate(new_rate);
446
447            let mut stats = self.stats.write();
448            stats.current_limit = new_rate;
449        }
450    }
451
452    /// Adapt rate downward on failure
453    fn adapt_rate_on_failure(&self) {
454        let mut bucket = self.bucket.write();
455        let current_rate = bucket.rate;
456        let new_rate =
457            (current_rate * (1.0 - self.config.adaptive_factor)).max(self.config.min_rate);
458
459        if new_rate != current_rate {
460            bucket.update_rate(new_rate);
461
462            let mut stats = self.stats.write();
463            stats.current_limit = new_rate;
464        }
465    }
466
467    /// Process queued connections
468    pub async fn process_queue(&self) -> Vec<String> {
469        let mut queue = self.queue.write();
470        let mut bucket = self.bucket.write();
471        let mut allowed = Vec::new();
472
473        // Sort by priority (Critical first, then by timestamp)
474        queue.sort_by(|a, b| match (a.1, b.1) {
475            (ConnectionPriority::Critical, ConnectionPriority::Critical) => a.2.cmp(&b.2),
476            (ConnectionPriority::Critical, _) => std::cmp::Ordering::Less,
477            (_, ConnectionPriority::Critical) => std::cmp::Ordering::Greater,
478            (ConnectionPriority::High, ConnectionPriority::High) => a.2.cmp(&b.2),
479            (ConnectionPriority::High, _) => std::cmp::Ordering::Less,
480            (_, ConnectionPriority::High) => std::cmp::Ordering::Greater,
481            _ => a.2.cmp(&b.2),
482        });
483
484        // Process as many as possible
485        queue.retain(|(peer_id, priority, _)| {
486            let cost = 1.0 / priority.rate_multiplier();
487            if bucket.try_consume(cost) {
488                allowed.push(peer_id.clone());
489                false // Remove from queue
490            } else {
491                true // Keep in queue
492            }
493        });
494
495        let mut stats = self.stats.write();
496        stats.current_queue_size = queue.len();
497
498        allowed
499    }
500
501    /// Get current statistics
502    pub fn stats(&self) -> RateLimiterStats {
503        let mut stats = self.stats.read().clone();
504
505        // Update dynamic stats
506        let bucket = self.bucket.write();
507        stats.current_limit = bucket.rate;
508        stats.tokens_available = bucket.tokens;
509
510        if stats.total_attempts > 0 {
511            stats.avg_rate = stats.allowed as f64 / (stats.total_attempts as f64 / bucket.rate);
512        }
513
514        stats
515    }
516
517    /// Get per-peer statistics
518    pub fn peer_stats(&self, peer_id: &str) -> Option<(u64, u64, f64)> {
519        let tracking = self.peer_tracking.read();
520        tracking.get(peer_id).map(|track| {
521            (
522                track.successes,
523                track.failures,
524                track.current_rate(self.config.peer_window),
525            )
526        })
527    }
528
529    /// Reset rate limiter state
530    pub fn reset(&self) {
531        let mut bucket = self.bucket.write();
532        bucket.tokens = bucket.capacity;
533        bucket.last_refill = Instant::now();
534
535        self.peer_tracking.write().clear();
536        self.queue.write().clear();
537
538        let mut stats = self.stats.write();
539        *stats = RateLimiterStats::default();
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546    use tokio::time::sleep;
547
548    #[tokio::test]
549    async fn test_rate_limiter_creation() {
550        let limiter = ConnectionRateLimiter::new(RateLimiterConfig::default());
551        let stats = limiter.stats();
552        assert_eq!(stats.total_attempts, 0);
553    }
554
555    #[tokio::test]
556    async fn test_allow_connection() {
557        let limiter = ConnectionRateLimiter::new(RateLimiterConfig::default());
558
559        let allowed = limiter.allow_connection("peer1").await;
560        assert!(allowed);
561
562        let stats = limiter.stats();
563        assert_eq!(stats.allowed, 1);
564    }
565
566    #[tokio::test]
567    async fn test_rate_limiting() {
568        let config = RateLimiterConfig {
569            max_rate: 10.0,
570            burst_size: 5,
571            ..Default::default()
572        };
573        let limiter = ConnectionRateLimiter::new(config);
574
575        // Use up burst
576        for _ in 0..5 {
577            assert!(limiter.allow_connection("peer1").await);
578        }
579
580        // Next should be rate limited
581        let allowed = limiter.allow_connection("peer1").await;
582        assert!(!allowed);
583    }
584
585    #[tokio::test]
586    async fn test_per_peer_limits() {
587        let config = RateLimiterConfig {
588            max_rate: 100.0,
589            burst_size: 100,
590            enable_per_peer_limits: true,
591            max_per_peer_rate: 5.0, // Allow 5 connections per second
592            peer_window: Duration::from_secs(1), // 1 second window for testing
593            ..Default::default()
594        };
595        let limiter = ConnectionRateLimiter::new(config);
596
597        // First 5 connections should succeed (rate = 5/sec)
598        for _ in 0..5 {
599            assert!(limiter.allow_connection("peer1").await);
600        }
601
602        // Sixth should be rate limited (would exceed 5 connections/sec)
603        let allowed = limiter.allow_connection("peer1").await;
604        assert!(!allowed);
605
606        // Different peer should still work
607        assert!(limiter.allow_connection("peer2").await);
608    }
609
610    #[tokio::test]
611    async fn test_priority() {
612        let config = RateLimiterConfig {
613            max_rate: 10.0,
614            burst_size: 2,
615            ..Default::default()
616        };
617        let limiter = ConnectionRateLimiter::new(config);
618
619        // Critical priority should cost less
620        assert!(
621            limiter
622                .allow_connection_with_priority("peer1", ConnectionPriority::Critical)
623                .await
624        );
625        assert!(
626            limiter
627                .allow_connection_with_priority("peer2", ConnectionPriority::Critical)
628                .await
629        );
630
631        // Should still have some tokens due to lower cost
632        let stats = limiter.stats();
633        assert!(stats.tokens_available > 0.0);
634    }
635
636    #[tokio::test]
637    async fn test_queuing() {
638        let config = RateLimiterConfig {
639            max_rate: 1.0,
640            burst_size: 1,
641            enable_queuing: true,
642            max_queue_size: 10,
643            ..Default::default()
644        };
645        let limiter = ConnectionRateLimiter::new(config);
646
647        // First connection allowed
648        assert!(limiter.allow_connection("peer1").await);
649
650        // Next connections should be queued
651        assert!(!limiter.allow_connection("peer2").await);
652        assert!(!limiter.allow_connection("peer3").await);
653
654        let stats = limiter.stats();
655        assert_eq!(stats.queued, 2);
656    }
657
658    #[tokio::test]
659    async fn test_process_queue() {
660        let config = RateLimiterConfig {
661            max_rate: 10.0,
662            burst_size: 1,
663            enable_queuing: true,
664            ..Default::default()
665        };
666        let limiter = ConnectionRateLimiter::new(config);
667
668        // Fill bucket
669        limiter.allow_connection("peer1").await;
670
671        // Queue some connections
672        limiter.allow_connection("peer2").await;
673        limiter.allow_connection("peer3").await;
674
675        // Wait for refill
676        sleep(Duration::from_millis(200)).await;
677
678        // Process queue
679        let allowed = limiter.process_queue().await;
680        assert!(!allowed.is_empty());
681    }
682
683    #[tokio::test]
684    async fn test_success_failure_recording() {
685        let config = RateLimiterConfig {
686            enable_per_peer_limits: true,
687            ..Default::default()
688        };
689        let limiter = ConnectionRateLimiter::new(config);
690
691        limiter.allow_connection("peer1").await;
692        limiter.record_success("peer1");
693
694        let (successes, failures, _) = limiter.peer_stats("peer1").unwrap();
695        assert_eq!(successes, 1);
696        assert_eq!(failures, 0);
697    }
698
699    #[tokio::test]
700    async fn test_config_presets() {
701        let conservative = RateLimiterConfig::conservative();
702        assert!(conservative.max_rate < 10.0);
703
704        let permissive = RateLimiterConfig::permissive();
705        assert!(permissive.max_rate > 10.0);
706
707        let adaptive = RateLimiterConfig::adaptive();
708        assert!(adaptive.enable_adaptive);
709    }
710
711    #[tokio::test]
712    async fn test_reset() {
713        let limiter = ConnectionRateLimiter::new(RateLimiterConfig::default());
714
715        limiter.allow_connection("peer1").await;
716        assert_eq!(limiter.stats().allowed, 1);
717
718        limiter.reset();
719        assert_eq!(limiter.stats().allowed, 0);
720    }
721
722    #[tokio::test]
723    async fn test_token_refill() {
724        let config = RateLimiterConfig {
725            max_rate: 10.0,
726            burst_size: 5,
727            ..Default::default()
728        };
729        let limiter = ConnectionRateLimiter::new(config);
730
731        // Use all tokens
732        for _ in 0..5 {
733            limiter.allow_connection("peer1").await;
734        }
735
736        // Wait for refill
737        sleep(Duration::from_millis(200)).await;
738
739        // Should be able to connect again
740        assert!(limiter.allow_connection("peer1").await);
741    }
742}