mockforge_http/middleware/
rate_limit.rs

1//! Global rate limiting middleware for HTTP server
2//!
3//! This module provides rate limiting to protect against abuse and DDoS attacks
4//! and adds production-like rate limit headers to responses.
5
6use axum::{
7    body::Body,
8    extract::{ConnectInfo, State},
9    http::{HeaderName, HeaderValue, Request, StatusCode},
10    middleware::Next,
11    response::Response,
12};
13use governor::{
14    clock::DefaultClock,
15    state::{InMemoryState, NotKeyed},
16    Quota, RateLimiter,
17};
18use std::net::SocketAddr;
19use std::num::NonZeroU32;
20use std::sync::{Arc, Mutex};
21use std::time::{Duration, SystemTime, UNIX_EPOCH};
22use tracing::warn;
23
24/// Rate limiting configuration
25#[derive(Debug, Clone)]
26pub struct RateLimitConfig {
27    /// Requests per minute
28    pub requests_per_minute: u32,
29    /// Burst capacity
30    pub burst: u32,
31    /// Enable per-IP rate limiting
32    pub per_ip: bool,
33    /// Enable per-endpoint rate limiting
34    pub per_endpoint: bool,
35}
36
37impl Default for RateLimitConfig {
38    fn default() -> Self {
39        Self {
40            requests_per_minute: 100,
41            burst: 200,
42            per_ip: true,
43            per_endpoint: false,
44        }
45    }
46}
47
48/// Rate limit quota information for headers
49#[derive(Debug, Clone)]
50pub struct RateLimitQuota {
51    /// Maximum requests per minute (limit)
52    pub limit: u32,
53    /// Remaining requests in current window (approximate)
54    pub remaining: u32,
55    /// Unix timestamp when the rate limit resets
56    pub reset: u64,
57}
58
59/// Global rate limiter state
60pub struct GlobalRateLimiter {
61    limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
62    config: RateLimitConfig,
63    /// Track window start time for reset calculation
64    window_start: Arc<Mutex<SystemTime>>,
65    /// Track approximate remaining requests
66    remaining_counter: Arc<Mutex<u32>>,
67}
68
69impl GlobalRateLimiter {
70    /// Create a new global rate limiter
71    pub fn new(config: RateLimitConfig) -> Self {
72        let quota = Quota::per_minute(
73            NonZeroU32::new(config.requests_per_minute).unwrap_or(NonZeroU32::new(100).unwrap()),
74        )
75        .allow_burst(NonZeroU32::new(config.burst).unwrap_or(NonZeroU32::new(200).unwrap()));
76
77        let limiter = Arc::new(RateLimiter::direct(quota));
78        let window_start = Arc::new(Mutex::new(SystemTime::now()));
79        let remaining_counter = Arc::new(Mutex::new(config.requests_per_minute));
80
81        Self {
82            limiter,
83            config,
84            window_start,
85            remaining_counter,
86        }
87    }
88
89    /// Check if request should be rate limited
90    pub fn check_rate_limit(&self) -> bool {
91        self.limiter.check().is_ok()
92    }
93
94    /// Get rate limit quota information for headers
95    ///
96    /// Returns information about the current rate limit state including
97    /// limit, remaining requests, and reset timestamp.
98    pub fn get_quota_info(&self) -> RateLimitQuota {
99        let now = SystemTime::now();
100        let mut window_start = self.window_start.lock().unwrap();
101        let mut remaining = self.remaining_counter.lock().unwrap();
102
103        // Check if we need to reset the window (every minute)
104        let window_duration = Duration::from_secs(60);
105        if now.duration_since(*window_start).unwrap_or(Duration::ZERO) >= window_duration {
106            // Reset window
107            *window_start = now;
108            *remaining = self.config.requests_per_minute;
109        }
110
111        // Decrement remaining if we successfully checked (approximate)
112        // Note: This is approximate because governor's token bucket
113        // may have different internal state, but it's good enough for headers
114        let current_remaining = *remaining;
115        if current_remaining > 0 {
116            *remaining = current_remaining.saturating_sub(1);
117        }
118
119        // Calculate reset timestamp (start of next window)
120        let reset_timestamp =
121            window_start.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_secs() + 60; // Add 60 seconds for next window
122
123        RateLimitQuota {
124            limit: self.config.requests_per_minute,
125            remaining: current_remaining,
126            reset: reset_timestamp,
127        }
128    }
129}
130
131/// Rate limiting middleware
132///
133/// This middleware:
134/// 1. Checks if the request should be rate limited
135/// 2. Adds rate limit headers to successful responses (for deceptive deploy)
136pub async fn rate_limit_middleware(
137    State(state): axum::extract::State<crate::HttpServerState>,
138    ConnectInfo(addr): ConnectInfo<SocketAddr>,
139    req: Request<Body>,
140    next: Next,
141) -> Result<Response, StatusCode> {
142    // Get rate limiter from app state
143    let quota_info = if let Some(limiter) = &state.rate_limiter {
144        // Check rate limit
145        if !limiter.check_rate_limit() {
146            warn!("Rate limit exceeded for IP: {}", addr.ip());
147            return Err(StatusCode::TOO_MANY_REQUESTS);
148        }
149
150        // Get quota information for headers
151        Some(limiter.get_quota_info())
152    } else {
153        // No rate limiter configured, allow request to proceed
154        tracing::debug!("No rate limiter configured, allowing request");
155        None
156    };
157
158    // Process request and get response
159    let mut response = next.run(req).await;
160
161    // Add rate limit headers to response if we have quota info
162    // This makes the mock API look more like production
163    if let Some(quota) = quota_info {
164        // Add X-Rate-Limit-Limit header
165        let limit_name = HeaderName::from_static("x-rate-limit-limit");
166        if let Ok(limit_value) = HeaderValue::from_str(&quota.limit.to_string()) {
167            response.headers_mut().insert(limit_name, limit_value);
168        }
169
170        // Add X-Rate-Limit-Remaining header
171        let remaining_name = HeaderName::from_static("x-rate-limit-remaining");
172        if let Ok(remaining_value) = HeaderValue::from_str(&quota.remaining.to_string()) {
173            response.headers_mut().insert(remaining_name, remaining_value);
174        }
175
176        // Add X-Rate-Limit-Reset header (Unix timestamp)
177        let reset_name = HeaderName::from_static("x-rate-limit-reset");
178        if let Ok(reset_value) = HeaderValue::from_str(&quota.reset.to_string()) {
179            response.headers_mut().insert(reset_name, reset_value);
180        }
181    }
182
183    Ok(response)
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    // ==================== RateLimitConfig Tests ====================
191
192    #[test]
193    fn test_rate_limit_config_default() {
194        let config = RateLimitConfig::default();
195        assert_eq!(config.requests_per_minute, 100);
196        assert_eq!(config.burst, 200);
197        assert!(config.per_ip);
198        assert!(!config.per_endpoint);
199    }
200
201    #[test]
202    fn test_rate_limit_config_custom() {
203        let config = RateLimitConfig {
204            requests_per_minute: 50,
205            burst: 100,
206            per_ip: false,
207            per_endpoint: true,
208        };
209
210        assert_eq!(config.requests_per_minute, 50);
211        assert_eq!(config.burst, 100);
212        assert!(!config.per_ip);
213        assert!(config.per_endpoint);
214    }
215
216    #[test]
217    fn test_rate_limit_config_clone() {
218        let config = RateLimitConfig {
219            requests_per_minute: 75,
220            burst: 150,
221            per_ip: true,
222            per_endpoint: true,
223        };
224
225        let cloned = config.clone();
226
227        assert_eq!(cloned.requests_per_minute, config.requests_per_minute);
228        assert_eq!(cloned.burst, config.burst);
229        assert_eq!(cloned.per_ip, config.per_ip);
230        assert_eq!(cloned.per_endpoint, config.per_endpoint);
231    }
232
233    #[test]
234    fn test_rate_limit_config_debug() {
235        let config = RateLimitConfig::default();
236        let debug_str = format!("{:?}", config);
237
238        assert!(debug_str.contains("requests_per_minute"));
239        assert!(debug_str.contains("burst"));
240        assert!(debug_str.contains("per_ip"));
241        assert!(debug_str.contains("per_endpoint"));
242    }
243
244    // ==================== RateLimitQuota Tests ====================
245
246    #[test]
247    fn test_rate_limit_quota_creation() {
248        let quota = RateLimitQuota {
249            limit: 100,
250            remaining: 50,
251            reset: 1234567890,
252        };
253
254        assert_eq!(quota.limit, 100);
255        assert_eq!(quota.remaining, 50);
256        assert_eq!(quota.reset, 1234567890);
257    }
258
259    #[test]
260    fn test_rate_limit_quota_clone() {
261        let quota = RateLimitQuota {
262            limit: 200,
263            remaining: 175,
264            reset: 9876543210,
265        };
266
267        let cloned = quota.clone();
268
269        assert_eq!(cloned.limit, quota.limit);
270        assert_eq!(cloned.remaining, quota.remaining);
271        assert_eq!(cloned.reset, quota.reset);
272    }
273
274    #[test]
275    fn test_rate_limit_quota_debug() {
276        let quota = RateLimitQuota {
277            limit: 100,
278            remaining: 50,
279            reset: 1234567890,
280        };
281
282        let debug_str = format!("{:?}", quota);
283
284        assert!(debug_str.contains("limit"));
285        assert!(debug_str.contains("remaining"));
286        assert!(debug_str.contains("reset"));
287    }
288
289    // ==================== GlobalRateLimiter Tests ====================
290
291    #[test]
292    fn test_rate_limiter_creation() {
293        let config = RateLimitConfig::default();
294        let limiter = GlobalRateLimiter::new(config);
295
296        // Should allow first request
297        assert!(limiter.check_rate_limit());
298    }
299
300    #[test]
301    fn test_rate_limiter_with_custom_config() {
302        let config = RateLimitConfig {
303            requests_per_minute: 60,
304            burst: 10,
305            per_ip: false,
306            per_endpoint: false,
307        };
308
309        let limiter = GlobalRateLimiter::new(config);
310        assert!(limiter.check_rate_limit());
311    }
312
313    #[test]
314    fn test_rate_limiter_burst() {
315        let config = RateLimitConfig {
316            requests_per_minute: 10,
317            burst: 5,
318            per_ip: false,
319            per_endpoint: false,
320        };
321
322        let limiter = GlobalRateLimiter::new(config);
323
324        // Should allow burst requests
325        for _ in 0..5 {
326            assert!(limiter.check_rate_limit(), "Burst request should be allowed");
327        }
328    }
329
330    #[test]
331    fn test_rate_limiter_multiple_requests() {
332        let config = RateLimitConfig {
333            requests_per_minute: 1000,
334            burst: 100,
335            per_ip: false,
336            per_endpoint: false,
337        };
338
339        let limiter = GlobalRateLimiter::new(config);
340
341        // Should allow many requests within burst limit
342        for i in 0..50 {
343            assert!(limiter.check_rate_limit(), "Request {} should be allowed", i);
344        }
345    }
346
347    #[test]
348    fn test_get_quota_info() {
349        let config = RateLimitConfig {
350            requests_per_minute: 100,
351            burst: 50,
352            per_ip: false,
353            per_endpoint: false,
354        };
355
356        let limiter = GlobalRateLimiter::new(config);
357
358        let quota = limiter.get_quota_info();
359
360        assert_eq!(quota.limit, 100);
361        assert!(quota.remaining > 0);
362        assert!(quota.reset > 0);
363    }
364
365    #[test]
366    fn test_quota_info_limit_matches_config() {
367        let config = RateLimitConfig {
368            requests_per_minute: 500,
369            burst: 100,
370            per_ip: false,
371            per_endpoint: false,
372        };
373
374        let limiter = GlobalRateLimiter::new(config);
375        let quota = limiter.get_quota_info();
376
377        assert_eq!(quota.limit, 500);
378    }
379
380    #[test]
381    fn test_quota_decrements_remaining() {
382        let config = RateLimitConfig {
383            requests_per_minute: 100,
384            burst: 50,
385            per_ip: false,
386            per_endpoint: false,
387        };
388
389        let limiter = GlobalRateLimiter::new(config);
390
391        let first_quota = limiter.get_quota_info();
392        let second_quota = limiter.get_quota_info();
393
394        // Remaining should decrement between calls
395        assert!(second_quota.remaining <= first_quota.remaining, "Remaining should not increase");
396    }
397
398    #[test]
399    fn test_quota_reset_timestamp_is_future() {
400        let config = RateLimitConfig::default();
401        let limiter = GlobalRateLimiter::new(config);
402
403        let quota = limiter.get_quota_info();
404
405        let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
406
407        // Reset timestamp should be in the future (approximately 60 seconds from window start)
408        assert!(quota.reset >= now, "Reset timestamp should be >= current time");
409        assert!(quota.reset <= now + 120, "Reset timestamp should be within 2 minutes");
410    }
411
412    #[test]
413    fn test_rate_limiter_high_burst() {
414        let config = RateLimitConfig {
415            requests_per_minute: 10,
416            burst: 1000, // Very high burst
417            per_ip: false,
418            per_endpoint: false,
419        };
420
421        let limiter = GlobalRateLimiter::new(config);
422
423        // Should allow many requests due to high burst
424        for _ in 0..100 {
425            assert!(limiter.check_rate_limit());
426        }
427    }
428
429    #[test]
430    fn test_rate_limiter_low_limit() {
431        let config = RateLimitConfig {
432            requests_per_minute: 1,
433            burst: 1,
434            per_ip: false,
435            per_endpoint: false,
436        };
437
438        let limiter = GlobalRateLimiter::new(config);
439
440        // First request should succeed
441        assert!(limiter.check_rate_limit());
442    }
443
444    // ==================== Edge Cases ====================
445
446    #[test]
447    fn test_config_with_zero_values_handled() {
448        // Zero values should be handled gracefully by governor
449        let config = RateLimitConfig {
450            requests_per_minute: 0, // Will use default (100)
451            burst: 0,               // Will use default (200)
452            per_ip: false,
453            per_endpoint: false,
454        };
455
456        // Should not panic - NonZeroU32::new(0) returns None, unwrap_or handles it
457        let limiter = GlobalRateLimiter::new(config);
458        assert!(limiter.check_rate_limit());
459    }
460
461    #[test]
462    fn test_multiple_quota_calls_same_limiter() {
463        let config = RateLimitConfig::default();
464        let limiter = GlobalRateLimiter::new(config);
465
466        // Call get_quota_info multiple times
467        let quotas: Vec<RateLimitQuota> = (0..5).map(|_| limiter.get_quota_info()).collect();
468
469        // All should have same limit
470        for quota in &quotas {
471            assert_eq!(quota.limit, 100);
472        }
473
474        // Reset timestamps should be similar (within same window)
475        let first_reset = quotas[0].reset;
476        for quota in &quotas {
477            assert!(
478                (quota.reset as i64 - first_reset as i64).abs() <= 1,
479                "Reset timestamps should be within 1 second of each other"
480            );
481        }
482    }
483
484    #[test]
485    fn test_quota_remaining_never_negative() {
486        let config = RateLimitConfig {
487            requests_per_minute: 5,
488            burst: 5,
489            per_ip: false,
490            per_endpoint: false,
491        };
492
493        let limiter = GlobalRateLimiter::new(config);
494
495        // Call many times to exhaust quota
496        for _ in 0..20 {
497            let quota = limiter.get_quota_info();
498            // Remaining should never go below 0 due to saturating_sub
499            assert!(quota.remaining <= 100, "Remaining should be reasonable");
500        }
501    }
502}