jwks_cache/cache/
manager.rs

1//! Cache manager handling JWKS retrieval and lifecycle.
2
3// crates.io
4use http::{
5	HeaderName, HeaderValue, Request, Response,
6	header::{ETAG, IF_NONE_MATCH, LAST_MODIFIED},
7};
8use http_cache_semantics::BeforeRequest;
9#[cfg(feature = "redis")] use http_cache_semantics::CachePolicy;
10use jsonwebtoken::jwk::JwkSet;
11use rand::Rng;
12use reqwest::{Client, redirect::Policy};
13use tokio::{
14	sync::{Mutex, RwLock},
15	time,
16};
17// self
18#[cfg(feature = "redis")] use crate::registry::PersistentSnapshot;
19use crate::{
20	_prelude::*,
21	cache::{
22		entry::CacheEntry,
23		state::{CachePayload, CacheState},
24	},
25	http::{
26		client::fetch_jwks,
27		retry::{AttemptBudget, RetryExecutor},
28		semantics::{Freshness, base_request, evaluate_freshness, evaluate_revalidation},
29	},
30	metrics::{self, ProviderMetrics},
31	registry::IdentityProviderRegistration,
32};
33
34/// Coordinates fetching, caching, and background refresh for a registration.
35///
36/// Instances are scoped per tenant/provider pair; the single-flight guard only
37/// serialises refresh work for that specific provider.
38#[derive(Clone, Debug)]
39pub struct CacheManager {
40	registration: Arc<IdentityProviderRegistration>,
41	client: Arc<Client>,
42	entry: Arc<RwLock<CacheEntry>>,
43	single_flight: Arc<Mutex<()>>,
44	metrics: Arc<ProviderMetrics>,
45}
46impl CacheManager {
47	/// Build a new cache manager with the default reqwest client.
48	pub fn new(registration: IdentityProviderRegistration) -> Result<Self> {
49		registration.validate()?;
50
51		let client = Client::builder()
52			.redirect(Policy::limited(10))
53			.user_agent(format!("jwks-cache/{}", env!("CARGO_PKG_VERSION")))
54			.connect_timeout(Duration::from_secs(5))
55			.build()?;
56
57		Ok(Self::with_parts(registration, client, ProviderMetrics::new()))
58	}
59
60	/// Build a cache manager using the supplied HTTP client (primarily for tests).
61	pub fn with_client(registration: IdentityProviderRegistration, client: Client) -> Self {
62		Self::with_parts(registration, client, ProviderMetrics::new())
63	}
64
65	fn with_parts(
66		registration: IdentityProviderRegistration,
67		client: Client,
68		metrics: Arc<ProviderMetrics>,
69	) -> Self {
70		let tenant = registration.tenant_id.clone();
71		let provider = registration.provider_id.clone();
72
73		Self {
74			registration: Arc::new(registration),
75			client: Arc::new(client),
76			entry: Arc::new(RwLock::new(CacheEntry::new(tenant, provider))),
77			single_flight: Arc::new(Mutex::new(())),
78			metrics,
79		}
80	}
81
82	/// Access the per-provider metrics accumulator.
83	pub fn metrics(&self) -> Arc<ProviderMetrics> {
84		self.metrics.clone()
85	}
86
87	/// Capture the current cache state for status reporting.
88	pub async fn snapshot(&self) -> CacheSnapshot {
89		let captured_at = Instant::now();
90		let captured_at_wallclock = Utc::now();
91		let state = { self.entry.read().await.state().clone() };
92
93		CacheSnapshot { captured_at, captured_at_wallclock, state }
94	}
95
96	#[cfg(feature = "redis")]
97	/// Build a persistence payload capturing the current cache contents.
98	pub async fn persistent_snapshot(&self) -> Result<Option<PersistentSnapshot>> {
99		let snapshot = self.snapshot().await;
100		let payload = match snapshot.state {
101			CacheState::Ready(ref payload) | CacheState::Refreshing(ref payload) => payload.clone(),
102			_ => return Ok(None),
103		};
104		let expires_at = match snapshot.to_datetime(payload.expires_at) {
105			Some(dt) => dt,
106			None => return Ok(None),
107		};
108		let jwks_json = serde_json::to_string(&*payload.jwks)?;
109		let persisted_at = Utc::now();
110		let snapshot = PersistentSnapshot {
111			tenant_id: self.registration.tenant_id.clone(),
112			provider_id: self.registration.provider_id.clone(),
113			jwks_json,
114			etag: payload.etag.clone(),
115			last_modified: payload.last_modified,
116			expires_at,
117			persisted_at,
118		};
119
120		Ok(Some(snapshot))
121	}
122
123	#[cfg(feature = "redis")]
124	/// Restore cache state from a previously persisted snapshot.
125	pub async fn restore_snapshot(&self, snapshot: PersistentSnapshot) -> Result<()> {
126		snapshot.validate(&self.registration)?;
127
128		let PersistentSnapshot { jwks_json, etag, last_modified, expires_at, persisted_at, .. } =
129			snapshot;
130		let jwks: JwkSet = serde_json::from_str(&jwks_json)?;
131		let jwks = Arc::new(jwks);
132		let ttl = (expires_at - persisted_at)
133			.to_std()
134			.unwrap_or_default()
135			.max(self.registration.min_ttl)
136			.min(self.registration.max_ttl);
137		let request = base_request(&self.registration)?;
138		let mut response = Response::builder()
139			.status(200)
140			.header("cache-control", format!("public, max-age={}", ttl.as_secs()))
141			.header("content-type", "application/json")
142			.body(())
143			.map_err(Error::from)?;
144
145		if let Some(ref etag_value) = etag {
146			let value = HeaderValue::from_str(etag_value).map_err(|err| Error::Validation {
147				field: "etag",
148				reason: format!("Invalid persisted ETag: {err}."),
149			})?;
150
151			response.headers_mut().insert(ETAG, value);
152		}
153		if let Some(ref last_modified_value) = last_modified {
154			let http_date = httpdate::fmt_http_date((*last_modified_value).into());
155			let value = HeaderValue::from_str(&http_date).map_err(|err| Error::Validation {
156				field: "last_modified",
157				reason: format!("Invalid persisted Last-Modified: {err}."),
158			})?;
159
160			response.headers_mut().insert(LAST_MODIFIED, value);
161		}
162
163		let policy = CachePolicy::new(&request, &response);
164		let freshness = Freshness { ttl, policy };
165		let now = Instant::now();
166		let payload = self.build_payload(jwks, freshness, etag, last_modified, now, persisted_at);
167
168		{
169			let mut entry = self.entry.write().await;
170
171			entry.load_success(payload.clone());
172		}
173
174		tracing::debug!(
175			tenant = %self.registration.tenant_id,
176			provider = %self.registration.provider_id,
177			"restored cache entry from persistent snapshot"
178		);
179
180		Ok(())
181	}
182
183	/// Resolve JWKS for the registration, fetching upstream when necessary.
184	#[tracing::instrument(
185		skip(self, kid),
186		fields(
187			tenant = %self.registration.tenant_id,
188			provider = %self.registration.provider_id,
189			kid = kid.unwrap_or_default()
190		)
191	)]
192	pub async fn resolve(&self, kid: Option<&str>) -> Result<Arc<JwkSet>> {
193		loop {
194			let snapshot = { self.entry.read().await.snapshot() };
195			let now = Instant::now();
196
197			match snapshot {
198				None => {
199					tracing::debug!("cache empty; performing initial fetch");
200
201					match self.refresh_blocking(true).await? {
202						RefreshOutcome::Updated { jwks, from_cache } => {
203							if from_cache {
204								self.observe_hit(false);
205							} else {
206								self.observe_miss();
207							}
208
209							return Ok(jwks);
210						},
211						RefreshOutcome::Stale(jwks) => {
212							self.observe_hit(true);
213
214							return Ok(jwks);
215						},
216					}
217				},
218				Some(payload) => {
219					if !payload.is_expired(now) {
220						let jwks = payload.jwks.clone();
221
222						self.observe_hit(false);
223
224						if now >= payload.next_refresh_at {
225							self.schedule_background_refresh(now).await;
226						}
227
228						return Ok(jwks);
229					}
230
231					if payload.can_serve_stale(now) {
232						// TODO(refactor): consolidate stale fallback with perform_fetch_with_retry
233						// once the helper can orchestrate stale responses directly.
234						match self.refresh_blocking(false).await {
235							Ok(RefreshOutcome::Updated { jwks, from_cache }) => {
236								if from_cache {
237									self.observe_hit(false);
238								} else {
239									self.observe_miss();
240								}
241
242								return Ok(jwks);
243							},
244							Ok(RefreshOutcome::Stale(jwks)) => {
245								self.observe_hit(true);
246
247								return Ok(jwks);
248							},
249							Err(err) =>
250								if payload.can_serve_stale(Instant::now()) {
251									tracing::warn!(error = %err, "refresh failed, serving stale data");
252
253									self.observe_hit(true);
254
255									return Ok(payload.jwks.clone());
256								} else {
257									return Err(err);
258								},
259						}
260					} else if let RefreshOutcome::Updated { jwks, from_cache } =
261						self.refresh_blocking(true).await?
262					{
263						if from_cache {
264							self.observe_hit(false);
265						} else {
266							self.observe_miss();
267						}
268						return Ok(jwks);
269					}
270				},
271			}
272		}
273	}
274
275	/// Trigger a manual refresh asynchronously; used by the control plane.
276	#[tracing::instrument(
277		skip(self),
278		fields(tenant = %self.registration.tenant_id, provider = %self.registration.provider_id)
279	)]
280	pub async fn trigger_refresh(&self) -> Result<()> {
281		let now = Instant::now();
282		let action = {
283			let mut entry = self.entry.write().await;
284
285			match entry.state() {
286				CacheState::Empty => {
287					entry.begin_load();
288					RefreshTrigger::Blocking
289				},
290				CacheState::Loading | CacheState::Refreshing(_) => RefreshTrigger::None,
291				CacheState::Ready(_) =>
292					if entry.begin_refresh(now) {
293						RefreshTrigger::Background
294					} else {
295						RefreshTrigger::None
296					},
297			}
298		};
299
300		match action {
301			RefreshTrigger::Background => {
302				let manager = self.clone();
303
304				tokio::spawn(async move {
305					if let Err(err) = manager.refresh_blocking(true).await {
306						tracing::warn!(error = %err, "manual refresh failed");
307					}
308				});
309			},
310			RefreshTrigger::Blocking => {
311				self.refresh_blocking(true).await?;
312			},
313			RefreshTrigger::None => {},
314		}
315
316		Ok(())
317	}
318
319	#[tracing::instrument(
320		skip(self),
321		fields(tenant = %self.registration.tenant_id, provider = %self.registration.provider_id)
322	)]
323	async fn schedule_background_refresh(&self, now: Instant) {
324		let should_spawn = {
325			let mut entry = self.entry.write().await;
326
327			entry.begin_refresh(now)
328		};
329		if should_spawn {
330			let manager = self.clone();
331
332			tokio::spawn(async move {
333				if let Err(err) = manager.refresh_blocking(true).await {
334					tracing::debug!(error = %err, "background refresh failed");
335				}
336			});
337		}
338	}
339
340	#[tracing::instrument(
341		skip(self, force_revalidation),
342		fields(tenant = %self.registration.tenant_id, provider = %self.registration.provider_id, force_revalidation)
343	)]
344	async fn refresh_blocking(&self, force_revalidation: bool) -> Result<RefreshOutcome> {
345		let _guard = self.single_flight.lock().await;
346		let now = Instant::now();
347		let (existing, mode) = {
348			let mut entry = self.entry.write().await;
349			let snapshot = entry.snapshot();
350			let mode = if snapshot.is_some() {
351				entry.begin_refresh(now);
352
353				FetchMode::Refresh
354			} else {
355				entry.begin_load();
356
357				FetchMode::Initial
358			};
359
360			(snapshot, mode)
361		};
362
363		match self.prepare_request(existing.as_ref(), force_revalidation)? {
364			PreparedRequest::UseCached { jwks } =>
365				Ok(RefreshOutcome::Updated { jwks, from_cache: true }),
366			PreparedRequest::Send(request) =>
367				self.perform_fetch_with_retry(*request, existing, mode, force_revalidation).await,
368		}
369	}
370
371	fn prepare_request(
372		&self,
373		existing: Option<&CachePayload>,
374		force_revalidation: bool,
375	) -> Result<PreparedRequest> {
376		let mut request = base_request(&self.registration)?;
377
378		if let Some(payload) = existing {
379			let mut send_conditional = force_revalidation;
380
381			match payload.policy.before_request(&request, SystemTime::now()) {
382				BeforeRequest::Fresh(_) if !force_revalidation => {
383					return Ok(PreparedRequest::UseCached { jwks: payload.jwks.clone() });
384				},
385				BeforeRequest::Stale { request: parts, matches } if matches => {
386					request = Request::from_parts(parts, ());
387					send_conditional = true;
388				},
389				_ => {},
390			}
391
392			if send_conditional
393				&& let Some(etag) = &payload.etag
394				&& let Ok(value) = HeaderValue::from_str(etag)
395			{
396				request.headers_mut().insert(IF_NONE_MATCH, value);
397			}
398		}
399
400		Ok(PreparedRequest::Send(Box::new(request)))
401	}
402
403	async fn perform_fetch_with_retry(
404		&self,
405		request: Request<()>,
406		existing: Option<CachePayload>,
407		mode: FetchMode,
408		force_revalidation: bool,
409	) -> Result<RefreshOutcome> {
410		let mut executor = RetryExecutor::new(&self.registration.retry_policy);
411		let mut last_error: Option<Error> = None;
412		let mut last_backoff: Option<Duration> = None;
413		let request = request;
414
415		while let AttemptBudget::Granted { timeout } = executor.attempt_budget() {
416			let attempt_started = Instant::now();
417			let fetch = fetch_jwks(&self.client, &self.registration, &request, timeout).await;
418
419			match fetch {
420				Ok(fetch) => {
421					let now = Instant::now();
422					let payload = match (&fetch.jwks, existing.as_ref()) {
423						(Some(fresh_jwks), _) => {
424							let freshness =
425								evaluate_freshness(&self.registration, &fetch.exchange)?;
426
427							self.build_payload(
428								fresh_jwks.clone(),
429								freshness,
430								fetch.etag.clone(),
431								fetch.last_modified,
432								now,
433								Utc::now(),
434							)
435						},
436						(None, Some(previous)) => {
437							let revalidation = evaluate_revalidation(
438								&self.registration,
439								&previous.policy,
440								&fetch.exchange.request,
441								&fetch.exchange.response,
442							)?;
443							let updated_etag = extract_header(&revalidation.response, &ETAG)
444								.or_else(|| previous.etag.clone());
445
446							self.build_payload(
447								previous.jwks.clone(),
448								revalidation.freshness,
449								updated_etag,
450								extract_last_modified(&revalidation.response)
451									.or(previous.last_modified),
452								now,
453								Utc::now(),
454							)
455						},
456						(None, None) => {
457							return Err(Error::Cache(
458								"Received 304 status without a cached payload.".into(),
459							));
460						},
461					};
462
463					let jwks = payload.jwks.clone();
464
465					self.commit_success(mode, payload).await;
466					self.observe_refresh_success(attempt_started.elapsed());
467
468					return Ok(RefreshOutcome::Updated { jwks, from_cache: false });
469				},
470				Err(err) => {
471					last_error = Some(err);
472
473					if !executor.can_retry() {
474						break;
475					}
476
477					if let Some(delay) = executor.next_backoff() {
478						last_backoff = Some(delay);
479
480						if !delay.is_zero() {
481							time::sleep(delay).await;
482						}
483						continue;
484					}
485
486					break;
487				},
488			}
489		}
490
491		let now = Instant::now();
492
493		match mode {
494			FetchMode::Initial => {
495				let mut entry = self.entry.write().await;
496
497				entry.invalidate();
498			},
499			FetchMode::Refresh => {
500				let mut entry = self.entry.write().await;
501
502				entry.refresh_failure(now, last_backoff);
503			},
504		}
505
506		self.observe_refresh_error();
507
508		if !force_revalidation
509			&& let Some(payload) = existing
510			&& payload.can_serve_stale(now)
511		{
512			return Ok(RefreshOutcome::Stale(payload.jwks));
513		}
514
515		Err(last_error.unwrap_or_else(|| Error::Cache("Refresh attempts exhausted.".into())))
516	}
517
518	async fn commit_success(&self, mode: FetchMode, payload: CachePayload) {
519		let mut entry = self.entry.write().await;
520
521		match mode {
522			FetchMode::Initial => entry.load_success(payload),
523			FetchMode::Refresh => entry.refresh_success(payload),
524		}
525	}
526
527	fn build_payload(
528		&self,
529		jwks: Arc<JwkSet>,
530		freshness: Freshness,
531		etag: Option<String>,
532		last_modified: Option<DateTime<Utc>>,
533		now: Instant,
534		refreshed_at: DateTime<Utc>,
535	) -> CachePayload {
536		let ttl = freshness.ttl;
537		let expires_at = now + ttl;
538		let mut refresh_at = if self.registration.refresh_early >= ttl {
539			now
540		} else {
541			expires_at - self.registration.refresh_early
542		};
543
544		if !self.registration.prefetch_jitter.is_zero() {
545			let jitter = random_jitter(self.registration.prefetch_jitter);
546
547			if refresh_at > now + jitter {
548				refresh_at -= jitter;
549			}
550		}
551
552		let stale_deadline = if self.registration.stale_while_error.is_zero() {
553			None
554		} else {
555			Some(expires_at + self.registration.stale_while_error)
556		};
557
558		CachePayload {
559			jwks,
560			policy: freshness.policy,
561			etag,
562			last_modified,
563			last_refresh_at: refreshed_at,
564			expires_at,
565			next_refresh_at: refresh_at,
566			stale_deadline,
567			retry_backoff: None,
568			error_count: 0,
569		}
570	}
571
572	fn observe_hit(&self, stale: bool) {
573		let tenant = &self.registration.tenant_id;
574		let provider = &self.registration.provider_id;
575
576		metrics::record_resolve_hit(tenant, provider, stale);
577
578		self.metrics.record_hit(stale);
579	}
580
581	fn observe_miss(&self) {
582		let tenant = &self.registration.tenant_id;
583		let provider = &self.registration.provider_id;
584
585		metrics::record_resolve_miss(tenant, provider);
586
587		self.metrics.record_miss();
588	}
589
590	fn observe_refresh_success(&self, duration: Duration) {
591		let tenant = &self.registration.tenant_id;
592		let provider = &self.registration.provider_id;
593
594		metrics::record_refresh_success(tenant, provider, duration);
595
596		self.metrics.record_refresh_success(duration);
597	}
598
599	fn observe_refresh_error(&self) {
600		let tenant = &self.registration.tenant_id;
601		let provider = &self.registration.provider_id;
602
603		metrics::record_refresh_error(tenant, provider);
604
605		self.metrics.record_refresh_error();
606	}
607}
608
609/// Snapshot of cache state captured for status reporting.
610#[derive(Clone, Debug)]
611pub struct CacheSnapshot {
612	/// Monotonic instant when the snapshot was taken.
613	pub captured_at: Instant,
614	/// Wall-clock timestamp that aligns with `captured_at`.
615	pub captured_at_wallclock: DateTime<Utc>,
616	/// Cache state recorded at capture time.
617	pub state: CacheState,
618}
619impl CacheSnapshot {
620	/// Convert a monotonic instant drawn from the cached payload into UTC.
621	pub fn to_datetime(&self, instant: Instant) -> Option<DateTime<Utc>> {
622		if let Some(delta) = instant.checked_duration_since(self.captured_at) {
623			let chrono = TimeDelta::from_std(delta).ok()?;
624
625			self.captured_at_wallclock.checked_add_signed(chrono)
626		} else if let Some(delta) = self.captured_at.checked_duration_since(instant) {
627			let chrono = TimeDelta::from_std(delta).ok()?;
628
629			self.captured_at_wallclock.checked_sub_signed(chrono)
630		} else {
631			None
632		}
633	}
634}
635
636#[derive(Clone, Copy, Debug)]
637enum FetchMode {
638	Initial,
639	Refresh,
640}
641
642#[derive(Debug)]
643enum RefreshOutcome {
644	Updated { jwks: Arc<JwkSet>, from_cache: bool },
645	Stale(Arc<JwkSet>),
646}
647
648#[derive(Clone, Copy, Debug)]
649enum RefreshTrigger {
650	Background,
651	Blocking,
652	None,
653}
654
655#[derive(Debug)]
656enum PreparedRequest {
657	UseCached { jwks: Arc<JwkSet> },
658	Send(Box<Request<()>>),
659}
660
661fn random_jitter(max: Duration) -> Duration {
662	if max.is_zero() {
663		return Duration::ZERO;
664	}
665
666	let mut rng = rand::rng();
667	let jitter = rng.random_range(0.0..=max.as_secs_f64());
668
669	Duration::from_secs_f64(jitter)
670}
671
672fn extract_header(response: &Response<()>, name: &HeaderName) -> Option<String> {
673	response.headers().get(name).and_then(|value| value.to_str().ok()).map(|s| s.to_string())
674}
675
676fn extract_last_modified(response: &Response<()>) -> Option<DateTime<Utc>> {
677	response
678		.headers()
679		.get(LAST_MODIFIED)
680		.and_then(|value| value.to_str().ok())
681		.and_then(|raw| httpdate::parse_http_date(raw).ok())
682		.map(<DateTime<Utc>>::from)
683}