1use dashmap::DashMap;
22use parking_lot::Mutex;
23use std::time::{Duration, Instant};
24use tracing::{debug, warn};
25
26#[derive(Debug, Clone)]
28pub struct RateLimiterConfig {
29 pub requests_per_second: f64,
31 pub burst_size: u32,
33 pub per_client: bool,
35 pub global_limit: Option<u32>,
37 pub idle_timeout: Duration,
39}
40
41impl RateLimiterConfig {
42 pub fn new(requests_per_second: f64, burst_size: u32) -> Self {
48 Self {
49 requests_per_second,
50 burst_size,
51 per_client: true,
52 global_limit: None,
53 idle_timeout: Duration::from_secs(300), }
55 }
56
57 #[must_use]
59 pub fn with_per_client(mut self, per_client: bool) -> Self {
60 self.per_client = per_client;
61 self
62 }
63
64 #[must_use]
66 pub fn with_global_limit(mut self, limit: u32) -> Self {
67 self.global_limit = Some(limit);
68 self
69 }
70
71 #[must_use]
73 pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
74 self.idle_timeout = timeout;
75 self
76 }
77}
78
79impl Default for RateLimiterConfig {
80 fn default() -> Self {
81 Self::new(100.0, 50)
82 }
83}
84
85#[derive(Debug, Clone)]
87pub enum RateLimitError {
88 GlobalLimitExceeded {
90 retry_after_ms: u64,
92 },
93 ClientLimitExceeded {
95 client_id: String,
97 retry_after_ms: u64,
99 },
100}
101
102impl std::fmt::Display for RateLimitError {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 match self {
105 RateLimitError::GlobalLimitExceeded { retry_after_ms } => {
106 write!(
107 f,
108 "Global rate limit exceeded, retry after {}ms",
109 retry_after_ms
110 )
111 }
112 RateLimitError::ClientLimitExceeded {
113 client_id,
114 retry_after_ms,
115 } => {
116 write!(
117 f,
118 "Rate limit exceeded for client '{}', retry after {}ms",
119 client_id, retry_after_ms
120 )
121 }
122 }
123 }
124}
125
126impl std::error::Error for RateLimitError {}
127
128#[derive(Debug)]
133pub struct TokenBucket {
134 tokens: f64,
136 max_tokens: f64,
138 refill_rate: f64,
140 last_refill: Instant,
142}
143
144impl TokenBucket {
145 pub fn new(max_tokens: f64, refill_rate: f64) -> Self {
151 Self {
152 tokens: max_tokens,
153 max_tokens,
154 refill_rate,
155 last_refill: Instant::now(),
156 }
157 }
158
159 fn refill(&mut self) {
161 let now = Instant::now();
162 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
163 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
164 self.last_refill = now;
165 }
166
167 pub fn try_acquire(&mut self) -> bool {
169 self.refill();
170 if self.tokens >= 1.0 {
171 self.tokens -= 1.0;
172 true
173 } else {
174 false
175 }
176 }
177
178 pub fn remaining(&mut self) -> u32 {
180 self.refill();
181 self.tokens.floor().max(0.0) as u32
182 }
183
184 pub fn retry_after_ms(&self) -> u64 {
186 if self.tokens >= 1.0 {
187 return 0;
188 }
189 let deficit = 1.0 - self.tokens;
190 if self.refill_rate <= 0.0 {
191 return u64::MAX;
192 }
193 let seconds = deficit / self.refill_rate;
194 (seconds * 1000.0).ceil() as u64
195 }
196
197 pub fn reset(&mut self) {
199 self.tokens = self.max_tokens;
200 self.last_refill = Instant::now();
201 }
202
203 pub fn last_access(&self) -> Instant {
205 self.last_refill
206 }
207}
208
209pub struct RateLimiter {
214 config: RateLimiterConfig,
216 global_bucket: Mutex<TokenBucket>,
218 client_buckets: DashMap<String, Mutex<TokenBucket>>,
220}
221
222impl RateLimiter {
223 pub fn new(config: RateLimiterConfig) -> Self {
225 let global_max = config
226 .global_limit
227 .map(f64::from)
228 .unwrap_or(config.requests_per_second * 2.0);
229 let global_rate = config
230 .global_limit
231 .map(f64::from)
232 .unwrap_or(config.requests_per_second * 2.0);
233
234 Self {
235 config: config.clone(),
236 global_bucket: Mutex::new(TokenBucket::new(global_max, global_rate)),
237 client_buckets: DashMap::new(),
238 }
239 }
240
241 pub fn check_rate_limit(&self, client_id: &str) -> Result<(), RateLimitError> {
247 if self.config.per_client {
249 let bucket = self
250 .client_buckets
251 .entry(client_id.to_string())
252 .or_insert_with(|| {
253 Mutex::new(TokenBucket::new(
254 f64::from(self.config.burst_size),
255 self.config.requests_per_second,
256 ))
257 });
258
259 let mut bucket_guard = bucket.lock();
260 if !bucket_guard.try_acquire() {
261 let retry_after_ms = bucket_guard.retry_after_ms();
262 debug!(
263 client_id = %client_id,
264 retry_after_ms = retry_after_ms,
265 "Per-client rate limit exceeded"
266 );
267 return Err(RateLimitError::ClientLimitExceeded {
268 client_id: client_id.to_string(),
269 retry_after_ms,
270 });
271 }
272 }
273
274 if self.config.global_limit.is_some() {
276 let mut global = self.global_bucket.lock();
277 if !global.try_acquire() {
278 let retry_after_ms = global.retry_after_ms();
279 warn!(
280 client_id = %client_id,
281 retry_after_ms = retry_after_ms,
282 "Global rate limit exceeded"
283 );
284 return Err(RateLimitError::GlobalLimitExceeded { retry_after_ms });
285 }
286 }
287
288 Ok(())
289 }
290
291 pub fn try_acquire(&self, client_id: &str) -> bool {
295 self.check_rate_limit(client_id).is_ok()
296 }
297
298 pub fn remaining_tokens(&self, client_id: &str) -> u32 {
303 if self.config.per_client {
304 if let Some(bucket) = self.client_buckets.get(client_id) {
305 return bucket.lock().remaining();
306 }
307 return self.config.burst_size;
309 }
310
311 self.global_bucket.lock().remaining()
313 }
314
315 pub fn cleanup_expired_buckets(&self) -> usize {
319 let now = Instant::now();
320 let timeout = self.config.idle_timeout;
321 let mut removed = 0;
322
323 let expired_keys: Vec<String> = self
325 .client_buckets
326 .iter()
327 .filter_map(|entry| {
328 let bucket = entry.value().lock();
329 if now.duration_since(bucket.last_access()) > timeout {
330 Some(entry.key().clone())
331 } else {
332 None
333 }
334 })
335 .collect();
336
337 for key in &expired_keys {
338 if let Some((_k, bucket)) = self.client_buckets.remove(key) {
340 let guard = bucket.lock();
341 if now.duration_since(guard.last_access()) > timeout {
342 removed += 1;
343 debug!(client_id = %key, "Cleaned up expired rate limiter bucket");
344 } else {
345 drop(guard);
347 self.client_buckets.insert(key.clone(), bucket);
348 }
349 }
350 }
351
352 if removed > 0 {
353 debug!(count = removed, "Cleaned up expired rate limiter buckets");
354 }
355
356 removed
357 }
358
359 pub fn reset(&self, client_id: &str) {
361 if let Some(bucket) = self.client_buckets.get(client_id) {
362 bucket.lock().reset();
363 }
364 }
365
366 pub fn reset_all(&self) {
368 self.global_bucket.lock().reset();
369 self.client_buckets.clear();
370 }
371
372 pub fn tracked_client_count(&self) -> usize {
374 self.client_buckets.len()
375 }
376
377 pub fn config(&self) -> &RateLimiterConfig {
379 &self.config
380 }
381}
382
383impl std::fmt::Debug for RateLimiter {
384 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385 f.debug_struct("RateLimiter")
386 .field("config", &self.config)
387 .field("tracked_clients", &self.client_buckets.len())
388 .finish()
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use std::thread;
396 use std::time::Duration;
397
398 #[test]
399 fn test_token_bucket_basic() {
400 let mut bucket = TokenBucket::new(5.0, 10.0);
401
402 for _ in 0..5 {
404 assert!(
405 bucket.try_acquire(),
406 "Should acquire token from full bucket"
407 );
408 }
409
410 assert!(!bucket.try_acquire(), "Should fail when bucket is depleted");
412 }
413
414 #[test]
415 fn test_token_bucket_refill() {
416 let mut bucket = TokenBucket::new(3.0, 100.0); for _ in 0..3 {
420 assert!(bucket.try_acquire());
421 }
422 assert!(!bucket.try_acquire(), "Bucket should be empty");
423
424 thread::sleep(Duration::from_millis(25));
426
427 assert!(
428 bucket.try_acquire(),
429 "Should have refilled at least one token after 25ms at 100/s"
430 );
431 }
432
433 #[test]
434 fn test_token_bucket_remaining() {
435 let mut bucket = TokenBucket::new(10.0, 1.0);
436 assert_eq!(bucket.remaining(), 10);
437
438 assert!(bucket.try_acquire());
439 assert_eq!(bucket.remaining(), 9);
440 }
441
442 #[test]
443 fn test_token_bucket_retry_after() {
444 let mut bucket = TokenBucket::new(1.0, 10.0); assert!(bucket.try_acquire());
448 assert!(!bucket.try_acquire());
449
450 let retry = bucket.retry_after_ms();
451 assert!(
454 retry <= 110,
455 "retry_after_ms should be approximately 100ms, got {}",
456 retry
457 );
458 assert!(retry > 0, "retry_after_ms should be > 0 when depleted");
459 }
460
461 #[test]
462 fn test_token_bucket_reset() {
463 let mut bucket = TokenBucket::new(5.0, 1.0);
464
465 for _ in 0..5 {
467 assert!(bucket.try_acquire());
468 }
469 assert!(!bucket.try_acquire());
470
471 bucket.reset();
473 assert_eq!(bucket.remaining(), 5);
474 assert!(bucket.try_acquire());
475 }
476
477 #[test]
478 fn test_per_client_isolation() {
479 let config = RateLimiterConfig::new(1000.0, 3).with_per_client(true);
480 let limiter = RateLimiter::new(config);
481
482 for _ in 0..3 {
484 assert!(limiter.check_rate_limit("client-a").is_ok());
485 }
486 assert!(
487 limiter.check_rate_limit("client-a").is_err(),
488 "Client A should be rate limited"
489 );
490
491 assert!(
493 limiter.check_rate_limit("client-b").is_ok(),
494 "Client B should not be affected by Client A's limit"
495 );
496 }
497
498 #[test]
499 fn test_global_limit() {
500 let config = RateLimiterConfig::new(1000.0, 10)
501 .with_per_client(false)
502 .with_global_limit(3);
503 let limiter = RateLimiter::new(config);
504
505 assert!(limiter.check_rate_limit("client-a").is_ok());
507 assert!(limiter.check_rate_limit("client-b").is_ok());
508 assert!(limiter.check_rate_limit("client-c").is_ok());
509
510 let result = limiter.check_rate_limit("client-d");
512 assert!(result.is_err(), "Global limit should be enforced");
513 match result {
514 Err(RateLimitError::GlobalLimitExceeded { retry_after_ms }) => {
515 assert!(retry_after_ms > 0);
516 }
517 other => panic!("Expected GlobalLimitExceeded, got {:?}", other),
518 }
519 }
520
521 #[test]
522 fn test_burst_handling() {
523 let config = RateLimiterConfig::new(10.0, 20).with_per_client(true);
524 let limiter = RateLimiter::new(config);
525
526 let mut allowed = 0;
528 for _ in 0..25 {
529 if limiter.check_rate_limit("burst-client").is_ok() {
530 allowed += 1;
531 }
532 }
533
534 assert_eq!(
535 allowed, 20,
536 "Should allow exactly burst_size requests in a burst"
537 );
538 }
539
540 #[test]
541 fn test_cleanup_expired() {
542 let config = RateLimiterConfig::new(100.0, 5)
543 .with_per_client(true)
544 .with_idle_timeout(Duration::from_millis(50));
545 let limiter = RateLimiter::new(config);
546
547 assert!(limiter.check_rate_limit("client-1").is_ok());
549 assert!(limiter.check_rate_limit("client-2").is_ok());
550 assert_eq!(limiter.tracked_client_count(), 2);
551
552 thread::sleep(Duration::from_millis(80));
554
555 let removed = limiter.cleanup_expired_buckets();
556 assert_eq!(removed, 2, "Both idle clients should be cleaned up");
557 assert_eq!(limiter.tracked_client_count(), 0);
558 }
559
560 #[test]
561 fn test_cleanup_keeps_active() {
562 let config = RateLimiterConfig::new(100.0, 5)
563 .with_per_client(true)
564 .with_idle_timeout(Duration::from_millis(100));
565 let limiter = RateLimiter::new(config);
566
567 assert!(limiter.check_rate_limit("active-client").is_ok());
569
570 thread::sleep(Duration::from_millis(30));
572
573 assert!(limiter.check_rate_limit("active-client").is_ok());
575
576 assert!(limiter.check_rate_limit("idle-client").is_ok());
578
579 thread::sleep(Duration::from_millis(120));
581
582 assert!(limiter.check_rate_limit("active-client").is_ok());
584
585 let removed = limiter.cleanup_expired_buckets();
586 assert_eq!(removed, 1, "Only idle client should be cleaned up");
587 assert_eq!(limiter.tracked_client_count(), 1);
588 }
589
590 #[test]
591 fn test_rate_limit_error_display() {
592 let global_err = RateLimitError::GlobalLimitExceeded { retry_after_ms: 42 };
593 let msg = format!("{}", global_err);
594 assert!(msg.contains("Global rate limit exceeded"));
595 assert!(msg.contains("42ms"));
596
597 let client_err = RateLimitError::ClientLimitExceeded {
598 client_id: "test-client".to_string(),
599 retry_after_ms: 100,
600 };
601 let msg = format!("{}", client_err);
602 assert!(msg.contains("test-client"));
603 assert!(msg.contains("100ms"));
604 }
605
606 #[test]
607 fn test_rate_limit_error_details() {
608 let config = RateLimiterConfig::new(10.0, 2).with_per_client(true);
609 let limiter = RateLimiter::new(config);
610
611 assert!(limiter.check_rate_limit("err-client").is_ok());
613 assert!(limiter.check_rate_limit("err-client").is_ok());
614
615 let result = limiter.check_rate_limit("err-client");
616 match result {
617 Err(RateLimitError::ClientLimitExceeded {
618 client_id,
619 retry_after_ms,
620 }) => {
621 assert_eq!(client_id, "err-client");
622 assert!(retry_after_ms > 0);
623 }
624 other => panic!("Expected ClientLimitExceeded, got {:?}", other),
625 }
626 }
627
628 #[test]
629 fn test_concurrent_access() {
630 use std::sync::Arc;
631
632 let config = RateLimiterConfig::new(1000.0, 100).with_per_client(true);
633 let limiter = Arc::new(RateLimiter::new(config));
634
635 let mut handles = Vec::new();
636 for i in 0..8 {
637 let limiter = Arc::clone(&limiter);
638 let handle = thread::spawn(move || {
639 let client_id = format!("thread-client-{}", i);
640 let mut allowed = 0u32;
641 for _ in 0..50 {
642 if limiter.check_rate_limit(&client_id).is_ok() {
643 allowed += 1;
644 }
645 }
646 allowed
647 });
648 handles.push(handle);
649 }
650
651 let mut total_allowed = 0u32;
652 for handle in handles {
653 let count = handle.join().expect("Thread panicked");
654 total_allowed += count;
655 }
656
657 assert_eq!(
660 total_allowed, 400,
661 "All requests should be allowed (50 per thread * 8 threads)"
662 );
663 }
664
665 #[test]
666 fn test_concurrent_same_client() {
667 use std::sync::Arc;
668
669 let config = RateLimiterConfig::new(0.001, 50).with_per_client(true);
671 let limiter = Arc::new(RateLimiter::new(config));
672
673 let mut handles = Vec::new();
674 for _ in 0..4 {
675 let limiter = Arc::clone(&limiter);
676 let handle = thread::spawn(move || {
677 let mut allowed = 0u32;
678 for _ in 0..20 {
679 if limiter.check_rate_limit("shared-client").is_ok() {
680 allowed += 1;
681 }
682 }
683 allowed
684 });
685 handles.push(handle);
686 }
687
688 let mut total_allowed = 0u32;
689 for handle in handles {
690 let count = handle.join().expect("Thread panicked");
691 total_allowed += count;
692 }
693
694 assert_eq!(
697 total_allowed, 50,
698 "Total allowed should equal burst_size for shared client"
699 );
700 }
701
702 #[test]
703 fn test_try_acquire_convenience() {
704 let config = RateLimiterConfig::new(100.0, 2).with_per_client(true);
705 let limiter = RateLimiter::new(config);
706
707 assert!(limiter.try_acquire("client-x"));
708 assert!(limiter.try_acquire("client-x"));
709 assert!(!limiter.try_acquire("client-x"));
710 }
711
712 #[test]
713 fn test_remaining_tokens() {
714 let config = RateLimiterConfig::new(100.0, 5).with_per_client(true);
715 let limiter = RateLimiter::new(config);
716
717 assert_eq!(limiter.remaining_tokens("new-client"), 5);
719
720 assert!(limiter.check_rate_limit("new-client").is_ok());
722 assert_eq!(limiter.remaining_tokens("new-client"), 4);
723 }
724
725 #[test]
726 fn test_reset_client() {
727 let config = RateLimiterConfig::new(100.0, 3).with_per_client(true);
728 let limiter = RateLimiter::new(config);
729
730 for _ in 0..3 {
732 assert!(limiter.check_rate_limit("reset-client").is_ok());
733 }
734 assert!(limiter.check_rate_limit("reset-client").is_err());
735
736 limiter.reset("reset-client");
738 assert!(
739 limiter.check_rate_limit("reset-client").is_ok(),
740 "Should be able to make requests after reset"
741 );
742 }
743
744 #[test]
745 fn test_reset_all() {
746 let config = RateLimiterConfig::new(100.0, 2)
747 .with_per_client(true)
748 .with_global_limit(5);
749 let limiter = RateLimiter::new(config);
750
751 assert!(limiter.check_rate_limit("a").is_ok());
752 assert!(limiter.check_rate_limit("b").is_ok());
753 assert_eq!(limiter.tracked_client_count(), 2);
754
755 limiter.reset_all();
756 assert_eq!(limiter.tracked_client_count(), 0);
757 }
758
759 #[test]
760 fn test_config_default() {
761 let config = RateLimiterConfig::default();
762 assert!((config.requests_per_second - 100.0).abs() < f64::EPSILON);
763 assert_eq!(config.burst_size, 50);
764 assert!(config.per_client);
765 assert!(config.global_limit.is_none());
766 assert_eq!(config.idle_timeout, Duration::from_secs(300));
767 }
768
769 #[test]
770 fn test_config_builder_pattern() {
771 let config = RateLimiterConfig::new(200.0, 100)
772 .with_per_client(false)
773 .with_global_limit(500)
774 .with_idle_timeout(Duration::from_secs(60));
775
776 assert!((config.requests_per_second - 200.0).abs() < f64::EPSILON);
777 assert_eq!(config.burst_size, 100);
778 assert!(!config.per_client);
779 assert_eq!(config.global_limit, Some(500));
780 assert_eq!(config.idle_timeout, Duration::from_secs(60));
781 }
782
783 #[test]
784 fn test_debug_impl() {
785 let config = RateLimiterConfig::new(50.0, 10);
786 let limiter = RateLimiter::new(config);
787 let debug_str = format!("{:?}", limiter);
788 assert!(debug_str.contains("RateLimiter"));
789 assert!(debug_str.contains("tracked_clients"));
790 }
791
792 #[test]
793 fn test_global_and_per_client_combined() {
794 let config = RateLimiterConfig::new(0.001, 3)
796 .with_per_client(true)
797 .with_global_limit(5);
798 let limiter = RateLimiter::new(config);
799
800 assert!(limiter.check_rate_limit("a").is_ok());
802 assert!(limiter.check_rate_limit("a").is_ok());
803 assert!(limiter.check_rate_limit("a").is_ok());
804 assert!(
806 limiter.check_rate_limit("a").is_err(),
807 "Client A should hit per-client limit"
808 );
809
810 assert!(limiter.check_rate_limit("b").is_ok()); assert!(limiter.check_rate_limit("b").is_ok()); let result = limiter.check_rate_limit("b");
816 assert!(result.is_err(), "Should hit global limit");
817 assert!(
818 matches!(result, Err(RateLimitError::GlobalLimitExceeded { .. })),
819 "Error should be GlobalLimitExceeded"
820 );
821 }
822
823 #[test]
824 fn test_zero_refill_rate_retry_after() {
825 let bucket = TokenBucket::new(0.0, 0.0);
826 assert_eq!(bucket.retry_after_ms(), u64::MAX);
827 }
828}