jsonwebtoken_jwks_cache/cache/
mod.rs1#[cfg(test)]
2mod test;
3
4use core::future::Future;
5use jsonwebtoken::jwk::JwkSet;
6use spin::RwLock;
7use std::sync::Arc;
8use std::time::{Duration, SystemTime};
9use tokio::sync::Notify;
10use url::Url;
11
12fn get_expiration(now: SystemTime, req: &reqwest::Request, res: &reqwest::Response) -> SystemTime {
13 now + http_cache_semantics::CachePolicy::new(req, res).time_to_live(now)
14}
15
16pub trait JwksSource: Clone + Send + Sync + 'static {
17 type Error: core::fmt::Debug + Send + Sync + 'static;
18
19 fn get_jwks_within_deadline(
20 self,
21 url: Url,
22 now: SystemTime,
23 deadline: Duration,
24 ) -> impl Future<Output = Result<(JwkSet, SystemTime), RequestError<Self::Error>>>
25 + Send
26 + Sync
27 + 'static {
28 async move {
29 let result = tokio::time::timeout(deadline, self.get_jwks(url, now)).await;
30
31 match result {
32 Ok(res) => res.map_err(RequestError::Client),
33 Err(_) => Err(RequestError::Timeout),
34 }
35 }
36 }
37
38 fn get_jwks(
39 self,
40 url: Url,
41 now: SystemTime,
42 ) -> impl Future<Output = Result<(JwkSet, SystemTime), Self::Error>> + Send + Sync + 'static;
43}
44
45impl JwksSource for reqwest::Client {
46 type Error = reqwest::Error;
47
48 async fn get_jwks(
49 self,
50 url: Url,
51 now: SystemTime,
52 ) -> Result<(JwkSet, SystemTime), Self::Error> {
53 let req = reqwest::Request::new(http::Method::GET, url.clone());
54 let res = reqwest::Client::builder()
55 .build()?
56 .execute(
57 req.try_clone().expect("Request should be always copyable"),
59 )
60 .await?
61 .error_for_status()?;
62
63 let expiration = get_expiration(now, &req, &res);
64 let jwks: JwkSet = res.json().await?;
65
66 Ok((jwks, expiration))
67 }
68}
69
70#[derive(Debug, Clone, Default)]
72enum JWKSCache {
73 #[default]
75 Empty,
76 Fetching(Arc<Notify>),
79 Refreshing { expires: SystemTime, jwks: JwkSet },
81 Fetched { expires: SystemTime, jwks: JwkSet },
83}
84
85#[derive(Debug, thiserror::Error)]
86pub enum RequestError<E: core::fmt::Debug> {
87 #[error("Client error: {0}")]
88 Client(E),
89 #[error("Timeout for request completion reached")]
90 Timeout,
91}
92
93impl<T: core::fmt::Debug> From<T> for RequestError<T> {
94 fn from(value: T) -> Self {
95 Self::Client(value)
96 }
97}
98
99#[derive(Debug, Clone, Copy)]
100pub struct TimeoutSpec {
101 pub retries: u8,
103 pub retry_after: Duration,
105 pub backoff: Duration,
107 pub deadline: Duration,
109}
110
111impl Default for TimeoutSpec {
112 fn default() -> Self {
113 Self {
114 retries: 0,
115 retry_after: Duration::from_secs(10),
116 backoff: Duration::ZERO,
117 deadline: Duration::from_secs(10),
118 }
119 }
120}
121
122#[derive(Clone)]
123pub struct CachedJWKS<S> {
124 jwks_url: Url,
125 update_period: Duration,
126 timeout_spec: TimeoutSpec,
127 cache_state: Arc<RwLock<JWKSCache>>,
128 source: S,
129}
130
131impl CachedJWKS<reqwest::Client> {
132 pub fn new(
133 jwks_url: Url,
134 update_period: Duration,
136 timeout_spec: TimeoutSpec,
137 ) -> Result<Self, reqwest::Error> {
138 Ok(Self::from_source(
139 jwks_url,
140 update_period,
141 timeout_spec,
142 reqwest::Client::builder().build()?,
143 ))
144 }
145}
146
147impl<S: JwksSource> CachedJWKS<S> {
148 pub fn from_source(
149 jwks_url: Url,
150 update_period: Duration,
151 timeout_spec: TimeoutSpec,
152 source: S,
153 ) -> Self {
154 assert!(
155 update_period > timeout_spec.deadline,
156 "Update period should be greater than timeout deadline"
157 );
158
159 Self {
160 jwks_url,
161 update_period,
162 timeout_spec,
163 cache_state: Default::default(),
164 source,
165 }
166 }
167
168 async fn request(
169 source: S,
170 url: Url,
171 now: SystemTime,
172 timeout: TimeoutSpec,
173 ) -> Result<(JwkSet, SystemTime), RequestError<S::Error>> {
174 let perform = async {
175 let mut retries = 0u8;
176 loop {
177 match source
178 .clone()
179 .get_jwks_within_deadline(url.clone(), now, timeout.retry_after)
180 .await
181 {
182 Ok(res) => return Ok(res),
183 Err(err) => {
184 if retries == timeout.retries {
185 return Err(err);
186 } else {
187 retries += 1;
188 tokio::time::sleep(timeout.backoff).await;
189 continue;
190 }
191 }
192 }
193 }
194 };
195
196 tokio::time::timeout(timeout.deadline, perform)
197 .await
198 .map_err(|_| RequestError::Timeout)?
199 }
200
201 async fn update_notify(
202 &self,
203 now: SystemTime,
204 ) -> Result<Option<JwkSet>, RequestError<S::Error>> {
205 let notifier = if let Some(mut cached_state) = self.cache_state.try_write() {
206 let notifier = Arc::new(Notify::new());
207
208 *cached_state = JWKSCache::Fetching(notifier.clone());
209
210 notifier
211 } else {
212 return Ok(None);
213 };
214
215 let result = Self::request(
216 self.source.clone(),
217 self.jwks_url.clone(),
218 now,
219 self.timeout_spec,
220 )
221 .await;
222
223 let result = {
224 let mut cached_state = self.cache_state.write();
225
226 match result {
227 Ok((jwks, expires)) => {
228 *cached_state = JWKSCache::Fetched {
229 expires,
230 jwks: jwks.clone(),
231 };
232
233 Ok(Some(jwks))
234 }
235 Err(err) => {
237 *cached_state = JWKSCache::Empty;
238
239 Err(err)
240 }
241 }
242 };
243
244 notifier.notify_waiters();
245
246 result
247 }
248
249 fn update_in_background(&self, now: SystemTime, old_jwks: JwkSet, old_expires: SystemTime) {
252 {
253 let mut cache_state = self.cache_state.write();
254
255 *cache_state = JWKSCache::Refreshing {
256 expires: old_expires,
257 jwks: old_jwks,
258 };
259 }
260
261 let cache_state = self.cache_state.clone();
262 let jwks_url = self.jwks_url.clone();
263 let timeout_spec = self.timeout_spec;
264 let source = self.source.clone();
265
266 tokio::spawn(async move {
267 let result = Self::request(source, jwks_url, now, timeout_spec).await;
268
269 if let Err(err) = &result {
270 log::error!("Error while refreshing JWKS in the background: {err:?}");
271 }
272
273 let mut cache_state = cache_state.write();
274
275 let new_state = match cache_state.to_owned() {
276 JWKSCache::Empty => match result {
277 Ok((jwks, expires)) => JWKSCache::Fetched { expires, jwks },
278 Err(_) => JWKSCache::Empty,
279 },
280 JWKSCache::Fetching(notify) => {
281 if let Ok((jwks, expires)) = result {
282 notify.notify_waiters();
283 JWKSCache::Fetched { expires, jwks }
284 } else {
285 JWKSCache::Fetching(notify)
286 }
287 }
288 JWKSCache::Refreshing { expires, jwks } => {
289 if let Ok((jwks, expires)) = result {
290 JWKSCache::Fetched { expires, jwks }
291 } else {
292 JWKSCache::Refreshing { expires, jwks }
293 }
294 }
295 JWKSCache::Fetched { expires, jwks } => {
296 if let Ok((jwks, expires)) = result {
297 JWKSCache::Fetched { expires, jwks }
298 } else {
299 JWKSCache::Refreshing { expires, jwks }
300 }
301 }
302 };
303
304 *cache_state = new_state;
305 });
306 }
307
308 pub async fn get(&self) -> Result<JwkSet, RequestError<S::Error>> {
309 let now = SystemTime::now();
310 loop {
311 let cached_state = self.cache_state.read().clone();
312
313 match cached_state {
314 JWKSCache::Empty => {
315 if let Some(jwks) = self.update_notify(now).await? {
316 return Ok(jwks);
317 } else {
318 continue;
320 }
321 }
322 JWKSCache::Fetching(notifier) => {
323 notifier.notified().await;
324
325 continue;
327 }
328 JWKSCache::Refreshing { expires: _, jwks } => {
329 return Ok(jwks);
331 }
332 JWKSCache::Fetched { expires, jwks } => {
333 if now >= expires {
334 if let Some(jwks) = self.update_notify(now).await? {
335 return Ok(jwks);
336 } else {
337 continue;
339 }
340 }
341
342 if now + self.update_period >= expires {
343 self.update_in_background(now, jwks.clone(), expires);
344 }
345
346 return Ok(jwks);
347 }
348 }
349 }
350 }
351}