Skip to main content

aster/ratelimit/
limiter.rs

1//! 速率限制器
2//!
3//! 管理 API 请求速率限制
4
5use parking_lot::RwLock;
6use std::collections::VecDeque;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::mpsc;
10
11/// 速率限制配置
12#[derive(Debug, Clone)]
13pub struct RateLimitConfig {
14    /// 每分钟最大请求数
15    pub max_requests_per_minute: u32,
16    /// 每分钟最大 Token 数
17    pub max_tokens_per_minute: u32,
18    /// 最大重试次数
19    pub max_retries: u32,
20    /// 基础重试延迟(毫秒)
21    pub base_retry_delay_ms: u64,
22    /// 最大重试延迟(毫秒)
23    pub max_retry_delay_ms: u64,
24    /// 可重试的状态码
25    pub retryable_status_codes: Vec<u16>,
26}
27
28impl Default for RateLimitConfig {
29    fn default() -> Self {
30        Self {
31            max_requests_per_minute: 50,
32            max_tokens_per_minute: 100_000,
33            max_retries: 3,
34            base_retry_delay_ms: 1000,
35            max_retry_delay_ms: 60_000,
36            retryable_status_codes: vec![429, 500, 502, 503, 504],
37        }
38    }
39}
40
41/// 速率限制状态
42#[derive(Debug, Clone)]
43pub struct RateLimitState {
44    /// 本分钟请求数
45    pub requests_this_minute: u32,
46    /// 本分钟 Token 数
47    pub tokens_this_minute: u32,
48    /// 上次重置时间
49    pub last_reset_time: Instant,
50    /// 是否被限流
51    pub is_rate_limited: bool,
52    /// 重试等待时间(秒)
53    pub retry_after: Option<u64>,
54}
55
56impl Default for RateLimitState {
57    fn default() -> Self {
58        Self {
59            requests_this_minute: 0,
60            tokens_this_minute: 0,
61            last_reset_time: Instant::now(),
62            is_rate_limited: false,
63            retry_after: None,
64        }
65    }
66}
67
68/// 速率限制事件
69#[derive(Debug, Clone)]
70pub enum RateLimitEvent {
71    /// 被限流
72    RateLimited {
73        reason: String,
74        current: u32,
75        limit: u32,
76    },
77    /// 限流重置
78    RateLimitReset,
79}
80
81/// 速率限制器
82pub struct RateLimiter {
83    config: RateLimitConfig,
84    state: Arc<RwLock<RateLimitState>>,
85    event_tx: Option<mpsc::UnboundedSender<RateLimitEvent>>,
86    queue: Arc<RwLock<VecDeque<QueuedRequest>>>,
87}
88
89struct QueuedRequest {
90    id: u64,
91    estimated_tokens: Option<u32>,
92}
93
94impl RateLimiter {
95    /// 创建新的速率限制器
96    pub fn new(config: RateLimitConfig) -> Self {
97        Self {
98            config,
99            state: Arc::new(RwLock::new(RateLimitState::default())),
100            event_tx: None,
101            queue: Arc::new(RwLock::new(VecDeque::new())),
102        }
103    }
104
105    /// 设置事件通道
106    pub fn with_event_channel(mut self, tx: mpsc::UnboundedSender<RateLimitEvent>) -> Self {
107        self.event_tx = Some(tx);
108        self
109    }
110
111    /// 检查是否需要重置计数器
112    fn maybe_reset(&self) {
113        let mut state = self.state.write();
114        let elapsed = state.last_reset_time.elapsed();
115
116        if elapsed >= Duration::from_secs(60) {
117            state.requests_this_minute = 0;
118            state.tokens_this_minute = 0;
119            state.last_reset_time = Instant::now();
120
121            if state.is_rate_limited {
122                state.is_rate_limited = false;
123                if let Some(ref tx) = self.event_tx {
124                    let _ = tx.send(RateLimitEvent::RateLimitReset);
125                }
126            }
127        }
128    }
129
130    /// 检查是否可以发起请求
131    pub fn can_make_request(&self, estimated_tokens: Option<u32>) -> bool {
132        self.maybe_reset();
133        let state = self.state.read();
134
135        if state.is_rate_limited {
136            return false;
137        }
138
139        if state.requests_this_minute >= self.config.max_requests_per_minute {
140            return false;
141        }
142
143        if let Some(tokens) = estimated_tokens {
144            if state.tokens_this_minute + tokens > self.config.max_tokens_per_minute {
145                return false;
146            }
147        }
148
149        true
150    }
151
152    /// 记录请求
153    pub fn record_request(&self, tokens: Option<u32>) {
154        self.maybe_reset();
155        let mut state = self.state.write();
156
157        state.requests_this_minute += 1;
158
159        if let Some(t) = tokens {
160            state.tokens_this_minute += t;
161        }
162
163        // 检查是否达到限制
164        if state.requests_this_minute >= self.config.max_requests_per_minute {
165            state.is_rate_limited = true;
166            if let Some(ref tx) = self.event_tx {
167                let _ = tx.send(RateLimitEvent::RateLimited {
168                    reason: "requests".to_string(),
169                    current: state.requests_this_minute,
170                    limit: self.config.max_requests_per_minute,
171                });
172            }
173        }
174
175        if state.tokens_this_minute >= self.config.max_tokens_per_minute {
176            state.is_rate_limited = true;
177            if let Some(ref tx) = self.event_tx {
178                let _ = tx.send(RateLimitEvent::RateLimited {
179                    reason: "tokens".to_string(),
180                    current: state.tokens_this_minute,
181                    limit: self.config.max_tokens_per_minute,
182                });
183            }
184        }
185    }
186
187    /// 处理 API 返回的限流响应
188    pub fn handle_rate_limit_response(&self, retry_after: Option<u64>) {
189        let mut state = self.state.write();
190        state.is_rate_limited = true;
191        state.retry_after = retry_after;
192
193        if let Some(ref tx) = self.event_tx {
194            let _ = tx.send(RateLimitEvent::RateLimited {
195                reason: "api".to_string(),
196                current: 0,
197                limit: 0,
198            });
199        }
200    }
201
202    /// 获取当前状态
203    pub fn get_state(&self) -> RateLimitState {
204        self.maybe_reset();
205        self.state.read().clone()
206    }
207
208    /// 获取距离重置的时间(毫秒)
209    pub fn get_time_until_reset(&self) -> u64 {
210        let state = self.state.read();
211        let elapsed = state.last_reset_time.elapsed().as_millis() as u64;
212        60_000u64.saturating_sub(elapsed)
213    }
214
215    /// 等待直到可以发起请求
216    pub async fn wait_for_capacity(&self, estimated_tokens: Option<u32>) {
217        while !self.can_make_request(estimated_tokens) {
218            let wait_time = self.get_time_until_reset();
219            tokio::time::sleep(Duration::from_millis(wait_time.min(1000))).await;
220        }
221    }
222
223    /// 获取配置
224    pub fn config(&self) -> &RateLimitConfig {
225        &self.config
226    }
227}
228
229impl Default for RateLimiter {
230    fn default() -> Self {
231        Self::new(RateLimitConfig::default())
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn test_rate_limiter_default() {
241        let limiter = RateLimiter::default();
242        assert!(limiter.can_make_request(None));
243    }
244
245    #[test]
246    fn test_record_request() {
247        let limiter = RateLimiter::default();
248        limiter.record_request(Some(100));
249
250        let state = limiter.get_state();
251        assert_eq!(state.requests_this_minute, 1);
252        assert_eq!(state.tokens_this_minute, 100);
253    }
254
255    #[test]
256    fn test_rate_limit_reached() {
257        let config = RateLimitConfig {
258            max_requests_per_minute: 2,
259            ..Default::default()
260        };
261        let limiter = RateLimiter::new(config);
262
263        assert!(limiter.can_make_request(None));
264        limiter.record_request(None);
265        assert!(limiter.can_make_request(None));
266        limiter.record_request(None);
267        assert!(!limiter.can_make_request(None));
268    }
269
270    #[test]
271    fn test_token_limit() {
272        let config = RateLimitConfig {
273            max_tokens_per_minute: 1000,
274            ..Default::default()
275        };
276        let limiter = RateLimiter::new(config);
277
278        assert!(limiter.can_make_request(Some(500)));
279        limiter.record_request(Some(500));
280        assert!(limiter.can_make_request(Some(400)));
281        assert!(!limiter.can_make_request(Some(600)));
282    }
283}