#![forbid(unsafe_code, future_incompatible)]
#![deny(
missing_debug_implementations,
nonstandard_style,
missing_docs,
unreachable_pub,
missing_copy_implementations,
unused_qualifications
)]
use std::{
convert::Infallible,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use axum::http::{self, HeaderValue, Request, StatusCode};
use axum_core::response::{IntoResponse, Response};
use axum_sessions::{async_session::Session, SessionHandle};
use rand::RngCore;
use tokio::sync::RwLockWriteGuard;
use tower::Layer;
#[derive(Clone, Copy, Debug)]
pub struct CsrfLayer {
pub regenerate_token: RegenerateToken,
pub request_header: &'static str,
pub response_header: &'static str,
pub session_key: &'static str,
}
impl Default for CsrfLayer {
fn default() -> Self {
Self {
regenerate_token: Default::default(),
request_header: "X-CSRF-TOKEN",
response_header: "X-CSRF-TOKEN",
session_key: "_csrf_token",
}
}
}
impl CsrfLayer {
pub fn new() -> Self {
Self::default()
}
pub fn regenerate(mut self, regenerate_token: RegenerateToken) -> Self {
self.regenerate_token = regenerate_token;
self
}
pub fn request_header(mut self, request_header: &'static str) -> Self {
self.request_header = request_header;
self
}
pub fn response_header(mut self, response_header: &'static str) -> Self {
self.response_header = response_header;
self
}
pub fn session_key(mut self, session_key: &'static str) -> Self {
self.session_key = session_key;
self
}
fn regenerate_token(
&self,
session_write: &mut RwLockWriteGuard<Session>,
) -> Result<String, Error> {
let mut buf = [0; 32];
rand::thread_rng().try_fill_bytes(&mut buf)?;
let token = base64::encode(buf);
session_write.insert(self.session_key, &token)?;
Ok(token)
}
fn response_with_token(&self, mut response: Response, server_token: &str) -> Response {
response.headers_mut().insert(
self.response_header,
match HeaderValue::from_str(server_token).map_err(Error::from) {
Ok(token_header) => token_header,
Err(error) => return error.into_response(),
},
);
response
}
}
impl<S> Layer<S> for CsrfLayer {
type Service = CsrfMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
CsrfMiddleware::new(inner, *self)
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[allow(clippy::enum_variant_names)]
pub enum RegenerateToken {
#[default]
PerSession,
PerUse,
PerRequest,
}
#[derive(thiserror::Error, Debug)]
enum Error {
#[error("Random number generator error")]
Rng(#[from] rand::Error),
#[error("Serde JSON error")]
Serde(#[from] axum_sessions::async_session::serde_json::Error),
#[error("Session extension missing. Is `axum_sessions::SessionLayer` installed and layered around the `axum_csrf_sync_pattern::CsrfLayer`?")]
SessionLayerMissing,
#[error("Incoming CSRF token header was not valid ASCII")]
InvalidClientTokenHeader(#[from] http::header::ToStrError),
#[error("Invalid CSRF token when preparing response header")]
InvalidServerTokenHeader(#[from] http::header::InvalidHeaderValue),
}
impl IntoResponse for Error {
fn into_response(self) -> Response {
tracing::error!(?self);
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
#[derive(Debug, Clone)]
pub struct CsrfMiddleware<S> {
inner: S,
layer: CsrfLayer,
}
impl<S> CsrfMiddleware<S> {
pub fn new(inner: S, layer: CsrfLayer) -> Self {
CsrfMiddleware { inner, layer }
}
pub fn layer() -> CsrfLayer {
CsrfLayer::default()
}
}
impl<S, B: Send + 'static> tower::Service<Request<B>> for CsrfMiddleware<S>
where
S: tower::Service<Request<B>, Response = Response, Error = Infallible> + Send + Clone + 'static,
S::Future: Send,
{
type Response = S::Response;
type Error = Infallible;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<B>) -> Self::Future {
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
let layer = self.layer;
Box::pin(async move {
let session_handle = match req
.extensions()
.get::<SessionHandle>()
.ok_or(Error::SessionLayerMissing)
{
Ok(session_handle) => session_handle,
Err(error) => return Ok(error.into_response()),
};
let mut session_write = session_handle.write().await;
let mut server_token = match session_write.get::<String>(layer.session_key) {
Some(token) => token,
None => match layer.regenerate_token(&mut session_write) {
Ok(token) => token,
Err(error) => return Ok(error.into_response()),
},
};
if !req.method().is_safe() {
let client_token = {
match req.headers().get(layer.request_header) {
Some(token) => token,
None => {
tracing::warn!("{} header missing!", layer.request_header);
return Ok(layer.response_with_token(
StatusCode::FORBIDDEN.into_response(),
&server_token,
));
}
}
};
let client_token = match client_token.to_str().map_err(Error::from) {
Ok(token) => token,
Err(error) => {
return Ok(layer.response_with_token(error.into_response(), &server_token))
}
};
if client_token != server_token {
tracing::warn!("{} header mismatch!", layer.request_header);
return Ok(layer.response_with_token(
(StatusCode::FORBIDDEN).into_response(),
&server_token,
));
}
}
if layer.regenerate_token == RegenerateToken::PerRequest
|| (!req.method().is_safe() && layer.regenerate_token == RegenerateToken::PerUse)
{
server_token = match layer.regenerate_token(&mut session_write) {
Ok(token) => token,
Err(error) => {
return Ok(layer.response_with_token(error.into_response(), &server_token))
}
};
}
drop(session_write);
let response = inner.call(req).await.into_response();
Ok(layer.response_with_token(response, &server_token))
})
}
}
#[cfg(test)]
mod tests {
use std::convert::Infallible;
use axum::{body::Body, routing::get, Router};
use axum_core::response::{IntoResponse, Response};
use axum_sessions::{async_session::MemoryStore, extractors::ReadableSession, SessionLayer};
use http::{
header::{COOKIE, SET_COOKIE},
Method, Request, StatusCode,
};
use tower::{Service, ServiceExt};
use super::*;
async fn handler() -> Result<Response, Infallible> {
Ok((
StatusCode::OK,
"The default test success response has a body",
)
.into_response())
}
fn session_layer() -> SessionLayer<MemoryStore> {
let mut secret = [0; 64];
rand::thread_rng().try_fill_bytes(&mut secret).unwrap();
SessionLayer::new(MemoryStore::new(), &secret)
}
fn app(csrf_layer: CsrfLayer) -> Router {
Router::new()
.route("/", get(handler).post(handler))
.layer(csrf_layer)
.layer(session_layer())
}
#[tokio::test]
async fn get_without_token_succeeds() {
let request = Request::builder()
.method(Method::GET)
.body(Body::empty())
.unwrap();
let response = app(CsrfLayer::new()).oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
assert_eq!(base64::decode(client_token).unwrap().len(), 32);
}
#[tokio::test]
async fn post_without_token_fails() {
let request = Request::builder()
.method(Method::POST)
.body(Body::empty())
.unwrap();
let response = app(CsrfLayer::new()).oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
assert_eq!(base64::decode(client_token).unwrap().len(), 32);
}
#[tokio::test]
async fn session_token_remains_valid() {
let mut app = app(CsrfLayer::new().regenerate(RegenerateToken::PerSession));
let response = app
.ready()
.await
.unwrap()
.call(Request::builder().body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let session_cookie = response.headers().get(SET_COOKIE).unwrap().clone();
let initial_client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
assert_eq!(base64::decode(initial_client_token).unwrap().len(), 32);
let response = app
.ready()
.await
.unwrap()
.call(
Request::builder()
.method(Method::POST)
.header("X-CSRF-TOKEN", initial_client_token)
.header(COOKIE, session_cookie.clone())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
assert_eq!(client_token, initial_client_token);
let response = app
.ready()
.await
.unwrap()
.call(
Request::builder()
.method(Method::POST)
.header("X-CSRF-TOKEN", initial_client_token)
.header(COOKIE, session_cookie)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
assert_eq!(client_token, initial_client_token);
}
#[tokio::test]
async fn single_use_token_is_regenerated() {
let mut app = app(CsrfLayer::new().regenerate(RegenerateToken::PerUse));
let response = app
.ready()
.await
.unwrap()
.call(Request::builder().body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let session_cookie = response.headers().get(SET_COOKIE).unwrap().clone();
let initial_client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
assert_eq!(base64::decode(initial_client_token).unwrap().len(), 32);
let response = app
.ready()
.await
.unwrap()
.call(
Request::builder()
.method(Method::POST)
.header("X-CSRF-TOKEN", initial_client_token)
.header(COOKIE, session_cookie.clone())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
assert_ne!(client_token, initial_client_token);
let response = app
.ready()
.await
.unwrap()
.call(
Request::builder()
.method(Method::POST)
.header("X-CSRF-TOKEN", initial_client_token)
.header(COOKIE, session_cookie)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
assert_ne!(client_token, initial_client_token);
}
#[tokio::test]
async fn single_request_token_is_regenerated() {
let mut app = app(CsrfLayer::new().regenerate(RegenerateToken::PerRequest));
let response = app
.ready()
.await
.unwrap()
.call(Request::builder().body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let session_cookie = response.headers().get(SET_COOKIE).unwrap().clone();
let initial_client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
assert_eq!(base64::decode(initial_client_token).unwrap().len(), 32);
let response = app
.ready()
.await
.unwrap()
.call(
Request::builder()
.method(Method::GET)
.header(COOKIE, session_cookie.clone())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
assert_ne!(client_token, initial_client_token);
let response = app
.ready()
.await
.unwrap()
.call(
Request::builder()
.method(Method::POST)
.header("X-CSRF-TOKEN", client_token)
.header(COOKIE, session_cookie)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
assert_ne!(client_token, initial_client_token);
}
#[tokio::test]
async fn accepts_custom_request_header() {
let mut app = app(CsrfLayer::new().request_header("X-Custom-Token-Request-Header"));
let response = app
.ready()
.await
.unwrap()
.call(Request::builder().body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let session_cookie = response.headers().get(SET_COOKIE).unwrap().clone();
let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
assert_eq!(base64::decode(client_token).unwrap().len(), 32);
let response = app
.ready()
.await
.unwrap()
.call(
Request::builder()
.method(Method::POST)
.header("X-Custom-Token-Request-Header", client_token)
.header(COOKIE, session_cookie.clone())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn sends_custom_response_header() {
let response = app(CsrfLayer::new().response_header("X-Custom-Token-Response-Header"))
.oneshot(Request::builder().body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let client_token = response
.headers()
.get("X-Custom-Token-Response-Header")
.unwrap();
assert_eq!(base64::decode(client_token).unwrap().len(), 32);
}
#[tokio::test]
async fn uses_custom_session_key() {
async fn extract_session(session: ReadableSession) -> StatusCode {
let session_csrf_token: String = session.get("custom_session_key").unwrap();
assert_eq!(base64::decode(session_csrf_token).unwrap().len(), 32);
StatusCode::OK
}
let app = Router::new()
.route("/", get(extract_session))
.layer(CsrfLayer::new().session_key("custom_session_key"))
.layer(session_layer());
let response = app
.oneshot(Request::builder().body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn missing_session_layer_error_response() {
let app = Router::new()
.route("/", get(handler))
.layer(CsrfLayer::new());
let response = app
.oneshot(Request::builder().body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
async fn invalid_token_str_error_response() {
let layer = CsrfLayer::new();
let response = Response::builder()
.status(StatusCode::OK)
.body(axum::body::boxed(Body::empty()))
.unwrap();
let response = layer.response_with_token(response, "\n");
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
}