use async_trait::async_trait;
use hyper::StatusCode;
use reinhardt_http::{Handler, Middleware, Request, Response, Result};
use std::sync::Arc;
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct HttpsRedirectConfig {
pub enabled: bool,
pub exempt_paths: Vec<String>,
pub status_code: StatusCode,
pub allowed_hosts: Vec<String>,
}
impl Default for HttpsRedirectConfig {
fn default() -> Self {
Self {
enabled: true,
exempt_paths: vec![],
status_code: StatusCode::MOVED_PERMANENTLY, allowed_hosts: vec![],
}
}
}
pub struct HttpsRedirectMiddleware {
config: HttpsRedirectConfig,
}
impl HttpsRedirectMiddleware {
pub fn new(config: HttpsRedirectConfig) -> Self {
Self { config }
}
pub fn default_config() -> Self {
Self {
config: HttpsRedirectConfig::default(),
}
}
fn is_exempt(&self, path: &str) -> bool {
self.config
.exempt_paths
.iter()
.any(|exempt| path.starts_with(exempt))
}
fn validate_host<'a>(&self, host: Option<&'a str>) -> Option<&'a str> {
let host = host?;
if host.contains('/') || host.contains('\\') || host.contains(char::is_whitespace) {
return None;
}
if self.config.allowed_hosts.is_empty() {
return None;
}
let host_without_port = host.split(':').next().unwrap_or(host);
let is_allowed = self.config.allowed_hosts.iter().any(|allowed| {
let allowed_lower = allowed.to_lowercase();
let host_lower = host_without_port.to_lowercase();
allowed_lower == host_lower
});
if is_allowed { Some(host) } else { None }
}
}
#[async_trait]
impl Middleware for HttpsRedirectMiddleware {
async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
if !self.config.enabled {
return handler.handle(request).await;
}
if request.is_secure() {
return handler.handle(request).await;
}
if self.is_exempt(request.path()) {
return handler.handle(request).await;
}
let host_value = request
.headers
.get(hyper::header::HOST)
.and_then(|h| h.to_str().ok());
let validated_host = match self.validate_host(host_value) {
Some(host) => host,
None => {
return Ok(Response::new(StatusCode::BAD_REQUEST));
}
};
let https_url = format!(
"https://{}{}",
validated_host,
request
.uri
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/")
);
let mut response = Response::new(self.config.status_code);
response.headers.insert(
hyper::header::LOCATION,
https_url
.parse()
.unwrap_or_else(|_| hyper::header::HeaderValue::from_static("/")),
);
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, StatusCode, Version};
use reinhardt_http::Request;
use rstest::rstest;
struct TestHandler;
#[async_trait]
impl Handler for TestHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
Ok(Response::ok().with_body(Bytes::from("test")))
}
}
fn config_with_allowed_hosts(hosts: Vec<&str>) -> HttpsRedirectConfig {
HttpsRedirectConfig {
enabled: true,
exempt_paths: vec![],
status_code: StatusCode::MOVED_PERMANENTLY,
allowed_hosts: hosts.into_iter().map(String::from).collect(),
}
}
#[rstest]
#[tokio::test]
async fn test_redirect_http_to_https_with_allowed_host() {
let config = config_with_allowed_hosts(vec!["example.com"]);
let middleware = HttpsRedirectMiddleware::new(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
assert_eq!(
response.headers.get("Location").unwrap(),
"https://example.com/test"
);
}
#[rstest]
#[tokio::test]
async fn test_no_redirect_for_https() {
let config = config_with_allowed_hosts(vec!["example.com"]);
let middleware = HttpsRedirectMiddleware::new(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.secure(true)
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
}
#[rstest]
#[tokio::test]
async fn test_exempt_paths() {
let config = HttpsRedirectConfig {
enabled: true,
exempt_paths: vec!["/health".to_string()],
status_code: StatusCode::MOVED_PERMANENTLY,
allowed_hosts: vec!["example.com".to_string()],
};
let middleware = HttpsRedirectMiddleware::new(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/health")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
}
#[rstest]
#[tokio::test]
async fn test_reject_disallowed_host() {
let config = config_with_allowed_hosts(vec!["example.com"]);
let middleware = HttpsRedirectMiddleware::new(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "evil.com".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::BAD_REQUEST);
assert!(response.headers.get("Location").is_none());
}
#[rstest]
#[tokio::test]
async fn test_reject_host_with_path_separator() {
let config = config_with_allowed_hosts(vec!["example.com"]);
let middleware = HttpsRedirectMiddleware::new(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "evil.com/redirect".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::BAD_REQUEST);
}
#[rstest]
#[tokio::test]
async fn test_reject_missing_host_header() {
let config = config_with_allowed_hosts(vec!["example.com"]);
let middleware = HttpsRedirectMiddleware::new(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::BAD_REQUEST);
}
#[rstest]
#[tokio::test]
async fn test_reject_empty_allowed_hosts() {
let middleware = HttpsRedirectMiddleware::default_config();
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::BAD_REQUEST);
}
#[rstest]
#[tokio::test]
async fn test_allowed_host_with_port() {
let config = config_with_allowed_hosts(vec!["example.com"]);
let middleware = HttpsRedirectMiddleware::new(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "example.com:8080".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
assert_eq!(
response.headers.get("Location").unwrap(),
"https://example.com:8080/test"
);
}
#[rstest]
#[tokio::test]
async fn test_case_insensitive_host_matching() {
let config = config_with_allowed_hosts(vec!["Example.COM"]);
let middleware = HttpsRedirectMiddleware::new(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
}
}