datasynth_server/rest/
rate_limit_backend.rs1use axum::{
8 body::Body,
9 http::{Request, StatusCode},
10 middleware::Next,
11 response::{IntoResponse, Response},
12};
13
14use super::rate_limit::{RateLimitConfig, RateLimiter};
15#[cfg(feature = "redis")]
16use super::redis_rate_limit::RedisRateLimiter;
17
18#[derive(Clone)]
26pub enum RateLimitBackend {
27 InMemory {
29 limiter: RateLimiter,
30 config: RateLimitConfig,
31 },
32 #[cfg(feature = "redis")]
34 Redis {
35 limiter: Box<RedisRateLimiter>,
36 config: RateLimitConfig,
37 },
38}
39
40impl RateLimitBackend {
41 pub fn in_memory(config: RateLimitConfig) -> Self {
43 let limiter = RateLimiter::new(config.clone());
44 Self::InMemory { limiter, config }
45 }
46
47 #[cfg(feature = "redis")]
53 pub async fn redis(
54 redis_url: &str,
55 config: RateLimitConfig,
56 ) -> Result<Self, redis::RedisError> {
57 let limiter = RedisRateLimiter::new(redis_url, config.max_requests, config.window).await?;
58 Ok(Self::Redis {
59 limiter: Box::new(limiter),
60 config,
61 })
62 }
63
64 pub fn config(&self) -> &RateLimitConfig {
66 match self {
67 Self::InMemory { config, .. } => config,
68 #[cfg(feature = "redis")]
69 Self::Redis { config, .. } => config,
70 }
71 }
72
73 pub async fn check_rate_limit(&self, client_key: &str) -> bool {
77 match self {
78 Self::InMemory { limiter, config } => {
79 if !config.enabled {
80 return true;
81 }
82 limiter.check_rate_limit(client_key).await
83 }
84 #[cfg(feature = "redis")]
85 Self::Redis { limiter, config } => {
86 if !config.enabled {
87 return true;
88 }
89 limiter.check_rate_limit(client_key).await.allowed
90 }
91 }
92 }
93
94 pub async fn remaining(&self, client_key: &str) -> u32 {
96 match self {
97 Self::InMemory { limiter, config } => {
98 if !config.enabled {
99 return config.max_requests;
100 }
101 limiter.remaining(client_key).await
102 }
103 #[cfg(feature = "redis")]
104 Self::Redis { limiter, config } => {
105 if !config.enabled {
106 return config.max_requests;
107 }
108 limiter.remaining(client_key).await
109 }
110 }
111 }
112
113 pub async fn cleanup_expired(&self) {
117 match self {
118 Self::InMemory { limiter, .. } => {
119 limiter.cleanup_expired().await;
120 }
121 #[cfg(feature = "redis")]
122 Self::Redis { .. } => {
123 }
125 }
126 }
127
128 pub fn backend_name(&self) -> &'static str {
130 match self {
131 Self::InMemory { .. } => "in-memory",
132 #[cfg(feature = "redis")]
133 Self::Redis { .. } => "redis",
134 }
135 }
136}
137
138pub async fn backend_rate_limit_middleware(
145 axum::Extension(backend): axum::Extension<RateLimitBackend>,
146 request: Request<Body>,
147 next: Next,
148) -> Response {
149 let config = backend.config();
150
151 if !config.enabled {
153 return next.run(request).await;
154 }
155
156 let path = request.uri().path();
158 if config.exempt_paths.iter().any(|p| path.starts_with(p)) {
159 return next.run(request).await;
160 }
161
162 let client_key = extract_client_key(&request);
164 let max_requests = config.max_requests;
165 let window_secs = config.window.as_secs();
166
167 if backend.check_rate_limit(&client_key).await {
169 let remaining = backend.remaining(&client_key).await;
170 let mut response = next.run(request).await;
171
172 let headers = response.headers_mut();
174 headers.insert(
175 "X-RateLimit-Limit",
176 max_requests.to_string().parse().unwrap(),
177 );
178 headers.insert(
179 "X-RateLimit-Remaining",
180 remaining.to_string().parse().unwrap(),
181 );
182
183 response
184 } else {
185 (
187 StatusCode::TOO_MANY_REQUESTS,
188 [
189 ("X-RateLimit-Limit", max_requests.to_string()),
190 ("X-RateLimit-Remaining", "0".to_string()),
191 ("Retry-After", window_secs.to_string()),
192 ],
193 format!(
194 "Rate limit exceeded. Max {} requests per {} seconds.",
195 max_requests, window_secs
196 ),
197 )
198 .into_response()
199 }
200}
201
202fn extract_client_key(request: &Request<Body>) -> String {
204 if let Some(forwarded) = request.headers().get("X-Forwarded-For") {
206 if let Ok(s) = forwarded.to_str() {
207 if let Some(ip) = s.split(',').next() {
208 return ip.trim().to_string();
209 }
210 }
211 }
212
213 if let Some(real_ip) = request.headers().get("X-Real-IP") {
215 if let Ok(s) = real_ip.to_str() {
216 return s.to_string();
217 }
218 }
219
220 "unknown".to_string()
222}
223
224#[cfg(test)]
225#[allow(clippy::unwrap_used)]
226mod tests {
227 use super::*;
228 use axum::{body::Body, http::Request, middleware, routing::get, Router};
229 use tower::ServiceExt;
230
231 async fn test_handler() -> &'static str {
232 "ok"
233 }
234
235 fn test_router_with_backend(config: RateLimitConfig) -> Router {
236 let backend = RateLimitBackend::in_memory(config);
237 Router::new()
238 .route("/api/test", get(test_handler))
239 .route("/health", get(test_handler))
240 .layer(middleware::from_fn(backend_rate_limit_middleware))
241 .layer(axum::Extension(backend))
242 }
243
244 #[tokio::test]
245 async fn test_backend_rate_limit_disabled() {
246 let config = RateLimitConfig::default(); let router = test_router_with_backend(config);
248
249 let request = Request::builder()
250 .uri("/api/test")
251 .body(Body::empty())
252 .unwrap();
253
254 let response = router.oneshot(request).await.unwrap();
255 assert_eq!(response.status(), StatusCode::OK);
256 }
257
258 #[tokio::test]
259 async fn test_backend_rate_limit_allows_under_limit() {
260 let config = RateLimitConfig::new(5, 60);
261 let router = test_router_with_backend(config);
262
263 for _ in 0..3 {
264 let router = router.clone();
265 let request = Request::builder()
266 .uri("/api/test")
267 .header("X-Forwarded-For", "192.168.1.1")
268 .body(Body::empty())
269 .unwrap();
270
271 let response = router.oneshot(request).await.unwrap();
272 assert_eq!(response.status(), StatusCode::OK);
273 }
274 }
275
276 #[tokio::test]
277 async fn test_backend_rate_limit_blocks_over_limit() {
278 let config = RateLimitConfig::new(2, 60);
279 let backend = RateLimitBackend::in_memory(config.clone());
280
281 let router = Router::new()
282 .route("/api/test", get(test_handler))
283 .layer(middleware::from_fn(backend_rate_limit_middleware))
284 .layer(axum::Extension(backend));
285
286 for i in 0..3 {
287 let router = router.clone();
288 let request = Request::builder()
289 .uri("/api/test")
290 .header("X-Forwarded-For", "192.168.1.100")
291 .body(Body::empty())
292 .unwrap();
293
294 let response = router.oneshot(request).await.unwrap();
295 if i < 2 {
296 assert_eq!(response.status(), StatusCode::OK);
297 } else {
298 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
299 }
300 }
301 }
302
303 #[tokio::test]
304 async fn test_backend_rate_limit_exempt_path() {
305 let config = RateLimitConfig::new(1, 60);
306 let backend = RateLimitBackend::in_memory(config);
307
308 let router = Router::new()
309 .route("/api/test", get(test_handler))
310 .route("/health", get(test_handler))
311 .layer(middleware::from_fn(backend_rate_limit_middleware))
312 .layer(axum::Extension(backend));
313
314 let request = Request::builder()
316 .uri("/api/test")
317 .header("X-Forwarded-For", "192.168.1.200")
318 .body(Body::empty())
319 .unwrap();
320 let _ = router.clone().oneshot(request).await.unwrap();
321
322 let request = Request::builder()
324 .uri("/health")
325 .header("X-Forwarded-For", "192.168.1.200")
326 .body(Body::empty())
327 .unwrap();
328 let response = router.oneshot(request).await.unwrap();
329 assert_eq!(response.status(), StatusCode::OK);
330 }
331
332 #[test]
333 fn test_backend_name_in_memory() {
334 let config = RateLimitConfig::default();
335 let backend = RateLimitBackend::in_memory(config);
336 assert_eq!(backend.backend_name(), "in-memory");
337 }
338
339 #[tokio::test]
340 async fn test_backend_cleanup_in_memory() {
341 let config = RateLimitConfig::new(10, 1);
342 let backend = RateLimitBackend::in_memory(config);
343
344 backend.cleanup_expired().await;
346 }
347}