use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use axum::{
body::Body,
http::{
Request, StatusCode,
header::{HOST, HeaderName, HeaderValue, LOCATION, STRICT_TRANSPORT_SECURITY},
},
response::Response,
};
use tower::{Layer, Service};
type BoxResponseFuture<E> = Pin<Box<dyn Future<Output = Result<Response, E>> + Send>>;
fn request_is_secure(request: &Request<Body>) -> bool {
request
.headers()
.get(HeaderName::from_static("x-forwarded-proto"))
.and_then(|value| value.to_str().ok())
.is_some_and(|value| value.eq_ignore_ascii_case("https"))
}
fn redirect_target(request: &Request<Body>) -> Option<String> {
let host = request.headers().get(HOST)?.to_str().ok()?;
let path_and_query = request
.uri()
.path_and_query()
.map_or("/", |value| value.as_str());
Some(format!("https://{host}{path_and_query}"))
}
#[derive(Clone)]
pub struct SecurityMiddlewareLayer {
pub hsts_seconds: u64,
pub content_type_nosniff: bool,
pub ssl_redirect: bool,
}
impl<S> Layer<S> for SecurityMiddlewareLayer {
type Service = SecurityMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
SecurityMiddleware {
inner,
hsts_seconds: self.hsts_seconds,
content_type_nosniff: self.content_type_nosniff,
ssl_redirect: self.ssl_redirect,
}
}
}
#[derive(Clone)]
pub struct SecurityMiddleware<S> {
inner: S,
hsts_seconds: u64,
content_type_nosniff: bool,
ssl_redirect: bool,
}
impl<S> Service<Request<Body>> for SecurityMiddleware<S>
where
S: Service<Request<Body>, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = Response;
type Error = S::Error;
type Future = BoxResponseFuture<Self::Error>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let hsts_seconds = self.hsts_seconds;
let content_type_nosniff = self.content_type_nosniff;
let ssl_redirect = self.ssl_redirect;
if ssl_redirect && !request_is_secure(&request) {
let redirect = redirect_target(&request).unwrap_or_else(|| "https:///".to_string());
return Box::pin(async move {
Ok(Response::builder()
.status(StatusCode::PERMANENT_REDIRECT)
.header(LOCATION, redirect)
.body(Body::empty())
.expect("redirect response should build"))
});
}
let future = self.inner.call(request);
Box::pin(async move {
let mut response = future.await?;
if hsts_seconds > 0 {
let value = HeaderValue::from_str(&format!("max-age={hsts_seconds}"))
.expect("HSTS header value should be valid");
response
.headers_mut()
.insert(STRICT_TRANSPORT_SECURITY, value);
}
if content_type_nosniff {
response.headers_mut().insert(
HeaderName::from_static("x-content-type-options"),
HeaderValue::from_static("nosniff"),
);
}
Ok(response)
})
}
}
#[cfg(test)]
mod tests {
use std::convert::Infallible;
use super::*;
use axum::http::{Request, StatusCode, header};
use tower::{ServiceExt, service_fn};
#[tokio::test]
async fn security_layer_adds_configured_headers() {
let layer = SecurityMiddlewareLayer {
hsts_seconds: 31_536_000,
content_type_nosniff: true,
ssl_redirect: false,
};
let service = layer.layer(service_fn(|_request: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(Body::from("ok")))
}));
let response = service
.oneshot(
Request::builder()
.uri("/")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get(STRICT_TRANSPORT_SECURITY)
.expect("HSTS header should be present"),
"max-age=31536000"
);
assert_eq!(
response
.headers()
.get("x-content-type-options")
.expect("nosniff header should be present"),
"nosniff"
);
}
#[tokio::test]
async fn security_layer_redirects_insecure_requests_when_enabled() {
let layer = SecurityMiddlewareLayer {
hsts_seconds: 0,
content_type_nosniff: false,
ssl_redirect: true,
};
let service = layer.layer(service_fn(|_request: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(Body::from("ok")))
}));
let response = service
.oneshot(
Request::builder()
.uri("/dashboard?tab=security")
.header(header::HOST, "example.com")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
assert_eq!(response.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(
response
.headers()
.get(header::LOCATION)
.expect("redirect location should be present"),
"https://example.com/dashboard?tab=security"
);
}
#[test]
fn security_layer_is_cloneable() {
let layer = SecurityMiddlewareLayer {
hsts_seconds: 0,
content_type_nosniff: false,
ssl_redirect: true,
};
let _ = layer.clone();
}
}