datasynth_server/rest/
rate_limit_backend.rs1use axum::{
8 body::Body,
9 http::{header::HeaderValue, Request, StatusCode},
10 middleware::Next,
11 response::{IntoResponse, Response},
12};
13
14use super::rate_limit::{RateLimitConfig, RateLimiter};
15#[cfg(feature = "redis")]
16use super::redis_rate_limit::RedisRateLimiter;
17
18#[derive(Clone)]
26pub enum RateLimitBackend {
27 InMemory {
29 limiter: RateLimiter,
30 config: RateLimitConfig,
31 },
32 #[cfg(feature = "redis")]
34 Redis {
35 limiter: Box<RedisRateLimiter>,
36 config: RateLimitConfig,
37 },
38}
39
40impl RateLimitBackend {
41 pub fn in_memory(config: RateLimitConfig) -> Self {
43 let limiter = RateLimiter::new(config.clone());
44 Self::InMemory { limiter, config }
45 }
46
47 #[cfg(feature = "redis")]
53 pub async fn redis(
54 redis_url: &str,
55 config: RateLimitConfig,
56 ) -> Result<Self, redis::RedisError> {
57 let limiter = RedisRateLimiter::new(redis_url, config.max_requests, config.window).await?;
58 Ok(Self::Redis {
59 limiter: Box::new(limiter),
60 config,
61 })
62 }
63
64 pub fn config(&self) -> &RateLimitConfig {
66 match self {
67 Self::InMemory { config, .. } => config,
68 #[cfg(feature = "redis")]
69 Self::Redis { config, .. } => config,
70 }
71 }
72
73 pub async fn check_rate_limit(&self, client_key: &str) -> bool {
77 match self {
78 Self::InMemory { limiter, config } => {
79 if !config.enabled {
80 return true;
81 }
82 limiter.check_rate_limit(client_key).await
83 }
84 #[cfg(feature = "redis")]
85 Self::Redis { limiter, config } => {
86 if !config.enabled {
87 return true;
88 }
89 limiter.check_rate_limit(client_key).await.allowed
90 }
91 }
92 }
93
94 pub async fn remaining(&self, client_key: &str) -> u32 {
96 match self {
97 Self::InMemory { limiter, config } => {
98 if !config.enabled {
99 return config.max_requests;
100 }
101 limiter.remaining(client_key).await
102 }
103 #[cfg(feature = "redis")]
104 Self::Redis { limiter, config } => {
105 if !config.enabled {
106 return config.max_requests;
107 }
108 limiter.remaining(client_key).await
109 }
110 }
111 }
112
113 pub async fn cleanup_expired(&self) {
117 match self {
118 Self::InMemory { limiter, .. } => {
119 limiter.cleanup_expired().await;
120 }
121 #[cfg(feature = "redis")]
122 Self::Redis { .. } => {
123 }
125 }
126 }
127
128 pub fn backend_name(&self) -> &'static str {
130 match self {
131 Self::InMemory { .. } => "in-memory",
132 #[cfg(feature = "redis")]
133 Self::Redis { .. } => "redis",
134 }
135 }
136}
137
138pub async fn backend_rate_limit_middleware(
145 axum::Extension(backend): axum::Extension<RateLimitBackend>,
146 request: Request<Body>,
147 next: Next,
148) -> Response {
149 let config = backend.config();
150
151 if !config.enabled {
153 return next.run(request).await;
154 }
155
156 let path = request.uri().path();
158 if config.exempt_paths.iter().any(|p| path.starts_with(p)) {
159 return next.run(request).await;
160 }
161
162 let client_key = extract_client_key(&request);
164 let max_requests = config.max_requests;
165 let window_secs = config.window.as_secs();
166
167 if backend.check_rate_limit(&client_key).await {
169 let remaining = backend.remaining(&client_key).await;
170 let mut response = next.run(request).await;
171
172 let headers = response.headers_mut();
174 headers.insert("X-RateLimit-Limit", HeaderValue::from(max_requests));
175 headers.insert("X-RateLimit-Remaining", HeaderValue::from(remaining));
176
177 response
178 } else {
179 (
181 StatusCode::TOO_MANY_REQUESTS,
182 [
183 ("X-RateLimit-Limit", max_requests.to_string()),
184 ("X-RateLimit-Remaining", "0".to_string()),
185 ("Retry-After", window_secs.to_string()),
186 ],
187 format!("Rate limit exceeded. Max {max_requests} requests per {window_secs} seconds."),
188 )
189 .into_response()
190 }
191}
192
193fn extract_client_key(request: &Request<Body>) -> String {
195 if let Some(forwarded) = request.headers().get("X-Forwarded-For") {
197 if let Ok(s) = forwarded.to_str() {
198 if let Some(ip) = s.split(',').next() {
199 return ip.trim().to_string();
200 }
201 }
202 }
203
204 if let Some(real_ip) = request.headers().get("X-Real-IP") {
206 if let Ok(s) = real_ip.to_str() {
207 return s.to_string();
208 }
209 }
210
211 "unknown".to_string()
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use axum::{body::Body, http::Request, middleware, routing::get, Router};
219 use tower::ServiceExt;
220
221 async fn test_handler() -> &'static str {
222 "ok"
223 }
224
225 fn test_router_with_backend(config: RateLimitConfig) -> Router {
226 let backend = RateLimitBackend::in_memory(config);
227 Router::new()
228 .route("/api/test", get(test_handler))
229 .route("/health", get(test_handler))
230 .layer(middleware::from_fn(backend_rate_limit_middleware))
231 .layer(axum::Extension(backend))
232 }
233
234 #[tokio::test]
235 async fn test_backend_rate_limit_disabled() {
236 let config = RateLimitConfig::default(); let router = test_router_with_backend(config);
238
239 let request = Request::builder()
240 .uri("/api/test")
241 .body(Body::empty())
242 .unwrap();
243
244 let response = router.oneshot(request).await.unwrap();
245 assert_eq!(response.status(), StatusCode::OK);
246 }
247
248 #[tokio::test]
249 async fn test_backend_rate_limit_allows_under_limit() {
250 let config = RateLimitConfig::new(5, 60);
251 let router = test_router_with_backend(config);
252
253 for _ in 0..3 {
254 let router = router.clone();
255 let request = Request::builder()
256 .uri("/api/test")
257 .header("X-Forwarded-For", "192.168.1.1")
258 .body(Body::empty())
259 .unwrap();
260
261 let response = router.oneshot(request).await.unwrap();
262 assert_eq!(response.status(), StatusCode::OK);
263 }
264 }
265
266 #[tokio::test]
267 async fn test_backend_rate_limit_blocks_over_limit() {
268 let config = RateLimitConfig::new(2, 60);
269 let backend = RateLimitBackend::in_memory(config.clone());
270
271 let router = Router::new()
272 .route("/api/test", get(test_handler))
273 .layer(middleware::from_fn(backend_rate_limit_middleware))
274 .layer(axum::Extension(backend));
275
276 for i in 0..3 {
277 let router = router.clone();
278 let request = Request::builder()
279 .uri("/api/test")
280 .header("X-Forwarded-For", "192.168.1.100")
281 .body(Body::empty())
282 .unwrap();
283
284 let response = router.oneshot(request).await.unwrap();
285 if i < 2 {
286 assert_eq!(response.status(), StatusCode::OK);
287 } else {
288 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
289 }
290 }
291 }
292
293 #[tokio::test]
294 async fn test_backend_rate_limit_exempt_path() {
295 let config = RateLimitConfig::new(1, 60);
296 let backend = RateLimitBackend::in_memory(config);
297
298 let router = Router::new()
299 .route("/api/test", get(test_handler))
300 .route("/health", get(test_handler))
301 .layer(middleware::from_fn(backend_rate_limit_middleware))
302 .layer(axum::Extension(backend));
303
304 let request = Request::builder()
306 .uri("/api/test")
307 .header("X-Forwarded-For", "192.168.1.200")
308 .body(Body::empty())
309 .unwrap();
310 let _ = router.clone().oneshot(request).await.unwrap();
311
312 let request = Request::builder()
314 .uri("/health")
315 .header("X-Forwarded-For", "192.168.1.200")
316 .body(Body::empty())
317 .unwrap();
318 let response = router.oneshot(request).await.unwrap();
319 assert_eq!(response.status(), StatusCode::OK);
320 }
321
322 #[test]
323 fn test_backend_name_in_memory() {
324 let config = RateLimitConfig::default();
325 let backend = RateLimitBackend::in_memory(config);
326 assert_eq!(backend.backend_name(), "in-memory");
327 }
328
329 #[tokio::test]
330 async fn test_backend_cleanup_in_memory() {
331 let config = RateLimitConfig::new(10, 1);
332 let backend = RateLimitBackend::in_memory(config);
333
334 backend.cleanup_expired().await;
336 }
337}