1use std::{
5 collections::HashMap,
6 sync::{Arc, Mutex},
7 time::{SystemTime, UNIX_EPOCH},
8};
9
10use crate::auth::error::{AuthError, Result};
11
12#[derive(Debug, Clone)]
14pub struct RateLimitConfig {
15 pub max_requests: u32,
17 pub window_secs: u64,
19}
20
21impl RateLimitConfig {
22 pub fn per_ip_standard() -> Self {
25 Self {
26 max_requests: 100,
27 window_secs: 60,
28 }
29 }
30
31 pub fn per_ip_strict() -> Self {
34 Self {
35 max_requests: 50,
36 window_secs: 60,
37 }
38 }
39
40 pub fn per_user_standard() -> Self {
43 Self {
44 max_requests: 10,
45 window_secs: 60,
46 }
47 }
48
49 pub fn failed_login_attempts() -> Self {
52 Self {
53 max_requests: 5,
54 window_secs: 3600,
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61struct RequestRecord {
62 count: u32,
64 window_start: u64,
66}
67
68pub struct KeyedRateLimiter {
71 records: Arc<Mutex<HashMap<String, RequestRecord>>>,
72 config: RateLimitConfig,
73}
74
75impl KeyedRateLimiter {
76 pub fn new(config: RateLimitConfig) -> Self {
78 Self {
79 records: Arc::new(Mutex::new(HashMap::new())),
80 config,
81 }
82 }
83
84 fn current_timestamp() -> u64 {
86 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs()
87 }
88
89 pub fn check(&self, key: &str) -> Result<()> {
93 let mut records = self.records.lock().unwrap();
94 let now = Self::current_timestamp();
95
96 let record = records.entry(key.to_string()).or_insert_with(|| RequestRecord {
97 count: 0,
98 window_start: now,
99 });
100
101 if now >= record.window_start + self.config.window_secs {
103 record.count = 1;
105 record.window_start = now;
106 Ok(())
107 } else if record.count < self.config.max_requests {
108 record.count += 1;
110 Ok(())
111 } else {
112 Err(AuthError::RateLimited {
114 retry_after_secs: self.config.window_secs,
115 })
116 }
117 }
118
119 pub fn active_limiters(&self) -> usize {
121 let records = self.records.lock().unwrap();
122 records.len()
123 }
124
125 pub fn clear(&self) {
127 let mut records = self.records.lock().unwrap();
128 records.clear();
129 }
130
131 pub fn clone_config(&self) -> RateLimitConfig {
133 self.config.clone()
134 }
135}
136
137pub struct RateLimiters {
139 pub auth_start: KeyedRateLimiter,
141 pub auth_callback: KeyedRateLimiter,
143 pub auth_refresh: KeyedRateLimiter,
145 pub auth_logout: KeyedRateLimiter,
147 pub failed_logins: KeyedRateLimiter,
149}
150
151impl RateLimiters {
152 pub fn new() -> Self {
154 Self {
155 auth_start: KeyedRateLimiter::new(RateLimitConfig::per_ip_standard()),
156 auth_callback: KeyedRateLimiter::new(RateLimitConfig::per_ip_strict()),
157 auth_refresh: KeyedRateLimiter::new(RateLimitConfig::per_user_standard()),
158 auth_logout: KeyedRateLimiter::new(RateLimitConfig::per_user_standard()),
159 failed_logins: KeyedRateLimiter::new(RateLimitConfig::failed_login_attempts()),
160 }
161 }
162
163 pub fn with_configs(
165 start_cfg: RateLimitConfig,
166 callback_cfg: RateLimitConfig,
167 refresh_cfg: RateLimitConfig,
168 logout_cfg: RateLimitConfig,
169 failed_cfg: RateLimitConfig,
170 ) -> Self {
171 Self {
172 auth_start: KeyedRateLimiter::new(start_cfg),
173 auth_callback: KeyedRateLimiter::new(callback_cfg),
174 auth_refresh: KeyedRateLimiter::new(refresh_cfg),
175 auth_logout: KeyedRateLimiter::new(logout_cfg),
176 failed_logins: KeyedRateLimiter::new(failed_cfg),
177 }
178 }
179}
180
181impl Default for RateLimiters {
182 fn default() -> Self {
183 Self::new()
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_rate_limiter_allows_within_limit() {
193 let limiter = KeyedRateLimiter::new(RateLimitConfig {
194 max_requests: 3,
195 window_secs: 60,
196 });
197
198 for i in 0..3 {
200 let result = limiter.check("key");
201 assert!(result.is_ok(), "Request {} should be allowed", i);
202 }
203 }
204
205 #[test]
206 fn test_rate_limiter_rejects_over_limit() {
207 let limiter = KeyedRateLimiter::new(RateLimitConfig {
208 max_requests: 2,
209 window_secs: 60,
210 });
211
212 limiter.check("key").ok();
213 limiter.check("key").ok();
214
215 let result = limiter.check("key");
217 assert!(result.is_err(), "Request over limit should fail");
218 }
219
220 #[test]
221 fn test_rate_limiter_per_key() {
222 let limiter = KeyedRateLimiter::new(RateLimitConfig {
223 max_requests: 2,
224 window_secs: 60,
225 });
226
227 limiter.check("key1").ok();
229 limiter.check("key1").ok();
230
231 let result = limiter.check("key2");
233 assert!(result.is_ok(), "Different key should have independent limit");
234 }
235
236 #[test]
237 fn test_rate_limiter_error_contains_retry_after() {
238 let limiter = KeyedRateLimiter::new(RateLimitConfig {
239 max_requests: 1,
240 window_secs: 60,
241 });
242
243 limiter.check("key").ok();
244 let result = limiter.check("key");
245
246 match result {
247 Err(AuthError::RateLimited { retry_after_secs }) => {
248 assert_eq!(retry_after_secs, 60);
249 },
250 _ => panic!("Expected RateLimited error"),
251 }
252 }
253
254 #[test]
255 fn test_rate_limiter_active_limiters_count() {
256 let limiter = KeyedRateLimiter::new(RateLimitConfig {
257 max_requests: 100,
258 window_secs: 60,
259 });
260
261 assert_eq!(limiter.active_limiters(), 0);
262
263 limiter.check("key1").ok();
264 assert_eq!(limiter.active_limiters(), 1);
265
266 limiter.check("key2").ok();
267 assert_eq!(limiter.active_limiters(), 2);
268 }
269
270 #[test]
271 fn test_rate_limiters_default() {
272 let limiters = RateLimiters::new();
273
274 let result = limiters.auth_start.check("ip_1");
276 assert!(result.is_ok());
277
278 let result = limiters.auth_refresh.check("user_1");
280 assert!(result.is_ok());
281 }
282
283 #[test]
284 fn test_rate_limit_config_presets() {
285 let standard_ip = RateLimitConfig::per_ip_standard();
286 assert_eq!(standard_ip.max_requests, 100);
287 assert_eq!(standard_ip.window_secs, 60);
288
289 let strict_ip = RateLimitConfig::per_ip_strict();
290 assert_eq!(strict_ip.max_requests, 50);
291
292 let user_limit = RateLimitConfig::per_user_standard();
293 assert_eq!(user_limit.max_requests, 10);
294
295 let failed = RateLimitConfig::failed_login_attempts();
296 assert_eq!(failed.max_requests, 5);
297 assert_eq!(failed.window_secs, 3600);
298 }
299
300 #[test]
301 fn test_ip_based_rate_limiting() {
302 let limiter = KeyedRateLimiter::new(RateLimitConfig::per_ip_standard());
303
304 let ip = "203.0.113.1";
305
306 for _ in 0..100 {
308 let result = limiter.check(ip);
309 assert!(result.is_ok());
310 }
311
312 let result = limiter.check(ip);
314 assert!(result.is_err());
315 }
316
317 #[test]
318 fn test_failed_login_tracking() {
319 let limiter = KeyedRateLimiter::new(RateLimitConfig::failed_login_attempts());
320
321 let user = "alice@example.com";
322
323 for _ in 0..5 {
325 let result = limiter.check(user);
326 assert!(result.is_ok());
327 }
328
329 let result = limiter.check(user);
331 assert!(result.is_err());
332 }
333
334 #[test]
335 fn test_multiple_users_independent() {
336 let limiter = KeyedRateLimiter::new(RateLimitConfig::failed_login_attempts());
337
338 for _ in 0..5 {
340 limiter.check("user1").ok();
341 }
342
343 let result = limiter.check("user1");
345 assert!(result.is_err());
346
347 let result = limiter.check("user2");
349 assert!(result.is_ok());
350 }
351
352 #[test]
353 fn test_clear_limiters() {
354 let limiter = KeyedRateLimiter::new(RateLimitConfig {
355 max_requests: 1,
356 window_secs: 60,
357 });
358
359 limiter.check("key").ok();
360 let result = limiter.check("key");
361 assert!(result.is_err());
362
363 limiter.clear();
364
365 let result = limiter.check("key");
367 assert!(result.is_ok());
368 }
369
370 #[test]
371 fn test_thread_safe_rate_limiting() {
372 use std::sync::Arc as StdArc;
373
374 let limiter = StdArc::new(KeyedRateLimiter::new(RateLimitConfig {
375 max_requests: 100,
376 window_secs: 60,
377 }));
378
379 let mut handles = vec![];
380
381 for _ in 0..10 {
382 let limiter_clone = StdArc::clone(&limiter);
383 let handle = std::thread::spawn(move || {
384 for _ in 0..10 {
385 let _ = limiter_clone.check("concurrent");
386 }
387 });
388 handles.push(handle);
389 }
390
391 for handle in handles {
392 handle.join().ok();
393 }
394
395 let result = limiter.check("concurrent");
397 assert!(result.is_err());
398 }
399
400 #[test]
401 fn test_rate_limiting_many_keys() {
402 let limiter = KeyedRateLimiter::new(RateLimitConfig {
403 max_requests: 10,
404 window_secs: 60,
405 });
406
407 for i in 0..1000 {
409 let key = format!("192.168.{}.{}", i / 256, i % 256);
410 let result = limiter.check(&key);
411 assert!(result.is_ok());
412 }
413
414 assert_eq!(limiter.active_limiters(), 1000);
415 }
416
417 #[test]
418 fn test_endpoint_combinations() {
419 let limiters = RateLimiters::new();
420
421 let ip = "203.0.113.1";
422 let user = "bob@example.com";
423
424 let result = limiters.auth_start.check(ip);
426 assert!(result.is_ok());
427
428 let result = limiters.auth_callback.check(ip);
429 assert!(result.is_ok());
430
431 let result = limiters.auth_refresh.check(user);
432 assert!(result.is_ok());
433
434 let result = limiters.auth_logout.check(user);
435 assert!(result.is_ok());
436
437 let result = limiters.failed_logins.check(user);
438 assert!(result.is_ok());
439 }
440
441 #[test]
442 fn test_attack_prevention_scenario() {
443 let limiter = KeyedRateLimiter::new(RateLimitConfig {
444 max_requests: 10,
445 window_secs: 60,
446 });
447
448 let target = "admin@example.com";
449
450 for _ in 0..10 {
452 let _ = limiter.check(target);
453 }
454
455 let result = limiter.check(target);
457 assert!(result.is_err());
458 }
459}