crissy 0.1.1

CSRF protection middleware for Axum
Documentation
#![cfg_attr(docsrs, feature(doc_cfg))]

//! # crissy
//!
//! `crissy` is a middleware for [axum](https://crates.io/crates/axum) that protects
//! (browser-facing) web resources from cross-site request forgery (CSRF) attacks.
//!
//! ## Usage
//!
//! In short:
//!
//! - Pick a [`middleware`] and layer it in
//! - Extract [`CsrfToken`]
//! - Add [`CsrfToken::expected_csrf_token`] to your form
//! - Call [`CsrfToken::validate`] in your form handler
//!
//! For example:
//!
//! ```
//! use axum::{
//!     Form, Router,
//!     response::{Html, IntoResponse},
//!     routing::get,
//! };
//! use crissy::CsrfToken;
//! use serde::Deserialize;
//! use tokio::net::TcpListener;
//!
//! # #[cfg(feature = "cookie")]
//! let app = Router::<()>::new()
//!     .route("/", get(route_index).post(route_post))
//!     .layer(axum::middleware::from_fn(crissy::middleware::cookie));
//! // Run `app` as usual
//!
//! async fn route_index(csrf: CsrfToken) -> impl IntoResponse {
//!     Html(format!(
//!         r#"<form method="POST">
//!     <input type="hidden" name="csrf_token" value="{csrf}"/>
//!     <button type="submit">Submit</button>
//! </form>"#,
//!         csrf = csrf.expected_csrf_token,
//!     ))
//! }
//!
//! #[derive(Deserialize)]
//! struct Body {
//!     csrf_token: String,
//! }
//! async fn route_post(csrf: CsrfToken, body: Form<Body>) -> Result<impl IntoResponse, crissy::Error> {
//!     csrf.validate(&body.csrf_token)?;
//!     Ok("validation successful!")
//! }
//! ```
//!
//! ## Feature flags
//!
//! crissy supports the following feature flags:
//!
//! | Name        | Default | Description                |
//! |-------------|---------|----------------------------|
//! | `cookie`    | Yes     | [`middleware::cookie`]     |
//! | `client_ip` | No      | [`middleware::client_ip`]  |
//! | `full`      | No      | Enables all feature flags. |

#[cfg(feature = "client_ip")]
mod client_address;

use axum::{extract::FromRequestParts, http::StatusCode, response::IntoResponse};
use snafu::{OptionExt as _, Snafu};

#[derive(Debug, Snafu)]
pub enum Error {
    /// invalid CSRF token
    InvalidCsrf,
    /// no crissy middleware enabled
    NoMiddleware,
}

impl Error {
    pub fn status_code(&self) -> StatusCode {
        match self {
            Self::InvalidCsrf => StatusCode::BAD_REQUEST,
            Self::NoMiddleware => StatusCode::INTERNAL_SERVER_ERROR,
        }
    }
}
impl IntoResponse for Error {
    fn into_response(self) -> axum::response::Response {
        (self.status_code(), format!("{self}")).into_response()
    }
}

/// Users must use exactly one middleware from this module, which is responsible for
/// generating and maintaining the CSRF tokens.
///
/// Typically, you should prefer [`middleware::cookie`].
pub mod middleware {
    #[cfg(feature = "client_ip")]
    use std::{net::IpAddr, sync::Arc};

    #[cfg(feature = "client_ip")]
    use axum::extract::State;
    #[cfg(any(feature = "cookie", feature = "client_ip"))]
    use axum::{extract::Request, middleware::Next, response::IntoResponse};
    #[cfg(feature = "cookie")]
    use axum_extra::extract::{CookieJar, cookie::Cookie};

    #[cfg(any(feature = "cookie", feature = "client_ip"))]
    use crate::CsrfToken;
    #[cfg(feature = "client_ip")]
    use crate::client_address::ClientAddress;

    /// Configuration required by [`client_ip`], to be implemented over your [`State`].
    #[cfg(feature = "client_ip")]
    pub trait ClientIpConfig {
        /// Pepper used to prevent attackers from forging their own CSRF tokens.
        ///
        /// Should be configured to use a relatively static but high-quality random value.
        ///
        /// Rotating this key will invalidate all existing CSRF tokens.
        fn csrf_secret_key(&self) -> &[u8];

        /// Checks whether a client is a trusted reverse proxy that should be allowed to
        /// impersonate another IP address by sending an
        /// [`X-Forwarded-For`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/X-Forwarded-For)
        /// header.
        ///
        /// `crissy` allows nested forwarding; the whole chain will be followed up until the
        /// first untrusted forwarder.
        ///
        /// If this check is too narrow then all clients will be given the same CSRF token
        /// (based on your reverse proxy's IP address).
        /// If it is too broad then untrusted clients can forge arbitrary tokens simply
        /// by sending their own `X-Forwarded-For` header.
        ///
        /// # Reference implementations
        ///
        /// If there is no reverse proxy then this function should be configured to always return `false`.
        ///
        /// If the only reverse proxy is running on localhost then it should return [`addr.is_loopback()`](`IpAddr::is_loopback`).
        ///
        /// # Caveats
        ///
        /// Keep in mind that hostnames (including `localhost`) may often resolve to *both* IPv4 and IPv6
        /// addresses.
        ///
        /// If so, this function *must* be consistent across both.
        fn is_trusted_forwarder(&self, addr: IpAddr) -> bool;
    }
    #[cfg(feature = "client_ip")]
    impl<S: ClientIpConfig> ClientIpConfig for Arc<S> {
        fn csrf_secret_key(&self) -> &[u8] {
            S::csrf_secret_key(self)
        }
        fn is_trusted_forwarder(&self, addr: IpAddr) -> bool {
            S::is_trusted_forwarder(self, addr)
        }
    }

