pub mod extractor;
use std::default::Default;
use std::error::Error;
use std::fmt::Display;
use std::future::{self, Future, Ready};
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use extractor::CsrfToken;
use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform};
use actix_web::http::header::{self, HeaderValue};
use actix_web::{HttpMessage, HttpResponse, ResponseError};
use extractor::CsrfCookieConfig;
use cookie::Cookie;
use cookie::SameSite;
macro_rules! token_name {
() => {
"Csrf-Token"
};
}
#[macro_export]
#[doc(hidden)]
macro_rules! host_prefix {
() => {
"__Host-"
};
}
#[macro_export]
#[doc(hidden)]
macro_rules! secure_prefix {
() => {
"__Secure-"
};
}
const DEFAULT_CSRF_TOKEN_NAME: &str = token_name!();
const DEFAULT_CSRF_COOKIE_NAME: &str = token_name!();
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
pub enum CsrfError {
TokenMismatch,
MissingCookie,
MissingToken,
}
impl Display for CsrfError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TokenMismatch => write!(f, "The CSRF Tokens do not match"),
Self::MissingCookie => write!(f, "The CSRF Cookie is missing"),
Self::MissingToken => write!(f, "The CSRF Header is missing"),
}
}
}
impl ResponseError for CsrfError {
fn error_response(&self) -> HttpResponse {
log::warn!("Potential CSRF attack: {}", self);
HttpResponse::UnprocessableEntity().finish()
}
}
impl Error for CsrfError {}
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct CsrfMiddleware {
inner: Inner,
}
impl CsrfMiddleware {
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
impl CsrfMiddleware {
#[must_use]
pub fn cookie_name<T: Into<String>>(mut self, name: T) -> Self {
self.inner.cookie_name = Rc::new(name.into());
self
}
#[must_use]
pub fn host_prefixed_cookie_name<T: AsRef<str>>(mut self, name: T) -> Self {
let mut prefixed = host_prefix!().to_owned();
prefixed.push_str(name.as_ref());
self.inner.cookie_name = Rc::new(prefixed);
self
}
#[must_use]
pub fn secure_prefixed_cookie_name<T: AsRef<str>>(mut self, name: T) -> Self {
let mut prefixed = secure_prefix!().to_owned();
prefixed.push_str(name.as_ref());
self.inner.cookie_name = Rc::new(prefixed);
self
}
#[must_use]
pub const fn same_site(mut self, same_site: Option<SameSite>) -> Self {
self.inner.same_site = same_site;
self
}
#[must_use]
pub const fn http_only(mut self, enabled: bool) -> Self {
self.inner.http_only = enabled;
self
}
#[must_use]
pub const fn secure(mut self, enabled: bool) -> Self {
self.inner.secure = enabled;
self
}
#[must_use]
pub fn domain<S: Into<String>>(mut self, domain: impl Into<Option<S>>) -> Self {
if let Some(stripped) = self.inner.cookie_name.strip_prefix(host_prefix!()) {
self.inner.cookie_name = Rc::new(format!(concat!(secure_prefix!(), "{}"), stripped));
}
self.inner.domain = domain.into().map(Into::into);
self
}
#[must_use]
pub fn cookie_config(&self) -> CsrfCookieConfig {
CsrfCookieConfig::new((*self.inner.cookie_name).clone())
}
}
impl Default for CsrfMiddleware {
fn default() -> Self {
Self {
inner: Inner::default(),
}
.cookie_name(DEFAULT_CSRF_COOKIE_NAME.to_string())
}
}
#[doc(hidden)]
pub struct CsrfMiddlewareImpl<S> {
service: S,
inner: Inner,
}
impl<S> CsrfMiddlewareImpl<S> {
fn get_cookie(&self, token: String, path: String) -> HeaderValue {
let mut cookie = Cookie::new(self.inner.cookie_name.as_ref(), token);
cookie.set_http_only(self.inner.http_only);
cookie.set_secure(self.inner.secure);
cookie.set_path(path);
if let Some(same_site) = self.inner.same_site {
cookie.set_same_site(same_site);
}
if let Some(ref domain) = &self.inner.domain {
cookie.set_domain(domain.clone());
}
HeaderValue::from_str(&cookie.to_string()).expect("cookie to be a valid header value")
}
}
#[derive(Clone, Eq, PartialEq, Debug)]
struct Inner {
cookie_name: Rc<String>,
http_only: bool,
same_site: Option<SameSite>,
secure: bool,
domain: Option<String>,
}
impl Default for Inner {
fn default() -> Self {
Self::with_rng()
}
}
impl Inner {
fn with_rng() -> Self {
Self {
cookie_name: Rc::new(DEFAULT_CSRF_COOKIE_NAME.to_owned()),
http_only: false,
same_site: None,
secure: true,
domain: None,
}
}
}
impl<S, B> Transform<S, ServiceRequest> for CsrfMiddleware
where
S: Service<ServiceRequest, Response = ServiceResponse<B>>,
S::Future: 'static,
B: 'static,
{
type Response = S::Response;
type Error = S::Error;
type InitError = ();
type Transform = CsrfMiddlewareImpl<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
future::ready(Ok(CsrfMiddlewareImpl {
service,
inner: self.inner.clone(),
}))
}
}
impl<S, B> Service<ServiceRequest> for CsrfMiddlewareImpl<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>>,
S::Future: 'static,
B: 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(ctx)
}
fn call(&self, req: ServiceRequest) -> Self::Future {
let token = uuid::Uuid::now_v7().to_string();
req.extensions_mut().insert(CsrfToken(token.clone()));
let cookie1 = self.get_cookie(token.clone(), "/".to_string());
let fut = self.service.call(req);
Box::pin(async move {
let mut res = fut.await?;
res.headers_mut().append(header::SET_COOKIE, cookie1);
Ok(res)
})
}
}