mockforge_http/middleware/
rate_limit.rs1use axum::{
7 body::Body,
8 extract::{ConnectInfo, State},
9 http::{HeaderName, HeaderValue, Request, StatusCode},
10 middleware::Next,
11 response::Response,
12};
13use governor::{
14 clock::DefaultClock,
15 state::{InMemoryState, NotKeyed},
16 Quota, RateLimiter,
17};
18use std::net::SocketAddr;
19use std::num::NonZeroU32;
20use std::sync::{Arc, Mutex};
21use std::time::{Duration, SystemTime, UNIX_EPOCH};
22use tracing::warn;
23
24#[derive(Debug, Clone)]
26pub struct RateLimitConfig {
27 pub requests_per_minute: u32,
29 pub burst: u32,
31 pub per_ip: bool,
33 pub per_endpoint: bool,
35}
36
37impl Default for RateLimitConfig {
38 fn default() -> Self {
39 Self {
40 requests_per_minute: 100,
41 burst: 200,
42 per_ip: true,
43 per_endpoint: false,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct RateLimitQuota {
51 pub limit: u32,
53 pub remaining: u32,
55 pub reset: u64,
57}
58
59pub struct GlobalRateLimiter {
61 limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
62 config: RateLimitConfig,
63 window_start: Arc<Mutex<SystemTime>>,
65 remaining_counter: Arc<Mutex<u32>>,
67}
68
69impl GlobalRateLimiter {
70 pub fn new(config: RateLimitConfig) -> Self {
72 let quota = Quota::per_minute(
73 NonZeroU32::new(config.requests_per_minute)
74 .unwrap_or(NonZeroU32::new(100).expect("constant 100 is non-zero")),
75 )
76 .allow_burst(
77 NonZeroU32::new(config.burst)
78 .unwrap_or(NonZeroU32::new(200).expect("constant 200 is non-zero")),
79 );
80
81 let limiter = Arc::new(RateLimiter::direct(quota));
82 let window_start = Arc::new(Mutex::new(SystemTime::now()));
83 let remaining_counter = Arc::new(Mutex::new(config.requests_per_minute));
84
85 Self {
86 limiter,
87 config,
88 window_start,
89 remaining_counter,
90 }
91 }
92
93 pub fn check_rate_limit(&self) -> bool {
95 self.limiter.check().is_ok()
96 }
97
98 pub fn get_quota_info(&self) -> RateLimitQuota {
103 let now = SystemTime::now();
104 let mut window_start =
105 self.window_start.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
106 let mut remaining =
107 self.remaining_counter.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
108
109 let window_duration = Duration::from_secs(60);
111 if now.duration_since(*window_start).unwrap_or(Duration::ZERO) >= window_duration {
112 *window_start = now;
114 *remaining = self.config.requests_per_minute;
115 }
116
117 let current_remaining = *remaining;
121 if current_remaining > 0 {
122 *remaining = current_remaining.saturating_sub(1);
123 }
124
125 let reset_timestamp =
127 window_start.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_secs() + 60; RateLimitQuota {
130 limit: self.config.requests_per_minute,
131 remaining: current_remaining,
132 reset: reset_timestamp,
133 }
134 }
135}
136
137pub async fn rate_limit_middleware(
144 State(state): axum::extract::State<crate::HttpServerState>,
145 ConnectInfo(addr): ConnectInfo<SocketAddr>,
146 req: Request<Body>,
147 next: Next,
148) -> Response {
149 let quota_info = if let Some(limiter) = &state.rate_limiter {
151 if !limiter.check_rate_limit() {
153 warn!("Rate limit exceeded for IP: {}", addr.ip());
154 let mut response = Response::builder()
156 .status(StatusCode::TOO_MANY_REQUESTS)
157 .body(Body::from("Too Many Requests"))
158 .unwrap_or_else(|_| Response::new(Body::from("Too Many Requests")));
159
160 if let Ok(retry_after) = HeaderValue::from_static("60").try_into() {
162 response
163 .headers_mut()
164 .insert(HeaderName::from_static("retry-after"), retry_after);
165 }
166
167 let quota = limiter.get_quota_info();
169 if let Ok(limit_value) = HeaderValue::from_str("a.limit.to_string()) {
170 response
171 .headers_mut()
172 .insert(HeaderName::from_static("x-rate-limit-limit"), limit_value);
173 }
174 if let Ok(remaining_value) = HeaderValue::from_str("0") {
175 response
176 .headers_mut()
177 .insert(HeaderName::from_static("x-rate-limit-remaining"), remaining_value);
178 }
179 if let Ok(reset_value) = HeaderValue::from_str("a.reset.to_string()) {
180 response
181 .headers_mut()
182 .insert(HeaderName::from_static("x-rate-limit-reset"), reset_value);
183 }
184
185 return response;
186 }
187
188 Some(limiter.get_quota_info())
190 } else {
191 tracing::debug!("No rate limiter configured, allowing request");
193 None
194 };
195
196 let mut response = next.run(req).await;
198
199 if let Some(quota) = quota_info {
202 let limit_name = HeaderName::from_static("x-rate-limit-limit");
204 if let Ok(limit_value) = HeaderValue::from_str("a.limit.to_string()) {
205 response.headers_mut().insert(limit_name, limit_value);
206 }
207
208 let remaining_name = HeaderName::from_static("x-rate-limit-remaining");
210 if let Ok(remaining_value) = HeaderValue::from_str("a.remaining.to_string()) {
211 response.headers_mut().insert(remaining_name, remaining_value);
212 }
213
214 let reset_name = HeaderName::from_static("x-rate-limit-reset");
216 if let Ok(reset_value) = HeaderValue::from_str("a.reset.to_string()) {
217 response.headers_mut().insert(reset_name, reset_value);
218 }
219 }
220
221 response
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
231 fn test_rate_limit_config_default() {
232 let config = RateLimitConfig::default();
233 assert_eq!(config.requests_per_minute, 100);
234 assert_eq!(config.burst, 200);
235 assert!(config.per_ip);
236 assert!(!config.per_endpoint);
237 }
238
239 #[test]
240 fn test_rate_limit_config_custom() {
241 let config = RateLimitConfig {
242 requests_per_minute: 50,
243 burst: 100,
244 per_ip: false,
245 per_endpoint: true,
246 };
247
248 assert_eq!(config.requests_per_minute, 50);
249 assert_eq!(config.burst, 100);
250 assert!(!config.per_ip);
251 assert!(config.per_endpoint);
252 }
253
254 #[test]
255 fn test_rate_limit_config_clone() {
256 let config = RateLimitConfig {
257 requests_per_minute: 75,
258 burst: 150,
259 per_ip: true,
260 per_endpoint: true,
261 };
262
263 let cloned = config.clone();
264
265 assert_eq!(cloned.requests_per_minute, config.requests_per_minute);
266 assert_eq!(cloned.burst, config.burst);
267 assert_eq!(cloned.per_ip, config.per_ip);
268 assert_eq!(cloned.per_endpoint, config.per_endpoint);
269 }
270
271 #[test]
272 fn test_rate_limit_config_debug() {
273 let config = RateLimitConfig::default();
274 let debug_str = format!("{:?}", config);
275
276 assert!(debug_str.contains("requests_per_minute"));
277 assert!(debug_str.contains("burst"));
278 assert!(debug_str.contains("per_ip"));
279 assert!(debug_str.contains("per_endpoint"));
280 }
281
282 #[test]
285 fn test_rate_limit_quota_creation() {
286 let quota = RateLimitQuota {
287 limit: 100,
288 remaining: 50,
289 reset: 1234567890,
290 };
291
292 assert_eq!(quota.limit, 100);
293 assert_eq!(quota.remaining, 50);
294 assert_eq!(quota.reset, 1234567890);
295 }
296
297 #[test]
298 fn test_rate_limit_quota_clone() {
299 let quota = RateLimitQuota {
300 limit: 200,
301 remaining: 175,
302 reset: 9876543210,
303 };
304
305 let cloned = quota.clone();
306
307 assert_eq!(cloned.limit, quota.limit);
308 assert_eq!(cloned.remaining, quota.remaining);
309 assert_eq!(cloned.reset, quota.reset);
310 }
311
312 #[test]
313 fn test_rate_limit_quota_debug() {
314 let quota = RateLimitQuota {
315 limit: 100,
316 remaining: 50,
317 reset: 1234567890,
318 };
319
320 let debug_str = format!("{:?}", quota);
321
322 assert!(debug_str.contains("limit"));
323 assert!(debug_str.contains("remaining"));
324 assert!(debug_str.contains("reset"));
325 }
326
327 #[test]
330 fn test_rate_limiter_creation() {
331 let config = RateLimitConfig::default();
332 let limiter = GlobalRateLimiter::new(config);
333
334 assert!(limiter.check_rate_limit());
336 }
337
338 #[test]
339 fn test_rate_limiter_with_custom_config() {
340 let config = RateLimitConfig {
341 requests_per_minute: 60,
342 burst: 10,
343 per_ip: false,
344 per_endpoint: false,
345 };
346
347 let limiter = GlobalRateLimiter::new(config);
348 assert!(limiter.check_rate_limit());
349 }
350
351 #[test]
352 fn test_rate_limiter_burst() {
353 let config = RateLimitConfig {
354 requests_per_minute: 10,
355 burst: 5,
356 per_ip: false,
357 per_endpoint: false,
358 };
359
360 let limiter = GlobalRateLimiter::new(config);
361
362 for _ in 0..5 {
364 assert!(limiter.check_rate_limit(), "Burst request should be allowed");
365 }
366 }
367
368 #[test]
369 fn test_rate_limiter_multiple_requests() {
370 let config = RateLimitConfig {
371 requests_per_minute: 1000,
372 burst: 100,
373 per_ip: false,
374 per_endpoint: false,
375 };
376
377 let limiter = GlobalRateLimiter::new(config);
378
379 for i in 0..50 {
381 assert!(limiter.check_rate_limit(), "Request {} should be allowed", i);
382 }
383 }
384
385 #[test]
386 fn test_get_quota_info() {
387 let config = RateLimitConfig {
388 requests_per_minute: 100,
389 burst: 50,
390 per_ip: false,
391 per_endpoint: false,
392 };
393
394 let limiter = GlobalRateLimiter::new(config);
395
396 let quota = limiter.get_quota_info();
397
398 assert_eq!(quota.limit, 100);
399 assert!(quota.remaining > 0);
400 assert!(quota.reset > 0);
401 }
402
403 #[test]
404 fn test_quota_info_limit_matches_config() {
405 let config = RateLimitConfig {
406 requests_per_minute: 500,
407 burst: 100,
408 per_ip: false,
409 per_endpoint: false,
410 };
411
412 let limiter = GlobalRateLimiter::new(config);
413 let quota = limiter.get_quota_info();
414
415 assert_eq!(quota.limit, 500);
416 }
417
418 #[test]
419 fn test_quota_decrements_remaining() {
420 let config = RateLimitConfig {
421 requests_per_minute: 100,
422 burst: 50,
423 per_ip: false,
424 per_endpoint: false,
425 };
426
427 let limiter = GlobalRateLimiter::new(config);
428
429 let first_quota = limiter.get_quota_info();
430 let second_quota = limiter.get_quota_info();
431
432 assert!(second_quota.remaining <= first_quota.remaining, "Remaining should not increase");
434 }
435
436 #[test]
437 fn test_quota_reset_timestamp_is_future() {
438 let config = RateLimitConfig::default();
439 let limiter = GlobalRateLimiter::new(config);
440
441 let quota = limiter.get_quota_info();
442
443 let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
444
445 assert!(quota.reset >= now, "Reset timestamp should be >= current time");
447 assert!(quota.reset <= now + 120, "Reset timestamp should be within 2 minutes");
448 }
449
450 #[test]
451 fn test_rate_limiter_high_burst() {
452 let config = RateLimitConfig {
453 requests_per_minute: 10,
454 burst: 1000, per_ip: false,
456 per_endpoint: false,
457 };
458
459 let limiter = GlobalRateLimiter::new(config);
460
461 for _ in 0..100 {
463 assert!(limiter.check_rate_limit());
464 }
465 }
466
467 #[test]
468 fn test_rate_limiter_low_limit() {
469 let config = RateLimitConfig {
470 requests_per_minute: 1,
471 burst: 1,
472 per_ip: false,
473 per_endpoint: false,
474 };
475
476 let limiter = GlobalRateLimiter::new(config);
477
478 assert!(limiter.check_rate_limit());
480 }
481
482 #[test]
485 fn test_config_with_zero_values_handled() {
486 let config = RateLimitConfig {
488 requests_per_minute: 0, burst: 0, per_ip: false,
491 per_endpoint: false,
492 };
493
494 let limiter = GlobalRateLimiter::new(config);
496 assert!(limiter.check_rate_limit());
497 }
498
499 #[test]
500 fn test_multiple_quota_calls_same_limiter() {
501 let config = RateLimitConfig::default();
502 let limiter = GlobalRateLimiter::new(config);
503
504 let quotas: Vec<RateLimitQuota> = (0..5).map(|_| limiter.get_quota_info()).collect();
506
507 for quota in "as {
509 assert_eq!(quota.limit, 100);
510 }
511
512 let first_reset = quotas[0].reset;
514 for quota in "as {
515 assert!(
516 (quota.reset as i64 - first_reset as i64).abs() <= 1,
517 "Reset timestamps should be within 1 second of each other"
518 );
519 }
520 }
521
522 #[test]
523 fn test_quota_remaining_never_negative() {
524 let config = RateLimitConfig {
525 requests_per_minute: 5,
526 burst: 5,
527 per_ip: false,
528 per_endpoint: false,
529 };
530
531 let limiter = GlobalRateLimiter::new(config);
532
533 for _ in 0..20 {
535 let quota = limiter.get_quota_info();
536 assert!(quota.remaining <= 100, "Remaining should be reasonable");
538 }
539 }
540}