use anyhow::Result;
use serde::de::DeserializeOwned;
use std::collections::HashMap;
use crate::claims::{OidcAccessTokenClaims, OidcIdTokenClaims, OidcJwtClaims};
use crate::common::error::{ErrorLogger, ErrorVerbosity, JwtError};
use crate::common::token::TokenParser;
use crate::jwk::provider::JwkProvider;
use crate::jwk::registry::{JwkProviderRegistry, RegistryError};
use crate::oidc::config::OidcProviderConfig;
use crate::oidc::discovery::OidcDiscovery;
use crate::verifier::{AccessTokenClaims, IdTokenClaims, JwtVerifier};
#[derive(Debug)]
pub struct OidcJwtVerifier {
jwk_registry: JwkProviderRegistry,
configs: HashMap<String, OidcProviderConfig>,
discovery: OidcDiscovery,
error_logger: ErrorLogger,
}
impl OidcJwtVerifier {
pub fn new(configs: Vec<OidcProviderConfig>) -> Result<Self, JwtError> {
let mut verifier = Self {
jwk_registry: JwkProviderRegistry::new(),
configs: HashMap::new(),
discovery: OidcDiscovery::new(std::time::Duration::from_secs(3600 * 24)), error_logger: ErrorLogger::new(ErrorVerbosity::Standard),
};
for config in configs {
let issuer = config.issuer.clone();
verifier.add_provider(&issuer, config)?;
}
Ok(verifier)
}
pub fn new_single_provider(
issuer: &str,
jwks_url: Option<&str>,
client_ids: &[String],
) -> Result<Self, JwtError> {
let config = OidcProviderConfig::new(issuer, jwks_url, client_ids, None)?;
Self::new(vec![config])
}
pub fn add_provider(&mut self, id: &str, config: OidcProviderConfig) -> Result<(), JwtError> {
let issuer = config.issuer.clone();
let jwks_url = config
.jwks_url
.as_ref()
.ok_or_else(|| JwtError::ConfigurationError {
parameter: Some("jwks_url".to_string()),
error: "JWKS URL is required".to_string(),
})?;
let jwk_provider =
JwkProvider::from_jwks_url(jwks_url, &issuer, config.jwk_cache_duration)?;
self.jwk_registry
.register(id, jwk_provider)
.map_err(|e| match e {
RegistryError::JwtError(err) => err,
other => JwtError::ConfigurationError {
parameter: Some("provider_registration".to_string()),
error: other.to_string(),
},
})?;
self.configs.insert(id.to_string(), config);
Ok(())
}
pub fn add_provider_with_discovery(
&mut self,
id: &str,
issuer: &str,
client_ids: &[String],
) -> Result<(), JwtError> {
let config = OidcProviderConfig::with_discovery(issuer, client_ids)?;
self.add_provider(id, config)
}
pub fn get_provider_ids(&self) -> Vec<String> {
self.jwk_registry.list_ids()
}
pub fn remove_provider(&mut self, id: &str) -> Result<(), JwtError> {
self.jwk_registry.remove(id).map_err(|e| match e {
RegistryError::ProviderNotFound(_) => JwtError::ConfigurationError {
parameter: Some("provider_id".to_string()),
error: format!("Provider '{}' not found", id),
},
RegistryError::JwtError(err) => err,
other => JwtError::ConfigurationError {
parameter: Some("provider_id".to_string()),
error: other.to_string(),
},
})?;
self.configs.remove(id);
Ok(())
}
pub fn set_error_verbosity(&mut self, verbosity: ErrorVerbosity) {
self.error_logger = ErrorLogger::new(verbosity);
}
pub fn set_discovery_cache_duration(&mut self, duration: std::time::Duration) {
self.discovery = OidcDiscovery::new(duration);
}
pub async fn hydrate(&self) -> Vec<(String, Result<(), JwtError>)> {
self.jwk_registry.hydrate().await
}
pub async fn verify<T>(&self, token: &str) -> Result<T, JwtError>
where
T: DeserializeOwned + TryFrom<OidcJwtClaims, Error = JwtError>,
{
let issuer = TokenParser::extract_issuer(token)?;
let provider_id = self.find_provider_id_by_issuer(&issuer)?;
self.verify_generic_with_provider::<T>(token, &provider_id)
.await
}
fn find_provider_id_by_issuer(&self, issuer: &str) -> Result<String, JwtError> {
for (id, config) in &self.configs {
if config.issuer == issuer {
return Ok(id.clone());
}
}
Err(JwtError::InvalidIssuer {
expected: "a registered OIDC provider".to_string(),
actual: issuer.to_string(),
})
}
async fn verify_generic_with_provider<T>(
&self,
token: &str,
provider_id: &str,
) -> Result<T, JwtError>
where
T: DeserializeOwned + TryFrom<OidcJwtClaims, Error = JwtError>,
{
let jwk_provider = self.jwk_registry.get(provider_id).map_err(|e| match e {
RegistryError::ProviderNotFound(_) => JwtError::ConfigurationError {
parameter: Some("provider_id".to_string()),
error: format!("Provider '{}' not found", provider_id),
},
RegistryError::JwtError(err) => err,
other => JwtError::ConfigurationError {
parameter: Some("provider_id".to_string()),
error: other.to_string(),
},
})?;
let config = self
.configs
.get(provider_id)
.ok_or_else(|| JwtError::ConfigurationError {
parameter: Some("provider_id".to_string()),
error: format!("Configuration for provider '{}' not found", provider_id),
})?;
let header = TokenParser::parse_token_header(token)?;
let key = jwk_provider.get_key(&header.kid).await?;
let validation = self
.jwk_registry
.create_validation_for_issuer(
jwk_provider.get_issuer(),
config.clock_skew,
&config.client_ids,
)
.map_err(|e| match e {
RegistryError::JwtError(err) => err,
other => JwtError::ConfigurationError {
parameter: Some("validation".to_string()),
error: other.to_string(),
},
})?;
let claims: OidcJwtClaims = TokenParser::parse_token_claims(token, &key, &validation)?;
self.validate_claims(&claims, config)?;
T::try_from(claims)
}
fn validate_claims(
&self,
claims: &OidcJwtClaims,
config: &OidcProviderConfig,
) -> Result<(), JwtError> {
if !claims.validate_issuer(&config.issuer) {
return Err(JwtError::InvalidClaim {
claim: "iss".to_string(),
reason: format!("Expected issuer {}, got {}", config.issuer, claims.iss),
value: Some(claims.iss.clone()),
});
}
if !claims.validate_client_id(&config.client_ids) {
return Err(JwtError::InvalidClaim {
claim: "aud/azp".to_string(),
reason: format!("Client ID not allowed: {}", claims.aud),
value: Some(claims.aud.clone()),
});
}
for claim_name in &config.required_claims {
match claim_name.as_str() {
"sub" => {
if claims.sub.is_empty() {
return Err(JwtError::InvalidClaim {
claim: "sub".to_string(),
reason: "Subject claim is empty".to_string(),
value: None,
});
}
}
"iss" => {
if claims.iss.is_empty() {
return Err(JwtError::InvalidClaim {
claim: "iss".to_string(),
reason: "Issuer claim is empty".to_string(),
value: None,
});
}
}
"aud" => {
}
"exp" => {
}
"iat" => {
}
"auth_time" => {
if claims.auth_time.is_none() {
return Err(JwtError::InvalidClaim {
claim: "auth_time".to_string(),
reason: "Authentication time claim is missing".to_string(),
value: None,
});
}
}
"nonce" => {
if claims.nonce.is_none() {
return Err(JwtError::InvalidClaim {
claim: "nonce".to_string(),
reason: "Nonce claim is missing".to_string(),
value: None,
});
}
}
_ => {
if !claims.custom_claims.contains_key(claim_name) {
return Err(JwtError::InvalidClaim {
claim: claim_name.clone(),
reason: format!("Required claim {} is missing", claim_name),
value: None,
});
}
}
}
}
Ok(())
}
}
impl JwtVerifier for OidcJwtVerifier {
async fn verify_id_token(&self, token: &str) -> Result<Box<dyn IdTokenClaims>, JwtError> {
let claims = self.verify::<OidcIdTokenClaims>(token).await?;
Ok(Box::new(claims))
}
async fn verify_access_token(
&self,
token: &str,
) -> Result<Box<dyn AccessTokenClaims>, JwtError> {
let claims = self.verify::<OidcAccessTokenClaims>(token).await?;
Ok(Box::new(claims))
}
}
impl IdTokenClaims for OidcIdTokenClaims {
fn get_sub(&self) -> &str {
&self.base.sub
}
fn get_iss(&self) -> &str {
&self.base.iss
}
fn get_aud(&self) -> &str {
&self.base.aud
}
fn get_exp(&self) -> u64 {
self.base.exp
}
fn get_iat(&self) -> u64 {
self.base.iat
}
fn get_email(&self) -> Option<&str> {
self.email.as_deref()
}
fn is_email_verified(&self) -> bool {
self.email_verified.unwrap_or(false)
}
fn get_name(&self) -> Option<&str> {
self.name.as_deref()
}
}
impl AccessTokenClaims for OidcAccessTokenClaims {
fn get_sub(&self) -> &str {
&self.base.sub
}
fn get_iss(&self) -> &str {
&self.base.iss
}
fn get_aud(&self) -> &str {
&self.base.aud
}
fn get_exp(&self) -> u64 {
self.base.exp
}
fn get_iat(&self) -> u64 {
self.base.iat
}
fn get_scopes(&self) -> Vec<String> {
match &self.scope {
Some(scope) => scope.split_whitespace().map(|s| s.to_string()).collect(),
None => Vec::new(),
}
}
fn has_scope(&self, scope: &str) -> bool {
self.get_scopes().contains(&scope.to_string())
}
fn get_client_id(&self) -> Option<&str> {
self.client_id.as_deref()
}
}