mockforge_http/middleware/
rate_limit.rs

1//! Global rate limiting middleware for HTTP server
2//!
3//! This module provides rate limiting to protect against abuse and DDoS attacks
4//! and adds production-like rate limit headers to responses.
5
6use axum::{
7    body::Body,
8    extract::{ConnectInfo, State},
9    http::{HeaderName, HeaderValue, Request, StatusCode},
10    middleware::Next,
11    response::Response,
12};
13use governor::{
14    clock::DefaultClock,
15    state::{InMemoryState, NotKeyed},
16    Quota, RateLimiter,
17};
18use std::net::SocketAddr;
19use std::num::NonZeroU32;
20use std::sync::{Arc, Mutex};
21use std::time::{Duration, SystemTime, UNIX_EPOCH};
22use tracing::warn;
23
24/// Rate limiting configuration
25#[derive(Debug, Clone)]
26pub struct RateLimitConfig {
27    /// Requests per minute
28    pub requests_per_minute: u32,
29    /// Burst capacity
30    pub burst: u32,
31    /// Enable per-IP rate limiting
32    pub per_ip: bool,
33    /// Enable per-endpoint rate limiting
34    pub per_endpoint: bool,
35}
36
37impl Default for RateLimitConfig {
38    fn default() -> Self {
39        Self {
40            requests_per_minute: 100,
41            burst: 200,
42            per_ip: true,
43            per_endpoint: false,
44        }
45    }
46}
47
48/// Rate limit quota information for headers
49#[derive(Debug, Clone)]
50pub struct RateLimitQuota {
51    /// Maximum requests per minute (limit)
52    pub limit: u32,
53    /// Remaining requests in current window (approximate)
54    pub remaining: u32,
55    /// Unix timestamp when the rate limit resets
56    pub reset: u64,
57}
58
59/// Global rate limiter state
60pub struct GlobalRateLimiter {
61    limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
62    config: RateLimitConfig,
63    /// Track window start time for reset calculation
64    window_start: Arc<Mutex<SystemTime>>,
65    /// Track approximate remaining requests
66    remaining_counter: Arc<Mutex<u32>>,
67}
68
69impl GlobalRateLimiter {
70    /// Create a new global rate limiter
71    pub fn new(config: RateLimitConfig) -> Self {
72        let quota = Quota::per_minute(
73            NonZeroU32::new(config.requests_per_minute).unwrap_or(NonZeroU32::new(100).unwrap()),
74        )
75        .allow_burst(NonZeroU32::new(config.burst).unwrap_or(NonZeroU32::new(200).unwrap()));
76
77        let limiter = Arc::new(RateLimiter::direct(quota));
78        let window_start = Arc::new(Mutex::new(SystemTime::now()));
79        let remaining_counter = Arc::new(Mutex::new(config.requests_per_minute));
80
81        Self {
82            limiter,
83            config,
84            window_start,
85            remaining_counter,
86        }
87    }
88
89    /// Check if request should be rate limited
90    pub fn check_rate_limit(&self) -> bool {
91        self.limiter.check().is_ok()
92    }
93
94    /// Get rate limit quota information for headers
95    ///
96    /// Returns information about the current rate limit state including
97    /// limit, remaining requests, and reset timestamp.
98    pub fn get_quota_info(&self) -> RateLimitQuota {
99        let now = SystemTime::now();
100        let mut window_start = self.window_start.lock().unwrap();
101        let mut remaining = self.remaining_counter.lock().unwrap();
102
103        // Check if we need to reset the window (every minute)
104        let window_duration = Duration::from_secs(60);
105        if now.duration_since(*window_start).unwrap_or(Duration::ZERO) >= window_duration {
106            // Reset window
107            *window_start = now;
108            *remaining = self.config.requests_per_minute;
109        }
110
111        // Decrement remaining if we successfully checked (approximate)
112        // Note: This is approximate because governor's token bucket
113        // may have different internal state, but it's good enough for headers
114        let current_remaining = *remaining;
115        if current_remaining > 0 {
116            *remaining = current_remaining.saturating_sub(1);
117        }
118
119        // Calculate reset timestamp (start of next window)
120        let reset_timestamp =
121            window_start.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_secs() + 60; // Add 60 seconds for next window
122
123        RateLimitQuota {
124            limit: self.config.requests_per_minute,
125            remaining: current_remaining,
126            reset: reset_timestamp,
127        }
128    }
129}
130
131/// Rate limiting middleware
132///
133/// This middleware:
134/// 1. Checks if the request should be rate limited
135/// 2. Adds rate limit headers to successful responses (for deceptive deploy)
136pub async fn rate_limit_middleware(
137    State(state): axum::extract::State<crate::HttpServerState>,
138    ConnectInfo(addr): ConnectInfo<SocketAddr>,
139    req: Request<Body>,
140    next: Next,
141) -> Result<Response, StatusCode> {
142    // Get rate limiter from app state
143    let quota_info = if let Some(limiter) = &state.rate_limiter {
144        // Check rate limit
145        if !limiter.check_rate_limit() {
146            warn!("Rate limit exceeded for IP: {}", addr.ip());
147            return Err(StatusCode::TOO_MANY_REQUESTS);
148        }
149
150        // Get quota information for headers
151        Some(limiter.get_quota_info())
152    } else {
153        // No rate limiter configured, allow request to proceed
154        tracing::debug!("No rate limiter configured, allowing request");
155        None
156    };
157
158    // Process request and get response
159    let mut response = next.run(req).await;
160
161    // Add rate limit headers to response if we have quota info
162    // This makes the mock API look more like production
163    if let Some(quota) = quota_info {
164        // Add X-Rate-Limit-Limit header
165        let limit_name = HeaderName::from_static("x-rate-limit-limit");
166        if let Ok(limit_value) = HeaderValue::from_str(&quota.limit.to_string()) {
167            response.headers_mut().insert(limit_name, limit_value);
168        }
169
170        // Add X-Rate-Limit-Remaining header
171        let remaining_name = HeaderName::from_static("x-rate-limit-remaining");
172        if let Ok(remaining_value) = HeaderValue::from_str(&quota.remaining.to_string()) {
173            response.headers_mut().insert(remaining_name, remaining_value);
174        }
175
176        // Add X-Rate-Limit-Reset header (Unix timestamp)
177        let reset_name = HeaderName::from_static("x-rate-limit-reset");
178        if let Ok(reset_value) = HeaderValue::from_str(&quota.reset.to_string()) {
179            response.headers_mut().insert(reset_name, reset_value);
180        }
181    }
182
183    Ok(response)
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn test_rate_limiter_creation() {
192        let config = RateLimitConfig::default();
193        let limiter = GlobalRateLimiter::new(config);
194
195        // Should allow first request
196        assert!(limiter.check_rate_limit());
197    }
198
199    #[test]
200    fn test_rate_limiter_burst() {
201        let config = RateLimitConfig {
202            requests_per_minute: 10,
203            burst: 5,
204            per_ip: false,
205            per_endpoint: false,
206        };
207
208        let limiter = GlobalRateLimiter::new(config);
209
210        // Should allow burst requests
211        for _ in 0..5 {
212            assert!(limiter.check_rate_limit(), "Burst request should be allowed");
213        }
214    }
215
216    #[test]
217    fn test_rate_limit_config_default() {
218        let config = RateLimitConfig::default();
219        assert_eq!(config.requests_per_minute, 100);
220        assert_eq!(config.burst, 200);
221        assert!(config.per_ip);
222        assert!(!config.per_endpoint);
223    }
224}