1use crate::{LlmError, LlmProvider, LlmRequest, LlmResponse};
24use async_trait::async_trait;
25use std::sync::Arc;
26use std::time::{Duration, Instant};
27use tokio::sync::Mutex;
28
29#[derive(Debug, Clone)]
31pub struct RateLimitConfig {
32 pub requests_per_minute: u32,
34 pub tokens_per_minute: u32,
36 refill_rate: f64,
38 token_refill_rate: f64,
40}
41
42impl Default for RateLimitConfig {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl RateLimitConfig {
49 pub fn new() -> Self {
51 Self {
52 requests_per_minute: 0,
53 tokens_per_minute: 0,
54 refill_rate: 0.0,
55 token_refill_rate: 0.0,
56 }
57 }
58
59 pub fn with_requests_per_minute(mut self, rpm: u32) -> Self {
61 self.requests_per_minute = rpm;
62 self.refill_rate = rpm as f64 / 60.0; self
64 }
65
66 pub fn with_tokens_per_minute(mut self, tpm: u32) -> Self {
68 self.tokens_per_minute = tpm;
69 self.token_refill_rate = tpm as f64 / 60.0; self
71 }
72}
73
74#[derive(Debug)]
76struct TokenBucket {
77 capacity: f64,
78 tokens: f64,
79 refill_rate: f64,
80 last_refill: Instant,
81}
82
83impl TokenBucket {
84 fn new(capacity: u32, refill_rate: f64) -> Self {
85 Self {
86 capacity: capacity as f64,
87 tokens: capacity as f64,
88 refill_rate,
89 last_refill: Instant::now(),
90 }
91 }
92
93 fn refill(&mut self) {
94 let now = Instant::now();
95 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
96 let new_tokens = elapsed * self.refill_rate;
97
98 self.tokens = (self.tokens + new_tokens).min(self.capacity);
99 self.last_refill = now;
100 }
101
102 fn try_acquire(&mut self, count: f64) -> bool {
103 self.refill();
104
105 if self.tokens >= count {
106 self.tokens -= count;
107 true
108 } else {
109 false
110 }
111 }
112
113 fn wait_time(&mut self, count: f64) -> Duration {
114 self.refill();
115
116 if self.tokens >= count {
117 Duration::from_secs(0)
118 } else {
119 let deficit = count - self.tokens;
120 let wait_secs = deficit / self.refill_rate;
121 Duration::from_secs_f64(wait_secs)
122 }
123 }
124}
125
126#[derive(Debug)]
128struct RateLimiterState {
129 request_bucket: Option<TokenBucket>,
130 token_bucket: Option<TokenBucket>,
131}
132
133impl RateLimiterState {
134 fn new(config: &RateLimitConfig) -> Self {
135 let request_bucket = if config.requests_per_minute > 0 {
136 Some(TokenBucket::new(
137 config.requests_per_minute,
138 config.refill_rate,
139 ))
140 } else {
141 None
142 };
143
144 let token_bucket = if config.tokens_per_minute > 0 {
145 Some(TokenBucket::new(
146 config.tokens_per_minute,
147 config.token_refill_rate,
148 ))
149 } else {
150 None
151 };
152
153 Self {
154 request_bucket,
155 token_bucket,
156 }
157 }
158
159 async fn acquire(&mut self, estimated_tokens: u32) -> Result<(), Duration> {
160 if let Some(bucket) = &mut self.request_bucket {
162 if !bucket.try_acquire(1.0) {
163 return Err(bucket.wait_time(1.0));
164 }
165 }
166
167 if let Some(bucket) = &mut self.token_bucket {
169 if !bucket.try_acquire(estimated_tokens as f64) {
170 return Err(bucket.wait_time(estimated_tokens as f64));
171 }
172 }
173
174 Ok(())
175 }
176}
177
178pub struct RateLimitProvider {
180 provider: Box<dyn LlmProvider>,
181 state: Arc<Mutex<RateLimiterState>>,
182 config: RateLimitConfig,
183}
184
185impl RateLimitProvider {
186 pub fn new(provider: Box<dyn LlmProvider>, config: RateLimitConfig) -> Self {
188 let state = Arc::new(Mutex::new(RateLimiterState::new(&config)));
189 Self {
190 provider,
191 state,
192 config,
193 }
194 }
195
196 pub async fn get_stats(&self) -> RateLimitStats {
198 let state = self.state.lock().await;
199
200 let available_requests = state
201 .request_bucket
202 .as_ref()
203 .map(|b| b.tokens as u32)
204 .unwrap_or(0);
205
206 let available_tokens = state
207 .token_bucket
208 .as_ref()
209 .map(|b| b.tokens as u32)
210 .unwrap_or(0);
211
212 RateLimitStats {
213 requests_per_minute: self.config.requests_per_minute,
214 tokens_per_minute: self.config.tokens_per_minute,
215 available_requests,
216 available_tokens,
217 }
218 }
219
220 fn estimate_tokens(request: &LlmRequest) -> u32 {
222 let prompt_len = request.prompt.len();
223 let system_len = request.system_prompt.as_ref().map(|s| s.len()).unwrap_or(0);
224 let total_chars = prompt_len + system_len;
225
226 ((total_chars / 4) as u32).max(1)
228 }
229}
230
231#[derive(Debug, Clone)]
233pub struct RateLimitStats {
234 pub requests_per_minute: u32,
236 pub tokens_per_minute: u32,
238 pub available_requests: u32,
240 pub available_tokens: u32,
242}
243
244#[async_trait]
245impl LlmProvider for RateLimitProvider {
246 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
247 let estimated_tokens = Self::estimate_tokens(&request);
248
249 loop {
251 let result = {
252 let mut state = self.state.lock().await;
253 state.acquire(estimated_tokens).await
254 };
255
256 match result {
257 Ok(()) => break,
258 Err(wait_time) => {
259 tokio::time::sleep(wait_time).await;
261 }
262 }
263 }
264
265 self.provider.complete(request).await
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::Usage;
274 use std::sync::atomic::{AtomicU32, Ordering};
275
276 struct MockProvider {
277 call_count: Arc<AtomicU32>,
278 }
279
280 #[async_trait]
281 impl LlmProvider for MockProvider {
282 async fn complete(&self, _request: LlmRequest) -> Result<LlmResponse, LlmError> {
283 self.call_count.fetch_add(1, Ordering::SeqCst);
284 Ok(LlmResponse {
285 content: "Success".to_string(),
286 model: "mock".to_string(),
287 usage: Some(Usage {
288 prompt_tokens: 10,
289 completion_tokens: 20,
290 total_tokens: 30,
291 }),
292 tool_calls: Vec::new(),
293 })
294 }
295 }
296
297 #[tokio::test]
298 async fn test_rate_limit_requests() {
299 let call_count = Arc::new(AtomicU32::new(0));
300 let mock = MockProvider {
301 call_count: Arc::clone(&call_count),
302 };
303
304 let config = RateLimitConfig::new().with_requests_per_minute(10);
306 let rate_limited = RateLimitProvider::new(Box::new(mock), config);
307
308 let start = Instant::now();
310 for _ in 0..5 {
311 let request = LlmRequest {
312 prompt: "test".to_string(),
313 system_prompt: None,
314 temperature: None,
315 max_tokens: None,
316 tools: Vec::new(),
317 images: Vec::new(),
318 };
319 rate_limited.complete(request).await.unwrap();
320 }
321 let elapsed = start.elapsed();
322
323 assert_eq!(call_count.load(Ordering::SeqCst), 5);
324 assert!(elapsed < Duration::from_secs(1));
326 }
327
328 #[tokio::test]
329 async fn test_rate_limit_stats() {
330 let mock = MockProvider {
331 call_count: Arc::new(AtomicU32::new(0)),
332 };
333
334 let config = RateLimitConfig::new()
335 .with_requests_per_minute(60)
336 .with_tokens_per_minute(100_000);
337
338 let rate_limited = RateLimitProvider::new(Box::new(mock), config);
339
340 let stats = rate_limited.get_stats().await;
341 assert_eq!(stats.requests_per_minute, 60);
342 assert_eq!(stats.tokens_per_minute, 100_000);
343 assert!(stats.available_requests <= 60);
344 assert!(stats.available_tokens <= 100_000);
345 }
346
347 #[tokio::test]
348 async fn test_rate_limit_config() {
349 let config = RateLimitConfig::new()
350 .with_requests_per_minute(120)
351 .with_tokens_per_minute(200_000);
352
353 assert_eq!(config.requests_per_minute, 120);
354 assert_eq!(config.tokens_per_minute, 200_000);
355 assert_eq!(config.refill_rate, 2.0); assert_eq!(config.token_refill_rate, 200_000.0 / 60.0);
357 }
358
359 #[tokio::test]
360 async fn test_token_estimation() {
361 let request = LlmRequest {
362 prompt: "Hello world this is a test prompt".to_string(), system_prompt: Some("You are a helpful assistant".to_string()), temperature: None,
365 max_tokens: None,
366 tools: Vec::new(),
367 images: Vec::new(),
368 };
369
370 let tokens = RateLimitProvider::estimate_tokens(&request);
371 assert!((10..=20).contains(&tokens));
373 }
374
375 #[test]
376 fn test_token_bucket_refill() {
377 let mut bucket = TokenBucket::new(100, 10.0); assert!(bucket.try_acquire(100.0));
381 assert!(!bucket.try_acquire(1.0));
382
383 std::thread::sleep(Duration::from_millis(500));
385 bucket.refill();
386
387 assert!(bucket.tokens >= 4.0 && bucket.tokens <= 6.0);
389 }
390}