1use dashmap::DashMap;
7use serde::{Deserialize, Serialize};
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct RateLimitConfig {
13 pub requests_per_second: u32,
15 pub requests_per_minute: u32,
17 pub requests_per_hour: u32,
19 pub burst_size: u32,
21 #[serde(default = "default_minute_capacity")]
23 pub minute_window_capacity: usize,
24 #[serde(default = "default_hour_capacity")]
26 pub hour_window_capacity: usize,
27}
28
29fn default_minute_capacity() -> usize {
30 1000
31}
32fn default_hour_capacity() -> usize {
33 10000
34}
35
36impl Default for RateLimitConfig {
37 fn default() -> Self {
38 Self {
39 requests_per_second: 10,
40 requests_per_minute: 100,
41 requests_per_hour: 1000,
42 burst_size: 20,
43 minute_window_capacity: 1000,
44 hour_window_capacity: 10000,
45 }
46 }
47}
48
49#[derive(Debug)]
51struct TokenBucket {
52 tokens: f64,
54 max_tokens: f64,
56 refill_rate: f64,
58 last_update: Instant,
60}
61
62impl TokenBucket {
63 fn new(max_tokens: f64, refill_rate: f64) -> Self {
64 Self {
65 tokens: max_tokens,
66 max_tokens,
67 refill_rate,
68 last_update: Instant::now(),
69 }
70 }
71
72 fn try_take(&mut self, tokens: f64) -> bool {
73 let elapsed = self.last_update.elapsed().as_secs_f64();
75 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
76 self.last_update = Instant::now();
77
78 if self.tokens >= tokens {
80 self.tokens -= tokens;
81 true
82 } else {
83 false
84 }
85 }
86}
87
88pub struct RateLimiter {
90 config: RateLimitConfig,
92 buckets: DashMap<String, TokenBucket>,
94 counters: DashMap<String, SlidingWindowCounter>,
96}
97
98#[derive(Debug)]
100struct SlidingWindowCounter {
101 minute_requests: Vec<Instant>,
103 hour_requests: Vec<Instant>,
105 minute_capacity: usize,
107 hour_capacity: usize,
109}
110
111impl SlidingWindowCounter {
112 fn with_capacity(minute_capacity: usize, hour_capacity: usize) -> Self {
113 Self {
114 minute_requests: Vec::new(),
115 hour_requests: Vec::new(),
116 minute_capacity,
117 hour_capacity,
118 }
119 }
120
121 fn add_request(&mut self) {
122 let now = Instant::now();
123 self.minute_requests.push(now);
124 self.hour_requests.push(now);
125
126 self.minute_requests
128 .retain(|t| t.elapsed() < Duration::from_secs(60));
129 self.hour_requests
130 .retain(|t| t.elapsed() < Duration::from_secs(3600));
131
132 if self.minute_requests.len() > self.minute_capacity {
134 let excess = self.minute_requests.len() - self.minute_capacity;
135 self.minute_requests.drain(0..excess);
136 }
137 if self.hour_requests.len() > self.hour_capacity {
138 let excess = self.hour_requests.len() - self.hour_capacity;
139 self.hour_requests.drain(0..excess);
140 }
141 }
142
143 fn minute_count(&self) -> usize {
144 self.minute_requests.len()
145 }
146
147 fn hour_count(&self) -> usize {
148 self.hour_requests.len()
149 }
150}
151
152impl RateLimiter {
153 pub fn new() -> Self {
154 Self::with_config(RateLimitConfig::default())
155 }
156
157 pub fn with_config(config: RateLimitConfig) -> Self {
158 Self {
159 config: config.clone(),
160 buckets: DashMap::new(),
161 counters: DashMap::new(),
162 }
163 }
164
165 pub async fn check(&self, key: &str) -> anyhow::Result<bool> {
167 let bucket_result = {
169 let mut bucket = self.buckets.entry(key.to_string()).or_insert_with(|| {
170 TokenBucket::new(
171 self.config.burst_size as f64,
172 self.config.requests_per_second as f64,
173 )
174 });
175 bucket.try_take(1.0)
176 };
177
178 if !bucket_result {
179 return Ok(false);
180 }
181
182 let window_result = {
184 let minute_cap = self.config.minute_window_capacity;
185 let hour_cap = self.config.hour_window_capacity;
186 let mut counter = self
187 .counters
188 .entry(key.to_string())
189 .or_insert_with(|| SlidingWindowCounter::with_capacity(minute_cap, hour_cap));
190
191 let minute_exceeded =
192 counter.minute_count() >= self.config.requests_per_minute as usize;
193 let hour_exceeded = counter.hour_count() >= self.config.requests_per_hour as usize;
194
195 if minute_exceeded || hour_exceeded {
196 false
197 } else {
198 counter.add_request();
199 true
200 }
201 };
202
203 Ok(window_result)
204 }
205
206 pub fn reset(&self, key: &str) {
208 self.buckets.remove(key);
209 self.counters.remove(key);
210 }
211
212 pub fn get_status(&self, key: &str) -> RateLimitStatus {
214 let tokens_remaining = self
215 .buckets
216 .get(key)
217 .map(|b| b.tokens as u32)
218 .unwrap_or(self.config.burst_size);
219
220 let minute_remaining = self.config.requests_per_minute
221 - self
222 .counters
223 .get(key)
224 .map(|c| c.minute_count() as u32)
225 .unwrap_or(0);
226
227 let hour_remaining = self.config.requests_per_hour
228 - self
229 .counters
230 .get(key)
231 .map(|c| c.hour_count() as u32)
232 .unwrap_or(0);
233
234 RateLimitStatus {
235 tokens_remaining,
236 minute_remaining,
237 hour_remaining,
238 }
239 }
240
241 pub fn cleanup_expired(&self, max_age: Duration) {
243 let now = Instant::now();
244
245 self.buckets
247 .retain(|_, bucket| now.duration_since(bucket.last_update) < max_age);
248
249 self.counters.retain(|_, counter| {
251 !counter.minute_requests.is_empty() || !counter.hour_requests.is_empty()
252 });
253 }
254
255 pub fn active_keys(&self) -> usize {
257 self.buckets.len()
258 }
259}
260
261#[derive(Debug, Serialize, Deserialize)]
263pub struct RateLimitStatus {
264 pub tokens_remaining: u32,
265 pub minute_remaining: u32,
266 pub hour_remaining: u32,
267}
268
269impl Default for RateLimiter {
270 fn default() -> Self {
271 Self::new()
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use std::sync::Arc;
279
280 #[tokio::test]
281 async fn test_basic_rate_limit() {
282 let limiter = RateLimiter::new();
283
284 for _ in 0..10 {
286 assert!(limiter.check("test_key").await.unwrap());
287 }
288 }
289
290 #[tokio::test]
291 async fn test_rate_limit_exceeded() {
292 let config = RateLimitConfig {
293 requests_per_second: 1,
294 requests_per_minute: 2,
295 requests_per_hour: 3,
296 burst_size: 2,
297 ..Default::default()
298 };
299 let limiter = RateLimiter::with_config(config);
300
301 assert!(limiter.check("test_key").await.unwrap());
303 assert!(limiter.check("test_key").await.unwrap());
304
305 assert!(!limiter.check("test_key").await.unwrap());
307 }
308
309 #[tokio::test]
310 async fn test_concurrent_requests() {
311 let config = RateLimitConfig {
312 requests_per_second: 100,
313 requests_per_minute: 1000,
314 requests_per_hour: 10000,
315 burst_size: 50,
316 ..Default::default()
317 };
318 let limiter = Arc::new(RateLimiter::with_config(config));
319
320 let mut tasks = vec![];
321
322 for _ in 0..100 {
323 let limiter_clone = Arc::clone(&limiter);
324 tasks.push(tokio::spawn(async move {
325 limiter_clone.check("concurrent_key").await.unwrap()
326 }));
327 }
328
329 let results: Vec<bool> = futures::future::join_all(tasks)
330 .await
331 .into_iter()
332 .map(|r| r.unwrap())
333 .collect();
334
335 let success_count = results.iter().filter(|&&r| r).count();
337 let fail_count = results.iter().filter(|&&r| !r).count();
338
339 assert!(success_count > 0, "At least some requests should succeed");
341 println!("Success: {}, Fail: {}", success_count, fail_count);
342 }
343
344 #[tokio::test]
345 async fn test_burst_handling() {
346 let config = RateLimitConfig {
347 requests_per_second: 5,
348 requests_per_minute: 100,
349 requests_per_hour: 1000,
350 burst_size: 10,
351 ..Default::default()
352 };
353 let limiter = RateLimiter::with_config(config);
354
355 let mut success_count = 0;
357 for _ in 0..20 {
358 if limiter.check("burst_key").await.unwrap() {
359 success_count += 1;
360 }
361 }
362
363 assert!(
365 success_count <= 11,
366 "Burst should be limited, but got {} successes",
367 success_count
368 );
369 assert!(
370 success_count >= 8,
371 "At least burst_size requests should succeed, but got {}",
372 success_count
373 );
374 }
375
376 #[tokio::test]
377 async fn test_token_refill_accuracy() {
378 let config = RateLimitConfig {
379 requests_per_second: 10,
380 requests_per_minute: 100,
381 requests_per_hour: 1000,
382 burst_size: 5,
383 ..Default::default()
384 };
385 let limiter = RateLimiter::with_config(config);
386
387 for _ in 0..5 {
389 assert!(limiter.check("refill_key").await.unwrap());
390 }
391
392 assert!(!limiter.check("refill_key").await.unwrap());
394
395 tokio::time::sleep(tokio::time::Duration::from_millis(150)).await;
397
398 assert!(
400 limiter.check("refill_key").await.unwrap(),
401 "Token should be refilled after waiting"
402 );
403 }
404
405 #[tokio::test]
406 async fn test_different_keys_isolated() {
407 let config = RateLimitConfig {
408 requests_per_second: 1,
409 requests_per_minute: 1,
410 requests_per_hour: 1,
411 burst_size: 1,
412 ..Default::default()
413 };
414 let limiter = RateLimiter::with_config(config);
415
416 assert!(limiter.check("key1").await.unwrap());
418 assert!(!limiter.check("key1").await.unwrap());
419
420 assert!(limiter.check("key2").await.unwrap());
422 assert!(!limiter.check("key2").await.unwrap());
423 }
424
425 #[test]
426 fn test_reset_functionality() {
427 let config = RateLimitConfig {
428 requests_per_second: 1,
429 requests_per_minute: 1,
430 requests_per_hour: 1,
431 burst_size: 1,
432 ..Default::default()
433 };
434 let limiter = RateLimiter::with_config(config);
435
436 let rt = tokio::runtime::Runtime::new().unwrap();
438 rt.block_on(async {
439 assert!(limiter.check("reset_key").await.unwrap());
440 assert!(!limiter.check("reset_key").await.unwrap());
441 });
442
443 limiter.reset("reset_key");
445
446 rt.block_on(async {
447 assert!(limiter.check("reset_key").await.unwrap());
449 });
450 }
451
452 #[test]
453 fn test_status_reporting() {
454 let config = RateLimitConfig {
455 requests_per_second: 10,
456 requests_per_minute: 100,
457 requests_per_hour: 1000,
458 burst_size: 20,
459 ..Default::default()
460 };
461 let limiter = RateLimiter::with_config(config);
462
463 let rt = tokio::runtime::Runtime::new().unwrap();
464 rt.block_on(async {
465 for _ in 0..5 {
467 limiter.check("status_key").await.unwrap();
468 }
469 });
470
471 let status = limiter.get_status("status_key");
472 assert!(status.tokens_remaining < 20, "Tokens should be consumed");
473 assert!(
474 status.minute_remaining < 100,
475 "Minute count should increase"
476 );
477 }
478
479 #[test]
480 fn test_cleanup_expired() {
481 let limiter = RateLimiter::new();
482
483 let rt = tokio::runtime::Runtime::new().unwrap();
485 rt.block_on(async {
486 limiter.check("key1").await.unwrap();
487 limiter.check("key2").await.unwrap();
488 });
489
490 assert!(limiter.active_keys() >= 2);
491
492 limiter.cleanup_expired(Duration::from_secs(0));
494
495 assert_eq!(limiter.active_keys(), 0);
497 }
498
499 #[test]
500 fn test_active_keys_count() {
501 let limiter = RateLimiter::new();
502
503 let rt = tokio::runtime::Runtime::new().unwrap();
504 rt.block_on(async {
505 limiter.check("key1").await.unwrap();
506 limiter.check("key2").await.unwrap();
507 limiter.check("key3").await.unwrap();
508 });
509
510 assert_eq!(limiter.active_keys(), 3);
511 }
512
513 #[test]
516 fn test_zero_rate_limit() {
517 let config = RateLimitConfig {
520 requests_per_second: 0,
521 requests_per_minute: 100,
522 requests_per_hour: 1000,
523 burst_size: 2,
524 ..Default::default()
525 };
526 let limiter = RateLimiter::with_config(config);
527
528 let rt = tokio::runtime::Runtime::new().unwrap();
529 rt.block_on(async {
530 let first = limiter.check("key").await.unwrap();
532 assert!(first, "First request with burst_size=2 should succeed");
533
534 let second = limiter.check("key").await.unwrap();
535 assert!(second, "Second request with burst_size=2 should succeed");
536
537 let third = limiter.check("key").await.unwrap();
539 assert!(!third, "Third request should be rate limited (no refill)");
540 });
541 }
542
543 #[test]
544 fn test_very_small_burst_size() {
545 let config = RateLimitConfig {
546 requests_per_second: 1,
547 requests_per_minute: 100,
548 requests_per_hour: 1000,
549 burst_size: 1,
550 ..Default::default()
551 };
552 let limiter = RateLimiter::with_config(config);
553
554 let rt = tokio::runtime::Runtime::new().unwrap();
555 rt.block_on(async {
556 assert!(limiter.check("key").await.unwrap());
557 assert!(!limiter.check("key").await.unwrap());
558 });
559 }
560
561 #[test]
562 fn test_large_burst_size() {
563 let config = RateLimitConfig {
564 requests_per_second: 1000,
565 requests_per_minute: 100000,
566 requests_per_hour: 1000000,
567 burst_size: 1000,
568 ..Default::default()
569 };
570 let limiter = RateLimiter::with_config(config);
571
572 let rt = tokio::runtime::Runtime::new().unwrap();
573 rt.block_on(async {
574 let mut success_count = 0;
575 for _ in 0..500 {
576 if limiter.check("key").await.unwrap() {
577 success_count += 1;
578 }
579 }
580 assert!(
581 success_count >= 400,
582 "Should allow most requests with large burst"
583 );
584 });
585 }
586
587 #[test]
588 fn test_empty_key() {
589 let limiter = RateLimiter::new();
590
591 let rt = tokio::runtime::Runtime::new().unwrap();
592 rt.block_on(async {
593 assert!(limiter.check("").await.unwrap());
595 });
596 }
597
598 #[test]
599 fn test_special_characters_in_key() {
600 let limiter = RateLimiter::new();
601
602 let rt = tokio::runtime::Runtime::new().unwrap();
603 rt.block_on(async {
604 let special_keys = vec![
606 "key:with:colons",
607 "key-with-dashes",
608 "key_with_underscores",
609 "key.with.dots",
610 "key/with/slashes",
611 ];
612 for key in special_keys {
613 assert!(
614 limiter.check(key).await.unwrap(),
615 "Key '{}' should work",
616 key
617 );
618 }
619 });
620 }
621
622 #[test]
623 fn test_unicode_key() {
624 let limiter = RateLimiter::new();
625
626 let rt = tokio::runtime::Runtime::new().unwrap();
627 rt.block_on(async {
628 assert!(limiter.check("用户_123").await.unwrap());
630 assert!(limiter.check("🔑_key").await.unwrap());
631 });
632 }
633
634 #[test]
635 fn test_very_long_key() {
636 let limiter = RateLimiter::new();
637 let long_key = "a".repeat(10000);
638
639 let rt = tokio::runtime::Runtime::new().unwrap();
640 rt.block_on(async {
641 assert!(limiter.check(&long_key).await.unwrap());
642 });
643 }
644
645 #[test]
646 fn test_reset_nonexistent_key() {
647 let limiter = RateLimiter::new();
648
649 limiter.reset("nonexistent_key");
651 assert_eq!(limiter.active_keys(), 0);
652 }
653
654 #[test]
655 fn test_status_nonexistent_key() {
656 let limiter = RateLimiter::new();
657 let config = RateLimitConfig::default();
658
659 let status = limiter.get_status("nonexistent");
660 assert_eq!(status.tokens_remaining, config.burst_size);
662 assert_eq!(status.minute_remaining, config.requests_per_minute);
663 assert_eq!(status.hour_remaining, config.requests_per_hour);
664 }
665
666 #[tokio::test]
667 async fn test_rapid_requests() {
668 let config = RateLimitConfig {
669 requests_per_second: 10,
670 requests_per_minute: 100,
671 requests_per_hour: 1000,
672 burst_size: 5,
673 ..Default::default()
674 };
675 let limiter = RateLimiter::with_config(config);
676
677 let mut success_count = 0;
679 for _ in 0..20 {
680 if limiter.check("rapid").await.unwrap() {
681 success_count += 1;
682 }
683 }
684
685 assert!(
687 success_count <= 7,
688 "Expected ~5 successful requests, got {}",
689 success_count
690 );
691 }
692
693 #[test]
694 fn test_cleanup_with_negative_duration() {
695 let limiter = RateLimiter::new();
696
697 let rt = tokio::runtime::Runtime::new().unwrap();
698 rt.block_on(async {
699 limiter.check("key").await.unwrap();
700 });
701
702 limiter.cleanup_expired(Duration::from_secs(u64::MAX));
705
706 assert!(limiter.active_keys() >= 1);
708 }
709
710 #[tokio::test]
711 async fn test_status_accuracy() {
712 let config = RateLimitConfig {
713 requests_per_second: 10,
714 requests_per_minute: 100,
715 requests_per_hour: 1000,
716 burst_size: 10,
717 ..Default::default()
718 };
719 let limiter = RateLimiter::with_config(config);
720
721 for _ in 0..3 {
723 limiter.check("status_test").await.unwrap();
724 }
725
726 let status = limiter.get_status("status_test");
727 assert!(status.tokens_remaining < 10);
729 assert!(status.tokens_remaining > 0);
731 }
732
733 #[test]
734 fn test_config_default_values() {
735 let config = RateLimitConfig::default();
736 assert_eq!(config.requests_per_second, 10);
737 assert_eq!(config.requests_per_minute, 100);
738 assert_eq!(config.requests_per_hour, 1000);
739 assert_eq!(config.burst_size, 20);
740 }
741
742 #[test]
743 fn test_config_serialization() {
744 let config = RateLimitConfig::default();
745 let json = serde_json::to_string(&config).unwrap();
746 let parsed: RateLimitConfig = serde_json::from_str(&json).unwrap();
747 assert_eq!(parsed.requests_per_second, config.requests_per_second);
748 }
749
750 #[tokio::test]
751 async fn test_token_refill_boundary() {
752 let config = RateLimitConfig {
753 requests_per_second: 100, requests_per_minute: 10000,
755 requests_per_hour: 100000,
756 burst_size: 10,
757 ..Default::default()
758 };
759 let limiter = RateLimiter::with_config(config);
760
761 for _ in 0..10 {
763 limiter.check("refill_boundary").await.unwrap();
764 }
765
766 assert!(!limiter.check("refill_boundary").await.unwrap());
768
769 tokio::time::sleep(tokio::time::Duration::from_millis(15)).await;
771
772 assert!(limiter.check("refill_boundary").await.unwrap());
774 }
775}