use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use axum::{
body::Body,
http::{
Method, Request, StatusCode,
header::{COOKIE, HeaderMap, HeaderName},
},
response::Response,
};
use tower::{Layer, Service};
type BoxResponseFuture<E> = Pin<Box<dyn Future<Output = Result<Response, E>> + Send>>;
#[derive(Clone)]
pub struct CsrfMiddlewareLayer {
pub cookie_name: String,
pub header_name: String,
}
impl<S> Layer<S> for CsrfMiddlewareLayer {
type Service = CsrfMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
CsrfMiddleware {
inner,
cookie_name: self.cookie_name.clone(),
header_name: self.header_name.clone(),
}
}
}
#[derive(Clone)]
pub struct CsrfMiddleware<S> {
inner: S,
cookie_name: String,
header_name: String,
}
impl<S> Service<Request<Body>> for CsrfMiddleware<S>
where
S: Service<Request<Body>, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = Response;
type Error = S::Error;
type Future = BoxResponseFuture<Self::Error>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
if is_safe_method(request.method())
|| has_matching_token(request.headers(), &self.cookie_name, &self.header_name)
{
return Box::pin(self.inner.call(request));
}
Box::pin(async {
Ok(Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Body::from("CSRF verification failed"))
.expect("forbidden response should build"))
})
}
}
fn is_safe_method(method: &Method) -> bool {
matches!(method.as_str(), "GET" | "HEAD" | "OPTIONS" | "TRACE")
}
fn has_matching_token(headers: &HeaderMap, cookie_name: &str, header_name: &str) -> bool {
let cookie_value = headers
.get(COOKIE)
.and_then(|value| value.to_str().ok())
.and_then(|value| cookie_value(value, cookie_name));
let header_value = HeaderName::from_bytes(header_name.as_bytes())
.ok()
.and_then(|name| headers.get(name))
.and_then(|value| value.to_str().ok());
matches!((cookie_value, header_value), (Some(cookie), Some(header)) if !cookie.is_empty() && cookie == header)
}
fn cookie_value<'a>(cookie_header: &'a str, cookie_name: &str) -> Option<&'a str> {
cookie_header.split(';').find_map(|part| {
let (name, value) = part.trim().split_once('=')?;
(name == cookie_name).then_some(value)
})
}
#[cfg(test)]
mod tests {
use std::convert::Infallible;
use super::*;
use tower::{ServiceExt, service_fn};
#[tokio::test]
async fn csrf_layer_allows_safe_methods_without_tokens() {
let layer = CsrfMiddlewareLayer {
cookie_name: "csrftoken".to_string(),
header_name: "x-csrftoken".to_string(),
};
let service = layer.layer(service_fn(|_request: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(Body::from("ok")))
}));
let response = service
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn csrf_layer_rejects_unsafe_methods_with_mismatched_tokens() {
let layer = CsrfMiddlewareLayer {
cookie_name: "csrftoken".to_string(),
header_name: "x-csrftoken".to_string(),
};
let service = layer.layer(service_fn(|_request: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(Body::from("ok")))
}));
let response = service
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/")
.header(COOKIE, "csrftoken=expected")
.header("x-csrftoken", "actual")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn csrf_layer_allows_unsafe_methods_with_matching_tokens() {
let layer = CsrfMiddlewareLayer {
cookie_name: "csrftoken".to_string(),
header_name: "x-csrftoken".to_string(),
};
let service = layer.layer(service_fn(|_request: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(Body::from("ok")))
}));
let response = service
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/")
.header(COOKIE, "csrftoken=match")
.header("x-csrftoken", "match")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
assert_eq!(response.status(), StatusCode::OK);
}
#[test]
fn cookie_value_extracts_named_cookie_from_multi_cookie_header() {
assert_eq!(
cookie_value("sessionid=abc; csrftoken=expected; theme=dark", "csrftoken"),
Some("expected")
);
}
#[test]
fn has_matching_token_requires_non_empty_cookie_and_header_values() {
let headers = HeaderMap::from_iter([
(COOKIE, http::HeaderValue::from_static("csrftoken=expected")),
(
HeaderName::from_static("x-csrftoken"),
http::HeaderValue::from_static(""),
),
]);
assert!(!has_matching_token(&headers, "csrftoken", "x-csrftoken"));
}
#[test]
fn csrf_layer_is_cloneable() {
let layer = CsrfMiddlewareLayer {
cookie_name: "csrftoken".to_string(),
header_name: "x-csrftoken".to_string(),
};
let _ = layer.clone();
}
}