1use crate::errors::{AuthError, Result};
11use parking_lot::RwLock;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct RateLimitConfig {
20 pub max_requests: u32,
22 pub window_duration: Duration,
24 pub strategy: RateLimitStrategy,
26 pub distributed: bool,
28 pub redis_url: Option<String>,
30 pub burst_allowance: Option<u32>,
32 pub adaptive: bool,
34 pub penalty_duration: Option<Duration>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
40pub enum RateLimitStrategy {
41 TokenBucket,
43 FixedWindow,
45 SlidingWindow,
47 Adaptive,
49}
50
51impl Default for RateLimitConfig {
52 fn default() -> Self {
53 Self {
54 max_requests: 100,
55 window_duration: Duration::from_secs(60),
56 strategy: RateLimitStrategy::SlidingWindow,
57 distributed: false,
58 redis_url: None,
59 burst_allowance: Some(20),
60 adaptive: false,
61 penalty_duration: Some(Duration::from_secs(300)), }
63 }
64}
65
66impl RateLimitConfig {
67 pub fn strict_auth() -> Self {
69 Self {
70 max_requests: 5,
71 window_duration: Duration::from_secs(300), strategy: RateLimitStrategy::FixedWindow,
73 distributed: true,
74 redis_url: None, burst_allowance: None, adaptive: false,
77 penalty_duration: Some(Duration::from_secs(3600)), }
79 }
80
81 pub fn lenient_api() -> Self {
83 Self {
84 max_requests: 1000,
85 window_duration: Duration::from_secs(60),
86 strategy: RateLimitStrategy::TokenBucket,
87 distributed: false,
88 redis_url: None,
89 burst_allowance: Some(200),
90 adaptive: true,
91 penalty_duration: Some(Duration::from_secs(60)),
92 }
93 }
94
95 pub fn balanced() -> Self {
97 Self::default()
98 }
99}
100
101#[derive(Debug, Clone, PartialEq)]
103pub enum RateLimitResult {
104 Allowed { remaining: u32, reset_at: Instant },
106 Denied {
108 retry_after: Duration,
109 total_hits: u32,
110 },
111 Blocked { unblock_at: Instant, reason: String },
113}
114
115pub struct DistributedRateLimiter {
117 config: RateLimitConfig,
118 in_memory_limiter: Option<Arc<InMemoryRateLimiter>>,
119 #[cfg(feature = "redis-storage")]
120 redis_limiter: Option<Arc<RedisRateLimiter>>,
121 penalties: Arc<RwLock<HashMap<String, Instant>>>,
123}
124
125impl DistributedRateLimiter {
126 pub async fn new(config: RateLimitConfig) -> Result<Self> {
128 let in_memory_limiter = if config.distributed {
129 None
130 } else {
131 Some(Arc::new(InMemoryRateLimiter::new(&config)?))
132 };
133
134 #[cfg(feature = "redis-storage")]
135 let redis_limiter = if config.distributed && config.redis_url.is_some() {
136 Some(Arc::new(RedisRateLimiter::new(&config).await?))
137 } else {
138 None
139 };
140
141 #[cfg(not(feature = "redis-storage"))]
142 {
143 tracing::warn!(
145 "Redis storage not available for distributed rate limiting - using in-memory only"
146 );
147 tracing::warn!(
148 "For production deployments, enable 'redis-storage' feature for true distributed limiting"
149 );
150 }
151
152 Ok(Self {
153 config,
154 in_memory_limiter,
155 #[cfg(feature = "redis-storage")]
156 redis_limiter,
157 penalties: Arc::new(RwLock::new(HashMap::new())),
158 })
159 }
160
161 pub async fn check_rate_limit(&self, key: &str) -> Result<RateLimitResult> {
163 if let Some(unblock_at) = self.get_penalty_expiry(key) {
165 if Instant::now() < unblock_at {
166 return Ok(RateLimitResult::Blocked {
167 unblock_at,
168 reason: "Previous rate limit violations".to_string(),
169 });
170 } else {
171 self.remove_penalty(key);
173 }
174 }
175
176 let result = if self.config.distributed {
178 #[cfg(feature = "redis-storage")]
179 if let Some(ref redis_limiter) = self.redis_limiter {
180 redis_limiter.check_rate_limit(key).await?
181 } else {
182 self.fallback_check(key).await?
184 }
185 #[cfg(not(feature = "redis-storage"))]
186 self.fallback_check(key).await?
187 } else if let Some(ref in_memory_limiter) = self.in_memory_limiter {
188 in_memory_limiter.check_rate_limit(key).await?
189 } else {
190 return Err(AuthError::internal("No rate limiter configured"));
191 };
192
193 if matches!(result, RateLimitResult::Denied { .. })
195 && let Some(penalty_duration) = self.config.penalty_duration
196 {
197 self.apply_penalty(key, penalty_duration);
198 }
199
200 Ok(result)
201 }
202
203 pub async fn check_multiple_limits(
205 &self,
206 checks: &[(String, RateLimitConfig)],
207 ) -> Result<RateLimitResult> {
208 for (key, config) in checks {
209 let limiter = Self::new(config.clone()).await?;
210 let result = limiter.check_rate_limit(key).await?;
211
212 if !matches!(result, RateLimitResult::Allowed { .. }) {
214 return Ok(result);
215 }
216 }
217
218 Ok(RateLimitResult::Allowed {
220 remaining: u32::MAX, reset_at: Instant::now() + self.config.window_duration,
222 })
223 }
224
225 fn get_penalty_expiry(&self, key: &str) -> Option<Instant> {
226 let penalties = self.penalties.read();
227 penalties.get(key).copied()
228 }
229
230 fn apply_penalty(&self, key: &str, duration: Duration) {
231 let mut penalties = self.penalties.write();
232 penalties.insert(key.to_string(), Instant::now() + duration);
233 }
234
235 fn remove_penalty(&self, key: &str) {
236 let mut penalties = self.penalties.write();
237 penalties.remove(key);
238 }
239
240 async fn fallback_check(&self, key: &str) -> Result<RateLimitResult> {
241 let limiter = InMemoryRateLimiter::new(&self.config)?;
243 limiter.check_rate_limit(key).await
244 }
245}
246
247pub struct InMemoryRateLimiter {
249 config: RateLimitConfig,
250 buckets: std::sync::Arc<dashmap::DashMap<String, TokenBucket>>,
251}
252
253#[derive(Debug, Clone)]
254struct TokenBucket {
255 tokens: u32,
256 last_refill: Instant,
257}
258
259impl TokenBucket {
260 fn new(capacity: u32) -> Self {
261 Self {
262 tokens: capacity,
263 last_refill: Instant::now(),
264 }
265 }
266
267 fn try_consume(&mut self, config: &RateLimitConfig) -> bool {
268 let now = Instant::now();
269 let elapsed = now.duration_since(self.last_refill);
270
271 if elapsed >= config.window_duration {
273 self.tokens = config.max_requests;
274 self.last_refill = now;
275 }
276
277 if self.tokens > 0 {
278 self.tokens -= 1;
279 true
280 } else {
281 false
282 }
283 }
284}
285
286impl InMemoryRateLimiter {
287 pub fn new(config: &RateLimitConfig) -> Result<Self> {
288 Ok(Self {
289 config: config.clone(),
290 buckets: std::sync::Arc::new(dashmap::DashMap::new()),
291 })
292 }
293
294 pub async fn check_rate_limit(&self, key: &str) -> Result<RateLimitResult> {
295 let mut bucket = self
296 .buckets
297 .entry(key.to_string())
298 .or_insert_with(|| TokenBucket::new(self.config.max_requests));
299
300 if bucket.try_consume(&self.config) {
301 Ok(RateLimitResult::Allowed {
302 remaining: bucket.tokens,
303 reset_at: bucket.last_refill + self.config.window_duration,
304 })
305 } else {
306 let retry_after =
307 (bucket.last_refill + self.config.window_duration).duration_since(Instant::now());
308
309 Ok(RateLimitResult::Denied {
310 retry_after,
311 total_hits: self.config.max_requests + 1, })
313 }
314 }
315}
316
317#[cfg(feature = "redis-storage")]
319pub struct RedisRateLimiter {
320 client: redis::Client,
321 config: RateLimitConfig,
322}
323
324#[cfg(feature = "redis-storage")]
325impl RedisRateLimiter {
326 pub async fn new(config: &RateLimitConfig) -> Result<Self> {
327 let redis_url = config
328 .redis_url
329 .as_ref()
330 .ok_or_else(|| AuthError::config("Redis URL required for distributed rate limiting"))?;
331
332 let client = redis::Client::open(redis_url.as_str())
333 .map_err(|e| AuthError::internal(format!("Failed to connect to Redis: {}", e)))?;
334
335 Ok(Self {
336 client,
337 config: config.clone(),
338 })
339 }
340
341 pub async fn check_rate_limit(&self, key: &str) -> Result<RateLimitResult> {
342 let mut conn = self
343 .client
344 .get_multiplexed_tokio_connection()
345 .await
346 .map_err(|e| AuthError::internal(format!("Redis connection failed: {}", e)))?;
347
348 match self.config.strategy {
349 RateLimitStrategy::SlidingWindow => self.sliding_window_check(&mut conn, key).await,
350 RateLimitStrategy::FixedWindow => self.fixed_window_check(&mut conn, key).await,
351 RateLimitStrategy::TokenBucket => self.token_bucket_check(&mut conn, key).await,
352 RateLimitStrategy::Adaptive => self.adaptive_check(&mut conn, key).await,
353 }
354 }
355
356 async fn sliding_window_check(
357 &self,
358 conn: &mut redis::aio::MultiplexedConnection,
359 key: &str,
360 ) -> Result<RateLimitResult> {
361 let now = chrono::Utc::now().timestamp();
362 let window_start = now - self.config.window_duration.as_secs() as i64;
363 let redis_key = format!("rate_limit:sliding:{}", key);
364
365 let script = r#"
367 local key = KEYS[1]
368 local window_start = ARGV[1]
369 local now = ARGV[2]
370 local max_requests = tonumber(ARGV[3])
371 local expiry = tonumber(ARGV[4])
372
373 -- Remove expired entries
374 redis.call('ZREMRANGEBYSCORE', key, '-inf', window_start)
375
376 -- Count current requests in window
377 local current_requests = redis.call('ZCARD', key)
378
379 if current_requests < max_requests then
380 -- Add current request
381 redis.call('ZADD', key, now, now)
382 redis.call('EXPIRE', key, expiry)
383 return {1, max_requests - current_requests - 1}
384 else
385 return {0, current_requests}
386 end
387 "#;
388
389 let result: Vec<i32> = redis::Script::new(script)
390 .key(&redis_key)
391 .arg(window_start)
392 .arg(now)
393 .arg(self.config.max_requests)
394 .arg(self.config.window_duration.as_secs())
395 .invoke_async(conn)
396 .await
397 .map_err(|e| AuthError::internal(format!("Redis script error: {}", e)))?;
398
399 if result[0] == 1 {
400 Ok(RateLimitResult::Allowed {
401 remaining: result[1] as u32,
402 reset_at: Instant::now() + self.config.window_duration,
403 })
404 } else {
405 Ok(RateLimitResult::Denied {
406 retry_after: self.config.window_duration,
407 total_hits: result[1] as u32,
408 })
409 }
410 }
411
412 async fn fixed_window_check(
413 &self,
414 conn: &mut redis::aio::MultiplexedConnection,
415 key: &str,
416 ) -> Result<RateLimitResult> {
417 use redis::AsyncCommands;
418
419 let window_size = self.config.window_duration.as_secs();
420 let current_window = chrono::Utc::now().timestamp() / window_size as i64;
421 let redis_key = format!("rate_limit:fixed:{}:{}", key, current_window);
422
423 let count: u32 = conn
425 .incr(&redis_key, 1)
426 .await
427 .map_err(|e| AuthError::internal(format!("Redis incr error: {}", e)))?;
428
429 if count == 1 {
430 let _: () = conn
432 .expire(&redis_key, window_size as i64)
433 .await
434 .map_err(|e| AuthError::internal(format!("Redis expire error: {}", e)))?;
435 }
436
437 if count <= self.config.max_requests {
438 Ok(RateLimitResult::Allowed {
439 remaining: self.config.max_requests - count,
440 reset_at: Instant::now()
441 + Duration::from_secs(
442 window_size - (chrono::Utc::now().timestamp() % window_size as i64) as u64,
443 ),
444 })
445 } else {
446 Ok(RateLimitResult::Denied {
447 retry_after: Duration::from_secs(
448 window_size - (chrono::Utc::now().timestamp() % window_size as i64) as u64,
449 ),
450 total_hits: count,
451 })
452 }
453 }
454
455 async fn token_bucket_check(
456 &self,
457 conn: &mut redis::aio::MultiplexedConnection,
458 key: &str,
459 ) -> Result<RateLimitResult> {
460 let redis_key = format!("rate_limit:bucket:{}", key);
461 let now = chrono::Utc::now().timestamp_millis();
462 let refill_rate =
463 self.config.max_requests as f64 / self.config.window_duration.as_secs_f64();
464 let bucket_size = self.config.max_requests + self.config.burst_allowance.unwrap_or(0);
465
466 let script = r#"
468 local key = KEYS[1]
469 local now = tonumber(ARGV[1])
470 local refill_rate = tonumber(ARGV[2])
471 local bucket_size = tonumber(ARGV[3])
472 local cost = tonumber(ARGV[4])
473
474 local bucket = redis.call('HMGET', key, 'tokens', 'last_refill')
475 local tokens = tonumber(bucket[1]) or bucket_size
476 local last_refill = tonumber(bucket[2]) or now
477
478 -- Calculate tokens to add
479 local time_passed = (now - last_refill) / 1000.0
480 local tokens_to_add = time_passed * refill_rate
481 tokens = math.min(bucket_size, tokens + tokens_to_add)
482
483 if tokens >= cost then
484 tokens = tokens - cost
485 redis.call('HMSET', key, 'tokens', tokens, 'last_refill', now)
486 redis.call('EXPIRE', key, 3600) -- 1 hour expiry
487 return {1, math.floor(tokens)}
488 else
489 redis.call('HMSET', key, 'tokens', tokens, 'last_refill', now)
490 redis.call('EXPIRE', key, 3600)
491 return {0, math.floor(tokens)}
492 end
493 "#;
494
495 let result: Vec<i32> = redis::Script::new(script)
496 .key(&redis_key)
497 .arg(now)
498 .arg(refill_rate)
499 .arg(bucket_size)
500 .arg(1) .invoke_async(conn)
502 .await
503 .map_err(|e| AuthError::internal(format!("Redis script error: {}", e)))?;
504
505 if result[0] == 1 {
506 Ok(RateLimitResult::Allowed {
507 remaining: result[1] as u32,
508 reset_at: Instant::now() + self.config.window_duration,
509 })
510 } else {
511 let retry_after = Duration::from_secs_f64(1.0 / refill_rate);
512 Ok(RateLimitResult::Denied {
513 retry_after,
514 total_hits: self.config.max_requests + 1,
515 })
516 }
517 }
518
519 async fn adaptive_check(
520 &self,
521 conn: &mut redis::aio::MultiplexedConnection,
522 key: &str,
523 ) -> Result<RateLimitResult> {
524 self.sliding_window_check(conn, key).await
527 }
528}
529
530pub struct RateLimitMiddleware {
532 limiters: HashMap<String, Arc<DistributedRateLimiter>>,
533}
534
535impl Default for RateLimitMiddleware {
536 fn default() -> Self {
537 Self::new()
538 }
539}
540
541impl RateLimitMiddleware {
542 pub fn new() -> Self {
543 Self {
544 limiters: HashMap::new(),
545 }
546 }
547
548 pub async fn add_limiter(&mut self, name: &str, config: RateLimitConfig) -> Result<()> {
550 let limiter = Arc::new(DistributedRateLimiter::new(config).await?);
551 self.limiters.insert(name.to_string(), limiter);
552 Ok(())
553 }
554
555 pub async fn check_limit(&self, limiter_name: &str, key: &str) -> Result<RateLimitResult> {
557 let limiter = self.limiters.get(limiter_name).ok_or_else(|| {
558 AuthError::config(format!("No rate limiter found for '{}'", limiter_name))
559 })?;
560
561 limiter.check_rate_limit(key).await
562 }
563
564 pub async fn check_multiple(&self, checks: &[(String, String)]) -> Result<RateLimitResult> {
566 for (limiter_name, key) in checks {
567 let result = self.check_limit(limiter_name, key).await?;
568 if !matches!(result, RateLimitResult::Allowed { .. }) {
569 return Ok(result);
570 }
571 }
572
573 Ok(RateLimitResult::Allowed {
574 remaining: u32::MAX,
575 reset_at: Instant::now() + Duration::from_secs(60),
576 })
577 }
578}
579
580pub struct RateLimitUtils;
582
583impl RateLimitUtils {
584 pub fn ip_key(ip: &str) -> String {
586 format!("ip:{}", ip)
587 }
588
589 pub fn user_key(user_id: &str) -> String {
591 format!("user:{}", user_id)
592 }
593
594 pub fn endpoint_key(endpoint: &str, ip: &str) -> String {
596 format!("endpoint:{}:{}", endpoint, ip)
597 }
598
599 pub fn auth_key(ip: &str, username: Option<&str>) -> String {
601 match username {
602 Some(user) => format!("auth:{}:{}", ip, user),
603 None => format!("auth:{}", ip),
604 }
605 }
606
607 pub fn exponential_backoff(attempt: u32, base_duration: Duration) -> Duration {
609 let multiplier = 2_u64.pow(attempt.min(10)); Duration::from_millis(base_duration.as_millis() as u64 * multiplier)
611 }
612
613 pub fn add_jitter(duration: Duration, jitter_factor: f64) -> Duration {
615 use rand::Rng;
616 let jitter = rand::rng().random_range(0.0..jitter_factor);
617 let jitter_ms = (duration.as_millis() as f64 * jitter) as u64;
618 duration + Duration::from_millis(jitter_ms)
619 }
620}
621
622#[cfg(test)]
623mod tests {
624 use super::*;
625 use tokio::time::sleep;
626
627 #[tokio::test]
628 async fn test_in_memory_rate_limiter() {
629 let config = RateLimitConfig {
630 max_requests: 3,
631 window_duration: Duration::from_millis(100), strategy: RateLimitStrategy::TokenBucket,
633 distributed: false,
634 redis_url: None,
635 burst_allowance: Some(1),
636 adaptive: false,
637 penalty_duration: None,
638 };
639
640 let limiter = DistributedRateLimiter::new(config).await.unwrap();
641
642 for i in 0..3 {
644 let result = limiter.check_rate_limit("test_key").await.unwrap();
645 assert!(
646 matches!(result, RateLimitResult::Allowed { .. }),
647 "Request {} should be allowed",
648 i
649 );
650 }
651
652 let result = limiter.check_rate_limit("test_key").await.unwrap();
654 assert!(
655 matches!(result, RateLimitResult::Denied { .. }),
656 "4th request should be denied"
657 );
658
659 sleep(Duration::from_millis(150)).await;
661
662 let result = limiter.check_rate_limit("test_key").await.unwrap();
664 assert!(
665 matches!(result, RateLimitResult::Allowed { .. }),
666 "Request after window reset should be allowed"
667 );
668 assert!(
669 matches!(result, RateLimitResult::Allowed { .. }),
670 "Request after reset should be allowed"
671 );
672 }
673
674 #[tokio::test]
675 async fn test_penalty_system() {
676 let config = RateLimitConfig {
677 max_requests: 1,
678 window_duration: Duration::from_millis(50),
679 strategy: RateLimitStrategy::FixedWindow,
680 distributed: false,
681 redis_url: None,
682 burst_allowance: None,
683 adaptive: false,
684 penalty_duration: Some(Duration::from_millis(200)),
685 };
686
687 let limiter = DistributedRateLimiter::new(config).await.unwrap();
688
689 let result = limiter.check_rate_limit("penalty_test").await.unwrap();
691 assert!(matches!(result, RateLimitResult::Allowed { .. }));
692
693 let result = limiter.check_rate_limit("penalty_test").await.unwrap();
695 assert!(matches!(result, RateLimitResult::Denied { .. }));
696
697 sleep(Duration::from_millis(10)).await;
699
700 let result = limiter.check_rate_limit("penalty_test").await.unwrap();
702 assert!(matches!(result, RateLimitResult::Blocked { .. }));
703
704 sleep(Duration::from_millis(250)).await;
706
707 let result = limiter.check_rate_limit("penalty_test").await.unwrap();
709 assert!(matches!(result, RateLimitResult::Allowed { .. }));
710 }
711
712 #[tokio::test]
713 async fn test_rate_limit_key_generation() {
714 assert_eq!(RateLimitUtils::ip_key("192.168.1.1"), "ip:192.168.1.1");
715 assert_eq!(RateLimitUtils::user_key("user123"), "user:user123");
716 assert_eq!(
717 RateLimitUtils::endpoint_key("/api/login", "192.168.1.1"),
718 "endpoint:/api/login:192.168.1.1"
719 );
720 assert_eq!(
721 RateLimitUtils::auth_key("192.168.1.1", Some("user123")),
722 "auth:192.168.1.1:user123"
723 );
724 assert_eq!(
725 RateLimitUtils::auth_key("192.168.1.1", None),
726 "auth:192.168.1.1"
727 );
728 }
729
730 #[tokio::test]
731 async fn test_exponential_backoff() {
732 let base = Duration::from_millis(100);
733
734 assert_eq!(
735 RateLimitUtils::exponential_backoff(0, base),
736 Duration::from_millis(100)
737 );
738 assert_eq!(
739 RateLimitUtils::exponential_backoff(1, base),
740 Duration::from_millis(200)
741 );
742 assert_eq!(
743 RateLimitUtils::exponential_backoff(2, base),
744 Duration::from_millis(400)
745 );
746 assert_eq!(
747 RateLimitUtils::exponential_backoff(10, base),
748 Duration::from_millis(102400)
749 );
750
751 assert_eq!(
753 RateLimitUtils::exponential_backoff(15, base),
754 Duration::from_millis(102400)
755 );
756 }
757
758 #[tokio::test]
759 async fn test_rate_limit_configurations() {
760 let strict = RateLimitConfig::strict_auth();
761 assert_eq!(strict.max_requests, 5);
762 assert_eq!(strict.window_duration, Duration::from_secs(300));
763 assert!(strict.distributed);
764
765 let lenient = RateLimitConfig::lenient_api();
766 assert_eq!(lenient.max_requests, 1000);
767 assert!(lenient.adaptive);
768
769 let balanced = RateLimitConfig::balanced();
770 assert_eq!(balanced.max_requests, 100);
771 assert_eq!(balanced.strategy, RateLimitStrategy::SlidingWindow);
772 }
773}