Skip to main content

modkit_auth/providers/
jwks.rs

1use crate::{claims_error::ClaimsError, traits::KeyProvider};
2use arc_swap::ArcSwap;
3use async_trait::async_trait;
4use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
5use jsonwebtoken::{DecodingKey, Header, Validation, decode, decode_header};
6use serde::Deserialize;
7use serde_json::Value;
8use std::collections::{HashMap, HashSet};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::RwLock;
12use tokio::time::Instant;
13use tokio_util::sync::CancellationToken;
14
15#[derive(Debug, Clone, Deserialize)]
16struct Jwk {
17    kid: String,
18    kty: String,
19    #[serde(rename = "use")]
20    #[allow(dead_code)]
21    use_: Option<String>,
22    n: String,
23    e: String,
24    #[allow(dead_code)]
25    alg: Option<String>,
26}
27
28#[derive(Debug, Clone, Deserialize)]
29struct JwksResponse {
30    keys: Vec<Jwk>,
31}
32
33/// Handler for non-string custom JWT header fields; return `Some` to keep as string, or `None` to drop.
34type HeaderExtrasHandler = dyn Fn(&str, &Value) -> Option<String> + Send + Sync;
35
36/// Standard JWT header field names from RFC 7515 (JWS), RFC 7516 (JWE),
37/// RFC 7518 (JWA), RFC 7797 (b64), and RFC 8555 (ACME).
38const STANDARD_HEADER_FIELDS: &[&str] = &[
39    "typ", "alg", "cty", "jku", "jwk", "kid", "x5u", "x5c", "x5t", "x5t#S256", "crit", "enc",
40    "zip", "url", "nonce", "epk", "apu", "apv", "iv", "tag", "p2s", "p2c", "b64",
41];
42
43/// JWKS-based key provider with lock-free reads
44///
45/// Uses `ArcSwap` for lock-free key lookups and background refresh with exponential backoff.
46#[must_use]
47pub struct JwksKeyProvider {
48    /// JWKS endpoint URL
49    jwks_uri: String,
50
51    /// Keys stored in `ArcSwap` for lock-free reads
52    keys: Arc<ArcSwap<HashMap<String, DecodingKey>>>,
53
54    /// Last refresh time and error tracking for backoff
55    refresh_state: Arc<RwLock<RefreshState>>,
56
57    /// Shared HTTP client for JWKS fetches (pooled connections)
58    /// `HttpClient` is `Clone + Send + Sync`, no external locking needed.
59    client: modkit_http::HttpClient,
60
61    /// Refresh interval (default: 5 minutes)
62    refresh_interval: Duration,
63
64    /// Maximum backoff duration (default: 1 hour)
65    max_backoff: Duration,
66
67    /// Cooldown for on-demand refresh (default: 60 seconds)
68    on_demand_refresh_cooldown: Duration,
69
70    /// Optional handler for non-string custom JWT header fields.
71    /// Called for each non-standard field whose value is not a JSON string.
72    /// Return `Some(s)` to keep, `None` to drop.
73    header_extras_handler: Option<Arc<HeaderExtrasHandler>>,
74}
75
76#[derive(Debug, Default)]
77struct RefreshState {
78    last_refresh: Option<Instant>,
79    last_on_demand_refresh: Option<Instant>,
80    consecutive_failures: u32,
81    last_error: Option<String>,
82    failed_kids: HashSet<String>,
83}
84
85impl JwksKeyProvider {
86    /// Create a new JWKS key provider
87    ///
88    /// # Errors
89    /// Returns error if HTTP client initialization fails (e.g., TLS setup)
90    pub fn new(jwks_uri: impl Into<String>) -> Result<Self, modkit_http::HttpError> {
91        Self::with_http_timeout(jwks_uri, Duration::from_secs(10))
92    }
93
94    /// Create a new JWKS key provider with custom HTTP timeout
95    ///
96    /// # Errors
97    /// Returns error if HTTP client initialization fails (e.g., TLS setup)
98    pub fn with_http_timeout(
99        jwks_uri: impl Into<String>,
100        timeout: Duration,
101    ) -> Result<Self, modkit_http::HttpError> {
102        let client = modkit_http::HttpClient::builder()
103            .timeout(timeout)
104            .retry(None) // JWKS provider handles its own retry logic
105            .build()?;
106
107        Ok(Self {
108            jwks_uri: jwks_uri.into(),
109            keys: Arc::new(ArcSwap::from_pointee(HashMap::new())),
110            refresh_state: Arc::new(RwLock::new(RefreshState::default())),
111            client,
112            refresh_interval: Duration::from_secs(300), // 5 minutes
113            max_backoff: Duration::from_secs(3600),     // 1 hour
114            on_demand_refresh_cooldown: Duration::from_secs(60), // 1 minute
115            header_extras_handler: None,
116        })
117    }
118
119    /// Create a new JWKS key provider (alias for new, kept for compatibility)
120    ///
121    /// # Errors
122    /// Returns error if HTTP client initialization fails (e.g., TLS setup)
123    pub fn try_new(jwks_uri: impl Into<String>) -> Result<Self, modkit_http::HttpError> {
124        Self::new(jwks_uri)
125    }
126
127    /// Create with custom refresh interval
128    pub fn with_refresh_interval(mut self, interval: Duration) -> Self {
129        self.refresh_interval = interval;
130        self
131    }
132
133    /// Create with custom max backoff
134    pub fn with_max_backoff(mut self, max_backoff: Duration) -> Self {
135        self.max_backoff = max_backoff;
136        self
137    }
138
139    /// Create with custom on-demand refresh cooldown
140    pub fn with_on_demand_refresh_cooldown(mut self, cooldown: Duration) -> Self {
141        self.on_demand_refresh_cooldown = cooldown;
142        self
143    }
144
145    /// Stringify all non-string custom JWT header fields.
146    ///
147    /// Convenience wrapper around [`with_header_extras_handler`](Self::with_header_extras_handler)
148    /// that converts every non-string value to its JSON representation
149    /// (e.g. `123` → `"123"`, `true` → `"true"`, `[1,2]` → `"[1,2]"`).
150    pub fn with_header_extras_stringified(self) -> Self {
151        self.with_header_extras_handler(|_, v| Some(v.to_string()))
152    }
153
154    /// Set a handler for non-string custom JWT header fields.
155    ///
156    /// `jsonwebtoken::Header::extras` is `HashMap<String, String>` and rejects
157    /// non-string values. This callback is invoked for each such field.
158    /// Return `Some(s)` to keep, `None` to drop.
159    /// Without a handler, upstream `decode_header` is used as-is.
160    pub fn with_header_extras_handler(
161        mut self,
162        handler: impl Fn(&str, &Value) -> Option<String> + Send + Sync + 'static,
163    ) -> Self {
164        self.header_extras_handler = Some(Arc::new(handler));
165        self
166    }
167
168    /// Fetch JWKS from the endpoint
169    async fn fetch_jwks(&self) -> Result<HashMap<String, DecodingKey>, ClaimsError> {
170        // HttpClient is Clone + Send + Sync, no locking needed
171        let jwks: JwksResponse = self
172            .client
173            .get(&self.jwks_uri)
174            .send()
175            .await
176            .map_err(|e| map_http_error(&e))?
177            .json()
178            .await
179            .map_err(|e| map_http_error(&e))?;
180
181        let mut keys = HashMap::new();
182        for jwk in jwks.keys {
183            if jwk.kty == "RSA" {
184                let key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)
185                    .map_err(|e| ClaimsError::JwksFetchFailed(format!("Invalid RSA key: {e}")))?;
186                keys.insert(jwk.kid, key);
187            }
188        }
189
190        if keys.is_empty() {
191            return Err(ClaimsError::JwksFetchFailed(
192                "No valid RSA keys found in JWKS".into(),
193            ));
194        }
195
196        Ok(keys)
197    }
198
199    /// Calculate backoff duration based on consecutive failures
200    fn calculate_backoff(&self, failures: u32) -> Duration {
201        let base = Duration::from_secs(60); // 1 minute base
202        let exponential = base * 2u32.pow(failures.min(10)); // Cap at 2^10
203        exponential.min(self.max_backoff)
204    }
205
206    /// Check if refresh is needed based on interval and backoff
207    async fn should_refresh(&self) -> bool {
208        let state = self.refresh_state.read().await;
209
210        match state.last_refresh {
211            None => true, // Never refreshed
212            Some(last) => {
213                let elapsed = last.elapsed();
214                if state.consecutive_failures == 0 {
215                    // Normal refresh interval
216                    elapsed >= self.refresh_interval
217                } else {
218                    // Exponential backoff
219                    elapsed >= self.calculate_backoff(state.consecutive_failures)
220                }
221            }
222        }
223    }
224
225    /// Perform key refresh with error tracking
226    async fn perform_refresh(&self) -> Result<(), ClaimsError> {
227        match self.fetch_jwks().await {
228            Ok(new_keys) => {
229                // Update keys atomically
230                self.keys.store(Arc::new(new_keys));
231
232                // Update refresh state
233                let mut state = self.refresh_state.write().await;
234                state.last_refresh = Some(Instant::now());
235                state.consecutive_failures = 0;
236                state.last_error = None;
237
238                Ok(())
239            }
240            Err(e) => {
241                // Update failure state
242                let mut state = self.refresh_state.write().await;
243                state.last_refresh = Some(Instant::now());
244                state.consecutive_failures += 1;
245                state.last_error = Some(e.to_string());
246
247                Err(e)
248            }
249        }
250    }
251
252    /// Check if a key exists in the cache
253    fn key_exists(&self, kid: &str) -> bool {
254        let keys = self.keys.load();
255        keys.contains_key(kid)
256    }
257
258    /// Check if we're in cooldown period and handle throttling logic
259    async fn check_refresh_throttle(&self, kid: &str) -> Result<(), ClaimsError> {
260        let state = self.refresh_state.read().await;
261        if let Some(last_on_demand) = state.last_on_demand_refresh {
262            let elapsed = last_on_demand.elapsed();
263            if elapsed < self.on_demand_refresh_cooldown {
264                let remaining = self.on_demand_refresh_cooldown.saturating_sub(elapsed);
265                tracing::debug!(
266                    kid = kid,
267                    remaining_secs = remaining.as_secs(),
268                    "On-demand JWKS refresh throttled (cooldown active)"
269                );
270
271                // Check if this kid has failed before
272                if state.failed_kids.contains(kid) {
273                    tracing::warn!(
274                        kid = kid,
275                        "Unknown kid repeatedly requested despite recent refresh attempts"
276                    );
277                }
278
279                return Err(ClaimsError::UnknownKeyId(kid.to_owned()));
280            }
281        }
282        Ok(())
283    }
284
285    /// Update state after successful refresh and check if kid is now available
286    async fn handle_refresh_success(&self, kid: &str) -> Result<(), ClaimsError> {
287        let mut state = self.refresh_state.write().await;
288        state.last_on_demand_refresh = Some(Instant::now());
289
290        // Check if the kid now exists
291        if self.key_exists(kid) {
292            // Kid found - remove from failed list if present
293            state.failed_kids.remove(kid);
294        } else {
295            // Kid still not found after refresh - track it
296            state.failed_kids.insert(kid.to_owned());
297            tracing::warn!(
298                kid = kid,
299                "Kid still not found after on-demand JWKS refresh"
300            );
301        }
302
303        Ok(())
304    }
305
306    /// Update state after failed refresh
307    async fn handle_refresh_failure(&self, kid: &str, error: ClaimsError) -> ClaimsError {
308        let mut state = self.refresh_state.write().await;
309        state.last_on_demand_refresh = Some(Instant::now());
310        state.failed_kids.insert(kid.to_owned());
311        error
312    }
313
314    /// Try to refresh keys if unknown kid is encountered
315    /// Implements throttling to prevent excessive refreshes
316    async fn on_demand_refresh(&self, kid: &str) -> Result<(), ClaimsError> {
317        // Check if key exists
318        if self.key_exists(kid) {
319            return Ok(());
320        }
321
322        // Check if we're in cooldown period
323        self.check_refresh_throttle(kid).await?;
324
325        // Attempt refresh and track the kid if it fails
326        tracing::info!(
327            kid = kid,
328            "Performing on-demand JWKS refresh for unknown kid"
329        );
330
331        match self.perform_refresh().await {
332            Ok(()) => self.handle_refresh_success(kid).await,
333            Err(e) => Err(self.handle_refresh_failure(kid, e).await),
334        }
335    }
336
337    /// Get a key by kid (lock-free read)
338    fn get_key(&self, kid: &str) -> Option<DecodingKey> {
339        let keys = self.keys.load();
340        keys.get(kid).cloned()
341    }
342
343    /// Validate JWT and decode into header + raw claims
344    fn validate_token(
345        token: &str,
346        key: &DecodingKey,
347        header: &Header,
348    ) -> Result<Value, ClaimsError> {
349        let mut validation = Validation::new(header.alg);
350
351        // Disable all built-in validations - we'll do them separately
352        validation.validate_exp = false;
353        validation.validate_nbf = false;
354        validation.validate_aud = false;
355
356        // Don't require any standard claims
357        let empty_claims: &[&str] = &[];
358        validation.set_required_spec_claims(empty_claims);
359
360        let token_data = decode::<Value>(token, key, &validation)
361            .map_err(|e| ClaimsError::DecodeFailed(format!("JWT validation failed: {e}")))?;
362
363        Ok(token_data.claims)
364    }
365}
366
367#[async_trait]
368impl KeyProvider for JwksKeyProvider {
369    fn name(&self) -> &'static str {
370        "jwks"
371    }
372
373    async fn validate_and_decode(&self, token: &str) -> Result<(Header, Value), ClaimsError> {
374        // Strip "Bearer " prefix if present
375        let token = token.trim_start_matches("Bearer ").trim();
376
377        // Decode header to get kid and algorithm
378        let header = match &self.header_extras_handler {
379            Some(handler) => decode_header_with_handler(token, handler.as_ref()),
380            None => decode_header(token),
381        }
382        .map_err(|e| ClaimsError::DecodeFailed(format!("Invalid JWT header: {e}")))?;
383
384        let kid = header
385            .kid
386            .as_ref()
387            .ok_or_else(|| ClaimsError::DecodeFailed("Missing kid in JWT header".into()))?;
388
389        // Try to get key from cache
390        let key = if let Some(k) = self.get_key(kid) {
391            k
392        } else {
393            // Key not in cache, try on-demand refresh
394            self.on_demand_refresh(kid).await?;
395
396            // Try again after refresh
397            self.get_key(kid)
398                .ok_or_else(|| ClaimsError::UnknownKeyId(kid.clone()))?
399        };
400
401        // Validate signature and decode claims
402        let claims = Self::validate_token(token, &key, &header)?;
403
404        Ok((header, claims))
405    }
406
407    async fn refresh_keys(&self) -> Result<(), ClaimsError> {
408        if self.should_refresh().await {
409            self.perform_refresh().await
410        } else {
411            Ok(())
412        }
413    }
414}
415
416/// Background task to periodically refresh JWKS
417///
418/// This task will run until the `cancellation_token` is cancelled, enabling
419/// graceful shutdown per `ModKit` patterns. Without cancellation support, this
420/// task would run indefinitely and potentially cause process hang on shutdown.
421///
422/// # Example
423///
424/// ```ignore
425/// use tokio_util::sync::CancellationToken;
426/// use std::sync::Arc;
427///
428/// let provider = Arc::new(JwksKeyProvider::new("https://issuer/.well-known/jwks.json")?);
429/// let cancel_token = CancellationToken::new();
430///
431/// // Spawn the refresh task
432/// let task_handle = tokio::spawn(run_jwks_refresh_task(provider.clone(), cancel_token.clone()));
433///
434/// // On shutdown:
435/// cancel_token.cancel();
436/// task_handle.await?;
437/// ```
438pub async fn run_jwks_refresh_task(
439    provider: Arc<JwksKeyProvider>,
440    cancellation_token: CancellationToken,
441) {
442    let mut interval = tokio::time::interval(Duration::from_secs(60)); // Check every minute
443
444    loop {
445        tokio::select! {
446            () = cancellation_token.cancelled() => {
447                tracing::info!("JWKS refresh task shutting down");
448                break;
449            }
450            _ = interval.tick() => {
451                if let Err(e) = provider.refresh_keys().await {
452                    tracing::warn!("JWKS refresh failed: {}", e);
453                }
454            }
455        }
456    }
457}
458
459/// Decode a JWT header, routing non-string custom fields through `handler`.
460///
461/// Returns `Some(s)` to keep the field, `None` to drop it.
462fn decode_header_with_handler(
463    token: &str,
464    handler: &dyn Fn(&str, &Value) -> Option<String>,
465) -> Result<Header, jsonwebtoken::errors::Error> {
466    let header_b64 = token
467        .split('.')
468        .next()
469        .ok_or(jsonwebtoken::errors::ErrorKind::InvalidToken)?;
470
471    let header_bytes = URL_SAFE_NO_PAD
472        .decode(header_b64.trim_end_matches('='))
473        .map_err(jsonwebtoken::errors::ErrorKind::Base64)?;
474
475    let mut json: serde_json::Map<String, Value> = serde_json::from_slice(&header_bytes)?;
476
477    json.retain(|key, value| {
478        if STANDARD_HEADER_FIELDS.contains(&key.as_str()) || value.is_string() {
479            return true;
480        }
481        match handler(key, value) {
482            Some(s) => {
483                *value = Value::String(s);
484                true
485            }
486            None => false,
487        }
488    });
489
490    Ok(serde_json::from_value(Value::Object(json))?)
491}
492
493/// Map `HttpError` variants to appropriate `ClaimsError` messages
494fn map_http_error(e: &modkit_http::HttpError) -> ClaimsError {
495    ClaimsError::JwksFetchFailed(crate::http_error::format_http_error(e, "JWKS"))
496}
497
498#[cfg(test)]
499#[cfg_attr(coverage_nightly, coverage(off))]
500mod tests {
501    use super::*;
502    use httpmock::prelude::*;
503
504    /// Create a test provider with insecure HTTP allowed (for httpmock) and no retries
505    fn test_provider_with_http(uri: &str) -> JwksKeyProvider {
506        let client = modkit_http::HttpClient::builder()
507            .timeout(Duration::from_secs(5))
508            .retry(None)
509            .allow_insecure_http()
510            .build()
511            .expect("failed to create test HTTP client");
512
513        JwksKeyProvider {
514            jwks_uri: uri.to_owned(),
515            keys: Arc::new(ArcSwap::from_pointee(HashMap::new())),
516            refresh_state: Arc::new(RwLock::new(RefreshState::default())),
517            client,
518            refresh_interval: Duration::from_secs(300),
519            max_backoff: Duration::from_secs(3600),
520            on_demand_refresh_cooldown: Duration::from_secs(60),
521            header_extras_handler: None,
522        }
523    }
524
525    /// Create a basic test provider (HTTPS only, for non-network tests)
526    fn test_provider(uri: &str) -> JwksKeyProvider {
527        JwksKeyProvider::new(uri).expect("failed to create test provider")
528    }
529
530    /// Valid JWKS JSON response with a single RSA key
531    fn valid_jwks_json() -> &'static str {
532        r#"{
533            "keys": [{
534                "kty": "RSA",
535                "kid": "test-key-1",
536                "use": "sig",
537                "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
538                "e": "AQAB",
539                "alg": "RS256"
540            }]
541        }"#
542    }
543
544    #[tokio::test]
545    async fn test_calculate_backoff() {
546        let provider = test_provider("https://example.com/jwks");
547
548        assert_eq!(provider.calculate_backoff(0), Duration::from_secs(60));
549        assert_eq!(provider.calculate_backoff(1), Duration::from_secs(120));
550        assert_eq!(provider.calculate_backoff(2), Duration::from_secs(240));
551        assert_eq!(provider.calculate_backoff(3), Duration::from_secs(480));
552
553        // Should cap at max_backoff
554        assert_eq!(provider.calculate_backoff(100), provider.max_backoff);
555    }
556
557    #[tokio::test]
558    async fn test_should_refresh_on_first_call() {
559        let provider = test_provider("https://example.com/jwks");
560        assert!(provider.should_refresh().await);
561    }
562
563    #[tokio::test]
564    async fn test_key_storage() {
565        let provider = test_provider("https://example.com/jwks");
566
567        // Initially empty
568        assert!(provider.get_key("test-kid").is_none());
569
570        // Store a dummy key
571        let mut keys = HashMap::new();
572        keys.insert("test-kid".to_owned(), DecodingKey::from_secret(b"secret"));
573        provider.keys.store(Arc::new(keys));
574
575        // Should be retrievable
576        assert!(provider.get_key("test-kid").is_some());
577    }
578
579    #[tokio::test]
580    async fn test_on_demand_refresh_returns_ok_when_key_exists() {
581        let provider = test_provider("https://example.com/jwks");
582
583        // Pre-populate with a key
584        let mut keys = HashMap::new();
585        keys.insert(
586            "existing-kid".to_owned(),
587            DecodingKey::from_secret(b"secret"),
588        );
589        provider.keys.store(Arc::new(keys));
590
591        // Should return Ok immediately without any refresh
592        let result = provider.on_demand_refresh("existing-kid").await;
593        assert!(result.is_ok());
594    }
595
596    #[tokio::test]
597    async fn test_try_new_returns_result() {
598        // Valid URL should work
599        let result = JwksKeyProvider::try_new("https://example.com/jwks");
600        assert!(result.is_ok());
601    }
602
603    // ==================== httpmock-based tests ====================
604
605    #[tokio::test]
606    async fn test_fetch_jwks_success_with_valid_json() {
607        let server = MockServer::start();
608
609        let mock = server.mock(|when, then| {
610            when.method(GET).path("/jwks");
611            then.status(200)
612                .header("content-type", "application/json")
613                .body(valid_jwks_json());
614        });
615
616        let jwks_url = server.url("/jwks");
617        let provider = test_provider_with_http(&jwks_url);
618
619        let result = provider.perform_refresh().await;
620        assert!(result.is_ok(), "Expected success, got: {result:?}");
621
622        // Verify key was stored
623        assert!(
624            provider.get_key("test-key-1").is_some(),
625            "Expected key 'test-key-1' to be stored"
626        );
627
628        mock.assert();
629    }
630
631    #[tokio::test]
632    async fn test_fetch_jwks_http_404_error_mapping() {
633        let server = MockServer::start();
634
635        let mock = server.mock(|when, then| {
636            when.method(GET).path("/jwks");
637            then.status(404).body("Not Found");
638        });
639
640        let jwks_url = server.url("/jwks");
641        let provider = test_provider_with_http(&jwks_url);
642
643        let result = provider.perform_refresh().await;
644        assert!(result.is_err());
645
646        let err = result.unwrap_err();
647        let err_msg = err.to_string();
648        assert!(
649            err_msg.contains("JWKS HTTP 404"),
650            "Expected error to contain 'JWKS HTTP 404', got: {err_msg}"
651        );
652        // Must NOT say "parse"
653        assert!(
654            !err_msg.to_lowercase().contains("parse"),
655            "HTTP status error should not mention 'parse', got: {err_msg}"
656        );
657
658        mock.assert();
659    }
660
661    #[tokio::test]
662    async fn test_fetch_jwks_http_500_error_mapping() {
663        let server = MockServer::start();
664
665        let mock = server.mock(|when, then| {
666            when.method(GET).path("/jwks");
667            then.status(500).body("Internal Server Error");
668        });
669
670        let jwks_url = server.url("/jwks");
671        let provider = test_provider_with_http(&jwks_url);
672
673        let result = provider.perform_refresh().await;
674        assert!(result.is_err());
675
676        let err = result.unwrap_err();
677        let err_msg = err.to_string();
678        assert!(
679            err_msg.contains("JWKS HTTP 500"),
680            "Expected error to contain 'JWKS HTTP 500', got: {err_msg}"
681        );
682
683        mock.assert();
684    }
685
686    #[tokio::test]
687    async fn test_fetch_jwks_invalid_json_error_mapping() {
688        let server = MockServer::start();
689
690        let mock = server.mock(|when, then| {
691            when.method(GET).path("/jwks");
692            then.status(200)
693                .header("content-type", "application/json")
694                .body("this is not valid json");
695        });
696
697        let jwks_url = server.url("/jwks");
698        let provider = test_provider_with_http(&jwks_url);
699
700        let result = provider.perform_refresh().await;
701        assert!(result.is_err());
702
703        let err = result.unwrap_err();
704        let err_msg = err.to_string();
705        assert!(
706            err_msg.contains("JWKS JSON parse failed"),
707            "Expected error to contain 'JWKS JSON parse failed', got: {err_msg}"
708        );
709
710        mock.assert();
711    }
712
713    #[tokio::test]
714    async fn test_fetch_jwks_empty_keys_error() {
715        let server = MockServer::start();
716
717        let mock = server.mock(|when, then| {
718            when.method(GET).path("/jwks");
719            then.status(200)
720                .header("content-type", "application/json")
721                .body(r#"{"keys": []}"#);
722        });
723
724        let jwks_url = server.url("/jwks");
725        let provider = test_provider_with_http(&jwks_url);
726
727        let result = provider.perform_refresh().await;
728        assert!(result.is_err());
729
730        let err = result.unwrap_err();
731        let err_msg = err.to_string();
732        assert!(
733            err_msg.contains("No valid RSA keys"),
734            "Expected error about no RSA keys, got: {err_msg}"
735        );
736
737        mock.assert();
738    }
739
740    #[tokio::test]
741    async fn test_on_demand_refresh_respects_cooldown() {
742        let server = MockServer::start();
743
744        // First request will return 404
745        let mock = server.mock(|when, then| {
746            when.method(GET).path("/jwks");
747            then.status(404).body("Not Found");
748        });
749
750        let jwks_url = server.url("/jwks");
751        let provider = test_provider_with_http(&jwks_url)
752            .with_on_demand_refresh_cooldown(Duration::from_secs(60));
753
754        // First attempt - should try to refresh and fail
755        let result1 = provider.on_demand_refresh("test-kid").await;
756        assert!(result1.is_err());
757
758        // Immediate second attempt - should be throttled (no network call)
759        let result2 = provider.on_demand_refresh("test-kid").await;
760        assert!(result2.is_err());
761
762        // Should return UnknownKeyId due to cooldown
763        match result2.unwrap_err() {
764            ClaimsError::UnknownKeyId(_) => {}
765            other => panic!("Expected UnknownKeyId during cooldown, got: {other:?}"),
766        }
767
768        // Only one request should have been made (first attempt)
769        mock.assert_calls(1);
770    }
771
772    #[tokio::test]
773    async fn test_on_demand_refresh_tracks_failed_kids() {
774        let server = MockServer::start();
775
776        server.mock(|when, then| {
777            when.method(GET).path("/jwks");
778            then.status(404).body("Not Found");
779        });
780
781        let jwks_url = server.url("/jwks");
782        let provider = test_provider_with_http(&jwks_url)
783            .with_on_demand_refresh_cooldown(Duration::from_millis(100));
784
785        // Attempt refresh - will fail and track the kid
786        let result = provider.on_demand_refresh("failed-kid").await;
787        assert!(result.is_err());
788
789        // Check that failed_kids contains the kid
790        let state = provider.refresh_state.read().await;
791        assert!(state.failed_kids.contains("failed-kid"));
792    }
793
794    #[tokio::test]
795    async fn test_perform_refresh_updates_state_on_failure() {
796        let server = MockServer::start();
797
798        server.mock(|when, then| {
799            when.method(GET).path("/jwks");
800            then.status(500).body("Server Error");
801        });
802
803        let jwks_url = server.url("/jwks");
804        let provider = test_provider_with_http(&jwks_url);
805
806        // Mark as previously failed
807        {
808            let mut state = provider.refresh_state.write().await;
809            state.consecutive_failures = 3;
810            state.last_error = Some("Previous error".to_owned());
811        }
812
813        // This will fail
814        _ = provider.perform_refresh().await;
815
816        // Check that consecutive_failures increased
817        let state = provider.refresh_state.read().await;
818        assert_eq!(state.consecutive_failures, 4);
819        assert!(state.last_error.is_some());
820    }
821
822    #[tokio::test]
823    async fn test_perform_refresh_resets_state_on_success() {
824        let server = MockServer::start();
825
826        server.mock(|when, then| {
827            when.method(GET).path("/jwks");
828            then.status(200)
829                .header("content-type", "application/json")
830                .body(valid_jwks_json());
831        });
832
833        let jwks_url = server.url("/jwks");
834        let provider = test_provider_with_http(&jwks_url);
835
836        // Mark as previously failed
837        {
838            let mut state = provider.refresh_state.write().await;
839            state.consecutive_failures = 5;
840            state.last_error = Some("Previous error".to_owned());
841        }
842
843        // This should succeed
844        let result = provider.perform_refresh().await;
845        assert!(result.is_ok());
846
847        // Check that state was reset
848        let state = provider.refresh_state.read().await;
849        assert_eq!(state.consecutive_failures, 0);
850        assert!(state.last_error.is_none());
851    }
852
853    #[tokio::test]
854    async fn test_validate_and_decode_with_missing_kid() {
855        let server = MockServer::start();
856
857        // Return valid JWKS but without the requested kid
858        server.mock(|when, then| {
859            when.method(GET).path("/jwks");
860            then.status(200)
861                .header("content-type", "application/json")
862                .body(valid_jwks_json());
863        });
864
865        let jwks_url = server.url("/jwks");
866        let provider = test_provider_with_http(&jwks_url)
867            .with_on_demand_refresh_cooldown(Duration::from_millis(100));
868
869        // Create a minimal JWT with a kid that doesn't exist in JWKS
870        // Header: {"alg":"RS256","kid":"nonexistent-kid"}
871        let token = "eyJhbGciOiJSUzI1NiIsImtpZCI6Im5vbmV4aXN0ZW50LWtpZCJ9.\
872                     eyJzdWIiOiIxMjM0NTY3ODkwIn0.invalid";
873
874        // Should attempt on-demand refresh but kid still won't exist
875        let result = provider.validate_and_decode(token).await;
876        assert!(result.is_err());
877
878        match result.unwrap_err() {
879            ClaimsError::UnknownKeyId(kid) => {
880                assert_eq!(kid, "nonexistent-kid");
881            }
882            other => panic!("Expected UnknownKeyId, got: {other:?}"),
883        }
884    }
885
886    #[test]
887    fn test_decode_header_with_handler_coerces_non_string_extras() {
888        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
889
890        // Header with non-standard fields: integer, string, and array
891        let header_json = r#"{"alg":"RS256","eap":1,"iri":"some-string-id","irn":["role_a"],"kid":"kid-1","typ":"at+jwt"}"#;
892        let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
893        let payload_b64 = URL_SAFE_NO_PAD.encode(b"{}");
894        let token = format!("{header_b64}.{payload_b64}.fake");
895
896        let header = decode_header_with_handler(&token, &|_key, value| Some(value.to_string()))
897            .expect("should handle non-standard header fields");
898
899        assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
900        assert_eq!(header.kid.as_deref(), Some("kid-1"));
901        assert_eq!(header.typ.as_deref(), Some("at+jwt"));
902
903        // Non-string extras coerced to JSON text
904        assert_eq!(header.extras.get("eap").map(String::as_str), Some("1"));
905        assert_eq!(
906            header.extras.get("irn").map(String::as_str),
907            Some(r#"["role_a"]"#)
908        );
909        // String extras preserved as-is
910        assert_eq!(
911            header.extras.get("iri").map(String::as_str),
912            Some("some-string-id")
913        );
914    }
915
916    #[test]
917    fn test_decode_header_with_handler_can_drop_fields() {
918        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
919
920        let header_json = r#"{"alg":"RS256","eap":1,"iri":"keep-me","kid":"kid-1","typ":"JWT"}"#;
921        let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
922        let token = format!("{header_b64}.e30.fake");
923
924        let header = decode_header_with_handler(&token, &|_key, _value| None)
925            .expect("should succeed when handler drops non-string fields");
926
927        assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
928        assert!(header.extras.get("eap").is_none());
929        assert_eq!(
930            header.extras.get("iri").map(String::as_str),
931            Some("keep-me")
932        );
933    }
934
935    #[tokio::test]
936    async fn test_with_header_extras_stringified_coerces_non_string_extras() {
937        let server = MockServer::start();
938
939        server.mock(|when, then| {
940            when.method(GET).path("/jwks");
941            then.status(200)
942                .header("content-type", "application/json")
943                .body(valid_jwks_json());
944        });
945
946        let jwks_url = server.url("/jwks");
947        let provider = test_provider_with_http(&jwks_url).with_header_extras_stringified();
948
949        // Header with non-string extras: integer and array
950        let header_json =
951            r#"{"alg":"RS256","kid":"test-key-1","typ":"JWT","eap":1,"irn":["role_a"]}"#;
952        let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
953        let payload_b64 = URL_SAFE_NO_PAD.encode(b"{}");
954        let token = format!("{header_b64}.{payload_b64}.AAAA");
955
956        let result = provider.validate_and_decode(&token).await;
957
958        // The handler lets header decode succeed; error must come from signature
959        // validation, not from header parsing.
960        let err = result.expect_err("fake signature should fail validation");
961        match &err {
962            ClaimsError::DecodeFailed(msg) => {
963                assert!(
964                    msg.contains("JWT validation failed"),
965                    "Expected signature-validation error, got: {msg}"
966                );
967            }
968            other => panic!("Expected DecodeFailed, got: {other:?}"),
969        }
970    }
971
972    #[tokio::test]
973    async fn test_validate_and_decode_uses_header_extras_handler() {
974        let server = MockServer::start();
975
976        server.mock(|when, then| {
977            when.method(GET).path("/jwks");
978            then.status(200)
979                .header("content-type", "application/json")
980                .body(valid_jwks_json());
981        });
982
983        let jwks_url = server.url("/jwks");
984        let provider = test_provider_with_http(&jwks_url)
985            .with_header_extras_handler(|_key, value| Some(value.to_string()));
986
987        // Header with a non-string extra ("eap":1) that would reject without handler
988        let header_json = r#"{"alg":"RS256","kid":"test-key-1","typ":"JWT","eap":1}"#;
989        let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
990        let payload_b64 = URL_SAFE_NO_PAD.encode(b"{}");
991        let token = format!("{header_b64}.{payload_b64}.AAAA");
992
993        let result = provider.validate_and_decode(&token).await;
994
995        // Handler lets header decode succeed → error must come from signature
996        // validation, not from header parsing.
997        let err = result.expect_err("fake signature should fail validation");
998        match &err {
999            ClaimsError::DecodeFailed(msg) => {
1000                assert!(
1001                    msg.contains("JWT validation failed"),
1002                    "Expected signature-validation error, got: {msg}"
1003                );
1004            }
1005            other => panic!("Expected DecodeFailed, got: {other:?}"),
1006        }
1007    }
1008
1009    #[test]
1010    fn test_decode_header_without_handler_rejects_non_string_extras() {
1011        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
1012
1013        let header_json = r#"{"alg":"RS256","eap":1,"kid":"kid-1","typ":"JWT"}"#;
1014        let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
1015        let token = format!("{header_b64}.e30.fake");
1016
1017        let result = decode_header(&token);
1018        assert!(result.is_err());
1019        let err = result.unwrap_err().to_string();
1020        assert!(
1021            err.contains("invalid type: integer"),
1022            "expected type error, got: {err}"
1023        );
1024    }
1025}