allframe_core/resilience/
rate_limit.rs

1//! Rate limiting primitives with adaptive and keyed variants.
2//!
3//! Provides token bucket rate limiting for controlling request throughput.
4
5use std::{
6    hash::Hash,
7    num::NonZeroU32,
8    sync::{
9        atomic::{AtomicU32, AtomicU64, Ordering},
10        Arc,
11    },
12    time::{Duration, Instant},
13};
14
15use dashmap::DashMap;
16use governor::{
17    clock::{Clock, DefaultClock},
18    state::{InMemoryState, NotKeyed},
19    Quota, RateLimiter as GovernorRateLimiter,
20};
21use parking_lot::RwLock;
22
23/// Error returned when rate limit is exceeded.
24#[derive(Debug, Clone)]
25pub struct RateLimitError {
26    /// When the rate limit will reset.
27    pub retry_after: Duration,
28}
29
30impl std::fmt::Display for RateLimitError {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        write!(f, "rate limit exceeded, retry after {:?}", self.retry_after)
33    }
34}
35
36impl std::error::Error for RateLimitError {}
37
38/// Status information for a rate limiter.
39#[derive(Debug, Clone)]
40pub struct RateLimiterStatus {
41    /// Current requests per second.
42    pub current_rps: f64,
43    /// Configured maximum RPS.
44    pub max_rps: u32,
45    /// Burst capacity.
46    pub burst_size: u32,
47    /// Whether currently rate limited.
48    pub is_limited: bool,
49    /// Number of requests allowed in the last minute.
50    pub requests_last_minute: u64,
51    /// Number of requests rejected in the last minute.
52    pub rejections_last_minute: u64,
53}
54
55/// Token bucket rate limiter.
56///
57/// Allows a sustained rate of requests with burst capacity.
58pub struct RateLimiter {
59    limiter: GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>,
60    rps: u32,
61    burst_size: u32,
62    requests: AtomicU64,
63    rejections: AtomicU64,
64    last_reset: RwLock<Instant>,
65}
66
67impl RateLimiter {
68    /// Create a new rate limiter.
69    ///
70    /// # Arguments
71    /// * `rps` - Maximum requests per second
72    /// * `burst_size` - Additional burst capacity above the sustained rate
73    pub fn new(rps: u32, burst_size: u32) -> Self {
74        let rps_nz = NonZeroU32::new(rps.max(1)).unwrap();
75        let burst_nz = NonZeroU32::new(burst_size.max(1)).unwrap();
76
77        let quota = Quota::per_second(rps_nz).allow_burst(burst_nz);
78        let limiter = GovernorRateLimiter::direct(quota);
79
80        Self {
81            limiter,
82            rps,
83            burst_size,
84            requests: AtomicU64::new(0),
85            rejections: AtomicU64::new(0),
86            last_reset: RwLock::new(Instant::now()),
87        }
88    }
89
90    /// Check if a request is allowed without blocking.
91    ///
92    /// Returns `Ok(())` if allowed, `Err(RateLimitError)` if rate limited.
93    pub fn check(&self) -> Result<(), RateLimitError> {
94        self.maybe_reset_counters();
95
96        match self.limiter.check() {
97            Ok(_) => {
98                self.requests.fetch_add(1, Ordering::Relaxed);
99                Ok(())
100            }
101            Err(not_until) => {
102                self.rejections.fetch_add(1, Ordering::Relaxed);
103                Err(RateLimitError {
104                    retry_after: not_until.wait_time_from(DefaultClock::default().now()),
105                })
106            }
107        }
108    }
109
110    /// Wait until a request is allowed.
111    ///
112    /// Blocks the current task until the rate limit permits the request.
113    pub async fn wait(&self) {
114        self.maybe_reset_counters();
115        self.limiter.until_ready().await;
116        self.requests.fetch_add(1, Ordering::Relaxed);
117    }
118
119    /// Get the current status of the rate limiter.
120    pub fn get_status(&self) -> RateLimiterStatus {
121        self.maybe_reset_counters();
122
123        let requests = self.requests.load(Ordering::Relaxed);
124        let rejections = self.rejections.load(Ordering::Relaxed);
125        let elapsed = self.last_reset.read().elapsed().as_secs_f64().max(1.0);
126
127        RateLimiterStatus {
128            current_rps: requests as f64 / elapsed.min(60.0),
129            max_rps: self.rps,
130            burst_size: self.burst_size,
131            is_limited: self.limiter.check().is_err(),
132            requests_last_minute: requests,
133            rejections_last_minute: rejections,
134        }
135    }
136
137    fn maybe_reset_counters(&self) {
138        let mut last = self.last_reset.write();
139        if last.elapsed() > Duration::from_secs(60) {
140            self.requests.store(0, Ordering::Relaxed);
141            self.rejections.store(0, Ordering::Relaxed);
142            *last = Instant::now();
143        }
144    }
145}
146
147/// Adaptive rate limiter that backs off when receiving external rate limits.
148///
149/// When external services return 429 responses, this limiter reduces its
150/// throughput to avoid hammering the service.
151pub struct AdaptiveRateLimiter {
152    /// Base rate limiter.
153    base_rps: u32,
154    burst_size: u32,
155    /// Current effective RPS (may be reduced).
156    current_rps: AtomicU32,
157    /// Underlying rate limiter.
158    limiter: RwLock<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
159    /// Consecutive rate limit responses.
160    consecutive_limits: AtomicU32,
161    /// Last rate limit time.
162    last_limit: RwLock<Option<Instant>>,
163    /// Recovery interval.
164    recovery_interval: Duration,
165    /// Minimum RPS (floor).
166    min_rps: u32,
167    /// Backoff factor when rate limited.
168    backoff_factor: f64,
169    /// Statistics.
170    requests: AtomicU64,
171    rejections: AtomicU64,
172    external_limits: AtomicU64,
173}
174
175impl AdaptiveRateLimiter {
176    /// Create a new adaptive rate limiter.
177    ///
178    /// # Arguments
179    /// * `rps` - Base requests per second
180    /// * `burst_size` - Burst capacity
181    pub fn new(rps: u32, burst_size: u32) -> Self {
182        let limiter = Self::create_limiter(rps, burst_size);
183
184        Self {
185            base_rps: rps,
186            burst_size,
187            current_rps: AtomicU32::new(rps),
188            limiter: RwLock::new(limiter),
189            consecutive_limits: AtomicU32::new(0),
190            last_limit: RwLock::new(None),
191            recovery_interval: Duration::from_secs(30),
192            min_rps: 1,
193            backoff_factor: 0.5,
194            requests: AtomicU64::new(0),
195            rejections: AtomicU64::new(0),
196            external_limits: AtomicU64::new(0),
197        }
198    }
199
200    /// Set the recovery interval.
201    pub fn with_recovery_interval(mut self, interval: Duration) -> Self {
202        self.recovery_interval = interval;
203        self
204    }
205
206    /// Set the minimum RPS floor.
207    pub fn with_min_rps(mut self, min_rps: u32) -> Self {
208        self.min_rps = min_rps.max(1);
209        self
210    }
211
212    /// Set the backoff factor (0.0-1.0).
213    pub fn with_backoff_factor(mut self, factor: f64) -> Self {
214        self.backoff_factor = factor.clamp(0.1, 0.9);
215        self
216    }
217
218    fn create_limiter(
219        rps: u32,
220        burst_size: u32,
221    ) -> GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock> {
222        let rps_nz = NonZeroU32::new(rps.max(1)).unwrap();
223        let burst_nz = NonZeroU32::new(burst_size.max(1)).unwrap();
224        let quota = Quota::per_second(rps_nz).allow_burst(burst_nz);
225        GovernorRateLimiter::direct(quota)
226    }
227
228    /// Record a successful request.
229    pub fn record_success(&self) {
230        self.consecutive_limits.store(0, Ordering::Relaxed);
231        self.maybe_recover();
232    }
233
234    /// Record a rate limit response from an external service.
235    pub fn record_rate_limit(&self) {
236        self.external_limits.fetch_add(1, Ordering::Relaxed);
237        let consecutive = self.consecutive_limits.fetch_add(1, Ordering::Relaxed) + 1;
238        *self.last_limit.write() = Some(Instant::now());
239
240        // Reduce rate based on consecutive rate limits
241        let reduction = self.backoff_factor.powi(consecutive.min(5) as i32);
242        let new_rps = ((self.base_rps as f64 * reduction) as u32).max(self.min_rps);
243
244        self.current_rps.store(new_rps, Ordering::Relaxed);
245        *self.limiter.write() = Self::create_limiter(new_rps, self.burst_size);
246    }
247
248    fn maybe_recover(&self) {
249        let last_limit = *self.last_limit.read();
250        if let Some(last) = last_limit {
251            if last.elapsed() > self.recovery_interval {
252                // Gradually recover to base rate
253                let current = self.current_rps.load(Ordering::Relaxed);
254                if current < self.base_rps {
255                    let new_rps = ((current as f64 * 1.5) as u32).min(self.base_rps);
256                    self.current_rps.store(new_rps, Ordering::Relaxed);
257                    *self.limiter.write() = Self::create_limiter(new_rps, self.burst_size);
258
259                    if new_rps >= self.base_rps {
260                        *self.last_limit.write() = None;
261                    }
262                }
263            }
264        }
265    }
266
267    /// Check if a request is allowed.
268    pub fn check(&self) -> Result<(), RateLimitError> {
269        self.maybe_recover();
270
271        match self.limiter.read().check() {
272            Ok(_) => {
273                self.requests.fetch_add(1, Ordering::Relaxed);
274                Ok(())
275            }
276            Err(not_until) => {
277                self.rejections.fetch_add(1, Ordering::Relaxed);
278                Err(RateLimitError {
279                    retry_after: not_until.wait_time_from(DefaultClock::default().now()),
280                })
281            }
282        }
283    }
284
285    /// Wait until a request is allowed.
286    pub async fn wait(&self) {
287        self.maybe_recover();
288        self.limiter.read().until_ready().await;
289        self.requests.fetch_add(1, Ordering::Relaxed);
290    }
291
292    /// Get the current status.
293    pub fn get_status(&self) -> RateLimiterStatus {
294        RateLimiterStatus {
295            current_rps: self.current_rps.load(Ordering::Relaxed) as f64,
296            max_rps: self.base_rps,
297            burst_size: self.burst_size,
298            is_limited: self.limiter.read().check().is_err(),
299            requests_last_minute: self.requests.load(Ordering::Relaxed),
300            rejections_last_minute: self.rejections.load(Ordering::Relaxed),
301        }
302    }
303
304    /// Get external rate limit count.
305    pub fn external_limit_count(&self) -> u64 {
306        self.external_limits.load(Ordering::Relaxed)
307    }
308}
309
310/// Per-key rate limiter for limiting different resources independently.
311///
312/// Useful for per-endpoint, per-user, or per-API-key rate limiting.
313pub struct KeyedRateLimiter<K: Hash + Eq + Clone + Send + Sync + 'static> {
314    limiters: DashMap<K, Arc<RateLimiter>>,
315    default_rps: u32,
316    default_burst: u32,
317}
318
319impl<K: Hash + Eq + Clone + Send + Sync + 'static> KeyedRateLimiter<K> {
320    /// Create a new keyed rate limiter with default limits.
321    ///
322    /// # Arguments
323    /// * `default_rps` - Default requests per second for new keys
324    /// * `default_burst` - Default burst capacity for new keys
325    pub fn new(default_rps: u32, default_burst: u32) -> Self {
326        Self {
327            limiters: DashMap::new(),
328            default_rps,
329            default_burst,
330        }
331    }
332
333    /// Set a specific limit for a key.
334    pub fn set_limit(&self, key: K, rps: u32, burst: u32) {
335        self.limiters
336            .insert(key, Arc::new(RateLimiter::new(rps, burst)));
337    }
338
339    /// Remove limit for a key (will use default on next access).
340    pub fn remove_limit(&self, key: &K) {
341        self.limiters.remove(key);
342    }
343
344    /// Check if a request for a key is allowed.
345    pub fn check(&self, key: &K) -> Result<(), RateLimitError> {
346        let limiter = self.get_or_create(key);
347        limiter.check()
348    }
349
350    /// Wait until a request for a key is allowed.
351    pub async fn wait(&self, key: &K) {
352        let limiter = self.get_or_create(key);
353        limiter.wait().await
354    }
355
356    /// Get status for a specific key.
357    pub fn get_status(&self, key: &K) -> Option<RateLimiterStatus> {
358        self.limiters.get(key).map(|l| l.get_status())
359    }
360
361    /// Get all keys with their status.
362    pub fn get_all_status(&self) -> Vec<(K, RateLimiterStatus)> {
363        self.limiters
364            .iter()
365            .map(|entry| (entry.key().clone(), entry.value().get_status()))
366            .collect()
367    }
368
369    /// Clear all limiters.
370    pub fn clear(&self) {
371        self.limiters.clear();
372    }
373
374    fn get_or_create(&self, key: &K) -> Arc<RateLimiter> {
375        self.limiters
376            .entry(key.clone())
377            .or_insert_with(|| Arc::new(RateLimiter::new(self.default_rps, self.default_burst)))
378            .clone()
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn test_rate_limiter_basic() {
388        let limiter = RateLimiter::new(10, 5);
389
390        // Should allow initial burst
391        for _ in 0..5 {
392            assert!(limiter.check().is_ok());
393        }
394    }
395
396    #[test]
397    fn test_rate_limiter_status() {
398        let limiter = RateLimiter::new(100, 10);
399        let status = limiter.get_status();
400
401        assert_eq!(status.max_rps, 100);
402        assert_eq!(status.burst_size, 10);
403    }
404
405    #[test]
406    fn test_adaptive_rate_limiter_backoff() {
407        let limiter = AdaptiveRateLimiter::new(100, 10).with_backoff_factor(0.5);
408
409        // Simulate rate limit responses
410        limiter.record_rate_limit();
411        let status1 = limiter.get_status();
412        assert!(status1.current_rps < 100.0);
413
414        limiter.record_rate_limit();
415        let status2 = limiter.get_status();
416        assert!(status2.current_rps < status1.current_rps);
417    }
418
419    #[test]
420    fn test_adaptive_rate_limiter_recovery() {
421        let limiter = AdaptiveRateLimiter::new(100, 10)
422            .with_recovery_interval(Duration::from_millis(1))
423            .with_backoff_factor(0.5);
424
425        limiter.record_rate_limit();
426        let reduced = limiter.get_status().current_rps;
427        assert!(reduced < 100.0);
428
429        // Record success after recovery interval
430        std::thread::sleep(Duration::from_millis(10));
431        limiter.record_success();
432        // Note: Recovery is gradual, may need multiple cycles
433    }
434
435    #[test]
436    fn test_keyed_rate_limiter() {
437        let limiter = KeyedRateLimiter::new(10, 5);
438
439        // Different keys should have independent limits
440        for _ in 0..5 {
441            assert!(limiter.check(&"key1").is_ok());
442            assert!(limiter.check(&"key2").is_ok());
443        }
444    }
445
446    #[test]
447    fn test_keyed_rate_limiter_custom_limits() {
448        let limiter = KeyedRateLimiter::new(10, 5);
449
450        // Set custom limit for specific key
451        limiter.set_limit("premium", 100, 50);
452
453        let status = limiter.get_status(&"premium").unwrap();
454        assert_eq!(status.max_rps, 100);
455        assert_eq!(status.burst_size, 50);
456    }
457
458    #[test]
459    fn test_keyed_rate_limiter_all_status() {
460        let limiter = KeyedRateLimiter::new(10, 5);
461
462        limiter.check(&"a").ok();
463        limiter.check(&"b").ok();
464        limiter.check(&"c").ok();
465
466        let all = limiter.get_all_status();
467        assert_eq!(all.len(), 3);
468    }
469
470    #[test]
471    fn test_rate_limit_error_display() {
472        let err = RateLimitError {
473            retry_after: Duration::from_secs(5),
474        };
475        let msg = format!("{}", err);
476        assert!(msg.contains("rate limit exceeded"));
477        assert!(msg.contains("5"));
478    }
479
480    #[tokio::test]
481    async fn test_rate_limiter_wait() {
482        let limiter = RateLimiter::new(1000, 100);
483
484        let start = Instant::now();
485        for _ in 0..10 {
486            limiter.wait().await;
487        }
488        // Should complete quickly with high RPS
489        assert!(start.elapsed() < Duration::from_secs(1));
490    }
491}