#![cfg_attr(docsrs, feature(doc_cfg))]
#[cfg(feature = "client_ip")]
mod client_address;
use axum::{extract::FromRequestParts, http::StatusCode, response::IntoResponse};
use snafu::{OptionExt as _, Snafu};
#[derive(Debug, Snafu)]
pub enum Error {
InvalidCsrf,
NoMiddleware,
}
impl Error {
pub fn status_code(&self) -> StatusCode {
match self {
Self::InvalidCsrf => StatusCode::BAD_REQUEST,
Self::NoMiddleware => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
(self.status_code(), format!("{self}")).into_response()
}
}
pub mod middleware {
#[cfg(feature = "client_ip")]
use std::{net::IpAddr, sync::Arc};
#[cfg(feature = "client_ip")]
use axum::extract::State;
#[cfg(any(feature = "cookie", feature = "client_ip"))]
use axum::{extract::Request, middleware::Next, response::IntoResponse};
#[cfg(feature = "cookie")]
use axum_extra::extract::{CookieJar, cookie::Cookie};
#[cfg(any(feature = "cookie", feature = "client_ip"))]
use crate::CsrfToken;
#[cfg(feature = "client_ip")]
use crate::client_address::ClientAddress;
#[cfg(feature = "client_ip")]
pub trait ClientIpConfig {
fn csrf_secret_key(&self) -> &[u8];
fn is_trusted_forwarder(&self, addr: IpAddr) -> bool;
}
#[cfg(feature = "client_ip")]
impl<S: ClientIpConfig> ClientIpConfig for Arc<S> {
fn csrf_secret_key(&self) -> &[u8] {
S::csrf_secret_key(self)
}
fn is_trusted_forwarder(&self, addr: IpAddr) -> bool {
S::is_trusted_forwarder(self, addr)
}
}
#[cfg(feature = "client_ip")]
pub async fn client_ip<S: ClientIpConfig>(
state: State<S>,
client_addr: ClientAddress,
mut request: Request,
next: Next,
) -> impl IntoResponse {
use sha2::{Digest, Sha256};
let mut hash = Sha256::new_with_prefix("csrftoken ");
hash.update(state.csrf_secret_key());
hash.update(" ");
hash.update(client_addr.address.to_string());
request.extensions_mut().insert(CsrfToken {
expected_csrf_token: format!("{:x}", hash.finalize()),
});
next.run(request).await
}
#[cfg(feature = "cookie")]
pub async fn cookie(
mut cookie_jar: CookieJar,
mut request: Request,
next: Next,
) -> impl IntoResponse {
use crate::random_csrf_token;
const COOKIE_NAME: &str = "CRISSY_CSRF_TOKEN";
let csrf_token = if let Some(cookie) = cookie_jar.get(COOKIE_NAME) {
cookie.value().to_string()
} else {
let token = random_csrf_token();
cookie_jar = cookie_jar.add(
Cookie::build((COOKIE_NAME, token.clone()))
.permanent()
.path("/")
.http_only(true)
.secure(true)
.build(),
);
token
};
request.extensions_mut().insert(CsrfToken {
expected_csrf_token: csrf_token,
});
(cookie_jar, next.run(request).await)
}
}
#[derive(Clone)]
pub struct CsrfToken {
pub expected_csrf_token: String,
}
impl<S> FromRequestParts<S> for CsrfToken {
type Rejection = Error;
fn from_request_parts(
parts: &mut axum::http::request::Parts,
_state: &S,
) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
std::future::ready(
parts
.extensions
.get::<CsrfToken>()
.cloned()
.context(NoMiddlewareSnafu),
)
}
}
impl CsrfToken {
pub fn validate(&self, form_csrf_token: &str) -> Result<(), Error> {
if form_csrf_token != self.expected_csrf_token {
tracing::debug!(
csrf.session = self.expected_csrf_token,
csrf.form = form_csrf_token,
"invalid CSRF token"
);
return InvalidCsrfSnafu.fail();
}
Ok(())
}
}
#[cfg(feature = "cookie")]
fn random_csrf_token() -> String {
format!("{:x}", rand::random::<u128>())
}