    /// Generates CSRF tokens based on the client's IP address.
    ///
    /// This can be used to enable CSRF protection in contexts where cookies are not available,
    /// for whatever reasons. If cookies *are* available, prefer [`cookie`] instead.
    ///
    /// Note that this does not (by itself) validate the tokens!
    /// That must be done by calling [`CsrfToken::validate`].
    ///
    /// # Requirements
    ///
    /// This implementation requires that the request supports [`axum::extract::ConnectInfo`],
    /// such as by using [`axum::Router::into_make_service_with_connect_info`].
    ///
    /// # Security Caveats
    ///
    /// Before using this strategy, READ THIS WHOLE SECTION.
    /// This method is easy to misconfigure, rendering it effectively useless.
    ///
    /// ## Salting
    ///
    /// IP addresses are peppered with a global secret key, in order to prevent
    /// third parties from generating their own tokens.
    ///
    /// See [`ClientIpConfig::csrf_secret_key`].
    ///
    /// ## Reverse Proxies
    ///
    /// This strategy requires clients to know the client's "true" IP address,
    /// which is often obscured by reverse proxies.
    ///
    /// Hence, any reverse proxy *must* be configured to send the
    /// [`X-Forwarded-For`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/X-Forwarded-For)
    /// header, and `crissy` must be configured to trust it by implementing [`ClientIpConfig::is_trusted_forwarder`].
    ///
    /// Otherwise, all clients will unconditionally end up being given the same CSRF token
    /// (for the reverse proxy's IP address).
    ///
    /// ## Network Address Translation/IP Reuse
    ///
    /// This strategy can only know about IP addresses, which these days often
    /// cover whole households, and [sometimes](https://en.wikipedia.org/wiki/Carrier-grade_NAT)
    /// even unrelated blocks of users.
    ///
    /// Users with the same IP address will be given the same CSRF token,
    /// and hence be able to mount CSRF attacks against each others.
    ///
    /// # SemVer
    ///
    /// Note that the exact signature is *unstable*, and may change without future SemVer bumps!
    /// The only thing that is safe to rely on is that it can be used with [`axum::middleware::from_fn_with_state`].
    #[cfg(feature = "client_ip")]
    pub async fn client_ip<S: ClientIpConfig>(
        state: State<S>,
        client_addr: ClientAddress,
        mut request: Request,
        next: Next,
    ) -> impl IntoResponse {
        use sha2::{Digest, Sha256};

        let mut hash = Sha256::new_with_prefix("csrftoken ");
        hash.update(state.csrf_secret_key());
        hash.update(" ");
        hash.update(client_addr.address.to_string());
        request.extensions_mut().insert(CsrfToken {
            expected_csrf_token: format!("{:x}", hash.finalize()),
        });
        next.run(request).await
    }

    /// Generates random CSRF tokens and stores them in user cookies.
    ///
    /// This method should be preferred whenever cookies are available.
    ///
    /// Note that this does not (by itself) validate the tokens!
    /// That must be done by calling [`CsrfToken::validate`].
    ///
    /// # SemVer
    ///
    /// Note that the exact signature is *unstable*, and may change without future SemVer bumps!
    /// The only thing that is safe to rely on is that it can be used with [`axum::middleware::from_fn`].
    #[cfg(feature = "cookie")]
    pub async fn cookie(
        mut cookie_jar: CookieJar,
        mut request: Request,
        next: Next,
    ) -> impl IntoResponse {
        use crate::random_csrf_token;

        const COOKIE_NAME: &str = "CRISSY_CSRF_TOKEN";
        let csrf_token = if let Some(cookie) = cookie_jar.get(COOKIE_NAME) {
            cookie.value().to_string()
        } else {
            let token = random_csrf_token();
            cookie_jar = cookie_jar.add(
                Cookie::build((COOKIE_NAME, token.clone()))
                    .permanent()
                    .path("/")
                    .http_only(true)
                    .secure(true)
                    .build(),
            );
            token
        };
        request.extensions_mut().insert(CsrfToken {
            expected_csrf_token: csrf_token,
        });
        (cookie_jar, next.run(request).await)
    }
}

/// The client's active CSRF token.
///
/// You *must* use a [`middleware`] that is responsible for injecting the `CsrfToken` into the request.
#[derive(Clone)]
pub struct CsrfToken {
    pub expected_csrf_token: String,
}

impl<S> FromRequestParts<S> for CsrfToken {
    type Rejection = Error;

    fn from_request_parts(
        parts: &mut axum::http::request::Parts,
        _state: &S,
    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
        std::future::ready(
            parts
                .extensions
                .get::<CsrfToken>()
                .cloned()
                .context(NoMiddlewareSnafu),
        )
    }
}
impl CsrfToken {
    /// Validates that the submitted CSRF token matches the client's expected CSRF token.
    pub fn validate(&self, form_csrf_token: &str) -> Result<(), Error> {
        if form_csrf_token != self.expected_csrf_token {
            tracing::debug!(
                csrf.session = self.expected_csrf_token,
                csrf.form = form_csrf_token,
                "invalid CSRF token"
            );
            return InvalidCsrfSnafu.fail();
        }
        Ok(())
    }
}

#[cfg(feature = "cookie")]
fn random_csrf_token() -> String {
    format!("{:x}", rand::random::<u128>())
}