use axum::{middleware::Next, response::IntoResponse};
use tower_http::cors::{Any, CorsLayer};
#[must_use]
pub fn cors_layer() -> CorsLayer {
tracing::warn!(
"Using permissive CORS settings (allows all origins). \
This is suitable for development only. \
For production, configure cors_origins in server config and use cors_layer_restricted()."
);
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
.expose_headers(Any)
}
#[must_use]
pub fn cors_layer_restricted(allowed_origins: Vec<String>) -> CorsLayer {
let origins: Vec<_> = allowed_origins.iter().filter_map(|origin| origin.parse().ok()).collect();
CorsLayer::new()
.allow_origin(origins)
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
axum::http::Method::OPTIONS,
])
.allow_headers([
axum::http::header::CONTENT_TYPE,
axum::http::header::AUTHORIZATION,
])
}
pub async fn security_headers_middleware(
req: axum::extract::Request,
next: Next,
) -> impl IntoResponse {
let mut response = next.run(req).await;
let headers = response.headers_mut();
headers.insert("X-Content-Type-Options", "nosniff".parse().expect("valid header value"));
headers.insert("X-Frame-Options", "DENY".parse().expect("valid header value"));
headers.insert(
"Strict-Transport-Security",
"max-age=31536000; includeSubDomains".parse().expect("valid header value"),
);
headers.insert("X-XSS-Protection", "1; mode=block".parse().expect("valid header value"));
headers.insert(
"Referrer-Policy",
"strict-origin-when-cross-origin".parse().expect("valid header value"),
);
headers.insert(
"Content-Security-Policy",
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'"
.parse()
.expect("valid header value"),
);
response
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cors_layer_creation() {
let _layer = cors_layer();
}
#[test]
fn test_cors_layer_restricted() {
let origins = vec!["https://example.com".to_string()];
let _layer = cors_layer_restricted(origins);
}
#[test]
fn test_cors_layer_restricted_empty_origins() {
let origins = vec![];
let _layer = cors_layer_restricted(origins);
}
#[test]
fn test_cors_layer_restricted_invalid_origin() {
let origins = vec![
"not-a-valid-url".to_string(),
"https://valid.com".to_string(),
];
let layer = cors_layer_restricted(origins);
let _ = layer;
}
#[test]
fn test_security_headers_values_hardcoded() {
let header = "nosniff";
assert_eq!(header, "nosniff");
let header = "DENY";
assert_eq!(header, "DENY");
let header = "max-age=31536000; includeSubDomains";
assert!(header.contains("max-age=31536000"));
assert!(header.contains("includeSubDomains"));
let header = "1; mode=block";
assert_eq!(header, "1; mode=block");
let header = "strict-origin-when-cross-origin";
assert_eq!(header, "strict-origin-when-cross-origin");
let header = "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'";
assert!(header.contains("default-src 'self'"));
}
#[test]
fn test_security_headers_csp_structure() {
let csp = "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'";
let directives: Vec<&str> = csp.split(';').map(|s| s.trim()).collect();
assert_eq!(directives.len(), 3);
assert!(directives[0].contains("default-src"));
assert!(directives[1].contains("script-src"));
assert!(directives[2].contains("style-src"));
}
#[test]
fn test_cors_layer_config_comprehensive() {
let origins = vec![
"https://example.com".to_string(),
"https://app.example.com".to_string(),
];
let layer = cors_layer_restricted(origins);
let _ = layer;
}
#[test]
fn test_security_headers_middleware_callable() {
let _ = security_headers_middleware;
}
#[test]
fn test_hsts_policy_compliance() {
let max_age_seconds = 31_536_000; assert!(max_age_seconds >= 31_536_000, "HSTS max-age should be at least 1 year");
}
#[test]
fn test_csp_policy_compliance() {
let csp = "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'";
assert!(csp.contains("'self'"), "CSP should restrict to same-origin");
assert!(!csp.contains("*"), "CSP should not allow wildcards");
}
}