oxify_connect_llm/
rate_limit.rs

1//! Rate limiting for LLM providers.
2//!
3//! Prevents hitting API rate limits by throttling requests based on configured limits.
4//! Supports both request-based and token-based rate limiting with a token bucket algorithm.
5//!
6//! # Example
7//!
8//! ```rust,no_run
9//! use oxify_connect_llm::{RateLimitProvider, RateLimitConfig, LlmProvider};
10//!
11//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
12//! # let provider: Box<dyn LlmProvider> = todo!();
13//! // Limit to 60 requests per minute
14//! let config = RateLimitConfig::new()
15//!     .with_requests_per_minute(60)
16//!     .with_tokens_per_minute(100_000);
17//!
18//! let rate_limited = RateLimitProvider::new(provider, config);
19//! # Ok(())
20//! # }
21//! ```
22
23use crate::{LlmError, LlmProvider, LlmRequest, LlmResponse};
24use async_trait::async_trait;
25use std::sync::Arc;
26use std::time::{Duration, Instant};
27use tokio::sync::Mutex;
28
29/// Rate limit configuration
30#[derive(Debug, Clone)]
31pub struct RateLimitConfig {
32    /// Maximum requests per minute (0 = unlimited)
33    pub requests_per_minute: u32,
34    /// Maximum tokens per minute (0 = unlimited)
35    pub tokens_per_minute: u32,
36    /// Refill rate for token bucket (requests per second)
37    refill_rate: f64,
38    /// Refill rate for tokens (tokens per second)
39    token_refill_rate: f64,
40}
41
42impl Default for RateLimitConfig {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl RateLimitConfig {
49    /// Create a new rate limit configuration with no limits
50    pub fn new() -> Self {
51        Self {
52            requests_per_minute: 0,
53            tokens_per_minute: 0,
54            refill_rate: 0.0,
55            token_refill_rate: 0.0,
56        }
57    }
58
59    /// Set the maximum requests per minute
60    pub fn with_requests_per_minute(mut self, rpm: u32) -> Self {
61        self.requests_per_minute = rpm;
62        self.refill_rate = rpm as f64 / 60.0; // Convert to per second
63        self
64    }
65
66    /// Set the maximum tokens per minute
67    pub fn with_tokens_per_minute(mut self, tpm: u32) -> Self {
68        self.tokens_per_minute = tpm;
69        self.token_refill_rate = tpm as f64 / 60.0; // Convert to per second
70        self
71    }
72}
73
74/// Token bucket for rate limiting
75#[derive(Debug)]
76struct TokenBucket {
77    capacity: f64,
78    tokens: f64,
79    refill_rate: f64,
80    last_refill: Instant,
81}
82
83impl TokenBucket {
84    fn new(capacity: u32, refill_rate: f64) -> Self {
85        Self {
86            capacity: capacity as f64,
87            tokens: capacity as f64,
88            refill_rate,
89            last_refill: Instant::now(),
90        }
91    }
92
93    fn refill(&mut self) {
94        let now = Instant::now();
95        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
96        let new_tokens = elapsed * self.refill_rate;
97
98        self.tokens = (self.tokens + new_tokens).min(self.capacity);
99        self.last_refill = now;
100    }
101
102    fn try_acquire(&mut self, count: f64) -> bool {
103        self.refill();
104
105        if self.tokens >= count {
106            self.tokens -= count;
107            true
108        } else {
109            false
110        }
111    }
112
113    fn wait_time(&mut self, count: f64) -> Duration {
114        self.refill();
115
116        if self.tokens >= count {
117            Duration::from_secs(0)
118        } else {
119            let deficit = count - self.tokens;
120            let wait_secs = deficit / self.refill_rate;
121            Duration::from_secs_f64(wait_secs)
122        }
123    }
124}
125
126/// Rate limiter state
127#[derive(Debug)]
128struct RateLimiterState {
129    request_bucket: Option<TokenBucket>,
130    token_bucket: Option<TokenBucket>,
131}
132
133impl RateLimiterState {
134    fn new(config: &RateLimitConfig) -> Self {
135        let request_bucket = if config.requests_per_minute > 0 {
136            Some(TokenBucket::new(
137                config.requests_per_minute,
138                config.refill_rate,
139            ))
140        } else {
141            None
142        };
143
144        let token_bucket = if config.tokens_per_minute > 0 {
145            Some(TokenBucket::new(
146                config.tokens_per_minute,
147                config.token_refill_rate,
148            ))
149        } else {
150            None
151        };
152
153        Self {
154            request_bucket,
155            token_bucket,
156        }
157    }
158
159    async fn acquire(&mut self, estimated_tokens: u32) -> Result<(), Duration> {
160        // Check request limit
161        if let Some(bucket) = &mut self.request_bucket {
162            if !bucket.try_acquire(1.0) {
163                return Err(bucket.wait_time(1.0));
164            }
165        }
166
167        // Check token limit
168        if let Some(bucket) = &mut self.token_bucket {
169            if !bucket.try_acquire(estimated_tokens as f64) {
170                return Err(bucket.wait_time(estimated_tokens as f64));
171            }
172        }
173
174        Ok(())
175    }
176}
177
178/// Rate limiting provider wrapper
179pub struct RateLimitProvider {
180    provider: Box<dyn LlmProvider>,
181    state: Arc<Mutex<RateLimiterState>>,
182    config: RateLimitConfig,
183}
184
185impl RateLimitProvider {
186    /// Create a new rate-limited provider
187    pub fn new(provider: Box<dyn LlmProvider>, config: RateLimitConfig) -> Self {
188        let state = Arc::new(Mutex::new(RateLimiterState::new(&config)));
189        Self {
190            provider,
191            state,
192            config,
193        }
194    }
195
196    /// Get current rate limit statistics
197    pub async fn get_stats(&self) -> RateLimitStats {
198        let state = self.state.lock().await;
199
200        let available_requests = state
201            .request_bucket
202            .as_ref()
203            .map(|b| b.tokens as u32)
204            .unwrap_or(0);
205
206        let available_tokens = state
207            .token_bucket
208            .as_ref()
209            .map(|b| b.tokens as u32)
210            .unwrap_or(0);
211
212        RateLimitStats {
213            requests_per_minute: self.config.requests_per_minute,
214            tokens_per_minute: self.config.tokens_per_minute,
215            available_requests,
216            available_tokens,
217        }
218    }
219
220    /// Estimate tokens in a request (simple heuristic: 4 chars per token)
221    fn estimate_tokens(request: &LlmRequest) -> u32 {
222        let prompt_len = request.prompt.len();
223        let system_len = request.system_prompt.as_ref().map(|s| s.len()).unwrap_or(0);
224        let total_chars = prompt_len + system_len;
225
226        // Simple heuristic: ~4 characters per token
227        ((total_chars / 4) as u32).max(1)
228    }
229}
230
231/// Rate limit statistics
232#[derive(Debug, Clone)]
233pub struct RateLimitStats {
234    /// Maximum requests per minute
235    pub requests_per_minute: u32,
236    /// Maximum tokens per minute
237    pub tokens_per_minute: u32,
238    /// Currently available requests
239    pub available_requests: u32,
240    /// Currently available tokens
241    pub available_tokens: u32,
242}
243
244#[async_trait]
245impl LlmProvider for RateLimitProvider {
246    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
247        let estimated_tokens = Self::estimate_tokens(&request);
248
249        // Try to acquire rate limit tokens
250        loop {
251            let result = {
252                let mut state = self.state.lock().await;
253                state.acquire(estimated_tokens).await
254            };
255
256            match result {
257                Ok(()) => break,
258                Err(wait_time) => {
259                    // Wait and retry
260                    tokio::time::sleep(wait_time).await;
261                }
262            }
263        }
264
265        // Make the actual request
266        self.provider.complete(request).await
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use crate::Usage;
274    use std::sync::atomic::{AtomicU32, Ordering};
275
276    struct MockProvider {
277        call_count: Arc<AtomicU32>,
278    }
279
280    #[async_trait]
281    impl LlmProvider for MockProvider {
282        async fn complete(&self, _request: LlmRequest) -> Result<LlmResponse, LlmError> {
283            self.call_count.fetch_add(1, Ordering::SeqCst);
284            Ok(LlmResponse {
285                content: "Success".to_string(),
286                model: "mock".to_string(),
287                usage: Some(Usage {
288                    prompt_tokens: 10,
289                    completion_tokens: 20,
290                    total_tokens: 30,
291                }),
292                tool_calls: Vec::new(),
293            })
294        }
295    }
296
297    #[tokio::test]
298    async fn test_rate_limit_requests() {
299        let call_count = Arc::new(AtomicU32::new(0));
300        let mock = MockProvider {
301            call_count: Arc::clone(&call_count),
302        };
303
304        // Allow 10 requests per minute
305        let config = RateLimitConfig::new().with_requests_per_minute(10);
306        let rate_limited = RateLimitProvider::new(Box::new(mock), config);
307
308        // Make 5 requests - should all succeed quickly
309        let start = Instant::now();
310        for _ in 0..5 {
311            let request = LlmRequest {
312                prompt: "test".to_string(),
313                system_prompt: None,
314                temperature: None,
315                max_tokens: None,
316                tools: Vec::new(),
317                images: Vec::new(),
318            };
319            rate_limited.complete(request).await.unwrap();
320        }
321        let elapsed = start.elapsed();
322
323        assert_eq!(call_count.load(Ordering::SeqCst), 5);
324        // Should complete quickly (within 1 second)
325        assert!(elapsed < Duration::from_secs(1));
326    }
327
328    #[tokio::test]
329    async fn test_rate_limit_stats() {
330        let mock = MockProvider {
331            call_count: Arc::new(AtomicU32::new(0)),
332        };
333
334        let config = RateLimitConfig::new()
335            .with_requests_per_minute(60)
336            .with_tokens_per_minute(100_000);
337
338        let rate_limited = RateLimitProvider::new(Box::new(mock), config);
339
340        let stats = rate_limited.get_stats().await;
341        assert_eq!(stats.requests_per_minute, 60);
342        assert_eq!(stats.tokens_per_minute, 100_000);
343        assert!(stats.available_requests <= 60);
344        assert!(stats.available_tokens <= 100_000);
345    }
346
347    #[tokio::test]
348    async fn test_rate_limit_config() {
349        let config = RateLimitConfig::new()
350            .with_requests_per_minute(120)
351            .with_tokens_per_minute(200_000);
352
353        assert_eq!(config.requests_per_minute, 120);
354        assert_eq!(config.tokens_per_minute, 200_000);
355        assert_eq!(config.refill_rate, 2.0); // 120/60
356        assert_eq!(config.token_refill_rate, 200_000.0 / 60.0);
357    }
358
359    #[tokio::test]
360    async fn test_token_estimation() {
361        let request = LlmRequest {
362            prompt: "Hello world this is a test prompt".to_string(), // ~35 chars
363            system_prompt: Some("You are a helpful assistant".to_string()), // ~27 chars
364            temperature: None,
365            max_tokens: None,
366            tools: Vec::new(),
367            images: Vec::new(),
368        };
369
370        let tokens = RateLimitProvider::estimate_tokens(&request);
371        // Total ~62 chars / 4 = ~15 tokens
372        assert!((10..=20).contains(&tokens));
373    }
374
375    #[test]
376    fn test_token_bucket_refill() {
377        let mut bucket = TokenBucket::new(100, 10.0); // 100 capacity, 10/sec refill
378
379        // Consume all tokens
380        assert!(bucket.try_acquire(100.0));
381        assert!(!bucket.try_acquire(1.0));
382
383        // Wait a bit and tokens should refill
384        std::thread::sleep(Duration::from_millis(500));
385        bucket.refill();
386
387        // Should have ~5 tokens now (0.5 sec * 10/sec)
388        assert!(bucket.tokens >= 4.0 && bucket.tokens <= 6.0);
389    }
390}