Skip to main content

datasynth_server/rest/
rate_limit_backend.rs

1//! Unified rate limiting backend.
2//!
3//! Provides a single `RateLimitBackend` enum that abstracts over in-memory
4//! and Redis-backed rate limiting, allowing the middleware to work
5//! transparently with either backend.
6
7use axum::{
8    body::Body,
9    http::{Request, StatusCode},
10    middleware::Next,
11    response::{IntoResponse, Response},
12};
13
14use super::rate_limit::{RateLimitConfig, RateLimiter};
15#[cfg(feature = "redis")]
16use super::redis_rate_limit::RedisRateLimiter;
17
18/// Unified rate limiting backend that supports both in-memory and Redis
19/// implementations.
20///
21/// When running a single server instance, `InMemory` is sufficient.
22/// For distributed deployments with multiple server instances behind a
23/// load balancer, use `Redis` to ensure consistent rate limiting across
24/// all nodes.
25#[derive(Clone)]
26pub enum RateLimitBackend {
27    /// In-memory rate limiter (single instance only).
28    InMemory {
29        limiter: RateLimiter,
30        config: RateLimitConfig,
31    },
32    /// Redis-backed rate limiter (distributed, multi-instance).
33    #[cfg(feature = "redis")]
34    Redis {
35        limiter: Box<RedisRateLimiter>,
36        config: RateLimitConfig,
37    },
38}
39
40impl RateLimitBackend {
41    /// Create a new in-memory rate limiting backend.
42    pub fn in_memory(config: RateLimitConfig) -> Self {
43        let limiter = RateLimiter::new(config.clone());
44        Self::InMemory { limiter, config }
45    }
46
47    /// Create a new Redis-backed rate limiting backend.
48    ///
49    /// # Arguments
50    /// * `redis_url` - Redis connection URL (e.g., `redis://127.0.0.1:6379`)
51    /// * `config` - Rate limit configuration
52    #[cfg(feature = "redis")]
53    pub async fn redis(
54        redis_url: &str,
55        config: RateLimitConfig,
56    ) -> Result<Self, redis::RedisError> {
57        let limiter = RedisRateLimiter::new(redis_url, config.max_requests, config.window).await?;
58        Ok(Self::Redis {
59            limiter: Box::new(limiter),
60            config,
61        })
62    }
63
64    /// Get the rate limit configuration.
65    pub fn config(&self) -> &RateLimitConfig {
66        match self {
67            Self::InMemory { config, .. } => config,
68            #[cfg(feature = "redis")]
69            Self::Redis { config, .. } => config,
70        }
71    }
72
73    /// Check if a request from the given client should be allowed.
74    ///
75    /// Returns `true` if the request is allowed.
76    pub async fn check_rate_limit(&self, client_key: &str) -> bool {
77        match self {
78            Self::InMemory { limiter, config } => {
79                if !config.enabled {
80                    return true;
81                }
82                limiter.check_rate_limit(client_key).await
83            }
84            #[cfg(feature = "redis")]
85            Self::Redis { limiter, config } => {
86                if !config.enabled {
87                    return true;
88                }
89                limiter.check_rate_limit(client_key).await.allowed
90            }
91        }
92    }
93
94    /// Get remaining requests for a client key.
95    pub async fn remaining(&self, client_key: &str) -> u32 {
96        match self {
97            Self::InMemory { limiter, config } => {
98                if !config.enabled {
99                    return config.max_requests;
100                }
101                limiter.remaining(client_key).await
102            }
103            #[cfg(feature = "redis")]
104            Self::Redis { limiter, config } => {
105                if !config.enabled {
106                    return config.max_requests;
107                }
108                limiter.remaining(client_key).await
109            }
110        }
111    }
112
113    /// Clean up expired records (only applicable to in-memory backend).
114    ///
115    /// For Redis, TTL-based expiry is handled automatically by Redis.
116    pub async fn cleanup_expired(&self) {
117        match self {
118            Self::InMemory { limiter, .. } => {
119                limiter.cleanup_expired().await;
120            }
121            #[cfg(feature = "redis")]
122            Self::Redis { .. } => {
123                // Redis handles expiry automatically via TTL
124            }
125        }
126    }
127
128    /// Return a human-readable description of the backend type.
129    pub fn backend_name(&self) -> &'static str {
130        match self {
131            Self::InMemory { .. } => "in-memory",
132            #[cfg(feature = "redis")]
133            Self::Redis { .. } => "redis",
134        }
135    }
136}
137
138/// Rate limiting middleware that works with any `RateLimitBackend`.
139///
140/// This replaces the original `rate_limit_middleware` when using the
141/// backend abstraction. It is added to the router via
142/// `axum::middleware::from_fn(backend_rate_limit_middleware)` and expects
143/// `RateLimitBackend` to be available as an `Extension`.
144pub async fn backend_rate_limit_middleware(
145    axum::Extension(backend): axum::Extension<RateLimitBackend>,
146    request: Request<Body>,
147    next: Next,
148) -> Response {
149    let config = backend.config();
150
151    // Check if rate limiting is enabled
152    if !config.enabled {
153        return next.run(request).await;
154    }
155
156    // Check if path is exempt
157    let path = request.uri().path();
158    if config.exempt_paths.iter().any(|p| path.starts_with(p)) {
159        return next.run(request).await;
160    }
161
162    // Get client identifier (IP address or fallback)
163    let client_key = extract_client_key(&request);
164    let max_requests = config.max_requests;
165    let window_secs = config.window.as_secs();
166
167    // Check rate limit
168    if backend.check_rate_limit(&client_key).await {
169        let remaining = backend.remaining(&client_key).await;
170        let mut response = next.run(request).await;
171
172        // Add rate limit headers
173        let headers = response.headers_mut();
174        headers.insert(
175            "X-RateLimit-Limit",
176            max_requests.to_string().parse().unwrap(),
177        );
178        headers.insert(
179            "X-RateLimit-Remaining",
180            remaining.to_string().parse().unwrap(),
181        );
182
183        response
184    } else {
185        // Rate limited
186        (
187            StatusCode::TOO_MANY_REQUESTS,
188            [
189                ("X-RateLimit-Limit", max_requests.to_string()),
190                ("X-RateLimit-Remaining", "0".to_string()),
191                ("Retry-After", window_secs.to_string()),
192            ],
193            format!(
194                "Rate limit exceeded. Max {} requests per {} seconds.",
195                max_requests, window_secs
196            ),
197        )
198            .into_response()
199    }
200}
201
202/// Extract client identifier from request.
203fn extract_client_key(request: &Request<Body>) -> String {
204    // Try X-Forwarded-For header (for proxied requests)
205    if let Some(forwarded) = request.headers().get("X-Forwarded-For") {
206        if let Ok(s) = forwarded.to_str() {
207            if let Some(ip) = s.split(',').next() {
208                return ip.trim().to_string();
209            }
210        }
211    }
212
213    // Try X-Real-IP header
214    if let Some(real_ip) = request.headers().get("X-Real-IP") {
215        if let Ok(s) = real_ip.to_str() {
216            return s.to_string();
217        }
218    }
219
220    // Fallback to a default
221    "unknown".to_string()
222}
223
224#[cfg(test)]
225#[allow(clippy::unwrap_used)]
226mod tests {
227    use super::*;
228    use axum::{body::Body, http::Request, middleware, routing::get, Router};
229    use tower::ServiceExt;
230
231    async fn test_handler() -> &'static str {
232        "ok"
233    }
234
235    fn test_router_with_backend(config: RateLimitConfig) -> Router {
236        let backend = RateLimitBackend::in_memory(config);
237        Router::new()
238            .route("/api/test", get(test_handler))
239            .route("/health", get(test_handler))
240            .layer(middleware::from_fn(backend_rate_limit_middleware))
241            .layer(axum::Extension(backend))
242    }
243
244    #[tokio::test]
245    async fn test_backend_rate_limit_disabled() {
246        let config = RateLimitConfig::default(); // disabled by default
247        let router = test_router_with_backend(config);
248
249        let request = Request::builder()
250            .uri("/api/test")
251            .body(Body::empty())
252            .unwrap();
253
254        let response = router.oneshot(request).await.unwrap();
255        assert_eq!(response.status(), StatusCode::OK);
256    }
257
258    #[tokio::test]
259    async fn test_backend_rate_limit_allows_under_limit() {
260        let config = RateLimitConfig::new(5, 60);
261        let router = test_router_with_backend(config);
262
263        for _ in 0..3 {
264            let router = router.clone();
265            let request = Request::builder()
266                .uri("/api/test")
267                .header("X-Forwarded-For", "192.168.1.1")
268                .body(Body::empty())
269                .unwrap();
270
271            let response = router.oneshot(request).await.unwrap();
272            assert_eq!(response.status(), StatusCode::OK);
273        }
274    }
275
276    #[tokio::test]
277    async fn test_backend_rate_limit_blocks_over_limit() {
278        let config = RateLimitConfig::new(2, 60);
279        let backend = RateLimitBackend::in_memory(config.clone());
280
281        let router = Router::new()
282            .route("/api/test", get(test_handler))
283            .layer(middleware::from_fn(backend_rate_limit_middleware))
284            .layer(axum::Extension(backend));
285
286        for i in 0..3 {
287            let router = router.clone();
288            let request = Request::builder()
289                .uri("/api/test")
290                .header("X-Forwarded-For", "192.168.1.100")
291                .body(Body::empty())
292                .unwrap();
293
294            let response = router.oneshot(request).await.unwrap();
295            if i < 2 {
296                assert_eq!(response.status(), StatusCode::OK);
297            } else {
298                assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
299            }
300        }
301    }
302
303    #[tokio::test]
304    async fn test_backend_rate_limit_exempt_path() {
305        let config = RateLimitConfig::new(1, 60);
306        let backend = RateLimitBackend::in_memory(config);
307
308        let router = Router::new()
309            .route("/api/test", get(test_handler))
310            .route("/health", get(test_handler))
311            .layer(middleware::from_fn(backend_rate_limit_middleware))
312            .layer(axum::Extension(backend));
313
314        // Exhaust rate limit on /api/test
315        let request = Request::builder()
316            .uri("/api/test")
317            .header("X-Forwarded-For", "192.168.1.200")
318            .body(Body::empty())
319            .unwrap();
320        let _ = router.clone().oneshot(request).await.unwrap();
321
322        // /health should still work (exempt)
323        let request = Request::builder()
324            .uri("/health")
325            .header("X-Forwarded-For", "192.168.1.200")
326            .body(Body::empty())
327            .unwrap();
328        let response = router.oneshot(request).await.unwrap();
329        assert_eq!(response.status(), StatusCode::OK);
330    }
331
332    #[test]
333    fn test_backend_name_in_memory() {
334        let config = RateLimitConfig::default();
335        let backend = RateLimitBackend::in_memory(config);
336        assert_eq!(backend.backend_name(), "in-memory");
337    }
338
339    #[tokio::test]
340    async fn test_backend_cleanup_in_memory() {
341        let config = RateLimitConfig::new(10, 1);
342        let backend = RateLimitBackend::in_memory(config);
343
344        // Should not panic for in-memory
345        backend.cleanup_expired().await;
346    }
347}