Skip to main content

auth_framework/utils/
rate_limit.rs

1//! Rate limiting utilities for the AuthFramework.
2
3use crate::errors::{AuthError, Result};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7
8/// Rate limiter implementation.
9///
10/// # Example
11/// ```rust
12/// use auth_framework::utils::rate_limit::RateLimiter;
13/// use std::time::Duration;
14/// let limiter = RateLimiter::new(5, Duration::from_secs(60));
15/// assert!(limiter.is_allowed("client-1"));
16/// ```
17#[derive(Debug, Clone)]
18pub struct RateLimiter {
19    max_requests: u32,
20    window: Duration,
21    requests: Arc<Mutex<HashMap<String, Vec<Instant>>>>,
22}
23
24impl RateLimiter {
25    /// Create a new rate limiter.
26    ///
27    /// # Example
28    /// ```rust
29    /// use auth_framework::utils::rate_limit::RateLimiter;
30    /// use std::time::Duration;
31    /// let limiter = RateLimiter::new(100, Duration::from_secs(60));
32    /// ```
33    pub fn new(max_requests: u32, window: Duration) -> Self {
34        Self {
35            max_requests,
36            window,
37            requests: Arc::new(Mutex::new(HashMap::new())),
38        }
39    }
40
41    /// Check if a request is allowed for the given key.
42    ///
43    /// # Example
44    /// ```rust
45    /// use auth_framework::utils::rate_limit::RateLimiter;
46    /// use std::time::Duration;
47    /// let limiter = RateLimiter::new(2, Duration::from_secs(60));
48    /// assert_eq!(limiter.check_rate_limit("k").unwrap(), true);
49    /// assert_eq!(limiter.check_rate_limit("k").unwrap(), true);
50    /// assert_eq!(limiter.check_rate_limit("k").unwrap(), false);
51    /// ```
52    pub fn check_rate_limit(&self, key: &str) -> Result<bool> {
53        let mut requests = self
54            .requests
55            .lock()
56            .map_err(|_| AuthError::internal("Failed to acquire rate limiter lock".to_string()))?;
57
58        let now = Instant::now();
59        let entry = requests.entry(key.to_string()).or_insert_with(Vec::new);
60
61        // Remove expired requests
62        entry.retain(|&request_time| now.duration_since(request_time) < self.window);
63
64        if entry.len() >= self.max_requests as usize {
65            return Ok(false); // Rate limit exceeded
66        }
67
68        // Add current request
69        entry.push(now);
70        Ok(true)
71    }
72
73    /// Alias for check_rate_limit for compatibility.
74    ///
75    /// # Example
76    /// ```rust
77    /// use auth_framework::utils::rate_limit::RateLimiter;
78    /// use std::time::Duration;
79    /// let limiter = RateLimiter::new(1, Duration::from_secs(60));
80    /// assert!(limiter.is_allowed("k"));
81    /// assert!(!limiter.is_allowed("k"));
82    /// ```
83    pub fn is_allowed(&self, key: &str) -> bool {
84        self.check_rate_limit(key).unwrap_or(false)
85    }
86
87    /// Alias for get_remaining_requests for compatibility.
88    ///
89    /// # Example
90    /// ```rust
91    /// use auth_framework::utils::rate_limit::RateLimiter;
92    /// use std::time::Duration;
93    /// let limiter = RateLimiter::new(5, Duration::from_secs(60));
94    /// assert_eq!(limiter.remaining_requests("k").unwrap(), 5);
95    /// ```
96    pub fn remaining_requests(&self, key: &str) -> Result<u32> {
97        self.get_remaining_requests(key)
98    }
99
100    /// Get the number of requests for a key.
101    ///
102    /// # Example
103    /// ```rust
104    /// use auth_framework::utils::rate_limit::RateLimiter;
105    /// use std::time::Duration;
106    /// let limiter = RateLimiter::new(10, Duration::from_secs(60));
107    /// limiter.is_allowed("k");
108    /// assert_eq!(limiter.get_request_count("k").unwrap(), 1);
109    /// ```
110    pub fn get_request_count(&self, key: &str) -> Result<usize> {
111        let requests = self
112            .requests
113            .lock()
114            .map_err(|_| AuthError::internal("Failed to acquire rate limiter lock".to_string()))?;
115
116        let now = Instant::now();
117        if let Some(entry) = requests.get(key) {
118            let valid_requests = entry
119                .iter()
120                .filter(|&&request_time| now.duration_since(request_time) < self.window)
121                .count();
122            Ok(valid_requests)
123        } else {
124            Ok(0)
125        }
126    }
127
128    /// Clean up expired entries.
129    ///
130    /// # Example
131    /// ```rust
132    /// use auth_framework::utils::rate_limit::RateLimiter;
133    /// use std::time::Duration;
134    /// let limiter = RateLimiter::new(10, Duration::from_secs(60));
135    /// let removed = limiter.cleanup().unwrap();
136    /// assert_eq!(removed, 0);
137    /// ```
138    pub fn cleanup(&self) -> Result<usize> {
139        let mut requests = self
140            .requests
141            .lock()
142            .map_err(|_| AuthError::internal("Failed to acquire rate limiter lock".to_string()))?;
143
144        let now = Instant::now();
145        let mut removed_count = 0;
146
147        requests.retain(|_, entry| {
148            entry.retain(|&request_time| now.duration_since(request_time) < self.window);
149            if entry.is_empty() {
150                removed_count += 1;
151                false
152            } else {
153                true
154            }
155        });
156
157        Ok(removed_count)
158    }
159
160    /// Reset rate limit for a specific key.
161    ///
162    /// # Example
163    /// ```rust
164    /// use auth_framework::utils::rate_limit::RateLimiter;
165    /// use std::time::Duration;
166    /// let limiter = RateLimiter::new(1, Duration::from_secs(60));
167    /// limiter.is_allowed("k");
168    /// limiter.reset("k").unwrap();
169    /// assert!(limiter.is_allowed("k"));
170    /// ```
171    pub fn reset(&self, key: &str) -> Result<()> {
172        let mut requests = self
173            .requests
174            .lock()
175            .map_err(|_| AuthError::internal("Failed to acquire rate limiter lock".to_string()))?;
176
177        requests.remove(key);
178        Ok(())
179    }
180
181    /// Get remaining requests for a key.
182    ///
183    /// # Example
184    /// ```rust
185    /// use auth_framework::utils::rate_limit::RateLimiter;
186    /// use std::time::Duration;
187    /// let limiter = RateLimiter::new(5, Duration::from_secs(60));
188    /// limiter.is_allowed("k");
189    /// assert_eq!(limiter.get_remaining_requests("k").unwrap(), 4);
190    /// ```
191    pub fn get_remaining_requests(&self, key: &str) -> Result<u32> {
192        let count = self.get_request_count(key)?;
193        Ok(self.max_requests.saturating_sub(count as u32))
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use std::thread;
201
202    #[test]
203    fn test_rate_limiter() {
204        let limiter = RateLimiter::new(3, Duration::from_secs(1));
205        let key = "test_key";
206
207        // First 3 requests should be allowed
208        assert!(limiter.check_rate_limit(key).unwrap());
209        assert!(limiter.check_rate_limit(key).unwrap());
210        assert!(limiter.check_rate_limit(key).unwrap());
211
212        // 4th request should be denied
213        assert!(!limiter.check_rate_limit(key).unwrap());
214
215        // Wait for window to expire
216        thread::sleep(Duration::from_millis(1100));
217
218        // Should be allowed again
219        assert!(limiter.check_rate_limit(key).unwrap());
220    }
221
222    #[test]
223    fn test_cleanup() {
224        let limiter = RateLimiter::new(10, Duration::from_millis(100));
225
226        limiter.check_rate_limit("key1").unwrap();
227        limiter.check_rate_limit("key2").unwrap();
228
229        thread::sleep(Duration::from_millis(150));
230
231        let removed = limiter.cleanup().unwrap();
232        assert_eq!(removed, 2);
233    }
234
235    #[test]
236    fn test_zero_max_requests_denies_all() {
237        let limiter = RateLimiter::new(0, Duration::from_secs(60));
238        assert!(!limiter.check_rate_limit("key").unwrap());
239        assert!(!limiter.is_allowed("key"));
240    }
241
242    #[test]
243    fn test_single_request_limit() {
244        let limiter = RateLimiter::new(1, Duration::from_secs(60));
245        assert!(limiter.check_rate_limit("key").unwrap());
246        assert!(!limiter.check_rate_limit("key").unwrap());
247    }
248
249    #[test]
250    fn test_independent_keys() {
251        let limiter = RateLimiter::new(1, Duration::from_secs(60));
252        assert!(limiter.check_rate_limit("key1").unwrap());
253        assert!(limiter.check_rate_limit("key2").unwrap());
254        // key1 is exhausted, key2 is exhausted
255        assert!(!limiter.check_rate_limit("key1").unwrap());
256        assert!(!limiter.check_rate_limit("key2").unwrap());
257    }
258
259    #[test]
260    fn test_empty_key() {
261        let limiter = RateLimiter::new(2, Duration::from_secs(60));
262        assert!(limiter.check_rate_limit("").unwrap());
263        assert!(limiter.check_rate_limit("").unwrap());
264        assert!(!limiter.check_rate_limit("").unwrap());
265    }
266
267    #[test]
268    fn test_remaining_requests_decrements() {
269        let limiter = RateLimiter::new(3, Duration::from_secs(60));
270        assert_eq!(limiter.get_remaining_requests("k").unwrap(), 3);
271        limiter.check_rate_limit("k").unwrap();
272        assert_eq!(limiter.get_remaining_requests("k").unwrap(), 2);
273        limiter.check_rate_limit("k").unwrap();
274        assert_eq!(limiter.get_remaining_requests("k").unwrap(), 1);
275        limiter.check_rate_limit("k").unwrap();
276        assert_eq!(limiter.get_remaining_requests("k").unwrap(), 0);
277    }
278
279    #[test]
280    fn test_remaining_requests_for_unknown_key() {
281        let limiter = RateLimiter::new(5, Duration::from_secs(60));
282        assert_eq!(limiter.get_remaining_requests("unknown").unwrap(), 5);
283    }
284
285    #[test]
286    fn test_get_request_count_unknown_key() {
287        let limiter = RateLimiter::new(5, Duration::from_secs(60));
288        assert_eq!(limiter.get_request_count("unknown").unwrap(), 0);
289    }
290
291    #[test]
292    fn test_reset_clears_count() {
293        let limiter = RateLimiter::new(2, Duration::from_secs(60));
294        limiter.check_rate_limit("k").unwrap();
295        limiter.check_rate_limit("k").unwrap();
296        assert!(!limiter.is_allowed("k"));
297
298        limiter.reset("k").unwrap();
299        assert!(limiter.is_allowed("k"));
300        assert_eq!(limiter.get_request_count("k").unwrap(), 1);
301    }
302
303    #[test]
304    fn test_reset_nonexistent_key_is_ok() {
305        let limiter = RateLimiter::new(5, Duration::from_secs(60));
306        assert!(limiter.reset("nonexistent").is_ok());
307    }
308
309    #[test]
310    fn test_cleanup_empty_limiter() {
311        let limiter = RateLimiter::new(5, Duration::from_secs(60));
312        assert_eq!(limiter.cleanup().unwrap(), 0);
313    }
314
315    #[test]
316    fn test_clone_shares_state() {
317        let limiter = RateLimiter::new(2, Duration::from_secs(60));
318        let limiter2 = limiter.clone();
319        limiter.check_rate_limit("k").unwrap();
320        // Clone should see the same request count
321        assert_eq!(limiter2.get_request_count("k").unwrap(), 1);
322    }
323
324    #[test]
325    fn test_concurrent_access() {
326        let limiter = RateLimiter::new(100, Duration::from_secs(60));
327        let mut handles = vec![];
328
329        for i in 0..10 {
330            let l = limiter.clone();
331            handles.push(thread::spawn(move || {
332                for _ in 0..10 {
333                    let _ = l.check_rate_limit(&format!("thread-{}", i));
334                }
335            }));
336        }
337
338        for h in handles {
339            h.join().unwrap();
340        }
341
342        // Each thread made 10 requests under its own key
343        for i in 0..10 {
344            assert_eq!(
345                limiter
346                    .get_request_count(&format!("thread-{}", i))
347                    .unwrap(),
348                10
349            );
350        }
351    }
352
353    #[test]
354    fn test_remaining_alias_matches() {
355        let limiter = RateLimiter::new(5, Duration::from_secs(60));
356        limiter.check_rate_limit("k").unwrap();
357        assert_eq!(
358            limiter.remaining_requests("k").unwrap(),
359            limiter.get_remaining_requests("k").unwrap()
360        );
361    }
362
363    #[test]
364    fn test_is_allowed_alias_matches() {
365        // is_allowed is an alias for check_rate_limit that returns false on error
366        let limiter = RateLimiter::new(2, Duration::from_secs(60));
367        // Both should return true for fresh key
368        assert!(limiter.is_allowed("a"));
369        assert!(limiter.check_rate_limit("b").unwrap());
370    }
371
372    #[test]
373    fn test_many_keys_cleanup() {
374        let limiter = RateLimiter::new(1, Duration::from_millis(50));
375        for i in 0..100 {
376            limiter.check_rate_limit(&format!("key-{}", i)).unwrap();
377        }
378        thread::sleep(Duration::from_millis(100));
379        let removed = limiter.cleanup().unwrap();
380        assert_eq!(removed, 100);
381    }
382}