oidc-util 0.0.1

OIDC utility
Documentation
use error_util::error::{AppError, HttpError};
use alcoholic_jwt::{token_kid, validate, ValidJWT, Validation, ValidationError, JWKS};
use reqwest::Response;
use tracing::{event, Level};

use serde::{Deserialize, Serialize};

#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Root {
    pub keys: Vec<Key>,
}

#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Key {
    pub kid: String,
    pub kty: String,
    pub alg: String,
    #[serde(rename = "use")]
    pub use_field: String,
    pub n: String,
    pub e: String,
    pub x5c: Vec<String>,
    pub x5t: String,
    #[serde(rename = "x5t#S256")]
    pub x5t_s256: String,
}

/// Gets the jwks information from the IDP provider
pub async fn fetch_jwks_idp(url: String) -> Result<JWKS, Box<dyn std::error::Error>> {
    let client = reqwest::Client::new();
    let res = client.get(url).send().await?.text().await?;
    let jwk = serde_json::from_str(res.as_str()).unwrap();
    Ok(jwk)
}

/// Gets the jwks information from the Keycloak provider
pub async fn fetch_jwks_keycloak(url: String) -> Result<JWKS, AppError> {
    let client = reqwest::Client::new();
    let res = client
        .get(url)
        .send()
        .await
        .map_err(AppError::ReqwestAPIError)?;

    if !res.status().is_success() {
        return Err(handle_error_response(res).await);
    };

    let root = res
        .json::<Root>()
        .await
        .map_err(AppError::ReqwestAPIError)?;

    // filter out problematic keys for keycloak
    let keys: Vec<Key> = root.keys.into_iter().filter(|x| x.alg == "RS256").collect();

    let root = Root { keys };

    let json = serde_json::to_string(&root).unwrap();
    let jwks = serde_json::from_str(&json).unwrap();

    Ok(jwks)
}

pub fn validate_token(token: &str, jwks: &JWKS, issuer_uri: &str) -> Result<ValidJWT, AppError> {
    let validations = vec![
        Validation::Issuer(issuer_uri.to_string()),
        Validation::NotExpired,
        Validation::SubjectPresent
    ];

    let kid = match token_kid(token) {
        Ok(k) => match k {
            Some(ki) => ki,
            None => return Err(AppError::JwksError(ValidationError::InvalidComponents)),
        },
        Err(e) => return Err(AppError::JwksError(e)),
    };

    let jwk = match jwks.find(&kid) {
        Some(j) => j,
        None => return Err(AppError::JwksError(ValidationError::InvalidComponents)),
    };

    match validate(token, jwk, validations) {
        Ok(valid) => Ok(valid),
        Err(e) => Err(AppError::JwksError(e)),
    }
}

/// Get claim as Vec<String>
pub fn get_claim_from_token_as_list(valid_jwt: &ValidJWT, claim_name: &str) -> Result<Vec<String>, AppError> {
    let token_claims = valid_jwt.claims.clone();
    let claim = match token_claims.as_object() {
        Some(role_claim) => role_claim.get(claim_name),
        None => return Err(AppError::StandardError("could not decode token".to_string())),
    };

    match claim {
        None => Err(AppError::StandardError("could not decode token".to_string())),
        Some(roles) => {
            let resp = serde_json::from_value(roles.clone())
                .map_err(|e| AppError::StandardError(e.to_string()))?;

            Ok(resp)
        }
    }
}

/// Get claim as String
pub fn get_claim_from_token_as_string(valid_jwt: &ValidJWT, claim_name: &str) -> Result<String, AppError> {
    let token_claims = valid_jwt.claims.clone();
    let claim = match token_claims.as_object() {
        Some(role_claim) => role_claim.get(claim_name),
        None => return Err(AppError::StandardError("could not decode token".to_string())),
    };

    match claim {
        None => Err(AppError::StandardError("could not decode token".to_string())),
        Some(roles) => {
            let resp = serde_json::from_value(roles.clone())
                .map_err(|e| AppError::StandardError(e.to_string()))?;

            Ok(resp)
        }
    }
}

pub async fn handle_error_response(res: Response) -> AppError {
    event!(
        Level::ERROR,
        "Api error while calling {:?}",
        &res.url().as_str()
    );
    if res.status().is_client_error() {
        let err = transform_error(res).await;
        event!(Level::ERROR, "{:?}", &err);
        AppError::ClientError(err)
    } else {
        let err = transform_error(res).await;
        event!(Level::ERROR, "{:?}", &err);
        AppError::ServerError(err)
    }
}

pub async fn transform_error(res: Response) -> HttpError {
    HttpError {
        status: res.status().as_u16(),
        message: match res.json::<APIResponse>().await {
            Ok(x) => x.message,
            Err(e) => {
                event!(Level::ERROR, "see reason --> {:?}", e);
                e.to_string()
            }
        },
    }
}

#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct APIResponse {
    pub message: String,
}

#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ApiErrorResponse {
    pub error: String,
    pub error_description: String,
}