jwt_verify/jwk/
provider.rs

1use anyhow::Result;
2use jsonwebtoken::DecodingKey;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fmt::Debug;
7use std::sync::RwLock;
8use std::time::{Duration, Instant};
9use tokio::sync::Mutex;
10
11use crate::common::error::JwtError;
12
13/// RSA key from JWK set
14#[derive(Debug, Clone, Deserialize, Serialize)]
15pub struct RSAKey {
16    /// Key ID
17    pub kid: String,
18    /// Algorithm
19    pub alg: String,
20    /// Modulus
21    pub n: String,
22    /// Exponent
23    pub e: String,
24    /// Key use
25    #[serde(rename = "use")]
26    pub use_for: String,
27}
28
29/// JWK set response
30#[derive(Debug, Deserialize)]
31pub struct JwkSet {
32    /// Keys in the set
33    pub keys: Vec<RSAKey>,
34}
35
36/// Cached JWK
37struct CachedJwk {
38    /// Decoding key
39    key: DecodingKey,
40    /// Time when the key was inserted into the cache
41    inserted_at: Instant,
42}
43
44impl std::fmt::Debug for CachedJwk {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.debug_struct("CachedJwk")
47            .field("inserted_at", &self.inserted_at)
48            .finish()
49    }
50}
51
52/// JWK provider for fetching and caching JWKs
53#[derive(Debug)]
54pub struct JwkProvider {
55    /// JWK URL
56    jwk_url: String,
57    /// Issuer URL
58    issuer: String,
59    /// Keys cache
60    keys_cache: RwLock<HashMap<String, CachedJwk>>,
61    /// Last refresh time
62    last_refresh: Mutex<Option<Instant>>,
63    /// Cache duration
64    cache_duration: Duration,
65    /// Minimum refresh interval
66    min_refresh_interval: Duration,
67    /// HTTP client
68    client: Client,
69}
70
71impl JwkProvider {
72    /// Create a new JWK provider
73    ///
74    /// This constructor creates a JwkProvider without prefetching keys.
75    /// Keys will be fetched on the first request.
76    pub fn new(
77        region: &str,
78        user_pool_id: &str,
79        cache_duration: Duration,
80    ) -> Result<Self, JwtError> {
81        // Default minimum refresh interval is 60 seconds for production use
82        // Validate region
83        if region.is_empty() {
84            return Err(JwtError::ConfigurationError {
85                parameter: Some("region".to_string()),
86                error: "Region cannot be empty".to_string(),
87            });
88        }
89
90        // Validate user pool ID
91        if user_pool_id.is_empty() {
92            return Err(JwtError::ConfigurationError {
93                parameter: Some("user_pool_id".to_string()),
94                error: "User pool ID cannot be empty".to_string(),
95            });
96        }
97
98        // Validate user pool ID format (region_code-region-number)
99        // AWS Cognito user pool IDs follow the format: region_code_region-number
100        // For example: us-east-1_abcdefghi
101        if !Self::is_valid_user_pool_id(user_pool_id) {
102            return Err(JwtError::ConfigurationError {
103                parameter: Some("user_pool_id".to_string()),
104                error: format!(
105                    "Invalid user pool ID format: {}. Expected format: region_number",
106                    user_pool_id
107                ),
108            });
109        }
110
111        // Create JWK URL
112        let issuer = format!(
113            "https://cognito-idp.{}.amazonaws.com/{}",
114            region, user_pool_id
115        );
116        let jwk_url = format!("{}/.well-known/jwks.json", issuer);
117
118        // Create HTTP client
119        let client = Client::builder().use_rustls_tls().build().map_err(|e| {
120            JwtError::ConfigurationError {
121                parameter: Some("http_client".to_string()),
122                error: format!("Failed to create HTTP client: {}", e),
123            }
124        })?;
125
126        let provider = Self {
127            jwk_url,
128            issuer,
129            keys_cache: RwLock::new(HashMap::new()),
130            last_refresh: Mutex::new(None),
131            cache_duration,
132            min_refresh_interval: Duration::from_secs(60), // Default to 60 seconds for production
133            client,
134        };
135
136        Ok(provider)
137    }
138
139    /// Create a new JWK provider from a JWKS URL (for OIDC providers)
140    ///
141    /// This constructor creates a JwkProvider for generic OIDC providers.
142    /// Keys will be fetched on the first request.
143    ///
144    /// # Arguments
145    /// * `jwks_url` - The full URL to the JWKS endpoint (e.g., "https://example.com/.well-known/jwks.json")
146    /// * `issuer` - The issuer URL (e.g., "https://example.com")
147    /// * `cache_duration` - How long to cache keys before refreshing
148    pub fn from_jwks_url(
149        jwks_url: &str,
150        issuer: &str,
151        cache_duration: Duration,
152    ) -> Result<Self, JwtError> {
153        // Validate JWKS URL
154        if jwks_url.is_empty() {
155            return Err(JwtError::ConfigurationError {
156                parameter: Some("jwks_url".to_string()),
157                error: "JWKS URL cannot be empty".to_string(),
158            });
159        }
160
161        // Validate issuer
162        if issuer.is_empty() {
163            return Err(JwtError::ConfigurationError {
164                parameter: Some("issuer".to_string()),
165                error: "Issuer cannot be empty".to_string(),
166            });
167        }
168
169        // Validate URL format
170        if !jwks_url.starts_with("http://") && !jwks_url.starts_with("https://") {
171            return Err(JwtError::ConfigurationError {
172                parameter: Some("jwks_url".to_string()),
173                error: "JWKS URL must start with http:// or https://".to_string(),
174            });
175        }
176
177        // Create HTTP client
178        let client = Client::builder().use_rustls_tls().build().map_err(|e| {
179            JwtError::ConfigurationError {
180                parameter: Some("http_client".to_string()),
181                error: format!("Failed to create HTTP client: {}", e),
182            }
183        })?;
184
185        let provider = Self {
186            jwk_url: jwks_url.to_string(),
187            issuer: issuer.to_string(),
188            keys_cache: RwLock::new(HashMap::new()),
189            last_refresh: Mutex::new(None),
190            cache_duration,
191            min_refresh_interval: Duration::from_secs(60), // Default to 60 seconds for production
192            client,
193        };
194
195        Ok(provider)
196    }
197
198    /// Validate user pool ID format
199    fn is_valid_user_pool_id(user_pool_id: &str) -> bool {
200        // AWS Cognito user pool IDs follow the format: region_code_region-number
201        // For example: us-east-1_abcdefghi
202        let parts: Vec<&str> = user_pool_id.split('_').collect();
203
204        if parts.len() != 2 {
205            return false;
206        }
207
208        // The second part should be alphanumeric
209        parts[1].chars().all(|c| c.is_alphanumeric())
210    }
211
212    /// Create a new JWK provider with a custom base URL (for testing)
213    #[cfg(test)]
214    pub fn new_with_base_url(
215        base_url: &str,
216        issuer: &str,
217        cache_duration: Duration,
218    ) -> Result<Self, JwtError> {
219        // For tests, use a shorter minimum refresh interval (1 second)
220        Self::new_with_base_url_and_refresh_interval(
221            base_url,
222            issuer,
223            cache_duration,
224            Duration::from_secs(1),
225        )
226    }
227
228    #[cfg(test)]
229    pub fn new_with_base_url_and_refresh_interval(
230        base_url: &str,
231        issuer: &str,
232        cache_duration: Duration,
233        min_refresh_interval: Duration,
234    ) -> Result<Self, JwtError> {
235        // Create JWK URL
236        let jwk_url = format!("{}/.well-known/jwks.json", base_url);
237
238        // Create HTTP client
239        let client = Client::builder().use_rustls_tls().build().map_err(|e| {
240            JwtError::ConfigurationError {
241                parameter: Some("http_client".to_string()),
242                error: format!("Failed to create HTTP client: {}", e),
243            }
244        })?;
245
246        let provider = Self {
247            jwk_url,
248            issuer: issuer.to_string(),
249            keys_cache: RwLock::new(HashMap::new()),
250            last_refresh: Mutex::new(None),
251            cache_duration,
252            min_refresh_interval,
253            client,
254        };
255
256        Ok(provider)
257    }
258
259    /// Get the issuer URL
260    pub fn get_issuer(&self) -> &str {
261        &self.issuer
262    }
263
264    /// Prefetch JWKs from the Cognito user pool
265    /// This method should be called when the system starts to ensure JWKs are available
266    pub async fn prefetch_keys(&self) -> Result<(), JwtError> {
267        self.refresh_keys().await
268    }
269
270    /// Get a JWK by key ID
271    pub async fn get_key(&self, kid: &str) -> Result<DecodingKey, JwtError> {
272        // Check if key is in cache
273        {
274            let cache = self.keys_cache.read().unwrap();
275            if let Some(cached_jwk) = cache.get(kid) {
276                let now = Instant::now();
277                if now.duration_since(cached_jwk.inserted_at) < self.cache_duration {
278                    // Cache hit
279                    return Ok(cached_jwk.key.clone());
280                }
281                // Cache entry expired
282            } else {
283                // Cache miss
284            }
285        }
286
287        // Key not found or expired, refresh keys
288        self.refresh_keys().await?;
289
290        // Try to get key from cache again
291        {
292            let cache = self.keys_cache.read().unwrap();
293            if let Some(cached_jwk) = cache.get(kid) {
294                return Ok(cached_jwk.key.clone());
295            }
296        }
297
298        // Key still not found
299        Err(JwtError::KeyNotFound(kid.to_string()))
300    }
301
302    /// Refresh JWKs
303    async fn refresh_keys(&self) -> Result<(), JwtError> {
304        // Check if we need to refresh
305        {
306            let mut last_refresh = self.last_refresh.lock().await;
307            if let Some(time) = *last_refresh {
308                let now = Instant::now();
309                if now.duration_since(time) < self.min_refresh_interval {
310                    // Don't refresh more than once per configured interval
311                    tracing::debug!(
312                        "Skipping JWK refresh, last refresh was less than {:?} ago",
313                        self.min_refresh_interval
314                    );
315                    return Ok(());
316                }
317            }
318
319            // Update last refresh time
320            *last_refresh = Some(Instant::now());
321        }
322
323        tracing::debug!("Fetching JWKs from {}", self.jwk_url);
324
325        // Fetch JWKs with retry for transient failures
326        let mut retry_count = 0;
327        let max_retries = 3;
328        let mut last_error = None;
329
330        while retry_count < max_retries {
331            match self.fetch_and_parse_jwks().await {
332                Ok(()) => {
333                    tracing::debug!("Successfully refreshed JWKs");
334                    return Ok(());
335                }
336                Err(e) => {
337                    // Only retry on network errors, not on parsing errors
338                    match &e {
339                        JwtError::JwksFetchError { .. } => {
340                            retry_count += 1;
341                            if retry_count < max_retries {
342                                tracing::warn!(
343                                    "Failed to fetch JWKs (attempt {}/{}): {}. Retrying...",
344                                    retry_count,
345                                    max_retries,
346                                    e
347                                );
348                                tokio::time::sleep(Duration::from_millis(500 * (1 << retry_count)))
349                                    .await;
350                            }
351                            last_error = Some(e);
352                        }
353                        _ => {
354                            // Don't retry on parsing errors
355                            return Err(e);
356                        }
357                    }
358                }
359            }
360        }
361
362        // All retries failed
363        Err(last_error.unwrap_or_else(|| JwtError::JwksFetchError {
364            url: Some(self.jwk_url.clone()),
365            error: "Failed to fetch JWKs after multiple attempts".to_string(),
366        }))
367    }
368
369    /// Remove expired keys from the cache
370    fn prune_expired_keys(&self, cache: &mut HashMap<String, CachedJwk>, now: Instant) {
371        let expired_keys: Vec<String> = cache
372            .iter()
373            .filter(|(_, cached_jwk)| {
374                now.duration_since(cached_jwk.inserted_at) >= self.cache_duration
375            })
376            .map(|(kid, _)| kid.clone())
377            .collect();
378
379        if !expired_keys.is_empty() {
380            tracing::debug!("Pruning {} expired keys from cache", expired_keys.len());
381            for kid in expired_keys {
382                cache.remove(&kid);
383            }
384        }
385    }
386
387    /// Fetch and parse JWKs from the Cognito endpoint
388    async fn fetch_and_parse_jwks(&self) -> Result<(), JwtError> {
389        // Fetch JWKs
390        let response = self
391            .client
392            .get(&self.jwk_url)
393            .timeout(Duration::from_secs(5))
394            .send()
395            .await
396            .map_err(|e| {
397                let error_msg = if e.is_timeout() {
398                    "Request timed out".to_string()
399                } else if e.is_connect() {
400                    "Connection error".to_string()
401                } else {
402                    format!("HTTP request failed: {}", e)
403                };
404
405                JwtError::JwksFetchError {
406                    url: Some(self.jwk_url.clone()),
407                    error: error_msg,
408                }
409            })?;
410
411        // Check response status
412        if !response.status().is_success() {
413            return Err(JwtError::JwksFetchError {
414                url: Some(self.jwk_url.clone()),
415                error: format!("Failed to fetch JWKs: HTTP {}", response.status()),
416            });
417        }
418
419        // Parse response
420        let jwk_set: JwkSet = response.json().await.map_err(|e| JwtError::ParseError {
421            part: Some("jwk_response".to_string()),
422            error: format!("Failed to parse JWK response: {}", e),
423        })?;
424
425        // Check if we got any keys
426        if jwk_set.keys.is_empty() {
427            return Err(JwtError::JwksFetchError {
428                url: Some(self.jwk_url.clone()),
429                error: "JWK set is empty".to_string(),
430            });
431        }
432
433        tracing::debug!("Fetched {} JWKs from Cognito", jwk_set.keys.len());
434
435        // Update cache
436        {
437            let mut cache = self.keys_cache.write().unwrap();
438            let now = Instant::now();
439
440            // Remove expired keys before adding new ones
441            self.prune_expired_keys(&mut cache, now);
442
443            for key in jwk_set.keys {
444                // Validate key fields
445                if key.kid.is_empty() {
446                    tracing::warn!("Skipping JWK with empty kid");
447                    continue;
448                }
449
450                if key.n.is_empty() || key.e.is_empty() {
451                    tracing::warn!("Skipping JWK with empty RSA components: kid={}", key.kid);
452                    continue;
453                }
454
455                // Create decoding key
456                let decoding_key =
457                    DecodingKey::from_rsa_components(&key.n, &key.e).map_err(|e| {
458                        JwtError::ParseError {
459                            part: Some("jwk".to_string()),
460                            error: format!("Failed to create decoding key: {}", e),
461                        }
462                    })?;
463
464                // Add to cache
465                cache.insert(
466                    key.kid.clone(),
467                    CachedJwk {
468                        key: decoding_key,
469                        inserted_at: now,
470                    },
471                );
472
473                tracing::debug!("Cached JWK with kid={}", key.kid);
474            }
475        }
476
477        Ok(())
478    }
479}