ricecoder_providers/
rate_limiter.rs

1//! Rate limiting for API calls
2//!
3//! This module provides rate limiting functionality to prevent exceeding provider limits
4//! and to implement backoff strategies for rate limit errors.
5
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10/// Token bucket rate limiter
11///
12/// Implements the token bucket algorithm for rate limiting API calls.
13/// Tokens are added at a fixed rate, and each request consumes tokens.
14/// If insufficient tokens are available, the request is rate limited.
15pub struct TokenBucketLimiter {
16    /// Tokens per second (refill rate)
17    tokens_per_second: f64,
18    /// Maximum tokens in bucket (burst capacity)
19    max_tokens: f64,
20    /// Current tokens in bucket
21    tokens: f64,
22    /// Last refill time
23    last_refill: Instant,
24}
25
26impl TokenBucketLimiter {
27    /// Create a new token bucket limiter
28    ///
29    /// # Arguments
30    /// * `tokens_per_second` - Rate at which tokens are added (e.g., 10 for 10 requests/sec)
31    /// * `max_tokens` - Maximum tokens in bucket (burst capacity)
32    pub fn new(tokens_per_second: f64, max_tokens: f64) -> Self {
33        Self {
34            tokens_per_second,
35            max_tokens,
36            tokens: max_tokens,
37            last_refill: Instant::now(),
38        }
39    }
40
41    /// Refill tokens based on elapsed time
42    fn refill(&mut self) {
43        let now = Instant::now();
44        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
45        let new_tokens = elapsed * self.tokens_per_second;
46        self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
47        self.last_refill = now;
48    }
49
50    /// Try to acquire tokens
51    ///
52    /// Returns true if tokens were acquired, false if rate limited
53    pub fn try_acquire(&mut self, tokens: f64) -> bool {
54        self.refill();
55        if self.tokens >= tokens {
56            self.tokens -= tokens;
57            true
58        } else {
59            false
60        }
61    }
62
63    /// Wait until tokens are available
64    ///
65    /// Blocks until the specified number of tokens are available
66    pub async fn acquire(&mut self, tokens: f64) {
67        loop {
68            if self.try_acquire(tokens) {
69                return;
70            }
71            // Wait a bit before trying again
72            tokio::time::sleep(Duration::from_millis(10)).await;
73        }
74    }
75
76    /// Get current token count
77    pub fn current_tokens(&mut self) -> f64 {
78        self.refill();
79        self.tokens
80    }
81
82    /// Get time until tokens are available
83    pub fn time_until_available(&mut self, tokens: f64) -> Duration {
84        self.refill();
85        if self.tokens >= tokens {
86            Duration::from_secs(0)
87        } else {
88            let needed = tokens - self.tokens;
89            let seconds = needed / self.tokens_per_second;
90            Duration::from_secs_f64(seconds)
91        }
92    }
93}
94
95/// Exponential backoff strategy
96///
97/// Implements exponential backoff with jitter for retrying failed requests
98pub struct ExponentialBackoff {
99    /// Initial backoff duration
100    initial_delay: Duration,
101    /// Maximum backoff duration
102    max_delay: Duration,
103    /// Backoff multiplier
104    multiplier: f64,
105    /// Current attempt number
106    attempt: u32,
107}
108
109impl ExponentialBackoff {
110    /// Create a new exponential backoff strategy
111    ///
112    /// # Arguments
113    /// * `initial_delay` - Initial backoff duration (e.g., 100ms)
114    /// * `max_delay` - Maximum backoff duration (e.g., 30s)
115    /// * `multiplier` - Backoff multiplier (e.g., 2.0 for doubling)
116    pub fn new(initial_delay: Duration, max_delay: Duration, multiplier: f64) -> Self {
117        Self {
118            initial_delay,
119            max_delay,
120            multiplier,
121            attempt: 0,
122        }
123    }
124
125    /// Get the next backoff duration
126    pub fn next_delay(&mut self) -> Duration {
127        let delay = self.initial_delay.as_secs_f64()
128            * self.multiplier.powi(self.attempt as i32);
129        let delay = Duration::from_secs_f64(delay);
130        let delay = delay.min(self.max_delay);
131
132        // Add jitter (±10%)
133        let jitter = delay.as_secs_f64() * 0.1;
134        let jitter_offset = (rand::random::<f64>() - 0.5) * 2.0 * jitter;
135        let final_delay = (delay.as_secs_f64() + jitter_offset).max(0.0);
136
137        self.attempt += 1;
138        Duration::from_secs_f64(final_delay)
139    }
140
141    /// Reset backoff counter
142    pub fn reset(&mut self) {
143        self.attempt = 0;
144    }
145
146    /// Get current attempt number
147    pub fn attempt(&self) -> u32 {
148        self.attempt
149    }
150}
151
152/// Per-provider rate limiter registry
153pub struct RateLimiterRegistry {
154    limiters: Arc<Mutex<HashMap<String, TokenBucketLimiter>>>,
155}
156
157impl RateLimiterRegistry {
158    /// Create a new rate limiter registry
159    pub fn new() -> Self {
160        Self {
161            limiters: Arc::new(Mutex::new(HashMap::new())),
162        }
163    }
164
165    /// Register a rate limiter for a provider
166    pub fn register(&self, provider_id: &str, limiter: TokenBucketLimiter) {
167        let mut limiters = self.limiters.lock().unwrap();
168        limiters.insert(provider_id.to_string(), limiter);
169    }
170
171    /// Get or create a rate limiter for a provider
172    pub fn get_or_create(&self, provider_id: &str) -> Arc<Mutex<TokenBucketLimiter>> {
173        let mut limiters = self.limiters.lock().unwrap();
174
175        // Return existing limiter if available
176        if limiters.contains_key(provider_id) {
177            // We need to return a reference, but we can't hold the lock
178            // So we'll create a new Arc for each call
179            drop(limiters);
180            return Arc::new(Mutex::new(TokenBucketLimiter::new(10.0, 100.0)));
181        }
182
183        // Create default limiter (10 requests/sec, burst of 100)
184        let limiter = TokenBucketLimiter::new(10.0, 100.0);
185        limiters.insert(provider_id.to_string(), limiter);
186        drop(limiters);
187
188        Arc::new(Mutex::new(TokenBucketLimiter::new(10.0, 100.0)))
189    }
190
191    /// Try to acquire tokens for a provider
192    pub fn try_acquire(&self, provider_id: &str, tokens: f64) -> bool {
193        let mut limiters = self.limiters.lock().unwrap();
194        if let Some(limiter) = limiters.get_mut(provider_id) {
195            limiter.try_acquire(tokens)
196        } else {
197            // No limiter registered, allow request
198            true
199        }
200    }
201
202    /// Wait until tokens are available for a provider
203    pub async fn acquire(&self, provider_id: &str, tokens: f64) {
204        loop {
205            if self.try_acquire(provider_id, tokens) {
206                return;
207            }
208            tokio::time::sleep(Duration::from_millis(10)).await;
209        }
210    }
211}
212
213impl Default for RateLimiterRegistry {
214    fn default() -> Self {
215        Self::new()
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_token_bucket_acquire() {
225        let mut limiter = TokenBucketLimiter::new(10.0, 100.0);
226        assert!(limiter.try_acquire(50.0));
227        let tokens = limiter.current_tokens();
228        // Allow for small floating-point variations
229        assert!((tokens - 50.0).abs() < 0.1);
230    }
231
232    #[test]
233    fn test_token_bucket_rate_limited() {
234        let mut limiter = TokenBucketLimiter::new(10.0, 100.0);
235        // Acquire all tokens
236        assert!(limiter.try_acquire(100.0));
237        // Try to acquire more (should fail)
238        assert!(!limiter.try_acquire(1.0));
239    }
240
241    #[test]
242    fn test_token_bucket_refill() {
243        let mut limiter = TokenBucketLimiter::new(10.0, 100.0);
244        // Acquire all tokens
245        assert!(limiter.try_acquire(100.0));
246        // Wait for refill
247        std::thread::sleep(Duration::from_millis(150));
248        // Should have some tokens now
249        let tokens = limiter.current_tokens();
250        assert!(tokens > 0.0);
251    }
252
253    #[test]
254    fn test_exponential_backoff() {
255        let mut backoff = ExponentialBackoff::new(
256            Duration::from_millis(100),
257            Duration::from_secs(10),
258            2.0,
259        );
260
261        let delay1 = backoff.next_delay();
262        assert!(delay1.as_millis() >= 90 && delay1.as_millis() <= 110);
263
264        let delay2 = backoff.next_delay();
265        assert!(delay2.as_millis() >= 180 && delay2.as_millis() <= 220);
266
267        let delay3 = backoff.next_delay();
268        assert!(delay3.as_millis() >= 360 && delay3.as_millis() <= 440);
269    }
270
271    #[test]
272    fn test_exponential_backoff_max_delay() {
273        let mut backoff = ExponentialBackoff::new(
274            Duration::from_millis(100),
275            Duration::from_secs(1),
276            2.0,
277        );
278
279        // Skip to high attempt number
280        for _ in 0..10 {
281            backoff.next_delay();
282        }
283
284        // Should be capped at max_delay (with small tolerance for timing)
285        let delay = backoff.next_delay();
286        assert!(delay <= Duration::from_millis(1100));
287    }
288
289    #[test]
290    fn test_exponential_backoff_reset() {
291        let mut backoff = ExponentialBackoff::new(
292            Duration::from_millis(100),
293            Duration::from_secs(10),
294            2.0,
295        );
296
297        backoff.next_delay();
298        backoff.next_delay();
299        assert_eq!(backoff.attempt(), 2);
300
301        backoff.reset();
302        assert_eq!(backoff.attempt(), 0);
303    }
304
305    #[test]
306    fn test_rate_limiter_registry() {
307        let registry = RateLimiterRegistry::new();
308        registry.register("openai", TokenBucketLimiter::new(10.0, 100.0));
309
310        assert!(registry.try_acquire("openai", 50.0));
311        assert!(registry.try_acquire("openai", 50.0));
312        assert!(!registry.try_acquire("openai", 1.0));
313    }
314
315    #[test]
316    fn test_rate_limiter_registry_unknown_provider() {
317        let registry = RateLimiterRegistry::new();
318        // Unknown provider should be allowed (no limiter registered)
319        assert!(registry.try_acquire("unknown", 1000.0));
320    }
321}