1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use jsonwebtoken::DecodingKey;
11use serde::Deserialize;
12use tokio::sync::RwLock;
13use tracing::{debug, warn};
14
15#[derive(Debug, Deserialize)]
17pub struct JwksResponse {
18 pub keys: Vec<JsonWebKey>,
20}
21
22#[derive(Debug, Deserialize)]
24pub struct JsonWebKey {
25 pub kid: Option<String>,
27
28 pub kty: String,
30
31 pub alg: Option<String>,
33
34 #[serde(rename = "use")]
36 pub key_use: Option<String>,
37
38 pub n: Option<String>,
40
41 pub e: Option<String>,
43
44 pub x5c: Option<Vec<String>>,
46}
47
48struct CachedJwks {
50 keys: HashMap<String, DecodingKey>,
52 fetched_at: Instant,
54}
55
56pub struct JwksClient {
73 url: String,
75 http_client: reqwest::Client,
77 cache: Arc<RwLock<Option<CachedJwks>>>,
79 cache_ttl: Duration,
81}
82
83impl std::fmt::Debug for JwksClient {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 f.debug_struct("JwksClient")
86 .field("url", &self.url)
87 .field("cache_ttl", &self.cache_ttl)
88 .finish_non_exhaustive()
89 }
90}
91
92impl JwksClient {
93 pub fn new(url: String, cache_ttl_secs: u64) -> Result<Self, JwksError> {
100 let http_client = reqwest::Client::builder()
101 .timeout(Duration::from_secs(10))
102 .build()
103 .map_err(|e| JwksError::HttpClientError(e.to_string()))?;
104
105 Ok(Self {
106 url,
107 http_client,
108 cache: Arc::new(RwLock::new(None)),
109 cache_ttl: Duration::from_secs(cache_ttl_secs),
110 })
111 }
112
113 pub async fn get_key(&self, kid: &str) -> Result<DecodingKey, JwksError> {
118 {
120 let cache = self.cache.read().await;
121 if let Some(ref cached) = *cache
122 && cached.fetched_at.elapsed() < self.cache_ttl
123 && let Some(key) = cached.keys.get(kid)
124 {
125 debug!(kid = %kid, "Using cached JWKS key");
126 return Ok(key.clone());
127 }
128 }
129
130 debug!(kid = %kid, "JWKS cache miss, refreshing");
132 self.refresh().await?;
133
134 let cache = self.cache.read().await;
136 if let Some(ref cached) = *cache {
137 cached
138 .keys
139 .get(kid)
140 .cloned()
141 .ok_or_else(|| JwksError::KeyNotFound(kid.to_string()))
142 } else {
143 Err(JwksError::FetchFailed(
144 "Cache empty after refresh".to_string(),
145 ))
146 }
147 }
148
149 pub async fn get_any_key(&self) -> Result<DecodingKey, JwksError> {
154 {
156 let cache = self.cache.read().await;
157 if let Some(ref cached) = *cache
158 && cached.fetched_at.elapsed() < self.cache_ttl
159 && let Some(key) = cached.keys.values().next()
160 {
161 debug!("Using first cached JWKS key (no kid specified)");
162 return Ok(key.clone());
163 }
164 }
165
166 debug!("JWKS cache miss for any key, refreshing");
168 self.refresh().await?;
169
170 let cache = self.cache.read().await;
171 if let Some(ref cached) = *cache {
172 cached
173 .keys
174 .values()
175 .next()
176 .cloned()
177 .ok_or(JwksError::NoKeysAvailable)
178 } else {
179 Err(JwksError::FetchFailed("No keys in JWKS".to_string()))
180 }
181 }
182
183 pub async fn refresh(&self) -> Result<(), JwksError> {
187 debug!(url = %self.url, "Fetching JWKS");
188
189 let response = self
190 .http_client
191 .get(&self.url)
192 .send()
193 .await
194 .map_err(|e| JwksError::FetchFailed(e.to_string()))?;
195
196 if !response.status().is_success() {
197 return Err(JwksError::FetchFailed(format!(
198 "HTTP {} from JWKS endpoint",
199 response.status()
200 )));
201 }
202
203 let jwks: JwksResponse = response
204 .json()
205 .await
206 .map_err(|e| JwksError::ParseFailed(e.to_string()))?;
207
208 let mut keys = HashMap::new();
209
210 for jwk in jwks.keys {
211 if let Some(ref key_use) = jwk.key_use
213 && key_use != "sig"
214 {
215 continue;
216 }
217
218 let kid = jwk.kid.clone().unwrap_or_else(|| "default".to_string());
219
220 match self.parse_jwk(&jwk) {
221 Ok(Some(key)) => {
222 debug!(kid = %kid, kty = %jwk.kty, "Parsed JWKS key");
223 keys.insert(kid, key);
224 }
225 Ok(None) => {
226 debug!(kid = %kid, kty = %jwk.kty, "Skipping unsupported key type");
227 }
228 Err(e) => {
229 warn!(kid = %kid, error = %e, "Failed to parse JWKS key");
230 }
231 }
232 }
233
234 if keys.is_empty() {
235 return Err(JwksError::NoKeysAvailable);
236 }
237
238 debug!(count = keys.len(), "Cached JWKS keys");
239
240 let mut cache = self.cache.write().await;
241 *cache = Some(CachedJwks {
242 keys,
243 fetched_at: Instant::now(),
244 });
245
246 Ok(())
247 }
248
249 fn parse_jwk(&self, jwk: &JsonWebKey) -> Result<Option<DecodingKey>, JwksError> {
251 match jwk.kty.as_str() {
252 "RSA" => {
253 if let Some(ref x5c) = jwk.x5c
255 && let Some(cert) = x5c.first()
256 {
257 let pem = format!(
258 "-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----",
259 cert
260 );
261 return DecodingKey::from_rsa_pem(pem.as_bytes()).map(Some).map_err(
262 |e: jsonwebtoken::errors::Error| JwksError::KeyParseFailed(e.to_string()),
263 );
264 }
265
266 if let (Some(n), Some(e)) = (&jwk.n, &jwk.e) {
268 return DecodingKey::from_rsa_components(n, e).map(Some).map_err(
269 |e: jsonwebtoken::errors::Error| JwksError::KeyParseFailed(e.to_string()),
270 );
271 }
272
273 Ok(None)
275 }
276 _ => {
277 Ok(None)
279 }
280 }
281 }
282
283 pub fn url(&self) -> &str {
285 &self.url
286 }
287}
288
289#[derive(Debug, thiserror::Error)]
291pub enum JwksError {
292 #[error("Failed to fetch JWKS: {0}")]
294 FetchFailed(String),
295
296 #[error("Failed to parse JWKS: {0}")]
298 ParseFailed(String),
299
300 #[error("Failed to parse key: {0}")]
302 KeyParseFailed(String),
303
304 #[error("Key not found: {0}")]
306 KeyNotFound(String),
307
308 #[error("No keys available in JWKS")]
310 NoKeysAvailable,
311
312 #[error("Failed to create HTTP client: {0}")]
314 HttpClientError(String),
315}
316
317#[cfg(test)]
318#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_parse_jwk_with_n_e() {
324 let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap();
325
326 let jwk = JsonWebKey {
328 kid: Some("test-key".to_string()),
329 kty: "RSA".to_string(),
330 alg: Some("RS256".to_string()),
331 key_use: Some("sig".to_string()),
332 n: Some("0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw".to_string()),
334 e: Some("AQAB".to_string()),
335 x5c: None,
336 };
337
338 let result = client.parse_jwk(&jwk);
339 assert!(result.is_ok());
340 assert!(result.unwrap().is_some());
341 }
342
343 #[test]
344 fn test_parse_jwk_unsupported_type() {
345 let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap();
346
347 let jwk = JsonWebKey {
348 kid: Some("test-key".to_string()),
349 kty: "EC".to_string(), alg: Some("ES256".to_string()),
351 key_use: Some("sig".to_string()),
352 n: None,
353 e: None,
354 x5c: None,
355 };
356
357 let result = client.parse_jwk(&jwk);
358 assert!(result.is_ok());
359 assert!(result.unwrap().is_none()); }
361
362 #[test]
363 fn test_parse_jwk_missing_components() {
364 let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap();
365
366 let jwk = JsonWebKey {
367 kid: Some("test-key".to_string()),
368 kty: "RSA".to_string(),
369 alg: Some("RS256".to_string()),
370 key_use: Some("sig".to_string()),
371 n: None, e: None, x5c: None,
374 };
375
376 let result = client.parse_jwk(&jwk);
377 assert!(result.is_ok());
378 assert!(result.unwrap().is_none()); }
380}