Skip to main content

forge_runtime/gateway/
jwks.rs

1//! JWKS (JSON Web Key Set) client for RSA token validation.
2//!
3//! This module provides a client for fetching and caching public keys from
4//! JWKS endpoints, used by providers like Firebase, Clerk, Auth0, etc.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use jsonwebtoken::DecodingKey;
11use serde::Deserialize;
12use tokio::sync::RwLock;
13use tracing::{debug, warn};
14
15/// JWKS response structure from providers.
16#[derive(Debug, Deserialize)]
17pub struct JwksResponse {
18    /// List of JSON Web Keys.
19    pub keys: Vec<JsonWebKey>,
20}
21
22/// Individual JSON Web Key.
23#[derive(Debug, Deserialize)]
24pub struct JsonWebKey {
25    /// Key ID - used to match tokens to keys.
26    pub kid: Option<String>,
27
28    /// Key type (RSA, EC, etc.).
29    pub kty: String,
30
31    /// Algorithm (RS256, RS384, RS512, etc.).
32    pub alg: Option<String>,
33
34    /// Key use (sig = signature, enc = encryption).
35    #[serde(rename = "use")]
36    pub key_use: Option<String>,
37
38    /// RSA modulus (base64url encoded).
39    pub n: Option<String>,
40
41    /// RSA exponent (base64url encoded).
42    pub e: Option<String>,
43
44    /// X.509 certificate chain (used by Firebase).
45    pub x5c: Option<Vec<String>>,
46}
47
48/// Cached JWKS keys with TTL tracking.
49struct CachedJwks {
50    /// Map of key ID to decoding key.
51    keys: HashMap<String, DecodingKey>,
52    /// When the cache was last refreshed.
53    fetched_at: Instant,
54}
55
56/// JWKS client with automatic caching.
57///
58/// Fetches public keys from a JWKS endpoint and caches them for efficient
59/// token validation. Keys are automatically refreshed when the cache expires.
60///
61/// # Example
62///
63/// ```ignore
64/// let client = JwksClient::new(
65///     "https://www.googleapis.com/service_accounts/v1/jwk/securetoken@system.gserviceaccount.com".to_string(),
66///     3600, // 1 hour cache TTL
67/// );
68///
69/// // Get key by ID from token header
70/// let key = client.get_key("abc123").await?;
71/// ```
72pub struct JwksClient {
73    /// JWKS endpoint URL.
74    url: String,
75    /// HTTP client for fetching keys.
76    http_client: reqwest::Client,
77    /// Cached keys with TTL.
78    cache: Arc<RwLock<Option<CachedJwks>>>,
79    /// Cache time-to-live.
80    cache_ttl: Duration,
81}
82
83impl std::fmt::Debug for JwksClient {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        f.debug_struct("JwksClient")
86            .field("url", &self.url)
87            .field("cache_ttl", &self.cache_ttl)
88            .finish_non_exhaustive()
89    }
90}
91
92impl JwksClient {
93    /// Create a new JWKS client.
94    ///
95    /// # Arguments
96    ///
97    /// * `url` - The JWKS endpoint URL
98    /// * `cache_ttl_secs` - How long to cache keys (in seconds)
99    pub fn new(url: String, cache_ttl_secs: u64) -> Result<Self, JwksError> {
100        let http_client = reqwest::Client::builder()
101            .timeout(Duration::from_secs(10))
102            .build()
103            .map_err(|e| JwksError::HttpClientError(e.to_string()))?;
104
105        Ok(Self {
106            url,
107            http_client,
108            cache: Arc::new(RwLock::new(None)),
109            cache_ttl: Duration::from_secs(cache_ttl_secs),
110        })
111    }
112
113    /// Get a decoding key by key ID.
114    ///
115    /// This will return a cached key if available and not expired,
116    /// otherwise it will fetch fresh keys from the JWKS endpoint.
117    pub async fn get_key(&self, kid: &str) -> Result<DecodingKey, JwksError> {
118        // Try to get from cache first
119        {
120            let cache = self.cache.read().await;
121            if let Some(ref cached) = *cache
122                && cached.fetched_at.elapsed() < self.cache_ttl
123                && let Some(key) = cached.keys.get(kid)
124            {
125                debug!(kid = %kid, "Using cached JWKS key");
126                return Ok(key.clone());
127            }
128        }
129
130        // Cache miss or expired - refresh
131        debug!(kid = %kid, "JWKS cache miss, refreshing");
132        self.refresh().await?;
133
134        // Try again from refreshed cache
135        let cache = self.cache.read().await;
136        if let Some(ref cached) = *cache {
137            cached
138                .keys
139                .get(kid)
140                .cloned()
141                .ok_or_else(|| JwksError::KeyNotFound(kid.to_string()))
142        } else {
143            Err(JwksError::FetchFailed(
144                "Cache empty after refresh".to_string(),
145            ))
146        }
147    }
148
149    /// Get any available key (for tokens without kid header).
150    ///
151    /// Some providers don't include a key ID in tokens. This method
152    /// returns the first available key from the JWKS.
153    pub async fn get_any_key(&self) -> Result<DecodingKey, JwksError> {
154        // Try to get from cache first
155        {
156            let cache = self.cache.read().await;
157            if let Some(ref cached) = *cache
158                && cached.fetched_at.elapsed() < self.cache_ttl
159                && let Some(key) = cached.keys.values().next()
160            {
161                debug!("Using first cached JWKS key (no kid specified)");
162                return Ok(key.clone());
163            }
164        }
165
166        // Cache miss or expired - refresh
167        debug!("JWKS cache miss for any key, refreshing");
168        self.refresh().await?;
169
170        let cache = self.cache.read().await;
171        if let Some(ref cached) = *cache {
172            cached
173                .keys
174                .values()
175                .next()
176                .cloned()
177                .ok_or(JwksError::NoKeysAvailable)
178        } else {
179            Err(JwksError::FetchFailed("No keys in JWKS".to_string()))
180        }
181    }
182
183    /// Force refresh the key cache.
184    ///
185    /// Fetches fresh keys from the JWKS endpoint regardless of cache state.
186    pub async fn refresh(&self) -> Result<(), JwksError> {
187        debug!(url = %self.url, "Fetching JWKS");
188
189        let response = self
190            .http_client
191            .get(&self.url)
192            .send()
193            .await
194            .map_err(|e| JwksError::FetchFailed(e.to_string()))?;
195
196        if !response.status().is_success() {
197            return Err(JwksError::FetchFailed(format!(
198                "HTTP {} from JWKS endpoint",
199                response.status()
200            )));
201        }
202
203        let jwks: JwksResponse = response
204            .json()
205            .await
206            .map_err(|e| JwksError::ParseFailed(e.to_string()))?;
207
208        let mut keys = HashMap::new();
209
210        for jwk in jwks.keys {
211            // Skip non-signature keys
212            if let Some(ref key_use) = jwk.key_use
213                && key_use != "sig"
214            {
215                continue;
216            }
217
218            let kid = jwk.kid.clone().unwrap_or_else(|| "default".to_string());
219
220            match self.parse_jwk(&jwk) {
221                Ok(Some(key)) => {
222                    debug!(kid = %kid, kty = %jwk.kty, "Parsed JWKS key");
223                    keys.insert(kid, key);
224                }
225                Ok(None) => {
226                    debug!(kid = %kid, kty = %jwk.kty, "Skipping unsupported key type");
227                }
228                Err(e) => {
229                    warn!(kid = %kid, error = %e, "Failed to parse JWKS key");
230                }
231            }
232        }
233
234        if keys.is_empty() {
235            return Err(JwksError::NoKeysAvailable);
236        }
237
238        debug!(count = keys.len(), "Cached JWKS keys");
239
240        let mut cache = self.cache.write().await;
241        *cache = Some(CachedJwks {
242            keys,
243            fetched_at: Instant::now(),
244        });
245
246        Ok(())
247    }
248
249    /// Parse a JWK into a DecodingKey.
250    fn parse_jwk(&self, jwk: &JsonWebKey) -> Result<Option<DecodingKey>, JwksError> {
251        match jwk.kty.as_str() {
252            "RSA" => {
253                // Try X.509 certificate chain first (used by Firebase)
254                if let Some(ref x5c) = jwk.x5c
255                    && let Some(cert) = x5c.first()
256                {
257                    let pem = format!(
258                        "-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----",
259                        cert
260                    );
261                    return DecodingKey::from_rsa_pem(pem.as_bytes()).map(Some).map_err(
262                        |e: jsonwebtoken::errors::Error| JwksError::KeyParseFailed(e.to_string()),
263                    );
264                }
265
266                // Fall back to n/e components (used by Clerk, Auth0, etc.)
267                if let (Some(n), Some(e)) = (&jwk.n, &jwk.e) {
268                    return DecodingKey::from_rsa_components(n, e).map(Some).map_err(
269                        |e: jsonwebtoken::errors::Error| JwksError::KeyParseFailed(e.to_string()),
270                    );
271                }
272
273                // RSA key but missing required components
274                Ok(None)
275            }
276            _ => {
277                // Unsupported key type (EC, oct, etc.)
278                Ok(None)
279            }
280        }
281    }
282
283    /// Get the JWKS URL.
284    pub fn url(&self) -> &str {
285        &self.url
286    }
287}
288
289/// Errors that can occur when working with JWKS.
290#[derive(Debug, thiserror::Error)]
291pub enum JwksError {
292    /// Failed to fetch JWKS from endpoint.
293    #[error("Failed to fetch JWKS: {0}")]
294    FetchFailed(String),
295
296    /// Failed to parse JWKS response.
297    #[error("Failed to parse JWKS: {0}")]
298    ParseFailed(String),
299
300    /// Failed to parse individual key.
301    #[error("Failed to parse key: {0}")]
302    KeyParseFailed(String),
303
304    /// Requested key ID not found in JWKS.
305    #[error("Key not found: {0}")]
306    KeyNotFound(String),
307
308    /// No usable keys in JWKS.
309    #[error("No keys available in JWKS")]
310    NoKeysAvailable,
311
312    /// Failed to create HTTP client.
313    #[error("Failed to create HTTP client: {0}")]
314    HttpClientError(String),
315}
316
317#[cfg(test)]
318#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn test_parse_jwk_with_n_e() {
324        let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap();
325
326        // Example RSA public key components (minimal test)
327        let jwk = JsonWebKey {
328            kid: Some("test-key".to_string()),
329            kty: "RSA".to_string(),
330            alg: Some("RS256".to_string()),
331            key_use: Some("sig".to_string()),
332            // These are example values - not a real key
333            n: Some("0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw".to_string()),
334            e: Some("AQAB".to_string()),
335            x5c: None,
336        };
337
338        let result = client.parse_jwk(&jwk);
339        assert!(result.is_ok());
340        assert!(result.unwrap().is_some());
341    }
342
343    #[test]
344    fn test_parse_jwk_unsupported_type() {
345        let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap();
346
347        let jwk = JsonWebKey {
348            kid: Some("test-key".to_string()),
349            kty: "EC".to_string(), // Unsupported
350            alg: Some("ES256".to_string()),
351            key_use: Some("sig".to_string()),
352            n: None,
353            e: None,
354            x5c: None,
355        };
356
357        let result = client.parse_jwk(&jwk);
358        assert!(result.is_ok());
359        assert!(result.unwrap().is_none()); // Should return None for unsupported types
360    }
361
362    #[test]
363    fn test_parse_jwk_missing_components() {
364        let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap();
365
366        let jwk = JsonWebKey {
367            kid: Some("test-key".to_string()),
368            kty: "RSA".to_string(),
369            alg: Some("RS256".to_string()),
370            key_use: Some("sig".to_string()),
371            n: None, // Missing
372            e: None, // Missing
373            x5c: None,
374        };
375
376        let result = client.parse_jwk(&jwk);
377        assert!(result.is_ok());
378        assert!(result.unwrap().is_none()); // Should return None when missing components
379    }
380}