axum-gate 1.1.0

Flexible authentication and authorization for Axum with JWT cookies or bearer tokens, optional OAuth2, and role/group/permission RBAC. Suitable for single-node and distributed systems.
Documentation
use crate::accounts::Account;
use crate::authz::{AccessHierarchy, AccessPolicy, AuthorizationService};
use crate::codecs::Codec;
use crate::codecs::jwt::{JwtClaims, JwtValidationResult, JwtValidationService, RegisteredClaims};

use std::convert::Infallible;
use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
#[cfg(all(feature = "audit-logging", feature = "prometheus"))]
use std::time::Instant;

#[cfg(feature = "audit-logging")]
use crate::audit;
#[cfg(all(feature = "audit-logging", feature = "prometheus"))]
use crate::audit::prometheus_metrics::{JwtValidationOutcome, observe_jwt_validation_latency};
#[cfg(all(feature = "audit-logging", feature = "prometheus"))]
use crate::audit::{AuthzOutcome, observe_authz_latency};
use crate::cookie_template::CookieTemplate;
use axum::{body::Body, extract::Request, http::Response};
use axum_extra::extract::cookie::{Cookie, CookieJar};
use http::StatusCode;
use tower::Service;
use tracing::{debug, trace, warn};

/// Cookie-backed JWT gate service.
///
/// Behavior:
/// - Strict mode (default): validates the JWT from the configured cookie,
///   enforces the configured `AccessPolicy`, and on success inserts
///   `Account<R, G>` and `RegisteredClaims` into request extensions. On failure,
///   responds with 401 Unauthorized.
/// - Optional mode (`allow_anonymous_with_optional_user()` from the builder):
///   never blocks. It inserts only:
///
///   - `Option<Account<R, G>>`
///   - `Option<RegisteredClaims>`
///
///   They are `Some(..)` when a valid JWT cookie is present and `None` otherwise.
///   No concrete types are inserted in this mode. Authorization policy is not
///   evaluated; handlers must enforce any required checks explicitly.
///
/// The cookie name and attributes are derived from the provided `CookieBuilder`.
/// The issuer and JWT validation are configured via the builder that constructs
/// this service.
#[derive(Debug, Clone)]
pub struct CookieGateService<C, R, G, S>
where
    C: Codec,
    R: AccessHierarchy + Eq + std::fmt::Display,
    G: Eq + Clone,
{
    inner: S,
    authorization_service: AuthorizationService<R, G>,
    jwt_validation_service: JwtValidationService<C>,
    cookie_template: CookieTemplate,
    /// If `true`, the service will ALWAYS forward the request and install
    /// `Option<Account<R,G>>` and `Option<RegisteredClaims>` extensions:
    /// - `Some(..)` when a valid JWT cookie is present (policy IS NOT evaluated)
    /// - `None` when missing / invalid token
    ///
    /// In this mode no 401 response is generated by the gate; handler code
    /// must perform any required access checks.
    install_optional_extensions: bool,
}

impl<C, R, G, S> CookieGateService<C, R, G, S>
where
    C: Codec,
    R: AccessHierarchy + Eq + std::fmt::Display,
    G: Eq + Clone,
{
    /// Creates a new instance of a cookie gate service (strict mode).
    pub fn new(
        inner: S,
        issuer: &str,
        policy: AccessPolicy<R, G>,
        codec: Arc<C>,
        cookie_template: CookieTemplate,
    ) -> Self {
        Self {
            inner,
            authorization_service: AuthorizationService::new(policy),
            jwt_validation_service: JwtValidationService::new(codec, issuer),
            cookie_template,
            install_optional_extensions: false,
        }
    }

    /// Creates a new instance in "optional extensions" mode.
    ///
    /// Public for advanced usage; normally constructed by the `Gate` builder when
    /// `with_optional_extensions()` is used.
    pub fn new_with_optional_extensions(
        inner: S,
        issuer: &str,
        codec: Arc<C>,
        cookie_template: CookieTemplate,
    ) -> Self {
        // In optional mode we ignore the policy completely, but we still supply a
        // deny-all policy object (it is never consulted).
        Self {
            inner,
            authorization_service: AuthorizationService::new(AccessPolicy::deny_all()),
            jwt_validation_service: JwtValidationService::new(codec, issuer),
            cookie_template,
            install_optional_extensions: true,
        }
    }
}

impl<C, R, G, S> CookieGateService<C, R, G, S>
where
    C: Codec,
    R: AccessHierarchy + Eq + std::fmt::Display,
    G: Eq + Clone,
{
    /// Queries the axum-gate auth cookie from the request.
    pub fn auth_cookie(&self, req: &Request<Body>) -> Option<Cookie<'_>> {
        let cookie_jar = CookieJar::from_headers(req.headers());
        cookie_jar
            .get(self.cookie_template.cookie_name_ref())
            .cloned()
    }

    /// Used to return the unauthorized response.
    #[allow(clippy::unwrap_used)]
    fn unauthorized() -> Response<Body> {
        Response::builder()
            .status(StatusCode::UNAUTHORIZED)
            .body(Body::from("Unauthorized"))
            .unwrap()
    }
}

