1use crate::security::{RateLimitConfig, Result, SecurityError};
2use axum::{
3 extract::{ConnectInfo, Request, State},
4 http::{HeaderMap, StatusCode},
5 middleware::Next,
6 response::Response,
7};
8use governor::{
9 clock::DefaultClock,
10 middleware::NoOpMiddleware,
11 state::{InMemoryState, NotKeyed},
12 Quota, RateLimiter as GovernorRateLimiter,
13};
14use std::collections::HashMap;
15use std::net::{IpAddr, SocketAddr};
16use std::num::NonZeroU32;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::{debug, info, warn};
20
21type GovernorLimiter = GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>;
23type IpLimiters = Arc<RwLock<HashMap<IpAddr, Arc<GovernorLimiter>>>>;
24type UserLimiters = Arc<RwLock<HashMap<String, Arc<GovernorLimiter>>>>;
25
26pub struct RateLimitManager {
28 config: RateLimitConfig,
29 ip_limiters: IpLimiters,
30 user_limiters: UserLimiters,
31 global_limiter: Option<Arc<GovernorLimiter>>,
32}
33
34impl RateLimitManager {
35 pub fn new(config: RateLimitConfig) -> Self {
36 let global_limiter = if config.enabled {
37 let quota = Quota::per_minute(
38 NonZeroU32::new(config.requests_per_minute)
39 .unwrap_or(NonZeroU32::new(100).unwrap()),
40 );
41 Some(Arc::new(GovernorRateLimiter::direct(quota)))
42 } else {
43 None
44 };
45
46 Self {
47 config,
48 ip_limiters: Arc::new(RwLock::new(HashMap::new())),
49 user_limiters: Arc::new(RwLock::new(HashMap::new())),
50 global_limiter,
51 }
52 }
53
54 pub async fn check_ip_limit(&self, ip: IpAddr) -> Result<()> {
56 if !self.config.enabled || !self.config.per_ip {
57 return Ok(());
58 }
59
60 let ip_str = ip.to_string();
62 if self.config.whitelist_ips.contains(&ip_str) {
63 debug!("IP {} is whitelisted, bypassing rate limit", ip);
64 return Ok(());
65 }
66
67 let mut limiters = self.ip_limiters.write().await;
68
69 let limiter = limiters.entry(ip).or_insert_with(|| {
70 let quota = Quota::per_minute(
71 NonZeroU32::new(self.config.requests_per_minute)
72 .unwrap_or(NonZeroU32::new(100).unwrap()),
73 )
74 .allow_burst(
75 NonZeroU32::new(self.config.burst_size).unwrap_or(NonZeroU32::new(10).unwrap()),
76 );
77 Arc::new(GovernorRateLimiter::direct(quota))
78 });
79
80 let limiter = Arc::clone(limiter);
81 drop(limiters); match limiter.check() {
84 Ok(_) => {
85 debug!("Rate limit check passed for IP: {}", ip);
86 Ok(())
87 }
88 Err(_) => {
89 warn!("Rate limit exceeded for IP: {}", ip);
90 Err(SecurityError::RateLimitExceeded)
91 }
92 }
93 }
94
95 pub async fn check_user_limit(&self, user_id: &str) -> Result<()> {
97 if !self.config.enabled || !self.config.per_user {
98 return Ok(());
99 }
100
101 let mut limiters = self.user_limiters.write().await;
102
103 let limiter = limiters.entry(user_id.to_string()).or_insert_with(|| {
104 let quota = Quota::per_minute(
105 NonZeroU32::new(self.config.requests_per_minute)
106 .unwrap_or(NonZeroU32::new(100).unwrap()),
107 )
108 .allow_burst(
109 NonZeroU32::new(self.config.burst_size).unwrap_or(NonZeroU32::new(10).unwrap()),
110 );
111 Arc::new(GovernorRateLimiter::direct(quota))
112 });
113
114 let limiter = Arc::clone(limiter);
115 drop(limiters); match limiter.check() {
118 Ok(_) => {
119 debug!("Rate limit check passed for user: {}", user_id);
120 Ok(())
121 }
122 Err(_) => {
123 warn!("Rate limit exceeded for user: {}", user_id);
124 Err(SecurityError::RateLimitExceeded)
125 }
126 }
127 }
128
129 pub async fn check_global_limit(&self) -> Result<()> {
131 if !self.config.enabled {
132 return Ok(());
133 }
134
135 if let Some(limiter) = &self.global_limiter {
136 match limiter.check() {
137 Ok(_) => {
138 debug!("Global rate limit check passed");
139 Ok(())
140 }
141 Err(_) => {
142 warn!("Global rate limit exceeded");
143 Err(SecurityError::RateLimitExceeded)
144 }
145 }
146 } else {
147 Ok(())
148 }
149 }
150
151 pub async fn cleanup_limiters(&self) -> Result<()> {
153 let mut ip_limiters = self.ip_limiters.write().await;
154 let mut user_limiters = self.user_limiters.write().await;
155
156 let initial_ip_count = ip_limiters.len();
157 let initial_user_count = user_limiters.len();
158
159 ip_limiters.retain(|_, limiter| Arc::strong_count(limiter) > 1);
162 user_limiters.retain(|_, limiter| Arc::strong_count(limiter) > 1);
163
164 let cleaned_ip = initial_ip_count - ip_limiters.len();
165 let cleaned_user = initial_user_count - user_limiters.len();
166
167 if cleaned_ip > 0 || cleaned_user > 0 {
168 info!(
169 "Cleaned up {} IP limiters and {} user limiters",
170 cleaned_ip, cleaned_user
171 );
172 }
173
174 Ok(())
175 }
176
177 pub async fn get_statistics(&self) -> RateLimitStatistics {
179 let ip_limiters = self.ip_limiters.read().await;
180 let user_limiters = self.user_limiters.read().await;
181
182 RateLimitStatistics {
183 enabled: self.config.enabled,
184 requests_per_minute: self.config.requests_per_minute,
185 burst_size: self.config.burst_size,
186 active_ip_limiters: ip_limiters.len(),
187 active_user_limiters: user_limiters.len(),
188 per_ip_enabled: self.config.per_ip,
189 per_user_enabled: self.config.per_user,
190 whitelist_count: self.config.whitelist_ips.len(),
191 }
192 }
193
194 pub fn is_enabled(&self) -> bool {
195 self.config.enabled
196 }
197}
198
199#[derive(Debug, Clone, serde::Serialize)]
200pub struct RateLimitStatistics {
201 pub enabled: bool,
202 pub requests_per_minute: u32,
203 pub burst_size: u32,
204 pub active_ip_limiters: usize,
205 pub active_user_limiters: usize,
206 pub per_ip_enabled: bool,
207 pub per_user_enabled: bool,
208 pub whitelist_count: usize,
209}
210
211pub async fn rate_limit_middleware(
213 State(rate_limiter): State<Arc<RateLimitManager>>,
214 ConnectInfo(addr): ConnectInfo<SocketAddr>,
215 headers: HeaderMap,
216 request: Request,
217 next: Next,
218) -> std::result::Result<Response, StatusCode> {
219 if !rate_limiter.is_enabled() {
220 return Ok(next.run(request).await);
221 }
222
223 if rate_limiter.check_global_limit().await.is_err() {
225 warn!("Global rate limit exceeded");
226 return Err(StatusCode::TOO_MANY_REQUESTS);
227 }
228
229 let ip = addr.ip();
231 if rate_limiter.check_ip_limit(ip).await.is_err() {
232 warn!("IP rate limit exceeded for: {}", ip);
233 return Err(StatusCode::TOO_MANY_REQUESTS);
234 }
235
236 if rate_limiter.config.per_user {
238 if let Some(user_header) = headers.get("X-User-ID") {
239 if let Ok(user_id) = user_header.to_str() {
240 if rate_limiter.check_user_limit(user_id).await.is_err() {
241 warn!("User rate limit exceeded for: {}", user_id);
242 return Err(StatusCode::TOO_MANY_REQUESTS);
243 }
244 }
245 }
246 }
247
248 debug!("Rate limit checks passed for IP: {}", ip);
249 Ok(next.run(request).await)
250}
251
252pub fn create_rate_limit_middleware(
254 requests_per_minute: u32,
255 burst_size: u32,
256 whitelist_ips: Vec<String>,
257) -> RateLimitManager {
258 let config = RateLimitConfig {
259 enabled: true,
260 requests_per_minute,
261 burst_size,
262 per_ip: true,
263 per_user: true,
264 whitelist_ips,
265 };
266
267 RateLimitManager::new(config)
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use std::net::Ipv4Addr;
274
275 #[tokio::test]
276 async fn test_rate_limit_manager_creation() {
277 let config = RateLimitConfig::default();
278 let manager = RateLimitManager::new(config);
279 assert!(!manager.is_enabled());
280 }
281
282 #[tokio::test]
283 async fn test_disabled_rate_limiting() {
284 let config = RateLimitConfig::default(); let manager = RateLimitManager::new(config);
286
287 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
288 let result = manager.check_ip_limit(ip).await;
289 assert!(result.is_ok());
290 }
291
292 #[tokio::test]
293 async fn test_ip_whitelist() {
294 let config = RateLimitConfig {
295 enabled: true,
296 requests_per_minute: 1, burst_size: 1,
298 per_ip: true,
299 per_user: false,
300 whitelist_ips: vec!["192.168.1.1".to_string()],
301 };
302
303 let manager = RateLimitManager::new(config);
304 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
305
306 let result = manager.check_ip_limit(ip).await;
308 assert!(result.is_ok());
309 }
310
311 #[tokio::test]
312 async fn test_rate_limit_exceeded() {
313 let config = RateLimitConfig {
314 enabled: true,
315 requests_per_minute: 1,
316 burst_size: 1,
317 per_ip: true,
318 per_user: false,
319 whitelist_ips: Vec::new(),
320 };
321
322 let manager = RateLimitManager::new(config);
323 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
324
325 let result1 = manager.check_ip_limit(ip).await;
327 assert!(result1.is_ok());
328
329 let result2 = manager.check_ip_limit(ip).await;
331 assert!(result2.is_err());
332
333 if let Err(SecurityError::RateLimitExceeded) = result2 {
334 } else {
336 panic!("Expected RateLimitExceeded error");
337 }
338 }
339
340 #[tokio::test]
341 async fn test_user_rate_limiting() {
342 let config = RateLimitConfig {
343 enabled: true,
344 requests_per_minute: 1,
345 burst_size: 1,
346 per_ip: false,
347 per_user: true,
348 whitelist_ips: Vec::new(),
349 };
350
351 let manager = RateLimitManager::new(config);
352 let user_id = "test-user";
353
354 let result1 = manager.check_user_limit(user_id).await;
356 assert!(result1.is_ok());
357
358 let result2 = manager.check_user_limit(user_id).await;
360 assert!(result2.is_err());
361 }
362
363 #[tokio::test]
364 async fn test_global_rate_limiting() {
365 let config = RateLimitConfig {
366 enabled: true,
367 requests_per_minute: 1,
368 burst_size: 1,
369 per_ip: false,
370 per_user: false,
371 whitelist_ips: Vec::new(),
372 };
373
374 let manager = RateLimitManager::new(config);
375
376 let result1 = manager.check_global_limit().await;
378 assert!(result1.is_ok());
379
380 let result2 = manager.check_global_limit().await;
382 assert!(result2.is_err());
383 }
384
385 #[tokio::test]
386 async fn test_statistics() {
387 let config = RateLimitConfig {
388 enabled: true,
389 requests_per_minute: 100,
390 burst_size: 10,
391 per_ip: true,
392 per_user: true,
393 whitelist_ips: vec!["127.0.0.1".to_string()],
394 };
395
396 let manager = RateLimitManager::new(config);
397 let stats = manager.get_statistics().await;
398
399 assert!(stats.enabled);
400 assert_eq!(stats.requests_per_minute, 100);
401 assert_eq!(stats.burst_size, 10);
402 assert!(stats.per_ip_enabled);
403 assert!(stats.per_user_enabled);
404 assert_eq!(stats.whitelist_count, 1);
405 assert_eq!(stats.active_ip_limiters, 0);
406 assert_eq!(stats.active_user_limiters, 0);
407 }
408
409 #[tokio::test]
410 async fn test_limiter_cleanup() {
411 let config = RateLimitConfig {
412 enabled: true,
413 requests_per_minute: 100,
414 burst_size: 10,
415 per_ip: true,
416 per_user: true,
417 whitelist_ips: Vec::new(),
418 };
419
420 let manager = RateLimitManager::new(config);
421
422 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
424 let _ = manager.check_ip_limit(ip).await;
425 let _ = manager.check_user_limit("test-user").await;
426
427 let stats_before = manager.get_statistics().await;
428 assert!(stats_before.active_ip_limiters > 0 || stats_before.active_user_limiters > 0);
429
430 let result = manager.cleanup_limiters().await;
432 assert!(result.is_ok());
433 }
434
435 #[test]
436 fn test_custom_rate_limiter_creation() {
437 let manager =
438 create_rate_limit_middleware(200, 20, vec!["127.0.0.1".to_string(), "::1".to_string()]);
439
440 assert!(manager.is_enabled());
441 assert_eq!(manager.config.requests_per_minute, 200);
442 assert_eq!(manager.config.burst_size, 20);
443 assert_eq!(manager.config.whitelist_ips.len(), 2);
444 }
445}