use axum::{
body::Body,
http::{header, Method, Request, Response},
response::IntoResponse,
};
use rand::{distributions::Alphanumeric, rngs::OsRng, Rng};
use tower::{Layer, Service};
use crate::config::CookieConfig;
use crate::errors::AppError;
use crate::utils::{extract_cookie, is_valid_cookie_domain};
const CSRF_COOKIE_NAME: &str = "XSRF-TOKEN";
const CSRF_HEADER_NAME: &str = "x-csrf-token";
const CSRF_COOKIE_MAX_AGE_SECS: u64 = 60 * 60 * 24;
#[derive(Clone)]
pub struct CsrfLayer {
cookie_config: CookieConfig,
}
impl CsrfLayer {
pub fn new(cookie_config: CookieConfig) -> Self {
Self { cookie_config }
}
}
impl<S> Layer<S> for CsrfLayer {
type Service = CsrfService<S>;
fn layer(&self, inner: S) -> Self::Service {
CsrfService {
inner,
cookie_config: self.cookie_config.clone(),
}
}
}
#[derive(Clone)]
pub struct CsrfService<S> {
inner: S,
cookie_config: CookieConfig,
}
impl<S> Service<Request<Body>> for CsrfService<S>
where
S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let cookie_config = self.cookie_config.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
if !cookie_config.enabled {
return inner.call(req).await;
}
let method = req.method().clone();
let headers = req.headers();
let _has_auth_header = headers.get(header::AUTHORIZATION).is_some();
let csrf_cookie = extract_cookie(headers, CSRF_COOKIE_NAME)
.or_else(|| extract_cookie(headers, "csrf-token"));
let csrf_header = headers
.get(CSRF_HEADER_NAME)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let has_auth_cookie = extract_cookie(headers, &cookie_config.access_cookie_name)
.is_some()
|| extract_cookie(headers, &cookie_config.refresh_cookie_name).is_some();
let is_safe_method = matches!(method, Method::GET | Method::HEAD | Method::OPTIONS);
if has_auth_cookie && !is_safe_method {
match (csrf_cookie.as_deref(), csrf_header.as_deref()) {
(Some(cookie), Some(header)) if cookie == header => {}
_ => {
let mut response =
AppError::Forbidden("Invalid or missing CSRF token".into())
.into_response();
if csrf_cookie.is_none() {
let token = generate_token();
if let Ok(value) = header::HeaderValue::from_str(&build_csrf_cookie(
&cookie_config,
&token,
)) {
response.headers_mut().append(header::SET_COOKIE, value);
}
}
return Ok(response);
}
}
}
let mut response = inner.call(req).await?;
if csrf_cookie.is_none() {
let token = generate_token();
if let Ok(value) =
header::HeaderValue::from_str(&build_csrf_cookie(&cookie_config, &token))
{
response.headers_mut().append(header::SET_COOKIE, value);
}
}
Ok(response)
})
}
}
fn generate_token() -> String {
OsRng
.sample_iter(&Alphanumeric)
.take(44)
.map(char::from)
.collect()
}
fn build_csrf_cookie(config: &CookieConfig, token: &str) -> String {
let path = csrf_cookie_path(config);
let mut cookie = format!(
"{}={}; Path={}; Max-Age={}",
CSRF_COOKIE_NAME, token, path, CSRF_COOKIE_MAX_AGE_SECS
);
if config.secure {
cookie.push_str("; Secure");
}
match config.same_site.to_lowercase().as_str() {
"strict" => cookie.push_str("; SameSite=Strict"),
"none" => cookie.push_str("; SameSite=None"),
_ => cookie.push_str("; SameSite=Lax"),
}
if let Some(ref domain) = config.domain {
if is_valid_cookie_domain(domain) {
cookie.push_str(&format!("; Domain={}", domain));
} else {
tracing::warn!(
domain = %domain,
"Invalid CSRF cookie domain format, skipping Domain attribute"
);
}
}
cookie
}
fn csrf_cookie_path(config: &CookieConfig) -> String {
let trimmed = config.path_prefix.trim_end_matches('/');
if trimmed.is_empty() {
"/".to_string()
} else {
trimmed.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{routing::get, routing::post, Router};
use http_body_util::BodyExt;
use tower::ServiceExt;
fn cookie_config() -> CookieConfig {
CookieConfig {
enabled: true,
domain: None,
secure: false,
same_site: "lax".to_string(),
access_cookie_name: "cedros_access".to_string(),
refresh_cookie_name: "cedros_refresh".to_string(),
path_prefix: "".to_string(),
}
}
#[tokio::test]
async fn test_sets_csrf_cookie_on_safe_request() {
let app = Router::new()
.route("/", get(|| async { "ok" }))
.layer(CsrfLayer::new(cookie_config()));
let response = app
.oneshot(
Request::builder()
.uri("/")
.method(Method::GET)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let header_value = response.headers().get(header::SET_COOKIE).unwrap();
assert!(header_value.to_str().unwrap().contains("XSRF-TOKEN="));
}
#[tokio::test]
async fn test_csrf_cookie_path_respects_prefix() {
let mut config = cookie_config();
config.path_prefix = "/auth".to_string();
let app = Router::new()
.route("/", get(|| async { "ok" }))
.layer(CsrfLayer::new(config));
let response = app
.oneshot(
Request::builder()
.uri("/")
.method(Method::GET)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let header_value = response.headers().get(header::SET_COOKIE).unwrap();
assert!(header_value.to_str().unwrap().contains("Path=/auth"));
}
#[tokio::test]
async fn test_blocks_missing_csrf_on_unsafe_request() {
let app = Router::new()
.route("/", post(|| async { "ok" }))
.layer(CsrfLayer::new(cookie_config()));
let response = app
.oneshot(
Request::builder()
.uri("/")
.method(Method::POST)
.header(header::COOKIE, "cedros_access=session123; XSRF-TOKEN=abc123")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), axum::http::StatusCode::FORBIDDEN);
let body = response.into_body().collect().await.unwrap().to_bytes();
assert!(std::str::from_utf8(&body)
.unwrap()
.contains("Invalid or missing CSRF token"));
}
#[tokio::test]
async fn test_blocks_missing_csrf_cookie_on_unsafe_request() {
let app = Router::new()
.route("/", post(|| async { "ok" }))
.layer(CsrfLayer::new(cookie_config()));
let response = app
.oneshot(
Request::builder()
.uri("/")
.method(Method::POST)
.header(CSRF_HEADER_NAME, "abc123")
.header(header::COOKIE, "cedros_access=session123")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), axum::http::StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn test_missing_csrf_sets_cookie_on_forbidden() {
let app = Router::new()
.route("/", post(|| async { "ok" }))
.layer(CsrfLayer::new(cookie_config()));
let response = app
.oneshot(
Request::builder()
.uri("/")
.method(Method::POST)
.header(header::COOKIE, "cedros_access=session123")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), axum::http::StatusCode::FORBIDDEN);
let header_value = response.headers().get(header::SET_COOKIE).unwrap();
assert!(header_value.to_str().unwrap().contains("XSRF-TOKEN="));
}
#[tokio::test]
async fn test_allows_valid_csrf_on_unsafe_request() {
let app = Router::new()
.route("/", post(|| async { "ok" }))
.layer(CsrfLayer::new(cookie_config()));
let response = app
.oneshot(
Request::builder()
.uri("/")
.method(Method::POST)
.header(header::COOKIE, "cedros_access=session123; XSRF-TOKEN=abc123")
.header(CSRF_HEADER_NAME, "abc123")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), axum::http::StatusCode::OK);
}
#[test]
fn test_csrf_token_length() {
let token = generate_token();
assert_eq!(
token.len(),
44,
"CSRF token should be 44 chars for 262-bit entropy"
);
}
#[test]
fn test_csrf_token_uniqueness() {
let mut tokens = std::collections::HashSet::new();
for _ in 0..1000 {
let token = generate_token();
assert!(tokens.insert(token), "Generated duplicate CSRF token");
}
}
#[tokio::test]
async fn test_csrf_required_with_hybrid_auth() {
let mut config = cookie_config();
config.access_cookie_name = "cedros_access".to_string();
let app = Router::new()
.route("/", post(|| async { "ok" }))
.layer(CsrfLayer::new(config));
let response = app
.oneshot(
Request::builder()
.uri("/")
.method(Method::POST)
.header(header::AUTHORIZATION, "Bearer token123")
.header(header::COOKIE, "cedros_access=session123")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), axum::http::StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn test_csrf_skipped_for_pure_token_auth() {
let config = cookie_config();
let app = Router::new()
.route("/", post(|| async { "ok" }))
.layer(CsrfLayer::new(config));
let response = app
.oneshot(
Request::builder()
.uri("/")
.method(Method::POST)
.header(header::AUTHORIZATION, "Bearer token123")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), axum::http::StatusCode::OK);
}
#[tokio::test]
async fn test_csrf_not_required_without_auth_cookies() {
let app = Router::new()
.route("/", post(|| async { "ok" }))
.layer(CsrfLayer::new(cookie_config()));
let response = app
.oneshot(
Request::builder()
.uri("/")
.method(Method::POST)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), axum::http::StatusCode::OK);
}
#[tokio::test]
async fn test_csrf_success_with_hybrid_auth_and_valid_token() {
let mut config = cookie_config();
config.access_cookie_name = "cedros_access".to_string();
let csrf_token = generate_token();
let app = Router::new()
.route("/", post(|| async { "ok" }))
.layer(CsrfLayer::new(config));
let response = app
.oneshot(
Request::builder()
.uri("/")
.method(Method::POST)
.header(header::AUTHORIZATION, "Bearer token123")
.header(
header::COOKIE,
format!(
"cedros_access=session123; {}={}",
CSRF_COOKIE_NAME, csrf_token
),
)
.header("X-CSRF-Token", &csrf_token)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), axum::http::StatusCode::OK);
}
#[tokio::test]
async fn test_csrf_cookie_only_auth_requires_token() {
let mut config = cookie_config();
config.access_cookie_name = "cedros_access".to_string();
let app = Router::new()
.route("/", post(|| async { "ok" }))
.layer(CsrfLayer::new(config));
let response = app
.oneshot(
Request::builder()
.uri("/")
.method(Method::POST)
.header(header::COOKIE, "cedros_access=session123")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), axum::http::StatusCode::FORBIDDEN);
}
#[test]
fn test_build_csrf_cookie_rejects_invalid_domain() {
let mut config = cookie_config();
config.domain = Some(".com".to_string());
let cookie = build_csrf_cookie(&config, "test_token");
assert!(
!cookie.contains("Domain="),
"Invalid domain .com should be rejected"
);
}
#[test]
fn test_build_csrf_cookie_accepts_valid_domain() {
let mut config = cookie_config();
config.domain = Some(".example.com".to_string());
let cookie = build_csrf_cookie(&config, "test_token");
assert!(
cookie.contains("Domain=.example.com"),
"Valid domain should be included"
);
}
}