axum_jwt_auth/
remote.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use dashmap::DashMap;
5use derive_builder::Builder;
6use jsonwebtoken::{jwk::JwkSet, DecodingKey, TokenData, Validation};
7use serde::de::DeserializeOwned;
8use tokio::sync::Notify;
9
10use crate::{Error, JwtDecoder};
11
12const DEFAULT_CACHE_DURATION: std::time::Duration = std::time::Duration::from_secs(60 * 60); // 1 hour
13const DEFAULT_RETRY_COUNT: usize = 3; // 3 attempts
14const DEFAULT_BACKOFF: std::time::Duration = std::time::Duration::from_secs(1); // 1 second
15
16/// Configuration for remote JWKS fetching and caching behavior.
17#[derive(Debug, Clone, Builder)]
18pub struct RemoteJwksDecoderConfig {
19    /// Duration to cache JWKS keys before refreshing (default: 1 hour)
20    #[builder(default = "DEFAULT_CACHE_DURATION")]
21    pub cache_duration: std::time::Duration,
22    /// Number of retry attempts when fetching JWKS fails (default: 3)
23    #[builder(default = "DEFAULT_RETRY_COUNT")]
24    pub retry_count: usize,
25    /// Delay between retry attempts (default: 1 second)
26    #[builder(default = "DEFAULT_BACKOFF")]
27    pub backoff: std::time::Duration,
28}
29
30impl Default for RemoteJwksDecoderConfig {
31    fn default() -> Self {
32        Self {
33            cache_duration: DEFAULT_CACHE_DURATION,
34            retry_count: DEFAULT_RETRY_COUNT,
35            backoff: DEFAULT_BACKOFF,
36        }
37    }
38}
39
40impl RemoteJwksDecoderConfig {
41    /// Creates a new builder for configuring JWKS fetching behavior.
42    pub fn builder() -> RemoteJwksDecoderConfigBuilder {
43        RemoteJwksDecoderConfigBuilder::default()
44    }
45}
46
47/// JWT decoder that fetches and caches keys from a remote JWKS endpoint.
48///
49/// Automatically fetches JWKS from the specified URL, caches keys by their `kid` (key ID),
50/// and periodically refreshes them. Includes retry logic for robustness.
51///
52/// # Example
53///
54/// ```ignore
55/// use axum_jwt_auth::RemoteJwksDecoder;
56/// use jsonwebtoken::{Algorithm, Validation};
57///
58/// let decoder = RemoteJwksDecoder::builder()
59///     .jwks_url("https://example.com/.well-known/jwks.json".to_string())
60///     .validation(Validation::new(Algorithm::RS256))
61///     .build()
62///     .unwrap();
63///
64/// // Spawn background refresh task
65/// let decoder_clone = decoder.clone();
66/// tokio::spawn(async move {
67///     decoder_clone.refresh_keys_periodically().await;
68/// });
69/// ```
70#[derive(Clone, Builder)]
71pub struct RemoteJwksDecoder {
72    /// The JWKS endpoint URL
73    jwks_url: String,
74    /// Configuration for caching and retry behavior
75    #[builder(default = "RemoteJwksDecoderConfig::default()")]
76    config: RemoteJwksDecoderConfig,
77    /// Thread-safe cache mapping key IDs to decoding keys
78    #[builder(default = "Arc::new(DashMap::new())")]
79    keys_cache: Arc<DashMap<String, DecodingKey>>,
80    /// JWT validation settings
81    validation: Validation,
82    /// HTTP client for fetching JWKS
83    #[builder(default = "reqwest::Client::new()")]
84    client: reqwest::Client,
85    /// Notification for initialization completion
86    #[builder(default = "Arc::new(Notify::new())")]
87    initialized: Arc<Notify>,
88}
89
90impl RemoteJwksDecoder {
91    /// Creates a new `RemoteJwksDecoder` with the given JWKS URL and default settings.
92    ///
93    /// # Errors
94    ///
95    /// Returns `Error::Configuration` if the builder fails to construct the decoder.
96    pub fn new(jwks_url: String) -> Result<Self, Error> {
97        RemoteJwksDecoderBuilder::default()
98            .jwks_url(jwks_url)
99            .build()
100            .map_err(|e| Error::Configuration(e.to_string()))
101    }
102
103    /// Creates a new builder for configuring a remote JWKS decoder.
104    pub fn builder() -> RemoteJwksDecoderBuilder {
105        RemoteJwksDecoderBuilder::default()
106    }
107
108    /// Refreshes the JWKS cache with retry logic.
109    ///
110    /// Retries up to `config.retry_count` times, waiting `config.backoff` duration between attempts.
111    ///
112    /// # Errors
113    ///
114    /// Returns `Error::JwksRefresh` if all retry attempts fail.
115    async fn refresh_keys(&self) -> Result<(), Error> {
116        let max_attempts = self.config.retry_count;
117        let mut attempt = 0;
118        let mut err = None;
119
120        while attempt < max_attempts {
121            match self.refresh_keys_once().await {
122                Ok(_) => return Ok(()),
123                Err(e) => {
124                    err = Some(e);
125                    attempt += 1;
126                    tokio::time::sleep(self.config.backoff).await;
127                }
128            }
129        }
130
131        Err(Error::JwksRefresh {
132            message: "Failed to refresh JWKS after multiple attempts".to_string(),
133            retry_count: max_attempts,
134            source: err.map(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>),
135        })
136    }
137
138    /// Fetches JWKS from the remote URL and updates the cache.
139    ///
140    /// Parses all keys before updating the cache to ensure atomicity.
141    async fn refresh_keys_once(&self) -> Result<(), Error> {
142        let jwks = self
143            .client
144            .get(&self.jwks_url)
145            .send()
146            .await?
147            .json::<JwkSet>()
148            .await?;
149
150        // Parse all keys first before clearing cache
151        let mut new_keys = Vec::new();
152        for jwk in jwks.keys.iter() {
153            let key_id = jwk.common.key_id.to_owned();
154            let key = DecodingKey::from_jwk(jwk).map_err(Error::Jwt)?;
155            new_keys.push((key_id.unwrap_or_default(), key));
156        }
157
158        // Only clear and update cache after all keys parsed successfully
159        self.keys_cache.clear();
160        for (kid, key) in new_keys {
161            self.keys_cache.insert(kid, key);
162        }
163
164        // Notify waiters after the first successful fetch
165        self.initialized.notify_waiters();
166
167        Ok(())
168    }
169
170    /// Runs an infinite loop that periodically refreshes the JWKS cache.
171    ///
172    /// This method never returns and should be spawned in a background task using `tokio::spawn`.
173    /// Refresh failures are logged, and the decoder continues using stale keys until the next
174    /// successful refresh.
175    ///
176    /// # Example
177    ///
178    /// ```ignore
179    /// let decoder = RemoteJwksDecoder::builder()
180    ///     .jwks_url("https://example.com/.well-known/jwks.json".to_string())
181    ///     .build()
182    ///     .unwrap();
183    ///
184    /// let decoder_clone = decoder.clone();
185    /// tokio::spawn(async move {
186    ///     decoder_clone.refresh_keys_periodically().await;
187    /// });
188    /// ```
189    pub async fn refresh_keys_periodically(&self) {
190        loop {
191            tracing::info!("Refreshing JWKS");
192            match self.refresh_keys().await {
193                Ok(_) => {}
194                Err(err) => {
195                    // log the error and continue with stale keys
196                    tracing::error!(
197                        "Failed to refresh JWKS after {} attempts: {:?}",
198                        self.config.retry_count,
199                        err
200                    );
201                }
202            }
203            tokio::time::sleep(self.config.cache_duration).await;
204        }
205    }
206
207    /// Ensures the key cache is initialized before attempting token validation.
208    ///
209    /// If the cache is empty, waits for the background refresh task to complete
210    /// the first successful key fetch.
211    async fn ensure_initialized(&self) {
212        // If we already have keys, we're already initialized
213        if !self.keys_cache.is_empty() {
214            tracing::trace!("Key store already initialised, continuing.");
215            return;
216        }
217
218        // If direct initialization failed, fall back to waiting for the background task
219        tracing::trace!("Waiting for background initialization to complete");
220        self.initialized.notified().await;
221    }
222}
223
224#[async_trait]
225impl<T> JwtDecoder<T> for RemoteJwksDecoder
226where
227    T: for<'de> DeserializeOwned,
228{
229    async fn decode(&self, token: &str) -> Result<TokenData<T>, Error> {
230        self.ensure_initialized().await;
231        let header = jsonwebtoken::decode_header(token)?;
232        let target_kid = header.kid;
233
234        if let Some(ref kid) = target_kid {
235            if let Some(key) = self.keys_cache.get(kid) {
236                return Ok(jsonwebtoken::decode::<T>(
237                    token,
238                    key.value(),
239                    &self.validation,
240                )?);
241            }
242            return Err(Error::KeyNotFound(Some(kid.clone())));
243        }
244        return Err(Error::KeyNotFound(None));
245    }
246}