Skip to main content

datasynth_server/rest/
rate_limit.rs

1//! Rate limiting middleware for REST API.
2//!
3//! Provides configurable rate limiting to prevent abuse.
4
5use axum::{
6    body::Body,
7    http::{header::HeaderValue, Request, StatusCode},
8    middleware::Next,
9    response::{IntoResponse, Response},
10};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::RwLock;
15
16/// Rate limiting configuration.
17#[derive(Clone, Debug)]
18pub struct RateLimitConfig {
19    /// Whether rate limiting is enabled.
20    pub enabled: bool,
21    /// Maximum requests per window.
22    pub max_requests: u32,
23    /// Time window duration.
24    pub window: Duration,
25    /// Paths exempt from rate limiting.
26    pub exempt_paths: Vec<String>,
27}
28
29impl Default for RateLimitConfig {
30    fn default() -> Self {
31        Self {
32            enabled: false,
33            max_requests: 100,
34            window: Duration::from_secs(60), // 100 requests per minute
35            exempt_paths: vec![
36                "/health".to_string(),
37                "/ready".to_string(),
38                "/live".to_string(),
39            ],
40        }
41    }
42}
43
44impl RateLimitConfig {
45    /// Create a new rate limit config with custom limits.
46    pub fn new(max_requests: u32, window_secs: u64) -> Self {
47        Self {
48            enabled: true,
49            max_requests,
50            window: Duration::from_secs(window_secs),
51            exempt_paths: vec![
52                "/health".to_string(),
53                "/ready".to_string(),
54                "/live".to_string(),
55            ],
56        }
57    }
58
59    /// Add exempt paths.
60    pub fn with_exempt_paths(mut self, paths: Vec<String>) -> Self {
61        self.exempt_paths.extend(paths);
62        self
63    }
64}
65
66/// Request record for rate limiting.
67#[derive(Clone)]
68struct RequestRecord {
69    count: u32,
70    window_start: Instant,
71}
72
73/// Shared rate limiter state.
74#[derive(Clone)]
75pub struct RateLimiter {
76    config: RateLimitConfig,
77    records: Arc<RwLock<HashMap<String, RequestRecord>>>,
78}
79
80impl RateLimiter {
81    /// Create a new rate limiter.
82    pub fn new(config: RateLimitConfig) -> Self {
83        Self {
84            config,
85            records: Arc::new(RwLock::new(HashMap::new())),
86        }
87    }
88
89    /// Check if request should be allowed.
90    pub async fn check_rate_limit(&self, key: &str) -> bool {
91        if !self.config.enabled {
92            return true;
93        }
94
95        let mut records = self.records.write().await;
96        let now = Instant::now();
97
98        match records.get_mut(key) {
99            Some(record) => {
100                // Check if we're in a new window
101                if now.duration_since(record.window_start) >= self.config.window {
102                    // Reset for new window
103                    record.count = 1;
104                    record.window_start = now;
105                    true
106                } else if record.count < self.config.max_requests {
107                    // Within window and under limit
108                    record.count += 1;
109                    true
110                } else {
111                    // Rate limited
112                    false
113                }
114            }
115            None => {
116                // First request from this client
117                records.insert(
118                    key.to_string(),
119                    RequestRecord {
120                        count: 1,
121                        window_start: now,
122                    },
123                );
124                true
125            }
126        }
127    }
128
129    /// Get remaining requests for a key.
130    pub async fn remaining(&self, key: &str) -> u32 {
131        if !self.config.enabled {
132            return self.config.max_requests;
133        }
134
135        let records = self.records.read().await;
136        match records.get(key) {
137            Some(record) => {
138                let now = Instant::now();
139                if now.duration_since(record.window_start) >= self.config.window {
140                    self.config.max_requests
141                } else {
142                    self.config.max_requests.saturating_sub(record.count)
143                }
144            }
145            None => self.config.max_requests,
146        }
147    }
148
149    /// Clean up expired records.
150    pub async fn cleanup_expired(&self) {
151        let mut records = self.records.write().await;
152        let now = Instant::now();
153        records.retain(|_, record| now.duration_since(record.window_start) < self.config.window);
154    }
155}
156
157/// Rate limiting middleware.
158pub async fn rate_limit_middleware(
159    axum::Extension(limiter): axum::Extension<RateLimiter>,
160    request: Request<Body>,
161    next: Next,
162) -> Response {
163    // Check if path is exempt
164    let path = request.uri().path();
165    if limiter
166        .config
167        .exempt_paths
168        .iter()
169        .any(|p| path.starts_with(p))
170    {
171        return next.run(request).await;
172    }
173
174    // Get client identifier (IP address or fallback)
175    let client_key = extract_client_key(&request);
176
177    // Check rate limit
178    if limiter.check_rate_limit(&client_key).await {
179        let remaining = limiter.remaining(&client_key).await;
180        let mut response = next.run(request).await;
181
182        // Add rate limit headers
183        let headers = response.headers_mut();
184        if let Ok(val) = HeaderValue::try_from(limiter.config.max_requests.to_string()) {
185            headers.insert("X-RateLimit-Limit", val);
186        }
187        if let Ok(val) = HeaderValue::try_from(remaining.to_string()) {
188            headers.insert("X-RateLimit-Remaining", val);
189        }
190
191        response
192    } else {
193        // Rate limited
194        let window_secs = limiter.config.window.as_secs();
195        (
196            StatusCode::TOO_MANY_REQUESTS,
197            [
198                ("X-RateLimit-Limit", limiter.config.max_requests.to_string()),
199                ("X-RateLimit-Remaining", "0".to_string()),
200                ("Retry-After", window_secs.to_string()),
201            ],
202            format!(
203                "Rate limit exceeded. Max {} requests per {} seconds.",
204                limiter.config.max_requests, window_secs
205            ),
206        )
207            .into_response()
208    }
209}
210
211/// Extract client identifier from request.
212fn extract_client_key(request: &Request<Body>) -> String {
213    // Try X-Forwarded-For header (for proxied requests)
214    if let Some(forwarded) = request.headers().get("X-Forwarded-For") {
215        if let Ok(s) = forwarded.to_str() {
216            if let Some(ip) = s.split(',').next() {
217                return ip.trim().to_string();
218            }
219        }
220    }
221
222    // Try X-Real-IP header
223    if let Some(real_ip) = request.headers().get("X-Real-IP") {
224        if let Ok(s) = real_ip.to_str() {
225            return s.to_string();
226        }
227    }
228
229    // Fallback to a default (in production, you'd want to extract from connection info)
230    "unknown".to_string()
231}
232
233#[cfg(test)]
234#[allow(clippy::unwrap_used)]
235mod tests {
236    use super::*;
237    use axum::{body::Body, http::Request, middleware, routing::get, Router};
238    use tower::ServiceExt;
239
240    async fn test_handler() -> &'static str {
241        "ok"
242    }
243
244    fn test_router(config: RateLimitConfig) -> Router {
245        let limiter = RateLimiter::new(config);
246        Router::new()
247            .route("/api/test", get(test_handler))
248            .route("/health", get(test_handler))
249            .layer(middleware::from_fn(rate_limit_middleware))
250            .layer(axum::Extension(limiter))
251    }
252
253    #[tokio::test]
254    async fn test_rate_limit_disabled() {
255        let config = RateLimitConfig::default();
256        let router = test_router(config);
257
258        let request = Request::builder()
259            .uri("/api/test")
260            .body(Body::empty())
261            .unwrap();
262
263        let response = router.oneshot(request).await.unwrap();
264        assert_eq!(response.status(), StatusCode::OK);
265    }
266
267    #[tokio::test]
268    async fn test_rate_limit_allows_under_limit() {
269        let config = RateLimitConfig::new(5, 60);
270        let router = test_router(config);
271
272        // Make 3 requests - should all succeed
273        for _ in 0..3 {
274            let router = router.clone();
275            let request = Request::builder()
276                .uri("/api/test")
277                .header("X-Forwarded-For", "192.168.1.1")
278                .body(Body::empty())
279                .unwrap();
280
281            let response = router.oneshot(request).await.unwrap();
282            assert_eq!(response.status(), StatusCode::OK);
283        }
284    }
285
286    #[tokio::test]
287    async fn test_rate_limit_blocks_over_limit() {
288        let config = RateLimitConfig::new(2, 60);
289        let limiter = RateLimiter::new(config.clone());
290
291        let router = Router::new()
292            .route("/api/test", get(test_handler))
293            .layer(middleware::from_fn(rate_limit_middleware))
294            .layer(axum::Extension(limiter.clone()));
295
296        // Make requests until rate limited
297        for i in 0..3 {
298            let router = router.clone();
299            let request = Request::builder()
300                .uri("/api/test")
301                .header("X-Forwarded-For", "192.168.1.100")
302                .body(Body::empty())
303                .unwrap();
304
305            let response = router.oneshot(request).await.unwrap();
306            if i < 2 {
307                assert_eq!(response.status(), StatusCode::OK);
308            } else {
309                assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
310            }
311        }
312    }
313
314    #[tokio::test]
315    async fn test_rate_limit_exempt_path() {
316        let config = RateLimitConfig::new(1, 60);
317        let limiter = RateLimiter::new(config);
318
319        let router = Router::new()
320            .route("/api/test", get(test_handler))
321            .route("/health", get(test_handler))
322            .layer(middleware::from_fn(rate_limit_middleware))
323            .layer(axum::Extension(limiter));
324
325        // Exhaust rate limit on /api/test
326        let request = Request::builder()
327            .uri("/api/test")
328            .header("X-Forwarded-For", "192.168.1.200")
329            .body(Body::empty())
330            .unwrap();
331        let _ = router.clone().oneshot(request).await.unwrap();
332
333        // /health should still work (exempt)
334        let request = Request::builder()
335            .uri("/health")
336            .header("X-Forwarded-For", "192.168.1.200")
337            .body(Body::empty())
338            .unwrap();
339        let response = router.oneshot(request).await.unwrap();
340        assert_eq!(response.status(), StatusCode::OK);
341    }
342
343    #[tokio::test]
344    async fn test_rate_limiter_cleanup() {
345        let config = RateLimitConfig::new(10, 1); // 1 second window
346        let limiter = RateLimiter::new(config);
347
348        // Make a request
349        limiter.check_rate_limit("test-client").await;
350
351        // Records should exist
352        assert!(limiter.records.read().await.contains_key("test-client"));
353
354        // Wait for window to expire
355        tokio::time::sleep(Duration::from_millis(1100)).await;
356
357        // Clean up expired
358        limiter.cleanup_expired().await;
359
360        // Records should be removed
361        assert!(!limiter.records.read().await.contains_key("test-client"));
362    }
363}