use std::sync::Arc;
use axum::body::Body;
use axum::extract::Request;
use axum::http::header::{HeaderName, HeaderValue};
use axum::http::Response;
use axum::middleware::Next;
use axum::Router;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use rand::RngCore;
pub const CSP_NONCE_PLACEHOLDER: &str = "'nonce-__RUSTANGO_NONCE__'";
const PLACEHOLDER_TOKEN: &str = "__RUSTANGO_NONCE__";
#[derive(Debug, Clone)]
pub struct Nonce {
value: Arc<String>,
}
impl Nonce {
#[must_use]
pub fn value(&self) -> &str {
&self.value
}
#[must_use]
pub fn into_string(self) -> String {
Arc::try_unwrap(self.value).unwrap_or_else(|arc| (*arc).clone())
}
}
#[derive(Clone, Debug)]
pub struct CspNonceLayer {
pub bytes: usize,
}
impl Default for CspNonceLayer {
fn default() -> Self {
Self { bytes: 16 }
}
}
impl CspNonceLayer {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn bytes(mut self, n: usize) -> Self {
self.bytes = n.clamp(8, 64);
self
}
}
pub trait CspNonceRouterExt {
#[must_use]
fn csp_nonce(self, layer: CspNonceLayer) -> Self;
}
impl<S: Clone + Send + Sync + 'static> CspNonceRouterExt for Router<S> {
fn csp_nonce(self, layer: CspNonceLayer) -> Self {
let cfg = Arc::new(layer);
self.layer(axum::middleware::from_fn(
move |mut req: Request<Body>, next: Next| {
let cfg = cfg.clone();
async move {
let nonce = Nonce {
value: Arc::new(generate_nonce(cfg.bytes)),
};
req.extensions_mut().insert(nonce.clone());
let mut response = next.run(req).await;
substitute_nonce(&mut response, nonce.value());
response
}
},
))
}
}
fn generate_nonce(byte_len: usize) -> String {
let mut buf = vec![0u8; byte_len];
rand::thread_rng().fill_bytes(&mut buf);
URL_SAFE_NO_PAD.encode(&buf)
}
fn substitute_nonce(response: &mut Response<Body>, nonce: &str) {
for name in [
"content-security-policy",
"content-security-policy-report-only",
] {
let Ok(name) = HeaderName::try_from(name) else {
continue;
};
let Some(existing) = response.headers().get(&name) else {
continue;
};
let Ok(s) = existing.to_str() else { continue };
if !s.contains(PLACEHOLDER_TOKEN) {
continue;
}
let replaced = s.replace(PLACEHOLDER_TOKEN, nonce);
if let Ok(hv) = HeaderValue::from_str(&replaced) {
response.headers_mut().insert(name, hv);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::header::HeaderValue;
use axum::http::Request;
use axum::routing::get;
use axum::Extension;
use tower::ServiceExt;
#[test]
fn placeholder_format_includes_nonce_prefix_and_quotes() {
assert_eq!(CSP_NONCE_PLACEHOLDER, "'nonce-__RUSTANGO_NONCE__'");
}
#[test]
fn nonce_bytes_is_clamped() {
assert_eq!(CspNonceLayer::new().bytes(0).bytes, 8);
assert_eq!(CspNonceLayer::new().bytes(1000).bytes, 64);
assert_eq!(CspNonceLayer::new().bytes(20).bytes, 20);
}
#[test]
fn generated_nonce_is_url_safe_base64() {
let n = generate_nonce(16);
assert!(n
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
assert_eq!(n.len(), 22);
}
#[test]
fn each_nonce_is_unique() {
let a = generate_nonce(16);
let b = generate_nonce(16);
assert_ne!(a, b);
}
#[tokio::test]
async fn handler_can_read_nonce_via_extension() {
async fn h(Extension(nonce): Extension<Nonce>) -> String {
nonce.value().to_owned()
}
let app = Router::new()
.route("/", get(h))
.csp_nonce(CspNonceLayer::default());
let resp = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
let bytes = axum::body::to_bytes(resp.into_body(), 1 << 16)
.await
.unwrap();
let body = std::str::from_utf8(&bytes).unwrap();
assert_eq!(body.len(), 22);
}
#[tokio::test]
async fn nonce_substituted_into_csp_header() {
async fn h() -> ([(&'static str, &'static str); 1], &'static str) {
(
[(
"content-security-policy",
"script-src 'self' 'nonce-__RUSTANGO_NONCE__'",
)],
"ok",
)
}
let app = Router::new()
.route("/", get(h))
.csp_nonce(CspNonceLayer::default());
let resp = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
let csp = resp
.headers()
.get("content-security-policy")
.unwrap()
.to_str()
.unwrap();
assert!(
!csp.contains("__RUSTANGO_NONCE__"),
"placeholder should be replaced"
);
assert!(csp.contains("'nonce-"), "rendered nonce should be present");
}
#[tokio::test]
async fn nonce_substituted_into_report_only_csp_too() {
async fn h() -> ([(&'static str, &'static str); 1], &'static str) {
(
[(
"content-security-policy-report-only",
"script-src 'nonce-__RUSTANGO_NONCE__'",
)],
"ok",
)
}
let app = Router::new()
.route("/", get(h))
.csp_nonce(CspNonceLayer::default());
let resp = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
let csp = resp
.headers()
.get("content-security-policy-report-only")
.unwrap()
.to_str()
.unwrap();
assert!(!csp.contains("__RUSTANGO_NONCE__"));
}
#[tokio::test]
async fn csp_without_placeholder_is_untouched() {
async fn h() -> ([(&'static str, &'static str); 1], &'static str) {
([("content-security-policy", "script-src 'self'")], "ok")
}
let app = Router::new()
.route("/", get(h))
.csp_nonce(CspNonceLayer::default());
let resp = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(
resp.headers()
.get("content-security-policy")
.unwrap()
.to_str()
.unwrap(),
"script-src 'self'"
);
}
#[tokio::test]
async fn csp_substitutes_consistently_with_handler_nonce() {
async fn h(Extension(nonce): Extension<Nonce>) -> ([(HeaderName, HeaderValue); 1], String) {
let csp = format!("script-src 'nonce-{}'", "__RUSTANGO_NONCE__");
(
[(
HeaderName::from_static("content-security-policy"),
HeaderValue::from_str(&csp).unwrap(),
)],
format!("nonce={}", nonce.value()),
)
}
let app = Router::new()
.route("/", get(h))
.csp_nonce(CspNonceLayer::default());
let resp = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
let csp = resp
.headers()
.get("content-security-policy")
.unwrap()
.to_str()
.unwrap()
.to_owned();
let bytes = axum::body::to_bytes(resp.into_body(), 1 << 16)
.await
.unwrap();
let body = std::str::from_utf8(&bytes).unwrap();
let body_nonce = body.strip_prefix("nonce=").unwrap();
assert!(
csp.contains(&format!("'nonce-{body_nonce}'")),
"header CSP should embed the same nonce the handler saw\nCSP: {csp}\nbody: {body}"
);
}
}