aster/ratelimit/
limiter.rs1use parking_lot::RwLock;
6use std::collections::VecDeque;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::mpsc;
10
11#[derive(Debug, Clone)]
13pub struct RateLimitConfig {
14 pub max_requests_per_minute: u32,
16 pub max_tokens_per_minute: u32,
18 pub max_retries: u32,
20 pub base_retry_delay_ms: u64,
22 pub max_retry_delay_ms: u64,
24 pub retryable_status_codes: Vec<u16>,
26}
27
28impl Default for RateLimitConfig {
29 fn default() -> Self {
30 Self {
31 max_requests_per_minute: 50,
32 max_tokens_per_minute: 100_000,
33 max_retries: 3,
34 base_retry_delay_ms: 1000,
35 max_retry_delay_ms: 60_000,
36 retryable_status_codes: vec![429, 500, 502, 503, 504],
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct RateLimitState {
44 pub requests_this_minute: u32,
46 pub tokens_this_minute: u32,
48 pub last_reset_time: Instant,
50 pub is_rate_limited: bool,
52 pub retry_after: Option<u64>,
54}
55
56impl Default for RateLimitState {
57 fn default() -> Self {
58 Self {
59 requests_this_minute: 0,
60 tokens_this_minute: 0,
61 last_reset_time: Instant::now(),
62 is_rate_limited: false,
63 retry_after: None,
64 }
65 }
66}
67
68#[derive(Debug, Clone)]
70pub enum RateLimitEvent {
71 RateLimited {
73 reason: String,
74 current: u32,
75 limit: u32,
76 },
77 RateLimitReset,
79}
80
81pub struct RateLimiter {
83 config: RateLimitConfig,
84 state: Arc<RwLock<RateLimitState>>,
85 event_tx: Option<mpsc::UnboundedSender<RateLimitEvent>>,
86 queue: Arc<RwLock<VecDeque<QueuedRequest>>>,
87}
88
89struct QueuedRequest {
90 id: u64,
91 estimated_tokens: Option<u32>,
92}
93
94impl RateLimiter {
95 pub fn new(config: RateLimitConfig) -> Self {
97 Self {
98 config,
99 state: Arc::new(RwLock::new(RateLimitState::default())),
100 event_tx: None,
101 queue: Arc::new(RwLock::new(VecDeque::new())),
102 }
103 }
104
105 pub fn with_event_channel(mut self, tx: mpsc::UnboundedSender<RateLimitEvent>) -> Self {
107 self.event_tx = Some(tx);
108 self
109 }
110
111 fn maybe_reset(&self) {
113 let mut state = self.state.write();
114 let elapsed = state.last_reset_time.elapsed();
115
116 if elapsed >= Duration::from_secs(60) {
117 state.requests_this_minute = 0;
118 state.tokens_this_minute = 0;
119 state.last_reset_time = Instant::now();
120
121 if state.is_rate_limited {
122 state.is_rate_limited = false;
123 if let Some(ref tx) = self.event_tx {
124 let _ = tx.send(RateLimitEvent::RateLimitReset);
125 }
126 }
127 }
128 }
129
130 pub fn can_make_request(&self, estimated_tokens: Option<u32>) -> bool {
132 self.maybe_reset();
133 let state = self.state.read();
134
135 if state.is_rate_limited {
136 return false;
137 }
138
139 if state.requests_this_minute >= self.config.max_requests_per_minute {
140 return false;
141 }
142
143 if let Some(tokens) = estimated_tokens {
144 if state.tokens_this_minute + tokens > self.config.max_tokens_per_minute {
145 return false;
146 }
147 }
148
149 true
150 }
151
152 pub fn record_request(&self, tokens: Option<u32>) {
154 self.maybe_reset();
155 let mut state = self.state.write();
156
157 state.requests_this_minute += 1;
158
159 if let Some(t) = tokens {
160 state.tokens_this_minute += t;
161 }
162
163 if state.requests_this_minute >= self.config.max_requests_per_minute {
165 state.is_rate_limited = true;
166 if let Some(ref tx) = self.event_tx {
167 let _ = tx.send(RateLimitEvent::RateLimited {
168 reason: "requests".to_string(),
169 current: state.requests_this_minute,
170 limit: self.config.max_requests_per_minute,
171 });
172 }
173 }
174
175 if state.tokens_this_minute >= self.config.max_tokens_per_minute {
176 state.is_rate_limited = true;
177 if let Some(ref tx) = self.event_tx {
178 let _ = tx.send(RateLimitEvent::RateLimited {
179 reason: "tokens".to_string(),
180 current: state.tokens_this_minute,
181 limit: self.config.max_tokens_per_minute,
182 });
183 }
184 }
185 }
186
187 pub fn handle_rate_limit_response(&self, retry_after: Option<u64>) {
189 let mut state = self.state.write();
190 state.is_rate_limited = true;
191 state.retry_after = retry_after;
192
193 if let Some(ref tx) = self.event_tx {
194 let _ = tx.send(RateLimitEvent::RateLimited {
195 reason: "api".to_string(),
196 current: 0,
197 limit: 0,
198 });
199 }
200 }
201
202 pub fn get_state(&self) -> RateLimitState {
204 self.maybe_reset();
205 self.state.read().clone()
206 }
207
208 pub fn get_time_until_reset(&self) -> u64 {
210 let state = self.state.read();
211 let elapsed = state.last_reset_time.elapsed().as_millis() as u64;
212 60_000u64.saturating_sub(elapsed)
213 }
214
215 pub async fn wait_for_capacity(&self, estimated_tokens: Option<u32>) {
217 while !self.can_make_request(estimated_tokens) {
218 let wait_time = self.get_time_until_reset();
219 tokio::time::sleep(Duration::from_millis(wait_time.min(1000))).await;
220 }
221 }
222
223 pub fn config(&self) -> &RateLimitConfig {
225 &self.config
226 }
227}
228
229impl Default for RateLimiter {
230 fn default() -> Self {
231 Self::new(RateLimitConfig::default())
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn test_rate_limiter_default() {
241 let limiter = RateLimiter::default();
242 assert!(limiter.can_make_request(None));
243 }
244
245 #[test]
246 fn test_record_request() {
247 let limiter = RateLimiter::default();
248 limiter.record_request(Some(100));
249
250 let state = limiter.get_state();
251 assert_eq!(state.requests_this_minute, 1);
252 assert_eq!(state.tokens_this_minute, 100);
253 }
254
255 #[test]
256 fn test_rate_limit_reached() {
257 let config = RateLimitConfig {
258 max_requests_per_minute: 2,
259 ..Default::default()
260 };
261 let limiter = RateLimiter::new(config);
262
263 assert!(limiter.can_make_request(None));
264 limiter.record_request(None);
265 assert!(limiter.can_make_request(None));
266 limiter.record_request(None);
267 assert!(!limiter.can_make_request(None));
268 }
269
270 #[test]
271 fn test_token_limit() {
272 let config = RateLimitConfig {
273 max_tokens_per_minute: 1000,
274 ..Default::default()
275 };
276 let limiter = RateLimiter::new(config);
277
278 assert!(limiter.can_make_request(Some(500)));
279 limiter.record_request(Some(500));
280 assert!(limiter.can_make_request(Some(400)));
281 assert!(!limiter.can_make_request(Some(600)));
282 }
283}