ccxt_core/
rate_limiter.rs1use std::sync::Arc;
31use std::time::{Duration, Instant};
32use tokio::sync::Mutex;
33use tokio::time::sleep;
34
35#[derive(Debug, Clone)]
37pub struct RateLimiterConfig {
38 pub capacity: u32,
40 pub refill_period: Duration,
42 pub refill_amount: u32,
44 pub cost_per_request: u32,
46}
47
48impl RateLimiterConfig {
49 pub fn new(capacity: u32, refill_period: Duration) -> Self {
66 Self {
67 capacity,
68 refill_period,
69 refill_amount: capacity,
70 cost_per_request: 1,
71 }
72 }
73
74 pub fn with_refill_amount(mut self, amount: u32) -> Self {
76 self.refill_amount = amount;
77 self
78 }
79
80 pub fn with_cost_per_request(mut self, cost: u32) -> Self {
82 self.cost_per_request = cost;
83 self
84 }
85}
86
87impl Default for RateLimiterConfig {
88 fn default() -> Self {
89 Self::new(10, Duration::from_secs(1))
91 }
92}
93
94#[derive(Debug)]
96struct RateLimiterState {
97 tokens: u32,
99 last_refill: Instant,
101 config: RateLimiterConfig,
103}
104
105impl RateLimiterState {
106 fn new(config: RateLimiterConfig) -> Self {
107 Self {
108 tokens: config.capacity,
109 last_refill: Instant::now(),
110 config,
111 }
112 }
113
114 fn refill(&mut self) {
116 let now = Instant::now();
117 let elapsed = now.duration_since(self.last_refill);
118
119 if elapsed >= self.config.refill_period {
120 let periods = elapsed.as_secs_f64() / self.config.refill_period.as_secs_f64();
122 #[allow(clippy::cast_possible_truncation)]
123 let tokens_to_add = (periods * f64::from(self.config.refill_amount)) as u32;
124
125 self.tokens = (self.tokens + tokens_to_add).min(self.config.capacity);
127 self.last_refill = now;
128 }
129 }
130
131 fn try_consume(&mut self, cost: u32) -> bool {
133 self.refill();
134
135 if self.tokens >= cost {
136 self.tokens -= cost;
137 true
138 } else {
139 false
140 }
141 }
142
143 fn wait_time(&self, cost: u32) -> Duration {
145 if self.tokens >= cost {
146 return Duration::ZERO;
147 }
148
149 let tokens_needed = cost - self.tokens;
150 let refill_rate =
151 f64::from(self.config.refill_amount) / self.config.refill_period.as_secs_f64();
152 let wait_seconds = f64::from(tokens_needed) / refill_rate;
153
154 Duration::from_secs_f64(wait_seconds)
155 }
156}
157
158#[derive(Debug, Clone)]
162pub struct RateLimiter {
163 state: Arc<Mutex<RateLimiterState>>,
164}
165
166impl Default for RateLimiter {
167 fn default() -> Self {
168 Self::new(RateLimiterConfig::default())
169 }
170}
171
172impl RateLimiter {
173 pub fn new(config: RateLimiterConfig) -> Self {
185 Self {
186 state: Arc::new(Mutex::new(RateLimiterState::new(config))),
187 }
188 }
189
190 pub async fn wait(&self) {
206 self.wait_with_cost(1).await;
207 }
208
209 pub async fn wait_with_cost(&self, cost: u32) {
215 loop {
216 let wait_duration = {
217 let mut state = self.state.lock().await;
218 if state.try_consume(cost) {
219 return;
220 }
221 state.wait_time(cost)
222 };
223
224 if wait_duration > Duration::ZERO {
225 sleep(wait_duration).await;
226 } else {
227 sleep(Duration::from_millis(10)).await;
229 }
230 }
231 }
232
233 pub async fn acquire(&self, cost: u32) {
238 self.wait_with_cost(cost).await;
239 }
240
241 pub async fn try_acquire(&self) -> bool {
260 self.try_acquire_with_cost(1).await
261 }
262
263 pub async fn try_acquire_with_cost(&self, cost: u32) -> bool {
265 let mut state = self.state.lock().await;
266 state.try_consume(cost)
267 }
268
269 pub async fn available_tokens(&self) -> u32 {
271 let mut state = self.state.lock().await;
272 state.refill();
273 state.tokens
274 }
275
276 pub async fn reset(&self) {
278 let mut state = self.state.lock().await;
279 state.tokens = state.config.capacity;
280 state.last_refill = Instant::now();
281 }
282}
283
284#[derive(Debug, Clone)]
289pub struct MultiTierRateLimiter {
290 limiters: Arc<Mutex<std::collections::HashMap<String, RateLimiter>>>,
291}
292
293impl MultiTierRateLimiter {
294 pub fn new() -> Self {
296 Self {
297 limiters: Arc::new(Mutex::new(std::collections::HashMap::new())),
298 }
299 }
300
301 pub async fn add_tier(&self, tier: String, limiter: RateLimiter) {
308 let mut limiters = self.limiters.lock().await;
309 limiters.insert(tier, limiter);
310 }
311
312 pub async fn wait(&self, tier: &str) {
314 let limiter = {
315 let limiters = self.limiters.lock().await;
316 limiters.get(tier).cloned()
317 };
318
319 if let Some(limiter) = limiter {
320 limiter.wait().await;
321 }
322 }
323
324 pub async fn try_acquire(&self, tier: &str) -> bool {
326 let limiter = {
327 let limiters = self.limiters.lock().await;
328 limiters.get(tier).cloned()
329 };
330
331 if let Some(limiter) = limiter {
332 limiter.try_acquire().await
333 } else {
334 true }
336 }
337}
338
339impl Default for MultiTierRateLimiter {
340 fn default() -> Self {
341 Self::new()
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_rate_limiter_config() {
351 let config = RateLimiterConfig::new(100, Duration::from_secs(60));
352 assert_eq!(config.capacity, 100);
353 assert_eq!(config.refill_period, Duration::from_secs(60));
354 assert_eq!(config.refill_amount, 100);
355 assert_eq!(config.cost_per_request, 1);
356 }
357
358 #[test]
359 fn test_rate_limiter_config_custom() {
360 let config = RateLimiterConfig::new(100, Duration::from_secs(60))
361 .with_refill_amount(50)
362 .with_cost_per_request(2);
363
364 assert_eq!(config.refill_amount, 50);
365 assert_eq!(config.cost_per_request, 2);
366 }
367
368 #[tokio::test]
369 async fn test_rate_limiter_basic() {
370 let config = RateLimiterConfig::new(5, Duration::from_secs(1));
371 let limiter = RateLimiter::new(config);
372
373 for _ in 0..5 {
375 assert!(limiter.try_acquire().await);
376 }
377
378 assert!(!limiter.try_acquire().await);
380 }
381
382 #[tokio::test]
383 async fn test_rate_limiter_refill() {
384 let config = RateLimiterConfig::new(2, Duration::from_millis(100));
385 let limiter = RateLimiter::new(config);
386
387 assert!(limiter.try_acquire().await);
389 assert!(limiter.try_acquire().await);
390 assert!(!limiter.try_acquire().await);
391
392 sleep(Duration::from_millis(150)).await;
394
395 assert!(limiter.try_acquire().await);
397 }
398
399 #[tokio::test]
400 async fn test_rate_limiter_wait() {
401 let config = RateLimiterConfig::new(2, Duration::from_millis(100));
402 let limiter = RateLimiter::new(config);
403
404 limiter.wait().await;
406 limiter.wait().await;
407
408 let start = Instant::now();
409 limiter.wait().await; let elapsed = start.elapsed();
411
412 assert!(elapsed >= Duration::from_millis(80));
414 }
415
416 #[tokio::test]
417 async fn test_rate_limiter_custom_cost() {
418 let config = RateLimiterConfig::new(10, Duration::from_secs(1));
419 let limiter = RateLimiter::new(config);
420
421 assert!(limiter.try_acquire_with_cost(5).await);
423 assert_eq!(limiter.available_tokens().await, 5);
424
425 assert!(limiter.try_acquire_with_cost(3).await);
427 assert_eq!(limiter.available_tokens().await, 2);
428
429 assert!(!limiter.try_acquire_with_cost(3).await);
431 }
432
433 #[tokio::test]
434 async fn test_rate_limiter_reset() {
435 let config = RateLimiterConfig::new(5, Duration::from_secs(1));
436 let limiter = RateLimiter::new(config);
437
438 for _ in 0..5 {
440 limiter.wait().await;
441 }
442
443 assert_eq!(limiter.available_tokens().await, 0);
444
445 limiter.reset().await;
447
448 assert_eq!(limiter.available_tokens().await, 5);
449 }
450
451 #[tokio::test]
452 async fn test_multi_tier_rate_limiter() {
453 let multi = MultiTierRateLimiter::new();
454
455 let public_config = RateLimiterConfig::new(10, Duration::from_secs(1));
457 let private_config = RateLimiterConfig::new(5, Duration::from_secs(1));
458
459 multi
460 .add_tier("public".to_string(), RateLimiter::new(public_config))
461 .await;
462 multi
463 .add_tier("private".to_string(), RateLimiter::new(private_config))
464 .await;
465
466 for _ in 0..10 {
468 assert!(multi.try_acquire("public").await);
469 }
470 assert!(!multi.try_acquire("public").await);
471
472 for _ in 0..5 {
474 assert!(multi.try_acquire("private").await);
475 }
476 assert!(!multi.try_acquire("private").await);
477
478 assert!(multi.try_acquire("unknown").await);
480 }
481
482 #[tokio::test]
483 async fn test_concurrent_access() {
484 let config = RateLimiterConfig::new(10, Duration::from_secs(1));
485 let limiter = RateLimiter::new(config);
486
487 let mut handles = vec![];
488
489 for _ in 0..10 {
491 let limiter_clone = limiter.clone();
492 let handle = tokio::spawn(async move {
493 limiter_clone.wait().await;
494 });
495 handles.push(handle);
496 }
497
498 for handle in handles {
500 handle.await.unwrap();
501 }
502
503 assert_eq!(limiter.available_tokens().await, 0);
505 }
506}