Skip to main content

reinhardt_middleware/
redirect_fallback.rs

1//! Redirect fallback middleware
2//!
3//! Provides automatic redirection for 404 errors to a fallback URL.
4//! Useful for handling missing pages gracefully.
5
6use async_trait::async_trait;
7use hyper::StatusCode;
8use regex::Regex;
9use reinhardt_http::{Handler, Middleware, Request, Response, Result};
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12
13/// Configuration for redirect fallback behavior
14#[non_exhaustive]
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct RedirectResponseConfig {
17	/// The fallback URL to redirect to on 404 errors
18	pub fallback_url: String,
19	/// Optional path patterns to match (if None, matches all 404s)
20	pub path_patterns: Option<Vec<String>>,
21	/// Status code to use for redirect (default: 302 Found)
22	pub redirect_status: Option<u16>,
23}
24
25impl RedirectResponseConfig {
26	/// Create a new configuration with a fallback URL
27	///
28	/// # Examples
29	///
30	/// ```
31	/// use reinhardt_middleware::RedirectResponseConfig;
32	///
33	/// let config = RedirectResponseConfig::new("/404".to_string());
34	/// assert_eq!(config.fallback_url, "/404");
35	/// ```
36	pub fn new(fallback_url: String) -> Self {
37		Self {
38			fallback_url,
39			path_patterns: None,
40			redirect_status: None,
41		}
42	}
43
44	/// Add path patterns to match
45	///
46	/// # Examples
47	///
48	/// ```
49	/// use reinhardt_middleware::RedirectResponseConfig;
50	///
51	/// let config = RedirectResponseConfig::new("/404".to_string())
52	///     .with_patterns(vec!["/api/.*".to_string()]);
53	/// ```
54	pub fn with_patterns(mut self, patterns: Vec<String>) -> Self {
55		self.path_patterns = Some(patterns);
56		self
57	}
58
59	/// Set custom redirect status code
60	///
61	/// # Examples
62	///
63	/// ```
64	/// use reinhardt_middleware::RedirectResponseConfig;
65	///
66	/// let config = RedirectResponseConfig::new("/404".to_string())
67	///     .with_status(301);
68	/// ```
69	pub fn with_status(mut self, status: u16) -> Self {
70		self.redirect_status = Some(status);
71		self
72	}
73}
74
75/// Middleware that redirects 404 errors to a fallback URL
76///
77/// # Examples
78///
79/// ```
80/// use std::sync::Arc;
81/// use reinhardt_middleware::{RedirectFallbackMiddleware, RedirectResponseConfig};
82/// use reinhardt_http::{Handler, Middleware, Request, Response};
83/// use hyper::{StatusCode, Method, Version, HeaderMap};
84/// use bytes::Bytes;
85///
86/// struct NotFoundHandler;
87///
88/// #[async_trait::async_trait]
89/// impl Handler for NotFoundHandler {
90///     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
91///         Ok(Response::new(StatusCode::NOT_FOUND))
92///     }
93/// }
94///
95/// # tokio_test::block_on(async {
96/// let config = RedirectResponseConfig::new("/404".to_string());
97/// let middleware = RedirectFallbackMiddleware::new(config);
98/// let handler = Arc::new(NotFoundHandler);
99///
100/// let request = Request::builder()
101///     .method(Method::GET)
102///     .uri("/missing")
103///     .version(Version::HTTP_11)
104///     .headers(HeaderMap::new())
105///     .body(Bytes::new())
106///     .build()
107///     .unwrap();
108///
109/// let response = middleware.process(request, handler).await.unwrap();
110/// assert_eq!(response.status, StatusCode::FOUND);
111/// assert_eq!(
112///     response.headers.get(hyper::header::LOCATION).unwrap(),
113///     "/404"
114/// );
115/// # });
116/// ```
117pub struct RedirectFallbackMiddleware {
118	config: RedirectResponseConfig,
119	compiled_patterns: Option<Vec<Regex>>,
120}
121
122impl RedirectFallbackMiddleware {
123	/// Create a new RedirectFallbackMiddleware with the given configuration
124	///
125	/// # Examples
126	///
127	/// ```
128	/// use reinhardt_middleware::{RedirectFallbackMiddleware, RedirectResponseConfig};
129	///
130	/// let config = RedirectResponseConfig::new("/404".to_string());
131	/// let middleware = RedirectFallbackMiddleware::new(config);
132	/// ```
133	pub fn new(config: RedirectResponseConfig) -> Self {
134		let compiled_patterns = config
135			.path_patterns
136			.as_ref()
137			.map(|patterns| patterns.iter().filter_map(|p| Regex::new(p).ok()).collect());
138
139		Self {
140			config,
141			compiled_patterns,
142		}
143	}
144
145	/// Check if the path matches any configured patterns
146	fn matches_pattern(&self, path: &str) -> bool {
147		match &self.compiled_patterns {
148			None => true, // No patterns means match all
149			Some(patterns) => patterns.iter().any(|re| re.is_match(path)),
150		}
151	}
152
153	/// Get the redirect status code to use
154	fn redirect_status(&self) -> StatusCode {
155		self.config
156			.redirect_status
157			.and_then(|code| StatusCode::from_u16(code).ok())
158			.unwrap_or(StatusCode::FOUND)
159	}
160
161	/// Check if we should redirect to avoid loops
162	fn should_redirect(&self, path: &str) -> bool {
163		// Prevent redirect loop: don't redirect if already at fallback URL
164		path != self.config.fallback_url
165	}
166}
167
168#[async_trait]
169impl Middleware for RedirectFallbackMiddleware {
170	async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
171		let path = request.uri.path().to_string();
172
173		// Call the handler
174		let response = handler.handle(request).await?;
175
176		// Only redirect on 404 errors
177		if response.status != StatusCode::NOT_FOUND {
178			return Ok(response);
179		}
180
181		// Check if we should redirect (pattern match and loop prevention)
182		if !self.matches_pattern(&path) || !self.should_redirect(&path) {
183			return Ok(response);
184		}
185
186		// Create redirect response
187		let mut redirect_response = Response::new(self.redirect_status());
188		redirect_response.headers.insert(
189			hyper::header::LOCATION,
190			self.config
191				.fallback_url
192				.parse()
193				.unwrap_or_else(|_| hyper::header::HeaderValue::from_static("/")),
194		);
195
196		Ok(redirect_response)
197	}
198}
199
200#[cfg(test)]
201mod tests {
202	use super::*;
203	use bytes::Bytes;
204	use hyper::{HeaderMap, Method, StatusCode, Version};
205
206	struct NotFoundHandler;
207
208	#[async_trait]
209	impl Handler for NotFoundHandler {
210		async fn handle(&self, _request: Request) -> Result<Response> {
211			Ok(Response::new(StatusCode::NOT_FOUND))
212		}
213	}
214
215	struct OkHandler;
216
217	#[async_trait]
218	impl Handler for OkHandler {
219		async fn handle(&self, _request: Request) -> Result<Response> {
220			Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
221		}
222	}
223
224	#[tokio::test]
225	async fn test_redirect_on_404() {
226		let config = RedirectResponseConfig::new("/404".to_string());
227		let middleware = RedirectFallbackMiddleware::new(config);
228		let handler = Arc::new(NotFoundHandler);
229
230		let request = Request::builder()
231			.method(Method::GET)
232			.uri("/missing")
233			.version(Version::HTTP_11)
234			.headers(HeaderMap::new())
235			.body(Bytes::new())
236			.build()
237			.unwrap();
238
239		let response = middleware.process(request, handler).await.unwrap();
240
241		assert_eq!(response.status, StatusCode::FOUND);
242		assert_eq!(
243			response.headers.get(hyper::header::LOCATION).unwrap(),
244			"/404"
245		);
246	}
247
248	#[tokio::test]
249	async fn test_no_redirect_on_200() {
250		let config = RedirectResponseConfig::new("/404".to_string());
251		let middleware = RedirectFallbackMiddleware::new(config);
252		let handler = Arc::new(OkHandler);
253
254		let request = Request::builder()
255			.method(Method::GET)
256			.uri("/existing")
257			.version(Version::HTTP_11)
258			.headers(HeaderMap::new())
259			.body(Bytes::new())
260			.build()
261			.unwrap();
262
263		let response = middleware.process(request, handler).await.unwrap();
264
265		assert_eq!(response.status, StatusCode::OK);
266		assert!(!response.headers.contains_key(hyper::header::LOCATION));
267	}
268
269	#[tokio::test]
270	async fn test_pattern_matching_redirect() {
271		let config = RedirectResponseConfig::new("/404".to_string())
272			.with_patterns(vec!["/api/.*".to_string()]);
273		let middleware = RedirectFallbackMiddleware::new(config);
274		let handler = Arc::new(NotFoundHandler);
275
276		// Should redirect for /api/* paths
277		let request = Request::builder()
278			.method(Method::GET)
279			.uri("/api/missing")
280			.version(Version::HTTP_11)
281			.headers(HeaderMap::new())
282			.body(Bytes::new())
283			.build()
284			.unwrap();
285
286		let response = middleware.process(request, handler).await.unwrap();
287
288		assert_eq!(response.status, StatusCode::FOUND);
289		assert_eq!(
290			response.headers.get(hyper::header::LOCATION).unwrap(),
291			"/404"
292		);
293	}
294
295	#[tokio::test]
296	async fn test_pattern_no_match_no_redirect() {
297		let config = RedirectResponseConfig::new("/404".to_string())
298			.with_patterns(vec!["/api/.*".to_string()]);
299		let middleware = RedirectFallbackMiddleware::new(config);
300		let handler = Arc::new(NotFoundHandler);
301
302		// Should NOT redirect for non-/api/* paths
303		let request = Request::builder()
304			.method(Method::GET)
305			.uri("/other/missing")
306			.version(Version::HTTP_11)
307			.headers(HeaderMap::new())
308			.body(Bytes::new())
309			.build()
310			.unwrap();
311
312		let response = middleware.process(request, handler).await.unwrap();
313
314		assert_eq!(response.status, StatusCode::NOT_FOUND);
315		assert!(!response.headers.contains_key(hyper::header::LOCATION));
316	}
317
318	#[tokio::test]
319	async fn test_custom_redirect_status() {
320		let config = RedirectResponseConfig::new("/404".to_string()).with_status(301);
321		let middleware = RedirectFallbackMiddleware::new(config);
322		let handler = Arc::new(NotFoundHandler);
323
324		let request = Request::builder()
325			.method(Method::GET)
326			.uri("/missing")
327			.version(Version::HTTP_11)
328			.headers(HeaderMap::new())
329			.body(Bytes::new())
330			.build()
331			.unwrap();
332
333		let response = middleware.process(request, handler).await.unwrap();
334
335		assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
336		assert_eq!(
337			response.headers.get(hyper::header::LOCATION).unwrap(),
338			"/404"
339		);
340	}
341
342	#[tokio::test]
343	async fn test_prevent_redirect_loop() {
344		let config = RedirectResponseConfig::new("/404".to_string());
345		let middleware = RedirectFallbackMiddleware::new(config);
346		let handler = Arc::new(NotFoundHandler);
347
348		// Request to the fallback URL itself should not redirect
349		let request = Request::builder()
350			.method(Method::GET)
351			.uri("/404")
352			.version(Version::HTTP_11)
353			.headers(HeaderMap::new())
354			.body(Bytes::new())
355			.build()
356			.unwrap();
357
358		let response = middleware.process(request, handler).await.unwrap();
359
360		assert_eq!(response.status, StatusCode::NOT_FOUND);
361		assert!(!response.headers.contains_key(hyper::header::LOCATION));
362	}
363
364	#[tokio::test]
365	async fn test_multiple_pattern_matching() {
366		let config = RedirectResponseConfig::new("/error".to_string())
367			.with_patterns(vec!["/api/.*".to_string(), "/v1/.*".to_string()]);
368		let middleware = RedirectFallbackMiddleware::new(config);
369		let handler = Arc::new(NotFoundHandler);
370
371		// Test first pattern
372		let request1 = Request::builder()
373			.method(Method::GET)
374			.uri("/api/test")
375			.version(Version::HTTP_11)
376			.headers(HeaderMap::new())
377			.body(Bytes::new())
378			.build()
379			.unwrap();
380
381		let response1 = middleware.process(request1, handler.clone()).await.unwrap();
382		assert_eq!(response1.status, StatusCode::FOUND);
383
384		// Test second pattern
385		let request2 = Request::builder()
386			.method(Method::GET)
387			.uri("/v1/test")
388			.version(Version::HTTP_11)
389			.headers(HeaderMap::new())
390			.body(Bytes::new())
391			.build()
392			.unwrap();
393
394		let response2 = middleware.process(request2, handler).await.unwrap();
395		assert_eq!(response2.status, StatusCode::FOUND);
396	}
397
398	#[tokio::test]
399	async fn test_different_http_methods() {
400		let config = RedirectResponseConfig::new("/404".to_string());
401		let middleware = RedirectFallbackMiddleware::new(config);
402		let handler = Arc::new(NotFoundHandler);
403
404		// Test POST
405		let request = Request::builder()
406			.method(Method::POST)
407			.uri("/missing")
408			.version(Version::HTTP_11)
409			.headers(HeaderMap::new())
410			.body(Bytes::new())
411			.build()
412			.unwrap();
413
414		let response = middleware.process(request, handler).await.unwrap();
415
416		assert_eq!(response.status, StatusCode::FOUND);
417		assert_eq!(
418			response.headers.get(hyper::header::LOCATION).unwrap(),
419			"/404"
420		);
421	}
422
423	#[tokio::test]
424	async fn test_no_patterns_matches_all() {
425		let config = RedirectResponseConfig::new("/fallback".to_string());
426		let middleware = RedirectFallbackMiddleware::new(config);
427		let handler = Arc::new(NotFoundHandler);
428
429		// Any path should redirect when no patterns are specified
430		let paths = vec!["/api/test", "/admin/test", "/any/path/here"];
431
432		for path in paths {
433			let request = Request::builder()
434				.method(Method::GET)
435				.uri(path)
436				.version(Version::HTTP_11)
437				.headers(HeaderMap::new())
438				.body(Bytes::new())
439				.build()
440				.unwrap();
441
442			let response = middleware.process(request, handler.clone()).await.unwrap();
443
444			assert_eq!(response.status, StatusCode::FOUND);
445			assert_eq!(
446				response.headers.get(hyper::header::LOCATION).unwrap(),
447				"/fallback"
448			);
449		}
450	}
451
452	#[tokio::test]
453	async fn test_complex_pattern_matching() {
454		let config = RedirectResponseConfig::new("/404".to_string())
455			.with_patterns(vec!["/api/v[0-9]+/.*".to_string()]);
456		let middleware = RedirectFallbackMiddleware::new(config);
457		let handler = Arc::new(NotFoundHandler);
458
459		// Should match /api/v1/, /api/v2/, etc.
460		let request1 = Request::builder()
461			.method(Method::GET)
462			.uri("/api/v1/users")
463			.version(Version::HTTP_11)
464			.headers(HeaderMap::new())
465			.body(Bytes::new())
466			.build()
467			.unwrap();
468
469		let response1 = middleware.process(request1, handler.clone()).await.unwrap();
470		assert_eq!(response1.status, StatusCode::FOUND);
471
472		// Should NOT match /api/version/
473		let request2 = Request::builder()
474			.method(Method::GET)
475			.uri("/api/version/users")
476			.version(Version::HTTP_11)
477			.headers(HeaderMap::new())
478			.body(Bytes::new())
479			.build()
480			.unwrap();
481
482		let response2 = middleware.process(request2, handler).await.unwrap();
483		assert_eq!(response2.status, StatusCode::NOT_FOUND);
484	}
485}