Skip to main content

allframe_core/resilience/
rate_limit.rs

1//! Rate limiting primitives with adaptive and keyed variants.
2//!
3//! Provides token bucket rate limiting for controlling request throughput.
4
5use std::{
6    hash::Hash,
7    num::NonZeroU32,
8    sync::{
9        atomic::{AtomicU32, AtomicU64, Ordering},
10        Arc,
11    },
12    time::{Duration, Instant},
13};
14
15use dashmap::DashMap;
16use governor::{
17    clock::{Clock, DefaultClock},
18    state::{InMemoryState, NotKeyed},
19    Quota, RateLimiter as GovernorRateLimiter,
20};
21use parking_lot::RwLock;
22
23/// Error returned when rate limit is exceeded.
24#[derive(Debug, Clone)]
25pub struct RateLimitError {
26    /// When the rate limit will reset.
27    pub retry_after: Duration,
28}
29
30impl std::fmt::Display for RateLimitError {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        write!(f, "rate limit exceeded, retry after {:?}", self.retry_after)
33    }
34}
35
36impl std::error::Error for RateLimitError {}
37
38/// Status information for a rate limiter.
39#[derive(Debug, Clone)]
40pub struct RateLimiterStatus {
41    /// Current requests per second.
42    pub current_rps: f64,
43    /// Configured maximum RPS.
44    pub max_rps: u32,
45    /// Burst capacity.
46    pub burst_size: u32,
47    /// Whether currently rate limited.
48    pub is_limited: bool,
49    /// Number of requests allowed in the last minute.
50    pub requests_last_minute: u64,
51    /// Number of requests rejected in the last minute.
52    pub rejections_last_minute: u64,
53}
54
55/// Token bucket rate limiter.
56///
57/// Allows a sustained rate of requests with burst capacity.
58pub struct RateLimiter {
59    limiter: GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>,
60    rps: u32,
61    burst_size: u32,
62    requests: AtomicU64,
63    rejections: AtomicU64,
64    last_reset: RwLock<Instant>,
65}
66
67impl RateLimiter {
68    /// Create a new rate limiter.
69    ///
70    /// # Arguments
71    /// * `rps` - Maximum requests per second
72    /// * `burst_size` - Additional burst capacity above the sustained rate
73    pub fn new(rps: u32, burst_size: u32) -> Self {
74        let rps_nz = NonZeroU32::new(rps.max(1)).unwrap();
75        let burst_nz = NonZeroU32::new(burst_size.max(1)).unwrap();
76
77        let quota = Quota::per_second(rps_nz).allow_burst(burst_nz);
78        let limiter = GovernorRateLimiter::direct(quota);
79
80        Self {
81            limiter,
82            rps,
83            burst_size,
84            requests: AtomicU64::new(0),
85            rejections: AtomicU64::new(0),
86            last_reset: RwLock::new(Instant::now()),
87        }
88    }
89
90    /// Check if a request is allowed without blocking.
91    ///
92    /// Returns `Ok(())` if allowed, `Err(RateLimitError)` if rate limited.
93    pub fn check(&self) -> Result<(), RateLimitError> {
94        self.maybe_reset_counters();
95
96        match self.limiter.check() {
97            Ok(_) => {
98                self.requests.fetch_add(1, Ordering::Relaxed);
99                Ok(())
100            }
101            Err(not_until) => {
102                self.rejections.fetch_add(1, Ordering::Relaxed);
103                Err(RateLimitError {
104                    retry_after: not_until.wait_time_from(DefaultClock::default().now()),
105                })
106            }
107        }
108    }
109
110    /// Wait until a request is allowed.
111    ///
112    /// Blocks the current task until the rate limit permits the request.
113    pub async fn wait(&self) {
114        self.maybe_reset_counters();
115        self.limiter.until_ready().await;
116        self.requests.fetch_add(1, Ordering::Relaxed);
117    }
118
119    /// Get the current status of the rate limiter.
120    pub fn get_status(&self) -> RateLimiterStatus {
121        self.maybe_reset_counters();
122
123        let requests = self.requests.load(Ordering::Relaxed);
124        let rejections = self.rejections.load(Ordering::Relaxed);
125        let elapsed = self.last_reset.read().elapsed().as_secs_f64().max(1.0);
126
127        RateLimiterStatus {
128            current_rps: requests as f64 / elapsed.min(60.0),
129            max_rps: self.rps,
130            burst_size: self.burst_size,
131            is_limited: self.limiter.check().is_err(),
132            requests_last_minute: requests,
133            rejections_last_minute: rejections,
134        }
135    }
136
137    fn maybe_reset_counters(&self) {
138        let mut last = self.last_reset.write();
139        if last.elapsed() > Duration::from_secs(60) {
140            self.requests.store(0, Ordering::Relaxed);
141            self.rejections.store(0, Ordering::Relaxed);
142            *last = Instant::now();
143        }
144    }
145}
146
147/// Adaptive rate limiter that backs off when receiving external rate limits.
148///
149/// When external services return 429 responses, this limiter reduces its
150/// throughput to avoid hammering the service.
151pub struct AdaptiveRateLimiter {
152    /// Base rate limiter.
153    base_rps: u32,
154    burst_size: u32,
155    /// Current effective RPS (may be reduced).
156    current_rps: AtomicU32,
157    /// Underlying rate limiter.
158    limiter: RwLock<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
159    /// Consecutive rate limit responses.
160    consecutive_limits: AtomicU32,
161    /// Last rate limit time.
162    last_limit: RwLock<Option<Instant>>,
163    /// Recovery interval.
164    recovery_interval: Duration,
165    /// Minimum RPS (floor).
166    min_rps: u32,
167    /// Backoff factor when rate limited.
168    backoff_factor: f64,
169    /// Statistics.
170    requests: AtomicU64,
171    rejections: AtomicU64,
172    external_limits: AtomicU64,
173}
174
175impl AdaptiveRateLimiter {
176    /// Create a new adaptive rate limiter.
177    ///
178    /// # Arguments
179    /// * `rps` - Base requests per second
180    /// * `burst_size` - Burst capacity
181    pub fn new(rps: u32, burst_size: u32) -> Self {
182        let limiter = Self::create_limiter(rps, burst_size);
183
184        Self {
185            base_rps: rps,
186            burst_size,
187            current_rps: AtomicU32::new(rps),
188            limiter: RwLock::new(limiter),
189            consecutive_limits: AtomicU32::new(0),
190            last_limit: RwLock::new(None),
191            recovery_interval: Duration::from_secs(30),
192            min_rps: 1,
193            backoff_factor: 0.5,
194            requests: AtomicU64::new(0),
195            rejections: AtomicU64::new(0),
196            external_limits: AtomicU64::new(0),
197        }
198    }
199
200    /// Set the recovery interval.
201    pub fn with_recovery_interval(mut self, interval: Duration) -> Self {
202        self.recovery_interval = interval;
203        self
204    }
205
206    /// Set the minimum RPS floor.
207    pub fn with_min_rps(mut self, min_rps: u32) -> Self {
208        self.min_rps = min_rps.max(1);
209        self
210    }
211
212    /// Set the backoff factor (0.0-1.0).
213    pub fn with_backoff_factor(mut self, factor: f64) -> Self {
214        self.backoff_factor = factor.clamp(0.1, 0.9);
215        self
216    }
217
218    fn create_limiter(
219        rps: u32,
220        burst_size: u32,
221    ) -> GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock> {
222        let rps_nz = NonZeroU32::new(rps.max(1)).unwrap();
223        let burst_nz = NonZeroU32::new(burst_size.max(1)).unwrap();
224        let quota = Quota::per_second(rps_nz).allow_burst(burst_nz);
225        GovernorRateLimiter::direct(quota)
226    }
227
228    /// Record a successful request.
229    pub fn record_success(&self) {
230        self.consecutive_limits.store(0, Ordering::Relaxed);
231        self.maybe_recover();
232    }
233
234    /// Record a rate limit response from an external service.
235    pub fn record_rate_limit(&self) {
236        self.external_limits.fetch_add(1, Ordering::Relaxed);
237        let consecutive = self.consecutive_limits.fetch_add(1, Ordering::Relaxed) + 1;
238        *self.last_limit.write() = Some(Instant::now());
239
240        // Reduce rate based on consecutive rate limits
241        let reduction = self.backoff_factor.powi(consecutive.min(5) as i32);
242        let new_rps = ((self.base_rps as f64 * reduction) as u32).max(self.min_rps);
243
244        self.current_rps.store(new_rps, Ordering::Relaxed);
245        *self.limiter.write() = Self::create_limiter(new_rps, self.burst_size);
246    }
247
248    fn maybe_recover(&self) {
249        let last_limit = *self.last_limit.read();
250        if let Some(last) = last_limit {
251            if last.elapsed() > self.recovery_interval {
252                // Gradually recover to base rate
253                let current = self.current_rps.load(Ordering::Relaxed);
254                if current < self.base_rps {
255                    let new_rps = ((current as f64 * 1.5) as u32).min(self.base_rps);
256                    self.current_rps.store(new_rps, Ordering::Relaxed);
257                    *self.limiter.write() = Self::create_limiter(new_rps, self.burst_size);
258
259                    if new_rps >= self.base_rps {
260                        *self.last_limit.write() = None;
261                    }
262                }
263            }
264        }
265    }
266
267    /// Check if a request is allowed.
268    pub fn check(&self) -> Result<(), RateLimitError> {
269        self.maybe_recover();
270
271        match self.limiter.read().check() {
272            Ok(_) => {
273                self.requests.fetch_add(1, Ordering::Relaxed);
274                Ok(())
275            }
276            Err(not_until) => {
277                self.rejections.fetch_add(1, Ordering::Relaxed);
278                Err(RateLimitError {
279                    retry_after: not_until.wait_time_from(DefaultClock::default().now()),
280                })
281            }
282        }
283    }
284
285    /// Wait until a request is allowed.
286    pub async fn wait(&self) {
287        self.maybe_recover();
288        loop {
289            let check_result = self.limiter.read().check();
290            match check_result {
291                Ok(_) => {
292                    self.requests.fetch_add(1, Ordering::Relaxed);
293                    return;
294                }
295                Err(not_until) => {
296                    let wait_time =
297                        not_until.wait_time_from(DefaultClock::default().now());
298                    tokio::time::sleep(wait_time).await;
299                }
300            }
301        }
302    }
303
304    /// Get the current status.
305    pub fn get_status(&self) -> RateLimiterStatus {
306        RateLimiterStatus {
307            current_rps: self.current_rps.load(Ordering::Relaxed) as f64,
308            max_rps: self.base_rps,
309            burst_size: self.burst_size,
310            is_limited: self.limiter.read().check().is_err(),
311            requests_last_minute: self.requests.load(Ordering::Relaxed),
312            rejections_last_minute: self.rejections.load(Ordering::Relaxed),
313        }
314    }
315
316    /// Get external rate limit count.
317    pub fn external_limit_count(&self) -> u64 {
318        self.external_limits.load(Ordering::Relaxed)
319    }
320}
321
322/// Per-key rate limiter for limiting different resources independently.
323///
324/// Useful for per-endpoint, per-user, or per-API-key rate limiting.
325pub struct KeyedRateLimiter<K: Hash + Eq + Clone + Send + Sync + 'static> {
326    limiters: DashMap<K, Arc<RateLimiter>>,
327    default_rps: u32,
328    default_burst: u32,
329}
330
331impl<K: Hash + Eq + Clone + Send + Sync + 'static> KeyedRateLimiter<K> {
332    /// Create a new keyed rate limiter with default limits.
333    ///
334    /// # Arguments
335    /// * `default_rps` - Default requests per second for new keys
336    /// * `default_burst` - Default burst capacity for new keys
337    pub fn new(default_rps: u32, default_burst: u32) -> Self {
338        Self {
339            limiters: DashMap::new(),
340            default_rps,
341            default_burst,
342        }
343    }
344
345    /// Set a specific limit for a key.
346    pub fn set_limit(&self, key: K, rps: u32, burst: u32) {
347        self.limiters
348            .insert(key, Arc::new(RateLimiter::new(rps, burst)));
349    }
350
351    /// Remove limit for a key (will use default on next access).
352    pub fn remove_limit(&self, key: &K) {
353        self.limiters.remove(key);
354    }
355
356    /// Check if a request for a key is allowed.
357    pub fn check(&self, key: &K) -> Result<(), RateLimitError> {
358        let limiter = self.get_or_create(key);
359        limiter.check()
360    }
361
362    /// Wait until a request for a key is allowed.
363    pub async fn wait(&self, key: &K) {
364        let limiter = self.get_or_create(key);
365        limiter.wait().await
366    }
367
368    /// Get status for a specific key.
369    pub fn get_status(&self, key: &K) -> Option<RateLimiterStatus> {
370        self.limiters.get(key).map(|l| l.get_status())
371    }
372
373    /// Get all keys with their status.
374    pub fn get_all_status(&self) -> Vec<(K, RateLimiterStatus)> {
375        self.limiters
376            .iter()
377            .map(|entry| (entry.key().clone(), entry.value().get_status()))
378            .collect()
379    }
380
381    /// Clear all limiters.
382    pub fn clear(&self) {
383        self.limiters.clear();
384    }
385
386    fn get_or_create(&self, key: &K) -> Arc<RateLimiter> {
387        self.limiters
388            .entry(key.clone())
389            .or_insert_with(|| Arc::new(RateLimiter::new(self.default_rps, self.default_burst)))
390            .clone()
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397
398    #[test]
399    fn test_rate_limiter_basic() {
400        let limiter = RateLimiter::new(10, 5);
401
402        // Should allow initial burst
403        for _ in 0..5 {
404            assert!(limiter.check().is_ok());
405        }
406    }
407
408    #[test]
409    fn test_rate_limiter_status() {
410        let limiter = RateLimiter::new(100, 10);
411        let status = limiter.get_status();
412
413        assert_eq!(status.max_rps, 100);
414        assert_eq!(status.burst_size, 10);
415    }
416
417    #[test]
418    fn test_adaptive_rate_limiter_backoff() {
419        let limiter = AdaptiveRateLimiter::new(100, 10).with_backoff_factor(0.5);
420
421        // Simulate rate limit responses
422        limiter.record_rate_limit();
423        let status1 = limiter.get_status();
424        assert!(status1.current_rps < 100.0);
425
426        limiter.record_rate_limit();
427        let status2 = limiter.get_status();
428        assert!(status2.current_rps < status1.current_rps);
429    }
430
431    #[test]
432    fn test_adaptive_rate_limiter_recovery() {
433        let limiter = AdaptiveRateLimiter::new(100, 10)
434            .with_recovery_interval(Duration::from_millis(1))
435            .with_backoff_factor(0.5);
436
437        limiter.record_rate_limit();
438        let reduced = limiter.get_status().current_rps;
439        assert!(reduced < 100.0);
440
441        // Record success after recovery interval
442        std::thread::sleep(Duration::from_millis(10));
443        limiter.record_success();
444        // Note: Recovery is gradual, may need multiple cycles
445    }
446
447    #[test]
448    fn test_keyed_rate_limiter() {
449        let limiter = KeyedRateLimiter::new(10, 5);
450
451        // Different keys should have independent limits
452        for _ in 0..5 {
453            assert!(limiter.check(&"key1").is_ok());
454            assert!(limiter.check(&"key2").is_ok());
455        }
456    }
457
458    #[test]
459    fn test_keyed_rate_limiter_custom_limits() {
460        let limiter = KeyedRateLimiter::new(10, 5);
461
462        // Set custom limit for specific key
463        limiter.set_limit("premium", 100, 50);
464
465        let status = limiter.get_status(&"premium").unwrap();
466        assert_eq!(status.max_rps, 100);
467        assert_eq!(status.burst_size, 50);
468    }
469
470    #[test]
471    fn test_keyed_rate_limiter_all_status() {
472        let limiter = KeyedRateLimiter::new(10, 5);
473
474        limiter.check(&"a").ok();
475        limiter.check(&"b").ok();
476        limiter.check(&"c").ok();
477
478        let all = limiter.get_all_status();
479        assert_eq!(all.len(), 3);
480    }
481
482    #[test]
483    fn test_rate_limit_error_display() {
484        let err = RateLimitError {
485            retry_after: Duration::from_secs(5),
486        };
487        let msg = format!("{}", err);
488        assert!(msg.contains("rate limit exceeded"));
489        assert!(msg.contains("5"));
490    }
491
492    #[tokio::test]
493    async fn test_rate_limiter_wait() {
494        let limiter = RateLimiter::new(1000, 100);
495
496        let start = Instant::now();
497        for _ in 0..10 {
498            limiter.wait().await;
499        }
500        // Should complete quickly with high RPS
501        assert!(start.elapsed() < Duration::from_secs(1));
502    }
503}