Skip to main content

fraiseql_server/auth/
rate_limiting.rs

1// Rate limiting for brute-force protection
2// Uses an in-memory approach with Arc and Mutex for simplicity
3
4use std::{
5    collections::HashMap,
6    sync::{Arc, Mutex},
7    time::{SystemTime, UNIX_EPOCH},
8};
9
10use crate::auth::error::{AuthError, Result};
11
12/// Rate limit configuration for an endpoint
13#[derive(Debug, Clone)]
14pub struct RateLimitConfig {
15    /// Maximum number of requests allowed in the window
16    pub max_requests: u32,
17    /// Window duration in seconds
18    pub window_secs:  u64,
19}
20
21impl RateLimitConfig {
22    /// IP-based rate limiting for public endpoints
23    /// 100 requests per 60 seconds (typical for auth/start, auth/callback)
24    pub fn per_ip_standard() -> Self {
25        Self {
26            max_requests: 100,
27            window_secs:  60,
28        }
29    }
30
31    /// Stricter IP-based rate limiting for sensitive endpoints
32    /// 50 requests per 60 seconds
33    pub fn per_ip_strict() -> Self {
34        Self {
35            max_requests: 50,
36            window_secs:  60,
37        }
38    }
39
40    /// User-based rate limiting for authenticated endpoints
41    /// 10 requests per 60 seconds
42    pub fn per_user_standard() -> Self {
43        Self {
44            max_requests: 10,
45            window_secs:  60,
46        }
47    }
48
49    /// Failed login attempt limiting
50    /// 5 failed attempts per 3600 seconds (1 hour)
51    pub fn failed_login_attempts() -> Self {
52        Self {
53            max_requests: 5,
54            window_secs:  3600,
55        }
56    }
57}
58
59/// Request record for tracking
60#[derive(Debug, Clone)]
61struct RequestRecord {
62    /// Number of requests in current window
63    count:        u32,
64    /// Unix timestamp of window start
65    window_start: u64,
66}
67
68/// Per-key rate limiter using in-memory tracking
69/// Maintains separate rate limits for each key (IP, user ID, etc.)
70pub struct KeyedRateLimiter {
71    records: Arc<Mutex<HashMap<String, RequestRecord>>>,
72    config:  RateLimitConfig,
73}
74
75impl KeyedRateLimiter {
76    /// Create a new keyed rate limiter
77    pub fn new(config: RateLimitConfig) -> Self {
78        Self {
79            records: Arc::new(Mutex::new(HashMap::new())),
80            config,
81        }
82    }
83
84    /// Get current Unix timestamp in seconds
85    fn current_timestamp() -> u64 {
86        SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs()
87    }
88
89    /// Check if a request should be allowed for the given key
90    ///
91    /// Returns Ok(()) if allowed, Err with status code if rate limited
92    pub fn check(&self, key: &str) -> Result<()> {
93        let mut records = self.records.lock().unwrap();
94        let now = Self::current_timestamp();
95
96        let record = records.entry(key.to_string()).or_insert_with(|| RequestRecord {
97            count:        0,
98            window_start: now,
99        });
100
101        // Check if window has expired
102        if now >= record.window_start + self.config.window_secs {
103            // Reset window
104            record.count = 1;
105            record.window_start = now;
106            Ok(())
107        } else if record.count < self.config.max_requests {
108            // Request allowed
109            record.count += 1;
110            Ok(())
111        } else {
112            // Rate limited
113            Err(AuthError::RateLimited {
114                retry_after_secs: self.config.window_secs,
115            })
116        }
117    }
118
119    /// Get the number of active rate limiters (for monitoring)
120    pub fn active_limiters(&self) -> usize {
121        let records = self.records.lock().unwrap();
122        records.len()
123    }
124
125    /// Clear all rate limiters (for testing or reset)
126    pub fn clear(&self) {
127        let mut records = self.records.lock().unwrap();
128        records.clear();
129    }
130
131    /// Create a copy for independent testing
132    pub fn clone_config(&self) -> RateLimitConfig {
133        self.config.clone()
134    }
135}
136
137/// Global rate limiters for different endpoints
138pub struct RateLimiters {
139    /// auth/start: per-IP, 100 req/min
140    pub auth_start:    KeyedRateLimiter,
141    /// auth/callback: per-IP, 50 req/min
142    pub auth_callback: KeyedRateLimiter,
143    /// auth/refresh: per-user, 10 req/min
144    pub auth_refresh:  KeyedRateLimiter,
145    /// auth/logout: per-user, 20 req/min
146    pub auth_logout:   KeyedRateLimiter,
147    /// Failed login tracking: per-user, 5 attempts/hour
148    pub failed_logins: KeyedRateLimiter,
149}
150
151impl RateLimiters {
152    /// Create default rate limiters for all endpoints
153    pub fn new() -> Self {
154        Self {
155            auth_start:    KeyedRateLimiter::new(RateLimitConfig::per_ip_standard()),
156            auth_callback: KeyedRateLimiter::new(RateLimitConfig::per_ip_strict()),
157            auth_refresh:  KeyedRateLimiter::new(RateLimitConfig::per_user_standard()),
158            auth_logout:   KeyedRateLimiter::new(RateLimitConfig::per_user_standard()),
159            failed_logins: KeyedRateLimiter::new(RateLimitConfig::failed_login_attempts()),
160        }
161    }
162
163    /// Create with custom configurations
164    pub fn with_configs(
165        start_cfg: RateLimitConfig,
166        callback_cfg: RateLimitConfig,
167        refresh_cfg: RateLimitConfig,
168        logout_cfg: RateLimitConfig,
169        failed_cfg: RateLimitConfig,
170    ) -> Self {
171        Self {
172            auth_start:    KeyedRateLimiter::new(start_cfg),
173            auth_callback: KeyedRateLimiter::new(callback_cfg),
174            auth_refresh:  KeyedRateLimiter::new(refresh_cfg),
175            auth_logout:   KeyedRateLimiter::new(logout_cfg),
176            failed_logins: KeyedRateLimiter::new(failed_cfg),
177        }
178    }
179}
180
181impl Default for RateLimiters {
182    fn default() -> Self {
183        Self::new()
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn test_rate_limiter_allows_within_limit() {
193        let limiter = KeyedRateLimiter::new(RateLimitConfig {
194            max_requests: 3,
195            window_secs:  60,
196        });
197
198        // Should allow up to max_requests
199        for i in 0..3 {
200            let result = limiter.check("key");
201            assert!(result.is_ok(), "Request {} should be allowed", i);
202        }
203    }
204
205    #[test]
206    fn test_rate_limiter_rejects_over_limit() {
207        let limiter = KeyedRateLimiter::new(RateLimitConfig {
208            max_requests: 2,
209            window_secs:  60,
210        });
211
212        limiter.check("key").ok();
213        limiter.check("key").ok();
214
215        // Third should fail
216        let result = limiter.check("key");
217        assert!(result.is_err(), "Request over limit should fail");
218    }
219
220    #[test]
221    fn test_rate_limiter_per_key() {
222        let limiter = KeyedRateLimiter::new(RateLimitConfig {
223            max_requests: 2,
224            window_secs:  60,
225        });
226
227        // Key 1: use allowance
228        limiter.check("key1").ok();
229        limiter.check("key1").ok();
230
231        // Key 2: should have fresh allowance
232        let result = limiter.check("key2");
233        assert!(result.is_ok(), "Different key should have independent limit");
234    }
235
236    #[test]
237    fn test_rate_limiter_error_contains_retry_after() {
238        let limiter = KeyedRateLimiter::new(RateLimitConfig {
239            max_requests: 1,
240            window_secs:  60,
241        });
242
243        limiter.check("key").ok();
244        let result = limiter.check("key");
245
246        match result {
247            Err(AuthError::RateLimited { retry_after_secs }) => {
248                assert_eq!(retry_after_secs, 60);
249            },
250            _ => panic!("Expected RateLimited error"),
251        }
252    }
253
254    #[test]
255    fn test_rate_limiter_active_limiters_count() {
256        let limiter = KeyedRateLimiter::new(RateLimitConfig {
257            max_requests: 100,
258            window_secs:  60,
259        });
260
261        assert_eq!(limiter.active_limiters(), 0);
262
263        limiter.check("key1").ok();
264        assert_eq!(limiter.active_limiters(), 1);
265
266        limiter.check("key2").ok();
267        assert_eq!(limiter.active_limiters(), 2);
268    }
269
270    #[test]
271    fn test_rate_limiters_default() {
272        let limiters = RateLimiters::new();
273
274        // auth/start should allow requests
275        let result = limiters.auth_start.check("ip_1");
276        assert!(result.is_ok());
277
278        // auth/refresh should track per-user
279        let result = limiters.auth_refresh.check("user_1");
280        assert!(result.is_ok());
281    }
282
283    #[test]
284    fn test_rate_limit_config_presets() {
285        let standard_ip = RateLimitConfig::per_ip_standard();
286        assert_eq!(standard_ip.max_requests, 100);
287        assert_eq!(standard_ip.window_secs, 60);
288
289        let strict_ip = RateLimitConfig::per_ip_strict();
290        assert_eq!(strict_ip.max_requests, 50);
291
292        let user_limit = RateLimitConfig::per_user_standard();
293        assert_eq!(user_limit.max_requests, 10);
294
295        let failed = RateLimitConfig::failed_login_attempts();
296        assert_eq!(failed.max_requests, 5);
297        assert_eq!(failed.window_secs, 3600);
298    }
299
300    #[test]
301    fn test_ip_based_rate_limiting() {
302        let limiter = KeyedRateLimiter::new(RateLimitConfig::per_ip_standard());
303
304        let ip = "203.0.113.1";
305
306        // Should allow up to 100 requests
307        for _ in 0..100 {
308            let result = limiter.check(ip);
309            assert!(result.is_ok());
310        }
311
312        // 101st should fail
313        let result = limiter.check(ip);
314        assert!(result.is_err());
315    }
316
317    #[test]
318    fn test_failed_login_tracking() {
319        let limiter = KeyedRateLimiter::new(RateLimitConfig::failed_login_attempts());
320
321        let user = "alice@example.com";
322
323        // Should allow 5 failed attempts
324        for _ in 0..5 {
325            let result = limiter.check(user);
326            assert!(result.is_ok());
327        }
328
329        // 6th should fail
330        let result = limiter.check(user);
331        assert!(result.is_err());
332    }
333
334    #[test]
335    fn test_multiple_users_independent() {
336        let limiter = KeyedRateLimiter::new(RateLimitConfig::failed_login_attempts());
337
338        // User 1 uses attempts
339        for _ in 0..5 {
340            limiter.check("user1").ok();
341        }
342
343        // User 1 blocked
344        let result = limiter.check("user1");
345        assert!(result.is_err());
346
347        // User 2 should have fresh attempts
348        let result = limiter.check("user2");
349        assert!(result.is_ok());
350    }
351
352    #[test]
353    fn test_clear_limiters() {
354        let limiter = KeyedRateLimiter::new(RateLimitConfig {
355            max_requests: 1,
356            window_secs:  60,
357        });
358
359        limiter.check("key").ok();
360        let result = limiter.check("key");
361        assert!(result.is_err());
362
363        limiter.clear();
364
365        // After clear, should allow again
366        let result = limiter.check("key");
367        assert!(result.is_ok());
368    }
369
370    #[test]
371    fn test_thread_safe_rate_limiting() {
372        use std::sync::Arc as StdArc;
373
374        let limiter = StdArc::new(KeyedRateLimiter::new(RateLimitConfig {
375            max_requests: 100,
376            window_secs:  60,
377        }));
378
379        let mut handles = vec![];
380
381        for _ in 0..10 {
382            let limiter_clone = StdArc::clone(&limiter);
383            let handle = std::thread::spawn(move || {
384                for _ in 0..10 {
385                    let _ = limiter_clone.check("concurrent");
386                }
387            });
388            handles.push(handle);
389        }
390
391        for handle in handles {
392            handle.join().ok();
393        }
394
395        // After 100 concurrent requests, next should fail
396        let result = limiter.check("concurrent");
397        assert!(result.is_err());
398    }
399
400    #[test]
401    fn test_rate_limiting_many_keys() {
402        let limiter = KeyedRateLimiter::new(RateLimitConfig {
403            max_requests: 10,
404            window_secs:  60,
405        });
406
407        // Simulate 1000 different IPs, each with requests
408        for i in 0..1000 {
409            let key = format!("192.168.{}.{}", i / 256, i % 256);
410            let result = limiter.check(&key);
411            assert!(result.is_ok());
412        }
413
414        assert_eq!(limiter.active_limiters(), 1000);
415    }
416
417    #[test]
418    fn test_endpoint_combinations() {
419        let limiters = RateLimiters::new();
420
421        let ip = "203.0.113.1";
422        let user = "bob@example.com";
423
424        // Complete flow
425        let result = limiters.auth_start.check(ip);
426        assert!(result.is_ok());
427
428        let result = limiters.auth_callback.check(ip);
429        assert!(result.is_ok());
430
431        let result = limiters.auth_refresh.check(user);
432        assert!(result.is_ok());
433
434        let result = limiters.auth_logout.check(user);
435        assert!(result.is_ok());
436
437        let result = limiters.failed_logins.check(user);
438        assert!(result.is_ok());
439    }
440
441    #[test]
442    fn test_attack_prevention_scenario() {
443        let limiter = KeyedRateLimiter::new(RateLimitConfig {
444            max_requests: 10,
445            window_secs:  60,
446        });
447
448        let target = "admin@example.com";
449
450        // Attacker tries 10 failed attempts
451        for _ in 0..10 {
452            let _ = limiter.check(target);
453        }
454
455        // 11th blocked
456        let result = limiter.check(target);
457        assert!(result.is_err());
458    }
459}