Skip to main content

datasynth_server/rest/
rate_limit_backend.rs

1//! Unified rate limiting backend.
2//!
3//! Provides a single `RateLimitBackend` enum that abstracts over in-memory
4//! and Redis-backed rate limiting, allowing the middleware to work
5//! transparently with either backend.
6
7use axum::{
8    body::Body,
9    http::{header::HeaderValue, Request, StatusCode},
10    middleware::Next,
11    response::{IntoResponse, Response},
12};
13
14use super::rate_limit::{RateLimitConfig, RateLimiter};
15#[cfg(feature = "redis")]
16use super::redis_rate_limit::RedisRateLimiter;
17
18/// Unified rate limiting backend that supports both in-memory and Redis
19/// implementations.
20///
21/// When running a single server instance, `InMemory` is sufficient.
22/// For distributed deployments with multiple server instances behind a
23/// load balancer, use `Redis` to ensure consistent rate limiting across
24/// all nodes.
25#[derive(Clone)]
26pub enum RateLimitBackend {
27    /// In-memory rate limiter (single instance only).
28    InMemory {
29        limiter: RateLimiter,
30        config: RateLimitConfig,
31    },
32    /// Redis-backed rate limiter (distributed, multi-instance).
33    #[cfg(feature = "redis")]
34    Redis {
35        limiter: Box<RedisRateLimiter>,
36        config: RateLimitConfig,
37    },
38}
39
40impl RateLimitBackend {
41    /// Create a new in-memory rate limiting backend.
42    pub fn in_memory(config: RateLimitConfig) -> Self {
43        let limiter = RateLimiter::new(config.clone());
44        Self::InMemory { limiter, config }
45    }
46
47    /// Create a new Redis-backed rate limiting backend.
48    ///
49    /// # Arguments
50    /// * `redis_url` - Redis connection URL (e.g., `redis://127.0.0.1:6379`)
51    /// * `config` - Rate limit configuration
52    #[cfg(feature = "redis")]
53    pub async fn redis(
54        redis_url: &str,
55        config: RateLimitConfig,
56    ) -> Result<Self, redis::RedisError> {
57        let limiter = RedisRateLimiter::new(redis_url, config.max_requests, config.window).await?;
58        Ok(Self::Redis {
59            limiter: Box::new(limiter),
60            config,
61        })
62    }
63
64    /// Get the rate limit configuration.
65    pub fn config(&self) -> &RateLimitConfig {
66        match self {
67            Self::InMemory { config, .. } => config,
68            #[cfg(feature = "redis")]
69            Self::Redis { config, .. } => config,
70        }
71    }
72
73    /// Check if a request from the given client should be allowed.
74    ///
75    /// Returns `true` if the request is allowed.
76    pub async fn check_rate_limit(&self, client_key: &str) -> bool {
77        match self {
78            Self::InMemory { limiter, config } => {
79                if !config.enabled {
80                    return true;
81                }
82                limiter.check_rate_limit(client_key).await
83            }
84            #[cfg(feature = "redis")]
85            Self::Redis { limiter, config } => {
86                if !config.enabled {
87                    return true;
88                }
89                limiter.check_rate_limit(client_key).await.allowed
90            }
91        }
92    }
93
94    /// Get remaining requests for a client key.
95    pub async fn remaining(&self, client_key: &str) -> u32 {
96        match self {
97            Self::InMemory { limiter, config } => {
98                if !config.enabled {
99                    return config.max_requests;
100                }
101                limiter.remaining(client_key).await
102            }
103            #[cfg(feature = "redis")]
104            Self::Redis { limiter, config } => {
105                if !config.enabled {
106                    return config.max_requests;
107                }
108                limiter.remaining(client_key).await
109            }
110        }
111    }
112
113    /// Clean up expired records (only applicable to in-memory backend).
114    ///
115    /// For Redis, TTL-based expiry is handled automatically by Redis.
116    pub async fn cleanup_expired(&self) {
117        match self {
118            Self::InMemory { limiter, .. } => {
119                limiter.cleanup_expired().await;
120            }
121            #[cfg(feature = "redis")]
122            Self::Redis { .. } => {
123                // Redis handles expiry automatically via TTL
124            }
125        }
126    }
127
128    /// Return a human-readable description of the backend type.
129    pub fn backend_name(&self) -> &'static str {
130        match self {
131            Self::InMemory { .. } => "in-memory",
132            #[cfg(feature = "redis")]
133            Self::Redis { .. } => "redis",
134        }
135    }
136}
137
138/// Rate limiting middleware that works with any `RateLimitBackend`.
139///
140/// This replaces the original `rate_limit_middleware` when using the
141/// backend abstraction. It is added to the router via
142/// `axum::middleware::from_fn(backend_rate_limit_middleware)` and expects
143/// `RateLimitBackend` to be available as an `Extension`.
144pub async fn backend_rate_limit_middleware(
145    axum::Extension(backend): axum::Extension<RateLimitBackend>,
146    request: Request<Body>,
147    next: Next,
148) -> Response {
149    let config = backend.config();
150
151    // Check if rate limiting is enabled
152    if !config.enabled {
153        return next.run(request).await;
154    }
155
156    // Check if path is exempt
157    let path = request.uri().path();
158    if config.exempt_paths.iter().any(|p| path.starts_with(p)) {
159        return next.run(request).await;
160    }
161
162    // Get client identifier (IP address or fallback)
163    let client_key = extract_client_key(&request);
164    let max_requests = config.max_requests;
165    let window_secs = config.window.as_secs();
166
167    // Check rate limit
168    if backend.check_rate_limit(&client_key).await {
169        let remaining = backend.remaining(&client_key).await;
170        let mut response = next.run(request).await;
171
172        // Add rate limit headers
173        let headers = response.headers_mut();
174        headers.insert("X-RateLimit-Limit", HeaderValue::from(max_requests));
175        headers.insert("X-RateLimit-Remaining", HeaderValue::from(remaining));
176
177        response
178    } else {
179        // Rate limited
180        (
181            StatusCode::TOO_MANY_REQUESTS,
182            [
183                ("X-RateLimit-Limit", max_requests.to_string()),
184                ("X-RateLimit-Remaining", "0".to_string()),
185                ("Retry-After", window_secs.to_string()),
186            ],
187            format!("Rate limit exceeded. Max {max_requests} requests per {window_secs} seconds."),
188        )
189            .into_response()
190    }
191}
192
193/// Extract client identifier from request.
194fn extract_client_key(request: &Request<Body>) -> String {
195    // Try X-Forwarded-For header (for proxied requests)
196    if let Some(forwarded) = request.headers().get("X-Forwarded-For") {
197        if let Ok(s) = forwarded.to_str() {
198            if let Some(ip) = s.split(',').next() {
199                return ip.trim().to_string();
200            }
201        }
202    }
203
204    // Try X-Real-IP header
205    if let Some(real_ip) = request.headers().get("X-Real-IP") {
206        if let Ok(s) = real_ip.to_str() {
207            return s.to_string();
208        }
209    }
210
211    // Fallback to a default
212    "unknown".to_string()
213}
214
215#[cfg(test)]
216#[allow(clippy::unwrap_used)]
217mod tests {
218    use super::*;
219    use axum::{body::Body, http::Request, middleware, routing::get, Router};
220    use tower::ServiceExt;
221
222    async fn test_handler() -> &'static str {
223        "ok"
224    }
225
226    fn test_router_with_backend(config: RateLimitConfig) -> Router {
227        let backend = RateLimitBackend::in_memory(config);
228        Router::new()
229            .route("/api/test", get(test_handler))
230            .route("/health", get(test_handler))
231            .layer(middleware::from_fn(backend_rate_limit_middleware))
232            .layer(axum::Extension(backend))
233    }
234
235    #[tokio::test]
236    async fn test_backend_rate_limit_disabled() {
237        let config = RateLimitConfig::default(); // disabled by default
238        let router = test_router_with_backend(config);
239
240        let request = Request::builder()
241            .uri("/api/test")
242            .body(Body::empty())
243            .unwrap();
244
245        let response = router.oneshot(request).await.unwrap();
246        assert_eq!(response.status(), StatusCode::OK);
247    }
248
249    #[tokio::test]
250    async fn test_backend_rate_limit_allows_under_limit() {
251        let config = RateLimitConfig::new(5, 60);
252        let router = test_router_with_backend(config);
253
254        for _ in 0..3 {
255            let router = router.clone();
256            let request = Request::builder()
257                .uri("/api/test")
258                .header("X-Forwarded-For", "192.168.1.1")
259                .body(Body::empty())
260                .unwrap();
261
262            let response = router.oneshot(request).await.unwrap();
263            assert_eq!(response.status(), StatusCode::OK);
264        }
265    }
266
267    #[tokio::test]
268    async fn test_backend_rate_limit_blocks_over_limit() {
269        let config = RateLimitConfig::new(2, 60);
270        let backend = RateLimitBackend::in_memory(config.clone());
271
272        let router = Router::new()
273            .route("/api/test", get(test_handler))
274            .layer(middleware::from_fn(backend_rate_limit_middleware))
275            .layer(axum::Extension(backend));
276
277        for i in 0..3 {
278            let router = router.clone();
279            let request = Request::builder()
280                .uri("/api/test")
281                .header("X-Forwarded-For", "192.168.1.100")
282                .body(Body::empty())
283                .unwrap();
284
285            let response = router.oneshot(request).await.unwrap();
286            if i < 2 {
287                assert_eq!(response.status(), StatusCode::OK);
288            } else {
289                assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
290            }
291        }
292    }
293
294    #[tokio::test]
295    async fn test_backend_rate_limit_exempt_path() {
296        let config = RateLimitConfig::new(1, 60);
297        let backend = RateLimitBackend::in_memory(config);
298
299        let router = Router::new()
300            .route("/api/test", get(test_handler))
301            .route("/health", get(test_handler))
302            .layer(middleware::from_fn(backend_rate_limit_middleware))
303            .layer(axum::Extension(backend));
304
305        // Exhaust rate limit on /api/test
306        let request = Request::builder()
307            .uri("/api/test")
308            .header("X-Forwarded-For", "192.168.1.200")
309            .body(Body::empty())
310            .unwrap();
311        let _ = router.clone().oneshot(request).await.unwrap();
312
313        // /health should still work (exempt)
314        let request = Request::builder()
315            .uri("/health")
316            .header("X-Forwarded-For", "192.168.1.200")
317            .body(Body::empty())
318            .unwrap();
319        let response = router.oneshot(request).await.unwrap();
320        assert_eq!(response.status(), StatusCode::OK);
321    }
322
323    #[test]
324    fn test_backend_name_in_memory() {
325        let config = RateLimitConfig::default();
326        let backend = RateLimitBackend::in_memory(config);
327        assert_eq!(backend.backend_name(), "in-memory");
328    }
329
330    #[tokio::test]
331    async fn test_backend_cleanup_in_memory() {
332        let config = RateLimitConfig::new(10, 1);
333        let backend = RateLimitBackend::in_memory(config);
334
335        // Should not panic for in-memory
336        backend.cleanup_expired().await;
337    }
338}