ricecoder_providers/
rate_limiter.rs1use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10pub struct TokenBucketLimiter {
16 tokens_per_second: f64,
18 max_tokens: f64,
20 tokens: f64,
22 last_refill: Instant,
24}
25
26impl TokenBucketLimiter {
27 pub fn new(tokens_per_second: f64, max_tokens: f64) -> Self {
33 Self {
34 tokens_per_second,
35 max_tokens,
36 tokens: max_tokens,
37 last_refill: Instant::now(),
38 }
39 }
40
41 fn refill(&mut self) {
43 let now = Instant::now();
44 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
45 let new_tokens = elapsed * self.tokens_per_second;
46 self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
47 self.last_refill = now;
48 }
49
50 pub fn try_acquire(&mut self, tokens: f64) -> bool {
54 self.refill();
55 if self.tokens >= tokens {
56 self.tokens -= tokens;
57 true
58 } else {
59 false
60 }
61 }
62
63 pub async fn acquire(&mut self, tokens: f64) {
67 loop {
68 if self.try_acquire(tokens) {
69 return;
70 }
71 tokio::time::sleep(Duration::from_millis(10)).await;
73 }
74 }
75
76 pub fn current_tokens(&mut self) -> f64 {
78 self.refill();
79 self.tokens
80 }
81
82 pub fn time_until_available(&mut self, tokens: f64) -> Duration {
84 self.refill();
85 if self.tokens >= tokens {
86 Duration::from_secs(0)
87 } else {
88 let needed = tokens - self.tokens;
89 let seconds = needed / self.tokens_per_second;
90 Duration::from_secs_f64(seconds)
91 }
92 }
93}
94
95pub struct ExponentialBackoff {
99 initial_delay: Duration,
101 max_delay: Duration,
103 multiplier: f64,
105 attempt: u32,
107}
108
109impl ExponentialBackoff {
110 pub fn new(initial_delay: Duration, max_delay: Duration, multiplier: f64) -> Self {
117 Self {
118 initial_delay,
119 max_delay,
120 multiplier,
121 attempt: 0,
122 }
123 }
124
125 pub fn next_delay(&mut self) -> Duration {
127 let delay = self.initial_delay.as_secs_f64()
128 * self.multiplier.powi(self.attempt as i32);
129 let delay = Duration::from_secs_f64(delay);
130 let delay = delay.min(self.max_delay);
131
132 let jitter = delay.as_secs_f64() * 0.1;
134 let jitter_offset = (rand::random::<f64>() - 0.5) * 2.0 * jitter;
135 let final_delay = (delay.as_secs_f64() + jitter_offset).max(0.0);
136
137 self.attempt += 1;
138 Duration::from_secs_f64(final_delay)
139 }
140
141 pub fn reset(&mut self) {
143 self.attempt = 0;
144 }
145
146 pub fn attempt(&self) -> u32 {
148 self.attempt
149 }
150}
151
152pub struct RateLimiterRegistry {
154 limiters: Arc<Mutex<HashMap<String, TokenBucketLimiter>>>,
155}
156
157impl RateLimiterRegistry {
158 pub fn new() -> Self {
160 Self {
161 limiters: Arc::new(Mutex::new(HashMap::new())),
162 }
163 }
164
165 pub fn register(&self, provider_id: &str, limiter: TokenBucketLimiter) {
167 let mut limiters = self.limiters.lock().unwrap();
168 limiters.insert(provider_id.to_string(), limiter);
169 }
170
171 pub fn get_or_create(&self, provider_id: &str) -> Arc<Mutex<TokenBucketLimiter>> {
173 let mut limiters = self.limiters.lock().unwrap();
174
175 if limiters.contains_key(provider_id) {
177 drop(limiters);
180 return Arc::new(Mutex::new(TokenBucketLimiter::new(10.0, 100.0)));
181 }
182
183 let limiter = TokenBucketLimiter::new(10.0, 100.0);
185 limiters.insert(provider_id.to_string(), limiter);
186 drop(limiters);
187
188 Arc::new(Mutex::new(TokenBucketLimiter::new(10.0, 100.0)))
189 }
190
191 pub fn try_acquire(&self, provider_id: &str, tokens: f64) -> bool {
193 let mut limiters = self.limiters.lock().unwrap();
194 if let Some(limiter) = limiters.get_mut(provider_id) {
195 limiter.try_acquire(tokens)
196 } else {
197 true
199 }
200 }
201
202 pub async fn acquire(&self, provider_id: &str, tokens: f64) {
204 loop {
205 if self.try_acquire(provider_id, tokens) {
206 return;
207 }
208 tokio::time::sleep(Duration::from_millis(10)).await;
209 }
210 }
211}
212
213impl Default for RateLimiterRegistry {
214 fn default() -> Self {
215 Self::new()
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn test_token_bucket_acquire() {
225 let mut limiter = TokenBucketLimiter::new(10.0, 100.0);
226 assert!(limiter.try_acquire(50.0));
227 let tokens = limiter.current_tokens();
228 assert!((tokens - 50.0).abs() < 0.1);
230 }
231
232 #[test]
233 fn test_token_bucket_rate_limited() {
234 let mut limiter = TokenBucketLimiter::new(10.0, 100.0);
235 assert!(limiter.try_acquire(100.0));
237 assert!(!limiter.try_acquire(1.0));
239 }
240
241 #[test]
242 fn test_token_bucket_refill() {
243 let mut limiter = TokenBucketLimiter::new(10.0, 100.0);
244 assert!(limiter.try_acquire(100.0));
246 std::thread::sleep(Duration::from_millis(150));
248 let tokens = limiter.current_tokens();
250 assert!(tokens > 0.0);
251 }
252
253 #[test]
254 fn test_exponential_backoff() {
255 let mut backoff = ExponentialBackoff::new(
256 Duration::from_millis(100),
257 Duration::from_secs(10),
258 2.0,
259 );
260
261 let delay1 = backoff.next_delay();
262 assert!(delay1.as_millis() >= 90 && delay1.as_millis() <= 110);
263
264 let delay2 = backoff.next_delay();
265 assert!(delay2.as_millis() >= 180 && delay2.as_millis() <= 220);
266
267 let delay3 = backoff.next_delay();
268 assert!(delay3.as_millis() >= 360 && delay3.as_millis() <= 440);
269 }
270
271 #[test]
272 fn test_exponential_backoff_max_delay() {
273 let mut backoff = ExponentialBackoff::new(
274 Duration::from_millis(100),
275 Duration::from_secs(1),
276 2.0,
277 );
278
279 for _ in 0..10 {
281 backoff.next_delay();
282 }
283
284 let delay = backoff.next_delay();
286 assert!(delay <= Duration::from_millis(1100));
287 }
288
289 #[test]
290 fn test_exponential_backoff_reset() {
291 let mut backoff = ExponentialBackoff::new(
292 Duration::from_millis(100),
293 Duration::from_secs(10),
294 2.0,
295 );
296
297 backoff.next_delay();
298 backoff.next_delay();
299 assert_eq!(backoff.attempt(), 2);
300
301 backoff.reset();
302 assert_eq!(backoff.attempt(), 0);
303 }
304
305 #[test]
306 fn test_rate_limiter_registry() {
307 let registry = RateLimiterRegistry::new();
308 registry.register("openai", TokenBucketLimiter::new(10.0, 100.0));
309
310 assert!(registry.try_acquire("openai", 50.0));
311 assert!(registry.try_acquire("openai", 50.0));
312 assert!(!registry.try_acquire("openai", 1.0));
313 }
314
315 #[test]
316 fn test_rate_limiter_registry_unknown_provider() {
317 let registry = RateLimiterRegistry::new();
318 assert!(registry.try_acquire("unknown", 1000.0));
320 }
321}