jsonwebtoken_jwks_cache/cache/
mod.rs

1#[cfg(test)]
2mod test;
3
4use core::future::Future;
5use jsonwebtoken::jwk::JwkSet;
6use spin::RwLock;
7use std::sync::Arc;
8use std::time::{Duration, SystemTime};
9use tokio::sync::Notify;
10use url::Url;
11
12fn get_expiration(now: SystemTime, req: &reqwest::Request, res: &reqwest::Response) -> SystemTime {
13    now + http_cache_semantics::CachePolicy::new(req, res).time_to_live(now)
14}
15
16pub trait JwksSource: Clone + Send + Sync + 'static {
17    type Error: core::fmt::Debug + Send + Sync + 'static;
18
19    fn get_jwks_within_deadline(
20        self,
21        url: Url,
22        now: SystemTime,
23        deadline: Duration,
24    ) -> impl Future<Output = Result<(JwkSet, SystemTime), RequestError<Self::Error>>>
25    + Send
26    + Sync
27    + 'static {
28        async move {
29            let result = tokio::time::timeout(deadline, self.get_jwks(url, now)).await;
30
31            match result {
32                Ok(res) => res.map_err(RequestError::Client),
33                Err(_) => Err(RequestError::Timeout),
34            }
35        }
36    }
37
38    fn get_jwks(
39        self,
40        url: Url,
41        now: SystemTime,
42    ) -> impl Future<Output = Result<(JwkSet, SystemTime), Self::Error>> + Send + Sync + 'static;
43}
44
45impl JwksSource for reqwest::Client {
46    type Error = reqwest::Error;
47
48    async fn get_jwks(
49        self,
50        url: Url,
51        now: SystemTime,
52    ) -> Result<(JwkSet, SystemTime), Self::Error> {
53        let req = reqwest::Request::new(http::Method::GET, url.clone());
54        let res = reqwest::Client::builder()
55            .build()?
56            .execute(
57                // safety: because we control the request creation we can ensure its not a stateful stream and can be copied at all times
58                req.try_clone().expect("Request should be always copyable"),
59            )
60            .await?
61            .error_for_status()?;
62
63        let expiration = get_expiration(now, &req, &res);
64        let jwks: JwkSet = res.json().await?;
65
66        Ok((jwks, expiration))
67    }
68}
69
70/// State machine of the JWKS cache
71#[derive(Debug, Clone, Default)]
72enum JWKSCache {
73    /// There is no data in cache, this is initial state
74    #[default]
75    Empty,
76    /// Cache is empty or expired, fetching of new content is ongoing.
77    /// Contains handle for awaiting for fetching to conclude
78    Fetching(Arc<Notify>),
79    /// Cache is valid, but content is being refreshed in the background
80    Refreshing { expires: SystemTime, jwks: JwkSet },
81    /// Cache is populated, but needs to be revalidated before use
82    Fetched { expires: SystemTime, jwks: JwkSet },
83}
84
85#[derive(Debug, thiserror::Error)]
86pub enum RequestError<E: core::fmt::Debug> {
87    #[error("Client error: {0}")]
88    Client(E),
89    #[error("Timeout for request completion reached")]
90    Timeout,
91}
92
93impl<T: core::fmt::Debug> From<T> for RequestError<T> {
94    fn from(value: T) -> Self {
95        Self::Client(value)
96    }
97}
98
99#[derive(Debug, Clone, Copy)]
100pub struct TimeoutSpec {
101    /// How many times to retry on failure (timeout or client error)
102    pub retries: u8,
103    /// How long to wait for a single response before retrying
104    pub retry_after: Duration,
105    /// Waiting between retries
106    pub backoff: Duration,
107    /// Total time for completion before considering failure
108    pub deadline: Duration,
109}
110
111impl Default for TimeoutSpec {
112    fn default() -> Self {
113        Self {
114            retries: 0,
115            retry_after: Duration::from_secs(10),
116            backoff: Duration::ZERO,
117            deadline: Duration::from_secs(10),
118        }
119    }
120}
121
122#[derive(Clone)]
123pub struct CachedJWKS<S> {
124    jwks_url: Url,
125    update_period: Duration,
126    timeout_spec: TimeoutSpec,
127    cache_state: Arc<RwLock<JWKSCache>>,
128    source: S,
129}
130
131impl CachedJWKS<reqwest::Client> {
132    pub fn new(
133        jwks_url: Url,
134        // Period when to refresh in the background before expiration period
135        update_period: Duration,
136        timeout_spec: TimeoutSpec,
137    ) -> Result<Self, reqwest::Error> {
138        Ok(Self::from_source(
139            jwks_url,
140            update_period,
141            timeout_spec,
142            reqwest::Client::builder().build()?,
143        ))
144    }
145}
146
147impl<S: JwksSource> CachedJWKS<S> {
148    pub fn from_source(
149        jwks_url: Url,
150        update_period: Duration,
151        timeout_spec: TimeoutSpec,
152        source: S,
153    ) -> Self {
154        assert!(
155            update_period > timeout_spec.deadline,
156            "Update period should be greater than timeout deadline"
157        );
158
159        Self {
160            jwks_url,
161            update_period,
162            timeout_spec,
163            cache_state: Default::default(),
164            source,
165        }
166    }
167
168    async fn request(
169        source: S,
170        url: Url,
171        now: SystemTime,
172        timeout: TimeoutSpec,
173    ) -> Result<(JwkSet, SystemTime), RequestError<S::Error>> {
174        let perform = async {
175            let mut retries = 0u8;
176            loop {
177                match source
178                    .clone()
179                    .get_jwks_within_deadline(url.clone(), now, timeout.retry_after)
180                    .await
181                {
182                    Ok(res) => return Ok(res),
183                    Err(err) => {
184                        if retries == timeout.retries {
185                            return Err(err);
186                        } else {
187                            retries += 1;
188                            tokio::time::sleep(timeout.backoff).await;
189                            continue;
190                        }
191                    }
192                }
193            }
194        };
195
196        tokio::time::timeout(timeout.deadline, perform)
197            .await
198            .map_err(|_| RequestError::Timeout)?
199    }
200
201    async fn update_notify(
202        &self,
203        now: SystemTime,
204    ) -> Result<Option<JwkSet>, RequestError<S::Error>> {
205        let notifier = if let Some(mut cached_state) = self.cache_state.try_write() {
206            let notifier = Arc::new(Notify::new());
207
208            *cached_state = JWKSCache::Fetching(notifier.clone());
209
210            notifier
211        } else {
212            return Ok(None);
213        };
214
215        let result = Self::request(
216            self.source.clone(),
217            self.jwks_url.clone(),
218            now,
219            self.timeout_spec,
220        )
221        .await;
222
223        let result = {
224            let mut cached_state = self.cache_state.write();
225
226            match result {
227                Ok((jwks, expires)) => {
228                    *cached_state = JWKSCache::Fetched {
229                        expires,
230                        jwks: jwks.clone(),
231                    };
232
233                    Ok(Some(jwks))
234                }
235                // Could not fetch in time, let follow up request try again later
236                Err(err) => {
237                    *cached_state = JWKSCache::Empty;
238
239                    Err(err)
240                }
241            }
242        };
243
244        notifier.notify_waiters();
245
246        result
247    }
248
249    /// Trigger refresh of JWKS in the background when cached JWKS is stil valid but about to expire,
250    /// if process dies then we do not care if this completes
251    fn update_in_background(&self, now: SystemTime, old_jwks: JwkSet, old_expires: SystemTime) {
252        {
253            let mut cache_state = self.cache_state.write();
254
255            *cache_state = JWKSCache::Refreshing {
256                expires: old_expires,
257                jwks: old_jwks,
258            };
259        }
260
261        let cache_state = self.cache_state.clone();
262        let jwks_url = self.jwks_url.clone();
263        let timeout_spec = self.timeout_spec;
264        let source = self.source.clone();
265
266        tokio::spawn(async move {
267            let result = Self::request(source, jwks_url, now, timeout_spec).await;
268
269            if let Err(err) = &result {
270                log::error!("Error while refreshing JWKS in the background: {err:?}");
271            }
272
273            let mut cache_state = cache_state.write();
274
275            let new_state = match cache_state.to_owned() {
276                JWKSCache::Empty => match result {
277                    Ok((jwks, expires)) => JWKSCache::Fetched { expires, jwks },
278                    Err(_) => JWKSCache::Empty,
279                },
280                JWKSCache::Fetching(notify) => {
281                    if let Ok((jwks, expires)) = result {
282                        notify.notify_waiters();
283                        JWKSCache::Fetched { expires, jwks }
284                    } else {
285                        JWKSCache::Fetching(notify)
286                    }
287                }
288                JWKSCache::Refreshing { expires, jwks } => {
289                    if let Ok((jwks, expires)) = result {
290                        JWKSCache::Fetched { expires, jwks }
291                    } else {
292                        JWKSCache::Refreshing { expires, jwks }
293                    }
294                }
295                JWKSCache::Fetched { expires, jwks } => {
296                    if let Ok((jwks, expires)) = result {
297                        JWKSCache::Fetched { expires, jwks }
298                    } else {
299                        JWKSCache::Refreshing { expires, jwks }
300                    }
301                }
302            };
303
304            *cache_state = new_state;
305        });
306    }
307
308    pub async fn get(&self) -> Result<JwkSet, RequestError<S::Error>> {
309        let now = SystemTime::now();
310        loop {
311            let cached_state = self.cache_state.read().clone();
312
313            match cached_state {
314                JWKSCache::Empty => {
315                    if let Some(jwks) = self.update_notify(now).await? {
316                        return Ok(jwks);
317                    } else {
318                        // state changed since reading it, reload
319                        continue;
320                    }
321                }
322                JWKSCache::Fetching(notifier) => {
323                    notifier.notified().await;
324
325                    // we got notified about change in state, reload
326                    continue;
327                }
328                JWKSCache::Refreshing { expires: _, jwks } => {
329                    // Refresh mechanism should guarantee it will change the state before cache is no longer valid
330                    return Ok(jwks);
331                }
332                JWKSCache::Fetched { expires, jwks } => {
333                    if now >= expires {
334                        if let Some(jwks) = self.update_notify(now).await? {
335                            return Ok(jwks);
336                        } else {
337                            // state changed since reading it, reload
338                            continue;
339                        }
340                    }
341
342                    if now + self.update_period >= expires {
343                        self.update_in_background(now, jwks.clone(), expires);
344                    }
345
346                    return Ok(jwks);
347                }
348            }
349        }
350    }
351}