Skip to main content

mockforge_http/middleware/
rate_limit.rs

1//! Global rate limiting middleware for HTTP server
2//!
3//! This module provides rate limiting to protect against abuse and DDoS attacks
4//! and adds production-like rate limit headers to responses.
5
6use axum::{
7    body::Body,
8    extract::{ConnectInfo, State},
9    http::{HeaderName, HeaderValue, Request, StatusCode},
10    middleware::Next,
11    response::Response,
12};
13use governor::{
14    clock::DefaultClock,
15    state::{InMemoryState, NotKeyed},
16    Quota, RateLimiter,
17};
18use std::net::SocketAddr;
19use std::num::NonZeroU32;
20use std::sync::{Arc, Mutex};
21use std::time::{Duration, SystemTime, UNIX_EPOCH};
22use tracing::warn;
23
24/// Returns true when the operator has explicitly disabled HTTP rate limiting.
25///
26/// Honors `MOCKFORGE_RATE_LIMIT_ENABLED=false` (preferred, matches the
27/// `MOCKFORGE_LATENCY_ENABLED` / `MOCKFORGE_FAILURES_ENABLED` family) and the
28/// shorthand alias `MOCKFORGE_RATE_LIMIT_DISABLED=true`. The `--no-rate-limit`
29/// CLI flag funnels through these env vars as well.
30///
31/// When this returns true the router skips wiring a `GlobalRateLimiter` into
32/// state; the middleware then short-circuits and forwards every request.
33pub fn is_rate_limit_disabled() -> bool {
34    fn truthy(v: &str) -> bool {
35        matches!(v.trim().to_ascii_lowercase().as_str(), "1" | "true" | "yes" | "on")
36    }
37    fn falsy(v: &str) -> bool {
38        matches!(v.trim().to_ascii_lowercase().as_str(), "0" | "false" | "no" | "off")
39    }
40
41    if let Ok(v) = std::env::var("MOCKFORGE_RATE_LIMIT_ENABLED") {
42        if falsy(&v) {
43            return true;
44        }
45    }
46    if let Ok(v) = std::env::var("MOCKFORGE_RATE_LIMIT_DISABLED") {
47        if truthy(&v) {
48            return true;
49        }
50    }
51    false
52}
53
54/// Rate limiting configuration
55#[derive(Debug, Clone)]
56pub struct RateLimitConfig {
57    /// Requests per minute
58    pub requests_per_minute: u32,
59    /// Burst capacity
60    pub burst: u32,
61    /// Enable per-IP rate limiting
62    pub per_ip: bool,
63    /// Enable per-endpoint rate limiting
64    pub per_endpoint: bool,
65}
66
67impl Default for RateLimitConfig {
68    fn default() -> Self {
69        Self {
70            requests_per_minute: 100,
71            burst: 200,
72            per_ip: true,
73            per_endpoint: false,
74        }
75    }
76}
77
78/// Rate limit quota information for headers
79#[derive(Debug, Clone)]
80pub struct RateLimitQuota {
81    /// Maximum requests per minute (limit)
82    pub limit: u32,
83    /// Remaining requests in current window (approximate)
84    pub remaining: u32,
85    /// Unix timestamp when the rate limit resets
86    pub reset: u64,
87}
88
89/// Global rate limiter state
90pub struct GlobalRateLimiter {
91    limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
92    config: RateLimitConfig,
93    /// Track window start time for reset calculation
94    window_start: Arc<Mutex<SystemTime>>,
95    /// Track approximate remaining requests
96    remaining_counter: Arc<Mutex<u32>>,
97}
98
99impl GlobalRateLimiter {
100    /// Create a new global rate limiter
101    pub fn new(config: RateLimitConfig) -> Self {
102        let quota = Quota::per_minute(
103            NonZeroU32::new(config.requests_per_minute)
104                .unwrap_or(NonZeroU32::new(100).expect("constant 100 is non-zero")),
105        )
106        .allow_burst(
107            NonZeroU32::new(config.burst)
108                .unwrap_or(NonZeroU32::new(200).expect("constant 200 is non-zero")),
109        );
110
111        let limiter = Arc::new(RateLimiter::direct(quota));
112        let window_start = Arc::new(Mutex::new(SystemTime::now()));
113        let remaining_counter = Arc::new(Mutex::new(config.requests_per_minute));
114
115        Self {
116            limiter,
117            config,
118            window_start,
119            remaining_counter,
120        }
121    }
122
123    /// Check if request should be rate limited
124    pub fn check_rate_limit(&self) -> bool {
125        self.limiter.check().is_ok()
126    }
127
128    /// Get rate limit quota information for headers
129    ///
130    /// Returns information about the current rate limit state including
131    /// limit, remaining requests, and reset timestamp.
132    pub fn get_quota_info(&self) -> RateLimitQuota {
133        let now = SystemTime::now();
134        let mut window_start =
135            self.window_start.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
136        let mut remaining =
137            self.remaining_counter.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
138
139        // Check if we need to reset the window (every minute)
140        let window_duration = Duration::from_secs(60);
141        if now.duration_since(*window_start).unwrap_or(Duration::ZERO) >= window_duration {
142            // Reset window
143            *window_start = now;
144            *remaining = self.config.requests_per_minute;
145        }
146
147        // Decrement remaining if we successfully checked (approximate)
148        // Note: This is approximate because governor's token bucket
149        // may have different internal state, but it's good enough for headers
150        let current_remaining = *remaining;
151        if current_remaining > 0 {
152            *remaining = current_remaining.saturating_sub(1);
153        }
154
155        // Calculate reset timestamp (start of next window)
156        let reset_timestamp =
157            window_start.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_secs() + 60; // Add 60 seconds for next window
158
159        RateLimitQuota {
160            limit: self.config.requests_per_minute,
161            remaining: current_remaining,
162            reset: reset_timestamp,
163        }
164    }
165}
166
167/// Rate limiting middleware
168///
169/// This middleware:
170/// 1. Checks if the request should be rate limited
171/// 2. Adds rate limit headers to successful responses (for deceptive deploy)
172/// 3. Returns 429 with Retry-After header when rate limited
173pub async fn rate_limit_middleware(
174    State(state): State<crate::HttpServerState>,
175    ConnectInfo(addr): ConnectInfo<SocketAddr>,
176    req: Request<Body>,
177    next: Next,
178) -> Response {
179    // Get rate limiter from app state
180    let quota_info = if let Some(limiter) = &state.rate_limiter {
181        // Check rate limit
182        if !limiter.check_rate_limit() {
183            warn!("Rate limit exceeded for IP: {}", addr.ip());
184            // Return 429 with Retry-After header per HTTP spec
185            let mut response = Response::builder()
186                .status(StatusCode::TOO_MANY_REQUESTS)
187                .body(Body::from("Too Many Requests"))
188                .unwrap_or_else(|_| Response::new(Body::from("Too Many Requests")));
189
190            // Add Retry-After header (60 seconds = 1 minute window)
191            let retry_after = HeaderValue::from_static("60");
192            response
193                .headers_mut()
194                .insert(HeaderName::from_static("retry-after"), retry_after);
195
196            // Add rate limit headers to the 429 response
197            let quota = limiter.get_quota_info();
198            if let Ok(limit_value) = HeaderValue::from_str(&quota.limit.to_string()) {
199                response
200                    .headers_mut()
201                    .insert(HeaderName::from_static("x-rate-limit-limit"), limit_value);
202            }
203            if let Ok(remaining_value) = HeaderValue::from_str("0") {
204                response
205                    .headers_mut()
206                    .insert(HeaderName::from_static("x-rate-limit-remaining"), remaining_value);
207            }
208            if let Ok(reset_value) = HeaderValue::from_str(&quota.reset.to_string()) {
209                response
210                    .headers_mut()
211                    .insert(HeaderName::from_static("x-rate-limit-reset"), reset_value);
212            }
213
214            return response;
215        }
216
217        // Get quota information for headers
218        Some(limiter.get_quota_info())
219    } else {
220        // No rate limiter configured, allow request to proceed
221        tracing::debug!("No rate limiter configured, allowing request");
222        None
223    };
224
225    // Process request and get response
226    let mut response = next.run(req).await;
227
228    // Add rate limit headers to response if we have quota info
229    // This makes the mock API look more like production
230    if let Some(quota) = quota_info {
231        // Add X-Rate-Limit-Limit header
232        let limit_name = HeaderName::from_static("x-rate-limit-limit");
233        if let Ok(limit_value) = HeaderValue::from_str(&quota.limit.to_string()) {
234            response.headers_mut().insert(limit_name, limit_value);
235        }
236
237        // Add X-Rate-Limit-Remaining header
238        let remaining_name = HeaderName::from_static("x-rate-limit-remaining");
239        if let Ok(remaining_value) = HeaderValue::from_str(&quota.remaining.to_string()) {
240            response.headers_mut().insert(remaining_name, remaining_value);
241        }
242
243        // Add X-Rate-Limit-Reset header (Unix timestamp)
244        let reset_name = HeaderName::from_static("x-rate-limit-reset");
245        if let Ok(reset_value) = HeaderValue::from_str(&quota.reset.to_string()) {
246            response.headers_mut().insert(reset_name, reset_value);
247        }
248    }
249
250    response
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    // ==================== is_rate_limit_disabled Tests ====================
258
259    /// Snapshot + restore both env vars around the assertions so this test is
260    /// robust to other tests in the binary touching them in parallel.
261    #[test]
262    fn test_is_rate_limit_disabled_env_vars() {
263        let saved_enabled = std::env::var("MOCKFORGE_RATE_LIMIT_ENABLED").ok();
264        let saved_disabled = std::env::var("MOCKFORGE_RATE_LIMIT_DISABLED").ok();
265
266        // Default (unset): limiter is on
267        std::env::remove_var("MOCKFORGE_RATE_LIMIT_ENABLED");
268        std::env::remove_var("MOCKFORGE_RATE_LIMIT_DISABLED");
269        assert!(!is_rate_limit_disabled());
270
271        // ENABLED=false → disabled
272        for v in ["false", "0", "no", "off", "FALSE", "  False  "] {
273            std::env::set_var("MOCKFORGE_RATE_LIMIT_ENABLED", v);
274            assert!(is_rate_limit_disabled(), "ENABLED={v:?} should disable");
275        }
276        std::env::remove_var("MOCKFORGE_RATE_LIMIT_ENABLED");
277
278        // ENABLED=true → not disabled
279        std::env::set_var("MOCKFORGE_RATE_LIMIT_ENABLED", "true");
280        assert!(!is_rate_limit_disabled());
281        std::env::remove_var("MOCKFORGE_RATE_LIMIT_ENABLED");
282
283        // DISABLED=true → disabled (alias)
284        for v in ["true", "1", "yes", "on", "TRUE"] {
285            std::env::set_var("MOCKFORGE_RATE_LIMIT_DISABLED", v);
286            assert!(is_rate_limit_disabled(), "DISABLED={v:?} should disable");
287        }
288        std::env::remove_var("MOCKFORGE_RATE_LIMIT_DISABLED");
289
290        // DISABLED=false → not disabled
291        std::env::set_var("MOCKFORGE_RATE_LIMIT_DISABLED", "false");
292        assert!(!is_rate_limit_disabled());
293        std::env::remove_var("MOCKFORGE_RATE_LIMIT_DISABLED");
294
295        // Restore
296        match saved_enabled {
297            Some(v) => std::env::set_var("MOCKFORGE_RATE_LIMIT_ENABLED", v),
298            None => std::env::remove_var("MOCKFORGE_RATE_LIMIT_ENABLED"),
299        }
300        match saved_disabled {
301            Some(v) => std::env::set_var("MOCKFORGE_RATE_LIMIT_DISABLED", v),
302            None => std::env::remove_var("MOCKFORGE_RATE_LIMIT_DISABLED"),
303        }
304    }
305
306    // ==================== RateLimitConfig Tests ====================
307
308    #[test]
309    fn test_rate_limit_config_default() {
310        let config = RateLimitConfig::default();
311        assert_eq!(config.requests_per_minute, 100);
312        assert_eq!(config.burst, 200);
313        assert!(config.per_ip);
314        assert!(!config.per_endpoint);
315    }
316
317    #[test]
318    fn test_rate_limit_config_custom() {
319        let config = RateLimitConfig {
320            requests_per_minute: 50,
321            burst: 100,
322            per_ip: false,
323            per_endpoint: true,
324        };
325
326        assert_eq!(config.requests_per_minute, 50);
327        assert_eq!(config.burst, 100);
328        assert!(!config.per_ip);
329        assert!(config.per_endpoint);
330    }
331
332    #[test]
333    fn test_rate_limit_config_clone() {
334        let config = RateLimitConfig {
335            requests_per_minute: 75,
336            burst: 150,
337            per_ip: true,
338            per_endpoint: true,
339        };
340
341        let cloned = config.clone();
342
343        assert_eq!(cloned.requests_per_minute, config.requests_per_minute);
344        assert_eq!(cloned.burst, config.burst);
345        assert_eq!(cloned.per_ip, config.per_ip);
346        assert_eq!(cloned.per_endpoint, config.per_endpoint);
347    }
348
349    #[test]
350    fn test_rate_limit_config_debug() {
351        let config = RateLimitConfig::default();
352        let debug_str = format!("{:?}", config);
353
354        assert!(debug_str.contains("requests_per_minute"));
355        assert!(debug_str.contains("burst"));
356        assert!(debug_str.contains("per_ip"));
357        assert!(debug_str.contains("per_endpoint"));
358    }
359
360    // ==================== RateLimitQuota Tests ====================
361
362    #[test]
363    fn test_rate_limit_quota_creation() {
364        let quota = RateLimitQuota {
365            limit: 100,
366            remaining: 50,
367            reset: 1234567890,
368        };
369
370        assert_eq!(quota.limit, 100);
371        assert_eq!(quota.remaining, 50);
372        assert_eq!(quota.reset, 1234567890);
373    }
374
375    #[test]
376    fn test_rate_limit_quota_clone() {
377        let quota = RateLimitQuota {
378            limit: 200,
379            remaining: 175,
380            reset: 9876543210,
381        };
382
383        let cloned = quota.clone();
384
385        assert_eq!(cloned.limit, quota.limit);
386        assert_eq!(cloned.remaining, quota.remaining);
387        assert_eq!(cloned.reset, quota.reset);
388    }
389
390    #[test]
391    fn test_rate_limit_quota_debug() {
392        let quota = RateLimitQuota {
393            limit: 100,
394            remaining: 50,
395            reset: 1234567890,
396        };
397
398        let debug_str = format!("{:?}", quota);
399
400        assert!(debug_str.contains("limit"));
401        assert!(debug_str.contains("remaining"));
402        assert!(debug_str.contains("reset"));
403    }
404
405    // ==================== GlobalRateLimiter Tests ====================
406
407    #[test]
408    fn test_rate_limiter_creation() {
409        let config = RateLimitConfig::default();
410        let limiter = GlobalRateLimiter::new(config);
411
412        // Should allow first request
413        assert!(limiter.check_rate_limit());
414    }
415
416    #[test]
417    fn test_rate_limiter_with_custom_config() {
418        let config = RateLimitConfig {
419            requests_per_minute: 60,
420            burst: 10,
421            per_ip: false,
422            per_endpoint: false,
423        };
424
425        let limiter = GlobalRateLimiter::new(config);
426        assert!(limiter.check_rate_limit());
427    }
428
429    #[test]
430    fn test_rate_limiter_burst() {
431        let config = RateLimitConfig {
432            requests_per_minute: 10,
433            burst: 5,
434            per_ip: false,
435            per_endpoint: false,
436        };
437
438        let limiter = GlobalRateLimiter::new(config);
439
440        // Should allow burst requests
441        for _ in 0..5 {
442            assert!(limiter.check_rate_limit(), "Burst request should be allowed");
443        }
444    }
445
446    #[test]
447    fn test_rate_limiter_multiple_requests() {
448        let config = RateLimitConfig {
449            requests_per_minute: 1000,
450            burst: 100,
451            per_ip: false,
452            per_endpoint: false,
453        };
454
455        let limiter = GlobalRateLimiter::new(config);
456
457        // Should allow many requests within burst limit
458        for i in 0..50 {
459            assert!(limiter.check_rate_limit(), "Request {} should be allowed", i);
460        }
461    }
462
463    #[test]
464    fn test_get_quota_info() {
465        let config = RateLimitConfig {
466            requests_per_minute: 100,
467            burst: 50,
468            per_ip: false,
469            per_endpoint: false,
470        };
471
472        let limiter = GlobalRateLimiter::new(config);
473
474        let quota = limiter.get_quota_info();
475
476        assert_eq!(quota.limit, 100);
477        assert!(quota.remaining > 0);
478        assert!(quota.reset > 0);
479    }
480
481    #[test]
482    fn test_quota_info_limit_matches_config() {
483        let config = RateLimitConfig {
484            requests_per_minute: 500,
485            burst: 100,
486            per_ip: false,
487            per_endpoint: false,
488        };
489
490        let limiter = GlobalRateLimiter::new(config);
491        let quota = limiter.get_quota_info();
492
493        assert_eq!(quota.limit, 500);
494    }
495
496    #[test]
497    fn test_quota_decrements_remaining() {
498        let config = RateLimitConfig {
499            requests_per_minute: 100,
500            burst: 50,
501            per_ip: false,
502            per_endpoint: false,
503        };
504
505        let limiter = GlobalRateLimiter::new(config);
506
507        let first_quota = limiter.get_quota_info();
508        let second_quota = limiter.get_quota_info();
509
510        // Remaining should decrement between calls
511        assert!(second_quota.remaining <= first_quota.remaining, "Remaining should not increase");
512    }
513
514    #[test]
515    fn test_quota_reset_timestamp_is_future() {
516        let config = RateLimitConfig::default();
517        let limiter = GlobalRateLimiter::new(config);
518
519        let quota = limiter.get_quota_info();
520
521        let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
522
523        // Reset timestamp should be in the future (approximately 60 seconds from window start)
524        assert!(quota.reset >= now, "Reset timestamp should be >= current time");
525        assert!(quota.reset <= now + 120, "Reset timestamp should be within 2 minutes");
526    }
527
528    #[test]
529    fn test_rate_limiter_high_burst() {
530        let config = RateLimitConfig {
531            requests_per_minute: 10,
532            burst: 1000, // Very high burst
533            per_ip: false,
534            per_endpoint: false,
535        };
536
537        let limiter = GlobalRateLimiter::new(config);
538
539        // Should allow many requests due to high burst
540        for _ in 0..100 {
541            assert!(limiter.check_rate_limit());
542        }
543    }
544
545    #[test]
546    fn test_rate_limiter_low_limit() {
547        let config = RateLimitConfig {
548            requests_per_minute: 1,
549            burst: 1,
550            per_ip: false,
551            per_endpoint: false,
552        };
553
554        let limiter = GlobalRateLimiter::new(config);
555
556        // First request should succeed
557        assert!(limiter.check_rate_limit());
558    }
559
560    // ==================== Edge Cases ====================
561
562    #[test]
563    fn test_config_with_zero_values_handled() {
564        // Zero values should be handled gracefully by governor
565        let config = RateLimitConfig {
566            requests_per_minute: 0, // Will use default (100)
567            burst: 0,               // Will use default (200)
568            per_ip: false,
569            per_endpoint: false,
570        };
571
572        // Should not panic - NonZeroU32::new(0) returns None, unwrap_or handles it
573        let limiter = GlobalRateLimiter::new(config);
574        assert!(limiter.check_rate_limit());
575    }
576
577    #[test]
578    fn test_multiple_quota_calls_same_limiter() {
579        let config = RateLimitConfig::default();
580        let limiter = GlobalRateLimiter::new(config);
581
582        // Call get_quota_info multiple times
583        let quotas: Vec<RateLimitQuota> = (0..5).map(|_| limiter.get_quota_info()).collect();
584
585        // All should have same limit
586        for quota in &quotas {
587            assert_eq!(quota.limit, 100);
588        }
589
590        // Reset timestamps should be similar (within same window)
591        let first_reset = quotas[0].reset;
592        for quota in &quotas {
593            assert!(
594                (quota.reset as i64 - first_reset as i64).abs() <= 1,
595                "Reset timestamps should be within 1 second of each other"
596            );
597        }
598    }
599
600    #[test]
601    fn test_quota_remaining_never_negative() {
602        let config = RateLimitConfig {
603            requests_per_minute: 5,
604            burst: 5,
605            per_ip: false,
606            per_endpoint: false,
607        };
608
609        let limiter = GlobalRateLimiter::new(config);
610
611        // Call many times to exhaust quota
612        for _ in 0..20 {
613            let quota = limiter.get_quota_info();
614            // Remaining should never go below 0 due to saturating_sub
615            assert!(quota.remaining <= 100, "Remaining should be reasonable");
616        }
617    }
618}