use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use axum::body::Body;
use axum::response::IntoResponse;
use axum_extra::extract::cookie::Key;
use cookie::{Cookie, CookieJar, SameSite};
use http::{HeaderValue, Method, Request, Response};
use serde::Deserialize;
use tower::{Layer, Service};
#[non_exhaustive]
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct CsrfConfig {
pub cookie_name: String,
pub header_name: String,
pub field_name: String,
pub ttl_secs: u64,
pub exempt_methods: Vec<String>,
}
impl Default for CsrfConfig {
fn default() -> Self {
Self {
cookie_name: "_csrf".to_string(),
header_name: "X-CSRF-Token".to_string(),
field_name: "_csrf_token".to_string(),
ttl_secs: 21600,
exempt_methods: vec!["GET", "HEAD", "OPTIONS"]
.into_iter()
.map(String::from)
.collect(),
}
}
}
#[derive(Clone, Debug)]
pub struct CsrfToken(pub String);
#[derive(Clone)]
pub struct CsrfLayer {
config: CsrfConfig,
key: Key,
}
impl<S> Layer<S> for CsrfLayer {
type Service = CsrfService<S>;
fn layer(&self, inner: S) -> Self::Service {
CsrfService {
inner,
config: self.config.clone(),
key: self.key.clone(),
}
}
}
#[derive(Clone)]
pub struct CsrfService<S> {
inner: S,
config: CsrfConfig,
key: Key,
}
impl<S> CsrfService<S> {
fn sign_token(&self, token: &str) -> String {
let mut jar = CookieJar::new();
jar.signed_mut(&self.key).add(Cookie::new(
self.config.cookie_name.clone(),
token.to_string(),
));
jar.get(&self.config.cookie_name)
.expect("cookie was just added")
.value()
.to_string()
}
fn verify_token(&self, signed_value: &str) -> Option<String> {
let mut jar = CookieJar::new();
jar.add_original(Cookie::new(
self.config.cookie_name.clone(),
signed_value.to_string(),
));
jar.signed(&self.key)
.get(&self.config.cookie_name)
.map(|c: Cookie<'_>| c.value().to_string())
}
fn build_set_cookie(&self, signed_value: &str) -> String {
Cookie::build((self.config.cookie_name.clone(), signed_value.to_string()))
.http_only(true)
.same_site(SameSite::Lax)
.path("/")
.max_age(cookie::time::Duration::seconds(self.config.ttl_secs as i64))
.build()
.to_string()
}
fn is_exempt(&self, method: &Method) -> bool {
self.config
.exempt_methods
.iter()
.any(|m| m.eq_ignore_ascii_case(method.as_str()))
}
fn extract_submitted_token<B>(&self, request: &Request<B>) -> Option<String> {
request
.headers()
.get(&self.config.header_name)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
fn extract_cookie_value<B>(&self, request: &Request<B>) -> Option<String> {
let cookie_header = request.headers().get(http::header::COOKIE)?;
let cookie_str = cookie_header.to_str().ok()?;
for pair in cookie_str.split(';') {
let pair = pair.trim();
if let Some((name, value)) = pair.split_once('=')
&& name.trim() == self.config.cookie_name
{
return Some(value.trim().to_string());
}
}
None
}
}
impl<S, ReqBody> Service<Request<ReqBody>> for CsrfService<S>
where
S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
ReqBody: Send + 'static,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
let is_exempt = self.is_exempt(request.method());
if is_exempt {
let token = crate::id::ulid();
let signed_value = self.sign_token(&token);
let set_cookie_value = self.build_set_cookie(&signed_value);
request.extensions_mut().insert(CsrfToken(token.clone()));
Box::pin(async move {
let mut response = inner.call(request).await?;
if let Ok(header_value) = HeaderValue::from_str(&set_cookie_value) {
response
.headers_mut()
.append(http::header::SET_COOKIE, header_value);
}
response.extensions_mut().insert(CsrfToken(token));
Ok(response)
})
} else {
let cookie_value = self.extract_cookie_value(&request);
let submitted_token = self.extract_submitted_token(&request);
let verified = cookie_value
.and_then(|signed| self.verify_token(&signed))
.zip(submitted_token)
.is_some_and(|(cookie_token, header_token)| {
use subtle::ConstantTimeEq;
cookie_token
.as_bytes()
.ct_eq(header_token.as_bytes())
.into()
});
if verified {
Box::pin(async move { inner.call(request).await })
} else {
let header_name = self.config.header_name.clone();
Box::pin(async move {
let error = crate::error::Error::forbidden(format!(
"CSRF validation failed: missing or invalid {header_name}"
));
Ok(error.into_response())
})
}
}
}
}
pub fn csrf(config: &CsrfConfig, key: &Key) -> CsrfLayer {
CsrfLayer {
config: config.clone(),
key: key.clone(),
}
}