Skip to main content

modkit_auth/providers/
jwks.rs

1use crate::{claims_error::ClaimsError, plugin_traits::KeyProvider};
2use arc_swap::ArcSwap;
3use async_trait::async_trait;
4use jsonwebtoken::{DecodingKey, Header, Validation, decode, decode_header};
5use serde::Deserialize;
6use serde_json::Value;
7use std::collections::{HashMap, HashSet};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::RwLock;
11use tokio::time::Instant;
12
13#[derive(Debug, Clone, Deserialize)]
14struct Jwk {
15    kid: String,
16    kty: String,
17    #[serde(rename = "use")]
18    #[allow(dead_code)]
19    use_: Option<String>,
20    n: String,
21    e: String,
22    #[allow(dead_code)]
23    alg: Option<String>,
24}
25
26#[derive(Debug, Clone, Deserialize)]
27struct JwksResponse {
28    keys: Vec<Jwk>,
29}
30
31/// JWKS-based key provider with lock-free reads
32///
33/// Uses `ArcSwap` for lock-free key lookups and background refresh with exponential backoff.
34#[must_use]
35pub struct JwksKeyProvider {
36    /// JWKS endpoint URL
37    jwks_uri: String,
38
39    /// Keys stored in `ArcSwap` for lock-free reads
40    keys: Arc<ArcSwap<HashMap<String, DecodingKey>>>,
41
42    /// Last refresh time and error tracking for backoff
43    refresh_state: Arc<RwLock<RefreshState>>,
44
45    /// HTTP client for fetching JWKS
46    client: reqwest::Client,
47
48    /// Refresh interval (default: 5 minutes)
49    refresh_interval: Duration,
50
51    /// Maximum backoff duration (default: 1 hour)
52    max_backoff: Duration,
53
54    /// Cooldown for on-demand refresh (default: 60 seconds)
55    on_demand_refresh_cooldown: Duration,
56}
57
58#[derive(Debug, Default)]
59struct RefreshState {
60    last_refresh: Option<Instant>,
61    last_on_demand_refresh: Option<Instant>,
62    consecutive_failures: u32,
63    last_error: Option<String>,
64    failed_kids: HashSet<String>,
65}
66
67impl JwksKeyProvider {
68    /// Create a new JWKS key provider
69    ///
70    /// # Errors
71    /// Returns an error if the HTTP client fails to build.
72    pub fn new(jwks_uri: impl Into<String>) -> Result<Self, reqwest::Error> {
73        Ok(Self {
74            jwks_uri: jwks_uri.into(),
75            keys: Arc::new(ArcSwap::from_pointee(HashMap::new())),
76            refresh_state: Arc::new(RwLock::new(RefreshState::default())),
77            client: reqwest::Client::builder()
78                .timeout(Duration::from_secs(10))
79                .build()?,
80            refresh_interval: Duration::from_secs(300), // 5 minutes
81            max_backoff: Duration::from_secs(3600),     // 1 hour
82            on_demand_refresh_cooldown: Duration::from_secs(60), // 1 minute
83        })
84    }
85
86    /// Create with custom refresh interval
87    pub fn with_refresh_interval(mut self, interval: Duration) -> Self {
88        self.refresh_interval = interval;
89        self
90    }
91
92    /// Create with custom max backoff
93    pub fn with_max_backoff(mut self, max_backoff: Duration) -> Self {
94        self.max_backoff = max_backoff;
95        self
96    }
97
98    /// Create with custom on-demand refresh cooldown
99    pub fn with_on_demand_refresh_cooldown(mut self, cooldown: Duration) -> Self {
100        self.on_demand_refresh_cooldown = cooldown;
101        self
102    }
103
104    /// Fetch JWKS from the endpoint
105    async fn fetch_jwks(&self) -> Result<HashMap<String, DecodingKey>, ClaimsError> {
106        let response = self
107            .client
108            .get(&self.jwks_uri)
109            .send()
110            .await
111            .map_err(|e| ClaimsError::JwksFetchFailed(format!("HTTP request failed: {e}")))?;
112
113        if !response.status().is_success() {
114            return Err(ClaimsError::JwksFetchFailed(format!(
115                "HTTP error: {}",
116                response.status()
117            )));
118        }
119
120        let jwks: JwksResponse = response
121            .json()
122            .await
123            .map_err(|e| ClaimsError::JwksFetchFailed(format!("Failed to parse JWKS: {e}")))?;
124
125        let mut keys = HashMap::new();
126        for jwk in jwks.keys {
127            if jwk.kty == "RSA" {
128                let key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)
129                    .map_err(|e| ClaimsError::JwksFetchFailed(format!("Invalid RSA key: {e}")))?;
130                keys.insert(jwk.kid, key);
131            }
132        }
133
134        if keys.is_empty() {
135            return Err(ClaimsError::JwksFetchFailed(
136                "No valid RSA keys found in JWKS".into(),
137            ));
138        }
139
140        Ok(keys)
141    }
142
143    /// Calculate backoff duration based on consecutive failures
144    fn calculate_backoff(&self, failures: u32) -> Duration {
145        let base = Duration::from_secs(60); // 1 minute base
146        let exponential = base * 2u32.pow(failures.min(10)); // Cap at 2^10
147        exponential.min(self.max_backoff)
148    }
149
150    /// Check if refresh is needed based on interval and backoff
151    async fn should_refresh(&self) -> bool {
152        let state = self.refresh_state.read().await;
153
154        match state.last_refresh {
155            None => true, // Never refreshed
156            Some(last) => {
157                let elapsed = last.elapsed();
158                if state.consecutive_failures == 0 {
159                    // Normal refresh interval
160                    elapsed >= self.refresh_interval
161                } else {
162                    // Exponential backoff
163                    elapsed >= self.calculate_backoff(state.consecutive_failures)
164                }
165            }
166        }
167    }
168
169    /// Perform key refresh with error tracking
170    async fn perform_refresh(&self) -> Result<(), ClaimsError> {
171        match self.fetch_jwks().await {
172            Ok(new_keys) => {
173                // Update keys atomically
174                self.keys.store(Arc::new(new_keys));
175
176                // Update refresh state
177                let mut state = self.refresh_state.write().await;
178                state.last_refresh = Some(Instant::now());
179                state.consecutive_failures = 0;
180                state.last_error = None;
181
182                Ok(())
183            }
184            Err(e) => {
185                // Update failure state
186                let mut state = self.refresh_state.write().await;
187                state.last_refresh = Some(Instant::now());
188                state.consecutive_failures += 1;
189                state.last_error = Some(e.to_string());
190
191                Err(e)
192            }
193        }
194    }
195
196    /// Check if a key exists in the cache
197    fn key_exists(&self, kid: &str) -> bool {
198        let keys = self.keys.load();
199        keys.contains_key(kid)
200    }
201
202    /// Check if we're in cooldown period and handle throttling logic
203    async fn check_refresh_throttle(&self, kid: &str) -> Result<(), ClaimsError> {
204        let state = self.refresh_state.read().await;
205        if let Some(last_on_demand) = state.last_on_demand_refresh {
206            let elapsed = last_on_demand.elapsed();
207            if elapsed < self.on_demand_refresh_cooldown {
208                let remaining = self.on_demand_refresh_cooldown.saturating_sub(elapsed);
209                tracing::debug!(
210                    kid = kid,
211                    remaining_secs = remaining.as_secs(),
212                    "On-demand JWKS refresh throttled (cooldown active)"
213                );
214
215                // Check if this kid has failed before
216                if state.failed_kids.contains(kid) {
217                    tracing::warn!(
218                        kid = kid,
219                        "Unknown kid repeatedly requested despite recent refresh attempts"
220                    );
221                }
222
223                return Err(ClaimsError::UnknownKeyId(kid.to_owned()));
224            }
225        }
226        Ok(())
227    }
228
229    /// Update state after successful refresh and check if kid is now available
230    async fn handle_refresh_success(&self, kid: &str) -> Result<(), ClaimsError> {
231        let mut state = self.refresh_state.write().await;
232        state.last_on_demand_refresh = Some(Instant::now());
233
234        // Check if the kid now exists
235        if self.key_exists(kid) {
236            // Kid found - remove from failed list if present
237            state.failed_kids.remove(kid);
238        } else {
239            // Kid still not found after refresh - track it
240            state.failed_kids.insert(kid.to_owned());
241            tracing::warn!(
242                kid = kid,
243                "Kid still not found after on-demand JWKS refresh"
244            );
245        }
246
247        Ok(())
248    }
249
250    /// Update state after failed refresh
251    async fn handle_refresh_failure(&self, kid: &str, error: ClaimsError) -> ClaimsError {
252        let mut state = self.refresh_state.write().await;
253        state.last_on_demand_refresh = Some(Instant::now());
254        state.failed_kids.insert(kid.to_owned());
255        error
256    }
257
258    /// Try to refresh keys if unknown kid is encountered
259    /// Implements throttling to prevent excessive refreshes
260    async fn on_demand_refresh(&self, kid: &str) -> Result<(), ClaimsError> {
261        // Check if key exists
262        if self.key_exists(kid) {
263            return Ok(());
264        }
265
266        // Check if we're in cooldown period
267        self.check_refresh_throttle(kid).await?;
268
269        // Attempt refresh and track the kid if it fails
270        tracing::info!(
271            kid = kid,
272            "Performing on-demand JWKS refresh for unknown kid"
273        );
274
275        match self.perform_refresh().await {
276            Ok(()) => self.handle_refresh_success(kid).await,
277            Err(e) => Err(self.handle_refresh_failure(kid, e).await),
278        }
279    }
280
281    /// Get a key by kid (lock-free read)
282    fn get_key(&self, kid: &str) -> Option<DecodingKey> {
283        let keys = self.keys.load();
284        keys.get(kid).cloned()
285    }
286
287    /// Validate JWT and decode into header + raw claims
288    fn validate_token(
289        token: &str,
290        key: &DecodingKey,
291        header: &Header,
292    ) -> Result<Value, ClaimsError> {
293        let mut validation = Validation::new(header.alg);
294
295        // Disable all built-in validations - we'll do them separately
296        validation.validate_exp = false;
297        validation.validate_nbf = false;
298        validation.validate_aud = false;
299
300        // Don't require any standard claims
301        let empty_claims: &[&str] = &[];
302        validation.set_required_spec_claims(empty_claims);
303
304        let token_data = decode::<Value>(token, key, &validation)
305            .map_err(|e| ClaimsError::DecodeFailed(format!("JWT validation failed: {e}")))?;
306
307        Ok(token_data.claims)
308    }
309}
310
311#[async_trait]
312impl KeyProvider for JwksKeyProvider {
313    fn name(&self) -> &'static str {
314        "jwks"
315    }
316
317    async fn validate_and_decode(&self, token: &str) -> Result<(Header, Value), ClaimsError> {
318        // Strip "Bearer " prefix if present
319        let token = token.trim_start_matches("Bearer ").trim();
320
321        // Decode header to get kid and algorithm
322        let header = decode_header(token)
323            .map_err(|e| ClaimsError::DecodeFailed(format!("Invalid JWT header: {e}")))?;
324
325        let kid = header
326            .kid
327            .as_ref()
328            .ok_or_else(|| ClaimsError::DecodeFailed("Missing kid in JWT header".into()))?;
329
330        // Try to get key from cache
331        let key = if let Some(k) = self.get_key(kid) {
332            k
333        } else {
334            // Key not in cache, try on-demand refresh
335            self.on_demand_refresh(kid).await?;
336
337            // Try again after refresh
338            self.get_key(kid)
339                .ok_or_else(|| ClaimsError::UnknownKeyId(kid.clone()))?
340        };
341
342        // Validate signature and decode claims
343        let claims = Self::validate_token(token, &key, &header)?;
344
345        Ok((header, claims))
346    }
347
348    async fn refresh_keys(&self) -> Result<(), ClaimsError> {
349        if self.should_refresh().await {
350            self.perform_refresh().await
351        } else {
352            Ok(())
353        }
354    }
355}
356
357/// Background task to periodically refresh JWKS
358pub async fn run_jwks_refresh_task(provider: Arc<JwksKeyProvider>) {
359    let mut interval = tokio::time::interval(Duration::from_secs(60)); // Check every minute
360
361    loop {
362        interval.tick().await;
363
364        if let Err(e) = provider.refresh_keys().await {
365            tracing::warn!("JWKS refresh failed: {}", e);
366        }
367    }
368}
369
370#[cfg(test)]
371#[cfg_attr(coverage_nightly, coverage(off))]
372mod tests {
373    use super::*;
374
375    #[test]
376    fn test_calculate_backoff() -> Result<(), reqwest::Error> {
377        let provider = JwksKeyProvider::new("https://example.com/jwks")?;
378
379        assert_eq!(provider.calculate_backoff(0), Duration::from_secs(60));
380        assert_eq!(provider.calculate_backoff(1), Duration::from_secs(120));
381        assert_eq!(provider.calculate_backoff(2), Duration::from_secs(240));
382        assert_eq!(provider.calculate_backoff(3), Duration::from_secs(480));
383
384        // Should cap at max_backoff
385        assert_eq!(provider.calculate_backoff(100), provider.max_backoff);
386        Ok(())
387    }
388
389    #[tokio::test]
390    async fn test_should_refresh_on_first_call() -> Result<(), reqwest::Error> {
391        let provider = JwksKeyProvider::new("https://example.com/jwks")?;
392        assert!(provider.should_refresh().await);
393        Ok(())
394    }
395
396    #[tokio::test]
397    async fn test_key_storage() -> Result<(), reqwest::Error> {
398        let provider = JwksKeyProvider::new("https://example.com/jwks")?;
399
400        // Initially empty
401        assert!(provider.get_key("test-kid").is_none());
402
403        // Store a dummy key
404        let mut keys = HashMap::new();
405        keys.insert("test-kid".to_owned(), DecodingKey::from_secret(b"secret"));
406        provider.keys.store(Arc::new(keys));
407
408        // Should be retrievable
409        assert!(provider.get_key("test-kid").is_some());
410        Ok(())
411    }
412
413    #[tokio::test]
414    async fn test_on_demand_refresh_returns_ok_when_key_exists() -> Result<(), reqwest::Error> {
415        let provider = JwksKeyProvider::new("https://example.com/jwks")?;
416
417        // Pre-populate with a key
418        let mut keys = HashMap::new();
419        keys.insert(
420            "existing-kid".to_owned(),
421            DecodingKey::from_secret(b"secret"),
422        );
423        provider.keys.store(Arc::new(keys));
424
425        // Should return Ok immediately without any refresh
426        let result = provider.on_demand_refresh("existing-kid").await;
427        assert!(result.is_ok());
428        Ok(())
429    }
430
431    #[tokio::test]
432    async fn test_on_demand_refresh_returns_error_for_missing_key_on_failed_fetch()
433    -> Result<(), reqwest::Error> {
434        let provider =
435            JwksKeyProvider::new("https://invalid-domain-that-does-not-exist.local/jwks")?;
436
437        // Attempting to refresh a missing key should fail (network error)
438        let result = provider.on_demand_refresh("missing-kid").await;
439        assert!(result.is_err());
440
441        // The error should be related to fetch failure
442        match result.expect_err("expected error for missing key") {
443            ClaimsError::JwksFetchFailed(_) | ClaimsError::UnknownKeyId(_) => {}
444            other => panic!("Expected JwksFetchFailed or UnknownKeyId, got: {other:?}"),
445        }
446        Ok(())
447    }
448
449    #[tokio::test]
450    async fn test_on_demand_refresh_respects_cooldown() -> Result<(), reqwest::Error> {
451        let provider = JwksKeyProvider::new("https://invalid-domain.local/jwks")?
452            .with_on_demand_refresh_cooldown(Duration::from_secs(5));
453
454        // First attempt - should try to refresh
455        let result1 = provider.on_demand_refresh("test-kid").await;
456        assert!(result1.is_err()); // Will fail due to invalid domain
457
458        // Immediate second attempt - should be throttled
459        let result2 = provider.on_demand_refresh("test-kid").await;
460        assert!(result2.is_err());
461
462        // Should return UnknownKeyId due to cooldown
463        match result2.expect_err("expected throttle error") {
464            ClaimsError::UnknownKeyId(_) => {}
465            other => panic!("Expected UnknownKeyId during cooldown, got: {other:?}"),
466        }
467        Ok(())
468    }
469
470    #[tokio::test]
471    async fn test_on_demand_refresh_tracks_failed_kids() -> Result<(), reqwest::Error> {
472        let provider = JwksKeyProvider::new("https://invalid-domain.local/jwks")?
473            .with_on_demand_refresh_cooldown(Duration::from_millis(100));
474
475        // Attempt refresh - will fail and track the kid
476        let result = provider.on_demand_refresh("failed-kid").await;
477        assert!(result.is_err());
478
479        // Check that failed_kids contains the kid
480        let state = provider.refresh_state.read().await;
481        assert!(state.failed_kids.contains("failed-kid"));
482        Ok(())
483    }
484
485    #[tokio::test]
486    async fn test_validate_and_decode_with_missing_kid() -> Result<(), reqwest::Error> {
487        let provider = JwksKeyProvider::new("https://invalid-domain.local/jwks")?
488            .with_on_demand_refresh_cooldown(Duration::from_millis(100));
489
490        // Create a minimal JWT with a kid header but invalid signature
491        let token =
492            "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIn0.eyJzdWIiOiIxMjM0NTY3ODkwIn0.invalid";
493
494        // Should attempt on-demand refresh and fail
495        let result = provider.validate_and_decode(token).await;
496        assert!(result.is_err());
497        Ok(())
498    }
499
500    #[tokio::test]
501    async fn test_perform_refresh_updates_state_on_success() -> Result<(), reqwest::Error> {
502        let provider = JwksKeyProvider::new("https://invalid-domain.local/jwks")?;
503
504        // Mark as previously failed
505        {
506            let mut state = provider.refresh_state.write().await;
507            state.consecutive_failures = 3;
508            state.last_error = Some("Previous error".to_owned());
509        }
510
511        // This will fail, but we're testing state update logic
512        let _ = provider.perform_refresh().await;
513
514        // Check that consecutive_failures increased
515        let state = provider.refresh_state.read().await;
516        assert_eq!(state.consecutive_failures, 4);
517        assert!(state.last_error.is_some());
518        Ok(())
519    }
520}