use crate::csrf::CsrfTokens;
use axum::body::Body;
use axum::http::{HeaderMap, Method, Request, Response, StatusCode};
use cookie::Cookie;
use http::header;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tower::{Layer, Service};
#[derive(Clone)]
struct CsrfConfig {
tokens: CsrfTokens,
secure: bool,
same_site: cookie::SameSite,
cookie_name: String,
header_name: String,
excludes: Vec<String>,
}
#[derive(Clone)]
pub struct CsrfLayer {
config: CsrfConfig,
}
impl CsrfLayer {
pub fn new(secret: impl Into<Vec<u8>>) -> Self {
Self {
config: CsrfConfig {
tokens: CsrfTokens::new(secret),
secure: true,
same_site: cookie::SameSite::Lax,
cookie_name: "XSRF-TOKEN".into(),
header_name: "x-xsrf-token".into(),
excludes: Vec::new(),
},
}
}
pub fn secure(mut self, yes: bool) -> Self {
self.config.secure = yes;
self
}
pub fn same_site(mut self, s: cookie::SameSite) -> Self {
self.config.same_site = s;
self
}
pub fn cookie_name(mut self, name: impl Into<String>) -> Self {
self.config.cookie_name = name.into();
self
}
pub fn header_name(mut self, name: impl Into<String>) -> Self {
self.config.header_name = name.into();
self
}
pub fn exclude(mut self, path: impl Into<String>) -> Self {
let mut p = path.into();
if !p.starts_with('/') {
p.insert(0, '/');
}
self.config.excludes.push(p);
self
}
}
impl<S> Layer<S> for CsrfLayer {
type Service = CsrfMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
CsrfMiddleware {
inner,
config: Arc::new(self.config.clone()),
}
}
}
#[doc(hidden)]
#[derive(Clone)]
pub struct CsrfMiddleware<S> {
inner: S,
config: Arc<CsrfConfig>,
}
impl<S> Service<Request<Body>> for CsrfMiddleware<S>
where
S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let cfg = self.config.clone();
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
Box::pin(async move {
let cookie_val = read_cookie(req.headers(), &cfg.cookie_name);
let cookie_is_valid = cookie_val
.as_deref()
.map(|c| cfg.tokens.is_valid(c))
.unwrap_or(false);
let is_mutating = matches!(
*req.method(),
Method::POST | Method::PUT | Method::PATCH | Method::DELETE
);
let excluded = path_excluded(req.uri().path(), &cfg.excludes);
if is_mutating && !excluded {
let header_val = req
.headers()
.get(cfg.header_name.as_str())
.and_then(|v| v.to_str().ok());
let ok = match (cookie_val.as_deref(), header_val) {
(Some(c), Some(h)) => cfg.tokens.verify(c, h),
_ => false,
};
if !ok {
let mut resp = Response::new(Body::from("CSRF token mismatch"));
*resp.status_mut() = StatusCode::from_u16(419).unwrap();
resp.headers_mut().insert(
header::CONTENT_TYPE,
http::HeaderValue::from_static("text/plain; charset=utf-8"),
);
if !cookie_is_valid {
set_token_cookie(resp.headers_mut(), &cfg);
}
return Ok(resp);
}
}
let mut resp = inner.call(req).await?;
if !cookie_is_valid {
set_token_cookie(resp.headers_mut(), &cfg);
}
Ok(resp)
})
}
}
fn path_excluded(path: &str, excludes: &[String]) -> bool {
excludes.iter().any(|p| {
let p = p.trim_end_matches('/');
if p.is_empty() {
return false;
}
path == p || (path.starts_with(p) && path.as_bytes().get(p.len()) == Some(&b'/'))
})
}
fn read_cookie(headers: &HeaderMap, name: &str) -> Option<String> {
headers
.get_all(header::COOKIE)
.iter()
.filter_map(|hv| hv.to_str().ok())
.flat_map(|s| s.split(';'))
.filter_map(|s| Cookie::parse(s.trim().to_owned()).ok())
.find(|c| c.name() == name)
.map(|c| c.value().to_string())
}
fn set_token_cookie(headers: &mut HeaderMap, cfg: &CsrfConfig) {
let mut c = Cookie::new(cfg.cookie_name.clone(), cfg.tokens.generate());
c.set_path("/");
c.set_secure(cfg.secure);
c.set_same_site(cfg.same_site);
if let Ok(hv) = http::HeaderValue::from_str(&c.to_string()) {
headers.append(header::SET_COOKIE, hv);
}
}