axum_jwt_auth/remote.rs
1use std::sync::Arc;
2
3use async_trait::async_trait;
4use dashmap::DashMap;
5use derive_builder::Builder;
6use jsonwebtoken::{jwk::JwkSet, DecodingKey, TokenData, Validation};
7use serde::de::DeserializeOwned;
8use tokio::sync::Notify;
9
10use crate::{Error, JwtDecoder};
11
12const DEFAULT_CACHE_DURATION: std::time::Duration = std::time::Duration::from_secs(60 * 60); // 1 hour
13const DEFAULT_RETRY_COUNT: usize = 3; // 3 attempts
14const DEFAULT_BACKOFF: std::time::Duration = std::time::Duration::from_secs(1); // 1 second
15
16/// Configuration for remote JWKS fetching and caching behavior.
17#[derive(Debug, Clone, Builder)]
18pub struct RemoteJwksDecoderConfig {
19 /// Duration to cache JWKS keys before refreshing (default: 1 hour)
20 #[builder(default = "DEFAULT_CACHE_DURATION")]
21 pub cache_duration: std::time::Duration,
22 /// Number of retry attempts when fetching JWKS fails (default: 3)
23 #[builder(default = "DEFAULT_RETRY_COUNT")]
24 pub retry_count: usize,
25 /// Delay between retry attempts (default: 1 second)
26 #[builder(default = "DEFAULT_BACKOFF")]
27 pub backoff: std::time::Duration,
28}
29
30impl Default for RemoteJwksDecoderConfig {
31 fn default() -> Self {
32 Self {
33 cache_duration: DEFAULT_CACHE_DURATION,
34 retry_count: DEFAULT_RETRY_COUNT,
35 backoff: DEFAULT_BACKOFF,
36 }
37 }
38}
39
40impl RemoteJwksDecoderConfig {
41 /// Creates a new builder for configuring JWKS fetching behavior.
42 pub fn builder() -> RemoteJwksDecoderConfigBuilder {
43 RemoteJwksDecoderConfigBuilder::default()
44 }
45}
46
47/// JWT decoder that fetches and caches keys from a remote JWKS endpoint.
48///
49/// Automatically fetches JWKS from the specified URL, caches keys by their `kid` (key ID),
50/// and periodically refreshes them. Includes retry logic for robustness.
51///
52/// # Example
53///
54/// ```ignore
55/// use axum_jwt_auth::RemoteJwksDecoder;
56/// use jsonwebtoken::{Algorithm, Validation};
57///
58/// let decoder = RemoteJwksDecoder::builder()
59/// .jwks_url("https://example.com/.well-known/jwks.json".to_string())
60/// .validation(Validation::new(Algorithm::RS256))
61/// .build()
62/// .unwrap();
63///
64/// // Spawn background refresh task
65/// let decoder_clone = decoder.clone();
66/// tokio::spawn(async move {
67/// decoder_clone.refresh_keys_periodically().await;
68/// });
69/// ```
70#[derive(Clone, Builder)]
71pub struct RemoteJwksDecoder {
72 /// The JWKS endpoint URL
73 jwks_url: String,
74 /// Configuration for caching and retry behavior
75 #[builder(default = "RemoteJwksDecoderConfig::default()")]
76 config: RemoteJwksDecoderConfig,
77 /// Thread-safe cache mapping key IDs to decoding keys
78 #[builder(default = "Arc::new(DashMap::new())")]
79 keys_cache: Arc<DashMap<String, DecodingKey>>,
80 /// JWT validation settings
81 validation: Validation,
82 /// HTTP client for fetching JWKS
83 #[builder(default = "reqwest::Client::new()")]
84 client: reqwest::Client,
85 /// Notification for initialization completion
86 #[builder(default = "Arc::new(Notify::new())")]
87 initialized: Arc<Notify>,
88}
89
90impl RemoteJwksDecoder {
91 /// Creates a new `RemoteJwksDecoder` with the given JWKS URL and default settings.
92 ///
93 /// # Errors
94 ///
95 /// Returns `Error::Configuration` if the builder fails to construct the decoder.
96 pub fn new(jwks_url: String) -> Result<Self, Error> {
97 RemoteJwksDecoderBuilder::default()
98 .jwks_url(jwks_url)
99 .build()
100 .map_err(|e| Error::Configuration(e.to_string()))
101 }
102
103 /// Creates a new builder for configuring a remote JWKS decoder.
104 pub fn builder() -> RemoteJwksDecoderBuilder {
105 RemoteJwksDecoderBuilder::default()
106 }
107
108 /// Refreshes the JWKS cache with retry logic.
109 ///
110 /// Retries up to `config.retry_count` times, waiting `config.backoff` duration between attempts.
111 ///
112 /// # Errors
113 ///
114 /// Returns `Error::JwksRefresh` if all retry attempts fail.
115 async fn refresh_keys(&self) -> Result<(), Error> {
116 let max_attempts = self.config.retry_count;
117 let mut attempt = 0;
118 let mut err = None;
119
120 while attempt < max_attempts {
121 match self.refresh_keys_once().await {
122 Ok(_) => return Ok(()),
123 Err(e) => {
124 err = Some(e);
125 attempt += 1;
126 tokio::time::sleep(self.config.backoff).await;
127 }
128 }
129 }
130
131 Err(Error::JwksRefresh {
132 message: "Failed to refresh JWKS after multiple attempts".to_string(),
133 retry_count: max_attempts,
134 source: err.map(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>),
135 })
136 }
137
138 /// Fetches JWKS from the remote URL and updates the cache.
139 ///
140 /// Parses all keys before updating the cache to ensure atomicity.
141 async fn refresh_keys_once(&self) -> Result<(), Error> {
142 let jwks = self
143 .client
144 .get(&self.jwks_url)
145 .send()
146 .await?
147 .json::<JwkSet>()
148 .await?;
149
150 // Parse all keys first before clearing cache
151 let mut new_keys = Vec::new();
152 for jwk in jwks.keys.iter() {
153 let key_id = jwk.common.key_id.to_owned();
154 let key = DecodingKey::from_jwk(jwk).map_err(Error::Jwt)?;
155 new_keys.push((key_id.unwrap_or_default(), key));
156 }
157
158 // Only clear and update cache after all keys parsed successfully
159 self.keys_cache.clear();
160 for (kid, key) in new_keys {
161 self.keys_cache.insert(kid, key);
162 }
163
164 // Notify waiters after the first successful fetch
165 self.initialized.notify_waiters();
166
167 Ok(())
168 }
169
170 /// Runs an infinite loop that periodically refreshes the JWKS cache.
171 ///
172 /// This method never returns and should be spawned in a background task using `tokio::spawn`.
173 /// Refresh failures are logged, and the decoder continues using stale keys until the next
174 /// successful refresh.
175 ///
176 /// # Example
177 ///
178 /// ```ignore
179 /// let decoder = RemoteJwksDecoder::builder()
180 /// .jwks_url("https://example.com/.well-known/jwks.json".to_string())
181 /// .build()
182 /// .unwrap();
183 ///
184 /// let decoder_clone = decoder.clone();
185 /// tokio::spawn(async move {
186 /// decoder_clone.refresh_keys_periodically().await;
187 /// });
188 /// ```
189 pub async fn refresh_keys_periodically(&self) {
190 loop {
191 tracing::info!("Refreshing JWKS");
192 match self.refresh_keys().await {
193 Ok(_) => {}
194 Err(err) => {
195 // log the error and continue with stale keys
196 tracing::error!(
197 "Failed to refresh JWKS after {} attempts: {:?}",
198 self.config.retry_count,
199 err
200 );
201 }
202 }
203 tokio::time::sleep(self.config.cache_duration).await;
204 }
205 }
206
207 /// Ensures the key cache is initialized before attempting token validation.
208 ///
209 /// If the cache is empty, waits for the background refresh task to complete
210 /// the first successful key fetch.
211 async fn ensure_initialized(&self) {
212 // If we already have keys, we're already initialized
213 if !self.keys_cache.is_empty() {
214 tracing::trace!("Key store already initialised, continuing.");
215 return;
216 }
217
218 // If direct initialization failed, fall back to waiting for the background task
219 tracing::trace!("Waiting for background initialization to complete");
220 self.initialized.notified().await;
221 }
222}
223
224#[async_trait]
225impl<T> JwtDecoder<T> for RemoteJwksDecoder
226where
227 T: for<'de> DeserializeOwned,
228{
229 async fn decode(&self, token: &str) -> Result<TokenData<T>, Error> {
230 self.ensure_initialized().await;
231 let header = jsonwebtoken::decode_header(token)?;
232 let target_kid = header.kid;
233
234 if let Some(ref kid) = target_kid {
235 if let Some(key) = self.keys_cache.get(kid) {
236 return Ok(jsonwebtoken::decode::<T>(
237 token,
238 key.value(),
239 &self.validation,
240 )?);
241 }
242 return Err(Error::KeyNotFound(Some(kid.clone())));
243 }
244 return Err(Error::KeyNotFound(None));
245 }
246}