use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use axum::extract::FromRequestParts;
use axum::http::{Request, Response, StatusCode};
use http::header::HeaderName;
use tower::{Layer, Service};
use uuid::Uuid;
use super::config::CsrfConfig;
const CSRF_FORBIDDEN_MESSAGE: &str = "CSRF token missing or invalid";
#[derive(Clone, Debug)]
pub struct CsrfToken(String);
impl CsrfToken {
#[must_use]
pub fn token(&self) -> &str {
&self.0
}
}
impl std::fmt::Display for CsrfToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl<S> FromRequestParts<S> for CsrfToken
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
parts.extensions.get::<Self>().cloned().ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"CSRF token not found in request extensions. Is CsrfLayer enabled?",
))
}
}
#[derive(Debug, Clone)]
struct CsrfSettings {
cookie_name: String,
token_header: HeaderName,
form_field: String,
safe_methods: Vec<http::Method>,
exempt_paths: Vec<String>,
}
#[derive(Clone, Debug)]
pub struct CsrfLayer {
settings: Arc<CsrfSettings>,
}
impl CsrfLayer {
#[must_use]
pub fn from_config(config: &CsrfConfig) -> Self {
let safe_methods = config
.safe_methods
.iter()
.filter_map(|m| m.parse::<http::Method>().ok())
.collect();
let token_header = config
.token_header
.parse::<HeaderName>()
.unwrap_or_else(|_| HeaderName::from_static("x-csrf-token"));
Self {
settings: Arc::new(CsrfSettings {
cookie_name: config.cookie_name.clone(),
token_header,
form_field: config.form_field.clone(),
safe_methods,
exempt_paths: config.exempt_paths.clone(),
}),
}
}
}
impl<S> Layer<S> for CsrfLayer {
type Service = CsrfService<S>;
fn layer(&self, inner: S) -> Self::Service {
CsrfService {
inner,
settings: Arc::clone(&self.settings),
}
}
}
#[derive(Clone, Debug)]
pub struct CsrfService<S> {
inner: S,
settings: Arc<CsrfSettings>,
}
use subtle::{Choice, ConstantTimeEq};
#[inline(never)]
fn constant_time_eq(a: &str, b: &str) -> bool {
let a = a.as_bytes();
let b = b.as_bytes();
let len_eq = a.len().ct_eq(&b.len());
let mut bytes_eq = Choice::from(1u8);
for (i, &a_byte) in a.iter().enumerate() {
let b_byte = *b.get(i).unwrap_or(&0xFF);
bytes_eq &= a_byte.ct_eq(&b_byte);
}
(len_eq & bytes_eq).into()
}
fn extract_cookie_token(req_headers: &http::HeaderMap, cookie_name: &str) -> Option<String> {
let mut found_token = None;
for cookie_header in &req_headers.get_all(http::header::COOKIE) {
if let Ok(cookie_str) = cookie_header.to_str() {
for pair in cookie_str.split(';') {
let pair = pair.trim();
if let Some((name, value)) = pair.split_once('=') {
if name.trim() == cookie_name {
if found_token.is_some() {
return None;
}
found_token = Some(value.trim().to_owned());
}
}
}
}
}
found_token
}
impl<S, ResBody> Service<Request<axum::body::Body>> for CsrfService<S>
where
S: Service<Request<axum::body::Body>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
ResBody: From<&'static str> + Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<axum::body::Body>) -> Self::Future {
let path = req.uri().path();
let is_exempt = self
.settings
.exempt_paths
.iter()
.any(|prefix| path.starts_with(prefix.as_str()));
let is_safe = is_exempt || self.settings.safe_methods.contains(req.method());
let cookie_token = extract_cookie_token(req.headers(), &self.settings.cookie_name);
let token = cookie_token
.clone()
.unwrap_or_else(|| Uuid::new_v4().to_string());
req.extensions_mut().insert(CsrfToken(token.clone()));
let set_cookie = if cookie_token.is_none() {
Some(format!(
"{}={}; Path=/; SameSite=Lax; HttpOnly",
self.settings.cookie_name, token
))
} else {
None
};
let settings = Arc::clone(&self.settings);
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
Box::pin(async move {
if !is_safe && !verify_csrf_token(&mut req, &settings, cookie_token.as_deref()).await {
let mut response = Response::new(ResBody::from(CSRF_FORBIDDEN_MESSAGE));
*response.status_mut() = StatusCode::FORBIDDEN;
response.headers_mut().insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("text/plain; charset=utf-8"),
);
return Ok(response);
}
let mut response = inner.call(req).await?;
if let Some(cookie) = set_cookie {
if let Ok(val) = http::header::HeaderValue::from_str(&cookie) {
response.headers_mut().append(http::header::SET_COOKIE, val);
}
}
Ok(response)
})
}
}
async fn verify_csrf_token(
req: &mut Request<axum::body::Body>,
settings: &CsrfSettings,
cookie_token: Option<&str>,
) -> bool {
let mut token_found = false;
let header_token = req
.headers()
.get(&settings.token_header)
.and_then(|v| v.to_str().ok());
if let (Some(c), Some(h)) = (cookie_token, header_token) {
if !c.is_empty() && !h.is_empty() && constant_time_eq(c, h) {
token_found = true;
}
}
if token_found {
return true;
}
let content_type = req
.headers()
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or_default();
if !content_type.starts_with("application/x-www-form-urlencoded") {
return false;
}
let body = std::mem::replace(req.body_mut(), axum::body::Body::empty());
let bytes = axum::body::to_bytes(body, 2 * 1024 * 1024)
.await
.unwrap_or_else(|_| axum::body::Bytes::new());
if let Ok(body_str) = std::str::from_utf8(&bytes) {
for pair in body_str.split('&') {
if let Some((key, value)) = pair.split_once('=') {
if key == settings.form_field {
if let Some(c) = cookie_token {
if !c.is_empty() && !value.is_empty() && constant_time_eq(c, value) {
token_found = true;
}
}
break;
}
}
}
}
*req.body_mut() = axum::body::Body::from(bytes);
token_found
}
#[cfg(test)]
mod tests {
use super::*;
use axum::Router;
use axum::body::Body;
use axum::routing::{get, post};
use tower::ServiceExt;
fn default_csrf_config() -> CsrfConfig {
CsrfConfig {
enabled: true,
..Default::default()
}
}
#[tokio::test]
async fn safe_method_passes_without_token() {
let app = Router::new()
.route("/", get(|| async { "ok" }))
.layer(CsrfLayer::from_config(&default_csrf_config()));
let response = app
.oneshot(
Request::builder()
.method("GET")
.uri("/")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn safe_method_sets_csrf_cookie() {
let app = Router::new()
.route("/", get(|| async { "ok" }))
.layer(CsrfLayer::from_config(&default_csrf_config()));
let response = app
.oneshot(
Request::builder()
.method("GET")
.uri("/")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let set_cookie = response
.headers()
.get("set-cookie")
.unwrap()
.to_str()
.unwrap();
assert!(set_cookie.starts_with("autumn-csrf="));
assert!(set_cookie.contains("HttpOnly"));
}
#[tokio::test]
async fn post_without_token_returns_403() {
let app = Router::new()
.route("/submit", post(|| async { "created" }))
.layer(CsrfLayer::from_config(&default_csrf_config()));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/submit")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn forbidden_response_has_clear_error_body() {
let app = Router::new()
.route("/submit", post(|| async { "created" }))
.layer(CsrfLayer::from_config(&default_csrf_config()));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/submit")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
assert_eq!(
response
.headers()
.get(http::header::CONTENT_TYPE)
.map(|v| v.to_str().unwrap_or_default()),
Some("text/plain; charset=utf-8")
);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let text = std::str::from_utf8(&body).unwrap();
assert!(
text.contains("CSRF"),
"expected CSRF error message, got: {text:?}"
);
}
#[tokio::test]
async fn exempt_path_skips_csrf_validation() {
let config = CsrfConfig {
enabled: true,
exempt_paths: vec!["/api/".to_string()],
..Default::default()
};
let app = Router::new()
.route("/api/items", post(|| async { "created" }))
.route("/form/submit", post(|| async { "created" }))
.layer(CsrfLayer::from_config(&config));
let response = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/api/items")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/form/submit")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn post_with_valid_token_passes() {
let token = Uuid::new_v4().to_string();
let app = Router::new()
.route("/submit", post(|| async { "created" }))
.layer(CsrfLayer::from_config(&default_csrf_config()));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/submit")
.header("Cookie", format!("autumn-csrf={token}"))
.header("X-CSRF-Token", &token)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn post_with_mismatched_token_returns_403() {
let cookie_token = Uuid::new_v4().to_string();
let header_token = Uuid::new_v4().to_string();
let app = Router::new()
.route("/submit", post(|| async { "created" }))
.layer(CsrfLayer::from_config(&default_csrf_config()));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/submit")
.header("Cookie", format!("autumn-csrf={cookie_token}"))
.header("X-CSRF-Token", &header_token)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn csrf_token_extractor_works() {
async fn handler(csrf: CsrfToken) -> String {
csrf.token().to_owned()
}
let app = Router::new()
.route("/", get(handler))
.layer(CsrfLayer::from_config(&default_csrf_config()));
let response = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let token_str = String::from_utf8(body.to_vec()).unwrap();
assert!(Uuid::parse_str(&token_str).is_ok());
}
#[test]
fn extract_cookie_from_header() {
let mut headers = http::HeaderMap::new();
headers.insert(
http::header::COOKIE,
"autumn-csrf=abc123; other=xyz".parse().unwrap(),
);
assert_eq!(
extract_cookie_token(&headers, "autumn-csrf"),
Some("abc123".to_owned())
);
}
#[test]
fn missing_cookie_returns_none() {
let headers = http::HeaderMap::new();
assert_eq!(extract_cookie_token(&headers, "autumn-csrf"), None);
}
#[test]
fn extract_cookie_rejects_multiple_cookies() {
let mut headers = http::HeaderMap::new();
headers.insert(
http::header::COOKIE,
"autumn-csrf=abc123; autumn-csrf=xyz456".parse().unwrap(),
);
assert_eq!(extract_cookie_token(&headers, "autumn-csrf"), None);
let mut headers2 = http::HeaderMap::new();
headers2.append(http::header::COOKIE, "autumn-csrf=abc123".parse().unwrap());
headers2.append(http::header::COOKIE, "autumn-csrf=xyz456".parse().unwrap());
assert_eq!(extract_cookie_token(&headers2, "autumn-csrf"), None);
}
#[test]
fn extract_cookie_ignores_malformed_cookies() {
let mut headers = http::HeaderMap::new();
headers.insert(http::header::COOKIE, "autumn-csrf abc123".parse().unwrap());
assert_eq!(extract_cookie_token(&headers, "autumn-csrf"), None);
headers.insert(
http::header::COOKIE,
" autumn-csrf = abc123 ; other=xyz".parse().unwrap(),
);
assert_eq!(
extract_cookie_token(&headers, "autumn-csrf"),
Some("abc123".to_owned())
);
}
#[test]
fn test_constant_time_eq() {
assert!(super::constant_time_eq("abc", "abc"));
assert!(!super::constant_time_eq("abc", "ab"));
assert!(!super::constant_time_eq("abc", "abd"));
assert!(super::constant_time_eq("", ""));
assert!(!super::constant_time_eq("a", "b"));
assert!(!super::constant_time_eq("a", "A"));
}
#[tokio::test]
async fn post_with_empty_cookie_but_valid_header() {
let token = Uuid::new_v4().to_string();
let app = Router::new()
.route("/submit", post(|| async { "created" }))
.layer(CsrfLayer::from_config(&default_csrf_config()));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/submit")
.header("Cookie", "autumn-csrf=")
.header("X-CSRF-Token", &token)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn post_with_valid_cookie_but_empty_header() {
let token = Uuid::new_v4().to_string();
let app = Router::new()
.route("/submit", post(|| async { "created" }))
.layer(CsrfLayer::from_config(&default_csrf_config()));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/submit")
.header("Cookie", format!("autumn-csrf={token}"))
.header("X-CSRF-Token", "")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn post_with_empty_cookie_but_valid_form_field() {
let token = Uuid::new_v4().to_string();
let app = Router::new()
.route("/submit", post(|| async { "created" }))
.layer(CsrfLayer::from_config(&default_csrf_config()));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/submit")
.header("Cookie", "autumn-csrf=")
.header("Content-Type", "application/x-www-form-urlencoded")
.body(Body::from(format!("_csrf={token}")))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn post_with_valid_cookie_but_empty_form_field() {
let token = Uuid::new_v4().to_string();
let app = Router::new()
.route("/submit", post(|| async { "created" }))
.layer(CsrfLayer::from_config(&default_csrf_config()));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/submit")
.header("Cookie", format!("autumn-csrf={token}"))
.header("Content-Type", "application/x-www-form-urlencoded")
.body(Body::from("_csrf="))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn post_with_large_body_fails_csrf() {
let token = Uuid::new_v4().to_string();
let app = Router::new()
.route("/submit", post(|| async { "created" }))
.layer(CsrfLayer::from_config(&default_csrf_config()));
let large_padding = "a".repeat(2 * 1024 * 1024 + 10);
let body_content = format!("_csrf={token}&pad={large_padding}");
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/submit")
.header("Cookie", format!("autumn-csrf={token}"))
.header("Content-Type", "application/x-www-form-urlencoded")
.body(Body::from(body_content))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn post_with_empty_tokens_returns_403() {
let app = Router::new()
.route("/submit", post(|| async { "created" }))
.layer(CsrfLayer::from_config(&CsrfConfig {
enabled: true,
..Default::default()
}));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/submit")
.header("Cookie", "autumn-csrf=")
.header("X-CSRF-Token", "")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn post_with_empty_form_tokens_returns_403() {
let app = Router::new()
.route("/submit", post(|| async { "created" }))
.layer(CsrfLayer::from_config(&CsrfConfig {
enabled: true,
..Default::default()
}));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/submit")
.header("Cookie", "autumn-csrf=")
.header("Content-Type", "application/x-www-form-urlencoded")
.body(Body::from("_csrf="))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[test]
fn from_config_filters_invalid_methods() {
let config = CsrfConfig {
safe_methods: vec![
"GET".to_string(),
"INVALID METHOD".to_string(),
"POST".to_string(),
],
..Default::default()
};
let layer = CsrfLayer::from_config(&config);
assert_eq!(layer.settings.safe_methods.len(), 2);
assert!(layer.settings.safe_methods.contains(&http::Method::GET));
assert!(layer.settings.safe_methods.contains(&http::Method::POST));
}
#[test]
fn from_config_handles_invalid_header_name() {
let config = CsrfConfig {
token_header: "Invalid Header Name\n".to_string(),
..Default::default()
};
let layer = CsrfLayer::from_config(&config);
assert_eq!(layer.settings.token_header.as_str(), "x-csrf-token");
}
}