use axum::{
body::Body,
http::{header, Request, Response},
middleware::Next,
response::IntoResponse,
};
use std::fmt;
#[derive(Debug, Clone)]
pub struct SecurityHeadersConfig {
pub frame_options: Option<FrameOptions>,
pub content_type_options: bool,
pub xss_protection: Option<bool>,
pub hsts: Option<HstsConfig>,
pub csp: Option<String>,
pub referrer_policy: Option<ReferrerPolicy>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FrameOptions {
Deny,
SameOrigin,
}
impl fmt::Display for FrameOptions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Deny => write!(f, "DENY"),
Self::SameOrigin => write!(f, "SAMEORIGIN"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HstsConfig {
pub max_age: u32,
pub include_subdomains: bool,
pub preload: bool,
}
impl HstsConfig {
#[must_use]
pub const fn strict() -> Self {
Self {
max_age: 31_536_000, include_subdomains: true,
preload: true,
}
}
#[must_use]
pub const fn moderate() -> Self {
Self {
max_age: 31_536_000, include_subdomains: false,
preload: false,
}
}
}
impl fmt::Display for HstsConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "max-age={}", self.max_age)?;
if self.include_subdomains {
write!(f, "; includeSubDomains")?;
}
if self.preload {
write!(f, "; preload")?;
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReferrerPolicy {
NoReferrer,
NoReferrerWhenDowngrade,
Origin,
OriginWhenCrossOrigin,
SameOrigin,
StrictOrigin,
StrictOriginWhenCrossOrigin,
UnsafeUrl,
}
impl fmt::Display for ReferrerPolicy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NoReferrer => write!(f, "no-referrer"),
Self::NoReferrerWhenDowngrade => write!(f, "no-referrer-when-downgrade"),
Self::Origin => write!(f, "origin"),
Self::OriginWhenCrossOrigin => write!(f, "origin-when-cross-origin"),
Self::SameOrigin => write!(f, "same-origin"),
Self::StrictOrigin => write!(f, "strict-origin"),
Self::StrictOriginWhenCrossOrigin => write!(f, "strict-origin-when-cross-origin"),
Self::UnsafeUrl => write!(f, "unsafe-url"),
}
}
}
impl SecurityHeadersConfig {
#[must_use]
pub fn strict() -> Self {
Self {
frame_options: Some(FrameOptions::Deny),
content_type_options: true,
xss_protection: Some(true),
hsts: Some(HstsConfig::strict()),
csp: Some("default-src 'self'".to_string()),
referrer_policy: Some(ReferrerPolicy::StrictOriginWhenCrossOrigin),
}
}
#[must_use]
pub fn development() -> Self {
Self {
frame_options: Some(FrameOptions::SameOrigin),
content_type_options: true,
xss_protection: None, hsts: None, csp: Some(
"default-src 'self' 'unsafe-inline' 'unsafe-eval'; img-src 'self' data:"
.to_string(),
),
referrer_policy: Some(ReferrerPolicy::StrictOriginWhenCrossOrigin),
}
}
#[must_use]
pub const fn custom() -> Self {
Self {
frame_options: None,
content_type_options: false,
xss_protection: None,
hsts: None,
csp: None,
referrer_policy: None,
}
}
#[must_use]
pub const fn with_frame_options(mut self, options: FrameOptions) -> Self {
self.frame_options = Some(options);
self
}
#[must_use]
pub const fn with_content_type_options(mut self) -> Self {
self.content_type_options = true;
self
}
#[must_use]
pub const fn with_xss_protection(mut self, block_mode: bool) -> Self {
self.xss_protection = Some(block_mode);
self
}
#[must_use]
pub const fn with_hsts(mut self, config: HstsConfig) -> Self {
self.hsts = Some(config);
self
}
#[must_use]
pub fn with_csp(mut self, policy: String) -> Self {
self.csp = Some(policy);
self
}
#[must_use]
pub const fn with_referrer_policy(mut self, policy: ReferrerPolicy) -> Self {
self.referrer_policy = Some(policy);
self
}
}
#[derive(Clone)]
pub struct SecurityHeadersLayer {
config: SecurityHeadersConfig,
}
impl SecurityHeadersLayer {
#[must_use]
pub const fn new(config: SecurityHeadersConfig) -> Self {
Self { config }
}
}
impl<S> tower::Layer<S> for SecurityHeadersLayer {
type Service = SecurityHeadersMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
SecurityHeadersMiddleware {
inner,
config: self.config.clone(),
}
}
}
#[derive(Clone)]
pub struct SecurityHeadersMiddleware<S> {
inner: S,
config: SecurityHeadersConfig,
}
impl<S> tower::Service<Request<Body>> for SecurityHeadersMiddleware<S>
where
S: tower::Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let config = self.config.clone();
let future = self.inner.call(request);
Box::pin(async move {
let mut response = future.await?;
add_security_headers(&mut response, &config);
Ok(response)
})
}
}
fn add_security_headers(response: &mut Response<Body>, config: &SecurityHeadersConfig) {
let headers = response.headers_mut();
if let Some(frame_options) = &config.frame_options {
headers.insert("x-frame-options", frame_options.to_string().parse().unwrap());
}
if config.content_type_options {
headers.insert(
"x-content-type-options",
"nosniff".parse().unwrap(),
);
}
if let Some(block_mode) = config.xss_protection {
let value = if block_mode {
"1; mode=block"
} else {
"1"
};
headers.insert("x-xss-protection", value.parse().unwrap());
}
if let Some(hsts) = &config.hsts {
headers.insert(
header::STRICT_TRANSPORT_SECURITY,
hsts.to_string().parse().unwrap(),
);
}
if let Some(csp) = &config.csp {
headers.insert(
header::CONTENT_SECURITY_POLICY,
csp.parse().unwrap(),
);
}
if let Some(referrer_policy) = &config.referrer_policy {
headers.insert(
header::REFERER,
referrer_policy.to_string().parse().unwrap(),
);
}
}
pub async fn security_headers(
request: Request<Body>,
next: Next,
config: SecurityHeadersConfig,
) -> impl IntoResponse {
let mut response = next.run(request).await;
add_security_headers(&mut response, &config);
response
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request, StatusCode},
response::IntoResponse,
routing::get,
Router,
};
use tower::ServiceExt;
async fn test_handler() -> impl IntoResponse {
(StatusCode::OK, "Hello, World!")
}
#[tokio::test]
async fn test_strict_config_headers() {
let config = SecurityHeadersConfig::strict();
let app = Router::new()
.route("/", get(test_handler))
.layer(SecurityHeadersLayer::new(config));
let request = Request::builder()
.uri("/")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
let headers = response.headers();
assert_eq!(headers.get("x-frame-options").unwrap(), "DENY");
assert_eq!(headers.get("x-content-type-options").unwrap(), "nosniff");
assert_eq!(headers.get("x-xss-protection").unwrap(), "1; mode=block");
assert!(headers.contains_key("strict-transport-security"));
assert!(headers.contains_key("content-security-policy"));
}
#[tokio::test]
async fn test_development_config_headers() {
let config = SecurityHeadersConfig::development();
let app = Router::new()
.route("/", get(test_handler))
.layer(SecurityHeadersLayer::new(config));
let request = Request::builder()
.uri("/")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
let headers = response.headers();
assert_eq!(headers.get("x-frame-options").unwrap(), "SAMEORIGIN");
assert_eq!(headers.get("x-content-type-options").unwrap(), "nosniff");
assert!(!headers.contains_key("x-xss-protection"));
assert!(!headers.contains_key("strict-transport-security"));
assert!(headers.contains_key("content-security-policy"));
}
#[tokio::test]
async fn test_custom_config() {
let config = SecurityHeadersConfig::custom()
.with_frame_options(FrameOptions::SameOrigin)
.with_content_type_options()
.with_referrer_policy(ReferrerPolicy::NoReferrer);
let app = Router::new()
.route("/", get(test_handler))
.layer(SecurityHeadersLayer::new(config));
let request = Request::builder()
.uri("/")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
let headers = response.headers();
assert_eq!(headers.get("x-frame-options").unwrap(), "SAMEORIGIN");
assert_eq!(headers.get("x-content-type-options").unwrap(), "nosniff");
assert!(!headers.contains_key("x-xss-protection"));
assert!(!headers.contains_key("strict-transport-security"));
assert!(!headers.contains_key("content-security-policy"));
}
#[test]
fn test_hsts_config_display() {
let hsts = HstsConfig::strict();
assert_eq!(
hsts.to_string(),
"max-age=31536000; includeSubDomains; preload"
);
let hsts = HstsConfig::moderate();
assert_eq!(hsts.to_string(), "max-age=31536000");
}
#[test]
fn test_frame_options_display() {
assert_eq!(FrameOptions::Deny.to_string(), "DENY");
assert_eq!(FrameOptions::SameOrigin.to_string(), "SAMEORIGIN");
}
#[test]
fn test_referrer_policy_display() {
assert_eq!(ReferrerPolicy::NoReferrer.to_string(), "no-referrer");
assert_eq!(
ReferrerPolicy::StrictOriginWhenCrossOrigin.to_string(),
"strict-origin-when-cross-origin"
);
}
#[test]
fn test_config_builder() {
let config = SecurityHeadersConfig::custom()
.with_frame_options(FrameOptions::Deny)
.with_content_type_options()
.with_xss_protection(true)
.with_hsts(HstsConfig::strict())
.with_csp("default-src 'self'".to_string())
.with_referrer_policy(ReferrerPolicy::StrictOriginWhenCrossOrigin);
assert_eq!(config.frame_options, Some(FrameOptions::Deny));
assert!(config.content_type_options);
assert_eq!(config.xss_protection, Some(true));
assert!(config.hsts.is_some());
assert!(config.csp.is_some());
assert_eq!(
config.referrer_policy,
Some(ReferrerPolicy::StrictOriginWhenCrossOrigin)
);
}
}