Skip to main content

reinhardt_middleware/
common.rs

1//! Common middleware utilities
2//!
3//! Provides URL normalization and common request processing patterns.
4
5use async_trait::async_trait;
6use hyper::header::HOST;
7use hyper::{Method, StatusCode};
8use reinhardt_http::{Handler, Middleware, Request, Response, Result};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11
12/// Common middleware configuration
13#[non_exhaustive]
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct CommonConfig {
16	/// Append trailing slash to URLs that don't have one (except URLs with file extensions)
17	pub append_slash: bool,
18	/// Prepend 'www.' to the domain if not present
19	pub prepend_www: bool,
20}
21
22impl CommonConfig {
23	/// Create a new CommonConfig with default settings
24	///
25	/// Default configuration:
26	/// - `append_slash`: true - Adds trailing slashes to URLs
27	/// - `prepend_www`: false - Does not add www prefix
28	///
29	/// # Examples
30	///
31	/// ```
32	/// use reinhardt_middleware::common::CommonConfig;
33	///
34	/// let config = CommonConfig::new();
35	/// assert!(config.append_slash);
36	/// assert!(!config.prepend_www);
37	/// ```
38	pub fn new() -> Self {
39		Self {
40			append_slash: true,
41			prepend_www: false,
42		}
43	}
44}
45
46impl Default for CommonConfig {
47	fn default() -> Self {
48		Self::new()
49	}
50}
51
52/// Common middleware for URL normalization
53///
54/// Handles common URL transformations:
55/// - Appending trailing slashes to URLs
56/// - Prepending 'www.' to domain names
57///
58/// # Examples
59///
60/// ```
61/// use std::sync::Arc;
62/// use reinhardt_middleware::{CommonMiddleware, CommonConfig};
63/// use reinhardt_http::{Handler, Middleware, Request, Response};
64/// use hyper::{StatusCode, Method, Version, HeaderMap};
65/// use bytes::Bytes;
66///
67/// struct TestHandler;
68///
69/// #[async_trait::async_trait]
70/// impl Handler for TestHandler {
71///     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
72///         Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
73///     }
74/// }
75///
76/// # tokio_test::block_on(async {
77/// let mut config = CommonConfig::new();
78/// config.append_slash = true;
79/// config.prepend_www = false;
80///
81/// let middleware = CommonMiddleware::with_config(config);
82/// let handler = Arc::new(TestHandler);
83///
84/// let request = Request::builder()
85///     .method(Method::GET)
86///     .uri("/path/to/page")
87///     .version(Version::HTTP_11)
88///     .headers(HeaderMap::new())
89///     .body(Bytes::new())
90///     .build()
91///     .unwrap();
92///
93/// let response = middleware.process(request, handler).await.unwrap();
94/// // URL without trailing slash redirects to /path/to/page/
95/// assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
96/// # });
97/// ```
98pub struct CommonMiddleware {
99	config: CommonConfig,
100}
101
102impl CommonMiddleware {
103	/// Create a new CommonMiddleware with default configuration
104	///
105	/// # Examples
106	///
107	/// ```
108	/// use reinhardt_middleware::CommonMiddleware;
109	///
110	/// let middleware = CommonMiddleware::new();
111	/// ```
112	pub fn new() -> Self {
113		Self {
114			config: CommonConfig::default(),
115		}
116	}
117
118	/// Create a new CommonMiddleware with custom configuration
119	///
120	/// # Examples
121	///
122	/// ```
123	/// use reinhardt_middleware::{CommonMiddleware, CommonConfig};
124	///
125	/// let mut config = CommonConfig::new();
126	/// config.append_slash = true;
127	/// config.prepend_www = true;
128	///
129	/// let middleware = CommonMiddleware::with_config(config);
130	/// ```
131	pub fn with_config(config: CommonConfig) -> Self {
132		Self { config }
133	}
134
135	/// Check if the URL path should have a trailing slash appended
136	fn should_append_slash(&self, path: &str) -> bool {
137		if !self.config.append_slash {
138			return false;
139		}
140
141		// Already ends with slash
142		if path.ends_with('/') {
143			return false;
144		}
145
146		// Check if path looks like a file (has extension)
147		if let Some(last_segment) = path.rsplit('/').next()
148			&& last_segment.contains('.')
149		{
150			return false;
151		}
152
153		true
154	}
155
156	/// Check if the host should have www prepended
157	fn should_prepend_www(&self, host: &str) -> bool {
158		if !self.config.prepend_www {
159			return false;
160		}
161
162		// Already has www
163		if host.starts_with("www.") {
164			return false;
165		}
166
167		// Localhost and IPs should not get www
168		if host.starts_with("localhost") || host.starts_with("127.") || host.starts_with("192.168.")
169		{
170			return false;
171		}
172
173		true
174	}
175
176	/// Build the redirect URL
177	fn build_redirect_url(&self, request: &Request) -> Option<String> {
178		let path = request.uri.path();
179		let query = request.uri.query();
180
181		let host = request
182			.headers
183			.get(HOST)
184			.and_then(|h| h.to_str().ok())
185			.unwrap_or("localhost");
186
187		let mut redirect_needed = false;
188		let mut new_path = path.to_string();
189		let mut new_host = host.to_string();
190
191		// Check if we need to append slash
192		if self.should_append_slash(path) {
193			new_path.push('/');
194			redirect_needed = true;
195		}
196
197		// Check if we need to prepend www
198		if self.should_prepend_www(host) {
199			new_host = format!("www.{}", host);
200			redirect_needed = true;
201		}
202
203		if !redirect_needed {
204			return None;
205		}
206
207		// Build the full URL using Request::scheme() which validates trusted proxies
208		// before honoring X-Forwarded-Proto headers
209		let scheme = request.scheme();
210
211		let url = if let Some(q) = query {
212			format!("{}://{}{}?{}", scheme, new_host, new_path, q)
213		} else {
214			format!("{}://{}{}", scheme, new_host, new_path)
215		};
216
217		Some(url)
218	}
219}
220
221impl Default for CommonMiddleware {
222	fn default() -> Self {
223		Self::new()
224	}
225}
226
227#[async_trait]
228impl Middleware for CommonMiddleware {
229	async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
230		// Check if we need to redirect
231		if let Some(redirect_url) = self.build_redirect_url(&request) {
232			// Use 307 Temporary Redirect for non-GET/HEAD methods to preserve
233			// the request method and body. Use 301 Moved Permanently for GET/HEAD.
234			let status = if matches!(request.method, Method::GET | Method::HEAD) {
235				StatusCode::MOVED_PERMANENTLY
236			} else {
237				StatusCode::TEMPORARY_REDIRECT
238			};
239			let mut response = Response::new(status);
240			response.headers.insert(
241				hyper::header::LOCATION,
242				redirect_url
243					.parse()
244					.unwrap_or_else(|_| hyper::header::HeaderValue::from_static("/")),
245			);
246			return Ok(response);
247		}
248
249		// No redirect needed, proceed with the handler
250		handler.handle(request).await
251	}
252}
253
254#[cfg(test)]
255mod tests {
256	use super::*;
257	use bytes::Bytes;
258	use hyper::{HeaderMap, Method, Version};
259	use rstest::rstest;
260
261	struct TestHandler;
262
263	#[async_trait]
264	impl Handler for TestHandler {
265		async fn handle(&self, _request: Request) -> Result<Response> {
266			Ok(Response::new(StatusCode::OK).with_body("test response".as_bytes()))
267		}
268	}
269
270	#[tokio::test]
271	async fn test_append_slash_redirects() {
272		let config = CommonConfig {
273			append_slash: true,
274			prepend_www: false,
275		};
276		let middleware = CommonMiddleware::with_config(config);
277		let handler = Arc::new(TestHandler);
278
279		let request = Request::builder()
280			.method(Method::GET)
281			.uri("/path/to/page")
282			.version(Version::HTTP_11)
283			.headers(HeaderMap::new())
284			.body(Bytes::new())
285			.build()
286			.unwrap();
287
288		let response = middleware.process(request, handler).await.unwrap();
289
290		assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
291		let location = response.headers.get(hyper::header::LOCATION).unwrap();
292		assert!(location.to_str().unwrap().contains("/path/to/page/"));
293	}
294
295	#[tokio::test]
296	async fn test_no_redirect_with_trailing_slash() {
297		let config = CommonConfig {
298			append_slash: true,
299			prepend_www: false,
300		};
301		let middleware = CommonMiddleware::with_config(config);
302		let handler = Arc::new(TestHandler);
303
304		let request = Request::builder()
305			.method(Method::GET)
306			.uri("/path/to/page/")
307			.version(Version::HTTP_11)
308			.headers(HeaderMap::new())
309			.body(Bytes::new())
310			.build()
311			.unwrap();
312
313		let response = middleware.process(request, handler).await.unwrap();
314
315		assert_eq!(response.status, StatusCode::OK);
316	}
317
318	#[tokio::test]
319	async fn test_no_redirect_for_file_extensions() {
320		let config = CommonConfig {
321			append_slash: true,
322			prepend_www: false,
323		};
324		let middleware = CommonMiddleware::with_config(config);
325		let handler = Arc::new(TestHandler);
326
327		let request = Request::builder()
328			.method(Method::GET)
329			.uri("/static/file.css")
330			.version(Version::HTTP_11)
331			.headers(HeaderMap::new())
332			.body(Bytes::new())
333			.build()
334			.unwrap();
335
336		let response = middleware.process(request, handler).await.unwrap();
337
338		assert_eq!(response.status, StatusCode::OK);
339	}
340
341	#[tokio::test]
342	async fn test_append_slash_with_query_params() {
343		let config = CommonConfig {
344			append_slash: true,
345			prepend_www: false,
346		};
347		let middleware = CommonMiddleware::with_config(config);
348		let handler = Arc::new(TestHandler);
349
350		let request = Request::builder()
351			.method(Method::GET)
352			.uri("/search?q=test")
353			.version(Version::HTTP_11)
354			.headers(HeaderMap::new())
355			.body(Bytes::new())
356			.build()
357			.unwrap();
358
359		let response = middleware.process(request, handler).await.unwrap();
360
361		assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
362		let location = response.headers.get(hyper::header::LOCATION).unwrap();
363		let loc_str = location.to_str().unwrap();
364		assert!(loc_str.contains("/search/"));
365		assert!(loc_str.contains("?q=test"));
366	}
367
368	#[tokio::test]
369	async fn test_prepend_www() {
370		let config = CommonConfig {
371			append_slash: false,
372			prepend_www: true,
373		};
374		let middleware = CommonMiddleware::with_config(config);
375		let handler = Arc::new(TestHandler);
376
377		let mut headers = HeaderMap::new();
378		headers.insert(HOST, "example.com".parse().unwrap());
379
380		let request = Request::builder()
381			.method(Method::GET)
382			.uri("/page/")
383			.version(Version::HTTP_11)
384			.headers(headers)
385			.body(Bytes::new())
386			.build()
387			.unwrap();
388
389		let response = middleware.process(request, handler).await.unwrap();
390
391		assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
392		let location = response.headers.get(hyper::header::LOCATION).unwrap();
393		assert!(location.to_str().unwrap().contains("www.example.com"));
394	}
395
396	#[tokio::test]
397	async fn test_no_prepend_www_for_localhost() {
398		let config = CommonConfig {
399			append_slash: false,
400			prepend_www: true,
401		};
402		let middleware = CommonMiddleware::with_config(config);
403		let handler = Arc::new(TestHandler);
404
405		let mut headers = HeaderMap::new();
406		headers.insert(HOST, "localhost:8000".parse().unwrap());
407
408		let request = Request::builder()
409			.method(Method::GET)
410			.uri("/page/")
411			.version(Version::HTTP_11)
412			.headers(headers)
413			.body(Bytes::new())
414			.build()
415			.unwrap();
416
417		let response = middleware.process(request, handler).await.unwrap();
418
419		assert_eq!(response.status, StatusCode::OK);
420	}
421
422	#[tokio::test]
423	async fn test_no_prepend_www_when_already_present() {
424		let config = CommonConfig {
425			append_slash: false,
426			prepend_www: true,
427		};
428		let middleware = CommonMiddleware::with_config(config);
429		let handler = Arc::new(TestHandler);
430
431		let mut headers = HeaderMap::new();
432		headers.insert(HOST, "www.example.com".parse().unwrap());
433
434		let request = Request::builder()
435			.method(Method::GET)
436			.uri("/page/")
437			.version(Version::HTTP_11)
438			.headers(headers)
439			.body(Bytes::new())
440			.build()
441			.unwrap();
442
443		let response = middleware.process(request, handler).await.unwrap();
444
445		assert_eq!(response.status, StatusCode::OK);
446	}
447
448	#[tokio::test]
449	async fn test_both_transformations() {
450		let config = CommonConfig {
451			append_slash: true,
452			prepend_www: true,
453		};
454		let middleware = CommonMiddleware::with_config(config);
455		let handler = Arc::new(TestHandler);
456
457		let mut headers = HeaderMap::new();
458		headers.insert(HOST, "example.com".parse().unwrap());
459
460		let request = Request::builder()
461			.method(Method::GET)
462			.uri("/page")
463			.version(Version::HTTP_11)
464			.headers(headers)
465			.body(Bytes::new())
466			.build()
467			.unwrap();
468
469		let response = middleware.process(request, handler).await.unwrap();
470
471		assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
472		let location = response.headers.get(hyper::header::LOCATION).unwrap();
473		let loc_str = location.to_str().unwrap();
474		assert!(loc_str.contains("www.example.com"));
475		assert!(loc_str.contains("/page/"));
476	}
477
478	#[tokio::test]
479	async fn test_both_disabled() {
480		let config = CommonConfig {
481			append_slash: false,
482			prepend_www: false,
483		};
484		let middleware = CommonMiddleware::with_config(config);
485		let handler = Arc::new(TestHandler);
486
487		let mut headers = HeaderMap::new();
488		headers.insert(HOST, "example.com".parse().unwrap());
489
490		let request = Request::builder()
491			.method(Method::GET)
492			.uri("/page")
493			.version(Version::HTTP_11)
494			.headers(headers)
495			.body(Bytes::new())
496			.build()
497			.unwrap();
498
499		let response = middleware.process(request, handler).await.unwrap();
500
501		assert_eq!(response.status, StatusCode::OK);
502	}
503
504	#[rstest]
505	#[case::get_returns_301(Method::GET, StatusCode::MOVED_PERMANENTLY)]
506	#[case::head_returns_301(Method::HEAD, StatusCode::MOVED_PERMANENTLY)]
507	#[case::post_returns_307(Method::POST, StatusCode::TEMPORARY_REDIRECT)]
508	#[case::put_returns_307(Method::PUT, StatusCode::TEMPORARY_REDIRECT)]
509	#[case::patch_returns_307(Method::PATCH, StatusCode::TEMPORARY_REDIRECT)]
510	#[case::delete_returns_307(Method::DELETE, StatusCode::TEMPORARY_REDIRECT)]
511	#[tokio::test]
512	async fn test_redirect_status_by_method(
513		#[case] method: Method,
514		#[case] expected_status: StatusCode,
515	) {
516		// Arrange
517		let config = CommonConfig {
518			append_slash: true,
519			prepend_www: false,
520		};
521		let middleware = CommonMiddleware::with_config(config);
522		let handler = Arc::new(TestHandler);
523
524		let request = Request::builder()
525			.method(method)
526			.uri("/path/to/page")
527			.version(Version::HTTP_11)
528			.headers(HeaderMap::new())
529			.body(Bytes::new())
530			.build()
531			.unwrap();
532
533		// Act
534		let response = middleware.process(request, handler).await.unwrap();
535
536		// Assert
537		assert_eq!(response.status, expected_status);
538		assert!(response.headers.contains_key(hyper::header::LOCATION));
539	}
540}