jwt-verify 0.1.3

JWT verification library for AWS Cognito tokens and any OIDC-compatible IDP
Documentation
use std::collections::HashMap;
use std::fmt;
use std::sync::{Arc, RwLock};
use std::time::Duration;

use crate::cognito::config::VerifierConfig;
use crate::common::error::JwtError;
use crate::jwk::provider::JwkProvider;
use jsonwebtoken::Validation;

/// Error types for JWK Provider Registry operations
#[derive(Debug, thiserror::Error)]
pub enum RegistryError {
    /// Provider with the specified ID already exists
    #[error("Provider with ID '{0}' already exists")]
    ProviderAlreadyExists(String),

    /// Provider with the specified ID was not found
    #[error("Provider with ID '{0}' not found")]
    ProviderNotFound(String),

    /// No provider found for the specified issuer
    #[error("No provider found for issuer '{0}'")]
    IssuerNotFound(String),

    /// JWT error occurred
    #[error("JWT error: {0}")]
    JwtError(#[from] JwtError),
}

/// Registry for managing multiple JWK providers
#[derive(Debug)]
pub struct JwkProviderRegistry {
    /// Map of provider IDs to JWK providers
    providers: RwLock<HashMap<String, Arc<JwkProvider>>>,
    /// Map of issuer URLs to provider IDs
    issuer_to_id: RwLock<HashMap<String, String>>,
}

impl JwkProviderRegistry {
    /// Create a new JWK provider registry
    pub fn new() -> Self {
        Self {
            providers: RwLock::new(HashMap::new()),
            issuer_to_id: RwLock::new(HashMap::new()),
        }
    }

    /// Register a new JWK provider
    pub fn register(&self, id: &str, provider: JwkProvider) -> Result<(), RegistryError> {
        let issuer = provider.get_issuer().to_string();

        // Update providers map
        let mut providers = self.providers.write().unwrap();
        if providers.contains_key(id) {
            return Err(RegistryError::ProviderAlreadyExists(id.to_string()));
        }

        // Update issuer to ID map
        let mut issuer_map = self.issuer_to_id.write().unwrap();
        issuer_map.insert(issuer, id.to_string());

        // Add provider to registry
        providers.insert(id.to_string(), Arc::new(provider));

        Ok(())
    }

    /// Create and register a new JWK provider
    pub fn register_new(
        &self,
        id: &str,
        region: &str,
        user_pool_id: &str,
        cache_duration: Duration,
    ) -> Result<(), RegistryError> {
        let provider = JwkProvider::new(region, user_pool_id, cache_duration)?;
        self.register(id, provider)
    }

    /// Get a JWK provider by ID
    pub fn get(&self, id: &str) -> Result<Arc<JwkProvider>, RegistryError> {
        let providers = self.providers.read().unwrap();
        providers
            .get(id)
            .cloned()
            .ok_or_else(|| RegistryError::ProviderNotFound(id.to_string()))
    }

    /// Get a JWK provider by issuer URL
    pub fn get_by_issuer(&self, issuer: &str) -> Result<Arc<JwkProvider>, RegistryError> {
        // Look up the provider ID for this issuer
        let id = self.find_provider_id_by_issuer(issuer)?;

        // Get the provider by ID
        self.get(&id)
    }

    /// Find a provider ID by issuer URL
    pub fn find_provider_id_by_issuer(&self, issuer: &str) -> Result<String, RegistryError> {
        // First, try to get the provider ID directly from the issuer map
        {
            let issuer_map = self.issuer_to_id.read().unwrap();
            if let Some(id) = issuer_map.get(issuer) {
                return Ok(id.clone());
            }
        }

        // If not found in the map, check all providers
        let providers = self.providers.read().unwrap();
        for (id, provider) in providers.iter() {
            if provider.get_issuer() == issuer {
                // Update the issuer map for future lookups
                let mut issuer_map = self.issuer_to_id.write().unwrap();
                issuer_map.insert(issuer.to_string(), id.clone());

                return Ok(id.clone());
            }
        }

        // If still not found, return an error
        Err(RegistryError::IssuerNotFound(issuer.to_string()))
    }

    // Removed default provider methods as we're using issuer-based provider selection only

    /// Remove a JWK provider
    pub fn remove(&self, id: &str) -> Result<(), RegistryError> {
        let mut providers = self.providers.write().unwrap();
        if !providers.contains_key(id) {
            return Err(RegistryError::ProviderNotFound(id.to_string()));
        }

        // Get the issuer for this provider to remove from the issuer map
        let issuer = providers.get(id).map(|p| p.get_issuer().to_string());

        // Remove from providers map
        providers.remove(id);

        // Remove from issuer map if found
        if let Some(issuer) = issuer {
            let mut issuer_map = self.issuer_to_id.write().unwrap();
            issuer_map.remove(&issuer);
        }

        Ok(())
    }

    /// Prefetch keys for all providers
    pub async fn hydrate(&self) -> Vec<(String, Result<(), JwtError>)> {
        let providers = self.providers.read().unwrap();
        let mut results = Vec::new();

        for (id, provider) in providers.iter() {
            let result = provider.prefetch_keys().await;
            results.push((id.clone(), result));
        }

        results
    }

    /// Prefetch keys for a specific provider
    pub async fn prefetch(&self, id: &str) -> Result<(), RegistryError> {
        let provider = self.get(id)?;
        provider.prefetch_keys().await?;
        Ok(())
    }

    /// Get the number of registered providers
    pub fn count(&self) -> usize {
        let providers = self.providers.read().unwrap();
        providers.len()
    }

    /// Check if a provider with the given ID exists
    pub fn contains(&self, id: &str) -> bool {
        let providers = self.providers.read().unwrap();
        providers.contains_key(id)
    }

    /// Get a list of all provider IDs
    pub fn list_ids(&self) -> Vec<String> {
        let providers = self.providers.read().unwrap();
        providers.keys().cloned().collect()
    }

    /// Create a validation object for a specific issuer
    pub fn create_validation_for_issuer(
        &self,
        issuer: &str,
        clock_skew: Duration,
        client_ids: &Vec<String>,
    ) -> Result<Validation, RegistryError> {
        // Find provider for issuer
        let provider = self.get_by_issuer(issuer)?;

        // Create validation
        let mut validation = Validation::new(jsonwebtoken::Algorithm::RS256);

        validation.set_issuer(&[provider.get_issuer().to_string()]);
        validation.set_audience(client_ids);
        validation.set_required_spec_claims(&["exp", "iat", "iss", "sub"]);
        validation.validate_exp = true;
        validation.validate_nbf = true;
        validation.validate_aud = true;
        validation.leeway = clock_skew.as_secs() as u64;

        Ok(validation)
    }
}

impl Default for JwkProviderRegistry {
    fn default() -> Self {
        Self::new()
    }
}