jwt_verify/jwk/
registry.rs1use std::collections::HashMap;
2use std::fmt;
3use std::sync::{Arc, RwLock};
4use std::time::Duration;
5
6use crate::cognito::config::VerifierConfig;
7use crate::common::error::JwtError;
8use crate::jwk::provider::JwkProvider;
9use jsonwebtoken::Validation;
10
11#[derive(Debug, thiserror::Error)]
13pub enum RegistryError {
14 #[error("Provider with ID '{0}' already exists")]
16 ProviderAlreadyExists(String),
17
18 #[error("Provider with ID '{0}' not found")]
20 ProviderNotFound(String),
21
22 #[error("No provider found for issuer '{0}'")]
24 IssuerNotFound(String),
25
26 #[error("JWT error: {0}")]
28 JwtError(#[from] JwtError),
29}
30
31#[derive(Debug)]
33pub struct JwkProviderRegistry {
34 providers: RwLock<HashMap<String, Arc<JwkProvider>>>,
36 issuer_to_id: RwLock<HashMap<String, String>>,
38}
39
40impl JwkProviderRegistry {
41 pub fn new() -> Self {
43 Self {
44 providers: RwLock::new(HashMap::new()),
45 issuer_to_id: RwLock::new(HashMap::new()),
46 }
47 }
48
49 pub fn register(&self, id: &str, provider: JwkProvider) -> Result<(), RegistryError> {
51 let issuer = provider.get_issuer().to_string();
52
53 let mut providers = self.providers.write().unwrap();
55 if providers.contains_key(id) {
56 return Err(RegistryError::ProviderAlreadyExists(id.to_string()));
57 }
58
59 let mut issuer_map = self.issuer_to_id.write().unwrap();
61 issuer_map.insert(issuer, id.to_string());
62
63 providers.insert(id.to_string(), Arc::new(provider));
65
66 Ok(())
67 }
68
69 pub fn register_new(
71 &self,
72 id: &str,
73 region: &str,
74 user_pool_id: &str,
75 cache_duration: Duration,
76 ) -> Result<(), RegistryError> {
77 let provider = JwkProvider::new(region, user_pool_id, cache_duration)?;
78 self.register(id, provider)
79 }
80
81 pub fn get(&self, id: &str) -> Result<Arc<JwkProvider>, RegistryError> {
83 let providers = self.providers.read().unwrap();
84 providers
85 .get(id)
86 .cloned()
87 .ok_or_else(|| RegistryError::ProviderNotFound(id.to_string()))
88 }
89
90 pub fn get_by_issuer(&self, issuer: &str) -> Result<Arc<JwkProvider>, RegistryError> {
92 let id = self.find_provider_id_by_issuer(issuer)?;
94
95 self.get(&id)
97 }
98
99 pub fn find_provider_id_by_issuer(&self, issuer: &str) -> Result<String, RegistryError> {
101 {
103 let issuer_map = self.issuer_to_id.read().unwrap();
104 if let Some(id) = issuer_map.get(issuer) {
105 return Ok(id.clone());
106 }
107 }
108
109 let providers = self.providers.read().unwrap();
111 for (id, provider) in providers.iter() {
112 if provider.get_issuer() == issuer {
113 let mut issuer_map = self.issuer_to_id.write().unwrap();
115 issuer_map.insert(issuer.to_string(), id.clone());
116
117 return Ok(id.clone());
118 }
119 }
120
121 Err(RegistryError::IssuerNotFound(issuer.to_string()))
123 }
124
125 pub fn remove(&self, id: &str) -> Result<(), RegistryError> {
129 let mut providers = self.providers.write().unwrap();
130 if !providers.contains_key(id) {
131 return Err(RegistryError::ProviderNotFound(id.to_string()));
132 }
133
134 let issuer = providers.get(id).map(|p| p.get_issuer().to_string());
136
137 providers.remove(id);
139
140 if let Some(issuer) = issuer {
142 let mut issuer_map = self.issuer_to_id.write().unwrap();
143 issuer_map.remove(&issuer);
144 }
145
146 Ok(())
147 }
148
149 pub async fn hydrate(&self) -> Vec<(String, Result<(), JwtError>)> {
151 let providers = self.providers.read().unwrap();
152 let mut results = Vec::new();
153
154 for (id, provider) in providers.iter() {
155 let result = provider.prefetch_keys().await;
156 results.push((id.clone(), result));
157 }
158
159 results
160 }
161
162 pub async fn prefetch(&self, id: &str) -> Result<(), RegistryError> {
164 let provider = self.get(id)?;
165 provider.prefetch_keys().await?;
166 Ok(())
167 }
168
169 pub fn count(&self) -> usize {
171 let providers = self.providers.read().unwrap();
172 providers.len()
173 }
174
175 pub fn contains(&self, id: &str) -> bool {
177 let providers = self.providers.read().unwrap();
178 providers.contains_key(id)
179 }
180
181 pub fn list_ids(&self) -> Vec<String> {
183 let providers = self.providers.read().unwrap();
184 providers.keys().cloned().collect()
185 }
186
187 pub fn create_validation_for_issuer(
189 &self,
190 issuer: &str,
191 clock_skew: Duration,
192 client_ids: &Vec<String>,
193 ) -> Result<Validation, RegistryError> {
194 let provider = self.get_by_issuer(issuer)?;
196
197 let mut validation = Validation::new(jsonwebtoken::Algorithm::RS256);
199
200 validation.set_issuer(&[provider.get_issuer().to_string()]);
201 validation.set_audience(client_ids);
202 validation.set_required_spec_claims(&["exp", "iat", "iss", "sub"]);
203 validation.validate_exp = true;
204 validation.validate_nbf = true;
205 validation.validate_aud = true;
206 validation.leeway = clock_skew.as_secs() as u64;
207
208 Ok(validation)
209 }
210}
211
212impl Default for JwkProviderRegistry {
213 fn default() -> Self {
214 Self::new()
215 }
216}