1use crate::{claims_error::ClaimsError, plugin_traits::KeyProvider};
2use arc_swap::ArcSwap;
3use async_trait::async_trait;
4use jsonwebtoken::{DecodingKey, Header, Validation, decode, decode_header};
5use serde::Deserialize;
6use serde_json::Value;
7use std::collections::{HashMap, HashSet};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::RwLock;
11use tokio::time::Instant;
12use tokio_util::sync::CancellationToken;
13
14#[derive(Debug, Clone, Deserialize)]
15struct Jwk {
16 kid: String,
17 kty: String,
18 #[serde(rename = "use")]
19 #[allow(dead_code)]
20 use_: Option<String>,
21 n: String,
22 e: String,
23 #[allow(dead_code)]
24 alg: Option<String>,
25}
26
27#[derive(Debug, Clone, Deserialize)]
28struct JwksResponse {
29 keys: Vec<Jwk>,
30}
31
32#[must_use]
36pub struct JwksKeyProvider {
37 jwks_uri: String,
39
40 keys: Arc<ArcSwap<HashMap<String, DecodingKey>>>,
42
43 refresh_state: Arc<RwLock<RefreshState>>,
45
46 client: modkit_http::HttpClient,
49
50 refresh_interval: Duration,
52
53 max_backoff: Duration,
55
56 on_demand_refresh_cooldown: Duration,
58}
59
60#[derive(Debug, Default)]
61struct RefreshState {
62 last_refresh: Option<Instant>,
63 last_on_demand_refresh: Option<Instant>,
64 consecutive_failures: u32,
65 last_error: Option<String>,
66 failed_kids: HashSet<String>,
67}
68
69impl JwksKeyProvider {
70 pub fn new(jwks_uri: impl Into<String>) -> Result<Self, modkit_http::HttpError> {
75 Self::with_http_timeout(jwks_uri, Duration::from_secs(10))
76 }
77
78 pub fn with_http_timeout(
83 jwks_uri: impl Into<String>,
84 timeout: Duration,
85 ) -> Result<Self, modkit_http::HttpError> {
86 let client = modkit_http::HttpClient::builder()
87 .timeout(timeout)
88 .retry(None) .build()?;
90
91 Ok(Self {
92 jwks_uri: jwks_uri.into(),
93 keys: Arc::new(ArcSwap::from_pointee(HashMap::new())),
94 refresh_state: Arc::new(RwLock::new(RefreshState::default())),
95 client,
96 refresh_interval: Duration::from_secs(300), max_backoff: Duration::from_secs(3600), on_demand_refresh_cooldown: Duration::from_secs(60), })
100 }
101
102 pub fn try_new(jwks_uri: impl Into<String>) -> Result<Self, modkit_http::HttpError> {
107 Self::new(jwks_uri)
108 }
109
110 pub fn with_refresh_interval(mut self, interval: Duration) -> Self {
112 self.refresh_interval = interval;
113 self
114 }
115
116 pub fn with_max_backoff(mut self, max_backoff: Duration) -> Self {
118 self.max_backoff = max_backoff;
119 self
120 }
121
122 pub fn with_on_demand_refresh_cooldown(mut self, cooldown: Duration) -> Self {
124 self.on_demand_refresh_cooldown = cooldown;
125 self
126 }
127
128 async fn fetch_jwks(&self) -> Result<HashMap<String, DecodingKey>, ClaimsError> {
130 let jwks: JwksResponse = self
132 .client
133 .get(&self.jwks_uri)
134 .send()
135 .await
136 .map_err(|e| map_http_error(&e))?
137 .json()
138 .await
139 .map_err(|e| map_http_error(&e))?;
140
141 let mut keys = HashMap::new();
142 for jwk in jwks.keys {
143 if jwk.kty == "RSA" {
144 let key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)
145 .map_err(|e| ClaimsError::JwksFetchFailed(format!("Invalid RSA key: {e}")))?;
146 keys.insert(jwk.kid, key);
147 }
148 }
149
150 if keys.is_empty() {
151 return Err(ClaimsError::JwksFetchFailed(
152 "No valid RSA keys found in JWKS".into(),
153 ));
154 }
155
156 Ok(keys)
157 }
158
159 fn calculate_backoff(&self, failures: u32) -> Duration {
161 let base = Duration::from_secs(60); let exponential = base * 2u32.pow(failures.min(10)); exponential.min(self.max_backoff)
164 }
165
166 async fn should_refresh(&self) -> bool {
168 let state = self.refresh_state.read().await;
169
170 match state.last_refresh {
171 None => true, Some(last) => {
173 let elapsed = last.elapsed();
174 if state.consecutive_failures == 0 {
175 elapsed >= self.refresh_interval
177 } else {
178 elapsed >= self.calculate_backoff(state.consecutive_failures)
180 }
181 }
182 }
183 }
184
185 async fn perform_refresh(&self) -> Result<(), ClaimsError> {
187 match self.fetch_jwks().await {
188 Ok(new_keys) => {
189 self.keys.store(Arc::new(new_keys));
191
192 let mut state = self.refresh_state.write().await;
194 state.last_refresh = Some(Instant::now());
195 state.consecutive_failures = 0;
196 state.last_error = None;
197
198 Ok(())
199 }
200 Err(e) => {
201 let mut state = self.refresh_state.write().await;
203 state.last_refresh = Some(Instant::now());
204 state.consecutive_failures += 1;
205 state.last_error = Some(e.to_string());
206
207 Err(e)
208 }
209 }
210 }
211
212 fn key_exists(&self, kid: &str) -> bool {
214 let keys = self.keys.load();
215 keys.contains_key(kid)
216 }
217
218 async fn check_refresh_throttle(&self, kid: &str) -> Result<(), ClaimsError> {
220 let state = self.refresh_state.read().await;
221 if let Some(last_on_demand) = state.last_on_demand_refresh {
222 let elapsed = last_on_demand.elapsed();
223 if elapsed < self.on_demand_refresh_cooldown {
224 let remaining = self.on_demand_refresh_cooldown.saturating_sub(elapsed);
225 tracing::debug!(
226 kid = kid,
227 remaining_secs = remaining.as_secs(),
228 "On-demand JWKS refresh throttled (cooldown active)"
229 );
230
231 if state.failed_kids.contains(kid) {
233 tracing::warn!(
234 kid = kid,
235 "Unknown kid repeatedly requested despite recent refresh attempts"
236 );
237 }
238
239 return Err(ClaimsError::UnknownKeyId(kid.to_owned()));
240 }
241 }
242 Ok(())
243 }
244
245 async fn handle_refresh_success(&self, kid: &str) -> Result<(), ClaimsError> {
247 let mut state = self.refresh_state.write().await;
248 state.last_on_demand_refresh = Some(Instant::now());
249
250 if self.key_exists(kid) {
252 state.failed_kids.remove(kid);
254 } else {
255 state.failed_kids.insert(kid.to_owned());
257 tracing::warn!(
258 kid = kid,
259 "Kid still not found after on-demand JWKS refresh"
260 );
261 }
262
263 Ok(())
264 }
265
266 async fn handle_refresh_failure(&self, kid: &str, error: ClaimsError) -> ClaimsError {
268 let mut state = self.refresh_state.write().await;
269 state.last_on_demand_refresh = Some(Instant::now());
270 state.failed_kids.insert(kid.to_owned());
271 error
272 }
273
274 async fn on_demand_refresh(&self, kid: &str) -> Result<(), ClaimsError> {
277 if self.key_exists(kid) {
279 return Ok(());
280 }
281
282 self.check_refresh_throttle(kid).await?;
284
285 tracing::info!(
287 kid = kid,
288 "Performing on-demand JWKS refresh for unknown kid"
289 );
290
291 match self.perform_refresh().await {
292 Ok(()) => self.handle_refresh_success(kid).await,
293 Err(e) => Err(self.handle_refresh_failure(kid, e).await),
294 }
295 }
296
297 fn get_key(&self, kid: &str) -> Option<DecodingKey> {
299 let keys = self.keys.load();
300 keys.get(kid).cloned()
301 }
302
303 fn validate_token(
305 token: &str,
306 key: &DecodingKey,
307 header: &Header,
308 ) -> Result<Value, ClaimsError> {
309 let mut validation = Validation::new(header.alg);
310
311 validation.validate_exp = false;
313 validation.validate_nbf = false;
314 validation.validate_aud = false;
315
316 let empty_claims: &[&str] = &[];
318 validation.set_required_spec_claims(empty_claims);
319
320 let token_data = decode::<Value>(token, key, &validation)
321 .map_err(|e| ClaimsError::DecodeFailed(format!("JWT validation failed: {e}")))?;
322
323 Ok(token_data.claims)
324 }
325}
326
327#[async_trait]
328impl KeyProvider for JwksKeyProvider {
329 fn name(&self) -> &'static str {
330 "jwks"
331 }
332
333 async fn validate_and_decode(&self, token: &str) -> Result<(Header, Value), ClaimsError> {
334 let token = token.trim_start_matches("Bearer ").trim();
336
337 let header = decode_header(token)
339 .map_err(|e| ClaimsError::DecodeFailed(format!("Invalid JWT header: {e}")))?;
340
341 let kid = header
342 .kid
343 .as_ref()
344 .ok_or_else(|| ClaimsError::DecodeFailed("Missing kid in JWT header".into()))?;
345
346 let key = if let Some(k) = self.get_key(kid) {
348 k
349 } else {
350 self.on_demand_refresh(kid).await?;
352
353 self.get_key(kid)
355 .ok_or_else(|| ClaimsError::UnknownKeyId(kid.clone()))?
356 };
357
358 let claims = Self::validate_token(token, &key, &header)?;
360
361 Ok((header, claims))
362 }
363
364 async fn refresh_keys(&self) -> Result<(), ClaimsError> {
365 if self.should_refresh().await {
366 self.perform_refresh().await
367 } else {
368 Ok(())
369 }
370 }
371}
372
373pub async fn run_jwks_refresh_task(
396 provider: Arc<JwksKeyProvider>,
397 cancellation_token: CancellationToken,
398) {
399 let mut interval = tokio::time::interval(Duration::from_secs(60)); loop {
402 tokio::select! {
403 () = cancellation_token.cancelled() => {
404 tracing::info!("JWKS refresh task shutting down");
405 break;
406 }
407 _ = interval.tick() => {
408 if let Err(e) = provider.refresh_keys().await {
409 tracing::warn!("JWKS refresh failed: {}", e);
410 }
411 }
412 }
413 }
414}
415
416fn map_http_error(e: &modkit_http::HttpError) -> ClaimsError {
418 ClaimsError::JwksFetchFailed(crate::http_error::format_http_error(e, "JWKS"))
419}
420
421#[cfg(test)]
422#[cfg_attr(coverage_nightly, coverage(off))]
423mod tests {
424 use super::*;
425 use httpmock::prelude::*;
426
427 fn test_provider_with_http(uri: &str) -> JwksKeyProvider {
429 let client = modkit_http::HttpClient::builder()
430 .timeout(Duration::from_secs(5))
431 .retry(None)
432 .allow_insecure_http()
433 .build()
434 .expect("failed to create test HTTP client");
435
436 JwksKeyProvider {
437 jwks_uri: uri.to_owned(),
438 keys: Arc::new(ArcSwap::from_pointee(HashMap::new())),
439 refresh_state: Arc::new(RwLock::new(RefreshState::default())),
440 client,
441 refresh_interval: Duration::from_secs(300),
442 max_backoff: Duration::from_secs(3600),
443 on_demand_refresh_cooldown: Duration::from_secs(60),
444 }
445 }
446
447 fn test_provider(uri: &str) -> JwksKeyProvider {
449 JwksKeyProvider::new(uri).expect("failed to create test provider")
450 }
451
452 fn valid_jwks_json() -> &'static str {
454 r#"{
455 "keys": [{
456 "kty": "RSA",
457 "kid": "test-key-1",
458 "use": "sig",
459 "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
460 "e": "AQAB",
461 "alg": "RS256"
462 }]
463 }"#
464 }
465
466 #[tokio::test]
467 async fn test_calculate_backoff() {
468 let provider = test_provider("https://example.com/jwks");
469
470 assert_eq!(provider.calculate_backoff(0), Duration::from_secs(60));
471 assert_eq!(provider.calculate_backoff(1), Duration::from_secs(120));
472 assert_eq!(provider.calculate_backoff(2), Duration::from_secs(240));
473 assert_eq!(provider.calculate_backoff(3), Duration::from_secs(480));
474
475 assert_eq!(provider.calculate_backoff(100), provider.max_backoff);
477 }
478
479 #[tokio::test]
480 async fn test_should_refresh_on_first_call() {
481 let provider = test_provider("https://example.com/jwks");
482 assert!(provider.should_refresh().await);
483 }
484
485 #[tokio::test]
486 async fn test_key_storage() {
487 let provider = test_provider("https://example.com/jwks");
488
489 assert!(provider.get_key("test-kid").is_none());
491
492 let mut keys = HashMap::new();
494 keys.insert("test-kid".to_owned(), DecodingKey::from_secret(b"secret"));
495 provider.keys.store(Arc::new(keys));
496
497 assert!(provider.get_key("test-kid").is_some());
499 }
500
501 #[tokio::test]
502 async fn test_on_demand_refresh_returns_ok_when_key_exists() {
503 let provider = test_provider("https://example.com/jwks");
504
505 let mut keys = HashMap::new();
507 keys.insert(
508 "existing-kid".to_owned(),
509 DecodingKey::from_secret(b"secret"),
510 );
511 provider.keys.store(Arc::new(keys));
512
513 let result = provider.on_demand_refresh("existing-kid").await;
515 assert!(result.is_ok());
516 }
517
518 #[tokio::test]
519 async fn test_try_new_returns_result() {
520 let result = JwksKeyProvider::try_new("https://example.com/jwks");
522 assert!(result.is_ok());
523 }
524
525 #[tokio::test]
528 async fn test_fetch_jwks_success_with_valid_json() {
529 let server = MockServer::start();
530
531 let mock = server.mock(|when, then| {
532 when.method(GET).path("/jwks");
533 then.status(200)
534 .header("content-type", "application/json")
535 .body(valid_jwks_json());
536 });
537
538 let jwks_url = server.url("/jwks");
539 let provider = test_provider_with_http(&jwks_url);
540
541 let result = provider.perform_refresh().await;
542 assert!(result.is_ok(), "Expected success, got: {result:?}");
543
544 assert!(
546 provider.get_key("test-key-1").is_some(),
547 "Expected key 'test-key-1' to be stored"
548 );
549
550 mock.assert();
551 }
552
553 #[tokio::test]
554 async fn test_fetch_jwks_http_404_error_mapping() {
555 let server = MockServer::start();
556
557 let mock = server.mock(|when, then| {
558 when.method(GET).path("/jwks");
559 then.status(404).body("Not Found");
560 });
561
562 let jwks_url = server.url("/jwks");
563 let provider = test_provider_with_http(&jwks_url);
564
565 let result = provider.perform_refresh().await;
566 assert!(result.is_err());
567
568 let err = result.unwrap_err();
569 let err_msg = err.to_string();
570 assert!(
571 err_msg.contains("JWKS HTTP 404"),
572 "Expected error to contain 'JWKS HTTP 404', got: {err_msg}"
573 );
574 assert!(
576 !err_msg.to_lowercase().contains("parse"),
577 "HTTP status error should not mention 'parse', got: {err_msg}"
578 );
579
580 mock.assert();
581 }
582
583 #[tokio::test]
584 async fn test_fetch_jwks_http_500_error_mapping() {
585 let server = MockServer::start();
586
587 let mock = server.mock(|when, then| {
588 when.method(GET).path("/jwks");
589 then.status(500).body("Internal Server Error");
590 });
591
592 let jwks_url = server.url("/jwks");
593 let provider = test_provider_with_http(&jwks_url);
594
595 let result = provider.perform_refresh().await;
596 assert!(result.is_err());
597
598 let err = result.unwrap_err();
599 let err_msg = err.to_string();
600 assert!(
601 err_msg.contains("JWKS HTTP 500"),
602 "Expected error to contain 'JWKS HTTP 500', got: {err_msg}"
603 );
604
605 mock.assert();
606 }
607
608 #[tokio::test]
609 async fn test_fetch_jwks_invalid_json_error_mapping() {
610 let server = MockServer::start();
611
612 let mock = server.mock(|when, then| {
613 when.method(GET).path("/jwks");
614 then.status(200)
615 .header("content-type", "application/json")
616 .body("this is not valid json");
617 });
618
619 let jwks_url = server.url("/jwks");
620 let provider = test_provider_with_http(&jwks_url);
621
622 let result = provider.perform_refresh().await;
623 assert!(result.is_err());
624
625 let err = result.unwrap_err();
626 let err_msg = err.to_string();
627 assert!(
628 err_msg.contains("JWKS JSON parse failed"),
629 "Expected error to contain 'JWKS JSON parse failed', got: {err_msg}"
630 );
631
632 mock.assert();
633 }
634
635 #[tokio::test]
636 async fn test_fetch_jwks_empty_keys_error() {
637 let server = MockServer::start();
638
639 let mock = server.mock(|when, then| {
640 when.method(GET).path("/jwks");
641 then.status(200)
642 .header("content-type", "application/json")
643 .body(r#"{"keys": []}"#);
644 });
645
646 let jwks_url = server.url("/jwks");
647 let provider = test_provider_with_http(&jwks_url);
648
649 let result = provider.perform_refresh().await;
650 assert!(result.is_err());
651
652 let err = result.unwrap_err();
653 let err_msg = err.to_string();
654 assert!(
655 err_msg.contains("No valid RSA keys"),
656 "Expected error about no RSA keys, got: {err_msg}"
657 );
658
659 mock.assert();
660 }
661
662 #[tokio::test]
663 async fn test_on_demand_refresh_respects_cooldown() {
664 let server = MockServer::start();
665
666 let mock = server.mock(|when, then| {
668 when.method(GET).path("/jwks");
669 then.status(404).body("Not Found");
670 });
671
672 let jwks_url = server.url("/jwks");
673 let provider = test_provider_with_http(&jwks_url)
674 .with_on_demand_refresh_cooldown(Duration::from_secs(60));
675
676 let result1 = provider.on_demand_refresh("test-kid").await;
678 assert!(result1.is_err());
679
680 let result2 = provider.on_demand_refresh("test-kid").await;
682 assert!(result2.is_err());
683
684 match result2.unwrap_err() {
686 ClaimsError::UnknownKeyId(_) => {}
687 other => panic!("Expected UnknownKeyId during cooldown, got: {other:?}"),
688 }
689
690 mock.assert_calls(1);
692 }
693
694 #[tokio::test]
695 async fn test_on_demand_refresh_tracks_failed_kids() {
696 let server = MockServer::start();
697
698 server.mock(|when, then| {
699 when.method(GET).path("/jwks");
700 then.status(404).body("Not Found");
701 });
702
703 let jwks_url = server.url("/jwks");
704 let provider = test_provider_with_http(&jwks_url)
705 .with_on_demand_refresh_cooldown(Duration::from_millis(100));
706
707 let result = provider.on_demand_refresh("failed-kid").await;
709 assert!(result.is_err());
710
711 let state = provider.refresh_state.read().await;
713 assert!(state.failed_kids.contains("failed-kid"));
714 }
715
716 #[tokio::test]
717 async fn test_perform_refresh_updates_state_on_failure() {
718 let server = MockServer::start();
719
720 server.mock(|when, then| {
721 when.method(GET).path("/jwks");
722 then.status(500).body("Server Error");
723 });
724
725 let jwks_url = server.url("/jwks");
726 let provider = test_provider_with_http(&jwks_url);
727
728 {
730 let mut state = provider.refresh_state.write().await;
731 state.consecutive_failures = 3;
732 state.last_error = Some("Previous error".to_owned());
733 }
734
735 _ = provider.perform_refresh().await;
737
738 let state = provider.refresh_state.read().await;
740 assert_eq!(state.consecutive_failures, 4);
741 assert!(state.last_error.is_some());
742 }
743
744 #[tokio::test]
745 async fn test_perform_refresh_resets_state_on_success() {
746 let server = MockServer::start();
747
748 server.mock(|when, then| {
749 when.method(GET).path("/jwks");
750 then.status(200)
751 .header("content-type", "application/json")
752 .body(valid_jwks_json());
753 });
754
755 let jwks_url = server.url("/jwks");
756 let provider = test_provider_with_http(&jwks_url);
757
758 {
760 let mut state = provider.refresh_state.write().await;
761 state.consecutive_failures = 5;
762 state.last_error = Some("Previous error".to_owned());
763 }
764
765 let result = provider.perform_refresh().await;
767 assert!(result.is_ok());
768
769 let state = provider.refresh_state.read().await;
771 assert_eq!(state.consecutive_failures, 0);
772 assert!(state.last_error.is_none());
773 }
774
775 #[tokio::test]
776 async fn test_validate_and_decode_with_missing_kid() {
777 let server = MockServer::start();
778
779 server.mock(|when, then| {
781 when.method(GET).path("/jwks");
782 then.status(200)
783 .header("content-type", "application/json")
784 .body(valid_jwks_json());
785 });
786
787 let jwks_url = server.url("/jwks");
788 let provider = test_provider_with_http(&jwks_url)
789 .with_on_demand_refresh_cooldown(Duration::from_millis(100));
790
791 let token = "eyJhbGciOiJSUzI1NiIsImtpZCI6Im5vbmV4aXN0ZW50LWtpZCJ9.\
794 eyJzdWIiOiIxMjM0NTY3ODkwIn0.invalid";
795
796 let result = provider.validate_and_decode(token).await;
798 assert!(result.is_err());
799
800 match result.unwrap_err() {
801 ClaimsError::UnknownKeyId(kid) => {
802 assert_eq!(kid, "nonexistent-kid");
803 }
804 other => panic!("Expected UnknownKeyId, got: {other:?}"),
805 }
806 }
807}