1use crate::errors::{SecurityError, SecurityResult};
4use governor::{
5 clock::DefaultClock,
6 state::{InMemoryState, NotKeyed},
7 Quota, RateLimiter as GovernorRateLimiter,
8};
9use std::collections::HashMap;
10use std::net::IpAddr;
11use std::num::NonZeroU32;
12use std::sync::{Arc, RwLock};
13use std::time::Duration;
14
15#[derive(Debug, Clone)]
17pub struct RateLimitConfig {
18 pub authenticated_rps: u32,
20 pub unauthenticated_rps: u32,
22 pub burst_size: u32,
24 pub window_seconds: u64,
26 pub ban_duration_seconds: u64,
28 pub ban_threshold: usize,
30}
31
32impl Default for RateLimitConfig {
33 fn default() -> Self {
34 Self {
35 authenticated_rps: 100,
36 unauthenticated_rps: 10,
37 burst_size: 50,
38 window_seconds: 60,
39 ban_duration_seconds: 3600, ban_threshold: 10,
41 }
42 }
43}
44
45pub struct RateLimiter {
47 config: RateLimitConfig,
48 authenticated_limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
49 unauthenticated_limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
50 per_ip_limiters: Arc<RwLock<HashMap<IpAddr, IpLimiter>>>,
51 banned_ips: Arc<RwLock<HashMap<IpAddr, BanInfo>>>,
52}
53
54#[derive(Debug, Clone)]
55struct IpLimiter {
56 limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
57 violations: usize,
58 last_violation: std::time::Instant,
59}
60
61#[derive(Debug, Clone)]
63pub struct BanInfo {
64 pub banned_at: std::time::Instant,
66 pub reason: String,
68 pub violations: usize,
70}
71
72impl RateLimiter {
73 pub fn new(config: RateLimitConfig) -> Self {
75 let authenticated_quota = Quota::per_second(
76 NonZeroU32::new(config.authenticated_rps).unwrap_or(NonZeroU32::new(100).unwrap())
77 ).allow_burst(
78 NonZeroU32::new(config.burst_size).unwrap_or(NonZeroU32::new(50).unwrap())
79 );
80
81 let unauthenticated_quota = Quota::per_second(
82 NonZeroU32::new(config.unauthenticated_rps).unwrap_or(NonZeroU32::new(10).unwrap())
83 ).allow_burst(
84 NonZeroU32::new(config.burst_size / 5).unwrap_or(NonZeroU32::new(10).unwrap())
85 );
86
87 Self {
88 config,
89 authenticated_limiter: Arc::new(GovernorRateLimiter::direct(authenticated_quota)),
90 unauthenticated_limiter: Arc::new(GovernorRateLimiter::direct(
91 unauthenticated_quota,
92 )),
93 per_ip_limiters: Arc::new(RwLock::new(HashMap::new())),
94 banned_ips: Arc::new(RwLock::new(HashMap::new())),
95 }
96 }
97
98 pub fn check_request(
100 &self,
101 ip: IpAddr,
102 authenticated: bool,
103 ) -> SecurityResult<()> {
104 if self.is_banned(ip) {
106 return Err(SecurityError::RateLimitExceeded(
107 "IP address is temporarily banned".to_string(),
108 ));
109 }
110
111 let limiter = if authenticated {
113 &self.authenticated_limiter
114 } else {
115 &self.unauthenticated_limiter
116 };
117
118 if limiter.check().is_err() {
119 self.record_violation(ip, "Global rate limit exceeded");
120 return Err(SecurityError::RateLimitExceeded(
121 "Too many requests. Please try again later".to_string(),
122 ));
123 }
124
125 let mut limiters = self.per_ip_limiters.write().unwrap();
127 let ip_limiter = limiters.entry(ip).or_insert_with(|| {
128 let quota = if authenticated {
129 Quota::per_second(
130 NonZeroU32::new(self.config.authenticated_rps / 10)
131 .unwrap_or(NonZeroU32::new(10).unwrap())
132 )
133 } else {
134 Quota::per_second(
135 NonZeroU32::new(self.config.unauthenticated_rps)
136 .unwrap_or(NonZeroU32::new(10).unwrap())
137 )
138 }
139 .allow_burst(NonZeroU32::new(10).unwrap());
140
141 IpLimiter {
142 limiter: Arc::new(GovernorRateLimiter::direct(quota)),
143 violations: 0,
144 last_violation: std::time::Instant::now(),
145 }
146 });
147
148 if ip_limiter.limiter.check().is_err() {
149 drop(limiters); self.record_violation(ip, "Per-IP rate limit exceeded");
151 return Err(SecurityError::RateLimitExceeded(format!(
152 "Too many requests from IP {}. Please try again later",
153 ip
154 )));
155 }
156
157 Ok(())
158 }
159
160 fn is_banned(&self, ip: IpAddr) -> bool {
162 let banned = self.banned_ips.read().unwrap();
163 if let Some(ban_info) = banned.get(&ip) {
164 let elapsed = ban_info.banned_at.elapsed();
165 let ban_duration = Duration::from_secs(self.config.ban_duration_seconds);
166
167 if elapsed < ban_duration {
168 return true;
169 }
170 }
171 false
172 }
173
174 fn record_violation(&self, ip: IpAddr, reason: &str) {
176 let mut limiters = self.per_ip_limiters.write().unwrap();
177 if let Some(ip_limiter) = limiters.get_mut(&ip) {
178 ip_limiter.violations += 1;
179 ip_limiter.last_violation = std::time::Instant::now();
180
181 if ip_limiter.violations >= self.config.ban_threshold {
183 let violations = ip_limiter.violations; drop(limiters); self.ban_ip(ip, reason.to_string(), violations);
186 }
187 }
188 }
189
190 fn ban_ip(&self, ip: IpAddr, reason: String, violations: usize) {
192 let mut banned = self.banned_ips.write().unwrap();
193 banned.insert(
194 ip,
195 BanInfo {
196 banned_at: std::time::Instant::now(),
197 reason,
198 violations,
199 },
200 );
201
202 tracing::warn!(
203 ip = %ip,
204 violations = violations,
205 "IP address banned due to rate limit violations"
206 );
207 }
208
209 pub fn ban(&self, ip: IpAddr, reason: String) {
211 self.ban_ip(ip, reason, 0);
212 }
213
214 pub fn unban(&self, ip: IpAddr) {
216 let mut banned = self.banned_ips.write().unwrap();
217 if banned.remove(&ip).is_some() {
218 tracing::info!(ip = %ip, "IP address unbanned");
219 }
220 }
221
222 pub fn get_banned_ips(&self) -> Vec<(IpAddr, BanInfo)> {
224 let banned = self.banned_ips.read().unwrap();
225 banned
226 .iter()
227 .map(|(ip, info)| (*ip, info.clone()))
228 .collect()
229 }
230
231 pub fn cleanup(&self) {
233 let mut banned = self.banned_ips.write().unwrap();
235 let ban_duration = Duration::from_secs(self.config.ban_duration_seconds);
236 banned.retain(|_, ban_info| ban_info.banned_at.elapsed() < ban_duration);
237
238 let mut limiters = self.per_ip_limiters.write().unwrap();
240 limiters.retain(|_, ip_limiter| {
241 ip_limiter.last_violation.elapsed() < Duration::from_secs(3600)
242 });
243 }
244
245 pub fn get_stats(&self) -> RateLimitStats {
247 let banned = self.banned_ips.read().unwrap();
248 let limiters = self.per_ip_limiters.read().unwrap();
249
250 RateLimitStats {
251 active_limiters: limiters.len(),
252 banned_ips: banned.len(),
253 total_violations: limiters.values().map(|l| l.violations).sum(),
254 }
255 }
256}
257
258#[derive(Debug, Clone)]
260pub struct RateLimitStats {
261 pub active_limiters: usize,
262 pub banned_ips: usize,
263 pub total_violations: usize,
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use std::net::Ipv4Addr;
270 use std::thread;
271 use std::time::Duration;
272
273 #[test]
274 fn test_rate_limiter_basic() {
275 let config = RateLimitConfig {
276 authenticated_rps: 10,
277 unauthenticated_rps: 5,
278 burst_size: 10,
279 ..Default::default()
280 };
281
282 let limiter = RateLimiter::new(config);
283 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
284
285 assert!(limiter.check_request(ip, true).is_ok());
287
288 for _ in 0..9 {
290 assert!(limiter.check_request(ip, true).is_ok());
291 }
292
293 assert!(limiter.check_request(ip, true).is_err());
295 }
296
297 #[test]
298 fn test_per_ip_limiting() {
299 let config = RateLimitConfig {
300 authenticated_rps: 100,
301 unauthenticated_rps: 10,
302 burst_size: 20,
303 ..Default::default()
304 };
305
306 let limiter = RateLimiter::new(config);
307 let ip1 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
308 let ip2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
309
310 for _ in 0..10 {
312 limiter.check_request(ip1, false).ok();
313 }
314
315 assert!(limiter.check_request(ip2, false).is_ok());
317 }
318
319 #[test]
320 fn test_banning() {
321 let config = RateLimitConfig {
322 authenticated_rps: 5,
323 unauthenticated_rps: 5,
324 burst_size: 10,
325 ban_threshold: 3,
326 ban_duration_seconds: 1,
327 ..Default::default()
328 };
329
330 let limiter = RateLimiter::new(config);
331 let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
332
333 for _ in 0..20 {
335 limiter.check_request(ip, false).ok();
336 }
337
338 assert!(limiter.is_banned(ip));
340
341 thread::sleep(Duration::from_secs(2));
343
344 limiter.cleanup();
346
347 assert!(!limiter.is_banned(ip));
349 }
350
351 #[test]
352 fn test_manual_ban() {
353 let limiter = RateLimiter::new(RateLimitConfig::default());
354 let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
355
356 limiter.ban(ip, "Test ban".to_string());
358 assert!(limiter.is_banned(ip));
359
360 limiter.unban(ip);
362 assert!(!limiter.is_banned(ip));
363 }
364
365 #[test]
366 fn test_stats() {
367 let limiter = RateLimiter::new(RateLimitConfig::default());
368 let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
369
370 limiter.check_request(ip, false).ok();
371
372 let stats = limiter.get_stats();
373 assert_eq!(stats.active_limiters, 1);
374 }
375
376 #[test]
377 fn test_cleanup() {
378 let limiter = RateLimiter::new(RateLimitConfig::default());
379 let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
380
381 limiter.check_request(ip, false).ok();
382 assert_eq!(limiter.get_stats().active_limiters, 1);
383
384 limiter.cleanup();
385 assert_eq!(limiter.get_stats().active_limiters, 1);
387 }
388}