axum_jwt_auth/remote.rs
1use std::sync::Arc;
2
3use dashmap::DashMap;
4use jsonwebtoken::{DecodingKey, TokenData, Validation, jwk::JwkSet};
5use serde::de::DeserializeOwned;
6use tokio_util::sync::CancellationToken;
7
8use crate::{Error, JwtDecoder};
9
10const DEFAULT_CACHE_DURATION: std::time::Duration = std::time::Duration::from_secs(60 * 60); // 1 hour
11const DEFAULT_RETRY_COUNT: usize = 3; // 3 attempts
12const DEFAULT_BACKOFF: std::time::Duration = std::time::Duration::from_secs(1); // 1 second
13
14/// Configuration for remote JWKS fetching and caching behavior.
15#[derive(Debug, Clone)]
16pub struct RemoteJwksDecoderConfig {
17 /// Duration to cache JWKS keys before refreshing (default: 1 hour)
18 pub cache_duration: std::time::Duration,
19 /// Number of retry attempts when fetching JWKS fails (default: 3)
20 pub retry_count: usize,
21 /// Delay between retry attempts (default: 1 second)
22 pub backoff: std::time::Duration,
23}
24
25impl Default for RemoteJwksDecoderConfig {
26 fn default() -> Self {
27 Self {
28 cache_duration: DEFAULT_CACHE_DURATION,
29 retry_count: DEFAULT_RETRY_COUNT,
30 backoff: DEFAULT_BACKOFF,
31 }
32 }
33}
34
35impl RemoteJwksDecoderConfig {
36 /// Creates a new builder for configuring JWKS fetching behavior.
37 pub fn builder() -> RemoteJwksDecoderConfigBuilder {
38 RemoteJwksDecoderConfigBuilder {
39 cache_duration: None,
40 retry_count: None,
41 backoff: None,
42 }
43 }
44}
45
46/// Builder for `RemoteJwksDecoderConfig`.
47pub struct RemoteJwksDecoderConfigBuilder {
48 cache_duration: Option<std::time::Duration>,
49 retry_count: Option<usize>,
50 backoff: Option<std::time::Duration>,
51}
52
53impl RemoteJwksDecoderConfigBuilder {
54 /// Sets the cache duration.
55 pub fn cache_duration(mut self, cache_duration: std::time::Duration) -> Self {
56 self.cache_duration = Some(cache_duration);
57 self
58 }
59
60 /// Sets the retry count.
61 pub fn retry_count(mut self, retry_count: usize) -> Self {
62 self.retry_count = Some(retry_count);
63 self
64 }
65
66 /// Sets the backoff duration.
67 pub fn backoff(mut self, backoff: std::time::Duration) -> Self {
68 self.backoff = Some(backoff);
69 self
70 }
71
72 /// Builds the `RemoteJwksDecoderConfig` with defaults for unset fields.
73 pub fn build(self) -> RemoteJwksDecoderConfig {
74 RemoteJwksDecoderConfig {
75 cache_duration: self.cache_duration.unwrap_or(DEFAULT_CACHE_DURATION),
76 retry_count: self.retry_count.unwrap_or(DEFAULT_RETRY_COUNT),
77 backoff: self.backoff.unwrap_or(DEFAULT_BACKOFF),
78 }
79 }
80}
81
82/// JWT decoder that fetches and caches keys from a remote JWKS endpoint.
83///
84/// Automatically fetches JWKS from the specified URL, caches keys by their `kid` (key ID),
85/// and periodically refreshes them in the background. Includes retry logic for robustness.
86///
87/// # Example
88///
89/// ```ignore
90/// use axum_jwt_auth::RemoteJwksDecoder;
91/// use jsonwebtoken::{Algorithm, Validation};
92///
93/// let decoder = RemoteJwksDecoder::builder()
94/// .jwks_url("https://example.com/.well-known/jwks.json".to_string())
95/// .validation(Validation::new(Algorithm::RS256))
96/// .build()
97/// .unwrap();
98///
99/// // Initialize: fetch keys and start background refresh task
100/// decoder.initialize().await.unwrap();
101/// ```
102#[derive(Clone)]
103pub struct RemoteJwksDecoder {
104 /// The JWKS endpoint URL
105 jwks_url: String,
106 /// Configuration for caching and retry behavior
107 config: RemoteJwksDecoderConfig,
108 /// Thread-safe cache mapping key IDs to decoding keys
109 keys_cache: Arc<DashMap<String, DecodingKey>>,
110 /// JWT validation settings
111 validation: Validation,
112 /// HTTP client for fetching JWKS
113 client: reqwest::Client,
114}
115
116impl RemoteJwksDecoder {
117 /// Creates a new `RemoteJwksDecoder` with the given JWKS URL and default settings.
118 ///
119 /// # Errors
120 ///
121 /// Returns `Error::Configuration` if the builder fails to construct the decoder.
122 pub fn new(jwks_url: String) -> Result<Self, Error> {
123 RemoteJwksDecoderBuilder::new().jwks_url(jwks_url).build()
124 }
125
126 /// Creates a new builder for configuring a remote JWKS decoder.
127 pub fn builder() -> RemoteJwksDecoderBuilder {
128 RemoteJwksDecoderBuilder::new()
129 }
130
131 /// Performs an initial fetch of JWKS keys and starts the background refresh task.
132 ///
133 /// This method should be called once after construction. It will:
134 /// 1. Immediately fetch keys from the JWKS endpoint
135 /// 2. Spawn a background task to periodically refresh keys
136 ///
137 /// Returns a `CancellationToken` that can be used to gracefully stop the background refresh task.
138 ///
139 /// # Errors
140 ///
141 /// Returns an error if the initial fetch fails after all retry attempts.
142 ///
143 /// # Example
144 ///
145 /// ```ignore
146 /// let decoder = RemoteJwksDecoder::builder()
147 /// .jwks_url("https://example.com/.well-known/jwks.json".to_string())
148 /// .validation(Validation::new(Algorithm::RS256))
149 /// .build()?;
150 ///
151 /// // Fetch keys and start background refresh
152 /// let shutdown_token = decoder.initialize().await?;
153 ///
154 /// // Later, during application shutdown:
155 /// shutdown_token.cancel();
156 /// ```
157 pub async fn initialize(&self) -> Result<CancellationToken, Error> {
158 // Fetch keys immediately
159 self.refresh_keys().await?;
160
161 // Create cancellation token for graceful shutdown
162 let shutdown_token = CancellationToken::new();
163
164 // Spawn background refresh task
165 let decoder_clone = self.clone();
166 let token_clone = shutdown_token.clone();
167 tokio::spawn(async move {
168 decoder_clone.refresh_keys_periodically(token_clone).await;
169 });
170
171 Ok(shutdown_token)
172 }
173
174 /// Manually triggers a JWKS refresh with retry logic.
175 ///
176 /// Useful for forcing an update outside the normal refresh cycle.
177 ///
178 /// # Errors
179 ///
180 /// Returns an error if the refresh fails after all retry attempts.
181 pub async fn refresh(&self) -> Result<(), Error> {
182 self.refresh_keys().await
183 }
184
185 /// Refreshes the JWKS cache with retry logic.
186 ///
187 /// Retries up to `config.retry_count` times, waiting `config.backoff` duration between attempts.
188 ///
189 /// # Errors
190 ///
191 /// Returns `Error::JwksRefresh` if all retry attempts fail.
192 async fn refresh_keys(&self) -> Result<(), Error> {
193 let max_attempts = self.config.retry_count;
194 let mut attempt = 0;
195 let mut err = None;
196
197 while attempt < max_attempts {
198 match self.refresh_keys_once().await {
199 Ok(_) => return Ok(()),
200 Err(e) => {
201 err = Some(e);
202 attempt += 1;
203 tokio::time::sleep(self.config.backoff).await;
204 }
205 }
206 }
207
208 Err(Error::JwksRefresh {
209 message: "Failed to refresh JWKS after multiple attempts".to_string(),
210 retry_count: max_attempts,
211 source: err.map(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>),
212 })
213 }
214
215 /// Fetches JWKS from the remote URL and updates the cache.
216 ///
217 /// Parses all keys before updating the cache to ensure atomicity.
218 async fn refresh_keys_once(&self) -> Result<(), Error> {
219 let jwks = self
220 .client
221 .get(&self.jwks_url)
222 .send()
223 .await?
224 .json::<JwkSet>()
225 .await?;
226
227 // Parse all keys first before clearing cache
228 let mut new_keys = Vec::new();
229 for jwk in jwks.keys.iter() {
230 let key_id = jwk.common.key_id.to_owned();
231 let key = DecodingKey::from_jwk(jwk).map_err(Error::Jwt)?;
232 new_keys.push((key_id.unwrap_or_default(), key));
233 }
234
235 // Only clear and update cache after all keys parsed successfully
236 self.keys_cache.clear();
237 for (kid, key) in new_keys {
238 self.keys_cache.insert(kid, key);
239 }
240
241 Ok(())
242 }
243
244 /// Runs a loop that periodically refreshes the JWKS cache until cancelled.
245 ///
246 /// This method should be spawned in a background task using `tokio::spawn`.
247 /// Refresh failures are logged, and the decoder continues using stale keys until the next
248 /// successful refresh.
249 ///
250 /// The loop will exit gracefully when the `shutdown_token` is cancelled.
251 ///
252 /// # Example
253 ///
254 /// ```ignore
255 /// use tokio_util::sync::CancellationToken;
256 ///
257 /// let decoder = RemoteJwksDecoder::builder()
258 /// .jwks_url("https://example.com/.well-known/jwks.json".to_string())
259 /// .build()
260 /// .unwrap();
261 ///
262 /// let shutdown_token = CancellationToken::new();
263 /// let decoder_clone = decoder.clone();
264 /// let token_clone = shutdown_token.clone();
265 ///
266 /// tokio::spawn(async move {
267 /// decoder_clone.refresh_keys_periodically(token_clone).await;
268 /// });
269 ///
270 /// // Later, to stop the refresh task:
271 /// shutdown_token.cancel();
272 /// ```
273 pub async fn refresh_keys_periodically(&self, shutdown_token: CancellationToken) {
274 loop {
275 tokio::select! {
276 _ = shutdown_token.cancelled() => {
277 tracing::info!("JWKS refresh task shutting down gracefully");
278 break;
279 }
280 _ = tokio::time::sleep(self.config.cache_duration) => {
281 tracing::info!("Refreshing JWKS");
282 match self.refresh_keys().await {
283 Ok(_) => {}
284 Err(err) => {
285 // log the error and continue with stale keys
286 tracing::error!(
287 "Failed to refresh JWKS after {} attempts: {:?}",
288 self.config.retry_count,
289 err
290 );
291 }
292 }
293 }
294 }
295 }
296 }
297
298 /// Checks that the key cache has been initialized.
299 ///
300 /// # Errors
301 ///
302 /// Returns `Error::Configuration` if the cache is empty, which indicates
303 /// that `initialize()` was never called.
304 fn check_initialized(&self) -> Result<(), Error> {
305 if self.keys_cache.is_empty() {
306 Err(Error::Configuration(
307 "JWKS decoder not initialized: call initialize() after building the decoder".into(),
308 ))
309 } else {
310 Ok(())
311 }
312 }
313}
314
315/// Builder for `RemoteJwksDecoder`.
316pub struct RemoteJwksDecoderBuilder {
317 jwks_url: Option<String>,
318 config: Option<RemoteJwksDecoderConfig>,
319 keys_cache: Option<Arc<DashMap<String, DecodingKey>>>,
320 validation: Option<Validation>,
321 client: Option<reqwest::Client>,
322}
323
324impl RemoteJwksDecoderBuilder {
325 /// Creates a new `RemoteJwksDecoderBuilder`.
326 pub fn new() -> Self {
327 Self {
328 jwks_url: None,
329 config: None,
330 keys_cache: None,
331 validation: None,
332 client: None,
333 }
334 }
335
336 /// Sets the JWKS URL.
337 pub fn jwks_url(mut self, jwks_url: String) -> Self {
338 self.jwks_url = Some(jwks_url);
339 self
340 }
341
342 /// Sets the configuration.
343 pub fn config(mut self, config: RemoteJwksDecoderConfig) -> Self {
344 self.config = Some(config);
345 self
346 }
347
348 /// Sets the keys cache.
349 pub fn keys_cache(mut self, keys_cache: Arc<DashMap<String, DecodingKey>>) -> Self {
350 self.keys_cache = Some(keys_cache);
351 self
352 }
353
354 /// Sets the validation settings.
355 pub fn validation(mut self, validation: Validation) -> Self {
356 self.validation = Some(validation);
357 self
358 }
359
360 /// Sets the HTTP client.
361 pub fn client(mut self, client: reqwest::Client) -> Self {
362 self.client = Some(client);
363 self
364 }
365
366 /// Builds the `RemoteJwksDecoder`.
367 ///
368 /// # Errors
369 ///
370 /// Returns `Error::Configuration` if required fields are missing.
371 pub fn build(self) -> Result<RemoteJwksDecoder, Error> {
372 let jwks_url = self
373 .jwks_url
374 .ok_or_else(|| Error::Configuration("jwks_url is required".into()))?;
375
376 let validation = self
377 .validation
378 .ok_or_else(|| Error::Configuration("validation is required".into()))?;
379
380 // Configure client with sensible timeouts if not provided
381 let client = self.client.unwrap_or_else(|| {
382 reqwest::Client::builder()
383 .timeout(std::time::Duration::from_secs(10))
384 .connect_timeout(std::time::Duration::from_secs(5))
385 .build()
386 .expect("Failed to build HTTP client")
387 });
388
389 Ok(RemoteJwksDecoder {
390 jwks_url,
391 config: self.config.unwrap_or_default(),
392 keys_cache: self.keys_cache.unwrap_or_else(|| Arc::new(DashMap::new())),
393 validation,
394 client,
395 })
396 }
397}
398
399impl Default for RemoteJwksDecoderBuilder {
400 fn default() -> Self {
401 Self::new()
402 }
403}
404
405impl<T> JwtDecoder<T> for RemoteJwksDecoder
406where
407 T: for<'de> DeserializeOwned,
408{
409 fn decode<'a>(
410 &'a self,
411 token: &'a str,
412 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<TokenData<T>, Error>> + Send + 'a>>
413 {
414 Box::pin(async move {
415 self.check_initialized()?;
416 let header = jsonwebtoken::decode_header(token)?;
417 let target_kid = header.kid;
418
419 if let Some(ref kid) = target_kid {
420 if let Some(key) = self.keys_cache.get(kid) {
421 Ok(jsonwebtoken::decode::<T>(
422 token,
423 key.value(),
424 &self.validation,
425 )?)
426 } else {
427 Err(Error::KeyNotFound(Some(kid.clone())))
428 }
429 } else {
430 Err(Error::KeyNotFound(None))
431 }
432 })
433 }
434}