1#[cfg(feature = "redis")]
36use anyhow::anyhow;
37use anyhow::Result;
38use chrono::{DateTime, Duration as ChronoDuration, Utc};
39use serde::{Deserialize, Serialize};
40use std::collections::{HashMap, VecDeque};
41use std::sync::Arc;
42use tokio::sync::RwLock;
43use tracing::{debug, info, warn};
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub enum RateLimitAlgorithm {
48 TokenBucket { capacity: u64, refill_rate: u64 },
52
53 SlidingWindow {
57 window_size: ChronoDuration,
58 max_requests: u64,
59 },
60
61 LeakyBucket { capacity: u64, leak_rate: u64 },
65
66 FixedWindow {
70 window_size: ChronoDuration,
71 max_requests: u64,
72 },
73
74 Adaptive {
78 base_limit: u64,
79 adjustment_factor: f64,
80 },
81}
82
83impl Default for RateLimitAlgorithm {
84 fn default() -> Self {
85 Self::TokenBucket {
86 capacity: 1000,
87 refill_rate: 100,
88 }
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct RateLimitConfig {
95 pub algorithm: RateLimitAlgorithm,
97
98 pub distributed: bool,
100
101 pub redis_url: Option<String>,
103
104 pub per_tenant_quotas: bool,
106
107 pub default_quota: QuotaLimits,
109
110 pub enable_adaptive: bool,
112
113 pub monitoring: RateLimitMonitoringConfig,
115
116 pub rejection_strategy: RejectionStrategy,
118}
119
120impl Default for RateLimitConfig {
121 fn default() -> Self {
122 Self {
123 algorithm: RateLimitAlgorithm::default(),
124 distributed: false,
125 redis_url: None,
126 per_tenant_quotas: true,
127 default_quota: QuotaLimits::default(),
128 enable_adaptive: true,
129 monitoring: RateLimitMonitoringConfig::default(),
130 rejection_strategy: RejectionStrategy::ImmediateReject,
131 }
132 }
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct QuotaLimits {
138 pub requests_per_second: u64,
140
141 pub requests_per_minute: u64,
143
144 pub requests_per_hour: u64,
146
147 pub requests_per_day: u64,
149
150 pub bandwidth_bytes_per_second: u64,
152
153 pub max_concurrent_requests: u32,
155
156 pub max_burst: u64,
158}
159
160impl Default for QuotaLimits {
161 fn default() -> Self {
162 Self {
163 requests_per_second: 100,
164 requests_per_minute: 5000,
165 requests_per_hour: 100_000,
166 requests_per_day: 1_000_000,
167 bandwidth_bytes_per_second: 10_485_760, max_concurrent_requests: 100,
169 max_burst: 200,
170 }
171 }
172}
173
174#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
176pub enum RejectionStrategy {
177 ImmediateReject,
179
180 QueueWithTimeout(u64), ExponentialBackoff {
185 initial_delay_ms: u64,
186 max_delay_ms: u64,
187 },
188
189 BestEffort,
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct RateLimitMonitoringConfig {
196 pub enable_metrics: bool,
198
199 pub metrics_interval: ChronoDuration,
201
202 pub enable_alerts: bool,
204
205 pub alert_threshold: f64,
207
208 pub alert_cooldown: ChronoDuration,
210}
211
212impl Default for RateLimitMonitoringConfig {
213 fn default() -> Self {
214 Self {
215 enable_metrics: true,
216 metrics_interval: ChronoDuration::seconds(60),
217 enable_alerts: true,
218 alert_threshold: 0.9, alert_cooldown: ChronoDuration::minutes(5),
220 }
221 }
222}
223
224#[derive(Debug, Clone)]
226struct TokenBucketState {
227 tokens: f64,
228 capacity: u64,
229 refill_rate: u64,
230 last_refill: DateTime<Utc>,
231}
232
233impl TokenBucketState {
234 fn new(capacity: u64, refill_rate: u64) -> Self {
235 Self {
236 tokens: capacity as f64,
237 capacity,
238 refill_rate,
239 last_refill: Utc::now(),
240 }
241 }
242
243 fn refill(&mut self) {
244 let now = Utc::now();
245 let elapsed = now.signed_duration_since(self.last_refill);
246 let seconds = elapsed.num_milliseconds() as f64 / 1000.0;
247
248 let new_tokens = seconds * self.refill_rate as f64;
249 self.tokens = (self.tokens + new_tokens).min(self.capacity as f64);
250 self.last_refill = now;
251 }
252
253 fn consume(&mut self, tokens: u64) -> bool {
254 self.refill();
255
256 if self.tokens >= tokens as f64 {
257 self.tokens -= tokens as f64;
258 true
259 } else {
260 false
261 }
262 }
263
264 fn available_tokens(&self) -> u64 {
265 self.tokens.floor() as u64
266 }
267}
268
269#[derive(Debug, Clone)]
271struct SlidingWindowState {
272 requests: VecDeque<DateTime<Utc>>,
273 window_size: ChronoDuration,
274 max_requests: u64,
275}
276
277impl SlidingWindowState {
278 fn new(window_size: ChronoDuration, max_requests: u64) -> Self {
279 Self {
280 requests: VecDeque::new(),
281 window_size,
282 max_requests,
283 }
284 }
285
286 fn cleanup(&mut self) {
287 let now = Utc::now();
288 let cutoff = now - self.window_size;
289
290 while let Some(&oldest) = self.requests.front() {
291 if oldest < cutoff {
292 self.requests.pop_front();
293 } else {
294 break;
295 }
296 }
297 }
298
299 fn allow(&mut self) -> bool {
300 self.cleanup();
301
302 if self.requests.len() < self.max_requests as usize {
303 self.requests.push_back(Utc::now());
304 true
305 } else {
306 false
307 }
308 }
309
310 fn current_count(&self) -> usize {
311 self.requests.len()
312 }
313}
314
315#[derive(Debug, Clone)]
317struct LeakyBucketState {
318 queue_size: u64,
319 capacity: u64,
320 leak_rate: u64,
321 last_leak: DateTime<Utc>,
322}
323
324impl LeakyBucketState {
325 fn new(capacity: u64, leak_rate: u64) -> Self {
326 Self {
327 queue_size: 0,
328 capacity,
329 leak_rate,
330 last_leak: Utc::now(),
331 }
332 }
333
334 fn leak(&mut self) {
335 let now = Utc::now();
336 let elapsed = now.signed_duration_since(self.last_leak);
337 let seconds = elapsed.num_milliseconds() as f64 / 1000.0;
338
339 let leaked = (seconds * self.leak_rate as f64) as u64;
340 self.queue_size = self.queue_size.saturating_sub(leaked);
341 self.last_leak = now;
342 }
343
344 fn add(&mut self, items: u64) -> bool {
345 self.leak();
346
347 if self.queue_size + items <= self.capacity {
348 self.queue_size += items;
349 true
350 } else {
351 false
352 }
353 }
354}
355
356#[derive(Debug)]
358enum RateLimiterState {
359 TokenBucket(TokenBucketState),
360 SlidingWindow(SlidingWindowState),
361 LeakyBucket(LeakyBucketState),
362}
363
364pub struct RateLimiter {
366 config: RateLimitConfig,
367 states: Arc<RwLock<HashMap<String, RateLimiterState>>>,
368 quotas: Arc<RwLock<HashMap<String, QuotaLimits>>>,
369 stats: Arc<RwLock<RateLimitStats>>,
370 #[cfg(feature = "redis")]
371 redis_client: Option<Arc<redis::Client>>,
372}
373
374impl RateLimiter {
375 pub fn new(config: RateLimitConfig) -> Result<Self> {
377 #[cfg(feature = "redis")]
378 let redis_client = if config.distributed {
379 if let Some(ref url) = config.redis_url {
380 Some(Arc::new(redis::Client::open(url.as_str())?))
381 } else {
382 return Err(anyhow!("Redis URL required for distributed rate limiting"));
383 }
384 } else {
385 None
386 };
387
388 Ok(Self {
389 config,
390 states: Arc::new(RwLock::new(HashMap::new())),
391 quotas: Arc::new(RwLock::new(HashMap::new())),
392 stats: Arc::new(RwLock::new(RateLimitStats::default())),
393 #[cfg(feature = "redis")]
394 redis_client,
395 })
396 }
397
398 pub async fn allow(&self, tenant_id: &str, tokens: u64) -> Result<bool> {
400 let mut states = self.states.write().await;
401 let mut stats = self.stats.write().await;
402
403 stats.total_requests += 1;
404
405 let state = states
407 .entry(tenant_id.to_string())
408 .or_insert_with(|| self.create_state());
409
410 let allowed = match state {
411 RateLimiterState::TokenBucket(bucket) => bucket.consume(tokens),
412 RateLimiterState::SlidingWindow(window) => {
413 if tokens != 1 {
414 warn!("Sliding window only supports single requests");
415 }
416 window.allow()
417 }
418 RateLimiterState::LeakyBucket(bucket) => bucket.add(tokens),
419 };
420
421 if allowed {
422 stats.allowed_requests += 1;
423 debug!(
424 "Request allowed for tenant {}: {} tokens",
425 tenant_id, tokens
426 );
427 } else {
428 stats.rejected_requests += 1;
429 warn!(
430 "Request rejected for tenant {}: rate limit exceeded",
431 tenant_id
432 );
433 }
434
435 Ok(allowed)
436 }
437
438 pub async fn set_quota(&self, tenant_id: &str, quota: QuotaLimits) -> Result<()> {
440 let mut quotas = self.quotas.write().await;
441 quotas.insert(tenant_id.to_string(), quota);
442 info!("Updated quota for tenant {}", tenant_id);
443 Ok(())
444 }
445
446 pub async fn get_quota(&self, tenant_id: &str) -> Result<QuotaLimits> {
448 let quotas = self.quotas.read().await;
449 Ok(quotas
450 .get(tenant_id)
451 .cloned()
452 .unwrap_or_else(|| self.config.default_quota.clone()))
453 }
454
455 pub async fn remaining_quota(&self, tenant_id: &str) -> Result<u64> {
457 let states = self.states.read().await;
458
459 match states.get(tenant_id) {
460 Some(RateLimiterState::TokenBucket(bucket)) => Ok(bucket.available_tokens()),
461 Some(RateLimiterState::SlidingWindow(window)) => Ok(window
462 .max_requests
463 .saturating_sub(window.current_count() as u64)),
464 Some(RateLimiterState::LeakyBucket(bucket)) => {
465 Ok(bucket.capacity.saturating_sub(bucket.queue_size))
466 }
467 None => Ok(0),
468 }
469 }
470
471 pub async fn reset(&self, tenant_id: &str) -> Result<()> {
473 let mut states = self.states.write().await;
474 states.remove(tenant_id);
475 info!("Reset rate limit state for tenant {}", tenant_id);
476 Ok(())
477 }
478
479 pub async fn stats(&self) -> Result<RateLimitStats> {
481 let stats = self.stats.read().await;
482 Ok(stats.clone())
483 }
484
485 pub async fn clear(&self) -> Result<()> {
487 let mut states = self.states.write().await;
488 let mut quotas = self.quotas.write().await;
489 states.clear();
490 quotas.clear();
491 info!("Cleared all rate limiting state");
492 Ok(())
493 }
494
495 fn create_state(&self) -> RateLimiterState {
497 match &self.config.algorithm {
498 RateLimitAlgorithm::TokenBucket {
499 capacity,
500 refill_rate,
501 } => RateLimiterState::TokenBucket(TokenBucketState::new(*capacity, *refill_rate)),
502 RateLimitAlgorithm::SlidingWindow {
503 window_size,
504 max_requests,
505 } => RateLimiterState::SlidingWindow(SlidingWindowState::new(
506 *window_size,
507 *max_requests,
508 )),
509 RateLimitAlgorithm::LeakyBucket {
510 capacity,
511 leak_rate,
512 } => RateLimiterState::LeakyBucket(LeakyBucketState::new(*capacity, *leak_rate)),
513 RateLimitAlgorithm::FixedWindow {
514 window_size,
515 max_requests,
516 } => {
517 RateLimiterState::SlidingWindow(SlidingWindowState::new(
519 *window_size,
520 *max_requests,
521 ))
522 }
523 RateLimitAlgorithm::Adaptive { base_limit, .. } => {
524 RateLimiterState::TokenBucket(TokenBucketState::new(*base_limit, *base_limit / 10))
526 }
527 }
528 }
529}
530
531#[derive(Debug, Clone, Default, Serialize, Deserialize)]
533pub struct RateLimitStats {
534 pub total_requests: u64,
536
537 pub allowed_requests: u64,
539
540 pub rejected_requests: u64,
542
543 pub active_tenants: usize,
545
546 pub rejection_rate: f64,
548}
549
550impl RateLimitStats {
551 pub fn calculate_rejection_rate(&mut self) {
553 if self.total_requests > 0 {
554 self.rejection_rate = self.rejected_requests as f64 / self.total_requests as f64;
555 }
556 }
557}
558
559pub struct QuotaManager {
561 limiter: Arc<RateLimiter>,
562 enforcement_mode: QuotaEnforcementMode,
563}
564
565impl QuotaManager {
566 pub fn new(config: RateLimitConfig) -> Result<Self> {
568 Ok(Self {
569 limiter: Arc::new(RateLimiter::new(config)?),
570 enforcement_mode: QuotaEnforcementMode::Strict,
571 })
572 }
573
574 pub async fn check_quota(
576 &self,
577 tenant_id: &str,
578 operation: &QuotaOperation,
579 ) -> Result<QuotaCheckResult> {
580 let tokens = match operation {
581 QuotaOperation::Request { count } => *count,
582 QuotaOperation::Bandwidth { bytes } => bytes / 1024, QuotaOperation::Storage { bytes } => bytes / (1024 * 1024), };
585
586 let allowed = self.limiter.allow(tenant_id, tokens).await?;
587 let remaining = self.limiter.remaining_quota(tenant_id).await?;
588
589 Ok(QuotaCheckResult {
590 allowed,
591 remaining,
592 reset_at: Utc::now() + ChronoDuration::seconds(60),
593 retry_after: if allowed {
594 None
595 } else {
596 Some(ChronoDuration::seconds(1))
597 },
598 })
599 }
600
601 pub async fn update_quota(&self, tenant_id: &str, quota: QuotaLimits) -> Result<()> {
603 self.limiter.set_quota(tenant_id, quota).await
604 }
605}
606
607#[derive(Debug, Clone, Copy, PartialEq, Eq)]
609pub enum QuotaEnforcementMode {
610 Strict,
612 Soft,
614 Disabled,
616}
617
618#[derive(Debug, Clone)]
620pub enum QuotaOperation {
621 Request { count: u64 },
623 Bandwidth { bytes: u64 },
625 Storage { bytes: u64 },
627}
628
629#[derive(Debug, Clone, Serialize, Deserialize)]
631pub struct QuotaCheckResult {
632 pub allowed: bool,
634 pub remaining: u64,
636 pub reset_at: DateTime<Utc>,
638 pub retry_after: Option<ChronoDuration>,
640}
641
642#[cfg(test)]
643mod tests {
644 use super::*;
645
646 #[tokio::test]
647 async fn test_token_bucket_basic() {
648 let config = RateLimitConfig {
649 algorithm: RateLimitAlgorithm::TokenBucket {
650 capacity: 10,
651 refill_rate: 1,
652 },
653 ..Default::default()
654 };
655
656 let limiter = RateLimiter::new(config).unwrap();
657
658 for i in 0..10 {
660 assert!(
661 limiter.allow("tenant-1", 1).await.unwrap(),
662 "Request {} should be allowed",
663 i
664 );
665 }
666
667 assert!(
669 !limiter.allow("tenant-1", 1).await.unwrap(),
670 "Request 11 should be rejected"
671 );
672 }
673
674 #[tokio::test]
675 async fn test_sliding_window_basic() {
676 let config = RateLimitConfig {
677 algorithm: RateLimitAlgorithm::SlidingWindow {
678 window_size: ChronoDuration::seconds(1),
679 max_requests: 5,
680 },
681 ..Default::default()
682 };
683
684 let limiter = RateLimiter::new(config).unwrap();
685
686 for i in 0..5 {
688 assert!(
689 limiter.allow("tenant-1", 1).await.unwrap(),
690 "Request {} should be allowed",
691 i
692 );
693 }
694
695 assert!(
697 !limiter.allow("tenant-1", 1).await.unwrap(),
698 "Request 6 should be rejected"
699 );
700 }
701
702 #[tokio::test]
703 async fn test_multi_tenant_isolation() {
704 let config = RateLimitConfig {
705 algorithm: RateLimitAlgorithm::TokenBucket {
706 capacity: 5,
707 refill_rate: 1,
708 },
709 per_tenant_quotas: true,
710 ..Default::default()
711 };
712
713 let limiter = RateLimiter::new(config).unwrap();
714
715 for _ in 0..5 {
717 assert!(limiter.allow("tenant-1", 1).await.unwrap());
718 }
719 assert!(!limiter.allow("tenant-1", 1).await.unwrap());
720
721 assert!(limiter.allow("tenant-2", 1).await.unwrap());
723 }
724
725 #[tokio::test]
726 async fn test_quota_manager() {
727 let config = RateLimitConfig {
728 algorithm: RateLimitAlgorithm::TokenBucket {
729 capacity: 100,
730 refill_rate: 10,
731 },
732 ..Default::default()
733 };
734
735 let manager = QuotaManager::new(config).unwrap();
736
737 let result = manager
739 .check_quota("tenant-1", &QuotaOperation::Request { count: 50 })
740 .await
741 .unwrap();
742 assert!(result.allowed);
743 assert!(result.remaining > 0);
744 }
745
746 #[tokio::test]
747 async fn test_quota_reset() {
748 let config = RateLimitConfig {
749 algorithm: RateLimitAlgorithm::TokenBucket {
750 capacity: 5,
751 refill_rate: 1,
752 },
753 ..Default::default()
754 };
755
756 let limiter = RateLimiter::new(config).unwrap();
757
758 for _ in 0..5 {
760 limiter.allow("tenant-1", 1).await.unwrap();
761 }
762 assert!(!limiter.allow("tenant-1", 1).await.unwrap());
763
764 limiter.reset("tenant-1").await.unwrap();
766 assert!(limiter.allow("tenant-1", 1).await.unwrap());
767 }
768
769 #[tokio::test]
770 async fn test_custom_quota() {
771 let config = RateLimitConfig::default();
772 let limiter = RateLimiter::new(config).unwrap();
773
774 let quota = QuotaLimits {
776 requests_per_second: 1000,
777 ..Default::default()
778 };
779
780 limiter
781 .set_quota("premium-tenant", quota.clone())
782 .await
783 .unwrap();
784
785 let retrieved = limiter.get_quota("premium-tenant").await.unwrap();
787 assert_eq!(retrieved.requests_per_second, 1000);
788 }
789
790 #[tokio::test]
791 async fn test_rate_limit_stats() {
792 let config = RateLimitConfig {
793 algorithm: RateLimitAlgorithm::TokenBucket {
794 capacity: 3,
795 refill_rate: 1,
796 },
797 ..Default::default()
798 };
799
800 let limiter = RateLimiter::new(config).unwrap();
801
802 limiter.allow("tenant-1", 1).await.unwrap();
804 limiter.allow("tenant-1", 1).await.unwrap();
805 limiter.allow("tenant-1", 1).await.unwrap();
806 limiter.allow("tenant-1", 1).await.unwrap(); let stats = limiter.stats().await.unwrap();
809 assert_eq!(stats.total_requests, 4);
810 assert_eq!(stats.allowed_requests, 3);
811 assert_eq!(stats.rejected_requests, 1);
812 }
813}