axum_jwt_auth/
remote.rs

1use std::sync::Arc;
2
3use dashmap::DashMap;
4use jsonwebtoken::{DecodingKey, TokenData, Validation, jwk::JwkSet};
5use serde::de::DeserializeOwned;
6use tokio_util::sync::CancellationToken;
7
8use crate::{Error, JwtDecoder};
9
10const DEFAULT_CACHE_DURATION: std::time::Duration = std::time::Duration::from_secs(60 * 60); // 1 hour
11const DEFAULT_RETRY_COUNT: usize = 3; // 3 attempts
12const DEFAULT_BACKOFF: std::time::Duration = std::time::Duration::from_secs(1); // 1 second
13
14/// Configuration for remote JWKS fetching and caching behavior.
15#[derive(Debug, Clone)]
16pub struct RemoteJwksDecoderConfig {
17    /// Duration to cache JWKS keys before refreshing (default: 1 hour)
18    pub cache_duration: std::time::Duration,
19    /// Number of retry attempts when fetching JWKS fails (default: 3)
20    pub retry_count: usize,
21    /// Delay between retry attempts (default: 1 second)
22    pub backoff: std::time::Duration,
23}
24
25impl Default for RemoteJwksDecoderConfig {
26    fn default() -> Self {
27        Self {
28            cache_duration: DEFAULT_CACHE_DURATION,
29            retry_count: DEFAULT_RETRY_COUNT,
30            backoff: DEFAULT_BACKOFF,
31        }
32    }
33}
34
35impl RemoteJwksDecoderConfig {
36    /// Creates a new builder for configuring JWKS fetching behavior.
37    pub fn builder() -> RemoteJwksDecoderConfigBuilder {
38        RemoteJwksDecoderConfigBuilder {
39            cache_duration: None,
40            retry_count: None,
41            backoff: None,
42        }
43    }
44}
45
46/// Builder for `RemoteJwksDecoderConfig`.
47pub struct RemoteJwksDecoderConfigBuilder {
48    cache_duration: Option<std::time::Duration>,
49    retry_count: Option<usize>,
50    backoff: Option<std::time::Duration>,
51}
52
53impl RemoteJwksDecoderConfigBuilder {
54    /// Sets the cache duration.
55    pub fn cache_duration(mut self, cache_duration: std::time::Duration) -> Self {
56        self.cache_duration = Some(cache_duration);
57        self
58    }
59
60    /// Sets the retry count.
61    pub fn retry_count(mut self, retry_count: usize) -> Self {
62        self.retry_count = Some(retry_count);
63        self
64    }
65
66    /// Sets the backoff duration.
67    pub fn backoff(mut self, backoff: std::time::Duration) -> Self {
68        self.backoff = Some(backoff);
69        self
70    }
71
72    /// Builds the `RemoteJwksDecoderConfig` with defaults for unset fields.
73    pub fn build(self) -> RemoteJwksDecoderConfig {
74        RemoteJwksDecoderConfig {
75            cache_duration: self.cache_duration.unwrap_or(DEFAULT_CACHE_DURATION),
76            retry_count: self.retry_count.unwrap_or(DEFAULT_RETRY_COUNT),
77            backoff: self.backoff.unwrap_or(DEFAULT_BACKOFF),
78        }
79    }
80}
81
82/// JWT decoder that fetches and caches keys from a remote JWKS endpoint.
83///
84/// Automatically fetches JWKS from the specified URL, caches keys by their `kid` (key ID),
85/// and periodically refreshes them in the background. Includes retry logic for robustness.
86///
87/// # Example
88///
89/// ```ignore
90/// use axum_jwt_auth::RemoteJwksDecoder;
91/// use jsonwebtoken::{Algorithm, Validation};
92///
93/// let decoder = RemoteJwksDecoder::builder()
94///     .jwks_url("https://example.com/.well-known/jwks.json".to_string())
95///     .validation(Validation::new(Algorithm::RS256))
96///     .build()
97///     .unwrap();
98///
99/// // Initialize: fetch keys and start background refresh task
100/// decoder.initialize().await.unwrap();
101/// ```
102#[derive(Clone)]
103pub struct RemoteJwksDecoder {
104    /// The JWKS endpoint URL
105    jwks_url: String,
106    /// Configuration for caching and retry behavior
107    config: RemoteJwksDecoderConfig,
108    /// Thread-safe cache mapping key IDs to decoding keys
109    keys_cache: Arc<DashMap<String, DecodingKey>>,
110    /// JWT validation settings
111    validation: Validation,
112    /// HTTP client for fetching JWKS
113    client: reqwest::Client,
114}
115
116impl RemoteJwksDecoder {
117    /// Creates a new `RemoteJwksDecoder` with the given JWKS URL and default settings.
118    ///
119    /// # Errors
120    ///
121    /// Returns `Error::Configuration` if the builder fails to construct the decoder.
122    pub fn new(jwks_url: String) -> Result<Self, Error> {
123        RemoteJwksDecoderBuilder::new().jwks_url(jwks_url).build()
124    }
125
126    /// Creates a new builder for configuring a remote JWKS decoder.
127    pub fn builder() -> RemoteJwksDecoderBuilder {
128        RemoteJwksDecoderBuilder::new()
129    }
130
131    /// Performs an initial fetch of JWKS keys and starts the background refresh task.
132    ///
133    /// This method should be called once after construction. It will:
134    /// 1. Immediately fetch keys from the JWKS endpoint
135    /// 2. Spawn a background task to periodically refresh keys
136    ///
137    /// Returns a `CancellationToken` that can be used to gracefully stop the background refresh task.
138    ///
139    /// # Errors
140    ///
141    /// Returns an error if the initial fetch fails after all retry attempts.
142    ///
143    /// # Example
144    ///
145    /// ```ignore
146    /// let decoder = RemoteJwksDecoder::builder()
147    ///     .jwks_url("https://example.com/.well-known/jwks.json".to_string())
148    ///     .validation(Validation::new(Algorithm::RS256))
149    ///     .build()?;
150    ///
151    /// // Fetch keys and start background refresh
152    /// let shutdown_token = decoder.initialize().await?;
153    ///
154    /// // Later, during application shutdown:
155    /// shutdown_token.cancel();
156    /// ```
157    pub async fn initialize(&self) -> Result<CancellationToken, Error> {
158        // Fetch keys immediately
159        self.refresh_keys().await?;
160
161        // Create cancellation token for graceful shutdown
162        let shutdown_token = CancellationToken::new();
163
164        // Spawn background refresh task
165        let decoder_clone = self.clone();
166        let token_clone = shutdown_token.clone();
167        tokio::spawn(async move {
168            decoder_clone.refresh_keys_periodically(token_clone).await;
169        });
170
171        Ok(shutdown_token)
172    }
173
174    /// Manually triggers a JWKS refresh with retry logic.
175    ///
176    /// Useful for forcing an update outside the normal refresh cycle.
177    ///
178    /// # Errors
179    ///
180    /// Returns an error if the refresh fails after all retry attempts.
181    pub async fn refresh(&self) -> Result<(), Error> {
182        self.refresh_keys().await
183    }
184
185    /// Refreshes the JWKS cache with retry logic.
186    ///
187    /// Retries up to `config.retry_count` times, waiting `config.backoff` duration between attempts.
188    ///
189    /// # Errors
190    ///
191    /// Returns `Error::JwksRefresh` if all retry attempts fail.
192    async fn refresh_keys(&self) -> Result<(), Error> {
193        let max_attempts = self.config.retry_count;
194        let mut attempt = 0;
195        let mut err = None;
196
197        while attempt < max_attempts {
198            match self.refresh_keys_once().await {
199                Ok(_) => return Ok(()),
200                Err(e) => {
201                    err = Some(e);
202                    attempt += 1;
203                    tokio::time::sleep(self.config.backoff).await;
204                }
205            }
206        }
207
208        Err(Error::JwksRefresh {
209            message: "Failed to refresh JWKS after multiple attempts".to_string(),
210            retry_count: max_attempts,
211            source: err.map(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>),
212        })
213    }
214
215    /// Fetches JWKS from the remote URL and updates the cache.
216    ///
217    /// Parses all keys before updating the cache to ensure atomicity.
218    async fn refresh_keys_once(&self) -> Result<(), Error> {
219        let jwks = self
220            .client
221            .get(&self.jwks_url)
222            .send()
223            .await?
224            .json::<JwkSet>()
225            .await?;
226
227        // Parse all keys first before clearing cache
228        let mut new_keys = Vec::new();
229        for jwk in jwks.keys.iter() {
230            let key_id = jwk.common.key_id.to_owned();
231            let key = DecodingKey::from_jwk(jwk).map_err(Error::Jwt)?;
232            new_keys.push((key_id.unwrap_or_default(), key));
233        }
234
235        // Only clear and update cache after all keys parsed successfully
236        self.keys_cache.clear();
237        for (kid, key) in new_keys {
238            self.keys_cache.insert(kid, key);
239        }
240
241        Ok(())
242    }
243
244    /// Runs a loop that periodically refreshes the JWKS cache until cancelled.
245    ///
246    /// This method should be spawned in a background task using `tokio::spawn`.
247    /// Refresh failures are logged, and the decoder continues using stale keys until the next
248    /// successful refresh.
249    ///
250    /// The loop will exit gracefully when the `shutdown_token` is cancelled.
251    ///
252    /// # Example
253    ///
254    /// ```ignore
255    /// use tokio_util::sync::CancellationToken;
256    ///
257    /// let decoder = RemoteJwksDecoder::builder()
258    ///     .jwks_url("https://example.com/.well-known/jwks.json".to_string())
259    ///     .build()
260    ///     .unwrap();
261    ///
262    /// let shutdown_token = CancellationToken::new();
263    /// let decoder_clone = decoder.clone();
264    /// let token_clone = shutdown_token.clone();
265    ///
266    /// tokio::spawn(async move {
267    ///     decoder_clone.refresh_keys_periodically(token_clone).await;
268    /// });
269    ///
270    /// // Later, to stop the refresh task:
271    /// shutdown_token.cancel();
272    /// ```
273    pub async fn refresh_keys_periodically(&self, shutdown_token: CancellationToken) {
274        loop {
275            tokio::select! {
276                _ = shutdown_token.cancelled() => {
277                    tracing::info!("JWKS refresh task shutting down gracefully");
278                    break;
279                }
280                _ = tokio::time::sleep(self.config.cache_duration) => {
281                    tracing::info!("Refreshing JWKS");
282                    match self.refresh_keys().await {
283                        Ok(_) => {}
284                        Err(err) => {
285                            // log the error and continue with stale keys
286                            tracing::error!(
287                                "Failed to refresh JWKS after {} attempts: {:?}",
288                                self.config.retry_count,
289                                err
290                            );
291                        }
292                    }
293                }
294            }
295        }
296    }
297
298    /// Checks that the key cache has been initialized.
299    ///
300    /// # Errors
301    ///
302    /// Returns `Error::Configuration` if the cache is empty, which indicates
303    /// that `initialize()` was never called.
304    fn check_initialized(&self) -> Result<(), Error> {
305        if self.keys_cache.is_empty() {
306            Err(Error::Configuration(
307                "JWKS decoder not initialized: call initialize() after building the decoder".into(),
308            ))
309        } else {
310            Ok(())
311        }
312    }
313}
314
315/// Builder for `RemoteJwksDecoder`.
316pub struct RemoteJwksDecoderBuilder {
317    jwks_url: Option<String>,
318    config: Option<RemoteJwksDecoderConfig>,
319    keys_cache: Option<Arc<DashMap<String, DecodingKey>>>,
320    validation: Option<Validation>,
321    client: Option<reqwest::Client>,
322}
323
324impl RemoteJwksDecoderBuilder {
325    /// Creates a new `RemoteJwksDecoderBuilder`.
326    pub fn new() -> Self {
327        Self {
328            jwks_url: None,
329            config: None,
330            keys_cache: None,
331            validation: None,
332            client: None,
333        }
334    }
335
336    /// Sets the JWKS URL.
337    pub fn jwks_url(mut self, jwks_url: String) -> Self {
338        self.jwks_url = Some(jwks_url);
339        self
340    }
341
342    /// Sets the configuration.
343    pub fn config(mut self, config: RemoteJwksDecoderConfig) -> Self {
344        self.config = Some(config);
345        self
346    }
347
348    /// Sets the keys cache.
349    pub fn keys_cache(mut self, keys_cache: Arc<DashMap<String, DecodingKey>>) -> Self {
350        self.keys_cache = Some(keys_cache);
351        self
352    }
353
354    /// Sets the validation settings.
355    pub fn validation(mut self, validation: Validation) -> Self {
356        self.validation = Some(validation);
357        self
358    }
359
360    /// Sets the HTTP client.
361    pub fn client(mut self, client: reqwest::Client) -> Self {
362        self.client = Some(client);
363        self
364    }
365
366    /// Builds the `RemoteJwksDecoder`.
367    ///
368    /// # Errors
369    ///
370    /// Returns `Error::Configuration` if required fields are missing.
371    pub fn build(self) -> Result<RemoteJwksDecoder, Error> {
372        let jwks_url = self
373            .jwks_url
374            .ok_or_else(|| Error::Configuration("jwks_url is required".into()))?;
375
376        let validation = self
377            .validation
378            .ok_or_else(|| Error::Configuration("validation is required".into()))?;
379
380        // Configure client with sensible timeouts if not provided
381        let client = self.client.unwrap_or_else(|| {
382            reqwest::Client::builder()
383                .timeout(std::time::Duration::from_secs(10))
384                .connect_timeout(std::time::Duration::from_secs(5))
385                .build()
386                .expect("Failed to build HTTP client")
387        });
388
389        Ok(RemoteJwksDecoder {
390            jwks_url,
391            config: self.config.unwrap_or_default(),
392            keys_cache: self.keys_cache.unwrap_or_else(|| Arc::new(DashMap::new())),
393            validation,
394            client,
395        })
396    }
397}
398
399impl Default for RemoteJwksDecoderBuilder {
400    fn default() -> Self {
401        Self::new()
402    }
403}
404
405impl<T> JwtDecoder<T> for RemoteJwksDecoder
406where
407    T: for<'de> DeserializeOwned,
408{
409    fn decode<'a>(
410        &'a self,
411        token: &'a str,
412    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<TokenData<T>, Error>> + Send + 'a>>
413    {
414        Box::pin(async move {
415            self.check_initialized()?;
416            let header = jsonwebtoken::decode_header(token)?;
417            let target_kid = header.kid;
418
419            if let Some(ref kid) = target_kid {
420                if let Some(key) = self.keys_cache.get(kid) {
421                    Ok(jsonwebtoken::decode::<T>(
422                        token,
423                        key.value(),
424                        &self.validation,
425                    )?)
426                } else {
427                    Err(Error::KeyNotFound(Some(kid.clone())))
428                }
429            } else {
430                Err(Error::KeyNotFound(None))
431            }
432        })
433    }
434}