use error_util::error::{AppError, HttpError};
use alcoholic_jwt::{token_kid, validate, ValidJWT, Validation, ValidationError, JWKS};
use reqwest::Response;
use tracing::{event, Level};
use serde::{Deserialize, Serialize};
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Root {
pub keys: Vec<Key>,
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Key {
pub kid: String,
pub kty: String,
pub alg: String,
#[serde(rename = "use")]
pub use_field: String,
pub n: String,
pub e: String,
pub x5c: Vec<String>,
pub x5t: String,
#[serde(rename = "x5t#S256")]
pub x5t_s256: String,
}
pub async fn fetch_jwks_idp(url: String) -> Result<JWKS, Box<dyn std::error::Error>> {
let client = reqwest::Client::new();
let res = client.get(url).send().await?.text().await?;
let jwk = serde_json::from_str(res.as_str()).unwrap();
Ok(jwk)
}
pub async fn fetch_jwks_keycloak(url: String) -> Result<JWKS, AppError> {
let client = reqwest::Client::new();
let res = client
.get(url)
.send()
.await
.map_err(AppError::ReqwestAPIError)?;
if !res.status().is_success() {
return Err(handle_error_response(res).await);
};
let root = res
.json::<Root>()
.await
.map_err(AppError::ReqwestAPIError)?;
let keys: Vec<Key> = root.keys.into_iter().filter(|x| x.alg == "RS256").collect();
let root = Root { keys };
let json = serde_json::to_string(&root).unwrap();
let jwks = serde_json::from_str(&json).unwrap();
Ok(jwks)
}
pub fn validate_token(token: &str, jwks: &JWKS, issuer_uri: &str) -> Result<ValidJWT, AppError> {
let validations = vec![
Validation::Issuer(issuer_uri.to_string()),
Validation::NotExpired,
Validation::SubjectPresent
];
let kid = match token_kid(token) {
Ok(k) => match k {
Some(ki) => ki,
None => return Err(AppError::JwksError(ValidationError::InvalidComponents)),
},
Err(e) => return Err(AppError::JwksError(e)),
};
let jwk = match jwks.find(&kid) {
Some(j) => j,
None => return Err(AppError::JwksError(ValidationError::InvalidComponents)),
};
match validate(token, jwk, validations) {
Ok(valid) => Ok(valid),
Err(e) => Err(AppError::JwksError(e)),
}
}
pub fn get_claim_from_token_as_list(valid_jwt: &ValidJWT, claim_name: &str) -> Result<Vec<String>, AppError> {
let token_claims = valid_jwt.claims.clone();
let claim = match token_claims.as_object() {
Some(role_claim) => role_claim.get(claim_name),
None => return Err(AppError::StandardError("could not decode token".to_string())),
};
match claim {
None => Err(AppError::StandardError("could not decode token".to_string())),
Some(roles) => {
let resp = serde_json::from_value(roles.clone())
.map_err(|e| AppError::StandardError(e.to_string()))?;
Ok(resp)
}
}
}
pub fn get_claim_from_token_as_string(valid_jwt: &ValidJWT, claim_name: &str) -> Result<String, AppError> {
let token_claims = valid_jwt.claims.clone();
let claim = match token_claims.as_object() {
Some(role_claim) => role_claim.get(claim_name),
None => return Err(AppError::StandardError("could not decode token".to_string())),
};
match claim {
None => Err(AppError::StandardError("could not decode token".to_string())),
Some(roles) => {
let resp = serde_json::from_value(roles.clone())
.map_err(|e| AppError::StandardError(e.to_string()))?;
Ok(resp)
}
}
}
pub async fn handle_error_response(res: Response) -> AppError {
event!(
Level::ERROR,
"Api error while calling {:?}",
&res.url().as_str()
);
if res.status().is_client_error() {
let err = transform_error(res).await;
event!(Level::ERROR, "{:?}", &err);
AppError::ClientError(err)
} else {
let err = transform_error(res).await;
event!(Level::ERROR, "{:?}", &err);
AppError::ServerError(err)
}
}
pub async fn transform_error(res: Response) -> HttpError {
HttpError {
status: res.status().as_u16(),
message: match res.json::<APIResponse>().await {
Ok(x) => x.message,
Err(e) => {
event!(Level::ERROR, "see reason --> {:?}", e);
e.to_string()
}
},
}
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct APIResponse {
pub message: String,
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ApiErrorResponse {
pub error: String,
pub error_description: String,
}