datasynth_server/rest/
rate_limit.rs1use axum::{
6 body::Body,
7 http::{Request, StatusCode},
8 middleware::Next,
9 response::{IntoResponse, Response},
10};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::RwLock;
15
16#[derive(Clone, Debug)]
18pub struct RateLimitConfig {
19 pub enabled: bool,
21 pub max_requests: u32,
23 pub window: Duration,
25 pub exempt_paths: Vec<String>,
27}
28
29impl Default for RateLimitConfig {
30 fn default() -> Self {
31 Self {
32 enabled: false,
33 max_requests: 100,
34 window: Duration::from_secs(60), exempt_paths: vec![
36 "/health".to_string(),
37 "/ready".to_string(),
38 "/live".to_string(),
39 ],
40 }
41 }
42}
43
44impl RateLimitConfig {
45 pub fn new(max_requests: u32, window_secs: u64) -> Self {
47 Self {
48 enabled: true,
49 max_requests,
50 window: Duration::from_secs(window_secs),
51 exempt_paths: vec![
52 "/health".to_string(),
53 "/ready".to_string(),
54 "/live".to_string(),
55 ],
56 }
57 }
58
59 pub fn with_exempt_paths(mut self, paths: Vec<String>) -> Self {
61 self.exempt_paths.extend(paths);
62 self
63 }
64}
65
66#[derive(Clone)]
68struct RequestRecord {
69 count: u32,
70 window_start: Instant,
71}
72
73#[derive(Clone)]
75pub struct RateLimiter {
76 config: RateLimitConfig,
77 records: Arc<RwLock<HashMap<String, RequestRecord>>>,
78}
79
80impl RateLimiter {
81 pub fn new(config: RateLimitConfig) -> Self {
83 Self {
84 config,
85 records: Arc::new(RwLock::new(HashMap::new())),
86 }
87 }
88
89 pub async fn check_rate_limit(&self, key: &str) -> bool {
91 if !self.config.enabled {
92 return true;
93 }
94
95 let mut records = self.records.write().await;
96 let now = Instant::now();
97
98 match records.get_mut(key) {
99 Some(record) => {
100 if now.duration_since(record.window_start) >= self.config.window {
102 record.count = 1;
104 record.window_start = now;
105 true
106 } else if record.count < self.config.max_requests {
107 record.count += 1;
109 true
110 } else {
111 false
113 }
114 }
115 None => {
116 records.insert(
118 key.to_string(),
119 RequestRecord {
120 count: 1,
121 window_start: now,
122 },
123 );
124 true
125 }
126 }
127 }
128
129 pub async fn remaining(&self, key: &str) -> u32 {
131 if !self.config.enabled {
132 return self.config.max_requests;
133 }
134
135 let records = self.records.read().await;
136 match records.get(key) {
137 Some(record) => {
138 let now = Instant::now();
139 if now.duration_since(record.window_start) >= self.config.window {
140 self.config.max_requests
141 } else {
142 self.config.max_requests.saturating_sub(record.count)
143 }
144 }
145 None => self.config.max_requests,
146 }
147 }
148
149 pub async fn cleanup_expired(&self) {
151 let mut records = self.records.write().await;
152 let now = Instant::now();
153 records.retain(|_, record| now.duration_since(record.window_start) < self.config.window);
154 }
155}
156
157pub async fn rate_limit_middleware(
159 axum::Extension(limiter): axum::Extension<RateLimiter>,
160 request: Request<Body>,
161 next: Next,
162) -> Response {
163 let path = request.uri().path();
165 if limiter
166 .config
167 .exempt_paths
168 .iter()
169 .any(|p| path.starts_with(p))
170 {
171 return next.run(request).await;
172 }
173
174 let client_key = extract_client_key(&request);
176
177 if limiter.check_rate_limit(&client_key).await {
179 let remaining = limiter.remaining(&client_key).await;
180 let mut response = next.run(request).await;
181
182 let headers = response.headers_mut();
184 headers.insert(
185 "X-RateLimit-Limit",
186 limiter.config.max_requests.to_string().parse().unwrap(),
187 );
188 headers.insert(
189 "X-RateLimit-Remaining",
190 remaining.to_string().parse().unwrap(),
191 );
192
193 response
194 } else {
195 let window_secs = limiter.config.window.as_secs();
197 (
198 StatusCode::TOO_MANY_REQUESTS,
199 [
200 ("X-RateLimit-Limit", limiter.config.max_requests.to_string()),
201 ("X-RateLimit-Remaining", "0".to_string()),
202 ("Retry-After", window_secs.to_string()),
203 ],
204 format!(
205 "Rate limit exceeded. Max {} requests per {} seconds.",
206 limiter.config.max_requests, window_secs
207 ),
208 )
209 .into_response()
210 }
211}
212
213fn extract_client_key(request: &Request<Body>) -> String {
215 if let Some(forwarded) = request.headers().get("X-Forwarded-For") {
217 if let Ok(s) = forwarded.to_str() {
218 if let Some(ip) = s.split(',').next() {
219 return ip.trim().to_string();
220 }
221 }
222 }
223
224 if let Some(real_ip) = request.headers().get("X-Real-IP") {
226 if let Ok(s) = real_ip.to_str() {
227 return s.to_string();
228 }
229 }
230
231 "unknown".to_string()
233}
234
235#[cfg(test)]
236#[allow(clippy::unwrap_used)]
237mod tests {
238 use super::*;
239 use axum::{body::Body, http::Request, middleware, routing::get, Router};
240 use tower::ServiceExt;
241
242 async fn test_handler() -> &'static str {
243 "ok"
244 }
245
246 fn test_router(config: RateLimitConfig) -> Router {
247 let limiter = RateLimiter::new(config);
248 Router::new()
249 .route("/api/test", get(test_handler))
250 .route("/health", get(test_handler))
251 .layer(middleware::from_fn(rate_limit_middleware))
252 .layer(axum::Extension(limiter))
253 }
254
255 #[tokio::test]
256 async fn test_rate_limit_disabled() {
257 let config = RateLimitConfig::default();
258 let router = test_router(config);
259
260 let request = Request::builder()
261 .uri("/api/test")
262 .body(Body::empty())
263 .unwrap();
264
265 let response = router.oneshot(request).await.unwrap();
266 assert_eq!(response.status(), StatusCode::OK);
267 }
268
269 #[tokio::test]
270 async fn test_rate_limit_allows_under_limit() {
271 let config = RateLimitConfig::new(5, 60);
272 let router = test_router(config);
273
274 for _ in 0..3 {
276 let router = router.clone();
277 let request = Request::builder()
278 .uri("/api/test")
279 .header("X-Forwarded-For", "192.168.1.1")
280 .body(Body::empty())
281 .unwrap();
282
283 let response = router.oneshot(request).await.unwrap();
284 assert_eq!(response.status(), StatusCode::OK);
285 }
286 }
287
288 #[tokio::test]
289 async fn test_rate_limit_blocks_over_limit() {
290 let config = RateLimitConfig::new(2, 60);
291 let limiter = RateLimiter::new(config.clone());
292
293 let router = Router::new()
294 .route("/api/test", get(test_handler))
295 .layer(middleware::from_fn(rate_limit_middleware))
296 .layer(axum::Extension(limiter.clone()));
297
298 for i in 0..3 {
300 let router = router.clone();
301 let request = Request::builder()
302 .uri("/api/test")
303 .header("X-Forwarded-For", "192.168.1.100")
304 .body(Body::empty())
305 .unwrap();
306
307 let response = router.oneshot(request).await.unwrap();
308 if i < 2 {
309 assert_eq!(response.status(), StatusCode::OK);
310 } else {
311 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
312 }
313 }
314 }
315
316 #[tokio::test]
317 async fn test_rate_limit_exempt_path() {
318 let config = RateLimitConfig::new(1, 60);
319 let limiter = RateLimiter::new(config);
320
321 let router = Router::new()
322 .route("/api/test", get(test_handler))
323 .route("/health", get(test_handler))
324 .layer(middleware::from_fn(rate_limit_middleware))
325 .layer(axum::Extension(limiter));
326
327 let request = Request::builder()
329 .uri("/api/test")
330 .header("X-Forwarded-For", "192.168.1.200")
331 .body(Body::empty())
332 .unwrap();
333 let _ = router.clone().oneshot(request).await.unwrap();
334
335 let request = Request::builder()
337 .uri("/health")
338 .header("X-Forwarded-For", "192.168.1.200")
339 .body(Body::empty())
340 .unwrap();
341 let response = router.oneshot(request).await.unwrap();
342 assert_eq!(response.status(), StatusCode::OK);
343 }
344
345 #[tokio::test]
346 async fn test_rate_limiter_cleanup() {
347 let config = RateLimitConfig::new(10, 1); let limiter = RateLimiter::new(config);
349
350 limiter.check_rate_limit("test-client").await;
352
353 assert!(limiter.records.read().await.contains_key("test-client"));
355
356 tokio::time::sleep(Duration::from_millis(1100)).await;
358
359 limiter.cleanup_expired().await;
361
362 assert!(!limiter.records.read().await.contains_key("test-client"));
364 }
365}