Skip to main content

sh_layer0/
rate_limiter.rs

1//! 速率限制模块
2//!
3//! Token Bucket 和滑动窗口算法。
4//! 使用 DashMap 实现高性能并发访问。
5
6use dashmap::DashMap;
7use serde::{Deserialize, Serialize};
8use std::time::{Duration, Instant};
9
10/// 速率限制配置
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct RateLimitConfig {
13    /// 每秒允许的请求数
14    pub requests_per_second: u32,
15    /// 每分钟允许的请求数
16    pub requests_per_minute: u32,
17    /// 每小时允许的请求数
18    pub requests_per_hour: u32,
19    /// Burst 大小(突发流量)
20    pub burst_size: u32,
21    /// 分钟窗口容量上限(防止内存无限增长)
22    #[serde(default = "default_minute_capacity")]
23    pub minute_window_capacity: usize,
24    /// 小时窗口容量上限
25    #[serde(default = "default_hour_capacity")]
26    pub hour_window_capacity: usize,
27}
28
29fn default_minute_capacity() -> usize {
30    1000
31}
32fn default_hour_capacity() -> usize {
33    10000
34}
35
36impl Default for RateLimitConfig {
37    fn default() -> Self {
38        Self {
39            requests_per_second: 10,
40            requests_per_minute: 100,
41            requests_per_hour: 1000,
42            burst_size: 20,
43            minute_window_capacity: 1000,
44            hour_window_capacity: 10000,
45        }
46    }
47}
48
49/// Token Bucket 状态
50#[derive(Debug)]
51struct TokenBucket {
52    /// 当前 token 数
53    tokens: f64,
54    /// 最大 token 数
55    max_tokens: f64,
56    /// 每秒补充的 token 数
57    refill_rate: f64,
58    /// 上次更新时间
59    last_update: Instant,
60}
61
62impl TokenBucket {
63    fn new(max_tokens: f64, refill_rate: f64) -> Self {
64        Self {
65            tokens: max_tokens,
66            max_tokens,
67            refill_rate,
68            last_update: Instant::now(),
69        }
70    }
71
72    fn try_take(&mut self, tokens: f64) -> bool {
73        // 先补充 token
74        let elapsed = self.last_update.elapsed().as_secs_f64();
75        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
76        self.last_update = Instant::now();
77
78        // 检查是否有足够的 token
79        if self.tokens >= tokens {
80            self.tokens -= tokens;
81            true
82        } else {
83            false
84        }
85    }
86}
87
88/// 速率限制器(高性能版本,使用 DashMap)
89pub struct RateLimiter {
90    /// 配置
91    config: RateLimitConfig,
92    /// 每个 key 的 token bucket(并发安全)
93    buckets: DashMap<String, TokenBucket>,
94    /// 每个 key 的请求计数(滑动窗口,并发安全)
95    counters: DashMap<String, SlidingWindowCounter>,
96}
97
98/// 滑动窗口计数器
99#[derive(Debug)]
100struct SlidingWindowCounter {
101    /// 最近一分钟的请求时间戳
102    minute_requests: Vec<Instant>,
103    /// 最近一小时的请求时间戳
104    hour_requests: Vec<Instant>,
105    /// 分钟窗口容量上限
106    minute_capacity: usize,
107    /// 小时窗口容量上限
108    hour_capacity: usize,
109}
110
111impl SlidingWindowCounter {
112    fn new() -> Self {
113        Self {
114            minute_requests: Vec::new(),
115            hour_requests: Vec::new(),
116            minute_capacity: 1000, // 默认上限
117            hour_capacity: 10000,  // 默认上限
118        }
119    }
120
121    fn with_capacity(minute_capacity: usize, hour_capacity: usize) -> Self {
122        Self {
123            minute_requests: Vec::new(),
124            hour_requests: Vec::new(),
125            minute_capacity,
126            hour_capacity,
127        }
128    }
129
130    fn add_request(&mut self) {
131        let now = Instant::now();
132        self.minute_requests.push(now);
133        self.hour_requests.push(now);
134
135        // 清理过期记录
136        self.minute_requests
137            .retain(|t| t.elapsed() < Duration::from_secs(60));
138        self.hour_requests
139            .retain(|t| t.elapsed() < Duration::from_secs(3600));
140
141        // 如果超过容量上限,强制截断(保留最新的)
142        if self.minute_requests.len() > self.minute_capacity {
143            let excess = self.minute_requests.len() - self.minute_capacity;
144            self.minute_requests.drain(0..excess);
145        }
146        if self.hour_requests.len() > self.hour_capacity {
147            let excess = self.hour_requests.len() - self.hour_capacity;
148            self.hour_requests.drain(0..excess);
149        }
150    }
151
152    fn minute_count(&self) -> usize {
153        self.minute_requests.len()
154    }
155
156    fn hour_count(&self) -> usize {
157        self.hour_requests.len()
158    }
159}
160
161impl RateLimiter {
162    pub fn new() -> Self {
163        Self::with_config(RateLimitConfig::default())
164    }
165
166    pub fn with_config(config: RateLimitConfig) -> Self {
167        Self {
168            config: config.clone(),
169            buckets: DashMap::new(),
170            counters: DashMap::new(),
171        }
172    }
173
174    /// 检查是否允许请求
175    pub async fn check(&self, key: &str) -> anyhow::Result<bool> {
176        // 检查 Token Bucket(秒级限制)
177        let bucket_result = {
178            let mut bucket = self.buckets.entry(key.to_string()).or_insert_with(|| {
179                TokenBucket::new(
180                    self.config.burst_size as f64,
181                    self.config.requests_per_second as f64,
182                )
183            });
184            bucket.try_take(1.0)
185        };
186
187        if !bucket_result {
188            return Ok(false);
189        }
190
191        // 检查滑动窗口(分钟和小时限制)
192        let window_result = {
193            let minute_cap = self.config.minute_window_capacity;
194            let hour_cap = self.config.hour_window_capacity;
195            let mut counter = self
196                .counters
197                .entry(key.to_string())
198                .or_insert_with(|| SlidingWindowCounter::with_capacity(minute_cap, hour_cap));
199
200            let minute_exceeded =
201                counter.minute_count() >= self.config.requests_per_minute as usize;
202            let hour_exceeded = counter.hour_count() >= self.config.requests_per_hour as usize;
203
204            if minute_exceeded || hour_exceeded {
205                false
206            } else {
207                counter.add_request();
208                true
209            }
210        };
211
212        Ok(window_result)
213    }
214
215    /// 重置指定 key 的限制
216    pub fn reset(&self, key: &str) {
217        self.buckets.remove(key);
218        self.counters.remove(key);
219    }
220
221    /// 获取指定 key 的状态
222    pub fn get_status(&self, key: &str) -> RateLimitStatus {
223        let tokens_remaining = self
224            .buckets
225            .get(key)
226            .map(|b| b.tokens as u32)
227            .unwrap_or(self.config.burst_size);
228
229        let minute_remaining = self.config.requests_per_minute
230            - self
231                .counters
232                .get(key)
233                .map(|c| c.minute_count() as u32)
234                .unwrap_or(0);
235
236        let hour_remaining = self.config.requests_per_hour
237            - self
238                .counters
239                .get(key)
240                .map(|c| c.hour_count() as u32)
241                .unwrap_or(0);
242
243        RateLimitStatus {
244            tokens_remaining,
245            minute_remaining,
246            hour_remaining,
247        }
248    }
249
250    /// 清理过期的条目(定期维护)
251    pub fn cleanup_expired(&self, max_age: Duration) {
252        let now = Instant::now();
253
254        // 清理过期的 buckets
255        self.buckets
256            .retain(|_, bucket| now.duration_since(bucket.last_update) < max_age);
257
258        // 清理空的 counters
259        self.counters.retain(|_, counter| {
260            !counter.minute_requests.is_empty() || !counter.hour_requests.is_empty()
261        });
262    }
263
264    /// 获取当前活跃的 key 数量
265    pub fn active_keys(&self) -> usize {
266        self.buckets.len()
267    }
268}
269
270/// 速率限制状态
271#[derive(Debug, Serialize, Deserialize)]
272pub struct RateLimitStatus {
273    pub tokens_remaining: u32,
274    pub minute_remaining: u32,
275    pub hour_remaining: u32,
276}
277
278impl Default for RateLimiter {
279    fn default() -> Self {
280        Self::new()
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use std::sync::Arc;
288
289    #[tokio::test]
290    async fn test_basic_rate_limit() {
291        let limiter = RateLimiter::new();
292
293        // 前10次请求应该成功
294        for _ in 0..10 {
295            assert!(limiter.check("test_key").await.unwrap());
296        }
297    }
298
299    #[tokio::test]
300    async fn test_rate_limit_exceeded() {
301        let config = RateLimitConfig {
302            requests_per_second: 1,
303            requests_per_minute: 2,
304            requests_per_hour: 3,
305            burst_size: 2,
306            ..Default::default()
307        };
308        let limiter = RateLimiter::with_config(config);
309
310        // Burst 应该允许
311        assert!(limiter.check("test_key").await.unwrap());
312        assert!(limiter.check("test_key").await.unwrap());
313
314        // 超过 burst 应该被限制
315        assert!(!limiter.check("test_key").await.unwrap());
316    }
317
318    #[tokio::test]
319    async fn test_concurrent_requests() {
320        let config = RateLimitConfig {
321            requests_per_second: 100,
322            requests_per_minute: 1000,
323            requests_per_hour: 10000,
324            burst_size: 50,
325            ..Default::default()
326        };
327        let limiter = Arc::new(RateLimiter::with_config(config));
328
329        let mut tasks = vec![];
330
331        for _ in 0..100 {
332            let limiter_clone = Arc::clone(&limiter);
333            tasks.push(tokio::spawn(async move {
334                limiter_clone.check("concurrent_key").await.unwrap()
335            }));
336        }
337
338        let results: Vec<bool> = futures::future::join_all(tasks)
339            .await
340            .into_iter()
341            .map(|r| r.unwrap())
342            .collect();
343
344        // 统计成功和失败的请求数
345        let success_count = results.iter().filter(|&&r| r).count();
346        let fail_count = results.iter().filter(|&&r| !r).count();
347
348        // 由于 burst_size 是 50,应该有一些成功,一些失败
349        assert!(success_count > 0, "At least some requests should succeed");
350        println!("Success: {}, Fail: {}", success_count, fail_count);
351    }
352
353    #[tokio::test]
354    async fn test_burst_handling() {
355        let config = RateLimitConfig {
356            requests_per_second: 5,
357            requests_per_minute: 100,
358            requests_per_hour: 1000,
359            burst_size: 10,
360            ..Default::default()
361        };
362        let limiter = RateLimiter::with_config(config);
363
364        // 连续快速请求应该消耗 burst
365        let mut success_count = 0;
366        for _ in 0..20 {
367            if limiter.check("burst_key").await.unwrap() {
368                success_count += 1;
369            }
370        }
371
372        // burst_size 是 10,应该允许约 10 个请求通过
373        assert!(
374            success_count <= 11,
375            "Burst should be limited, but got {} successes",
376            success_count
377        );
378        assert!(
379            success_count >= 8,
380            "At least burst_size requests should succeed, but got {}",
381            success_count
382        );
383    }
384
385    #[tokio::test]
386    async fn test_token_refill_accuracy() {
387        let config = RateLimitConfig {
388            requests_per_second: 10,
389            requests_per_minute: 100,
390            requests_per_hour: 1000,
391            burst_size: 5,
392            ..Default::default()
393        };
394        let limiter = RateLimiter::with_config(config);
395
396        // 消耗所有 token
397        for _ in 0..5 {
398            assert!(limiter.check("refill_key").await.unwrap());
399        }
400
401        // 应该被限制
402        assert!(!limiter.check("refill_key").await.unwrap());
403
404        // 等待 token 补充 (100ms 应该补充约 1 个 token)
405        tokio::time::sleep(tokio::time::Duration::from_millis(150)).await;
406
407        // 现在应该至少有一个 token 可用
408        assert!(
409            limiter.check("refill_key").await.unwrap(),
410            "Token should be refilled after waiting"
411        );
412    }
413
414    #[tokio::test]
415    async fn test_different_keys_isolated() {
416        let config = RateLimitConfig {
417            requests_per_second: 1,
418            requests_per_minute: 1,
419            requests_per_hour: 1,
420            burst_size: 1,
421            ..Default::default()
422        };
423        let limiter = RateLimiter::with_config(config);
424
425        // key1 应该被限制在 burst_size
426        assert!(limiter.check("key1").await.unwrap());
427        assert!(!limiter.check("key1").await.unwrap());
428
429        // key2 应该独立计数
430        assert!(limiter.check("key2").await.unwrap());
431        assert!(!limiter.check("key2").await.unwrap());
432    }
433
434    #[test]
435    fn test_reset_functionality() {
436        let config = RateLimitConfig {
437            requests_per_second: 1,
438            requests_per_minute: 1,
439            requests_per_hour: 1,
440            burst_size: 1,
441            ..Default::default()
442        };
443        let limiter = RateLimiter::with_config(config);
444
445        // 在同步上下文中测试 reset
446        let rt = tokio::runtime::Runtime::new().unwrap();
447        rt.block_on(async {
448            assert!(limiter.check("reset_key").await.unwrap());
449            assert!(!limiter.check("reset_key").await.unwrap());
450        });
451
452        // 重置
453        limiter.reset("reset_key");
454
455        rt.block_on(async {
456            // 重置后应该可以再次请求
457            assert!(limiter.check("reset_key").await.unwrap());
458        });
459    }
460
461    #[test]
462    fn test_status_reporting() {
463        let config = RateLimitConfig {
464            requests_per_second: 10,
465            requests_per_minute: 100,
466            requests_per_hour: 1000,
467            burst_size: 20,
468            ..Default::default()
469        };
470        let limiter = RateLimiter::with_config(config);
471
472        let rt = tokio::runtime::Runtime::new().unwrap();
473        rt.block_on(async {
474            // 消耗一些 token
475            for _ in 0..5 {
476                limiter.check("status_key").await.unwrap();
477            }
478        });
479
480        let status = limiter.get_status("status_key");
481        assert!(status.tokens_remaining < 20, "Tokens should be consumed");
482        assert!(
483            status.minute_remaining < 100,
484            "Minute count should increase"
485        );
486    }
487
488    #[test]
489    fn test_cleanup_expired() {
490        let limiter = RateLimiter::new();
491
492        // 创建一些条目
493        let rt = tokio::runtime::Runtime::new().unwrap();
494        rt.block_on(async {
495            limiter.check("key1").await.unwrap();
496            limiter.check("key2").await.unwrap();
497        });
498
499        assert!(limiter.active_keys() >= 2);
500
501        // 清理(设置为 0 表示立即过期)
502        limiter.cleanup_expired(Duration::from_secs(0));
503
504        // 应该被清理
505        assert_eq!(limiter.active_keys(), 0);
506    }
507
508    #[test]
509    fn test_active_keys_count() {
510        let limiter = RateLimiter::new();
511
512        let rt = tokio::runtime::Runtime::new().unwrap();
513        rt.block_on(async {
514            limiter.check("key1").await.unwrap();
515            limiter.check("key2").await.unwrap();
516            limiter.check("key3").await.unwrap();
517        });
518
519        assert_eq!(limiter.active_keys(), 3);
520    }
521
522    // ========== 边界条件测试 ==========
523
524    #[test]
525    fn test_zero_rate_limit() {
526        // 零速率限制 - requests_per_minute/hour 为 0 会阻止所有请求
527        // 我们测试 requests_per_second=0 但其他值足够大的情况
528        let config = RateLimitConfig {
529            requests_per_second: 0,
530            requests_per_minute: 100,
531            requests_per_hour: 1000,
532            burst_size: 2,
533            ..Default::default()
534        };
535        let limiter = RateLimiter::with_config(config);
536
537        let rt = tokio::runtime::Runtime::new().unwrap();
538        rt.block_on(async {
539            // burst_size 为 2,前两次请求应该成功(token bucket 初始有 burst_size 个 token)
540            let first = limiter.check("key").await.unwrap();
541            assert!(first, "First request with burst_size=2 should succeed");
542
543            let second = limiter.check("key").await.unwrap();
544            assert!(second, "Second request with burst_size=2 should succeed");
545
546            // 第三次请求应该失败(没有 refill,因为 requests_per_second=0)
547            let third = limiter.check("key").await.unwrap();
548            assert!(!third, "Third request should be rate limited (no refill)");
549        });
550    }
551
552    #[test]
553    fn test_very_small_burst_size() {
554        let config = RateLimitConfig {
555            requests_per_second: 1,
556            requests_per_minute: 100,
557            requests_per_hour: 1000,
558            burst_size: 1,
559            ..Default::default()
560        };
561        let limiter = RateLimiter::with_config(config);
562
563        let rt = tokio::runtime::Runtime::new().unwrap();
564        rt.block_on(async {
565            assert!(limiter.check("key").await.unwrap());
566            assert!(!limiter.check("key").await.unwrap());
567        });
568    }
569
570    #[test]
571    fn test_large_burst_size() {
572        let config = RateLimitConfig {
573            requests_per_second: 1000,
574            requests_per_minute: 100000,
575            requests_per_hour: 1000000,
576            burst_size: 1000,
577            ..Default::default()
578        };
579        let limiter = RateLimiter::with_config(config);
580
581        let rt = tokio::runtime::Runtime::new().unwrap();
582        rt.block_on(async {
583            let mut success_count = 0;
584            for _ in 0..500 {
585                if limiter.check("key").await.unwrap() {
586                    success_count += 1;
587                }
588            }
589            assert!(
590                success_count >= 400,
591                "Should allow most requests with large burst"
592            );
593        });
594    }
595
596    #[test]
597    fn test_empty_key() {
598        let limiter = RateLimiter::new();
599
600        let rt = tokio::runtime::Runtime::new().unwrap();
601        rt.block_on(async {
602            // 空键应该被正常处理
603            assert!(limiter.check("").await.unwrap());
604        });
605    }
606
607    #[test]
608    fn test_special_characters_in_key() {
609        let limiter = RateLimiter::new();
610
611        let rt = tokio::runtime::Runtime::new().unwrap();
612        rt.block_on(async {
613            // 特殊字符键
614            let special_keys = vec![
615                "key:with:colons",
616                "key-with-dashes",
617                "key_with_underscores",
618                "key.with.dots",
619                "key/with/slashes",
620            ];
621            for key in special_keys {
622                assert!(
623                    limiter.check(key).await.unwrap(),
624                    "Key '{}' should work",
625                    key
626                );
627            }
628        });
629    }
630
631    #[test]
632    fn test_unicode_key() {
633        let limiter = RateLimiter::new();
634
635        let rt = tokio::runtime::Runtime::new().unwrap();
636        rt.block_on(async {
637            // Unicode 键
638            assert!(limiter.check("用户_123").await.unwrap());
639            assert!(limiter.check("🔑_key").await.unwrap());
640        });
641    }
642
643    #[test]
644    fn test_very_long_key() {
645        let limiter = RateLimiter::new();
646        let long_key = "a".repeat(10000);
647
648        let rt = tokio::runtime::Runtime::new().unwrap();
649        rt.block_on(async {
650            assert!(limiter.check(&long_key).await.unwrap());
651        });
652    }
653
654    #[test]
655    fn test_reset_nonexistent_key() {
656        let limiter = RateLimiter::new();
657
658        // 重置不存在的键不应该 panic
659        limiter.reset("nonexistent_key");
660        assert_eq!(limiter.active_keys(), 0);
661    }
662
663    #[test]
664    fn test_status_nonexistent_key() {
665        let limiter = RateLimiter::new();
666        let config = RateLimitConfig::default();
667
668        let status = limiter.get_status("nonexistent");
669        // 对于不存在的键,应该返回配置的最大值
670        assert_eq!(status.tokens_remaining, config.burst_size);
671        assert_eq!(status.minute_remaining, config.requests_per_minute);
672        assert_eq!(status.hour_remaining, config.requests_per_hour);
673    }
674
675    #[tokio::test]
676    async fn test_rapid_requests() {
677        let config = RateLimitConfig {
678            requests_per_second: 10,
679            requests_per_minute: 100,
680            requests_per_hour: 1000,
681            burst_size: 5,
682            ..Default::default()
683        };
684        let limiter = RateLimiter::with_config(config);
685
686        // 快速连续请求
687        let mut success_count = 0;
688        for _ in 0..20 {
689            if limiter.check("rapid").await.unwrap() {
690                success_count += 1;
691            }
692        }
693
694        // 应该受到 burst_size 限制
695        assert!(
696            success_count <= 7,
697            "Expected ~5 successful requests, got {}",
698            success_count
699        );
700    }
701
702    #[test]
703    fn test_cleanup_with_negative_duration() {
704        let limiter = RateLimiter::new();
705
706        let rt = tokio::runtime::Runtime::new().unwrap();
707        rt.block_on(async {
708            limiter.check("key").await.unwrap();
709        });
710
711        // 使用非常大的 duration(相当于负数时间)
712        // 这不应该 panic
713        limiter.cleanup_expired(Duration::from_secs(u64::MAX));
714
715        // 键应该仍然存在
716        assert!(limiter.active_keys() >= 1);
717    }
718
719    #[tokio::test]
720    async fn test_status_accuracy() {
721        let config = RateLimitConfig {
722            requests_per_second: 10,
723            requests_per_minute: 100,
724            requests_per_hour: 1000,
725            burst_size: 10,
726            ..Default::default()
727        };
728        let limiter = RateLimiter::with_config(config);
729
730        // 消耗 3 个 token
731        for _ in 0..3 {
732            limiter.check("status_test").await.unwrap();
733        }
734
735        let status = limiter.get_status("status_test");
736        // tokens_remaining 应该少于初始值
737        assert!(status.tokens_remaining < 10);
738        // 但不能是负数
739        assert!(status.tokens_remaining > 0);
740    }
741
742    #[test]
743    fn test_config_default_values() {
744        let config = RateLimitConfig::default();
745        assert_eq!(config.requests_per_second, 10);
746        assert_eq!(config.requests_per_minute, 100);
747        assert_eq!(config.requests_per_hour, 1000);
748        assert_eq!(config.burst_size, 20);
749    }
750
751    #[test]
752    fn test_config_serialization() {
753        let config = RateLimitConfig::default();
754        let json = serde_json::to_string(&config).unwrap();
755        let parsed: RateLimitConfig = serde_json::from_str(&json).unwrap();
756        assert_eq!(parsed.requests_per_second, config.requests_per_second);
757    }
758
759    #[tokio::test]
760    async fn test_token_refill_boundary() {
761        let config = RateLimitConfig {
762            requests_per_second: 100, // 100 tokens/sec
763            requests_per_minute: 10000,
764            requests_per_hour: 100000,
765            burst_size: 10,
766            ..Default::default()
767        };
768        let limiter = RateLimiter::with_config(config);
769
770        // 消耗所有 tokens
771        for _ in 0..10 {
772            limiter.check("refill_boundary").await.unwrap();
773        }
774
775        // 应该被限制
776        assert!(!limiter.check("refill_boundary").await.unwrap());
777
778        // 等待 10ms,应该补充约 1 个 token
779        tokio::time::sleep(tokio::time::Duration::from_millis(15)).await;
780
781        // 现在应该有至少一个 token
782        assert!(limiter.check("refill_boundary").await.unwrap());
783    }
784}