use actix_http::{header::HeaderMap, StatusCode};
#[cfg(feature = "actix-session")]
use actix_session::SessionExt;
use actix_utils::future::Either;
use actix_web::{
body::{EitherBody, MessageBody},
cookie::{time, Cookie, SameSite},
dev::forward_ready,
dev::{Service, ServiceRequest, ServiceResponse, Transform},
http::{header, Method},
web::BytesMut,
Error, FromRequest, HttpMessage, HttpRequest, HttpResponse, HttpResponseBuilder, ResponseError,
};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use futures_util::{
future::{err, ok, Ready},
ready,
stream::StreamExt,
};
use hmac::{Hmac, KeyInit, Mac};
use log::{error, warn};
use pin_project_lite::pin_project;
use rand::Rng;
use sha2::Sha256;
use std::{
collections::HashMap,
error, fmt,
future::Future,
marker::PhantomData,
pin::Pin,
rc::Rc,
task::{Context, Poll},
};
use subtle::ConstantTimeEq;
use url::Url;
pub const DEFAULT_CSRF_TOKEN_KEY: &str = "CSRF";
pub const DEFAULT_CSRF_ANON_TOKEN_KEY: &str = "CSRF-ANON";
pub const DEFAULT_CSRF_TOKEN_FIELD: &str = "csrf_token";
pub const DEFAULT_CSRF_TOKEN_HEADER: &str = "X-CSRF-Token";
pub const DEFAULT_SESSION_ID_KEY: &str = "id";
pub const CSRF_PRE_SESSION_KEY: &str = "pre-session";
const PRE_SESSION_HTTP_ONLY: bool = true;
const PRE_SESSION_SAME_SITE: SameSite = SameSite::Strict;
const TOKEN_LEN: usize = 32;
type HmacSha256 = Hmac<Sha256>;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum TokenClass {
Anonymous,
Authorized,
}
impl TokenClass {
fn as_str(&self) -> &'static str {
match self {
TokenClass::Anonymous => "anon",
TokenClass::Authorized => "auth",
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CsrfError {
TokenMissing,
TokenInvalid,
OriginRejected,
MultipartNotEnabled,
BodyTooLarge,
BodyRead,
Internal,
}
impl CsrfError {
pub fn code(self) -> &'static str {
match self {
CsrfError::TokenMissing => "csrf_token_missing",
CsrfError::TokenInvalid => "csrf_token_invalid",
CsrfError::OriginRejected => "csrf_origin_rejected",
CsrfError::MultipartNotEnabled => "csrf_multipart_not_enabled",
CsrfError::BodyTooLarge => "csrf_body_too_large",
CsrfError::BodyRead => "csrf_body_read_error",
CsrfError::Internal => "csrf_internal_error",
}
}
}
impl fmt::Display for CsrfError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.code())
}
}
impl error::Error for CsrfError {}
impl ResponseError for CsrfError {
fn status_code(&self) -> StatusCode {
match self {
CsrfError::OriginRejected => StatusCode::FORBIDDEN,
CsrfError::BodyTooLarge => StatusCode::PAYLOAD_TOO_LARGE,
CsrfError::Internal => StatusCode::INTERNAL_SERVER_ERROR,
CsrfError::TokenMissing
| CsrfError::TokenInvalid
| CsrfError::MultipartNotEnabled
| CsrfError::BodyRead => StatusCode::BAD_REQUEST,
}
}
fn error_response(&self) -> HttpResponse {
let mut resp = HttpResponse::build(self.status_code())
.content_type("application/json")
.body(format!(r#"{{"error":"{}"}}"#, self.code()));
resp.extensions_mut().insert(*self);
resp
}
}
#[derive(Clone, PartialEq)]
pub enum CsrfPattern {
#[cfg(feature = "actix-session")]
SynchronizerToken,
DoubleSubmitCookie,
}
#[derive(Clone)]
pub struct CsrfDoubleSubmitCookie {
pub http_only: bool,
pub same_site: SameSite,
}
#[derive(Clone)]
pub struct CsrfMiddlewareConfig {
pub pattern: CsrfPattern,
pub manual_multipart: bool,
pub session_id_cookie_name: String,
pub secure: bool,
pub token_cookie_name: String,
pub anon_token_cookie_name: String,
#[cfg(feature = "actix-session")]
pub anon_session_key_name: String,
pub token_form_field: String,
pub token_header_name: String,
pub token_cookie_config: Option<CsrfDoubleSubmitCookie>,
pub secret_key: zeroize::Zeroizing<Vec<u8>>,
pub skip_for: Vec<String>,
pub enforce_origin: bool,
pub allowed_origins: Vec<String>,
pub max_body_bytes: usize,
}
impl CsrfMiddlewareConfig {
#[cfg(feature = "actix-session")]
pub fn synchronizer_token(secret_key: &[u8]) -> Self {
check_secret_key(secret_key);
CsrfMiddlewareConfig {
pattern: CsrfPattern::SynchronizerToken,
session_id_cookie_name: DEFAULT_SESSION_ID_KEY.to_string(),
token_cookie_name: DEFAULT_CSRF_TOKEN_KEY.into(),
anon_token_cookie_name: DEFAULT_CSRF_ANON_TOKEN_KEY.into(),
#[cfg(feature = "actix-session")]
anon_session_key_name: format!("{DEFAULT_CSRF_TOKEN_KEY}-anon"),
token_form_field: DEFAULT_CSRF_TOKEN_FIELD.into(),
token_header_name: DEFAULT_CSRF_TOKEN_HEADER.into(),
token_cookie_config: None,
secret_key: zeroize::Zeroizing::new(secret_key.into()),
skip_for: vec![],
manual_multipart: false,
secure: true,
enforce_origin: false,
allowed_origins: vec![],
max_body_bytes: 2 * 1024 * 1024, }
}
pub fn double_submit_cookie(secret_key: &[u8]) -> Self {
check_secret_key(secret_key);
CsrfMiddlewareConfig {
pattern: CsrfPattern::DoubleSubmitCookie,
session_id_cookie_name: DEFAULT_SESSION_ID_KEY.to_string(),
token_cookie_name: DEFAULT_CSRF_TOKEN_KEY.into(),
anon_token_cookie_name: DEFAULT_CSRF_ANON_TOKEN_KEY.into(),
#[cfg(feature = "actix-session")]
anon_session_key_name: format!("{DEFAULT_CSRF_TOKEN_KEY}-anon"),
token_form_field: DEFAULT_CSRF_TOKEN_FIELD.into(),
token_header_name: DEFAULT_CSRF_TOKEN_HEADER.into(),
token_cookie_config: Some(CsrfDoubleSubmitCookie {
http_only: false, same_site: SameSite::Strict,
}),
secret_key: zeroize::Zeroizing::new(secret_key.into()),
skip_for: vec![],
manual_multipart: false,
secure: true,
enforce_origin: false,
allowed_origins: vec![],
max_body_bytes: 2 * 1024 * 1024,
}
}
pub fn with_multipart(mut self, multipart: bool) -> Self {
self.manual_multipart = multipart;
self
}
pub fn with_max_body_bytes(mut self, limit: usize) -> Self {
self.max_body_bytes = limit;
self
}
pub fn with_token_cookie_config(mut self, config: CsrfDoubleSubmitCookie) -> Self {
self.token_cookie_config = Some(config);
self
}
pub fn with_secure(mut self, secure: bool) -> Self {
self.secure = secure;
self
}
pub fn with_skip_for(mut self, patches: Vec<String>) -> Self {
self.skip_for = patches;
self
}
pub fn with_enforce_origin(mut self, enforce: bool, allowed: Vec<String>) -> Self {
self.enforce_origin = enforce;
self.allowed_origins = allowed;
self
}
}
pub struct CsrfMiddleware {
config: Rc<CsrfMiddlewareConfig>,
}
impl CsrfMiddleware {
pub fn new(config: CsrfMiddlewareConfig) -> Self {
Self {
config: Rc::new(config),
}
}
}
impl<S, B> Transform<S, ServiceRequest> for CsrfMiddleware
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
B: MessageBody,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type Transform = CsrfMiddlewareService<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(CsrfMiddlewareService {
service: Rc::new(service),
config: self.config.clone(),
})
}
}
pub struct CsrfMiddlewareService<S> {
service: Rc<S>,
config: Rc<CsrfMiddlewareConfig>,
}
impl<S> CsrfMiddlewareService<S> {
fn get_session_from_cookie(&self, req: &ServiceRequest) -> (String, bool, TokenClass) {
if let Some(id) = req
.cookie(&self.config.session_id_cookie_name)
.map(|c| c.value().to_string())
{
(id, false, TokenClass::Authorized)
} else if let Some(val) = req
.cookie(CSRF_PRE_SESSION_KEY)
.map(|c| c.value().to_string())
{
if let Some(pre_id) = decode_pre_session_cookie(&val, self.config.secret_key.as_slice())
{
(pre_id, false, TokenClass::Anonymous)
} else {
(generate_random_token(), true, TokenClass::Anonymous)
}
} else {
(generate_random_token(), true, TokenClass::Anonymous)
}
}
fn get_true_token(
&self,
req: &ServiceRequest,
session_id: Option<&str>,
class: TokenClass,
pre_session_regenerated: bool,
) -> (String, bool) {
match self.config.pattern {
#[cfg(feature = "actix-session")]
CsrfPattern::SynchronizerToken => {
let session = req.get_session();
let key = match class {
TokenClass::Authorized => &self.config.token_cookie_name,
TokenClass::Anonymous => &self.config.anon_session_key_name,
};
let found = session.get::<String>(key).ok().flatten();
match found {
Some(tok) => (tok, false),
None => (generate_random_token(), true),
}
}
CsrfPattern::DoubleSubmitCookie => {
let (cookie_name, ctx) = match class {
TokenClass::Authorized => {
(&self.config.token_cookie_name, TokenClass::Authorized)
}
TokenClass::Anonymous => {
(&self.config.anon_token_cookie_name, TokenClass::Anonymous)
}
};
let existing = req.cookie(cookie_name).map(|c| c.value().to_string());
match existing {
Some(tok) if !pre_session_regenerated => (tok, false),
_ => {
let secret = self.config.secret_key.as_slice();
let tok = generate_hmac_token_ctx(
ctx,
session_id.expect("Session or pre-session id is passed"),
secret,
);
(tok, true)
}
}
}
}
}
fn should_skip_validation(&self, req: &ServiceRequest) -> bool {
let req_path = req.path();
self.config
.skip_for
.iter()
.any(|prefix| req_path.starts_with(prefix))
}
}
impl<S, B> Service<ServiceRequest> for CsrfMiddlewareService<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
B: MessageBody,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type Future = Either<CsrfTokenValidator<S, B>, Ready<Result<Self::Response, Self::Error>>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
if self.should_skip_validation(&req) {
let resp = CsrfResponse {
fut: self.service.call(req),
config: Some(self.config.clone()),
set_token: None,
set_pre_session: None,
token_class: None,
remove_pre_session: false,
_phantom: PhantomData,
};
return Either::left(CsrfTokenValidator::CsrfResponse { response: resp });
}
let (true_token, should_set_token, cookie_session, token_class): (
String,
bool,
Option<(String, bool)>,
Option<TokenClass>,
) = match self.config.pattern {
CsrfPattern::DoubleSubmitCookie => {
let (session_id, set_pre_session, token_class) = self.get_session_from_cookie(&req);
let (true_token, should_set_token) =
self.get_true_token(&req, Some(&session_id), token_class, set_pre_session);
(
true_token,
should_set_token,
Some((session_id, set_pre_session)),
Some(token_class),
)
}
#[cfg(feature = "actix-session")]
CsrfPattern::SynchronizerToken => {
let (session_id, set_pre_session, token_class) = self.get_session_from_cookie(&req);
let (token, should_set_token) =
self.get_true_token(&req, None, token_class, set_pre_session);
(
token,
should_set_token,
Some((session_id, set_pre_session)),
Some(token_class),
)
}
};
req.extensions_mut().insert(CsrfToken(true_token.clone()));
req.extensions_mut().insert(self.config.clone());
let is_mutating = matches!(
*req.method(),
Method::POST | Method::PUT | Method::PATCH | Method::DELETE
);
if !is_mutating {
let mut set_token_bytes = if should_set_token {
Some(true_token.clone())
} else {
None
};
let session_id = if let Some((ref session_id, set_pre_session)) = cookie_session {
if set_pre_session {
Some(session_id.clone())
} else {
None
}
} else {
None
};
if self.config.pattern == CsrfPattern::DoubleSubmitCookie {
if let (Some(TokenClass::Authorized), Some((ref sess_id, _))) =
(token_class, cookie_session.as_ref())
{
if req.cookie(&self.config.token_cookie_name).is_none() {
let tok = generate_hmac_token_ctx(
TokenClass::Authorized,
sess_id,
self.config.secret_key.as_slice(),
);
set_token_bytes = Some(tok);
}
}
}
let remove_pre_session = matches!(token_class, Some(TokenClass::Authorized));
let resp = CsrfResponse {
fut: self.service.call(req),
config: Some(self.config.clone()),
set_token: set_token_bytes,
set_pre_session: session_id,
token_class,
remove_pre_session,
_phantom: PhantomData,
};
return Either::left(CsrfTokenValidator::CsrfResponse { response: resp });
}
if self.config.enforce_origin && !origin_allowed(req.headers(), &self.config) {
let resp = CsrfError::OriginRejected.error_response();
return Either::right(ok(req
.into_response(resp)
.map_into_boxed_body()
.map_into_right_body()));
}
if let Some(ct) = req
.headers()
.get(header::CONTENT_TYPE)
.and_then(|hv| hv.to_str().ok())
{
if ct.starts_with("multipart/form-data") {
if !self.config.manual_multipart {
let resp = CsrfError::MultipartNotEnabled.error_response();
return Either::right(ok(req
.into_response(resp)
.map_into_boxed_body()
.map_into_right_body()));
}
let resp = CsrfResponse {
fut: self.service.call(req),
config: Some(self.config.clone()),
set_token: None,
set_pre_session: None,
token_class: None,
remove_pre_session: false,
_phantom: PhantomData,
};
return Either::left(CsrfTokenValidator::CsrfResponse { response: resp });
}
}
let (session_id, token_class) = if let Some((session_id, _)) = cookie_session {
(Some(session_id), token_class)
} else {
(None, token_class)
};
let header_token = req
.headers()
.get(&self.config.token_header_name)
.and_then(|hv| hv.to_str().ok())
.map(|s| s.to_string());
if let Some(token) = header_token {
return Either::left(CsrfTokenValidator::MutatingRequest {
service: self.service.clone(),
config: self.config.clone(),
true_token,
client_token: token,
session_id,
token_class,
req: Some(req),
});
}
let mut req2 = req;
let payload = req2.take_payload();
let initial_capacity = req2
.headers()
.get(header::CONTENT_LENGTH)
.and_then(|hv| hv.to_str().ok())
.and_then(|s| s.parse::<usize>().ok())
.map(|n| n.min(self.config.max_body_bytes))
.unwrap_or(0);
let body_buf = if initial_capacity > 0 {
BytesMut::with_capacity(initial_capacity)
} else {
BytesMut::new()
};
Either::left(CsrfTokenValidator::ReadingBody {
req: Some(req2),
payload: Some(payload),
body_bytes: body_buf,
config: self.config.clone(),
service: self.service.clone(),
true_token,
session_id,
token_class,
})
}
}
pin_project! {
#[project = CsrfTokenValidatorProj]
pub enum CsrfTokenValidator<S, B>
where
S: Service<ServiceRequest>,
B: MessageBody,
{
CsrfResponse {
#[pin]
response: CsrfResponse<S, B>,
},
MutatingRequest {
service: Rc<S>,
config: Rc<CsrfMiddlewareConfig>,
true_token: String,
client_token: String,
session_id: Option<String>,
token_class: Option<TokenClass>,
req: Option<ServiceRequest>
},
ReadingBody {
service: Rc<S>,
config: Rc<CsrfMiddlewareConfig>,
req: Option<ServiceRequest>,
payload: Option<actix_web::dev::Payload>,
body_bytes: BytesMut,
true_token: String,
session_id: Option<String>,
token_class: Option<TokenClass>,
},
}
}
impl<S, B> Future for CsrfTokenValidator<S, B>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
B: MessageBody,
{
type Output = Result<ServiceResponse<EitherBody<B>>, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().project() {
CsrfTokenValidatorProj::CsrfResponse { response } => response.poll(cx),
CsrfTokenValidatorProj::MutatingRequest {
service,
config,
true_token,
client_token,
session_id,
token_class,
req,
} => {
#[cfg(not(feature = "actix-session"))]
let _ = &true_token;
if let Some(req) = req.take() {
let session_id = if config.pattern == CsrfPattern::DoubleSubmitCookie {
if let Some(id) = session_id.take() {
Some(id)
} else {
error!("session id is empty in csrf token validator");
let resp = CsrfError::Internal.error_response();
return Poll::Ready(Ok(req
.into_response(resp)
.map_into_boxed_body()
.map_into_right_body()));
}
} else {
None
};
let valid = match &config.pattern {
#[cfg(feature = "actix-session")]
CsrfPattern::SynchronizerToken => {
if eq_tokens(true_token.as_bytes(), client_token.as_bytes()) {
true
} else {
let alt_valid = {
let session = req.get_session();
let alt_key = match token_class
.as_ref()
.copied()
.unwrap_or(TokenClass::Authorized)
{
TokenClass::Authorized => &config.anon_session_key_name,
TokenClass::Anonymous => &config.token_cookie_name,
};
let alt = session.get::<String>(alt_key).ok().flatten();
alt.map(|t| eq_tokens(t.as_bytes(), client_token.as_bytes()))
.unwrap_or(false)
};
alt_valid
}
}
CsrfPattern::DoubleSubmitCookie => {
let ctx = token_class
.as_ref()
.copied()
.unwrap_or(TokenClass::Anonymous);
validate_hmac_token_ctx(
ctx,
session_id
.as_deref()
.expect("session id cannot be empty is hmac validation"),
client_token.as_bytes(),
config.secret_key.as_slice(),
)
.unwrap_or(false)
}
};
if !valid {
let resp = CsrfError::TokenInvalid.error_response();
return Poll::Ready(Ok(req
.into_response(resp)
.map_into_boxed_body()
.map_into_right_body()));
}
let new_token = match &config.pattern {
#[cfg(feature = "actix-session")]
CsrfPattern::SynchronizerToken => generate_random_token(),
CsrfPattern::DoubleSubmitCookie => {
let ctx = token_class
.as_ref()
.copied()
.unwrap_or(TokenClass::Anonymous);
generate_hmac_token_ctx(
ctx,
session_id
.as_deref()
.expect("session id cannot be empty is hmac validation"),
config.secret_key.as_ref(),
)
}
};
let resp = CsrfResponse {
fut: service.call(req),
config: Some(config.clone()),
set_token: Some(new_token),
set_pre_session: None,
token_class: *token_class,
remove_pre_session: false,
_phantom: PhantomData,
};
self.set(CsrfTokenValidator::CsrfResponse { response: resp });
cx.waker().wake_by_ref(); Poll::Pending
} else {
error!("request already taken in csrf validator's state machine");
Poll::Ready(Err(CsrfError::Internal.into()))
}
}
CsrfTokenValidatorProj::ReadingBody {
service,
config,
req,
payload,
body_bytes,
true_token,
session_id,
token_class,
} => {
if req.is_none() {
error!("request already taken in csrf validator's state machine");
return Poll::Ready(Err(CsrfError::Internal.into()));
}
let request_mut = req.as_mut().unwrap();
let payload = match payload.as_mut() {
Some(p) => p,
None => {
error!("payload missing in reading body state");
return Poll::Ready(Err(CsrfError::Internal.into()));
}
};
match payload.poll_next_unpin(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(Ok(bytes))) => {
body_bytes.extend_from_slice(&bytes);
if body_bytes.len() > config.max_body_bytes {
let req_owned = req.take().unwrap();
let resp = CsrfError::BodyTooLarge.error_response();
return Poll::Ready(Ok(req_owned
.into_response(resp)
.map_into_boxed_body()
.map_into_right_body()));
}
cx.waker().wake_by_ref();
Poll::Pending
}
Poll::Ready(Some(Err(e))) => {
error!("failed to read request body for csrf extraction: {e:?}");
let req_owned = req.take().unwrap();
let resp = CsrfError::BodyRead.error_response();
Poll::Ready(Ok(req_owned
.into_response(resp)
.map_into_boxed_body()
.map_into_right_body()))
}
Poll::Ready(None) => {
let body = std::mem::take(&mut *body_bytes).freeze();
let client_token = match sync_read_token_from_body(
request_mut.headers(),
&body,
&config.token_form_field,
) {
Some(token) => token,
None => {
let req_owned = req.take().unwrap();
let res = CsrfError::TokenMissing.error_response();
return Poll::Ready(Ok(req_owned
.into_response(res)
.map_into_boxed_body()
.map_into_right_body()));
}
};
request_mut.set_payload(actix_web::dev::Payload::from(body.clone()));
let req_owned = req.take().unwrap();
let next_state = {
let service = service.clone();
let config = config.clone();
let true_token = std::mem::take(true_token);
let session_id = session_id.take();
let token_class = token_class.take();
let req = Some(req_owned);
CsrfTokenValidator::MutatingRequest {
service,
config,
true_token,
client_token,
session_id,
token_class,
req,
}
};
self.set(next_state);
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
}
}
}
fn sync_read_token_from_body(
headers: &HeaderMap,
body: &[u8],
token_field: &str,
) -> Option<String> {
if let Some(ct) = headers.get(header::CONTENT_TYPE) {
if let Ok(ct) = ct.to_str() {
if ct.starts_with("application/json") {
if let Ok(json) = serde_json::from_slice::<serde_json::Value>(body) {
return json
.get(token_field)
.and_then(|v| v.as_str().map(String::from));
}
} else if ct.starts_with("application/x-www-form-urlencoded") {
if let Ok(form) = serde_urlencoded::from_bytes::<HashMap<String, String>>(body) {
return form.get(token_field).cloned();
}
} else {
warn!("unsupported request content type, unable to extract and verify csrf token");
}
}
}
None
}
pin_project! {
pub struct CsrfResponse<S, B>
where
S: Service<ServiceRequest>,
B: MessageBody,
{
#[pin]
fut: S::Future,
config: Option<Rc<CsrfMiddlewareConfig>>,
set_token: Option<String>,
set_pre_session: Option<String>,
token_class: Option<TokenClass>,
remove_pre_session: bool,
_phantom: PhantomData<B>,
}
}
impl<S, B> Future for CsrfResponse<S, B>
where
B: MessageBody,
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
{
type Output = Result<ServiceResponse<EitherBody<B>>, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut().project();
match ready!(this.fut.poll(cx)) {
Ok(mut resp) => {
let config = match &this.config {
Some(config) => config,
None => {
error!("unable to extract csrf middleware config in csrf response");
let res = CsrfError::Internal.error_response();
return Poll::Ready(Ok(resp
.into_response(res)
.map_into_boxed_body()
.map_into_right_body()));
}
};
if let Some(pre_session_id) = this.set_pre_session {
let cookie_val =
encode_pre_session_cookie(pre_session_id, config.secret_key.as_slice());
match resp.response_mut().add_cookie(
&Cookie::build(CSRF_PRE_SESSION_KEY, cookie_val)
.http_only(PRE_SESSION_HTTP_ONLY)
.secure(config.secure)
.same_site(PRE_SESSION_SAME_SITE)
.path("/")
.finish(),
) {
Ok(_) => {}
Err(e) => {
error!("unable to set pre-session cookie in csrf response: {e:?}");
let res = CsrfError::Internal.error_response();
return Poll::Ready(Ok(resp
.into_response(res)
.map_into_boxed_body()
.map_into_right_body()));
}
}
}
if *this.remove_pre_session {
if let Err(e) = resp
.response_mut()
.add_cookie(&expired_pre_session_cookie(config.secure))
{
error!("unable to expire pre-session cookie in csrf response: {e:?}");
let res = CsrfError::Internal.error_response();
return Poll::Ready(Ok(resp
.into_response(res)
.map_into_boxed_body()
.map_into_right_body()));
}
if matches!(config.pattern, CsrfPattern::DoubleSubmitCookie) {
if let Err(e) = resp.response_mut().add_cookie(&expire_cookie(
&config.anon_token_cookie_name,
config.secure,
)) {
error!("unable to expire anon token cookie in csrf response: {e:?}");
let res = CsrfError::Internal.error_response();
return Poll::Ready(Ok(resp
.into_response(res)
.map_into_boxed_body()
.map_into_right_body()));
}
}
}
let teardown = resp.request().extensions().get::<CsrfTeardown>().is_some();
if let Some(new_token) = this.set_token.take().filter(|_| !teardown) {
match config.pattern {
#[cfg(feature = "actix-session")]
CsrfPattern::SynchronizerToken => {
if *this.remove_pre_session {
let _ = resp
.request()
.get_session()
.remove(&config.anon_session_key_name);
}
let key = match this.token_class.unwrap_or(TokenClass::Authorized) {
TokenClass::Authorized => &config.token_cookie_name,
TokenClass::Anonymous => &config.anon_session_key_name,
};
match resp.request().get_session().insert(key, new_token) {
Ok(()) => {}
Err(e) => {
error!("unable to set a csrf token with actix session in csrf response: {e:?}");
let res = CsrfError::Internal.error_response();
return Poll::Ready(Ok(resp
.into_response(res)
.map_into_boxed_body()
.map_into_right_body()));
}
}
}
CsrfPattern::DoubleSubmitCookie => {
let cookie_config = match &config.token_cookie_config {
Some(config) => config,
None => {
error!(
"unable to extract token_cookie_config in csrf response"
);
let res = CsrfError::Internal.error_response();
return Poll::Ready(Ok(resp
.into_response(res)
.map_into_boxed_body()
.map_into_right_body()));
}
};
let cookie_name =
match this.token_class.unwrap_or(TokenClass::Anonymous) {
TokenClass::Authorized => &config.token_cookie_name,
TokenClass::Anonymous => &config.anon_token_cookie_name,
};
let new_token_cookie = Cookie::build(cookie_name, new_token)
.http_only(cookie_config.http_only)
.secure(config.secure)
.same_site(cookie_config.same_site)
.path("/")
.finish();
match resp.response_mut().add_cookie(&new_token_cookie) {
Ok(_) => {}
Err(e) => {
error!("unable to set a token cookie in csrf response: {e:?}");
let res = CsrfError::Internal.error_response();
return Poll::Ready(Ok(resp
.into_response(res)
.map_into_boxed_body()
.map_into_right_body()));
}
}
}
}
}
Poll::Ready(Ok(resp.map_into_left_body()))
}
Err(err) => Poll::Ready(Err(err)),
}
}
}
#[derive(Clone)]
pub struct CsrfToken(pub String);
impl FromRequest for CsrfToken {
type Error = Error;
type Future = Ready<Result<Self, Self::Error>>;
fn from_request(req: &HttpRequest, _: &mut actix_web::dev::Payload) -> Self::Future {
match req.extensions().get::<CsrfToken>() {
Some(token) => ok(token.clone()),
None => {
error!("CsrfToken extracted without CsrfMiddleware installed");
err(CsrfError::Internal.into())
}
}
}
}
pub trait CsrfRequestExt {
fn rotate_csrf_after_login(
&self,
session_id: &str,
resp: &mut HttpResponseBuilder,
) -> Result<(), Error>;
fn rotate_csrf_after_logout(&self, resp: &mut HttpResponseBuilder) -> Result<(), Error>;
}
impl CsrfRequestExt for HttpRequest {
fn rotate_csrf_after_login(
&self,
session_id: &str,
resp: &mut HttpResponseBuilder,
) -> Result<(), Error> {
let config = config_from_request(self)?;
rotate_csrf_after_login(session_id, self, resp, config.as_ref())
}
fn rotate_csrf_after_logout(&self, resp: &mut HttpResponseBuilder) -> Result<(), Error> {
let config = config_from_request(self)?;
rotate_csrf_after_logout(self, resp, config.as_ref())
}
}
fn config_from_request(req: &HttpRequest) -> Result<Rc<CsrfMiddlewareConfig>, Error> {
req.extensions()
.get::<Rc<CsrfMiddlewareConfig>>()
.cloned()
.ok_or_else(|| {
error!("CSRF middleware config not found in request extensions");
CsrfError::Internal.into()
})
}
pub fn generate_random_token() -> String {
let mut buf = [0u8; TOKEN_LEN];
rand::rng().fill_bytes(&mut buf);
URL_SAFE_NO_PAD.encode(buf)
}
pub fn generate_hmac_token_ctx(class: TokenClass, id: &str, secret: &[u8]) -> String {
let tok = generate_random_token();
let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC can take key of any size");
mac.update(class.as_str().as_bytes());
mac.update(b"|");
mac.update(id.as_bytes());
mac.update(b"|");
mac.update(tok.as_bytes());
let hmac_hex = hex::encode(mac.finalize().into_bytes());
format!("{hmac_hex}.{tok}")
}
pub fn eq_tokens(token_a: &[u8], token_b: &[u8]) -> bool {
token_a.ct_eq(token_b).unwrap_u8() == 1
}
fn encode_pre_session_cookie(id: &str, secret: &[u8]) -> String {
let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC can take key of any size");
mac.update(b"pre|");
mac.update(id.as_bytes());
let sig = hex::encode(mac.finalize().into_bytes());
format!("{sig}.{id}")
}
fn decode_pre_session_cookie(val: &str, secret: &[u8]) -> Option<String> {
let parts: Vec<&str> = val.split('.').collect();
if parts.len() != 2 {
return None;
}
let (sig_hex, id) = (parts[0], parts[1]);
let sig_bytes = hex::decode(sig_hex).ok()?;
let mut mac = Hmac::<Sha256>::new_from_slice(secret).ok()?;
mac.update(b"pre|");
mac.update(id.as_bytes());
let expected = mac.finalize().into_bytes();
if eq_tokens(&expected, &sig_bytes) {
Some(id.to_string())
} else {
None
}
}
pub fn validate_hmac_token_ctx(
class: TokenClass,
id: &str,
token: &[u8],
secret: &[u8],
) -> Result<bool, Error> {
let token_str = std::str::from_utf8(token)?;
let parts: Vec<&str> = token_str.split('.').collect();
if parts.len() != 2 {
return Ok(false);
}
let (hmac_hex, csrf_token) = (parts[0], parts[1]);
let hmac_bytes = hex::decode(hmac_hex).map_err(actix_web::error::ErrorInternalServerError)?;
let mut mac = Hmac::<Sha256>::new_from_slice(secret)
.map_err(actix_web::error::ErrorInternalServerError)?;
mac.update(class.as_str().as_bytes());
mac.update(b"|");
mac.update(id.as_bytes());
mac.update(b"|");
mac.update(csrf_token.as_bytes());
let expected_hmac = mac.finalize().into_bytes();
Ok(eq_tokens(&expected_hmac, &hmac_bytes))
}
pub fn validate_hmac_token(session_id: &str, token: &[u8], secret: &[u8]) -> Result<bool, Error> {
validate_hmac_token_ctx(TokenClass::Authorized, session_id, token, secret)
}
struct CsrfTeardown;
fn expire_cookie(name: &str, secure: bool) -> Cookie<'static> {
let mut del = Cookie::new(name.to_owned(), "");
del.set_max_age(time::Duration::seconds(0));
del.set_expires(time::OffsetDateTime::UNIX_EPOCH);
del.set_path("/");
del.set_secure(secure);
del
}
fn expired_pre_session_cookie(secure: bool) -> Cookie<'static> {
let mut del = expire_cookie(CSRF_PRE_SESSION_KEY, secure);
del.set_http_only(PRE_SESSION_HTTP_ONLY);
del.set_same_site(PRE_SESSION_SAME_SITE);
del
}
#[cfg_attr(not(feature = "actix-session"), allow(unused_variables))]
pub fn rotate_csrf_after_login(
session_id: &str,
req: &HttpRequest,
resp: &mut HttpResponseBuilder,
config: &CsrfMiddlewareConfig,
) -> Result<(), Error> {
resp.cookie(expired_pre_session_cookie(config.secure));
match config.pattern {
#[cfg(feature = "actix-session")]
CsrfPattern::SynchronizerToken => {
let session = req.get_session();
let _ = session.remove(&config.anon_session_key_name);
session
.insert(&config.token_cookie_name, generate_random_token())
.map_err(|_| {
actix_web::error::ErrorInternalServerError(
"Failed to rotate CSRF token in session",
)
})?;
Ok(())
}
CsrfPattern::DoubleSubmitCookie => {
let token = generate_hmac_token_ctx(
TokenClass::Authorized,
session_id,
config.secret_key.as_slice(),
);
let (http_only, same_site) = match &config.token_cookie_config {
Some(cfg) => (cfg.http_only, cfg.same_site),
None => (true, SameSite::Lax),
};
let csrf_cookie = Cookie::build(&config.token_cookie_name, token)
.http_only(http_only)
.secure(config.secure)
.same_site(same_site)
.path("/")
.finish();
resp.cookie(csrf_cookie);
resp.cookie(expire_cookie(&config.anon_token_cookie_name, config.secure));
Ok(())
}
}
}
#[cfg_attr(not(feature = "actix-session"), allow(unused_variables))]
pub fn rotate_csrf_after_logout(
req: &HttpRequest,
resp: &mut HttpResponseBuilder,
config: &CsrfMiddlewareConfig,
) -> Result<(), Error> {
req.extensions_mut().insert(CsrfTeardown);
resp.cookie(expired_pre_session_cookie(config.secure));
match config.pattern {
#[cfg(feature = "actix-session")]
CsrfPattern::SynchronizerToken => {
req.get_session().purge();
}
CsrfPattern::DoubleSubmitCookie => {
resp.cookie(expire_cookie(&config.session_id_cookie_name, config.secure));
resp.cookie(expire_cookie(&config.token_cookie_name, config.secure));
resp.cookie(expire_cookie(&config.anon_token_cookie_name, config.secure));
}
}
Ok(())
}
fn check_secret_key(secret_key: &[u8]) {
if secret_key.len() < 32 {
panic!("csrf secret_key too short: require >=32 bytes");
}
}
fn origin_allowed(headers: &HeaderMap, cfg: &CsrfMiddlewareConfig) -> bool {
if !cfg.enforce_origin {
return true;
}
if cfg.allowed_origins.is_empty() {
return false;
}
let is_allowed_origin = |u: &Url| -> bool {
cfg.allowed_origins.iter().any(|allowed| {
if let Ok(au) = Url::parse(allowed) {
au.scheme() == u.scheme()
&& au.host_str() == u.host_str()
&& au.port_or_known_default() == u.port_or_known_default()
} else {
false
}
})
};
if let Some(origin) = headers.get(header::ORIGIN).and_then(|hv| hv.to_str().ok()) {
if let Ok(u) = Url::parse(origin) {
return is_allowed_origin(&u);
}
return false;
}
if let Some(referer) = headers.get(header::REFERER).and_then(|hv| hv.to_str().ok()) {
if let Ok(u) = Url::parse(referer) {
let origin = format!(
"{}://{}{}",
u.scheme(),
u.host_str().unwrap_or(""),
u.port().map(|p| format!(":{p}")).unwrap_or_default()
);
if let Ok(o) = Url::parse(&origin) {
return is_allowed_origin(&o);
}
}
return false;
}
false
}