reinhardt_middleware/
redirect_fallback.rs1use async_trait::async_trait;
7use hyper::StatusCode;
8use regex::Regex;
9use reinhardt_http::{Handler, Middleware, Request, Response, Result};
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12
13#[non_exhaustive]
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct RedirectResponseConfig {
17 pub fallback_url: String,
19 pub path_patterns: Option<Vec<String>>,
21 pub redirect_status: Option<u16>,
23}
24
25impl RedirectResponseConfig {
26 pub fn new(fallback_url: String) -> Self {
37 Self {
38 fallback_url,
39 path_patterns: None,
40 redirect_status: None,
41 }
42 }
43
44 pub fn with_patterns(mut self, patterns: Vec<String>) -> Self {
55 self.path_patterns = Some(patterns);
56 self
57 }
58
59 pub fn with_status(mut self, status: u16) -> Self {
70 self.redirect_status = Some(status);
71 self
72 }
73}
74
75pub struct RedirectFallbackMiddleware {
118 config: RedirectResponseConfig,
119 compiled_patterns: Option<Vec<Regex>>,
120}
121
122impl RedirectFallbackMiddleware {
123 pub fn new(config: RedirectResponseConfig) -> Self {
134 let compiled_patterns = config
135 .path_patterns
136 .as_ref()
137 .map(|patterns| patterns.iter().filter_map(|p| Regex::new(p).ok()).collect());
138
139 Self {
140 config,
141 compiled_patterns,
142 }
143 }
144
145 fn matches_pattern(&self, path: &str) -> bool {
147 match &self.compiled_patterns {
148 None => true, Some(patterns) => patterns.iter().any(|re| re.is_match(path)),
150 }
151 }
152
153 fn redirect_status(&self) -> StatusCode {
155 self.config
156 .redirect_status
157 .and_then(|code| StatusCode::from_u16(code).ok())
158 .unwrap_or(StatusCode::FOUND)
159 }
160
161 fn should_redirect(&self, path: &str) -> bool {
163 path != self.config.fallback_url
165 }
166}
167
168#[async_trait]
169impl Middleware for RedirectFallbackMiddleware {
170 async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
171 let path = request.uri.path().to_string();
172
173 let response = handler.handle(request).await?;
175
176 if response.status != StatusCode::NOT_FOUND {
178 return Ok(response);
179 }
180
181 if !self.matches_pattern(&path) || !self.should_redirect(&path) {
183 return Ok(response);
184 }
185
186 let mut redirect_response = Response::new(self.redirect_status());
188 redirect_response.headers.insert(
189 hyper::header::LOCATION,
190 self.config
191 .fallback_url
192 .parse()
193 .unwrap_or_else(|_| hyper::header::HeaderValue::from_static("/")),
194 );
195
196 Ok(redirect_response)
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use bytes::Bytes;
204 use hyper::{HeaderMap, Method, StatusCode, Version};
205
206 struct NotFoundHandler;
207
208 #[async_trait]
209 impl Handler for NotFoundHandler {
210 async fn handle(&self, _request: Request) -> Result<Response> {
211 Ok(Response::new(StatusCode::NOT_FOUND))
212 }
213 }
214
215 struct OkHandler;
216
217 #[async_trait]
218 impl Handler for OkHandler {
219 async fn handle(&self, _request: Request) -> Result<Response> {
220 Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
221 }
222 }
223
224 #[tokio::test]
225 async fn test_redirect_on_404() {
226 let config = RedirectResponseConfig::new("/404".to_string());
227 let middleware = RedirectFallbackMiddleware::new(config);
228 let handler = Arc::new(NotFoundHandler);
229
230 let request = Request::builder()
231 .method(Method::GET)
232 .uri("/missing")
233 .version(Version::HTTP_11)
234 .headers(HeaderMap::new())
235 .body(Bytes::new())
236 .build()
237 .unwrap();
238
239 let response = middleware.process(request, handler).await.unwrap();
240
241 assert_eq!(response.status, StatusCode::FOUND);
242 assert_eq!(
243 response.headers.get(hyper::header::LOCATION).unwrap(),
244 "/404"
245 );
246 }
247
248 #[tokio::test]
249 async fn test_no_redirect_on_200() {
250 let config = RedirectResponseConfig::new("/404".to_string());
251 let middleware = RedirectFallbackMiddleware::new(config);
252 let handler = Arc::new(OkHandler);
253
254 let request = Request::builder()
255 .method(Method::GET)
256 .uri("/existing")
257 .version(Version::HTTP_11)
258 .headers(HeaderMap::new())
259 .body(Bytes::new())
260 .build()
261 .unwrap();
262
263 let response = middleware.process(request, handler).await.unwrap();
264
265 assert_eq!(response.status, StatusCode::OK);
266 assert!(!response.headers.contains_key(hyper::header::LOCATION));
267 }
268
269 #[tokio::test]
270 async fn test_pattern_matching_redirect() {
271 let config = RedirectResponseConfig::new("/404".to_string())
272 .with_patterns(vec!["/api/.*".to_string()]);
273 let middleware = RedirectFallbackMiddleware::new(config);
274 let handler = Arc::new(NotFoundHandler);
275
276 let request = Request::builder()
278 .method(Method::GET)
279 .uri("/api/missing")
280 .version(Version::HTTP_11)
281 .headers(HeaderMap::new())
282 .body(Bytes::new())
283 .build()
284 .unwrap();
285
286 let response = middleware.process(request, handler).await.unwrap();
287
288 assert_eq!(response.status, StatusCode::FOUND);
289 assert_eq!(
290 response.headers.get(hyper::header::LOCATION).unwrap(),
291 "/404"
292 );
293 }
294
295 #[tokio::test]
296 async fn test_pattern_no_match_no_redirect() {
297 let config = RedirectResponseConfig::new("/404".to_string())
298 .with_patterns(vec!["/api/.*".to_string()]);
299 let middleware = RedirectFallbackMiddleware::new(config);
300 let handler = Arc::new(NotFoundHandler);
301
302 let request = Request::builder()
304 .method(Method::GET)
305 .uri("/other/missing")
306 .version(Version::HTTP_11)
307 .headers(HeaderMap::new())
308 .body(Bytes::new())
309 .build()
310 .unwrap();
311
312 let response = middleware.process(request, handler).await.unwrap();
313
314 assert_eq!(response.status, StatusCode::NOT_FOUND);
315 assert!(!response.headers.contains_key(hyper::header::LOCATION));
316 }
317
318 #[tokio::test]
319 async fn test_custom_redirect_status() {
320 let config = RedirectResponseConfig::new("/404".to_string()).with_status(301);
321 let middleware = RedirectFallbackMiddleware::new(config);
322 let handler = Arc::new(NotFoundHandler);
323
324 let request = Request::builder()
325 .method(Method::GET)
326 .uri("/missing")
327 .version(Version::HTTP_11)
328 .headers(HeaderMap::new())
329 .body(Bytes::new())
330 .build()
331 .unwrap();
332
333 let response = middleware.process(request, handler).await.unwrap();
334
335 assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
336 assert_eq!(
337 response.headers.get(hyper::header::LOCATION).unwrap(),
338 "/404"
339 );
340 }
341
342 #[tokio::test]
343 async fn test_prevent_redirect_loop() {
344 let config = RedirectResponseConfig::new("/404".to_string());
345 let middleware = RedirectFallbackMiddleware::new(config);
346 let handler = Arc::new(NotFoundHandler);
347
348 let request = Request::builder()
350 .method(Method::GET)
351 .uri("/404")
352 .version(Version::HTTP_11)
353 .headers(HeaderMap::new())
354 .body(Bytes::new())
355 .build()
356 .unwrap();
357
358 let response = middleware.process(request, handler).await.unwrap();
359
360 assert_eq!(response.status, StatusCode::NOT_FOUND);
361 assert!(!response.headers.contains_key(hyper::header::LOCATION));
362 }
363
364 #[tokio::test]
365 async fn test_multiple_pattern_matching() {
366 let config = RedirectResponseConfig::new("/error".to_string())
367 .with_patterns(vec!["/api/.*".to_string(), "/v1/.*".to_string()]);
368 let middleware = RedirectFallbackMiddleware::new(config);
369 let handler = Arc::new(NotFoundHandler);
370
371 let request1 = Request::builder()
373 .method(Method::GET)
374 .uri("/api/test")
375 .version(Version::HTTP_11)
376 .headers(HeaderMap::new())
377 .body(Bytes::new())
378 .build()
379 .unwrap();
380
381 let response1 = middleware.process(request1, handler.clone()).await.unwrap();
382 assert_eq!(response1.status, StatusCode::FOUND);
383
384 let request2 = Request::builder()
386 .method(Method::GET)
387 .uri("/v1/test")
388 .version(Version::HTTP_11)
389 .headers(HeaderMap::new())
390 .body(Bytes::new())
391 .build()
392 .unwrap();
393
394 let response2 = middleware.process(request2, handler).await.unwrap();
395 assert_eq!(response2.status, StatusCode::FOUND);
396 }
397
398 #[tokio::test]
399 async fn test_different_http_methods() {
400 let config = RedirectResponseConfig::new("/404".to_string());
401 let middleware = RedirectFallbackMiddleware::new(config);
402 let handler = Arc::new(NotFoundHandler);
403
404 let request = Request::builder()
406 .method(Method::POST)
407 .uri("/missing")
408 .version(Version::HTTP_11)
409 .headers(HeaderMap::new())
410 .body(Bytes::new())
411 .build()
412 .unwrap();
413
414 let response = middleware.process(request, handler).await.unwrap();
415
416 assert_eq!(response.status, StatusCode::FOUND);
417 assert_eq!(
418 response.headers.get(hyper::header::LOCATION).unwrap(),
419 "/404"
420 );
421 }
422
423 #[tokio::test]
424 async fn test_no_patterns_matches_all() {
425 let config = RedirectResponseConfig::new("/fallback".to_string());
426 let middleware = RedirectFallbackMiddleware::new(config);
427 let handler = Arc::new(NotFoundHandler);
428
429 let paths = vec!["/api/test", "/admin/test", "/any/path/here"];
431
432 for path in paths {
433 let request = Request::builder()
434 .method(Method::GET)
435 .uri(path)
436 .version(Version::HTTP_11)
437 .headers(HeaderMap::new())
438 .body(Bytes::new())
439 .build()
440 .unwrap();
441
442 let response = middleware.process(request, handler.clone()).await.unwrap();
443
444 assert_eq!(response.status, StatusCode::FOUND);
445 assert_eq!(
446 response.headers.get(hyper::header::LOCATION).unwrap(),
447 "/fallback"
448 );
449 }
450 }
451
452 #[tokio::test]
453 async fn test_complex_pattern_matching() {
454 let config = RedirectResponseConfig::new("/404".to_string())
455 .with_patterns(vec!["/api/v[0-9]+/.*".to_string()]);
456 let middleware = RedirectFallbackMiddleware::new(config);
457 let handler = Arc::new(NotFoundHandler);
458
459 let request1 = Request::builder()
461 .method(Method::GET)
462 .uri("/api/v1/users")
463 .version(Version::HTTP_11)
464 .headers(HeaderMap::new())
465 .body(Bytes::new())
466 .build()
467 .unwrap();
468
469 let response1 = middleware.process(request1, handler.clone()).await.unwrap();
470 assert_eq!(response1.status, StatusCode::FOUND);
471
472 let request2 = Request::builder()
474 .method(Method::GET)
475 .uri("/api/version/users")
476 .version(Version::HTTP_11)
477 .headers(HeaderMap::new())
478 .body(Bytes::new())
479 .build()
480 .unwrap();
481
482 let response2 = middleware.process(request2, handler).await.unwrap();
483 assert_eq!(response2.status, StatusCode::NOT_FOUND);
484 }
485}