Skip to main content

modkit_auth/providers/
jwks.rs

1use crate::{claims_error::ClaimsError, plugin_traits::KeyProvider};
2use arc_swap::ArcSwap;
3use async_trait::async_trait;
4use jsonwebtoken::{DecodingKey, Header, Validation, decode, decode_header};
5use serde::Deserialize;
6use serde_json::Value;
7use std::collections::{HashMap, HashSet};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::RwLock;
11use tokio::time::Instant;
12use tokio_util::sync::CancellationToken;
13
14#[derive(Debug, Clone, Deserialize)]
15struct Jwk {
16    kid: String,
17    kty: String,
18    #[serde(rename = "use")]
19    #[allow(dead_code)]
20    use_: Option<String>,
21    n: String,
22    e: String,
23    #[allow(dead_code)]
24    alg: Option<String>,
25}
26
27#[derive(Debug, Clone, Deserialize)]
28struct JwksResponse {
29    keys: Vec<Jwk>,
30}
31
32/// JWKS-based key provider with lock-free reads
33///
34/// Uses `ArcSwap` for lock-free key lookups and background refresh with exponential backoff.
35#[must_use]
36pub struct JwksKeyProvider {
37    /// JWKS endpoint URL
38    jwks_uri: String,
39
40    /// Keys stored in `ArcSwap` for lock-free reads
41    keys: Arc<ArcSwap<HashMap<String, DecodingKey>>>,
42
43    /// Last refresh time and error tracking for backoff
44    refresh_state: Arc<RwLock<RefreshState>>,
45
46    /// Shared HTTP client for JWKS fetches (pooled connections)
47    /// `HttpClient` is `Clone + Send + Sync`, no external locking needed.
48    client: modkit_http::HttpClient,
49
50    /// Refresh interval (default: 5 minutes)
51    refresh_interval: Duration,
52
53    /// Maximum backoff duration (default: 1 hour)
54    max_backoff: Duration,
55
56    /// Cooldown for on-demand refresh (default: 60 seconds)
57    on_demand_refresh_cooldown: Duration,
58}
59
60#[derive(Debug, Default)]
61struct RefreshState {
62    last_refresh: Option<Instant>,
63    last_on_demand_refresh: Option<Instant>,
64    consecutive_failures: u32,
65    last_error: Option<String>,
66    failed_kids: HashSet<String>,
67}
68
69impl JwksKeyProvider {
70    /// Create a new JWKS key provider
71    ///
72    /// # Errors
73    /// Returns error if HTTP client initialization fails (e.g., TLS setup)
74    pub fn new(jwks_uri: impl Into<String>) -> Result<Self, modkit_http::HttpError> {
75        Self::with_http_timeout(jwks_uri, Duration::from_secs(10))
76    }
77
78    /// Create a new JWKS key provider with custom HTTP timeout
79    ///
80    /// # Errors
81    /// Returns error if HTTP client initialization fails (e.g., TLS setup)
82    pub fn with_http_timeout(
83        jwks_uri: impl Into<String>,
84        timeout: Duration,
85    ) -> Result<Self, modkit_http::HttpError> {
86        let client = modkit_http::HttpClient::builder()
87            .timeout(timeout)
88            .retry(None) // JWKS provider handles its own retry logic
89            .build()?;
90
91        Ok(Self {
92            jwks_uri: jwks_uri.into(),
93            keys: Arc::new(ArcSwap::from_pointee(HashMap::new())),
94            refresh_state: Arc::new(RwLock::new(RefreshState::default())),
95            client,
96            refresh_interval: Duration::from_secs(300), // 5 minutes
97            max_backoff: Duration::from_secs(3600),     // 1 hour
98            on_demand_refresh_cooldown: Duration::from_secs(60), // 1 minute
99        })
100    }
101
102    /// Create a new JWKS key provider (alias for new, kept for compatibility)
103    ///
104    /// # Errors
105    /// Returns error if HTTP client initialization fails (e.g., TLS setup)
106    pub fn try_new(jwks_uri: impl Into<String>) -> Result<Self, modkit_http::HttpError> {
107        Self::new(jwks_uri)
108    }
109
110    /// Create with custom refresh interval
111    pub fn with_refresh_interval(mut self, interval: Duration) -> Self {
112        self.refresh_interval = interval;
113        self
114    }
115
116    /// Create with custom max backoff
117    pub fn with_max_backoff(mut self, max_backoff: Duration) -> Self {
118        self.max_backoff = max_backoff;
119        self
120    }
121
122    /// Create with custom on-demand refresh cooldown
123    pub fn with_on_demand_refresh_cooldown(mut self, cooldown: Duration) -> Self {
124        self.on_demand_refresh_cooldown = cooldown;
125        self
126    }
127
128    /// Fetch JWKS from the endpoint
129    async fn fetch_jwks(&self) -> Result<HashMap<String, DecodingKey>, ClaimsError> {
130        // HttpClient is Clone + Send + Sync, no locking needed
131        let jwks: JwksResponse = self
132            .client
133            .get(&self.jwks_uri)
134            .send()
135            .await
136            .map_err(|e| map_http_error(&e))?
137            .json()
138            .await
139            .map_err(|e| map_http_error(&e))?;
140
141        let mut keys = HashMap::new();
142        for jwk in jwks.keys {
143            if jwk.kty == "RSA" {
144                let key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)
145                    .map_err(|e| ClaimsError::JwksFetchFailed(format!("Invalid RSA key: {e}")))?;
146                keys.insert(jwk.kid, key);
147            }
148        }
149
150        if keys.is_empty() {
151            return Err(ClaimsError::JwksFetchFailed(
152                "No valid RSA keys found in JWKS".into(),
153            ));
154        }
155
156        Ok(keys)
157    }
158
159    /// Calculate backoff duration based on consecutive failures
160    fn calculate_backoff(&self, failures: u32) -> Duration {
161        let base = Duration::from_secs(60); // 1 minute base
162        let exponential = base * 2u32.pow(failures.min(10)); // Cap at 2^10
163        exponential.min(self.max_backoff)
164    }
165
166    /// Check if refresh is needed based on interval and backoff
167    async fn should_refresh(&self) -> bool {
168        let state = self.refresh_state.read().await;
169
170        match state.last_refresh {
171            None => true, // Never refreshed
172            Some(last) => {
173                let elapsed = last.elapsed();
174                if state.consecutive_failures == 0 {
175                    // Normal refresh interval
176                    elapsed >= self.refresh_interval
177                } else {
178                    // Exponential backoff
179                    elapsed >= self.calculate_backoff(state.consecutive_failures)
180                }
181            }
182        }
183    }
184
185    /// Perform key refresh with error tracking
186    async fn perform_refresh(&self) -> Result<(), ClaimsError> {
187        match self.fetch_jwks().await {
188            Ok(new_keys) => {
189                // Update keys atomically
190                self.keys.store(Arc::new(new_keys));
191
192                // Update refresh state
193                let mut state = self.refresh_state.write().await;
194                state.last_refresh = Some(Instant::now());
195                state.consecutive_failures = 0;
196                state.last_error = None;
197
198                Ok(())
199            }
200            Err(e) => {
201                // Update failure state
202                let mut state = self.refresh_state.write().await;
203                state.last_refresh = Some(Instant::now());
204                state.consecutive_failures += 1;
205                state.last_error = Some(e.to_string());
206
207                Err(e)
208            }
209        }
210    }
211
212    /// Check if a key exists in the cache
213    fn key_exists(&self, kid: &str) -> bool {
214        let keys = self.keys.load();
215        keys.contains_key(kid)
216    }
217
218    /// Check if we're in cooldown period and handle throttling logic
219    async fn check_refresh_throttle(&self, kid: &str) -> Result<(), ClaimsError> {
220        let state = self.refresh_state.read().await;
221        if let Some(last_on_demand) = state.last_on_demand_refresh {
222            let elapsed = last_on_demand.elapsed();
223            if elapsed < self.on_demand_refresh_cooldown {
224                let remaining = self.on_demand_refresh_cooldown.saturating_sub(elapsed);
225                tracing::debug!(
226                    kid = kid,
227                    remaining_secs = remaining.as_secs(),
228                    "On-demand JWKS refresh throttled (cooldown active)"
229                );
230
231                // Check if this kid has failed before
232                if state.failed_kids.contains(kid) {
233                    tracing::warn!(
234                        kid = kid,
235                        "Unknown kid repeatedly requested despite recent refresh attempts"
236                    );
237                }
238
239                return Err(ClaimsError::UnknownKeyId(kid.to_owned()));
240            }
241        }
242        Ok(())
243    }
244
245    /// Update state after successful refresh and check if kid is now available
246    async fn handle_refresh_success(&self, kid: &str) -> Result<(), ClaimsError> {
247        let mut state = self.refresh_state.write().await;
248        state.last_on_demand_refresh = Some(Instant::now());
249
250        // Check if the kid now exists
251        if self.key_exists(kid) {
252            // Kid found - remove from failed list if present
253            state.failed_kids.remove(kid);
254        } else {
255            // Kid still not found after refresh - track it
256            state.failed_kids.insert(kid.to_owned());
257            tracing::warn!(
258                kid = kid,
259                "Kid still not found after on-demand JWKS refresh"
260            );
261        }
262
263        Ok(())
264    }
265
266    /// Update state after failed refresh
267    async fn handle_refresh_failure(&self, kid: &str, error: ClaimsError) -> ClaimsError {
268        let mut state = self.refresh_state.write().await;
269        state.last_on_demand_refresh = Some(Instant::now());
270        state.failed_kids.insert(kid.to_owned());
271        error
272    }
273
274    /// Try to refresh keys if unknown kid is encountered
275    /// Implements throttling to prevent excessive refreshes
276    async fn on_demand_refresh(&self, kid: &str) -> Result<(), ClaimsError> {
277        // Check if key exists
278        if self.key_exists(kid) {
279            return Ok(());
280        }
281
282        // Check if we're in cooldown period
283        self.check_refresh_throttle(kid).await?;
284
285        // Attempt refresh and track the kid if it fails
286        tracing::info!(
287            kid = kid,
288            "Performing on-demand JWKS refresh for unknown kid"
289        );
290
291        match self.perform_refresh().await {
292            Ok(()) => self.handle_refresh_success(kid).await,
293            Err(e) => Err(self.handle_refresh_failure(kid, e).await),
294        }
295    }
296
297    /// Get a key by kid (lock-free read)
298    fn get_key(&self, kid: &str) -> Option<DecodingKey> {
299        let keys = self.keys.load();
300        keys.get(kid).cloned()
301    }
302
303    /// Validate JWT and decode into header + raw claims
304    fn validate_token(
305        token: &str,
306        key: &DecodingKey,
307        header: &Header,
308    ) -> Result<Value, ClaimsError> {
309        let mut validation = Validation::new(header.alg);
310
311        // Disable all built-in validations - we'll do them separately
312        validation.validate_exp = false;
313        validation.validate_nbf = false;
314        validation.validate_aud = false;
315
316        // Don't require any standard claims
317        let empty_claims: &[&str] = &[];
318        validation.set_required_spec_claims(empty_claims);
319
320        let token_data = decode::<Value>(token, key, &validation)
321            .map_err(|e| ClaimsError::DecodeFailed(format!("JWT validation failed: {e}")))?;
322
323        Ok(token_data.claims)
324    }
325}
326
327#[async_trait]
328impl KeyProvider for JwksKeyProvider {
329    fn name(&self) -> &'static str {
330        "jwks"
331    }
332
333    async fn validate_and_decode(&self, token: &str) -> Result<(Header, Value), ClaimsError> {
334        // Strip "Bearer " prefix if present
335        let token = token.trim_start_matches("Bearer ").trim();
336
337        // Decode header to get kid and algorithm
338        let header = decode_header(token)
339            .map_err(|e| ClaimsError::DecodeFailed(format!("Invalid JWT header: {e}")))?;
340
341        let kid = header
342            .kid
343            .as_ref()
344            .ok_or_else(|| ClaimsError::DecodeFailed("Missing kid in JWT header".into()))?;
345
346        // Try to get key from cache
347        let key = if let Some(k) = self.get_key(kid) {
348            k
349        } else {
350            // Key not in cache, try on-demand refresh
351            self.on_demand_refresh(kid).await?;
352
353            // Try again after refresh
354            self.get_key(kid)
355                .ok_or_else(|| ClaimsError::UnknownKeyId(kid.clone()))?
356        };
357
358        // Validate signature and decode claims
359        let claims = Self::validate_token(token, &key, &header)?;
360
361        Ok((header, claims))
362    }
363
364    async fn refresh_keys(&self) -> Result<(), ClaimsError> {
365        if self.should_refresh().await {
366            self.perform_refresh().await
367        } else {
368            Ok(())
369        }
370    }
371}
372
373/// Background task to periodically refresh JWKS
374///
375/// This task will run until the `cancellation_token` is cancelled, enabling
376/// graceful shutdown per `ModKit` patterns. Without cancellation support, this
377/// task would run indefinitely and potentially cause process hang on shutdown.
378///
379/// # Example
380///
381/// ```ignore
382/// use tokio_util::sync::CancellationToken;
383/// use std::sync::Arc;
384///
385/// let provider = Arc::new(JwksKeyProvider::new("https://issuer/.well-known/jwks.json")?);
386/// let cancel_token = CancellationToken::new();
387///
388/// // Spawn the refresh task
389/// let task_handle = tokio::spawn(run_jwks_refresh_task(provider.clone(), cancel_token.clone()));
390///
391/// // On shutdown:
392/// cancel_token.cancel();
393/// task_handle.await?;
394/// ```
395pub async fn run_jwks_refresh_task(
396    provider: Arc<JwksKeyProvider>,
397    cancellation_token: CancellationToken,
398) {
399    let mut interval = tokio::time::interval(Duration::from_secs(60)); // Check every minute
400
401    loop {
402        tokio::select! {
403            () = cancellation_token.cancelled() => {
404                tracing::info!("JWKS refresh task shutting down");
405                break;
406            }
407            _ = interval.tick() => {
408                if let Err(e) = provider.refresh_keys().await {
409                    tracing::warn!("JWKS refresh failed: {}", e);
410                }
411            }
412        }
413    }
414}
415
416/// Map `HttpError` variants to appropriate `ClaimsError` messages
417fn map_http_error(e: &modkit_http::HttpError) -> ClaimsError {
418    ClaimsError::JwksFetchFailed(crate::http_error::format_http_error(e, "JWKS"))
419}
420
421#[cfg(test)]
422#[cfg_attr(coverage_nightly, coverage(off))]
423mod tests {
424    use super::*;
425    use httpmock::prelude::*;
426
427    /// Create a test provider with insecure HTTP allowed (for httpmock) and no retries
428    fn test_provider_with_http(uri: &str) -> JwksKeyProvider {
429        let client = modkit_http::HttpClient::builder()
430            .timeout(Duration::from_secs(5))
431            .retry(None)
432            .allow_insecure_http()
433            .build()
434            .expect("failed to create test HTTP client");
435
436        JwksKeyProvider {
437            jwks_uri: uri.to_owned(),
438            keys: Arc::new(ArcSwap::from_pointee(HashMap::new())),
439            refresh_state: Arc::new(RwLock::new(RefreshState::default())),
440            client,
441            refresh_interval: Duration::from_secs(300),
442            max_backoff: Duration::from_secs(3600),
443            on_demand_refresh_cooldown: Duration::from_secs(60),
444        }
445    }
446
447    /// Create a basic test provider (HTTPS only, for non-network tests)
448    fn test_provider(uri: &str) -> JwksKeyProvider {
449        JwksKeyProvider::new(uri).expect("failed to create test provider")
450    }
451
452    /// Valid JWKS JSON response with a single RSA key
453    fn valid_jwks_json() -> &'static str {
454        r#"{
455            "keys": [{
456                "kty": "RSA",
457                "kid": "test-key-1",
458                "use": "sig",
459                "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
460                "e": "AQAB",
461                "alg": "RS256"
462            }]
463        }"#
464    }
465
466    #[tokio::test]
467    async fn test_calculate_backoff() {
468        let provider = test_provider("https://example.com/jwks");
469
470        assert_eq!(provider.calculate_backoff(0), Duration::from_secs(60));
471        assert_eq!(provider.calculate_backoff(1), Duration::from_secs(120));
472        assert_eq!(provider.calculate_backoff(2), Duration::from_secs(240));
473        assert_eq!(provider.calculate_backoff(3), Duration::from_secs(480));
474
475        // Should cap at max_backoff
476        assert_eq!(provider.calculate_backoff(100), provider.max_backoff);
477    }
478
479    #[tokio::test]
480    async fn test_should_refresh_on_first_call() {
481        let provider = test_provider("https://example.com/jwks");
482        assert!(provider.should_refresh().await);
483    }
484
485    #[tokio::test]
486    async fn test_key_storage() {
487        let provider = test_provider("https://example.com/jwks");
488
489        // Initially empty
490        assert!(provider.get_key("test-kid").is_none());
491
492        // Store a dummy key
493        let mut keys = HashMap::new();
494        keys.insert("test-kid".to_owned(), DecodingKey::from_secret(b"secret"));
495        provider.keys.store(Arc::new(keys));
496
497        // Should be retrievable
498        assert!(provider.get_key("test-kid").is_some());
499    }
500
501    #[tokio::test]
502    async fn test_on_demand_refresh_returns_ok_when_key_exists() {
503        let provider = test_provider("https://example.com/jwks");
504
505        // Pre-populate with a key
506        let mut keys = HashMap::new();
507        keys.insert(
508            "existing-kid".to_owned(),
509            DecodingKey::from_secret(b"secret"),
510        );
511        provider.keys.store(Arc::new(keys));
512
513        // Should return Ok immediately without any refresh
514        let result = provider.on_demand_refresh("existing-kid").await;
515        assert!(result.is_ok());
516    }
517
518    #[tokio::test]
519    async fn test_try_new_returns_result() {
520        // Valid URL should work
521        let result = JwksKeyProvider::try_new("https://example.com/jwks");
522        assert!(result.is_ok());
523    }
524
525    // ==================== httpmock-based tests ====================
526
527    #[tokio::test]
528    async fn test_fetch_jwks_success_with_valid_json() {
529        let server = MockServer::start();
530
531        let mock = server.mock(|when, then| {
532            when.method(GET).path("/jwks");
533            then.status(200)
534                .header("content-type", "application/json")
535                .body(valid_jwks_json());
536        });
537
538        let jwks_url = server.url("/jwks");
539        let provider = test_provider_with_http(&jwks_url);
540
541        let result = provider.perform_refresh().await;
542        assert!(result.is_ok(), "Expected success, got: {result:?}");
543
544        // Verify key was stored
545        assert!(
546            provider.get_key("test-key-1").is_some(),
547            "Expected key 'test-key-1' to be stored"
548        );
549
550        mock.assert();
551    }
552
553    #[tokio::test]
554    async fn test_fetch_jwks_http_404_error_mapping() {
555        let server = MockServer::start();
556
557        let mock = server.mock(|when, then| {
558            when.method(GET).path("/jwks");
559            then.status(404).body("Not Found");
560        });
561
562        let jwks_url = server.url("/jwks");
563        let provider = test_provider_with_http(&jwks_url);
564
565        let result = provider.perform_refresh().await;
566        assert!(result.is_err());
567
568        let err = result.unwrap_err();
569        let err_msg = err.to_string();
570        assert!(
571            err_msg.contains("JWKS HTTP 404"),
572            "Expected error to contain 'JWKS HTTP 404', got: {err_msg}"
573        );
574        // Must NOT say "parse"
575        assert!(
576            !err_msg.to_lowercase().contains("parse"),
577            "HTTP status error should not mention 'parse', got: {err_msg}"
578        );
579
580        mock.assert();
581    }
582
583    #[tokio::test]
584    async fn test_fetch_jwks_http_500_error_mapping() {
585        let server = MockServer::start();
586
587        let mock = server.mock(|when, then| {
588            when.method(GET).path("/jwks");
589            then.status(500).body("Internal Server Error");
590        });
591
592        let jwks_url = server.url("/jwks");
593        let provider = test_provider_with_http(&jwks_url);
594
595        let result = provider.perform_refresh().await;
596        assert!(result.is_err());
597
598        let err = result.unwrap_err();
599        let err_msg = err.to_string();
600        assert!(
601            err_msg.contains("JWKS HTTP 500"),
602            "Expected error to contain 'JWKS HTTP 500', got: {err_msg}"
603        );
604
605        mock.assert();
606    }
607
608    #[tokio::test]
609    async fn test_fetch_jwks_invalid_json_error_mapping() {
610        let server = MockServer::start();
611
612        let mock = server.mock(|when, then| {
613            when.method(GET).path("/jwks");
614            then.status(200)
615                .header("content-type", "application/json")
616                .body("this is not valid json");
617        });
618
619        let jwks_url = server.url("/jwks");
620        let provider = test_provider_with_http(&jwks_url);
621
622        let result = provider.perform_refresh().await;
623        assert!(result.is_err());
624
625        let err = result.unwrap_err();
626        let err_msg = err.to_string();
627        assert!(
628            err_msg.contains("JWKS JSON parse failed"),
629            "Expected error to contain 'JWKS JSON parse failed', got: {err_msg}"
630        );
631
632        mock.assert();
633    }
634
635    #[tokio::test]
636    async fn test_fetch_jwks_empty_keys_error() {
637        let server = MockServer::start();
638
639        let mock = server.mock(|when, then| {
640            when.method(GET).path("/jwks");
641            then.status(200)
642                .header("content-type", "application/json")
643                .body(r#"{"keys": []}"#);
644        });
645
646        let jwks_url = server.url("/jwks");
647        let provider = test_provider_with_http(&jwks_url);
648
649        let result = provider.perform_refresh().await;
650        assert!(result.is_err());
651
652        let err = result.unwrap_err();
653        let err_msg = err.to_string();
654        assert!(
655            err_msg.contains("No valid RSA keys"),
656            "Expected error about no RSA keys, got: {err_msg}"
657        );
658
659        mock.assert();
660    }
661
662    #[tokio::test]
663    async fn test_on_demand_refresh_respects_cooldown() {
664        let server = MockServer::start();
665
666        // First request will return 404
667        let mock = server.mock(|when, then| {
668            when.method(GET).path("/jwks");
669            then.status(404).body("Not Found");
670        });
671
672        let jwks_url = server.url("/jwks");
673        let provider = test_provider_with_http(&jwks_url)
674            .with_on_demand_refresh_cooldown(Duration::from_secs(60));
675
676        // First attempt - should try to refresh and fail
677        let result1 = provider.on_demand_refresh("test-kid").await;
678        assert!(result1.is_err());
679
680        // Immediate second attempt - should be throttled (no network call)
681        let result2 = provider.on_demand_refresh("test-kid").await;
682        assert!(result2.is_err());
683
684        // Should return UnknownKeyId due to cooldown
685        match result2.unwrap_err() {
686            ClaimsError::UnknownKeyId(_) => {}
687            other => panic!("Expected UnknownKeyId during cooldown, got: {other:?}"),
688        }
689
690        // Only one request should have been made (first attempt)
691        mock.assert_calls(1);
692    }
693
694    #[tokio::test]
695    async fn test_on_demand_refresh_tracks_failed_kids() {
696        let server = MockServer::start();
697
698        server.mock(|when, then| {
699            when.method(GET).path("/jwks");
700            then.status(404).body("Not Found");
701        });
702
703        let jwks_url = server.url("/jwks");
704        let provider = test_provider_with_http(&jwks_url)
705            .with_on_demand_refresh_cooldown(Duration::from_millis(100));
706
707        // Attempt refresh - will fail and track the kid
708        let result = provider.on_demand_refresh("failed-kid").await;
709        assert!(result.is_err());
710
711        // Check that failed_kids contains the kid
712        let state = provider.refresh_state.read().await;
713        assert!(state.failed_kids.contains("failed-kid"));
714    }
715
716    #[tokio::test]
717    async fn test_perform_refresh_updates_state_on_failure() {
718        let server = MockServer::start();
719
720        server.mock(|when, then| {
721            when.method(GET).path("/jwks");
722            then.status(500).body("Server Error");
723        });
724
725        let jwks_url = server.url("/jwks");
726        let provider = test_provider_with_http(&jwks_url);
727
728        // Mark as previously failed
729        {
730            let mut state = provider.refresh_state.write().await;
731            state.consecutive_failures = 3;
732            state.last_error = Some("Previous error".to_owned());
733        }
734
735        // This will fail
736        let _ = provider.perform_refresh().await;
737
738        // Check that consecutive_failures increased
739        let state = provider.refresh_state.read().await;
740        assert_eq!(state.consecutive_failures, 4);
741        assert!(state.last_error.is_some());
742    }
743
744    #[tokio::test]
745    async fn test_perform_refresh_resets_state_on_success() {
746        let server = MockServer::start();
747
748        server.mock(|when, then| {
749            when.method(GET).path("/jwks");
750            then.status(200)
751                .header("content-type", "application/json")
752                .body(valid_jwks_json());
753        });
754
755        let jwks_url = server.url("/jwks");
756        let provider = test_provider_with_http(&jwks_url);
757
758        // Mark as previously failed
759        {
760            let mut state = provider.refresh_state.write().await;
761            state.consecutive_failures = 5;
762            state.last_error = Some("Previous error".to_owned());
763        }
764
765        // This should succeed
766        let result = provider.perform_refresh().await;
767        assert!(result.is_ok());
768
769        // Check that state was reset
770        let state = provider.refresh_state.read().await;
771        assert_eq!(state.consecutive_failures, 0);
772        assert!(state.last_error.is_none());
773    }
774
775    #[tokio::test]
776    async fn test_validate_and_decode_with_missing_kid() {
777        let server = MockServer::start();
778
779        // Return valid JWKS but without the requested kid
780        server.mock(|when, then| {
781            when.method(GET).path("/jwks");
782            then.status(200)
783                .header("content-type", "application/json")
784                .body(valid_jwks_json());
785        });
786
787        let jwks_url = server.url("/jwks");
788        let provider = test_provider_with_http(&jwks_url)
789            .with_on_demand_refresh_cooldown(Duration::from_millis(100));
790
791        // Create a minimal JWT with a kid that doesn't exist in JWKS
792        // Header: {"alg":"RS256","kid":"nonexistent-kid"}
793        let token = "eyJhbGciOiJSUzI1NiIsImtpZCI6Im5vbmV4aXN0ZW50LWtpZCJ9.\
794                     eyJzdWIiOiIxMjM0NTY3ODkwIn0.invalid";
795
796        // Should attempt on-demand refresh but kid still won't exist
797        let result = provider.validate_and_decode(token).await;
798        assert!(result.is_err());
799
800        match result.unwrap_err() {
801            ClaimsError::UnknownKeyId(kid) => {
802                assert_eq!(kid, "nonexistent-kid");
803            }
804            other => panic!("Expected UnknownKeyId, got: {other:?}"),
805        }
806    }
807}