mockforge_http/middleware/
rate_limit.rs1use axum::{
7 body::Body,
8 extract::{ConnectInfo, State},
9 http::{HeaderName, HeaderValue, Request, StatusCode},
10 middleware::Next,
11 response::Response,
12};
13use governor::{
14 clock::DefaultClock,
15 state::{InMemoryState, NotKeyed},
16 Quota, RateLimiter,
17};
18use std::net::SocketAddr;
19use std::num::NonZeroU32;
20use std::sync::{Arc, Mutex};
21use std::time::{Duration, SystemTime, UNIX_EPOCH};
22use tracing::warn;
23
24#[derive(Debug, Clone)]
26pub struct RateLimitConfig {
27 pub requests_per_minute: u32,
29 pub burst: u32,
31 pub per_ip: bool,
33 pub per_endpoint: bool,
35}
36
37impl Default for RateLimitConfig {
38 fn default() -> Self {
39 Self {
40 requests_per_minute: 100,
41 burst: 200,
42 per_ip: true,
43 per_endpoint: false,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct RateLimitQuota {
51 pub limit: u32,
53 pub remaining: u32,
55 pub reset: u64,
57}
58
59pub struct GlobalRateLimiter {
61 limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
62 config: RateLimitConfig,
63 window_start: Arc<Mutex<SystemTime>>,
65 remaining_counter: Arc<Mutex<u32>>,
67}
68
69impl GlobalRateLimiter {
70 pub fn new(config: RateLimitConfig) -> Self {
72 let quota = Quota::per_minute(
73 NonZeroU32::new(config.requests_per_minute).unwrap_or(NonZeroU32::new(100).unwrap()),
74 )
75 .allow_burst(NonZeroU32::new(config.burst).unwrap_or(NonZeroU32::new(200).unwrap()));
76
77 let limiter = Arc::new(RateLimiter::direct(quota));
78 let window_start = Arc::new(Mutex::new(SystemTime::now()));
79 let remaining_counter = Arc::new(Mutex::new(config.requests_per_minute));
80
81 Self {
82 limiter,
83 config,
84 window_start,
85 remaining_counter,
86 }
87 }
88
89 pub fn check_rate_limit(&self) -> bool {
91 self.limiter.check().is_ok()
92 }
93
94 pub fn get_quota_info(&self) -> RateLimitQuota {
99 let now = SystemTime::now();
100 let mut window_start = self.window_start.lock().unwrap();
101 let mut remaining = self.remaining_counter.lock().unwrap();
102
103 let window_duration = Duration::from_secs(60);
105 if now.duration_since(*window_start).unwrap_or(Duration::ZERO) >= window_duration {
106 *window_start = now;
108 *remaining = self.config.requests_per_minute;
109 }
110
111 let current_remaining = *remaining;
115 if current_remaining > 0 {
116 *remaining = current_remaining.saturating_sub(1);
117 }
118
119 let reset_timestamp =
121 window_start.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_secs() + 60; RateLimitQuota {
124 limit: self.config.requests_per_minute,
125 remaining: current_remaining,
126 reset: reset_timestamp,
127 }
128 }
129}
130
131pub async fn rate_limit_middleware(
137 State(state): axum::extract::State<crate::HttpServerState>,
138 ConnectInfo(addr): ConnectInfo<SocketAddr>,
139 req: Request<Body>,
140 next: Next,
141) -> Result<Response, StatusCode> {
142 let quota_info = if let Some(limiter) = &state.rate_limiter {
144 if !limiter.check_rate_limit() {
146 warn!("Rate limit exceeded for IP: {}", addr.ip());
147 return Err(StatusCode::TOO_MANY_REQUESTS);
148 }
149
150 Some(limiter.get_quota_info())
152 } else {
153 tracing::debug!("No rate limiter configured, allowing request");
155 None
156 };
157
158 let mut response = next.run(req).await;
160
161 if let Some(quota) = quota_info {
164 let limit_name = HeaderName::from_static("x-rate-limit-limit");
166 if let Ok(limit_value) = HeaderValue::from_str("a.limit.to_string()) {
167 response.headers_mut().insert(limit_name, limit_value);
168 }
169
170 let remaining_name = HeaderName::from_static("x-rate-limit-remaining");
172 if let Ok(remaining_value) = HeaderValue::from_str("a.remaining.to_string()) {
173 response.headers_mut().insert(remaining_name, remaining_value);
174 }
175
176 let reset_name = HeaderName::from_static("x-rate-limit-reset");
178 if let Ok(reset_value) = HeaderValue::from_str("a.reset.to_string()) {
179 response.headers_mut().insert(reset_name, reset_value);
180 }
181 }
182
183 Ok(response)
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
193 fn test_rate_limit_config_default() {
194 let config = RateLimitConfig::default();
195 assert_eq!(config.requests_per_minute, 100);
196 assert_eq!(config.burst, 200);
197 assert!(config.per_ip);
198 assert!(!config.per_endpoint);
199 }
200
201 #[test]
202 fn test_rate_limit_config_custom() {
203 let config = RateLimitConfig {
204 requests_per_minute: 50,
205 burst: 100,
206 per_ip: false,
207 per_endpoint: true,
208 };
209
210 assert_eq!(config.requests_per_minute, 50);
211 assert_eq!(config.burst, 100);
212 assert!(!config.per_ip);
213 assert!(config.per_endpoint);
214 }
215
216 #[test]
217 fn test_rate_limit_config_clone() {
218 let config = RateLimitConfig {
219 requests_per_minute: 75,
220 burst: 150,
221 per_ip: true,
222 per_endpoint: true,
223 };
224
225 let cloned = config.clone();
226
227 assert_eq!(cloned.requests_per_minute, config.requests_per_minute);
228 assert_eq!(cloned.burst, config.burst);
229 assert_eq!(cloned.per_ip, config.per_ip);
230 assert_eq!(cloned.per_endpoint, config.per_endpoint);
231 }
232
233 #[test]
234 fn test_rate_limit_config_debug() {
235 let config = RateLimitConfig::default();
236 let debug_str = format!("{:?}", config);
237
238 assert!(debug_str.contains("requests_per_minute"));
239 assert!(debug_str.contains("burst"));
240 assert!(debug_str.contains("per_ip"));
241 assert!(debug_str.contains("per_endpoint"));
242 }
243
244 #[test]
247 fn test_rate_limit_quota_creation() {
248 let quota = RateLimitQuota {
249 limit: 100,
250 remaining: 50,
251 reset: 1234567890,
252 };
253
254 assert_eq!(quota.limit, 100);
255 assert_eq!(quota.remaining, 50);
256 assert_eq!(quota.reset, 1234567890);
257 }
258
259 #[test]
260 fn test_rate_limit_quota_clone() {
261 let quota = RateLimitQuota {
262 limit: 200,
263 remaining: 175,
264 reset: 9876543210,
265 };
266
267 let cloned = quota.clone();
268
269 assert_eq!(cloned.limit, quota.limit);
270 assert_eq!(cloned.remaining, quota.remaining);
271 assert_eq!(cloned.reset, quota.reset);
272 }
273
274 #[test]
275 fn test_rate_limit_quota_debug() {
276 let quota = RateLimitQuota {
277 limit: 100,
278 remaining: 50,
279 reset: 1234567890,
280 };
281
282 let debug_str = format!("{:?}", quota);
283
284 assert!(debug_str.contains("limit"));
285 assert!(debug_str.contains("remaining"));
286 assert!(debug_str.contains("reset"));
287 }
288
289 #[test]
292 fn test_rate_limiter_creation() {
293 let config = RateLimitConfig::default();
294 let limiter = GlobalRateLimiter::new(config);
295
296 assert!(limiter.check_rate_limit());
298 }
299
300 #[test]
301 fn test_rate_limiter_with_custom_config() {
302 let config = RateLimitConfig {
303 requests_per_minute: 60,
304 burst: 10,
305 per_ip: false,
306 per_endpoint: false,
307 };
308
309 let limiter = GlobalRateLimiter::new(config);
310 assert!(limiter.check_rate_limit());
311 }
312
313 #[test]
314 fn test_rate_limiter_burst() {
315 let config = RateLimitConfig {
316 requests_per_minute: 10,
317 burst: 5,
318 per_ip: false,
319 per_endpoint: false,
320 };
321
322 let limiter = GlobalRateLimiter::new(config);
323
324 for _ in 0..5 {
326 assert!(limiter.check_rate_limit(), "Burst request should be allowed");
327 }
328 }
329
330 #[test]
331 fn test_rate_limiter_multiple_requests() {
332 let config = RateLimitConfig {
333 requests_per_minute: 1000,
334 burst: 100,
335 per_ip: false,
336 per_endpoint: false,
337 };
338
339 let limiter = GlobalRateLimiter::new(config);
340
341 for i in 0..50 {
343 assert!(limiter.check_rate_limit(), "Request {} should be allowed", i);
344 }
345 }
346
347 #[test]
348 fn test_get_quota_info() {
349 let config = RateLimitConfig {
350 requests_per_minute: 100,
351 burst: 50,
352 per_ip: false,
353 per_endpoint: false,
354 };
355
356 let limiter = GlobalRateLimiter::new(config);
357
358 let quota = limiter.get_quota_info();
359
360 assert_eq!(quota.limit, 100);
361 assert!(quota.remaining > 0);
362 assert!(quota.reset > 0);
363 }
364
365 #[test]
366 fn test_quota_info_limit_matches_config() {
367 let config = RateLimitConfig {
368 requests_per_minute: 500,
369 burst: 100,
370 per_ip: false,
371 per_endpoint: false,
372 };
373
374 let limiter = GlobalRateLimiter::new(config);
375 let quota = limiter.get_quota_info();
376
377 assert_eq!(quota.limit, 500);
378 }
379
380 #[test]
381 fn test_quota_decrements_remaining() {
382 let config = RateLimitConfig {
383 requests_per_minute: 100,
384 burst: 50,
385 per_ip: false,
386 per_endpoint: false,
387 };
388
389 let limiter = GlobalRateLimiter::new(config);
390
391 let first_quota = limiter.get_quota_info();
392 let second_quota = limiter.get_quota_info();
393
394 assert!(second_quota.remaining <= first_quota.remaining, "Remaining should not increase");
396 }
397
398 #[test]
399 fn test_quota_reset_timestamp_is_future() {
400 let config = RateLimitConfig::default();
401 let limiter = GlobalRateLimiter::new(config);
402
403 let quota = limiter.get_quota_info();
404
405 let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
406
407 assert!(quota.reset >= now, "Reset timestamp should be >= current time");
409 assert!(quota.reset <= now + 120, "Reset timestamp should be within 2 minutes");
410 }
411
412 #[test]
413 fn test_rate_limiter_high_burst() {
414 let config = RateLimitConfig {
415 requests_per_minute: 10,
416 burst: 1000, per_ip: false,
418 per_endpoint: false,
419 };
420
421 let limiter = GlobalRateLimiter::new(config);
422
423 for _ in 0..100 {
425 assert!(limiter.check_rate_limit());
426 }
427 }
428
429 #[test]
430 fn test_rate_limiter_low_limit() {
431 let config = RateLimitConfig {
432 requests_per_minute: 1,
433 burst: 1,
434 per_ip: false,
435 per_endpoint: false,
436 };
437
438 let limiter = GlobalRateLimiter::new(config);
439
440 assert!(limiter.check_rate_limit());
442 }
443
444 #[test]
447 fn test_config_with_zero_values_handled() {
448 let config = RateLimitConfig {
450 requests_per_minute: 0, burst: 0, per_ip: false,
453 per_endpoint: false,
454 };
455
456 let limiter = GlobalRateLimiter::new(config);
458 assert!(limiter.check_rate_limit());
459 }
460
461 #[test]
462 fn test_multiple_quota_calls_same_limiter() {
463 let config = RateLimitConfig::default();
464 let limiter = GlobalRateLimiter::new(config);
465
466 let quotas: Vec<RateLimitQuota> = (0..5).map(|_| limiter.get_quota_info()).collect();
468
469 for quota in "as {
471 assert_eq!(quota.limit, 100);
472 }
473
474 let first_reset = quotas[0].reset;
476 for quota in "as {
477 assert!(
478 (quota.reset as i64 - first_reset as i64).abs() <= 1,
479 "Reset timestamps should be within 1 second of each other"
480 );
481 }
482 }
483
484 #[test]
485 fn test_quota_remaining_never_negative() {
486 let config = RateLimitConfig {
487 requests_per_minute: 5,
488 burst: 5,
489 per_ip: false,
490 per_endpoint: false,
491 };
492
493 let limiter = GlobalRateLimiter::new(config);
494
495 for _ in 0..20 {
497 let quota = limiter.get_quota_info();
498 assert!(quota.remaining <= 100, "Remaining should be reasonable");
500 }
501 }
502}