datasynth_server/rest/
rate_limit_backend.rs1use axum::{
8 body::Body,
9 http::{header::HeaderValue, Request, StatusCode},
10 middleware::Next,
11 response::{IntoResponse, Response},
12};
13
14use super::rate_limit::{RateLimitConfig, RateLimiter};
15#[cfg(feature = "redis")]
16use super::redis_rate_limit::RedisRateLimiter;
17
18#[derive(Clone)]
26pub enum RateLimitBackend {
27 InMemory {
29 limiter: RateLimiter,
30 config: RateLimitConfig,
31 },
32 #[cfg(feature = "redis")]
34 Redis {
35 limiter: Box<RedisRateLimiter>,
36 config: RateLimitConfig,
37 },
38}
39
40impl RateLimitBackend {
41 pub fn in_memory(config: RateLimitConfig) -> Self {
43 let limiter = RateLimiter::new(config.clone());
44 Self::InMemory { limiter, config }
45 }
46
47 #[cfg(feature = "redis")]
53 pub async fn redis(
54 redis_url: &str,
55 config: RateLimitConfig,
56 ) -> Result<Self, redis::RedisError> {
57 let limiter = RedisRateLimiter::new(redis_url, config.max_requests, config.window).await?;
58 Ok(Self::Redis {
59 limiter: Box::new(limiter),
60 config,
61 })
62 }
63
64 pub fn config(&self) -> &RateLimitConfig {
66 match self {
67 Self::InMemory { config, .. } => config,
68 #[cfg(feature = "redis")]
69 Self::Redis { config, .. } => config,
70 }
71 }
72
73 pub async fn check_rate_limit(&self, client_key: &str) -> bool {
77 match self {
78 Self::InMemory { limiter, config } => {
79 if !config.enabled {
80 return true;
81 }
82 limiter.check_rate_limit(client_key).await
83 }
84 #[cfg(feature = "redis")]
85 Self::Redis { limiter, config } => {
86 if !config.enabled {
87 return true;
88 }
89 limiter.check_rate_limit(client_key).await.allowed
90 }
91 }
92 }
93
94 pub async fn remaining(&self, client_key: &str) -> u32 {
96 match self {
97 Self::InMemory { limiter, config } => {
98 if !config.enabled {
99 return config.max_requests;
100 }
101 limiter.remaining(client_key).await
102 }
103 #[cfg(feature = "redis")]
104 Self::Redis { limiter, config } => {
105 if !config.enabled {
106 return config.max_requests;
107 }
108 limiter.remaining(client_key).await
109 }
110 }
111 }
112
113 pub async fn cleanup_expired(&self) {
117 match self {
118 Self::InMemory { limiter, .. } => {
119 limiter.cleanup_expired().await;
120 }
121 #[cfg(feature = "redis")]
122 Self::Redis { .. } => {
123 }
125 }
126 }
127
128 pub fn backend_name(&self) -> &'static str {
130 match self {
131 Self::InMemory { .. } => "in-memory",
132 #[cfg(feature = "redis")]
133 Self::Redis { .. } => "redis",
134 }
135 }
136}
137
138pub async fn backend_rate_limit_middleware(
145 axum::Extension(backend): axum::Extension<RateLimitBackend>,
146 request: Request<Body>,
147 next: Next,
148) -> Response {
149 let config = backend.config();
150
151 if !config.enabled {
153 return next.run(request).await;
154 }
155
156 let path = request.uri().path();
158 if config.exempt_paths.iter().any(|p| path.starts_with(p)) {
159 return next.run(request).await;
160 }
161
162 let client_key = extract_client_key(&request);
164 let max_requests = config.max_requests;
165 let window_secs = config.window.as_secs();
166
167 if backend.check_rate_limit(&client_key).await {
169 let remaining = backend.remaining(&client_key).await;
170 let mut response = next.run(request).await;
171
172 let headers = response.headers_mut();
174 headers.insert("X-RateLimit-Limit", HeaderValue::from(max_requests));
175 headers.insert("X-RateLimit-Remaining", HeaderValue::from(remaining));
176
177 response
178 } else {
179 (
181 StatusCode::TOO_MANY_REQUESTS,
182 [
183 ("X-RateLimit-Limit", max_requests.to_string()),
184 ("X-RateLimit-Remaining", "0".to_string()),
185 ("Retry-After", window_secs.to_string()),
186 ],
187 format!("Rate limit exceeded. Max {max_requests} requests per {window_secs} seconds."),
188 )
189 .into_response()
190 }
191}
192
193fn extract_client_key(request: &Request<Body>) -> String {
195 if let Some(forwarded) = request.headers().get("X-Forwarded-For") {
197 if let Ok(s) = forwarded.to_str() {
198 if let Some(ip) = s.split(',').next() {
199 return ip.trim().to_string();
200 }
201 }
202 }
203
204 if let Some(real_ip) = request.headers().get("X-Real-IP") {
206 if let Ok(s) = real_ip.to_str() {
207 return s.to_string();
208 }
209 }
210
211 "unknown".to_string()
213}
214
215#[cfg(test)]
216#[allow(clippy::unwrap_used)]
217mod tests {
218 use super::*;
219 use axum::{body::Body, http::Request, middleware, routing::get, Router};
220 use tower::ServiceExt;
221
222 async fn test_handler() -> &'static str {
223 "ok"
224 }
225
226 fn test_router_with_backend(config: RateLimitConfig) -> Router {
227 let backend = RateLimitBackend::in_memory(config);
228 Router::new()
229 .route("/api/test", get(test_handler))
230 .route("/health", get(test_handler))
231 .layer(middleware::from_fn(backend_rate_limit_middleware))
232 .layer(axum::Extension(backend))
233 }
234
235 #[tokio::test]
236 async fn test_backend_rate_limit_disabled() {
237 let config = RateLimitConfig::default(); let router = test_router_with_backend(config);
239
240 let request = Request::builder()
241 .uri("/api/test")
242 .body(Body::empty())
243 .unwrap();
244
245 let response = router.oneshot(request).await.unwrap();
246 assert_eq!(response.status(), StatusCode::OK);
247 }
248
249 #[tokio::test]
250 async fn test_backend_rate_limit_allows_under_limit() {
251 let config = RateLimitConfig::new(5, 60);
252 let router = test_router_with_backend(config);
253
254 for _ in 0..3 {
255 let router = router.clone();
256 let request = Request::builder()
257 .uri("/api/test")
258 .header("X-Forwarded-For", "192.168.1.1")
259 .body(Body::empty())
260 .unwrap();
261
262 let response = router.oneshot(request).await.unwrap();
263 assert_eq!(response.status(), StatusCode::OK);
264 }
265 }
266
267 #[tokio::test]
268 async fn test_backend_rate_limit_blocks_over_limit() {
269 let config = RateLimitConfig::new(2, 60);
270 let backend = RateLimitBackend::in_memory(config.clone());
271
272 let router = Router::new()
273 .route("/api/test", get(test_handler))
274 .layer(middleware::from_fn(backend_rate_limit_middleware))
275 .layer(axum::Extension(backend));
276
277 for i in 0..3 {
278 let router = router.clone();
279 let request = Request::builder()
280 .uri("/api/test")
281 .header("X-Forwarded-For", "192.168.1.100")
282 .body(Body::empty())
283 .unwrap();
284
285 let response = router.oneshot(request).await.unwrap();
286 if i < 2 {
287 assert_eq!(response.status(), StatusCode::OK);
288 } else {
289 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
290 }
291 }
292 }
293
294 #[tokio::test]
295 async fn test_backend_rate_limit_exempt_path() {
296 let config = RateLimitConfig::new(1, 60);
297 let backend = RateLimitBackend::in_memory(config);
298
299 let router = Router::new()
300 .route("/api/test", get(test_handler))
301 .route("/health", get(test_handler))
302 .layer(middleware::from_fn(backend_rate_limit_middleware))
303 .layer(axum::Extension(backend));
304
305 let request = Request::builder()
307 .uri("/api/test")
308 .header("X-Forwarded-For", "192.168.1.200")
309 .body(Body::empty())
310 .unwrap();
311 let _ = router.clone().oneshot(request).await.unwrap();
312
313 let request = Request::builder()
315 .uri("/health")
316 .header("X-Forwarded-For", "192.168.1.200")
317 .body(Body::empty())
318 .unwrap();
319 let response = router.oneshot(request).await.unwrap();
320 assert_eq!(response.status(), StatusCode::OK);
321 }
322
323 #[test]
324 fn test_backend_name_in_memory() {
325 let config = RateLimitConfig::default();
326 let backend = RateLimitBackend::in_memory(config);
327 assert_eq!(backend.backend_name(), "in-memory");
328 }
329
330 #[tokio::test]
331 async fn test_backend_cleanup_in_memory() {
332 let config = RateLimitConfig::new(10, 1);
333 let backend = RateLimitBackend::in_memory(config);
334
335 backend.cleanup_expired().await;
337 }
338}