reinhardt_middleware/
https_redirect.rs1use async_trait::async_trait;
7use hyper::StatusCode;
8use reinhardt_http::{Handler, Middleware, Request, Response, Result};
9use std::sync::Arc;
10
11#[non_exhaustive]
13#[derive(Debug, Clone)]
14pub struct HttpsRedirectConfig {
15 pub enabled: bool,
17 pub exempt_paths: Vec<String>,
19 pub status_code: StatusCode,
21 pub allowed_hosts: Vec<String>,
24}
25
26impl Default for HttpsRedirectConfig {
27 fn default() -> Self {
28 Self {
29 enabled: true,
30 exempt_paths: vec![],
31 status_code: StatusCode::MOVED_PERMANENTLY, allowed_hosts: vec![],
33 }
34 }
35}
36
37pub struct HttpsRedirectMiddleware {
39 config: HttpsRedirectConfig,
40}
41
42impl HttpsRedirectMiddleware {
43 pub fn new(config: HttpsRedirectConfig) -> Self {
95 Self { config }
96 }
97 pub fn default_config() -> Self {
143 Self {
144 config: HttpsRedirectConfig::default(),
145 }
146 }
147
148 fn is_exempt(&self, path: &str) -> bool {
150 self.config
151 .exempt_paths
152 .iter()
153 .any(|exempt| path.starts_with(exempt))
154 }
155
156 fn validate_host<'a>(&self, host: Option<&'a str>) -> Option<&'a str> {
159 let host = host?;
160
161 if host.contains('/') || host.contains('\\') || host.contains(char::is_whitespace) {
163 return None;
164 }
165
166 if self.config.allowed_hosts.is_empty() {
168 return None;
169 }
170
171 let host_without_port = host.split(':').next().unwrap_or(host);
173
174 let is_allowed = self.config.allowed_hosts.iter().any(|allowed| {
176 let allowed_lower = allowed.to_lowercase();
177 let host_lower = host_without_port.to_lowercase();
178 allowed_lower == host_lower
179 });
180
181 if is_allowed { Some(host) } else { None }
182 }
183}
184
185#[async_trait]
186impl Middleware for HttpsRedirectMiddleware {
187 async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
188 if !self.config.enabled {
190 return handler.handle(request).await;
191 }
192
193 if request.is_secure() {
195 return handler.handle(request).await;
196 }
197
198 if self.is_exempt(request.path()) {
200 return handler.handle(request).await;
201 }
202
203 let host_value = request
205 .headers
206 .get(hyper::header::HOST)
207 .and_then(|h| h.to_str().ok());
208
209 let validated_host = match self.validate_host(host_value) {
210 Some(host) => host,
211 None => {
212 return Ok(Response::new(StatusCode::BAD_REQUEST));
214 }
215 };
216
217 let https_url = format!(
219 "https://{}{}",
220 validated_host,
221 request
222 .uri
223 .path_and_query()
224 .map(|pq| pq.as_str())
225 .unwrap_or("/")
226 );
227
228 let mut response = Response::new(self.config.status_code);
230 response.headers.insert(
231 hyper::header::LOCATION,
232 https_url
233 .parse()
234 .unwrap_or_else(|_| hyper::header::HeaderValue::from_static("/")),
235 );
236 Ok(response)
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use bytes::Bytes;
244 use hyper::{HeaderMap, Method, StatusCode, Version};
245 use reinhardt_http::Request;
246 use rstest::rstest;
247
248 struct TestHandler;
249
250 #[async_trait]
251 impl Handler for TestHandler {
252 async fn handle(&self, _request: Request) -> Result<Response> {
253 Ok(Response::ok().with_body(Bytes::from("test")))
254 }
255 }
256
257 fn config_with_allowed_hosts(hosts: Vec<&str>) -> HttpsRedirectConfig {
258 HttpsRedirectConfig {
259 enabled: true,
260 exempt_paths: vec![],
261 status_code: StatusCode::MOVED_PERMANENTLY,
262 allowed_hosts: hosts.into_iter().map(String::from).collect(),
263 }
264 }
265
266 #[rstest]
267 #[tokio::test]
268 async fn test_redirect_http_to_https_with_allowed_host() {
269 let config = config_with_allowed_hosts(vec!["example.com"]);
271 let middleware = HttpsRedirectMiddleware::new(config);
272 let handler = Arc::new(TestHandler);
273
274 let mut headers = HeaderMap::new();
275 headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
276
277 let request = Request::builder()
278 .method(Method::GET)
279 .uri("/test")
280 .version(Version::HTTP_11)
281 .headers(headers)
282 .body(Bytes::new())
283 .build()
284 .unwrap();
285
286 let response = middleware.process(request, handler).await.unwrap();
288
289 assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
291 assert_eq!(
292 response.headers.get("Location").unwrap(),
293 "https://example.com/test"
294 );
295 }
296
297 #[rstest]
298 #[tokio::test]
299 async fn test_no_redirect_for_https() {
300 let config = config_with_allowed_hosts(vec!["example.com"]);
302 let middleware = HttpsRedirectMiddleware::new(config);
303 let handler = Arc::new(TestHandler);
304
305 let mut headers = HeaderMap::new();
306 headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
307
308 let request = Request::builder()
309 .method(Method::GET)
310 .uri("/test")
311 .version(Version::HTTP_11)
312 .headers(headers)
313 .body(Bytes::new())
314 .secure(true)
315 .build()
316 .unwrap();
317
318 let response = middleware.process(request, handler).await.unwrap();
320
321 assert_eq!(response.status, StatusCode::OK);
323 }
324
325 #[rstest]
326 #[tokio::test]
327 async fn test_exempt_paths() {
328 let config = HttpsRedirectConfig {
330 enabled: true,
331 exempt_paths: vec!["/health".to_string()],
332 status_code: StatusCode::MOVED_PERMANENTLY,
333 allowed_hosts: vec!["example.com".to_string()],
334 };
335 let middleware = HttpsRedirectMiddleware::new(config);
336 let handler = Arc::new(TestHandler);
337
338 let mut headers = HeaderMap::new();
339 headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
340
341 let request = Request::builder()
342 .method(Method::GET)
343 .uri("/health")
344 .version(Version::HTTP_11)
345 .headers(headers)
346 .body(Bytes::new())
347 .build()
348 .unwrap();
349
350 let response = middleware.process(request, handler).await.unwrap();
352
353 assert_eq!(response.status, StatusCode::OK);
355 }
356
357 #[rstest]
358 #[tokio::test]
359 async fn test_reject_disallowed_host() {
360 let config = config_with_allowed_hosts(vec!["example.com"]);
362 let middleware = HttpsRedirectMiddleware::new(config);
363 let handler = Arc::new(TestHandler);
364
365 let mut headers = HeaderMap::new();
366 headers.insert(hyper::header::HOST, "evil.com".parse().unwrap());
367
368 let request = Request::builder()
369 .method(Method::GET)
370 .uri("/test")
371 .version(Version::HTTP_11)
372 .headers(headers)
373 .body(Bytes::new())
374 .build()
375 .unwrap();
376
377 let response = middleware.process(request, handler).await.unwrap();
379
380 assert_eq!(response.status, StatusCode::BAD_REQUEST);
382 assert!(response.headers.get("Location").is_none());
383 }
384
385 #[rstest]
386 #[tokio::test]
387 async fn test_reject_host_with_path_separator() {
388 let config = config_with_allowed_hosts(vec!["example.com"]);
390 let middleware = HttpsRedirectMiddleware::new(config);
391 let handler = Arc::new(TestHandler);
392
393 let mut headers = HeaderMap::new();
394 headers.insert(hyper::header::HOST, "evil.com/redirect".parse().unwrap());
395
396 let request = Request::builder()
397 .method(Method::GET)
398 .uri("/test")
399 .version(Version::HTTP_11)
400 .headers(headers)
401 .body(Bytes::new())
402 .build()
403 .unwrap();
404
405 let response = middleware.process(request, handler).await.unwrap();
407
408 assert_eq!(response.status, StatusCode::BAD_REQUEST);
410 }
411
412 #[rstest]
413 #[tokio::test]
414 async fn test_reject_missing_host_header() {
415 let config = config_with_allowed_hosts(vec!["example.com"]);
417 let middleware = HttpsRedirectMiddleware::new(config);
418 let handler = Arc::new(TestHandler);
419
420 let request = Request::builder()
421 .method(Method::GET)
422 .uri("/test")
423 .version(Version::HTTP_11)
424 .body(Bytes::new())
425 .build()
426 .unwrap();
427
428 let response = middleware.process(request, handler).await.unwrap();
430
431 assert_eq!(response.status, StatusCode::BAD_REQUEST);
433 }
434
435 #[rstest]
436 #[tokio::test]
437 async fn test_reject_empty_allowed_hosts() {
438 let middleware = HttpsRedirectMiddleware::default_config();
440 let handler = Arc::new(TestHandler);
441
442 let mut headers = HeaderMap::new();
443 headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
444
445 let request = Request::builder()
446 .method(Method::GET)
447 .uri("/test")
448 .version(Version::HTTP_11)
449 .headers(headers)
450 .body(Bytes::new())
451 .build()
452 .unwrap();
453
454 let response = middleware.process(request, handler).await.unwrap();
456
457 assert_eq!(response.status, StatusCode::BAD_REQUEST);
459 }
460
461 #[rstest]
462 #[tokio::test]
463 async fn test_allowed_host_with_port() {
464 let config = config_with_allowed_hosts(vec!["example.com"]);
466 let middleware = HttpsRedirectMiddleware::new(config);
467 let handler = Arc::new(TestHandler);
468
469 let mut headers = HeaderMap::new();
470 headers.insert(hyper::header::HOST, "example.com:8080".parse().unwrap());
471
472 let request = Request::builder()
473 .method(Method::GET)
474 .uri("/test")
475 .version(Version::HTTP_11)
476 .headers(headers)
477 .body(Bytes::new())
478 .build()
479 .unwrap();
480
481 let response = middleware.process(request, handler).await.unwrap();
483
484 assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
486 assert_eq!(
487 response.headers.get("Location").unwrap(),
488 "https://example.com:8080/test"
489 );
490 }
491
492 #[rstest]
493 #[tokio::test]
494 async fn test_case_insensitive_host_matching() {
495 let config = config_with_allowed_hosts(vec!["Example.COM"]);
497 let middleware = HttpsRedirectMiddleware::new(config);
498 let handler = Arc::new(TestHandler);
499
500 let mut headers = HeaderMap::new();
501 headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
502
503 let request = Request::builder()
504 .method(Method::GET)
505 .uri("/test")
506 .version(Version::HTTP_11)
507 .headers(headers)
508 .body(Bytes::new())
509 .build()
510 .unwrap();
511
512 let response = middleware.process(request, handler).await.unwrap();
514
515 assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
517 }
518}