1use http::{
5 HeaderName, HeaderValue, Request, Response,
6 header::{ETAG, IF_NONE_MATCH, LAST_MODIFIED},
7};
8use http_cache_semantics::BeforeRequest;
9#[cfg(feature = "redis")] use http_cache_semantics::CachePolicy;
10use jsonwebtoken::jwk::JwkSet;
11use rand::Rng;
12use reqwest::{Client, redirect::Policy};
13use tokio::{
14 sync::{Mutex, RwLock},
15 time,
16};
17#[cfg(feature = "metrics")] use crate::metrics::{self, ProviderMetrics};
19#[cfg(feature = "redis")] use crate::registry::PersistentSnapshot;
20use crate::{
21 _prelude::*,
22 cache::{
23 entry::CacheEntry,
24 state::{CachePayload, CacheState},
25 },
26 http::{
27 client::fetch_jwks,
28 retry::{AttemptBudget, RetryExecutor},
29 semantics::{Freshness, base_request, evaluate_freshness, evaluate_revalidation},
30 },
31 registry::IdentityProviderRegistration,
32};
33
34#[derive(Clone, Debug)]
39pub struct CacheManager {
40 registration: Arc<IdentityProviderRegistration>,
41 client: Arc<Client>,
42 entry: Arc<RwLock<CacheEntry>>,
43 single_flight: Arc<Mutex<()>>,
44 #[cfg(feature = "metrics")]
45 metrics: Arc<ProviderMetrics>,
46}
47impl CacheManager {
48 pub fn new(registration: IdentityProviderRegistration) -> Result<Self> {
50 registration.validate()?;
51
52 let client = Client::builder()
53 .redirect(Policy::limited(10))
54 .user_agent(format!("jwks-cache/{}", env!("CARGO_PKG_VERSION")))
55 .connect_timeout(Duration::from_secs(5))
56 .build()?;
57
58 #[cfg(feature = "metrics")]
59 let manager = Self::with_parts(registration, client, ProviderMetrics::new());
60 #[cfg(not(feature = "metrics"))]
61 let manager = Self::with_parts(registration, client);
62
63 Ok(manager)
64 }
65
66 pub fn with_client(registration: IdentityProviderRegistration, client: Client) -> Self {
68 #[cfg(feature = "metrics")]
69 let manager = Self::with_parts(registration, client, ProviderMetrics::new());
70 #[cfg(not(feature = "metrics"))]
71 let manager = Self::with_parts(registration, client);
72
73 manager
74 }
75
76 #[cfg(feature = "metrics")]
77 fn with_parts(
78 registration: IdentityProviderRegistration,
79 client: Client,
80 metrics: Arc<ProviderMetrics>,
81 ) -> Self {
82 let tenant = registration.tenant_id.clone();
83 let provider = registration.provider_id.clone();
84
85 Self {
86 registration: Arc::new(registration),
87 client: Arc::new(client),
88 entry: Arc::new(RwLock::new(CacheEntry::new(tenant, provider))),
89 single_flight: Arc::new(Mutex::new(())),
90 metrics,
91 }
92 }
93
94 #[cfg(not(feature = "metrics"))]
95 fn with_parts(registration: IdentityProviderRegistration, client: Client) -> Self {
96 let tenant = registration.tenant_id.clone();
97 let provider = registration.provider_id.clone();
98
99 Self {
100 registration: Arc::new(registration),
101 client: Arc::new(client),
102 entry: Arc::new(RwLock::new(CacheEntry::new(tenant, provider))),
103 single_flight: Arc::new(Mutex::new(())),
104 }
105 }
106
107 #[cfg(feature = "metrics")]
109 pub fn metrics(&self) -> Arc<ProviderMetrics> {
110 self.metrics.clone()
111 }
112
113 pub async fn snapshot(&self) -> CacheSnapshot {
115 let captured_at = Instant::now();
116 let captured_at_wallclock = Utc::now();
117 let state = { self.entry.read().await.state().clone() };
118
119 CacheSnapshot { captured_at, captured_at_wallclock, state }
120 }
121
122 #[cfg(feature = "redis")]
123 pub async fn persistent_snapshot(&self) -> Result<Option<PersistentSnapshot>> {
125 let snapshot = self.snapshot().await;
126 let payload = match snapshot.state {
127 CacheState::Ready(ref payload) | CacheState::Refreshing(ref payload) => payload.clone(),
128 _ => return Ok(None),
129 };
130 let expires_at = match snapshot.to_datetime(payload.expires_at) {
131 Some(dt) => dt,
132 None => return Ok(None),
133 };
134 let jwks_json = serde_json::to_string(&*payload.jwks)?;
135 let persisted_at = Utc::now();
136 let snapshot = PersistentSnapshot {
137 tenant_id: self.registration.tenant_id.clone(),
138 provider_id: self.registration.provider_id.clone(),
139 jwks_json,
140 etag: payload.etag.clone(),
141 last_modified: payload.last_modified,
142 expires_at,
143 persisted_at,
144 };
145
146 Ok(Some(snapshot))
147 }
148
149 #[cfg(feature = "redis")]
150 pub async fn restore_snapshot(&self, snapshot: PersistentSnapshot) -> Result<()> {
152 snapshot.validate(&self.registration)?;
153
154 let PersistentSnapshot { jwks_json, etag, last_modified, expires_at, persisted_at, .. } =
155 snapshot;
156 let jwks: JwkSet = serde_json::from_str(&jwks_json)?;
157 let jwks = Arc::new(jwks);
158 let ttl = (expires_at - persisted_at)
159 .to_std()
160 .unwrap_or_default()
161 .max(self.registration.min_ttl)
162 .min(self.registration.max_ttl);
163 let request = base_request(&self.registration)?;
164 let mut response = Response::builder()
165 .status(200)
166 .header("cache-control", format!("public, max-age={}", ttl.as_secs()))
167 .header("content-type", "application/json")
168 .body(())
169 .map_err(Error::from)?;
170
171 if let Some(ref etag_value) = etag {
172 let value = HeaderValue::from_str(etag_value).map_err(|err| Error::Validation {
173 field: "etag",
174 reason: format!("Invalid persisted ETag: {err}."),
175 })?;
176
177 response.headers_mut().insert(ETAG, value);
178 }
179 if let Some(ref last_modified_value) = last_modified {
180 let http_date = httpdate::fmt_http_date((*last_modified_value).into());
181 let value = HeaderValue::from_str(&http_date).map_err(|err| Error::Validation {
182 field: "last_modified",
183 reason: format!("Invalid persisted Last-Modified: {err}."),
184 })?;
185
186 response.headers_mut().insert(LAST_MODIFIED, value);
187 }
188
189 let policy = CachePolicy::new(&request, &response);
190 let freshness = Freshness { ttl, policy };
191 let now = Instant::now();
192 let payload = self.build_payload(jwks, freshness, etag, last_modified, now, persisted_at);
193
194 {
195 let mut entry = self.entry.write().await;
196
197 entry.load_success(payload.clone());
198 }
199
200 tracing::debug!(
201 tenant = %self.registration.tenant_id,
202 provider = %self.registration.provider_id,
203 "restored cache entry from persistent snapshot"
204 );
205
206 Ok(())
207 }
208
209 #[tracing::instrument(
211 skip(self, kid),
212 fields(
213 tenant = %self.registration.tenant_id,
214 provider = %self.registration.provider_id,
215 kid = kid.unwrap_or_default()
216 )
217 )]
218 pub async fn resolve(&self, kid: Option<&str>) -> Result<Arc<JwkSet>> {
219 loop {
220 let snapshot = { self.entry.read().await.snapshot() };
221 let now = Instant::now();
222
223 match snapshot {
224 None => {
225 tracing::debug!("cache empty; performing initial fetch");
226
227 match self.refresh_blocking(true).await? {
228 RefreshOutcome::Updated { jwks, from_cache } => {
229 if from_cache {
230 #[cfg(feature = "metrics")]
231 self.observe_hit(false);
232 } else {
233 #[cfg(feature = "metrics")]
234 self.observe_miss();
235 }
236
237 return Ok(jwks);
238 },
239 RefreshOutcome::Stale(jwks) => {
240 #[cfg(feature = "metrics")]
241 self.observe_hit(true);
242
243 return Ok(jwks);
244 },
245 }
246 },
247 Some(payload) => {
248 if !payload.is_expired(now) {
249 let jwks = payload.jwks.clone();
250
251 #[cfg(feature = "metrics")]
252 self.observe_hit(false);
253
254 if now >= payload.next_refresh_at {
255 self.schedule_background_refresh(now).await;
256 }
257
258 return Ok(jwks);
259 }
260
261 if payload.can_serve_stale(now) {
262 match self.refresh_blocking(false).await {
265 Ok(RefreshOutcome::Updated { jwks, from_cache }) => {
266 if from_cache {
267 #[cfg(feature = "metrics")]
268 self.observe_hit(false);
269 } else {
270 #[cfg(feature = "metrics")]
271 self.observe_miss();
272 }
273
274 return Ok(jwks);
275 },
276 Ok(RefreshOutcome::Stale(jwks)) => {
277 #[cfg(feature = "metrics")]
278 self.observe_hit(true);
279
280 return Ok(jwks);
281 },
282 Err(err) =>
283 if payload.can_serve_stale(Instant::now()) {
284 tracing::warn!(error = %err, "refresh failed, serving stale data");
285
286 #[cfg(feature = "metrics")]
287 self.observe_hit(true);
288
289 return Ok(payload.jwks.clone());
290 } else {
291 return Err(err);
292 },
293 }
294 } else if let RefreshOutcome::Updated { jwks, from_cache } =
295 self.refresh_blocking(true).await?
296 {
297 if from_cache {
298 #[cfg(feature = "metrics")]
299 self.observe_hit(false);
300 } else {
301 #[cfg(feature = "metrics")]
302 self.observe_miss();
303 }
304 return Ok(jwks);
305 }
306 },
307 }
308 }
309 }
310
311 #[tracing::instrument(
313 skip(self),
314 fields(tenant = %self.registration.tenant_id, provider = %self.registration.provider_id)
315 )]
316 pub async fn trigger_refresh(&self) -> Result<()> {
317 let now = Instant::now();
318 let action = {
319 let mut entry = self.entry.write().await;
320
321 match entry.state() {
322 CacheState::Empty => {
323 entry.begin_load();
324 RefreshTrigger::Blocking
325 },
326 CacheState::Loading | CacheState::Refreshing(_) => RefreshTrigger::None,
327 CacheState::Ready(_) =>
328 if entry.begin_refresh(now) {
329 RefreshTrigger::Background
330 } else {
331 RefreshTrigger::None
332 },
333 }
334 };
335
336 match action {
337 RefreshTrigger::Background => {
338 let manager = self.clone();
339
340 tokio::spawn(async move {
341 if let Err(err) = manager.refresh_blocking(true).await {
342 tracing::warn!(error = %err, "manual refresh failed");
343 }
344 });
345 },
346 RefreshTrigger::Blocking => {
347 self.refresh_blocking(true).await?;
348 },
349 RefreshTrigger::None => {},
350 }
351
352 Ok(())
353 }
354
355 #[tracing::instrument(
356 skip(self),
357 fields(tenant = %self.registration.tenant_id, provider = %self.registration.provider_id)
358 )]
359 async fn schedule_background_refresh(&self, now: Instant) {
360 let should_spawn = {
361 let mut entry = self.entry.write().await;
362
363 entry.begin_refresh(now)
364 };
365 if should_spawn {
366 let manager = self.clone();
367
368 tokio::spawn(async move {
369 if let Err(err) = manager.refresh_blocking(true).await {
370 tracing::debug!(error = %err, "background refresh failed");
371 }
372 });
373 }
374 }
375
376 #[tracing::instrument(
377 skip(self, force_revalidation),
378 fields(tenant = %self.registration.tenant_id, provider = %self.registration.provider_id, force_revalidation)
379 )]
380 async fn refresh_blocking(&self, force_revalidation: bool) -> Result<RefreshOutcome> {
381 let _guard = self.single_flight.lock().await;
382 let now = Instant::now();
383 let (existing, mode) = {
384 let mut entry = self.entry.write().await;
385 let snapshot = entry.snapshot();
386 let mode = if snapshot.is_some() {
387 entry.begin_refresh(now);
388
389 FetchMode::Refresh
390 } else {
391 entry.begin_load();
392
393 FetchMode::Initial
394 };
395
396 (snapshot, mode)
397 };
398
399 match self.prepare_request(existing.as_ref(), force_revalidation)? {
400 PreparedRequest::UseCached { jwks } =>
401 Ok(RefreshOutcome::Updated { jwks, from_cache: true }),
402 PreparedRequest::Send(request) =>
403 self.perform_fetch_with_retry(*request, existing, mode, force_revalidation).await,
404 }
405 }
406
407 fn prepare_request(
408 &self,
409 existing: Option<&CachePayload>,
410 force_revalidation: bool,
411 ) -> Result<PreparedRequest> {
412 let mut request = base_request(&self.registration)?;
413
414 if let Some(payload) = existing {
415 let mut send_conditional = force_revalidation;
416
417 match payload.policy.before_request(&request, SystemTime::now()) {
418 BeforeRequest::Fresh(_) if !force_revalidation => {
419 return Ok(PreparedRequest::UseCached { jwks: payload.jwks.clone() });
420 },
421 BeforeRequest::Stale { request: parts, matches } if matches => {
422 request = Request::from_parts(parts, ());
423 send_conditional = true;
424 },
425 _ => {},
426 }
427
428 if send_conditional
429 && let Some(etag) = &payload.etag
430 && let Ok(value) = HeaderValue::from_str(etag)
431 {
432 request.headers_mut().insert(IF_NONE_MATCH, value);
433 }
434 }
435
436 Ok(PreparedRequest::Send(Box::new(request)))
437 }
438
439 async fn perform_fetch_with_retry(
440 &self,
441 request: Request<()>,
442 existing: Option<CachePayload>,
443 mode: FetchMode,
444 force_revalidation: bool,
445 ) -> Result<RefreshOutcome> {
446 let mut executor = RetryExecutor::new(&self.registration.retry_policy);
447 let mut last_error: Option<Error> = None;
448 let mut last_backoff: Option<Duration> = None;
449 let request = request;
450
451 while let AttemptBudget::Granted { timeout } = executor.attempt_budget() {
452 #[cfg(feature = "metrics")]
453 let attempt_started = Instant::now();
454 let fetch = fetch_jwks(&self.client, &self.registration, &request, timeout).await;
455
456 match fetch {
457 Ok(fetch) => {
458 let now = Instant::now();
459 let payload = match (&fetch.jwks, existing.as_ref()) {
460 (Some(fresh_jwks), _) => {
461 let freshness =
462 evaluate_freshness(&self.registration, &fetch.exchange)?;
463
464 self.build_payload(
465 fresh_jwks.clone(),
466 freshness,
467 fetch.etag.clone(),
468 fetch.last_modified,
469 now,
470 Utc::now(),
471 )
472 },
473 (None, Some(previous)) => {
474 let revalidation = evaluate_revalidation(
475 &self.registration,
476 &previous.policy,
477 &fetch.exchange.request,
478 &fetch.exchange.response,
479 )?;
480 let updated_etag = extract_header(&revalidation.response, &ETAG)
481 .or_else(|| previous.etag.clone());
482
483 self.build_payload(
484 previous.jwks.clone(),
485 revalidation.freshness,
486 updated_etag,
487 extract_last_modified(&revalidation.response)
488 .or(previous.last_modified),
489 now,
490 Utc::now(),
491 )
492 },
493 (None, None) => {
494 return Err(Error::Cache(
495 "Received 304 status without a cached payload.".into(),
496 ));
497 },
498 };
499
500 let jwks = payload.jwks.clone();
501
502 self.commit_success(mode, payload).await;
503 #[cfg(feature = "metrics")]
504 self.observe_refresh_success(attempt_started.elapsed());
505
506 return Ok(RefreshOutcome::Updated { jwks, from_cache: false });
507 },
508 Err(err) => {
509 last_error = Some(err);
510
511 if !executor.can_retry() {
512 break;
513 }
514
515 if let Some(delay) = executor.next_backoff() {
516 last_backoff = Some(delay);
517
518 if !delay.is_zero() {
519 time::sleep(delay).await;
520 }
521 continue;
522 }
523
524 break;
525 },
526 }
527 }
528
529 let now = Instant::now();
530
531 match mode {
532 FetchMode::Initial => {
533 let mut entry = self.entry.write().await;
534
535 entry.invalidate();
536 },
537 FetchMode::Refresh => {
538 let mut entry = self.entry.write().await;
539
540 entry.refresh_failure(now, last_backoff);
541 },
542 }
543
544 #[cfg(feature = "metrics")]
545 self.observe_refresh_error();
546
547 if !force_revalidation
548 && let Some(payload) = existing
549 && payload.can_serve_stale(now)
550 {
551 return Ok(RefreshOutcome::Stale(payload.jwks));
552 }
553
554 Err(last_error.unwrap_or_else(|| Error::Cache("Refresh attempts exhausted.".into())))
555 }
556
557 async fn commit_success(&self, mode: FetchMode, payload: CachePayload) {
558 let mut entry = self.entry.write().await;
559
560 match mode {
561 FetchMode::Initial => entry.load_success(payload),
562 FetchMode::Refresh => entry.refresh_success(payload),
563 }
564 }
565
566 fn build_payload(
567 &self,
568 jwks: Arc<JwkSet>,
569 freshness: Freshness,
570 etag: Option<String>,
571 last_modified: Option<DateTime<Utc>>,
572 now: Instant,
573 refreshed_at: DateTime<Utc>,
574 ) -> CachePayload {
575 let ttl = freshness.ttl;
576 let expires_at = now + ttl;
577 let mut refresh_at = if self.registration.refresh_early >= ttl {
578 now
579 } else {
580 expires_at - self.registration.refresh_early
581 };
582
583 if !self.registration.prefetch_jitter.is_zero() {
584 let jitter = random_jitter(self.registration.prefetch_jitter);
585
586 if refresh_at > now + jitter {
587 refresh_at -= jitter;
588 }
589 }
590
591 let stale_deadline = if self.registration.stale_while_error.is_zero() {
592 None
593 } else {
594 Some(expires_at + self.registration.stale_while_error)
595 };
596
597 CachePayload {
598 jwks,
599 policy: freshness.policy,
600 etag,
601 last_modified,
602 last_refresh_at: refreshed_at,
603 expires_at,
604 next_refresh_at: refresh_at,
605 stale_deadline,
606 retry_backoff: None,
607 error_count: 0,
608 }
609 }
610
611 #[cfg(feature = "metrics")]
612 fn observe_hit(&self, stale: bool) {
613 let tenant = &self.registration.tenant_id;
614 let provider = &self.registration.provider_id;
615
616 metrics::record_resolve_hit(tenant, provider, stale);
617
618 self.metrics.record_hit(stale);
619 }
620
621 #[cfg(feature = "metrics")]
622 fn observe_miss(&self) {
623 let tenant = &self.registration.tenant_id;
624 let provider = &self.registration.provider_id;
625
626 metrics::record_resolve_miss(tenant, provider);
627
628 self.metrics.record_miss();
629 }
630
631 #[cfg(feature = "metrics")]
632 fn observe_refresh_success(&self, duration: Duration) {
633 let tenant = &self.registration.tenant_id;
634 let provider = &self.registration.provider_id;
635
636 metrics::record_refresh_success(tenant, provider, duration);
637
638 self.metrics.record_refresh_success(duration);
639 }
640
641 #[cfg(feature = "metrics")]
642 fn observe_refresh_error(&self) {
643 let tenant = &self.registration.tenant_id;
644 let provider = &self.registration.provider_id;
645
646 metrics::record_refresh_error(tenant, provider);
647
648 self.metrics.record_refresh_error();
649 }
650}
651
652#[derive(Clone, Debug)]
654pub struct CacheSnapshot {
655 pub captured_at: Instant,
657 pub captured_at_wallclock: DateTime<Utc>,
659 pub state: CacheState,
661}
662impl CacheSnapshot {
663 pub fn to_datetime(&self, instant: Instant) -> Option<DateTime<Utc>> {
665 if let Some(delta) = instant.checked_duration_since(self.captured_at) {
666 let chrono = TimeDelta::from_std(delta).ok()?;
667
668 self.captured_at_wallclock.checked_add_signed(chrono)
669 } else if let Some(delta) = self.captured_at.checked_duration_since(instant) {
670 let chrono = TimeDelta::from_std(delta).ok()?;
671
672 self.captured_at_wallclock.checked_sub_signed(chrono)
673 } else {
674 None
675 }
676 }
677}
678
679#[derive(Clone, Copy, Debug)]
680enum FetchMode {
681 Initial,
682 Refresh,
683}
684
685#[derive(Debug)]
686enum RefreshOutcome {
687 Updated { jwks: Arc<JwkSet>, from_cache: bool },
688 Stale(Arc<JwkSet>),
689}
690
691#[derive(Clone, Copy, Debug)]
692enum RefreshTrigger {
693 Background,
694 Blocking,
695 None,
696}
697
698#[derive(Debug)]
699enum PreparedRequest {
700 UseCached { jwks: Arc<JwkSet> },
701 Send(Box<Request<()>>),
702}
703
704fn random_jitter(max: Duration) -> Duration {
705 if max.is_zero() {
706 return Duration::ZERO;
707 }
708
709 let mut rng = rand::rng();
710 let jitter = rng.random_range(0.0..=max.as_secs_f64());
711
712 Duration::from_secs_f64(jitter)
713}
714
715fn extract_header(response: &Response<()>, name: &HeaderName) -> Option<String> {
716 response.headers().get(name).and_then(|value| value.to_str().ok()).map(|s| s.to_string())
717}
718
719fn extract_last_modified(response: &Response<()>) -> Option<DateTime<Utc>> {
720 response
721 .headers()
722 .get(LAST_MODIFIED)
723 .and_then(|value| value.to_str().ok())
724 .and_then(|raw| httpdate::parse_http_date(raw).ok())
725 .map(<DateTime<Utc>>::from)
726}