Skip to main content

sh_layer0/
rate_limiter.rs

1//! 速率限制模块
2//!
3//! Token Bucket 和滑动窗口算法。
4//! 使用 DashMap 实现高性能并发访问。
5
6use dashmap::DashMap;
7use serde::{Deserialize, Serialize};
8use std::time::{Duration, Instant};
9
10/// 速率限制配置
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct RateLimitConfig {
13    /// 每秒允许的请求数
14    pub requests_per_second: u32,
15    /// 每分钟允许的请求数
16    pub requests_per_minute: u32,
17    /// 每小时允许的请求数
18    pub requests_per_hour: u32,
19    /// Burst 大小(突发流量)
20    pub burst_size: u32,
21    /// 分钟窗口容量上限(防止内存无限增长)
22    #[serde(default = "default_minute_capacity")]
23    pub minute_window_capacity: usize,
24    /// 小时窗口容量上限
25    #[serde(default = "default_hour_capacity")]
26    pub hour_window_capacity: usize,
27}
28
29fn default_minute_capacity() -> usize {
30    1000
31}
32fn default_hour_capacity() -> usize {
33    10000
34}
35
36impl Default for RateLimitConfig {
37    fn default() -> Self {
38        Self {
39            requests_per_second: 10,
40            requests_per_minute: 100,
41            requests_per_hour: 1000,
42            burst_size: 20,
43            minute_window_capacity: 1000,
44            hour_window_capacity: 10000,
45        }
46    }
47}
48
49/// Token Bucket 状态
50#[derive(Debug)]
51struct TokenBucket {
52    /// 当前 token 数
53    tokens: f64,
54    /// 最大 token 数
55    max_tokens: f64,
56    /// 每秒补充的 token 数
57    refill_rate: f64,
58    /// 上次更新时间
59    last_update: Instant,
60}
61
62impl TokenBucket {
63    fn new(max_tokens: f64, refill_rate: f64) -> Self {
64        Self {
65            tokens: max_tokens,
66            max_tokens,
67            refill_rate,
68            last_update: Instant::now(),
69        }
70    }
71
72    fn try_take(&mut self, tokens: f64) -> bool {
73        // 先补充 token
74        let elapsed = self.last_update.elapsed().as_secs_f64();
75        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
76        self.last_update = Instant::now();
77
78        // 检查是否有足够的 token
79        if self.tokens >= tokens {
80            self.tokens -= tokens;
81            true
82        } else {
83            false
84        }
85    }
86}
87
88/// 速率限制器(高性能版本,使用 DashMap)
89pub struct RateLimiter {
90    /// 配置
91    config: RateLimitConfig,
92    /// 每个 key 的 token bucket(并发安全)
93    buckets: DashMap<String, TokenBucket>,
94    /// 每个 key 的请求计数(滑动窗口,并发安全)
95    counters: DashMap<String, SlidingWindowCounter>,
96}
97
98/// 滑动窗口计数器
99#[derive(Debug)]
100struct SlidingWindowCounter {
101    /// 最近一分钟的请求时间戳
102    minute_requests: Vec<Instant>,
103    /// 最近一小时的请求时间戳
104    hour_requests: Vec<Instant>,
105    /// 分钟窗口容量上限
106    minute_capacity: usize,
107    /// 小时窗口容量上限
108    hour_capacity: usize,
109}
110
111impl SlidingWindowCounter {
112    fn with_capacity(minute_capacity: usize, hour_capacity: usize) -> Self {
113        Self {
114            minute_requests: Vec::new(),
115            hour_requests: Vec::new(),
116            minute_capacity,
117            hour_capacity,
118        }
119    }
120
121    fn add_request(&mut self) {
122        let now = Instant::now();
123        self.minute_requests.push(now);
124        self.hour_requests.push(now);
125
126        // 清理过期记录
127        self.minute_requests
128            .retain(|t| t.elapsed() < Duration::from_secs(60));
129        self.hour_requests
130            .retain(|t| t.elapsed() < Duration::from_secs(3600));
131
132        // 如果超过容量上限,强制截断(保留最新的)
133        if self.minute_requests.len() > self.minute_capacity {
134            let excess = self.minute_requests.len() - self.minute_capacity;
135            self.minute_requests.drain(0..excess);
136        }
137        if self.hour_requests.len() > self.hour_capacity {
138            let excess = self.hour_requests.len() - self.hour_capacity;
139            self.hour_requests.drain(0..excess);
140        }
141    }
142
143    fn minute_count(&self) -> usize {
144        self.minute_requests.len()
145    }
146
147    fn hour_count(&self) -> usize {
148        self.hour_requests.len()
149    }
150}
151
152impl RateLimiter {
153    pub fn new() -> Self {
154        Self::with_config(RateLimitConfig::default())
155    }
156
157    pub fn with_config(config: RateLimitConfig) -> Self {
158        Self {
159            config: config.clone(),
160            buckets: DashMap::new(),
161            counters: DashMap::new(),
162        }
163    }
164
165    /// 检查是否允许请求
166    pub async fn check(&self, key: &str) -> anyhow::Result<bool> {
167        // 检查 Token Bucket(秒级限制)
168        let bucket_result = {
169            let mut bucket = self.buckets.entry(key.to_string()).or_insert_with(|| {
170                TokenBucket::new(
171                    self.config.burst_size as f64,
172                    self.config.requests_per_second as f64,
173                )
174            });
175            bucket.try_take(1.0)
176        };
177
178        if !bucket_result {
179            return Ok(false);
180        }
181
182        // 检查滑动窗口(分钟和小时限制)
183        let window_result = {
184            let minute_cap = self.config.minute_window_capacity;
185            let hour_cap = self.config.hour_window_capacity;
186            let mut counter = self
187                .counters
188                .entry(key.to_string())
189                .or_insert_with(|| SlidingWindowCounter::with_capacity(minute_cap, hour_cap));
190
191            let minute_exceeded =
192                counter.minute_count() >= self.config.requests_per_minute as usize;
193            let hour_exceeded = counter.hour_count() >= self.config.requests_per_hour as usize;
194
195            if minute_exceeded || hour_exceeded {
196                false
197            } else {
198                counter.add_request();
199                true
200            }
201        };
202
203        Ok(window_result)
204    }
205
206    /// 重置指定 key 的限制
207    pub fn reset(&self, key: &str) {
208        self.buckets.remove(key);
209        self.counters.remove(key);
210    }
211
212    /// 获取指定 key 的状态
213    pub fn get_status(&self, key: &str) -> RateLimitStatus {
214        let tokens_remaining = self
215            .buckets
216            .get(key)
217            .map(|b| b.tokens as u32)
218            .unwrap_or(self.config.burst_size);
219
220        let minute_remaining = self.config.requests_per_minute
221            - self
222                .counters
223                .get(key)
224                .map(|c| c.minute_count() as u32)
225                .unwrap_or(0);
226
227        let hour_remaining = self.config.requests_per_hour
228            - self
229                .counters
230                .get(key)
231                .map(|c| c.hour_count() as u32)
232                .unwrap_or(0);
233
234        RateLimitStatus {
235            tokens_remaining,
236            minute_remaining,
237            hour_remaining,
238        }
239    }
240
241    /// 清理过期的条目(定期维护)
242    pub fn cleanup_expired(&self, max_age: Duration) {
243        let now = Instant::now();
244
245        // 清理过期的 buckets
246        self.buckets
247            .retain(|_, bucket| now.duration_since(bucket.last_update) < max_age);
248
249        // 清理空的 counters
250        self.counters.retain(|_, counter| {
251            !counter.minute_requests.is_empty() || !counter.hour_requests.is_empty()
252        });
253    }
254
255    /// 获取当前活跃的 key 数量
256    pub fn active_keys(&self) -> usize {
257        self.buckets.len()
258    }
259}
260
261/// 速率限制状态
262#[derive(Debug, Serialize, Deserialize)]
263pub struct RateLimitStatus {
264    pub tokens_remaining: u32,
265    pub minute_remaining: u32,
266    pub hour_remaining: u32,
267}
268
269impl Default for RateLimiter {
270    fn default() -> Self {
271        Self::new()
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use std::sync::Arc;
279
280    #[tokio::test]
281    async fn test_basic_rate_limit() {
282        let limiter = RateLimiter::new();
283
284        // 前10次请求应该成功
285        for _ in 0..10 {
286            assert!(limiter.check("test_key").await.unwrap());
287        }
288    }
289
290    #[tokio::test]
291    async fn test_rate_limit_exceeded() {
292        let config = RateLimitConfig {
293            requests_per_second: 1,
294            requests_per_minute: 2,
295            requests_per_hour: 3,
296            burst_size: 2,
297            ..Default::default()
298        };
299        let limiter = RateLimiter::with_config(config);
300
301        // Burst 应该允许
302        assert!(limiter.check("test_key").await.unwrap());
303        assert!(limiter.check("test_key").await.unwrap());
304
305        // 超过 burst 应该被限制
306        assert!(!limiter.check("test_key").await.unwrap());
307    }
308
309    #[tokio::test]
310    async fn test_concurrent_requests() {
311        let config = RateLimitConfig {
312            requests_per_second: 100,
313            requests_per_minute: 1000,
314            requests_per_hour: 10000,
315            burst_size: 50,
316            ..Default::default()
317        };
318        let limiter = Arc::new(RateLimiter::with_config(config));
319
320        let mut tasks = vec![];
321
322        for _ in 0..100 {
323            let limiter_clone = Arc::clone(&limiter);
324            tasks.push(tokio::spawn(async move {
325                limiter_clone.check("concurrent_key").await.unwrap()
326            }));
327        }
328
329        let results: Vec<bool> = futures::future::join_all(tasks)
330            .await
331            .into_iter()
332            .map(|r| r.unwrap())
333            .collect();
334
335        // 统计成功和失败的请求数
336        let success_count = results.iter().filter(|&&r| r).count();
337        let fail_count = results.iter().filter(|&&r| !r).count();
338
339        // 由于 burst_size 是 50,应该有一些成功,一些失败
340        assert!(success_count > 0, "At least some requests should succeed");
341        println!("Success: {}, Fail: {}", success_count, fail_count);
342    }
343
344    #[tokio::test]
345    async fn test_burst_handling() {
346        let config = RateLimitConfig {
347            requests_per_second: 5,
348            requests_per_minute: 100,
349            requests_per_hour: 1000,
350            burst_size: 10,
351            ..Default::default()
352        };
353        let limiter = RateLimiter::with_config(config);
354
355        // 连续快速请求应该消耗 burst
356        let mut success_count = 0;
357        for _ in 0..20 {
358            if limiter.check("burst_key").await.unwrap() {
359                success_count += 1;
360            }
361        }
362
363        // burst_size 是 10,应该允许约 10 个请求通过
364        assert!(
365            success_count <= 11,
366            "Burst should be limited, but got {} successes",
367            success_count
368        );
369        assert!(
370            success_count >= 8,
371            "At least burst_size requests should succeed, but got {}",
372            success_count
373        );
374    }
375
376    #[tokio::test]
377    async fn test_token_refill_accuracy() {
378        let config = RateLimitConfig {
379            requests_per_second: 10,
380            requests_per_minute: 100,
381            requests_per_hour: 1000,
382            burst_size: 5,
383            ..Default::default()
384        };
385        let limiter = RateLimiter::with_config(config);
386
387        // 消耗所有 token
388        for _ in 0..5 {
389            assert!(limiter.check("refill_key").await.unwrap());
390        }
391
392        // 应该被限制
393        assert!(!limiter.check("refill_key").await.unwrap());
394
395        // 等待 token 补充 (100ms 应该补充约 1 个 token)
396        tokio::time::sleep(tokio::time::Duration::from_millis(150)).await;
397
398        // 现在应该至少有一个 token 可用
399        assert!(
400            limiter.check("refill_key").await.unwrap(),
401            "Token should be refilled after waiting"
402        );
403    }
404
405    #[tokio::test]
406    async fn test_different_keys_isolated() {
407        let config = RateLimitConfig {
408            requests_per_second: 1,
409            requests_per_minute: 1,
410            requests_per_hour: 1,
411            burst_size: 1,
412            ..Default::default()
413        };
414        let limiter = RateLimiter::with_config(config);
415
416        // key1 应该被限制在 burst_size
417        assert!(limiter.check("key1").await.unwrap());
418        assert!(!limiter.check("key1").await.unwrap());
419
420        // key2 应该独立计数
421        assert!(limiter.check("key2").await.unwrap());
422        assert!(!limiter.check("key2").await.unwrap());
423    }
424
425    #[test]
426    fn test_reset_functionality() {
427        let config = RateLimitConfig {
428            requests_per_second: 1,
429            requests_per_minute: 1,
430            requests_per_hour: 1,
431            burst_size: 1,
432            ..Default::default()
433        };
434        let limiter = RateLimiter::with_config(config);
435
436        // 在同步上下文中测试 reset
437        let rt = tokio::runtime::Runtime::new().unwrap();
438        rt.block_on(async {
439            assert!(limiter.check("reset_key").await.unwrap());
440            assert!(!limiter.check("reset_key").await.unwrap());
441        });
442
443        // 重置
444        limiter.reset("reset_key");
445
446        rt.block_on(async {
447            // 重置后应该可以再次请求
448            assert!(limiter.check("reset_key").await.unwrap());
449        });
450    }
451
452    #[test]
453    fn test_status_reporting() {
454        let config = RateLimitConfig {
455            requests_per_second: 10,
456            requests_per_minute: 100,
457            requests_per_hour: 1000,
458            burst_size: 20,
459            ..Default::default()
460        };
461        let limiter = RateLimiter::with_config(config);
462
463        let rt = tokio::runtime::Runtime::new().unwrap();
464        rt.block_on(async {
465            // 消耗一些 token
466            for _ in 0..5 {
467                limiter.check("status_key").await.unwrap();
468            }
469        });
470
471        let status = limiter.get_status("status_key");
472        assert!(status.tokens_remaining < 20, "Tokens should be consumed");
473        assert!(
474            status.minute_remaining < 100,
475            "Minute count should increase"
476        );
477    }
478
479    #[test]
480    fn test_cleanup_expired() {
481        let limiter = RateLimiter::new();
482
483        // 创建一些条目
484        let rt = tokio::runtime::Runtime::new().unwrap();
485        rt.block_on(async {
486            limiter.check("key1").await.unwrap();
487            limiter.check("key2").await.unwrap();
488        });
489
490        assert!(limiter.active_keys() >= 2);
491
492        // 清理(设置为 0 表示立即过期)
493        limiter.cleanup_expired(Duration::from_secs(0));
494
495        // 应该被清理
496        assert_eq!(limiter.active_keys(), 0);
497    }
498
499    #[test]
500    fn test_active_keys_count() {
501        let limiter = RateLimiter::new();
502
503        let rt = tokio::runtime::Runtime::new().unwrap();
504        rt.block_on(async {
505            limiter.check("key1").await.unwrap();
506            limiter.check("key2").await.unwrap();
507            limiter.check("key3").await.unwrap();
508        });
509
510        assert_eq!(limiter.active_keys(), 3);
511    }
512
513    // ========== 边界条件测试 ==========
514
515    #[test]
516    fn test_zero_rate_limit() {
517        // 零速率限制 - requests_per_minute/hour 为 0 会阻止所有请求
518        // 我们测试 requests_per_second=0 但其他值足够大的情况
519        let config = RateLimitConfig {
520            requests_per_second: 0,
521            requests_per_minute: 100,
522            requests_per_hour: 1000,
523            burst_size: 2,
524            ..Default::default()
525        };
526        let limiter = RateLimiter::with_config(config);
527
528        let rt = tokio::runtime::Runtime::new().unwrap();
529        rt.block_on(async {
530            // burst_size 为 2,前两次请求应该成功(token bucket 初始有 burst_size 个 token)
531            let first = limiter.check("key").await.unwrap();
532            assert!(first, "First request with burst_size=2 should succeed");
533
534            let second = limiter.check("key").await.unwrap();
535            assert!(second, "Second request with burst_size=2 should succeed");
536
537            // 第三次请求应该失败(没有 refill,因为 requests_per_second=0)
538            let third = limiter.check("key").await.unwrap();
539            assert!(!third, "Third request should be rate limited (no refill)");
540        });
541    }
542
543    #[test]
544    fn test_very_small_burst_size() {
545        let config = RateLimitConfig {
546            requests_per_second: 1,
547            requests_per_minute: 100,
548            requests_per_hour: 1000,
549            burst_size: 1,
550            ..Default::default()
551        };
552        let limiter = RateLimiter::with_config(config);
553
554        let rt = tokio::runtime::Runtime::new().unwrap();
555        rt.block_on(async {
556            assert!(limiter.check("key").await.unwrap());
557            assert!(!limiter.check("key").await.unwrap());
558        });
559    }
560
561    #[test]
562    fn test_large_burst_size() {
563        let config = RateLimitConfig {
564            requests_per_second: 1000,
565            requests_per_minute: 100000,
566            requests_per_hour: 1000000,
567            burst_size: 1000,
568            ..Default::default()
569        };
570        let limiter = RateLimiter::with_config(config);
571
572        let rt = tokio::runtime::Runtime::new().unwrap();
573        rt.block_on(async {
574            let mut success_count = 0;
575            for _ in 0..500 {
576                if limiter.check("key").await.unwrap() {
577                    success_count += 1;
578                }
579            }
580            assert!(
581                success_count >= 400,
582                "Should allow most requests with large burst"
583            );
584        });
585    }
586
587    #[test]
588    fn test_empty_key() {
589        let limiter = RateLimiter::new();
590
591        let rt = tokio::runtime::Runtime::new().unwrap();
592        rt.block_on(async {
593            // 空键应该被正常处理
594            assert!(limiter.check("").await.unwrap());
595        });
596    }
597
598    #[test]
599    fn test_special_characters_in_key() {
600        let limiter = RateLimiter::new();
601
602        let rt = tokio::runtime::Runtime::new().unwrap();
603        rt.block_on(async {
604            // 特殊字符键
605            let special_keys = vec![
606                "key:with:colons",
607                "key-with-dashes",
608                "key_with_underscores",
609                "key.with.dots",
610                "key/with/slashes",
611            ];
612            for key in special_keys {
613                assert!(
614                    limiter.check(key).await.unwrap(),
615                    "Key '{}' should work",
616                    key
617                );
618            }
619        });
620    }
621
622    #[test]
623    fn test_unicode_key() {
624        let limiter = RateLimiter::new();
625
626        let rt = tokio::runtime::Runtime::new().unwrap();
627        rt.block_on(async {
628            // Unicode 键
629            assert!(limiter.check("用户_123").await.unwrap());
630            assert!(limiter.check("🔑_key").await.unwrap());
631        });
632    }
633
634    #[test]
635    fn test_very_long_key() {
636        let limiter = RateLimiter::new();
637        let long_key = "a".repeat(10000);
638
639        let rt = tokio::runtime::Runtime::new().unwrap();
640        rt.block_on(async {
641            assert!(limiter.check(&long_key).await.unwrap());
642        });
643    }
644
645    #[test]
646    fn test_reset_nonexistent_key() {
647        let limiter = RateLimiter::new();
648
649        // 重置不存在的键不应该 panic
650        limiter.reset("nonexistent_key");
651        assert_eq!(limiter.active_keys(), 0);
652    }
653
654    #[test]
655    fn test_status_nonexistent_key() {
656        let limiter = RateLimiter::new();
657        let config = RateLimitConfig::default();
658
659        let status = limiter.get_status("nonexistent");
660        // 对于不存在的键,应该返回配置的最大值
661        assert_eq!(status.tokens_remaining, config.burst_size);
662        assert_eq!(status.minute_remaining, config.requests_per_minute);
663        assert_eq!(status.hour_remaining, config.requests_per_hour);
664    }
665
666    #[tokio::test]
667    async fn test_rapid_requests() {
668        let config = RateLimitConfig {
669            requests_per_second: 10,
670            requests_per_minute: 100,
671            requests_per_hour: 1000,
672            burst_size: 5,
673            ..Default::default()
674        };
675        let limiter = RateLimiter::with_config(config);
676
677        // 快速连续请求
678        let mut success_count = 0;
679        for _ in 0..20 {
680            if limiter.check("rapid").await.unwrap() {
681                success_count += 1;
682            }
683        }
684
685        // 应该受到 burst_size 限制
686        assert!(
687            success_count <= 7,
688            "Expected ~5 successful requests, got {}",
689            success_count
690        );
691    }
692
693    #[test]
694    fn test_cleanup_with_negative_duration() {
695        let limiter = RateLimiter::new();
696
697        let rt = tokio::runtime::Runtime::new().unwrap();
698        rt.block_on(async {
699            limiter.check("key").await.unwrap();
700        });
701
702        // 使用非常大的 duration(相当于负数时间)
703        // 这不应该 panic
704        limiter.cleanup_expired(Duration::from_secs(u64::MAX));
705
706        // 键应该仍然存在
707        assert!(limiter.active_keys() >= 1);
708    }
709
710    #[tokio::test]
711    async fn test_status_accuracy() {
712        let config = RateLimitConfig {
713            requests_per_second: 10,
714            requests_per_minute: 100,
715            requests_per_hour: 1000,
716            burst_size: 10,
717            ..Default::default()
718        };
719        let limiter = RateLimiter::with_config(config);
720
721        // 消耗 3 个 token
722        for _ in 0..3 {
723            limiter.check("status_test").await.unwrap();
724        }
725
726        let status = limiter.get_status("status_test");
727        // tokens_remaining 应该少于初始值
728        assert!(status.tokens_remaining < 10);
729        // 但不能是负数
730        assert!(status.tokens_remaining > 0);
731    }
732
733    #[test]
734    fn test_config_default_values() {
735        let config = RateLimitConfig::default();
736        assert_eq!(config.requests_per_second, 10);
737        assert_eq!(config.requests_per_minute, 100);
738        assert_eq!(config.requests_per_hour, 1000);
739        assert_eq!(config.burst_size, 20);
740    }
741
742    #[test]
743    fn test_config_serialization() {
744        let config = RateLimitConfig::default();
745        let json = serde_json::to_string(&config).unwrap();
746        let parsed: RateLimitConfig = serde_json::from_str(&json).unwrap();
747        assert_eq!(parsed.requests_per_second, config.requests_per_second);
748    }
749
750    #[tokio::test]
751    async fn test_token_refill_boundary() {
752        let config = RateLimitConfig {
753            requests_per_second: 100, // 100 tokens/sec
754            requests_per_minute: 10000,
755            requests_per_hour: 100000,
756            burst_size: 10,
757            ..Default::default()
758        };
759        let limiter = RateLimiter::with_config(config);
760
761        // 消耗所有 tokens
762        for _ in 0..10 {
763            limiter.check("refill_boundary").await.unwrap();
764        }
765
766        // 应该被限制
767        assert!(!limiter.check("refill_boundary").await.unwrap());
768
769        // 等待 10ms,应该补充约 1 个 token
770        tokio::time::sleep(tokio::time::Duration::from_millis(15)).await;
771
772        // 现在应该有至少一个 token
773        assert!(limiter.check("refill_boundary").await.unwrap());
774    }
775}