auth_framework/utils/
rate_limit.rs

1//! Rate limiting utilities for the AuthFramework.
2
3use crate::errors::{AuthError, Result};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7
8/// Rate limiter implementation
9#[derive(Debug, Clone)]
10pub struct RateLimiter {
11    max_requests: u32,
12    window: Duration,
13    requests: Arc<Mutex<HashMap<String, Vec<Instant>>>>,
14}
15
16impl RateLimiter {
17    /// Create a new rate limiter
18    pub fn new(max_requests: u32, window: Duration) -> Self {
19        Self {
20            max_requests,
21            window,
22            requests: Arc::new(Mutex::new(HashMap::new())),
23        }
24    }
25
26    /// Check if a request is allowed for the given key
27    pub fn check_rate_limit(&self, key: &str) -> Result<bool> {
28        let mut requests = self.requests.lock().map_err(|_| {
29            AuthError::internal("Failed to acquire rate limiter lock".to_string())
30        })?;
31
32        let now = Instant::now();
33        let entry = requests.entry(key.to_string()).or_insert_with(Vec::new);
34
35        // Remove expired requests
36        entry.retain(|&request_time| now.duration_since(request_time) < self.window);
37
38        if entry.len() >= self.max_requests as usize {
39            return Ok(false); // Rate limit exceeded
40        }
41
42        // Add current request
43        entry.push(now);
44        Ok(true)
45    }
46
47    /// Alias for check_rate_limit for compatibility
48    pub fn is_allowed(&self, key: &str) -> bool {
49        self.check_rate_limit(key).unwrap_or(false)
50    }
51
52    /// Alias for get_remaining_requests for compatibility  
53    pub fn remaining_requests(&self, key: &str) -> Result<u32> {
54        self.get_remaining_requests(key)
55    }
56
57    /// Get the number of requests for a key
58    pub fn get_request_count(&self, key: &str) -> Result<usize> {
59        let requests = self.requests.lock().map_err(|_| {
60            AuthError::internal("Failed to acquire rate limiter lock".to_string())
61        })?;
62
63        let now = Instant::now();
64        if let Some(entry) = requests.get(key) {
65            let valid_requests = entry
66                .iter()
67                .filter(|&&request_time| now.duration_since(request_time) < self.window)
68                .count();
69            Ok(valid_requests)
70        } else {
71            Ok(0)
72        }
73    }
74
75    /// Clean up expired entries
76    pub fn cleanup(&self) -> Result<usize> {
77        let mut requests = self.requests.lock().map_err(|_| {
78            AuthError::internal("Failed to acquire rate limiter lock".to_string())
79        })?;
80
81        let now = Instant::now();
82        let mut removed_count = 0;
83
84        requests.retain(|_, entry| {
85            entry.retain(|&request_time| now.duration_since(request_time) < self.window);
86            if entry.is_empty() {
87                removed_count += 1;
88                false
89            } else {
90                true
91            }
92        });
93
94        Ok(removed_count)
95    }
96
97    /// Reset rate limit for a specific key
98    pub fn reset(&self, key: &str) -> Result<()> {
99        let mut requests = self.requests.lock().map_err(|_| {
100            AuthError::internal("Failed to acquire rate limiter lock".to_string())
101        })?;
102
103        requests.remove(key);
104        Ok(())
105    }
106
107    /// Get remaining requests for a key
108    pub fn get_remaining_requests(&self, key: &str) -> Result<u32> {
109        let count = self.get_request_count(key)?;
110        Ok(self.max_requests.saturating_sub(count as u32))
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use std::thread;
118
119    #[test]
120    fn test_rate_limiter() {
121        let limiter = RateLimiter::new(3, Duration::from_secs(1));
122        let key = "test_key";
123
124        // First 3 requests should be allowed
125        assert!(limiter.check_rate_limit(key).unwrap());
126        assert!(limiter.check_rate_limit(key).unwrap());
127        assert!(limiter.check_rate_limit(key).unwrap());
128
129        // 4th request should be denied
130        assert!(!limiter.check_rate_limit(key).unwrap());
131
132        // Wait for window to expire
133        thread::sleep(Duration::from_millis(1100));
134
135        // Should be allowed again
136        assert!(limiter.check_rate_limit(key).unwrap());
137    }
138
139    #[test]
140    fn test_cleanup() {
141        let limiter = RateLimiter::new(10, Duration::from_millis(100));
142        
143        limiter.check_rate_limit("key1").unwrap();
144        limiter.check_rate_limit("key2").unwrap();
145        
146        thread::sleep(Duration::from_millis(150));
147        
148        let removed = limiter.cleanup().unwrap();
149        assert_eq!(removed, 2);
150    }
151}