use axum::{
body::{Body, Bytes},
extract::{Request, State},
http::{Method, StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
};
use hmac::{Hmac, KeyInit, Mac};
use sha2::Sha256;
use tracing::warn;
use crate::web::{
AppState,
auth::SessionCookie,
crypto::{ConstantTimeEq, ToHex},
origin,
};
type HmacSha256 = Hmac<Sha256>;
const MAX_FORM_BODY: usize = 64 * 1024;
pub(crate) struct CsrfToken(String);
impl CsrfToken {
pub(crate) fn into_string(self) -> String {
self.0
}
pub(crate) fn as_str(&self) -> &str {
&self.0
}
pub(crate) fn verify(&self, presented: &str) -> bool {
self.as_str().ct_eq(presented)
}
}
impl AppState {
pub(crate) fn csrf_token(&self, session_id: &str) -> CsrfToken {
let mut mac = HmacSha256::new_from_slice(self.csrf_key.as_ref())
.expect("HMAC accepts any key length");
mac.update(session_id.as_bytes());
CsrfToken(mac.finalize().into_bytes().to_hex())
}
}
fn is_safe(method: &Method) -> bool {
matches!(
*method,
Method::GET | Method::HEAD | Method::OPTIONS | Method::TRACE
)
}
pub async fn guard(State(state): State<AppState>, req: Request, next: Next) -> Response {
if is_safe(req.method()) {
return next.run(req).await;
}
if !origin_ok(&state, req.headers()) {
warn!("CSRF: rejected mutation with mismatched Origin/Referer");
return forbidden();
}
let Some(cookie) = SessionCookie::from_headers(req.headers()) else {
return next.run(req).await;
};
let expected = state.csrf_token(&cookie.id);
if let Some(header_token) = req
.headers()
.get("x-csrf-token")
.and_then(|v| v.to_str().ok())
&& expected.verify(header_token)
{
return next.run(req).await;
}
let is_json = req
.headers()
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.is_some_and(|ct| ct.contains("application/json"));
let (parts, body) = req.into_parts();
let bytes = match axum::body::to_bytes(body, MAX_FORM_BODY).await {
Ok(b) => b,
Err(_) => return forbidden(),
};
let token = if is_json {
json_field(&bytes, "csrf")
} else {
form_field(&bytes, "csrf_token")
};
if !token.is_some_and(|t| expected.verify(&t)) {
warn!("CSRF: rejected mutation with missing/invalid token");
return forbidden();
}
next.run(Request::from_parts(parts, Body::from(bytes)))
.await
}
fn origin_ok(state: &AppState, headers: &axum::http::HeaderMap) -> bool {
let Some(expected) = origin::origin(state.cookie_policy, headers) else {
return false;
};
if let Some(o) = headers.get(header::ORIGIN).and_then(|v| v.to_str().ok()) {
return o == expected;
}
if let Some(r) = headers.get(header::REFERER).and_then(|v| v.to_str().ok()) {
return r == expected
|| r.strip_prefix(&expected)
.is_some_and(|rest| rest.starts_with('/'));
}
true
}
fn json_field(body: &Bytes, key: &str) -> Option<String> {
let value: serde_json::Value = serde_json::from_slice(body).ok()?;
value.get(key)?.as_str().map(str::to_owned)
}
fn form_field(body: &Bytes, key: &str) -> Option<String> {
let body = std::str::from_utf8(body).ok()?;
for pair in body.split('&') {
if let Some((k, v)) = pair.split_once('=')
&& k == key
{
return Some(v.to_owned());
}
}
None
}
fn forbidden() -> Response {
(StatusCode::FORBIDDEN, "CSRF check failed").into_response()
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderMap;
#[test]
fn safe_methods_detected() {
assert!(is_safe(&Method::GET));
assert!(is_safe(&Method::HEAD));
assert!(!is_safe(&Method::POST));
assert!(!is_safe(&Method::DELETE));
}
#[test]
fn form_field_extracts_token() {
let body = Bytes::from_static(b"foo=1&csrf_token=abc123&bar=2");
assert_eq!(form_field(&body, "csrf_token").as_deref(), Some("abc123"));
assert_eq!(form_field(&body, "missing"), None);
}
#[test]
fn json_field_extracts_token() {
let body = Bytes::from_static(br#"{"csrf":"abc123","f_text":"","queries":5}"#);
assert_eq!(json_field(&body, "csrf").as_deref(), Some("abc123"));
assert_eq!(json_field(&body, "missing"), None);
assert_eq!(json_field(&Bytes::from_static(b"not json"), "csrf"), None);
assert_eq!(
json_field(&Bytes::from_static(br#"{"csrf":5}"#), "csrf"),
None
);
}
#[test]
fn origin_ok_matches_and_rejects() {
use crate::config::SessionCookieSecurePolicy;
let mut h = HeaderMap::new();
h.insert("host", "127.0.0.1:8080".parse().unwrap());
let expected = origin::origin(SessionCookieSecurePolicy::Never, &h).unwrap();
assert_eq!(expected, "http://127.0.0.1:8080");
}
}