use axum::{middleware::Next, response::IntoResponse};
use tower_http::cors::{Any, CorsLayer};
pub fn cors_layer() -> CorsLayer {
tracing::warn!(
"Using permissive CORS settings (allows all origins). \
This is suitable for development only. \
For production, configure cors_origins in server config and use cors_layer_restricted()."
);
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
.expose_headers(Any)
}
pub fn cors_layer_restricted(allowed_origins: &[String]) -> CorsLayer {
let origins: Vec<_> = allowed_origins.iter().filter_map(|origin| origin.parse().ok()).collect();
CorsLayer::new()
.allow_origin(origins)
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
axum::http::Method::OPTIONS,
])
.allow_headers([
axum::http::header::CONTENT_TYPE,
axum::http::header::AUTHORIZATION,
])
}
pub async fn security_headers_middleware(
req: axum::extract::Request,
next: Next,
) -> impl IntoResponse {
let mut response = next.run(req).await;
let headers = response.headers_mut();
headers.insert("X-Content-Type-Options", "nosniff".parse().expect("valid header value"));
headers.insert("X-Frame-Options", "DENY".parse().expect("valid header value"));
headers.insert(
"Strict-Transport-Security",
"max-age=31536000; includeSubDomains".parse().expect("valid header value"),
);
headers.insert(
"Referrer-Policy",
"strict-origin-when-cross-origin".parse().expect("valid header value"),
);
headers.insert(
"Content-Security-Policy",
"default-src 'self'; script-src 'self'; style-src 'self'"
.parse()
.expect("valid header value"),
);
headers.insert("X-XSS-Protection", "0".parse().expect("valid header value"));
response
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn test_cors_layer_creation() {
let _layer = cors_layer();
}
#[test]
fn test_cors_layer_restricted() {
let origins = vec!["https://example.com".to_string()];
let _layer = cors_layer_restricted(&origins);
}
#[test]
fn test_cors_layer_restricted_empty_origins() {
let origins = vec![];
let _layer = cors_layer_restricted(&origins);
}
#[test]
fn test_cors_layer_restricted_invalid_origin() {
let origins = vec![
"not-a-valid-url".to_string(),
"https://valid.com".to_string(),
];
let layer = cors_layer_restricted(&origins);
let _ = layer;
}
use axum::{Router, body::Body, http::Request, middleware, routing::get};
use tower::ServiceExt;
async fn ok_handler() -> &'static str {
"ok"
}
fn sec_app() -> Router {
Router::new()
.route("/", get(ok_handler))
.layer(middleware::from_fn(security_headers_middleware))
}
#[tokio::test]
async fn test_security_headers_nosniff_present() {
let resp = sec_app()
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp.headers().get("x-content-type-options").unwrap(), "nosniff");
}
#[tokio::test]
async fn test_security_headers_frame_options_deny() {
let resp = sec_app()
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
}
#[tokio::test]
async fn test_security_headers_xss_protection_zero() {
let resp = sec_app()
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(
resp.headers().get("x-xss-protection").unwrap(),
"0",
"X-XSS-Protection must be 0 (legacy auditor disabled)"
);
}
#[test]
fn test_cors_layer_config_comprehensive() {
let origins = vec![
"https://example.com".to_string(),
"https://app.example.com".to_string(),
];
let _ = cors_layer_restricted(&origins);
}
}