Skip to main content

datasynth_server/rest/
rate_limit_backend.rs

1//! Unified rate limiting backend.
2//!
3//! Provides a single `RateLimitBackend` enum that abstracts over in-memory
4//! and Redis-backed rate limiting, allowing the middleware to work
5//! transparently with either backend.
6
7use axum::{
8    body::Body,
9    http::{header::HeaderValue, Request, StatusCode},
10    middleware::Next,
11    response::{IntoResponse, Response},
12};
13
14use super::rate_limit::{RateLimitConfig, RateLimiter};
15#[cfg(feature = "redis")]
16use super::redis_rate_limit::RedisRateLimiter;
17
18/// Unified rate limiting backend that supports both in-memory and Redis
19/// implementations.
20///
21/// When running a single server instance, `InMemory` is sufficient.
22/// For distributed deployments with multiple server instances behind a
23/// load balancer, use `Redis` to ensure consistent rate limiting across
24/// all nodes.
25#[derive(Clone)]
26pub enum RateLimitBackend {
27    /// In-memory rate limiter (single instance only).
28    InMemory {
29        limiter: RateLimiter,
30        config: RateLimitConfig,
31    },
32    /// Redis-backed rate limiter (distributed, multi-instance).
33    #[cfg(feature = "redis")]
34    Redis {
35        limiter: Box<RedisRateLimiter>,
36        config: RateLimitConfig,
37    },
38}
39
40impl RateLimitBackend {
41    /// Create a new in-memory rate limiting backend.
42    pub fn in_memory(config: RateLimitConfig) -> Self {
43        let limiter = RateLimiter::new(config.clone());
44        Self::InMemory { limiter, config }
45    }
46
47    /// Create a new Redis-backed rate limiting backend.
48    ///
49    /// # Arguments
50    /// * `redis_url` - Redis connection URL (e.g., `redis://127.0.0.1:6379`)
51    /// * `config` - Rate limit configuration
52    #[cfg(feature = "redis")]
53    pub async fn redis(
54        redis_url: &str,
55        config: RateLimitConfig,
56    ) -> Result<Self, redis::RedisError> {
57        let limiter = RedisRateLimiter::new(redis_url, config.max_requests, config.window).await?;
58        Ok(Self::Redis {
59            limiter: Box::new(limiter),
60            config,
61        })
62    }
63
64    /// Get the rate limit configuration.
65    pub fn config(&self) -> &RateLimitConfig {
66        match self {
67            Self::InMemory { config, .. } => config,
68            #[cfg(feature = "redis")]
69            Self::Redis { config, .. } => config,
70        }
71    }
72
73    /// Check if a request from the given client should be allowed.
74    ///
75    /// Returns `true` if the request is allowed.
76    pub async fn check_rate_limit(&self, client_key: &str) -> bool {
77        match self {
78            Self::InMemory { limiter, config } => {
79                if !config.enabled {
80                    return true;
81                }
82                limiter.check_rate_limit(client_key).await
83            }
84            #[cfg(feature = "redis")]
85            Self::Redis { limiter, config } => {
86                if !config.enabled {
87                    return true;
88                }
89                limiter.check_rate_limit(client_key).await.allowed
90            }
91        }
92    }
93
94    /// Get remaining requests for a client key.
95    pub async fn remaining(&self, client_key: &str) -> u32 {
96        match self {
97            Self::InMemory { limiter, config } => {
98                if !config.enabled {
99                    return config.max_requests;
100                }
101                limiter.remaining(client_key).await
102            }
103            #[cfg(feature = "redis")]
104            Self::Redis { limiter, config } => {
105                if !config.enabled {
106                    return config.max_requests;
107                }
108                limiter.remaining(client_key).await
109            }
110        }
111    }
112
113    /// Clean up expired records (only applicable to in-memory backend).
114    ///
115    /// For Redis, TTL-based expiry is handled automatically by Redis.
116    pub async fn cleanup_expired(&self) {
117        match self {
118            Self::InMemory { limiter, .. } => {
119                limiter.cleanup_expired().await;
120            }
121            #[cfg(feature = "redis")]
122            Self::Redis { .. } => {
123                // Redis handles expiry automatically via TTL
124            }
125        }
126    }
127
128    /// Return a human-readable description of the backend type.
129    pub fn backend_name(&self) -> &'static str {
130        match self {
131            Self::InMemory { .. } => "in-memory",
132            #[cfg(feature = "redis")]
133            Self::Redis { .. } => "redis",
134        }
135    }
136}
137
138/// Rate limiting middleware that works with any `RateLimitBackend`.
139///
140/// This replaces the original `rate_limit_middleware` when using the
141/// backend abstraction. It is added to the router via
142/// `axum::middleware::from_fn(backend_rate_limit_middleware)` and expects
143/// `RateLimitBackend` to be available as an `Extension`.
144pub async fn backend_rate_limit_middleware(
145    axum::Extension(backend): axum::Extension<RateLimitBackend>,
146    request: Request<Body>,
147    next: Next,
148) -> Response {
149    let config = backend.config();
150
151    // Check if rate limiting is enabled
152    if !config.enabled {
153        return next.run(request).await;
154    }
155
156    // Check if path is exempt
157    let path = request.uri().path();
158    if config.exempt_paths.iter().any(|p| path.starts_with(p)) {
159        return next.run(request).await;
160    }
161
162    // Get client identifier (IP address or fallback)
163    let client_key = extract_client_key(&request);
164    let max_requests = config.max_requests;
165    let window_secs = config.window.as_secs();
166
167    // Check rate limit
168    if backend.check_rate_limit(&client_key).await {
169        let remaining = backend.remaining(&client_key).await;
170        let mut response = next.run(request).await;
171
172        // Add rate limit headers
173        let headers = response.headers_mut();
174        headers.insert("X-RateLimit-Limit", HeaderValue::from(max_requests));
175        headers.insert("X-RateLimit-Remaining", HeaderValue::from(remaining));
176
177        response
178    } else {
179        // Rate limited
180        (
181            StatusCode::TOO_MANY_REQUESTS,
182            [
183                ("X-RateLimit-Limit", max_requests.to_string()),
184                ("X-RateLimit-Remaining", "0".to_string()),
185                ("Retry-After", window_secs.to_string()),
186            ],
187            format!(
188                "Rate limit exceeded. Max {} requests per {} seconds.",
189                max_requests, window_secs
190            ),
191        )
192            .into_response()
193    }
194}
195
196/// Extract client identifier from request.
197fn extract_client_key(request: &Request<Body>) -> String {
198    // Try X-Forwarded-For header (for proxied requests)
199    if let Some(forwarded) = request.headers().get("X-Forwarded-For") {
200        if let Ok(s) = forwarded.to_str() {
201            if let Some(ip) = s.split(',').next() {
202                return ip.trim().to_string();
203            }
204        }
205    }
206
207    // Try X-Real-IP header
208    if let Some(real_ip) = request.headers().get("X-Real-IP") {
209        if let Ok(s) = real_ip.to_str() {
210            return s.to_string();
211        }
212    }
213
214    // Fallback to a default
215    "unknown".to_string()
216}
217
218#[cfg(test)]
219#[allow(clippy::unwrap_used)]
220mod tests {
221    use super::*;
222    use axum::{body::Body, http::Request, middleware, routing::get, Router};
223    use tower::ServiceExt;
224
225    async fn test_handler() -> &'static str {
226        "ok"
227    }
228
229    fn test_router_with_backend(config: RateLimitConfig) -> Router {
230        let backend = RateLimitBackend::in_memory(config);
231        Router::new()
232            .route("/api/test", get(test_handler))
233            .route("/health", get(test_handler))
234            .layer(middleware::from_fn(backend_rate_limit_middleware))
235            .layer(axum::Extension(backend))
236    }
237
238    #[tokio::test]
239    async fn test_backend_rate_limit_disabled() {
240        let config = RateLimitConfig::default(); // disabled by default
241        let router = test_router_with_backend(config);
242
243        let request = Request::builder()
244            .uri("/api/test")
245            .body(Body::empty())
246            .unwrap();
247
248        let response = router.oneshot(request).await.unwrap();
249        assert_eq!(response.status(), StatusCode::OK);
250    }
251
252    #[tokio::test]
253    async fn test_backend_rate_limit_allows_under_limit() {
254        let config = RateLimitConfig::new(5, 60);
255        let router = test_router_with_backend(config);
256
257        for _ in 0..3 {
258            let router = router.clone();
259            let request = Request::builder()
260                .uri("/api/test")
261                .header("X-Forwarded-For", "192.168.1.1")
262                .body(Body::empty())
263                .unwrap();
264
265            let response = router.oneshot(request).await.unwrap();
266            assert_eq!(response.status(), StatusCode::OK);
267        }
268    }
269
270    #[tokio::test]
271    async fn test_backend_rate_limit_blocks_over_limit() {
272        let config = RateLimitConfig::new(2, 60);
273        let backend = RateLimitBackend::in_memory(config.clone());
274
275        let router = Router::new()
276            .route("/api/test", get(test_handler))
277            .layer(middleware::from_fn(backend_rate_limit_middleware))
278            .layer(axum::Extension(backend));
279
280        for i in 0..3 {
281            let router = router.clone();
282            let request = Request::builder()
283                .uri("/api/test")
284                .header("X-Forwarded-For", "192.168.1.100")
285                .body(Body::empty())
286                .unwrap();
287
288            let response = router.oneshot(request).await.unwrap();
289            if i < 2 {
290                assert_eq!(response.status(), StatusCode::OK);
291            } else {
292                assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
293            }
294        }
295    }
296
297    #[tokio::test]
298    async fn test_backend_rate_limit_exempt_path() {
299        let config = RateLimitConfig::new(1, 60);
300        let backend = RateLimitBackend::in_memory(config);
301
302        let router = Router::new()
303            .route("/api/test", get(test_handler))
304            .route("/health", get(test_handler))
305            .layer(middleware::from_fn(backend_rate_limit_middleware))
306            .layer(axum::Extension(backend));
307
308        // Exhaust rate limit on /api/test
309        let request = Request::builder()
310            .uri("/api/test")
311            .header("X-Forwarded-For", "192.168.1.200")
312            .body(Body::empty())
313            .unwrap();
314        let _ = router.clone().oneshot(request).await.unwrap();
315
316        // /health should still work (exempt)
317        let request = Request::builder()
318            .uri("/health")
319            .header("X-Forwarded-For", "192.168.1.200")
320            .body(Body::empty())
321            .unwrap();
322        let response = router.oneshot(request).await.unwrap();
323        assert_eq!(response.status(), StatusCode::OK);
324    }
325
326    #[test]
327    fn test_backend_name_in_memory() {
328        let config = RateLimitConfig::default();
329        let backend = RateLimitBackend::in_memory(config);
330        assert_eq!(backend.backend_name(), "in-memory");
331    }
332
333    #[tokio::test]
334    async fn test_backend_cleanup_in_memory() {
335        let config = RateLimitConfig::new(10, 1);
336        let backend = RateLimitBackend::in_memory(config);
337
338        // Should not panic for in-memory
339        backend.cleanup_expired().await;
340    }
341}