Skip to main content

ai_agent/services/
rate_limit.rs

1//! Rate limiting for API requests.
2//!
3//! Provides rate limit tracking and enforcement similar to claude code.
4
5use serde::{Deserialize, Serialize};
6use std::time::{Duration, Instant};
7
8/// Rate limit information
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct RateLimit {
11    /// Utilization percentage (0-100)
12    pub utilization: f64,
13    /// Reset timestamp (ISO 8601)
14    pub resets_at: Option<String>,
15    /// Remaining requests
16    pub remaining: Option<u32>,
17    /// Total requests allowed
18    pub limit: Option<u32>,
19}
20
21/// Rate limit configuration
22#[derive(Debug, Clone)]
23pub struct RateLimitConfig {
24    /// Maximum requests per minute
25    pub requests_per_minute: u32,
26    /// Maximum tokens per minute
27    pub tokens_per_minute: u32,
28    /// Enable burst handling
29    pub burst: bool,
30}
31
32impl Default for RateLimitConfig {
33    fn default() -> Self {
34        Self {
35            requests_per_minute: 60,
36            tokens_per_minute: 100000,
37            burst: true,
38        }
39    }
40}
41
42/// Token bucket rate limiter
43#[derive(Debug)]
44pub struct TokenBucket {
45    capacity: u64,
46    tokens: u64,
47    refill_rate: f64, // tokens per millisecond
48    last_refill: Instant,
49}
50
51impl TokenBucket {
52    /// Create a new token bucket
53    pub fn new(capacity: u64, refill_per_second: f64) -> Self {
54        let refill_rate = refill_per_second / 1000.0; // per millisecond
55        Self {
56            capacity,
57            tokens: capacity,
58            refill_rate,
59            last_refill: Instant::now(),
60        }
61    }
62
63    /// Try to consume tokens, returns true if successful
64    pub fn try_consume(&mut self, tokens: u64) -> bool {
65        self.refill();
66
67        if self.tokens >= tokens {
68            self.tokens -= tokens;
69            true
70        } else {
71            false
72        }
73    }
74
75    /// Refill tokens based on elapsed time
76    fn refill(&mut self) {
77        let elapsed = self.last_refill.elapsed().as_millis() as f64;
78        let new_tokens = elapsed * self.refill_rate;
79        self.tokens = (self.tokens + new_tokens as u64).min(self.capacity);
80        self.last_refill = Instant::now();
81    }
82
83    /// Get current token balance
84    pub fn available(&self) -> u64 {
85        self.tokens
86    }
87
88    /// Reset the bucket
89    pub fn reset(&mut self) {
90        self.tokens = self.capacity;
91        self.last_refill = Instant::now();
92    }
93}
94
95/// Sliding window rate limiter
96#[derive(Debug)]
97pub struct SlidingWindow {
98    max_requests: u32,
99    window_ms: u64,
100    requests: Vec<Instant>,
101}
102
103impl SlidingWindow {
104    /// Create a new sliding window
105    pub fn new(max_requests: u32, window_duration: Duration) -> Self {
106        Self {
107            max_requests,
108            window_ms: window_duration.as_millis() as u64,
109            requests: Vec::new(),
110        }
111    }
112
113    /// Try to acquire a slot, returns true if successful
114    pub fn try_acquire(&mut self) -> bool {
115        let now = Instant::now();
116
117        // Remove expired requests
118        let window_start = now
119            .checked_sub(Duration::from_millis(self.window_ms))
120            .unwrap_or(now);
121
122        self.requests.retain(|&t| t > window_start);
123
124        // Check if we can add a new request
125        if self.requests.len() < self.max_requests as usize {
126            self.requests.push(now);
127            true
128        } else {
129            false
130        }
131    }
132
133    /// Get time until next slot is available
134    pub fn time_until_available(&self) -> Option<Duration> {
135        if self.requests.len() < self.max_requests as usize {
136            return None;
137        }
138
139        let oldest = self.requests.iter().min()?;
140        let window_end = oldest
141            .checked_add(Duration::from_millis(self.window_ms))
142            .unwrap_or(*oldest);
143
144        let now = Instant::now();
145        if window_end > now {
146            Some(window_end.duration_since(now))
147        } else {
148            Some(Duration::ZERO)
149        }
150    }
151
152    /// Get current request count in window
153    pub fn current_count(&self) -> u32 {
154        let now = Instant::now();
155        let window_start = now
156            .checked_sub(Duration::from_millis(self.window_ms))
157            .unwrap_or(now);
158
159        self.requests.iter().filter(|&&t| t > window_start).count() as u32
160    }
161
162    /// Reset the window
163    pub fn reset(&mut self) {
164        self.requests.clear();
165    }
166}
167
168/// Rate limiter that combines token bucket and sliding window
169#[derive(Debug)]
170pub struct RateLimiter {
171    request_limiter: SlidingWindow,
172    token_limiter: TokenBucket,
173}
174
175impl RateLimiter {
176    /// Create a new rate limiter
177    pub fn new(config: &RateLimitConfig) -> Self {
178        let request_limiter =
179            SlidingWindow::new(config.requests_per_minute, Duration::from_secs(60));
180        let token_limiter = TokenBucket::new(
181            config.tokens_per_minute as u64,
182            config.tokens_per_minute as f64 / 60.0,
183        );
184
185        Self {
186            request_limiter,
187            token_limiter,
188        }
189    }
190
191    /// Try to acquire rate limit slot for a request with given token count
192    pub fn try_acquire(&mut self, token_count: u64) -> bool {
193        self.request_limiter.try_acquire() && self.token_limiter.try_consume(token_count)
194    }
195
196    /// Wait until rate limit is available
197    pub async fn acquire(&mut self, token_count: u64) {
198        while !self.try_acquire(token_count) {
199            // Wait for either request slot or token refill
200            let request_wait = self.request_limiter.time_until_available();
201            let token_wait = if self.token_limiter.available() < token_count {
202                // Estimate wait time based on deficit
203                let deficit = token_count - self.token_limiter.available();
204                let refill_rate = 1000.0 / 60.0; // tokens per ms
205                Some(Duration::from_millis((deficit as f64 / refill_rate) as u64))
206            } else {
207                None
208            };
209
210            // Wait for the shorter duration
211            let wait_time = match (request_wait, token_wait) {
212                (Some(a), Some(b)) => std::cmp::min(a, b),
213                (Some(a), None) => a,
214                (None, Some(b)) => b,
215                (None, None) => Duration::from_millis(100),
216            };
217
218            tokio::time::sleep(wait_time).await;
219        }
220    }
221
222    /// Get current status
223    pub fn status(&self) -> RateLimitStatus {
224        RateLimitStatus {
225            requests_remaining: self.request_limiter.max_requests
226                - self.request_limiter.current_count(),
227            tokens_remaining: self.token_limiter.available() as u32,
228        }
229    }
230
231    /// Reset the limiter
232    pub fn reset(&mut self) {
233        self.request_limiter.reset();
234        self.token_limiter.reset();
235    }
236}
237
238/// Current rate limit status
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct RateLimitStatus {
241    pub requests_remaining: u32,
242    pub tokens_remaining: u32,
243}
244
245/// Builder for rate limiter
246pub struct RateLimiterBuilder {
247    config: RateLimitConfig,
248}
249
250impl RateLimiterBuilder {
251    pub fn new() -> Self {
252        Self {
253            config: RateLimitConfig::default(),
254        }
255    }
256
257    pub fn requests_per_minute(mut self, rpm: u32) -> Self {
258        self.config.requests_per_minute = rpm;
259        self
260    }
261
262    pub fn tokens_per_minute(mut self, tpm: u32) -> Self {
263        self.config.tokens_per_minute = tpm;
264        self
265    }
266
267    pub fn burst(mut self, enable: bool) -> Self {
268        self.config.burst = enable;
269        self
270    }
271
272    pub fn build(self) -> RateLimiter {
273        RateLimiter::new(&self.config)
274    }
275}
276
277impl Default for RateLimiterBuilder {
278    fn default() -> Self {
279        Self::new()
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    #[test]
288    fn test_token_bucket() {
289        let mut bucket = TokenBucket::new(10, 2.0); // 2 tokens per second
290
291        // Should be able to consume up to capacity
292        assert!(bucket.try_consume(5));
293        assert!(bucket.try_consume(5));
294        assert!(!bucket.try_consume(1)); // Only 0 remaining
295
296        // Wait for refill
297        std::thread::sleep(Duration::from_millis(600));
298        assert!(bucket.try_consume(1)); // Should have ~1 token
299    }
300
301    #[test]
302    fn test_sliding_window() {
303        let mut window = SlidingWindow::new(3, Duration::from_millis(100));
304
305        // Should allow up to max requests
306        assert!(window.try_acquire());
307        assert!(window.try_acquire());
308        assert!(window.try_acquire());
309        assert!(!window.try_acquire()); // Should be full
310
311        // Wait for window to slide
312        std::thread::sleep(Duration::from_millis(150));
313        assert!(window.try_acquire());
314    }
315
316    #[test]
317    fn test_sliding_window_count() {
318        let mut window = SlidingWindow::new(5, Duration::from_secs(1));
319
320        assert_eq!(window.current_count(), 0);
321        window.try_acquire();
322        window.try_acquire();
323        assert_eq!(window.current_count(), 2);
324    }
325
326    #[test]
327    fn test_rate_limiter_builder() {
328        let limiter = RateLimiterBuilder::new()
329            .requests_per_minute(100)
330            .tokens_per_minute(50000)
331            .build();
332
333        let status = limiter.status();
334        assert_eq!(status.requests_remaining, 100);
335    }
336
337    #[tokio::test]
338    async fn test_rate_limiter_acquire() {
339        let mut limiter = RateLimiterBuilder::new()
340            .requests_per_minute(10)
341            .tokens_per_minute(1000)
342            .build();
343
344        // Should be able to acquire immediately
345        limiter.acquire(100).await;
346
347        let status = limiter.status();
348        assert!(status.requests_remaining < 10);
349    }
350}