mockforge_http/middleware/
rate_limit.rs

1//! Global rate limiting middleware for HTTP server
2//!
3//! This module provides rate limiting to protect against abuse and DDoS attacks
4//! and adds production-like rate limit headers to responses.
5
6use axum::{
7    body::Body,
8    extract::{ConnectInfo, State},
9    http::{HeaderName, HeaderValue, Request, StatusCode},
10    middleware::Next,
11    response::Response,
12};
13use governor::{
14    clock::DefaultClock,
15    state::{InMemoryState, NotKeyed},
16    Quota, RateLimiter,
17};
18use std::net::SocketAddr;
19use std::num::NonZeroU32;
20use std::sync::{Arc, Mutex};
21use std::time::{Duration, SystemTime, UNIX_EPOCH};
22use tracing::warn;
23
24/// Rate limiting configuration
25#[derive(Debug, Clone)]
26pub struct RateLimitConfig {
27    /// Requests per minute
28    pub requests_per_minute: u32,
29    /// Burst capacity
30    pub burst: u32,
31    /// Enable per-IP rate limiting
32    pub per_ip: bool,
33    /// Enable per-endpoint rate limiting
34    pub per_endpoint: bool,
35}
36
37impl Default for RateLimitConfig {
38    fn default() -> Self {
39        Self {
40            requests_per_minute: 100,
41            burst: 200,
42            per_ip: true,
43            per_endpoint: false,
44        }
45    }
46}
47
48/// Rate limit quota information for headers
49#[derive(Debug, Clone)]
50pub struct RateLimitQuota {
51    /// Maximum requests per minute (limit)
52    pub limit: u32,
53    /// Remaining requests in current window (approximate)
54    pub remaining: u32,
55    /// Unix timestamp when the rate limit resets
56    pub reset: u64,
57}
58
59/// Global rate limiter state
60pub struct GlobalRateLimiter {
61    limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
62    config: RateLimitConfig,
63    /// Track window start time for reset calculation
64    window_start: Arc<Mutex<SystemTime>>,
65    /// Track approximate remaining requests
66    remaining_counter: Arc<Mutex<u32>>,
67}
68
69impl GlobalRateLimiter {
70    /// Create a new global rate limiter
71    pub fn new(config: RateLimitConfig) -> Self {
72        let quota = Quota::per_minute(
73            NonZeroU32::new(config.requests_per_minute)
74                .unwrap_or(NonZeroU32::new(100).expect("constant 100 is non-zero")),
75        )
76        .allow_burst(
77            NonZeroU32::new(config.burst)
78                .unwrap_or(NonZeroU32::new(200).expect("constant 200 is non-zero")),
79        );
80
81        let limiter = Arc::new(RateLimiter::direct(quota));
82        let window_start = Arc::new(Mutex::new(SystemTime::now()));
83        let remaining_counter = Arc::new(Mutex::new(config.requests_per_minute));
84
85        Self {
86            limiter,
87            config,
88            window_start,
89            remaining_counter,
90        }
91    }
92
93    /// Check if request should be rate limited
94    pub fn check_rate_limit(&self) -> bool {
95        self.limiter.check().is_ok()
96    }
97
98    /// Get rate limit quota information for headers
99    ///
100    /// Returns information about the current rate limit state including
101    /// limit, remaining requests, and reset timestamp.
102    pub fn get_quota_info(&self) -> RateLimitQuota {
103        let now = SystemTime::now();
104        let mut window_start =
105            self.window_start.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
106        let mut remaining =
107            self.remaining_counter.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
108
109        // Check if we need to reset the window (every minute)
110        let window_duration = Duration::from_secs(60);
111        if now.duration_since(*window_start).unwrap_or(Duration::ZERO) >= window_duration {
112            // Reset window
113            *window_start = now;
114            *remaining = self.config.requests_per_minute;
115        }
116
117        // Decrement remaining if we successfully checked (approximate)
118        // Note: This is approximate because governor's token bucket
119        // may have different internal state, but it's good enough for headers
120        let current_remaining = *remaining;
121        if current_remaining > 0 {
122            *remaining = current_remaining.saturating_sub(1);
123        }
124
125        // Calculate reset timestamp (start of next window)
126        let reset_timestamp =
127            window_start.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_secs() + 60; // Add 60 seconds for next window
128
129        RateLimitQuota {
130            limit: self.config.requests_per_minute,
131            remaining: current_remaining,
132            reset: reset_timestamp,
133        }
134    }
135}
136
137/// Rate limiting middleware
138///
139/// This middleware:
140/// 1. Checks if the request should be rate limited
141/// 2. Adds rate limit headers to successful responses (for deceptive deploy)
142/// 3. Returns 429 with Retry-After header when rate limited
143pub async fn rate_limit_middleware(
144    State(state): axum::extract::State<crate::HttpServerState>,
145    ConnectInfo(addr): ConnectInfo<SocketAddr>,
146    req: Request<Body>,
147    next: Next,
148) -> Response {
149    // Get rate limiter from app state
150    let quota_info = if let Some(limiter) = &state.rate_limiter {
151        // Check rate limit
152        if !limiter.check_rate_limit() {
153            warn!("Rate limit exceeded for IP: {}", addr.ip());
154            // Return 429 with Retry-After header per HTTP spec
155            let mut response = Response::builder()
156                .status(StatusCode::TOO_MANY_REQUESTS)
157                .body(Body::from("Too Many Requests"))
158                .unwrap_or_else(|_| Response::new(Body::from("Too Many Requests")));
159
160            // Add Retry-After header (60 seconds = 1 minute window)
161            if let Ok(retry_after) = HeaderValue::from_static("60").try_into() {
162                response
163                    .headers_mut()
164                    .insert(HeaderName::from_static("retry-after"), retry_after);
165            }
166
167            // Add rate limit headers to the 429 response
168            let quota = limiter.get_quota_info();
169            if let Ok(limit_value) = HeaderValue::from_str(&quota.limit.to_string()) {
170                response
171                    .headers_mut()
172                    .insert(HeaderName::from_static("x-rate-limit-limit"), limit_value);
173            }
174            if let Ok(remaining_value) = HeaderValue::from_str("0") {
175                response
176                    .headers_mut()
177                    .insert(HeaderName::from_static("x-rate-limit-remaining"), remaining_value);
178            }
179            if let Ok(reset_value) = HeaderValue::from_str(&quota.reset.to_string()) {
180                response
181                    .headers_mut()
182                    .insert(HeaderName::from_static("x-rate-limit-reset"), reset_value);
183            }
184
185            return response;
186        }
187
188        // Get quota information for headers
189        Some(limiter.get_quota_info())
190    } else {
191        // No rate limiter configured, allow request to proceed
192        tracing::debug!("No rate limiter configured, allowing request");
193        None
194    };
195
196    // Process request and get response
197    let mut response = next.run(req).await;
198
199    // Add rate limit headers to response if we have quota info
200    // This makes the mock API look more like production
201    if let Some(quota) = quota_info {
202        // Add X-Rate-Limit-Limit header
203        let limit_name = HeaderName::from_static("x-rate-limit-limit");
204        if let Ok(limit_value) = HeaderValue::from_str(&quota.limit.to_string()) {
205            response.headers_mut().insert(limit_name, limit_value);
206        }
207
208        // Add X-Rate-Limit-Remaining header
209        let remaining_name = HeaderName::from_static("x-rate-limit-remaining");
210        if let Ok(remaining_value) = HeaderValue::from_str(&quota.remaining.to_string()) {
211            response.headers_mut().insert(remaining_name, remaining_value);
212        }
213
214        // Add X-Rate-Limit-Reset header (Unix timestamp)
215        let reset_name = HeaderName::from_static("x-rate-limit-reset");
216        if let Ok(reset_value) = HeaderValue::from_str(&quota.reset.to_string()) {
217            response.headers_mut().insert(reset_name, reset_value);
218        }
219    }
220
221    response
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    // ==================== RateLimitConfig Tests ====================
229
230    #[test]
231    fn test_rate_limit_config_default() {
232        let config = RateLimitConfig::default();
233        assert_eq!(config.requests_per_minute, 100);
234        assert_eq!(config.burst, 200);
235        assert!(config.per_ip);
236        assert!(!config.per_endpoint);
237    }
238
239    #[test]
240    fn test_rate_limit_config_custom() {
241        let config = RateLimitConfig {
242            requests_per_minute: 50,
243            burst: 100,
244            per_ip: false,
245            per_endpoint: true,
246        };
247
248        assert_eq!(config.requests_per_minute, 50);
249        assert_eq!(config.burst, 100);
250        assert!(!config.per_ip);
251        assert!(config.per_endpoint);
252    }
253
254    #[test]
255    fn test_rate_limit_config_clone() {
256        let config = RateLimitConfig {
257            requests_per_minute: 75,
258            burst: 150,
259            per_ip: true,
260            per_endpoint: true,
261        };
262
263        let cloned = config.clone();
264
265        assert_eq!(cloned.requests_per_minute, config.requests_per_minute);
266        assert_eq!(cloned.burst, config.burst);
267        assert_eq!(cloned.per_ip, config.per_ip);
268        assert_eq!(cloned.per_endpoint, config.per_endpoint);
269    }
270
271    #[test]
272    fn test_rate_limit_config_debug() {
273        let config = RateLimitConfig::default();
274        let debug_str = format!("{:?}", config);
275
276        assert!(debug_str.contains("requests_per_minute"));
277        assert!(debug_str.contains("burst"));
278        assert!(debug_str.contains("per_ip"));
279        assert!(debug_str.contains("per_endpoint"));
280    }
281
282    // ==================== RateLimitQuota Tests ====================
283
284    #[test]
285    fn test_rate_limit_quota_creation() {
286        let quota = RateLimitQuota {
287            limit: 100,
288            remaining: 50,
289            reset: 1234567890,
290        };
291
292        assert_eq!(quota.limit, 100);
293        assert_eq!(quota.remaining, 50);
294        assert_eq!(quota.reset, 1234567890);
295    }
296
297    #[test]
298    fn test_rate_limit_quota_clone() {
299        let quota = RateLimitQuota {
300            limit: 200,
301            remaining: 175,
302            reset: 9876543210,
303        };
304
305        let cloned = quota.clone();
306
307        assert_eq!(cloned.limit, quota.limit);
308        assert_eq!(cloned.remaining, quota.remaining);
309        assert_eq!(cloned.reset, quota.reset);
310    }
311
312    #[test]
313    fn test_rate_limit_quota_debug() {
314        let quota = RateLimitQuota {
315            limit: 100,
316            remaining: 50,
317            reset: 1234567890,
318        };
319
320        let debug_str = format!("{:?}", quota);
321
322        assert!(debug_str.contains("limit"));
323        assert!(debug_str.contains("remaining"));
324        assert!(debug_str.contains("reset"));
325    }
326
327    // ==================== GlobalRateLimiter Tests ====================
328
329    #[test]
330    fn test_rate_limiter_creation() {
331        let config = RateLimitConfig::default();
332        let limiter = GlobalRateLimiter::new(config);
333
334        // Should allow first request
335        assert!(limiter.check_rate_limit());
336    }
337
338    #[test]
339    fn test_rate_limiter_with_custom_config() {
340        let config = RateLimitConfig {
341            requests_per_minute: 60,
342            burst: 10,
343            per_ip: false,
344            per_endpoint: false,
345        };
346
347        let limiter = GlobalRateLimiter::new(config);
348        assert!(limiter.check_rate_limit());
349    }
350
351    #[test]
352    fn test_rate_limiter_burst() {
353        let config = RateLimitConfig {
354            requests_per_minute: 10,
355            burst: 5,
356            per_ip: false,
357            per_endpoint: false,
358        };
359
360        let limiter = GlobalRateLimiter::new(config);
361
362        // Should allow burst requests
363        for _ in 0..5 {
364            assert!(limiter.check_rate_limit(), "Burst request should be allowed");
365        }
366    }
367
368    #[test]
369    fn test_rate_limiter_multiple_requests() {
370        let config = RateLimitConfig {
371            requests_per_minute: 1000,
372            burst: 100,
373            per_ip: false,
374            per_endpoint: false,
375        };
376
377        let limiter = GlobalRateLimiter::new(config);
378
379        // Should allow many requests within burst limit
380        for i in 0..50 {
381            assert!(limiter.check_rate_limit(), "Request {} should be allowed", i);
382        }
383    }
384
385    #[test]
386    fn test_get_quota_info() {
387        let config = RateLimitConfig {
388            requests_per_minute: 100,
389            burst: 50,
390            per_ip: false,
391            per_endpoint: false,
392        };
393
394        let limiter = GlobalRateLimiter::new(config);
395
396        let quota = limiter.get_quota_info();
397
398        assert_eq!(quota.limit, 100);
399        assert!(quota.remaining > 0);
400        assert!(quota.reset > 0);
401    }
402
403    #[test]
404    fn test_quota_info_limit_matches_config() {
405        let config = RateLimitConfig {
406            requests_per_minute: 500,
407            burst: 100,
408            per_ip: false,
409            per_endpoint: false,
410        };
411
412        let limiter = GlobalRateLimiter::new(config);
413        let quota = limiter.get_quota_info();
414
415        assert_eq!(quota.limit, 500);
416    }
417
418    #[test]
419    fn test_quota_decrements_remaining() {
420        let config = RateLimitConfig {
421            requests_per_minute: 100,
422            burst: 50,
423            per_ip: false,
424            per_endpoint: false,
425        };
426
427        let limiter = GlobalRateLimiter::new(config);
428
429        let first_quota = limiter.get_quota_info();
430        let second_quota = limiter.get_quota_info();
431
432        // Remaining should decrement between calls
433        assert!(second_quota.remaining <= first_quota.remaining, "Remaining should not increase");
434    }
435
436    #[test]
437    fn test_quota_reset_timestamp_is_future() {
438        let config = RateLimitConfig::default();
439        let limiter = GlobalRateLimiter::new(config);
440
441        let quota = limiter.get_quota_info();
442
443        let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
444
445        // Reset timestamp should be in the future (approximately 60 seconds from window start)
446        assert!(quota.reset >= now, "Reset timestamp should be >= current time");
447        assert!(quota.reset <= now + 120, "Reset timestamp should be within 2 minutes");
448    }
449
450    #[test]
451    fn test_rate_limiter_high_burst() {
452        let config = RateLimitConfig {
453            requests_per_minute: 10,
454            burst: 1000, // Very high burst
455            per_ip: false,
456            per_endpoint: false,
457        };
458
459        let limiter = GlobalRateLimiter::new(config);
460
461        // Should allow many requests due to high burst
462        for _ in 0..100 {
463            assert!(limiter.check_rate_limit());
464        }
465    }
466
467    #[test]
468    fn test_rate_limiter_low_limit() {
469        let config = RateLimitConfig {
470            requests_per_minute: 1,
471            burst: 1,
472            per_ip: false,
473            per_endpoint: false,
474        };
475
476        let limiter = GlobalRateLimiter::new(config);
477
478        // First request should succeed
479        assert!(limiter.check_rate_limit());
480    }
481
482    // ==================== Edge Cases ====================
483
484    #[test]
485    fn test_config_with_zero_values_handled() {
486        // Zero values should be handled gracefully by governor
487        let config = RateLimitConfig {
488            requests_per_minute: 0, // Will use default (100)
489            burst: 0,               // Will use default (200)
490            per_ip: false,
491            per_endpoint: false,
492        };
493
494        // Should not panic - NonZeroU32::new(0) returns None, unwrap_or handles it
495        let limiter = GlobalRateLimiter::new(config);
496        assert!(limiter.check_rate_limit());
497    }
498
499    #[test]
500    fn test_multiple_quota_calls_same_limiter() {
501        let config = RateLimitConfig::default();
502        let limiter = GlobalRateLimiter::new(config);
503
504        // Call get_quota_info multiple times
505        let quotas: Vec<RateLimitQuota> = (0..5).map(|_| limiter.get_quota_info()).collect();
506
507        // All should have same limit
508        for quota in &quotas {
509            assert_eq!(quota.limit, 100);
510        }
511
512        // Reset timestamps should be similar (within same window)
513        let first_reset = quotas[0].reset;
514        for quota in &quotas {
515            assert!(
516                (quota.reset as i64 - first_reset as i64).abs() <= 1,
517                "Reset timestamps should be within 1 second of each other"
518            );
519        }
520    }
521
522    #[test]
523    fn test_quota_remaining_never_negative() {
524        let config = RateLimitConfig {
525            requests_per_minute: 5,
526            burst: 5,
527            per_ip: false,
528            per_endpoint: false,
529        };
530
531        let limiter = GlobalRateLimiter::new(config);
532
533        // Call many times to exhaust quota
534        for _ in 0..20 {
535            let quota = limiter.get_quota_info();
536            // Remaining should never go below 0 due to saturating_sub
537            assert!(quota.remaining <= 100, "Remaining should be reasonable");
538        }
539    }
540}