1#![deny(unsafe_code)]
2#![warn(clippy::pedantic, clippy::nursery, clippy::cargo, missing_docs)]
3
4use std::cell::RefCell;
114use std::collections::HashSet;
115use std::default::Default;
116use std::error::Error;
117use std::fmt::Display;
118use std::future::{self, Future, Ready};
119use std::pin::Pin;
120use std::rc::Rc;
121use std::task::{Context, Poll};
122
123use crate::extractor::CsrfToken;
124
125use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform};
126use actix_web::error::InternalError;
127use actix_web::http::header::{self, HeaderValue};
128use actix_web::http::{Method, StatusCode};
129use actix_web::{HttpMessage, HttpResponse, ResponseError};
130use cookie::{Cookie, SameSite};
131use extractor::CsrfCookieConfig;
132use rand::SeedableRng;
133use tracing::{error, warn};
134
135pub mod extractor;
136mod token_rng;
137
138pub use crate::token_rng::TokenRng;
139
140macro_rules! token_name {
141 () => {
142 "Csrf-Token"
143 };
144}
145
146#[macro_export]
147#[doc(hidden)]
148macro_rules! host_prefix {
149 () => {
150 "__Host-"
151 };
152}
153
154#[macro_export]
155#[doc(hidden)]
156macro_rules! secure_prefix {
157 () => {
158 "__Secure-"
159 };
160}
161
162const DEFAULT_CSRF_TOKEN_NAME: &str = token_name!();
163const DEFAULT_CSRF_COOKIE_NAME: &str = concat!(host_prefix!(), token_name!());
164
165#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
167pub enum CsrfError {
168 TokenMismatch,
170 MissingCookie,
172 MissingToken,
174}
175
176impl Display for CsrfError {
177 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178 match self {
179 Self::TokenMismatch => write!(f, "The CSRF Tokens do not match"),
180 Self::MissingCookie => write!(f, "The CSRF Cookie is missing"),
181 Self::MissingToken => write!(f, "The CSRF Header is missing"),
182 }
183 }
184}
185
186impl ResponseError for CsrfError {
187 fn error_response(&self) -> HttpResponse {
188 warn!("Potential CSRF attack: {}", self);
189 HttpResponse::UnprocessableEntity().finish()
190 }
191}
192
193impl Error for CsrfError {}
194
195#[derive(Clone, Eq, PartialEq, Debug)]
197pub struct CsrfMiddleware<Rng> {
198 inner: Inner<Rng>,
199}
200
201impl<Rng: TokenRng + SeedableRng> CsrfMiddleware<Rng> {
202 #[must_use]
215 pub fn new() -> Self {
216 Self::default()
217 }
218}
219
220impl<Rng: TokenRng> CsrfMiddleware<Rng> {
221 #[must_use]
235 pub fn with_rng(rng: Rng) -> Self {
236 Self {
237 inner: Inner::with_rng(rng),
238 }
239 }
240}
241
242impl<Rng> CsrfMiddleware<Rng> {
243 #[must_use]
245 pub const fn enabled(mut self, enabled: bool) -> Self {
246 self.inner.csrf_enabled = enabled;
247 self
248 }
249
250 #[must_use]
254 pub fn set_cookie<T: Into<String>>(mut self, method: Method, uri: T) -> Self {
255 self.inner.set_cookie.insert((method, uri.into()));
256 self
257 }
258
259 #[must_use]
266 pub fn cookie_name<T: Into<String>>(mut self, name: T) -> Self {
267 self.inner.cookie_name = Rc::new(name.into());
268 self
269 }
270
271 #[must_use]
288 pub fn host_prefixed_cookie_name<T: AsRef<str>>(mut self, name: T) -> Self {
289 let mut prefixed = host_prefix!().to_owned();
290 prefixed.push_str(name.as_ref());
291 self.inner.cookie_name = Rc::new(prefixed);
292 self
293 }
294
295 #[must_use]
317 pub fn secure_prefixed_cookie_name<T: AsRef<str>>(mut self, name: T) -> Self {
318 let mut prefixed = secure_prefix!().to_owned();
319 prefixed.push_str(name.as_ref());
320 self.inner.cookie_name = Rc::new(prefixed);
321 self
322 }
323
324 #[must_use]
326 pub const fn same_site(mut self, same_site: Option<SameSite>) -> Self {
327 self.inner.same_site = same_site;
328 self
329 }
330
331 #[must_use]
333 pub const fn http_only(mut self, enabled: bool) -> Self {
334 self.inner.http_only = enabled;
335 self
336 }
337
338 #[must_use]
340 pub const fn secure(mut self, enabled: bool) -> Self {
341 self.inner.secure = enabled;
342 self
343 }
344
345 #[must_use]
353 pub fn domain<S: Into<String>>(mut self, domain: impl Into<Option<S>>) -> Self {
354 if let Some(stripped) = self.inner.cookie_name.strip_prefix(host_prefix!()) {
355 self.inner.cookie_name = Rc::new(format!(concat!(secure_prefix!(), "{}"), stripped));
356 }
357 self.inner.domain = domain.into().map(Into::into);
358 self
359 }
360
361 #[must_use]
365 pub fn cookie_config(&self) -> CsrfCookieConfig {
366 CsrfCookieConfig::new((*self.inner.cookie_name).clone())
367 }
368}
369
370impl<Rng: TokenRng + SeedableRng> Default for CsrfMiddleware<Rng> {
371 fn default() -> Self {
372 Self {
373 inner: Inner::default(),
374 }
375 .cookie_name(DEFAULT_CSRF_COOKIE_NAME.to_string())
376 }
377}
378
379impl<S, Rng> Transform<S, ServiceRequest> for CsrfMiddleware<Rng>
380where
381 S: Service<ServiceRequest, Response = ServiceResponse>,
382 Rng: TokenRng + Clone,
383{
384 type Response = ServiceResponse;
385 type Error = S::Error;
386 type InitError = ();
387 type Transform = CsrfMiddlewareImpl<S, Rng>;
388 type Future = Ready<Result<Self::Transform, Self::InitError>>;
389
390 fn new_transform(&self, service: S) -> Self::Future {
391 future::ready(Ok(CsrfMiddlewareImpl {
392 service,
393 inner: self.inner.clone(),
394 }))
395 }
396}
397
398#[doc(hidden)]
399pub struct CsrfMiddlewareImpl<S, Rng> {
400 service: S,
401 inner: Inner<Rng>,
402}
403
404#[derive(Clone, Eq, PartialEq, Debug)]
405struct Inner<Rng> {
406 rng: RefCell<Rng>,
408 cookie_name: Rc<String>,
409 http_only: bool,
410 same_site: Option<SameSite>,
411 secure: bool,
412 domain: Option<String>,
413
414 csrf_enabled: bool,
416 set_cookie: HashSet<(Method, String)>,
417}
418
419impl<Rng: TokenRng + SeedableRng> Default for Inner<Rng> {
420 fn default() -> Self {
421 Self::with_rng(Rng::from_entropy())
422 }
423}
424
425impl<Rng: TokenRng> Inner<Rng> {
426 fn with_rng(rng: Rng) -> Self {
427 Self {
428 rng: RefCell::new(rng),
429 cookie_name: Rc::new(DEFAULT_CSRF_COOKIE_NAME.to_owned()),
430 csrf_enabled: true,
431 http_only: true,
432 same_site: Some(SameSite::Strict),
433 secure: true,
434 domain: None,
435 set_cookie: HashSet::new(),
436 }
437 }
438
439 fn contains(&self, req: &ServiceRequest) -> bool {
440 req.match_pattern().map_or_else(
441 || {
442 self.set_cookie
443 .contains(&(req.method().clone(), req.path().to_string()))
444 },
445 |p| self.set_cookie.contains(&(req.method().clone(), p)),
446 )
447 }
448}
449
450impl<S, Rng> Service<ServiceRequest> for CsrfMiddlewareImpl<S, Rng>
451where
452 S: Service<ServiceRequest, Response = ServiceResponse>,
453 Rng: TokenRng,
454{
455 type Response = ServiceResponse;
456 type Error = S::Error;
457 type Future = CsrfMiddlewareImplFuture<S>;
458
459 fn call(&self, req: ServiceRequest) -> Self::Future {
460 let cookie = if self.inner.csrf_enabled && self.inner.contains(&req) {
461 let token =
462 match self.inner.rng.borrow_mut().generate_token() {
463 Ok(token) => token,
464 Err(e) => {
465 error!("Failed to generate CSRF token, aborting request");
466 return CsrfMiddlewareImplFuture::CsrfError(req.error_response(
467 InternalError::new(e, StatusCode::INTERNAL_SERVER_ERROR),
468 ));
469 }
470 };
471
472 let cookie = {
473 let mut cookie_builder =
474 Cookie::build(self.inner.cookie_name.as_ref(), token.clone())
475 .http_only(self.inner.http_only)
476 .secure(self.inner.secure)
477 .path("/");
478
479 if let Some(same_site) = self.inner.same_site {
480 cookie_builder = cookie_builder.same_site(same_site);
481 }
482
483 if let Some(domain) = &self.inner.domain {
484 cookie_builder = cookie_builder.domain(domain);
485 }
486
487 cookie_builder.finish()
488 };
489
490 let csrf_token = CsrfToken(token);
491 req.extensions_mut().insert(csrf_token);
492
493 let header = HeaderValue::from_str(&cookie.to_string())
497 .expect("cookie to be a valid header value");
498
499 Some(header)
500 } else {
501 None
502 };
503
504 CsrfMiddlewareImplFuture::Passthrough(Passthrough {
505 cookie,
506 service: Box::pin(self.service.call(req)),
507 })
508 }
509
510 fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
511 self.service.poll_ready(ctx)
512 }
513}
514
515#[doc(hidden)]
516#[derive(Debug)]
517pub enum CsrfMiddlewareImplFuture<S: Service<ServiceRequest>> {
518 CsrfError(ServiceResponse),
520 Passthrough(Passthrough<S::Future>),
522}
523
524impl<S> Future for CsrfMiddlewareImplFuture<S>
525where
526 S: Service<ServiceRequest, Response = ServiceResponse>,
527{
528 type Output = Result<ServiceResponse, S::Error>;
529
530 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
531 match self.get_mut() {
532 Self::CsrfError(error) => {
533 let req = error.request().clone();
535 let mut new_error = ServiceResponse::new(req, HttpResponse::NoContent().finish());
536 std::mem::swap(&mut new_error, error);
537 Poll::Ready(Ok(new_error))
538 }
539 Self::Passthrough(inner) => match inner.service.as_mut().poll(cx) {
540 Poll::Ready(Ok(mut res)) => {
541 if let Some(ref cookie) = inner.cookie {
542 res.response_mut()
543 .headers_mut()
544 .insert(header::SET_COOKIE, cookie.clone());
546 }
547
548 Poll::Ready(Ok(res))
549 }
550 other => other,
551 },
552 }
553 }
554}
555
556#[doc(hidden)]
557#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
558pub struct Passthrough<Fut> {
559 cookie: Option<HeaderValue>,
560 service: Pin<Box<Fut>>,
561}
562
563#[cfg(test)]
564mod tests {
565 use crate::extractor::{Csrf, CsrfHeader};
566
567 use super::*;
568
569 use actix_web::http::StatusCode;
570 use actix_web::test::{self, TestRequest};
571 use actix_web::{post, web, App, HttpResponse, Responder};
572 use rand::rngs::StdRng;
573
574 fn get_token_from_resp(resp: &ServiceResponse) -> String {
575 let cookie = get_cookie_from_resp(resp);
576 let token_header = cookie.split('=');
578 let token = token_header.skip(1).take(1).collect::<Vec<_>>()[0];
579 let token = token.split(';').next().expect("split to work");
580 String::from(token)
581 }
582
583 fn get_cookie_from_resp(resp: &ServiceResponse) -> String {
584 let cookie_header: Vec<_> = resp
585 .headers()
586 .iter()
587 .filter(|(header_name, _)| header_name.as_str() == "set-cookie")
588 .map(|(_, value)| value.to_str().expect("header to be valid string"))
589 .map(|v| v.split(';').next().expect("split to work"))
590 .collect();
591 assert_eq!(1, cookie_header.len());
592 String::from(*cookie_header.get(0).expect("header to have cookie"))
593 }
594
595 fn get_cookie_domain_from_resp(resp: &ServiceResponse) -> String {
596 let cookie_header: Vec<_> = resp
597 .headers()
598 .iter()
599 .filter(|(header_name, _)| header_name.as_str() == "set-cookie")
600 .map(|(_, value)| value.to_str().expect("header to be valid string"))
601 .flat_map(|v| v.split(';'))
602 .collect();
603 String::from(
604 cookie_header
605 .into_iter()
606 .find_map(|s| s.trim().strip_prefix("Domain="))
607 .expect("header to have cookie"),
608 )
609 }
610
611 #[tokio::test]
612 async fn attaches_token() {
613 let mut srv = test::init_service(
614 App::new()
615 .wrap(CsrfMiddleware::<StdRng>::new().set_cookie(Method::GET, "/"))
616 .service(web::resource("/").to(|| HttpResponse::Ok())),
617 )
618 .await;
619 let resp = test::call_service(&mut srv, TestRequest::with_uri("/").to_request()).await;
620 assert_eq!(resp.status(), StatusCode::OK);
621
622 assert!(get_cookie_from_resp(&resp).contains(DEFAULT_CSRF_COOKIE_NAME));
624 }
625
626 #[tokio::test]
627 async fn post_request_rejected_without_header() {
628 #[post("/")]
629 async fn test_route(_: Csrf<CsrfHeader>) -> impl Responder {
630 HttpResponse::Ok()
631 }
632
633 let mut srv = test::init_service(
634 App::new()
635 .wrap(CsrfMiddleware::<StdRng>::new())
636 .service(test_route),
637 )
638 .await;
639
640 let resp = test::call_service(&mut srv, TestRequest::post().uri("/").to_request()).await;
641 assert_eq!(resp.status(), StatusCode::UNPROCESSABLE_ENTITY);
642 }
643
644 #[tokio::test]
646 async fn double_submit_correct_token() {
647 let mut srv = test::init_service(
648 App::new()
649 .wrap(CsrfMiddleware::<StdRng>::new().set_cookie(Method::GET, "/"))
650 .service(
651 web::resource("/")
652 .route(web::get().to(|| HttpResponse::Ok()))
653 .route(web::post().to(|| HttpResponse::Ok())),
654 ),
655 )
656 .await;
657
658 let resp = test::call_service(&mut srv, TestRequest::with_uri("/").to_request()).await;
660
661 let token = get_token_from_resp(&resp);
662 let cookie = get_cookie_from_resp(&resp);
663
664 let req = TestRequest::post()
666 .uri("/")
667 .insert_header(("Cookie", cookie))
668 .insert_header((DEFAULT_CSRF_TOKEN_NAME, token))
669 .to_request();
670
671 let resp = test::call_service(&mut srv, req).await;
672 assert_eq!(resp.status(), StatusCode::OK);
673 }
674
675 #[tokio::test]
676 async fn domain_attribute_is_set() {
677 let mut srv = test::init_service(
678 App::new()
679 .wrap(
680 CsrfMiddleware::<StdRng>::new()
681 .set_cookie(Method::GET, "/")
682 .domain("example.com"),
683 )
684 .service(web::resource("/").to(|| HttpResponse::Ok())),
685 )
686 .await;
687 let resp = test::call_service(&mut srv, TestRequest::with_uri("/").to_request()).await;
688 assert_eq!(resp.status(), StatusCode::OK);
689
690 assert_eq!(get_cookie_domain_from_resp(&resp), "example.com");
691 }
692
693 #[tokio::test]
694 async fn path_info_is_set() {
695 let mut srv = test::init_service(
696 App::new()
697 .wrap(CsrfMiddleware::<StdRng>::new().set_cookie(Method::GET, "/{id}"))
698 .service(
699 web::resource("/{id}")
700 .route(web::get().to(|| HttpResponse::Ok()))
701 .route(web::post().to(|| HttpResponse::Ok())),
702 ),
703 )
704 .await;
705
706 let resp = test::call_service(&mut srv, TestRequest::with_uri("/1").to_request()).await;
708
709 let token = get_token_from_resp(&resp);
710 let cookie = get_cookie_from_resp(&resp);
711
712 let req = TestRequest::post()
714 .uri("/1")
715 .insert_header(("Cookie", cookie))
716 .insert_header((DEFAULT_CSRF_TOKEN_NAME, token))
717 .to_request();
718
719 let resp = test::call_service(&mut srv, req).await;
720 assert_eq!(resp.status(), StatusCode::OK);
721 }
722}