middleware 0.1.0

my custom middleware
Documentation
pub mod types;

use axum::{
    body::Body,
    extract::{FromRequestParts, Request, State},
    http::{self, StatusCode},
    middleware::Next,
    response::Response,
};
use axum_extra::extract::CookieJar;
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
use serde::Deserialize;
use service_error::ServiceError;

use crate::types::{Jwks, JwksTrait};

#[derive(Debug, Deserialize, Clone)]
pub struct RealmAccess {
    pub roles: Vec<String>,
}

#[derive(Debug, Deserialize, Clone)]
pub struct ResourceRoles {
    pub roles: Vec<String>,
}

#[derive(Debug, Deserialize, Clone)]
pub struct Claims {
    pub sub: Option<String>,
    pub exp: Option<u64>,
    pub iat: Option<u64>,
    pub iss: Option<String>,
    pub aud: Option<String>,
    pub typ: Option<String>,
    pub azp: Option<String>,
    pub name: Option<String>,
    pub preferred_username: Option<String>,
    pub given_name: Option<String>,
    pub family_name: Option<String>,
    pub email: Option<String>,
    pub email_verified: Option<bool>,
    pub realm_access: Option<RealmAccess>,
    pub resource_access: Option<std::collections::HashMap<String, ResourceRoles>>,
}

#[derive(Debug)]
pub struct ClaimsExtension(pub Claims);

impl<S> FromRequestParts<S> for ClaimsExtension
where
    S: Send + Sync,
{
    type Rejection = ServiceError;

    async fn from_request_parts(
        parts: &mut http::request::Parts,
        _state: &S,
    ) -> Result<Self, Self::Rejection> {
        // Access to headers from req
        let ext = &parts.extensions;

        let claims = ext.get::<Claims>();

        match claims {
            Some(claims) => Ok(ClaimsExtension(claims.clone())),
            None => Err(ServiceError::Internal(
                "Claims not found in request extensions".to_string(),
            )),
        }
    }
}

pub async fn validate_auth_cookie<T>(
    State(state): State<std::sync::Arc<T>>,
    cookies: CookieJar,
    mut req: Request<Body>,
    next: Next,
) -> Result<Response, StatusCode>
where
    T: JwksTrait,
{
    let token = get_auth_token(&cookies)?;

    match validate_token(token, &state.jwks()) {
        Ok(claims) => {
            req.extensions_mut().insert(claims);
            Ok(next.run(req).await)
        }
        Err(_) => Err(StatusCode::UNAUTHORIZED),
    }
}

fn get_auth_token(cookies: &CookieJar) -> Result<&str, StatusCode> {
    match cookies.get("access_token") {
        Some(token) => Ok(token.value()),
        None => Err(StatusCode::UNAUTHORIZED),
    }
}

fn validate_token(token: &str, jwks: &Jwks) -> Result<Claims, StatusCode> {
    let header = decode_header(token).map_err(|_| StatusCode::UNAUTHORIZED)?;
    let kid = header.kid.ok_or(StatusCode::UNAUTHORIZED)?;

    // Find matching key
    let jwk = jwks
        .keys
        .iter()
        .find(|k| k.kid == kid)
        .ok_or(StatusCode::UNAUTHORIZED)?;

    let decoding_key =
        DecodingKey::from_rsa_components(&jwk.n, &jwk.e).map_err(|_| StatusCode::UNAUTHORIZED)?;

    let validation = Validation::new(Algorithm::RS256);

    let token_data = decode::<Claims>(token, &decoding_key, &validation)
        .map_err(|_| StatusCode::UNAUTHORIZED)?;

    Ok(token_data.claims)
}