use axum::extract::Request;
use axum::http::{HeaderName, Method, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use std::sync::Arc;
use subtle::ConstantTimeEq;
use tower_cookies::Cookies;
#[derive(Clone, Debug)]
pub struct CsrfProtectionConfig {
pub cookie_name: &'static str,
pub header_name: HeaderName,
}
impl Default for CsrfProtectionConfig {
fn default() -> Self {
Self {
cookie_name: "csrf_token",
header_name: HeaderName::from_static("x-csrf-token"),
}
}
}
fn is_unsafe_method(m: &Method) -> bool {
matches!(
*m,
Method::POST | Method::PUT | Method::PATCH | Method::DELETE
)
}
fn constant_time_eq_bytes(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
a.ct_eq(b).into()
}
pub async fn csrf_double_submit_middleware(
axum::extract::State(config): axum::extract::State<Arc<CsrfProtectionConfig>>,
req: Request,
next: Next,
) -> Response {
if !is_unsafe_method(req.method()) {
return next.run(req).await;
}
let (parts, body) = req.into_parts();
let Some(cookies) = parts.extensions.get::<Cookies>() else {
return (
StatusCode::FORBIDDEN,
axum::Json(serde_json::json!({
"statusCode": 403,
"message": "CSRF check requires CookieManagerLayer (use NestApplication::use_cookies)",
"error": "Forbidden",
})),
)
.into_response();
};
let cookie_s = cookies
.get(config.cookie_name)
.map(|c| c.value().to_string());
let header_s = parts
.headers
.get(&config.header_name)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let ok = match (&cookie_s, &header_s) {
(Some(a), Some(b)) => constant_time_eq_bytes(a.as_bytes(), b.as_bytes()),
_ => false,
};
if !ok {
return (
StatusCode::FORBIDDEN,
axum::Json(serde_json::json!({
"statusCode": 403,
"message": "CSRF token missing or invalid",
"error": "Forbidden",
})),
)
.into_response();
}
let req = Request::from_parts(parts, body);
next.run(req).await
}