use std::convert::Infallible;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use axum::body::Body;
use axum::extract::Request;
use axum::http::{HeaderValue, Method, Response, StatusCode};
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine as _;
use rand::RngCore;
use tower::Service;
const CSRF_COOKIE: &str = "rustango_csrf";
const CSRF_HEADER: &str = "X-CSRF-Token";
pub const CSRF_FORM_FIELD: &str = "_csrf";
pub fn layer() -> CsrfLayer {
CsrfLayer::new(CsrfConfig::default())
}
pub fn with_config(cfg: CsrfConfig) -> CsrfLayer {
CsrfLayer::new(cfg)
}
#[derive(Debug, Clone)]
pub struct CsrfConfig {
pub cookie_name: String,
pub header_name: String,
pub secure: bool,
}
impl Default for CsrfConfig {
fn default() -> Self {
Self {
cookie_name: CSRF_COOKIE.to_owned(),
header_name: CSRF_HEADER.to_owned(),
secure: false,
}
}
}
#[derive(Clone)]
pub struct CsrfLayer {
cfg: Arc<CsrfConfig>,
}
impl CsrfLayer {
fn new(cfg: CsrfConfig) -> Self {
Self { cfg: Arc::new(cfg) }
}
}
impl<S> tower::Layer<S> for CsrfLayer {
type Service = CsrfService<S>;
fn layer(&self, inner: S) -> Self::Service {
CsrfService {
inner,
cfg: Arc::clone(&self.cfg),
}
}
}
#[derive(Clone)]
pub struct CsrfService<S> {
inner: S,
cfg: Arc<CsrfConfig>,
}
impl<S> Service<Request<Body>> for CsrfService<S>
where
S: Service<Request<Body>, Response = Response<Body>, Error = Infallible> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = Response<Body>;
type Error = Infallible;
type Future =
Pin<Box<dyn std::future::Future<Output = Result<Response<Body>, Infallible>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let cfg = Arc::clone(&self.cfg);
let mut inner = self.inner.clone();
Box::pin(async move {
let cookie_value = read_csrf_cookie(&req, &cfg.cookie_name);
if !is_safe_method(req.method()) {
let header_value = req
.headers()
.get(&cfg.header_name)
.and_then(|v| v.to_str().ok())
.map(str::to_owned);
let token_match = match (&cookie_value, &header_value) {
(Some(c), Some(h)) => constant_time_eq(c.as_bytes(), h.as_bytes()),
_ => false,
};
if !token_match {
return Ok(forbid_response("CSRF token missing or mismatched"));
}
}
let mut response = inner.call(req).await?;
if cookie_value.is_none() {
let token = mint_token();
let cookie_str = format!(
"{}={token}; Path=/; SameSite=Lax{}",
cfg.cookie_name,
if cfg.secure { "; Secure" } else { "" }
);
if let Ok(hv) = HeaderValue::from_str(&cookie_str) {
response
.headers_mut()
.append(axum::http::header::SET_COOKIE, hv);
}
}
Ok(response)
})
}
}
fn is_safe_method(m: &Method) -> bool {
matches!(*m, Method::GET | Method::HEAD | Method::OPTIONS | Method::TRACE)
}
fn read_csrf_cookie(req: &Request<Body>, name: &str) -> Option<String> {
let raw = req
.headers()
.get(axum::http::header::COOKIE)?
.to_str()
.ok()?;
for part in raw.split(';') {
let part = part.trim();
if let Some((k, v)) = part.split_once('=') {
if k == name {
return Some(v.to_owned());
}
}
}
None
}
fn mint_token() -> String {
let mut bytes = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut bytes);
URL_SAFE_NO_PAD.encode(bytes)
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
fn forbid_response(detail: &'static str) -> Response<Body> {
let mut response = Response::new(Body::from(detail));
*response.status_mut() = StatusCode::FORBIDDEN;
response
.headers_mut()
.insert("Content-Type", HeaderValue::from_static("text/plain"));
response
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn safe_method_predicate() {
assert!(is_safe_method(&Method::GET));
assert!(is_safe_method(&Method::HEAD));
assert!(is_safe_method(&Method::OPTIONS));
assert!(!is_safe_method(&Method::POST));
assert!(!is_safe_method(&Method::PUT));
assert!(!is_safe_method(&Method::DELETE));
}
#[test]
fn ct_eq_matches_eq() {
assert!(constant_time_eq(b"abc", b"abc"));
assert!(!constant_time_eq(b"abc", b"abd"));
assert!(!constant_time_eq(b"abc", b"abcd"));
assert!(constant_time_eq(b"", b""));
}
#[test]
fn mint_token_is_base64url_no_pad() {
let t = mint_token();
assert_eq!(t.len(), 43);
assert!(!t.contains('='));
assert!(URL_SAFE_NO_PAD.decode(t.as_bytes()).is_ok());
}
#[test]
fn read_csrf_cookie_finds_named_pair() {
use axum::http::Request;
let req = Request::builder()
.header("cookie", "session=abc; rustango_csrf=hello; theme=dark")
.body(Body::empty())
.unwrap();
assert_eq!(read_csrf_cookie(&req, "rustango_csrf").as_deref(), Some("hello"));
assert_eq!(read_csrf_cookie(&req, "missing").as_deref(), None);
}
#[test]
fn read_csrf_cookie_returns_none_when_no_header() {
use axum::http::Request;
let req = Request::builder().body(Body::empty()).unwrap();
assert_eq!(read_csrf_cookie(&req, "anything"), None);
}
}