Skip to main content

amaters_net/
rate_limiter.rs

1//! Token bucket rate limiter for the network/server layer
2//!
3//! Provides per-client and global rate limiting using a token bucket algorithm.
4//! Designed for concurrent access in a gRPC server environment.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use amaters_net::rate_limiter::{RateLimiter, RateLimiterConfig};
10//!
11//! let config = RateLimiterConfig::new(100.0, 50);
12//! let limiter = RateLimiter::new(config);
13//!
14//! // Check rate limit for a client
15//! match limiter.check_rate_limit("client-1") {
16//!     Ok(()) => { /* request allowed */ }
17//!     Err(e) => { /* rate limited */ }
18//! }
19//! ```
20
21use dashmap::DashMap;
22use parking_lot::Mutex;
23use std::time::{Duration, Instant};
24use tracing::{debug, warn};
25
26/// Rate limiter configuration
27#[derive(Debug, Clone)]
28pub struct RateLimiterConfig {
29    /// Maximum sustained requests per second
30    pub requests_per_second: f64,
31    /// Maximum burst size (peak capacity above sustained rate)
32    pub burst_size: u32,
33    /// Whether to track per-client limits
34    pub per_client: bool,
35    /// Optional global limit (requests per second across all clients)
36    pub global_limit: Option<u32>,
37    /// Duration after which idle client buckets are eligible for cleanup
38    pub idle_timeout: Duration,
39}
40
41impl RateLimiterConfig {
42    /// Create a new rate limiter configuration
43    ///
44    /// # Arguments
45    /// * `requests_per_second` - Sustained request rate
46    /// * `burst_size` - Maximum burst capacity
47    pub fn new(requests_per_second: f64, burst_size: u32) -> Self {
48        Self {
49            requests_per_second,
50            burst_size,
51            per_client: true,
52            global_limit: None,
53            idle_timeout: Duration::from_secs(300), // 5 minutes default
54        }
55    }
56
57    /// Set whether to use per-client tracking
58    #[must_use]
59    pub fn with_per_client(mut self, per_client: bool) -> Self {
60        self.per_client = per_client;
61        self
62    }
63
64    /// Set a global rate limit
65    #[must_use]
66    pub fn with_global_limit(mut self, limit: u32) -> Self {
67        self.global_limit = Some(limit);
68        self
69    }
70
71    /// Set the idle timeout for client bucket cleanup
72    #[must_use]
73    pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
74        self.idle_timeout = timeout;
75        self
76    }
77}
78
79impl Default for RateLimiterConfig {
80    fn default() -> Self {
81        Self::new(100.0, 50)
82    }
83}
84
85/// Error returned when a rate limit is exceeded
86#[derive(Debug, Clone)]
87pub enum RateLimitError {
88    /// The global rate limit has been exceeded
89    GlobalLimitExceeded {
90        /// Milliseconds the client should wait before retrying
91        retry_after_ms: u64,
92    },
93    /// A per-client rate limit has been exceeded
94    ClientLimitExceeded {
95        /// Identifier of the rate-limited client
96        client_id: String,
97        /// Milliseconds the client should wait before retrying
98        retry_after_ms: u64,
99    },
100}
101
102impl std::fmt::Display for RateLimitError {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        match self {
105            RateLimitError::GlobalLimitExceeded { retry_after_ms } => {
106                write!(
107                    f,
108                    "Global rate limit exceeded, retry after {}ms",
109                    retry_after_ms
110                )
111            }
112            RateLimitError::ClientLimitExceeded {
113                client_id,
114                retry_after_ms,
115            } => {
116                write!(
117                    f,
118                    "Rate limit exceeded for client '{}', retry after {}ms",
119                    client_id, retry_after_ms
120                )
121            }
122        }
123    }
124}
125
126impl std::error::Error for RateLimitError {}
127
128/// Token bucket implementation
129///
130/// Tracks available tokens and refills them at a constant rate.
131/// Allows bursts up to `max_tokens` followed by sustained usage at `refill_rate`.
132#[derive(Debug)]
133pub struct TokenBucket {
134    /// Current number of available tokens (can be fractional during refill)
135    tokens: f64,
136    /// Maximum token capacity (determines burst size)
137    max_tokens: f64,
138    /// Tokens added per second
139    refill_rate: f64,
140    /// Timestamp of the last refill calculation
141    last_refill: Instant,
142}
143
144impl TokenBucket {
145    /// Create a new token bucket
146    ///
147    /// # Arguments
148    /// * `max_tokens` - Maximum capacity (burst size)
149    /// * `refill_rate` - Tokens per second refill rate
150    pub fn new(max_tokens: f64, refill_rate: f64) -> Self {
151        Self {
152            tokens: max_tokens,
153            max_tokens,
154            refill_rate,
155            last_refill: Instant::now(),
156        }
157    }
158
159    /// Refill tokens based on elapsed time since last refill
160    fn refill(&mut self) {
161        let now = Instant::now();
162        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
163        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
164        self.last_refill = now;
165    }
166
167    /// Try to consume one token, returning true if successful
168    pub fn try_acquire(&mut self) -> bool {
169        self.refill();
170        if self.tokens >= 1.0 {
171            self.tokens -= 1.0;
172            true
173        } else {
174            false
175        }
176    }
177
178    /// Get the number of remaining whole tokens (after refill)
179    pub fn remaining(&mut self) -> u32 {
180        self.refill();
181        self.tokens.floor().max(0.0) as u32
182    }
183
184    /// Calculate how many milliseconds until at least one token is available
185    pub fn retry_after_ms(&self) -> u64 {
186        if self.tokens >= 1.0 {
187            return 0;
188        }
189        let deficit = 1.0 - self.tokens;
190        if self.refill_rate <= 0.0 {
191            return u64::MAX;
192        }
193        let seconds = deficit / self.refill_rate;
194        (seconds * 1000.0).ceil() as u64
195    }
196
197    /// Reset the bucket to full capacity
198    pub fn reset(&mut self) {
199        self.tokens = self.max_tokens;
200        self.last_refill = Instant::now();
201    }
202
203    /// Return the last time this bucket was accessed (for idle detection)
204    pub fn last_access(&self) -> Instant {
205        self.last_refill
206    }
207}
208
209/// Rate limiter supporting per-client and global limits
210///
211/// Thread-safe: uses `parking_lot::Mutex` for the global bucket and
212/// `DashMap` with per-entry `Mutex` for client buckets.
213pub struct RateLimiter {
214    /// Configuration
215    config: RateLimiterConfig,
216    /// Global token bucket (shared across all clients)
217    global_bucket: Mutex<TokenBucket>,
218    /// Per-client token buckets
219    client_buckets: DashMap<String, Mutex<TokenBucket>>,
220}
221
222impl RateLimiter {
223    /// Create a new rate limiter with the given configuration
224    pub fn new(config: RateLimiterConfig) -> Self {
225        let global_max = config
226            .global_limit
227            .map(f64::from)
228            .unwrap_or(config.requests_per_second * 2.0);
229        let global_rate = config
230            .global_limit
231            .map(f64::from)
232            .unwrap_or(config.requests_per_second * 2.0);
233
234        Self {
235            config: config.clone(),
236            global_bucket: Mutex::new(TokenBucket::new(global_max, global_rate)),
237            client_buckets: DashMap::new(),
238        }
239    }
240
241    /// Check rate limit for a given client, returning an error if exceeded
242    ///
243    /// Checks the per-client limit first, then the global limit (if enabled).
244    /// Tokens are only consumed from the global bucket if the per-client check passes,
245    /// ensuring that a per-client rejection does not waste global capacity.
246    pub fn check_rate_limit(&self, client_id: &str) -> Result<(), RateLimitError> {
247        // Check per-client limit first (cheaper, more common rejection reason)
248        if self.config.per_client {
249            let bucket = self
250                .client_buckets
251                .entry(client_id.to_string())
252                .or_insert_with(|| {
253                    Mutex::new(TokenBucket::new(
254                        f64::from(self.config.burst_size),
255                        self.config.requests_per_second,
256                    ))
257                });
258
259            let mut bucket_guard = bucket.lock();
260            if !bucket_guard.try_acquire() {
261                let retry_after_ms = bucket_guard.retry_after_ms();
262                debug!(
263                    client_id = %client_id,
264                    retry_after_ms = retry_after_ms,
265                    "Per-client rate limit exceeded"
266                );
267                return Err(RateLimitError::ClientLimitExceeded {
268                    client_id: client_id.to_string(),
269                    retry_after_ms,
270                });
271            }
272        }
273
274        // Check global limit (only if per-client check passed)
275        if self.config.global_limit.is_some() {
276            let mut global = self.global_bucket.lock();
277            if !global.try_acquire() {
278                let retry_after_ms = global.retry_after_ms();
279                warn!(
280                    client_id = %client_id,
281                    retry_after_ms = retry_after_ms,
282                    "Global rate limit exceeded"
283                );
284                return Err(RateLimitError::GlobalLimitExceeded { retry_after_ms });
285            }
286        }
287
288        Ok(())
289    }
290
291    /// Try to acquire a token for the given client, returning true if allowed
292    ///
293    /// This is a convenience wrapper around `check_rate_limit`.
294    pub fn try_acquire(&self, client_id: &str) -> bool {
295        self.check_rate_limit(client_id).is_ok()
296    }
297
298    /// Get the number of remaining tokens for a client
299    ///
300    /// Returns the per-client remaining tokens, or if per-client tracking is
301    /// disabled, returns the global remaining tokens.
302    pub fn remaining_tokens(&self, client_id: &str) -> u32 {
303        if self.config.per_client {
304            if let Some(bucket) = self.client_buckets.get(client_id) {
305                return bucket.lock().remaining();
306            }
307            // Client hasn't been seen yet, return full burst capacity
308            return self.config.burst_size;
309        }
310
311        // Fall back to global bucket
312        self.global_bucket.lock().remaining()
313    }
314
315    /// Remove client buckets that have been idle longer than the configured timeout
316    ///
317    /// Returns the number of buckets removed.
318    pub fn cleanup_expired_buckets(&self) -> usize {
319        let now = Instant::now();
320        let timeout = self.config.idle_timeout;
321        let mut removed = 0;
322
323        // Collect keys to remove (avoid holding DashMap shard locks during removal)
324        let expired_keys: Vec<String> = self
325            .client_buckets
326            .iter()
327            .filter_map(|entry| {
328                let bucket = entry.value().lock();
329                if now.duration_since(bucket.last_access()) > timeout {
330                    Some(entry.key().clone())
331                } else {
332                    None
333                }
334            })
335            .collect();
336
337        for key in &expired_keys {
338            // Re-check under removal to avoid race conditions
339            if let Some((_k, bucket)) = self.client_buckets.remove(key) {
340                let guard = bucket.lock();
341                if now.duration_since(guard.last_access()) > timeout {
342                    removed += 1;
343                    debug!(client_id = %key, "Cleaned up expired rate limiter bucket");
344                } else {
345                    // Bucket was accessed between our scan and removal; put it back
346                    drop(guard);
347                    self.client_buckets.insert(key.clone(), bucket);
348                }
349            }
350        }
351
352        if removed > 0 {
353            debug!(count = removed, "Cleaned up expired rate limiter buckets");
354        }
355
356        removed
357    }
358
359    /// Reset the rate limiter state for a specific client
360    pub fn reset(&self, client_id: &str) {
361        if let Some(bucket) = self.client_buckets.get(client_id) {
362            bucket.lock().reset();
363        }
364    }
365
366    /// Reset all rate limiter state (global and all clients)
367    pub fn reset_all(&self) {
368        self.global_bucket.lock().reset();
369        self.client_buckets.clear();
370    }
371
372    /// Get the current number of tracked clients
373    pub fn tracked_client_count(&self) -> usize {
374        self.client_buckets.len()
375    }
376
377    /// Get a reference to the configuration
378    pub fn config(&self) -> &RateLimiterConfig {
379        &self.config
380    }
381}
382
383impl std::fmt::Debug for RateLimiter {
384    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385        f.debug_struct("RateLimiter")
386            .field("config", &self.config)
387            .field("tracked_clients", &self.client_buckets.len())
388            .finish()
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use std::thread;
396    use std::time::Duration;
397
398    #[test]
399    fn test_token_bucket_basic() {
400        let mut bucket = TokenBucket::new(5.0, 10.0);
401
402        // Should be able to acquire 5 tokens (burst size)
403        for _ in 0..5 {
404            assert!(
405                bucket.try_acquire(),
406                "Should acquire token from full bucket"
407            );
408        }
409
410        // 6th should fail (no time for refill)
411        assert!(!bucket.try_acquire(), "Should fail when bucket is depleted");
412    }
413
414    #[test]
415    fn test_token_bucket_refill() {
416        let mut bucket = TokenBucket::new(3.0, 100.0); // 100 tokens/sec
417
418        // Drain all tokens
419        for _ in 0..3 {
420            assert!(bucket.try_acquire());
421        }
422        assert!(!bucket.try_acquire(), "Bucket should be empty");
423
424        // Wait for refill (at 100 tokens/sec, 20ms should give ~2 tokens)
425        thread::sleep(Duration::from_millis(25));
426
427        assert!(
428            bucket.try_acquire(),
429            "Should have refilled at least one token after 25ms at 100/s"
430        );
431    }
432
433    #[test]
434    fn test_token_bucket_remaining() {
435        let mut bucket = TokenBucket::new(10.0, 1.0);
436        assert_eq!(bucket.remaining(), 10);
437
438        assert!(bucket.try_acquire());
439        assert_eq!(bucket.remaining(), 9);
440    }
441
442    #[test]
443    fn test_token_bucket_retry_after() {
444        let mut bucket = TokenBucket::new(1.0, 10.0); // 10 tokens/sec
445
446        // Drain
447        assert!(bucket.try_acquire());
448        assert!(!bucket.try_acquire());
449
450        let retry = bucket.retry_after_ms();
451        // At 10 tokens/sec, refill one token takes 100ms
452        // retry_after should be <= 100ms (accounting for time elapsed during test)
453        assert!(
454            retry <= 110,
455            "retry_after_ms should be approximately 100ms, got {}",
456            retry
457        );
458        assert!(retry > 0, "retry_after_ms should be > 0 when depleted");
459    }
460
461    #[test]
462    fn test_token_bucket_reset() {
463        let mut bucket = TokenBucket::new(5.0, 1.0);
464
465        // Drain
466        for _ in 0..5 {
467            assert!(bucket.try_acquire());
468        }
469        assert!(!bucket.try_acquire());
470
471        // Reset
472        bucket.reset();
473        assert_eq!(bucket.remaining(), 5);
474        assert!(bucket.try_acquire());
475    }
476
477    #[test]
478    fn test_per_client_isolation() {
479        let config = RateLimiterConfig::new(1000.0, 3).with_per_client(true);
480        let limiter = RateLimiter::new(config);
481
482        // Exhaust client A's tokens
483        for _ in 0..3 {
484            assert!(limiter.check_rate_limit("client-a").is_ok());
485        }
486        assert!(
487            limiter.check_rate_limit("client-a").is_err(),
488            "Client A should be rate limited"
489        );
490
491        // Client B should still have tokens
492        assert!(
493            limiter.check_rate_limit("client-b").is_ok(),
494            "Client B should not be affected by Client A's limit"
495        );
496    }
497
498    #[test]
499    fn test_global_limit() {
500        let config = RateLimiterConfig::new(1000.0, 10)
501            .with_per_client(false)
502            .with_global_limit(3);
503        let limiter = RateLimiter::new(config);
504
505        // Exhaust global limit across different clients
506        assert!(limiter.check_rate_limit("client-a").is_ok());
507        assert!(limiter.check_rate_limit("client-b").is_ok());
508        assert!(limiter.check_rate_limit("client-c").is_ok());
509
510        // 4th request from any client should be denied
511        let result = limiter.check_rate_limit("client-d");
512        assert!(result.is_err(), "Global limit should be enforced");
513        match result {
514            Err(RateLimitError::GlobalLimitExceeded { retry_after_ms }) => {
515                assert!(retry_after_ms > 0);
516            }
517            other => panic!("Expected GlobalLimitExceeded, got {:?}", other),
518        }
519    }
520
521    #[test]
522    fn test_burst_handling() {
523        let config = RateLimiterConfig::new(10.0, 20).with_per_client(true);
524        let limiter = RateLimiter::new(config);
525
526        // Should allow burst of 20 requests quickly
527        let mut allowed = 0;
528        for _ in 0..25 {
529            if limiter.check_rate_limit("burst-client").is_ok() {
530                allowed += 1;
531            }
532        }
533
534        assert_eq!(
535            allowed, 20,
536            "Should allow exactly burst_size requests in a burst"
537        );
538    }
539
540    #[test]
541    fn test_cleanup_expired() {
542        let config = RateLimiterConfig::new(100.0, 5)
543            .with_per_client(true)
544            .with_idle_timeout(Duration::from_millis(50));
545        let limiter = RateLimiter::new(config);
546
547        // Create some client buckets
548        assert!(limiter.check_rate_limit("client-1").is_ok());
549        assert!(limiter.check_rate_limit("client-2").is_ok());
550        assert_eq!(limiter.tracked_client_count(), 2);
551
552        // Wait for idle timeout
553        thread::sleep(Duration::from_millis(80));
554
555        let removed = limiter.cleanup_expired_buckets();
556        assert_eq!(removed, 2, "Both idle clients should be cleaned up");
557        assert_eq!(limiter.tracked_client_count(), 0);
558    }
559
560    #[test]
561    fn test_cleanup_keeps_active() {
562        let config = RateLimiterConfig::new(100.0, 5)
563            .with_per_client(true)
564            .with_idle_timeout(Duration::from_millis(100));
565        let limiter = RateLimiter::new(config);
566
567        // Create client bucket
568        assert!(limiter.check_rate_limit("active-client").is_ok());
569
570        // Wait less than idle timeout
571        thread::sleep(Duration::from_millis(30));
572
573        // Touch the active client again
574        assert!(limiter.check_rate_limit("active-client").is_ok());
575
576        // Create another client that will be idle
577        assert!(limiter.check_rate_limit("idle-client").is_ok());
578
579        // Wait enough for idle-client to expire but not active-client
580        thread::sleep(Duration::from_millis(120));
581
582        // Touch active client right before cleanup
583        assert!(limiter.check_rate_limit("active-client").is_ok());
584
585        let removed = limiter.cleanup_expired_buckets();
586        assert_eq!(removed, 1, "Only idle client should be cleaned up");
587        assert_eq!(limiter.tracked_client_count(), 1);
588    }
589
590    #[test]
591    fn test_rate_limit_error_display() {
592        let global_err = RateLimitError::GlobalLimitExceeded { retry_after_ms: 42 };
593        let msg = format!("{}", global_err);
594        assert!(msg.contains("Global rate limit exceeded"));
595        assert!(msg.contains("42ms"));
596
597        let client_err = RateLimitError::ClientLimitExceeded {
598            client_id: "test-client".to_string(),
599            retry_after_ms: 100,
600        };
601        let msg = format!("{}", client_err);
602        assert!(msg.contains("test-client"));
603        assert!(msg.contains("100ms"));
604    }
605
606    #[test]
607    fn test_rate_limit_error_details() {
608        let config = RateLimiterConfig::new(10.0, 2).with_per_client(true);
609        let limiter = RateLimiter::new(config);
610
611        // Exhaust tokens
612        assert!(limiter.check_rate_limit("err-client").is_ok());
613        assert!(limiter.check_rate_limit("err-client").is_ok());
614
615        let result = limiter.check_rate_limit("err-client");
616        match result {
617            Err(RateLimitError::ClientLimitExceeded {
618                client_id,
619                retry_after_ms,
620            }) => {
621                assert_eq!(client_id, "err-client");
622                assert!(retry_after_ms > 0);
623            }
624            other => panic!("Expected ClientLimitExceeded, got {:?}", other),
625        }
626    }
627
628    #[test]
629    fn test_concurrent_access() {
630        use std::sync::Arc;
631
632        let config = RateLimiterConfig::new(1000.0, 100).with_per_client(true);
633        let limiter = Arc::new(RateLimiter::new(config));
634
635        let mut handles = Vec::new();
636        for i in 0..8 {
637            let limiter = Arc::clone(&limiter);
638            let handle = thread::spawn(move || {
639                let client_id = format!("thread-client-{}", i);
640                let mut allowed = 0u32;
641                for _ in 0..50 {
642                    if limiter.check_rate_limit(&client_id).is_ok() {
643                        allowed += 1;
644                    }
645                }
646                allowed
647            });
648            handles.push(handle);
649        }
650
651        let mut total_allowed = 0u32;
652        for handle in handles {
653            let count = handle.join().expect("Thread panicked");
654            total_allowed += count;
655        }
656
657        // Each of 8 threads has burst_size=100 tokens and makes 50 requests
658        // All 50 from each thread should succeed (50 < 100)
659        assert_eq!(
660            total_allowed, 400,
661            "All requests should be allowed (50 per thread * 8 threads)"
662        );
663    }
664
665    #[test]
666    fn test_concurrent_same_client() {
667        use std::sync::Arc;
668
669        // Use a very low refill rate so tokens don't replenish during the test
670        let config = RateLimiterConfig::new(0.001, 50).with_per_client(true);
671        let limiter = Arc::new(RateLimiter::new(config));
672
673        let mut handles = Vec::new();
674        for _ in 0..4 {
675            let limiter = Arc::clone(&limiter);
676            let handle = thread::spawn(move || {
677                let mut allowed = 0u32;
678                for _ in 0..20 {
679                    if limiter.check_rate_limit("shared-client").is_ok() {
680                        allowed += 1;
681                    }
682                }
683                allowed
684            });
685            handles.push(handle);
686        }
687
688        let mut total_allowed = 0u32;
689        for handle in handles {
690            let count = handle.join().expect("Thread panicked");
691            total_allowed += count;
692        }
693
694        // 4 threads * 20 requests = 80 total attempts, but burst_size=50
695        // With near-zero refill rate, total allowed should be exactly 50
696        assert_eq!(
697            total_allowed, 50,
698            "Total allowed should equal burst_size for shared client"
699        );
700    }
701
702    #[test]
703    fn test_try_acquire_convenience() {
704        let config = RateLimiterConfig::new(100.0, 2).with_per_client(true);
705        let limiter = RateLimiter::new(config);
706
707        assert!(limiter.try_acquire("client-x"));
708        assert!(limiter.try_acquire("client-x"));
709        assert!(!limiter.try_acquire("client-x"));
710    }
711
712    #[test]
713    fn test_remaining_tokens() {
714        let config = RateLimiterConfig::new(100.0, 5).with_per_client(true);
715        let limiter = RateLimiter::new(config);
716
717        // Before any requests, should return full burst capacity
718        assert_eq!(limiter.remaining_tokens("new-client"), 5);
719
720        // After one request
721        assert!(limiter.check_rate_limit("new-client").is_ok());
722        assert_eq!(limiter.remaining_tokens("new-client"), 4);
723    }
724
725    #[test]
726    fn test_reset_client() {
727        let config = RateLimiterConfig::new(100.0, 3).with_per_client(true);
728        let limiter = RateLimiter::new(config);
729
730        // Exhaust
731        for _ in 0..3 {
732            assert!(limiter.check_rate_limit("reset-client").is_ok());
733        }
734        assert!(limiter.check_rate_limit("reset-client").is_err());
735
736        // Reset
737        limiter.reset("reset-client");
738        assert!(
739            limiter.check_rate_limit("reset-client").is_ok(),
740            "Should be able to make requests after reset"
741        );
742    }
743
744    #[test]
745    fn test_reset_all() {
746        let config = RateLimiterConfig::new(100.0, 2)
747            .with_per_client(true)
748            .with_global_limit(5);
749        let limiter = RateLimiter::new(config);
750
751        assert!(limiter.check_rate_limit("a").is_ok());
752        assert!(limiter.check_rate_limit("b").is_ok());
753        assert_eq!(limiter.tracked_client_count(), 2);
754
755        limiter.reset_all();
756        assert_eq!(limiter.tracked_client_count(), 0);
757    }
758
759    #[test]
760    fn test_config_default() {
761        let config = RateLimiterConfig::default();
762        assert!((config.requests_per_second - 100.0).abs() < f64::EPSILON);
763        assert_eq!(config.burst_size, 50);
764        assert!(config.per_client);
765        assert!(config.global_limit.is_none());
766        assert_eq!(config.idle_timeout, Duration::from_secs(300));
767    }
768
769    #[test]
770    fn test_config_builder_pattern() {
771        let config = RateLimiterConfig::new(200.0, 100)
772            .with_per_client(false)
773            .with_global_limit(500)
774            .with_idle_timeout(Duration::from_secs(60));
775
776        assert!((config.requests_per_second - 200.0).abs() < f64::EPSILON);
777        assert_eq!(config.burst_size, 100);
778        assert!(!config.per_client);
779        assert_eq!(config.global_limit, Some(500));
780        assert_eq!(config.idle_timeout, Duration::from_secs(60));
781    }
782
783    #[test]
784    fn test_debug_impl() {
785        let config = RateLimiterConfig::new(50.0, 10);
786        let limiter = RateLimiter::new(config);
787        let debug_str = format!("{:?}", limiter);
788        assert!(debug_str.contains("RateLimiter"));
789        assert!(debug_str.contains("tracked_clients"));
790    }
791
792    #[test]
793    fn test_global_and_per_client_combined() {
794        // Global limit of 5, per-client burst of 3, near-zero refill
795        let config = RateLimiterConfig::new(0.001, 3)
796            .with_per_client(true)
797            .with_global_limit(5);
798        let limiter = RateLimiter::new(config);
799
800        // Client A uses 3 (hits per-client limit, also consumes 3 global tokens)
801        assert!(limiter.check_rate_limit("a").is_ok());
802        assert!(limiter.check_rate_limit("a").is_ok());
803        assert!(limiter.check_rate_limit("a").is_ok());
804        // 4th request: per-client exhausted, fails before touching global
805        assert!(
806            limiter.check_rate_limit("a").is_err(),
807            "Client A should hit per-client limit"
808        );
809
810        // Client B: per-client has 3 tokens, but global only has 2 remaining (5-3)
811        assert!(limiter.check_rate_limit("b").is_ok()); // global: 1 remaining
812        assert!(limiter.check_rate_limit("b").is_ok()); // global: 0 remaining
813
814        // Next request: per-client still has 1 token, but global is exhausted
815        let result = limiter.check_rate_limit("b");
816        assert!(result.is_err(), "Should hit global limit");
817        assert!(
818            matches!(result, Err(RateLimitError::GlobalLimitExceeded { .. })),
819            "Error should be GlobalLimitExceeded"
820        );
821    }
822
823    #[test]
824    fn test_zero_refill_rate_retry_after() {
825        let bucket = TokenBucket::new(0.0, 0.0);
826        assert_eq!(bucket.retry_after_ms(), u64::MAX);
827    }
828}