mockforge_http/middleware/
rate_limit.rs

1//! Global rate limiting middleware for HTTP server
2//!
3//! This module provides rate limiting to protect against abuse and DDoS attacks
4
5use axum::{
6    body::Body,
7    extract::{ConnectInfo, State},
8    http::{Request, StatusCode},
9    middleware::Next,
10    response::Response,
11};
12use governor::{
13    clock::DefaultClock,
14    state::{InMemoryState, NotKeyed},
15    Quota, RateLimiter,
16};
17use std::net::SocketAddr;
18use std::num::NonZeroU32;
19use std::sync::Arc;
20use tracing::warn;
21
22/// Rate limiting configuration
23#[derive(Debug, Clone)]
24pub struct RateLimitConfig {
25    /// Requests per minute
26    pub requests_per_minute: u32,
27    /// Burst capacity
28    pub burst: u32,
29    /// Enable per-IP rate limiting
30    pub per_ip: bool,
31    /// Enable per-endpoint rate limiting
32    pub per_endpoint: bool,
33}
34
35impl Default for RateLimitConfig {
36    fn default() -> Self {
37        Self {
38            requests_per_minute: 100,
39            burst: 200,
40            per_ip: true,
41            per_endpoint: false,
42        }
43    }
44}
45
46/// Global rate limiter state
47pub struct GlobalRateLimiter {
48    limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
49    _config: RateLimitConfig,
50}
51
52impl GlobalRateLimiter {
53    /// Create a new global rate limiter
54    pub fn new(config: RateLimitConfig) -> Self {
55        let quota = Quota::per_minute(
56            NonZeroU32::new(config.requests_per_minute).unwrap_or(NonZeroU32::new(100).unwrap()),
57        )
58        .allow_burst(NonZeroU32::new(config.burst).unwrap_or(NonZeroU32::new(200).unwrap()));
59
60        let limiter = Arc::new(RateLimiter::direct(quota));
61
62        Self {
63            limiter,
64            _config: config,
65        }
66    }
67
68    /// Check if request should be rate limited
69    pub fn check_rate_limit(&self) -> bool {
70        self.limiter.check().is_ok()
71    }
72}
73
74/// Rate limiting middleware
75pub async fn rate_limit_middleware(
76    State(state): axum::extract::State<crate::HttpServerState>,
77    ConnectInfo(addr): ConnectInfo<SocketAddr>,
78    req: Request<Body>,
79    next: Next,
80) -> Result<Response, StatusCode> {
81    // Get rate limiter from app state
82    if let Some(limiter) = &state.rate_limiter {
83        if !limiter.check_rate_limit() {
84            warn!("Rate limit exceeded for IP: {}", addr.ip());
85            return Err(StatusCode::TOO_MANY_REQUESTS);
86        }
87    } else {
88        // No rate limiter configured, allow request to proceed
89        tracing::debug!("No rate limiter configured, allowing request");
90    }
91
92    Ok(next.run(req).await)
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    #[test]
100    fn test_rate_limiter_creation() {
101        let config = RateLimitConfig::default();
102        let limiter = GlobalRateLimiter::new(config);
103
104        // Should allow first request
105        assert!(limiter.check_rate_limit());
106    }
107
108    #[test]
109    fn test_rate_limiter_burst() {
110        let config = RateLimitConfig {
111            requests_per_minute: 10,
112            burst: 5,
113            per_ip: false,
114            per_endpoint: false,
115        };
116
117        let limiter = GlobalRateLimiter::new(config);
118
119        // Should allow burst requests
120        for _ in 0..5 {
121            assert!(limiter.check_rate_limit(), "Burst request should be allowed");
122        }
123    }
124
125    #[test]
126    fn test_rate_limit_config_default() {
127        let config = RateLimitConfig::default();
128        assert_eq!(config.requests_per_minute, 100);
129        assert_eq!(config.burst, 200);
130        assert!(config.per_ip);
131        assert!(!config.per_endpoint);
132    }
133}