crissy 0.1.1

CSRF protection middleware for Axum
Documentation
use std::net::{AddrParseError, IpAddr, SocketAddr};

use axum::{
    extract::{ConnectInfo, FromRequestParts, rejection::ExtensionRejection},
    http::{self, StatusCode, request},
    response::IntoResponse,
};
use snafu::{ResultExt as _, Snafu};

use crate::middleware::ClientIpConfig;

#[derive(Debug, Snafu)]
pub enum Rejection {
    #[snafu(display("could not get connection info"))]
    GetConnectInfo { source: ExtensionRejection },

    #[snafu(display("X-Forwarded-For header has an invalid value"))]
    InvalidXForwardedFor { source: http::header::ToStrError },
    #[snafu(display("X-Forwarded-For header contains an invalid address"))]
    InvalidAddressInXForwardedFor { source: AddrParseError },
}

impl IntoResponse for Rejection {
    fn into_response(self) -> axum::response::Response {
        match self {
            Rejection::GetConnectInfo { source } => source.into_response(),
            Rejection::InvalidXForwardedFor { ref source } => {
                (StatusCode::BAD_REQUEST, format!("{self}: {source}")).into_response()
            }
            Rejection::InvalidAddressInXForwardedFor { ref source } => {
                (StatusCode::BAD_REQUEST, format!("{self}: {source}")).into_response()
            }
        }
    }
}

pub struct ClientAddress {
    pub address: IpAddr,
}

impl<S: Send + Sync + ClientIpConfig> FromRequestParts<S> for ClientAddress {
    type Rejection = Rejection;

    async fn from_request_parts(
        parts: &mut request::Parts,
        state: &S,
    ) -> Result<Self, Self::Rejection> {
        let connect_info = ConnectInfo::<SocketAddr>::from_request_parts(parts, state)
            .await
            .context(GetConnectInfoSnafu)?;
        let mut last_addr = connect_info.0.ip();
        let x_forwarded_for = parts
            .headers
            .get("X-Forwarded-For")
            .map(|v| v.to_str())
            .transpose()
            .context(InvalidXForwardedForSnafu)?;
        for next_addr in x_forwarded_for.into_iter().flat_map(|v| v.split(',')).rev() {
            if !state.is_trusted_forwarder(last_addr) {
                break;
            }
            last_addr = next_addr
                .trim()
                .parse()
                .context(InvalidAddressInXForwardedForSnafu)?;
        }
        Ok(ClientAddress { address: last_addr })
    }
}