mockforge_http/middleware/
rate_limit.rs1use axum::{
7 body::Body,
8 extract::{ConnectInfo, State},
9 http::{HeaderName, HeaderValue, Request, StatusCode},
10 middleware::Next,
11 response::Response,
12};
13use governor::{
14 clock::DefaultClock,
15 state::{InMemoryState, NotKeyed},
16 Quota, RateLimiter,
17};
18use std::net::SocketAddr;
19use std::num::NonZeroU32;
20use std::sync::{Arc, Mutex};
21use std::time::{Duration, SystemTime, UNIX_EPOCH};
22use tracing::warn;
23
24#[derive(Debug, Clone)]
26pub struct RateLimitConfig {
27 pub requests_per_minute: u32,
29 pub burst: u32,
31 pub per_ip: bool,
33 pub per_endpoint: bool,
35}
36
37impl Default for RateLimitConfig {
38 fn default() -> Self {
39 Self {
40 requests_per_minute: 100,
41 burst: 200,
42 per_ip: true,
43 per_endpoint: false,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct RateLimitQuota {
51 pub limit: u32,
53 pub remaining: u32,
55 pub reset: u64,
57}
58
59pub struct GlobalRateLimiter {
61 limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
62 config: RateLimitConfig,
63 window_start: Arc<Mutex<SystemTime>>,
65 remaining_counter: Arc<Mutex<u32>>,
67}
68
69impl GlobalRateLimiter {
70 pub fn new(config: RateLimitConfig) -> Self {
72 let quota = Quota::per_minute(
73 NonZeroU32::new(config.requests_per_minute)
74 .unwrap_or(NonZeroU32::new(100).expect("constant 100 is non-zero")),
75 )
76 .allow_burst(
77 NonZeroU32::new(config.burst)
78 .unwrap_or(NonZeroU32::new(200).expect("constant 200 is non-zero")),
79 );
80
81 let limiter = Arc::new(RateLimiter::direct(quota));
82 let window_start = Arc::new(Mutex::new(SystemTime::now()));
83 let remaining_counter = Arc::new(Mutex::new(config.requests_per_minute));
84
85 Self {
86 limiter,
87 config,
88 window_start,
89 remaining_counter,
90 }
91 }
92
93 pub fn check_rate_limit(&self) -> bool {
95 self.limiter.check().is_ok()
96 }
97
98 pub fn get_quota_info(&self) -> RateLimitQuota {
103 let now = SystemTime::now();
104 let mut window_start =
105 self.window_start.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
106 let mut remaining =
107 self.remaining_counter.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
108
109 let window_duration = Duration::from_secs(60);
111 if now.duration_since(*window_start).unwrap_or(Duration::ZERO) >= window_duration {
112 *window_start = now;
114 *remaining = self.config.requests_per_minute;
115 }
116
117 let current_remaining = *remaining;
121 if current_remaining > 0 {
122 *remaining = current_remaining.saturating_sub(1);
123 }
124
125 let reset_timestamp =
127 window_start.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_secs() + 60; RateLimitQuota {
130 limit: self.config.requests_per_minute,
131 remaining: current_remaining,
132 reset: reset_timestamp,
133 }
134 }
135}
136
137pub async fn rate_limit_middleware(
144 State(state): State<crate::HttpServerState>,
145 ConnectInfo(addr): ConnectInfo<SocketAddr>,
146 req: Request<Body>,
147 next: Next,
148) -> Response {
149 let quota_info = if let Some(limiter) = &state.rate_limiter {
151 if !limiter.check_rate_limit() {
153 warn!("Rate limit exceeded for IP: {}", addr.ip());
154 let mut response = Response::builder()
156 .status(StatusCode::TOO_MANY_REQUESTS)
157 .body(Body::from("Too Many Requests"))
158 .unwrap_or_else(|_| Response::new(Body::from("Too Many Requests")));
159
160 let retry_after = HeaderValue::from_static("60");
162 response
163 .headers_mut()
164 .insert(HeaderName::from_static("retry-after"), retry_after);
165
166 let quota = limiter.get_quota_info();
168 if let Ok(limit_value) = HeaderValue::from_str("a.limit.to_string()) {
169 response
170 .headers_mut()
171 .insert(HeaderName::from_static("x-rate-limit-limit"), limit_value);
172 }
173 if let Ok(remaining_value) = HeaderValue::from_str("0") {
174 response
175 .headers_mut()
176 .insert(HeaderName::from_static("x-rate-limit-remaining"), remaining_value);
177 }
178 if let Ok(reset_value) = HeaderValue::from_str("a.reset.to_string()) {
179 response
180 .headers_mut()
181 .insert(HeaderName::from_static("x-rate-limit-reset"), reset_value);
182 }
183
184 return response;
185 }
186
187 Some(limiter.get_quota_info())
189 } else {
190 tracing::debug!("No rate limiter configured, allowing request");
192 None
193 };
194
195 let mut response = next.run(req).await;
197
198 if let Some(quota) = quota_info {
201 let limit_name = HeaderName::from_static("x-rate-limit-limit");
203 if let Ok(limit_value) = HeaderValue::from_str("a.limit.to_string()) {
204 response.headers_mut().insert(limit_name, limit_value);
205 }
206
207 let remaining_name = HeaderName::from_static("x-rate-limit-remaining");
209 if let Ok(remaining_value) = HeaderValue::from_str("a.remaining.to_string()) {
210 response.headers_mut().insert(remaining_name, remaining_value);
211 }
212
213 let reset_name = HeaderName::from_static("x-rate-limit-reset");
215 if let Ok(reset_value) = HeaderValue::from_str("a.reset.to_string()) {
216 response.headers_mut().insert(reset_name, reset_value);
217 }
218 }
219
220 response
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
230 fn test_rate_limit_config_default() {
231 let config = RateLimitConfig::default();
232 assert_eq!(config.requests_per_minute, 100);
233 assert_eq!(config.burst, 200);
234 assert!(config.per_ip);
235 assert!(!config.per_endpoint);
236 }
237
238 #[test]
239 fn test_rate_limit_config_custom() {
240 let config = RateLimitConfig {
241 requests_per_minute: 50,
242 burst: 100,
243 per_ip: false,
244 per_endpoint: true,
245 };
246
247 assert_eq!(config.requests_per_minute, 50);
248 assert_eq!(config.burst, 100);
249 assert!(!config.per_ip);
250 assert!(config.per_endpoint);
251 }
252
253 #[test]
254 fn test_rate_limit_config_clone() {
255 let config = RateLimitConfig {
256 requests_per_minute: 75,
257 burst: 150,
258 per_ip: true,
259 per_endpoint: true,
260 };
261
262 let cloned = config.clone();
263
264 assert_eq!(cloned.requests_per_minute, config.requests_per_minute);
265 assert_eq!(cloned.burst, config.burst);
266 assert_eq!(cloned.per_ip, config.per_ip);
267 assert_eq!(cloned.per_endpoint, config.per_endpoint);
268 }
269
270 #[test]
271 fn test_rate_limit_config_debug() {
272 let config = RateLimitConfig::default();
273 let debug_str = format!("{:?}", config);
274
275 assert!(debug_str.contains("requests_per_minute"));
276 assert!(debug_str.contains("burst"));
277 assert!(debug_str.contains("per_ip"));
278 assert!(debug_str.contains("per_endpoint"));
279 }
280
281 #[test]
284 fn test_rate_limit_quota_creation() {
285 let quota = RateLimitQuota {
286 limit: 100,
287 remaining: 50,
288 reset: 1234567890,
289 };
290
291 assert_eq!(quota.limit, 100);
292 assert_eq!(quota.remaining, 50);
293 assert_eq!(quota.reset, 1234567890);
294 }
295
296 #[test]
297 fn test_rate_limit_quota_clone() {
298 let quota = RateLimitQuota {
299 limit: 200,
300 remaining: 175,
301 reset: 9876543210,
302 };
303
304 let cloned = quota.clone();
305
306 assert_eq!(cloned.limit, quota.limit);
307 assert_eq!(cloned.remaining, quota.remaining);
308 assert_eq!(cloned.reset, quota.reset);
309 }
310
311 #[test]
312 fn test_rate_limit_quota_debug() {
313 let quota = RateLimitQuota {
314 limit: 100,
315 remaining: 50,
316 reset: 1234567890,
317 };
318
319 let debug_str = format!("{:?}", quota);
320
321 assert!(debug_str.contains("limit"));
322 assert!(debug_str.contains("remaining"));
323 assert!(debug_str.contains("reset"));
324 }
325
326 #[test]
329 fn test_rate_limiter_creation() {
330 let config = RateLimitConfig::default();
331 let limiter = GlobalRateLimiter::new(config);
332
333 assert!(limiter.check_rate_limit());
335 }
336
337 #[test]
338 fn test_rate_limiter_with_custom_config() {
339 let config = RateLimitConfig {
340 requests_per_minute: 60,
341 burst: 10,
342 per_ip: false,
343 per_endpoint: false,
344 };
345
346 let limiter = GlobalRateLimiter::new(config);
347 assert!(limiter.check_rate_limit());
348 }
349
350 #[test]
351 fn test_rate_limiter_burst() {
352 let config = RateLimitConfig {
353 requests_per_minute: 10,
354 burst: 5,
355 per_ip: false,
356 per_endpoint: false,
357 };
358
359 let limiter = GlobalRateLimiter::new(config);
360
361 for _ in 0..5 {
363 assert!(limiter.check_rate_limit(), "Burst request should be allowed");
364 }
365 }
366
367 #[test]
368 fn test_rate_limiter_multiple_requests() {
369 let config = RateLimitConfig {
370 requests_per_minute: 1000,
371 burst: 100,
372 per_ip: false,
373 per_endpoint: false,
374 };
375
376 let limiter = GlobalRateLimiter::new(config);
377
378 for i in 0..50 {
380 assert!(limiter.check_rate_limit(), "Request {} should be allowed", i);
381 }
382 }
383
384 #[test]
385 fn test_get_quota_info() {
386 let config = RateLimitConfig {
387 requests_per_minute: 100,
388 burst: 50,
389 per_ip: false,
390 per_endpoint: false,
391 };
392
393 let limiter = GlobalRateLimiter::new(config);
394
395 let quota = limiter.get_quota_info();
396
397 assert_eq!(quota.limit, 100);
398 assert!(quota.remaining > 0);
399 assert!(quota.reset > 0);
400 }
401
402 #[test]
403 fn test_quota_info_limit_matches_config() {
404 let config = RateLimitConfig {
405 requests_per_minute: 500,
406 burst: 100,
407 per_ip: false,
408 per_endpoint: false,
409 };
410
411 let limiter = GlobalRateLimiter::new(config);
412 let quota = limiter.get_quota_info();
413
414 assert_eq!(quota.limit, 500);
415 }
416
417 #[test]
418 fn test_quota_decrements_remaining() {
419 let config = RateLimitConfig {
420 requests_per_minute: 100,
421 burst: 50,
422 per_ip: false,
423 per_endpoint: false,
424 };
425
426 let limiter = GlobalRateLimiter::new(config);
427
428 let first_quota = limiter.get_quota_info();
429 let second_quota = limiter.get_quota_info();
430
431 assert!(second_quota.remaining <= first_quota.remaining, "Remaining should not increase");
433 }
434
435 #[test]
436 fn test_quota_reset_timestamp_is_future() {
437 let config = RateLimitConfig::default();
438 let limiter = GlobalRateLimiter::new(config);
439
440 let quota = limiter.get_quota_info();
441
442 let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
443
444 assert!(quota.reset >= now, "Reset timestamp should be >= current time");
446 assert!(quota.reset <= now + 120, "Reset timestamp should be within 2 minutes");
447 }
448
449 #[test]
450 fn test_rate_limiter_high_burst() {
451 let config = RateLimitConfig {
452 requests_per_minute: 10,
453 burst: 1000, per_ip: false,
455 per_endpoint: false,
456 };
457
458 let limiter = GlobalRateLimiter::new(config);
459
460 for _ in 0..100 {
462 assert!(limiter.check_rate_limit());
463 }
464 }
465
466 #[test]
467 fn test_rate_limiter_low_limit() {
468 let config = RateLimitConfig {
469 requests_per_minute: 1,
470 burst: 1,
471 per_ip: false,
472 per_endpoint: false,
473 };
474
475 let limiter = GlobalRateLimiter::new(config);
476
477 assert!(limiter.check_rate_limit());
479 }
480
481 #[test]
484 fn test_config_with_zero_values_handled() {
485 let config = RateLimitConfig {
487 requests_per_minute: 0, burst: 0, per_ip: false,
490 per_endpoint: false,
491 };
492
493 let limiter = GlobalRateLimiter::new(config);
495 assert!(limiter.check_rate_limit());
496 }
497
498 #[test]
499 fn test_multiple_quota_calls_same_limiter() {
500 let config = RateLimitConfig::default();
501 let limiter = GlobalRateLimiter::new(config);
502
503 let quotas: Vec<RateLimitQuota> = (0..5).map(|_| limiter.get_quota_info()).collect();
505
506 for quota in "as {
508 assert_eq!(quota.limit, 100);
509 }
510
511 let first_reset = quotas[0].reset;
513 for quota in "as {
514 assert!(
515 (quota.reset as i64 - first_reset as i64).abs() <= 1,
516 "Reset timestamps should be within 1 second of each other"
517 );
518 }
519 }
520
521 #[test]
522 fn test_quota_remaining_never_negative() {
523 let config = RateLimitConfig {
524 requests_per_minute: 5,
525 burst: 5,
526 per_ip: false,
527 per_endpoint: false,
528 };
529
530 let limiter = GlobalRateLimiter::new(config);
531
532 for _ in 0..20 {
534 let quota = limiter.get_quota_info();
535 assert!(quota.remaining <= 100, "Remaining should be reasonable");
537 }
538 }
539}