codex_memory/security/
rate_limit.rs

1use crate::security::{RateLimitConfig, Result, SecurityError};
2use axum::{
3    extract::{ConnectInfo, Request, State},
4    http::{HeaderMap, StatusCode},
5    middleware::Next,
6    response::Response,
7};
8use governor::{
9    clock::DefaultClock,
10    middleware::NoOpMiddleware,
11    state::{InMemoryState, NotKeyed},
12    Quota, RateLimiter as GovernorRateLimiter,
13};
14use std::collections::HashMap;
15use std::net::{IpAddr, SocketAddr};
16use std::num::NonZeroU32;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::{debug, info, warn};
20
21/// Rate limiting manager
22pub struct RateLimitManager {
23    config: RateLimitConfig,
24    ip_limiters: Arc<
25        RwLock<
26            HashMap<
27                IpAddr,
28                Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>>,
29            >,
30        >,
31    >,
32    user_limiters: Arc<
33        RwLock<
34            HashMap<
35                String,
36                Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>>,
37            >,
38        >,
39    >,
40    global_limiter:
41        Option<Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>>>,
42}
43
44impl RateLimitManager {
45    pub fn new(config: RateLimitConfig) -> Self {
46        let global_limiter = if config.enabled {
47            let quota = Quota::per_minute(
48                NonZeroU32::new(config.requests_per_minute)
49                    .unwrap_or(NonZeroU32::new(100).unwrap()),
50            );
51            Some(Arc::new(GovernorRateLimiter::direct(quota)))
52        } else {
53            None
54        };
55
56        Self {
57            config,
58            ip_limiters: Arc::new(RwLock::new(HashMap::new())),
59            user_limiters: Arc::new(RwLock::new(HashMap::new())),
60            global_limiter,
61        }
62    }
63
64    /// Check rate limit for IP address
65    pub async fn check_ip_limit(&self, ip: IpAddr) -> Result<()> {
66        if !self.config.enabled || !self.config.per_ip {
67            return Ok(());
68        }
69
70        // Check whitelist
71        let ip_str = ip.to_string();
72        if self.config.whitelist_ips.contains(&ip_str) {
73            debug!("IP {} is whitelisted, bypassing rate limit", ip);
74            return Ok(());
75        }
76
77        let mut limiters = self.ip_limiters.write().await;
78
79        let limiter = limiters.entry(ip).or_insert_with(|| {
80            let quota = Quota::per_minute(
81                NonZeroU32::new(self.config.requests_per_minute)
82                    .unwrap_or(NonZeroU32::new(100).unwrap()),
83            )
84            .allow_burst(
85                NonZeroU32::new(self.config.burst_size).unwrap_or(NonZeroU32::new(10).unwrap()),
86            );
87            Arc::new(GovernorRateLimiter::direct(quota))
88        });
89
90        let limiter = Arc::clone(limiter);
91        drop(limiters); // Release lock before checking
92
93        match limiter.check() {
94            Ok(_) => {
95                debug!("Rate limit check passed for IP: {}", ip);
96                Ok(())
97            }
98            Err(_) => {
99                warn!("Rate limit exceeded for IP: {}", ip);
100                Err(SecurityError::RateLimitExceeded)
101            }
102        }
103    }
104
105    /// Check rate limit for user
106    pub async fn check_user_limit(&self, user_id: &str) -> Result<()> {
107        if !self.config.enabled || !self.config.per_user {
108            return Ok(());
109        }
110
111        let mut limiters = self.user_limiters.write().await;
112
113        let limiter = limiters.entry(user_id.to_string()).or_insert_with(|| {
114            let quota = Quota::per_minute(
115                NonZeroU32::new(self.config.requests_per_minute)
116                    .unwrap_or(NonZeroU32::new(100).unwrap()),
117            )
118            .allow_burst(
119                NonZeroU32::new(self.config.burst_size).unwrap_or(NonZeroU32::new(10).unwrap()),
120            );
121            Arc::new(GovernorRateLimiter::direct(quota))
122        });
123
124        let limiter = Arc::clone(limiter);
125        drop(limiters); // Release lock before checking
126
127        match limiter.check() {
128            Ok(_) => {
129                debug!("Rate limit check passed for user: {}", user_id);
130                Ok(())
131            }
132            Err(_) => {
133                warn!("Rate limit exceeded for user: {}", user_id);
134                Err(SecurityError::RateLimitExceeded)
135            }
136        }
137    }
138
139    /// Check global rate limit
140    pub async fn check_global_limit(&self) -> Result<()> {
141        if !self.config.enabled {
142            return Ok(());
143        }
144
145        if let Some(limiter) = &self.global_limiter {
146            match limiter.check() {
147                Ok(_) => {
148                    debug!("Global rate limit check passed");
149                    Ok(())
150                }
151                Err(_) => {
152                    warn!("Global rate limit exceeded");
153                    Err(SecurityError::RateLimitExceeded)
154                }
155            }
156        } else {
157            Ok(())
158        }
159    }
160
161    /// Clean up old limiters to prevent memory leaks
162    pub async fn cleanup_limiters(&self) -> Result<()> {
163        let mut ip_limiters = self.ip_limiters.write().await;
164        let mut user_limiters = self.user_limiters.write().await;
165
166        let initial_ip_count = ip_limiters.len();
167        let initial_user_count = user_limiters.len();
168
169        // Remove limiters that haven't been used recently
170        // This is a simplified cleanup - in production, you might want more sophisticated logic
171        ip_limiters.retain(|_, limiter| Arc::strong_count(limiter) > 1);
172        user_limiters.retain(|_, limiter| Arc::strong_count(limiter) > 1);
173
174        let cleaned_ip = initial_ip_count - ip_limiters.len();
175        let cleaned_user = initial_user_count - user_limiters.len();
176
177        if cleaned_ip > 0 || cleaned_user > 0 {
178            info!(
179                "Cleaned up {} IP limiters and {} user limiters",
180                cleaned_ip, cleaned_user
181            );
182        }
183
184        Ok(())
185    }
186
187    /// Get rate limit statistics
188    pub async fn get_statistics(&self) -> RateLimitStatistics {
189        let ip_limiters = self.ip_limiters.read().await;
190        let user_limiters = self.user_limiters.read().await;
191
192        RateLimitStatistics {
193            enabled: self.config.enabled,
194            requests_per_minute: self.config.requests_per_minute,
195            burst_size: self.config.burst_size,
196            active_ip_limiters: ip_limiters.len(),
197            active_user_limiters: user_limiters.len(),
198            per_ip_enabled: self.config.per_ip,
199            per_user_enabled: self.config.per_user,
200            whitelist_count: self.config.whitelist_ips.len(),
201        }
202    }
203
204    pub fn is_enabled(&self) -> bool {
205        self.config.enabled
206    }
207}
208
209#[derive(Debug, Clone, serde::Serialize)]
210pub struct RateLimitStatistics {
211    pub enabled: bool,
212    pub requests_per_minute: u32,
213    pub burst_size: u32,
214    pub active_ip_limiters: usize,
215    pub active_user_limiters: usize,
216    pub per_ip_enabled: bool,
217    pub per_user_enabled: bool,
218    pub whitelist_count: usize,
219}
220
221/// Rate limiting middleware for Axum
222pub async fn rate_limit_middleware(
223    State(rate_limiter): State<Arc<RateLimitManager>>,
224    ConnectInfo(addr): ConnectInfo<SocketAddr>,
225    headers: HeaderMap,
226    request: Request,
227    next: Next,
228) -> std::result::Result<Response, StatusCode> {
229    if !rate_limiter.is_enabled() {
230        return Ok(next.run(request).await);
231    }
232
233    // Check global rate limit first
234    if let Err(_) = rate_limiter.check_global_limit().await {
235        warn!("Global rate limit exceeded");
236        return Err(StatusCode::TOO_MANY_REQUESTS);
237    }
238
239    // Check IP-based rate limit
240    let ip = addr.ip();
241    if let Err(_) = rate_limiter.check_ip_limit(ip).await {
242        warn!("IP rate limit exceeded for: {}", ip);
243        return Err(StatusCode::TOO_MANY_REQUESTS);
244    }
245
246    // Check user-based rate limit if user is authenticated
247    if rate_limiter.config.per_user {
248        if let Some(user_header) = headers.get("X-User-ID") {
249            if let Ok(user_id) = user_header.to_str() {
250                if let Err(_) = rate_limiter.check_user_limit(user_id).await {
251                    warn!("User rate limit exceeded for: {}", user_id);
252                    return Err(StatusCode::TOO_MANY_REQUESTS);
253                }
254            }
255        }
256    }
257
258    debug!("Rate limit checks passed for IP: {}", ip);
259    Ok(next.run(request).await)
260}
261
262/// Create rate limit middleware with custom configuration
263pub fn create_rate_limit_middleware(
264    requests_per_minute: u32,
265    burst_size: u32,
266    whitelist_ips: Vec<String>,
267) -> RateLimitManager {
268    let config = RateLimitConfig {
269        enabled: true,
270        requests_per_minute,
271        burst_size,
272        per_ip: true,
273        per_user: true,
274        whitelist_ips,
275    };
276
277    RateLimitManager::new(config)
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use std::net::Ipv4Addr;
284
285    #[tokio::test]
286    async fn test_rate_limit_manager_creation() {
287        let config = RateLimitConfig::default();
288        let manager = RateLimitManager::new(config);
289        assert!(!manager.is_enabled());
290    }
291
292    #[tokio::test]
293    async fn test_disabled_rate_limiting() {
294        let config = RateLimitConfig::default(); // disabled by default
295        let manager = RateLimitManager::new(config);
296
297        let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
298        let result = manager.check_ip_limit(ip).await;
299        assert!(result.is_ok());
300    }
301
302    #[tokio::test]
303    async fn test_ip_whitelist() {
304        let config = RateLimitConfig {
305            enabled: true,
306            requests_per_minute: 1, // Very low limit
307            burst_size: 1,
308            per_ip: true,
309            per_user: false,
310            whitelist_ips: vec!["192.168.1.1".to_string()],
311        };
312
313        let manager = RateLimitManager::new(config);
314        let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
315
316        // Should pass even with low limit due to whitelist
317        let result = manager.check_ip_limit(ip).await;
318        assert!(result.is_ok());
319    }
320
321    #[tokio::test]
322    async fn test_rate_limit_exceeded() {
323        let config = RateLimitConfig {
324            enabled: true,
325            requests_per_minute: 1,
326            burst_size: 1,
327            per_ip: true,
328            per_user: false,
329            whitelist_ips: Vec::new(),
330        };
331
332        let manager = RateLimitManager::new(config);
333        let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
334
335        // First request should pass
336        let result1 = manager.check_ip_limit(ip).await;
337        assert!(result1.is_ok());
338
339        // Second request should fail (rate limit exceeded)
340        let result2 = manager.check_ip_limit(ip).await;
341        assert!(result2.is_err());
342
343        if let Err(SecurityError::RateLimitExceeded) = result2 {
344            // Expected error
345        } else {
346            panic!("Expected RateLimitExceeded error");
347        }
348    }
349
350    #[tokio::test]
351    async fn test_user_rate_limiting() {
352        let config = RateLimitConfig {
353            enabled: true,
354            requests_per_minute: 1,
355            burst_size: 1,
356            per_ip: false,
357            per_user: true,
358            whitelist_ips: Vec::new(),
359        };
360
361        let manager = RateLimitManager::new(config);
362        let user_id = "test-user";
363
364        // First request should pass
365        let result1 = manager.check_user_limit(user_id).await;
366        assert!(result1.is_ok());
367
368        // Second request should fail
369        let result2 = manager.check_user_limit(user_id).await;
370        assert!(result2.is_err());
371    }
372
373    #[tokio::test]
374    async fn test_global_rate_limiting() {
375        let config = RateLimitConfig {
376            enabled: true,
377            requests_per_minute: 1,
378            burst_size: 1,
379            per_ip: false,
380            per_user: false,
381            whitelist_ips: Vec::new(),
382        };
383
384        let manager = RateLimitManager::new(config);
385
386        // First request should pass
387        let result1 = manager.check_global_limit().await;
388        assert!(result1.is_ok());
389
390        // Second request should fail
391        let result2 = manager.check_global_limit().await;
392        assert!(result2.is_err());
393    }
394
395    #[tokio::test]
396    async fn test_statistics() {
397        let config = RateLimitConfig {
398            enabled: true,
399            requests_per_minute: 100,
400            burst_size: 10,
401            per_ip: true,
402            per_user: true,
403            whitelist_ips: vec!["127.0.0.1".to_string()],
404        };
405
406        let manager = RateLimitManager::new(config);
407        let stats = manager.get_statistics().await;
408
409        assert!(stats.enabled);
410        assert_eq!(stats.requests_per_minute, 100);
411        assert_eq!(stats.burst_size, 10);
412        assert!(stats.per_ip_enabled);
413        assert!(stats.per_user_enabled);
414        assert_eq!(stats.whitelist_count, 1);
415        assert_eq!(stats.active_ip_limiters, 0);
416        assert_eq!(stats.active_user_limiters, 0);
417    }
418
419    #[tokio::test]
420    async fn test_limiter_cleanup() {
421        let config = RateLimitConfig {
422            enabled: true,
423            requests_per_minute: 100,
424            burst_size: 10,
425            per_ip: true,
426            per_user: true,
427            whitelist_ips: Vec::new(),
428        };
429
430        let manager = RateLimitManager::new(config);
431
432        // Create some limiters
433        let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
434        let _ = manager.check_ip_limit(ip).await;
435        let _ = manager.check_user_limit("test-user").await;
436
437        let stats_before = manager.get_statistics().await;
438        assert!(stats_before.active_ip_limiters > 0 || stats_before.active_user_limiters > 0);
439
440        // Cleanup should work without errors
441        let result = manager.cleanup_limiters().await;
442        assert!(result.is_ok());
443    }
444
445    #[test]
446    fn test_custom_rate_limiter_creation() {
447        let manager =
448            create_rate_limit_middleware(200, 20, vec!["127.0.0.1".to_string(), "::1".to_string()]);
449
450        assert!(manager.is_enabled());
451        assert_eq!(manager.config.requests_per_minute, 200);
452        assert_eq!(manager.config.burst_size, 20);
453        assert_eq!(manager.config.whitelist_ips.len(), 2);
454    }
455}