Skip to main content

datasynth_server/rest/
rate_limit.rs

1//! Rate limiting middleware for REST API.
2//!
3//! Provides configurable rate limiting to prevent abuse.
4
5use axum::{
6    body::Body,
7    http::{Request, StatusCode},
8    middleware::Next,
9    response::{IntoResponse, Response},
10};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::RwLock;
15
16/// Rate limiting configuration.
17#[derive(Clone, Debug)]
18pub struct RateLimitConfig {
19    /// Whether rate limiting is enabled.
20    pub enabled: bool,
21    /// Maximum requests per window.
22    pub max_requests: u32,
23    /// Time window duration.
24    pub window: Duration,
25    /// Paths exempt from rate limiting.
26    pub exempt_paths: Vec<String>,
27}
28
29impl Default for RateLimitConfig {
30    fn default() -> Self {
31        Self {
32            enabled: false,
33            max_requests: 100,
34            window: Duration::from_secs(60), // 100 requests per minute
35            exempt_paths: vec![
36                "/health".to_string(),
37                "/ready".to_string(),
38                "/live".to_string(),
39            ],
40        }
41    }
42}
43
44impl RateLimitConfig {
45    /// Create a new rate limit config with custom limits.
46    pub fn new(max_requests: u32, window_secs: u64) -> Self {
47        Self {
48            enabled: true,
49            max_requests,
50            window: Duration::from_secs(window_secs),
51            exempt_paths: vec![
52                "/health".to_string(),
53                "/ready".to_string(),
54                "/live".to_string(),
55            ],
56        }
57    }
58
59    /// Add exempt paths.
60    pub fn with_exempt_paths(mut self, paths: Vec<String>) -> Self {
61        self.exempt_paths.extend(paths);
62        self
63    }
64}
65
66/// Request record for rate limiting.
67#[derive(Clone)]
68struct RequestRecord {
69    count: u32,
70    window_start: Instant,
71}
72
73/// Shared rate limiter state.
74#[derive(Clone)]
75pub struct RateLimiter {
76    config: RateLimitConfig,
77    records: Arc<RwLock<HashMap<String, RequestRecord>>>,
78}
79
80impl RateLimiter {
81    /// Create a new rate limiter.
82    pub fn new(config: RateLimitConfig) -> Self {
83        Self {
84            config,
85            records: Arc::new(RwLock::new(HashMap::new())),
86        }
87    }
88
89    /// Check if request should be allowed.
90    pub async fn check_rate_limit(&self, key: &str) -> bool {
91        if !self.config.enabled {
92            return true;
93        }
94
95        let mut records = self.records.write().await;
96        let now = Instant::now();
97
98        match records.get_mut(key) {
99            Some(record) => {
100                // Check if we're in a new window
101                if now.duration_since(record.window_start) >= self.config.window {
102                    // Reset for new window
103                    record.count = 1;
104                    record.window_start = now;
105                    true
106                } else if record.count < self.config.max_requests {
107                    // Within window and under limit
108                    record.count += 1;
109                    true
110                } else {
111                    // Rate limited
112                    false
113                }
114            }
115            None => {
116                // First request from this client
117                records.insert(
118                    key.to_string(),
119                    RequestRecord {
120                        count: 1,
121                        window_start: now,
122                    },
123                );
124                true
125            }
126        }
127    }
128
129    /// Get remaining requests for a key.
130    pub async fn remaining(&self, key: &str) -> u32 {
131        if !self.config.enabled {
132            return self.config.max_requests;
133        }
134
135        let records = self.records.read().await;
136        match records.get(key) {
137            Some(record) => {
138                let now = Instant::now();
139                if now.duration_since(record.window_start) >= self.config.window {
140                    self.config.max_requests
141                } else {
142                    self.config.max_requests.saturating_sub(record.count)
143                }
144            }
145            None => self.config.max_requests,
146        }
147    }
148
149    /// Clean up expired records.
150    pub async fn cleanup_expired(&self) {
151        let mut records = self.records.write().await;
152        let now = Instant::now();
153        records.retain(|_, record| now.duration_since(record.window_start) < self.config.window);
154    }
155}
156
157/// Rate limiting middleware.
158pub async fn rate_limit_middleware(
159    axum::Extension(limiter): axum::Extension<RateLimiter>,
160    request: Request<Body>,
161    next: Next,
162) -> Response {
163    // Check if path is exempt
164    let path = request.uri().path();
165    if limiter
166        .config
167        .exempt_paths
168        .iter()
169        .any(|p| path.starts_with(p))
170    {
171        return next.run(request).await;
172    }
173
174    // Get client identifier (IP address or fallback)
175    let client_key = extract_client_key(&request);
176
177    // Check rate limit
178    if limiter.check_rate_limit(&client_key).await {
179        let remaining = limiter.remaining(&client_key).await;
180        let mut response = next.run(request).await;
181
182        // Add rate limit headers
183        let headers = response.headers_mut();
184        headers.insert(
185            "X-RateLimit-Limit",
186            limiter.config.max_requests.to_string().parse().unwrap(),
187        );
188        headers.insert(
189            "X-RateLimit-Remaining",
190            remaining.to_string().parse().unwrap(),
191        );
192
193        response
194    } else {
195        // Rate limited
196        let window_secs = limiter.config.window.as_secs();
197        (
198            StatusCode::TOO_MANY_REQUESTS,
199            [
200                ("X-RateLimit-Limit", limiter.config.max_requests.to_string()),
201                ("X-RateLimit-Remaining", "0".to_string()),
202                ("Retry-After", window_secs.to_string()),
203            ],
204            format!(
205                "Rate limit exceeded. Max {} requests per {} seconds.",
206                limiter.config.max_requests, window_secs
207            ),
208        )
209            .into_response()
210    }
211}
212
213/// Extract client identifier from request.
214fn extract_client_key(request: &Request<Body>) -> String {
215    // Try X-Forwarded-For header (for proxied requests)
216    if let Some(forwarded) = request.headers().get("X-Forwarded-For") {
217        if let Ok(s) = forwarded.to_str() {
218            if let Some(ip) = s.split(',').next() {
219                return ip.trim().to_string();
220            }
221        }
222    }
223
224    // Try X-Real-IP header
225    if let Some(real_ip) = request.headers().get("X-Real-IP") {
226        if let Ok(s) = real_ip.to_str() {
227            return s.to_string();
228        }
229    }
230
231    // Fallback to a default (in production, you'd want to extract from connection info)
232    "unknown".to_string()
233}
234
235#[cfg(test)]
236#[allow(clippy::unwrap_used)]
237mod tests {
238    use super::*;
239    use axum::{body::Body, http::Request, middleware, routing::get, Router};
240    use tower::ServiceExt;
241
242    async fn test_handler() -> &'static str {
243        "ok"
244    }
245
246    fn test_router(config: RateLimitConfig) -> Router {
247        let limiter = RateLimiter::new(config);
248        Router::new()
249            .route("/api/test", get(test_handler))
250            .route("/health", get(test_handler))
251            .layer(middleware::from_fn(rate_limit_middleware))
252            .layer(axum::Extension(limiter))
253    }
254
255    #[tokio::test]
256    async fn test_rate_limit_disabled() {
257        let config = RateLimitConfig::default();
258        let router = test_router(config);
259
260        let request = Request::builder()
261            .uri("/api/test")
262            .body(Body::empty())
263            .unwrap();
264
265        let response = router.oneshot(request).await.unwrap();
266        assert_eq!(response.status(), StatusCode::OK);
267    }
268
269    #[tokio::test]
270    async fn test_rate_limit_allows_under_limit() {
271        let config = RateLimitConfig::new(5, 60);
272        let router = test_router(config);
273
274        // Make 3 requests - should all succeed
275        for _ in 0..3 {
276            let router = router.clone();
277            let request = Request::builder()
278                .uri("/api/test")
279                .header("X-Forwarded-For", "192.168.1.1")
280                .body(Body::empty())
281                .unwrap();
282
283            let response = router.oneshot(request).await.unwrap();
284            assert_eq!(response.status(), StatusCode::OK);
285        }
286    }
287
288    #[tokio::test]
289    async fn test_rate_limit_blocks_over_limit() {
290        let config = RateLimitConfig::new(2, 60);
291        let limiter = RateLimiter::new(config.clone());
292
293        let router = Router::new()
294            .route("/api/test", get(test_handler))
295            .layer(middleware::from_fn(rate_limit_middleware))
296            .layer(axum::Extension(limiter.clone()));
297
298        // Make requests until rate limited
299        for i in 0..3 {
300            let router = router.clone();
301            let request = Request::builder()
302                .uri("/api/test")
303                .header("X-Forwarded-For", "192.168.1.100")
304                .body(Body::empty())
305                .unwrap();
306
307            let response = router.oneshot(request).await.unwrap();
308            if i < 2 {
309                assert_eq!(response.status(), StatusCode::OK);
310            } else {
311                assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
312            }
313        }
314    }
315
316    #[tokio::test]
317    async fn test_rate_limit_exempt_path() {
318        let config = RateLimitConfig::new(1, 60);
319        let limiter = RateLimiter::new(config);
320
321        let router = Router::new()
322            .route("/api/test", get(test_handler))
323            .route("/health", get(test_handler))
324            .layer(middleware::from_fn(rate_limit_middleware))
325            .layer(axum::Extension(limiter));
326
327        // Exhaust rate limit on /api/test
328        let request = Request::builder()
329            .uri("/api/test")
330            .header("X-Forwarded-For", "192.168.1.200")
331            .body(Body::empty())
332            .unwrap();
333        let _ = router.clone().oneshot(request).await.unwrap();
334
335        // /health should still work (exempt)
336        let request = Request::builder()
337            .uri("/health")
338            .header("X-Forwarded-For", "192.168.1.200")
339            .body(Body::empty())
340            .unwrap();
341        let response = router.oneshot(request).await.unwrap();
342        assert_eq!(response.status(), StatusCode::OK);
343    }
344
345    #[tokio::test]
346    async fn test_rate_limiter_cleanup() {
347        let config = RateLimitConfig::new(10, 1); // 1 second window
348        let limiter = RateLimiter::new(config);
349
350        // Make a request
351        limiter.check_rate_limit("test-client").await;
352
353        // Records should exist
354        assert!(limiter.records.read().await.contains_key("test-client"));
355
356        // Wait for window to expire
357        tokio::time::sleep(Duration::from_millis(1100)).await;
358
359        // Clean up expired
360        limiter.cleanup_expired().await;
361
362        // Records should be removed
363        assert!(!limiter.records.read().await.contains_key("test-client"));
364    }
365}