datasynth_server/rest/
rate_limit.rs1use axum::{
6 body::Body,
7 http::{header::HeaderValue, Request, StatusCode},
8 middleware::Next,
9 response::{IntoResponse, Response},
10};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::RwLock;
15
16#[derive(Clone, Debug)]
18pub struct RateLimitConfig {
19 pub enabled: bool,
21 pub max_requests: u32,
23 pub window: Duration,
25 pub exempt_paths: Vec<String>,
27}
28
29impl Default for RateLimitConfig {
30 fn default() -> Self {
31 Self {
32 enabled: false,
33 max_requests: 100,
34 window: Duration::from_secs(60), exempt_paths: vec![
36 "/health".to_string(),
37 "/ready".to_string(),
38 "/live".to_string(),
39 ],
40 }
41 }
42}
43
44impl RateLimitConfig {
45 pub fn new(max_requests: u32, window_secs: u64) -> Self {
47 Self {
48 enabled: true,
49 max_requests,
50 window: Duration::from_secs(window_secs),
51 exempt_paths: vec![
52 "/health".to_string(),
53 "/ready".to_string(),
54 "/live".to_string(),
55 ],
56 }
57 }
58
59 pub fn with_exempt_paths(mut self, paths: Vec<String>) -> Self {
61 self.exempt_paths.extend(paths);
62 self
63 }
64}
65
66#[derive(Clone)]
68struct RequestRecord {
69 count: u32,
70 window_start: Instant,
71}
72
73#[derive(Clone)]
75pub struct RateLimiter {
76 config: RateLimitConfig,
77 records: Arc<RwLock<HashMap<String, RequestRecord>>>,
78}
79
80impl RateLimiter {
81 pub fn new(config: RateLimitConfig) -> Self {
83 Self {
84 config,
85 records: Arc::new(RwLock::new(HashMap::new())),
86 }
87 }
88
89 pub async fn check_rate_limit(&self, key: &str) -> bool {
91 if !self.config.enabled {
92 return true;
93 }
94
95 let mut records = self.records.write().await;
96 let now = Instant::now();
97
98 match records.get_mut(key) {
99 Some(record) => {
100 if now.duration_since(record.window_start) >= self.config.window {
102 record.count = 1;
104 record.window_start = now;
105 true
106 } else if record.count < self.config.max_requests {
107 record.count += 1;
109 true
110 } else {
111 false
113 }
114 }
115 None => {
116 records.insert(
118 key.to_string(),
119 RequestRecord {
120 count: 1,
121 window_start: now,
122 },
123 );
124 true
125 }
126 }
127 }
128
129 pub async fn remaining(&self, key: &str) -> u32 {
131 if !self.config.enabled {
132 return self.config.max_requests;
133 }
134
135 let records = self.records.read().await;
136 match records.get(key) {
137 Some(record) => {
138 let now = Instant::now();
139 if now.duration_since(record.window_start) >= self.config.window {
140 self.config.max_requests
141 } else {
142 self.config.max_requests.saturating_sub(record.count)
143 }
144 }
145 None => self.config.max_requests,
146 }
147 }
148
149 pub async fn cleanup_expired(&self) {
151 let mut records = self.records.write().await;
152 let now = Instant::now();
153 records.retain(|_, record| now.duration_since(record.window_start) < self.config.window);
154 }
155}
156
157pub async fn rate_limit_middleware(
159 axum::Extension(limiter): axum::Extension<RateLimiter>,
160 request: Request<Body>,
161 next: Next,
162) -> Response {
163 let path = request.uri().path();
165 if limiter
166 .config
167 .exempt_paths
168 .iter()
169 .any(|p| path.starts_with(p))
170 {
171 return next.run(request).await;
172 }
173
174 let client_key = extract_client_key(&request);
176
177 if limiter.check_rate_limit(&client_key).await {
179 let remaining = limiter.remaining(&client_key).await;
180 let mut response = next.run(request).await;
181
182 let headers = response.headers_mut();
184 if let Ok(val) = HeaderValue::try_from(limiter.config.max_requests.to_string()) {
185 headers.insert("X-RateLimit-Limit", val);
186 }
187 if let Ok(val) = HeaderValue::try_from(remaining.to_string()) {
188 headers.insert("X-RateLimit-Remaining", val);
189 }
190
191 response
192 } else {
193 let window_secs = limiter.config.window.as_secs();
195 (
196 StatusCode::TOO_MANY_REQUESTS,
197 [
198 ("X-RateLimit-Limit", limiter.config.max_requests.to_string()),
199 ("X-RateLimit-Remaining", "0".to_string()),
200 ("Retry-After", window_secs.to_string()),
201 ],
202 format!(
203 "Rate limit exceeded. Max {} requests per {} seconds.",
204 limiter.config.max_requests, window_secs
205 ),
206 )
207 .into_response()
208 }
209}
210
211fn extract_client_key(request: &Request<Body>) -> String {
213 if let Some(forwarded) = request.headers().get("X-Forwarded-For") {
215 if let Ok(s) = forwarded.to_str() {
216 if let Some(ip) = s.split(',').next() {
217 return ip.trim().to_string();
218 }
219 }
220 }
221
222 if let Some(real_ip) = request.headers().get("X-Real-IP") {
224 if let Ok(s) = real_ip.to_str() {
225 return s.to_string();
226 }
227 }
228
229 "unknown".to_string()
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use axum::{body::Body, http::Request, middleware, routing::get, Router};
237 use tower::ServiceExt;
238
239 async fn test_handler() -> &'static str {
240 "ok"
241 }
242
243 fn test_router(config: RateLimitConfig) -> Router {
244 let limiter = RateLimiter::new(config);
245 Router::new()
246 .route("/api/test", get(test_handler))
247 .route("/health", get(test_handler))
248 .layer(middleware::from_fn(rate_limit_middleware))
249 .layer(axum::Extension(limiter))
250 }
251
252 #[tokio::test]
253 async fn test_rate_limit_disabled() {
254 let config = RateLimitConfig::default();
255 let router = test_router(config);
256
257 let request = Request::builder()
258 .uri("/api/test")
259 .body(Body::empty())
260 .unwrap();
261
262 let response = router.oneshot(request).await.unwrap();
263 assert_eq!(response.status(), StatusCode::OK);
264 }
265
266 #[tokio::test]
267 async fn test_rate_limit_allows_under_limit() {
268 let config = RateLimitConfig::new(5, 60);
269 let router = test_router(config);
270
271 for _ in 0..3 {
273 let router = router.clone();
274 let request = Request::builder()
275 .uri("/api/test")
276 .header("X-Forwarded-For", "192.168.1.1")
277 .body(Body::empty())
278 .unwrap();
279
280 let response = router.oneshot(request).await.unwrap();
281 assert_eq!(response.status(), StatusCode::OK);
282 }
283 }
284
285 #[tokio::test]
286 async fn test_rate_limit_blocks_over_limit() {
287 let config = RateLimitConfig::new(2, 60);
288 let limiter = RateLimiter::new(config.clone());
289
290 let router = Router::new()
291 .route("/api/test", get(test_handler))
292 .layer(middleware::from_fn(rate_limit_middleware))
293 .layer(axum::Extension(limiter.clone()));
294
295 for i in 0..3 {
297 let router = router.clone();
298 let request = Request::builder()
299 .uri("/api/test")
300 .header("X-Forwarded-For", "192.168.1.100")
301 .body(Body::empty())
302 .unwrap();
303
304 let response = router.oneshot(request).await.unwrap();
305 if i < 2 {
306 assert_eq!(response.status(), StatusCode::OK);
307 } else {
308 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
309 }
310 }
311 }
312
313 #[tokio::test]
314 async fn test_rate_limit_exempt_path() {
315 let config = RateLimitConfig::new(1, 60);
316 let limiter = RateLimiter::new(config);
317
318 let router = Router::new()
319 .route("/api/test", get(test_handler))
320 .route("/health", get(test_handler))
321 .layer(middleware::from_fn(rate_limit_middleware))
322 .layer(axum::Extension(limiter));
323
324 let request = Request::builder()
326 .uri("/api/test")
327 .header("X-Forwarded-For", "192.168.1.200")
328 .body(Body::empty())
329 .unwrap();
330 let _ = router.clone().oneshot(request).await.unwrap();
331
332 let request = Request::builder()
334 .uri("/health")
335 .header("X-Forwarded-For", "192.168.1.200")
336 .body(Body::empty())
337 .unwrap();
338 let response = router.oneshot(request).await.unwrap();
339 assert_eq!(response.status(), StatusCode::OK);
340 }
341
342 #[tokio::test]
343 async fn test_rate_limiter_cleanup() {
344 let config = RateLimitConfig::new(10, 1); let limiter = RateLimiter::new(config);
346
347 limiter.check_rate_limit("test-client").await;
349
350 assert!(limiter.records.read().await.contains_key("test-client"));
352
353 tokio::time::sleep(Duration::from_millis(1100)).await;
355
356 limiter.cleanup_expired().await;
358
359 assert!(!limiter.records.read().await.contains_key("test-client"));
361 }
362}