1#[cfg(feature = "redis")]
46use deadpool_redis::Pool as RedisPool;
47
48use axum::{
49 extract::{ConnectInfo, Request, State},
50 http::StatusCode,
51 middleware::Next,
52 response::{IntoResponse, Response},
53};
54use std::collections::HashMap;
55use std::net::SocketAddr;
56use std::sync::Arc;
57use std::time::{Duration, Instant};
58use tokio::sync::RwLock;
59use tracing::{debug, warn};
60
61use crate::config::RateLimitConfig;
62
63#[derive(Debug, Clone)]
65struct RateLimitEntry {
66 count: u32,
68 window_start: Instant,
70}
71
72type InMemoryStore = Arc<RwLock<HashMap<String, RateLimitEntry>>>;
74
75#[derive(Clone)]
80pub struct RateLimit {
81 config: RateLimitConfig,
82 #[cfg(feature = "redis")]
83 redis_pool: Option<RedisPool>,
84 in_memory_store: InMemoryStore,
85}
86
87impl RateLimit {
88 #[must_use]
116 #[cfg(feature = "redis")]
117 pub fn new(config: RateLimitConfig, redis_pool: Option<RedisPool>) -> Self {
118 Self {
119 config,
120 redis_pool,
121 in_memory_store: Arc::new(RwLock::new(HashMap::new())),
122 }
123 }
124
125 #[must_use]
127 #[cfg(not(feature = "redis"))]
128 pub fn new(config: RateLimitConfig, _redis_pool: Option<()>) -> Self {
129 Self {
130 config,
131 in_memory_store: Arc::new(RwLock::new(HashMap::new())),
132 }
133 }
134
135 pub async fn middleware(
148 State(rate_limit): State<Self>,
149 request: Request,
150 next: Next,
151 ) -> Result<Response, RateLimitError> {
152 if !rate_limit.config.enabled {
154 return Ok(next.run(request).await);
155 }
156
157 let user_id: Option<i64> = request.extensions().get::<i64>().copied();
159
160 let ip_addr = request
162 .extensions()
163 .get::<ConnectInfo<SocketAddr>>()
164 .map(|ConnectInfo(addr)| addr.ip().to_string());
165
166 let path = request.uri().path();
168 let (key, limit) = rate_limit.determine_key_and_limit(user_id, ip_addr.as_deref(), path);
169
170 debug!(
171 key = %key,
172 limit = limit,
173 path = %path,
174 user_id = ?user_id,
175 "Checking rate limit"
176 );
177
178 rate_limit.check_rate_limit(&key, limit).await?;
180
181 Ok(next.run(request).await)
182 }
183
184 fn determine_key_and_limit(
186 &self,
187 user_id: Option<i64>,
188 ip_addr: Option<&str>,
189 path: &str,
190 ) -> (String, u32) {
191 let is_strict_route = self
193 .config
194 .strict_routes
195 .iter()
196 .any(|route| path.starts_with(route));
197
198 if is_strict_route {
199 let key = user_id.map_or_else(|| {
201 ip_addr.map_or_else(|| "ratelimit:route:unknown".to_string(), |ip| format!("ratelimit:route:ip:{ip}"))
202 }, |uid| format!("ratelimit:route:user:{uid}"));
203 (key, self.config.per_route_rpm)
204 } else if let Some(uid) = user_id {
205 (
207 format!("ratelimit:user:{uid}"),
208 self.config.per_user_rpm,
209 )
210 } else if let Some(ip) = ip_addr {
211 (format!("ratelimit:ip:{ip}"), self.config.per_ip_rpm)
213 } else {
214 ("ratelimit:unknown".to_string(), self.config.per_ip_rpm)
216 }
217 }
218
219 async fn check_rate_limit(&self, key: &str, limit: u32) -> Result<(), RateLimitError> {
221 #[cfg(feature = "redis")]
223 if self.config.redis_enabled {
224 if let Some(ref redis_pool) = self.redis_pool {
225 match self.check_rate_limit_redis(redis_pool, key, limit).await {
226 Ok(()) => return Ok(()),
227 Err(e) => {
228 warn!(
229 error = %e,
230 key = %key,
231 "Redis rate limit check failed, falling back to in-memory"
232 );
233 }
235 }
236 }
237 }
238
239 self.check_rate_limit_memory(key, limit).await
241 }
242
243 #[cfg(feature = "redis")]
245 async fn check_rate_limit_redis(
246 &self,
247 redis_pool: &RedisPool,
248 key: &str,
249 limit: u32,
250 ) -> Result<(), RateLimitError> {
251 let mut conn = redis_pool.get().await.map_err(|e| {
252 RateLimitError::Backend(format!("Failed to get Redis connection: {e}"))
253 })?;
254
255 let count: u32 = redis::cmd("INCR")
257 .arg(key)
258 .query_async(&mut *conn)
259 .await
260 .map_err(|e| RateLimitError::Backend(format!("Redis INCR failed: {e}")))?;
261
262 if count == 1 {
264 let expire_secs = i64::try_from(self.config.window_secs).unwrap_or(i64::MAX);
266 let _: () = redis::cmd("EXPIRE")
267 .arg(key)
268 .arg(expire_secs)
269 .query_async(&mut *conn)
270 .await
271 .map_err(|e| RateLimitError::Backend(format!("Redis EXPIRE failed: {e}")))?;
272 }
273
274 if count > limit {
276 warn!(
277 key = %key,
278 count = count,
279 limit = limit,
280 window_secs = self.config.window_secs,
281 "Rate limit exceeded"
282 );
283 return Err(RateLimitError::Exceeded {
284 limit,
285 window: Duration::from_secs(self.config.window_secs),
286 });
287 }
288
289 debug!(
290 key = %key,
291 count = count,
292 limit = limit,
293 "Rate limit check passed (Redis)"
294 );
295
296 Ok(())
297 }
298
299 async fn check_rate_limit_memory(&self, key: &str, limit: u32) -> Result<(), RateLimitError> {
301 let now = Instant::now();
302 let window_duration = Duration::from_secs(self.config.window_secs);
303
304 let mut store = self.in_memory_store.write().await;
306
307 let entry = store.entry(key.to_string()).or_insert_with(|| RateLimitEntry {
309 count: 0,
310 window_start: now,
311 });
312
313 if now.duration_since(entry.window_start) >= window_duration {
315 entry.count = 1;
317 entry.window_start = now;
318 } else {
319 entry.count += 1;
321 }
322
323 let count = entry.count;
324 drop(store); if count > limit {
328 warn!(
329 key = %key,
330 count = count,
331 limit = limit,
332 window_secs = self.config.window_secs,
333 "Rate limit exceeded"
334 );
335 return Err(RateLimitError::Exceeded {
336 limit,
337 window: window_duration,
338 });
339 }
340
341 debug!(
342 key = %key,
343 count = count,
344 limit = limit,
345 "Rate limit check passed (in-memory)"
346 );
347
348 Ok(())
349 }
350
351 pub async fn cleanup_expired(&self) -> usize {
356 let now = Instant::now();
357 let window_duration = Duration::from_secs(self.config.window_secs);
358
359 let removed = {
360 let mut store = self.in_memory_store.write().await;
361 let before_count = store.len();
362
363 store.retain(|_, entry| now.duration_since(entry.window_start) < window_duration);
364
365 before_count - store.len()
366 }; if removed > 0 {
369 debug!(removed = removed, "Cleaned up expired rate limit entries");
370 }
371
372 removed
373 }
374}
375
376#[derive(Debug, thiserror::Error)]
378pub enum RateLimitError {
379 #[error("Rate limit exceeded: {limit} requests per {window:?}")]
381 Exceeded {
382 limit: u32,
384 window: Duration,
386 },
387
388 #[error("Rate limit backend error: {0}")]
390 Backend(String),
391}
392
393impl IntoResponse for RateLimitError {
394 fn into_response(self) -> Response {
395 match self {
396 Self::Exceeded { limit, window } => {
397 let retry_after = window.as_secs();
398 (
399 StatusCode::TOO_MANY_REQUESTS,
400 [
401 ("Retry-After", retry_after.to_string()),
402 (
403 "X-RateLimit-Limit",
404 limit.to_string(),
405 ),
406 ],
407 format!(
408 "Rate limit exceeded. Maximum {} requests per {} seconds.",
409 limit,
410 window.as_secs()
411 ),
412 )
413 .into_response()
414 }
415 Self::Backend(msg) => {
416 warn!(error = %msg, "Rate limit backend error");
417 (
420 StatusCode::INTERNAL_SERVER_ERROR,
421 "Rate limiting temporarily unavailable",
422 )
423 .into_response()
424 }
425 }
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432 use crate::config::RateLimitFailureMode;
433
434 #[test]
435 fn test_rate_limit_creation() {
436 let config = RateLimitConfig::default();
437 let rate_limit = RateLimit::new(config, None);
438
439 assert!(rate_limit.config.enabled);
440 assert_eq!(rate_limit.config.per_user_rpm, 120);
441 assert_eq!(rate_limit.config.per_ip_rpm, 60);
442 assert_eq!(rate_limit.config.per_route_rpm, 30);
443 }
444
445 #[test]
446 fn test_determine_key_and_limit_authenticated() {
447 let config = RateLimitConfig::default();
448 let rate_limit = RateLimit::new(config, None);
449
450 let (key, limit) = rate_limit.determine_key_and_limit(Some(123), Some("192.168.1.1"), "/posts");
451 assert_eq!(key, "ratelimit:user:123");
452 assert_eq!(limit, 120);
453 }
454
455 #[test]
456 fn test_determine_key_and_limit_anonymous() {
457 let config = RateLimitConfig::default();
458 let rate_limit = RateLimit::new(config, None);
459
460 let (key, limit) = rate_limit.determine_key_and_limit(None, Some("192.168.1.1"), "/posts");
461 assert_eq!(key, "ratelimit:ip:192.168.1.1");
462 assert_eq!(limit, 60);
463 }
464
465 #[test]
466 fn test_determine_key_and_limit_strict_route_authenticated() {
467 let config = RateLimitConfig::default();
468 let rate_limit = RateLimit::new(config, None);
469
470 let (key, limit) = rate_limit.determine_key_and_limit(Some(123), Some("192.168.1.1"), "/login");
471 assert_eq!(key, "ratelimit:route:user:123");
472 assert_eq!(limit, 30);
473 }
474
475 #[test]
476 fn test_determine_key_and_limit_strict_route_anonymous() {
477 let config = RateLimitConfig::default();
478 let rate_limit = RateLimit::new(config, None);
479
480 let (key, limit) = rate_limit.determine_key_and_limit(None, Some("192.168.1.1"), "/register");
481 assert_eq!(key, "ratelimit:route:ip:192.168.1.1");
482 assert_eq!(limit, 30);
483 }
484
485 #[tokio::test]
486 async fn test_in_memory_rate_limit_within_limit() {
487 let config = RateLimitConfig {
488 enabled: true,
489 per_user_rpm: 5,
490 per_ip_rpm: 3,
491 per_route_rpm: 2,
492 window_secs: 60,
493 redis_enabled: false,
494 failure_mode: RateLimitFailureMode::Closed,
495 strict_routes: vec![],
496 };
497 let rate_limit = RateLimit::new(config, None);
498
499 for _ in 0..3 {
501 let result = rate_limit.check_rate_limit_memory("test_key", 5).await;
502 assert!(result.is_ok());
503 }
504 }
505
506 #[tokio::test]
507 async fn test_in_memory_rate_limit_exceeded() {
508 let config = RateLimitConfig {
509 enabled: true,
510 per_user_rpm: 5,
511 per_ip_rpm: 3,
512 per_route_rpm: 2,
513 window_secs: 60,
514 redis_enabled: false,
515 failure_mode: RateLimitFailureMode::Closed,
516 strict_routes: vec![],
517 };
518 let rate_limit = RateLimit::new(config, None);
519
520 for _ in 0..3 {
522 let result = rate_limit.check_rate_limit_memory("test_key", 3).await;
523 assert!(result.is_ok());
524 }
525
526 let result = rate_limit.check_rate_limit_memory("test_key", 3).await;
528 assert!(result.is_err());
529 assert!(matches!(result.unwrap_err(), RateLimitError::Exceeded { .. }));
530 }
531
532 #[tokio::test]
533 async fn test_in_memory_rate_limit_window_reset() {
534 let config = RateLimitConfig {
535 enabled: true,
536 per_user_rpm: 5,
537 per_ip_rpm: 3,
538 per_route_rpm: 2,
539 window_secs: 1, redis_enabled: false,
541 failure_mode: RateLimitFailureMode::Closed,
542 strict_routes: vec![],
543 };
544 let rate_limit = RateLimit::new(config, None);
545
546 for _ in 0..3 {
548 let result = rate_limit.check_rate_limit_memory("test_key", 3).await;
549 assert!(result.is_ok());
550 }
551
552 let result = rate_limit.check_rate_limit_memory("test_key", 3).await;
554 assert!(result.is_err());
555
556 tokio::time::sleep(Duration::from_secs(2)).await;
558
559 let result = rate_limit.check_rate_limit_memory("test_key", 3).await;
561 assert!(result.is_ok());
562 }
563
564 #[tokio::test]
565 async fn test_cleanup_expired() {
566 let config = RateLimitConfig {
567 enabled: true,
568 per_user_rpm: 5,
569 per_ip_rpm: 3,
570 per_route_rpm: 2,
571 window_secs: 1, redis_enabled: false,
573 failure_mode: RateLimitFailureMode::Closed,
574 strict_routes: vec![],
575 };
576 let rate_limit = RateLimit::new(config, None);
577
578 for i in 0..5 {
580 let key = format!("test_key_{i}");
581 let _ = rate_limit.check_rate_limit_memory(&key, 10).await;
582 }
583
584 let len = {
586 let store = rate_limit.in_memory_store.read().await;
587 store.len()
588 };
589 assert_eq!(len, 5);
590
591 tokio::time::sleep(Duration::from_secs(2)).await;
593
594 let removed = rate_limit.cleanup_expired().await;
596 assert_eq!(removed, 5);
597
598 let len = {
600 let store = rate_limit.in_memory_store.read().await;
601 store.len()
602 };
603 assert_eq!(len, 0);
604 }
605
606 #[test]
607 fn test_rate_limit_error_display() {
608 let error = RateLimitError::Exceeded {
609 limit: 100,
610 window: Duration::from_secs(60),
611 };
612 assert!(error.to_string().contains("100"));
613 assert!(error.to_string().contains("60"));
614
615 let error = RateLimitError::Backend("Redis connection failed".to_string());
616 assert!(error.to_string().contains("Redis connection failed"));
617 }
618}