allframe_core/resilience/
rate_limit.rs1use std::{
6 hash::Hash,
7 num::NonZeroU32,
8 sync::{
9 atomic::{AtomicU32, AtomicU64, Ordering},
10 Arc,
11 },
12 time::{Duration, Instant},
13};
14
15use dashmap::DashMap;
16use governor::{
17 clock::{Clock, DefaultClock},
18 state::{InMemoryState, NotKeyed},
19 Quota, RateLimiter as GovernorRateLimiter,
20};
21use parking_lot::RwLock;
22
23#[derive(Debug, Clone)]
25pub struct RateLimitError {
26 pub retry_after: Duration,
28}
29
30impl std::fmt::Display for RateLimitError {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 write!(f, "rate limit exceeded, retry after {:?}", self.retry_after)
33 }
34}
35
36impl std::error::Error for RateLimitError {}
37
38#[derive(Debug, Clone)]
40pub struct RateLimiterStatus {
41 pub current_rps: f64,
43 pub max_rps: u32,
45 pub burst_size: u32,
47 pub is_limited: bool,
49 pub requests_last_minute: u64,
51 pub rejections_last_minute: u64,
53}
54
55pub struct RateLimiter {
59 limiter: GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>,
60 rps: u32,
61 burst_size: u32,
62 requests: AtomicU64,
63 rejections: AtomicU64,
64 last_reset: RwLock<Instant>,
65}
66
67impl RateLimiter {
68 pub fn new(rps: u32, burst_size: u32) -> Self {
74 let rps_nz = NonZeroU32::new(rps.max(1)).unwrap();
75 let burst_nz = NonZeroU32::new(burst_size.max(1)).unwrap();
76
77 let quota = Quota::per_second(rps_nz).allow_burst(burst_nz);
78 let limiter = GovernorRateLimiter::direct(quota);
79
80 Self {
81 limiter,
82 rps,
83 burst_size,
84 requests: AtomicU64::new(0),
85 rejections: AtomicU64::new(0),
86 last_reset: RwLock::new(Instant::now()),
87 }
88 }
89
90 pub fn check(&self) -> Result<(), RateLimitError> {
94 self.maybe_reset_counters();
95
96 match self.limiter.check() {
97 Ok(_) => {
98 self.requests.fetch_add(1, Ordering::Relaxed);
99 Ok(())
100 }
101 Err(not_until) => {
102 self.rejections.fetch_add(1, Ordering::Relaxed);
103 Err(RateLimitError {
104 retry_after: not_until.wait_time_from(DefaultClock::default().now()),
105 })
106 }
107 }
108 }
109
110 pub async fn wait(&self) {
114 self.maybe_reset_counters();
115 self.limiter.until_ready().await;
116 self.requests.fetch_add(1, Ordering::Relaxed);
117 }
118
119 pub fn get_status(&self) -> RateLimiterStatus {
121 self.maybe_reset_counters();
122
123 let requests = self.requests.load(Ordering::Relaxed);
124 let rejections = self.rejections.load(Ordering::Relaxed);
125 let elapsed = self.last_reset.read().elapsed().as_secs_f64().max(1.0);
126
127 RateLimiterStatus {
128 current_rps: requests as f64 / elapsed.min(60.0),
129 max_rps: self.rps,
130 burst_size: self.burst_size,
131 is_limited: self.limiter.check().is_err(),
132 requests_last_minute: requests,
133 rejections_last_minute: rejections,
134 }
135 }
136
137 fn maybe_reset_counters(&self) {
138 let mut last = self.last_reset.write();
139 if last.elapsed() > Duration::from_secs(60) {
140 self.requests.store(0, Ordering::Relaxed);
141 self.rejections.store(0, Ordering::Relaxed);
142 *last = Instant::now();
143 }
144 }
145}
146
147pub struct AdaptiveRateLimiter {
152 base_rps: u32,
154 burst_size: u32,
155 current_rps: AtomicU32,
157 limiter: RwLock<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
159 consecutive_limits: AtomicU32,
161 last_limit: RwLock<Option<Instant>>,
163 recovery_interval: Duration,
165 min_rps: u32,
167 backoff_factor: f64,
169 requests: AtomicU64,
171 rejections: AtomicU64,
172 external_limits: AtomicU64,
173}
174
175impl AdaptiveRateLimiter {
176 pub fn new(rps: u32, burst_size: u32) -> Self {
182 let limiter = Self::create_limiter(rps, burst_size);
183
184 Self {
185 base_rps: rps,
186 burst_size,
187 current_rps: AtomicU32::new(rps),
188 limiter: RwLock::new(limiter),
189 consecutive_limits: AtomicU32::new(0),
190 last_limit: RwLock::new(None),
191 recovery_interval: Duration::from_secs(30),
192 min_rps: 1,
193 backoff_factor: 0.5,
194 requests: AtomicU64::new(0),
195 rejections: AtomicU64::new(0),
196 external_limits: AtomicU64::new(0),
197 }
198 }
199
200 pub fn with_recovery_interval(mut self, interval: Duration) -> Self {
202 self.recovery_interval = interval;
203 self
204 }
205
206 pub fn with_min_rps(mut self, min_rps: u32) -> Self {
208 self.min_rps = min_rps.max(1);
209 self
210 }
211
212 pub fn with_backoff_factor(mut self, factor: f64) -> Self {
214 self.backoff_factor = factor.clamp(0.1, 0.9);
215 self
216 }
217
218 fn create_limiter(
219 rps: u32,
220 burst_size: u32,
221 ) -> GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock> {
222 let rps_nz = NonZeroU32::new(rps.max(1)).unwrap();
223 let burst_nz = NonZeroU32::new(burst_size.max(1)).unwrap();
224 let quota = Quota::per_second(rps_nz).allow_burst(burst_nz);
225 GovernorRateLimiter::direct(quota)
226 }
227
228 pub fn record_success(&self) {
230 self.consecutive_limits.store(0, Ordering::Relaxed);
231 self.maybe_recover();
232 }
233
234 pub fn record_rate_limit(&self) {
236 self.external_limits.fetch_add(1, Ordering::Relaxed);
237 let consecutive = self.consecutive_limits.fetch_add(1, Ordering::Relaxed) + 1;
238 *self.last_limit.write() = Some(Instant::now());
239
240 let reduction = self.backoff_factor.powi(consecutive.min(5) as i32);
242 let new_rps = ((self.base_rps as f64 * reduction) as u32).max(self.min_rps);
243
244 self.current_rps.store(new_rps, Ordering::Relaxed);
245 *self.limiter.write() = Self::create_limiter(new_rps, self.burst_size);
246 }
247
248 fn maybe_recover(&self) {
249 let last_limit = *self.last_limit.read();
250 if let Some(last) = last_limit {
251 if last.elapsed() > self.recovery_interval {
252 let current = self.current_rps.load(Ordering::Relaxed);
254 if current < self.base_rps {
255 let new_rps = ((current as f64 * 1.5) as u32).min(self.base_rps);
256 self.current_rps.store(new_rps, Ordering::Relaxed);
257 *self.limiter.write() = Self::create_limiter(new_rps, self.burst_size);
258
259 if new_rps >= self.base_rps {
260 *self.last_limit.write() = None;
261 }
262 }
263 }
264 }
265 }
266
267 pub fn check(&self) -> Result<(), RateLimitError> {
269 self.maybe_recover();
270
271 match self.limiter.read().check() {
272 Ok(_) => {
273 self.requests.fetch_add(1, Ordering::Relaxed);
274 Ok(())
275 }
276 Err(not_until) => {
277 self.rejections.fetch_add(1, Ordering::Relaxed);
278 Err(RateLimitError {
279 retry_after: not_until.wait_time_from(DefaultClock::default().now()),
280 })
281 }
282 }
283 }
284
285 pub async fn wait(&self) {
287 self.maybe_recover();
288 self.limiter.read().until_ready().await;
289 self.requests.fetch_add(1, Ordering::Relaxed);
290 }
291
292 pub fn get_status(&self) -> RateLimiterStatus {
294 RateLimiterStatus {
295 current_rps: self.current_rps.load(Ordering::Relaxed) as f64,
296 max_rps: self.base_rps,
297 burst_size: self.burst_size,
298 is_limited: self.limiter.read().check().is_err(),
299 requests_last_minute: self.requests.load(Ordering::Relaxed),
300 rejections_last_minute: self.rejections.load(Ordering::Relaxed),
301 }
302 }
303
304 pub fn external_limit_count(&self) -> u64 {
306 self.external_limits.load(Ordering::Relaxed)
307 }
308}
309
310pub struct KeyedRateLimiter<K: Hash + Eq + Clone + Send + Sync + 'static> {
314 limiters: DashMap<K, Arc<RateLimiter>>,
315 default_rps: u32,
316 default_burst: u32,
317}
318
319impl<K: Hash + Eq + Clone + Send + Sync + 'static> KeyedRateLimiter<K> {
320 pub fn new(default_rps: u32, default_burst: u32) -> Self {
326 Self {
327 limiters: DashMap::new(),
328 default_rps,
329 default_burst,
330 }
331 }
332
333 pub fn set_limit(&self, key: K, rps: u32, burst: u32) {
335 self.limiters
336 .insert(key, Arc::new(RateLimiter::new(rps, burst)));
337 }
338
339 pub fn remove_limit(&self, key: &K) {
341 self.limiters.remove(key);
342 }
343
344 pub fn check(&self, key: &K) -> Result<(), RateLimitError> {
346 let limiter = self.get_or_create(key);
347 limiter.check()
348 }
349
350 pub async fn wait(&self, key: &K) {
352 let limiter = self.get_or_create(key);
353 limiter.wait().await
354 }
355
356 pub fn get_status(&self, key: &K) -> Option<RateLimiterStatus> {
358 self.limiters.get(key).map(|l| l.get_status())
359 }
360
361 pub fn get_all_status(&self) -> Vec<(K, RateLimiterStatus)> {
363 self.limiters
364 .iter()
365 .map(|entry| (entry.key().clone(), entry.value().get_status()))
366 .collect()
367 }
368
369 pub fn clear(&self) {
371 self.limiters.clear();
372 }
373
374 fn get_or_create(&self, key: &K) -> Arc<RateLimiter> {
375 self.limiters
376 .entry(key.clone())
377 .or_insert_with(|| Arc::new(RateLimiter::new(self.default_rps, self.default_burst)))
378 .clone()
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385
386 #[test]
387 fn test_rate_limiter_basic() {
388 let limiter = RateLimiter::new(10, 5);
389
390 for _ in 0..5 {
392 assert!(limiter.check().is_ok());
393 }
394 }
395
396 #[test]
397 fn test_rate_limiter_status() {
398 let limiter = RateLimiter::new(100, 10);
399 let status = limiter.get_status();
400
401 assert_eq!(status.max_rps, 100);
402 assert_eq!(status.burst_size, 10);
403 }
404
405 #[test]
406 fn test_adaptive_rate_limiter_backoff() {
407 let limiter = AdaptiveRateLimiter::new(100, 10).with_backoff_factor(0.5);
408
409 limiter.record_rate_limit();
411 let status1 = limiter.get_status();
412 assert!(status1.current_rps < 100.0);
413
414 limiter.record_rate_limit();
415 let status2 = limiter.get_status();
416 assert!(status2.current_rps < status1.current_rps);
417 }
418
419 #[test]
420 fn test_adaptive_rate_limiter_recovery() {
421 let limiter = AdaptiveRateLimiter::new(100, 10)
422 .with_recovery_interval(Duration::from_millis(1))
423 .with_backoff_factor(0.5);
424
425 limiter.record_rate_limit();
426 let reduced = limiter.get_status().current_rps;
427 assert!(reduced < 100.0);
428
429 std::thread::sleep(Duration::from_millis(10));
431 limiter.record_success();
432 }
434
435 #[test]
436 fn test_keyed_rate_limiter() {
437 let limiter = KeyedRateLimiter::new(10, 5);
438
439 for _ in 0..5 {
441 assert!(limiter.check(&"key1").is_ok());
442 assert!(limiter.check(&"key2").is_ok());
443 }
444 }
445
446 #[test]
447 fn test_keyed_rate_limiter_custom_limits() {
448 let limiter = KeyedRateLimiter::new(10, 5);
449
450 limiter.set_limit("premium", 100, 50);
452
453 let status = limiter.get_status(&"premium").unwrap();
454 assert_eq!(status.max_rps, 100);
455 assert_eq!(status.burst_size, 50);
456 }
457
458 #[test]
459 fn test_keyed_rate_limiter_all_status() {
460 let limiter = KeyedRateLimiter::new(10, 5);
461
462 limiter.check(&"a").ok();
463 limiter.check(&"b").ok();
464 limiter.check(&"c").ok();
465
466 let all = limiter.get_all_status();
467 assert_eq!(all.len(), 3);
468 }
469
470 #[test]
471 fn test_rate_limit_error_display() {
472 let err = RateLimitError {
473 retry_after: Duration::from_secs(5),
474 };
475 let msg = format!("{}", err);
476 assert!(msg.contains("rate limit exceeded"));
477 assert!(msg.contains("5"));
478 }
479
480 #[tokio::test]
481 async fn test_rate_limiter_wait() {
482 let limiter = RateLimiter::new(1000, 100);
483
484 let start = Instant::now();
485 for _ in 0..10 {
486 limiter.wait().await;
487 }
488 assert!(start.elapsed() < Duration::from_secs(1));
490 }
491}