llm_config_security/
rate_limit.rs

1//! Rate limiting and throttling
2
3use crate::errors::{SecurityError, SecurityResult};
4use governor::{
5    clock::DefaultClock,
6    state::{InMemoryState, NotKeyed},
7    Quota, RateLimiter as GovernorRateLimiter,
8};
9use std::collections::HashMap;
10use std::net::IpAddr;
11use std::num::NonZeroU32;
12use std::sync::{Arc, RwLock};
13use std::time::Duration;
14
15/// Rate limit configuration
16#[derive(Debug, Clone)]
17pub struct RateLimitConfig {
18    /// Requests per second for authenticated users
19    pub authenticated_rps: u32,
20    /// Requests per second for unauthenticated users
21    pub unauthenticated_rps: u32,
22    /// Burst size
23    pub burst_size: u32,
24    /// Time window in seconds
25    pub window_seconds: u64,
26    /// Ban duration for abusers (seconds)
27    pub ban_duration_seconds: u64,
28    /// Threshold for banning (violations)
29    pub ban_threshold: usize,
30}
31
32impl Default for RateLimitConfig {
33    fn default() -> Self {
34        Self {
35            authenticated_rps: 100,
36            unauthenticated_rps: 10,
37            burst_size: 50,
38            window_seconds: 60,
39            ban_duration_seconds: 3600, // 1 hour
40            ban_threshold: 10,
41        }
42    }
43}
44
45/// Rate limiter for API endpoints
46pub struct RateLimiter {
47    config: RateLimitConfig,
48    authenticated_limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
49    unauthenticated_limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
50    per_ip_limiters: Arc<RwLock<HashMap<IpAddr, IpLimiter>>>,
51    banned_ips: Arc<RwLock<HashMap<IpAddr, BanInfo>>>,
52}
53
54#[derive(Debug, Clone)]
55struct IpLimiter {
56    limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
57    violations: usize,
58    last_violation: std::time::Instant,
59}
60
61/// Information about a banned IP address
62#[derive(Debug, Clone)]
63pub struct BanInfo {
64    /// When the IP was banned
65    pub banned_at: std::time::Instant,
66    /// Reason for the ban
67    pub reason: String,
68    /// Number of violations that led to the ban
69    pub violations: usize,
70}
71
72impl RateLimiter {
73    /// Create a new rate limiter
74    pub fn new(config: RateLimitConfig) -> Self {
75        let authenticated_quota = Quota::per_second(
76            NonZeroU32::new(config.authenticated_rps).unwrap_or(NonZeroU32::new(100).unwrap())
77        ).allow_burst(
78            NonZeroU32::new(config.burst_size).unwrap_or(NonZeroU32::new(50).unwrap())
79        );
80
81        let unauthenticated_quota = Quota::per_second(
82            NonZeroU32::new(config.unauthenticated_rps).unwrap_or(NonZeroU32::new(10).unwrap())
83        ).allow_burst(
84            NonZeroU32::new(config.burst_size / 5).unwrap_or(NonZeroU32::new(10).unwrap())
85        );
86
87        Self {
88            config,
89            authenticated_limiter: Arc::new(GovernorRateLimiter::direct(authenticated_quota)),
90            unauthenticated_limiter: Arc::new(GovernorRateLimiter::direct(
91                unauthenticated_quota,
92            )),
93            per_ip_limiters: Arc::new(RwLock::new(HashMap::new())),
94            banned_ips: Arc::new(RwLock::new(HashMap::new())),
95        }
96    }
97
98    /// Check if a request is allowed
99    pub fn check_request(
100        &self,
101        ip: IpAddr,
102        authenticated: bool,
103    ) -> SecurityResult<()> {
104        // Check if IP is banned
105        if self.is_banned(ip) {
106            return Err(SecurityError::RateLimitExceeded(
107                "IP address is temporarily banned".to_string(),
108            ));
109        }
110
111        // Check global rate limit
112        let limiter = if authenticated {
113            &self.authenticated_limiter
114        } else {
115            &self.unauthenticated_limiter
116        };
117
118        if limiter.check().is_err() {
119            self.record_violation(ip, "Global rate limit exceeded");
120            return Err(SecurityError::RateLimitExceeded(
121                "Too many requests. Please try again later".to_string(),
122            ));
123        }
124
125        // Check per-IP rate limit
126        let mut limiters = self.per_ip_limiters.write().unwrap();
127        let ip_limiter = limiters.entry(ip).or_insert_with(|| {
128            let quota = if authenticated {
129                Quota::per_second(
130                    NonZeroU32::new(self.config.authenticated_rps / 10)
131                        .unwrap_or(NonZeroU32::new(10).unwrap())
132                )
133            } else {
134                Quota::per_second(
135                    NonZeroU32::new(self.config.unauthenticated_rps)
136                        .unwrap_or(NonZeroU32::new(10).unwrap())
137                )
138            }
139            .allow_burst(NonZeroU32::new(10).unwrap());
140
141            IpLimiter {
142                limiter: Arc::new(GovernorRateLimiter::direct(quota)),
143                violations: 0,
144                last_violation: std::time::Instant::now(),
145            }
146        });
147
148        if ip_limiter.limiter.check().is_err() {
149            drop(limiters); // Release lock before recording violation
150            self.record_violation(ip, "Per-IP rate limit exceeded");
151            return Err(SecurityError::RateLimitExceeded(format!(
152                "Too many requests from IP {}. Please try again later",
153                ip
154            )));
155        }
156
157        Ok(())
158    }
159
160    /// Check if an IP is banned
161    fn is_banned(&self, ip: IpAddr) -> bool {
162        let banned = self.banned_ips.read().unwrap();
163        if let Some(ban_info) = banned.get(&ip) {
164            let elapsed = ban_info.banned_at.elapsed();
165            let ban_duration = Duration::from_secs(self.config.ban_duration_seconds);
166
167            if elapsed < ban_duration {
168                return true;
169            }
170        }
171        false
172    }
173
174    /// Record a violation
175    fn record_violation(&self, ip: IpAddr, reason: &str) {
176        let mut limiters = self.per_ip_limiters.write().unwrap();
177        if let Some(ip_limiter) = limiters.get_mut(&ip) {
178            ip_limiter.violations += 1;
179            ip_limiter.last_violation = std::time::Instant::now();
180
181            // Ban if threshold exceeded
182            if ip_limiter.violations >= self.config.ban_threshold {
183                let violations = ip_limiter.violations; // Copy before dropping lock
184                drop(limiters); // Release lock
185                self.ban_ip(ip, reason.to_string(), violations);
186            }
187        }
188    }
189
190    /// Ban an IP address
191    fn ban_ip(&self, ip: IpAddr, reason: String, violations: usize) {
192        let mut banned = self.banned_ips.write().unwrap();
193        banned.insert(
194            ip,
195            BanInfo {
196                banned_at: std::time::Instant::now(),
197                reason,
198                violations,
199            },
200        );
201
202        tracing::warn!(
203            ip = %ip,
204            violations = violations,
205            "IP address banned due to rate limit violations"
206        );
207    }
208
209    /// Manually ban an IP
210    pub fn ban(&self, ip: IpAddr, reason: String) {
211        self.ban_ip(ip, reason, 0);
212    }
213
214    /// Unban an IP
215    pub fn unban(&self, ip: IpAddr) {
216        let mut banned = self.banned_ips.write().unwrap();
217        if banned.remove(&ip).is_some() {
218            tracing::info!(ip = %ip, "IP address unbanned");
219        }
220    }
221
222    /// Get banned IPs
223    pub fn get_banned_ips(&self) -> Vec<(IpAddr, BanInfo)> {
224        let banned = self.banned_ips.read().unwrap();
225        banned
226            .iter()
227            .map(|(ip, info)| (*ip, info.clone()))
228            .collect()
229    }
230
231    /// Clean up expired bans and old limiters
232    pub fn cleanup(&self) {
233        // Remove expired bans
234        let mut banned = self.banned_ips.write().unwrap();
235        let ban_duration = Duration::from_secs(self.config.ban_duration_seconds);
236        banned.retain(|_, ban_info| ban_info.banned_at.elapsed() < ban_duration);
237
238        // Remove old limiters (not accessed in last hour)
239        let mut limiters = self.per_ip_limiters.write().unwrap();
240        limiters.retain(|_, ip_limiter| {
241            ip_limiter.last_violation.elapsed() < Duration::from_secs(3600)
242        });
243    }
244
245    /// Get current statistics
246    pub fn get_stats(&self) -> RateLimitStats {
247        let banned = self.banned_ips.read().unwrap();
248        let limiters = self.per_ip_limiters.read().unwrap();
249
250        RateLimitStats {
251            active_limiters: limiters.len(),
252            banned_ips: banned.len(),
253            total_violations: limiters.values().map(|l| l.violations).sum(),
254        }
255    }
256}
257
258/// Rate limit statistics
259#[derive(Debug, Clone)]
260pub struct RateLimitStats {
261    pub active_limiters: usize,
262    pub banned_ips: usize,
263    pub total_violations: usize,
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use std::net::Ipv4Addr;
270    use std::thread;
271    use std::time::Duration;
272
273    #[test]
274    fn test_rate_limiter_basic() {
275        let config = RateLimitConfig {
276            authenticated_rps: 10,
277            unauthenticated_rps: 5,
278            burst_size: 10,
279            ..Default::default()
280        };
281
282        let limiter = RateLimiter::new(config);
283        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
284
285        // First request should succeed
286        assert!(limiter.check_request(ip, true).is_ok());
287
288        // Should allow burst
289        for _ in 0..9 {
290            assert!(limiter.check_request(ip, true).is_ok());
291        }
292
293        // Next request should fail
294        assert!(limiter.check_request(ip, true).is_err());
295    }
296
297    #[test]
298    fn test_per_ip_limiting() {
299        let config = RateLimitConfig {
300            authenticated_rps: 100,
301            unauthenticated_rps: 10,
302            burst_size: 20,
303            ..Default::default()
304        };
305
306        let limiter = RateLimiter::new(config);
307        let ip1 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
308        let ip2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
309
310        // IP1 should be limited independently of IP2
311        for _ in 0..10 {
312            limiter.check_request(ip1, false).ok();
313        }
314
315        // IP2 should still be allowed
316        assert!(limiter.check_request(ip2, false).is_ok());
317    }
318
319    #[test]
320    fn test_banning() {
321        let config = RateLimitConfig {
322            authenticated_rps: 5,
323            unauthenticated_rps: 5,
324            burst_size: 10,
325            ban_threshold: 3,
326            ban_duration_seconds: 1,
327            ..Default::default()
328        };
329
330        let limiter = RateLimiter::new(config);
331        let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
332
333        // Exceed rate limit multiple times to get banned
334        for _ in 0..20 {
335            limiter.check_request(ip, false).ok();
336        }
337
338        // Should be banned now
339        assert!(limiter.is_banned(ip));
340
341        // Wait for ban to expire
342        thread::sleep(Duration::from_secs(2));
343
344        // Clean up expired bans
345        limiter.cleanup();
346
347        // Should no longer be banned
348        assert!(!limiter.is_banned(ip));
349    }
350
351    #[test]
352    fn test_manual_ban() {
353        let limiter = RateLimiter::new(RateLimitConfig::default());
354        let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
355
356        // Manually ban IP
357        limiter.ban(ip, "Test ban".to_string());
358        assert!(limiter.is_banned(ip));
359
360        // Unban IP
361        limiter.unban(ip);
362        assert!(!limiter.is_banned(ip));
363    }
364
365    #[test]
366    fn test_stats() {
367        let limiter = RateLimiter::new(RateLimitConfig::default());
368        let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
369
370        limiter.check_request(ip, false).ok();
371
372        let stats = limiter.get_stats();
373        assert_eq!(stats.active_limiters, 1);
374    }
375
376    #[test]
377    fn test_cleanup() {
378        let limiter = RateLimiter::new(RateLimitConfig::default());
379        let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
380
381        limiter.check_request(ip, false).ok();
382        assert_eq!(limiter.get_stats().active_limiters, 1);
383
384        limiter.cleanup();
385        // Should still have the limiter as it was just used
386        assert_eq!(limiter.get_stats().active_limiters, 1);
387    }
388}