1use anyhow::Result;
2use jsonwebtoken::DecodingKey;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fmt::Debug;
7use std::sync::RwLock;
8use std::time::{Duration, Instant};
9use tokio::sync::Mutex;
10
11use crate::common::error::JwtError;
12
13#[derive(Debug, Clone, Deserialize, Serialize)]
15pub struct RSAKey {
16 pub kid: String,
18 pub alg: String,
20 pub n: String,
22 pub e: String,
24 #[serde(rename = "use")]
26 pub use_for: String,
27}
28
29#[derive(Debug, Deserialize)]
31pub struct JwkSet {
32 pub keys: Vec<RSAKey>,
34}
35
36struct CachedJwk {
38 key: DecodingKey,
40 inserted_at: Instant,
42}
43
44impl std::fmt::Debug for CachedJwk {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("CachedJwk")
47 .field("inserted_at", &self.inserted_at)
48 .finish()
49 }
50}
51
52#[derive(Debug)]
54pub struct JwkProvider {
55 jwk_url: String,
57 issuer: String,
59 keys_cache: RwLock<HashMap<String, CachedJwk>>,
61 last_refresh: Mutex<Option<Instant>>,
63 cache_duration: Duration,
65 min_refresh_interval: Duration,
67 client: Client,
69}
70
71impl JwkProvider {
72 pub fn new(
77 region: &str,
78 user_pool_id: &str,
79 cache_duration: Duration,
80 ) -> Result<Self, JwtError> {
81 if region.is_empty() {
84 return Err(JwtError::ConfigurationError {
85 parameter: Some("region".to_string()),
86 error: "Region cannot be empty".to_string(),
87 });
88 }
89
90 if user_pool_id.is_empty() {
92 return Err(JwtError::ConfigurationError {
93 parameter: Some("user_pool_id".to_string()),
94 error: "User pool ID cannot be empty".to_string(),
95 });
96 }
97
98 if !Self::is_valid_user_pool_id(user_pool_id) {
102 return Err(JwtError::ConfigurationError {
103 parameter: Some("user_pool_id".to_string()),
104 error: format!(
105 "Invalid user pool ID format: {}. Expected format: region_number",
106 user_pool_id
107 ),
108 });
109 }
110
111 let issuer = format!(
113 "https://cognito-idp.{}.amazonaws.com/{}",
114 region, user_pool_id
115 );
116 let jwk_url = format!("{}/.well-known/jwks.json", issuer);
117
118 let client = Client::builder().use_rustls_tls().build().map_err(|e| {
120 JwtError::ConfigurationError {
121 parameter: Some("http_client".to_string()),
122 error: format!("Failed to create HTTP client: {}", e),
123 }
124 })?;
125
126 let provider = Self {
127 jwk_url,
128 issuer,
129 keys_cache: RwLock::new(HashMap::new()),
130 last_refresh: Mutex::new(None),
131 cache_duration,
132 min_refresh_interval: Duration::from_secs(60), client,
134 };
135
136 Ok(provider)
137 }
138
139 pub fn from_jwks_url(
149 jwks_url: &str,
150 issuer: &str,
151 cache_duration: Duration,
152 ) -> Result<Self, JwtError> {
153 if jwks_url.is_empty() {
155 return Err(JwtError::ConfigurationError {
156 parameter: Some("jwks_url".to_string()),
157 error: "JWKS URL cannot be empty".to_string(),
158 });
159 }
160
161 if issuer.is_empty() {
163 return Err(JwtError::ConfigurationError {
164 parameter: Some("issuer".to_string()),
165 error: "Issuer cannot be empty".to_string(),
166 });
167 }
168
169 if !jwks_url.starts_with("http://") && !jwks_url.starts_with("https://") {
171 return Err(JwtError::ConfigurationError {
172 parameter: Some("jwks_url".to_string()),
173 error: "JWKS URL must start with http:// or https://".to_string(),
174 });
175 }
176
177 let client = Client::builder().use_rustls_tls().build().map_err(|e| {
179 JwtError::ConfigurationError {
180 parameter: Some("http_client".to_string()),
181 error: format!("Failed to create HTTP client: {}", e),
182 }
183 })?;
184
185 let provider = Self {
186 jwk_url: jwks_url.to_string(),
187 issuer: issuer.to_string(),
188 keys_cache: RwLock::new(HashMap::new()),
189 last_refresh: Mutex::new(None),
190 cache_duration,
191 min_refresh_interval: Duration::from_secs(60), client,
193 };
194
195 Ok(provider)
196 }
197
198 fn is_valid_user_pool_id(user_pool_id: &str) -> bool {
200 let parts: Vec<&str> = user_pool_id.split('_').collect();
203
204 if parts.len() != 2 {
205 return false;
206 }
207
208 parts[1].chars().all(|c| c.is_alphanumeric())
210 }
211
212 #[cfg(test)]
214 pub fn new_with_base_url(
215 base_url: &str,
216 issuer: &str,
217 cache_duration: Duration,
218 ) -> Result<Self, JwtError> {
219 Self::new_with_base_url_and_refresh_interval(
221 base_url,
222 issuer,
223 cache_duration,
224 Duration::from_secs(1),
225 )
226 }
227
228 #[cfg(test)]
229 pub fn new_with_base_url_and_refresh_interval(
230 base_url: &str,
231 issuer: &str,
232 cache_duration: Duration,
233 min_refresh_interval: Duration,
234 ) -> Result<Self, JwtError> {
235 let jwk_url = format!("{}/.well-known/jwks.json", base_url);
237
238 let client = Client::builder().use_rustls_tls().build().map_err(|e| {
240 JwtError::ConfigurationError {
241 parameter: Some("http_client".to_string()),
242 error: format!("Failed to create HTTP client: {}", e),
243 }
244 })?;
245
246 let provider = Self {
247 jwk_url,
248 issuer: issuer.to_string(),
249 keys_cache: RwLock::new(HashMap::new()),
250 last_refresh: Mutex::new(None),
251 cache_duration,
252 min_refresh_interval,
253 client,
254 };
255
256 Ok(provider)
257 }
258
259 pub fn get_issuer(&self) -> &str {
261 &self.issuer
262 }
263
264 pub async fn prefetch_keys(&self) -> Result<(), JwtError> {
267 self.refresh_keys().await
268 }
269
270 pub async fn get_key(&self, kid: &str) -> Result<DecodingKey, JwtError> {
272 {
274 let cache = self.keys_cache.read().unwrap();
275 if let Some(cached_jwk) = cache.get(kid) {
276 let now = Instant::now();
277 if now.duration_since(cached_jwk.inserted_at) < self.cache_duration {
278 return Ok(cached_jwk.key.clone());
280 }
281 } else {
283 }
285 }
286
287 self.refresh_keys().await?;
289
290 {
292 let cache = self.keys_cache.read().unwrap();
293 if let Some(cached_jwk) = cache.get(kid) {
294 return Ok(cached_jwk.key.clone());
295 }
296 }
297
298 Err(JwtError::KeyNotFound(kid.to_string()))
300 }
301
302 async fn refresh_keys(&self) -> Result<(), JwtError> {
304 {
306 let mut last_refresh = self.last_refresh.lock().await;
307 if let Some(time) = *last_refresh {
308 let now = Instant::now();
309 if now.duration_since(time) < self.min_refresh_interval {
310 tracing::debug!(
312 "Skipping JWK refresh, last refresh was less than {:?} ago",
313 self.min_refresh_interval
314 );
315 return Ok(());
316 }
317 }
318
319 *last_refresh = Some(Instant::now());
321 }
322
323 tracing::debug!("Fetching JWKs from {}", self.jwk_url);
324
325 let mut retry_count = 0;
327 let max_retries = 3;
328 let mut last_error = None;
329
330 while retry_count < max_retries {
331 match self.fetch_and_parse_jwks().await {
332 Ok(()) => {
333 tracing::debug!("Successfully refreshed JWKs");
334 return Ok(());
335 }
336 Err(e) => {
337 match &e {
339 JwtError::JwksFetchError { .. } => {
340 retry_count += 1;
341 if retry_count < max_retries {
342 tracing::warn!(
343 "Failed to fetch JWKs (attempt {}/{}): {}. Retrying...",
344 retry_count,
345 max_retries,
346 e
347 );
348 tokio::time::sleep(Duration::from_millis(500 * (1 << retry_count)))
349 .await;
350 }
351 last_error = Some(e);
352 }
353 _ => {
354 return Err(e);
356 }
357 }
358 }
359 }
360 }
361
362 Err(last_error.unwrap_or_else(|| JwtError::JwksFetchError {
364 url: Some(self.jwk_url.clone()),
365 error: "Failed to fetch JWKs after multiple attempts".to_string(),
366 }))
367 }
368
369 fn prune_expired_keys(&self, cache: &mut HashMap<String, CachedJwk>, now: Instant) {
371 let expired_keys: Vec<String> = cache
372 .iter()
373 .filter(|(_, cached_jwk)| {
374 now.duration_since(cached_jwk.inserted_at) >= self.cache_duration
375 })
376 .map(|(kid, _)| kid.clone())
377 .collect();
378
379 if !expired_keys.is_empty() {
380 tracing::debug!("Pruning {} expired keys from cache", expired_keys.len());
381 for kid in expired_keys {
382 cache.remove(&kid);
383 }
384 }
385 }
386
387 async fn fetch_and_parse_jwks(&self) -> Result<(), JwtError> {
389 let response = self
391 .client
392 .get(&self.jwk_url)
393 .timeout(Duration::from_secs(5))
394 .send()
395 .await
396 .map_err(|e| {
397 let error_msg = if e.is_timeout() {
398 "Request timed out".to_string()
399 } else if e.is_connect() {
400 "Connection error".to_string()
401 } else {
402 format!("HTTP request failed: {}", e)
403 };
404
405 JwtError::JwksFetchError {
406 url: Some(self.jwk_url.clone()),
407 error: error_msg,
408 }
409 })?;
410
411 if !response.status().is_success() {
413 return Err(JwtError::JwksFetchError {
414 url: Some(self.jwk_url.clone()),
415 error: format!("Failed to fetch JWKs: HTTP {}", response.status()),
416 });
417 }
418
419 let jwk_set: JwkSet = response.json().await.map_err(|e| JwtError::ParseError {
421 part: Some("jwk_response".to_string()),
422 error: format!("Failed to parse JWK response: {}", e),
423 })?;
424
425 if jwk_set.keys.is_empty() {
427 return Err(JwtError::JwksFetchError {
428 url: Some(self.jwk_url.clone()),
429 error: "JWK set is empty".to_string(),
430 });
431 }
432
433 tracing::debug!("Fetched {} JWKs from Cognito", jwk_set.keys.len());
434
435 {
437 let mut cache = self.keys_cache.write().unwrap();
438 let now = Instant::now();
439
440 self.prune_expired_keys(&mut cache, now);
442
443 for key in jwk_set.keys {
444 if key.kid.is_empty() {
446 tracing::warn!("Skipping JWK with empty kid");
447 continue;
448 }
449
450 if key.n.is_empty() || key.e.is_empty() {
451 tracing::warn!("Skipping JWK with empty RSA components: kid={}", key.kid);
452 continue;
453 }
454
455 let decoding_key =
457 DecodingKey::from_rsa_components(&key.n, &key.e).map_err(|e| {
458 JwtError::ParseError {
459 part: Some("jwk".to_string()),
460 error: format!("Failed to create decoding key: {}", e),
461 }
462 })?;
463
464 cache.insert(
466 key.kid.clone(),
467 CachedJwk {
468 key: decoding_key,
469 inserted_at: now,
470 },
471 );
472
473 tracing::debug!("Cached JWK with kid={}", key.kid);
474 }
475 }
476
477 Ok(())
478 }
479}