Skip to main content

modkit_auth/providers/
jwks.rs

1use crate::{claims_error::ClaimsError, traits::KeyProvider};
2use arc_swap::ArcSwap;
3use async_trait::async_trait;
4use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
5use jsonwebtoken::{DecodingKey, Header, decode_header};
6use serde::Deserialize;
7use serde_json::Value;
8use std::collections::{HashMap, HashSet};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::RwLock;
12use tokio::time::Instant;
13use tokio_util::sync::CancellationToken;
14
15#[derive(Debug, Clone, Deserialize)]
16struct Jwk {
17    kid: String,
18    kty: String,
19    #[serde(rename = "use")]
20    #[allow(dead_code)]
21    use_: Option<String>,
22    n: String,
23    e: String,
24    #[allow(dead_code)]
25    alg: Option<String>,
26}
27
28#[derive(Debug, Clone, Deserialize)]
29struct JwksResponse {
30    keys: Vec<Jwk>,
31}
32
33/// Handler for non-string custom JWT header fields; return `Some` to keep as string, or `None` to drop.
34type HeaderExtrasHandler = dyn Fn(&str, &Value) -> Option<String> + Send + Sync;
35
36/// Standard JWT header field names from RFC 7515 (JWS), RFC 7516 (JWE),
37/// RFC 7518 (JWA), RFC 7797 (b64), and RFC 8555 (ACME).
38const STANDARD_HEADER_FIELDS: &[&str] = &[
39    "typ", "alg", "cty", "jku", "jwk", "kid", "x5u", "x5c", "x5t", "x5t#S256", "crit", "enc",
40    "zip", "url", "nonce", "epk", "apu", "apv", "iv", "tag", "p2s", "p2c", "b64",
41];
42
43/// JWKS-based key provider with lock-free reads
44///
45/// Uses `ArcSwap` for lock-free key lookups and background refresh with exponential backoff.
46#[must_use]
47pub struct JwksKeyProvider {
48    /// JWKS endpoint URL
49    jwks_uri: String,
50
51    /// Keys stored in `ArcSwap` for lock-free reads
52    keys: Arc<ArcSwap<HashMap<String, DecodingKey>>>,
53
54    /// Last refresh time and error tracking for backoff
55    refresh_state: Arc<RwLock<RefreshState>>,
56
57    /// Shared HTTP client for JWKS fetches (pooled connections)
58    /// `HttpClient` is `Clone + Send + Sync`, no external locking needed.
59    client: modkit_http::HttpClient,
60
61    /// Refresh interval (default: 5 minutes)
62    refresh_interval: Duration,
63
64    /// Maximum backoff duration (default: 1 hour)
65    max_backoff: Duration,
66
67    /// Cooldown for on-demand refresh (default: 60 seconds)
68    on_demand_refresh_cooldown: Duration,
69
70    /// Optional handler for non-string custom JWT header fields.
71    /// Called for each non-standard field whose value is not a JSON string.
72    /// Return `Some(s)` to keep, `None` to drop.
73    header_extras_handler: Option<Arc<HeaderExtrasHandler>>,
74}
75
76#[derive(Debug, Default)]
77struct RefreshState {
78    last_refresh: Option<Instant>,
79    last_on_demand_refresh: Option<Instant>,
80    consecutive_failures: u32,
81    last_error: Option<String>,
82    failed_kids: HashSet<String>,
83}
84
85impl JwksKeyProvider {
86    /// Create a new JWKS key provider
87    ///
88    /// # Errors
89    /// Returns error if HTTP client initialization fails (e.g., TLS setup)
90    pub fn new(jwks_uri: impl Into<String>) -> Result<Self, modkit_http::HttpError> {
91        Self::with_http_timeout(jwks_uri, Duration::from_secs(10))
92    }
93
94    /// Create a new JWKS key provider with custom HTTP timeout
95    ///
96    /// # Errors
97    /// Returns error if HTTP client initialization fails (e.g., TLS setup)
98    pub fn with_http_timeout(
99        jwks_uri: impl Into<String>,
100        timeout: Duration,
101    ) -> Result<Self, modkit_http::HttpError> {
102        let client = modkit_http::HttpClient::builder()
103            .timeout(timeout)
104            .retry(None) // JWKS provider handles its own retry logic
105            .build()?;
106
107        Ok(Self {
108            jwks_uri: jwks_uri.into(),
109            keys: Arc::new(ArcSwap::from_pointee(HashMap::new())),
110            refresh_state: Arc::new(RwLock::new(RefreshState::default())),
111            client,
112            refresh_interval: Duration::from_secs(300), // 5 minutes
113            max_backoff: Duration::from_secs(3600),     // 1 hour
114            on_demand_refresh_cooldown: Duration::from_secs(60), // 1 minute
115            header_extras_handler: None,
116        })
117    }
118
119    /// Create a new JWKS key provider (alias for new, kept for compatibility)
120    ///
121    /// # Errors
122    /// Returns error if HTTP client initialization fails (e.g., TLS setup)
123    pub fn try_new(jwks_uri: impl Into<String>) -> Result<Self, modkit_http::HttpError> {
124        Self::new(jwks_uri)
125    }
126
127    /// Create with custom refresh interval
128    pub fn with_refresh_interval(mut self, interval: Duration) -> Self {
129        self.refresh_interval = interval;
130        self
131    }
132
133    /// Create with custom max backoff
134    pub fn with_max_backoff(mut self, max_backoff: Duration) -> Self {
135        self.max_backoff = max_backoff;
136        self
137    }
138
139    /// Create with custom on-demand refresh cooldown
140    pub fn with_on_demand_refresh_cooldown(mut self, cooldown: Duration) -> Self {
141        self.on_demand_refresh_cooldown = cooldown;
142        self
143    }
144
145    /// Stringify all non-string custom JWT header fields.
146    ///
147    /// Convenience wrapper around [`with_header_extras_handler`](Self::with_header_extras_handler)
148    /// that converts every non-string value to its JSON representation
149    /// (e.g. `123` → `"123"`, `true` → `"true"`, `[1,2]` → `"[1,2]"`).
150    pub fn with_header_extras_stringified(self) -> Self {
151        self.with_header_extras_handler(|_, v| Some(v.to_string()))
152    }
153
154    /// Set a handler for non-string custom JWT header fields.
155    ///
156    /// `jsonwebtoken::Header::extras` is `HashMap<String, String>` and rejects
157    /// non-string values. This callback is invoked for each such field.
158    /// Return `Some(s)` to keep, `None` to drop.
159    /// Without a handler, upstream `decode_header` is used as-is.
160    pub fn with_header_extras_handler(
161        mut self,
162        handler: impl Fn(&str, &Value) -> Option<String> + Send + Sync + 'static,
163    ) -> Self {
164        self.header_extras_handler = Some(Arc::new(handler));
165        self
166    }
167
168    /// Fetch JWKS from the endpoint
169    async fn fetch_jwks(&self) -> Result<HashMap<String, DecodingKey>, ClaimsError> {
170        // HttpClient is Clone + Send + Sync, no locking needed
171        let jwks: JwksResponse = self
172            .client
173            .get(&self.jwks_uri)
174            .send()
175            .await
176            .map_err(|e| map_http_error(&e))?
177            .json()
178            .await
179            .map_err(|e| map_http_error(&e))?;
180
181        let mut keys = HashMap::new();
182        for jwk in jwks.keys {
183            if jwk.kty == "RSA" {
184                let key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)
185                    .map_err(|e| ClaimsError::JwksFetchFailed(format!("Invalid RSA key: {e}")))?;
186                keys.insert(jwk.kid, key);
187            }
188        }
189
190        if keys.is_empty() {
191            return Err(ClaimsError::JwksFetchFailed(
192                "No valid RSA keys found in JWKS".into(),
193            ));
194        }
195
196        Ok(keys)
197    }
198
199    /// Calculate backoff duration based on consecutive failures
200    fn calculate_backoff(&self, failures: u32) -> Duration {
201        let base = Duration::from_secs(60); // 1 minute base
202        let exponential = base * 2u32.pow(failures.min(10)); // Cap at 2^10
203        exponential.min(self.max_backoff)
204    }
205
206    /// Check if refresh is needed based on interval and backoff
207    async fn should_refresh(&self) -> bool {
208        let state = self.refresh_state.read().await;
209
210        match state.last_refresh {
211            None => true, // Never refreshed
212            Some(last) => {
213                let elapsed = last.elapsed();
214                if state.consecutive_failures == 0 {
215                    // Normal refresh interval
216                    elapsed >= self.refresh_interval
217                } else {
218                    // Exponential backoff
219                    elapsed >= self.calculate_backoff(state.consecutive_failures)
220                }
221            }
222        }
223    }
224
225    /// Perform key refresh with error tracking
226    async fn perform_refresh(&self) -> Result<(), ClaimsError> {
227        match self.fetch_jwks().await {
228            Ok(new_keys) => {
229                // Update keys atomically
230                self.keys.store(Arc::new(new_keys));
231
232                // Update refresh state
233                let mut state = self.refresh_state.write().await;
234                state.last_refresh = Some(Instant::now());
235                state.consecutive_failures = 0;
236                state.last_error = None;
237
238                Ok(())
239            }
240            Err(e) => {
241                // Update failure state
242                let mut state = self.refresh_state.write().await;
243                state.last_refresh = Some(Instant::now());
244                state.consecutive_failures += 1;
245                state.last_error = Some(e.to_string());
246
247                Err(e)
248            }
249        }
250    }
251
252    /// Check if a key exists in the cache
253    fn key_exists(&self, kid: &str) -> bool {
254        let keys = self.keys.load();
255        keys.contains_key(kid)
256    }
257
258    /// Check if we're in cooldown period and handle throttling logic
259    async fn check_refresh_throttle(&self, kid: &str) -> Result<(), ClaimsError> {
260        let state = self.refresh_state.read().await;
261        if let Some(last_on_demand) = state.last_on_demand_refresh {
262            let elapsed = last_on_demand.elapsed();
263            if elapsed < self.on_demand_refresh_cooldown {
264                let remaining = self.on_demand_refresh_cooldown.saturating_sub(elapsed);
265                tracing::debug!(
266                    kid = kid,
267                    remaining_secs = remaining.as_secs(),
268                    "On-demand JWKS refresh throttled (cooldown active)"
269                );
270
271                // Check if this kid has failed before
272                if state.failed_kids.contains(kid) {
273                    tracing::warn!(
274                        kid = kid,
275                        "Unknown kid repeatedly requested despite recent refresh attempts"
276                    );
277                }
278
279                return Err(ClaimsError::UnknownKeyId(kid.to_owned()));
280            }
281        }
282        Ok(())
283    }
284
285    /// Update state after successful refresh and check if kid is now available
286    async fn handle_refresh_success(&self, kid: &str) -> Result<(), ClaimsError> {
287        let mut state = self.refresh_state.write().await;
288        state.last_on_demand_refresh = Some(Instant::now());
289
290        // Check if the kid now exists
291        if self.key_exists(kid) {
292            // Kid found - remove from failed list if present
293            state.failed_kids.remove(kid);
294        } else {
295            // Kid still not found after refresh - track it
296            state.failed_kids.insert(kid.to_owned());
297            tracing::warn!(
298                kid = kid,
299                "Kid still not found after on-demand JWKS refresh"
300            );
301        }
302
303        Ok(())
304    }
305
306    /// Update state after failed refresh
307    async fn handle_refresh_failure(&self, kid: &str, error: ClaimsError) -> ClaimsError {
308        let mut state = self.refresh_state.write().await;
309        state.last_on_demand_refresh = Some(Instant::now());
310        state.failed_kids.insert(kid.to_owned());
311        error
312    }
313
314    /// Try to refresh keys if unknown kid is encountered
315    /// Implements throttling to prevent excessive refreshes
316    async fn on_demand_refresh(&self, kid: &str) -> Result<(), ClaimsError> {
317        // Check if key exists
318        if self.key_exists(kid) {
319            return Ok(());
320        }
321
322        // Check if we're in cooldown period
323        self.check_refresh_throttle(kid).await?;
324
325        // Attempt refresh and track the kid if it fails
326        tracing::info!(
327            kid = kid,
328            "Performing on-demand JWKS refresh for unknown kid"
329        );
330
331        match self.perform_refresh().await {
332            Ok(()) => self.handle_refresh_success(kid).await,
333            Err(e) => Err(self.handle_refresh_failure(kid, e).await),
334        }
335    }
336
337    /// Get a key by kid (lock-free read)
338    fn get_key(&self, kid: &str) -> Option<DecodingKey> {
339        let keys = self.keys.load();
340        keys.get(kid).cloned()
341    }
342
343    /// Validate JWT signature and decode claims without re-parsing the header.
344    ///
345    /// Uses `jsonwebtoken::crypto::verify` directly instead of `decode()`,
346    /// because `decode()` internally calls `decode_header()` which fails
347    /// on non-string custom header fields (e.g. `"eap": 1`).
348    fn validate_token(
349        token: &str,
350        key: &DecodingKey,
351        header: &Header,
352    ) -> Result<Value, ClaimsError> {
353        // Enforce exactly three dot-separated segments: header.payload.signature
354        let parts: Vec<&str> = token.splitn(4, '.').collect();
355        if parts.len() != 3 {
356            return Err(ClaimsError::DecodeFailed("Invalid JWT structure".into()));
357        }
358        let signing_input = &token[..parts[0].len() + 1 + parts[1].len()];
359        let payload_b64 = parts[1];
360        let signature = parts[2];
361
362        // Verify signature over header.payload (the original signing input)
363        let valid =
364            jsonwebtoken::crypto::verify(signature, signing_input.as_bytes(), key, header.alg)
365                .map_err(|e| {
366                    ClaimsError::DecodeFailed(format!("JWT signature verification failed: {e}"))
367                })?;
368        if !valid {
369            return Err(ClaimsError::InvalidSignature);
370        }
371
372        // Decode payload
373        let payload_bytes = URL_SAFE_NO_PAD
374            .decode(payload_b64.trim_end_matches('='))
375            .map_err(|e| ClaimsError::DecodeFailed(format!("JWT payload decode failed: {e}")))?;
376        let claims: Value = serde_json::from_slice(&payload_bytes)
377            .map_err(|e| ClaimsError::DecodeFailed(format!("JWT claims parse failed: {e}")))?;
378
379        Ok(claims)
380    }
381}
382
383#[async_trait]
384impl KeyProvider for JwksKeyProvider {
385    fn name(&self) -> &'static str {
386        "jwks"
387    }
388
389    async fn validate_and_decode(&self, token: &str) -> Result<(Header, Value), ClaimsError> {
390        // Strip "Bearer " prefix if present
391        let token = token.trim_start_matches("Bearer ").trim();
392
393        // Decode header to get kid and algorithm
394        let header = match &self.header_extras_handler {
395            Some(handler) => decode_header_with_handler(token, handler.as_ref()),
396            None => decode_header(token),
397        }
398        .map_err(|e| ClaimsError::DecodeFailed(format!("Invalid JWT header: {e}")))?;
399
400        let kid = header
401            .kid
402            .as_ref()
403            .ok_or_else(|| ClaimsError::DecodeFailed("Missing kid in JWT header".into()))?;
404
405        // Try to get key from cache
406        let key = if let Some(k) = self.get_key(kid) {
407            k
408        } else {
409            // Key not in cache, try on-demand refresh
410            self.on_demand_refresh(kid).await?;
411
412            // Try again after refresh
413            self.get_key(kid)
414                .ok_or_else(|| ClaimsError::UnknownKeyId(kid.clone()))?
415        };
416
417        // Validate signature and decode claims
418        let claims = Self::validate_token(token, &key, &header)?;
419
420        Ok((header, claims))
421    }
422
423    async fn refresh_keys(&self) -> Result<(), ClaimsError> {
424        if self.should_refresh().await {
425            self.perform_refresh().await
426        } else {
427            Ok(())
428        }
429    }
430}
431
432/// Background task to periodically refresh JWKS
433///
434/// This task will run until the `cancellation_token` is cancelled, enabling
435/// graceful shutdown per `ModKit` patterns. Without cancellation support, this
436/// task would run indefinitely and potentially cause process hang on shutdown.
437///
438/// # Example
439///
440/// ```ignore
441/// use tokio_util::sync::CancellationToken;
442/// use std::sync::Arc;
443///
444/// let provider = Arc::new(JwksKeyProvider::new("https://issuer/.well-known/jwks.json")?);
445/// let cancel_token = CancellationToken::new();
446///
447/// // Spawn the refresh task
448/// let task_handle = tokio::spawn(run_jwks_refresh_task(provider.clone(), cancel_token.clone()));
449///
450/// // On shutdown:
451/// cancel_token.cancel();
452/// task_handle.await?;
453/// ```
454pub async fn run_jwks_refresh_task(
455    provider: Arc<JwksKeyProvider>,
456    cancellation_token: CancellationToken,
457) {
458    let mut interval = tokio::time::interval(Duration::from_secs(60)); // Check every minute
459
460    loop {
461        tokio::select! {
462            () = cancellation_token.cancelled() => {
463                tracing::info!("JWKS refresh task shutting down");
464                break;
465            }
466            _ = interval.tick() => {
467                if let Err(e) = provider.refresh_keys().await {
468                    tracing::warn!("JWKS refresh failed: {}", e);
469                }
470            }
471        }
472    }
473}
474
475/// Decode a JWT header, routing non-string custom fields through `handler`.
476///
477/// Returns `Some(s)` to keep the field, `None` to drop it.
478fn decode_header_with_handler(
479    token: &str,
480    handler: &dyn Fn(&str, &Value) -> Option<String>,
481) -> Result<Header, jsonwebtoken::errors::Error> {
482    let header_b64 = token
483        .split('.')
484        .next()
485        .ok_or(jsonwebtoken::errors::ErrorKind::InvalidToken)?;
486
487    let header_bytes = URL_SAFE_NO_PAD
488        .decode(header_b64.trim_end_matches('='))
489        .map_err(jsonwebtoken::errors::ErrorKind::Base64)?;
490
491    let mut json: serde_json::Map<String, Value> = serde_json::from_slice(&header_bytes)?;
492
493    json.retain(|key, value| {
494        if STANDARD_HEADER_FIELDS.contains(&key.as_str()) || value.is_string() {
495            return true;
496        }
497        match handler(key, value) {
498            Some(s) => {
499                *value = Value::String(s);
500                true
501            }
502            None => false,
503        }
504    });
505
506    Ok(serde_json::from_value(Value::Object(json))?)
507}
508
509/// Map `HttpError` variants to appropriate `ClaimsError` messages
510fn map_http_error(e: &modkit_http::HttpError) -> ClaimsError {
511    ClaimsError::JwksFetchFailed(crate::http_error::format_http_error(e, "JWKS"))
512}
513
514#[cfg(test)]
515#[cfg_attr(coverage_nightly, coverage(off))]
516mod tests {
517    use super::*;
518    use httpmock::prelude::*;
519
520    /// Create a test provider with insecure HTTP allowed (for httpmock) and no retries
521    fn test_provider_with_http(uri: &str) -> JwksKeyProvider {
522        let client = modkit_http::HttpClient::builder()
523            .timeout(Duration::from_secs(5))
524            .retry(None)
525            .build()
526            .expect("failed to create test HTTP client");
527
528        JwksKeyProvider {
529            jwks_uri: uri.to_owned(),
530            keys: Arc::new(ArcSwap::from_pointee(HashMap::new())),
531            refresh_state: Arc::new(RwLock::new(RefreshState::default())),
532            client,
533            refresh_interval: Duration::from_secs(300),
534            max_backoff: Duration::from_secs(3600),
535            on_demand_refresh_cooldown: Duration::from_secs(60),
536            header_extras_handler: None,
537        }
538    }
539
540    /// Create a basic test provider (HTTPS only, for non-network tests)
541    fn test_provider(uri: &str) -> JwksKeyProvider {
542        JwksKeyProvider::new(uri).expect("failed to create test provider")
543    }
544
545    /// Valid JWKS JSON response with a single RSA key
546    fn valid_jwks_json() -> &'static str {
547        r#"{
548            "keys": [{
549                "kty": "RSA",
550                "kid": "test-key-1",
551                "use": "sig",
552                "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
553                "e": "AQAB",
554                "alg": "RS256"
555            }]
556        }"#
557    }
558
559    #[tokio::test]
560    async fn test_calculate_backoff() {
561        let provider = test_provider("https://example.com/jwks");
562
563        assert_eq!(provider.calculate_backoff(0), Duration::from_secs(60));
564        assert_eq!(provider.calculate_backoff(1), Duration::from_secs(120));
565        assert_eq!(provider.calculate_backoff(2), Duration::from_secs(240));
566        assert_eq!(provider.calculate_backoff(3), Duration::from_secs(480));
567
568        // Should cap at max_backoff
569        assert_eq!(provider.calculate_backoff(100), provider.max_backoff);
570    }
571
572    #[tokio::test]
573    async fn test_should_refresh_on_first_call() {
574        let provider = test_provider("https://example.com/jwks");
575        assert!(provider.should_refresh().await);
576    }
577
578    #[tokio::test]
579    async fn test_key_storage() {
580        let provider = test_provider("https://example.com/jwks");
581
582        // Initially empty
583        assert!(provider.get_key("test-kid").is_none());
584
585        // Store a dummy key
586        let mut keys = HashMap::new();
587        keys.insert("test-kid".to_owned(), DecodingKey::from_secret(b"secret"));
588        provider.keys.store(Arc::new(keys));
589
590        // Should be retrievable
591        assert!(provider.get_key("test-kid").is_some());
592    }
593
594    #[tokio::test]
595    async fn test_on_demand_refresh_returns_ok_when_key_exists() {
596        let provider = test_provider("https://example.com/jwks");
597
598        // Pre-populate with a key
599        let mut keys = HashMap::new();
600        keys.insert(
601            "existing-kid".to_owned(),
602            DecodingKey::from_secret(b"secret"),
603        );
604        provider.keys.store(Arc::new(keys));
605
606        // Should return Ok immediately without any refresh
607        let result = provider.on_demand_refresh("existing-kid").await;
608        assert!(result.is_ok());
609    }
610
611    #[tokio::test]
612    async fn test_try_new_returns_result() {
613        // Valid URL should work
614        let result = JwksKeyProvider::try_new("https://example.com/jwks");
615        assert!(result.is_ok());
616    }
617
618    // ==================== httpmock-based tests ====================
619
620    #[tokio::test]
621    async fn test_fetch_jwks_success_with_valid_json() {
622        let server = MockServer::start();
623
624        let mock = server.mock(|when, then| {
625            when.method(GET).path("/jwks");
626            then.status(200)
627                .header("content-type", "application/json")
628                .body(valid_jwks_json());
629        });
630
631        let jwks_url = server.url("/jwks");
632        let provider = test_provider_with_http(&jwks_url);
633
634        let result = provider.perform_refresh().await;
635        assert!(result.is_ok(), "Expected success, got: {result:?}");
636
637        // Verify key was stored
638        assert!(
639            provider.get_key("test-key-1").is_some(),
640            "Expected key 'test-key-1' to be stored"
641        );
642
643        mock.assert();
644    }
645
646    #[tokio::test]
647    async fn test_fetch_jwks_http_404_error_mapping() {
648        let server = MockServer::start();
649
650        let mock = server.mock(|when, then| {
651            when.method(GET).path("/jwks");
652            then.status(404).body("Not Found");
653        });
654
655        let jwks_url = server.url("/jwks");
656        let provider = test_provider_with_http(&jwks_url);
657
658        let result = provider.perform_refresh().await;
659        assert!(result.is_err());
660
661        let err = result.unwrap_err();
662        let err_msg = err.to_string();
663        assert!(
664            err_msg.contains("JWKS HTTP 404"),
665            "Expected error to contain 'JWKS HTTP 404', got: {err_msg}"
666        );
667        // Must NOT say "parse"
668        assert!(
669            !err_msg.to_lowercase().contains("parse"),
670            "HTTP status error should not mention 'parse', got: {err_msg}"
671        );
672
673        mock.assert();
674    }
675
676    #[tokio::test]
677    async fn test_fetch_jwks_http_500_error_mapping() {
678        let server = MockServer::start();
679
680        let mock = server.mock(|when, then| {
681            when.method(GET).path("/jwks");
682            then.status(500).body("Internal Server Error");
683        });
684
685        let jwks_url = server.url("/jwks");
686        let provider = test_provider_with_http(&jwks_url);
687
688        let result = provider.perform_refresh().await;
689        assert!(result.is_err());
690
691        let err = result.unwrap_err();
692        let err_msg = err.to_string();
693        assert!(
694            err_msg.contains("JWKS HTTP 500"),
695            "Expected error to contain 'JWKS HTTP 500', got: {err_msg}"
696        );
697
698        mock.assert();
699    }
700
701    #[tokio::test]
702    async fn test_fetch_jwks_invalid_json_error_mapping() {
703        let server = MockServer::start();
704
705        let mock = server.mock(|when, then| {
706            when.method(GET).path("/jwks");
707            then.status(200)
708                .header("content-type", "application/json")
709                .body("this is not valid json");
710        });
711
712        let jwks_url = server.url("/jwks");
713        let provider = test_provider_with_http(&jwks_url);
714
715        let result = provider.perform_refresh().await;
716        assert!(result.is_err());
717
718        let err = result.unwrap_err();
719        let err_msg = err.to_string();
720        assert!(
721            err_msg.contains("JWKS JSON parse failed"),
722            "Expected error to contain 'JWKS JSON parse failed', got: {err_msg}"
723        );
724
725        mock.assert();
726    }
727
728    #[tokio::test]
729    async fn test_fetch_jwks_empty_keys_error() {
730        let server = MockServer::start();
731
732        let mock = server.mock(|when, then| {
733            when.method(GET).path("/jwks");
734            then.status(200)
735                .header("content-type", "application/json")
736                .body(r#"{"keys": []}"#);
737        });
738
739        let jwks_url = server.url("/jwks");
740        let provider = test_provider_with_http(&jwks_url);
741
742        let result = provider.perform_refresh().await;
743        assert!(result.is_err());
744
745        let err = result.unwrap_err();
746        let err_msg = err.to_string();
747        assert!(
748            err_msg.contains("No valid RSA keys"),
749            "Expected error about no RSA keys, got: {err_msg}"
750        );
751
752        mock.assert();
753    }
754
755    #[tokio::test]
756    async fn test_on_demand_refresh_respects_cooldown() {
757        let server = MockServer::start();
758
759        // First request will return 404
760        let mock = server.mock(|when, then| {
761            when.method(GET).path("/jwks");
762            then.status(404).body("Not Found");
763        });
764
765        let jwks_url = server.url("/jwks");
766        let provider = test_provider_with_http(&jwks_url)
767            .with_on_demand_refresh_cooldown(Duration::from_secs(60));
768
769        // First attempt - should try to refresh and fail
770        let result1 = provider.on_demand_refresh("test-kid").await;
771        assert!(result1.is_err());
772
773        // Immediate second attempt - should be throttled (no network call)
774        let result2 = provider.on_demand_refresh("test-kid").await;
775        assert!(result2.is_err());
776
777        // Should return UnknownKeyId due to cooldown
778        match result2.unwrap_err() {
779            ClaimsError::UnknownKeyId(_) => {}
780            other => panic!("Expected UnknownKeyId during cooldown, got: {other:?}"),
781        }
782
783        // Only one request should have been made (first attempt)
784        mock.assert_calls(1);
785    }
786
787    #[tokio::test]
788    async fn test_on_demand_refresh_tracks_failed_kids() {
789        let server = MockServer::start();
790
791        server.mock(|when, then| {
792            when.method(GET).path("/jwks");
793            then.status(404).body("Not Found");
794        });
795
796        let jwks_url = server.url("/jwks");
797        let provider = test_provider_with_http(&jwks_url)
798            .with_on_demand_refresh_cooldown(Duration::from_millis(100));
799
800        // Attempt refresh - will fail and track the kid
801        let result = provider.on_demand_refresh("failed-kid").await;
802        assert!(result.is_err());
803
804        // Check that failed_kids contains the kid
805        let state = provider.refresh_state.read().await;
806        assert!(state.failed_kids.contains("failed-kid"));
807    }
808
809    #[tokio::test]
810    async fn test_perform_refresh_updates_state_on_failure() {
811        let server = MockServer::start();
812
813        server.mock(|when, then| {
814            when.method(GET).path("/jwks");
815            then.status(500).body("Server Error");
816        });
817
818        let jwks_url = server.url("/jwks");
819        let provider = test_provider_with_http(&jwks_url);
820
821        // Mark as previously failed
822        {
823            let mut state = provider.refresh_state.write().await;
824            state.consecutive_failures = 3;
825            state.last_error = Some("Previous error".to_owned());
826        }
827
828        // This will fail
829        _ = provider.perform_refresh().await;
830
831        // Check that consecutive_failures increased
832        let state = provider.refresh_state.read().await;
833        assert_eq!(state.consecutive_failures, 4);
834        assert!(state.last_error.is_some());
835    }
836
837    #[tokio::test]
838    async fn test_perform_refresh_resets_state_on_success() {
839        let server = MockServer::start();
840
841        server.mock(|when, then| {
842            when.method(GET).path("/jwks");
843            then.status(200)
844                .header("content-type", "application/json")
845                .body(valid_jwks_json());
846        });
847
848        let jwks_url = server.url("/jwks");
849        let provider = test_provider_with_http(&jwks_url);
850
851        // Mark as previously failed
852        {
853            let mut state = provider.refresh_state.write().await;
854            state.consecutive_failures = 5;
855            state.last_error = Some("Previous error".to_owned());
856        }
857
858        // This should succeed
859        let result = provider.perform_refresh().await;
860        assert!(result.is_ok());
861
862        // Check that state was reset
863        let state = provider.refresh_state.read().await;
864        assert_eq!(state.consecutive_failures, 0);
865        assert!(state.last_error.is_none());
866    }
867
868    #[tokio::test]
869    async fn test_validate_and_decode_with_missing_kid() {
870        let server = MockServer::start();
871
872        // Return valid JWKS but without the requested kid
873        server.mock(|when, then| {
874            when.method(GET).path("/jwks");
875            then.status(200)
876                .header("content-type", "application/json")
877                .body(valid_jwks_json());
878        });
879
880        let jwks_url = server.url("/jwks");
881        let provider = test_provider_with_http(&jwks_url)
882            .with_on_demand_refresh_cooldown(Duration::from_millis(100));
883
884        // Create a minimal JWT with a kid that doesn't exist in JWKS
885        // Header: {"alg":"RS256","kid":"nonexistent-kid"}
886        let token = "eyJhbGciOiJSUzI1NiIsImtpZCI6Im5vbmV4aXN0ZW50LWtpZCJ9.\
887                     eyJzdWIiOiIxMjM0NTY3ODkwIn0.invalid";
888
889        // Should attempt on-demand refresh but kid still won't exist
890        let result = provider.validate_and_decode(token).await;
891        assert!(result.is_err());
892
893        match result.unwrap_err() {
894            ClaimsError::UnknownKeyId(kid) => {
895                assert_eq!(kid, "nonexistent-kid");
896            }
897            other => panic!("Expected UnknownKeyId, got: {other:?}"),
898        }
899    }
900
901    #[test]
902    fn test_decode_header_with_handler_coerces_non_string_extras() {
903        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
904
905        // Header with non-standard fields: integer, string, and array
906        let header_json = r#"{"alg":"RS256","eap":1,"iri":"some-string-id","irn":["role_a"],"kid":"kid-1","typ":"at+jwt"}"#;
907        let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
908        let payload_b64 = URL_SAFE_NO_PAD.encode(b"{}");
909        let token = format!("{header_b64}.{payload_b64}.fake");
910
911        let header = decode_header_with_handler(&token, &|_key, value| Some(value.to_string()))
912            .expect("should handle non-standard header fields");
913
914        assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
915        assert_eq!(header.kid.as_deref(), Some("kid-1"));
916        assert_eq!(header.typ.as_deref(), Some("at+jwt"));
917
918        // Non-string extras coerced to JSON text
919        assert_eq!(header.extras.get("eap").map(String::as_str), Some("1"));
920        assert_eq!(
921            header.extras.get("irn").map(String::as_str),
922            Some(r#"["role_a"]"#)
923        );
924        // String extras preserved as-is
925        assert_eq!(
926            header.extras.get("iri").map(String::as_str),
927            Some("some-string-id")
928        );
929    }
930
931    #[test]
932    fn test_decode_header_with_handler_can_drop_fields() {
933        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
934
935        let header_json = r#"{"alg":"RS256","eap":1,"iri":"keep-me","kid":"kid-1","typ":"JWT"}"#;
936        let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
937        let token = format!("{header_b64}.e30.fake");
938
939        let header = decode_header_with_handler(&token, &|_key, _value| None)
940            .expect("should succeed when handler drops non-string fields");
941
942        assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
943        assert!(!header.extras.contains_key("eap"));
944        assert_eq!(
945            header.extras.get("iri").map(String::as_str),
946            Some("keep-me")
947        );
948    }
949
950    #[tokio::test]
951    async fn test_with_header_extras_stringified_coerces_non_string_extras() {
952        let server = MockServer::start();
953
954        server.mock(|when, then| {
955            when.method(GET).path("/jwks");
956            then.status(200)
957                .header("content-type", "application/json")
958                .body(valid_jwks_json());
959        });
960
961        let jwks_url = server.url("/jwks");
962        let provider = test_provider_with_http(&jwks_url).with_header_extras_stringified();
963
964        // Header with non-string extras: integer and array
965        let header_json =
966            r#"{"alg":"RS256","kid":"test-key-1","typ":"JWT","eap":1,"irn":["role_a"]}"#;
967        let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
968        let payload_b64 = URL_SAFE_NO_PAD.encode(b"{}");
969        let token = format!("{header_b64}.{payload_b64}.AAAA");
970
971        let result = provider.validate_and_decode(&token).await;
972
973        // The handler lets header decode succeed; error must come from signature
974        // validation, not from header parsing.
975        let err = result.expect_err("fake signature should fail validation");
976        assert!(
977            matches!(
978                &err,
979                ClaimsError::InvalidSignature | ClaimsError::DecodeFailed(_)
980            ),
981            "Expected signature-related error, got: {err:?}"
982        );
983    }
984
985    #[tokio::test]
986    async fn test_validate_and_decode_uses_header_extras_handler() {
987        let server = MockServer::start();
988
989        server.mock(|when, then| {
990            when.method(GET).path("/jwks");
991            then.status(200)
992                .header("content-type", "application/json")
993                .body(valid_jwks_json());
994        });
995
996        let jwks_url = server.url("/jwks");
997        let provider = test_provider_with_http(&jwks_url)
998            .with_header_extras_handler(|_key, value| Some(value.to_string()));
999
1000        // Header with a non-string extra ("eap":1) that would reject without handler
1001        let header_json = r#"{"alg":"RS256","kid":"test-key-1","typ":"JWT","eap":1}"#;
1002        let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
1003        let payload_b64 = URL_SAFE_NO_PAD.encode(b"{}");
1004        let token = format!("{header_b64}.{payload_b64}.AAAA");
1005
1006        let result = provider.validate_and_decode(&token).await;
1007
1008        // Handler lets header decode succeed → error must come from signature
1009        // validation, not from header parsing.
1010        let err = result.expect_err("fake signature should fail validation");
1011        assert!(
1012            matches!(
1013                &err,
1014                ClaimsError::InvalidSignature | ClaimsError::DecodeFailed(_)
1015            ),
1016            "Expected signature-related error, got: {err:?}"
1017        );
1018    }
1019
1020    /// RSA private key (PKCS#8 PEM) used to sign test JWTs.
1021    /// The matching public-key components (n, e) are served by `signed_jwks_json()`.
1022    const TEST_RSA_PRIVATE_PEM: &[u8] = b"-----BEGIN PRIVATE KEY-----
1023MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCohcw9B9YK7ULF
1024KgrGNJKAH0BH9CpJB03wIkQl6ECCJ/BfmBsNSWwZdnG0cWwwGhsSSSj32AKB+t6W
102544/vi9hv+PHusIRCMNqM/AJ/zA7xau9mNsxS8U8J3olm74vLFtF05hTRmJuefMmz
1026mOt4kMP44UeVg0nyFlToa0SmhMxIeFgz2VgktHjHDe/rr/FdrjMwxesz3ezj+Y4k
1027YPPrQfMZJTyEd68M+pPkjyg6AkakNSUJp+dZibnRLKcj6Ehz1W3lSGkaQ4YFSXVX
1028UCaHWNmPsJHejwKrUA/fbkYi3sLO7cW/4h+b2laWsL9qC4P2RJMbZBzklJoL+WoH
1029Lo5zUvo7AgMBAAECggEACrynlBXdOcn/EI/KqvErilUzY8I3NXrtKMkOHXosLf68
1030bmLDCngslny45t25HmFzaxlVLmFJW52vs95gy8rVqeCrDWGas5roOcZOpHTMWO5O
1031vWztXLV6Ky9OAsxtVC2qf6+vEOGPvKvHsBUkn4RdsAwuYuS//9gTZdF7yL46Q72o
1032pJ8bLUZBpqmVNyLxyfbFn8u9j71zMUweB9vOMYAIAv1cYRa/0bVYLIZumcotY822
1033B0ny1fLru1gDJt2p1DL9fQTg16pBYr1V0nhoiktS8Lx5PFLMI+NhmalBerqtPN+u
1034qqauu9jolmXtydfOP7pTN2sqGFAKlcx55KZlVLK2YQKBgQDaiRxPXnFCPY4yYBxS
1035POFJe8UcvoM3d5HGwQfbJ5PHq+YN8NW0ACaox6QQkQYmE9OHriHrVmp4af6erN2K
1036zbjmL41E5C4MzEau2ipZWY4GA+lLXomEiHsUD0cfqfL+7Fs6ufiG2nXrWIBXggz8
10378mTdP/LHMPybY0wxoZI5Xij+2wKBgQDFacPh+PhT0U8wu7nSgvQ85ozJN7TWq0KD
1038TgWuZ0W6L5OlAAVernYuvvRH/Uy9JqVfX4KLHbcEcdUx8t5usKMf8S3kQyMM8xK+
1039KaEYZNOMdA6E9PAJVD8crDQT/QD6/+oHrTTFFKxW7jWLY1ggWXVHk4CxLXBlDnKQ
1040xIA5DuhgIQKBgQCA5Km77loi1aeO8r0BjELcUpH52CwQhQeIEMYPbpJtDGhOBKQm
10413IfwuH99/euAfeUfe4cqBPgbOXkiIZcxjRDnQ1ixL1wx1DJEYwzjUjzAM4JgH8xA
1042TTc6p6AtftGBpepRAusgrq0qODLKajw63MS88kDBV5VGGRURmNhj2bOYTQKBgHPr
1043hiVj/9Wf+6M/KH9vfCFis9rYBi1jxRu7LeTaKXyJwWXLHFwbj7QlVuYK3AvZ7JOT
1044TuGHoldOzISW+3v95tuz0GHP9n39Ic1ePoVHd11rLLdv6J9hw+l/SNlP4EqDCZZW
1045Y70yRXyKRhDCVhYw0YglGhVv/CarFCTj7fMTSOphAoGBAJcM4H4qmCFLdR9FRQgT
1046YJPGcyjWPmm9tlb8M6rSJGPlfpAhKjRVGWwpHPiUnvrW296QKr9+5q43HRcK3qa5
1047GU5n8VxYiniVFVMSEpLJgvu7hGq5fmMiRTTot1pOTSXZ1LY6rDQvjsTeGQumb/Eo
1048F8gvjIeiwVfp4nDnO2JFexiy
1049-----END PRIVATE KEY-----";
1050
1051    /// JWKS JSON whose public key matches `TEST_RSA_PRIVATE_PEM`.
1052    fn signed_jwks_json() -> &'static str {
1053        r#"{
1054            "keys": [{
1055                "kty": "RSA",
1056                "kid": "sign-key-1",
1057                "use": "sig",
1058                "n": "qIXMPQfWCu1CxSoKxjSSgB9AR_QqSQdN8CJEJehAgifwX5gbDUlsGXZxtHFsMBobEkko99gCgfreluOP74vYb_jx7rCEQjDajPwCf8wO8WrvZjbMUvFPCd6JZu-LyxbRdOYU0ZibnnzJs5jreJDD-OFHlYNJ8hZU6GtEpoTMSHhYM9lYJLR4xw3v66_xXa4zMMXrM93s4_mOJGDz60HzGSU8hHevDPqT5I8oOgJGpDUlCafnWYm50SynI-hIc9Vt5UhpGkOGBUl1V1Amh1jZj7CR3o8Cq1AP325GIt7Czu3Fv-Ifm9pWlrC_aguD9kSTG2Qc5JSaC_lqBy6Oc1L6Ow",
1059                "e": "AQAB",
1060                "alg": "RS256"
1061            }]
1062        }"#
1063    }
1064
1065    /// Build a properly-signed RS256 JWT for testing.
1066    fn build_signed_jwt(kid: &str, claims: &serde_json::Value) -> String {
1067        let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(TEST_RSA_PRIVATE_PEM)
1068            .expect("test RSA PEM should be valid");
1069        let mut header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256);
1070        header.kid = Some(kid.to_owned());
1071        jsonwebtoken::encode(&header, claims, &encoding_key).expect("JWT signing should succeed")
1072    }
1073
1074    #[tokio::test]
1075    async fn test_validate_and_decode_happy_path() {
1076        let server = MockServer::start();
1077
1078        server.mock(|when, then| {
1079            when.method(GET).path("/jwks");
1080            then.status(200)
1081                .header("content-type", "application/json")
1082                .body(signed_jwks_json());
1083        });
1084
1085        let jwks_url = server.url("/jwks");
1086        let provider = test_provider_with_http(&jwks_url);
1087
1088        let claims = serde_json::json!({
1089            "sub": "user-42",
1090            "name": "Test User",
1091            "iat": 1_700_000_000u64
1092        });
1093        let token = build_signed_jwt("sign-key-1", &claims);
1094
1095        let (header, decoded_claims) = provider
1096            .validate_and_decode(&token)
1097            .await
1098            .expect("validate_and_decode should succeed for a properly signed token");
1099
1100        assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
1101        assert_eq!(header.kid.as_deref(), Some("sign-key-1"));
1102        assert_eq!(decoded_claims["sub"], "user-42");
1103        assert_eq!(decoded_claims["name"], "Test User");
1104    }
1105
1106    #[tokio::test]
1107    async fn test_validate_and_decode_with_bearer_prefix() {
1108        let server = MockServer::start();
1109
1110        server.mock(|when, then| {
1111            when.method(GET).path("/jwks");
1112            then.status(200)
1113                .header("content-type", "application/json")
1114                .body(signed_jwks_json());
1115        });
1116
1117        let jwks_url = server.url("/jwks");
1118        let provider = test_provider_with_http(&jwks_url);
1119
1120        let claims = serde_json::json!({"sub": "user-99"});
1121        let token = format!("Bearer {}", build_signed_jwt("sign-key-1", &claims));
1122
1123        let (_, decoded_claims) = provider
1124            .validate_and_decode(&token)
1125            .await
1126            .expect("should strip Bearer prefix and succeed");
1127
1128        assert_eq!(decoded_claims["sub"], "user-99");
1129    }
1130
1131    #[tokio::test]
1132    async fn test_validate_and_decode_rejects_tampered_payload() {
1133        let server = MockServer::start();
1134
1135        server.mock(|when, then| {
1136            when.method(GET).path("/jwks");
1137            then.status(200)
1138                .header("content-type", "application/json")
1139                .body(signed_jwks_json());
1140        });
1141
1142        let jwks_url = server.url("/jwks");
1143        let provider = test_provider_with_http(&jwks_url);
1144
1145        let claims = serde_json::json!({"sub": "legit"});
1146        let token = build_signed_jwt("sign-key-1", &claims);
1147
1148        // Tamper with the payload segment
1149        let parts: Vec<&str> = token.splitn(3, '.').collect();
1150        let tampered_payload = URL_SAFE_NO_PAD.encode(br#"{"sub":"evil"}"#);
1151        let tampered_token = format!("{}.{}.{}", parts[0], tampered_payload, parts[2]);
1152
1153        let err = provider
1154            .validate_and_decode(&tampered_token)
1155            .await
1156            .expect_err("tampered token should fail signature verification");
1157
1158        assert!(
1159            matches!(err, ClaimsError::InvalidSignature),
1160            "Expected InvalidSignature, got: {err:?}"
1161        );
1162    }
1163
1164    /// Build a JWT with a custom header JSON (for non-string extras), properly signed.
1165    fn build_signed_jwt_custom_header(header_json: &str, claims: &serde_json::Value) -> String {
1166        let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(TEST_RSA_PRIVATE_PEM)
1167            .expect("test RSA PEM should be valid");
1168        let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
1169        let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(claims).unwrap());
1170        let message = format!("{header_b64}.{payload_b64}");
1171        let signature = jsonwebtoken::crypto::sign(
1172            message.as_bytes(),
1173            &encoding_key,
1174            jsonwebtoken::Algorithm::RS256,
1175        )
1176        .expect("signing should succeed");
1177        format!("{message}.{signature}")
1178    }
1179
1180    #[tokio::test]
1181    async fn test_validate_and_decode_with_non_string_header_extras() {
1182        let server = MockServer::start();
1183
1184        server.mock(|when, then| {
1185            when.method(GET).path("/jwks");
1186            then.status(200)
1187                .header("content-type", "application/json")
1188                .body(signed_jwks_json());
1189        });
1190
1191        let jwks_url = server.url("/jwks");
1192        let provider = test_provider_with_http(&jwks_url).with_header_extras_stringified();
1193
1194        let claims = serde_json::json!({"sub": "user-extras"});
1195        let header_json = r#"{"alg":"RS256","kid":"sign-key-1","typ":"JWT","eap":1}"#;
1196        let token = build_signed_jwt_custom_header(header_json, &claims);
1197
1198        let (header, decoded_claims) = provider
1199            .validate_and_decode(&token)
1200            .await
1201            .expect("should decode JWT with non-string header extras when handler is set");
1202
1203        assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
1204        assert_eq!(header.kid.as_deref(), Some("sign-key-1"));
1205        assert_eq!(header.extras.get("eap").map(String::as_str), Some("1"));
1206        assert_eq!(decoded_claims["sub"], "user-extras");
1207    }
1208
1209    #[test]
1210    fn test_decode_header_without_handler_rejects_non_string_extras() {
1211        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
1212
1213        let header_json = r#"{"alg":"RS256","eap":1,"kid":"kid-1","typ":"JWT"}"#;
1214        let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
1215        let token = format!("{header_b64}.e30.fake");
1216
1217        let result = decode_header(&token);
1218        assert!(result.is_err());
1219        let err = result.unwrap_err().to_string();
1220        assert!(
1221            err.contains("invalid type: integer"),
1222            "expected type error, got: {err}"
1223        );
1224    }
1225}