use http::{
header::{HeaderValue, HOST, LOCATION},
Request, Response, StatusCode,
};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tower::{Layer, Service};
#[derive(Clone, Debug)]
pub struct RedirectHttps {
status: StatusCode,
https_port: Option<u16>,
}
impl Default for RedirectHttps {
fn default() -> Self {
Self {
status: StatusCode::PERMANENT_REDIRECT,
https_port: None,
}
}
}
impl RedirectHttps {
pub fn new() -> Self {
Self::default()
}
pub fn status(mut self, status: StatusCode) -> Self {
self.status = status;
self
}
pub fn https_port(mut self, port: u16) -> Self {
self.https_port = Some(port);
self
}
fn is_http<B>(req: &Request<B>) -> bool {
if let Some(proto) = req.headers().get("x-forwarded-proto") {
return proto.as_bytes().eq_ignore_ascii_case(b"http");
}
req.uri().scheme() == Some(&http::uri::Scheme::HTTP)
}
fn location<B>(&self, req: &Request<B>) -> Option<HeaderValue> {
let host = req.headers().get(HOST)?.to_str().ok()?;
let hostname = host
.rsplit_once(':')
.filter(|(_, port)| port.parse::<u16>().is_ok())
.map_or(host, |(h, _)| h);
let authority = match self.https_port {
Some(port) => format!("{hostname}:{port}"),
None => hostname.to_owned(),
};
let path_and_query = req
.uri()
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/");
HeaderValue::from_str(&format!("https://{authority}{path_and_query}")).ok()
}
}
#[derive(Clone, Debug, Default)]
pub struct RedirectHttpsLayer {
config: RedirectHttps,
}
impl RedirectHttpsLayer {
pub fn new(config: RedirectHttps) -> Self {
Self { config }
}
}
impl<S> Layer<S> for RedirectHttpsLayer {
type Service = RedirectHttpsService<S>;
fn layer(&self, inner: S) -> Self::Service {
RedirectHttpsService {
inner,
config: self.config.clone(),
}
}
}
#[derive(Clone, Debug)]
pub struct RedirectHttpsService<S> {
inner: S,
config: RedirectHttps,
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for RedirectHttpsService<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
S::Future: Send + 'static,
S::Error: Send + 'static,
ResBody: Default + Send + 'static,
{
type Response = Response<ResBody>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
if RedirectHttps::is_http(&req)
&& let Some(location) = self.config.location(&req)
{
let status = self.config.status;
return Box::pin(async move {
let mut response = Response::builder()
.status(status)
.body(ResBody::default())
.expect("redirect response is valid");
response.headers_mut().insert(LOCATION, location);
Ok(response)
});
}
Box::pin(self.inner.call(req))
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{body::Body, routing::get, Router};
use http::StatusCode;
use tower::ServiceExt;
fn build_app(config: RedirectHttps) -> Router {
Router::new()
.route("/", get(|| async { "ok" }))
.layer(RedirectHttpsLayer::new(config))
}
async fn send(app: Router, req: http::Request<Body>) -> http::Response<Body> {
app.oneshot(req).await.unwrap()
}
fn forwarded_request(proto: &str, uri: &str) -> http::Request<Body> {
http::Request::builder()
.uri(uri)
.header(HOST, "example.com")
.header("x-forwarded-proto", proto)
.body(Body::empty())
.unwrap()
}
#[tokio::test]
async fn redirects_on_x_forwarded_proto_http() {
let response = send(
build_app(RedirectHttps::new()),
forwarded_request("http", "/path?q=1"),
)
.await;
assert_eq!(response.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(
response.headers()["location"],
"https://example.com/path?q=1"
);
}
#[tokio::test]
async fn passes_through_on_x_forwarded_proto_https() {
let response = send(
build_app(RedirectHttps::new()),
forwarded_request("https", "/"),
)
.await;
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn redirects_on_http_uri_scheme() {
let req = http::Request::builder()
.uri("http://example.com/page")
.header(HOST, "example.com")
.body(Body::empty())
.unwrap();
let response = send(build_app(RedirectHttps::new()), req).await;
assert_eq!(response.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(response.headers()["location"], "https://example.com/page");
}
#[tokio::test]
async fn passes_through_when_no_scheme_indicator() {
let req = http::Request::builder()
.uri("/")
.header(HOST, "example.com")
.body(Body::empty())
.unwrap();
let response = send(build_app(RedirectHttps::new()), req).await;
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn passes_through_when_no_host_header() {
let req = http::Request::builder()
.uri("/")
.header("x-forwarded-proto", "http")
.body(Body::empty())
.unwrap();
let response = send(build_app(RedirectHttps::new()), req).await;
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn custom_status_301() {
let config = RedirectHttps::new().status(StatusCode::MOVED_PERMANENTLY);
let response = send(build_app(config), forwarded_request("http", "/")).await;
assert_eq!(response.status(), StatusCode::MOVED_PERMANENTLY);
}
#[tokio::test]
async fn strips_http_port_from_host() {
let req = http::Request::builder()
.uri("/path")
.header(HOST, "example.com:80")
.header("x-forwarded-proto", "http")
.body(Body::empty())
.unwrap();
let response = send(build_app(RedirectHttps::new()), req).await;
assert_eq!(response.headers()["location"], "https://example.com/path");
}
#[tokio::test]
async fn custom_https_port() {
let config = RedirectHttps::new().https_port(8443);
let req = http::Request::builder()
.uri("/path")
.header(HOST, "example.com:8080")
.header("x-forwarded-proto", "http")
.body(Body::empty())
.unwrap();
let response = send(build_app(config), req).await;
assert_eq!(
response.headers()["location"],
"https://example.com:8443/path"
);
}
#[tokio::test]
async fn default_layer_uses_308() {
let app = Router::new()
.route("/", get(|| async { "ok" }))
.layer(RedirectHttpsLayer::default());
let response = send(app, forwarded_request("http", "/")).await;
assert_eq!(response.status(), StatusCode::PERMANENT_REDIRECT);
}
}