impl<C, R, G, S> Service<Request<Body>> for CookieGateService<C, R, G, S>
where
    S: Service<Request<Body>, Response = Response<Body>, Error = Infallible> + Send + 'static,
    S::Future: Send + 'static,
    Account<R, G>: Clone,
    C: Codec<Payload = JwtClaims<Account<R, G>>>,
    R: AccessHierarchy + Eq + std::fmt::Display + Sync + Send + 'static,
    G: Eq + Clone + Sync + Send + 'static,
{
    type Response = Response<Body>;
    type Error = Infallible;
    type Future =
        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;

    fn poll_ready(
        &mut self,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, mut req: Request<Body>) -> Self::Future {
        let unauthorized_future = Box::pin(async move { Ok(Self::unauthorized()) });

        #[cfg(feature = "audit-logging")]
        let _audit_request_span =
            audit::request_span(req.method().as_str(), req.uri().path(), None);
        #[cfg(feature = "audit-logging")]
        let _audit_request_enter = _audit_request_span.enter();

        // OPTIONAL MODE: Always forward; install Option<Account> / Option<RegisteredClaims>
        if self.install_optional_extensions {
            let mut opt_account: Option<Account<R, G>> = None;
            let mut opt_reg_claims: Option<RegisteredClaims> = None;

            if let Some(auth_cookie) = self.auth_cookie(&req) {
                trace!("axum-gate (optional) cookie: {auth_cookie:#?}");
                let cookie_value = auth_cookie.value_trimmed();
                if let JwtValidationResult::Valid(jwt) =
                    self.jwt_validation_service.validate_token(cookie_value)
                {
                    // Valid JWT present; optional mode will only install Option<...> extensions

                    opt_account = Some(jwt.custom_claims.clone());
                    opt_reg_claims = Some(jwt.registered_claims.clone());
                } else {
                    debug!("Optional mode: invalid or mismatched JWT, installing None extensions");
                }
            } else {
                trace!("Optional mode: no auth cookie present; installing None extensions");
            }

            // Always insert the Option variants
            req.extensions_mut().insert(opt_account);
            req.extensions_mut().insert(opt_reg_claims);

            let inner = self.inner.call(req);
            return Box::pin(inner);
        }

        if self.authorization_service.policy_denies_all_access() {
            debug!("Denying access because roles, groups or permissions are empty.");
            #[cfg(feature = "audit-logging")]
            {
                audit::denied(None, "policy_denies_all");
            }
            return unauthorized_future;
        }

        let Some(auth_cookie) = self.auth_cookie(&req) else {
            #[cfg(feature = "audit-logging")]
            {
                audit::denied(None, "missing_cookie");
            }
            return unauthorized_future;
        };
        trace!("axum-gate cookie: {auth_cookie:#?}");

        let cookie_value = auth_cookie.value_trimmed();
        #[cfg(all(feature = "audit-logging", feature = "prometheus"))]
        let jwt_validation_start = Instant::now();
        let jwt = match self.jwt_validation_service.validate_token(cookie_value) {
            JwtValidationResult::Valid(jwt) => {
                #[cfg(all(feature = "audit-logging", feature = "prometheus"))]
                observe_jwt_validation_latency(jwt_validation_start, JwtValidationOutcome::Valid);
                jwt
            }
            JwtValidationResult::InvalidToken => {
                #[cfg(all(feature = "audit-logging", feature = "prometheus"))]
                observe_jwt_validation_latency(
                    jwt_validation_start,
                    JwtValidationOutcome::InvalidToken,
                );
                debug!("JWT token validation failed");
                #[cfg(feature = "audit-logging")]
                {
                    audit::jwt_invalid_token("validation_failed");
                }
                return unauthorized_future;
            }
            JwtValidationResult::InvalidIssuer { expected, actual } => {
                #[cfg(all(feature = "audit-logging", feature = "prometheus"))]
                observe_jwt_validation_latency(
                    jwt_validation_start,
                    JwtValidationOutcome::InvalidIssuer,
                );
                warn!(
                    "JWT issuer validation failed. Expected: '{}', Actual: '{}', Account: {}",
                    expected, actual, "unknown"
                );
                #[cfg(feature = "audit-logging")]
                {
                    audit::jwt_invalid_issuer(&expected, &actual);
                }
                return unauthorized_future;
            }
        };

        debug!("Logged in with id: {}", jwt.custom_claims.account_id);

        #[cfg(feature = "audit-logging")]
        let _authz_span = audit::authorization_span(Some(&jwt.custom_claims.account_id), None);
        #[cfg(feature = "audit-logging")]
        let _authz_enter = _authz_span.enter();

        #[cfg(all(feature = "audit-logging", feature = "prometheus"))]
        let authz_start = std::time::Instant::now();

        let account = &jwt.custom_claims;
        let is_authorized = self.authorization_service.is_authorized(account);

        if !is_authorized {
            #[cfg(feature = "audit-logging")]
            {
                audit::denied(Some(&jwt.custom_claims.account_id), "policy_denied");
            }
            #[cfg(all(feature = "audit-logging", feature = "prometheus"))]
            observe_authz_latency(authz_start, AuthzOutcome::Denied);
            return unauthorized_future;
        }

        #[cfg(feature = "audit-logging")]
        {
            audit::authorized(&jwt.custom_claims.account_id, None);
        }
        #[cfg(all(feature = "audit-logging", feature = "prometheus"))]
        observe_authz_latency(authz_start, AuthzOutcome::Authorized);

        req.extensions_mut().insert(jwt.custom_claims.clone());
        req.extensions_mut().insert(jwt.registered_claims.clone());

        let inner = self.inner.call(req);
        Box::pin(inner)
    }
}