1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use jsonwebtoken::DecodingKey;
11use serde::Deserialize;
12use tokio::sync::RwLock;
13use tracing::{debug, warn};
14
15#[derive(Debug, Deserialize)]
17pub struct JwksResponse {
18 pub keys: Vec<JsonWebKey>,
20}
21
22#[derive(Debug, Deserialize)]
24pub struct JsonWebKey {
25 pub kid: Option<String>,
27
28 pub kty: String,
30
31 pub alg: Option<String>,
33
34 #[serde(rename = "use")]
36 pub key_use: Option<String>,
37
38 pub n: Option<String>,
40
41 pub e: Option<String>,
43
44 pub x5c: Option<Vec<String>>,
46}
47
48struct CachedJwks {
50 keys: HashMap<String, DecodingKey>,
52 fetched_at: Instant,
54}
55
56pub struct JwksClient {
73 url: String,
75 http_client: reqwest::Client,
77 cache: Arc<RwLock<Option<CachedJwks>>>,
79 cache_ttl: Duration,
81}
82
83impl std::fmt::Debug for JwksClient {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 f.debug_struct("JwksClient")
86 .field("url", &self.url)
87 .field("cache_ttl", &self.cache_ttl)
88 .finish_non_exhaustive()
89 }
90}
91
92impl JwksClient {
93 pub fn new(url: String, cache_ttl_secs: u64) -> Self {
100 Self {
101 url,
102 http_client: reqwest::Client::builder()
103 .timeout(Duration::from_secs(10))
104 .build()
105 .expect("Failed to create HTTP client"),
106 cache: Arc::new(RwLock::new(None)),
107 cache_ttl: Duration::from_secs(cache_ttl_secs),
108 }
109 }
110
111 pub async fn get_key(&self, kid: &str) -> Result<DecodingKey, JwksError> {
116 {
118 let cache = self.cache.read().await;
119 if let Some(ref cached) = *cache
120 && cached.fetched_at.elapsed() < self.cache_ttl
121 && let Some(key) = cached.keys.get(kid)
122 {
123 debug!(kid = %kid, "Using cached JWKS key");
124 return Ok(key.clone());
125 }
126 }
127
128 debug!(kid = %kid, "JWKS cache miss, refreshing");
130 self.refresh().await?;
131
132 let cache = self.cache.read().await;
134 if let Some(ref cached) = *cache {
135 cached
136 .keys
137 .get(kid)
138 .cloned()
139 .ok_or_else(|| JwksError::KeyNotFound(kid.to_string()))
140 } else {
141 Err(JwksError::FetchFailed(
142 "Cache empty after refresh".to_string(),
143 ))
144 }
145 }
146
147 pub async fn get_any_key(&self) -> Result<DecodingKey, JwksError> {
152 {
154 let cache = self.cache.read().await;
155 if let Some(ref cached) = *cache
156 && cached.fetched_at.elapsed() < self.cache_ttl
157 && let Some(key) = cached.keys.values().next()
158 {
159 debug!("Using first cached JWKS key (no kid specified)");
160 return Ok(key.clone());
161 }
162 }
163
164 debug!("JWKS cache miss for any key, refreshing");
166 self.refresh().await?;
167
168 let cache = self.cache.read().await;
169 if let Some(ref cached) = *cache {
170 cached
171 .keys
172 .values()
173 .next()
174 .cloned()
175 .ok_or(JwksError::NoKeysAvailable)
176 } else {
177 Err(JwksError::FetchFailed("No keys in JWKS".to_string()))
178 }
179 }
180
181 pub async fn refresh(&self) -> Result<(), JwksError> {
185 debug!(url = %self.url, "Fetching JWKS");
186
187 let response = self
188 .http_client
189 .get(&self.url)
190 .send()
191 .await
192 .map_err(|e| JwksError::FetchFailed(e.to_string()))?;
193
194 if !response.status().is_success() {
195 return Err(JwksError::FetchFailed(format!(
196 "HTTP {} from JWKS endpoint",
197 response.status()
198 )));
199 }
200
201 let jwks: JwksResponse = response
202 .json()
203 .await
204 .map_err(|e| JwksError::ParseFailed(e.to_string()))?;
205
206 let mut keys = HashMap::new();
207
208 for jwk in jwks.keys {
209 if let Some(ref key_use) = jwk.key_use
211 && key_use != "sig"
212 {
213 continue;
214 }
215
216 let kid = jwk.kid.clone().unwrap_or_else(|| "default".to_string());
217
218 match self.parse_jwk(&jwk) {
219 Ok(Some(key)) => {
220 debug!(kid = %kid, kty = %jwk.kty, "Parsed JWKS key");
221 keys.insert(kid, key);
222 }
223 Ok(None) => {
224 debug!(kid = %kid, kty = %jwk.kty, "Skipping unsupported key type");
225 }
226 Err(e) => {
227 warn!(kid = %kid, error = %e, "Failed to parse JWKS key");
228 }
229 }
230 }
231
232 if keys.is_empty() {
233 return Err(JwksError::NoKeysAvailable);
234 }
235
236 debug!(count = keys.len(), "Cached JWKS keys");
237
238 let mut cache = self.cache.write().await;
239 *cache = Some(CachedJwks {
240 keys,
241 fetched_at: Instant::now(),
242 });
243
244 Ok(())
245 }
246
247 fn parse_jwk(&self, jwk: &JsonWebKey) -> Result<Option<DecodingKey>, JwksError> {
249 match jwk.kty.as_str() {
250 "RSA" => {
251 if let Some(ref x5c) = jwk.x5c
253 && let Some(cert) = x5c.first()
254 {
255 let pem = format!(
256 "-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----",
257 cert
258 );
259 return DecodingKey::from_rsa_pem(pem.as_bytes()).map(Some).map_err(
260 |e: jsonwebtoken::errors::Error| JwksError::KeyParseFailed(e.to_string()),
261 );
262 }
263
264 if let (Some(n), Some(e)) = (&jwk.n, &jwk.e) {
266 return DecodingKey::from_rsa_components(n, e).map(Some).map_err(
267 |e: jsonwebtoken::errors::Error| JwksError::KeyParseFailed(e.to_string()),
268 );
269 }
270
271 Ok(None)
273 }
274 _ => {
275 Ok(None)
277 }
278 }
279 }
280
281 pub fn url(&self) -> &str {
283 &self.url
284 }
285}
286
287#[derive(Debug, thiserror::Error)]
289pub enum JwksError {
290 #[error("Failed to fetch JWKS: {0}")]
292 FetchFailed(String),
293
294 #[error("Failed to parse JWKS: {0}")]
296 ParseFailed(String),
297
298 #[error("Failed to parse key: {0}")]
300 KeyParseFailed(String),
301
302 #[error("Key not found: {0}")]
304 KeyNotFound(String),
305
306 #[error("No keys available in JWKS")]
308 NoKeysAvailable,
309}
310
311#[cfg(test)]
312#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn test_parse_jwk_with_n_e() {
318 let client = JwksClient::new("http://example.com".to_string(), 3600);
319
320 let jwk = JsonWebKey {
322 kid: Some("test-key".to_string()),
323 kty: "RSA".to_string(),
324 alg: Some("RS256".to_string()),
325 key_use: Some("sig".to_string()),
326 n: Some("0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw".to_string()),
328 e: Some("AQAB".to_string()),
329 x5c: None,
330 };
331
332 let result = client.parse_jwk(&jwk);
333 assert!(result.is_ok());
334 assert!(result.unwrap().is_some());
335 }
336
337 #[test]
338 fn test_parse_jwk_unsupported_type() {
339 let client = JwksClient::new("http://example.com".to_string(), 3600);
340
341 let jwk = JsonWebKey {
342 kid: Some("test-key".to_string()),
343 kty: "EC".to_string(), alg: Some("ES256".to_string()),
345 key_use: Some("sig".to_string()),
346 n: None,
347 e: None,
348 x5c: None,
349 };
350
351 let result = client.parse_jwk(&jwk);
352 assert!(result.is_ok());
353 assert!(result.unwrap().is_none()); }
355
356 #[test]
357 fn test_parse_jwk_missing_components() {
358 let client = JwksClient::new("http://example.com".to_string(), 3600);
359
360 let jwk = JsonWebKey {
361 kid: Some("test-key".to_string()),
362 kty: "RSA".to_string(),
363 alg: Some("RS256".to_string()),
364 key_use: Some("sig".to_string()),
365 n: None, e: None, x5c: None,
368 };
369
370 let result = client.parse_jwk(&jwk);
371 assert!(result.is_ok());
372 assert!(result.unwrap().is_none()); }
374}