1use parking_lot::RwLock;
44use serde::{Deserialize, Serialize};
45use std::collections::HashMap;
46use std::sync::Arc;
47use std::time::{Duration, Instant};
48use thiserror::Error;
49
50#[derive(Debug, Error)]
52pub enum RateLimiterError {
53 #[error("Rate limit exceeded")]
54 RateLimitExceeded,
55
56 #[error("Invalid configuration: {0}")]
57 InvalidConfig(String),
58
59 #[error("Peer blocked: {0}")]
60 PeerBlocked(String),
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
65pub enum ConnectionPriority {
66 Critical,
68 High,
70 Normal,
72 Low,
74}
75
76impl ConnectionPriority {
77 pub fn rate_multiplier(&self) -> f64 {
79 match self {
80 Self::Critical => 2.0, Self::High => 1.5, Self::Normal => 1.0, Self::Low => 0.5, }
85 }
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct RateLimiterConfig {
91 pub max_rate: f64,
93
94 pub burst_size: usize,
96
97 pub enable_per_peer_limits: bool,
99
100 pub max_per_peer_rate: f64,
102
103 pub enable_adaptive: bool,
105
106 pub adaptive_factor: f64,
108
109 pub min_rate: f64,
111
112 pub max_adaptive_rate: f64,
114
115 pub enable_queuing: bool,
117
118 pub max_queue_size: usize,
120
121 pub peer_window: Duration,
123}
124
125impl Default for RateLimiterConfig {
126 fn default() -> Self {
127 Self {
128 max_rate: 10.0, burst_size: 20, enable_per_peer_limits: true,
131 max_per_peer_rate: 2.0, enable_adaptive: false,
133 adaptive_factor: 0.1, min_rate: 1.0, max_adaptive_rate: 100.0, enable_queuing: true,
137 max_queue_size: 100,
138 peer_window: Duration::from_secs(60), }
140 }
141}
142
143impl RateLimiterConfig {
144 pub fn conservative() -> Self {
146 Self {
147 max_rate: 5.0,
148 burst_size: 10,
149 max_per_peer_rate: 1.0,
150 max_queue_size: 50,
151 ..Default::default()
152 }
153 }
154
155 pub fn permissive() -> Self {
157 Self {
158 max_rate: 50.0,
159 burst_size: 100,
160 max_per_peer_rate: 10.0,
161 max_queue_size: 200,
162 ..Default::default()
163 }
164 }
165
166 pub fn adaptive() -> Self {
168 Self {
169 enable_adaptive: true,
170 adaptive_factor: 0.2,
171 min_rate: 2.0,
172 max_adaptive_rate: 50.0,
173 ..Default::default()
174 }
175 }
176}
177
178#[derive(Debug, Clone)]
180struct PeerTracking {
181 attempts: Vec<Instant>,
183 successes: u64,
185 failures: u64,
187 last_connection: Option<Instant>,
189}
190
191impl PeerTracking {
192 fn new() -> Self {
193 Self {
194 attempts: Vec::new(),
195 successes: 0,
196 failures: 0,
197 last_connection: None,
198 }
199 }
200
201 fn cleanup(&mut self, window: Duration) {
203 let cutoff = Instant::now() - window;
204 self.attempts.retain(|&t| t > cutoff);
205 }
206
207 fn record_attempt(&mut self) {
209 self.attempts.push(Instant::now());
210 self.last_connection = Some(Instant::now());
211 }
212
213 fn current_rate(&self, window: Duration) -> f64 {
215 if self.attempts.is_empty() {
216 return 0.0;
217 }
218
219 let now = Instant::now();
220 let recent = self
221 .attempts
222 .iter()
223 .filter(|&&t| now.duration_since(t) < window)
224 .count();
225
226 recent as f64 / window.as_secs_f64()
227 }
228}
229
230#[derive(Debug)]
232struct TokenBucket {
233 tokens: f64,
235 capacity: f64,
237 rate: f64,
239 last_refill: Instant,
241}
242
243impl TokenBucket {
244 fn new(rate: f64, capacity: usize) -> Self {
245 Self {
246 tokens: capacity as f64,
247 capacity: capacity as f64,
248 rate,
249 last_refill: Instant::now(),
250 }
251 }
252
253 fn refill(&mut self) {
255 let now = Instant::now();
256 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
257 let new_tokens = elapsed * self.rate;
258
259 self.tokens = (self.tokens + new_tokens).min(self.capacity);
260 self.last_refill = now;
261 }
262
263 fn try_consume(&mut self, count: f64) -> bool {
265 self.refill();
266
267 if self.tokens >= count {
268 self.tokens -= count;
269 true
270 } else {
271 false
272 }
273 }
274
275 fn available(&mut self) -> f64 {
277 self.refill();
278 self.tokens
279 }
280
281 fn update_rate(&mut self, new_rate: f64) {
283 self.refill(); self.rate = new_rate;
285 }
286}
287
288#[derive(Debug, Clone, Default, Serialize, Deserialize)]
290pub struct RateLimiterStats {
291 pub total_attempts: u64,
293
294 pub allowed: u64,
296
297 pub rate_limited: u64,
299
300 pub queued: u64,
302
303 pub current_queue_size: usize,
305
306 pub avg_rate: f64,
308
309 pub current_limit: f64,
311
312 pub tokens_available: f64,
314}
315
316pub struct ConnectionRateLimiter {
318 config: RateLimiterConfig,
319 bucket: Arc<RwLock<TokenBucket>>,
320 peer_tracking: Arc<RwLock<HashMap<String, PeerTracking>>>,
321 stats: Arc<RwLock<RateLimiterStats>>,
322 queue: Arc<RwLock<Vec<(String, ConnectionPriority, Instant)>>>,
323}
324
325impl ConnectionRateLimiter {
326 pub fn new(config: RateLimiterConfig) -> Self {
328 let bucket = TokenBucket::new(config.max_rate, config.burst_size);
329
330 Self {
331 config,
332 bucket: Arc::new(RwLock::new(bucket)),
333 peer_tracking: Arc::new(RwLock::new(HashMap::new())),
334 stats: Arc::new(RwLock::new(RateLimiterStats::default())),
335 queue: Arc::new(RwLock::new(Vec::new())),
336 }
337 }
338
339 pub async fn allow_connection(&self, peer_id: &str) -> bool {
341 self.allow_connection_with_priority(peer_id, ConnectionPriority::Normal)
342 .await
343 }
344
345 pub async fn allow_connection_with_priority(
347 &self,
348 peer_id: &str,
349 priority: ConnectionPriority,
350 ) -> bool {
351 let mut stats = self.stats.write();
352 stats.total_attempts += 1;
353
354 if self.config.enable_per_peer_limits {
356 let mut tracking = self.peer_tracking.write();
357 let peer_track = tracking
358 .entry(peer_id.to_string())
359 .or_insert_with(PeerTracking::new);
360
361 peer_track.cleanup(self.config.peer_window);
362
363 let current_rate = peer_track.current_rate(self.config.peer_window);
364 if current_rate >= self.config.max_per_peer_rate {
365 stats.rate_limited += 1;
366 return false;
367 }
368 }
369
370 let cost = 1.0 / priority.rate_multiplier();
372 let mut bucket = self.bucket.write();
373
374 if bucket.try_consume(cost) {
375 if self.config.enable_per_peer_limits {
377 let mut tracking = self.peer_tracking.write();
378 if let Some(peer_track) = tracking.get_mut(peer_id) {
379 peer_track.record_attempt();
380 }
381 }
382
383 stats.allowed += 1;
384 stats.tokens_available = bucket.available();
385 true
386 } else {
387 stats.rate_limited += 1;
388
389 if self.config.enable_queuing {
391 let mut queue = self.queue.write();
392 if queue.len() < self.config.max_queue_size {
393 queue.push((peer_id.to_string(), priority, Instant::now()));
394 stats.queued += 1;
395 stats.current_queue_size = queue.len();
396 }
397 }
398
399 false
400 }
401 }
402
403 pub fn record_success(&self, peer_id: &str) {
405 if !self.config.enable_per_peer_limits {
406 return;
407 }
408
409 let mut tracking = self.peer_tracking.write();
410 if let Some(peer_track) = tracking.get_mut(peer_id) {
411 peer_track.successes += 1;
412
413 if self.config.enable_adaptive {
415 self.adapt_rate_on_success();
416 }
417 }
418 }
419
420 pub fn record_failure(&self, peer_id: &str) {
422 if !self.config.enable_per_peer_limits {
423 return;
424 }
425
426 let mut tracking = self.peer_tracking.write();
427 if let Some(peer_track) = tracking.get_mut(peer_id) {
428 peer_track.failures += 1;
429
430 if self.config.enable_adaptive {
432 self.adapt_rate_on_failure();
433 }
434 }
435 }
436
437 fn adapt_rate_on_success(&self) {
439 let mut bucket = self.bucket.write();
440 let current_rate = bucket.rate;
441 let new_rate =
442 (current_rate * (1.0 + self.config.adaptive_factor)).min(self.config.max_adaptive_rate);
443
444 if new_rate != current_rate {
445 bucket.update_rate(new_rate);
446
447 let mut stats = self.stats.write();
448 stats.current_limit = new_rate;
449 }
450 }
451
452 fn adapt_rate_on_failure(&self) {
454 let mut bucket = self.bucket.write();
455 let current_rate = bucket.rate;
456 let new_rate =
457 (current_rate * (1.0 - self.config.adaptive_factor)).max(self.config.min_rate);
458
459 if new_rate != current_rate {
460 bucket.update_rate(new_rate);
461
462 let mut stats = self.stats.write();
463 stats.current_limit = new_rate;
464 }
465 }
466
467 pub async fn process_queue(&self) -> Vec<String> {
469 let mut queue = self.queue.write();
470 let mut bucket = self.bucket.write();
471 let mut allowed = Vec::new();
472
473 queue.sort_by(|a, b| match (a.1, b.1) {
475 (ConnectionPriority::Critical, ConnectionPriority::Critical) => a.2.cmp(&b.2),
476 (ConnectionPriority::Critical, _) => std::cmp::Ordering::Less,
477 (_, ConnectionPriority::Critical) => std::cmp::Ordering::Greater,
478 (ConnectionPriority::High, ConnectionPriority::High) => a.2.cmp(&b.2),
479 (ConnectionPriority::High, _) => std::cmp::Ordering::Less,
480 (_, ConnectionPriority::High) => std::cmp::Ordering::Greater,
481 _ => a.2.cmp(&b.2),
482 });
483
484 queue.retain(|(peer_id, priority, _)| {
486 let cost = 1.0 / priority.rate_multiplier();
487 if bucket.try_consume(cost) {
488 allowed.push(peer_id.clone());
489 false } else {
491 true }
493 });
494
495 let mut stats = self.stats.write();
496 stats.current_queue_size = queue.len();
497
498 allowed
499 }
500
501 pub fn stats(&self) -> RateLimiterStats {
503 let mut stats = self.stats.read().clone();
504
505 let bucket = self.bucket.write();
507 stats.current_limit = bucket.rate;
508 stats.tokens_available = bucket.tokens;
509
510 if stats.total_attempts > 0 {
511 stats.avg_rate = stats.allowed as f64 / (stats.total_attempts as f64 / bucket.rate);
512 }
513
514 stats
515 }
516
517 pub fn peer_stats(&self, peer_id: &str) -> Option<(u64, u64, f64)> {
519 let tracking = self.peer_tracking.read();
520 tracking.get(peer_id).map(|track| {
521 (
522 track.successes,
523 track.failures,
524 track.current_rate(self.config.peer_window),
525 )
526 })
527 }
528
529 pub fn reset(&self) {
531 let mut bucket = self.bucket.write();
532 bucket.tokens = bucket.capacity;
533 bucket.last_refill = Instant::now();
534
535 self.peer_tracking.write().clear();
536 self.queue.write().clear();
537
538 let mut stats = self.stats.write();
539 *stats = RateLimiterStats::default();
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546 use tokio::time::sleep;
547
548 #[tokio::test]
549 async fn test_rate_limiter_creation() {
550 let limiter = ConnectionRateLimiter::new(RateLimiterConfig::default());
551 let stats = limiter.stats();
552 assert_eq!(stats.total_attempts, 0);
553 }
554
555 #[tokio::test]
556 async fn test_allow_connection() {
557 let limiter = ConnectionRateLimiter::new(RateLimiterConfig::default());
558
559 let allowed = limiter.allow_connection("peer1").await;
560 assert!(allowed);
561
562 let stats = limiter.stats();
563 assert_eq!(stats.allowed, 1);
564 }
565
566 #[tokio::test]
567 async fn test_rate_limiting() {
568 let config = RateLimiterConfig {
569 max_rate: 10.0,
570 burst_size: 5,
571 ..Default::default()
572 };
573 let limiter = ConnectionRateLimiter::new(config);
574
575 for _ in 0..5 {
577 assert!(limiter.allow_connection("peer1").await);
578 }
579
580 let allowed = limiter.allow_connection("peer1").await;
582 assert!(!allowed);
583 }
584
585 #[tokio::test]
586 async fn test_per_peer_limits() {
587 let config = RateLimiterConfig {
588 max_rate: 100.0,
589 burst_size: 100,
590 enable_per_peer_limits: true,
591 max_per_peer_rate: 5.0, peer_window: Duration::from_secs(1), ..Default::default()
594 };
595 let limiter = ConnectionRateLimiter::new(config);
596
597 for _ in 0..5 {
599 assert!(limiter.allow_connection("peer1").await);
600 }
601
602 let allowed = limiter.allow_connection("peer1").await;
604 assert!(!allowed);
605
606 assert!(limiter.allow_connection("peer2").await);
608 }
609
610 #[tokio::test]
611 async fn test_priority() {
612 let config = RateLimiterConfig {
613 max_rate: 10.0,
614 burst_size: 2,
615 ..Default::default()
616 };
617 let limiter = ConnectionRateLimiter::new(config);
618
619 assert!(
621 limiter
622 .allow_connection_with_priority("peer1", ConnectionPriority::Critical)
623 .await
624 );
625 assert!(
626 limiter
627 .allow_connection_with_priority("peer2", ConnectionPriority::Critical)
628 .await
629 );
630
631 let stats = limiter.stats();
633 assert!(stats.tokens_available > 0.0);
634 }
635
636 #[tokio::test]
637 async fn test_queuing() {
638 let config = RateLimiterConfig {
639 max_rate: 1.0,
640 burst_size: 1,
641 enable_queuing: true,
642 max_queue_size: 10,
643 ..Default::default()
644 };
645 let limiter = ConnectionRateLimiter::new(config);
646
647 assert!(limiter.allow_connection("peer1").await);
649
650 assert!(!limiter.allow_connection("peer2").await);
652 assert!(!limiter.allow_connection("peer3").await);
653
654 let stats = limiter.stats();
655 assert_eq!(stats.queued, 2);
656 }
657
658 #[tokio::test]
659 async fn test_process_queue() {
660 let config = RateLimiterConfig {
661 max_rate: 10.0,
662 burst_size: 1,
663 enable_queuing: true,
664 ..Default::default()
665 };
666 let limiter = ConnectionRateLimiter::new(config);
667
668 limiter.allow_connection("peer1").await;
670
671 limiter.allow_connection("peer2").await;
673 limiter.allow_connection("peer3").await;
674
675 sleep(Duration::from_millis(200)).await;
677
678 let allowed = limiter.process_queue().await;
680 assert!(!allowed.is_empty());
681 }
682
683 #[tokio::test]
684 async fn test_success_failure_recording() {
685 let config = RateLimiterConfig {
686 enable_per_peer_limits: true,
687 ..Default::default()
688 };
689 let limiter = ConnectionRateLimiter::new(config);
690
691 limiter.allow_connection("peer1").await;
692 limiter.record_success("peer1");
693
694 let (successes, failures, _) = limiter.peer_stats("peer1").unwrap();
695 assert_eq!(successes, 1);
696 assert_eq!(failures, 0);
697 }
698
699 #[tokio::test]
700 async fn test_config_presets() {
701 let conservative = RateLimiterConfig::conservative();
702 assert!(conservative.max_rate < 10.0);
703
704 let permissive = RateLimiterConfig::permissive();
705 assert!(permissive.max_rate > 10.0);
706
707 let adaptive = RateLimiterConfig::adaptive();
708 assert!(adaptive.enable_adaptive);
709 }
710
711 #[tokio::test]
712 async fn test_reset() {
713 let limiter = ConnectionRateLimiter::new(RateLimiterConfig::default());
714
715 limiter.allow_connection("peer1").await;
716 assert_eq!(limiter.stats().allowed, 1);
717
718 limiter.reset();
719 assert_eq!(limiter.stats().allowed, 0);
720 }
721
722 #[tokio::test]
723 async fn test_token_refill() {
724 let config = RateLimiterConfig {
725 max_rate: 10.0,
726 burst_size: 5,
727 ..Default::default()
728 };
729 let limiter = ConnectionRateLimiter::new(config);
730
731 for _ in 0..5 {
733 limiter.allow_connection("peer1").await;
734 }
735
736 sleep(Duration::from_millis(200)).await;
738
739 assert!(limiter.allow_connection("peer1").await);
741 }
742}