mockforge_http/middleware/
rate_limit.rs1use axum::{
7 body::Body,
8 extract::{ConnectInfo, State},
9 http::{HeaderName, HeaderValue, Request, StatusCode},
10 middleware::Next,
11 response::Response,
12};
13use governor::{
14 clock::DefaultClock,
15 state::{InMemoryState, NotKeyed},
16 Quota, RateLimiter,
17};
18use std::net::SocketAddr;
19use std::num::NonZeroU32;
20use std::sync::{Arc, Mutex};
21use std::time::{Duration, SystemTime, UNIX_EPOCH};
22use tracing::warn;
23
24#[derive(Debug, Clone)]
26pub struct RateLimitConfig {
27 pub requests_per_minute: u32,
29 pub burst: u32,
31 pub per_ip: bool,
33 pub per_endpoint: bool,
35}
36
37impl Default for RateLimitConfig {
38 fn default() -> Self {
39 Self {
40 requests_per_minute: 100,
41 burst: 200,
42 per_ip: true,
43 per_endpoint: false,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct RateLimitQuota {
51 pub limit: u32,
53 pub remaining: u32,
55 pub reset: u64,
57}
58
59pub struct GlobalRateLimiter {
61 limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
62 config: RateLimitConfig,
63 window_start: Arc<Mutex<SystemTime>>,
65 remaining_counter: Arc<Mutex<u32>>,
67}
68
69impl GlobalRateLimiter {
70 pub fn new(config: RateLimitConfig) -> Self {
72 let quota = Quota::per_minute(
73 NonZeroU32::new(config.requests_per_minute).unwrap_or(NonZeroU32::new(100).unwrap()),
74 )
75 .allow_burst(NonZeroU32::new(config.burst).unwrap_or(NonZeroU32::new(200).unwrap()));
76
77 let limiter = Arc::new(RateLimiter::direct(quota));
78 let window_start = Arc::new(Mutex::new(SystemTime::now()));
79 let remaining_counter = Arc::new(Mutex::new(config.requests_per_minute));
80
81 Self {
82 limiter,
83 config,
84 window_start,
85 remaining_counter,
86 }
87 }
88
89 pub fn check_rate_limit(&self) -> bool {
91 self.limiter.check().is_ok()
92 }
93
94 pub fn get_quota_info(&self) -> RateLimitQuota {
99 let now = SystemTime::now();
100 let mut window_start = self.window_start.lock().unwrap();
101 let mut remaining = self.remaining_counter.lock().unwrap();
102
103 let window_duration = Duration::from_secs(60);
105 if now.duration_since(*window_start).unwrap_or(Duration::ZERO) >= window_duration {
106 *window_start = now;
108 *remaining = self.config.requests_per_minute;
109 }
110
111 let current_remaining = *remaining;
115 if current_remaining > 0 {
116 *remaining = current_remaining.saturating_sub(1);
117 }
118
119 let reset_timestamp =
121 window_start.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_secs() + 60; RateLimitQuota {
124 limit: self.config.requests_per_minute,
125 remaining: current_remaining,
126 reset: reset_timestamp,
127 }
128 }
129}
130
131pub async fn rate_limit_middleware(
137 State(state): axum::extract::State<crate::HttpServerState>,
138 ConnectInfo(addr): ConnectInfo<SocketAddr>,
139 req: Request<Body>,
140 next: Next,
141) -> Result<Response, StatusCode> {
142 let quota_info = if let Some(limiter) = &state.rate_limiter {
144 if !limiter.check_rate_limit() {
146 warn!("Rate limit exceeded for IP: {}", addr.ip());
147 return Err(StatusCode::TOO_MANY_REQUESTS);
148 }
149
150 Some(limiter.get_quota_info())
152 } else {
153 tracing::debug!("No rate limiter configured, allowing request");
155 None
156 };
157
158 let mut response = next.run(req).await;
160
161 if let Some(quota) = quota_info {
164 let limit_name = HeaderName::from_static("x-rate-limit-limit");
166 if let Ok(limit_value) = HeaderValue::from_str("a.limit.to_string()) {
167 response.headers_mut().insert(limit_name, limit_value);
168 }
169
170 let remaining_name = HeaderName::from_static("x-rate-limit-remaining");
172 if let Ok(remaining_value) = HeaderValue::from_str("a.remaining.to_string()) {
173 response.headers_mut().insert(remaining_name, remaining_value);
174 }
175
176 let reset_name = HeaderName::from_static("x-rate-limit-reset");
178 if let Ok(reset_value) = HeaderValue::from_str("a.reset.to_string()) {
179 response.headers_mut().insert(reset_name, reset_value);
180 }
181 }
182
183 Ok(response)
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn test_rate_limiter_creation() {
192 let config = RateLimitConfig::default();
193 let limiter = GlobalRateLimiter::new(config);
194
195 assert!(limiter.check_rate_limit());
197 }
198
199 #[test]
200 fn test_rate_limiter_burst() {
201 let config = RateLimitConfig {
202 requests_per_minute: 10,
203 burst: 5,
204 per_ip: false,
205 per_endpoint: false,
206 };
207
208 let limiter = GlobalRateLimiter::new(config);
209
210 for _ in 0..5 {
212 assert!(limiter.check_rate_limit(), "Burst request should be allowed");
213 }
214 }
215
216 #[test]
217 fn test_rate_limit_config_default() {
218 let config = RateLimitConfig::default();
219 assert_eq!(config.requests_per_minute, 100);
220 assert_eq!(config.burst, 200);
221 assert!(config.per_ip);
222 assert!(!config.per_endpoint);
223 }
224}