1use crate::error::{ConfigValidationError, ValidationResult};
31use std::sync::Arc;
32use std::time::{Duration, Instant};
33use tokio::sync::Mutex;
34use tokio::time::sleep;
35
36#[derive(Debug, Clone)]
38pub struct RateLimiterConfig {
39 pub capacity: u32,
41 pub refill_period: Duration,
43 pub refill_amount: u32,
45 pub cost_per_request: u32,
47}
48
49impl RateLimiterConfig {
50 pub fn new(capacity: u32, refill_period: Duration) -> Self {
67 Self {
68 capacity,
69 refill_period,
70 refill_amount: capacity,
71 cost_per_request: 1,
72 }
73 }
74
75 pub fn with_refill_amount(mut self, amount: u32) -> Self {
77 self.refill_amount = amount;
78 self
79 }
80
81 pub fn with_cost_per_request(mut self, cost: u32) -> Self {
83 self.cost_per_request = cost;
84 self
85 }
86}
87
88impl Default for RateLimiterConfig {
89 fn default() -> Self {
90 Self::new(10, Duration::from_secs(1))
92 }
93}
94
95impl RateLimiterConfig {
96 pub fn validate(&self) -> Result<ValidationResult, ConfigValidationError> {
125 let mut warnings = Vec::new();
126
127 if self.capacity == 0 {
129 return Err(ConfigValidationError::invalid(
130 "capacity",
131 "capacity cannot be zero",
132 ));
133 }
134
135 if self.refill_period < Duration::from_millis(100) {
137 warnings.push(format!(
138 "refill_period {:?} is very short, may cause high CPU usage",
139 self.refill_period
140 ));
141 }
142
143 Ok(ValidationResult::with_warnings(warnings))
144 }
145}
146
147#[derive(Debug)]
149struct RateLimiterState {
150 tokens: u32,
152 last_refill: Instant,
154 remainder_nanos: u64,
158 config: RateLimiterConfig,
160}
161
162impl RateLimiterState {
163 fn new(config: RateLimiterConfig) -> Self {
164 Self {
165 tokens: config.capacity,
166 last_refill: Instant::now(),
167 remainder_nanos: 0,
168 config,
169 }
170 }
171
172 fn refill(&mut self) {
183 let now = Instant::now();
184 let elapsed_nanos = now.duration_since(self.last_refill).as_nanos(); let period_nanos = self.config.refill_period.as_nanos(); if period_nanos == 0 {
190 return;
191 }
192
193 let total_nanos = u128::from(self.remainder_nanos).saturating_add(elapsed_nanos);
196
197 let complete_periods = total_nanos / period_nanos;
199
200 if complete_periods > 0 {
201 #[allow(clippy::cast_possible_truncation)]
204 let tokens_to_add = (complete_periods * u128::from(self.config.refill_amount))
205 .min(u128::from(u32::MAX)) as u32;
206
207 self.tokens = self
209 .tokens
210 .saturating_add(tokens_to_add)
211 .min(self.config.capacity);
212
213 #[allow(clippy::cast_possible_truncation)]
217 let remainder = (total_nanos % period_nanos) as u64;
218 self.remainder_nanos = remainder;
219 self.last_refill = now;
220 }
221 }
222
223 fn try_consume(&mut self, cost: u32) -> bool {
225 self.refill();
226
227 if self.tokens >= cost {
228 self.tokens -= cost;
229 true
230 } else {
231 false
232 }
233 }
234
235 fn wait_time(&self, cost: u32) -> Duration {
237 if self.tokens >= cost {
238 return Duration::ZERO;
239 }
240
241 let tokens_needed = cost - self.tokens;
242 let refill_rate =
243 f64::from(self.config.refill_amount) / self.config.refill_period.as_secs_f64();
244 let wait_seconds = f64::from(tokens_needed) / refill_rate;
245
246 Duration::from_secs_f64(wait_seconds)
247 }
248}
249
250#[derive(Debug, Clone)]
254pub struct RateLimiter {
255 state: Arc<Mutex<RateLimiterState>>,
256}
257
258impl Default for RateLimiter {
259 fn default() -> Self {
260 Self::new(RateLimiterConfig::default())
261 }
262}
263
264impl RateLimiter {
265 pub fn new(config: RateLimiterConfig) -> Self {
277 Self {
278 state: Arc::new(Mutex::new(RateLimiterState::new(config))),
279 }
280 }
281
282 pub async fn wait(&self) {
298 self.wait_with_cost(1).await;
299 }
300
301 pub async fn wait_with_cost(&self, cost: u32) {
307 loop {
308 let wait_duration = {
309 let mut state = self.state.lock().await;
310 if state.try_consume(cost) {
311 return;
312 }
313 state.wait_time(cost)
314 };
315
316 if wait_duration > Duration::ZERO {
317 sleep(wait_duration).await;
318 } else {
319 sleep(Duration::from_millis(10)).await;
321 }
322 }
323 }
324
325 pub async fn acquire(&self, cost: u32) {
330 self.wait_with_cost(cost).await;
331 }
332
333 pub async fn try_acquire(&self) -> bool {
352 self.try_acquire_with_cost(1).await
353 }
354
355 pub async fn try_acquire_with_cost(&self, cost: u32) -> bool {
357 let mut state = self.state.lock().await;
358 state.try_consume(cost)
359 }
360
361 pub async fn available_tokens(&self) -> u32 {
363 let mut state = self.state.lock().await;
364 state.refill();
365 state.tokens
366 }
367
368 pub async fn reset(&self) {
370 let mut state = self.state.lock().await;
371 state.tokens = state.config.capacity;
372 state.last_refill = Instant::now();
373 state.remainder_nanos = 0;
374 }
375}
376
377#[derive(Debug, Clone)]
382pub struct MultiTierRateLimiter {
383 limiters: Arc<Mutex<std::collections::HashMap<String, RateLimiter>>>,
384}
385
386impl MultiTierRateLimiter {
387 pub fn new() -> Self {
389 Self {
390 limiters: Arc::new(Mutex::new(std::collections::HashMap::new())),
391 }
392 }
393
394 pub async fn add_tier(&self, tier: String, limiter: RateLimiter) {
401 let mut limiters = self.limiters.lock().await;
402 limiters.insert(tier, limiter);
403 }
404
405 pub async fn wait(&self, tier: &str) {
407 let limiter = {
408 let limiters = self.limiters.lock().await;
409 limiters.get(tier).cloned()
410 };
411
412 if let Some(limiter) = limiter {
413 limiter.wait().await;
414 }
415 }
416
417 pub async fn try_acquire(&self, tier: &str) -> bool {
419 let limiter = {
420 let limiters = self.limiters.lock().await;
421 limiters.get(tier).cloned()
422 };
423
424 if let Some(limiter) = limiter {
425 limiter.try_acquire().await
426 } else {
427 true }
429 }
430}
431
432impl Default for MultiTierRateLimiter {
433 fn default() -> Self {
434 Self::new()
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_rate_limiter_config() {
444 let config = RateLimiterConfig::new(100, Duration::from_secs(60));
445 assert_eq!(config.capacity, 100);
446 assert_eq!(config.refill_period, Duration::from_secs(60));
447 assert_eq!(config.refill_amount, 100);
448 assert_eq!(config.cost_per_request, 1);
449 }
450
451 #[test]
452 fn test_rate_limiter_config_custom() {
453 let config = RateLimiterConfig::new(100, Duration::from_secs(60))
454 .with_refill_amount(50)
455 .with_cost_per_request(2);
456
457 assert_eq!(config.refill_amount, 50);
458 assert_eq!(config.cost_per_request, 2);
459 }
460
461 #[tokio::test]
462 async fn test_rate_limiter_basic() {
463 let config = RateLimiterConfig::new(5, Duration::from_secs(1));
464 let limiter = RateLimiter::new(config);
465
466 for _ in 0..5 {
468 assert!(limiter.try_acquire().await);
469 }
470
471 assert!(!limiter.try_acquire().await);
473 }
474
475 #[tokio::test]
476 async fn test_rate_limiter_refill() {
477 let config = RateLimiterConfig::new(2, Duration::from_millis(100));
478 let limiter = RateLimiter::new(config);
479
480 assert!(limiter.try_acquire().await);
482 assert!(limiter.try_acquire().await);
483 assert!(!limiter.try_acquire().await);
484
485 sleep(Duration::from_millis(150)).await;
487
488 assert!(limiter.try_acquire().await);
490 }
491
492 #[tokio::test]
493 async fn test_rate_limiter_wait() {
494 let config = RateLimiterConfig::new(2, Duration::from_millis(100));
495 let limiter = RateLimiter::new(config);
496
497 limiter.wait().await;
499 limiter.wait().await;
500
501 let start = Instant::now();
502 limiter.wait().await; let elapsed = start.elapsed();
504
505 assert!(elapsed >= Duration::from_millis(80));
507 }
508
509 #[tokio::test]
510 async fn test_rate_limiter_custom_cost() {
511 let config = RateLimiterConfig::new(10, Duration::from_secs(1));
512 let limiter = RateLimiter::new(config);
513
514 assert!(limiter.try_acquire_with_cost(5).await);
516 assert_eq!(limiter.available_tokens().await, 5);
517
518 assert!(limiter.try_acquire_with_cost(3).await);
520 assert_eq!(limiter.available_tokens().await, 2);
521
522 assert!(!limiter.try_acquire_with_cost(3).await);
524 }
525
526 #[tokio::test]
527 async fn test_rate_limiter_reset() {
528 let config = RateLimiterConfig::new(5, Duration::from_secs(1));
529 let limiter = RateLimiter::new(config);
530
531 for _ in 0..5 {
533 limiter.wait().await;
534 }
535
536 assert_eq!(limiter.available_tokens().await, 0);
537
538 limiter.reset().await;
540
541 assert_eq!(limiter.available_tokens().await, 5);
542 }
543
544 #[tokio::test]
545 async fn test_multi_tier_rate_limiter() {
546 let multi = MultiTierRateLimiter::new();
547
548 let public_config = RateLimiterConfig::new(10, Duration::from_secs(1));
550 let private_config = RateLimiterConfig::new(5, Duration::from_secs(1));
551
552 multi
553 .add_tier("public".to_string(), RateLimiter::new(public_config))
554 .await;
555 multi
556 .add_tier("private".to_string(), RateLimiter::new(private_config))
557 .await;
558
559 for _ in 0..10 {
561 assert!(multi.try_acquire("public").await);
562 }
563 assert!(!multi.try_acquire("public").await);
564
565 for _ in 0..5 {
567 assert!(multi.try_acquire("private").await);
568 }
569 assert!(!multi.try_acquire("private").await);
570
571 assert!(multi.try_acquire("unknown").await);
573 }
574
575 #[tokio::test]
576 async fn test_concurrent_access() {
577 let config = RateLimiterConfig::new(10, Duration::from_secs(1));
578 let limiter = RateLimiter::new(config);
579
580 let mut handles = vec![];
581
582 for _ in 0..10 {
584 let limiter_clone = limiter.clone();
585 let handle = tokio::spawn(async move {
586 limiter_clone.wait().await;
587 });
588 handles.push(handle);
589 }
590
591 for handle in handles {
593 handle.await.unwrap();
594 }
595
596 assert_eq!(limiter.available_tokens().await, 0);
598 }
599
600 #[test]
601 fn test_rate_limiter_config_validate_default() {
602 let config = RateLimiterConfig::default();
603 let result = config.validate();
604 assert!(result.is_ok());
605 assert!(result.unwrap().warnings.is_empty());
606 }
607
608 #[test]
609 fn test_rate_limiter_config_validate_zero_capacity() {
610 let config = RateLimiterConfig::new(0, Duration::from_secs(1));
611 let result = config.validate();
612 assert!(result.is_err());
613 let err = result.unwrap_err();
614 assert_eq!(err.field_name(), "capacity");
615 assert!(matches!(
616 err,
617 crate::error::ConfigValidationError::ValueInvalid { .. }
618 ));
619 }
620
621 #[test]
622 fn test_rate_limiter_config_validate_short_refill_period_warning() {
623 let config = RateLimiterConfig::new(10, Duration::from_millis(50));
624 let result = config.validate();
625 assert!(result.is_ok());
626 let validation_result = result.unwrap();
627 assert!(!validation_result.warnings.is_empty());
628 assert!(validation_result.warnings[0].contains("refill_period"));
629 assert!(validation_result.warnings[0].contains("very short"));
630 }
631
632 #[test]
633 fn test_rate_limiter_config_validate_refill_period_boundary() {
634 let config = RateLimiterConfig::new(10, Duration::from_millis(100));
636 let result = config.validate();
637 assert!(result.is_ok());
638 assert!(result.unwrap().warnings.is_empty());
639
640 let config = RateLimiterConfig::new(10, Duration::from_millis(99));
642 let result = config.validate();
643 assert!(result.is_ok());
644 assert!(!result.unwrap().warnings.is_empty());
645 }
646
647 #[test]
648 fn test_rate_limiter_config_validate_valid_config() {
649 let config = RateLimiterConfig::new(100, Duration::from_secs(60));
650 let result = config.validate();
651 assert!(result.is_ok());
652 assert!(result.unwrap().warnings.is_empty());
653 }
654
655 #[test]
658 fn test_rate_limiter_integer_precision() {
659 let config = RateLimiterConfig::new(100, Duration::from_millis(100)).with_refill_amount(10);
661 let mut state = RateLimiterState::new(config);
662
663 state.tokens = 0;
665
666 let period_nanos = 100_000_000u64; let small_increment = 33_333_333u64; state.remainder_nanos = small_increment;
674 state.remainder_nanos += small_increment;
675 state.remainder_nanos += small_increment;
676
677 let complete_periods = state.remainder_nanos / period_nanos;
679 assert_eq!(complete_periods, 0); state.remainder_nanos += 1;
683 let complete_periods = state.remainder_nanos / period_nanos;
684 assert_eq!(complete_periods, 1); let expected_remainder = (small_increment * 3 + 1) % period_nanos;
688 assert_eq!(state.remainder_nanos % period_nanos, expected_remainder);
689 }
690
691 #[tokio::test]
693 async fn test_rate_limiter_reset_clears_remainder() {
694 let config = RateLimiterConfig::new(5, Duration::from_secs(1));
695 let limiter = RateLimiter::new(config);
696
697 for _ in 0..5 {
699 limiter.wait().await;
700 }
701
702 sleep(Duration::from_millis(50)).await;
704
705 limiter.reset().await;
707
708 assert_eq!(limiter.available_tokens().await, 5);
710 }
711
712 #[test]
714 fn test_rate_limiter_refill_zero_period_protection() {
715 let config = RateLimiterConfig {
718 capacity: 10,
719 refill_period: Duration::ZERO, refill_amount: 5,
721 cost_per_request: 1,
722 };
723 let mut state = RateLimiterState::new(config);
724 state.tokens = 0;
725
726 state.refill();
728
729 assert_eq!(state.tokens, 0);
731 }
732
733 #[test]
735 fn test_rate_limiter_refill_overflow_protection() {
736 let config =
737 RateLimiterConfig::new(u32::MAX, Duration::from_nanos(1)).with_refill_amount(u32::MAX);
738 let mut state = RateLimiterState::new(config);
739 state.tokens = 0;
740
741 state.remainder_nanos = u64::MAX / 2;
743
744 let period_nanos = 1u64;
747 let complete_periods = state.remainder_nanos / period_nanos;
748 let tokens_to_add = complete_periods
749 .saturating_mul(u64::from(state.config.refill_amount))
750 .min(u64::from(u32::MAX)) as u32;
751
752 assert_eq!(tokens_to_add, u32::MAX);
754 }
755}