securitydept-token-set-context 0.2.0-beta.2

Token Set Context of SecurityDept, a layered authentication and authorization toolkit built as reusable Rust crates.
Documentation
use axum::{
    body::Body,
    http::{Request, Response},
};
use axum_reverse_proxy::ReverseProxy;
use http::header::AUTHORIZATION;
use serde::{Deserialize, Serialize};
use typed_builder::TypedBuilder;
use url::Url;

use super::{
    super::{
        DEFAULT_PROPAGATION_HEADER_NAME, PropagatedBearer, PropagationRequestTarget,
        TokenPropagator,
    },
    error::PropagationForwarderError,
};

fn default_proxy_path() -> String {
    "/".to_string()
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TypedBuilder)]
pub struct AxumReverseProxyPropagationForwarderConfig {
    #[builder(default = default_proxy_path())]
    #[serde(default = "default_proxy_path")]
    pub proxy_path: String,
}

impl Default for AxumReverseProxyPropagationForwarderConfig {
    fn default() -> Self {
        Self {
            proxy_path: default_proxy_path(),
        }
    }
}

impl AxumReverseProxyPropagationForwarderConfig {
    pub fn validate(&self) -> Result<(), PropagationForwarderError> {
        if self.proxy_path.is_empty() || !self.proxy_path.starts_with('/') {
            return Err(PropagationForwarderError::Config {
                message: "proxy_path must start with `/`".to_string(),
            });
        }

        Ok(())
    }
}

impl super::PropagationForwarderConfigSource for AxumReverseProxyPropagationForwarderConfig {
    type Forwarder = super::AxumReverseProxyPropagationForwarder;
    type Error = PropagationForwarderError;

    fn build_forwarder(&self) -> Result<Self::Forwarder, Self::Error> {
        super::AxumReverseProxyPropagationForwarder::new(self.clone())
    }
}

#[derive(Debug, Clone)]
pub struct AxumReverseProxyPropagationForwarder {
    config: AxumReverseProxyPropagationForwarderConfig,
}

impl AxumReverseProxyPropagationForwarder {
    pub fn new(
        config: AxumReverseProxyPropagationForwarderConfig,
    ) -> Result<Self, PropagationForwarderError> {
        config.validate()?;
        Ok(Self { config })
    }

    pub fn config(&self) -> &AxumReverseProxyPropagationForwarderConfig {
        &self.config
    }
}

fn prepare_forward_request(
    authorization_header_value: http::HeaderValue,
    request: &mut Request<Body>,
) {
    request
        .headers_mut()
        .remove(DEFAULT_PROPAGATION_HEADER_NAME);
    request
        .headers_mut()
        .insert(AUTHORIZATION, authorization_header_value);
}

impl super::PropagationForwarder for AxumReverseProxyPropagationForwarder {
    type Body = Body;

    async fn forward(
        &self,
        propagator: &TokenPropagator,
        bearer: &PropagatedBearer<'_>,
        target: &PropagationRequestTarget,
        mut request: Request<Body>,
    ) -> Result<Response<Body>, PropagationForwarderError> {
        let authorization_header_value = propagator
            .authorization_header_value(bearer, target)
            .map_err(|source| PropagationForwarderError::TokenPropagator { source })?;
        let origin = propagator
            .resolve_target_origin(target)
            .map_err(|source| PropagationForwarderError::TokenPropagator { source })?;

        prepare_forward_request(authorization_header_value, &mut request);

        let origin = Url::parse(&origin)
            .map_err(|source| PropagationForwarderError::InvalidOrigin { source })?;
        let proxy = ReverseProxy::new(self.config.proxy_path.clone(), origin.to_string());

        let response = proxy
            .proxy_request(request)
            .await
            .expect("reverse proxy is infallible");

        Ok(response)
    }
}