Skip to main content

mockforge_http/middleware/
rate_limit.rs

1//! Global rate limiting middleware for HTTP server
2//!
3//! This module provides rate limiting to protect against abuse and DDoS attacks
4//! and adds production-like rate limit headers to responses.
5
6use axum::{
7    body::Body,
8    extract::{ConnectInfo, State},
9    http::{HeaderName, HeaderValue, Request, StatusCode},
10    middleware::Next,
11    response::Response,
12};
13use governor::{
14    clock::DefaultClock,
15    state::{InMemoryState, NotKeyed},
16    Quota, RateLimiter,
17};
18use std::net::SocketAddr;
19use std::num::NonZeroU32;
20use std::sync::{Arc, Mutex};
21use std::time::{Duration, SystemTime, UNIX_EPOCH};
22use tracing::warn;
23
24/// Rate limiting configuration
25#[derive(Debug, Clone)]
26pub struct RateLimitConfig {
27    /// Requests per minute
28    pub requests_per_minute: u32,
29    /// Burst capacity
30    pub burst: u32,
31    /// Enable per-IP rate limiting
32    pub per_ip: bool,
33    /// Enable per-endpoint rate limiting
34    pub per_endpoint: bool,
35}
36
37impl Default for RateLimitConfig {
38    fn default() -> Self {
39        Self {
40            requests_per_minute: 100,
41            burst: 200,
42            per_ip: true,
43            per_endpoint: false,
44        }
45    }
46}
47
48/// Rate limit quota information for headers
49#[derive(Debug, Clone)]
50pub struct RateLimitQuota {
51    /// Maximum requests per minute (limit)
52    pub limit: u32,
53    /// Remaining requests in current window (approximate)
54    pub remaining: u32,
55    /// Unix timestamp when the rate limit resets
56    pub reset: u64,
57}
58
59/// Global rate limiter state
60pub struct GlobalRateLimiter {
61    limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
62    config: RateLimitConfig,
63    /// Track window start time for reset calculation
64    window_start: Arc<Mutex<SystemTime>>,
65    /// Track approximate remaining requests
66    remaining_counter: Arc<Mutex<u32>>,
67}
68
69impl GlobalRateLimiter {
70    /// Create a new global rate limiter
71    pub fn new(config: RateLimitConfig) -> Self {
72        let quota = Quota::per_minute(
73            NonZeroU32::new(config.requests_per_minute)
74                .unwrap_or(NonZeroU32::new(100).expect("constant 100 is non-zero")),
75        )
76        .allow_burst(
77            NonZeroU32::new(config.burst)
78                .unwrap_or(NonZeroU32::new(200).expect("constant 200 is non-zero")),
79        );
80
81        let limiter = Arc::new(RateLimiter::direct(quota));
82        let window_start = Arc::new(Mutex::new(SystemTime::now()));
83        let remaining_counter = Arc::new(Mutex::new(config.requests_per_minute));
84
85        Self {
86            limiter,
87            config,
88            window_start,
89            remaining_counter,
90        }
91    }
92
93    /// Check if request should be rate limited
94    pub fn check_rate_limit(&self) -> bool {
95        self.limiter.check().is_ok()
96    }
97
98    /// Get rate limit quota information for headers
99    ///
100    /// Returns information about the current rate limit state including
101    /// limit, remaining requests, and reset timestamp.
102    pub fn get_quota_info(&self) -> RateLimitQuota {
103        let now = SystemTime::now();
104        let mut window_start =
105            self.window_start.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
106        let mut remaining =
107            self.remaining_counter.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
108
109        // Check if we need to reset the window (every minute)
110        let window_duration = Duration::from_secs(60);
111        if now.duration_since(*window_start).unwrap_or(Duration::ZERO) >= window_duration {
112            // Reset window
113            *window_start = now;
114            *remaining = self.config.requests_per_minute;
115        }
116
117        // Decrement remaining if we successfully checked (approximate)
118        // Note: This is approximate because governor's token bucket
119        // may have different internal state, but it's good enough for headers
120        let current_remaining = *remaining;
121        if current_remaining > 0 {
122            *remaining = current_remaining.saturating_sub(1);
123        }
124
125        // Calculate reset timestamp (start of next window)
126        let reset_timestamp =
127            window_start.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_secs() + 60; // Add 60 seconds for next window
128
129        RateLimitQuota {
130            limit: self.config.requests_per_minute,
131            remaining: current_remaining,
132            reset: reset_timestamp,
133        }
134    }
135}
136
137/// Rate limiting middleware
138///
139/// This middleware:
140/// 1. Checks if the request should be rate limited
141/// 2. Adds rate limit headers to successful responses (for deceptive deploy)
142/// 3. Returns 429 with Retry-After header when rate limited
143pub async fn rate_limit_middleware(
144    State(state): State<crate::HttpServerState>,
145    ConnectInfo(addr): ConnectInfo<SocketAddr>,
146    req: Request<Body>,
147    next: Next,
148) -> Response {
149    // Get rate limiter from app state
150    let quota_info = if let Some(limiter) = &state.rate_limiter {
151        // Check rate limit
152        if !limiter.check_rate_limit() {
153            warn!("Rate limit exceeded for IP: {}", addr.ip());
154            // Return 429 with Retry-After header per HTTP spec
155            let mut response = Response::builder()
156                .status(StatusCode::TOO_MANY_REQUESTS)
157                .body(Body::from("Too Many Requests"))
158                .unwrap_or_else(|_| Response::new(Body::from("Too Many Requests")));
159
160            // Add Retry-After header (60 seconds = 1 minute window)
161            let retry_after = HeaderValue::from_static("60");
162            response
163                .headers_mut()
164                .insert(HeaderName::from_static("retry-after"), retry_after);
165
166            // Add rate limit headers to the 429 response
167            let quota = limiter.get_quota_info();
168            if let Ok(limit_value) = HeaderValue::from_str(&quota.limit.to_string()) {
169                response
170                    .headers_mut()
171                    .insert(HeaderName::from_static("x-rate-limit-limit"), limit_value);
172            }
173            if let Ok(remaining_value) = HeaderValue::from_str("0") {
174                response
175                    .headers_mut()
176                    .insert(HeaderName::from_static("x-rate-limit-remaining"), remaining_value);
177            }
178            if let Ok(reset_value) = HeaderValue::from_str(&quota.reset.to_string()) {
179                response
180                    .headers_mut()
181                    .insert(HeaderName::from_static("x-rate-limit-reset"), reset_value);
182            }
183
184            return response;
185        }
186
187        // Get quota information for headers
188        Some(limiter.get_quota_info())
189    } else {
190        // No rate limiter configured, allow request to proceed
191        tracing::debug!("No rate limiter configured, allowing request");
192        None
193    };
194
195    // Process request and get response
196    let mut response = next.run(req).await;
197
198    // Add rate limit headers to response if we have quota info
199    // This makes the mock API look more like production
200    if let Some(quota) = quota_info {
201        // Add X-Rate-Limit-Limit header
202        let limit_name = HeaderName::from_static("x-rate-limit-limit");
203        if let Ok(limit_value) = HeaderValue::from_str(&quota.limit.to_string()) {
204            response.headers_mut().insert(limit_name, limit_value);
205        }
206
207        // Add X-Rate-Limit-Remaining header
208        let remaining_name = HeaderName::from_static("x-rate-limit-remaining");
209        if let Ok(remaining_value) = HeaderValue::from_str(&quota.remaining.to_string()) {
210            response.headers_mut().insert(remaining_name, remaining_value);
211        }
212
213        // Add X-Rate-Limit-Reset header (Unix timestamp)
214        let reset_name = HeaderName::from_static("x-rate-limit-reset");
215        if let Ok(reset_value) = HeaderValue::from_str(&quota.reset.to_string()) {
216            response.headers_mut().insert(reset_name, reset_value);
217        }
218    }
219
220    response
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    // ==================== RateLimitConfig Tests ====================
228
229    #[test]
230    fn test_rate_limit_config_default() {
231        let config = RateLimitConfig::default();
232        assert_eq!(config.requests_per_minute, 100);
233        assert_eq!(config.burst, 200);
234        assert!(config.per_ip);
235        assert!(!config.per_endpoint);
236    }
237
238    #[test]
239    fn test_rate_limit_config_custom() {
240        let config = RateLimitConfig {
241            requests_per_minute: 50,
242            burst: 100,
243            per_ip: false,
244            per_endpoint: true,
245        };
246
247        assert_eq!(config.requests_per_minute, 50);
248        assert_eq!(config.burst, 100);
249        assert!(!config.per_ip);
250        assert!(config.per_endpoint);
251    }
252
253    #[test]
254    fn test_rate_limit_config_clone() {
255        let config = RateLimitConfig {
256            requests_per_minute: 75,
257            burst: 150,
258            per_ip: true,
259            per_endpoint: true,
260        };
261
262        let cloned = config.clone();
263
264        assert_eq!(cloned.requests_per_minute, config.requests_per_minute);
265        assert_eq!(cloned.burst, config.burst);
266        assert_eq!(cloned.per_ip, config.per_ip);
267        assert_eq!(cloned.per_endpoint, config.per_endpoint);
268    }
269
270    #[test]
271    fn test_rate_limit_config_debug() {
272        let config = RateLimitConfig::default();
273        let debug_str = format!("{:?}", config);
274
275        assert!(debug_str.contains("requests_per_minute"));
276        assert!(debug_str.contains("burst"));
277        assert!(debug_str.contains("per_ip"));
278        assert!(debug_str.contains("per_endpoint"));
279    }
280
281    // ==================== RateLimitQuota Tests ====================
282
283    #[test]
284    fn test_rate_limit_quota_creation() {
285        let quota = RateLimitQuota {
286            limit: 100,
287            remaining: 50,
288            reset: 1234567890,
289        };
290
291        assert_eq!(quota.limit, 100);
292        assert_eq!(quota.remaining, 50);
293        assert_eq!(quota.reset, 1234567890);
294    }
295
296    #[test]
297    fn test_rate_limit_quota_clone() {
298        let quota = RateLimitQuota {
299            limit: 200,
300            remaining: 175,
301            reset: 9876543210,
302        };
303
304        let cloned = quota.clone();
305
306        assert_eq!(cloned.limit, quota.limit);
307        assert_eq!(cloned.remaining, quota.remaining);
308        assert_eq!(cloned.reset, quota.reset);
309    }
310
311    #[test]
312    fn test_rate_limit_quota_debug() {
313        let quota = RateLimitQuota {
314            limit: 100,
315            remaining: 50,
316            reset: 1234567890,
317        };
318
319        let debug_str = format!("{:?}", quota);
320
321        assert!(debug_str.contains("limit"));
322        assert!(debug_str.contains("remaining"));
323        assert!(debug_str.contains("reset"));
324    }
325
326    // ==================== GlobalRateLimiter Tests ====================
327
328    #[test]
329    fn test_rate_limiter_creation() {
330        let config = RateLimitConfig::default();
331        let limiter = GlobalRateLimiter::new(config);
332
333        // Should allow first request
334        assert!(limiter.check_rate_limit());
335    }
336
337    #[test]
338    fn test_rate_limiter_with_custom_config() {
339        let config = RateLimitConfig {
340            requests_per_minute: 60,
341            burst: 10,
342            per_ip: false,
343            per_endpoint: false,
344        };
345
346        let limiter = GlobalRateLimiter::new(config);
347        assert!(limiter.check_rate_limit());
348    }
349
350    #[test]
351    fn test_rate_limiter_burst() {
352        let config = RateLimitConfig {
353            requests_per_minute: 10,
354            burst: 5,
355            per_ip: false,
356            per_endpoint: false,
357        };
358
359        let limiter = GlobalRateLimiter::new(config);
360
361        // Should allow burst requests
362        for _ in 0..5 {
363            assert!(limiter.check_rate_limit(), "Burst request should be allowed");
364        }
365    }
366
367    #[test]
368    fn test_rate_limiter_multiple_requests() {
369        let config = RateLimitConfig {
370            requests_per_minute: 1000,
371            burst: 100,
372            per_ip: false,
373            per_endpoint: false,
374        };
375
376        let limiter = GlobalRateLimiter::new(config);
377
378        // Should allow many requests within burst limit
379        for i in 0..50 {
380            assert!(limiter.check_rate_limit(), "Request {} should be allowed", i);
381        }
382    }
383
384    #[test]
385    fn test_get_quota_info() {
386        let config = RateLimitConfig {
387            requests_per_minute: 100,
388            burst: 50,
389            per_ip: false,
390            per_endpoint: false,
391        };
392
393        let limiter = GlobalRateLimiter::new(config);
394
395        let quota = limiter.get_quota_info();
396
397        assert_eq!(quota.limit, 100);
398        assert!(quota.remaining > 0);
399        assert!(quota.reset > 0);
400    }
401
402    #[test]
403    fn test_quota_info_limit_matches_config() {
404        let config = RateLimitConfig {
405            requests_per_minute: 500,
406            burst: 100,
407            per_ip: false,
408            per_endpoint: false,
409        };
410
411        let limiter = GlobalRateLimiter::new(config);
412        let quota = limiter.get_quota_info();
413
414        assert_eq!(quota.limit, 500);
415    }
416
417    #[test]
418    fn test_quota_decrements_remaining() {
419        let config = RateLimitConfig {
420            requests_per_minute: 100,
421            burst: 50,
422            per_ip: false,
423            per_endpoint: false,
424        };
425
426        let limiter = GlobalRateLimiter::new(config);
427
428        let first_quota = limiter.get_quota_info();
429        let second_quota = limiter.get_quota_info();
430
431        // Remaining should decrement between calls
432        assert!(second_quota.remaining <= first_quota.remaining, "Remaining should not increase");
433    }
434
435    #[test]
436    fn test_quota_reset_timestamp_is_future() {
437        let config = RateLimitConfig::default();
438        let limiter = GlobalRateLimiter::new(config);
439
440        let quota = limiter.get_quota_info();
441
442        let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
443
444        // Reset timestamp should be in the future (approximately 60 seconds from window start)
445        assert!(quota.reset >= now, "Reset timestamp should be >= current time");
446        assert!(quota.reset <= now + 120, "Reset timestamp should be within 2 minutes");
447    }
448
449    #[test]
450    fn test_rate_limiter_high_burst() {
451        let config = RateLimitConfig {
452            requests_per_minute: 10,
453            burst: 1000, // Very high burst
454            per_ip: false,
455            per_endpoint: false,
456        };
457
458        let limiter = GlobalRateLimiter::new(config);
459
460        // Should allow many requests due to high burst
461        for _ in 0..100 {
462            assert!(limiter.check_rate_limit());
463        }
464    }
465
466    #[test]
467    fn test_rate_limiter_low_limit() {
468        let config = RateLimitConfig {
469            requests_per_minute: 1,
470            burst: 1,
471            per_ip: false,
472            per_endpoint: false,
473        };
474
475        let limiter = GlobalRateLimiter::new(config);
476
477        // First request should succeed
478        assert!(limiter.check_rate_limit());
479    }
480
481    // ==================== Edge Cases ====================
482
483    #[test]
484    fn test_config_with_zero_values_handled() {
485        // Zero values should be handled gracefully by governor
486        let config = RateLimitConfig {
487            requests_per_minute: 0, // Will use default (100)
488            burst: 0,               // Will use default (200)
489            per_ip: false,
490            per_endpoint: false,
491        };
492
493        // Should not panic - NonZeroU32::new(0) returns None, unwrap_or handles it
494        let limiter = GlobalRateLimiter::new(config);
495        assert!(limiter.check_rate_limit());
496    }
497
498    #[test]
499    fn test_multiple_quota_calls_same_limiter() {
500        let config = RateLimitConfig::default();
501        let limiter = GlobalRateLimiter::new(config);
502
503        // Call get_quota_info multiple times
504        let quotas: Vec<RateLimitQuota> = (0..5).map(|_| limiter.get_quota_info()).collect();
505
506        // All should have same limit
507        for quota in &quotas {
508            assert_eq!(quota.limit, 100);
509        }
510
511        // Reset timestamps should be similar (within same window)
512        let first_reset = quotas[0].reset;
513        for quota in &quotas {
514            assert!(
515                (quota.reset as i64 - first_reset as i64).abs() <= 1,
516                "Reset timestamps should be within 1 second of each other"
517            );
518        }
519    }
520
521    #[test]
522    fn test_quota_remaining_never_negative() {
523        let config = RateLimitConfig {
524            requests_per_minute: 5,
525            burst: 5,
526            per_ip: false,
527            per_endpoint: false,
528        };
529
530        let limiter = GlobalRateLimiter::new(config);
531
532        // Call many times to exhaust quota
533        for _ in 0..20 {
534            let quota = limiter.get_quota_info();
535            // Remaining should never go below 0 due to saturating_sub
536            assert!(quota.remaining <= 100, "Remaining should be reasonable");
537        }
538    }
539}