1use axum::{
7 body::Body,
8 extract::{ConnectInfo, State},
9 http::{HeaderName, HeaderValue, Request, StatusCode},
10 middleware::Next,
11 response::Response,
12};
13use governor::{
14 clock::DefaultClock,
15 state::{InMemoryState, NotKeyed},
16 Quota, RateLimiter,
17};
18use std::net::SocketAddr;
19use std::num::NonZeroU32;
20use std::sync::{Arc, Mutex};
21use std::time::{Duration, SystemTime, UNIX_EPOCH};
22use tracing::warn;
23
24pub fn is_rate_limit_disabled() -> bool {
34 fn truthy(v: &str) -> bool {
35 matches!(v.trim().to_ascii_lowercase().as_str(), "1" | "true" | "yes" | "on")
36 }
37 fn falsy(v: &str) -> bool {
38 matches!(v.trim().to_ascii_lowercase().as_str(), "0" | "false" | "no" | "off")
39 }
40
41 if let Ok(v) = std::env::var("MOCKFORGE_RATE_LIMIT_ENABLED") {
42 if falsy(&v) {
43 return true;
44 }
45 }
46 if let Ok(v) = std::env::var("MOCKFORGE_RATE_LIMIT_DISABLED") {
47 if truthy(&v) {
48 return true;
49 }
50 }
51 false
52}
53
54#[derive(Debug, Clone)]
56pub struct RateLimitConfig {
57 pub requests_per_minute: u32,
59 pub burst: u32,
61 pub per_ip: bool,
63 pub per_endpoint: bool,
65}
66
67impl Default for RateLimitConfig {
68 fn default() -> Self {
69 Self {
70 requests_per_minute: 100,
71 burst: 200,
72 per_ip: true,
73 per_endpoint: false,
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct RateLimitQuota {
81 pub limit: u32,
83 pub remaining: u32,
85 pub reset: u64,
87}
88
89pub struct GlobalRateLimiter {
91 limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
92 config: RateLimitConfig,
93 window_start: Arc<Mutex<SystemTime>>,
95 remaining_counter: Arc<Mutex<u32>>,
97}
98
99impl GlobalRateLimiter {
100 pub fn new(config: RateLimitConfig) -> Self {
102 let quota = Quota::per_minute(
103 NonZeroU32::new(config.requests_per_minute)
104 .unwrap_or(NonZeroU32::new(100).expect("constant 100 is non-zero")),
105 )
106 .allow_burst(
107 NonZeroU32::new(config.burst)
108 .unwrap_or(NonZeroU32::new(200).expect("constant 200 is non-zero")),
109 );
110
111 let limiter = Arc::new(RateLimiter::direct(quota));
112 let window_start = Arc::new(Mutex::new(SystemTime::now()));
113 let remaining_counter = Arc::new(Mutex::new(config.requests_per_minute));
114
115 Self {
116 limiter,
117 config,
118 window_start,
119 remaining_counter,
120 }
121 }
122
123 pub fn check_rate_limit(&self) -> bool {
125 self.limiter.check().is_ok()
126 }
127
128 pub fn get_quota_info(&self) -> RateLimitQuota {
133 let now = SystemTime::now();
134 let mut window_start =
135 self.window_start.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
136 let mut remaining =
137 self.remaining_counter.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
138
139 let window_duration = Duration::from_secs(60);
141 if now.duration_since(*window_start).unwrap_or(Duration::ZERO) >= window_duration {
142 *window_start = now;
144 *remaining = self.config.requests_per_minute;
145 }
146
147 let current_remaining = *remaining;
151 if current_remaining > 0 {
152 *remaining = current_remaining.saturating_sub(1);
153 }
154
155 let reset_timestamp =
157 window_start.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_secs() + 60; RateLimitQuota {
160 limit: self.config.requests_per_minute,
161 remaining: current_remaining,
162 reset: reset_timestamp,
163 }
164 }
165}
166
167pub async fn rate_limit_middleware(
174 State(state): State<crate::HttpServerState>,
175 ConnectInfo(addr): ConnectInfo<SocketAddr>,
176 req: Request<Body>,
177 next: Next,
178) -> Response {
179 let quota_info = if let Some(limiter) = &state.rate_limiter {
181 if !limiter.check_rate_limit() {
183 warn!("Rate limit exceeded for IP: {}", addr.ip());
184 let mut response = Response::builder()
186 .status(StatusCode::TOO_MANY_REQUESTS)
187 .body(Body::from("Too Many Requests"))
188 .unwrap_or_else(|_| Response::new(Body::from("Too Many Requests")));
189
190 let retry_after = HeaderValue::from_static("60");
192 response
193 .headers_mut()
194 .insert(HeaderName::from_static("retry-after"), retry_after);
195
196 let quota = limiter.get_quota_info();
198 if let Ok(limit_value) = HeaderValue::from_str("a.limit.to_string()) {
199 response
200 .headers_mut()
201 .insert(HeaderName::from_static("x-rate-limit-limit"), limit_value);
202 }
203 if let Ok(remaining_value) = HeaderValue::from_str("0") {
204 response
205 .headers_mut()
206 .insert(HeaderName::from_static("x-rate-limit-remaining"), remaining_value);
207 }
208 if let Ok(reset_value) = HeaderValue::from_str("a.reset.to_string()) {
209 response
210 .headers_mut()
211 .insert(HeaderName::from_static("x-rate-limit-reset"), reset_value);
212 }
213
214 return response;
215 }
216
217 Some(limiter.get_quota_info())
219 } else {
220 tracing::debug!("No rate limiter configured, allowing request");
222 None
223 };
224
225 let mut response = next.run(req).await;
227
228 if let Some(quota) = quota_info {
231 let limit_name = HeaderName::from_static("x-rate-limit-limit");
233 if let Ok(limit_value) = HeaderValue::from_str("a.limit.to_string()) {
234 response.headers_mut().insert(limit_name, limit_value);
235 }
236
237 let remaining_name = HeaderName::from_static("x-rate-limit-remaining");
239 if let Ok(remaining_value) = HeaderValue::from_str("a.remaining.to_string()) {
240 response.headers_mut().insert(remaining_name, remaining_value);
241 }
242
243 let reset_name = HeaderName::from_static("x-rate-limit-reset");
245 if let Ok(reset_value) = HeaderValue::from_str("a.reset.to_string()) {
246 response.headers_mut().insert(reset_name, reset_value);
247 }
248 }
249
250 response
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
262 fn test_is_rate_limit_disabled_env_vars() {
263 let saved_enabled = std::env::var("MOCKFORGE_RATE_LIMIT_ENABLED").ok();
264 let saved_disabled = std::env::var("MOCKFORGE_RATE_LIMIT_DISABLED").ok();
265
266 std::env::remove_var("MOCKFORGE_RATE_LIMIT_ENABLED");
268 std::env::remove_var("MOCKFORGE_RATE_LIMIT_DISABLED");
269 assert!(!is_rate_limit_disabled());
270
271 for v in ["false", "0", "no", "off", "FALSE", " False "] {
273 std::env::set_var("MOCKFORGE_RATE_LIMIT_ENABLED", v);
274 assert!(is_rate_limit_disabled(), "ENABLED={v:?} should disable");
275 }
276 std::env::remove_var("MOCKFORGE_RATE_LIMIT_ENABLED");
277
278 std::env::set_var("MOCKFORGE_RATE_LIMIT_ENABLED", "true");
280 assert!(!is_rate_limit_disabled());
281 std::env::remove_var("MOCKFORGE_RATE_LIMIT_ENABLED");
282
283 for v in ["true", "1", "yes", "on", "TRUE"] {
285 std::env::set_var("MOCKFORGE_RATE_LIMIT_DISABLED", v);
286 assert!(is_rate_limit_disabled(), "DISABLED={v:?} should disable");
287 }
288 std::env::remove_var("MOCKFORGE_RATE_LIMIT_DISABLED");
289
290 std::env::set_var("MOCKFORGE_RATE_LIMIT_DISABLED", "false");
292 assert!(!is_rate_limit_disabled());
293 std::env::remove_var("MOCKFORGE_RATE_LIMIT_DISABLED");
294
295 match saved_enabled {
297 Some(v) => std::env::set_var("MOCKFORGE_RATE_LIMIT_ENABLED", v),
298 None => std::env::remove_var("MOCKFORGE_RATE_LIMIT_ENABLED"),
299 }
300 match saved_disabled {
301 Some(v) => std::env::set_var("MOCKFORGE_RATE_LIMIT_DISABLED", v),
302 None => std::env::remove_var("MOCKFORGE_RATE_LIMIT_DISABLED"),
303 }
304 }
305
306 #[test]
309 fn test_rate_limit_config_default() {
310 let config = RateLimitConfig::default();
311 assert_eq!(config.requests_per_minute, 100);
312 assert_eq!(config.burst, 200);
313 assert!(config.per_ip);
314 assert!(!config.per_endpoint);
315 }
316
317 #[test]
318 fn test_rate_limit_config_custom() {
319 let config = RateLimitConfig {
320 requests_per_minute: 50,
321 burst: 100,
322 per_ip: false,
323 per_endpoint: true,
324 };
325
326 assert_eq!(config.requests_per_minute, 50);
327 assert_eq!(config.burst, 100);
328 assert!(!config.per_ip);
329 assert!(config.per_endpoint);
330 }
331
332 #[test]
333 fn test_rate_limit_config_clone() {
334 let config = RateLimitConfig {
335 requests_per_minute: 75,
336 burst: 150,
337 per_ip: true,
338 per_endpoint: true,
339 };
340
341 let cloned = config.clone();
342
343 assert_eq!(cloned.requests_per_minute, config.requests_per_minute);
344 assert_eq!(cloned.burst, config.burst);
345 assert_eq!(cloned.per_ip, config.per_ip);
346 assert_eq!(cloned.per_endpoint, config.per_endpoint);
347 }
348
349 #[test]
350 fn test_rate_limit_config_debug() {
351 let config = RateLimitConfig::default();
352 let debug_str = format!("{:?}", config);
353
354 assert!(debug_str.contains("requests_per_minute"));
355 assert!(debug_str.contains("burst"));
356 assert!(debug_str.contains("per_ip"));
357 assert!(debug_str.contains("per_endpoint"));
358 }
359
360 #[test]
363 fn test_rate_limit_quota_creation() {
364 let quota = RateLimitQuota {
365 limit: 100,
366 remaining: 50,
367 reset: 1234567890,
368 };
369
370 assert_eq!(quota.limit, 100);
371 assert_eq!(quota.remaining, 50);
372 assert_eq!(quota.reset, 1234567890);
373 }
374
375 #[test]
376 fn test_rate_limit_quota_clone() {
377 let quota = RateLimitQuota {
378 limit: 200,
379 remaining: 175,
380 reset: 9876543210,
381 };
382
383 let cloned = quota.clone();
384
385 assert_eq!(cloned.limit, quota.limit);
386 assert_eq!(cloned.remaining, quota.remaining);
387 assert_eq!(cloned.reset, quota.reset);
388 }
389
390 #[test]
391 fn test_rate_limit_quota_debug() {
392 let quota = RateLimitQuota {
393 limit: 100,
394 remaining: 50,
395 reset: 1234567890,
396 };
397
398 let debug_str = format!("{:?}", quota);
399
400 assert!(debug_str.contains("limit"));
401 assert!(debug_str.contains("remaining"));
402 assert!(debug_str.contains("reset"));
403 }
404
405 #[test]
408 fn test_rate_limiter_creation() {
409 let config = RateLimitConfig::default();
410 let limiter = GlobalRateLimiter::new(config);
411
412 assert!(limiter.check_rate_limit());
414 }
415
416 #[test]
417 fn test_rate_limiter_with_custom_config() {
418 let config = RateLimitConfig {
419 requests_per_minute: 60,
420 burst: 10,
421 per_ip: false,
422 per_endpoint: false,
423 };
424
425 let limiter = GlobalRateLimiter::new(config);
426 assert!(limiter.check_rate_limit());
427 }
428
429 #[test]
430 fn test_rate_limiter_burst() {
431 let config = RateLimitConfig {
432 requests_per_minute: 10,
433 burst: 5,
434 per_ip: false,
435 per_endpoint: false,
436 };
437
438 let limiter = GlobalRateLimiter::new(config);
439
440 for _ in 0..5 {
442 assert!(limiter.check_rate_limit(), "Burst request should be allowed");
443 }
444 }
445
446 #[test]
447 fn test_rate_limiter_multiple_requests() {
448 let config = RateLimitConfig {
449 requests_per_minute: 1000,
450 burst: 100,
451 per_ip: false,
452 per_endpoint: false,
453 };
454
455 let limiter = GlobalRateLimiter::new(config);
456
457 for i in 0..50 {
459 assert!(limiter.check_rate_limit(), "Request {} should be allowed", i);
460 }
461 }
462
463 #[test]
464 fn test_get_quota_info() {
465 let config = RateLimitConfig {
466 requests_per_minute: 100,
467 burst: 50,
468 per_ip: false,
469 per_endpoint: false,
470 };
471
472 let limiter = GlobalRateLimiter::new(config);
473
474 let quota = limiter.get_quota_info();
475
476 assert_eq!(quota.limit, 100);
477 assert!(quota.remaining > 0);
478 assert!(quota.reset > 0);
479 }
480
481 #[test]
482 fn test_quota_info_limit_matches_config() {
483 let config = RateLimitConfig {
484 requests_per_minute: 500,
485 burst: 100,
486 per_ip: false,
487 per_endpoint: false,
488 };
489
490 let limiter = GlobalRateLimiter::new(config);
491 let quota = limiter.get_quota_info();
492
493 assert_eq!(quota.limit, 500);
494 }
495
496 #[test]
497 fn test_quota_decrements_remaining() {
498 let config = RateLimitConfig {
499 requests_per_minute: 100,
500 burst: 50,
501 per_ip: false,
502 per_endpoint: false,
503 };
504
505 let limiter = GlobalRateLimiter::new(config);
506
507 let first_quota = limiter.get_quota_info();
508 let second_quota = limiter.get_quota_info();
509
510 assert!(second_quota.remaining <= first_quota.remaining, "Remaining should not increase");
512 }
513
514 #[test]
515 fn test_quota_reset_timestamp_is_future() {
516 let config = RateLimitConfig::default();
517 let limiter = GlobalRateLimiter::new(config);
518
519 let quota = limiter.get_quota_info();
520
521 let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
522
523 assert!(quota.reset >= now, "Reset timestamp should be >= current time");
525 assert!(quota.reset <= now + 120, "Reset timestamp should be within 2 minutes");
526 }
527
528 #[test]
529 fn test_rate_limiter_high_burst() {
530 let config = RateLimitConfig {
531 requests_per_minute: 10,
532 burst: 1000, per_ip: false,
534 per_endpoint: false,
535 };
536
537 let limiter = GlobalRateLimiter::new(config);
538
539 for _ in 0..100 {
541 assert!(limiter.check_rate_limit());
542 }
543 }
544
545 #[test]
546 fn test_rate_limiter_low_limit() {
547 let config = RateLimitConfig {
548 requests_per_minute: 1,
549 burst: 1,
550 per_ip: false,
551 per_endpoint: false,
552 };
553
554 let limiter = GlobalRateLimiter::new(config);
555
556 assert!(limiter.check_rate_limit());
558 }
559
560 #[test]
563 fn test_config_with_zero_values_handled() {
564 let config = RateLimitConfig {
566 requests_per_minute: 0, burst: 0, per_ip: false,
569 per_endpoint: false,
570 };
571
572 let limiter = GlobalRateLimiter::new(config);
574 assert!(limiter.check_rate_limit());
575 }
576
577 #[test]
578 fn test_multiple_quota_calls_same_limiter() {
579 let config = RateLimitConfig::default();
580 let limiter = GlobalRateLimiter::new(config);
581
582 let quotas: Vec<RateLimitQuota> = (0..5).map(|_| limiter.get_quota_info()).collect();
584
585 for quota in "as {
587 assert_eq!(quota.limit, 100);
588 }
589
590 let first_reset = quotas[0].reset;
592 for quota in "as {
593 assert!(
594 (quota.reset as i64 - first_reset as i64).abs() <= 1,
595 "Reset timestamps should be within 1 second of each other"
596 );
597 }
598 }
599
600 #[test]
601 fn test_quota_remaining_never_negative() {
602 let config = RateLimitConfig {
603 requests_per_minute: 5,
604 burst: 5,
605 per_ip: false,
606 per_endpoint: false,
607 };
608
609 let limiter = GlobalRateLimiter::new(config);
610
611 for _ in 0..20 {
613 let quota = limiter.get_quota_info();
614 assert!(quota.remaining <= 100, "Remaining should be reasonable");
616 }
617 }
618}