#![deny(unsafe_code)]
#![warn(clippy::pedantic, clippy::nursery, clippy::cargo, missing_docs)]
use std::cell::RefCell;
use std::collections::HashSet;
use std::default::Default;
use std::error::Error;
use std::fmt::Display;
use std::future::{self, Future, Ready};
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use crate::extractor::CsrfToken;
use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform};
use actix_web::error::InternalError;
use actix_web::http::header::{self, HeaderValue};
use actix_web::http::{Method, StatusCode};
use actix_web::{HttpMessage, HttpResponse, ResponseError};
use cookie::{Cookie, SameSite};
use extractor::CsrfCookieConfig;
use rand::SeedableRng;
use tracing::{error, warn};
pub mod extractor;
mod token_rng;
pub use crate::token_rng::TokenRng;
macro_rules! token_name {
() => {
"Csrf-Token"
};
}
#[macro_export]
#[doc(hidden)]
macro_rules! host_prefix {
() => {
"__Host-"
};
}
#[macro_export]
#[doc(hidden)]
macro_rules! secure_prefix {
() => {
"__Secure-"
};
}
const DEFAULT_CSRF_TOKEN_NAME: &str = token_name!();
const DEFAULT_CSRF_COOKIE_NAME: &str = concat!(host_prefix!(), token_name!());
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
pub enum CsrfError {
TokenMismatch,
MissingCookie,
MissingToken,
}
impl Display for CsrfError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TokenMismatch => write!(f, "The CSRF Tokens do not match"),
Self::MissingCookie => write!(f, "The CSRF Cookie is missing"),
Self::MissingToken => write!(f, "The CSRF Header is missing"),
}
}
}
impl ResponseError for CsrfError {
fn error_response(&self) -> HttpResponse {
warn!("Potential CSRF attack: {}", self);
HttpResponse::UnprocessableEntity().finish()
}
}
impl Error for CsrfError {}
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct CsrfMiddleware<Rng> {
inner: Inner<Rng>,
}
impl<Rng: TokenRng + SeedableRng> CsrfMiddleware<Rng> {
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
impl<Rng: TokenRng> CsrfMiddleware<Rng> {
#[must_use]
pub fn with_rng(rng: Rng) -> Self {
Self {
inner: Inner::with_rng(rng),
}
}
}
impl<Rng> CsrfMiddleware<Rng> {
#[must_use]
pub const fn enabled(mut self, enabled: bool) -> Self {
self.inner.csrf_enabled = enabled;
self
}
#[must_use]
pub fn set_cookie<T: Into<String>>(mut self, method: Method, uri: T) -> Self {
self.inner.set_cookie.insert((method, uri.into()));
self
}
#[must_use]
pub fn cookie_name<T: Into<String>>(mut self, name: T) -> Self {
self.inner.cookie_name = Rc::new(name.into());
self
}
#[must_use]
pub fn host_prefixed_cookie_name<T: AsRef<str>>(mut self, name: T) -> Self {
let mut prefixed = host_prefix!().to_owned();
prefixed.push_str(name.as_ref());
self.inner.cookie_name = Rc::new(prefixed);
self
}
#[must_use]
pub fn secure_prefixed_cookie_name<T: AsRef<str>>(mut self, name: T) -> Self {
let mut prefixed = secure_prefix!().to_owned();
prefixed.push_str(name.as_ref());
self.inner.cookie_name = Rc::new(prefixed);
self
}
#[must_use]
pub const fn same_site(mut self, same_site: Option<SameSite>) -> Self {
self.inner.same_site = same_site;
self
}
#[must_use]
pub const fn http_only(mut self, enabled: bool) -> Self {
self.inner.http_only = enabled;
self
}
#[must_use]
pub const fn secure(mut self, enabled: bool) -> Self {
self.inner.secure = enabled;
self
}
#[must_use]
pub fn domain<S: Into<String>>(mut self, domain: impl Into<Option<S>>) -> Self {
if let Some(stripped) = self.inner.cookie_name.strip_prefix(host_prefix!()) {
self.inner.cookie_name = Rc::new(format!(concat!(secure_prefix!(), "{}"), stripped));
}
self.inner.domain = domain.into().map(Into::into);
self
}
#[must_use]
pub fn cookie_config(&self) -> CsrfCookieConfig {
CsrfCookieConfig::new((*self.inner.cookie_name).clone())
}
}
impl<Rng: TokenRng + SeedableRng> Default for CsrfMiddleware<Rng> {
fn default() -> Self {
Self {
inner: Inner::default(),
}
.cookie_name(DEFAULT_CSRF_COOKIE_NAME.to_string())
}
}
impl<S, Rng> Transform<S, ServiceRequest> for CsrfMiddleware<Rng>
where
S: Service<ServiceRequest, Response = ServiceResponse>,
Rng: TokenRng + Clone,
{
type Response = ServiceResponse;
type Error = S::Error;
type InitError = ();
type Transform = CsrfMiddlewareImpl<S, Rng>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
future::ready(Ok(CsrfMiddlewareImpl {
service,
inner: self.inner.clone(),
}))
}
}
#[doc(hidden)]
pub struct CsrfMiddlewareImpl<S, Rng> {
service: S,
inner: Inner<Rng>,
}
#[derive(Clone, Eq, PartialEq, Debug)]
struct Inner<Rng> {
rng: RefCell<Rng>,
cookie_name: Rc<String>,
http_only: bool,
same_site: Option<SameSite>,
secure: bool,
domain: Option<String>,
csrf_enabled: bool,
set_cookie: HashSet<(Method, String)>,
}
impl<Rng: TokenRng + SeedableRng> Default for Inner<Rng> {
fn default() -> Self {
Self::with_rng(Rng::from_entropy())
}
}
impl<Rng: TokenRng> Inner<Rng> {
fn with_rng(rng: Rng) -> Self {
Self {
rng: RefCell::new(rng),
cookie_name: Rc::new(DEFAULT_CSRF_COOKIE_NAME.to_owned()),
csrf_enabled: true,
http_only: true,
same_site: Some(SameSite::Strict),
secure: true,
domain: None,
set_cookie: HashSet::new(),
}
}
fn contains(&self, req: &ServiceRequest) -> bool {
req.match_pattern().map_or_else(
|| {
self.set_cookie
.contains(&(req.method().clone(), req.path().to_string()))
},
|p| self.set_cookie.contains(&(req.method().clone(), p)),
)
}
}
impl<S, Rng> Service<ServiceRequest> for CsrfMiddlewareImpl<S, Rng>
where
S: Service<ServiceRequest, Response = ServiceResponse>,
Rng: TokenRng,
{
type Response = ServiceResponse;
type Error = S::Error;
type Future = CsrfMiddlewareImplFuture<S>;
fn call(&self, req: ServiceRequest) -> Self::Future {
let cookie = if self.inner.csrf_enabled && self.inner.contains(&req) {
let token =
match self.inner.rng.borrow_mut().generate_token() {
Ok(token) => token,
Err(e) => {
error!("Failed to generate CSRF token, aborting request");
return CsrfMiddlewareImplFuture::CsrfError(req.error_response(
InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR),
));
}
};
let cookie = {
let mut cookie_builder =
Cookie::build(self.inner.cookie_name.as_ref(), token.clone())
.http_only(self.inner.http_only)
.secure(self.inner.secure)
.path("/");
if let Some(same_site) = self.inner.same_site {
cookie_builder = cookie_builder.same_site(same_site);
}
if let Some(domain) = &self.inner.domain {
cookie_builder = cookie_builder.domain(domain);
}
cookie_builder.finish()
};
let csrf_token = CsrfToken(token);
req.extensions_mut().insert(csrf_token);
let header = HeaderValue::from_str(&cookie.to_string())
.expect("cookie to be a valid header value");
Some(header)
} else {
None
};
CsrfMiddlewareImplFuture::Passthrough(Passthrough {
cookie,
service: Box::pin(self.service.call(req)),
})
}
fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(ctx)
}
}
#[doc(hidden)]
#[derive(Debug)]
pub enum CsrfMiddlewareImplFuture<S: Service<ServiceRequest>> {
CsrfError(ServiceResponse),
Passthrough(Passthrough<S::Future>),
}
impl<S> Future for CsrfMiddlewareImplFuture<S>
where
S: Service<ServiceRequest, Response = ServiceResponse>,
{
type Output = Result<ServiceResponse, S::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
Self::CsrfError(error) => {
let req = error.request().clone();
let mut new_error = ServiceResponse::new(req, HttpResponse::NoContent().finish());
std::mem::swap(&mut new_error, error);
Poll::Ready(Ok(new_error))
}
Self::Passthrough(inner) => match inner.service.as_mut().poll(cx) {
Poll::Ready(Ok(mut res)) => {
if let Some(ref cookie) = inner.cookie {
res.response_mut()
.headers_mut()
.insert(header::SET_COOKIE, cookie.clone());
}
Poll::Ready(Ok(res))
}
other => other,
},
}
}
}
#[doc(hidden)]
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
pub struct Passthrough<Fut> {
cookie: Option<HeaderValue>,
service: Pin<Box<Fut>>,
}
#[cfg(test)]
mod tests {
use crate::extractor::{Csrf, CsrfHeader};
use super::*;
use actix_web::http::StatusCode;
use actix_web::test::{self, TestRequest};
use actix_web::{post, web, App, HttpResponse, Responder};
use rand::rngs::StdRng;
fn get_token_from_resp(resp: &ServiceResponse) -> String {
let cookie = get_cookie_from_resp(resp);
let token_header = cookie.split('=');
let token = token_header.skip(1).take(1).collect::<Vec<_>>()[0];
let token = token.split(';').next().expect("split to work");
String::from(token)
}
fn get_cookie_from_resp(resp: &ServiceResponse) -> String {
let cookie_header: Vec<_> = resp
.headers()
.iter()
.filter(|(header_name, _)| header_name.as_str() == "set-cookie")
.map(|(_, value)| value.to_str().expect("header to be valid string"))
.map(|v| v.split(';').next().expect("split to work"))
.collect();
assert_eq!(1, cookie_header.len());
String::from(*cookie_header.get(0).expect("header to have cookie"))
}
fn get_cookie_domain_from_resp(resp: &ServiceResponse) -> String {
let cookie_header: Vec<_> = resp
.headers()
.iter()
.filter(|(header_name, _)| header_name.as_str() == "set-cookie")
.map(|(_, value)| value.to_str().expect("header to be valid string"))
.flat_map(|v| v.split(';'))
.collect();
String::from(
cookie_header
.into_iter()
.find_map(|s| s.trim().strip_prefix("Domain="))
.expect("header to have cookie"),
)
}
#[tokio::test]
async fn attaches_token() {
let mut srv = test::init_service(
App::new()
.wrap(CsrfMiddleware::<StdRng>::new().set_cookie(Method::GET, "/"))
.service(web::resource("/").to(|| HttpResponse::Ok())),
)
.await;
let resp = test::call_service(&mut srv, TestRequest::with_uri("/").to_request()).await;
assert_eq!(resp.status(), StatusCode::OK);
assert!(get_cookie_from_resp(&resp).contains(DEFAULT_CSRF_COOKIE_NAME));
}
#[tokio::test]
async fn post_request_rejected_without_header() {
#[post("/")]
async fn test_route(_: Csrf<CsrfHeader>) -> impl Responder {
HttpResponse::Ok()
}
let mut srv = test::init_service(
App::new()
.wrap(CsrfMiddleware::<StdRng>::new())
.service(test_route),
)
.await;
let resp = test::call_service(&mut srv, TestRequest::post().uri("/").to_request()).await;
assert_eq!(resp.status(), StatusCode::UNPROCESSABLE_ENTITY);
}
#[tokio::test]
async fn double_submit_correct_token() {
let mut srv = test::init_service(
App::new()
.wrap(CsrfMiddleware::<StdRng>::new().set_cookie(Method::GET, "/"))
.service(
web::resource("/")
.route(web::get().to(|| HttpResponse::Ok()))
.route(web::post().to(|| HttpResponse::Ok())),
),
)
.await;
let resp = test::call_service(&mut srv, TestRequest::with_uri("/").to_request()).await;
let token = get_token_from_resp(&resp);
let cookie = get_cookie_from_resp(&resp);
let req = TestRequest::post()
.uri("/")
.insert_header(("Cookie", cookie))
.insert_header((DEFAULT_CSRF_TOKEN_NAME, token))
.to_request();
let resp = test::call_service(&mut srv, req).await;
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn domain_attribute_is_set() {
let mut srv = test::init_service(
App::new()
.wrap(
CsrfMiddleware::<StdRng>::new()
.set_cookie(Method::GET, "/")
.domain("example.com"),
)
.service(web::resource("/").to(|| HttpResponse::Ok())),
)
.await;
let resp = test::call_service(&mut srv, TestRequest::with_uri("/").to_request()).await;
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(get_cookie_domain_from_resp(&resp), "example.com");
}
#[tokio::test]
async fn path_info_is_set() {
let mut srv = test::init_service(
App::new()
.wrap(CsrfMiddleware::<StdRng>::new().set_cookie(Method::GET, "/{id}"))
.service(
web::resource("/{id}")
.route(web::get().to(|| HttpResponse::Ok()))
.route(web::post().to(|| HttpResponse::Ok())),
),
)
.await;
let resp = test::call_service(&mut srv, TestRequest::with_uri("/1").to_request()).await;
let token = get_token_from_resp(&resp);
let cookie = get_cookie_from_resp(&resp);
let req = TestRequest::post()
.uri("/1")
.insert_header(("Cookie", cookie))
.insert_header((DEFAULT_CSRF_TOKEN_NAME, token))
.to_request();
let resp = test::call_service(&mut srv, req).await;
assert_eq!(resp.status(), StatusCode::OK);
}
}