use std::convert::Infallible;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use axum::body::Body;
use axum::extract::Request;
use axum::http::{HeaderValue, Method, Response, StatusCode};
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine as _;
use rand::RngCore;
use tower::Service;
pub const CSRF_COOKIE: &str = "rustango_csrf";
const CSRF_HEADER: &str = "X-CSRF-Token";
pub const CSRF_FORM_FIELD: &str = "_csrf";
pub fn layer() -> CsrfLayer {
CsrfLayer::new(CsrfConfig::default())
}
pub fn with_config(cfg: CsrfConfig) -> CsrfLayer {
CsrfLayer::new(cfg)
}
#[derive(Debug, Clone)]
pub struct CsrfConfig {
pub cookie_name: String,
pub header_name: String,
pub secure: bool,
}
impl Default for CsrfConfig {
fn default() -> Self {
Self {
cookie_name: CSRF_COOKIE.to_owned(),
header_name: CSRF_HEADER.to_owned(),
secure: false,
}
}
}
#[derive(Clone)]
pub struct CsrfLayer {
cfg: Arc<CsrfConfig>,
}
impl CsrfLayer {
fn new(cfg: CsrfConfig) -> Self {
Self { cfg: Arc::new(cfg) }
}
}
impl<S> tower::Layer<S> for CsrfLayer {
type Service = CsrfService<S>;
fn layer(&self, inner: S) -> Self::Service {
CsrfService {
inner,
cfg: Arc::clone(&self.cfg),
}
}
}
#[derive(Clone)]
pub struct CsrfService<S> {
inner: S,
cfg: Arc<CsrfConfig>,
}
impl<S> Service<Request<Body>> for CsrfService<S>
where
S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = Response<Body>;
type Error = Infallible;
type Future =
Pin<Box<dyn std::future::Future<Output = Result<Response<Body>, Infallible>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let cfg = Arc::clone(&self.cfg);
let mut inner = self.inner.clone();
Box::pin(async move {
let cookie_value = read_csrf_cookie(&req, &cfg.cookie_name);
let req = if !is_safe_method(req.method()) {
let header_value = req
.headers()
.get(&cfg.header_name)
.and_then(|v| v.to_str().ok())
.map(str::to_owned);
if let Some(h) = header_value {
let token_match = match &cookie_value {
Some(c) => constant_time_eq(c.as_bytes(), h.as_bytes()),
None => false,
};
if !token_match {
return Ok(forbid_response("CSRF token missing or mismatched"));
}
req
} else if is_form_encoded(&req) {
let (parts, body) = req.into_parts();
let bytes = match axum::body::to_bytes(body, BODY_BUFFER_LIMIT).await {
Ok(b) => b,
Err(_) => {
return Ok(forbid_response("CSRF: form body exceeded buffer limit"));
}
};
let form_token = read_form_field(&bytes, CSRF_FORM_FIELD);
let token_match = match (&cookie_value, &form_token) {
(Some(c), Some(f)) => constant_time_eq(c.as_bytes(), f.as_bytes()),
_ => false,
};
if !token_match {
return Ok(forbid_response("CSRF token missing or mismatched"));
}
Request::from_parts(parts, Body::from(bytes))
} else {
return Ok(forbid_response("CSRF token missing or mismatched"));
}
} else {
req
};
let mut response = inner.call(req).await?;
if cookie_value.is_none() {
let token = mint_token();
let cookie_str = format!(
"{}={token}; Path=/; SameSite=Lax{}",
cfg.cookie_name,
if cfg.secure { "; Secure" } else { "" }
);
if let Ok(hv) = HeaderValue::from_str(&cookie_str) {
response
.headers_mut()
.append(axum::http::header::SET_COOKIE, hv);
}
}
Ok(response)
})
}
}
const BODY_BUFFER_LIMIT: usize = 64 * 1024;
fn is_safe_method(m: &Method) -> bool {
matches!(
*m,
Method::GET | Method::HEAD | Method::OPTIONS | Method::TRACE
)
}
fn read_csrf_cookie(req: &Request<Body>, name: &str) -> Option<String> {
read_csrf_cookie_from_headers(req.headers(), name)
}
fn read_csrf_cookie_from_headers(headers: &axum::http::HeaderMap, name: &str) -> Option<String> {
let raw = headers.get(axum::http::header::COOKIE)?.to_str().ok()?;
for part in raw.split(';') {
let part = part.trim();
if let Some((k, v)) = part.split_once('=') {
if k == name {
return Some(v.to_owned());
}
}
}
None
}
fn is_form_encoded(req: &Request<Body>) -> bool {
let Some(ct) = req
.headers()
.get(axum::http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
else {
return false;
};
let head = ct
.split(';')
.next()
.unwrap_or("")
.trim()
.to_ascii_lowercase();
head == "application/x-www-form-urlencoded"
}
fn read_form_field(body: &[u8], name: &str) -> Option<String> {
let s = std::str::from_utf8(body).ok()?;
for pair in s.split('&') {
let Some((k, v)) = pair.split_once('=') else {
continue;
};
let key = percent_decode(k.replace('+', " ").as_bytes())?;
if key == name {
return percent_decode(v.replace('+', " ").as_bytes());
}
}
None
}
fn percent_decode(bytes: &[u8]) -> Option<String> {
let mut out: Vec<u8> = Vec::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
let b = bytes[i];
if b == b'%' {
if i + 2 >= bytes.len() {
return None;
}
let hi = hex_digit(bytes[i + 1])?;
let lo = hex_digit(bytes[i + 2])?;
out.push(hi * 16 + lo);
i += 3;
} else {
out.push(b);
i += 1;
}
}
String::from_utf8(out).ok()
}
fn hex_digit(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
fn mint_token() -> String {
let mut bytes = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut bytes);
URL_SAFE_NO_PAD.encode(bytes)
}
#[must_use]
pub fn ensure_token(
headers: &axum::http::HeaderMap,
cookie_name: &str,
) -> (String, Option<String>) {
if let Some(existing) = read_csrf_cookie_from_headers(headers, cookie_name) {
return (existing, None);
}
let token = mint_token();
let cookie = format!("{cookie_name}={token}; Path=/; SameSite=Lax");
(token, Some(cookie))
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
fn forbid_response(detail: &'static str) -> Response<Body> {
let mut response = Response::new(Body::from(detail));
*response.status_mut() = StatusCode::FORBIDDEN;
response
.headers_mut()
.insert("Content-Type", HeaderValue::from_static("text/plain"));
response
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn safe_method_predicate() {
assert!(is_safe_method(&Method::GET));
assert!(is_safe_method(&Method::HEAD));
assert!(is_safe_method(&Method::OPTIONS));
assert!(!is_safe_method(&Method::POST));
assert!(!is_safe_method(&Method::PUT));
assert!(!is_safe_method(&Method::DELETE));
}
#[test]
fn ct_eq_matches_eq() {
assert!(constant_time_eq(b"abc", b"abc"));
assert!(!constant_time_eq(b"abc", b"abd"));
assert!(!constant_time_eq(b"abc", b"abcd"));
assert!(constant_time_eq(b"", b""));
}
#[test]
fn mint_token_is_base64url_no_pad() {
let t = mint_token();
assert_eq!(t.len(), 43);
assert!(!t.contains('='));
assert!(URL_SAFE_NO_PAD.decode(t.as_bytes()).is_ok());
}
#[test]
fn read_csrf_cookie_finds_named_pair() {
use axum::http::Request;
let req = Request::builder()
.header("cookie", "session=abc; rustango_csrf=hello; theme=dark")
.body(Body::empty())
.unwrap();
assert_eq!(
read_csrf_cookie(&req, "rustango_csrf").as_deref(),
Some("hello")
);
assert_eq!(read_csrf_cookie(&req, "missing").as_deref(), None);
}
#[test]
fn read_csrf_cookie_returns_none_when_no_header() {
use axum::http::Request;
let req = Request::builder().body(Body::empty()).unwrap();
assert_eq!(read_csrf_cookie(&req, "anything"), None);
}
#[test]
fn is_form_encoded_recognizes_canonical_and_charset_variants() {
use axum::http::Request;
let req = Request::builder()
.header("content-type", "application/x-www-form-urlencoded")
.body(Body::empty())
.unwrap();
assert!(is_form_encoded(&req));
let req = Request::builder()
.header(
"content-type",
"application/x-www-form-urlencoded; charset=UTF-8",
)
.body(Body::empty())
.unwrap();
assert!(is_form_encoded(&req));
let req = Request::builder()
.header("content-type", "multipart/form-data; boundary=---")
.body(Body::empty())
.unwrap();
assert!(!is_form_encoded(&req));
let req = Request::builder()
.header("content-type", "application/json")
.body(Body::empty())
.unwrap();
assert!(!is_form_encoded(&req));
let req = Request::builder().body(Body::empty()).unwrap();
assert!(!is_form_encoded(&req));
}
#[test]
fn read_form_field_extracts_named_value() {
let body = b"foo=bar&_csrf=tok123&other=baz";
assert_eq!(read_form_field(body, "_csrf").as_deref(), Some("tok123"));
assert_eq!(read_form_field(body, "foo").as_deref(), Some("bar"));
assert_eq!(read_form_field(body, "missing"), None);
}
#[test]
fn read_form_field_percent_decodes_value() {
let body = b"_csrf=abc%2Fxyz";
assert_eq!(read_form_field(body, "_csrf").as_deref(), Some("abc/xyz"));
}
#[test]
fn read_form_field_treats_plus_as_space() {
let body = b"q=hello+world";
assert_eq!(read_form_field(body, "q").as_deref(), Some("hello world"));
}
#[test]
fn read_form_field_returns_none_for_malformed_pairs() {
let body = b"foo&_csrf=tok";
assert_eq!(read_form_field(body, "_csrf").as_deref(), Some("tok"));
assert_eq!(read_form_field(body, "foo"), None);
}
#[test]
fn percent_decode_rejects_malformed() {
assert!(percent_decode(b"%2").is_none()); assert!(percent_decode(b"%ZZ").is_none()); assert_eq!(percent_decode(b"plain").as_deref(), Some("plain"));
assert_eq!(percent_decode(b"a%20b").as_deref(), Some("a b"));
}
}