1use crate::security::{RateLimitConfig, Result, SecurityError};
2use axum::{
3 extract::{ConnectInfo, Request, State},
4 http::{HeaderMap, StatusCode},
5 middleware::Next,
6 response::Response,
7};
8use governor::{
9 clock::DefaultClock,
10 middleware::NoOpMiddleware,
11 state::{InMemoryState, NotKeyed},
12 Quota, RateLimiter as GovernorRateLimiter,
13};
14use std::collections::HashMap;
15use std::net::{IpAddr, SocketAddr};
16use std::num::NonZeroU32;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::{debug, info, warn};
20
21pub struct RateLimitManager {
23 config: RateLimitConfig,
24 ip_limiters: Arc<
25 RwLock<
26 HashMap<
27 IpAddr,
28 Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>>,
29 >,
30 >,
31 >,
32 user_limiters: Arc<
33 RwLock<
34 HashMap<
35 String,
36 Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>>,
37 >,
38 >,
39 >,
40 global_limiter:
41 Option<Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>>>,
42}
43
44impl RateLimitManager {
45 pub fn new(config: RateLimitConfig) -> Self {
46 let global_limiter = if config.enabled {
47 let quota = Quota::per_minute(
48 NonZeroU32::new(config.requests_per_minute)
49 .unwrap_or(NonZeroU32::new(100).unwrap()),
50 );
51 Some(Arc::new(GovernorRateLimiter::direct(quota)))
52 } else {
53 None
54 };
55
56 Self {
57 config,
58 ip_limiters: Arc::new(RwLock::new(HashMap::new())),
59 user_limiters: Arc::new(RwLock::new(HashMap::new())),
60 global_limiter,
61 }
62 }
63
64 pub async fn check_ip_limit(&self, ip: IpAddr) -> Result<()> {
66 if !self.config.enabled || !self.config.per_ip {
67 return Ok(());
68 }
69
70 let ip_str = ip.to_string();
72 if self.config.whitelist_ips.contains(&ip_str) {
73 debug!("IP {} is whitelisted, bypassing rate limit", ip);
74 return Ok(());
75 }
76
77 let mut limiters = self.ip_limiters.write().await;
78
79 let limiter = limiters.entry(ip).or_insert_with(|| {
80 let quota = Quota::per_minute(
81 NonZeroU32::new(self.config.requests_per_minute)
82 .unwrap_or(NonZeroU32::new(100).unwrap()),
83 )
84 .allow_burst(
85 NonZeroU32::new(self.config.burst_size).unwrap_or(NonZeroU32::new(10).unwrap()),
86 );
87 Arc::new(GovernorRateLimiter::direct(quota))
88 });
89
90 let limiter = Arc::clone(limiter);
91 drop(limiters); match limiter.check() {
94 Ok(_) => {
95 debug!("Rate limit check passed for IP: {}", ip);
96 Ok(())
97 }
98 Err(_) => {
99 warn!("Rate limit exceeded for IP: {}", ip);
100 Err(SecurityError::RateLimitExceeded)
101 }
102 }
103 }
104
105 pub async fn check_user_limit(&self, user_id: &str) -> Result<()> {
107 if !self.config.enabled || !self.config.per_user {
108 return Ok(());
109 }
110
111 let mut limiters = self.user_limiters.write().await;
112
113 let limiter = limiters.entry(user_id.to_string()).or_insert_with(|| {
114 let quota = Quota::per_minute(
115 NonZeroU32::new(self.config.requests_per_minute)
116 .unwrap_or(NonZeroU32::new(100).unwrap()),
117 )
118 .allow_burst(
119 NonZeroU32::new(self.config.burst_size).unwrap_or(NonZeroU32::new(10).unwrap()),
120 );
121 Arc::new(GovernorRateLimiter::direct(quota))
122 });
123
124 let limiter = Arc::clone(limiter);
125 drop(limiters); match limiter.check() {
128 Ok(_) => {
129 debug!("Rate limit check passed for user: {}", user_id);
130 Ok(())
131 }
132 Err(_) => {
133 warn!("Rate limit exceeded for user: {}", user_id);
134 Err(SecurityError::RateLimitExceeded)
135 }
136 }
137 }
138
139 pub async fn check_global_limit(&self) -> Result<()> {
141 if !self.config.enabled {
142 return Ok(());
143 }
144
145 if let Some(limiter) = &self.global_limiter {
146 match limiter.check() {
147 Ok(_) => {
148 debug!("Global rate limit check passed");
149 Ok(())
150 }
151 Err(_) => {
152 warn!("Global rate limit exceeded");
153 Err(SecurityError::RateLimitExceeded)
154 }
155 }
156 } else {
157 Ok(())
158 }
159 }
160
161 pub async fn cleanup_limiters(&self) -> Result<()> {
163 let mut ip_limiters = self.ip_limiters.write().await;
164 let mut user_limiters = self.user_limiters.write().await;
165
166 let initial_ip_count = ip_limiters.len();
167 let initial_user_count = user_limiters.len();
168
169 ip_limiters.retain(|_, limiter| Arc::strong_count(limiter) > 1);
172 user_limiters.retain(|_, limiter| Arc::strong_count(limiter) > 1);
173
174 let cleaned_ip = initial_ip_count - ip_limiters.len();
175 let cleaned_user = initial_user_count - user_limiters.len();
176
177 if cleaned_ip > 0 || cleaned_user > 0 {
178 info!(
179 "Cleaned up {} IP limiters and {} user limiters",
180 cleaned_ip, cleaned_user
181 );
182 }
183
184 Ok(())
185 }
186
187 pub async fn get_statistics(&self) -> RateLimitStatistics {
189 let ip_limiters = self.ip_limiters.read().await;
190 let user_limiters = self.user_limiters.read().await;
191
192 RateLimitStatistics {
193 enabled: self.config.enabled,
194 requests_per_minute: self.config.requests_per_minute,
195 burst_size: self.config.burst_size,
196 active_ip_limiters: ip_limiters.len(),
197 active_user_limiters: user_limiters.len(),
198 per_ip_enabled: self.config.per_ip,
199 per_user_enabled: self.config.per_user,
200 whitelist_count: self.config.whitelist_ips.len(),
201 }
202 }
203
204 pub fn is_enabled(&self) -> bool {
205 self.config.enabled
206 }
207}
208
209#[derive(Debug, Clone, serde::Serialize)]
210pub struct RateLimitStatistics {
211 pub enabled: bool,
212 pub requests_per_minute: u32,
213 pub burst_size: u32,
214 pub active_ip_limiters: usize,
215 pub active_user_limiters: usize,
216 pub per_ip_enabled: bool,
217 pub per_user_enabled: bool,
218 pub whitelist_count: usize,
219}
220
221pub async fn rate_limit_middleware(
223 State(rate_limiter): State<Arc<RateLimitManager>>,
224 ConnectInfo(addr): ConnectInfo<SocketAddr>,
225 headers: HeaderMap,
226 request: Request,
227 next: Next,
228) -> std::result::Result<Response, StatusCode> {
229 if !rate_limiter.is_enabled() {
230 return Ok(next.run(request).await);
231 }
232
233 if let Err(_) = rate_limiter.check_global_limit().await {
235 warn!("Global rate limit exceeded");
236 return Err(StatusCode::TOO_MANY_REQUESTS);
237 }
238
239 let ip = addr.ip();
241 if let Err(_) = rate_limiter.check_ip_limit(ip).await {
242 warn!("IP rate limit exceeded for: {}", ip);
243 return Err(StatusCode::TOO_MANY_REQUESTS);
244 }
245
246 if rate_limiter.config.per_user {
248 if let Some(user_header) = headers.get("X-User-ID") {
249 if let Ok(user_id) = user_header.to_str() {
250 if let Err(_) = rate_limiter.check_user_limit(user_id).await {
251 warn!("User rate limit exceeded for: {}", user_id);
252 return Err(StatusCode::TOO_MANY_REQUESTS);
253 }
254 }
255 }
256 }
257
258 debug!("Rate limit checks passed for IP: {}", ip);
259 Ok(next.run(request).await)
260}
261
262pub fn create_rate_limit_middleware(
264 requests_per_minute: u32,
265 burst_size: u32,
266 whitelist_ips: Vec<String>,
267) -> RateLimitManager {
268 let config = RateLimitConfig {
269 enabled: true,
270 requests_per_minute,
271 burst_size,
272 per_ip: true,
273 per_user: true,
274 whitelist_ips,
275 };
276
277 RateLimitManager::new(config)
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283 use std::net::Ipv4Addr;
284
285 #[tokio::test]
286 async fn test_rate_limit_manager_creation() {
287 let config = RateLimitConfig::default();
288 let manager = RateLimitManager::new(config);
289 assert!(!manager.is_enabled());
290 }
291
292 #[tokio::test]
293 async fn test_disabled_rate_limiting() {
294 let config = RateLimitConfig::default(); let manager = RateLimitManager::new(config);
296
297 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
298 let result = manager.check_ip_limit(ip).await;
299 assert!(result.is_ok());
300 }
301
302 #[tokio::test]
303 async fn test_ip_whitelist() {
304 let config = RateLimitConfig {
305 enabled: true,
306 requests_per_minute: 1, burst_size: 1,
308 per_ip: true,
309 per_user: false,
310 whitelist_ips: vec!["192.168.1.1".to_string()],
311 };
312
313 let manager = RateLimitManager::new(config);
314 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
315
316 let result = manager.check_ip_limit(ip).await;
318 assert!(result.is_ok());
319 }
320
321 #[tokio::test]
322 async fn test_rate_limit_exceeded() {
323 let config = RateLimitConfig {
324 enabled: true,
325 requests_per_minute: 1,
326 burst_size: 1,
327 per_ip: true,
328 per_user: false,
329 whitelist_ips: Vec::new(),
330 };
331
332 let manager = RateLimitManager::new(config);
333 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
334
335 let result1 = manager.check_ip_limit(ip).await;
337 assert!(result1.is_ok());
338
339 let result2 = manager.check_ip_limit(ip).await;
341 assert!(result2.is_err());
342
343 if let Err(SecurityError::RateLimitExceeded) = result2 {
344 } else {
346 panic!("Expected RateLimitExceeded error");
347 }
348 }
349
350 #[tokio::test]
351 async fn test_user_rate_limiting() {
352 let config = RateLimitConfig {
353 enabled: true,
354 requests_per_minute: 1,
355 burst_size: 1,
356 per_ip: false,
357 per_user: true,
358 whitelist_ips: Vec::new(),
359 };
360
361 let manager = RateLimitManager::new(config);
362 let user_id = "test-user";
363
364 let result1 = manager.check_user_limit(user_id).await;
366 assert!(result1.is_ok());
367
368 let result2 = manager.check_user_limit(user_id).await;
370 assert!(result2.is_err());
371 }
372
373 #[tokio::test]
374 async fn test_global_rate_limiting() {
375 let config = RateLimitConfig {
376 enabled: true,
377 requests_per_minute: 1,
378 burst_size: 1,
379 per_ip: false,
380 per_user: false,
381 whitelist_ips: Vec::new(),
382 };
383
384 let manager = RateLimitManager::new(config);
385
386 let result1 = manager.check_global_limit().await;
388 assert!(result1.is_ok());
389
390 let result2 = manager.check_global_limit().await;
392 assert!(result2.is_err());
393 }
394
395 #[tokio::test]
396 async fn test_statistics() {
397 let config = RateLimitConfig {
398 enabled: true,
399 requests_per_minute: 100,
400 burst_size: 10,
401 per_ip: true,
402 per_user: true,
403 whitelist_ips: vec!["127.0.0.1".to_string()],
404 };
405
406 let manager = RateLimitManager::new(config);
407 let stats = manager.get_statistics().await;
408
409 assert!(stats.enabled);
410 assert_eq!(stats.requests_per_minute, 100);
411 assert_eq!(stats.burst_size, 10);
412 assert!(stats.per_ip_enabled);
413 assert!(stats.per_user_enabled);
414 assert_eq!(stats.whitelist_count, 1);
415 assert_eq!(stats.active_ip_limiters, 0);
416 assert_eq!(stats.active_user_limiters, 0);
417 }
418
419 #[tokio::test]
420 async fn test_limiter_cleanup() {
421 let config = RateLimitConfig {
422 enabled: true,
423 requests_per_minute: 100,
424 burst_size: 10,
425 per_ip: true,
426 per_user: true,
427 whitelist_ips: Vec::new(),
428 };
429
430 let manager = RateLimitManager::new(config);
431
432 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
434 let _ = manager.check_ip_limit(ip).await;
435 let _ = manager.check_user_limit("test-user").await;
436
437 let stats_before = manager.get_statistics().await;
438 assert!(stats_before.active_ip_limiters > 0 || stats_before.active_user_limiters > 0);
439
440 let result = manager.cleanup_limiters().await;
442 assert!(result.is_ok());
443 }
444
445 #[test]
446 fn test_custom_rate_limiter_creation() {
447 let manager =
448 create_rate_limit_middleware(200, 20, vec!["127.0.0.1".to_string(), "::1".to_string()]);
449
450 assert!(manager.is_enabled());
451 assert_eq!(manager.config.requests_per_minute, 200);
452 assert_eq!(manager.config.burst_size, 20);
453 assert_eq!(manager.config.whitelist_ips.len(), 2);
454 }
455}