use async_trait::async_trait;
use reinhardt_http::{Handler, Middleware, Request, Response, Result};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tracing::{debug, warn};
#[derive(Debug, Clone)]
pub struct CspNonce(pub String);
fn is_valid_nonce(nonce: &str) -> bool {
!nonce.is_empty()
&& nonce
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'=')
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct CspConfig {
pub directives: HashMap<String, Vec<String>>,
pub report_only: bool,
pub include_nonce: bool,
pub exempt_paths: HashSet<String>,
}
impl Default for CspConfig {
fn default() -> Self {
let mut directives = HashMap::new();
directives.insert("default-src".to_string(), vec!["'self'".to_string()]);
Self {
directives,
report_only: false,
include_nonce: false,
exempt_paths: HashSet::new(),
}
}
}
impl CspConfig {
pub fn strict() -> Self {
let mut directives = HashMap::new();
directives.insert("default-src".to_string(), vec!["'self'".to_string()]);
directives.insert("script-src".to_string(), vec!["'self'".to_string()]);
directives.insert("style-src".to_string(), vec!["'self'".to_string()]);
directives.insert(
"img-src".to_string(),
vec!["'self'".to_string(), "data:".to_string()],
);
directives.insert("font-src".to_string(), vec!["'self'".to_string()]);
directives.insert("connect-src".to_string(), vec!["'self'".to_string()]);
directives.insert("frame-ancestors".to_string(), vec!["'none'".to_string()]);
directives.insert("base-uri".to_string(), vec!["'self'".to_string()]);
directives.insert("form-action".to_string(), vec!["'self'".to_string()]);
Self {
directives,
report_only: false,
include_nonce: false,
exempt_paths: HashSet::new(),
}
}
pub fn add_exempt_path(mut self, path: String) -> Self {
self.exempt_paths.insert(path);
self
}
}
pub struct CspMiddleware {
config: CspConfig,
}
impl CspMiddleware {
pub fn new() -> Self {
Self {
config: CspConfig::default(),
}
}
pub fn with_config(config: CspConfig) -> Self {
Self { config }
}
pub fn strict() -> Self {
Self {
config: CspConfig::strict(),
}
}
fn generate_nonce(&self) -> String {
use base64::Engine;
use rand::RngCore;
let mut bytes = [0u8; 16];
rand::rng().fill_bytes(&mut bytes);
base64::engine::general_purpose::STANDARD.encode(bytes)
}
fn build_csp_header(&self, nonce: Option<&str>) -> String {
let mut parts = Vec::new();
let validated_nonce = nonce.filter(|n| is_valid_nonce(n));
for (directive, values) in &self.config.directives {
let mut directive_values = values.clone();
if self.config.include_nonce
&& (directive == "script-src" || directive == "style-src")
&& let Some(n) = validated_nonce
{
directive_values.push(format!("'nonce-{}'", n));
}
parts.push(format!("{} {}", directive, directive_values.join(" ")));
}
parts.join("; ")
}
fn get_header_name(&self) -> &'static str {
if self.config.report_only {
"Content-Security-Policy-Report-Only"
} else {
"Content-Security-Policy"
}
}
}
impl Default for CspMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Middleware for CspMiddleware {
async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
let path = request.uri.path();
if self
.config
.exempt_paths
.iter()
.any(|exempt| path == exempt.as_str() || path.starts_with(&format!("{}/", exempt)))
{
debug!(
path = path,
"Path is CSP-exempt, skipping CSP header insertion"
);
return match handler.handle(request).await {
Ok(resp) => Ok(resp),
Err(e) => Ok(Response::from(e)),
};
}
let nonce = if self.config.include_nonce {
let generated_nonce = self.generate_nonce();
request.extensions.insert(CspNonce(generated_nonce.clone()));
Some(generated_nonce)
} else {
None
};
let mut response = match handler.handle(request).await {
Ok(resp) => resp,
Err(e) => Response::from(e),
};
let header_name = self.get_header_name();
if response.headers.contains_key(header_name) {
debug!(
header = header_name,
"CSP header already present in response, skipping middleware insertion"
);
} else {
let csp_value = self.build_csp_header(nonce.as_deref());
match csp_value.parse() {
Ok(value) => {
response.headers.insert(header_name, value);
}
Err(e) => {
warn!(
error = %e,
"Failed to parse CSP header value, skipping header insertion"
);
}
}
}
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, StatusCode, Version};
use rstest::rstest;
struct TestHandler;
#[async_trait]
impl Handler for TestHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
Ok(Response::new(StatusCode::OK).with_body(Bytes::from("content")))
}
}
#[tokio::test]
async fn test_default_csp_header() {
let middleware = CspMiddleware::new();
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let csp_header = response.headers.get("Content-Security-Policy").unwrap();
assert!(csp_header.to_str().unwrap().contains("default-src 'self'"));
}
#[tokio::test]
async fn test_custom_csp_directives() {
let mut directives = HashMap::new();
directives.insert("default-src".to_string(), vec!["'self'".to_string()]);
directives.insert(
"script-src".to_string(),
vec!["'self'".to_string(), "https://cdn.example.com".to_string()],
);
let config = CspConfig {
directives,
report_only: false,
include_nonce: false,
exempt_paths: HashSet::new(),
};
let middleware = CspMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let csp_header = response
.headers
.get("Content-Security-Policy")
.unwrap()
.to_str()
.unwrap();
assert!(csp_header.contains("default-src 'self'"));
assert!(csp_header.contains("script-src 'self' https://cdn.example.com"));
}
#[tokio::test]
async fn test_report_only_mode() {
let config = CspConfig {
directives: {
let mut d = HashMap::new();
d.insert("default-src".to_string(), vec!["'self'".to_string()]);
d
},
report_only: true,
include_nonce: false,
exempt_paths: HashSet::new(),
};
let middleware = CspMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(
response
.headers
.contains_key("Content-Security-Policy-Report-Only")
);
assert!(!response.headers.contains_key("Content-Security-Policy"));
}
#[tokio::test]
async fn test_nonce_generation() {
let config = CspConfig {
directives: {
let mut d = HashMap::new();
d.insert("script-src".to_string(), vec!["'self'".to_string()]);
d
},
report_only: false,
include_nonce: true,
exempt_paths: HashSet::new(),
};
let middleware = CspMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let csp_header = response
.headers
.get("Content-Security-Policy")
.unwrap()
.to_str()
.unwrap();
assert!(csp_header.contains("'nonce-"));
}
#[tokio::test]
async fn test_strict_csp() {
let middleware = CspMiddleware::strict();
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let csp_header = response
.headers
.get("Content-Security-Policy")
.unwrap()
.to_str()
.unwrap();
assert!(csp_header.contains("default-src 'self'"));
assert!(csp_header.contains("script-src 'self'"));
assert!(csp_header.contains("style-src 'self'"));
assert!(csp_header.contains("frame-ancestors 'none'"));
assert!(csp_header.contains("base-uri 'self'"));
}
#[tokio::test]
async fn test_multiple_directive_values() {
let mut directives = HashMap::new();
directives.insert(
"img-src".to_string(),
vec![
"'self'".to_string(),
"data:".to_string(),
"https:".to_string(),
],
);
let config = CspConfig {
directives,
report_only: false,
include_nonce: false,
exempt_paths: HashSet::new(),
};
let middleware = CspMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let csp_header = response
.headers
.get("Content-Security-Policy")
.unwrap()
.to_str()
.unwrap();
assert!(csp_header.contains("img-src 'self' data: https:"));
}
#[tokio::test]
async fn test_nonce_only_added_to_script_and_style() {
let mut directives = HashMap::new();
directives.insert("script-src".to_string(), vec!["'self'".to_string()]);
directives.insert("style-src".to_string(), vec!["'self'".to_string()]);
directives.insert("img-src".to_string(), vec!["'self'".to_string()]);
let config = CspConfig {
directives,
report_only: false,
include_nonce: true,
exempt_paths: HashSet::new(),
};
let middleware = CspMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let csp_header = response
.headers
.get("Content-Security-Policy")
.unwrap()
.to_str()
.unwrap();
let nonce_count = csp_header.matches("'nonce-").count();
assert_eq!(nonce_count, 2);
}
#[tokio::test]
async fn test_empty_directives() {
let config = CspConfig {
directives: HashMap::new(),
report_only: false,
include_nonce: false,
exempt_paths: HashSet::new(),
};
let middleware = CspMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(response.headers.contains_key("Content-Security-Policy"));
}
#[tokio::test]
async fn test_frame_ancestors_directive() {
let mut directives = HashMap::new();
directives.insert(
"frame-ancestors".to_string(),
vec!["'self'".to_string(), "https://trusted.com".to_string()],
);
let config = CspConfig {
directives,
report_only: false,
include_nonce: false,
exempt_paths: HashSet::new(),
};
let middleware = CspMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let csp_header = response
.headers
.get("Content-Security-Policy")
.unwrap()
.to_str()
.unwrap();
assert!(csp_header.contains("frame-ancestors 'self' https://trusted.com"));
}
#[tokio::test]
async fn test_nonce_uniqueness_across_requests() {
let config = CspConfig {
directives: {
let mut d = HashMap::new();
d.insert("script-src".to_string(), vec!["'self'".to_string()]);
d
},
report_only: false,
include_nonce: true,
exempt_paths: HashSet::new(),
};
let middleware = CspMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let request1 = Request::builder()
.method(Method::GET)
.uri("/page1")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response1 = middleware.process(request1, handler.clone()).await.unwrap();
let csp1 = response1
.headers
.get("Content-Security-Policy")
.unwrap()
.to_str()
.unwrap()
.to_string();
let request2 = Request::builder()
.method(Method::GET)
.uri("/page2")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response2 = middleware.process(request2, handler).await.unwrap();
let csp2 = response2
.headers
.get("Content-Security-Policy")
.unwrap()
.to_str()
.unwrap()
.to_string();
let extract_nonce = |csp: &str| -> Option<String> {
csp.split("'nonce-")
.nth(1)
.and_then(|s| s.split('\'').next())
.map(|s| s.to_string())
};
let nonce1 = extract_nonce(&csp1);
let nonce2 = extract_nonce(&csp2);
assert!(nonce1.is_some(), "First CSP should contain nonce");
assert!(nonce2.is_some(), "Second CSP should contain nonce");
assert_ne!(nonce1, nonce2, "Nonces should be unique across requests");
}
#[tokio::test]
async fn test_response_body_preserved() {
struct TestHandlerWithBody;
#[async_trait]
impl Handler for TestHandlerWithBody {
async fn handle(&self, _request: Request) -> Result<Response> {
Ok(Response::new(StatusCode::OK).with_body(Bytes::from("custom response content")))
}
}
let middleware = CspMiddleware::new();
let handler = Arc::new(TestHandlerWithBody);
let request = Request::builder()
.method(Method::GET)
.uri("/page")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(response.headers.contains_key("Content-Security-Policy"));
assert_eq!(response.body, Bytes::from("custom response content"));
}
#[rstest]
fn test_nonce_is_valid_base64() {
use base64::Engine;
let middleware = CspMiddleware::new();
let nonce = middleware.generate_nonce();
let decoded = base64::engine::general_purpose::STANDARD.decode(&nonce);
assert!(
decoded.is_ok(),
"Nonce should be valid base64, got: {}",
nonce
);
}
#[rstest]
fn test_nonce_length() {
use base64::Engine;
let middleware = CspMiddleware::new();
let nonce = middleware.generate_nonce();
let decoded = base64::engine::general_purpose::STANDARD
.decode(&nonce)
.unwrap();
assert_eq!(
decoded.len(),
16,
"Nonce should be exactly 16 bytes (128 bits)"
);
}
#[rstest]
fn test_is_valid_nonce_accepts_base64() {
assert!(is_valid_nonce("YWJjZGVmZw=="));
assert!(is_valid_nonce("abc123+/="));
assert!(is_valid_nonce("ABCDEFGHIJKLMNOP"));
}
#[rstest]
fn test_is_valid_nonce_rejects_invalid_chars() {
assert!(!is_valid_nonce(""));
assert!(!is_valid_nonce("abc\ndef"));
assert!(!is_valid_nonce("abc;def"));
assert!(!is_valid_nonce("abc def"));
assert!(!is_valid_nonce("abc'def"));
assert!(!is_valid_nonce("abc\rdef"));
}
#[rstest]
fn test_build_csp_header_rejects_invalid_nonce() {
let mut directives = HashMap::new();
directives.insert("script-src".to_string(), vec!["'self'".to_string()]);
let config = CspConfig {
directives,
report_only: false,
include_nonce: true,
exempt_paths: HashSet::new(),
};
let middleware = CspMiddleware::with_config(config);
let csp = middleware.build_csp_header(Some("abc\r\ndef;injected"));
assert!(
!csp.contains("nonce-"),
"Invalid nonce should not be embedded in header"
);
assert!(csp.contains("script-src 'self'"));
}
#[rstest]
fn test_nonce_entropy() {
let middleware = CspMiddleware::new();
let mut nonces = std::collections::HashSet::new();
for _ in 0..100 {
nonces.insert(middleware.generate_nonce());
}
assert_eq!(
nonces.len(),
100,
"All 100 nonces should be unique (statistical randomness)"
);
}
#[tokio::test]
async fn test_does_not_override_existing_csp_header() {
struct HandlerWithCsp;
#[async_trait]
impl Handler for HandlerWithCsp {
async fn handle(&self, _request: Request) -> Result<Response> {
Ok(Response::new(StatusCode::OK).with_header(
"Content-Security-Policy",
"default-src 'self'; style-src 'self' 'unsafe-inline'",
))
}
}
let middleware = CspMiddleware::strict();
let handler = Arc::new(HandlerWithCsp);
let request = Request::builder()
.method(Method::GET)
.uri("/admin/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let csp = response
.headers
.get("Content-Security-Policy")
.unwrap()
.to_str()
.unwrap();
assert!(
csp.contains("'unsafe-inline'"),
"Handler-set CSP should be preserved, got: {}",
csp
);
}
#[tokio::test]
async fn test_does_not_override_existing_csp_report_only_header() {
struct HandlerWithReportOnlyCsp;
#[async_trait]
impl Handler for HandlerWithReportOnlyCsp {
async fn handle(&self, _request: Request) -> Result<Response> {
Ok(Response::new(StatusCode::OK)
.with_header("Content-Security-Policy-Report-Only", "default-src 'none'"))
}
}
let config = CspConfig {
directives: {
let mut d = HashMap::new();
d.insert("default-src".to_string(), vec!["'self'".to_string()]);
d
},
report_only: true,
include_nonce: false,
exempt_paths: HashSet::new(),
};
let middleware = CspMiddleware::with_config(config);
let handler = Arc::new(HandlerWithReportOnlyCsp);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let csp = response
.headers
.get("Content-Security-Policy-Report-Only")
.unwrap()
.to_str()
.unwrap();
assert_eq!(
csp, "default-src 'none'",
"Handler-set report-only CSP should be preserved"
);
}
#[rstest]
#[tokio::test]
async fn test_exempt_path_skips_csp() {
let config = CspConfig::strict().add_exempt_path("/admin".to_string());
let middleware = CspMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/admin/dashboard")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(
!response.headers.contains_key("Content-Security-Policy"),
"CSP should not be set for exempt path"
);
}
#[rstest]
#[tokio::test]
async fn test_exempt_path_exact_match() {
let config = CspConfig::strict().add_exempt_path("/admin".to_string());
let middleware = CspMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/admin")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(
!response.headers.contains_key("Content-Security-Policy"),
"CSP should not be set for exact exempt path match"
);
}
#[rstest]
#[tokio::test]
async fn test_non_exempt_path_gets_csp() {
let config = CspConfig::strict().add_exempt_path("/admin".to_string());
let middleware = CspMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/api/data")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(
response.headers.contains_key("Content-Security-Policy"),
"CSP should be set for non-exempt path"
);
}
#[rstest]
#[tokio::test]
async fn test_exempt_path_boundary_prevents_false_match() {
let config = CspConfig::strict().add_exempt_path("/admin".to_string());
let middleware = CspMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/administrator/panel")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(
response.headers.contains_key("Content-Security-Policy"),
"/administrator should NOT be exempt when only /admin is in exempt_paths"
);
}
#[rstest]
fn test_csp_config_add_exempt_path() {
let config = CspConfig::default()
.add_exempt_path("/admin".to_string())
.add_exempt_path("/static/admin".to_string());
assert!(config.exempt_paths.contains("/admin"));
assert!(config.exempt_paths.contains("/static/admin"));
assert_eq!(config.exempt_paths.len(), 2);
}
struct ErrorHandler;
#[async_trait]
impl Handler for ErrorHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
Err(reinhardt_http::Error::Http("handler error".to_string()))
}
}
#[rstest]
#[tokio::test]
async fn test_csp_header_applied_on_handler_error() {
let config = CspConfig {
directives: {
let mut d = HashMap::new();
d.insert("default-src".to_string(), vec!["'none'".to_string()]);
d
},
report_only: false,
include_nonce: false,
exempt_paths: HashSet::new(),
};
let middleware = CspMiddleware::with_config(config);
let handler: Arc<dyn Handler> = Arc::new(ErrorHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(response.status.is_client_error() || response.status.is_server_error());
assert!(
response.headers.contains_key("Content-Security-Policy"),
"CSP header should be applied even when handler returns an error"
);
}
#[rstest]
#[tokio::test]
async fn test_csp_exempt_path_error_converted_to_response() {
let config = CspConfig::strict().add_exempt_path("/exempt".to_string());
let middleware = CspMiddleware::with_config(config);
let handler: Arc<dyn Handler> = Arc::new(ErrorHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/exempt/resource")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let result = middleware.process(request, handler).await;
assert!(
result.is_ok(),
"Handler error should be converted to response for exempt path"
);
let response = result.unwrap();
assert!(response.status.is_client_error() || response.status.is_server_error());
}
#[rstest]
#[tokio::test]
async fn test_multiple_exempt_paths() {
let config = CspConfig::strict()
.add_exempt_path("/admin".to_string())
.add_exempt_path("/static/admin".to_string());
let middleware = CspMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
for uri in ["/admin/dashboard", "/static/admin/style.css"] {
let request = Request::builder()
.method(Method::GET)
.uri(uri)
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler.clone()).await.unwrap();
assert!(
!response.headers.contains_key("Content-Security-Policy"),
"Path {} should be exempt from CSP",
uri
);
}
}
}