mockforge_http/middleware/
rate_limit.rs1use axum::{
7 body::Body,
8 extract::{ConnectInfo, State},
9 http::{HeaderName, HeaderValue, Request, StatusCode},
10 middleware::Next,
11 response::Response,
12};
13use governor::{
14 clock::DefaultClock,
15 state::{InMemoryState, NotKeyed},
16 Quota, RateLimiter,
17};
18use std::net::SocketAddr;
19use std::num::NonZeroU32;
20use std::sync::{Arc, Mutex};
21use std::time::{Duration, SystemTime, UNIX_EPOCH};
22use tracing::warn;
23
24#[derive(Debug, Clone)]
26pub struct RateLimitConfig {
27 pub requests_per_minute: u32,
29 pub burst: u32,
31 pub per_ip: bool,
33 pub per_endpoint: bool,
35}
36
37impl Default for RateLimitConfig {
38 fn default() -> Self {
39 Self {
40 requests_per_minute: 100,
41 burst: 200,
42 per_ip: true,
43 per_endpoint: false,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct RateLimitQuota {
51 pub limit: u32,
53 pub remaining: u32,
55 pub reset: u64,
57}
58
59pub struct GlobalRateLimiter {
61 limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
62 config: RateLimitConfig,
63 window_start: Arc<Mutex<SystemTime>>,
65 remaining_counter: Arc<Mutex<u32>>,
67}
68
69impl GlobalRateLimiter {
70 pub fn new(config: RateLimitConfig) -> Self {
72 let quota = Quota::per_minute(
73 NonZeroU32::new(config.requests_per_minute)
74 .unwrap_or(NonZeroU32::new(100).expect("constant 100 is non-zero")),
75 )
76 .allow_burst(
77 NonZeroU32::new(config.burst)
78 .unwrap_or(NonZeroU32::new(200).expect("constant 200 is non-zero")),
79 );
80
81 let limiter = Arc::new(RateLimiter::direct(quota));
82 let window_start = Arc::new(Mutex::new(SystemTime::now()));
83 let remaining_counter = Arc::new(Mutex::new(config.requests_per_minute));
84
85 Self {
86 limiter,
87 config,
88 window_start,
89 remaining_counter,
90 }
91 }
92
93 pub fn check_rate_limit(&self) -> bool {
95 self.limiter.check().is_ok()
96 }
97
98 pub fn get_quota_info(&self) -> RateLimitQuota {
103 let now = SystemTime::now();
104 let mut window_start =
105 self.window_start.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
106 let mut remaining =
107 self.remaining_counter.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
108
109 let window_duration = Duration::from_secs(60);
111 if now.duration_since(*window_start).unwrap_or(Duration::ZERO) >= window_duration {
112 *window_start = now;
114 *remaining = self.config.requests_per_minute;
115 }
116
117 let current_remaining = *remaining;
121 if current_remaining > 0 {
122 *remaining = current_remaining.saturating_sub(1);
123 }
124
125 let reset_timestamp =
127 window_start.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_secs() + 60; RateLimitQuota {
130 limit: self.config.requests_per_minute,
131 remaining: current_remaining,
132 reset: reset_timestamp,
133 }
134 }
135}
136
137pub async fn rate_limit_middleware(
143 State(state): axum::extract::State<crate::HttpServerState>,
144 ConnectInfo(addr): ConnectInfo<SocketAddr>,
145 req: Request<Body>,
146 next: Next,
147) -> Result<Response, StatusCode> {
148 let quota_info = if let Some(limiter) = &state.rate_limiter {
150 if !limiter.check_rate_limit() {
152 warn!("Rate limit exceeded for IP: {}", addr.ip());
153 return Err(StatusCode::TOO_MANY_REQUESTS);
154 }
155
156 Some(limiter.get_quota_info())
158 } else {
159 tracing::debug!("No rate limiter configured, allowing request");
161 None
162 };
163
164 let mut response = next.run(req).await;
166
167 if let Some(quota) = quota_info {
170 let limit_name = HeaderName::from_static("x-rate-limit-limit");
172 if let Ok(limit_value) = HeaderValue::from_str("a.limit.to_string()) {
173 response.headers_mut().insert(limit_name, limit_value);
174 }
175
176 let remaining_name = HeaderName::from_static("x-rate-limit-remaining");
178 if let Ok(remaining_value) = HeaderValue::from_str("a.remaining.to_string()) {
179 response.headers_mut().insert(remaining_name, remaining_value);
180 }
181
182 let reset_name = HeaderName::from_static("x-rate-limit-reset");
184 if let Ok(reset_value) = HeaderValue::from_str("a.reset.to_string()) {
185 response.headers_mut().insert(reset_name, reset_value);
186 }
187 }
188
189 Ok(response)
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
199 fn test_rate_limit_config_default() {
200 let config = RateLimitConfig::default();
201 assert_eq!(config.requests_per_minute, 100);
202 assert_eq!(config.burst, 200);
203 assert!(config.per_ip);
204 assert!(!config.per_endpoint);
205 }
206
207 #[test]
208 fn test_rate_limit_config_custom() {
209 let config = RateLimitConfig {
210 requests_per_minute: 50,
211 burst: 100,
212 per_ip: false,
213 per_endpoint: true,
214 };
215
216 assert_eq!(config.requests_per_minute, 50);
217 assert_eq!(config.burst, 100);
218 assert!(!config.per_ip);
219 assert!(config.per_endpoint);
220 }
221
222 #[test]
223 fn test_rate_limit_config_clone() {
224 let config = RateLimitConfig {
225 requests_per_minute: 75,
226 burst: 150,
227 per_ip: true,
228 per_endpoint: true,
229 };
230
231 let cloned = config.clone();
232
233 assert_eq!(cloned.requests_per_minute, config.requests_per_minute);
234 assert_eq!(cloned.burst, config.burst);
235 assert_eq!(cloned.per_ip, config.per_ip);
236 assert_eq!(cloned.per_endpoint, config.per_endpoint);
237 }
238
239 #[test]
240 fn test_rate_limit_config_debug() {
241 let config = RateLimitConfig::default();
242 let debug_str = format!("{:?}", config);
243
244 assert!(debug_str.contains("requests_per_minute"));
245 assert!(debug_str.contains("burst"));
246 assert!(debug_str.contains("per_ip"));
247 assert!(debug_str.contains("per_endpoint"));
248 }
249
250 #[test]
253 fn test_rate_limit_quota_creation() {
254 let quota = RateLimitQuota {
255 limit: 100,
256 remaining: 50,
257 reset: 1234567890,
258 };
259
260 assert_eq!(quota.limit, 100);
261 assert_eq!(quota.remaining, 50);
262 assert_eq!(quota.reset, 1234567890);
263 }
264
265 #[test]
266 fn test_rate_limit_quota_clone() {
267 let quota = RateLimitQuota {
268 limit: 200,
269 remaining: 175,
270 reset: 9876543210,
271 };
272
273 let cloned = quota.clone();
274
275 assert_eq!(cloned.limit, quota.limit);
276 assert_eq!(cloned.remaining, quota.remaining);
277 assert_eq!(cloned.reset, quota.reset);
278 }
279
280 #[test]
281 fn test_rate_limit_quota_debug() {
282 let quota = RateLimitQuota {
283 limit: 100,
284 remaining: 50,
285 reset: 1234567890,
286 };
287
288 let debug_str = format!("{:?}", quota);
289
290 assert!(debug_str.contains("limit"));
291 assert!(debug_str.contains("remaining"));
292 assert!(debug_str.contains("reset"));
293 }
294
295 #[test]
298 fn test_rate_limiter_creation() {
299 let config = RateLimitConfig::default();
300 let limiter = GlobalRateLimiter::new(config);
301
302 assert!(limiter.check_rate_limit());
304 }
305
306 #[test]
307 fn test_rate_limiter_with_custom_config() {
308 let config = RateLimitConfig {
309 requests_per_minute: 60,
310 burst: 10,
311 per_ip: false,
312 per_endpoint: false,
313 };
314
315 let limiter = GlobalRateLimiter::new(config);
316 assert!(limiter.check_rate_limit());
317 }
318
319 #[test]
320 fn test_rate_limiter_burst() {
321 let config = RateLimitConfig {
322 requests_per_minute: 10,
323 burst: 5,
324 per_ip: false,
325 per_endpoint: false,
326 };
327
328 let limiter = GlobalRateLimiter::new(config);
329
330 for _ in 0..5 {
332 assert!(limiter.check_rate_limit(), "Burst request should be allowed");
333 }
334 }
335
336 #[test]
337 fn test_rate_limiter_multiple_requests() {
338 let config = RateLimitConfig {
339 requests_per_minute: 1000,
340 burst: 100,
341 per_ip: false,
342 per_endpoint: false,
343 };
344
345 let limiter = GlobalRateLimiter::new(config);
346
347 for i in 0..50 {
349 assert!(limiter.check_rate_limit(), "Request {} should be allowed", i);
350 }
351 }
352
353 #[test]
354 fn test_get_quota_info() {
355 let config = RateLimitConfig {
356 requests_per_minute: 100,
357 burst: 50,
358 per_ip: false,
359 per_endpoint: false,
360 };
361
362 let limiter = GlobalRateLimiter::new(config);
363
364 let quota = limiter.get_quota_info();
365
366 assert_eq!(quota.limit, 100);
367 assert!(quota.remaining > 0);
368 assert!(quota.reset > 0);
369 }
370
371 #[test]
372 fn test_quota_info_limit_matches_config() {
373 let config = RateLimitConfig {
374 requests_per_minute: 500,
375 burst: 100,
376 per_ip: false,
377 per_endpoint: false,
378 };
379
380 let limiter = GlobalRateLimiter::new(config);
381 let quota = limiter.get_quota_info();
382
383 assert_eq!(quota.limit, 500);
384 }
385
386 #[test]
387 fn test_quota_decrements_remaining() {
388 let config = RateLimitConfig {
389 requests_per_minute: 100,
390 burst: 50,
391 per_ip: false,
392 per_endpoint: false,
393 };
394
395 let limiter = GlobalRateLimiter::new(config);
396
397 let first_quota = limiter.get_quota_info();
398 let second_quota = limiter.get_quota_info();
399
400 assert!(second_quota.remaining <= first_quota.remaining, "Remaining should not increase");
402 }
403
404 #[test]
405 fn test_quota_reset_timestamp_is_future() {
406 let config = RateLimitConfig::default();
407 let limiter = GlobalRateLimiter::new(config);
408
409 let quota = limiter.get_quota_info();
410
411 let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
412
413 assert!(quota.reset >= now, "Reset timestamp should be >= current time");
415 assert!(quota.reset <= now + 120, "Reset timestamp should be within 2 minutes");
416 }
417
418 #[test]
419 fn test_rate_limiter_high_burst() {
420 let config = RateLimitConfig {
421 requests_per_minute: 10,
422 burst: 1000, per_ip: false,
424 per_endpoint: false,
425 };
426
427 let limiter = GlobalRateLimiter::new(config);
428
429 for _ in 0..100 {
431 assert!(limiter.check_rate_limit());
432 }
433 }
434
435 #[test]
436 fn test_rate_limiter_low_limit() {
437 let config = RateLimitConfig {
438 requests_per_minute: 1,
439 burst: 1,
440 per_ip: false,
441 per_endpoint: false,
442 };
443
444 let limiter = GlobalRateLimiter::new(config);
445
446 assert!(limiter.check_rate_limit());
448 }
449
450 #[test]
453 fn test_config_with_zero_values_handled() {
454 let config = RateLimitConfig {
456 requests_per_minute: 0, burst: 0, per_ip: false,
459 per_endpoint: false,
460 };
461
462 let limiter = GlobalRateLimiter::new(config);
464 assert!(limiter.check_rate_limit());
465 }
466
467 #[test]
468 fn test_multiple_quota_calls_same_limiter() {
469 let config = RateLimitConfig::default();
470 let limiter = GlobalRateLimiter::new(config);
471
472 let quotas: Vec<RateLimitQuota> = (0..5).map(|_| limiter.get_quota_info()).collect();
474
475 for quota in "as {
477 assert_eq!(quota.limit, 100);
478 }
479
480 let first_reset = quotas[0].reset;
482 for quota in "as {
483 assert!(
484 (quota.reset as i64 - first_reset as i64).abs() <= 1,
485 "Reset timestamps should be within 1 second of each other"
486 );
487 }
488 }
489
490 #[test]
491 fn test_quota_remaining_never_negative() {
492 let config = RateLimitConfig {
493 requests_per_minute: 5,
494 burst: 5,
495 per_ip: false,
496 per_endpoint: false,
497 };
498
499 let limiter = GlobalRateLimiter::new(config);
500
501 for _ in 0..20 {
503 let quota = limiter.get_quota_info();
504 assert!(quota.remaining <= 100, "Remaining should be reasonable");
506 }
507 }
508}