allframe_core/resilience/
rate_limit.rs1use std::{
6 hash::Hash,
7 num::NonZeroU32,
8 sync::{
9 atomic::{AtomicU32, AtomicU64, Ordering},
10 Arc,
11 },
12 time::{Duration, Instant},
13};
14
15use dashmap::DashMap;
16use governor::{
17 clock::{Clock, DefaultClock},
18 state::{InMemoryState, NotKeyed},
19 Quota, RateLimiter as GovernorRateLimiter,
20};
21use parking_lot::RwLock;
22
23#[derive(Debug, Clone)]
25pub struct RateLimitError {
26 pub retry_after: Duration,
28}
29
30impl std::fmt::Display for RateLimitError {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 write!(f, "rate limit exceeded, retry after {:?}", self.retry_after)
33 }
34}
35
36impl std::error::Error for RateLimitError {}
37
38#[derive(Debug, Clone)]
40pub struct RateLimiterStatus {
41 pub current_rps: f64,
43 pub max_rps: u32,
45 pub burst_size: u32,
47 pub is_limited: bool,
49 pub requests_last_minute: u64,
51 pub rejections_last_minute: u64,
53}
54
55pub struct RateLimiter {
59 limiter: GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>,
60 rps: u32,
61 burst_size: u32,
62 requests: AtomicU64,
63 rejections: AtomicU64,
64 last_reset: RwLock<Instant>,
65}
66
67impl RateLimiter {
68 pub fn new(rps: u32, burst_size: u32) -> Self {
74 let rps_nz = NonZeroU32::new(rps.max(1)).unwrap();
75 let burst_nz = NonZeroU32::new(burst_size.max(1)).unwrap();
76
77 let quota = Quota::per_second(rps_nz).allow_burst(burst_nz);
78 let limiter = GovernorRateLimiter::direct(quota);
79
80 Self {
81 limiter,
82 rps,
83 burst_size,
84 requests: AtomicU64::new(0),
85 rejections: AtomicU64::new(0),
86 last_reset: RwLock::new(Instant::now()),
87 }
88 }
89
90 pub fn check(&self) -> Result<(), RateLimitError> {
94 self.maybe_reset_counters();
95
96 match self.limiter.check() {
97 Ok(_) => {
98 self.requests.fetch_add(1, Ordering::Relaxed);
99 Ok(())
100 }
101 Err(not_until) => {
102 self.rejections.fetch_add(1, Ordering::Relaxed);
103 Err(RateLimitError {
104 retry_after: not_until.wait_time_from(DefaultClock::default().now()),
105 })
106 }
107 }
108 }
109
110 pub async fn wait(&self) {
114 self.maybe_reset_counters();
115 self.limiter.until_ready().await;
116 self.requests.fetch_add(1, Ordering::Relaxed);
117 }
118
119 pub fn get_status(&self) -> RateLimiterStatus {
121 self.maybe_reset_counters();
122
123 let requests = self.requests.load(Ordering::Relaxed);
124 let rejections = self.rejections.load(Ordering::Relaxed);
125 let elapsed = self.last_reset.read().elapsed().as_secs_f64().max(1.0);
126
127 RateLimiterStatus {
128 current_rps: requests as f64 / elapsed.min(60.0),
129 max_rps: self.rps,
130 burst_size: self.burst_size,
131 is_limited: self.limiter.check().is_err(),
132 requests_last_minute: requests,
133 rejections_last_minute: rejections,
134 }
135 }
136
137 fn maybe_reset_counters(&self) {
138 let mut last = self.last_reset.write();
139 if last.elapsed() > Duration::from_secs(60) {
140 self.requests.store(0, Ordering::Relaxed);
141 self.rejections.store(0, Ordering::Relaxed);
142 *last = Instant::now();
143 }
144 }
145}
146
147pub struct AdaptiveRateLimiter {
152 base_rps: u32,
154 burst_size: u32,
155 current_rps: AtomicU32,
157 limiter: RwLock<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
159 consecutive_limits: AtomicU32,
161 last_limit: RwLock<Option<Instant>>,
163 recovery_interval: Duration,
165 min_rps: u32,
167 backoff_factor: f64,
169 requests: AtomicU64,
171 rejections: AtomicU64,
172 external_limits: AtomicU64,
173}
174
175impl AdaptiveRateLimiter {
176 pub fn new(rps: u32, burst_size: u32) -> Self {
182 let limiter = Self::create_limiter(rps, burst_size);
183
184 Self {
185 base_rps: rps,
186 burst_size,
187 current_rps: AtomicU32::new(rps),
188 limiter: RwLock::new(limiter),
189 consecutive_limits: AtomicU32::new(0),
190 last_limit: RwLock::new(None),
191 recovery_interval: Duration::from_secs(30),
192 min_rps: 1,
193 backoff_factor: 0.5,
194 requests: AtomicU64::new(0),
195 rejections: AtomicU64::new(0),
196 external_limits: AtomicU64::new(0),
197 }
198 }
199
200 pub fn with_recovery_interval(mut self, interval: Duration) -> Self {
202 self.recovery_interval = interval;
203 self
204 }
205
206 pub fn with_min_rps(mut self, min_rps: u32) -> Self {
208 self.min_rps = min_rps.max(1);
209 self
210 }
211
212 pub fn with_backoff_factor(mut self, factor: f64) -> Self {
214 self.backoff_factor = factor.clamp(0.1, 0.9);
215 self
216 }
217
218 fn create_limiter(
219 rps: u32,
220 burst_size: u32,
221 ) -> GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock> {
222 let rps_nz = NonZeroU32::new(rps.max(1)).unwrap();
223 let burst_nz = NonZeroU32::new(burst_size.max(1)).unwrap();
224 let quota = Quota::per_second(rps_nz).allow_burst(burst_nz);
225 GovernorRateLimiter::direct(quota)
226 }
227
228 pub fn record_success(&self) {
230 self.consecutive_limits.store(0, Ordering::Relaxed);
231 self.maybe_recover();
232 }
233
234 pub fn record_rate_limit(&self) {
236 self.external_limits.fetch_add(1, Ordering::Relaxed);
237 let consecutive = self.consecutive_limits.fetch_add(1, Ordering::Relaxed) + 1;
238 *self.last_limit.write() = Some(Instant::now());
239
240 let reduction = self.backoff_factor.powi(consecutive.min(5) as i32);
242 let new_rps = ((self.base_rps as f64 * reduction) as u32).max(self.min_rps);
243
244 self.current_rps.store(new_rps, Ordering::Relaxed);
245 *self.limiter.write() = Self::create_limiter(new_rps, self.burst_size);
246 }
247
248 fn maybe_recover(&self) {
249 let last_limit = *self.last_limit.read();
250 if let Some(last) = last_limit {
251 if last.elapsed() > self.recovery_interval {
252 let current = self.current_rps.load(Ordering::Relaxed);
254 if current < self.base_rps {
255 let new_rps = ((current as f64 * 1.5) as u32).min(self.base_rps);
256 self.current_rps.store(new_rps, Ordering::Relaxed);
257 *self.limiter.write() = Self::create_limiter(new_rps, self.burst_size);
258
259 if new_rps >= self.base_rps {
260 *self.last_limit.write() = None;
261 }
262 }
263 }
264 }
265 }
266
267 pub fn check(&self) -> Result<(), RateLimitError> {
269 self.maybe_recover();
270
271 match self.limiter.read().check() {
272 Ok(_) => {
273 self.requests.fetch_add(1, Ordering::Relaxed);
274 Ok(())
275 }
276 Err(not_until) => {
277 self.rejections.fetch_add(1, Ordering::Relaxed);
278 Err(RateLimitError {
279 retry_after: not_until.wait_time_from(DefaultClock::default().now()),
280 })
281 }
282 }
283 }
284
285 pub async fn wait(&self) {
287 self.maybe_recover();
288 loop {
289 let check_result = self.limiter.read().check();
290 match check_result {
291 Ok(_) => {
292 self.requests.fetch_add(1, Ordering::Relaxed);
293 return;
294 }
295 Err(not_until) => {
296 let wait_time =
297 not_until.wait_time_from(DefaultClock::default().now());
298 tokio::time::sleep(wait_time).await;
299 }
300 }
301 }
302 }
303
304 pub fn get_status(&self) -> RateLimiterStatus {
306 RateLimiterStatus {
307 current_rps: self.current_rps.load(Ordering::Relaxed) as f64,
308 max_rps: self.base_rps,
309 burst_size: self.burst_size,
310 is_limited: self.limiter.read().check().is_err(),
311 requests_last_minute: self.requests.load(Ordering::Relaxed),
312 rejections_last_minute: self.rejections.load(Ordering::Relaxed),
313 }
314 }
315
316 pub fn external_limit_count(&self) -> u64 {
318 self.external_limits.load(Ordering::Relaxed)
319 }
320}
321
322pub struct KeyedRateLimiter<K: Hash + Eq + Clone + Send + Sync + 'static> {
326 limiters: DashMap<K, Arc<RateLimiter>>,
327 default_rps: u32,
328 default_burst: u32,
329}
330
331impl<K: Hash + Eq + Clone + Send + Sync + 'static> KeyedRateLimiter<K> {
332 pub fn new(default_rps: u32, default_burst: u32) -> Self {
338 Self {
339 limiters: DashMap::new(),
340 default_rps,
341 default_burst,
342 }
343 }
344
345 pub fn set_limit(&self, key: K, rps: u32, burst: u32) {
347 self.limiters
348 .insert(key, Arc::new(RateLimiter::new(rps, burst)));
349 }
350
351 pub fn remove_limit(&self, key: &K) {
353 self.limiters.remove(key);
354 }
355
356 pub fn check(&self, key: &K) -> Result<(), RateLimitError> {
358 let limiter = self.get_or_create(key);
359 limiter.check()
360 }
361
362 pub async fn wait(&self, key: &K) {
364 let limiter = self.get_or_create(key);
365 limiter.wait().await
366 }
367
368 pub fn get_status(&self, key: &K) -> Option<RateLimiterStatus> {
370 self.limiters.get(key).map(|l| l.get_status())
371 }
372
373 pub fn get_all_status(&self) -> Vec<(K, RateLimiterStatus)> {
375 self.limiters
376 .iter()
377 .map(|entry| (entry.key().clone(), entry.value().get_status()))
378 .collect()
379 }
380
381 pub fn clear(&self) {
383 self.limiters.clear();
384 }
385
386 fn get_or_create(&self, key: &K) -> Arc<RateLimiter> {
387 self.limiters
388 .entry(key.clone())
389 .or_insert_with(|| Arc::new(RateLimiter::new(self.default_rps, self.default_burst)))
390 .clone()
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397
398 #[test]
399 fn test_rate_limiter_basic() {
400 let limiter = RateLimiter::new(10, 5);
401
402 for _ in 0..5 {
404 assert!(limiter.check().is_ok());
405 }
406 }
407
408 #[test]
409 fn test_rate_limiter_status() {
410 let limiter = RateLimiter::new(100, 10);
411 let status = limiter.get_status();
412
413 assert_eq!(status.max_rps, 100);
414 assert_eq!(status.burst_size, 10);
415 }
416
417 #[test]
418 fn test_adaptive_rate_limiter_backoff() {
419 let limiter = AdaptiveRateLimiter::new(100, 10).with_backoff_factor(0.5);
420
421 limiter.record_rate_limit();
423 let status1 = limiter.get_status();
424 assert!(status1.current_rps < 100.0);
425
426 limiter.record_rate_limit();
427 let status2 = limiter.get_status();
428 assert!(status2.current_rps < status1.current_rps);
429 }
430
431 #[test]
432 fn test_adaptive_rate_limiter_recovery() {
433 let limiter = AdaptiveRateLimiter::new(100, 10)
434 .with_recovery_interval(Duration::from_millis(1))
435 .with_backoff_factor(0.5);
436
437 limiter.record_rate_limit();
438 let reduced = limiter.get_status().current_rps;
439 assert!(reduced < 100.0);
440
441 std::thread::sleep(Duration::from_millis(10));
443 limiter.record_success();
444 }
446
447 #[test]
448 fn test_keyed_rate_limiter() {
449 let limiter = KeyedRateLimiter::new(10, 5);
450
451 for _ in 0..5 {
453 assert!(limiter.check(&"key1").is_ok());
454 assert!(limiter.check(&"key2").is_ok());
455 }
456 }
457
458 #[test]
459 fn test_keyed_rate_limiter_custom_limits() {
460 let limiter = KeyedRateLimiter::new(10, 5);
461
462 limiter.set_limit("premium", 100, 50);
464
465 let status = limiter.get_status(&"premium").unwrap();
466 assert_eq!(status.max_rps, 100);
467 assert_eq!(status.burst_size, 50);
468 }
469
470 #[test]
471 fn test_keyed_rate_limiter_all_status() {
472 let limiter = KeyedRateLimiter::new(10, 5);
473
474 limiter.check(&"a").ok();
475 limiter.check(&"b").ok();
476 limiter.check(&"c").ok();
477
478 let all = limiter.get_all_status();
479 assert_eq!(all.len(), 3);
480 }
481
482 #[test]
483 fn test_rate_limit_error_display() {
484 let err = RateLimitError {
485 retry_after: Duration::from_secs(5),
486 };
487 let msg = format!("{}", err);
488 assert!(msg.contains("rate limit exceeded"));
489 assert!(msg.contains("5"));
490 }
491
492 #[tokio::test]
493 async fn test_rate_limiter_wait() {
494 let limiter = RateLimiter::new(1000, 100);
495
496 let start = Instant::now();
497 for _ in 0..10 {
498 limiter.wait().await;
499 }
500 assert!(start.elapsed() < Duration::from_secs(1));
502 }
503}