ai_agent/services/
rate_limit.rs1use serde::{Deserialize, Serialize};
6use std::time::{Duration, Instant};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct RateLimit {
11 pub utilization: f64,
13 pub resets_at: Option<String>,
15 pub remaining: Option<u32>,
17 pub limit: Option<u32>,
19}
20
21#[derive(Debug, Clone)]
23pub struct RateLimitConfig {
24 pub requests_per_minute: u32,
26 pub tokens_per_minute: u32,
28 pub burst: bool,
30}
31
32impl Default for RateLimitConfig {
33 fn default() -> Self {
34 Self {
35 requests_per_minute: 60,
36 tokens_per_minute: 100000,
37 burst: true,
38 }
39 }
40}
41
42#[derive(Debug)]
44pub struct TokenBucket {
45 capacity: u64,
46 tokens: u64,
47 refill_rate: f64, last_refill: Instant,
49}
50
51impl TokenBucket {
52 pub fn new(capacity: u64, refill_per_second: f64) -> Self {
54 let refill_rate = refill_per_second / 1000.0; Self {
56 capacity,
57 tokens: capacity,
58 refill_rate,
59 last_refill: Instant::now(),
60 }
61 }
62
63 pub fn try_consume(&mut self, tokens: u64) -> bool {
65 self.refill();
66
67 if self.tokens >= tokens {
68 self.tokens -= tokens;
69 true
70 } else {
71 false
72 }
73 }
74
75 fn refill(&mut self) {
77 let elapsed = self.last_refill.elapsed().as_millis() as f64;
78 let new_tokens = elapsed * self.refill_rate;
79 self.tokens = (self.tokens + new_tokens as u64).min(self.capacity);
80 self.last_refill = Instant::now();
81 }
82
83 pub fn available(&self) -> u64 {
85 self.tokens
86 }
87
88 pub fn reset(&mut self) {
90 self.tokens = self.capacity;
91 self.last_refill = Instant::now();
92 }
93}
94
95#[derive(Debug)]
97pub struct SlidingWindow {
98 max_requests: u32,
99 window_ms: u64,
100 requests: Vec<Instant>,
101}
102
103impl SlidingWindow {
104 pub fn new(max_requests: u32, window_duration: Duration) -> Self {
106 Self {
107 max_requests,
108 window_ms: window_duration.as_millis() as u64,
109 requests: Vec::new(),
110 }
111 }
112
113 pub fn try_acquire(&mut self) -> bool {
115 let now = Instant::now();
116
117 let window_start = now
119 .checked_sub(Duration::from_millis(self.window_ms))
120 .unwrap_or(now);
121
122 self.requests.retain(|&t| t > window_start);
123
124 if self.requests.len() < self.max_requests as usize {
126 self.requests.push(now);
127 true
128 } else {
129 false
130 }
131 }
132
133 pub fn time_until_available(&self) -> Option<Duration> {
135 if self.requests.len() < self.max_requests as usize {
136 return None;
137 }
138
139 let oldest = self.requests.iter().min()?;
140 let window_end = oldest
141 .checked_add(Duration::from_millis(self.window_ms))
142 .unwrap_or(*oldest);
143
144 let now = Instant::now();
145 if window_end > now {
146 Some(window_end.duration_since(now))
147 } else {
148 Some(Duration::ZERO)
149 }
150 }
151
152 pub fn current_count(&self) -> u32 {
154 let now = Instant::now();
155 let window_start = now
156 .checked_sub(Duration::from_millis(self.window_ms))
157 .unwrap_or(now);
158
159 self.requests.iter().filter(|&&t| t > window_start).count() as u32
160 }
161
162 pub fn reset(&mut self) {
164 self.requests.clear();
165 }
166}
167
168#[derive(Debug)]
170pub struct RateLimiter {
171 request_limiter: SlidingWindow,
172 token_limiter: TokenBucket,
173}
174
175impl RateLimiter {
176 pub fn new(config: &RateLimitConfig) -> Self {
178 let request_limiter =
179 SlidingWindow::new(config.requests_per_minute, Duration::from_secs(60));
180 let token_limiter = TokenBucket::new(
181 config.tokens_per_minute as u64,
182 config.tokens_per_minute as f64 / 60.0,
183 );
184
185 Self {
186 request_limiter,
187 token_limiter,
188 }
189 }
190
191 pub fn try_acquire(&mut self, token_count: u64) -> bool {
193 self.request_limiter.try_acquire() && self.token_limiter.try_consume(token_count)
194 }
195
196 pub async fn acquire(&mut self, token_count: u64) {
198 while !self.try_acquire(token_count) {
199 let request_wait = self.request_limiter.time_until_available();
201 let token_wait = if self.token_limiter.available() < token_count {
202 let deficit = token_count - self.token_limiter.available();
204 let refill_rate = 1000.0 / 60.0; Some(Duration::from_millis((deficit as f64 / refill_rate) as u64))
206 } else {
207 None
208 };
209
210 let wait_time = match (request_wait, token_wait) {
212 (Some(a), Some(b)) => std::cmp::min(a, b),
213 (Some(a), None) => a,
214 (None, Some(b)) => b,
215 (None, None) => Duration::from_millis(100),
216 };
217
218 tokio::time::sleep(wait_time).await;
219 }
220 }
221
222 pub fn status(&self) -> RateLimitStatus {
224 RateLimitStatus {
225 requests_remaining: self.request_limiter.max_requests
226 - self.request_limiter.current_count(),
227 tokens_remaining: self.token_limiter.available() as u32,
228 }
229 }
230
231 pub fn reset(&mut self) {
233 self.request_limiter.reset();
234 self.token_limiter.reset();
235 }
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct RateLimitStatus {
241 pub requests_remaining: u32,
242 pub tokens_remaining: u32,
243}
244
245pub struct RateLimiterBuilder {
247 config: RateLimitConfig,
248}
249
250impl RateLimiterBuilder {
251 pub fn new() -> Self {
252 Self {
253 config: RateLimitConfig::default(),
254 }
255 }
256
257 pub fn requests_per_minute(mut self, rpm: u32) -> Self {
258 self.config.requests_per_minute = rpm;
259 self
260 }
261
262 pub fn tokens_per_minute(mut self, tpm: u32) -> Self {
263 self.config.tokens_per_minute = tpm;
264 self
265 }
266
267 pub fn burst(mut self, enable: bool) -> Self {
268 self.config.burst = enable;
269 self
270 }
271
272 pub fn build(self) -> RateLimiter {
273 RateLimiter::new(&self.config)
274 }
275}
276
277impl Default for RateLimiterBuilder {
278 fn default() -> Self {
279 Self::new()
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[test]
288 fn test_token_bucket() {
289 let mut bucket = TokenBucket::new(10, 2.0); assert!(bucket.try_consume(5));
293 assert!(bucket.try_consume(5));
294 assert!(!bucket.try_consume(1)); std::thread::sleep(Duration::from_millis(600));
298 assert!(bucket.try_consume(1)); }
300
301 #[test]
302 fn test_sliding_window() {
303 let mut window = SlidingWindow::new(3, Duration::from_millis(100));
304
305 assert!(window.try_acquire());
307 assert!(window.try_acquire());
308 assert!(window.try_acquire());
309 assert!(!window.try_acquire()); std::thread::sleep(Duration::from_millis(150));
313 assert!(window.try_acquire());
314 }
315
316 #[test]
317 fn test_sliding_window_count() {
318 let mut window = SlidingWindow::new(5, Duration::from_secs(1));
319
320 assert_eq!(window.current_count(), 0);
321 window.try_acquire();
322 window.try_acquire();
323 assert_eq!(window.current_count(), 2);
324 }
325
326 #[test]
327 fn test_rate_limiter_builder() {
328 let limiter = RateLimiterBuilder::new()
329 .requests_per_minute(100)
330 .tokens_per_minute(50000)
331 .build();
332
333 let status = limiter.status();
334 assert_eq!(status.requests_remaining, 100);
335 }
336
337 #[tokio::test]
338 async fn test_rate_limiter_acquire() {
339 let mut limiter = RateLimiterBuilder::new()
340 .requests_per_minute(10)
341 .tokens_per_minute(1000)
342 .build();
343
344 limiter.acquire(100).await;
346
347 let status = limiter.status();
348 assert!(status.requests_remaining < 10);
349 }
350}