axum_jwt_auth/
remote.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use dashmap::DashMap;
5use derive_builder::Builder;
6use jsonwebtoken::{jwk::JwkSet, DecodingKey, TokenData, Validation};
7use serde::de::DeserializeOwned;
8use tokio::sync::Notify;
9
10use crate::{Error, JwtDecoder};
11
12const DEFAULT_CACHE_DURATION: std::time::Duration = std::time::Duration::from_secs(60 * 60); // 1 hour
13const DEFAULT_RETRY_COUNT: usize = 3; // 3 attempts
14const DEFAULT_BACKOFF: std::time::Duration = std::time::Duration::from_secs(1); // 1 second
15
16#[derive(Debug, Clone, Builder)]
17pub struct RemoteJwksDecoderConfig {
18    /// How long to cache the JWKS keys for
19    #[builder(default = "DEFAULT_CACHE_DURATION")]
20    pub cache_duration: std::time::Duration,
21    /// How many times to retry fetching the JWKS keys if it fails
22    #[builder(default = "DEFAULT_RETRY_COUNT")]
23    pub retry_count: usize,
24    /// How long to wait before retrying fetching the JWKS keys
25    #[builder(default = "DEFAULT_BACKOFF")]
26    pub backoff: std::time::Duration,
27}
28
29impl Default for RemoteJwksDecoderConfig {
30    fn default() -> Self {
31        Self {
32            cache_duration: DEFAULT_CACHE_DURATION,
33            retry_count: DEFAULT_RETRY_COUNT,
34            backoff: DEFAULT_BACKOFF,
35        }
36    }
37}
38
39impl RemoteJwksDecoderConfig {
40    /// Creates a new [`RemoteJwksDecoderConfigBuilder`].
41    ///
42    /// This is a convenience method to create a builder for the config.
43    pub fn builder() -> RemoteJwksDecoderConfigBuilder {
44        RemoteJwksDecoderConfigBuilder::default()
45    }
46}
47
48/// Remote JWKS decoder.
49/// It fetches the JWKS from the given URL and caches it for the given duration.
50/// It uses the cached JWKS to decode the JWT tokens.
51#[derive(Clone, Builder)]
52pub struct RemoteJwksDecoder {
53    /// The URL to fetch the JWKS from
54    jwks_url: String,
55    /// The configuration for the decoder
56    #[builder(default = "RemoteJwksDecoderConfig::default()")]
57    config: RemoteJwksDecoderConfig,
58    /// The cache for the JWKS keys
59    #[builder(default = "Arc::new(DashMap::new())")]
60    keys_cache: Arc<DashMap<String, DecodingKey>>,
61    /// The validation settings for the JWT tokens
62    validation: Validation,
63    /// The HTTP client to use for fetching the JWKS
64    #[builder(default = "reqwest::Client::new()")]
65    client: reqwest::Client,
66    /// The initialized flag
67    #[builder(default = "Arc::new(Notify::new())")]
68    initialized: Arc<Notify>,
69}
70
71impl RemoteJwksDecoder {
72    /// Creates a new [`RemoteJwksDecoder`] with the given JWKS URL.
73    pub fn new(jwks_url: String) -> Result<Self, Error> {
74        RemoteJwksDecoderBuilder::default()
75            .jwks_url(jwks_url)
76            .build()
77            .map_err(|e| Error::Configuration(e.to_string()))
78    }
79
80    /// Creates a new [`RemoteJwksDecoderBuilder`].
81    ///
82    /// This is a convenience method to create a builder for the decoder.
83    pub fn builder() -> RemoteJwksDecoderBuilder {
84        RemoteJwksDecoderBuilder::default()
85    }
86
87    /// Refreshes the JWKS cache.
88    /// It retries the refresh up to [`RemoteJwksDecoderConfig::retry_count`] times,
89    /// waiting [`RemoteJwksDecoderConfig::backoff`] seconds between attempts.
90    /// If it fails after all attempts, it returns the error.
91    async fn refresh_keys(&self) -> Result<(), Error> {
92        let max_attempts = self.config.retry_count;
93        let mut attempt = 0;
94        let mut err = None;
95
96        while attempt < max_attempts {
97            match self.refresh_keys_once().await {
98                Ok(_) => return Ok(()),
99                Err(e) => {
100                    err = Some(e);
101                    attempt += 1;
102                    tokio::time::sleep(self.config.backoff).await;
103                }
104            }
105        }
106
107        Err(Error::JwksRefresh {
108            message: "Failed to refresh JWKS after multiple attempts".to_string(),
109            retry_count: max_attempts,
110            source: err.map(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>),
111        })
112    }
113
114    /// Refreshes the JWKS cache once.
115    /// It fetches the JWKS from the given URL and caches the keys.
116    async fn refresh_keys_once(&self) -> Result<(), Error> {
117        let jwks = self
118            .client
119            .get(&self.jwks_url)
120            .send()
121            .await?
122            .json::<JwkSet>()
123            .await?;
124
125        // Parse all keys first before clearing cache
126        let mut new_keys = Vec::new();
127        for jwk in jwks.keys.iter() {
128            let key_id = jwk.common.key_id.to_owned();
129            let key = DecodingKey::from_jwk(jwk).map_err(Error::Jwt)?;
130            new_keys.push((key_id.unwrap_or_default(), key));
131        }
132
133        // Only clear and update cache after all keys parsed successfully
134        self.keys_cache.clear();
135        for (kid, key) in new_keys {
136            self.keys_cache.insert(kid, key);
137        }
138
139        // Notify waiters after the first successful fetch
140        self.initialized.notify_waiters();
141
142        Ok(())
143    }
144
145    /// Refreshes the JWKS cache periodically.
146    /// It runs in a loop and never returns, so it should be run in a separate tokio task
147    /// using [`tokio::spawn`]. If the JWKS refresh fails after multiple attemps,
148    /// it logs the error and continues. The decoder will use the stale keys until the next refresh
149    /// succeeds or the universe ends, whichever comes first.
150    pub async fn refresh_keys_periodically(&self) {
151        loop {
152            tracing::info!("Refreshing JWKS");
153            match self.refresh_keys().await {
154                Ok(_) => {}
155                Err(err) => {
156                    // log the error and continue with stale keys
157                    tracing::error!(
158                        "Failed to refresh JWKS after {} attempts: {:?}",
159                        self.config.retry_count,
160                        err
161                    );
162                }
163            }
164            tokio::time::sleep(self.config.cache_duration).await;
165        }
166    }
167
168    /// Ensures keys are available before proceeding
169    async fn ensure_initialized(&self) {
170        self.initialized.notified().await;
171    }
172}
173
174#[async_trait]
175impl<T> JwtDecoder<T> for RemoteJwksDecoder
176where
177    T: for<'de> DeserializeOwned,
178{
179    async fn decode(&self, token: &str) -> Result<TokenData<T>, Error> {
180        self.ensure_initialized().await;
181        let header = jsonwebtoken::decode_header(token)?;
182        let target_kid = header.kid;
183
184        if let Some(ref kid) = target_kid {
185            if let Some(key) = self.keys_cache.get(kid) {
186                return Ok(jsonwebtoken::decode::<T>(
187                    token,
188                    key.value(),
189                    &self.validation,
190                )?);
191            }
192            return Err(Error::KeyNotFound(Some(kid.clone())));
193        }
194        return Err(Error::KeyNotFound(None));
195    }
196}