1use crate::{claims_error::ClaimsError, traits::KeyProvider};
2use arc_swap::ArcSwap;
3use async_trait::async_trait;
4use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
5use jsonwebtoken::{DecodingKey, Header, Validation, decode, decode_header};
6use serde::Deserialize;
7use serde_json::Value;
8use std::collections::{HashMap, HashSet};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::RwLock;
12use tokio::time::Instant;
13use tokio_util::sync::CancellationToken;
14
15#[derive(Debug, Clone, Deserialize)]
16struct Jwk {
17 kid: String,
18 kty: String,
19 #[serde(rename = "use")]
20 #[allow(dead_code)]
21 use_: Option<String>,
22 n: String,
23 e: String,
24 #[allow(dead_code)]
25 alg: Option<String>,
26}
27
28#[derive(Debug, Clone, Deserialize)]
29struct JwksResponse {
30 keys: Vec<Jwk>,
31}
32
33type HeaderExtrasHandler = dyn Fn(&str, &Value) -> Option<String> + Send + Sync;
35
36const STANDARD_HEADER_FIELDS: &[&str] = &[
39 "typ", "alg", "cty", "jku", "jwk", "kid", "x5u", "x5c", "x5t", "x5t#S256", "crit", "enc",
40 "zip", "url", "nonce", "epk", "apu", "apv", "iv", "tag", "p2s", "p2c", "b64",
41];
42
43#[must_use]
47pub struct JwksKeyProvider {
48 jwks_uri: String,
50
51 keys: Arc<ArcSwap<HashMap<String, DecodingKey>>>,
53
54 refresh_state: Arc<RwLock<RefreshState>>,
56
57 client: modkit_http::HttpClient,
60
61 refresh_interval: Duration,
63
64 max_backoff: Duration,
66
67 on_demand_refresh_cooldown: Duration,
69
70 header_extras_handler: Option<Arc<HeaderExtrasHandler>>,
74}
75
76#[derive(Debug, Default)]
77struct RefreshState {
78 last_refresh: Option<Instant>,
79 last_on_demand_refresh: Option<Instant>,
80 consecutive_failures: u32,
81 last_error: Option<String>,
82 failed_kids: HashSet<String>,
83}
84
85impl JwksKeyProvider {
86 pub fn new(jwks_uri: impl Into<String>) -> Result<Self, modkit_http::HttpError> {
91 Self::with_http_timeout(jwks_uri, Duration::from_secs(10))
92 }
93
94 pub fn with_http_timeout(
99 jwks_uri: impl Into<String>,
100 timeout: Duration,
101 ) -> Result<Self, modkit_http::HttpError> {
102 let client = modkit_http::HttpClient::builder()
103 .timeout(timeout)
104 .retry(None) .build()?;
106
107 Ok(Self {
108 jwks_uri: jwks_uri.into(),
109 keys: Arc::new(ArcSwap::from_pointee(HashMap::new())),
110 refresh_state: Arc::new(RwLock::new(RefreshState::default())),
111 client,
112 refresh_interval: Duration::from_secs(300), max_backoff: Duration::from_secs(3600), on_demand_refresh_cooldown: Duration::from_secs(60), header_extras_handler: None,
116 })
117 }
118
119 pub fn try_new(jwks_uri: impl Into<String>) -> Result<Self, modkit_http::HttpError> {
124 Self::new(jwks_uri)
125 }
126
127 pub fn with_refresh_interval(mut self, interval: Duration) -> Self {
129 self.refresh_interval = interval;
130 self
131 }
132
133 pub fn with_max_backoff(mut self, max_backoff: Duration) -> Self {
135 self.max_backoff = max_backoff;
136 self
137 }
138
139 pub fn with_on_demand_refresh_cooldown(mut self, cooldown: Duration) -> Self {
141 self.on_demand_refresh_cooldown = cooldown;
142 self
143 }
144
145 pub fn with_header_extras_stringified(self) -> Self {
151 self.with_header_extras_handler(|_, v| Some(v.to_string()))
152 }
153
154 pub fn with_header_extras_handler(
161 mut self,
162 handler: impl Fn(&str, &Value) -> Option<String> + Send + Sync + 'static,
163 ) -> Self {
164 self.header_extras_handler = Some(Arc::new(handler));
165 self
166 }
167
168 async fn fetch_jwks(&self) -> Result<HashMap<String, DecodingKey>, ClaimsError> {
170 let jwks: JwksResponse = self
172 .client
173 .get(&self.jwks_uri)
174 .send()
175 .await
176 .map_err(|e| map_http_error(&e))?
177 .json()
178 .await
179 .map_err(|e| map_http_error(&e))?;
180
181 let mut keys = HashMap::new();
182 for jwk in jwks.keys {
183 if jwk.kty == "RSA" {
184 let key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)
185 .map_err(|e| ClaimsError::JwksFetchFailed(format!("Invalid RSA key: {e}")))?;
186 keys.insert(jwk.kid, key);
187 }
188 }
189
190 if keys.is_empty() {
191 return Err(ClaimsError::JwksFetchFailed(
192 "No valid RSA keys found in JWKS".into(),
193 ));
194 }
195
196 Ok(keys)
197 }
198
199 fn calculate_backoff(&self, failures: u32) -> Duration {
201 let base = Duration::from_secs(60); let exponential = base * 2u32.pow(failures.min(10)); exponential.min(self.max_backoff)
204 }
205
206 async fn should_refresh(&self) -> bool {
208 let state = self.refresh_state.read().await;
209
210 match state.last_refresh {
211 None => true, Some(last) => {
213 let elapsed = last.elapsed();
214 if state.consecutive_failures == 0 {
215 elapsed >= self.refresh_interval
217 } else {
218 elapsed >= self.calculate_backoff(state.consecutive_failures)
220 }
221 }
222 }
223 }
224
225 async fn perform_refresh(&self) -> Result<(), ClaimsError> {
227 match self.fetch_jwks().await {
228 Ok(new_keys) => {
229 self.keys.store(Arc::new(new_keys));
231
232 let mut state = self.refresh_state.write().await;
234 state.last_refresh = Some(Instant::now());
235 state.consecutive_failures = 0;
236 state.last_error = None;
237
238 Ok(())
239 }
240 Err(e) => {
241 let mut state = self.refresh_state.write().await;
243 state.last_refresh = Some(Instant::now());
244 state.consecutive_failures += 1;
245 state.last_error = Some(e.to_string());
246
247 Err(e)
248 }
249 }
250 }
251
252 fn key_exists(&self, kid: &str) -> bool {
254 let keys = self.keys.load();
255 keys.contains_key(kid)
256 }
257
258 async fn check_refresh_throttle(&self, kid: &str) -> Result<(), ClaimsError> {
260 let state = self.refresh_state.read().await;
261 if let Some(last_on_demand) = state.last_on_demand_refresh {
262 let elapsed = last_on_demand.elapsed();
263 if elapsed < self.on_demand_refresh_cooldown {
264 let remaining = self.on_demand_refresh_cooldown.saturating_sub(elapsed);
265 tracing::debug!(
266 kid = kid,
267 remaining_secs = remaining.as_secs(),
268 "On-demand JWKS refresh throttled (cooldown active)"
269 );
270
271 if state.failed_kids.contains(kid) {
273 tracing::warn!(
274 kid = kid,
275 "Unknown kid repeatedly requested despite recent refresh attempts"
276 );
277 }
278
279 return Err(ClaimsError::UnknownKeyId(kid.to_owned()));
280 }
281 }
282 Ok(())
283 }
284
285 async fn handle_refresh_success(&self, kid: &str) -> Result<(), ClaimsError> {
287 let mut state = self.refresh_state.write().await;
288 state.last_on_demand_refresh = Some(Instant::now());
289
290 if self.key_exists(kid) {
292 state.failed_kids.remove(kid);
294 } else {
295 state.failed_kids.insert(kid.to_owned());
297 tracing::warn!(
298 kid = kid,
299 "Kid still not found after on-demand JWKS refresh"
300 );
301 }
302
303 Ok(())
304 }
305
306 async fn handle_refresh_failure(&self, kid: &str, error: ClaimsError) -> ClaimsError {
308 let mut state = self.refresh_state.write().await;
309 state.last_on_demand_refresh = Some(Instant::now());
310 state.failed_kids.insert(kid.to_owned());
311 error
312 }
313
314 async fn on_demand_refresh(&self, kid: &str) -> Result<(), ClaimsError> {
317 if self.key_exists(kid) {
319 return Ok(());
320 }
321
322 self.check_refresh_throttle(kid).await?;
324
325 tracing::info!(
327 kid = kid,
328 "Performing on-demand JWKS refresh for unknown kid"
329 );
330
331 match self.perform_refresh().await {
332 Ok(()) => self.handle_refresh_success(kid).await,
333 Err(e) => Err(self.handle_refresh_failure(kid, e).await),
334 }
335 }
336
337 fn get_key(&self, kid: &str) -> Option<DecodingKey> {
339 let keys = self.keys.load();
340 keys.get(kid).cloned()
341 }
342
343 fn validate_token(
345 token: &str,
346 key: &DecodingKey,
347 header: &Header,
348 ) -> Result<Value, ClaimsError> {
349 let mut validation = Validation::new(header.alg);
350
351 validation.validate_exp = false;
353 validation.validate_nbf = false;
354 validation.validate_aud = false;
355
356 let empty_claims: &[&str] = &[];
358 validation.set_required_spec_claims(empty_claims);
359
360 let token_data = decode::<Value>(token, key, &validation)
361 .map_err(|e| ClaimsError::DecodeFailed(format!("JWT validation failed: {e}")))?;
362
363 Ok(token_data.claims)
364 }
365}
366
367#[async_trait]
368impl KeyProvider for JwksKeyProvider {
369 fn name(&self) -> &'static str {
370 "jwks"
371 }
372
373 async fn validate_and_decode(&self, token: &str) -> Result<(Header, Value), ClaimsError> {
374 let token = token.trim_start_matches("Bearer ").trim();
376
377 let header = match &self.header_extras_handler {
379 Some(handler) => decode_header_with_handler(token, handler.as_ref()),
380 None => decode_header(token),
381 }
382 .map_err(|e| ClaimsError::DecodeFailed(format!("Invalid JWT header: {e}")))?;
383
384 let kid = header
385 .kid
386 .as_ref()
387 .ok_or_else(|| ClaimsError::DecodeFailed("Missing kid in JWT header".into()))?;
388
389 let key = if let Some(k) = self.get_key(kid) {
391 k
392 } else {
393 self.on_demand_refresh(kid).await?;
395
396 self.get_key(kid)
398 .ok_or_else(|| ClaimsError::UnknownKeyId(kid.clone()))?
399 };
400
401 let claims = Self::validate_token(token, &key, &header)?;
403
404 Ok((header, claims))
405 }
406
407 async fn refresh_keys(&self) -> Result<(), ClaimsError> {
408 if self.should_refresh().await {
409 self.perform_refresh().await
410 } else {
411 Ok(())
412 }
413 }
414}
415
416pub async fn run_jwks_refresh_task(
439 provider: Arc<JwksKeyProvider>,
440 cancellation_token: CancellationToken,
441) {
442 let mut interval = tokio::time::interval(Duration::from_secs(60)); loop {
445 tokio::select! {
446 () = cancellation_token.cancelled() => {
447 tracing::info!("JWKS refresh task shutting down");
448 break;
449 }
450 _ = interval.tick() => {
451 if let Err(e) = provider.refresh_keys().await {
452 tracing::warn!("JWKS refresh failed: {}", e);
453 }
454 }
455 }
456 }
457}
458
459fn decode_header_with_handler(
463 token: &str,
464 handler: &dyn Fn(&str, &Value) -> Option<String>,
465) -> Result<Header, jsonwebtoken::errors::Error> {
466 let header_b64 = token
467 .split('.')
468 .next()
469 .ok_or(jsonwebtoken::errors::ErrorKind::InvalidToken)?;
470
471 let header_bytes = URL_SAFE_NO_PAD
472 .decode(header_b64.trim_end_matches('='))
473 .map_err(jsonwebtoken::errors::ErrorKind::Base64)?;
474
475 let mut json: serde_json::Map<String, Value> = serde_json::from_slice(&header_bytes)?;
476
477 json.retain(|key, value| {
478 if STANDARD_HEADER_FIELDS.contains(&key.as_str()) || value.is_string() {
479 return true;
480 }
481 match handler(key, value) {
482 Some(s) => {
483 *value = Value::String(s);
484 true
485 }
486 None => false,
487 }
488 });
489
490 Ok(serde_json::from_value(Value::Object(json))?)
491}
492
493fn map_http_error(e: &modkit_http::HttpError) -> ClaimsError {
495 ClaimsError::JwksFetchFailed(crate::http_error::format_http_error(e, "JWKS"))
496}
497
498#[cfg(test)]
499#[cfg_attr(coverage_nightly, coverage(off))]
500mod tests {
501 use super::*;
502 use httpmock::prelude::*;
503
504 fn test_provider_with_http(uri: &str) -> JwksKeyProvider {
506 let client = modkit_http::HttpClient::builder()
507 .timeout(Duration::from_secs(5))
508 .retry(None)
509 .allow_insecure_http()
510 .build()
511 .expect("failed to create test HTTP client");
512
513 JwksKeyProvider {
514 jwks_uri: uri.to_owned(),
515 keys: Arc::new(ArcSwap::from_pointee(HashMap::new())),
516 refresh_state: Arc::new(RwLock::new(RefreshState::default())),
517 client,
518 refresh_interval: Duration::from_secs(300),
519 max_backoff: Duration::from_secs(3600),
520 on_demand_refresh_cooldown: Duration::from_secs(60),
521 header_extras_handler: None,
522 }
523 }
524
525 fn test_provider(uri: &str) -> JwksKeyProvider {
527 JwksKeyProvider::new(uri).expect("failed to create test provider")
528 }
529
530 fn valid_jwks_json() -> &'static str {
532 r#"{
533 "keys": [{
534 "kty": "RSA",
535 "kid": "test-key-1",
536 "use": "sig",
537 "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
538 "e": "AQAB",
539 "alg": "RS256"
540 }]
541 }"#
542 }
543
544 #[tokio::test]
545 async fn test_calculate_backoff() {
546 let provider = test_provider("https://example.com/jwks");
547
548 assert_eq!(provider.calculate_backoff(0), Duration::from_secs(60));
549 assert_eq!(provider.calculate_backoff(1), Duration::from_secs(120));
550 assert_eq!(provider.calculate_backoff(2), Duration::from_secs(240));
551 assert_eq!(provider.calculate_backoff(3), Duration::from_secs(480));
552
553 assert_eq!(provider.calculate_backoff(100), provider.max_backoff);
555 }
556
557 #[tokio::test]
558 async fn test_should_refresh_on_first_call() {
559 let provider = test_provider("https://example.com/jwks");
560 assert!(provider.should_refresh().await);
561 }
562
563 #[tokio::test]
564 async fn test_key_storage() {
565 let provider = test_provider("https://example.com/jwks");
566
567 assert!(provider.get_key("test-kid").is_none());
569
570 let mut keys = HashMap::new();
572 keys.insert("test-kid".to_owned(), DecodingKey::from_secret(b"secret"));
573 provider.keys.store(Arc::new(keys));
574
575 assert!(provider.get_key("test-kid").is_some());
577 }
578
579 #[tokio::test]
580 async fn test_on_demand_refresh_returns_ok_when_key_exists() {
581 let provider = test_provider("https://example.com/jwks");
582
583 let mut keys = HashMap::new();
585 keys.insert(
586 "existing-kid".to_owned(),
587 DecodingKey::from_secret(b"secret"),
588 );
589 provider.keys.store(Arc::new(keys));
590
591 let result = provider.on_demand_refresh("existing-kid").await;
593 assert!(result.is_ok());
594 }
595
596 #[tokio::test]
597 async fn test_try_new_returns_result() {
598 let result = JwksKeyProvider::try_new("https://example.com/jwks");
600 assert!(result.is_ok());
601 }
602
603 #[tokio::test]
606 async fn test_fetch_jwks_success_with_valid_json() {
607 let server = MockServer::start();
608
609 let mock = server.mock(|when, then| {
610 when.method(GET).path("/jwks");
611 then.status(200)
612 .header("content-type", "application/json")
613 .body(valid_jwks_json());
614 });
615
616 let jwks_url = server.url("/jwks");
617 let provider = test_provider_with_http(&jwks_url);
618
619 let result = provider.perform_refresh().await;
620 assert!(result.is_ok(), "Expected success, got: {result:?}");
621
622 assert!(
624 provider.get_key("test-key-1").is_some(),
625 "Expected key 'test-key-1' to be stored"
626 );
627
628 mock.assert();
629 }
630
631 #[tokio::test]
632 async fn test_fetch_jwks_http_404_error_mapping() {
633 let server = MockServer::start();
634
635 let mock = server.mock(|when, then| {
636 when.method(GET).path("/jwks");
637 then.status(404).body("Not Found");
638 });
639
640 let jwks_url = server.url("/jwks");
641 let provider = test_provider_with_http(&jwks_url);
642
643 let result = provider.perform_refresh().await;
644 assert!(result.is_err());
645
646 let err = result.unwrap_err();
647 let err_msg = err.to_string();
648 assert!(
649 err_msg.contains("JWKS HTTP 404"),
650 "Expected error to contain 'JWKS HTTP 404', got: {err_msg}"
651 );
652 assert!(
654 !err_msg.to_lowercase().contains("parse"),
655 "HTTP status error should not mention 'parse', got: {err_msg}"
656 );
657
658 mock.assert();
659 }
660
661 #[tokio::test]
662 async fn test_fetch_jwks_http_500_error_mapping() {
663 let server = MockServer::start();
664
665 let mock = server.mock(|when, then| {
666 when.method(GET).path("/jwks");
667 then.status(500).body("Internal Server Error");
668 });
669
670 let jwks_url = server.url("/jwks");
671 let provider = test_provider_with_http(&jwks_url);
672
673 let result = provider.perform_refresh().await;
674 assert!(result.is_err());
675
676 let err = result.unwrap_err();
677 let err_msg = err.to_string();
678 assert!(
679 err_msg.contains("JWKS HTTP 500"),
680 "Expected error to contain 'JWKS HTTP 500', got: {err_msg}"
681 );
682
683 mock.assert();
684 }
685
686 #[tokio::test]
687 async fn test_fetch_jwks_invalid_json_error_mapping() {
688 let server = MockServer::start();
689
690 let mock = server.mock(|when, then| {
691 when.method(GET).path("/jwks");
692 then.status(200)
693 .header("content-type", "application/json")
694 .body("this is not valid json");
695 });
696
697 let jwks_url = server.url("/jwks");
698 let provider = test_provider_with_http(&jwks_url);
699
700 let result = provider.perform_refresh().await;
701 assert!(result.is_err());
702
703 let err = result.unwrap_err();
704 let err_msg = err.to_string();
705 assert!(
706 err_msg.contains("JWKS JSON parse failed"),
707 "Expected error to contain 'JWKS JSON parse failed', got: {err_msg}"
708 );
709
710 mock.assert();
711 }
712
713 #[tokio::test]
714 async fn test_fetch_jwks_empty_keys_error() {
715 let server = MockServer::start();
716
717 let mock = server.mock(|when, then| {
718 when.method(GET).path("/jwks");
719 then.status(200)
720 .header("content-type", "application/json")
721 .body(r#"{"keys": []}"#);
722 });
723
724 let jwks_url = server.url("/jwks");
725 let provider = test_provider_with_http(&jwks_url);
726
727 let result = provider.perform_refresh().await;
728 assert!(result.is_err());
729
730 let err = result.unwrap_err();
731 let err_msg = err.to_string();
732 assert!(
733 err_msg.contains("No valid RSA keys"),
734 "Expected error about no RSA keys, got: {err_msg}"
735 );
736
737 mock.assert();
738 }
739
740 #[tokio::test]
741 async fn test_on_demand_refresh_respects_cooldown() {
742 let server = MockServer::start();
743
744 let mock = server.mock(|when, then| {
746 when.method(GET).path("/jwks");
747 then.status(404).body("Not Found");
748 });
749
750 let jwks_url = server.url("/jwks");
751 let provider = test_provider_with_http(&jwks_url)
752 .with_on_demand_refresh_cooldown(Duration::from_secs(60));
753
754 let result1 = provider.on_demand_refresh("test-kid").await;
756 assert!(result1.is_err());
757
758 let result2 = provider.on_demand_refresh("test-kid").await;
760 assert!(result2.is_err());
761
762 match result2.unwrap_err() {
764 ClaimsError::UnknownKeyId(_) => {}
765 other => panic!("Expected UnknownKeyId during cooldown, got: {other:?}"),
766 }
767
768 mock.assert_calls(1);
770 }
771
772 #[tokio::test]
773 async fn test_on_demand_refresh_tracks_failed_kids() {
774 let server = MockServer::start();
775
776 server.mock(|when, then| {
777 when.method(GET).path("/jwks");
778 then.status(404).body("Not Found");
779 });
780
781 let jwks_url = server.url("/jwks");
782 let provider = test_provider_with_http(&jwks_url)
783 .with_on_demand_refresh_cooldown(Duration::from_millis(100));
784
785 let result = provider.on_demand_refresh("failed-kid").await;
787 assert!(result.is_err());
788
789 let state = provider.refresh_state.read().await;
791 assert!(state.failed_kids.contains("failed-kid"));
792 }
793
794 #[tokio::test]
795 async fn test_perform_refresh_updates_state_on_failure() {
796 let server = MockServer::start();
797
798 server.mock(|when, then| {
799 when.method(GET).path("/jwks");
800 then.status(500).body("Server Error");
801 });
802
803 let jwks_url = server.url("/jwks");
804 let provider = test_provider_with_http(&jwks_url);
805
806 {
808 let mut state = provider.refresh_state.write().await;
809 state.consecutive_failures = 3;
810 state.last_error = Some("Previous error".to_owned());
811 }
812
813 _ = provider.perform_refresh().await;
815
816 let state = provider.refresh_state.read().await;
818 assert_eq!(state.consecutive_failures, 4);
819 assert!(state.last_error.is_some());
820 }
821
822 #[tokio::test]
823 async fn test_perform_refresh_resets_state_on_success() {
824 let server = MockServer::start();
825
826 server.mock(|when, then| {
827 when.method(GET).path("/jwks");
828 then.status(200)
829 .header("content-type", "application/json")
830 .body(valid_jwks_json());
831 });
832
833 let jwks_url = server.url("/jwks");
834 let provider = test_provider_with_http(&jwks_url);
835
836 {
838 let mut state = provider.refresh_state.write().await;
839 state.consecutive_failures = 5;
840 state.last_error = Some("Previous error".to_owned());
841 }
842
843 let result = provider.perform_refresh().await;
845 assert!(result.is_ok());
846
847 let state = provider.refresh_state.read().await;
849 assert_eq!(state.consecutive_failures, 0);
850 assert!(state.last_error.is_none());
851 }
852
853 #[tokio::test]
854 async fn test_validate_and_decode_with_missing_kid() {
855 let server = MockServer::start();
856
857 server.mock(|when, then| {
859 when.method(GET).path("/jwks");
860 then.status(200)
861 .header("content-type", "application/json")
862 .body(valid_jwks_json());
863 });
864
865 let jwks_url = server.url("/jwks");
866 let provider = test_provider_with_http(&jwks_url)
867 .with_on_demand_refresh_cooldown(Duration::from_millis(100));
868
869 let token = "eyJhbGciOiJSUzI1NiIsImtpZCI6Im5vbmV4aXN0ZW50LWtpZCJ9.\
872 eyJzdWIiOiIxMjM0NTY3ODkwIn0.invalid";
873
874 let result = provider.validate_and_decode(token).await;
876 assert!(result.is_err());
877
878 match result.unwrap_err() {
879 ClaimsError::UnknownKeyId(kid) => {
880 assert_eq!(kid, "nonexistent-kid");
881 }
882 other => panic!("Expected UnknownKeyId, got: {other:?}"),
883 }
884 }
885
886 #[test]
887 fn test_decode_header_with_handler_coerces_non_string_extras() {
888 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
889
890 let header_json = r#"{"alg":"RS256","eap":1,"iri":"some-string-id","irn":["role_a"],"kid":"kid-1","typ":"at+jwt"}"#;
892 let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
893 let payload_b64 = URL_SAFE_NO_PAD.encode(b"{}");
894 let token = format!("{header_b64}.{payload_b64}.fake");
895
896 let header = decode_header_with_handler(&token, &|_key, value| Some(value.to_string()))
897 .expect("should handle non-standard header fields");
898
899 assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
900 assert_eq!(header.kid.as_deref(), Some("kid-1"));
901 assert_eq!(header.typ.as_deref(), Some("at+jwt"));
902
903 assert_eq!(header.extras.get("eap").map(String::as_str), Some("1"));
905 assert_eq!(
906 header.extras.get("irn").map(String::as_str),
907 Some(r#"["role_a"]"#)
908 );
909 assert_eq!(
911 header.extras.get("iri").map(String::as_str),
912 Some("some-string-id")
913 );
914 }
915
916 #[test]
917 fn test_decode_header_with_handler_can_drop_fields() {
918 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
919
920 let header_json = r#"{"alg":"RS256","eap":1,"iri":"keep-me","kid":"kid-1","typ":"JWT"}"#;
921 let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
922 let token = format!("{header_b64}.e30.fake");
923
924 let header = decode_header_with_handler(&token, &|_key, _value| None)
925 .expect("should succeed when handler drops non-string fields");
926
927 assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
928 assert!(header.extras.get("eap").is_none());
929 assert_eq!(
930 header.extras.get("iri").map(String::as_str),
931 Some("keep-me")
932 );
933 }
934
935 #[tokio::test]
936 async fn test_with_header_extras_stringified_coerces_non_string_extras() {
937 let server = MockServer::start();
938
939 server.mock(|when, then| {
940 when.method(GET).path("/jwks");
941 then.status(200)
942 .header("content-type", "application/json")
943 .body(valid_jwks_json());
944 });
945
946 let jwks_url = server.url("/jwks");
947 let provider = test_provider_with_http(&jwks_url).with_header_extras_stringified();
948
949 let header_json =
951 r#"{"alg":"RS256","kid":"test-key-1","typ":"JWT","eap":1,"irn":["role_a"]}"#;
952 let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
953 let payload_b64 = URL_SAFE_NO_PAD.encode(b"{}");
954 let token = format!("{header_b64}.{payload_b64}.AAAA");
955
956 let result = provider.validate_and_decode(&token).await;
957
958 let err = result.expect_err("fake signature should fail validation");
961 match &err {
962 ClaimsError::DecodeFailed(msg) => {
963 assert!(
964 msg.contains("JWT validation failed"),
965 "Expected signature-validation error, got: {msg}"
966 );
967 }
968 other => panic!("Expected DecodeFailed, got: {other:?}"),
969 }
970 }
971
972 #[tokio::test]
973 async fn test_validate_and_decode_uses_header_extras_handler() {
974 let server = MockServer::start();
975
976 server.mock(|when, then| {
977 when.method(GET).path("/jwks");
978 then.status(200)
979 .header("content-type", "application/json")
980 .body(valid_jwks_json());
981 });
982
983 let jwks_url = server.url("/jwks");
984 let provider = test_provider_with_http(&jwks_url)
985 .with_header_extras_handler(|_key, value| Some(value.to_string()));
986
987 let header_json = r#"{"alg":"RS256","kid":"test-key-1","typ":"JWT","eap":1}"#;
989 let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
990 let payload_b64 = URL_SAFE_NO_PAD.encode(b"{}");
991 let token = format!("{header_b64}.{payload_b64}.AAAA");
992
993 let result = provider.validate_and_decode(&token).await;
994
995 let err = result.expect_err("fake signature should fail validation");
998 match &err {
999 ClaimsError::DecodeFailed(msg) => {
1000 assert!(
1001 msg.contains("JWT validation failed"),
1002 "Expected signature-validation error, got: {msg}"
1003 );
1004 }
1005 other => panic!("Expected DecodeFailed, got: {other:?}"),
1006 }
1007 }
1008
1009 #[test]
1010 fn test_decode_header_without_handler_rejects_non_string_extras() {
1011 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
1012
1013 let header_json = r#"{"alg":"RS256","eap":1,"kid":"kid-1","typ":"JWT"}"#;
1014 let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
1015 let token = format!("{header_b64}.e30.fake");
1016
1017 let result = decode_header(&token);
1018 assert!(result.is_err());
1019 let err = result.unwrap_err().to_string();
1020 assert!(
1021 err.contains("invalid type: integer"),
1022 "expected type error, got: {err}"
1023 );
1024 }
1025}