auth_framework/utils/
rate_limit.rs1use crate::errors::{AuthError, Result};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7
8#[derive(Debug, Clone)]
10pub struct RateLimiter {
11 max_requests: u32,
12 window: Duration,
13 requests: Arc<Mutex<HashMap<String, Vec<Instant>>>>,
14}
15
16impl RateLimiter {
17 pub fn new(max_requests: u32, window: Duration) -> Self {
19 Self {
20 max_requests,
21 window,
22 requests: Arc::new(Mutex::new(HashMap::new())),
23 }
24 }
25
26 pub fn check_rate_limit(&self, key: &str) -> Result<bool> {
28 let mut requests = self.requests.lock().map_err(|_| {
29 AuthError::internal("Failed to acquire rate limiter lock".to_string())
30 })?;
31
32 let now = Instant::now();
33 let entry = requests.entry(key.to_string()).or_insert_with(Vec::new);
34
35 entry.retain(|&request_time| now.duration_since(request_time) < self.window);
37
38 if entry.len() >= self.max_requests as usize {
39 return Ok(false); }
41
42 entry.push(now);
44 Ok(true)
45 }
46
47 pub fn is_allowed(&self, key: &str) -> bool {
49 self.check_rate_limit(key).unwrap_or(false)
50 }
51
52 pub fn remaining_requests(&self, key: &str) -> Result<u32> {
54 self.get_remaining_requests(key)
55 }
56
57 pub fn get_request_count(&self, key: &str) -> Result<usize> {
59 let requests = self.requests.lock().map_err(|_| {
60 AuthError::internal("Failed to acquire rate limiter lock".to_string())
61 })?;
62
63 let now = Instant::now();
64 if let Some(entry) = requests.get(key) {
65 let valid_requests = entry
66 .iter()
67 .filter(|&&request_time| now.duration_since(request_time) < self.window)
68 .count();
69 Ok(valid_requests)
70 } else {
71 Ok(0)
72 }
73 }
74
75 pub fn cleanup(&self) -> Result<usize> {
77 let mut requests = self.requests.lock().map_err(|_| {
78 AuthError::internal("Failed to acquire rate limiter lock".to_string())
79 })?;
80
81 let now = Instant::now();
82 let mut removed_count = 0;
83
84 requests.retain(|_, entry| {
85 entry.retain(|&request_time| now.duration_since(request_time) < self.window);
86 if entry.is_empty() {
87 removed_count += 1;
88 false
89 } else {
90 true
91 }
92 });
93
94 Ok(removed_count)
95 }
96
97 pub fn reset(&self, key: &str) -> Result<()> {
99 let mut requests = self.requests.lock().map_err(|_| {
100 AuthError::internal("Failed to acquire rate limiter lock".to_string())
101 })?;
102
103 requests.remove(key);
104 Ok(())
105 }
106
107 pub fn get_remaining_requests(&self, key: &str) -> Result<u32> {
109 let count = self.get_request_count(key)?;
110 Ok(self.max_requests.saturating_sub(count as u32))
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use std::thread;
118
119 #[test]
120 fn test_rate_limiter() {
121 let limiter = RateLimiter::new(3, Duration::from_secs(1));
122 let key = "test_key";
123
124 assert!(limiter.check_rate_limit(key).unwrap());
126 assert!(limiter.check_rate_limit(key).unwrap());
127 assert!(limiter.check_rate_limit(key).unwrap());
128
129 assert!(!limiter.check_rate_limit(key).unwrap());
131
132 thread::sleep(Duration::from_millis(1100));
134
135 assert!(limiter.check_rate_limit(key).unwrap());
137 }
138
139 #[test]
140 fn test_cleanup() {
141 let limiter = RateLimiter::new(10, Duration::from_millis(100));
142
143 limiter.check_rate_limit("key1").unwrap();
144 limiter.check_rate_limit("key2").unwrap();
145
146 thread::sleep(Duration::from_millis(150));
147
148 let removed = limiter.cleanup().unwrap();
149 assert_eq!(removed, 2);
150 }
151}