codex_memory/security/
rate_limit.rs

1use crate::security::{RateLimitConfig, Result, SecurityError};
2use axum::{
3    extract::{ConnectInfo, Request, State},
4    http::{HeaderMap, StatusCode},
5    middleware::Next,
6    response::Response,
7};
8use governor::{
9    clock::DefaultClock,
10    middleware::NoOpMiddleware,
11    state::{InMemoryState, NotKeyed},
12    Quota, RateLimiter as GovernorRateLimiter,
13};
14use std::collections::HashMap;
15use std::net::{IpAddr, SocketAddr};
16use std::num::NonZeroU32;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::{debug, info, warn};
20
21// Type aliases to reduce complexity
22type GovernorLimiter = GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>;
23type IpLimiters = Arc<RwLock<HashMap<IpAddr, Arc<GovernorLimiter>>>>;
24type UserLimiters = Arc<RwLock<HashMap<String, Arc<GovernorLimiter>>>>;
25
26/// Rate limiting manager
27pub struct RateLimitManager {
28    config: RateLimitConfig,
29    ip_limiters: IpLimiters,
30    user_limiters: UserLimiters,
31    global_limiter: Option<Arc<GovernorLimiter>>,
32}
33
34impl RateLimitManager {
35    pub fn new(config: RateLimitConfig) -> Self {
36        let global_limiter = if config.enabled {
37            let quota = Quota::per_minute(
38                NonZeroU32::new(config.requests_per_minute)
39                    .unwrap_or(NonZeroU32::new(100).unwrap()),
40            );
41            Some(Arc::new(GovernorRateLimiter::direct(quota)))
42        } else {
43            None
44        };
45
46        Self {
47            config,
48            ip_limiters: Arc::new(RwLock::new(HashMap::new())),
49            user_limiters: Arc::new(RwLock::new(HashMap::new())),
50            global_limiter,
51        }
52    }
53
54    /// Check rate limit for IP address
55    pub async fn check_ip_limit(&self, ip: IpAddr) -> Result<()> {
56        if !self.config.enabled || !self.config.per_ip {
57            return Ok(());
58        }
59
60        // Check whitelist
61        let ip_str = ip.to_string();
62        if self.config.whitelist_ips.contains(&ip_str) {
63            debug!("IP {} is whitelisted, bypassing rate limit", ip);
64            return Ok(());
65        }
66
67        let mut limiters = self.ip_limiters.write().await;
68
69        let limiter = limiters.entry(ip).or_insert_with(|| {
70            let quota = Quota::per_minute(
71                NonZeroU32::new(self.config.requests_per_minute)
72                    .unwrap_or(NonZeroU32::new(100).unwrap()),
73            )
74            .allow_burst(
75                NonZeroU32::new(self.config.burst_size).unwrap_or(NonZeroU32::new(10).unwrap()),
76            );
77            Arc::new(GovernorRateLimiter::direct(quota))
78        });
79
80        let limiter = Arc::clone(limiter);
81        drop(limiters); // Release lock before checking
82
83        match limiter.check() {
84            Ok(_) => {
85                debug!("Rate limit check passed for IP: {}", ip);
86                Ok(())
87            }
88            Err(_) => {
89                warn!("Rate limit exceeded for IP: {}", ip);
90                Err(SecurityError::RateLimitExceeded)
91            }
92        }
93    }
94
95    /// Check rate limit for user
96    pub async fn check_user_limit(&self, user_id: &str) -> Result<()> {
97        if !self.config.enabled || !self.config.per_user {
98            return Ok(());
99        }
100
101        let mut limiters = self.user_limiters.write().await;
102
103        let limiter = limiters.entry(user_id.to_string()).or_insert_with(|| {
104            let quota = Quota::per_minute(
105                NonZeroU32::new(self.config.requests_per_minute)
106                    .unwrap_or(NonZeroU32::new(100).unwrap()),
107            )
108            .allow_burst(
109                NonZeroU32::new(self.config.burst_size).unwrap_or(NonZeroU32::new(10).unwrap()),
110            );
111            Arc::new(GovernorRateLimiter::direct(quota))
112        });
113
114        let limiter = Arc::clone(limiter);
115        drop(limiters); // Release lock before checking
116
117        match limiter.check() {
118            Ok(_) => {
119                debug!("Rate limit check passed for user: {}", user_id);
120                Ok(())
121            }
122            Err(_) => {
123                warn!("Rate limit exceeded for user: {}", user_id);
124                Err(SecurityError::RateLimitExceeded)
125            }
126        }
127    }
128
129    /// Check global rate limit
130    pub async fn check_global_limit(&self) -> Result<()> {
131        if !self.config.enabled {
132            return Ok(());
133        }
134
135        if let Some(limiter) = &self.global_limiter {
136            match limiter.check() {
137                Ok(_) => {
138                    debug!("Global rate limit check passed");
139                    Ok(())
140                }
141                Err(_) => {
142                    warn!("Global rate limit exceeded");
143                    Err(SecurityError::RateLimitExceeded)
144                }
145            }
146        } else {
147            Ok(())
148        }
149    }
150
151    /// Clean up old limiters to prevent memory leaks
152    pub async fn cleanup_limiters(&self) -> Result<()> {
153        let mut ip_limiters = self.ip_limiters.write().await;
154        let mut user_limiters = self.user_limiters.write().await;
155
156        let initial_ip_count = ip_limiters.len();
157        let initial_user_count = user_limiters.len();
158
159        // Remove limiters that haven't been used recently
160        // This is a simplified cleanup - in production, you might want more sophisticated logic
161        ip_limiters.retain(|_, limiter| Arc::strong_count(limiter) > 1);
162        user_limiters.retain(|_, limiter| Arc::strong_count(limiter) > 1);
163
164        let cleaned_ip = initial_ip_count - ip_limiters.len();
165        let cleaned_user = initial_user_count - user_limiters.len();
166
167        if cleaned_ip > 0 || cleaned_user > 0 {
168            info!(
169                "Cleaned up {} IP limiters and {} user limiters",
170                cleaned_ip, cleaned_user
171            );
172        }
173
174        Ok(())
175    }
176
177    /// Get rate limit statistics
178    pub async fn get_statistics(&self) -> RateLimitStatistics {
179        let ip_limiters = self.ip_limiters.read().await;
180        let user_limiters = self.user_limiters.read().await;
181
182        RateLimitStatistics {
183            enabled: self.config.enabled,
184            requests_per_minute: self.config.requests_per_minute,
185            burst_size: self.config.burst_size,
186            active_ip_limiters: ip_limiters.len(),
187            active_user_limiters: user_limiters.len(),
188            per_ip_enabled: self.config.per_ip,
189            per_user_enabled: self.config.per_user,
190            whitelist_count: self.config.whitelist_ips.len(),
191        }
192    }
193
194    pub fn is_enabled(&self) -> bool {
195        self.config.enabled
196    }
197}
198
199#[derive(Debug, Clone, serde::Serialize)]
200pub struct RateLimitStatistics {
201    pub enabled: bool,
202    pub requests_per_minute: u32,
203    pub burst_size: u32,
204    pub active_ip_limiters: usize,
205    pub active_user_limiters: usize,
206    pub per_ip_enabled: bool,
207    pub per_user_enabled: bool,
208    pub whitelist_count: usize,
209}
210
211/// Rate limiting middleware for Axum
212pub async fn rate_limit_middleware(
213    State(rate_limiter): State<Arc<RateLimitManager>>,
214    ConnectInfo(addr): ConnectInfo<SocketAddr>,
215    headers: HeaderMap,
216    request: Request,
217    next: Next,
218) -> std::result::Result<Response, StatusCode> {
219    if !rate_limiter.is_enabled() {
220        return Ok(next.run(request).await);
221    }
222
223    // Check global rate limit first
224    if rate_limiter.check_global_limit().await.is_err() {
225        warn!("Global rate limit exceeded");
226        return Err(StatusCode::TOO_MANY_REQUESTS);
227    }
228
229    // Check IP-based rate limit
230    let ip = addr.ip();
231    if rate_limiter.check_ip_limit(ip).await.is_err() {
232        warn!("IP rate limit exceeded for: {}", ip);
233        return Err(StatusCode::TOO_MANY_REQUESTS);
234    }
235
236    // Check user-based rate limit if user is authenticated
237    if rate_limiter.config.per_user {
238        if let Some(user_header) = headers.get("X-User-ID") {
239            if let Ok(user_id) = user_header.to_str() {
240                if rate_limiter.check_user_limit(user_id).await.is_err() {
241                    warn!("User rate limit exceeded for: {}", user_id);
242                    return Err(StatusCode::TOO_MANY_REQUESTS);
243                }
244            }
245        }
246    }
247
248    debug!("Rate limit checks passed for IP: {}", ip);
249    Ok(next.run(request).await)
250}
251
252/// Create rate limit middleware with custom configuration
253pub fn create_rate_limit_middleware(
254    requests_per_minute: u32,
255    burst_size: u32,
256    whitelist_ips: Vec<String>,
257) -> RateLimitManager {
258    let config = RateLimitConfig {
259        enabled: true,
260        requests_per_minute,
261        burst_size,
262        per_ip: true,
263        per_user: true,
264        whitelist_ips,
265    };
266
267    RateLimitManager::new(config)
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use std::net::Ipv4Addr;
274
275    #[tokio::test]
276    async fn test_rate_limit_manager_creation() {
277        let config = RateLimitConfig::default();
278        let manager = RateLimitManager::new(config);
279        assert!(!manager.is_enabled());
280    }
281
282    #[tokio::test]
283    async fn test_disabled_rate_limiting() {
284        let config = RateLimitConfig::default(); // disabled by default
285        let manager = RateLimitManager::new(config);
286
287        let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
288        let result = manager.check_ip_limit(ip).await;
289        assert!(result.is_ok());
290    }
291
292    #[tokio::test]
293    async fn test_ip_whitelist() {
294        let config = RateLimitConfig {
295            enabled: true,
296            requests_per_minute: 1, // Very low limit
297            burst_size: 1,
298            per_ip: true,
299            per_user: false,
300            whitelist_ips: vec!["192.168.1.1".to_string()],
301        };
302
303        let manager = RateLimitManager::new(config);
304        let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
305
306        // Should pass even with low limit due to whitelist
307        let result = manager.check_ip_limit(ip).await;
308        assert!(result.is_ok());
309    }
310
311    #[tokio::test]
312    async fn test_rate_limit_exceeded() {
313        let config = RateLimitConfig {
314            enabled: true,
315            requests_per_minute: 1,
316            burst_size: 1,
317            per_ip: true,
318            per_user: false,
319            whitelist_ips: Vec::new(),
320        };
321
322        let manager = RateLimitManager::new(config);
323        let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
324
325        // First request should pass
326        let result1 = manager.check_ip_limit(ip).await;
327        assert!(result1.is_ok());
328
329        // Second request should fail (rate limit exceeded)
330        let result2 = manager.check_ip_limit(ip).await;
331        assert!(result2.is_err());
332
333        if let Err(SecurityError::RateLimitExceeded) = result2 {
334            // Expected error
335        } else {
336            panic!("Expected RateLimitExceeded error");
337        }
338    }
339
340    #[tokio::test]
341    async fn test_user_rate_limiting() {
342        let config = RateLimitConfig {
343            enabled: true,
344            requests_per_minute: 1,
345            burst_size: 1,
346            per_ip: false,
347            per_user: true,
348            whitelist_ips: Vec::new(),
349        };
350
351        let manager = RateLimitManager::new(config);
352        let user_id = "test-user";
353
354        // First request should pass
355        let result1 = manager.check_user_limit(user_id).await;
356        assert!(result1.is_ok());
357
358        // Second request should fail
359        let result2 = manager.check_user_limit(user_id).await;
360        assert!(result2.is_err());
361    }
362
363    #[tokio::test]
364    async fn test_global_rate_limiting() {
365        let config = RateLimitConfig {
366            enabled: true,
367            requests_per_minute: 1,
368            burst_size: 1,
369            per_ip: false,
370            per_user: false,
371            whitelist_ips: Vec::new(),
372        };
373
374        let manager = RateLimitManager::new(config);
375
376        // First request should pass
377        let result1 = manager.check_global_limit().await;
378        assert!(result1.is_ok());
379
380        // Second request should fail
381        let result2 = manager.check_global_limit().await;
382        assert!(result2.is_err());
383    }
384
385    #[tokio::test]
386    async fn test_statistics() {
387        let config = RateLimitConfig {
388            enabled: true,
389            requests_per_minute: 100,
390            burst_size: 10,
391            per_ip: true,
392            per_user: true,
393            whitelist_ips: vec!["127.0.0.1".to_string()],
394        };
395
396        let manager = RateLimitManager::new(config);
397        let stats = manager.get_statistics().await;
398
399        assert!(stats.enabled);
400        assert_eq!(stats.requests_per_minute, 100);
401        assert_eq!(stats.burst_size, 10);
402        assert!(stats.per_ip_enabled);
403        assert!(stats.per_user_enabled);
404        assert_eq!(stats.whitelist_count, 1);
405        assert_eq!(stats.active_ip_limiters, 0);
406        assert_eq!(stats.active_user_limiters, 0);
407    }
408
409    #[tokio::test]
410    async fn test_limiter_cleanup() {
411        let config = RateLimitConfig {
412            enabled: true,
413            requests_per_minute: 100,
414            burst_size: 10,
415            per_ip: true,
416            per_user: true,
417            whitelist_ips: Vec::new(),
418        };
419
420        let manager = RateLimitManager::new(config);
421
422        // Create some limiters
423        let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
424        let _ = manager.check_ip_limit(ip).await;
425        let _ = manager.check_user_limit("test-user").await;
426
427        let stats_before = manager.get_statistics().await;
428        assert!(stats_before.active_ip_limiters > 0 || stats_before.active_user_limiters > 0);
429
430        // Cleanup should work without errors
431        let result = manager.cleanup_limiters().await;
432        assert!(result.is_ok());
433    }
434
435    #[test]
436    fn test_custom_rate_limiter_creation() {
437        let manager =
438            create_rate_limit_middleware(200, 20, vec!["127.0.0.1".to_string(), "::1".to_string()]);
439
440        assert!(manager.is_enabled());
441        assert_eq!(manager.config.requests_per_minute, 200);
442        assert_eq!(manager.config.burst_size, 20);
443        assert_eq!(manager.config.whitelist_ips.len(), 2);
444    }
445}