mockforge_http/middleware/
rate_limit.rs

1//! Global rate limiting middleware for HTTP server
2//!
3//! This module provides rate limiting to protect against abuse and DDoS attacks
4//! and adds production-like rate limit headers to responses.
5
6use axum::{
7    body::Body,
8    extract::{ConnectInfo, State},
9    http::{HeaderName, HeaderValue, Request, StatusCode},
10    middleware::Next,
11    response::Response,
12};
13use governor::{
14    clock::DefaultClock,
15    state::{InMemoryState, NotKeyed},
16    Quota, RateLimiter,
17};
18use std::net::SocketAddr;
19use std::num::NonZeroU32;
20use std::sync::{Arc, Mutex};
21use std::time::{Duration, SystemTime, UNIX_EPOCH};
22use tracing::warn;
23
24/// Rate limiting configuration
25#[derive(Debug, Clone)]
26pub struct RateLimitConfig {
27    /// Requests per minute
28    pub requests_per_minute: u32,
29    /// Burst capacity
30    pub burst: u32,
31    /// Enable per-IP rate limiting
32    pub per_ip: bool,
33    /// Enable per-endpoint rate limiting
34    pub per_endpoint: bool,
35}
36
37impl Default for RateLimitConfig {
38    fn default() -> Self {
39        Self {
40            requests_per_minute: 100,
41            burst: 200,
42            per_ip: true,
43            per_endpoint: false,
44        }
45    }
46}
47
48/// Rate limit quota information for headers
49#[derive(Debug, Clone)]
50pub struct RateLimitQuota {
51    /// Maximum requests per minute (limit)
52    pub limit: u32,
53    /// Remaining requests in current window (approximate)
54    pub remaining: u32,
55    /// Unix timestamp when the rate limit resets
56    pub reset: u64,
57}
58
59/// Global rate limiter state
60pub struct GlobalRateLimiter {
61    limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
62    config: RateLimitConfig,
63    /// Track window start time for reset calculation
64    window_start: Arc<Mutex<SystemTime>>,
65    /// Track approximate remaining requests
66    remaining_counter: Arc<Mutex<u32>>,
67}
68
69impl GlobalRateLimiter {
70    /// Create a new global rate limiter
71    pub fn new(config: RateLimitConfig) -> Self {
72        let quota = Quota::per_minute(
73            NonZeroU32::new(config.requests_per_minute)
74                .unwrap_or(NonZeroU32::new(100).expect("constant 100 is non-zero")),
75        )
76        .allow_burst(
77            NonZeroU32::new(config.burst)
78                .unwrap_or(NonZeroU32::new(200).expect("constant 200 is non-zero")),
79        );
80
81        let limiter = Arc::new(RateLimiter::direct(quota));
82        let window_start = Arc::new(Mutex::new(SystemTime::now()));
83        let remaining_counter = Arc::new(Mutex::new(config.requests_per_minute));
84
85        Self {
86            limiter,
87            config,
88            window_start,
89            remaining_counter,
90        }
91    }
92
93    /// Check if request should be rate limited
94    pub fn check_rate_limit(&self) -> bool {
95        self.limiter.check().is_ok()
96    }
97
98    /// Get rate limit quota information for headers
99    ///
100    /// Returns information about the current rate limit state including
101    /// limit, remaining requests, and reset timestamp.
102    pub fn get_quota_info(&self) -> RateLimitQuota {
103        let now = SystemTime::now();
104        let mut window_start =
105            self.window_start.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
106        let mut remaining =
107            self.remaining_counter.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
108
109        // Check if we need to reset the window (every minute)
110        let window_duration = Duration::from_secs(60);
111        if now.duration_since(*window_start).unwrap_or(Duration::ZERO) >= window_duration {
112            // Reset window
113            *window_start = now;
114            *remaining = self.config.requests_per_minute;
115        }
116
117        // Decrement remaining if we successfully checked (approximate)
118        // Note: This is approximate because governor's token bucket
119        // may have different internal state, but it's good enough for headers
120        let current_remaining = *remaining;
121        if current_remaining > 0 {
122            *remaining = current_remaining.saturating_sub(1);
123        }
124
125        // Calculate reset timestamp (start of next window)
126        let reset_timestamp =
127            window_start.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_secs() + 60; // Add 60 seconds for next window
128
129        RateLimitQuota {
130            limit: self.config.requests_per_minute,
131            remaining: current_remaining,
132            reset: reset_timestamp,
133        }
134    }
135}
136
137/// Rate limiting middleware
138///
139/// This middleware:
140/// 1. Checks if the request should be rate limited
141/// 2. Adds rate limit headers to successful responses (for deceptive deploy)
142pub async fn rate_limit_middleware(
143    State(state): axum::extract::State<crate::HttpServerState>,
144    ConnectInfo(addr): ConnectInfo<SocketAddr>,
145    req: Request<Body>,
146    next: Next,
147) -> Result<Response, StatusCode> {
148    // Get rate limiter from app state
149    let quota_info = if let Some(limiter) = &state.rate_limiter {
150        // Check rate limit
151        if !limiter.check_rate_limit() {
152            warn!("Rate limit exceeded for IP: {}", addr.ip());
153            return Err(StatusCode::TOO_MANY_REQUESTS);
154        }
155
156        // Get quota information for headers
157        Some(limiter.get_quota_info())
158    } else {
159        // No rate limiter configured, allow request to proceed
160        tracing::debug!("No rate limiter configured, allowing request");
161        None
162    };
163
164    // Process request and get response
165    let mut response = next.run(req).await;
166
167    // Add rate limit headers to response if we have quota info
168    // This makes the mock API look more like production
169    if let Some(quota) = quota_info {
170        // Add X-Rate-Limit-Limit header
171        let limit_name = HeaderName::from_static("x-rate-limit-limit");
172        if let Ok(limit_value) = HeaderValue::from_str(&quota.limit.to_string()) {
173            response.headers_mut().insert(limit_name, limit_value);
174        }
175
176        // Add X-Rate-Limit-Remaining header
177        let remaining_name = HeaderName::from_static("x-rate-limit-remaining");
178        if let Ok(remaining_value) = HeaderValue::from_str(&quota.remaining.to_string()) {
179            response.headers_mut().insert(remaining_name, remaining_value);
180        }
181
182        // Add X-Rate-Limit-Reset header (Unix timestamp)
183        let reset_name = HeaderName::from_static("x-rate-limit-reset");
184        if let Ok(reset_value) = HeaderValue::from_str(&quota.reset.to_string()) {
185            response.headers_mut().insert(reset_name, reset_value);
186        }
187    }
188
189    Ok(response)
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    // ==================== RateLimitConfig Tests ====================
197
198    #[test]
199    fn test_rate_limit_config_default() {
200        let config = RateLimitConfig::default();
201        assert_eq!(config.requests_per_minute, 100);
202        assert_eq!(config.burst, 200);
203        assert!(config.per_ip);
204        assert!(!config.per_endpoint);
205    }
206
207    #[test]
208    fn test_rate_limit_config_custom() {
209        let config = RateLimitConfig {
210            requests_per_minute: 50,
211            burst: 100,
212            per_ip: false,
213            per_endpoint: true,
214        };
215
216        assert_eq!(config.requests_per_minute, 50);
217        assert_eq!(config.burst, 100);
218        assert!(!config.per_ip);
219        assert!(config.per_endpoint);
220    }
221
222    #[test]
223    fn test_rate_limit_config_clone() {
224        let config = RateLimitConfig {
225            requests_per_minute: 75,
226            burst: 150,
227            per_ip: true,
228            per_endpoint: true,
229        };
230
231        let cloned = config.clone();
232
233        assert_eq!(cloned.requests_per_minute, config.requests_per_minute);
234        assert_eq!(cloned.burst, config.burst);
235        assert_eq!(cloned.per_ip, config.per_ip);
236        assert_eq!(cloned.per_endpoint, config.per_endpoint);
237    }
238
239    #[test]
240    fn test_rate_limit_config_debug() {
241        let config = RateLimitConfig::default();
242        let debug_str = format!("{:?}", config);
243
244        assert!(debug_str.contains("requests_per_minute"));
245        assert!(debug_str.contains("burst"));
246        assert!(debug_str.contains("per_ip"));
247        assert!(debug_str.contains("per_endpoint"));
248    }
249
250    // ==================== RateLimitQuota Tests ====================
251
252    #[test]
253    fn test_rate_limit_quota_creation() {
254        let quota = RateLimitQuota {
255            limit: 100,
256            remaining: 50,
257            reset: 1234567890,
258        };
259
260        assert_eq!(quota.limit, 100);
261        assert_eq!(quota.remaining, 50);
262        assert_eq!(quota.reset, 1234567890);
263    }
264
265    #[test]
266    fn test_rate_limit_quota_clone() {
267        let quota = RateLimitQuota {
268            limit: 200,
269            remaining: 175,
270            reset: 9876543210,
271        };
272
273        let cloned = quota.clone();
274
275        assert_eq!(cloned.limit, quota.limit);
276        assert_eq!(cloned.remaining, quota.remaining);
277        assert_eq!(cloned.reset, quota.reset);
278    }
279
280    #[test]
281    fn test_rate_limit_quota_debug() {
282        let quota = RateLimitQuota {
283            limit: 100,
284            remaining: 50,
285            reset: 1234567890,
286        };
287
288        let debug_str = format!("{:?}", quota);
289
290        assert!(debug_str.contains("limit"));
291        assert!(debug_str.contains("remaining"));
292        assert!(debug_str.contains("reset"));
293    }
294
295    // ==================== GlobalRateLimiter Tests ====================
296
297    #[test]
298    fn test_rate_limiter_creation() {
299        let config = RateLimitConfig::default();
300        let limiter = GlobalRateLimiter::new(config);
301
302        // Should allow first request
303        assert!(limiter.check_rate_limit());
304    }
305
306    #[test]
307    fn test_rate_limiter_with_custom_config() {
308        let config = RateLimitConfig {
309            requests_per_minute: 60,
310            burst: 10,
311            per_ip: false,
312            per_endpoint: false,
313        };
314
315        let limiter = GlobalRateLimiter::new(config);
316        assert!(limiter.check_rate_limit());
317    }
318
319    #[test]
320    fn test_rate_limiter_burst() {
321        let config = RateLimitConfig {
322            requests_per_minute: 10,
323            burst: 5,
324            per_ip: false,
325            per_endpoint: false,
326        };
327
328        let limiter = GlobalRateLimiter::new(config);
329
330        // Should allow burst requests
331        for _ in 0..5 {
332            assert!(limiter.check_rate_limit(), "Burst request should be allowed");
333        }
334    }
335
336    #[test]
337    fn test_rate_limiter_multiple_requests() {
338        let config = RateLimitConfig {
339            requests_per_minute: 1000,
340            burst: 100,
341            per_ip: false,
342            per_endpoint: false,
343        };
344
345        let limiter = GlobalRateLimiter::new(config);
346
347        // Should allow many requests within burst limit
348        for i in 0..50 {
349            assert!(limiter.check_rate_limit(), "Request {} should be allowed", i);
350        }
351    }
352
353    #[test]
354    fn test_get_quota_info() {
355        let config = RateLimitConfig {
356            requests_per_minute: 100,
357            burst: 50,
358            per_ip: false,
359            per_endpoint: false,
360        };
361
362        let limiter = GlobalRateLimiter::new(config);
363
364        let quota = limiter.get_quota_info();
365
366        assert_eq!(quota.limit, 100);
367        assert!(quota.remaining > 0);
368        assert!(quota.reset > 0);
369    }
370
371    #[test]
372    fn test_quota_info_limit_matches_config() {
373        let config = RateLimitConfig {
374            requests_per_minute: 500,
375            burst: 100,
376            per_ip: false,
377            per_endpoint: false,
378        };
379
380        let limiter = GlobalRateLimiter::new(config);
381        let quota = limiter.get_quota_info();
382
383        assert_eq!(quota.limit, 500);
384    }
385
386    #[test]
387    fn test_quota_decrements_remaining() {
388        let config = RateLimitConfig {
389            requests_per_minute: 100,
390            burst: 50,
391            per_ip: false,
392            per_endpoint: false,
393        };
394
395        let limiter = GlobalRateLimiter::new(config);
396
397        let first_quota = limiter.get_quota_info();
398        let second_quota = limiter.get_quota_info();
399
400        // Remaining should decrement between calls
401        assert!(second_quota.remaining <= first_quota.remaining, "Remaining should not increase");
402    }
403
404    #[test]
405    fn test_quota_reset_timestamp_is_future() {
406        let config = RateLimitConfig::default();
407        let limiter = GlobalRateLimiter::new(config);
408
409        let quota = limiter.get_quota_info();
410
411        let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
412
413        // Reset timestamp should be in the future (approximately 60 seconds from window start)
414        assert!(quota.reset >= now, "Reset timestamp should be >= current time");
415        assert!(quota.reset <= now + 120, "Reset timestamp should be within 2 minutes");
416    }
417
418    #[test]
419    fn test_rate_limiter_high_burst() {
420        let config = RateLimitConfig {
421            requests_per_minute: 10,
422            burst: 1000, // Very high burst
423            per_ip: false,
424            per_endpoint: false,
425        };
426
427        let limiter = GlobalRateLimiter::new(config);
428
429        // Should allow many requests due to high burst
430        for _ in 0..100 {
431            assert!(limiter.check_rate_limit());
432        }
433    }
434
435    #[test]
436    fn test_rate_limiter_low_limit() {
437        let config = RateLimitConfig {
438            requests_per_minute: 1,
439            burst: 1,
440            per_ip: false,
441            per_endpoint: false,
442        };
443
444        let limiter = GlobalRateLimiter::new(config);
445
446        // First request should succeed
447        assert!(limiter.check_rate_limit());
448    }
449
450    // ==================== Edge Cases ====================
451
452    #[test]
453    fn test_config_with_zero_values_handled() {
454        // Zero values should be handled gracefully by governor
455        let config = RateLimitConfig {
456            requests_per_minute: 0, // Will use default (100)
457            burst: 0,               // Will use default (200)
458            per_ip: false,
459            per_endpoint: false,
460        };
461
462        // Should not panic - NonZeroU32::new(0) returns None, unwrap_or handles it
463        let limiter = GlobalRateLimiter::new(config);
464        assert!(limiter.check_rate_limit());
465    }
466
467    #[test]
468    fn test_multiple_quota_calls_same_limiter() {
469        let config = RateLimitConfig::default();
470        let limiter = GlobalRateLimiter::new(config);
471
472        // Call get_quota_info multiple times
473        let quotas: Vec<RateLimitQuota> = (0..5).map(|_| limiter.get_quota_info()).collect();
474
475        // All should have same limit
476        for quota in &quotas {
477            assert_eq!(quota.limit, 100);
478        }
479
480        // Reset timestamps should be similar (within same window)
481        let first_reset = quotas[0].reset;
482        for quota in &quotas {
483            assert!(
484                (quota.reset as i64 - first_reset as i64).abs() <= 1,
485                "Reset timestamps should be within 1 second of each other"
486            );
487        }
488    }
489
490    #[test]
491    fn test_quota_remaining_never_negative() {
492        let config = RateLimitConfig {
493            requests_per_minute: 5,
494            burst: 5,
495            per_ip: false,
496            per_endpoint: false,
497        };
498
499        let limiter = GlobalRateLimiter::new(config);
500
501        // Call many times to exhaust quota
502        for _ in 0..20 {
503            let quota = limiter.get_quota_info();
504            // Remaining should never go below 0 due to saturating_sub
505            assert!(quota.remaining <= 100, "Remaining should be reasonable");
506        }
507    }
508}