auth_framework/utils/
rate_limit.rs1use crate::errors::{AuthError, Result};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7
8#[derive(Debug, Clone)]
18pub struct RateLimiter {
19 max_requests: u32,
20 window: Duration,
21 requests: Arc<Mutex<HashMap<String, Vec<Instant>>>>,
22}
23
24impl RateLimiter {
25 pub fn new(max_requests: u32, window: Duration) -> Self {
34 Self {
35 max_requests,
36 window,
37 requests: Arc::new(Mutex::new(HashMap::new())),
38 }
39 }
40
41 pub fn check_rate_limit(&self, key: &str) -> Result<bool> {
53 let mut requests = self
54 .requests
55 .lock()
56 .map_err(|_| AuthError::internal("Failed to acquire rate limiter lock".to_string()))?;
57
58 let now = Instant::now();
59 let entry = requests.entry(key.to_string()).or_insert_with(Vec::new);
60
61 entry.retain(|&request_time| now.duration_since(request_time) < self.window);
63
64 if entry.len() >= self.max_requests as usize {
65 return Ok(false); }
67
68 entry.push(now);
70 Ok(true)
71 }
72
73 pub fn is_allowed(&self, key: &str) -> bool {
84 self.check_rate_limit(key).unwrap_or(false)
85 }
86
87 pub fn remaining_requests(&self, key: &str) -> Result<u32> {
97 self.get_remaining_requests(key)
98 }
99
100 pub fn get_request_count(&self, key: &str) -> Result<usize> {
111 let requests = self
112 .requests
113 .lock()
114 .map_err(|_| AuthError::internal("Failed to acquire rate limiter lock".to_string()))?;
115
116 let now = Instant::now();
117 if let Some(entry) = requests.get(key) {
118 let valid_requests = entry
119 .iter()
120 .filter(|&&request_time| now.duration_since(request_time) < self.window)
121 .count();
122 Ok(valid_requests)
123 } else {
124 Ok(0)
125 }
126 }
127
128 pub fn cleanup(&self) -> Result<usize> {
139 let mut requests = self
140 .requests
141 .lock()
142 .map_err(|_| AuthError::internal("Failed to acquire rate limiter lock".to_string()))?;
143
144 let now = Instant::now();
145 let mut removed_count = 0;
146
147 requests.retain(|_, entry| {
148 entry.retain(|&request_time| now.duration_since(request_time) < self.window);
149 if entry.is_empty() {
150 removed_count += 1;
151 false
152 } else {
153 true
154 }
155 });
156
157 Ok(removed_count)
158 }
159
160 pub fn reset(&self, key: &str) -> Result<()> {
172 let mut requests = self
173 .requests
174 .lock()
175 .map_err(|_| AuthError::internal("Failed to acquire rate limiter lock".to_string()))?;
176
177 requests.remove(key);
178 Ok(())
179 }
180
181 pub fn get_remaining_requests(&self, key: &str) -> Result<u32> {
192 let count = self.get_request_count(key)?;
193 Ok(self.max_requests.saturating_sub(count as u32))
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use std::thread;
201
202 #[test]
203 fn test_rate_limiter() {
204 let limiter = RateLimiter::new(3, Duration::from_secs(1));
205 let key = "test_key";
206
207 assert!(limiter.check_rate_limit(key).unwrap());
209 assert!(limiter.check_rate_limit(key).unwrap());
210 assert!(limiter.check_rate_limit(key).unwrap());
211
212 assert!(!limiter.check_rate_limit(key).unwrap());
214
215 thread::sleep(Duration::from_millis(1100));
217
218 assert!(limiter.check_rate_limit(key).unwrap());
220 }
221
222 #[test]
223 fn test_cleanup() {
224 let limiter = RateLimiter::new(10, Duration::from_millis(100));
225
226 limiter.check_rate_limit("key1").unwrap();
227 limiter.check_rate_limit("key2").unwrap();
228
229 thread::sleep(Duration::from_millis(150));
230
231 let removed = limiter.cleanup().unwrap();
232 assert_eq!(removed, 2);
233 }
234
235 #[test]
236 fn test_zero_max_requests_denies_all() {
237 let limiter = RateLimiter::new(0, Duration::from_secs(60));
238 assert!(!limiter.check_rate_limit("key").unwrap());
239 assert!(!limiter.is_allowed("key"));
240 }
241
242 #[test]
243 fn test_single_request_limit() {
244 let limiter = RateLimiter::new(1, Duration::from_secs(60));
245 assert!(limiter.check_rate_limit("key").unwrap());
246 assert!(!limiter.check_rate_limit("key").unwrap());
247 }
248
249 #[test]
250 fn test_independent_keys() {
251 let limiter = RateLimiter::new(1, Duration::from_secs(60));
252 assert!(limiter.check_rate_limit("key1").unwrap());
253 assert!(limiter.check_rate_limit("key2").unwrap());
254 assert!(!limiter.check_rate_limit("key1").unwrap());
256 assert!(!limiter.check_rate_limit("key2").unwrap());
257 }
258
259 #[test]
260 fn test_empty_key() {
261 let limiter = RateLimiter::new(2, Duration::from_secs(60));
262 assert!(limiter.check_rate_limit("").unwrap());
263 assert!(limiter.check_rate_limit("").unwrap());
264 assert!(!limiter.check_rate_limit("").unwrap());
265 }
266
267 #[test]
268 fn test_remaining_requests_decrements() {
269 let limiter = RateLimiter::new(3, Duration::from_secs(60));
270 assert_eq!(limiter.get_remaining_requests("k").unwrap(), 3);
271 limiter.check_rate_limit("k").unwrap();
272 assert_eq!(limiter.get_remaining_requests("k").unwrap(), 2);
273 limiter.check_rate_limit("k").unwrap();
274 assert_eq!(limiter.get_remaining_requests("k").unwrap(), 1);
275 limiter.check_rate_limit("k").unwrap();
276 assert_eq!(limiter.get_remaining_requests("k").unwrap(), 0);
277 }
278
279 #[test]
280 fn test_remaining_requests_for_unknown_key() {
281 let limiter = RateLimiter::new(5, Duration::from_secs(60));
282 assert_eq!(limiter.get_remaining_requests("unknown").unwrap(), 5);
283 }
284
285 #[test]
286 fn test_get_request_count_unknown_key() {
287 let limiter = RateLimiter::new(5, Duration::from_secs(60));
288 assert_eq!(limiter.get_request_count("unknown").unwrap(), 0);
289 }
290
291 #[test]
292 fn test_reset_clears_count() {
293 let limiter = RateLimiter::new(2, Duration::from_secs(60));
294 limiter.check_rate_limit("k").unwrap();
295 limiter.check_rate_limit("k").unwrap();
296 assert!(!limiter.is_allowed("k"));
297
298 limiter.reset("k").unwrap();
299 assert!(limiter.is_allowed("k"));
300 assert_eq!(limiter.get_request_count("k").unwrap(), 1);
301 }
302
303 #[test]
304 fn test_reset_nonexistent_key_is_ok() {
305 let limiter = RateLimiter::new(5, Duration::from_secs(60));
306 assert!(limiter.reset("nonexistent").is_ok());
307 }
308
309 #[test]
310 fn test_cleanup_empty_limiter() {
311 let limiter = RateLimiter::new(5, Duration::from_secs(60));
312 assert_eq!(limiter.cleanup().unwrap(), 0);
313 }
314
315 #[test]
316 fn test_clone_shares_state() {
317 let limiter = RateLimiter::new(2, Duration::from_secs(60));
318 let limiter2 = limiter.clone();
319 limiter.check_rate_limit("k").unwrap();
320 assert_eq!(limiter2.get_request_count("k").unwrap(), 1);
322 }
323
324 #[test]
325 fn test_concurrent_access() {
326 let limiter = RateLimiter::new(100, Duration::from_secs(60));
327 let mut handles = vec![];
328
329 for i in 0..10 {
330 let l = limiter.clone();
331 handles.push(thread::spawn(move || {
332 for _ in 0..10 {
333 let _ = l.check_rate_limit(&format!("thread-{}", i));
334 }
335 }));
336 }
337
338 for h in handles {
339 h.join().unwrap();
340 }
341
342 for i in 0..10 {
344 assert_eq!(
345 limiter
346 .get_request_count(&format!("thread-{}", i))
347 .unwrap(),
348 10
349 );
350 }
351 }
352
353 #[test]
354 fn test_remaining_alias_matches() {
355 let limiter = RateLimiter::new(5, Duration::from_secs(60));
356 limiter.check_rate_limit("k").unwrap();
357 assert_eq!(
358 limiter.remaining_requests("k").unwrap(),
359 limiter.get_remaining_requests("k").unwrap()
360 );
361 }
362
363 #[test]
364 fn test_is_allowed_alias_matches() {
365 let limiter = RateLimiter::new(2, Duration::from_secs(60));
367 assert!(limiter.is_allowed("a"));
369 assert!(limiter.check_rate_limit("b").unwrap());
370 }
371
372 #[test]
373 fn test_many_keys_cleanup() {
374 let limiter = RateLimiter::new(1, Duration::from_millis(50));
375 for i in 0..100 {
376 limiter.check_rate_limit(&format!("key-{}", i)).unwrap();
377 }
378 thread::sleep(Duration::from_millis(100));
379 let removed = limiter.cleanup().unwrap();
380 assert_eq!(removed, 100);
381 }
382}