datasynth_server/rest/
rate_limit.rs1use axum::{
6 body::Body,
7 http::{header::HeaderValue, Request, StatusCode},
8 middleware::Next,
9 response::{IntoResponse, Response},
10};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::RwLock;
15
16#[derive(Clone, Debug)]
18pub struct RateLimitConfig {
19 pub enabled: bool,
21 pub max_requests: u32,
23 pub window: Duration,
25 pub exempt_paths: Vec<String>,
27}
28
29impl Default for RateLimitConfig {
30 fn default() -> Self {
31 Self {
32 enabled: false,
33 max_requests: 100,
34 window: Duration::from_secs(60), exempt_paths: vec![
36 "/health".to_string(),
37 "/ready".to_string(),
38 "/live".to_string(),
39 ],
40 }
41 }
42}
43
44impl RateLimitConfig {
45 pub fn new(max_requests: u32, window_secs: u64) -> Self {
47 Self {
48 enabled: true,
49 max_requests,
50 window: Duration::from_secs(window_secs),
51 exempt_paths: vec![
52 "/health".to_string(),
53 "/ready".to_string(),
54 "/live".to_string(),
55 ],
56 }
57 }
58
59 pub fn with_exempt_paths(mut self, paths: Vec<String>) -> Self {
61 self.exempt_paths.extend(paths);
62 self
63 }
64}
65
66#[derive(Clone)]
68struct RequestRecord {
69 count: u32,
70 window_start: Instant,
71}
72
73#[derive(Clone)]
75pub struct RateLimiter {
76 config: RateLimitConfig,
77 records: Arc<RwLock<HashMap<String, RequestRecord>>>,
78}
79
80impl RateLimiter {
81 pub fn new(config: RateLimitConfig) -> Self {
83 Self {
84 config,
85 records: Arc::new(RwLock::new(HashMap::new())),
86 }
87 }
88
89 pub async fn check_rate_limit(&self, key: &str) -> bool {
91 if !self.config.enabled {
92 return true;
93 }
94
95 let mut records = self.records.write().await;
96 let now = Instant::now();
97
98 match records.get_mut(key) {
99 Some(record) => {
100 if now.duration_since(record.window_start) >= self.config.window {
102 record.count = 1;
104 record.window_start = now;
105 true
106 } else if record.count < self.config.max_requests {
107 record.count += 1;
109 true
110 } else {
111 false
113 }
114 }
115 None => {
116 records.insert(
118 key.to_string(),
119 RequestRecord {
120 count: 1,
121 window_start: now,
122 },
123 );
124 true
125 }
126 }
127 }
128
129 pub async fn remaining(&self, key: &str) -> u32 {
131 if !self.config.enabled {
132 return self.config.max_requests;
133 }
134
135 let records = self.records.read().await;
136 match records.get(key) {
137 Some(record) => {
138 let now = Instant::now();
139 if now.duration_since(record.window_start) >= self.config.window {
140 self.config.max_requests
141 } else {
142 self.config.max_requests.saturating_sub(record.count)
143 }
144 }
145 None => self.config.max_requests,
146 }
147 }
148
149 pub async fn cleanup_expired(&self) {
151 let mut records = self.records.write().await;
152 let now = Instant::now();
153 records.retain(|_, record| now.duration_since(record.window_start) < self.config.window);
154 }
155}
156
157pub async fn rate_limit_middleware(
159 axum::Extension(limiter): axum::Extension<RateLimiter>,
160 request: Request<Body>,
161 next: Next,
162) -> Response {
163 let path = request.uri().path();
165 if limiter
166 .config
167 .exempt_paths
168 .iter()
169 .any(|p| path.starts_with(p))
170 {
171 return next.run(request).await;
172 }
173
174 let client_key = extract_client_key(&request);
176
177 if limiter.check_rate_limit(&client_key).await {
179 let remaining = limiter.remaining(&client_key).await;
180 let mut response = next.run(request).await;
181
182 let headers = response.headers_mut();
184 if let Ok(val) = HeaderValue::try_from(limiter.config.max_requests.to_string()) {
185 headers.insert("X-RateLimit-Limit", val);
186 }
187 if let Ok(val) = HeaderValue::try_from(remaining.to_string()) {
188 headers.insert("X-RateLimit-Remaining", val);
189 }
190
191 response
192 } else {
193 let window_secs = limiter.config.window.as_secs();
195 (
196 StatusCode::TOO_MANY_REQUESTS,
197 [
198 ("X-RateLimit-Limit", limiter.config.max_requests.to_string()),
199 ("X-RateLimit-Remaining", "0".to_string()),
200 ("Retry-After", window_secs.to_string()),
201 ],
202 format!(
203 "Rate limit exceeded. Max {} requests per {} seconds.",
204 limiter.config.max_requests, window_secs
205 ),
206 )
207 .into_response()
208 }
209}
210
211fn extract_client_key(request: &Request<Body>) -> String {
213 if let Some(forwarded) = request.headers().get("X-Forwarded-For") {
215 if let Ok(s) = forwarded.to_str() {
216 if let Some(ip) = s.split(',').next() {
217 return ip.trim().to_string();
218 }
219 }
220 }
221
222 if let Some(real_ip) = request.headers().get("X-Real-IP") {
224 if let Ok(s) = real_ip.to_str() {
225 return s.to_string();
226 }
227 }
228
229 "unknown".to_string()
231}
232
233#[cfg(test)]
234#[allow(clippy::unwrap_used)]
235mod tests {
236 use super::*;
237 use axum::{body::Body, http::Request, middleware, routing::get, Router};
238 use tower::ServiceExt;
239
240 async fn test_handler() -> &'static str {
241 "ok"
242 }
243
244 fn test_router(config: RateLimitConfig) -> Router {
245 let limiter = RateLimiter::new(config);
246 Router::new()
247 .route("/api/test", get(test_handler))
248 .route("/health", get(test_handler))
249 .layer(middleware::from_fn(rate_limit_middleware))
250 .layer(axum::Extension(limiter))
251 }
252
253 #[tokio::test]
254 async fn test_rate_limit_disabled() {
255 let config = RateLimitConfig::default();
256 let router = test_router(config);
257
258 let request = Request::builder()
259 .uri("/api/test")
260 .body(Body::empty())
261 .unwrap();
262
263 let response = router.oneshot(request).await.unwrap();
264 assert_eq!(response.status(), StatusCode::OK);
265 }
266
267 #[tokio::test]
268 async fn test_rate_limit_allows_under_limit() {
269 let config = RateLimitConfig::new(5, 60);
270 let router = test_router(config);
271
272 for _ in 0..3 {
274 let router = router.clone();
275 let request = Request::builder()
276 .uri("/api/test")
277 .header("X-Forwarded-For", "192.168.1.1")
278 .body(Body::empty())
279 .unwrap();
280
281 let response = router.oneshot(request).await.unwrap();
282 assert_eq!(response.status(), StatusCode::OK);
283 }
284 }
285
286 #[tokio::test]
287 async fn test_rate_limit_blocks_over_limit() {
288 let config = RateLimitConfig::new(2, 60);
289 let limiter = RateLimiter::new(config.clone());
290
291 let router = Router::new()
292 .route("/api/test", get(test_handler))
293 .layer(middleware::from_fn(rate_limit_middleware))
294 .layer(axum::Extension(limiter.clone()));
295
296 for i in 0..3 {
298 let router = router.clone();
299 let request = Request::builder()
300 .uri("/api/test")
301 .header("X-Forwarded-For", "192.168.1.100")
302 .body(Body::empty())
303 .unwrap();
304
305 let response = router.oneshot(request).await.unwrap();
306 if i < 2 {
307 assert_eq!(response.status(), StatusCode::OK);
308 } else {
309 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
310 }
311 }
312 }
313
314 #[tokio::test]
315 async fn test_rate_limit_exempt_path() {
316 let config = RateLimitConfig::new(1, 60);
317 let limiter = RateLimiter::new(config);
318
319 let router = Router::new()
320 .route("/api/test", get(test_handler))
321 .route("/health", get(test_handler))
322 .layer(middleware::from_fn(rate_limit_middleware))
323 .layer(axum::Extension(limiter));
324
325 let request = Request::builder()
327 .uri("/api/test")
328 .header("X-Forwarded-For", "192.168.1.200")
329 .body(Body::empty())
330 .unwrap();
331 let _ = router.clone().oneshot(request).await.unwrap();
332
333 let request = Request::builder()
335 .uri("/health")
336 .header("X-Forwarded-For", "192.168.1.200")
337 .body(Body::empty())
338 .unwrap();
339 let response = router.oneshot(request).await.unwrap();
340 assert_eq!(response.status(), StatusCode::OK);
341 }
342
343 #[tokio::test]
344 async fn test_rate_limiter_cleanup() {
345 let config = RateLimitConfig::new(10, 1); let limiter = RateLimiter::new(config);
347
348 limiter.check_rate_limit("test-client").await;
350
351 assert!(limiter.records.read().await.contains_key("test-client"));
353
354 tokio::time::sleep(Duration::from_millis(1100)).await;
356
357 limiter.cleanup_expired().await;
359
360 assert!(!limiter.records.read().await.contains_key("test-client"));
362 }
363}