Skip to main content

oidc_util/security/
validator.rs

1use error_util::error::{AppError, HttpError};
2use alcoholic_jwt::{token_kid, validate, ValidJWT, Validation, ValidationError, JWKS};
3use reqwest::Response;
4use tracing::{event, Level};
5
6use serde::{Deserialize, Serialize};
7
8#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
9#[serde(rename_all = "camelCase")]
10pub struct Root {
11    pub keys: Vec<Key>,
12}
13
14#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
15#[serde(rename_all = "camelCase")]
16pub struct Key {
17    pub kid: String,
18    pub kty: String,
19    pub alg: String,
20    #[serde(rename = "use")]
21    pub use_field: String,
22    pub n: String,
23    pub e: String,
24    pub x5c: Vec<String>,
25    pub x5t: String,
26    #[serde(rename = "x5t#S256")]
27    pub x5t_s256: String,
28}
29
30/// Gets the jwks information from the IDP provider
31pub async fn fetch_jwks_idp(url: String) -> Result<JWKS, Box<dyn std::error::Error>> {
32    let client = reqwest::Client::new();
33    let res = client.get(url).send().await?.text().await?;
34    let jwk = serde_json::from_str(res.as_str()).unwrap();
35    Ok(jwk)
36}
37
38/// Gets the jwks information from the Keycloak provider
39pub async fn fetch_jwks_keycloak(url: String) -> Result<JWKS, AppError> {
40    let client = reqwest::Client::new();
41    let res = client
42        .get(url)
43        .send()
44        .await
45        .map_err(AppError::ReqwestAPIError)?;
46
47    if !res.status().is_success() {
48        return Err(handle_error_response(res).await);
49    };
50
51    let root = res
52        .json::<Root>()
53        .await
54        .map_err(AppError::ReqwestAPIError)?;
55
56    // filter out problematic keys for keycloak
57    let keys: Vec<Key> = root.keys.into_iter().filter(|x| x.alg == "RS256").collect();
58
59    let root = Root { keys };
60
61    let json = serde_json::to_string(&root).unwrap();
62    let jwks = serde_json::from_str(&json).unwrap();
63
64    Ok(jwks)
65}
66
67pub fn validate_token(token: &str, jwks: &JWKS, issuer_uri: &str) -> Result<ValidJWT, AppError> {
68    let validations = vec![
69        Validation::Issuer(issuer_uri.to_string()),
70        Validation::NotExpired,
71        Validation::SubjectPresent
72    ];
73
74    let kid = match token_kid(token) {
75        Ok(k) => match k {
76            Some(ki) => ki,
77            None => return Err(AppError::JwksError(ValidationError::InvalidComponents)),
78        },
79        Err(e) => return Err(AppError::JwksError(e)),
80    };
81
82    let jwk = match jwks.find(&kid) {
83        Some(j) => j,
84        None => return Err(AppError::JwksError(ValidationError::InvalidComponents)),
85    };
86
87    match validate(token, jwk, validations) {
88        Ok(valid) => Ok(valid),
89        Err(e) => Err(AppError::JwksError(e)),
90    }
91}
92
93/// Get claim as Vec<String>
94pub fn get_claim_from_token_as_list(valid_jwt: &ValidJWT, claim_name: &str) -> Result<Vec<String>, AppError> {
95    let token_claims = valid_jwt.claims.clone();
96    let claim = match token_claims.as_object() {
97        Some(role_claim) => role_claim.get(claim_name),
98        None => return Err(AppError::StandardError("could not decode token".to_string())),
99    };
100
101    match claim {
102        None => Err(AppError::StandardError("could not decode token".to_string())),
103        Some(roles) => {
104            let resp = serde_json::from_value(roles.clone())
105                .map_err(|e| AppError::StandardError(e.to_string()))?;
106
107            Ok(resp)
108        }
109    }
110}
111
112/// Get claim as String
113pub fn get_claim_from_token_as_string(valid_jwt: &ValidJWT, claim_name: &str) -> Result<String, AppError> {
114    let token_claims = valid_jwt.claims.clone();
115    let claim = match token_claims.as_object() {
116        Some(role_claim) => role_claim.get(claim_name),
117        None => return Err(AppError::StandardError("could not decode token".to_string())),
118    };
119
120    match claim {
121        None => Err(AppError::StandardError("could not decode token".to_string())),
122        Some(roles) => {
123            let resp = serde_json::from_value(roles.clone())
124                .map_err(|e| AppError::StandardError(e.to_string()))?;
125
126            Ok(resp)
127        }
128    }
129}
130
131pub async fn handle_error_response(res: Response) -> AppError {
132    event!(
133        Level::ERROR,
134        "Api error while calling {:?}",
135        &res.url().as_str()
136    );
137    if res.status().is_client_error() {
138        let err = transform_error(res).await;
139        event!(Level::ERROR, "{:?}", &err);
140        AppError::ClientError(err)
141    } else {
142        let err = transform_error(res).await;
143        event!(Level::ERROR, "{:?}", &err);
144        AppError::ServerError(err)
145    }
146}
147
148pub async fn transform_error(res: Response) -> HttpError {
149    HttpError {
150        status: res.status().as_u16(),
151        message: match res.json::<APIResponse>().await {
152            Ok(x) => x.message,
153            Err(e) => {
154                event!(Level::ERROR, "see reason --> {:?}", e);
155                e.to_string()
156            }
157        },
158    }
159}
160
161#[derive(Deserialize, Serialize, Debug, Clone)]
162pub struct APIResponse {
163    pub message: String,
164}
165
166#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
167pub struct ApiErrorResponse {
168    pub error: String,
169    pub error_description: String,
170}