1use crate::{claims_error::ClaimsError, plugin_traits::KeyProvider};
2use arc_swap::ArcSwap;
3use async_trait::async_trait;
4use jsonwebtoken::{DecodingKey, Header, Validation, decode, decode_header};
5use serde::Deserialize;
6use serde_json::Value;
7use std::collections::{HashMap, HashSet};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::RwLock;
11use tokio::time::Instant;
12
13#[derive(Debug, Clone, Deserialize)]
14struct Jwk {
15 kid: String,
16 kty: String,
17 #[serde(rename = "use")]
18 #[allow(dead_code)]
19 use_: Option<String>,
20 n: String,
21 e: String,
22 #[allow(dead_code)]
23 alg: Option<String>,
24}
25
26#[derive(Debug, Clone, Deserialize)]
27struct JwksResponse {
28 keys: Vec<Jwk>,
29}
30
31#[must_use]
35pub struct JwksKeyProvider {
36 jwks_uri: String,
38
39 keys: Arc<ArcSwap<HashMap<String, DecodingKey>>>,
41
42 refresh_state: Arc<RwLock<RefreshState>>,
44
45 client: reqwest::Client,
47
48 refresh_interval: Duration,
50
51 max_backoff: Duration,
53
54 on_demand_refresh_cooldown: Duration,
56}
57
58#[derive(Debug, Default)]
59struct RefreshState {
60 last_refresh: Option<Instant>,
61 last_on_demand_refresh: Option<Instant>,
62 consecutive_failures: u32,
63 last_error: Option<String>,
64 failed_kids: HashSet<String>,
65}
66
67impl JwksKeyProvider {
68 pub fn new(jwks_uri: impl Into<String>) -> Result<Self, reqwest::Error> {
73 Ok(Self {
74 jwks_uri: jwks_uri.into(),
75 keys: Arc::new(ArcSwap::from_pointee(HashMap::new())),
76 refresh_state: Arc::new(RwLock::new(RefreshState::default())),
77 client: reqwest::Client::builder()
78 .timeout(Duration::from_secs(10))
79 .build()?,
80 refresh_interval: Duration::from_secs(300), max_backoff: Duration::from_secs(3600), on_demand_refresh_cooldown: Duration::from_secs(60), })
84 }
85
86 pub fn with_refresh_interval(mut self, interval: Duration) -> Self {
88 self.refresh_interval = interval;
89 self
90 }
91
92 pub fn with_max_backoff(mut self, max_backoff: Duration) -> Self {
94 self.max_backoff = max_backoff;
95 self
96 }
97
98 pub fn with_on_demand_refresh_cooldown(mut self, cooldown: Duration) -> Self {
100 self.on_demand_refresh_cooldown = cooldown;
101 self
102 }
103
104 async fn fetch_jwks(&self) -> Result<HashMap<String, DecodingKey>, ClaimsError> {
106 let response = self
107 .client
108 .get(&self.jwks_uri)
109 .send()
110 .await
111 .map_err(|e| ClaimsError::JwksFetchFailed(format!("HTTP request failed: {e}")))?;
112
113 if !response.status().is_success() {
114 return Err(ClaimsError::JwksFetchFailed(format!(
115 "HTTP error: {}",
116 response.status()
117 )));
118 }
119
120 let jwks: JwksResponse = response
121 .json()
122 .await
123 .map_err(|e| ClaimsError::JwksFetchFailed(format!("Failed to parse JWKS: {e}")))?;
124
125 let mut keys = HashMap::new();
126 for jwk in jwks.keys {
127 if jwk.kty == "RSA" {
128 let key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)
129 .map_err(|e| ClaimsError::JwksFetchFailed(format!("Invalid RSA key: {e}")))?;
130 keys.insert(jwk.kid, key);
131 }
132 }
133
134 if keys.is_empty() {
135 return Err(ClaimsError::JwksFetchFailed(
136 "No valid RSA keys found in JWKS".into(),
137 ));
138 }
139
140 Ok(keys)
141 }
142
143 fn calculate_backoff(&self, failures: u32) -> Duration {
145 let base = Duration::from_secs(60); let exponential = base * 2u32.pow(failures.min(10)); exponential.min(self.max_backoff)
148 }
149
150 async fn should_refresh(&self) -> bool {
152 let state = self.refresh_state.read().await;
153
154 match state.last_refresh {
155 None => true, Some(last) => {
157 let elapsed = last.elapsed();
158 if state.consecutive_failures == 0 {
159 elapsed >= self.refresh_interval
161 } else {
162 elapsed >= self.calculate_backoff(state.consecutive_failures)
164 }
165 }
166 }
167 }
168
169 async fn perform_refresh(&self) -> Result<(), ClaimsError> {
171 match self.fetch_jwks().await {
172 Ok(new_keys) => {
173 self.keys.store(Arc::new(new_keys));
175
176 let mut state = self.refresh_state.write().await;
178 state.last_refresh = Some(Instant::now());
179 state.consecutive_failures = 0;
180 state.last_error = None;
181
182 Ok(())
183 }
184 Err(e) => {
185 let mut state = self.refresh_state.write().await;
187 state.last_refresh = Some(Instant::now());
188 state.consecutive_failures += 1;
189 state.last_error = Some(e.to_string());
190
191 Err(e)
192 }
193 }
194 }
195
196 fn key_exists(&self, kid: &str) -> bool {
198 let keys = self.keys.load();
199 keys.contains_key(kid)
200 }
201
202 async fn check_refresh_throttle(&self, kid: &str) -> Result<(), ClaimsError> {
204 let state = self.refresh_state.read().await;
205 if let Some(last_on_demand) = state.last_on_demand_refresh {
206 let elapsed = last_on_demand.elapsed();
207 if elapsed < self.on_demand_refresh_cooldown {
208 let remaining = self.on_demand_refresh_cooldown.saturating_sub(elapsed);
209 tracing::debug!(
210 kid = kid,
211 remaining_secs = remaining.as_secs(),
212 "On-demand JWKS refresh throttled (cooldown active)"
213 );
214
215 if state.failed_kids.contains(kid) {
217 tracing::warn!(
218 kid = kid,
219 "Unknown kid repeatedly requested despite recent refresh attempts"
220 );
221 }
222
223 return Err(ClaimsError::UnknownKeyId(kid.to_owned()));
224 }
225 }
226 Ok(())
227 }
228
229 async fn handle_refresh_success(&self, kid: &str) -> Result<(), ClaimsError> {
231 let mut state = self.refresh_state.write().await;
232 state.last_on_demand_refresh = Some(Instant::now());
233
234 if self.key_exists(kid) {
236 state.failed_kids.remove(kid);
238 } else {
239 state.failed_kids.insert(kid.to_owned());
241 tracing::warn!(
242 kid = kid,
243 "Kid still not found after on-demand JWKS refresh"
244 );
245 }
246
247 Ok(())
248 }
249
250 async fn handle_refresh_failure(&self, kid: &str, error: ClaimsError) -> ClaimsError {
252 let mut state = self.refresh_state.write().await;
253 state.last_on_demand_refresh = Some(Instant::now());
254 state.failed_kids.insert(kid.to_owned());
255 error
256 }
257
258 async fn on_demand_refresh(&self, kid: &str) -> Result<(), ClaimsError> {
261 if self.key_exists(kid) {
263 return Ok(());
264 }
265
266 self.check_refresh_throttle(kid).await?;
268
269 tracing::info!(
271 kid = kid,
272 "Performing on-demand JWKS refresh for unknown kid"
273 );
274
275 match self.perform_refresh().await {
276 Ok(()) => self.handle_refresh_success(kid).await,
277 Err(e) => Err(self.handle_refresh_failure(kid, e).await),
278 }
279 }
280
281 fn get_key(&self, kid: &str) -> Option<DecodingKey> {
283 let keys = self.keys.load();
284 keys.get(kid).cloned()
285 }
286
287 fn validate_token(
289 token: &str,
290 key: &DecodingKey,
291 header: &Header,
292 ) -> Result<Value, ClaimsError> {
293 let mut validation = Validation::new(header.alg);
294
295 validation.validate_exp = false;
297 validation.validate_nbf = false;
298 validation.validate_aud = false;
299
300 let empty_claims: &[&str] = &[];
302 validation.set_required_spec_claims(empty_claims);
303
304 let token_data = decode::<Value>(token, key, &validation)
305 .map_err(|e| ClaimsError::DecodeFailed(format!("JWT validation failed: {e}")))?;
306
307 Ok(token_data.claims)
308 }
309}
310
311#[async_trait]
312impl KeyProvider for JwksKeyProvider {
313 fn name(&self) -> &'static str {
314 "jwks"
315 }
316
317 async fn validate_and_decode(&self, token: &str) -> Result<(Header, Value), ClaimsError> {
318 let token = token.trim_start_matches("Bearer ").trim();
320
321 let header = decode_header(token)
323 .map_err(|e| ClaimsError::DecodeFailed(format!("Invalid JWT header: {e}")))?;
324
325 let kid = header
326 .kid
327 .as_ref()
328 .ok_or_else(|| ClaimsError::DecodeFailed("Missing kid in JWT header".into()))?;
329
330 let key = if let Some(k) = self.get_key(kid) {
332 k
333 } else {
334 self.on_demand_refresh(kid).await?;
336
337 self.get_key(kid)
339 .ok_or_else(|| ClaimsError::UnknownKeyId(kid.clone()))?
340 };
341
342 let claims = Self::validate_token(token, &key, &header)?;
344
345 Ok((header, claims))
346 }
347
348 async fn refresh_keys(&self) -> Result<(), ClaimsError> {
349 if self.should_refresh().await {
350 self.perform_refresh().await
351 } else {
352 Ok(())
353 }
354 }
355}
356
357pub async fn run_jwks_refresh_task(provider: Arc<JwksKeyProvider>) {
359 let mut interval = tokio::time::interval(Duration::from_secs(60)); loop {
362 interval.tick().await;
363
364 if let Err(e) = provider.refresh_keys().await {
365 tracing::warn!("JWKS refresh failed: {}", e);
366 }
367 }
368}
369
370#[cfg(test)]
371#[cfg_attr(coverage_nightly, coverage(off))]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_calculate_backoff() -> Result<(), reqwest::Error> {
377 let provider = JwksKeyProvider::new("https://example.com/jwks")?;
378
379 assert_eq!(provider.calculate_backoff(0), Duration::from_secs(60));
380 assert_eq!(provider.calculate_backoff(1), Duration::from_secs(120));
381 assert_eq!(provider.calculate_backoff(2), Duration::from_secs(240));
382 assert_eq!(provider.calculate_backoff(3), Duration::from_secs(480));
383
384 assert_eq!(provider.calculate_backoff(100), provider.max_backoff);
386 Ok(())
387 }
388
389 #[tokio::test]
390 async fn test_should_refresh_on_first_call() -> Result<(), reqwest::Error> {
391 let provider = JwksKeyProvider::new("https://example.com/jwks")?;
392 assert!(provider.should_refresh().await);
393 Ok(())
394 }
395
396 #[tokio::test]
397 async fn test_key_storage() -> Result<(), reqwest::Error> {
398 let provider = JwksKeyProvider::new("https://example.com/jwks")?;
399
400 assert!(provider.get_key("test-kid").is_none());
402
403 let mut keys = HashMap::new();
405 keys.insert("test-kid".to_owned(), DecodingKey::from_secret(b"secret"));
406 provider.keys.store(Arc::new(keys));
407
408 assert!(provider.get_key("test-kid").is_some());
410 Ok(())
411 }
412
413 #[tokio::test]
414 async fn test_on_demand_refresh_returns_ok_when_key_exists() -> Result<(), reqwest::Error> {
415 let provider = JwksKeyProvider::new("https://example.com/jwks")?;
416
417 let mut keys = HashMap::new();
419 keys.insert(
420 "existing-kid".to_owned(),
421 DecodingKey::from_secret(b"secret"),
422 );
423 provider.keys.store(Arc::new(keys));
424
425 let result = provider.on_demand_refresh("existing-kid").await;
427 assert!(result.is_ok());
428 Ok(())
429 }
430
431 #[tokio::test]
432 async fn test_on_demand_refresh_returns_error_for_missing_key_on_failed_fetch()
433 -> Result<(), reqwest::Error> {
434 let provider =
435 JwksKeyProvider::new("https://invalid-domain-that-does-not-exist.local/jwks")?;
436
437 let result = provider.on_demand_refresh("missing-kid").await;
439 assert!(result.is_err());
440
441 match result.expect_err("expected error for missing key") {
443 ClaimsError::JwksFetchFailed(_) | ClaimsError::UnknownKeyId(_) => {}
444 other => panic!("Expected JwksFetchFailed or UnknownKeyId, got: {other:?}"),
445 }
446 Ok(())
447 }
448
449 #[tokio::test]
450 async fn test_on_demand_refresh_respects_cooldown() -> Result<(), reqwest::Error> {
451 let provider = JwksKeyProvider::new("https://invalid-domain.local/jwks")?
452 .with_on_demand_refresh_cooldown(Duration::from_secs(5));
453
454 let result1 = provider.on_demand_refresh("test-kid").await;
456 assert!(result1.is_err()); let result2 = provider.on_demand_refresh("test-kid").await;
460 assert!(result2.is_err());
461
462 match result2.expect_err("expected throttle error") {
464 ClaimsError::UnknownKeyId(_) => {}
465 other => panic!("Expected UnknownKeyId during cooldown, got: {other:?}"),
466 }
467 Ok(())
468 }
469
470 #[tokio::test]
471 async fn test_on_demand_refresh_tracks_failed_kids() -> Result<(), reqwest::Error> {
472 let provider = JwksKeyProvider::new("https://invalid-domain.local/jwks")?
473 .with_on_demand_refresh_cooldown(Duration::from_millis(100));
474
475 let result = provider.on_demand_refresh("failed-kid").await;
477 assert!(result.is_err());
478
479 let state = provider.refresh_state.read().await;
481 assert!(state.failed_kids.contains("failed-kid"));
482 Ok(())
483 }
484
485 #[tokio::test]
486 async fn test_validate_and_decode_with_missing_kid() -> Result<(), reqwest::Error> {
487 let provider = JwksKeyProvider::new("https://invalid-domain.local/jwks")?
488 .with_on_demand_refresh_cooldown(Duration::from_millis(100));
489
490 let token =
492 "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIn0.eyJzdWIiOiIxMjM0NTY3ODkwIn0.invalid";
493
494 let result = provider.validate_and_decode(token).await;
496 assert!(result.is_err());
497 Ok(())
498 }
499
500 #[tokio::test]
501 async fn test_perform_refresh_updates_state_on_success() -> Result<(), reqwest::Error> {
502 let provider = JwksKeyProvider::new("https://invalid-domain.local/jwks")?;
503
504 {
506 let mut state = provider.refresh_state.write().await;
507 state.consecutive_failures = 3;
508 state.last_error = Some("Previous error".to_owned());
509 }
510
511 let _ = provider.perform_refresh().await;
513
514 let state = provider.refresh_state.read().await;
516 assert_eq!(state.consecutive_failures, 4);
517 assert!(state.last_error.is_some());
518 Ok(())
519 }
520}