use std::collections::HashMap;
use std::fmt;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use crate::cognito::config::VerifierConfig;
use crate::common::error::JwtError;
use crate::jwk::provider::JwkProvider;
use jsonwebtoken::Validation;
#[derive(Debug, thiserror::Error)]
pub enum RegistryError {
#[error("Provider with ID '{0}' already exists")]
ProviderAlreadyExists(String),
#[error("Provider with ID '{0}' not found")]
ProviderNotFound(String),
#[error("No provider found for issuer '{0}'")]
IssuerNotFound(String),
#[error("JWT error: {0}")]
JwtError(#[from] JwtError),
}
#[derive(Debug)]
pub struct JwkProviderRegistry {
providers: RwLock<HashMap<String, Arc<JwkProvider>>>,
issuer_to_id: RwLock<HashMap<String, String>>,
}
impl JwkProviderRegistry {
pub fn new() -> Self {
Self {
providers: RwLock::new(HashMap::new()),
issuer_to_id: RwLock::new(HashMap::new()),
}
}
pub fn register(&self, id: &str, provider: JwkProvider) -> Result<(), RegistryError> {
let issuer = provider.get_issuer().to_string();
let mut providers = self.providers.write().unwrap();
if providers.contains_key(id) {
return Err(RegistryError::ProviderAlreadyExists(id.to_string()));
}
let mut issuer_map = self.issuer_to_id.write().unwrap();
issuer_map.insert(issuer, id.to_string());
providers.insert(id.to_string(), Arc::new(provider));
Ok(())
}
pub fn register_new(
&self,
id: &str,
region: &str,
user_pool_id: &str,
cache_duration: Duration,
) -> Result<(), RegistryError> {
let provider = JwkProvider::new(region, user_pool_id, cache_duration)?;
self.register(id, provider)
}
pub fn get(&self, id: &str) -> Result<Arc<JwkProvider>, RegistryError> {
let providers = self.providers.read().unwrap();
providers
.get(id)
.cloned()
.ok_or_else(|| RegistryError::ProviderNotFound(id.to_string()))
}
pub fn get_by_issuer(&self, issuer: &str) -> Result<Arc<JwkProvider>, RegistryError> {
let id = self.find_provider_id_by_issuer(issuer)?;
self.get(&id)
}
pub fn find_provider_id_by_issuer(&self, issuer: &str) -> Result<String, RegistryError> {
{
let issuer_map = self.issuer_to_id.read().unwrap();
if let Some(id) = issuer_map.get(issuer) {
return Ok(id.clone());
}
}
let providers = self.providers.read().unwrap();
for (id, provider) in providers.iter() {
if provider.get_issuer() == issuer {
let mut issuer_map = self.issuer_to_id.write().unwrap();
issuer_map.insert(issuer.to_string(), id.clone());
return Ok(id.clone());
}
}
Err(RegistryError::IssuerNotFound(issuer.to_string()))
}
pub fn remove(&self, id: &str) -> Result<(), RegistryError> {
let mut providers = self.providers.write().unwrap();
if !providers.contains_key(id) {
return Err(RegistryError::ProviderNotFound(id.to_string()));
}
let issuer = providers.get(id).map(|p| p.get_issuer().to_string());
providers.remove(id);
if let Some(issuer) = issuer {
let mut issuer_map = self.issuer_to_id.write().unwrap();
issuer_map.remove(&issuer);
}
Ok(())
}
pub async fn hydrate(&self) -> Vec<(String, Result<(), JwtError>)> {
let providers = self.providers.read().unwrap();
let mut results = Vec::new();
for (id, provider) in providers.iter() {
let result = provider.prefetch_keys().await;
results.push((id.clone(), result));
}
results
}
pub async fn prefetch(&self, id: &str) -> Result<(), RegistryError> {
let provider = self.get(id)?;
provider.prefetch_keys().await?;
Ok(())
}
pub fn count(&self) -> usize {
let providers = self.providers.read().unwrap();
providers.len()
}
pub fn contains(&self, id: &str) -> bool {
let providers = self.providers.read().unwrap();
providers.contains_key(id)
}
pub fn list_ids(&self) -> Vec<String> {
let providers = self.providers.read().unwrap();
providers.keys().cloned().collect()
}
pub fn create_validation_for_issuer(
&self,
issuer: &str,
clock_skew: Duration,
client_ids: &Vec<String>,
) -> Result<Validation, RegistryError> {
let provider = self.get_by_issuer(issuer)?;
let mut validation = Validation::new(jsonwebtoken::Algorithm::RS256);
validation.set_issuer(&[provider.get_issuer().to_string()]);
validation.set_audience(client_ids);
validation.set_required_spec_claims(&["exp", "iat", "iss", "sub"]);
validation.validate_exp = true;
validation.validate_nbf = true;
validation.validate_aud = true;
validation.leeway = clock_skew.as_secs() as u64;
Ok(validation)
}
}
impl Default for JwkProviderRegistry {
fn default() -> Self {
Self::new()
}
}