Skip to main content

jsonwebtoken_jwks_cache/cache/
mod.rs

1#[cfg(test)]
2mod test;
3
4use super::pem_set::PemMap;
5use core::future::Future;
6use jsonwebtoken::jwk::JwkSet;
7use spin::RwLock;
8use std::sync::Arc;
9use std::time::{Duration, SystemTime};
10use tokio::sync::Notify;
11use url::Url;
12
13fn get_expiration(now: SystemTime, req: &reqwest::Request, res: &reqwest::Response) -> SystemTime {
14    now + http_cache_semantics::CachePolicy::new(req, res).time_to_live(now)
15}
16
17pub trait JwksSource: Clone + Send + Sync + 'static {
18    type Error: core::fmt::Debug + Send + Sync + 'static;
19
20    fn get_jwks_within_deadline(
21        self,
22        url: Url,
23        as_pkeys: bool,
24        now: SystemTime,
25        deadline: Duration,
26    ) -> impl Future<Output = Result<(JwkSet, SystemTime), RequestError<Self::Error>>>
27    + Send
28    + Sync
29    + 'static {
30        async move {
31            let result = tokio::time::timeout(deadline, self.get_jwks(url, as_pkeys, now)).await;
32
33            match result {
34                Ok(res) => res.map_err(RequestError::Client),
35                Err(_) => Err(RequestError::Timeout),
36            }
37        }
38    }
39
40    fn get_jwks(
41        self,
42        url: Url,
43        as_pkeys: bool,
44        now: SystemTime,
45    ) -> impl Future<Output = Result<(JwkSet, SystemTime), Self::Error>> + Send + Sync + 'static;
46}
47
48impl JwksSource for reqwest::Client {
49    type Error = reqwest::Error;
50
51    async fn get_jwks(
52        self,
53        url: Url,
54        as_pkeys: bool,
55        now: SystemTime,
56    ) -> Result<(JwkSet, SystemTime), Self::Error> {
57        let req = reqwest::Request::new(http::Method::GET, url.clone());
58        let res = reqwest::Client::builder()
59            .build()?
60            .execute(
61                // safety: because we control the request creation we can ensure its not a stateful stream and can be copied at all times
62                req.try_clone().expect("Request should be always copyable"),
63            )
64            .await?
65            .error_for_status()?;
66
67        let expiration = get_expiration(now, &req, &res);
68        let jwks = if as_pkeys {
69            res.json::<PemMap>().await?.into_rsa_jwk_set()
70        } else {
71            res.json::<JwkSet>().await?
72        };
73
74        Ok((jwks, expiration))
75    }
76}
77
78/// State machine of the JWKS cache
79#[derive(Debug, Clone, Default)]
80enum JWKSCache {
81    /// There is no data in cache, this is initial state
82    #[default]
83    Empty,
84    /// Cache is empty or expired, fetching of new content is ongoing.
85    /// Contains handle for awaiting for fetching to conclude
86    Fetching(Arc<Notify>),
87    /// Cache is valid, but content is being refreshed in the background
88    Refreshing { expires: SystemTime, jwks: JwkSet },
89    /// Cache is populated, but needs to be revalidated before use
90    Fetched { expires: SystemTime, jwks: JwkSet },
91}
92
93#[derive(Debug, thiserror::Error)]
94pub enum RequestError<E: core::fmt::Debug> {
95    #[error("Client error: {0}")]
96    Client(E),
97    #[error("Timeout for request completion reached")]
98    Timeout,
99}
100
101impl<T: core::fmt::Debug> From<T> for RequestError<T> {
102    fn from(value: T) -> Self {
103        Self::Client(value)
104    }
105}
106
107#[derive(Debug, Clone, Copy)]
108pub struct TimeoutSpec {
109    /// How many times to retry on failure (timeout or client error)
110    pub retries: u8,
111    /// How long to wait for a single response before retrying
112    pub retry_after: Duration,
113    /// Waiting between retries
114    pub backoff: Duration,
115    /// Total time for completion before considering failure
116    pub deadline: Duration,
117}
118
119impl Default for TimeoutSpec {
120    fn default() -> Self {
121        Self {
122            retries: 0,
123            retry_after: Duration::from_secs(10),
124            backoff: Duration::ZERO,
125            deadline: Duration::from_secs(10),
126        }
127    }
128}
129
130#[derive(Clone)]
131pub struct CachedJWKS<S> {
132    jwks_url: Url,
133    pkeys: bool,
134    update_period: Duration,
135    timeout_spec: TimeoutSpec,
136    cache_state: Arc<RwLock<JWKSCache>>,
137    source: S,
138}
139
140impl CachedJWKS<reqwest::Client> {
141    pub fn new(
142        jwks_url: Url,
143        // Period when to refresh in the background before expiration period
144        update_period: Duration,
145        timeout_spec: TimeoutSpec,
146    ) -> Result<Self, reqwest::Error> {
147        Ok(Self::from_source(
148            jwks_url,
149            false,
150            update_period,
151            timeout_spec,
152            reqwest::Client::builder().build()?,
153        ))
154    }
155
156    /// Load keys as a map of RSA pub keys
157    pub fn new_rsa_pkeys(
158        pkeys_url: Url,
159        // Period when to refresh in the background before expiration period
160        update_period: Duration,
161        timeout_spec: TimeoutSpec,
162    ) -> Result<Self, reqwest::Error> {
163        Ok(Self::from_source(
164            pkeys_url,
165            true,
166            update_period,
167            timeout_spec,
168            reqwest::Client::builder().build()?,
169        ))
170    }
171}
172
173impl<S: JwksSource> CachedJWKS<S> {
174    pub fn from_source(
175        jwks_url: Url,
176        pkeys: bool,
177        update_period: Duration,
178        timeout_spec: TimeoutSpec,
179        source: S,
180    ) -> Self {
181        assert!(
182            update_period > timeout_spec.deadline,
183            "Update period should be greater than timeout deadline"
184        );
185
186        Self {
187            jwks_url,
188            pkeys,
189            update_period,
190            timeout_spec,
191            cache_state: Default::default(),
192            source,
193        }
194    }
195
196    async fn request(
197        source: S,
198        url: Url,
199        as_pkeys: bool,
200        now: SystemTime,
201        timeout: TimeoutSpec,
202    ) -> Result<(JwkSet, SystemTime), RequestError<S::Error>> {
203        let perform = async {
204            let mut retries = 0u8;
205            loop {
206                match source
207                    .clone()
208                    .get_jwks_within_deadline(url.clone(), as_pkeys, now, timeout.retry_after)
209                    .await
210                {
211                    Ok(res) => return Ok(res),
212                    Err(err) => {
213                        if retries == timeout.retries {
214                            return Err(err);
215                        } else {
216                            retries += 1;
217                            tokio::time::sleep(timeout.backoff).await;
218                            continue;
219                        }
220                    }
221                }
222            }
223        };
224
225        tokio::time::timeout(timeout.deadline, perform)
226            .await
227            .map_err(|_| RequestError::Timeout)?
228    }
229
230    async fn update_notify(
231        &self,
232        now: SystemTime,
233    ) -> Result<Option<JwkSet>, RequestError<S::Error>> {
234        let notifier = if let Some(mut cached_state) = self.cache_state.try_write() {
235            let notifier = Arc::new(Notify::new());
236
237            *cached_state = JWKSCache::Fetching(notifier.clone());
238
239            notifier
240        } else {
241            return Ok(None);
242        };
243
244        let result = Self::request(
245            self.source.clone(),
246            self.jwks_url.clone(),
247            self.pkeys,
248            now,
249            self.timeout_spec,
250        )
251        .await;
252
253        let result = {
254            let mut cached_state = self.cache_state.write();
255
256            match result {
257                Ok((jwks, expires)) => {
258                    *cached_state = JWKSCache::Fetched {
259                        expires,
260                        jwks: jwks.clone(),
261                    };
262
263                    Ok(Some(jwks))
264                }
265                // Could not fetch in time, let follow up request try again later
266                Err(err) => {
267                    *cached_state = JWKSCache::Empty;
268
269                    Err(err)
270                }
271            }
272        };
273
274        notifier.notify_waiters();
275
276        result
277    }
278
279    /// Trigger refresh of JWKS in the background when cached JWKS is stil valid but about to expire,
280    /// if process dies then we do not care if this completes
281    fn update_in_background(&self, now: SystemTime, old_jwks: JwkSet, old_expires: SystemTime) {
282        {
283            let mut cache_state = self.cache_state.write();
284
285            *cache_state = JWKSCache::Refreshing {
286                expires: old_expires,
287                jwks: old_jwks,
288            };
289        }
290
291        let cache_state = self.cache_state.clone();
292        let jwks_url = self.jwks_url.clone();
293        let timeout_spec = self.timeout_spec;
294        let source = self.source.clone();
295        let as_pkeys = self.pkeys;
296
297        tokio::spawn(async move {
298            let result = Self::request(source, jwks_url, as_pkeys, now, timeout_spec).await;
299
300            if let Err(err) = &result {
301                log::error!("Error while refreshing JWKS in the background: {err:?}");
302            }
303
304            let mut cache_state = cache_state.write();
305
306            let new_state = match cache_state.to_owned() {
307                JWKSCache::Empty => match result {
308                    Ok((jwks, expires)) => JWKSCache::Fetched { expires, jwks },
309                    Err(_) => JWKSCache::Empty,
310                },
311                JWKSCache::Fetching(notify) => {
312                    if let Ok((jwks, expires)) = result {
313                        notify.notify_waiters();
314                        JWKSCache::Fetched { expires, jwks }
315                    } else {
316                        JWKSCache::Fetching(notify)
317                    }
318                }
319                JWKSCache::Refreshing { expires, jwks } => {
320                    if let Ok((jwks, expires)) = result {
321                        JWKSCache::Fetched { expires, jwks }
322                    } else {
323                        JWKSCache::Refreshing { expires, jwks }
324                    }
325                }
326                JWKSCache::Fetched { expires, jwks } => {
327                    if let Ok((jwks, expires)) = result {
328                        JWKSCache::Fetched { expires, jwks }
329                    } else {
330                        JWKSCache::Refreshing { expires, jwks }
331                    }
332                }
333            };
334
335            *cache_state = new_state;
336        });
337    }
338
339    pub async fn get(&self) -> Result<JwkSet, RequestError<S::Error>> {
340        let now = SystemTime::now();
341        loop {
342            let cached_state = self.cache_state.read().clone();
343
344            match cached_state {
345                JWKSCache::Empty => {
346                    if let Some(jwks) = self.update_notify(now).await? {
347                        return Ok(jwks);
348                    } else {
349                        // state changed since reading it, reload
350                        continue;
351                    }
352                }
353                JWKSCache::Fetching(notifier) => {
354                    notifier.notified().await;
355
356                    // we got notified about change in state, reload
357                    continue;
358                }
359                JWKSCache::Refreshing { expires: _, jwks } => {
360                    // Refresh mechanism should guarantee it will change the state before cache is no longer valid
361                    return Ok(jwks);
362                }
363                JWKSCache::Fetched { expires, jwks } => {
364                    if now >= expires {
365                        if let Some(jwks) = self.update_notify(now).await? {
366                            return Ok(jwks);
367                        } else {
368                            // state changed since reading it, reload
369                            continue;
370                        }
371                    }
372
373                    if now + self.update_period >= expires {
374                        self.update_in_background(now, jwks.clone(), expires);
375                    }
376
377                    return Ok(jwks);
378                }
379            }
380        }
381    }
382}