1use crate::{claims_error::ClaimsError, traits::KeyProvider};
2use arc_swap::ArcSwap;
3use async_trait::async_trait;
4use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
5use jsonwebtoken::{DecodingKey, Header, decode_header};
6use serde::Deserialize;
7use serde_json::Value;
8use std::collections::{HashMap, HashSet};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::RwLock;
12use tokio::time::Instant;
13use tokio_util::sync::CancellationToken;
14
15#[derive(Debug, Clone, Deserialize)]
16struct Jwk {
17 kid: String,
18 kty: String,
19 #[serde(rename = "use")]
20 #[allow(dead_code)]
21 use_: Option<String>,
22 n: String,
23 e: String,
24 #[allow(dead_code)]
25 alg: Option<String>,
26}
27
28#[derive(Debug, Clone, Deserialize)]
29struct JwksResponse {
30 keys: Vec<Jwk>,
31}
32
33type HeaderExtrasHandler = dyn Fn(&str, &Value) -> Option<String> + Send + Sync;
35
36const STANDARD_HEADER_FIELDS: &[&str] = &[
39 "typ", "alg", "cty", "jku", "jwk", "kid", "x5u", "x5c", "x5t", "x5t#S256", "crit", "enc",
40 "zip", "url", "nonce", "epk", "apu", "apv", "iv", "tag", "p2s", "p2c", "b64",
41];
42
43#[must_use]
47pub struct JwksKeyProvider {
48 jwks_uri: String,
50
51 keys: Arc<ArcSwap<HashMap<String, DecodingKey>>>,
53
54 refresh_state: Arc<RwLock<RefreshState>>,
56
57 client: modkit_http::HttpClient,
60
61 refresh_interval: Duration,
63
64 max_backoff: Duration,
66
67 on_demand_refresh_cooldown: Duration,
69
70 header_extras_handler: Option<Arc<HeaderExtrasHandler>>,
74}
75
76#[derive(Debug, Default)]
77struct RefreshState {
78 last_refresh: Option<Instant>,
79 last_on_demand_refresh: Option<Instant>,
80 consecutive_failures: u32,
81 last_error: Option<String>,
82 failed_kids: HashSet<String>,
83}
84
85impl JwksKeyProvider {
86 pub fn new(jwks_uri: impl Into<String>) -> Result<Self, modkit_http::HttpError> {
91 Self::with_http_timeout(jwks_uri, Duration::from_secs(10))
92 }
93
94 pub fn with_http_timeout(
99 jwks_uri: impl Into<String>,
100 timeout: Duration,
101 ) -> Result<Self, modkit_http::HttpError> {
102 let client = modkit_http::HttpClient::builder()
103 .timeout(timeout)
104 .retry(None) .build()?;
106
107 Ok(Self {
108 jwks_uri: jwks_uri.into(),
109 keys: Arc::new(ArcSwap::from_pointee(HashMap::new())),
110 refresh_state: Arc::new(RwLock::new(RefreshState::default())),
111 client,
112 refresh_interval: Duration::from_secs(300), max_backoff: Duration::from_secs(3600), on_demand_refresh_cooldown: Duration::from_secs(60), header_extras_handler: None,
116 })
117 }
118
119 pub fn try_new(jwks_uri: impl Into<String>) -> Result<Self, modkit_http::HttpError> {
124 Self::new(jwks_uri)
125 }
126
127 pub fn with_refresh_interval(mut self, interval: Duration) -> Self {
129 self.refresh_interval = interval;
130 self
131 }
132
133 pub fn with_max_backoff(mut self, max_backoff: Duration) -> Self {
135 self.max_backoff = max_backoff;
136 self
137 }
138
139 pub fn with_on_demand_refresh_cooldown(mut self, cooldown: Duration) -> Self {
141 self.on_demand_refresh_cooldown = cooldown;
142 self
143 }
144
145 pub fn with_header_extras_stringified(self) -> Self {
151 self.with_header_extras_handler(|_, v| Some(v.to_string()))
152 }
153
154 pub fn with_header_extras_handler(
161 mut self,
162 handler: impl Fn(&str, &Value) -> Option<String> + Send + Sync + 'static,
163 ) -> Self {
164 self.header_extras_handler = Some(Arc::new(handler));
165 self
166 }
167
168 async fn fetch_jwks(&self) -> Result<HashMap<String, DecodingKey>, ClaimsError> {
170 let jwks: JwksResponse = self
172 .client
173 .get(&self.jwks_uri)
174 .send()
175 .await
176 .map_err(|e| map_http_error(&e))?
177 .json()
178 .await
179 .map_err(|e| map_http_error(&e))?;
180
181 let mut keys = HashMap::new();
182 for jwk in jwks.keys {
183 if jwk.kty == "RSA" {
184 let key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)
185 .map_err(|e| ClaimsError::JwksFetchFailed(format!("Invalid RSA key: {e}")))?;
186 keys.insert(jwk.kid, key);
187 }
188 }
189
190 if keys.is_empty() {
191 return Err(ClaimsError::JwksFetchFailed(
192 "No valid RSA keys found in JWKS".into(),
193 ));
194 }
195
196 Ok(keys)
197 }
198
199 fn calculate_backoff(&self, failures: u32) -> Duration {
201 let base = Duration::from_secs(60); let exponential = base * 2u32.pow(failures.min(10)); exponential.min(self.max_backoff)
204 }
205
206 async fn should_refresh(&self) -> bool {
208 let state = self.refresh_state.read().await;
209
210 match state.last_refresh {
211 None => true, Some(last) => {
213 let elapsed = last.elapsed();
214 if state.consecutive_failures == 0 {
215 elapsed >= self.refresh_interval
217 } else {
218 elapsed >= self.calculate_backoff(state.consecutive_failures)
220 }
221 }
222 }
223 }
224
225 async fn perform_refresh(&self) -> Result<(), ClaimsError> {
227 match self.fetch_jwks().await {
228 Ok(new_keys) => {
229 self.keys.store(Arc::new(new_keys));
231
232 let mut state = self.refresh_state.write().await;
234 state.last_refresh = Some(Instant::now());
235 state.consecutive_failures = 0;
236 state.last_error = None;
237
238 Ok(())
239 }
240 Err(e) => {
241 let mut state = self.refresh_state.write().await;
243 state.last_refresh = Some(Instant::now());
244 state.consecutive_failures += 1;
245 state.last_error = Some(e.to_string());
246
247 Err(e)
248 }
249 }
250 }
251
252 fn key_exists(&self, kid: &str) -> bool {
254 let keys = self.keys.load();
255 keys.contains_key(kid)
256 }
257
258 async fn check_refresh_throttle(&self, kid: &str) -> Result<(), ClaimsError> {
260 let state = self.refresh_state.read().await;
261 if let Some(last_on_demand) = state.last_on_demand_refresh {
262 let elapsed = last_on_demand.elapsed();
263 if elapsed < self.on_demand_refresh_cooldown {
264 let remaining = self.on_demand_refresh_cooldown.saturating_sub(elapsed);
265 tracing::debug!(
266 kid = kid,
267 remaining_secs = remaining.as_secs(),
268 "On-demand JWKS refresh throttled (cooldown active)"
269 );
270
271 if state.failed_kids.contains(kid) {
273 tracing::warn!(
274 kid = kid,
275 "Unknown kid repeatedly requested despite recent refresh attempts"
276 );
277 }
278
279 return Err(ClaimsError::UnknownKeyId(kid.to_owned()));
280 }
281 }
282 Ok(())
283 }
284
285 async fn handle_refresh_success(&self, kid: &str) -> Result<(), ClaimsError> {
287 let mut state = self.refresh_state.write().await;
288 state.last_on_demand_refresh = Some(Instant::now());
289
290 if self.key_exists(kid) {
292 state.failed_kids.remove(kid);
294 } else {
295 state.failed_kids.insert(kid.to_owned());
297 tracing::warn!(
298 kid = kid,
299 "Kid still not found after on-demand JWKS refresh"
300 );
301 }
302
303 Ok(())
304 }
305
306 async fn handle_refresh_failure(&self, kid: &str, error: ClaimsError) -> ClaimsError {
308 let mut state = self.refresh_state.write().await;
309 state.last_on_demand_refresh = Some(Instant::now());
310 state.failed_kids.insert(kid.to_owned());
311 error
312 }
313
314 async fn on_demand_refresh(&self, kid: &str) -> Result<(), ClaimsError> {
317 if self.key_exists(kid) {
319 return Ok(());
320 }
321
322 self.check_refresh_throttle(kid).await?;
324
325 tracing::info!(
327 kid = kid,
328 "Performing on-demand JWKS refresh for unknown kid"
329 );
330
331 match self.perform_refresh().await {
332 Ok(()) => self.handle_refresh_success(kid).await,
333 Err(e) => Err(self.handle_refresh_failure(kid, e).await),
334 }
335 }
336
337 fn get_key(&self, kid: &str) -> Option<DecodingKey> {
339 let keys = self.keys.load();
340 keys.get(kid).cloned()
341 }
342
343 fn validate_token(
349 token: &str,
350 key: &DecodingKey,
351 header: &Header,
352 ) -> Result<Value, ClaimsError> {
353 let parts: Vec<&str> = token.splitn(4, '.').collect();
355 if parts.len() != 3 {
356 return Err(ClaimsError::DecodeFailed("Invalid JWT structure".into()));
357 }
358 let signing_input = &token[..parts[0].len() + 1 + parts[1].len()];
359 let payload_b64 = parts[1];
360 let signature = parts[2];
361
362 let valid =
364 jsonwebtoken::crypto::verify(signature, signing_input.as_bytes(), key, header.alg)
365 .map_err(|e| {
366 ClaimsError::DecodeFailed(format!("JWT signature verification failed: {e}"))
367 })?;
368 if !valid {
369 return Err(ClaimsError::InvalidSignature);
370 }
371
372 let payload_bytes = URL_SAFE_NO_PAD
374 .decode(payload_b64.trim_end_matches('='))
375 .map_err(|e| ClaimsError::DecodeFailed(format!("JWT payload decode failed: {e}")))?;
376 let claims: Value = serde_json::from_slice(&payload_bytes)
377 .map_err(|e| ClaimsError::DecodeFailed(format!("JWT claims parse failed: {e}")))?;
378
379 Ok(claims)
380 }
381}
382
383#[async_trait]
384impl KeyProvider for JwksKeyProvider {
385 fn name(&self) -> &'static str {
386 "jwks"
387 }
388
389 async fn validate_and_decode(&self, token: &str) -> Result<(Header, Value), ClaimsError> {
390 let token = token.trim_start_matches("Bearer ").trim();
392
393 let header = match &self.header_extras_handler {
395 Some(handler) => decode_header_with_handler(token, handler.as_ref()),
396 None => decode_header(token),
397 }
398 .map_err(|e| ClaimsError::DecodeFailed(format!("Invalid JWT header: {e}")))?;
399
400 let kid = header
401 .kid
402 .as_ref()
403 .ok_or_else(|| ClaimsError::DecodeFailed("Missing kid in JWT header".into()))?;
404
405 let key = if let Some(k) = self.get_key(kid) {
407 k
408 } else {
409 self.on_demand_refresh(kid).await?;
411
412 self.get_key(kid)
414 .ok_or_else(|| ClaimsError::UnknownKeyId(kid.clone()))?
415 };
416
417 let claims = Self::validate_token(token, &key, &header)?;
419
420 Ok((header, claims))
421 }
422
423 async fn refresh_keys(&self) -> Result<(), ClaimsError> {
424 if self.should_refresh().await {
425 self.perform_refresh().await
426 } else {
427 Ok(())
428 }
429 }
430}
431
432pub async fn run_jwks_refresh_task(
455 provider: Arc<JwksKeyProvider>,
456 cancellation_token: CancellationToken,
457) {
458 let mut interval = tokio::time::interval(Duration::from_secs(60)); loop {
461 tokio::select! {
462 () = cancellation_token.cancelled() => {
463 tracing::info!("JWKS refresh task shutting down");
464 break;
465 }
466 _ = interval.tick() => {
467 if let Err(e) = provider.refresh_keys().await {
468 tracing::warn!("JWKS refresh failed: {}", e);
469 }
470 }
471 }
472 }
473}
474
475fn decode_header_with_handler(
479 token: &str,
480 handler: &dyn Fn(&str, &Value) -> Option<String>,
481) -> Result<Header, jsonwebtoken::errors::Error> {
482 let header_b64 = token
483 .split('.')
484 .next()
485 .ok_or(jsonwebtoken::errors::ErrorKind::InvalidToken)?;
486
487 let header_bytes = URL_SAFE_NO_PAD
488 .decode(header_b64.trim_end_matches('='))
489 .map_err(jsonwebtoken::errors::ErrorKind::Base64)?;
490
491 let mut json: serde_json::Map<String, Value> = serde_json::from_slice(&header_bytes)?;
492
493 json.retain(|key, value| {
494 if STANDARD_HEADER_FIELDS.contains(&key.as_str()) || value.is_string() {
495 return true;
496 }
497 match handler(key, value) {
498 Some(s) => {
499 *value = Value::String(s);
500 true
501 }
502 None => false,
503 }
504 });
505
506 Ok(serde_json::from_value(Value::Object(json))?)
507}
508
509fn map_http_error(e: &modkit_http::HttpError) -> ClaimsError {
511 ClaimsError::JwksFetchFailed(crate::http_error::format_http_error(e, "JWKS"))
512}
513
514#[cfg(test)]
515#[cfg_attr(coverage_nightly, coverage(off))]
516mod tests {
517 use super::*;
518 use httpmock::prelude::*;
519
520 fn test_provider_with_http(uri: &str) -> JwksKeyProvider {
522 let client = modkit_http::HttpClient::builder()
523 .timeout(Duration::from_secs(5))
524 .retry(None)
525 .build()
526 .expect("failed to create test HTTP client");
527
528 JwksKeyProvider {
529 jwks_uri: uri.to_owned(),
530 keys: Arc::new(ArcSwap::from_pointee(HashMap::new())),
531 refresh_state: Arc::new(RwLock::new(RefreshState::default())),
532 client,
533 refresh_interval: Duration::from_secs(300),
534 max_backoff: Duration::from_secs(3600),
535 on_demand_refresh_cooldown: Duration::from_secs(60),
536 header_extras_handler: None,
537 }
538 }
539
540 fn test_provider(uri: &str) -> JwksKeyProvider {
542 JwksKeyProvider::new(uri).expect("failed to create test provider")
543 }
544
545 fn valid_jwks_json() -> &'static str {
547 r#"{
548 "keys": [{
549 "kty": "RSA",
550 "kid": "test-key-1",
551 "use": "sig",
552 "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
553 "e": "AQAB",
554 "alg": "RS256"
555 }]
556 }"#
557 }
558
559 #[tokio::test]
560 async fn test_calculate_backoff() {
561 let provider = test_provider("https://example.com/jwks");
562
563 assert_eq!(provider.calculate_backoff(0), Duration::from_secs(60));
564 assert_eq!(provider.calculate_backoff(1), Duration::from_secs(120));
565 assert_eq!(provider.calculate_backoff(2), Duration::from_secs(240));
566 assert_eq!(provider.calculate_backoff(3), Duration::from_secs(480));
567
568 assert_eq!(provider.calculate_backoff(100), provider.max_backoff);
570 }
571
572 #[tokio::test]
573 async fn test_should_refresh_on_first_call() {
574 let provider = test_provider("https://example.com/jwks");
575 assert!(provider.should_refresh().await);
576 }
577
578 #[tokio::test]
579 async fn test_key_storage() {
580 let provider = test_provider("https://example.com/jwks");
581
582 assert!(provider.get_key("test-kid").is_none());
584
585 let mut keys = HashMap::new();
587 keys.insert("test-kid".to_owned(), DecodingKey::from_secret(b"secret"));
588 provider.keys.store(Arc::new(keys));
589
590 assert!(provider.get_key("test-kid").is_some());
592 }
593
594 #[tokio::test]
595 async fn test_on_demand_refresh_returns_ok_when_key_exists() {
596 let provider = test_provider("https://example.com/jwks");
597
598 let mut keys = HashMap::new();
600 keys.insert(
601 "existing-kid".to_owned(),
602 DecodingKey::from_secret(b"secret"),
603 );
604 provider.keys.store(Arc::new(keys));
605
606 let result = provider.on_demand_refresh("existing-kid").await;
608 assert!(result.is_ok());
609 }
610
611 #[tokio::test]
612 async fn test_try_new_returns_result() {
613 let result = JwksKeyProvider::try_new("https://example.com/jwks");
615 assert!(result.is_ok());
616 }
617
618 #[tokio::test]
621 async fn test_fetch_jwks_success_with_valid_json() {
622 let server = MockServer::start();
623
624 let mock = server.mock(|when, then| {
625 when.method(GET).path("/jwks");
626 then.status(200)
627 .header("content-type", "application/json")
628 .body(valid_jwks_json());
629 });
630
631 let jwks_url = server.url("/jwks");
632 let provider = test_provider_with_http(&jwks_url);
633
634 let result = provider.perform_refresh().await;
635 assert!(result.is_ok(), "Expected success, got: {result:?}");
636
637 assert!(
639 provider.get_key("test-key-1").is_some(),
640 "Expected key 'test-key-1' to be stored"
641 );
642
643 mock.assert();
644 }
645
646 #[tokio::test]
647 async fn test_fetch_jwks_http_404_error_mapping() {
648 let server = MockServer::start();
649
650 let mock = server.mock(|when, then| {
651 when.method(GET).path("/jwks");
652 then.status(404).body("Not Found");
653 });
654
655 let jwks_url = server.url("/jwks");
656 let provider = test_provider_with_http(&jwks_url);
657
658 let result = provider.perform_refresh().await;
659 assert!(result.is_err());
660
661 let err = result.unwrap_err();
662 let err_msg = err.to_string();
663 assert!(
664 err_msg.contains("JWKS HTTP 404"),
665 "Expected error to contain 'JWKS HTTP 404', got: {err_msg}"
666 );
667 assert!(
669 !err_msg.to_lowercase().contains("parse"),
670 "HTTP status error should not mention 'parse', got: {err_msg}"
671 );
672
673 mock.assert();
674 }
675
676 #[tokio::test]
677 async fn test_fetch_jwks_http_500_error_mapping() {
678 let server = MockServer::start();
679
680 let mock = server.mock(|when, then| {
681 when.method(GET).path("/jwks");
682 then.status(500).body("Internal Server Error");
683 });
684
685 let jwks_url = server.url("/jwks");
686 let provider = test_provider_with_http(&jwks_url);
687
688 let result = provider.perform_refresh().await;
689 assert!(result.is_err());
690
691 let err = result.unwrap_err();
692 let err_msg = err.to_string();
693 assert!(
694 err_msg.contains("JWKS HTTP 500"),
695 "Expected error to contain 'JWKS HTTP 500', got: {err_msg}"
696 );
697
698 mock.assert();
699 }
700
701 #[tokio::test]
702 async fn test_fetch_jwks_invalid_json_error_mapping() {
703 let server = MockServer::start();
704
705 let mock = server.mock(|when, then| {
706 when.method(GET).path("/jwks");
707 then.status(200)
708 .header("content-type", "application/json")
709 .body("this is not valid json");
710 });
711
712 let jwks_url = server.url("/jwks");
713 let provider = test_provider_with_http(&jwks_url);
714
715 let result = provider.perform_refresh().await;
716 assert!(result.is_err());
717
718 let err = result.unwrap_err();
719 let err_msg = err.to_string();
720 assert!(
721 err_msg.contains("JWKS JSON parse failed"),
722 "Expected error to contain 'JWKS JSON parse failed', got: {err_msg}"
723 );
724
725 mock.assert();
726 }
727
728 #[tokio::test]
729 async fn test_fetch_jwks_empty_keys_error() {
730 let server = MockServer::start();
731
732 let mock = server.mock(|when, then| {
733 when.method(GET).path("/jwks");
734 then.status(200)
735 .header("content-type", "application/json")
736 .body(r#"{"keys": []}"#);
737 });
738
739 let jwks_url = server.url("/jwks");
740 let provider = test_provider_with_http(&jwks_url);
741
742 let result = provider.perform_refresh().await;
743 assert!(result.is_err());
744
745 let err = result.unwrap_err();
746 let err_msg = err.to_string();
747 assert!(
748 err_msg.contains("No valid RSA keys"),
749 "Expected error about no RSA keys, got: {err_msg}"
750 );
751
752 mock.assert();
753 }
754
755 #[tokio::test]
756 async fn test_on_demand_refresh_respects_cooldown() {
757 let server = MockServer::start();
758
759 let mock = server.mock(|when, then| {
761 when.method(GET).path("/jwks");
762 then.status(404).body("Not Found");
763 });
764
765 let jwks_url = server.url("/jwks");
766 let provider = test_provider_with_http(&jwks_url)
767 .with_on_demand_refresh_cooldown(Duration::from_secs(60));
768
769 let result1 = provider.on_demand_refresh("test-kid").await;
771 assert!(result1.is_err());
772
773 let result2 = provider.on_demand_refresh("test-kid").await;
775 assert!(result2.is_err());
776
777 match result2.unwrap_err() {
779 ClaimsError::UnknownKeyId(_) => {}
780 other => panic!("Expected UnknownKeyId during cooldown, got: {other:?}"),
781 }
782
783 mock.assert_calls(1);
785 }
786
787 #[tokio::test]
788 async fn test_on_demand_refresh_tracks_failed_kids() {
789 let server = MockServer::start();
790
791 server.mock(|when, then| {
792 when.method(GET).path("/jwks");
793 then.status(404).body("Not Found");
794 });
795
796 let jwks_url = server.url("/jwks");
797 let provider = test_provider_with_http(&jwks_url)
798 .with_on_demand_refresh_cooldown(Duration::from_millis(100));
799
800 let result = provider.on_demand_refresh("failed-kid").await;
802 assert!(result.is_err());
803
804 let state = provider.refresh_state.read().await;
806 assert!(state.failed_kids.contains("failed-kid"));
807 }
808
809 #[tokio::test]
810 async fn test_perform_refresh_updates_state_on_failure() {
811 let server = MockServer::start();
812
813 server.mock(|when, then| {
814 when.method(GET).path("/jwks");
815 then.status(500).body("Server Error");
816 });
817
818 let jwks_url = server.url("/jwks");
819 let provider = test_provider_with_http(&jwks_url);
820
821 {
823 let mut state = provider.refresh_state.write().await;
824 state.consecutive_failures = 3;
825 state.last_error = Some("Previous error".to_owned());
826 }
827
828 _ = provider.perform_refresh().await;
830
831 let state = provider.refresh_state.read().await;
833 assert_eq!(state.consecutive_failures, 4);
834 assert!(state.last_error.is_some());
835 }
836
837 #[tokio::test]
838 async fn test_perform_refresh_resets_state_on_success() {
839 let server = MockServer::start();
840
841 server.mock(|when, then| {
842 when.method(GET).path("/jwks");
843 then.status(200)
844 .header("content-type", "application/json")
845 .body(valid_jwks_json());
846 });
847
848 let jwks_url = server.url("/jwks");
849 let provider = test_provider_with_http(&jwks_url);
850
851 {
853 let mut state = provider.refresh_state.write().await;
854 state.consecutive_failures = 5;
855 state.last_error = Some("Previous error".to_owned());
856 }
857
858 let result = provider.perform_refresh().await;
860 assert!(result.is_ok());
861
862 let state = provider.refresh_state.read().await;
864 assert_eq!(state.consecutive_failures, 0);
865 assert!(state.last_error.is_none());
866 }
867
868 #[tokio::test]
869 async fn test_validate_and_decode_with_missing_kid() {
870 let server = MockServer::start();
871
872 server.mock(|when, then| {
874 when.method(GET).path("/jwks");
875 then.status(200)
876 .header("content-type", "application/json")
877 .body(valid_jwks_json());
878 });
879
880 let jwks_url = server.url("/jwks");
881 let provider = test_provider_with_http(&jwks_url)
882 .with_on_demand_refresh_cooldown(Duration::from_millis(100));
883
884 let token = "eyJhbGciOiJSUzI1NiIsImtpZCI6Im5vbmV4aXN0ZW50LWtpZCJ9.\
887 eyJzdWIiOiIxMjM0NTY3ODkwIn0.invalid";
888
889 let result = provider.validate_and_decode(token).await;
891 assert!(result.is_err());
892
893 match result.unwrap_err() {
894 ClaimsError::UnknownKeyId(kid) => {
895 assert_eq!(kid, "nonexistent-kid");
896 }
897 other => panic!("Expected UnknownKeyId, got: {other:?}"),
898 }
899 }
900
901 #[test]
902 fn test_decode_header_with_handler_coerces_non_string_extras() {
903 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
904
905 let header_json = r#"{"alg":"RS256","eap":1,"iri":"some-string-id","irn":["role_a"],"kid":"kid-1","typ":"at+jwt"}"#;
907 let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
908 let payload_b64 = URL_SAFE_NO_PAD.encode(b"{}");
909 let token = format!("{header_b64}.{payload_b64}.fake");
910
911 let header = decode_header_with_handler(&token, &|_key, value| Some(value.to_string()))
912 .expect("should handle non-standard header fields");
913
914 assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
915 assert_eq!(header.kid.as_deref(), Some("kid-1"));
916 assert_eq!(header.typ.as_deref(), Some("at+jwt"));
917
918 assert_eq!(header.extras.get("eap").map(String::as_str), Some("1"));
920 assert_eq!(
921 header.extras.get("irn").map(String::as_str),
922 Some(r#"["role_a"]"#)
923 );
924 assert_eq!(
926 header.extras.get("iri").map(String::as_str),
927 Some("some-string-id")
928 );
929 }
930
931 #[test]
932 fn test_decode_header_with_handler_can_drop_fields() {
933 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
934
935 let header_json = r#"{"alg":"RS256","eap":1,"iri":"keep-me","kid":"kid-1","typ":"JWT"}"#;
936 let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
937 let token = format!("{header_b64}.e30.fake");
938
939 let header = decode_header_with_handler(&token, &|_key, _value| None)
940 .expect("should succeed when handler drops non-string fields");
941
942 assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
943 assert!(!header.extras.contains_key("eap"));
944 assert_eq!(
945 header.extras.get("iri").map(String::as_str),
946 Some("keep-me")
947 );
948 }
949
950 #[tokio::test]
951 async fn test_with_header_extras_stringified_coerces_non_string_extras() {
952 let server = MockServer::start();
953
954 server.mock(|when, then| {
955 when.method(GET).path("/jwks");
956 then.status(200)
957 .header("content-type", "application/json")
958 .body(valid_jwks_json());
959 });
960
961 let jwks_url = server.url("/jwks");
962 let provider = test_provider_with_http(&jwks_url).with_header_extras_stringified();
963
964 let header_json =
966 r#"{"alg":"RS256","kid":"test-key-1","typ":"JWT","eap":1,"irn":["role_a"]}"#;
967 let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
968 let payload_b64 = URL_SAFE_NO_PAD.encode(b"{}");
969 let token = format!("{header_b64}.{payload_b64}.AAAA");
970
971 let result = provider.validate_and_decode(&token).await;
972
973 let err = result.expect_err("fake signature should fail validation");
976 assert!(
977 matches!(
978 &err,
979 ClaimsError::InvalidSignature | ClaimsError::DecodeFailed(_)
980 ),
981 "Expected signature-related error, got: {err:?}"
982 );
983 }
984
985 #[tokio::test]
986 async fn test_validate_and_decode_uses_header_extras_handler() {
987 let server = MockServer::start();
988
989 server.mock(|when, then| {
990 when.method(GET).path("/jwks");
991 then.status(200)
992 .header("content-type", "application/json")
993 .body(valid_jwks_json());
994 });
995
996 let jwks_url = server.url("/jwks");
997 let provider = test_provider_with_http(&jwks_url)
998 .with_header_extras_handler(|_key, value| Some(value.to_string()));
999
1000 let header_json = r#"{"alg":"RS256","kid":"test-key-1","typ":"JWT","eap":1}"#;
1002 let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
1003 let payload_b64 = URL_SAFE_NO_PAD.encode(b"{}");
1004 let token = format!("{header_b64}.{payload_b64}.AAAA");
1005
1006 let result = provider.validate_and_decode(&token).await;
1007
1008 let err = result.expect_err("fake signature should fail validation");
1011 assert!(
1012 matches!(
1013 &err,
1014 ClaimsError::InvalidSignature | ClaimsError::DecodeFailed(_)
1015 ),
1016 "Expected signature-related error, got: {err:?}"
1017 );
1018 }
1019
1020 const TEST_RSA_PRIVATE_PEM: &[u8] = b"-----BEGIN PRIVATE KEY-----
1023MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCohcw9B9YK7ULF
1024KgrGNJKAH0BH9CpJB03wIkQl6ECCJ/BfmBsNSWwZdnG0cWwwGhsSSSj32AKB+t6W
102544/vi9hv+PHusIRCMNqM/AJ/zA7xau9mNsxS8U8J3olm74vLFtF05hTRmJuefMmz
1026mOt4kMP44UeVg0nyFlToa0SmhMxIeFgz2VgktHjHDe/rr/FdrjMwxesz3ezj+Y4k
1027YPPrQfMZJTyEd68M+pPkjyg6AkakNSUJp+dZibnRLKcj6Ehz1W3lSGkaQ4YFSXVX
1028UCaHWNmPsJHejwKrUA/fbkYi3sLO7cW/4h+b2laWsL9qC4P2RJMbZBzklJoL+WoH
1029Lo5zUvo7AgMBAAECggEACrynlBXdOcn/EI/KqvErilUzY8I3NXrtKMkOHXosLf68
1030bmLDCngslny45t25HmFzaxlVLmFJW52vs95gy8rVqeCrDWGas5roOcZOpHTMWO5O
1031vWztXLV6Ky9OAsxtVC2qf6+vEOGPvKvHsBUkn4RdsAwuYuS//9gTZdF7yL46Q72o
1032pJ8bLUZBpqmVNyLxyfbFn8u9j71zMUweB9vOMYAIAv1cYRa/0bVYLIZumcotY822
1033B0ny1fLru1gDJt2p1DL9fQTg16pBYr1V0nhoiktS8Lx5PFLMI+NhmalBerqtPN+u
1034qqauu9jolmXtydfOP7pTN2sqGFAKlcx55KZlVLK2YQKBgQDaiRxPXnFCPY4yYBxS
1035POFJe8UcvoM3d5HGwQfbJ5PHq+YN8NW0ACaox6QQkQYmE9OHriHrVmp4af6erN2K
1036zbjmL41E5C4MzEau2ipZWY4GA+lLXomEiHsUD0cfqfL+7Fs6ufiG2nXrWIBXggz8
10378mTdP/LHMPybY0wxoZI5Xij+2wKBgQDFacPh+PhT0U8wu7nSgvQ85ozJN7TWq0KD
1038TgWuZ0W6L5OlAAVernYuvvRH/Uy9JqVfX4KLHbcEcdUx8t5usKMf8S3kQyMM8xK+
1039KaEYZNOMdA6E9PAJVD8crDQT/QD6/+oHrTTFFKxW7jWLY1ggWXVHk4CxLXBlDnKQ
1040xIA5DuhgIQKBgQCA5Km77loi1aeO8r0BjELcUpH52CwQhQeIEMYPbpJtDGhOBKQm
10413IfwuH99/euAfeUfe4cqBPgbOXkiIZcxjRDnQ1ixL1wx1DJEYwzjUjzAM4JgH8xA
1042TTc6p6AtftGBpepRAusgrq0qODLKajw63MS88kDBV5VGGRURmNhj2bOYTQKBgHPr
1043hiVj/9Wf+6M/KH9vfCFis9rYBi1jxRu7LeTaKXyJwWXLHFwbj7QlVuYK3AvZ7JOT
1044TuGHoldOzISW+3v95tuz0GHP9n39Ic1ePoVHd11rLLdv6J9hw+l/SNlP4EqDCZZW
1045Y70yRXyKRhDCVhYw0YglGhVv/CarFCTj7fMTSOphAoGBAJcM4H4qmCFLdR9FRQgT
1046YJPGcyjWPmm9tlb8M6rSJGPlfpAhKjRVGWwpHPiUnvrW296QKr9+5q43HRcK3qa5
1047GU5n8VxYiniVFVMSEpLJgvu7hGq5fmMiRTTot1pOTSXZ1LY6rDQvjsTeGQumb/Eo
1048F8gvjIeiwVfp4nDnO2JFexiy
1049-----END PRIVATE KEY-----";
1050
1051 fn signed_jwks_json() -> &'static str {
1053 r#"{
1054 "keys": [{
1055 "kty": "RSA",
1056 "kid": "sign-key-1",
1057 "use": "sig",
1058 "n": "qIXMPQfWCu1CxSoKxjSSgB9AR_QqSQdN8CJEJehAgifwX5gbDUlsGXZxtHFsMBobEkko99gCgfreluOP74vYb_jx7rCEQjDajPwCf8wO8WrvZjbMUvFPCd6JZu-LyxbRdOYU0ZibnnzJs5jreJDD-OFHlYNJ8hZU6GtEpoTMSHhYM9lYJLR4xw3v66_xXa4zMMXrM93s4_mOJGDz60HzGSU8hHevDPqT5I8oOgJGpDUlCafnWYm50SynI-hIc9Vt5UhpGkOGBUl1V1Amh1jZj7CR3o8Cq1AP325GIt7Czu3Fv-Ifm9pWlrC_aguD9kSTG2Qc5JSaC_lqBy6Oc1L6Ow",
1059 "e": "AQAB",
1060 "alg": "RS256"
1061 }]
1062 }"#
1063 }
1064
1065 fn build_signed_jwt(kid: &str, claims: &serde_json::Value) -> String {
1067 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(TEST_RSA_PRIVATE_PEM)
1068 .expect("test RSA PEM should be valid");
1069 let mut header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256);
1070 header.kid = Some(kid.to_owned());
1071 jsonwebtoken::encode(&header, claims, &encoding_key).expect("JWT signing should succeed")
1072 }
1073
1074 #[tokio::test]
1075 async fn test_validate_and_decode_happy_path() {
1076 let server = MockServer::start();
1077
1078 server.mock(|when, then| {
1079 when.method(GET).path("/jwks");
1080 then.status(200)
1081 .header("content-type", "application/json")
1082 .body(signed_jwks_json());
1083 });
1084
1085 let jwks_url = server.url("/jwks");
1086 let provider = test_provider_with_http(&jwks_url);
1087
1088 let claims = serde_json::json!({
1089 "sub": "user-42",
1090 "name": "Test User",
1091 "iat": 1_700_000_000u64
1092 });
1093 let token = build_signed_jwt("sign-key-1", &claims);
1094
1095 let (header, decoded_claims) = provider
1096 .validate_and_decode(&token)
1097 .await
1098 .expect("validate_and_decode should succeed for a properly signed token");
1099
1100 assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
1101 assert_eq!(header.kid.as_deref(), Some("sign-key-1"));
1102 assert_eq!(decoded_claims["sub"], "user-42");
1103 assert_eq!(decoded_claims["name"], "Test User");
1104 }
1105
1106 #[tokio::test]
1107 async fn test_validate_and_decode_with_bearer_prefix() {
1108 let server = MockServer::start();
1109
1110 server.mock(|when, then| {
1111 when.method(GET).path("/jwks");
1112 then.status(200)
1113 .header("content-type", "application/json")
1114 .body(signed_jwks_json());
1115 });
1116
1117 let jwks_url = server.url("/jwks");
1118 let provider = test_provider_with_http(&jwks_url);
1119
1120 let claims = serde_json::json!({"sub": "user-99"});
1121 let token = format!("Bearer {}", build_signed_jwt("sign-key-1", &claims));
1122
1123 let (_, decoded_claims) = provider
1124 .validate_and_decode(&token)
1125 .await
1126 .expect("should strip Bearer prefix and succeed");
1127
1128 assert_eq!(decoded_claims["sub"], "user-99");
1129 }
1130
1131 #[tokio::test]
1132 async fn test_validate_and_decode_rejects_tampered_payload() {
1133 let server = MockServer::start();
1134
1135 server.mock(|when, then| {
1136 when.method(GET).path("/jwks");
1137 then.status(200)
1138 .header("content-type", "application/json")
1139 .body(signed_jwks_json());
1140 });
1141
1142 let jwks_url = server.url("/jwks");
1143 let provider = test_provider_with_http(&jwks_url);
1144
1145 let claims = serde_json::json!({"sub": "legit"});
1146 let token = build_signed_jwt("sign-key-1", &claims);
1147
1148 let parts: Vec<&str> = token.splitn(3, '.').collect();
1150 let tampered_payload = URL_SAFE_NO_PAD.encode(br#"{"sub":"evil"}"#);
1151 let tampered_token = format!("{}.{}.{}", parts[0], tampered_payload, parts[2]);
1152
1153 let err = provider
1154 .validate_and_decode(&tampered_token)
1155 .await
1156 .expect_err("tampered token should fail signature verification");
1157
1158 assert!(
1159 matches!(err, ClaimsError::InvalidSignature),
1160 "Expected InvalidSignature, got: {err:?}"
1161 );
1162 }
1163
1164 fn build_signed_jwt_custom_header(header_json: &str, claims: &serde_json::Value) -> String {
1166 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(TEST_RSA_PRIVATE_PEM)
1167 .expect("test RSA PEM should be valid");
1168 let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
1169 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(claims).unwrap());
1170 let message = format!("{header_b64}.{payload_b64}");
1171 let signature = jsonwebtoken::crypto::sign(
1172 message.as_bytes(),
1173 &encoding_key,
1174 jsonwebtoken::Algorithm::RS256,
1175 )
1176 .expect("signing should succeed");
1177 format!("{message}.{signature}")
1178 }
1179
1180 #[tokio::test]
1181 async fn test_validate_and_decode_with_non_string_header_extras() {
1182 let server = MockServer::start();
1183
1184 server.mock(|when, then| {
1185 when.method(GET).path("/jwks");
1186 then.status(200)
1187 .header("content-type", "application/json")
1188 .body(signed_jwks_json());
1189 });
1190
1191 let jwks_url = server.url("/jwks");
1192 let provider = test_provider_with_http(&jwks_url).with_header_extras_stringified();
1193
1194 let claims = serde_json::json!({"sub": "user-extras"});
1195 let header_json = r#"{"alg":"RS256","kid":"sign-key-1","typ":"JWT","eap":1}"#;
1196 let token = build_signed_jwt_custom_header(header_json, &claims);
1197
1198 let (header, decoded_claims) = provider
1199 .validate_and_decode(&token)
1200 .await
1201 .expect("should decode JWT with non-string header extras when handler is set");
1202
1203 assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
1204 assert_eq!(header.kid.as_deref(), Some("sign-key-1"));
1205 assert_eq!(header.extras.get("eap").map(String::as_str), Some("1"));
1206 assert_eq!(decoded_claims["sub"], "user-extras");
1207 }
1208
1209 #[test]
1210 fn test_decode_header_without_handler_rejects_non_string_extras() {
1211 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
1212
1213 let header_json = r#"{"alg":"RS256","eap":1,"kid":"kid-1","typ":"JWT"}"#;
1214 let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
1215 let token = format!("{header_b64}.e30.fake");
1216
1217 let result = decode_header(&token);
1218 assert!(result.is_err());
1219 let err = result.unwrap_err().to_string();
1220 assert!(
1221 err.contains("invalid type: integer"),
1222 "expected type error, got: {err}"
1223 );
1224 }
1225}