use std::collections::HashMap;
use std::str::FromStr;
use jsonwebtoken::{Algorithm, decode, decode_header, DecodingKey, Validation};
use jsonwebtoken::jwk::{AlgorithmParameters, JwkSet};
use serde_json::Value;
use crate::{ok_or_return_error, some_or_return_error};
use crate::auth::error_kind::{INVALID_TOKEN, JWKS_RETRIEVAL_FAILURE, MALFORMED_TOKEN};
use crate::auth::jwt_token::JwtToken;
use crate::auth::token_validator::TokenValidator;
use crate::error::Error;
pub struct JwtTokenValidator {
jwk_set: JwkSet,
issuers: Vec<String>,
audience: Vec<String>,
}
impl JwtTokenValidator {
pub fn new(jwk_set: JwkSet, issuers: Vec<String>, audience: Vec<String>) -> Self {
Self {
jwk_set,
issuers,
audience,
}
}
}
impl TokenValidator for JwtTokenValidator {
fn validate(&self, token: &str) -> Result<JwtToken, Error> {
let header = ok_or_return_error!(
decode_header(token),
MALFORMED_TOKEN,
"failed to decode token's header: "
);
let kid =
some_or_return_error!(header.kid, MALFORMED_TOKEN, "'kid' is missing from header");
let jwk = some_or_return_error!(
self.jwk_set.find(&kid),
MALFORMED_TOKEN,
"could not find 'kid' within 'jwk_set'"
);
let rsa = match jwk.algorithm {
AlgorithmParameters::RSA(ref rsa) => rsa,
_ => {
return Err(Error::new(
MALFORMED_TOKEN,
format!(
"got an unexpected algorithm for the 'jwk': {:?}",
jwk.algorithm
),
));
}
};
let decoding_key = ok_or_return_error!(
DecodingKey::from_rsa_components(&rsa.n, &rsa.e),
MALFORMED_TOKEN,
"failed to get decoding key: "
);
let key_algorithm = some_or_return_error!(
jwk.common.key_algorithm,
MALFORMED_TOKEN,
"'jwk' is missing the 'key_algorithm'"
);
let algorithm = ok_or_return_error!(
Algorithm::from_str(key_algorithm.to_string().as_str()),
MALFORMED_TOKEN,
"failed to get algorithm for 'key_algorithm': "
);
let mut validation = Validation::new(algorithm);
validation.validate_exp = true;
validation.set_issuer(self.issuers.as_slice());
validation.set_audience(self.audience.as_slice());
let decoded_token = ok_or_return_error!(
decode::<HashMap<String, Value>>(token, &decoding_key, &validation),
INVALID_TOKEN,
"failed to validate token: "
);
Ok(JwtToken::new(decoded_token))
}
}
pub async fn try_get_jwks(jwks_uri: &str) -> Result<JwkSet, Error> {
let response = ok_or_return_error!(
reqwest::get(jwks_uri).await,
JWKS_RETRIEVAL_FAILURE,
"invalid response obtained getting the 'JwkSet'"
);
let jwks = ok_or_return_error!(
response.json::<JwkSet>().await,
JWKS_RETRIEVAL_FAILURE,
"response cannot be deserialized as a 'JwkSet'"
);
Ok(jwks)
}
#[cfg(test)]
pub mod tests {
use std::path::PathBuf;
use crate::auth::error_kind::{INVALID_TOKEN, MALFORMED_TOKEN};
use crate::auth::jwt_token_validator::{JwtTokenValidator, try_get_jwks};
use crate::auth::token_validator::TokenValidator;
use crate::config_reader::ConfigReader;
use crate::secrets::get_secrets_manager;
use crate::test_base::get_unit_test_data_path;
#[tokio::test]
pub async fn try_get_jwks_succeeds_for_example_uri() {
let config_reader = ConfigReader::default();
let config_path = get_config_file();
let config_value = config_reader
.read(config_path)
.expect("expected config reader");
let token_validator_config = config_value
.get("TokenValidator")
.expect("expected 'TokenValidator'");
let jwks_uri = token_validator_config
.get("JwksUri")
.expect("expected 'JwksUri'")
.as_str()
.expect("expected 'JwksUri' as string");
let result = try_get_jwks(jwks_uri).await;
assert!(result.is_ok());
}
fn get_config_file() -> PathBuf {
let mut config_path = get_unit_test_data_path(file!());
config_path.push("config.yaml");
config_path
}
}