Skip to main content

datasynth_server/rest/
rate_limit.rs

1//! Rate limiting middleware for REST API.
2//!
3//! Provides configurable rate limiting to prevent abuse.
4
5use axum::{
6    body::Body,
7    http::{header::HeaderValue, Request, StatusCode},
8    middleware::Next,
9    response::{IntoResponse, Response},
10};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::RwLock;
15
16/// Rate limiting configuration.
17#[derive(Clone, Debug)]
18pub struct RateLimitConfig {
19    /// Whether rate limiting is enabled.
20    pub enabled: bool,
21    /// Maximum requests per window.
22    pub max_requests: u32,
23    /// Time window duration.
24    pub window: Duration,
25    /// Paths exempt from rate limiting.
26    pub exempt_paths: Vec<String>,
27}
28
29impl Default for RateLimitConfig {
30    fn default() -> Self {
31        Self {
32            enabled: false,
33            max_requests: 100,
34            window: Duration::from_secs(60), // 100 requests per minute
35            exempt_paths: vec![
36                "/health".to_string(),
37                "/ready".to_string(),
38                "/live".to_string(),
39            ],
40        }
41    }
42}
43
44impl RateLimitConfig {
45    /// Create a new rate limit config with custom limits.
46    pub fn new(max_requests: u32, window_secs: u64) -> Self {
47        Self {
48            enabled: true,
49            max_requests,
50            window: Duration::from_secs(window_secs),
51            exempt_paths: vec![
52                "/health".to_string(),
53                "/ready".to_string(),
54                "/live".to_string(),
55            ],
56        }
57    }
58
59    /// Add exempt paths.
60    pub fn with_exempt_paths(mut self, paths: Vec<String>) -> Self {
61        self.exempt_paths.extend(paths);
62        self
63    }
64}
65
66/// Request record for rate limiting.
67#[derive(Clone)]
68struct RequestRecord {
69    count: u32,
70    window_start: Instant,
71}
72
73/// Shared rate limiter state.
74#[derive(Clone)]
75pub struct RateLimiter {
76    config: RateLimitConfig,
77    records: Arc<RwLock<HashMap<String, RequestRecord>>>,
78}
79
80impl RateLimiter {
81    /// Create a new rate limiter.
82    pub fn new(config: RateLimitConfig) -> Self {
83        Self {
84            config,
85            records: Arc::new(RwLock::new(HashMap::new())),
86        }
87    }
88
89    /// Check if request should be allowed.
90    pub async fn check_rate_limit(&self, key: &str) -> bool {
91        if !self.config.enabled {
92            return true;
93        }
94
95        let mut records = self.records.write().await;
96        let now = Instant::now();
97
98        match records.get_mut(key) {
99            Some(record) => {
100                // Check if we're in a new window
101                if now.duration_since(record.window_start) >= self.config.window {
102                    // Reset for new window
103                    record.count = 1;
104                    record.window_start = now;
105                    true
106                } else if record.count < self.config.max_requests {
107                    // Within window and under limit
108                    record.count += 1;
109                    true
110                } else {
111                    // Rate limited
112                    false
113                }
114            }
115            None => {
116                // First request from this client
117                records.insert(
118                    key.to_string(),
119                    RequestRecord {
120                        count: 1,
121                        window_start: now,
122                    },
123                );
124                true
125            }
126        }
127    }
128
129    /// Get remaining requests for a key.
130    pub async fn remaining(&self, key: &str) -> u32 {
131        if !self.config.enabled {
132            return self.config.max_requests;
133        }
134
135        let records = self.records.read().await;
136        match records.get(key) {
137            Some(record) => {
138                let now = Instant::now();
139                if now.duration_since(record.window_start) >= self.config.window {
140                    self.config.max_requests
141                } else {
142                    self.config.max_requests.saturating_sub(record.count)
143                }
144            }
145            None => self.config.max_requests,
146        }
147    }
148
149    /// Clean up expired records.
150    pub async fn cleanup_expired(&self) {
151        let mut records = self.records.write().await;
152        let now = Instant::now();
153        records.retain(|_, record| now.duration_since(record.window_start) < self.config.window);
154    }
155}
156
157/// Rate limiting middleware.
158pub async fn rate_limit_middleware(
159    axum::Extension(limiter): axum::Extension<RateLimiter>,
160    request: Request<Body>,
161    next: Next,
162) -> Response {
163    // Check if path is exempt
164    let path = request.uri().path();
165    if limiter
166        .config
167        .exempt_paths
168        .iter()
169        .any(|p| path.starts_with(p))
170    {
171        return next.run(request).await;
172    }
173
174    // Get client identifier (IP address or fallback)
175    let client_key = extract_client_key(&request);
176
177    // Check rate limit
178    if limiter.check_rate_limit(&client_key).await {
179        let remaining = limiter.remaining(&client_key).await;
180        let mut response = next.run(request).await;
181
182        // Add rate limit headers
183        let headers = response.headers_mut();
184        if let Ok(val) = HeaderValue::try_from(limiter.config.max_requests.to_string()) {
185            headers.insert("X-RateLimit-Limit", val);
186        }
187        if let Ok(val) = HeaderValue::try_from(remaining.to_string()) {
188            headers.insert("X-RateLimit-Remaining", val);
189        }
190
191        response
192    } else {
193        // Rate limited
194        let window_secs = limiter.config.window.as_secs();
195        (
196            StatusCode::TOO_MANY_REQUESTS,
197            [
198                ("X-RateLimit-Limit", limiter.config.max_requests.to_string()),
199                ("X-RateLimit-Remaining", "0".to_string()),
200                ("Retry-After", window_secs.to_string()),
201            ],
202            format!(
203                "Rate limit exceeded. Max {} requests per {} seconds.",
204                limiter.config.max_requests, window_secs
205            ),
206        )
207            .into_response()
208    }
209}
210
211/// Extract client identifier from request.
212fn extract_client_key(request: &Request<Body>) -> String {
213    // Try X-Forwarded-For header (for proxied requests)
214    if let Some(forwarded) = request.headers().get("X-Forwarded-For") {
215        if let Ok(s) = forwarded.to_str() {
216            if let Some(ip) = s.split(',').next() {
217                return ip.trim().to_string();
218            }
219        }
220    }
221
222    // Try X-Real-IP header
223    if let Some(real_ip) = request.headers().get("X-Real-IP") {
224        if let Ok(s) = real_ip.to_str() {
225            return s.to_string();
226        }
227    }
228
229    // Fallback to a default (in production, you'd want to extract from connection info)
230    "unknown".to_string()
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use axum::{body::Body, http::Request, middleware, routing::get, Router};
237    use tower::ServiceExt;
238
239    async fn test_handler() -> &'static str {
240        "ok"
241    }
242
243    fn test_router(config: RateLimitConfig) -> Router {
244        let limiter = RateLimiter::new(config);
245        Router::new()
246            .route("/api/test", get(test_handler))
247            .route("/health", get(test_handler))
248            .layer(middleware::from_fn(rate_limit_middleware))
249            .layer(axum::Extension(limiter))
250    }
251
252    #[tokio::test]
253    async fn test_rate_limit_disabled() {
254        let config = RateLimitConfig::default();
255        let router = test_router(config);
256
257        let request = Request::builder()
258            .uri("/api/test")
259            .body(Body::empty())
260            .unwrap();
261
262        let response = router.oneshot(request).await.unwrap();
263        assert_eq!(response.status(), StatusCode::OK);
264    }
265
266    #[tokio::test]
267    async fn test_rate_limit_allows_under_limit() {
268        let config = RateLimitConfig::new(5, 60);
269        let router = test_router(config);
270
271        // Make 3 requests - should all succeed
272        for _ in 0..3 {
273            let router = router.clone();
274            let request = Request::builder()
275                .uri("/api/test")
276                .header("X-Forwarded-For", "192.168.1.1")
277                .body(Body::empty())
278                .unwrap();
279
280            let response = router.oneshot(request).await.unwrap();
281            assert_eq!(response.status(), StatusCode::OK);
282        }
283    }
284
285    #[tokio::test]
286    async fn test_rate_limit_blocks_over_limit() {
287        let config = RateLimitConfig::new(2, 60);
288        let limiter = RateLimiter::new(config.clone());
289
290        let router = Router::new()
291            .route("/api/test", get(test_handler))
292            .layer(middleware::from_fn(rate_limit_middleware))
293            .layer(axum::Extension(limiter.clone()));
294
295        // Make requests until rate limited
296        for i in 0..3 {
297            let router = router.clone();
298            let request = Request::builder()
299                .uri("/api/test")
300                .header("X-Forwarded-For", "192.168.1.100")
301                .body(Body::empty())
302                .unwrap();
303
304            let response = router.oneshot(request).await.unwrap();
305            if i < 2 {
306                assert_eq!(response.status(), StatusCode::OK);
307            } else {
308                assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
309            }
310        }
311    }
312
313    #[tokio::test]
314    async fn test_rate_limit_exempt_path() {
315        let config = RateLimitConfig::new(1, 60);
316        let limiter = RateLimiter::new(config);
317
318        let router = Router::new()
319            .route("/api/test", get(test_handler))
320            .route("/health", get(test_handler))
321            .layer(middleware::from_fn(rate_limit_middleware))
322            .layer(axum::Extension(limiter));
323
324        // Exhaust rate limit on /api/test
325        let request = Request::builder()
326            .uri("/api/test")
327            .header("X-Forwarded-For", "192.168.1.200")
328            .body(Body::empty())
329            .unwrap();
330        let _ = router.clone().oneshot(request).await.unwrap();
331
332        // /health should still work (exempt)
333        let request = Request::builder()
334            .uri("/health")
335            .header("X-Forwarded-For", "192.168.1.200")
336            .body(Body::empty())
337            .unwrap();
338        let response = router.oneshot(request).await.unwrap();
339        assert_eq!(response.status(), StatusCode::OK);
340    }
341
342    #[tokio::test]
343    async fn test_rate_limiter_cleanup() {
344        let config = RateLimitConfig::new(10, 1); // 1 second window
345        let limiter = RateLimiter::new(config);
346
347        // Make a request
348        limiter.check_rate_limit("test-client").await;
349
350        // Records should exist
351        assert!(limiter.records.read().await.contains_key("test-client"));
352
353        // Wait for window to expire
354        tokio::time::sleep(Duration::from_millis(1100)).await;
355
356        // Clean up expired
357        limiter.cleanup_expired().await;
358
359        // Records should be removed
360        assert!(!limiter.records.read().await.contains_key("test-client"));
361    }
362}