ccxt_core/
rate_limiter.rs1use std::sync::Arc;
31use std::time::{Duration, Instant};
32use tokio::sync::Mutex;
33use tokio::time::sleep;
34
35#[derive(Debug, Clone)]
37pub struct RateLimiterConfig {
38 pub capacity: u32,
40 pub refill_period: Duration,
42 pub refill_amount: u32,
44 pub cost_per_request: u32,
46}
47
48impl RateLimiterConfig {
49 pub fn new(capacity: u32, refill_period: Duration) -> Self {
66 Self {
67 capacity,
68 refill_period,
69 refill_amount: capacity,
70 cost_per_request: 1,
71 }
72 }
73
74 pub fn with_refill_amount(mut self, amount: u32) -> Self {
76 self.refill_amount = amount;
77 self
78 }
79
80 pub fn with_cost_per_request(mut self, cost: u32) -> Self {
82 self.cost_per_request = cost;
83 self
84 }
85}
86
87impl Default for RateLimiterConfig {
88 fn default() -> Self {
89 Self::new(10, Duration::from_secs(1))
91 }
92}
93
94#[derive(Debug)]
96struct RateLimiterState {
97 tokens: u32,
99 last_refill: Instant,
101 config: RateLimiterConfig,
103}
104
105impl RateLimiterState {
106 fn new(config: RateLimiterConfig) -> Self {
107 Self {
108 tokens: config.capacity,
109 last_refill: Instant::now(),
110 config,
111 }
112 }
113
114 fn refill(&mut self) {
116 let now = Instant::now();
117 let elapsed = now.duration_since(self.last_refill);
118
119 if elapsed >= self.config.refill_period {
120 let periods = elapsed.as_secs_f64() / self.config.refill_period.as_secs_f64();
122 let tokens_to_add = (periods * self.config.refill_amount as f64) as u32;
123
124 self.tokens = (self.tokens + tokens_to_add).min(self.config.capacity);
126 self.last_refill = now;
127 }
128 }
129
130 fn try_consume(&mut self, cost: u32) -> bool {
132 self.refill();
133
134 if self.tokens >= cost {
135 self.tokens -= cost;
136 true
137 } else {
138 false
139 }
140 }
141
142 fn wait_time(&self, cost: u32) -> Duration {
144 if self.tokens >= cost {
145 return Duration::ZERO;
146 }
147
148 let tokens_needed = cost - self.tokens;
149 let refill_rate =
150 self.config.refill_amount as f64 / self.config.refill_period.as_secs_f64();
151 let wait_seconds = tokens_needed as f64 / refill_rate;
152
153 Duration::from_secs_f64(wait_seconds)
154 }
155}
156
157#[derive(Debug, Clone)]
161pub struct RateLimiter {
162 state: Arc<Mutex<RateLimiterState>>,
163}
164
165impl RateLimiter {
166 pub fn new(config: RateLimiterConfig) -> Self {
178 Self {
179 state: Arc::new(Mutex::new(RateLimiterState::new(config))),
180 }
181 }
182
183 pub fn default() -> Self {
185 Self::new(RateLimiterConfig::default())
186 }
187
188 pub async fn wait(&self) {
204 self.wait_with_cost(1).await;
205 }
206
207 pub async fn wait_with_cost(&self, cost: u32) {
213 loop {
214 let wait_duration = {
215 let mut state = self.state.lock().await;
216 if state.try_consume(cost) {
217 return;
218 }
219 state.wait_time(cost)
220 };
221
222 if wait_duration > Duration::ZERO {
223 sleep(wait_duration).await;
224 } else {
225 sleep(Duration::from_millis(10)).await;
227 }
228 }
229 }
230
231 pub async fn acquire(&self, cost: u32) {
236 self.wait_with_cost(cost).await;
237 }
238
239 pub async fn try_acquire(&self) -> bool {
258 self.try_acquire_with_cost(1).await
259 }
260
261 pub async fn try_acquire_with_cost(&self, cost: u32) -> bool {
263 let mut state = self.state.lock().await;
264 state.try_consume(cost)
265 }
266
267 pub async fn available_tokens(&self) -> u32 {
269 let mut state = self.state.lock().await;
270 state.refill();
271 state.tokens
272 }
273
274 pub async fn reset(&self) {
276 let mut state = self.state.lock().await;
277 state.tokens = state.config.capacity;
278 state.last_refill = Instant::now();
279 }
280}
281
282#[derive(Debug, Clone)]
287pub struct MultiTierRateLimiter {
288 limiters: Arc<Mutex<std::collections::HashMap<String, RateLimiter>>>,
289}
290
291impl MultiTierRateLimiter {
292 pub fn new() -> Self {
294 Self {
295 limiters: Arc::new(Mutex::new(std::collections::HashMap::new())),
296 }
297 }
298
299 pub async fn add_tier(&self, tier: String, limiter: RateLimiter) {
306 let mut limiters = self.limiters.lock().await;
307 limiters.insert(tier, limiter);
308 }
309
310 pub async fn wait(&self, tier: &str) {
312 let limiter = {
313 let limiters = self.limiters.lock().await;
314 limiters.get(tier).cloned()
315 };
316
317 if let Some(limiter) = limiter {
318 limiter.wait().await;
319 }
320 }
321
322 pub async fn try_acquire(&self, tier: &str) -> bool {
324 let limiter = {
325 let limiters = self.limiters.lock().await;
326 limiters.get(tier).cloned()
327 };
328
329 if let Some(limiter) = limiter {
330 limiter.try_acquire().await
331 } else {
332 true }
334 }
335}
336
337impl Default for MultiTierRateLimiter {
338 fn default() -> Self {
339 Self::new()
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn test_rate_limiter_config() {
349 let config = RateLimiterConfig::new(100, Duration::from_secs(60));
350 assert_eq!(config.capacity, 100);
351 assert_eq!(config.refill_period, Duration::from_secs(60));
352 assert_eq!(config.refill_amount, 100);
353 assert_eq!(config.cost_per_request, 1);
354 }
355
356 #[test]
357 fn test_rate_limiter_config_custom() {
358 let config = RateLimiterConfig::new(100, Duration::from_secs(60))
359 .with_refill_amount(50)
360 .with_cost_per_request(2);
361
362 assert_eq!(config.refill_amount, 50);
363 assert_eq!(config.cost_per_request, 2);
364 }
365
366 #[tokio::test]
367 async fn test_rate_limiter_basic() {
368 let config = RateLimiterConfig::new(5, Duration::from_secs(1));
369 let limiter = RateLimiter::new(config);
370
371 for _ in 0..5 {
373 assert!(limiter.try_acquire().await);
374 }
375
376 assert!(!limiter.try_acquire().await);
378 }
379
380 #[tokio::test]
381 async fn test_rate_limiter_refill() {
382 let config = RateLimiterConfig::new(2, Duration::from_millis(100));
383 let limiter = RateLimiter::new(config);
384
385 assert!(limiter.try_acquire().await);
387 assert!(limiter.try_acquire().await);
388 assert!(!limiter.try_acquire().await);
389
390 sleep(Duration::from_millis(150)).await;
392
393 assert!(limiter.try_acquire().await);
395 }
396
397 #[tokio::test]
398 async fn test_rate_limiter_wait() {
399 let config = RateLimiterConfig::new(2, Duration::from_millis(100));
400 let limiter = RateLimiter::new(config);
401
402 limiter.wait().await;
404 limiter.wait().await;
405
406 let start = Instant::now();
407 limiter.wait().await; let elapsed = start.elapsed();
409
410 assert!(elapsed >= Duration::from_millis(80));
412 }
413
414 #[tokio::test]
415 async fn test_rate_limiter_custom_cost() {
416 let config = RateLimiterConfig::new(10, Duration::from_secs(1));
417 let limiter = RateLimiter::new(config);
418
419 assert!(limiter.try_acquire_with_cost(5).await);
421 assert_eq!(limiter.available_tokens().await, 5);
422
423 assert!(limiter.try_acquire_with_cost(3).await);
425 assert_eq!(limiter.available_tokens().await, 2);
426
427 assert!(!limiter.try_acquire_with_cost(3).await);
429 }
430
431 #[tokio::test]
432 async fn test_rate_limiter_reset() {
433 let config = RateLimiterConfig::new(5, Duration::from_secs(1));
434 let limiter = RateLimiter::new(config);
435
436 for _ in 0..5 {
438 limiter.wait().await;
439 }
440
441 assert_eq!(limiter.available_tokens().await, 0);
442
443 limiter.reset().await;
445
446 assert_eq!(limiter.available_tokens().await, 5);
447 }
448
449 #[tokio::test]
450 async fn test_multi_tier_rate_limiter() {
451 let multi = MultiTierRateLimiter::new();
452
453 let public_config = RateLimiterConfig::new(10, Duration::from_secs(1));
455 let private_config = RateLimiterConfig::new(5, Duration::from_secs(1));
456
457 multi
458 .add_tier("public".to_string(), RateLimiter::new(public_config))
459 .await;
460 multi
461 .add_tier("private".to_string(), RateLimiter::new(private_config))
462 .await;
463
464 for _ in 0..10 {
466 assert!(multi.try_acquire("public").await);
467 }
468 assert!(!multi.try_acquire("public").await);
469
470 for _ in 0..5 {
472 assert!(multi.try_acquire("private").await);
473 }
474 assert!(!multi.try_acquire("private").await);
475
476 assert!(multi.try_acquire("unknown").await);
478 }
479
480 #[tokio::test]
481 async fn test_concurrent_access() {
482 let config = RateLimiterConfig::new(10, Duration::from_secs(1));
483 let limiter = RateLimiter::new(config);
484
485 let mut handles = vec![];
486
487 for _ in 0..10 {
489 let limiter_clone = limiter.clone();
490 let handle = tokio::spawn(async move {
491 limiter_clone.wait().await;
492 });
493 handles.push(handle);
494 }
495
496 for handle in handles {
498 handle.await.unwrap();
499 }
500
501 assert_eq!(limiter.available_tokens().await, 0);
503 }
504}