oidc_util/security/
validator.rs1use error_util::error::{AppError, HttpError};
2use alcoholic_jwt::{token_kid, validate, ValidJWT, Validation, ValidationError, JWKS};
3use reqwest::Response;
4use tracing::{event, Level};
5
6use serde::{Deserialize, Serialize};
7
8#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
9#[serde(rename_all = "camelCase")]
10pub struct Root {
11 pub keys: Vec<Key>,
12}
13
14#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
15#[serde(rename_all = "camelCase")]
16pub struct Key {
17 pub kid: String,
18 pub kty: String,
19 pub alg: String,
20 #[serde(rename = "use")]
21 pub use_field: String,
22 pub n: String,
23 pub e: String,
24 pub x5c: Vec<String>,
25 pub x5t: String,
26 #[serde(rename = "x5t#S256")]
27 pub x5t_s256: String,
28}
29
30pub async fn fetch_jwks_idp(url: String) -> Result<JWKS, Box<dyn std::error::Error>> {
32 let client = reqwest::Client::new();
33 let res = client.get(url).send().await?.text().await?;
34 let jwk = serde_json::from_str(res.as_str()).unwrap();
35 Ok(jwk)
36}
37
38pub async fn fetch_jwks_keycloak(url: String) -> Result<JWKS, AppError> {
40 let client = reqwest::Client::new();
41 let res = client
42 .get(url)
43 .send()
44 .await
45 .map_err(AppError::ReqwestAPIError)?;
46
47 if !res.status().is_success() {
48 return Err(handle_error_response(res).await);
49 };
50
51 let root = res
52 .json::<Root>()
53 .await
54 .map_err(AppError::ReqwestAPIError)?;
55
56 let keys: Vec<Key> = root.keys.into_iter().filter(|x| x.alg == "RS256").collect();
58
59 let root = Root { keys };
60
61 let json = serde_json::to_string(&root).unwrap();
62 let jwks = serde_json::from_str(&json).unwrap();
63
64 Ok(jwks)
65}
66
67pub fn validate_token(token: &str, jwks: &JWKS, issuer_uri: &str) -> Result<ValidJWT, AppError> {
68 let validations = vec![
69 Validation::Issuer(issuer_uri.to_string()),
70 Validation::NotExpired,
71 Validation::SubjectPresent
72 ];
73
74 let kid = match token_kid(token) {
75 Ok(k) => match k {
76 Some(ki) => ki,
77 None => return Err(AppError::JwksError(ValidationError::InvalidComponents)),
78 },
79 Err(e) => return Err(AppError::JwksError(e)),
80 };
81
82 let jwk = match jwks.find(&kid) {
83 Some(j) => j,
84 None => return Err(AppError::JwksError(ValidationError::InvalidComponents)),
85 };
86
87 match validate(token, jwk, validations) {
88 Ok(valid) => Ok(valid),
89 Err(e) => Err(AppError::JwksError(e)),
90 }
91}
92
93pub fn get_claim_from_token_as_list(valid_jwt: &ValidJWT, claim_name: &str) -> Result<Vec<String>, AppError> {
95 let token_claims = valid_jwt.claims.clone();
96 let claim = match token_claims.as_object() {
97 Some(role_claim) => role_claim.get(claim_name),
98 None => return Err(AppError::StandardError("could not decode token".to_string())),
99 };
100
101 match claim {
102 None => Err(AppError::StandardError("could not decode token".to_string())),
103 Some(roles) => {
104 let resp = serde_json::from_value(roles.clone())
105 .map_err(|e| AppError::StandardError(e.to_string()))?;
106
107 Ok(resp)
108 }
109 }
110}
111
112pub fn get_claim_from_token_as_string(valid_jwt: &ValidJWT, claim_name: &str) -> Result<String, AppError> {
114 let token_claims = valid_jwt.claims.clone();
115 let claim = match token_claims.as_object() {
116 Some(role_claim) => role_claim.get(claim_name),
117 None => return Err(AppError::StandardError("could not decode token".to_string())),
118 };
119
120 match claim {
121 None => Err(AppError::StandardError("could not decode token".to_string())),
122 Some(roles) => {
123 let resp = serde_json::from_value(roles.clone())
124 .map_err(|e| AppError::StandardError(e.to_string()))?;
125
126 Ok(resp)
127 }
128 }
129}
130
131pub async fn handle_error_response(res: Response) -> AppError {
132 event!(
133 Level::ERROR,
134 "Api error while calling {:?}",
135 &res.url().as_str()
136 );
137 if res.status().is_client_error() {
138 let err = transform_error(res).await;
139 event!(Level::ERROR, "{:?}", &err);
140 AppError::ClientError(err)
141 } else {
142 let err = transform_error(res).await;
143 event!(Level::ERROR, "{:?}", &err);
144 AppError::ServerError(err)
145 }
146}
147
148pub async fn transform_error(res: Response) -> HttpError {
149 HttpError {
150 status: res.status().as_u16(),
151 message: match res.json::<APIResponse>().await {
152 Ok(x) => x.message,
153 Err(e) => {
154 event!(Level::ERROR, "see reason --> {:?}", e);
155 e.to_string()
156 }
157 },
158 }
159}
160
161#[derive(Deserialize, Serialize, Debug, Clone)]
162pub struct APIResponse {
163 pub message: String,
164}
165
166#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
167pub struct ApiErrorResponse {
168 pub error: String,
169 pub error_description: String,
170}