jwt_verify/jwk/
registry.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::sync::{Arc, RwLock};
4use std::time::Duration;
5
6use crate::cognito::config::VerifierConfig;
7use crate::common::error::JwtError;
8use crate::jwk::provider::JwkProvider;
9use jsonwebtoken::Validation;
10
11/// Error types for JWK Provider Registry operations
12#[derive(Debug, thiserror::Error)]
13pub enum RegistryError {
14    /// Provider with the specified ID already exists
15    #[error("Provider with ID '{0}' already exists")]
16    ProviderAlreadyExists(String),
17
18    /// Provider with the specified ID was not found
19    #[error("Provider with ID '{0}' not found")]
20    ProviderNotFound(String),
21
22    /// No provider found for the specified issuer
23    #[error("No provider found for issuer '{0}'")]
24    IssuerNotFound(String),
25
26    /// JWT error occurred
27    #[error("JWT error: {0}")]
28    JwtError(#[from] JwtError),
29}
30
31/// Registry for managing multiple JWK providers
32#[derive(Debug)]
33pub struct JwkProviderRegistry {
34    /// Map of provider IDs to JWK providers
35    providers: RwLock<HashMap<String, Arc<JwkProvider>>>,
36    /// Map of issuer URLs to provider IDs
37    issuer_to_id: RwLock<HashMap<String, String>>,
38}
39
40impl JwkProviderRegistry {
41    /// Create a new JWK provider registry
42    pub fn new() -> Self {
43        Self {
44            providers: RwLock::new(HashMap::new()),
45            issuer_to_id: RwLock::new(HashMap::new()),
46        }
47    }
48
49    /// Register a new JWK provider
50    pub fn register(&self, id: &str, provider: JwkProvider) -> Result<(), RegistryError> {
51        let issuer = provider.get_issuer().to_string();
52
53        // Update providers map
54        let mut providers = self.providers.write().unwrap();
55        if providers.contains_key(id) {
56            return Err(RegistryError::ProviderAlreadyExists(id.to_string()));
57        }
58
59        // Update issuer to ID map
60        let mut issuer_map = self.issuer_to_id.write().unwrap();
61        issuer_map.insert(issuer, id.to_string());
62
63        // Add provider to registry
64        providers.insert(id.to_string(), Arc::new(provider));
65
66        Ok(())
67    }
68
69    /// Create and register a new JWK provider
70    pub fn register_new(
71        &self,
72        id: &str,
73        region: &str,
74        user_pool_id: &str,
75        cache_duration: Duration,
76    ) -> Result<(), RegistryError> {
77        let provider = JwkProvider::new(region, user_pool_id, cache_duration)?;
78        self.register(id, provider)
79    }
80
81    /// Get a JWK provider by ID
82    pub fn get(&self, id: &str) -> Result<Arc<JwkProvider>, RegistryError> {
83        let providers = self.providers.read().unwrap();
84        providers
85            .get(id)
86            .cloned()
87            .ok_or_else(|| RegistryError::ProviderNotFound(id.to_string()))
88    }
89
90    /// Get a JWK provider by issuer URL
91    pub fn get_by_issuer(&self, issuer: &str) -> Result<Arc<JwkProvider>, RegistryError> {
92        // Look up the provider ID for this issuer
93        let id = self.find_provider_id_by_issuer(issuer)?;
94
95        // Get the provider by ID
96        self.get(&id)
97    }
98
99    /// Find a provider ID by issuer URL
100    pub fn find_provider_id_by_issuer(&self, issuer: &str) -> Result<String, RegistryError> {
101        // First, try to get the provider ID directly from the issuer map
102        {
103            let issuer_map = self.issuer_to_id.read().unwrap();
104            if let Some(id) = issuer_map.get(issuer) {
105                return Ok(id.clone());
106            }
107        }
108
109        // If not found in the map, check all providers
110        let providers = self.providers.read().unwrap();
111        for (id, provider) in providers.iter() {
112            if provider.get_issuer() == issuer {
113                // Update the issuer map for future lookups
114                let mut issuer_map = self.issuer_to_id.write().unwrap();
115                issuer_map.insert(issuer.to_string(), id.clone());
116
117                return Ok(id.clone());
118            }
119        }
120
121        // If still not found, return an error
122        Err(RegistryError::IssuerNotFound(issuer.to_string()))
123    }
124
125    // Removed default provider methods as we're using issuer-based provider selection only
126
127    /// Remove a JWK provider
128    pub fn remove(&self, id: &str) -> Result<(), RegistryError> {
129        let mut providers = self.providers.write().unwrap();
130        if !providers.contains_key(id) {
131            return Err(RegistryError::ProviderNotFound(id.to_string()));
132        }
133
134        // Get the issuer for this provider to remove from the issuer map
135        let issuer = providers.get(id).map(|p| p.get_issuer().to_string());
136
137        // Remove from providers map
138        providers.remove(id);
139
140        // Remove from issuer map if found
141        if let Some(issuer) = issuer {
142            let mut issuer_map = self.issuer_to_id.write().unwrap();
143            issuer_map.remove(&issuer);
144        }
145
146        Ok(())
147    }
148
149    /// Prefetch keys for all providers
150    pub async fn hydrate(&self) -> Vec<(String, Result<(), JwtError>)> {
151        let providers = self.providers.read().unwrap();
152        let mut results = Vec::new();
153
154        for (id, provider) in providers.iter() {
155            let result = provider.prefetch_keys().await;
156            results.push((id.clone(), result));
157        }
158
159        results
160    }
161
162    /// Prefetch keys for a specific provider
163    pub async fn prefetch(&self, id: &str) -> Result<(), RegistryError> {
164        let provider = self.get(id)?;
165        provider.prefetch_keys().await?;
166        Ok(())
167    }
168
169    /// Get the number of registered providers
170    pub fn count(&self) -> usize {
171        let providers = self.providers.read().unwrap();
172        providers.len()
173    }
174
175    /// Check if a provider with the given ID exists
176    pub fn contains(&self, id: &str) -> bool {
177        let providers = self.providers.read().unwrap();
178        providers.contains_key(id)
179    }
180
181    /// Get a list of all provider IDs
182    pub fn list_ids(&self) -> Vec<String> {
183        let providers = self.providers.read().unwrap();
184        providers.keys().cloned().collect()
185    }
186
187    /// Create a validation object for a specific issuer
188    pub fn create_validation_for_issuer(
189        &self,
190        issuer: &str,
191        clock_skew: Duration,
192        client_ids: &Vec<String>,
193    ) -> Result<Validation, RegistryError> {
194        // Find provider for issuer
195        let provider = self.get_by_issuer(issuer)?;
196
197        // Create validation
198        let mut validation = Validation::new(jsonwebtoken::Algorithm::RS256);
199
200        validation.set_issuer(&[provider.get_issuer().to_string()]);
201        validation.set_audience(client_ids);
202        validation.set_required_spec_claims(&["exp", "iat", "iss", "sub"]);
203        validation.validate_exp = true;
204        validation.validate_nbf = true;
205        validation.validate_aud = true;
206        validation.leeway = clock_skew.as_secs() as u64;
207
208        Ok(validation)
209    }
210}
211
212impl Default for JwkProviderRegistry {
213    fn default() -> Self {
214        Self::new()
215    }
216}