Skip to main content

reinhardt_middleware/
https_redirect.rs

1//! HTTPS Redirect Middleware
2//!
3//! Automatically redirects HTTP requests to HTTPS.
4//! Similar to Django's SECURE_SSL_REDIRECT setting.
5
6use async_trait::async_trait;
7use hyper::StatusCode;
8use reinhardt_http::{Handler, Middleware, Request, Response, Result};
9use std::sync::Arc;
10
11/// Configuration for HTTPS redirect middleware
12#[non_exhaustive]
13#[derive(Debug, Clone)]
14pub struct HttpsRedirectConfig {
15	/// Enable HTTPS redirect
16	pub enabled: bool,
17	/// Exempt paths from HTTPS redirect (e.g., health checks)
18	pub exempt_paths: Vec<String>,
19	/// Redirect status code (301 or 302)
20	pub status_code: StatusCode,
21	/// Allowed host names for redirect (prevents host header injection).
22	/// If empty, all requests without a valid allowed host are rejected with 400 Bad Request.
23	pub allowed_hosts: Vec<String>,
24}
25
26impl Default for HttpsRedirectConfig {
27	fn default() -> Self {
28		Self {
29			enabled: true,
30			exempt_paths: vec![],
31			status_code: StatusCode::MOVED_PERMANENTLY, // 301
32			allowed_hosts: vec![],
33		}
34	}
35}
36
37/// Middleware to redirect HTTP requests to HTTPS
38pub struct HttpsRedirectMiddleware {
39	config: HttpsRedirectConfig,
40}
41
42impl HttpsRedirectMiddleware {
43	/// Create a new HttpsRedirectMiddleware with the given configuration
44	///
45	/// # Arguments
46	///
47	/// * `config` - HTTPS redirect configuration
48	///
49	/// # Examples
50	///
51	/// ```
52	/// use std::sync::Arc;
53	/// use reinhardt_middleware::{HttpsRedirectMiddleware, HttpsRedirectConfig};
54	/// use reinhardt_http::{Handler, Middleware, Request, Response};
55	/// use hyper::{StatusCode, Method, Version, HeaderMap};
56	/// use bytes::Bytes;
57	///
58	/// struct TestHandler;
59	///
60	/// #[async_trait::async_trait]
61	/// impl Handler for TestHandler {
62	///     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
63	///         Ok(Response::new(StatusCode::OK))
64	///     }
65	/// }
66	///
67	/// # tokio_test::block_on(async {
68	/// let mut config = HttpsRedirectConfig::default();
69	/// config.enabled = true;
70	/// config.exempt_paths = vec!["/health".to_string()];
71	/// config.status_code = StatusCode::MOVED_PERMANENTLY;
72	/// config.allowed_hosts = vec!["example.com".to_string()];
73	///
74	/// let middleware = HttpsRedirectMiddleware::new(config);
75	/// let handler = Arc::new(TestHandler);
76	///
77	/// let mut headers = HeaderMap::new();
78	/// headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
79	///
80	/// let request = Request::builder()
81	///     .method(Method::GET)
82	///     .uri("/api/data")
83	///     .version(Version::HTTP_11)
84	///     .headers(headers)
85	///     .body(Bytes::new())
86	///     .build()
87	///     .unwrap();
88	///
89	/// let response = middleware.process(request, handler).await.unwrap();
90	/// assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
91	/// assert_eq!(response.headers.get("Location").unwrap(), "https://example.com/api/data");
92	/// # });
93	/// ```
94	pub fn new(config: HttpsRedirectConfig) -> Self {
95		Self { config }
96	}
97	/// Create with default configuration
98	///
99	/// Default configuration enables HTTPS redirect with 301 status code and no exempt paths.
100	///
101	/// # Examples
102	///
103	/// ```
104	/// use std::sync::Arc;
105	/// use reinhardt_middleware::{HttpsRedirectConfig, HttpsRedirectMiddleware};
106	/// use reinhardt_http::{Handler, Middleware, Request, Response};
107	/// use hyper::{StatusCode, Method, Version, HeaderMap};
108	/// use bytes::Bytes;
109	///
110	/// struct TestHandler;
111	///
112	/// #[async_trait::async_trait]
113	/// impl Handler for TestHandler {
114	///     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
115	///         Ok(Response::new(StatusCode::OK))
116	///     }
117	/// }
118	///
119	/// # tokio_test::block_on(async {
120	/// let mut config = HttpsRedirectConfig::default();
121	/// config.allowed_hosts = vec!["api.example.com".to_string()];
122	/// let middleware = HttpsRedirectMiddleware::new(config);
123	/// let handler = Arc::new(TestHandler);
124	///
125	/// let mut headers = HeaderMap::new();
126	/// headers.insert(hyper::header::HOST, "api.example.com".parse().unwrap());
127	///
128	/// let request = Request::builder()
129	///     .method(Method::GET)
130	///     .uri("/users?page=1")
131	///     .version(Version::HTTP_11)
132	///     .headers(headers)
133	///     .body(Bytes::new())
134	///     .build()
135	///     .unwrap();
136	///
137	/// let response = middleware.process(request, handler).await.unwrap();
138	/// assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
139	/// assert_eq!(response.headers.get("Location").unwrap(), "https://api.example.com/users?page=1");
140	/// # });
141	/// ```
142	pub fn default_config() -> Self {
143		Self {
144			config: HttpsRedirectConfig::default(),
145		}
146	}
147
148	/// Check if a path is exempt from HTTPS redirect
149	fn is_exempt(&self, path: &str) -> bool {
150		self.config
151			.exempt_paths
152			.iter()
153			.any(|exempt| path.starts_with(exempt))
154	}
155
156	/// Validate host header against allowed hosts list.
157	/// Returns the validated host string if valid, or None if the host is not allowed.
158	fn validate_host<'a>(&self, host: Option<&'a str>) -> Option<&'a str> {
159		let host = host?;
160
161		// Reject hosts containing path separators or whitespace (injection attempts)
162		if host.contains('/') || host.contains('\\') || host.contains(char::is_whitespace) {
163			return None;
164		}
165
166		// If no allowed hosts configured, reject all (secure by default)
167		if self.config.allowed_hosts.is_empty() {
168			return None;
169		}
170
171		// Strip port for comparison (e.g., "example.com:8080" -> "example.com")
172		let host_without_port = host.split(':').next().unwrap_or(host);
173
174		// Check against allowed hosts list
175		let is_allowed = self.config.allowed_hosts.iter().any(|allowed| {
176			let allowed_lower = allowed.to_lowercase();
177			let host_lower = host_without_port.to_lowercase();
178			allowed_lower == host_lower
179		});
180
181		if is_allowed { Some(host) } else { None }
182	}
183}
184
185#[async_trait]
186impl Middleware for HttpsRedirectMiddleware {
187	async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
188		// If HTTPS redirect is disabled, just pass through
189		if !self.config.enabled {
190			return handler.handle(request).await;
191		}
192
193		// If request is already secure, pass through
194		if request.is_secure() {
195			return handler.handle(request).await;
196		}
197
198		// If path is exempt, pass through
199		if self.is_exempt(request.path()) {
200			return handler.handle(request).await;
201		}
202
203		// Validate host header against allowed hosts to prevent host header injection
204		let host_value = request
205			.headers
206			.get(hyper::header::HOST)
207			.and_then(|h| h.to_str().ok());
208
209		let validated_host = match self.validate_host(host_value) {
210			Some(host) => host,
211			None => {
212				// Reject requests with invalid or disallowed host headers
213				return Ok(Response::new(StatusCode::BAD_REQUEST));
214			}
215		};
216
217		// Build HTTPS redirect URL with validated host
218		let https_url = format!(
219			"https://{}{}",
220			validated_host,
221			request
222				.uri
223				.path_and_query()
224				.map(|pq| pq.as_str())
225				.unwrap_or("/")
226		);
227
228		// Return redirect response
229		let mut response = Response::new(self.config.status_code);
230		response.headers.insert(
231			hyper::header::LOCATION,
232			https_url
233				.parse()
234				.unwrap_or_else(|_| hyper::header::HeaderValue::from_static("/")),
235		);
236		Ok(response)
237	}
238}
239
240#[cfg(test)]
241mod tests {
242	use super::*;
243	use bytes::Bytes;
244	use hyper::{HeaderMap, Method, StatusCode, Version};
245	use reinhardt_http::Request;
246	use rstest::rstest;
247
248	struct TestHandler;
249
250	#[async_trait]
251	impl Handler for TestHandler {
252		async fn handle(&self, _request: Request) -> Result<Response> {
253			Ok(Response::ok().with_body(Bytes::from("test")))
254		}
255	}
256
257	fn config_with_allowed_hosts(hosts: Vec<&str>) -> HttpsRedirectConfig {
258		HttpsRedirectConfig {
259			enabled: true,
260			exempt_paths: vec![],
261			status_code: StatusCode::MOVED_PERMANENTLY,
262			allowed_hosts: hosts.into_iter().map(String::from).collect(),
263		}
264	}
265
266	#[rstest]
267	#[tokio::test]
268	async fn test_redirect_http_to_https_with_allowed_host() {
269		// Arrange
270		let config = config_with_allowed_hosts(vec!["example.com"]);
271		let middleware = HttpsRedirectMiddleware::new(config);
272		let handler = Arc::new(TestHandler);
273
274		let mut headers = HeaderMap::new();
275		headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
276
277		let request = Request::builder()
278			.method(Method::GET)
279			.uri("/test")
280			.version(Version::HTTP_11)
281			.headers(headers)
282			.body(Bytes::new())
283			.build()
284			.unwrap();
285
286		// Act
287		let response = middleware.process(request, handler).await.unwrap();
288
289		// Assert
290		assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
291		assert_eq!(
292			response.headers.get("Location").unwrap(),
293			"https://example.com/test"
294		);
295	}
296
297	#[rstest]
298	#[tokio::test]
299	async fn test_no_redirect_for_https() {
300		// Arrange
301		let config = config_with_allowed_hosts(vec!["example.com"]);
302		let middleware = HttpsRedirectMiddleware::new(config);
303		let handler = Arc::new(TestHandler);
304
305		let mut headers = HeaderMap::new();
306		headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
307
308		let request = Request::builder()
309			.method(Method::GET)
310			.uri("/test")
311			.version(Version::HTTP_11)
312			.headers(headers)
313			.body(Bytes::new())
314			.secure(true)
315			.build()
316			.unwrap();
317
318		// Act
319		let response = middleware.process(request, handler).await.unwrap();
320
321		// Assert
322		assert_eq!(response.status, StatusCode::OK);
323	}
324
325	#[rstest]
326	#[tokio::test]
327	async fn test_exempt_paths() {
328		// Arrange
329		let config = HttpsRedirectConfig {
330			enabled: true,
331			exempt_paths: vec!["/health".to_string()],
332			status_code: StatusCode::MOVED_PERMANENTLY,
333			allowed_hosts: vec!["example.com".to_string()],
334		};
335		let middleware = HttpsRedirectMiddleware::new(config);
336		let handler = Arc::new(TestHandler);
337
338		let mut headers = HeaderMap::new();
339		headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
340
341		let request = Request::builder()
342			.method(Method::GET)
343			.uri("/health")
344			.version(Version::HTTP_11)
345			.headers(headers)
346			.body(Bytes::new())
347			.build()
348			.unwrap();
349
350		// Act
351		let response = middleware.process(request, handler).await.unwrap();
352
353		// Assert - should not redirect exempt paths
354		assert_eq!(response.status, StatusCode::OK);
355	}
356
357	#[rstest]
358	#[tokio::test]
359	async fn test_reject_disallowed_host() {
360		// Arrange
361		let config = config_with_allowed_hosts(vec!["example.com"]);
362		let middleware = HttpsRedirectMiddleware::new(config);
363		let handler = Arc::new(TestHandler);
364
365		let mut headers = HeaderMap::new();
366		headers.insert(hyper::header::HOST, "evil.com".parse().unwrap());
367
368		let request = Request::builder()
369			.method(Method::GET)
370			.uri("/test")
371			.version(Version::HTTP_11)
372			.headers(headers)
373			.body(Bytes::new())
374			.build()
375			.unwrap();
376
377		// Act
378		let response = middleware.process(request, handler).await.unwrap();
379
380		// Assert - should reject with 400 Bad Request
381		assert_eq!(response.status, StatusCode::BAD_REQUEST);
382		assert!(response.headers.get("Location").is_none());
383	}
384
385	#[rstest]
386	#[tokio::test]
387	async fn test_reject_host_with_path_separator() {
388		// Arrange - host header injection attempt with path separator
389		let config = config_with_allowed_hosts(vec!["example.com"]);
390		let middleware = HttpsRedirectMiddleware::new(config);
391		let handler = Arc::new(TestHandler);
392
393		let mut headers = HeaderMap::new();
394		headers.insert(hyper::header::HOST, "evil.com/redirect".parse().unwrap());
395
396		let request = Request::builder()
397			.method(Method::GET)
398			.uri("/test")
399			.version(Version::HTTP_11)
400			.headers(headers)
401			.body(Bytes::new())
402			.build()
403			.unwrap();
404
405		// Act
406		let response = middleware.process(request, handler).await.unwrap();
407
408		// Assert - should reject host with path separator
409		assert_eq!(response.status, StatusCode::BAD_REQUEST);
410	}
411
412	#[rstest]
413	#[tokio::test]
414	async fn test_reject_missing_host_header() {
415		// Arrange - no host header at all
416		let config = config_with_allowed_hosts(vec!["example.com"]);
417		let middleware = HttpsRedirectMiddleware::new(config);
418		let handler = Arc::new(TestHandler);
419
420		let request = Request::builder()
421			.method(Method::GET)
422			.uri("/test")
423			.version(Version::HTTP_11)
424			.body(Bytes::new())
425			.build()
426			.unwrap();
427
428		// Act
429		let response = middleware.process(request, handler).await.unwrap();
430
431		// Assert - should reject when no host header present
432		assert_eq!(response.status, StatusCode::BAD_REQUEST);
433	}
434
435	#[rstest]
436	#[tokio::test]
437	async fn test_reject_empty_allowed_hosts() {
438		// Arrange - default config has empty allowed_hosts (secure by default)
439		let middleware = HttpsRedirectMiddleware::default_config();
440		let handler = Arc::new(TestHandler);
441
442		let mut headers = HeaderMap::new();
443		headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
444
445		let request = Request::builder()
446			.method(Method::GET)
447			.uri("/test")
448			.version(Version::HTTP_11)
449			.headers(headers)
450			.body(Bytes::new())
451			.build()
452			.unwrap();
453
454		// Act
455		let response = middleware.process(request, handler).await.unwrap();
456
457		// Assert - should reject when no allowed hosts configured
458		assert_eq!(response.status, StatusCode::BAD_REQUEST);
459	}
460
461	#[rstest]
462	#[tokio::test]
463	async fn test_allowed_host_with_port() {
464		// Arrange - host header includes port
465		let config = config_with_allowed_hosts(vec!["example.com"]);
466		let middleware = HttpsRedirectMiddleware::new(config);
467		let handler = Arc::new(TestHandler);
468
469		let mut headers = HeaderMap::new();
470		headers.insert(hyper::header::HOST, "example.com:8080".parse().unwrap());
471
472		let request = Request::builder()
473			.method(Method::GET)
474			.uri("/test")
475			.version(Version::HTTP_11)
476			.headers(headers)
477			.body(Bytes::new())
478			.build()
479			.unwrap();
480
481		// Act
482		let response = middleware.process(request, handler).await.unwrap();
483
484		// Assert - should allow host with port when hostname matches
485		assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
486		assert_eq!(
487			response.headers.get("Location").unwrap(),
488			"https://example.com:8080/test"
489		);
490	}
491
492	#[rstest]
493	#[tokio::test]
494	async fn test_case_insensitive_host_matching() {
495		// Arrange
496		let config = config_with_allowed_hosts(vec!["Example.COM"]);
497		let middleware = HttpsRedirectMiddleware::new(config);
498		let handler = Arc::new(TestHandler);
499
500		let mut headers = HeaderMap::new();
501		headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
502
503		let request = Request::builder()
504			.method(Method::GET)
505			.uri("/test")
506			.version(Version::HTTP_11)
507			.headers(headers)
508			.body(Bytes::new())
509			.build()
510			.unwrap();
511
512		// Act
513		let response = middleware.process(request, handler).await.unwrap();
514
515		// Assert - host matching should be case-insensitive
516		assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
517	}
518}