Skip to main content

jwks_cache/cache/
manager.rs

1//! Cache manager handling JWKS retrieval and lifecycle.
2
3// crates.io
4use http::{
5	HeaderName, HeaderValue, Request, Response,
6	header::{ETAG, IF_NONE_MATCH, LAST_MODIFIED},
7};
8use http_cache_semantics::BeforeRequest;
9#[cfg(feature = "redis")] use http_cache_semantics::CachePolicy;
10use jsonwebtoken::jwk::JwkSet;
11use rand::Rng;
12use reqwest::{Client, redirect::Policy};
13use tokio::{
14	sync::{Mutex, RwLock},
15	time,
16};
17// self
18#[cfg(feature = "metrics")] use crate::metrics::{self, ProviderMetrics};
19#[cfg(feature = "redis")] use crate::registry::PersistentSnapshot;
20use crate::{
21	_prelude::*,
22	cache::{
23		entry::CacheEntry,
24		state::{CachePayload, CacheState},
25	},
26	http::{
27		client::fetch_jwks,
28		retry::{AttemptBudget, RetryExecutor},
29		semantics::{Freshness, base_request, evaluate_freshness, evaluate_revalidation},
30	},
31	registry::IdentityProviderRegistration,
32};
33
34/// Coordinates fetching, caching, and background refresh for a registration.
35///
36/// Instances are scoped per tenant/provider pair; the single-flight guard only
37/// serialises refresh work for that specific provider.
38#[derive(Clone, Debug)]
39pub struct CacheManager {
40	registration: Arc<IdentityProviderRegistration>,
41	client: Arc<Client>,
42	entry: Arc<RwLock<CacheEntry>>,
43	single_flight: Arc<Mutex<()>>,
44	#[cfg(feature = "metrics")]
45	metrics: Arc<ProviderMetrics>,
46}
47impl CacheManager {
48	/// Build a new cache manager with the default reqwest client.
49	pub fn new(registration: IdentityProviderRegistration) -> Result<Self> {
50		registration.validate()?;
51
52		let client = Client::builder()
53			.redirect(Policy::limited(10))
54			.user_agent(format!("jwks-cache/{}", env!("CARGO_PKG_VERSION")))
55			.connect_timeout(Duration::from_secs(5))
56			.build()?;
57
58		#[cfg(feature = "metrics")]
59		let manager = Self::with_parts(registration, client, ProviderMetrics::new());
60		#[cfg(not(feature = "metrics"))]
61		let manager = Self::with_parts(registration, client);
62
63		Ok(manager)
64	}
65
66	/// Build a cache manager using the supplied HTTP client (primarily for tests).
67	pub fn with_client(registration: IdentityProviderRegistration, client: Client) -> Self {
68		#[cfg(feature = "metrics")]
69		let manager = Self::with_parts(registration, client, ProviderMetrics::new());
70		#[cfg(not(feature = "metrics"))]
71		let manager = Self::with_parts(registration, client);
72
73		manager
74	}
75
76	#[cfg(feature = "metrics")]
77	fn with_parts(
78		registration: IdentityProviderRegistration,
79		client: Client,
80		metrics: Arc<ProviderMetrics>,
81	) -> Self {
82		let tenant = registration.tenant_id.clone();
83		let provider = registration.provider_id.clone();
84
85		Self {
86			registration: Arc::new(registration),
87			client: Arc::new(client),
88			entry: Arc::new(RwLock::new(CacheEntry::new(tenant, provider))),
89			single_flight: Arc::new(Mutex::new(())),
90			metrics,
91		}
92	}
93
94	#[cfg(not(feature = "metrics"))]
95	fn with_parts(registration: IdentityProviderRegistration, client: Client) -> Self {
96		let tenant = registration.tenant_id.clone();
97		let provider = registration.provider_id.clone();
98
99		Self {
100			registration: Arc::new(registration),
101			client: Arc::new(client),
102			entry: Arc::new(RwLock::new(CacheEntry::new(tenant, provider))),
103			single_flight: Arc::new(Mutex::new(())),
104		}
105	}
106
107	/// Access the per-provider metrics accumulator.
108	#[cfg(feature = "metrics")]
109	pub fn metrics(&self) -> Arc<ProviderMetrics> {
110		self.metrics.clone()
111	}
112
113	/// Capture the current cache state for status reporting.
114	pub async fn snapshot(&self) -> CacheSnapshot {
115		let captured_at = Instant::now();
116		let captured_at_wallclock = Utc::now();
117		let state = { self.entry.read().await.state().clone() };
118
119		CacheSnapshot { captured_at, captured_at_wallclock, state }
120	}
121
122	#[cfg(feature = "redis")]
123	/// Build a persistence payload capturing the current cache contents.
124	pub async fn persistent_snapshot(&self) -> Result<Option<PersistentSnapshot>> {
125		let snapshot = self.snapshot().await;
126		let payload = match snapshot.state {
127			CacheState::Ready(ref payload) | CacheState::Refreshing(ref payload) => payload.clone(),
128			_ => return Ok(None),
129		};
130		let expires_at = match snapshot.to_datetime(payload.expires_at) {
131			Some(dt) => dt,
132			None => return Ok(None),
133		};
134		let jwks_json = serde_json::to_string(&*payload.jwks)?;
135		let persisted_at = Utc::now();
136		let snapshot = PersistentSnapshot {
137			tenant_id: self.registration.tenant_id.clone(),
138			provider_id: self.registration.provider_id.clone(),
139			jwks_json,
140			etag: payload.etag.clone(),
141			last_modified: payload.last_modified,
142			expires_at,
143			persisted_at,
144		};
145
146		Ok(Some(snapshot))
147	}
148
149	#[cfg(feature = "redis")]
150	/// Restore cache state from a previously persisted snapshot.
151	pub async fn restore_snapshot(&self, snapshot: PersistentSnapshot) -> Result<()> {
152		snapshot.validate(&self.registration)?;
153
154		let PersistentSnapshot { jwks_json, etag, last_modified, expires_at, persisted_at, .. } =
155			snapshot;
156		let jwks: JwkSet = serde_json::from_str(&jwks_json)?;
157		let jwks = Arc::new(jwks);
158		let ttl = (expires_at - persisted_at)
159			.to_std()
160			.unwrap_or_default()
161			.max(self.registration.min_ttl)
162			.min(self.registration.max_ttl);
163		let request = base_request(&self.registration)?;
164		let mut response = Response::builder()
165			.status(200)
166			.header("cache-control", format!("public, max-age={}", ttl.as_secs()))
167			.header("content-type", "application/json")
168			.body(())
169			.map_err(Error::from)?;
170
171		if let Some(ref etag_value) = etag {
172			let value = HeaderValue::from_str(etag_value).map_err(|err| Error::Validation {
173				field: "etag",
174				reason: format!("Invalid persisted ETag: {err}."),
175			})?;
176
177			response.headers_mut().insert(ETAG, value);
178		}
179		if let Some(ref last_modified_value) = last_modified {
180			let http_date = httpdate::fmt_http_date((*last_modified_value).into());
181			let value = HeaderValue::from_str(&http_date).map_err(|err| Error::Validation {
182				field: "last_modified",
183				reason: format!("Invalid persisted Last-Modified: {err}."),
184			})?;
185
186			response.headers_mut().insert(LAST_MODIFIED, value);
187		}
188
189		let policy = CachePolicy::new(&request, &response);
190		let freshness = Freshness { ttl, policy };
191		let now = Instant::now();
192		let payload = self.build_payload(jwks, freshness, etag, last_modified, now, persisted_at);
193
194		{
195			let mut entry = self.entry.write().await;
196
197			entry.load_success(payload.clone());
198		}
199
200		tracing::debug!(
201			tenant = %self.registration.tenant_id,
202			provider = %self.registration.provider_id,
203			"restored cache entry from persistent snapshot"
204		);
205
206		Ok(())
207	}
208
209	/// Resolve JWKS for the registration, fetching upstream when necessary.
210	#[tracing::instrument(
211		skip(self, kid),
212		fields(
213			tenant = %self.registration.tenant_id,
214			provider = %self.registration.provider_id,
215			kid = kid.unwrap_or_default()
216		)
217	)]
218	pub async fn resolve(&self, kid: Option<&str>) -> Result<Arc<JwkSet>> {
219		loop {
220			let snapshot = { self.entry.read().await.snapshot() };
221			let now = Instant::now();
222
223			match snapshot {
224				None => {
225					tracing::debug!("cache empty; performing initial fetch");
226
227					match self.refresh_blocking(true).await? {
228						RefreshOutcome::Updated { jwks, from_cache } => {
229							if from_cache {
230								#[cfg(feature = "metrics")]
231								self.observe_hit(false);
232							} else {
233								#[cfg(feature = "metrics")]
234								self.observe_miss();
235							}
236
237							return Ok(jwks);
238						},
239						RefreshOutcome::Stale(jwks) => {
240							#[cfg(feature = "metrics")]
241							self.observe_hit(true);
242
243							return Ok(jwks);
244						},
245					}
246				},
247				Some(payload) => {
248					if !payload.is_expired(now) {
249						let jwks = payload.jwks.clone();
250
251						#[cfg(feature = "metrics")]
252						self.observe_hit(false);
253
254						if now >= payload.next_refresh_at {
255							self.schedule_background_refresh(now).await;
256						}
257
258						return Ok(jwks);
259					}
260
261					if payload.can_serve_stale(now) {
262						// TODO(refactor): consolidate stale fallback with perform_fetch_with_retry
263						// once the helper can orchestrate stale responses directly.
264						match self.refresh_blocking(false).await {
265							Ok(RefreshOutcome::Updated { jwks, from_cache }) => {
266								if from_cache {
267									#[cfg(feature = "metrics")]
268									self.observe_hit(false);
269								} else {
270									#[cfg(feature = "metrics")]
271									self.observe_miss();
272								}
273
274								return Ok(jwks);
275							},
276							Ok(RefreshOutcome::Stale(jwks)) => {
277								#[cfg(feature = "metrics")]
278								self.observe_hit(true);
279
280								return Ok(jwks);
281							},
282							Err(err) =>
283								if payload.can_serve_stale(Instant::now()) {
284									tracing::warn!(error = %err, "refresh failed, serving stale data");
285
286									#[cfg(feature = "metrics")]
287									self.observe_hit(true);
288
289									return Ok(payload.jwks.clone());
290								} else {
291									return Err(err);
292								},
293						}
294					} else if let RefreshOutcome::Updated { jwks, from_cache } =
295						self.refresh_blocking(true).await?
296					{
297						if from_cache {
298							#[cfg(feature = "metrics")]
299							self.observe_hit(false);
300						} else {
301							#[cfg(feature = "metrics")]
302							self.observe_miss();
303						}
304						return Ok(jwks);
305					}
306				},
307			}
308		}
309	}
310
311	/// Trigger a manual refresh asynchronously; used by the control plane.
312	#[tracing::instrument(
313		skip(self),
314		fields(tenant = %self.registration.tenant_id, provider = %self.registration.provider_id)
315	)]
316	pub async fn trigger_refresh(&self) -> Result<()> {
317		let now = Instant::now();
318		let action = {
319			let mut entry = self.entry.write().await;
320
321			match entry.state() {
322				CacheState::Empty => {
323					entry.begin_load();
324					RefreshTrigger::Blocking
325				},
326				CacheState::Loading | CacheState::Refreshing(_) => RefreshTrigger::None,
327				CacheState::Ready(_) =>
328					if entry.begin_refresh(now) {
329						RefreshTrigger::Background
330					} else {
331						RefreshTrigger::None
332					},
333			}
334		};
335
336		match action {
337			RefreshTrigger::Background => {
338				let manager = self.clone();
339
340				tokio::spawn(async move {
341					if let Err(err) = manager.refresh_blocking(true).await {
342						tracing::warn!(error = %err, "manual refresh failed");
343					}
344				});
345			},
346			RefreshTrigger::Blocking => {
347				self.refresh_blocking(true).await?;
348			},
349			RefreshTrigger::None => {},
350		}
351
352		Ok(())
353	}
354
355	#[tracing::instrument(
356		skip(self),
357		fields(tenant = %self.registration.tenant_id, provider = %self.registration.provider_id)
358	)]
359	async fn schedule_background_refresh(&self, now: Instant) {
360		let should_spawn = {
361			let mut entry = self.entry.write().await;
362
363			entry.begin_refresh(now)
364		};
365		if should_spawn {
366			let manager = self.clone();
367
368			tokio::spawn(async move {
369				if let Err(err) = manager.refresh_blocking(true).await {
370					tracing::debug!(error = %err, "background refresh failed");
371				}
372			});
373		}
374	}
375
376	#[tracing::instrument(
377		skip(self, force_revalidation),
378		fields(tenant = %self.registration.tenant_id, provider = %self.registration.provider_id, force_revalidation)
379	)]
380	async fn refresh_blocking(&self, force_revalidation: bool) -> Result<RefreshOutcome> {
381		let _guard = self.single_flight.lock().await;
382		let now = Instant::now();
383		let (existing, mode) = {
384			let mut entry = self.entry.write().await;
385			let snapshot = entry.snapshot();
386			let mode = if snapshot.is_some() {
387				entry.begin_refresh(now);
388
389				FetchMode::Refresh
390			} else {
391				entry.begin_load();
392
393				FetchMode::Initial
394			};
395
396			(snapshot, mode)
397		};
398
399		match self.prepare_request(existing.as_ref(), force_revalidation)? {
400			PreparedRequest::UseCached { jwks } =>
401				Ok(RefreshOutcome::Updated { jwks, from_cache: true }),
402			PreparedRequest::Send(request) =>
403				self.perform_fetch_with_retry(*request, existing, mode, force_revalidation).await,
404		}
405	}
406
407	fn prepare_request(
408		&self,
409		existing: Option<&CachePayload>,
410		force_revalidation: bool,
411	) -> Result<PreparedRequest> {
412		let mut request = base_request(&self.registration)?;
413
414		if let Some(payload) = existing {
415			let mut send_conditional = force_revalidation;
416
417			match payload.policy.before_request(&request, SystemTime::now()) {
418				BeforeRequest::Fresh(_) if !force_revalidation => {
419					return Ok(PreparedRequest::UseCached { jwks: payload.jwks.clone() });
420				},
421				BeforeRequest::Stale { request: parts, matches } if matches => {
422					request = Request::from_parts(parts, ());
423					send_conditional = true;
424				},
425				_ => {},
426			}
427
428			if send_conditional
429				&& let Some(etag) = &payload.etag
430				&& let Ok(value) = HeaderValue::from_str(etag)
431			{
432				request.headers_mut().insert(IF_NONE_MATCH, value);
433			}
434		}
435
436		Ok(PreparedRequest::Send(Box::new(request)))
437	}
438
439	async fn perform_fetch_with_retry(
440		&self,
441		request: Request<()>,
442		existing: Option<CachePayload>,
443		mode: FetchMode,
444		force_revalidation: bool,
445	) -> Result<RefreshOutcome> {
446		let mut executor = RetryExecutor::new(&self.registration.retry_policy);
447		let mut last_error: Option<Error> = None;
448		let mut last_backoff: Option<Duration> = None;
449		let request = request;
450
451		while let AttemptBudget::Granted { timeout } = executor.attempt_budget() {
452			#[cfg(feature = "metrics")]
453			let attempt_started = Instant::now();
454			let fetch = fetch_jwks(&self.client, &self.registration, &request, timeout).await;
455
456			match fetch {
457				Ok(fetch) => {
458					let now = Instant::now();
459					let payload = match (&fetch.jwks, existing.as_ref()) {
460						(Some(fresh_jwks), _) => {
461							let freshness =
462								evaluate_freshness(&self.registration, &fetch.exchange)?;
463
464							self.build_payload(
465								fresh_jwks.clone(),
466								freshness,
467								fetch.etag.clone(),
468								fetch.last_modified,
469								now,
470								Utc::now(),
471							)
472						},
473						(None, Some(previous)) => {
474							let revalidation = evaluate_revalidation(
475								&self.registration,
476								&previous.policy,
477								&fetch.exchange.request,
478								&fetch.exchange.response,
479							)?;
480							let updated_etag = extract_header(&revalidation.response, &ETAG)
481								.or_else(|| previous.etag.clone());
482
483							self.build_payload(
484								previous.jwks.clone(),
485								revalidation.freshness,
486								updated_etag,
487								extract_last_modified(&revalidation.response)
488									.or(previous.last_modified),
489								now,
490								Utc::now(),
491							)
492						},
493						(None, None) => {
494							return Err(Error::Cache(
495								"Received 304 status without a cached payload.".into(),
496							));
497						},
498					};
499
500					let jwks = payload.jwks.clone();
501
502					self.commit_success(mode, payload).await;
503					#[cfg(feature = "metrics")]
504					self.observe_refresh_success(attempt_started.elapsed());
505
506					return Ok(RefreshOutcome::Updated { jwks, from_cache: false });
507				},
508				Err(err) => {
509					last_error = Some(err);
510
511					if !executor.can_retry() {
512						break;
513					}
514
515					if let Some(delay) = executor.next_backoff() {
516						last_backoff = Some(delay);
517
518						if !delay.is_zero() {
519							time::sleep(delay).await;
520						}
521						continue;
522					}
523
524					break;
525				},
526			}
527		}
528
529		let now = Instant::now();
530
531		match mode {
532			FetchMode::Initial => {
533				let mut entry = self.entry.write().await;
534
535				entry.invalidate();
536			},
537			FetchMode::Refresh => {
538				let mut entry = self.entry.write().await;
539
540				entry.refresh_failure(now, last_backoff);
541			},
542		}
543
544		#[cfg(feature = "metrics")]
545		self.observe_refresh_error();
546
547		if !force_revalidation
548			&& let Some(payload) = existing
549			&& payload.can_serve_stale(now)
550		{
551			return Ok(RefreshOutcome::Stale(payload.jwks));
552		}
553
554		Err(last_error.unwrap_or_else(|| Error::Cache("Refresh attempts exhausted.".into())))
555	}
556
557	async fn commit_success(&self, mode: FetchMode, payload: CachePayload) {
558		let mut entry = self.entry.write().await;
559
560		match mode {
561			FetchMode::Initial => entry.load_success(payload),
562			FetchMode::Refresh => entry.refresh_success(payload),
563		}
564	}
565
566	fn build_payload(
567		&self,
568		jwks: Arc<JwkSet>,
569		freshness: Freshness,
570		etag: Option<String>,
571		last_modified: Option<DateTime<Utc>>,
572		now: Instant,
573		refreshed_at: DateTime<Utc>,
574	) -> CachePayload {
575		let ttl = freshness.ttl;
576		let expires_at = now + ttl;
577		let mut refresh_at = if self.registration.refresh_early >= ttl {
578			now
579		} else {
580			expires_at - self.registration.refresh_early
581		};
582
583		if !self.registration.prefetch_jitter.is_zero() {
584			let jitter = random_jitter(self.registration.prefetch_jitter);
585
586			if refresh_at > now + jitter {
587				refresh_at -= jitter;
588			}
589		}
590
591		let stale_deadline = if self.registration.stale_while_error.is_zero() {
592			None
593		} else {
594			Some(expires_at + self.registration.stale_while_error)
595		};
596
597		CachePayload {
598			jwks,
599			policy: freshness.policy,
600			etag,
601			last_modified,
602			last_refresh_at: refreshed_at,
603			expires_at,
604			next_refresh_at: refresh_at,
605			stale_deadline,
606			retry_backoff: None,
607			error_count: 0,
608		}
609	}
610
611	#[cfg(feature = "metrics")]
612	fn observe_hit(&self, stale: bool) {
613		let tenant = &self.registration.tenant_id;
614		let provider = &self.registration.provider_id;
615
616		metrics::record_resolve_hit(tenant, provider, stale);
617
618		self.metrics.record_hit(stale);
619	}
620
621	#[cfg(feature = "metrics")]
622	fn observe_miss(&self) {
623		let tenant = &self.registration.tenant_id;
624		let provider = &self.registration.provider_id;
625
626		metrics::record_resolve_miss(tenant, provider);
627
628		self.metrics.record_miss();
629	}
630
631	#[cfg(feature = "metrics")]
632	fn observe_refresh_success(&self, duration: Duration) {
633		let tenant = &self.registration.tenant_id;
634		let provider = &self.registration.provider_id;
635
636		metrics::record_refresh_success(tenant, provider, duration);
637
638		self.metrics.record_refresh_success(duration);
639	}
640
641	#[cfg(feature = "metrics")]
642	fn observe_refresh_error(&self) {
643		let tenant = &self.registration.tenant_id;
644		let provider = &self.registration.provider_id;
645
646		metrics::record_refresh_error(tenant, provider);
647
648		self.metrics.record_refresh_error();
649	}
650}
651
652/// Snapshot of cache state captured for status reporting.
653#[derive(Clone, Debug)]
654pub struct CacheSnapshot {
655	/// Monotonic instant when the snapshot was taken.
656	pub captured_at: Instant,
657	/// Wall-clock timestamp that aligns with `captured_at`.
658	pub captured_at_wallclock: DateTime<Utc>,
659	/// Cache state recorded at capture time.
660	pub state: CacheState,
661}
662impl CacheSnapshot {
663	/// Convert a monotonic instant drawn from the cached payload into UTC.
664	pub fn to_datetime(&self, instant: Instant) -> Option<DateTime<Utc>> {
665		if let Some(delta) = instant.checked_duration_since(self.captured_at) {
666			let chrono = TimeDelta::from_std(delta).ok()?;
667
668			self.captured_at_wallclock.checked_add_signed(chrono)
669		} else if let Some(delta) = self.captured_at.checked_duration_since(instant) {
670			let chrono = TimeDelta::from_std(delta).ok()?;
671
672			self.captured_at_wallclock.checked_sub_signed(chrono)
673		} else {
674			None
675		}
676	}
677}
678
679#[derive(Clone, Copy, Debug)]
680enum FetchMode {
681	Initial,
682	Refresh,
683}
684
685#[derive(Debug)]
686enum RefreshOutcome {
687	Updated { jwks: Arc<JwkSet>, from_cache: bool },
688	Stale(Arc<JwkSet>),
689}
690
691#[derive(Clone, Copy, Debug)]
692enum RefreshTrigger {
693	Background,
694	Blocking,
695	None,
696}
697
698#[derive(Debug)]
699enum PreparedRequest {
700	UseCached { jwks: Arc<JwkSet> },
701	Send(Box<Request<()>>),
702}
703
704fn random_jitter(max: Duration) -> Duration {
705	if max.is_zero() {
706		return Duration::ZERO;
707	}
708
709	let mut rng = rand::rng();
710	let jitter = rng.random_range(0.0..=max.as_secs_f64());
711
712	Duration::from_secs_f64(jitter)
713}
714
715fn extract_header(response: &Response<()>, name: &HeaderName) -> Option<String> {
716	response.headers().get(name).and_then(|value| value.to_str().ok()).map(|s| s.to_string())
717}
718
719fn extract_last_modified(response: &Response<()>) -> Option<DateTime<Utc>> {
720	response
721		.headers()
722		.get(LAST_MODIFIED)
723		.and_then(|value| value.to_str().ok())
724		.and_then(|raw| httpdate::parse_http_date(raw).ok())
725		.map(<DateTime<Utc>>::from)
726}