mockforge_http/middleware/
rate_limit.rs1use axum::{
6 body::Body,
7 extract::{ConnectInfo, State},
8 http::{Request, StatusCode},
9 middleware::Next,
10 response::Response,
11};
12use governor::{
13 clock::DefaultClock,
14 state::{InMemoryState, NotKeyed},
15 Quota, RateLimiter,
16};
17use std::net::SocketAddr;
18use std::num::NonZeroU32;
19use std::sync::Arc;
20use tracing::warn;
21
22#[derive(Debug, Clone)]
24pub struct RateLimitConfig {
25 pub requests_per_minute: u32,
27 pub burst: u32,
29 pub per_ip: bool,
31 pub per_endpoint: bool,
33}
34
35impl Default for RateLimitConfig {
36 fn default() -> Self {
37 Self {
38 requests_per_minute: 100,
39 burst: 200,
40 per_ip: true,
41 per_endpoint: false,
42 }
43 }
44}
45
46pub struct GlobalRateLimiter {
48 limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
49 _config: RateLimitConfig,
50}
51
52impl GlobalRateLimiter {
53 pub fn new(config: RateLimitConfig) -> Self {
55 let quota = Quota::per_minute(
56 NonZeroU32::new(config.requests_per_minute).unwrap_or(NonZeroU32::new(100).unwrap()),
57 )
58 .allow_burst(NonZeroU32::new(config.burst).unwrap_or(NonZeroU32::new(200).unwrap()));
59
60 let limiter = Arc::new(RateLimiter::direct(quota));
61
62 Self {
63 limiter,
64 _config: config,
65 }
66 }
67
68 pub fn check_rate_limit(&self) -> bool {
70 self.limiter.check().is_ok()
71 }
72}
73
74pub async fn rate_limit_middleware(
76 State(state): axum::extract::State<crate::HttpServerState>,
77 ConnectInfo(addr): ConnectInfo<SocketAddr>,
78 req: Request<Body>,
79 next: Next,
80) -> Result<Response, StatusCode> {
81 if let Some(limiter) = &state.rate_limiter {
83 if !limiter.check_rate_limit() {
84 warn!("Rate limit exceeded for IP: {}", addr.ip());
85 return Err(StatusCode::TOO_MANY_REQUESTS);
86 }
87 } else {
88 tracing::debug!("No rate limiter configured, allowing request");
90 }
91
92 Ok(next.run(req).await)
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98
99 #[test]
100 fn test_rate_limiter_creation() {
101 let config = RateLimitConfig::default();
102 let limiter = GlobalRateLimiter::new(config);
103
104 assert!(limiter.check_rate_limit());
106 }
107
108 #[test]
109 fn test_rate_limiter_burst() {
110 let config = RateLimitConfig {
111 requests_per_minute: 10,
112 burst: 5,
113 per_ip: false,
114 per_endpoint: false,
115 };
116
117 let limiter = GlobalRateLimiter::new(config);
118
119 for _ in 0..5 {
121 assert!(limiter.check_rate_limit(), "Burst request should be allowed");
122 }
123 }
124
125 #[test]
126 fn test_rate_limit_config_default() {
127 let config = RateLimitConfig::default();
128 assert_eq!(config.requests_per_minute, 100);
129 assert_eq!(config.burst, 200);
130 assert!(config.per_ip);
131 assert!(!config.per_endpoint);
132 }
133}