datasynth_server/rest/
rate_limit_backend.rs1use axum::{
8 body::Body,
9 http::{header::HeaderValue, Request, StatusCode},
10 middleware::Next,
11 response::{IntoResponse, Response},
12};
13
14use super::rate_limit::{RateLimitConfig, RateLimiter};
15#[cfg(feature = "redis")]
16use super::redis_rate_limit::RedisRateLimiter;
17
18#[derive(Clone)]
26pub enum RateLimitBackend {
27 InMemory {
29 limiter: RateLimiter,
30 config: RateLimitConfig,
31 },
32 #[cfg(feature = "redis")]
34 Redis {
35 limiter: Box<RedisRateLimiter>,
36 config: RateLimitConfig,
37 },
38}
39
40impl RateLimitBackend {
41 pub fn in_memory(config: RateLimitConfig) -> Self {
43 let limiter = RateLimiter::new(config.clone());
44 Self::InMemory { limiter, config }
45 }
46
47 #[cfg(feature = "redis")]
53 pub async fn redis(
54 redis_url: &str,
55 config: RateLimitConfig,
56 ) -> Result<Self, redis::RedisError> {
57 let limiter = RedisRateLimiter::new(redis_url, config.max_requests, config.window).await?;
58 Ok(Self::Redis {
59 limiter: Box::new(limiter),
60 config,
61 })
62 }
63
64 pub fn config(&self) -> &RateLimitConfig {
66 match self {
67 Self::InMemory { config, .. } => config,
68 #[cfg(feature = "redis")]
69 Self::Redis { config, .. } => config,
70 }
71 }
72
73 pub async fn check_rate_limit(&self, client_key: &str) -> bool {
77 match self {
78 Self::InMemory { limiter, config } => {
79 if !config.enabled {
80 return true;
81 }
82 limiter.check_rate_limit(client_key).await
83 }
84 #[cfg(feature = "redis")]
85 Self::Redis { limiter, config } => {
86 if !config.enabled {
87 return true;
88 }
89 limiter.check_rate_limit(client_key).await.allowed
90 }
91 }
92 }
93
94 pub async fn remaining(&self, client_key: &str) -> u32 {
96 match self {
97 Self::InMemory { limiter, config } => {
98 if !config.enabled {
99 return config.max_requests;
100 }
101 limiter.remaining(client_key).await
102 }
103 #[cfg(feature = "redis")]
104 Self::Redis { limiter, config } => {
105 if !config.enabled {
106 return config.max_requests;
107 }
108 limiter.remaining(client_key).await
109 }
110 }
111 }
112
113 pub async fn cleanup_expired(&self) {
117 match self {
118 Self::InMemory { limiter, .. } => {
119 limiter.cleanup_expired().await;
120 }
121 #[cfg(feature = "redis")]
122 Self::Redis { .. } => {
123 }
125 }
126 }
127
128 pub fn backend_name(&self) -> &'static str {
130 match self {
131 Self::InMemory { .. } => "in-memory",
132 #[cfg(feature = "redis")]
133 Self::Redis { .. } => "redis",
134 }
135 }
136}
137
138pub async fn backend_rate_limit_middleware(
145 axum::Extension(backend): axum::Extension<RateLimitBackend>,
146 request: Request<Body>,
147 next: Next,
148) -> Response {
149 let config = backend.config();
150
151 if !config.enabled {
153 return next.run(request).await;
154 }
155
156 let path = request.uri().path();
158 if config.exempt_paths.iter().any(|p| path.starts_with(p)) {
159 return next.run(request).await;
160 }
161
162 let client_key = extract_client_key(&request);
164 let max_requests = config.max_requests;
165 let window_secs = config.window.as_secs();
166
167 if backend.check_rate_limit(&client_key).await {
169 let remaining = backend.remaining(&client_key).await;
170 let mut response = next.run(request).await;
171
172 let headers = response.headers_mut();
174 headers.insert("X-RateLimit-Limit", HeaderValue::from(max_requests));
175 headers.insert("X-RateLimit-Remaining", HeaderValue::from(remaining));
176
177 response
178 } else {
179 (
181 StatusCode::TOO_MANY_REQUESTS,
182 [
183 ("X-RateLimit-Limit", max_requests.to_string()),
184 ("X-RateLimit-Remaining", "0".to_string()),
185 ("Retry-After", window_secs.to_string()),
186 ],
187 format!(
188 "Rate limit exceeded. Max {} requests per {} seconds.",
189 max_requests, window_secs
190 ),
191 )
192 .into_response()
193 }
194}
195
196fn extract_client_key(request: &Request<Body>) -> String {
198 if let Some(forwarded) = request.headers().get("X-Forwarded-For") {
200 if let Ok(s) = forwarded.to_str() {
201 if let Some(ip) = s.split(',').next() {
202 return ip.trim().to_string();
203 }
204 }
205 }
206
207 if let Some(real_ip) = request.headers().get("X-Real-IP") {
209 if let Ok(s) = real_ip.to_str() {
210 return s.to_string();
211 }
212 }
213
214 "unknown".to_string()
216}
217
218#[cfg(test)]
219#[allow(clippy::unwrap_used)]
220mod tests {
221 use super::*;
222 use axum::{body::Body, http::Request, middleware, routing::get, Router};
223 use tower::ServiceExt;
224
225 async fn test_handler() -> &'static str {
226 "ok"
227 }
228
229 fn test_router_with_backend(config: RateLimitConfig) -> Router {
230 let backend = RateLimitBackend::in_memory(config);
231 Router::new()
232 .route("/api/test", get(test_handler))
233 .route("/health", get(test_handler))
234 .layer(middleware::from_fn(backend_rate_limit_middleware))
235 .layer(axum::Extension(backend))
236 }
237
238 #[tokio::test]
239 async fn test_backend_rate_limit_disabled() {
240 let config = RateLimitConfig::default(); let router = test_router_with_backend(config);
242
243 let request = Request::builder()
244 .uri("/api/test")
245 .body(Body::empty())
246 .unwrap();
247
248 let response = router.oneshot(request).await.unwrap();
249 assert_eq!(response.status(), StatusCode::OK);
250 }
251
252 #[tokio::test]
253 async fn test_backend_rate_limit_allows_under_limit() {
254 let config = RateLimitConfig::new(5, 60);
255 let router = test_router_with_backend(config);
256
257 for _ in 0..3 {
258 let router = router.clone();
259 let request = Request::builder()
260 .uri("/api/test")
261 .header("X-Forwarded-For", "192.168.1.1")
262 .body(Body::empty())
263 .unwrap();
264
265 let response = router.oneshot(request).await.unwrap();
266 assert_eq!(response.status(), StatusCode::OK);
267 }
268 }
269
270 #[tokio::test]
271 async fn test_backend_rate_limit_blocks_over_limit() {
272 let config = RateLimitConfig::new(2, 60);
273 let backend = RateLimitBackend::in_memory(config.clone());
274
275 let router = Router::new()
276 .route("/api/test", get(test_handler))
277 .layer(middleware::from_fn(backend_rate_limit_middleware))
278 .layer(axum::Extension(backend));
279
280 for i in 0..3 {
281 let router = router.clone();
282 let request = Request::builder()
283 .uri("/api/test")
284 .header("X-Forwarded-For", "192.168.1.100")
285 .body(Body::empty())
286 .unwrap();
287
288 let response = router.oneshot(request).await.unwrap();
289 if i < 2 {
290 assert_eq!(response.status(), StatusCode::OK);
291 } else {
292 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
293 }
294 }
295 }
296
297 #[tokio::test]
298 async fn test_backend_rate_limit_exempt_path() {
299 let config = RateLimitConfig::new(1, 60);
300 let backend = RateLimitBackend::in_memory(config);
301
302 let router = Router::new()
303 .route("/api/test", get(test_handler))
304 .route("/health", get(test_handler))
305 .layer(middleware::from_fn(backend_rate_limit_middleware))
306 .layer(axum::Extension(backend));
307
308 let request = Request::builder()
310 .uri("/api/test")
311 .header("X-Forwarded-For", "192.168.1.200")
312 .body(Body::empty())
313 .unwrap();
314 let _ = router.clone().oneshot(request).await.unwrap();
315
316 let request = Request::builder()
318 .uri("/health")
319 .header("X-Forwarded-For", "192.168.1.200")
320 .body(Body::empty())
321 .unwrap();
322 let response = router.oneshot(request).await.unwrap();
323 assert_eq!(response.status(), StatusCode::OK);
324 }
325
326 #[test]
327 fn test_backend_name_in_memory() {
328 let config = RateLimitConfig::default();
329 let backend = RateLimitBackend::in_memory(config);
330 assert_eq!(backend.backend_name(), "in-memory");
331 }
332
333 #[tokio::test]
334 async fn test_backend_cleanup_in_memory() {
335 let config = RateLimitConfig::new(10, 1);
336 let backend = RateLimitBackend::in_memory(config);
337
338 backend.cleanup_expired().await;
340 }
341}