Skip to main content

synapse_pingora/
ratelimit.rs

1//! Per-site rate limiting with token bucket algorithm.
2//!
3//! Provides hostname-aware rate limiting with configurable limits,
4//! burst capacity, and sliding window tracking.
5
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tracing::{debug, info, warn};
13
14/// Adaptive rate limiter that adjusts limits based on system health.
15pub struct AdaptiveRateLimiter {
16    /// Reference to the rate limit manager to adjust
17    manager: Arc<RwLock<RateLimitManager>>,
18    /// Reference to metrics for health monitoring
19    metrics: Arc<crate::metrics::MetricsRegistry>,
20    /// Last recorded backend latency (ms)
21    last_latency_ms: AtomicU64,
22    /// Current rate multiplier (scaled by 1000, 1000 = 1.0)
23    multiplier: AtomicU64,
24}
25
26impl AdaptiveRateLimiter {
27    /// Creates a new adaptive rate limiter.
28    pub fn new(
29        manager: Arc<RwLock<RateLimitManager>>,
30        metrics: Arc<crate::metrics::MetricsRegistry>,
31    ) -> Self {
32        Self {
33            manager,
34            metrics,
35            last_latency_ms: AtomicU64::new(0),
36            multiplier: AtomicU64::new(1000), // Start at 1.0
37        }
38    }
39
40    /// Performs one adjustment cycle based on current metrics.
41    ///
42    /// Logic:
43    /// - If latency > 500ms OR CPU > 80%: reduce rate by 10%
44    /// - If latency < 100ms AND CPU < 50%: increase rate by 5% (up to 1.0)
45    pub fn adjust(&self) {
46        let avg_latency = self.metrics.avg_latency_ms();
47
48        // Simple CPU check using sysinfo (if available in metrics)
49        // For now, let's just use latency as the primary signal
50
51        let current_mult = self.multiplier.load(Ordering::Relaxed);
52        let mut next_mult = current_mult;
53
54        if avg_latency > 500.0 {
55            // High latency: throttle hard (10% reduction)
56            // Note: Use a floor of 20% (200/1000) to prevent permanent DoS
57            next_mult = (current_mult.saturating_mul(90) / 100).max(200);
58            warn!(latency = %avg_latency, multiplier = %(next_mult as f64 / 1000.0), "Adaptive RL: High latency detected, throttling fleet");
59        } else if avg_latency > 200.0 {
60            // Moderate latency: slight throttle (5% reduction)
61            next_mult = (current_mult.saturating_mul(95) / 100).max(200);
62            debug!(latency = %avg_latency, multiplier = %(next_mult as f64 / 1000.0), "Adaptive RL: Latency rising, slowing down");
63        } else if avg_latency < 50.0 && current_mult < 1000 {
64            // Low latency: recover (5% increase) for faster restoration
65            next_mult = (current_mult.saturating_add(50)).min(1000);
66            debug!(latency = %avg_latency, multiplier = %(next_mult as f64 / 1000.0), "Adaptive RL: Health recovered, restoring capacity");
67        }
68
69        if next_mult != current_mult {
70            self.multiplier.store(next_mult, Ordering::Relaxed);
71            self.apply_multiplier(next_mult as f64 / 1000.0);
72        }
73    }
74
75    fn apply_multiplier(&self, multiplier: f64) {
76        let manager = self.manager.read();
77
78        // Update global limiter if present
79        if let Some(global) = &manager.global_limiter {
80            global.set_multiplier(multiplier);
81        }
82
83        // Update all site limiters
84        let sites = manager.site_limiters.read();
85        for limiter in sites.values() {
86            limiter.set_multiplier(multiplier);
87        }
88    }
89
90    /// Returns the current adaptive multiplier.
91    pub fn current_multiplier(&self) -> f64 {
92        self.multiplier.load(Ordering::Relaxed) as f64 / 1000.0
93    }
94
95    /// Starts a background thread that periodically adjusts rate limits.
96    pub fn start_background_task(self: Arc<Self>, interval: Duration) {
97        info!(?interval, "Starting adaptive rate limiting background task");
98        std::thread::spawn(move || loop {
99            self.adjust();
100            std::thread::sleep(interval);
101        });
102    }
103}
104
105/// Rate limit decision.
106#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107pub enum RateLimitDecision {
108    /// Request is allowed
109    Allow,
110    /// Request is rate limited
111    Limited,
112}
113
114/// Rate limit configuration for a site.
115#[derive(Debug, Clone)]
116pub struct RateLimitConfig {
117    /// Requests per second limit
118    pub rps: u32,
119    /// Burst capacity (tokens available for bursts)
120    pub burst: u32,
121    /// Whether rate limiting is enabled
122    pub enabled: bool,
123    /// Window duration for sliding window
124    pub window_secs: u64,
125}
126
127impl Default for RateLimitConfig {
128    fn default() -> Self {
129        Self {
130            rps: 1000,
131            burst: 2000,
132            enabled: true,
133            window_secs: 1,
134        }
135    }
136}
137
138impl RateLimitConfig {
139    /// Creates a new rate limit config with specified RPS.
140    pub fn new(rps: u32) -> Self {
141        Self {
142            rps,
143            burst: rps * 2,
144            enabled: true,
145            window_secs: 1,
146        }
147    }
148
149    /// Sets the burst capacity.
150    pub fn with_burst(mut self, burst: u32) -> Self {
151        self.burst = burst;
152        self
153    }
154
155    /// Disables rate limiting.
156    pub fn disabled() -> Self {
157        Self {
158            enabled: false,
159            ..Default::default()
160        }
161    }
162}
163
164/// Token bucket rate limiter.
165#[derive(Debug)]
166pub struct TokenBucket {
167    /// Available tokens
168    tokens: AtomicU64,
169    /// Maximum tokens (burst capacity)
170    max_tokens: u64,
171    /// Tokens added per second
172    refill_rate: AtomicU64,
173    /// Last refill timestamp (nanos since start)
174    last_refill: AtomicU64,
175    /// Start time for timestamp calculation
176    start_time: Instant,
177    /// Last access time (nanos since start) for LRU eviction (SP-001)
178    last_access: AtomicU64,
179}
180
181impl TokenBucket {
182    /// Creates a new token bucket.
183    pub fn new(rps: u32, burst: u32) -> Self {
184        let max_tokens = burst as u64;
185        Self {
186            tokens: AtomicU64::new(max_tokens),
187            max_tokens,
188            refill_rate: AtomicU64::new(rps as u64),
189            last_refill: AtomicU64::new(0),
190            start_time: Instant::now(),
191            last_access: AtomicU64::new(0),
192        }
193    }
194
195    /// Sets a new refill rate (RPS).
196    pub fn set_rate(&self, rps: u32) {
197        self.refill_rate.store(rps as u64, Ordering::Relaxed);
198    }
199
200    /// Tries to acquire a token, returning true if successful.
201    ///
202    /// Uses atomic CAS loop with proper memory ordering to prevent race conditions.
203    pub fn try_acquire(&self) -> bool {
204        // Update last access time for LRU eviction (SP-001)
205        let now_nanos = self.start_time.elapsed().as_nanos() as u64;
206        self.last_access.store(now_nanos, Ordering::Relaxed);
207
208        // First refill based on elapsed time
209        self.refill();
210
211        // CAS loop to atomically decrement tokens
212        loop {
213            // Acquire ordering ensures we see all previous writes
214            let current = self.tokens.load(Ordering::Acquire);
215            if current == 0 {
216                return false;
217            }
218
219            // AcqRel ordering on success ensures the decrement is visible to other threads
220            // and that we see their updates
221            match self.tokens.compare_exchange_weak(
222                current,
223                current - 1,
224                Ordering::AcqRel,
225                Ordering::Acquire,
226            ) {
227                Ok(_) => return true,
228                Err(_) => {
229                    // CAS failed, retry with fresh value
230                    // Use core::hint::spin_loop to hint CPU we're in a spin loop
231                    core::hint::spin_loop();
232                    continue;
233                }
234            }
235        }
236    }
237
238    /// Refills tokens based on elapsed time.
239    ///
240    /// SECURITY: Uses atomic CAS operations to prevent race conditions that could
241    /// allow burst bypass. The timestamp and token updates are coordinated to ensure
242    /// only one thread adds tokens for any given time period.
243    fn refill(&self) {
244        let now_nanos = self.start_time.elapsed().as_nanos() as u64;
245
246        // Retry loop for the entire refill operation to handle concurrent access
247        loop {
248            // Acquire ordering ensures we see the latest timestamp
249            let last = self.last_refill.load(Ordering::Acquire);
250
251            // No time has passed or time went backwards (shouldn't happen with Instant)
252            if now_nanos <= last {
253                return;
254            }
255
256            let elapsed_nanos = now_nanos - last;
257            // Only refill if meaningful time has passed (at least 1 microsecond)
258            if elapsed_nanos < 1000 {
259                return;
260            }
261
262            let elapsed_secs = elapsed_nanos as f64 / 1_000_000_000.0;
263            let refill_rate = self.refill_rate.load(Ordering::Relaxed);
264            let tokens_to_add = (elapsed_secs * refill_rate as f64) as u64;
265
266            if tokens_to_add == 0 {
267                return;
268            }
269
270            // Atomically claim this time window by updating last_refill
271            // If another thread updated it, we'll retry with the new value
272            match self.last_refill.compare_exchange(
273                last,
274                now_nanos,
275                Ordering::AcqRel,
276                Ordering::Acquire,
277            ) {
278                Ok(_) => {
279                    // We won the race to claim this time window
280                    // Now atomically add the tokens
281                    self.add_tokens(tokens_to_add);
282                    return;
283                }
284                Err(actual) => {
285                    // Another thread claimed this time window
286                    // If the actual value is >= now_nanos, no need to retry
287                    if actual >= now_nanos {
288                        return;
289                    }
290                    // Otherwise, there may be more time to claim, retry
291                    core::hint::spin_loop();
292                    continue;
293                }
294            }
295        }
296    }
297
298    /// Atomically adds tokens up to the maximum capacity.
299    ///
300    /// Uses a CAS loop to ensure thread-safe token addition without races.
301    #[inline]
302    fn add_tokens(&self, tokens_to_add: u64) {
303        loop {
304            let current = self.tokens.load(Ordering::Acquire);
305            let new_tokens = (current.saturating_add(tokens_to_add)).min(self.max_tokens);
306
307            // If we're already at max, nothing to do
308            if new_tokens == current {
309                return;
310            }
311
312            match self.tokens.compare_exchange_weak(
313                current,
314                new_tokens,
315                Ordering::AcqRel,
316                Ordering::Acquire,
317            ) {
318                Ok(_) => return,
319                Err(_) => {
320                    core::hint::spin_loop();
321                    continue;
322                }
323            }
324        }
325    }
326
327    /// Returns the current number of available tokens.
328    ///
329    /// Performs a refill first to ensure the count is up-to-date.
330    pub fn available_tokens(&self) -> u64 {
331        self.refill();
332        self.tokens.load(Ordering::Acquire)
333    }
334
335    /// Returns the last access time in nanoseconds since bucket creation.
336    pub fn last_access_nanos(&self) -> u64 {
337        self.last_access.load(Ordering::Relaxed)
338    }
339}
340
341/// Per-key rate limiter (e.g., by IP address).
342#[derive(Debug)]
343pub struct KeyedRateLimiter {
344    /// Key -> token bucket mapping
345    buckets: RwLock<HashMap<String, Arc<TokenBucket>>>,
346    /// Configuration
347    config: RateLimitConfig,
348    /// Maximum number of tracked keys (to prevent memory exhaustion)
349    max_keys: usize,
350    /// Current adaptive multiplier (scaled by 1000)
351    multiplier: AtomicU64,
352}
353
354impl KeyedRateLimiter {
355    /// Creates a new keyed rate limiter.
356    pub fn new(config: RateLimitConfig) -> Self {
357        Self {
358            buckets: RwLock::new(HashMap::new()),
359            config,
360            max_keys: 100_000, // Default max tracked keys
361            multiplier: AtomicU64::new(1000),
362        }
363    }
364
365    /// Sets the current adaptive multiplier.
366    pub fn set_multiplier(&self, multiplier: f64) {
367        let m = (multiplier * 1000.0) as u64;
368        self.multiplier.store(m, Ordering::Relaxed);
369
370        // Update all existing buckets
371        let new_rps = (self.config.rps as f64 * multiplier) as u32;
372        let buckets = self.buckets.read();
373        for bucket in buckets.values() {
374            bucket.set_rate(new_rps);
375        }
376    }
377
378    /// Sets the maximum number of tracked keys.
379    pub fn with_max_keys(mut self, max_keys: usize) -> Self {
380        self.max_keys = max_keys;
381        self
382    }
383
384    /// Checks if a request for the given key is allowed.
385    pub fn check(&self, key: &str) -> RateLimitDecision {
386        if !self.config.enabled {
387            return RateLimitDecision::Allow;
388        }
389
390        // Try to get existing bucket
391        {
392            let buckets = self.buckets.read();
393            if let Some(bucket) = buckets.get(key) {
394                return if bucket.try_acquire() {
395                    RateLimitDecision::Allow
396                } else {
397                    debug!("Rate limited key: {}", key);
398                    RateLimitDecision::Limited
399                };
400            }
401        }
402
403        // Create new bucket
404        {
405            let mut buckets = self.buckets.write();
406
407            // SP-001: LRU eviction — remove least-recently-accessed entries when at capacity.
408            // Previous approach used arbitrary HashMap iteration order; now we sort by
409            // last_access timestamp and evict the stalest 10%.
410            if buckets.len() >= self.max_keys {
411                warn!(
412                    "Rate limiter at capacity ({}), evicting stale entries",
413                    buckets.len()
414                );
415                let evict_count = self.max_keys / 10;
416                let mut entries: Vec<_> = buckets
417                    .iter()
418                    .map(|(k, v)| (k.clone(), v.last_access.load(Ordering::Relaxed)))
419                    .collect();
420                entries.sort_unstable_by_key(|&(_, ts)| ts);
421                for (k, _) in entries.into_iter().take(evict_count) {
422                    buckets.remove(&k);
423                }
424            }
425
426            let multiplier = self.multiplier.load(Ordering::Relaxed) as f64 / 1000.0;
427            let effective_rps = (self.config.rps as f64 * multiplier) as u32;
428            let bucket = Arc::new(TokenBucket::new(effective_rps, self.config.burst));
429            let allowed = bucket.try_acquire();
430            buckets.insert(key.to_string(), bucket);
431
432            if allowed {
433                RateLimitDecision::Allow
434            } else {
435                RateLimitDecision::Limited
436            }
437        }
438    }
439
440    /// Returns the number of tracked keys.
441    pub fn key_count(&self) -> usize {
442        self.buckets.read().len()
443    }
444
445    /// Clears all tracked keys.
446    pub fn clear(&self) {
447        self.buckets.write().clear();
448    }
449}
450
451/// Per-site rate limit manager.
452#[derive(Debug)]
453pub struct RateLimitManager {
454    /// Site hostname -> keyed limiter mapping
455    site_limiters: RwLock<HashMap<String, Arc<KeyedRateLimiter>>>,
456    /// Global limiter (applied to all sites)
457    global_limiter: Option<Arc<KeyedRateLimiter>>,
458    /// Default config for new sites
459    default_config: RateLimitConfig,
460}
461
462impl RateLimitManager {
463    /// Creates a new rate limit manager.
464    pub fn new() -> Self {
465        Self {
466            site_limiters: RwLock::new(HashMap::new()),
467            global_limiter: None,
468            default_config: RateLimitConfig::default(),
469        }
470    }
471
472    /// Creates a manager with a global rate limit.
473    pub fn with_global(config: RateLimitConfig) -> Self {
474        Self {
475            site_limiters: RwLock::new(HashMap::new()),
476            global_limiter: Some(Arc::new(KeyedRateLimiter::new(config.clone()))),
477            default_config: config,
478        }
479    }
480
481    /// Sets the default configuration for new sites.
482    pub fn set_default_config(&mut self, config: RateLimitConfig) {
483        self.default_config = config;
484    }
485
486    /// Adds a site-specific rate limiter.
487    pub fn add_site(&self, hostname: &str, config: RateLimitConfig) {
488        let limiter = Arc::new(KeyedRateLimiter::new(config));
489        self.site_limiters
490            .write()
491            .insert(hostname.to_lowercase(), limiter);
492    }
493
494    /// Removes a site-specific rate limiter.
495    pub fn remove_site(&self, hostname: &str) {
496        self.site_limiters.write().remove(&hostname.to_lowercase());
497    }
498
499    /// Checks if a request is allowed.
500    ///
501    /// # Arguments
502    /// * `hostname` - The site hostname
503    /// * `key` - The rate limit key (usually client IP)
504    pub fn check(&self, hostname: &str, key: &str) -> RateLimitDecision {
505        // Check global limiter first
506        if let Some(global) = &self.global_limiter {
507            if matches!(global.check(key), RateLimitDecision::Limited) {
508                return RateLimitDecision::Limited;
509            }
510        }
511
512        // Check site-specific limiter
513        let normalized = hostname.to_lowercase();
514        let limiters = self.site_limiters.read();
515
516        if let Some(limiter) = limiters.get(&normalized) {
517            return limiter.check(key);
518        }
519
520        // No site-specific limiter, allow
521        RateLimitDecision::Allow
522    }
523
524    /// Returns true if the request is allowed.
525    pub fn is_allowed(&self, hostname: &str, key: &str) -> bool {
526        matches!(self.check(hostname, key), RateLimitDecision::Allow)
527    }
528
529    /// Returns rate limit statistics.
530    pub fn stats(&self) -> RateLimitStats {
531        let limiters = self.site_limiters.read();
532        let total_keys: usize = limiters.values().map(|l| l.key_count()).sum();
533        let global_keys = self
534            .global_limiter
535            .as_ref()
536            .map(|l| l.key_count())
537            .unwrap_or(0);
538
539        RateLimitStats {
540            site_count: limiters.len(),
541            total_tracked_keys: total_keys + global_keys,
542            global_enabled: self.global_limiter.is_some(),
543        }
544    }
545}
546
547impl Default for RateLimitManager {
548    fn default() -> Self {
549        Self::new()
550    }
551}
552
553/// Rate limit statistics.
554#[derive(Debug, Clone, Serialize, Deserialize)]
555pub struct RateLimitStats {
556    /// Number of sites with rate limiting
557    pub site_count: usize,
558    /// Total number of tracked keys across all limiters
559    pub total_tracked_keys: usize,
560    /// Whether global rate limiting is enabled
561    pub global_enabled: bool,
562}
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567    use std::thread;
568    use std::time::Duration;
569
570    #[test]
571    fn test_token_bucket_basic() {
572        let bucket = TokenBucket::new(10, 10); // 10 RPS, 10 burst
573
574        // Should allow 10 requests immediately
575        for _ in 0..10 {
576            assert!(bucket.try_acquire());
577        }
578
579        // 11th should fail
580        assert!(!bucket.try_acquire());
581    }
582
583    #[test]
584    fn test_token_bucket_refill() {
585        let bucket = TokenBucket::new(1000, 10); // 1000 RPS, 10 burst
586
587        // Drain the bucket
588        for _ in 0..10 {
589            bucket.try_acquire();
590        }
591
592        // Wait a bit for refill
593        thread::sleep(Duration::from_millis(20));
594
595        // Should have some tokens now
596        assert!(bucket.try_acquire());
597    }
598
599    #[test]
600    fn test_rate_limit_config() {
601        let config = RateLimitConfig::new(100).with_burst(200);
602        assert_eq!(config.rps, 100);
603        assert_eq!(config.burst, 200);
604        assert!(config.enabled);
605    }
606
607    #[test]
608    fn test_rate_limit_disabled() {
609        let config = RateLimitConfig::disabled();
610        let limiter = KeyedRateLimiter::new(config);
611
612        // Should always allow when disabled
613        for _ in 0..1000 {
614            assert!(matches!(limiter.check("key"), RateLimitDecision::Allow));
615        }
616    }
617
618    #[test]
619    fn test_keyed_rate_limiter() {
620        let config = RateLimitConfig::new(5).with_burst(5);
621        let limiter = KeyedRateLimiter::new(config);
622
623        // Different keys have separate buckets
624        for _ in 0..5 {
625            assert!(matches!(limiter.check("key1"), RateLimitDecision::Allow));
626            assert!(matches!(limiter.check("key2"), RateLimitDecision::Allow));
627        }
628
629        // Both should now be limited
630        assert!(matches!(limiter.check("key1"), RateLimitDecision::Limited));
631        assert!(matches!(limiter.check("key2"), RateLimitDecision::Limited));
632    }
633
634    #[test]
635    fn test_keyed_limiter_key_count() {
636        let config = RateLimitConfig::new(10);
637        let limiter = KeyedRateLimiter::new(config);
638
639        limiter.check("key1");
640        limiter.check("key2");
641        limiter.check("key3");
642
643        assert_eq!(limiter.key_count(), 3);
644    }
645
646    #[test]
647    fn test_rate_limit_manager() {
648        let manager = RateLimitManager::new();
649
650        // Add site-specific limiter
651        manager.add_site("api.example.com", RateLimitConfig::new(2).with_burst(2));
652
653        // Should limit api.example.com
654        assert!(manager.is_allowed("api.example.com", "client1"));
655        assert!(manager.is_allowed("api.example.com", "client1"));
656        assert!(!manager.is_allowed("api.example.com", "client1"));
657
658        // Other sites should be allowed (no limiter)
659        assert!(manager.is_allowed("other.example.com", "client1"));
660    }
661
662    #[test]
663    fn test_global_rate_limit() {
664        let manager = RateLimitManager::with_global(RateLimitConfig::new(3).with_burst(3));
665
666        // Global limit applies to all
667        assert!(manager.is_allowed("any.com", "client1"));
668        assert!(manager.is_allowed("any.com", "client1"));
669        assert!(manager.is_allowed("any.com", "client1"));
670        assert!(!manager.is_allowed("any.com", "client1"));
671    }
672
673    #[test]
674    fn test_manager_case_insensitive() {
675        let manager = RateLimitManager::new();
676        manager.add_site("Example.COM", RateLimitConfig::new(1).with_burst(1));
677
678        assert!(manager.is_allowed("example.com", "client"));
679        assert!(!manager.is_allowed("EXAMPLE.COM", "client"));
680    }
681
682    #[test]
683    fn test_keyed_limiter_clear() {
684        let config = RateLimitConfig::new(10);
685        let limiter = KeyedRateLimiter::new(config);
686
687        limiter.check("key1");
688        limiter.check("key2");
689        assert_eq!(limiter.key_count(), 2);
690
691        limiter.clear();
692        assert_eq!(limiter.key_count(), 0);
693    }
694
695    #[test]
696    fn test_stats() {
697        let manager = RateLimitManager::with_global(RateLimitConfig::new(100));
698        manager.add_site("site1.com", RateLimitConfig::new(50));
699        manager.add_site("site2.com", RateLimitConfig::new(50));
700
701        // Generate some traffic
702        manager.check("site1.com", "ip1");
703        manager.check("site2.com", "ip2");
704
705        let stats = manager.stats();
706        assert_eq!(stats.site_count, 2);
707        assert!(stats.global_enabled);
708    }
709
710    #[test]
711    fn test_available_tokens() {
712        let bucket = TokenBucket::new(100, 50);
713        assert_eq!(bucket.available_tokens(), 50); // Starts at burst capacity
714    }
715
716    /// Concurrent stress test to verify no race condition in token bucket.
717    ///
718    /// SECURITY TEST: Verifies that under high concurrent load, the token bucket
719    /// doesn't allow more requests than the burst capacity (which would indicate
720    /// a race condition allowing burst bypass).
721    #[test]
722    fn test_concurrent_token_bucket_no_burst_bypass() {
723        use std::sync::atomic::AtomicUsize;
724
725        let bucket = Arc::new(TokenBucket::new(10, 100)); // 10 RPS, 100 burst
726        let successful_acquires = Arc::new(AtomicUsize::new(0));
727
728        // Spawn multiple threads to hammer the bucket concurrently
729        let handles: Vec<_> = (0..10)
730            .map(|_| {
731                let bucket = Arc::clone(&bucket);
732                let counter = Arc::clone(&successful_acquires);
733
734                thread::spawn(move || {
735                    for _ in 0..50 {
736                        if bucket.try_acquire() {
737                            counter.fetch_add(1, Ordering::Relaxed);
738                        }
739                    }
740                })
741            })
742            .collect();
743
744        // Wait for all threads
745        for handle in handles {
746            handle.join().unwrap();
747        }
748
749        let total = successful_acquires.load(Ordering::Relaxed);
750
751        // Should never exceed burst capacity (100).
752        // With the race condition fix, this should be exactly 100.
753        // Before the fix, it could be 110-120 (10-20% bypass).
754        assert!(
755            total <= 100,
756            "Race condition detected! Got {} successful acquires, expected <= 100",
757            total
758        );
759
760        // Should get close to the burst capacity
761        assert!(
762            total >= 95,
763            "Token bucket may have performance issue: only {} acquires, expected ~100",
764            total
765        );
766    }
767
768    /// Test concurrent refill doesn't double-add tokens.
769    #[test]
770    fn test_concurrent_refill_no_double_add() {
771        let bucket = Arc::new(TokenBucket::new(1000, 10)); // 1000 RPS, 10 burst
772
773        // Drain the bucket
774        for _ in 0..10 {
775            bucket.try_acquire();
776        }
777
778        // Wait a bit for refill opportunity
779        thread::sleep(Duration::from_millis(50)); // Should add ~50 tokens worth
780
781        let tokens_before = bucket.available_tokens();
782
783        // Spawn threads to trigger concurrent refills
784        let handles: Vec<_> = (0..10)
785            .map(|_| {
786                let bucket = Arc::clone(&bucket);
787                thread::spawn(move || {
788                    // Just read available_tokens which triggers refill
789                    bucket.available_tokens()
790                })
791            })
792            .collect();
793
794        for handle in handles {
795            handle.join().unwrap();
796        }
797
798        let tokens_after = bucket.available_tokens();
799
800        // Tokens should not have increased dramatically due to race
801        // (at most a few more tokens from the small time elapsed)
802        assert!(
803            tokens_after <= tokens_before + 10,
804            "Possible double-add race: before={}, after={}",
805            tokens_before,
806            tokens_after
807        );
808    }
809}