jsonwebtoken_jwks_cache/cache/
mod.rs1#[cfg(test)]
2mod test;
3
4use super::pem_set::PemMap;
5use core::future::Future;
6use jsonwebtoken::jwk::JwkSet;
7use spin::RwLock;
8use std::sync::Arc;
9use std::time::{Duration, SystemTime};
10use tokio::sync::Notify;
11use url::Url;
12
13fn get_expiration(now: SystemTime, req: &reqwest::Request, res: &reqwest::Response) -> SystemTime {
14 now + http_cache_semantics::CachePolicy::new(req, res).time_to_live(now)
15}
16
17pub trait JwksSource: Clone + Send + Sync + 'static {
18 type Error: core::fmt::Debug + Send + Sync + 'static;
19
20 fn get_jwks_within_deadline(
21 self,
22 url: Url,
23 as_pkeys: bool,
24 now: SystemTime,
25 deadline: Duration,
26 ) -> impl Future<Output = Result<(JwkSet, SystemTime), RequestError<Self::Error>>>
27 + Send
28 + Sync
29 + 'static {
30 async move {
31 let result = tokio::time::timeout(deadline, self.get_jwks(url, as_pkeys, now)).await;
32
33 match result {
34 Ok(res) => res.map_err(RequestError::Client),
35 Err(_) => Err(RequestError::Timeout),
36 }
37 }
38 }
39
40 fn get_jwks(
41 self,
42 url: Url,
43 as_pkeys: bool,
44 now: SystemTime,
45 ) -> impl Future<Output = Result<(JwkSet, SystemTime), Self::Error>> + Send + Sync + 'static;
46}
47
48impl JwksSource for reqwest::Client {
49 type Error = reqwest::Error;
50
51 async fn get_jwks(
52 self,
53 url: Url,
54 as_pkeys: bool,
55 now: SystemTime,
56 ) -> Result<(JwkSet, SystemTime), Self::Error> {
57 let req = reqwest::Request::new(http::Method::GET, url.clone());
58 let res = reqwest::Client::builder()
59 .build()?
60 .execute(
61 req.try_clone().expect("Request should be always copyable"),
63 )
64 .await?
65 .error_for_status()?;
66
67 let expiration = get_expiration(now, &req, &res);
68 let jwks = if as_pkeys {
69 res.json::<PemMap>().await?.into_rsa_jwk_set()
70 } else {
71 res.json::<JwkSet>().await?
72 };
73
74 Ok((jwks, expiration))
75 }
76}
77
78#[derive(Debug, Clone, Default)]
80enum JWKSCache {
81 #[default]
83 Empty,
84 Fetching(Arc<Notify>),
87 Refreshing { expires: SystemTime, jwks: JwkSet },
89 Fetched { expires: SystemTime, jwks: JwkSet },
91}
92
93#[derive(Debug, thiserror::Error)]
94pub enum RequestError<E: core::fmt::Debug> {
95 #[error("Client error: {0}")]
96 Client(E),
97 #[error("Timeout for request completion reached")]
98 Timeout,
99}
100
101impl<T: core::fmt::Debug> From<T> for RequestError<T> {
102 fn from(value: T) -> Self {
103 Self::Client(value)
104 }
105}
106
107#[derive(Debug, Clone, Copy)]
108pub struct TimeoutSpec {
109 pub retries: u8,
111 pub retry_after: Duration,
113 pub backoff: Duration,
115 pub deadline: Duration,
117}
118
119impl Default for TimeoutSpec {
120 fn default() -> Self {
121 Self {
122 retries: 0,
123 retry_after: Duration::from_secs(10),
124 backoff: Duration::ZERO,
125 deadline: Duration::from_secs(10),
126 }
127 }
128}
129
130#[derive(Clone)]
131pub struct CachedJWKS<S> {
132 jwks_url: Url,
133 pkeys: bool,
134 update_period: Duration,
135 timeout_spec: TimeoutSpec,
136 cache_state: Arc<RwLock<JWKSCache>>,
137 source: S,
138}
139
140impl CachedJWKS<reqwest::Client> {
141 pub fn new(
142 jwks_url: Url,
143 update_period: Duration,
145 timeout_spec: TimeoutSpec,
146 ) -> Result<Self, reqwest::Error> {
147 Ok(Self::from_source(
148 jwks_url,
149 false,
150 update_period,
151 timeout_spec,
152 reqwest::Client::builder().build()?,
153 ))
154 }
155
156 pub fn new_rsa_pkeys(
158 pkeys_url: Url,
159 update_period: Duration,
161 timeout_spec: TimeoutSpec,
162 ) -> Result<Self, reqwest::Error> {
163 Ok(Self::from_source(
164 pkeys_url,
165 true,
166 update_period,
167 timeout_spec,
168 reqwest::Client::builder().build()?,
169 ))
170 }
171}
172
173impl<S: JwksSource> CachedJWKS<S> {
174 pub fn from_source(
175 jwks_url: Url,
176 pkeys: bool,
177 update_period: Duration,
178 timeout_spec: TimeoutSpec,
179 source: S,
180 ) -> Self {
181 assert!(
182 update_period > timeout_spec.deadline,
183 "Update period should be greater than timeout deadline"
184 );
185
186 Self {
187 jwks_url,
188 pkeys,
189 update_period,
190 timeout_spec,
191 cache_state: Default::default(),
192 source,
193 }
194 }
195
196 async fn request(
197 source: S,
198 url: Url,
199 as_pkeys: bool,
200 now: SystemTime,
201 timeout: TimeoutSpec,
202 ) -> Result<(JwkSet, SystemTime), RequestError<S::Error>> {
203 let perform = async {
204 let mut retries = 0u8;
205 loop {
206 match source
207 .clone()
208 .get_jwks_within_deadline(url.clone(), as_pkeys, now, timeout.retry_after)
209 .await
210 {
211 Ok(res) => return Ok(res),
212 Err(err) => {
213 if retries == timeout.retries {
214 return Err(err);
215 } else {
216 retries += 1;
217 tokio::time::sleep(timeout.backoff).await;
218 continue;
219 }
220 }
221 }
222 }
223 };
224
225 tokio::time::timeout(timeout.deadline, perform)
226 .await
227 .map_err(|_| RequestError::Timeout)?
228 }
229
230 async fn update_notify(
231 &self,
232 now: SystemTime,
233 ) -> Result<Option<JwkSet>, RequestError<S::Error>> {
234 let notifier = if let Some(mut cached_state) = self.cache_state.try_write() {
235 let notifier = Arc::new(Notify::new());
236
237 *cached_state = JWKSCache::Fetching(notifier.clone());
238
239 notifier
240 } else {
241 return Ok(None);
242 };
243
244 let result = Self::request(
245 self.source.clone(),
246 self.jwks_url.clone(),
247 self.pkeys,
248 now,
249 self.timeout_spec,
250 )
251 .await;
252
253 let result = {
254 let mut cached_state = self.cache_state.write();
255
256 match result {
257 Ok((jwks, expires)) => {
258 *cached_state = JWKSCache::Fetched {
259 expires,
260 jwks: jwks.clone(),
261 };
262
263 Ok(Some(jwks))
264 }
265 Err(err) => {
267 *cached_state = JWKSCache::Empty;
268
269 Err(err)
270 }
271 }
272 };
273
274 notifier.notify_waiters();
275
276 result
277 }
278
279 fn update_in_background(&self, now: SystemTime, old_jwks: JwkSet, old_expires: SystemTime) {
282 {
283 let mut cache_state = self.cache_state.write();
284
285 *cache_state = JWKSCache::Refreshing {
286 expires: old_expires,
287 jwks: old_jwks,
288 };
289 }
290
291 let cache_state = self.cache_state.clone();
292 let jwks_url = self.jwks_url.clone();
293 let timeout_spec = self.timeout_spec;
294 let source = self.source.clone();
295 let as_pkeys = self.pkeys;
296
297 tokio::spawn(async move {
298 let result = Self::request(source, jwks_url, as_pkeys, now, timeout_spec).await;
299
300 if let Err(err) = &result {
301 log::error!("Error while refreshing JWKS in the background: {err:?}");
302 }
303
304 let mut cache_state = cache_state.write();
305
306 let new_state = match cache_state.to_owned() {
307 JWKSCache::Empty => match result {
308 Ok((jwks, expires)) => JWKSCache::Fetched { expires, jwks },
309 Err(_) => JWKSCache::Empty,
310 },
311 JWKSCache::Fetching(notify) => {
312 if let Ok((jwks, expires)) = result {
313 notify.notify_waiters();
314 JWKSCache::Fetched { expires, jwks }
315 } else {
316 JWKSCache::Fetching(notify)
317 }
318 }
319 JWKSCache::Refreshing { expires, jwks } => {
320 if let Ok((jwks, expires)) = result {
321 JWKSCache::Fetched { expires, jwks }
322 } else {
323 JWKSCache::Refreshing { expires, jwks }
324 }
325 }
326 JWKSCache::Fetched { expires, jwks } => {
327 if let Ok((jwks, expires)) = result {
328 JWKSCache::Fetched { expires, jwks }
329 } else {
330 JWKSCache::Refreshing { expires, jwks }
331 }
332 }
333 };
334
335 *cache_state = new_state;
336 });
337 }
338
339 pub async fn get(&self) -> Result<JwkSet, RequestError<S::Error>> {
340 let now = SystemTime::now();
341 loop {
342 let cached_state = self.cache_state.read().clone();
343
344 match cached_state {
345 JWKSCache::Empty => {
346 if let Some(jwks) = self.update_notify(now).await? {
347 return Ok(jwks);
348 } else {
349 continue;
351 }
352 }
353 JWKSCache::Fetching(notifier) => {
354 notifier.notified().await;
355
356 continue;
358 }
359 JWKSCache::Refreshing { expires: _, jwks } => {
360 return Ok(jwks);
362 }
363 JWKSCache::Fetched { expires, jwks } => {
364 if now >= expires {
365 if let Some(jwks) = self.update_notify(now).await? {
366 return Ok(jwks);
367 } else {
368 continue;
370 }
371 }
372
373 if now + self.update_period >= expires {
374 self.update_in_background(now, jwks.clone(), expires);
375 }
376
377 return Ok(jwks);
378 }
379 }
380 }
381 }
382}