1use http::{
5 HeaderName, HeaderValue, Request, Response,
6 header::{ETAG, IF_NONE_MATCH, LAST_MODIFIED},
7};
8use http_cache_semantics::BeforeRequest;
9#[cfg(feature = "redis")] use http_cache_semantics::CachePolicy;
10use jsonwebtoken::jwk::JwkSet;
11use rand::Rng;
12use reqwest::{Client, redirect::Policy};
13use tokio::{
14 sync::{Mutex, RwLock},
15 time,
16};
17#[cfg(feature = "redis")] use crate::registry::PersistentSnapshot;
19use crate::{
20 _prelude::*,
21 cache::{
22 entry::CacheEntry,
23 state::{CachePayload, CacheState},
24 },
25 http::{
26 client::fetch_jwks,
27 retry::{AttemptBudget, RetryExecutor},
28 semantics::{Freshness, base_request, evaluate_freshness, evaluate_revalidation},
29 },
30 metrics::{self, ProviderMetrics},
31 registry::IdentityProviderRegistration,
32};
33
34#[derive(Clone, Debug)]
39pub struct CacheManager {
40 registration: Arc<IdentityProviderRegistration>,
41 client: Arc<Client>,
42 entry: Arc<RwLock<CacheEntry>>,
43 single_flight: Arc<Mutex<()>>,
44 metrics: Arc<ProviderMetrics>,
45}
46impl CacheManager {
47 pub fn new(registration: IdentityProviderRegistration) -> Result<Self> {
49 registration.validate()?;
50
51 let client = Client::builder()
52 .redirect(Policy::limited(10))
53 .user_agent(format!("jwks-cache/{}", env!("CARGO_PKG_VERSION")))
54 .connect_timeout(Duration::from_secs(5))
55 .build()?;
56
57 Ok(Self::with_parts(registration, client, ProviderMetrics::new()))
58 }
59
60 pub fn with_client(registration: IdentityProviderRegistration, client: Client) -> Self {
62 Self::with_parts(registration, client, ProviderMetrics::new())
63 }
64
65 fn with_parts(
66 registration: IdentityProviderRegistration,
67 client: Client,
68 metrics: Arc<ProviderMetrics>,
69 ) -> Self {
70 let tenant = registration.tenant_id.clone();
71 let provider = registration.provider_id.clone();
72
73 Self {
74 registration: Arc::new(registration),
75 client: Arc::new(client),
76 entry: Arc::new(RwLock::new(CacheEntry::new(tenant, provider))),
77 single_flight: Arc::new(Mutex::new(())),
78 metrics,
79 }
80 }
81
82 pub fn metrics(&self) -> Arc<ProviderMetrics> {
84 self.metrics.clone()
85 }
86
87 pub async fn snapshot(&self) -> CacheSnapshot {
89 let captured_at = Instant::now();
90 let captured_at_wallclock = Utc::now();
91 let state = { self.entry.read().await.state().clone() };
92
93 CacheSnapshot { captured_at, captured_at_wallclock, state }
94 }
95
96 #[cfg(feature = "redis")]
97 pub async fn persistent_snapshot(&self) -> Result<Option<PersistentSnapshot>> {
99 let snapshot = self.snapshot().await;
100 let payload = match snapshot.state {
101 CacheState::Ready(ref payload) | CacheState::Refreshing(ref payload) => payload.clone(),
102 _ => return Ok(None),
103 };
104 let expires_at = match snapshot.to_datetime(payload.expires_at) {
105 Some(dt) => dt,
106 None => return Ok(None),
107 };
108 let jwks_json = serde_json::to_string(&*payload.jwks)?;
109 let persisted_at = Utc::now();
110 let snapshot = PersistentSnapshot {
111 tenant_id: self.registration.tenant_id.clone(),
112 provider_id: self.registration.provider_id.clone(),
113 jwks_json,
114 etag: payload.etag.clone(),
115 last_modified: payload.last_modified,
116 expires_at,
117 persisted_at,
118 };
119
120 Ok(Some(snapshot))
121 }
122
123 #[cfg(feature = "redis")]
124 pub async fn restore_snapshot(&self, snapshot: PersistentSnapshot) -> Result<()> {
126 snapshot.validate(&self.registration)?;
127
128 let PersistentSnapshot { jwks_json, etag, last_modified, expires_at, persisted_at, .. } =
129 snapshot;
130 let jwks: JwkSet = serde_json::from_str(&jwks_json)?;
131 let jwks = Arc::new(jwks);
132 let ttl = (expires_at - persisted_at)
133 .to_std()
134 .unwrap_or_default()
135 .max(self.registration.min_ttl)
136 .min(self.registration.max_ttl);
137 let request = base_request(&self.registration)?;
138 let mut response = Response::builder()
139 .status(200)
140 .header("cache-control", format!("public, max-age={}", ttl.as_secs()))
141 .header("content-type", "application/json")
142 .body(())
143 .map_err(Error::from)?;
144
145 if let Some(ref etag_value) = etag {
146 let value = HeaderValue::from_str(etag_value).map_err(|err| Error::Validation {
147 field: "etag",
148 reason: format!("Invalid persisted ETag: {err}."),
149 })?;
150
151 response.headers_mut().insert(ETAG, value);
152 }
153 if let Some(ref last_modified_value) = last_modified {
154 let http_date = httpdate::fmt_http_date((*last_modified_value).into());
155 let value = HeaderValue::from_str(&http_date).map_err(|err| Error::Validation {
156 field: "last_modified",
157 reason: format!("Invalid persisted Last-Modified: {err}."),
158 })?;
159
160 response.headers_mut().insert(LAST_MODIFIED, value);
161 }
162
163 let policy = CachePolicy::new(&request, &response);
164 let freshness = Freshness { ttl, policy };
165 let now = Instant::now();
166 let payload = self.build_payload(jwks, freshness, etag, last_modified, now, persisted_at);
167
168 {
169 let mut entry = self.entry.write().await;
170
171 entry.load_success(payload.clone());
172 }
173
174 tracing::debug!(
175 tenant = %self.registration.tenant_id,
176 provider = %self.registration.provider_id,
177 "restored cache entry from persistent snapshot"
178 );
179
180 Ok(())
181 }
182
183 #[tracing::instrument(
185 skip(self, kid),
186 fields(
187 tenant = %self.registration.tenant_id,
188 provider = %self.registration.provider_id,
189 kid = kid.unwrap_or_default()
190 )
191 )]
192 pub async fn resolve(&self, kid: Option<&str>) -> Result<Arc<JwkSet>> {
193 loop {
194 let snapshot = { self.entry.read().await.snapshot() };
195 let now = Instant::now();
196
197 match snapshot {
198 None => {
199 tracing::debug!("cache empty; performing initial fetch");
200
201 match self.refresh_blocking(true).await? {
202 RefreshOutcome::Updated { jwks, from_cache } => {
203 if from_cache {
204 self.observe_hit(false);
205 } else {
206 self.observe_miss();
207 }
208
209 return Ok(jwks);
210 },
211 RefreshOutcome::Stale(jwks) => {
212 self.observe_hit(true);
213
214 return Ok(jwks);
215 },
216 }
217 },
218 Some(payload) => {
219 if !payload.is_expired(now) {
220 let jwks = payload.jwks.clone();
221
222 self.observe_hit(false);
223
224 if now >= payload.next_refresh_at {
225 self.schedule_background_refresh(now).await;
226 }
227
228 return Ok(jwks);
229 }
230
231 if payload.can_serve_stale(now) {
232 match self.refresh_blocking(false).await {
235 Ok(RefreshOutcome::Updated { jwks, from_cache }) => {
236 if from_cache {
237 self.observe_hit(false);
238 } else {
239 self.observe_miss();
240 }
241
242 return Ok(jwks);
243 },
244 Ok(RefreshOutcome::Stale(jwks)) => {
245 self.observe_hit(true);
246
247 return Ok(jwks);
248 },
249 Err(err) =>
250 if payload.can_serve_stale(Instant::now()) {
251 tracing::warn!(error = %err, "refresh failed, serving stale data");
252
253 self.observe_hit(true);
254
255 return Ok(payload.jwks.clone());
256 } else {
257 return Err(err);
258 },
259 }
260 } else if let RefreshOutcome::Updated { jwks, from_cache } =
261 self.refresh_blocking(true).await?
262 {
263 if from_cache {
264 self.observe_hit(false);
265 } else {
266 self.observe_miss();
267 }
268 return Ok(jwks);
269 }
270 },
271 }
272 }
273 }
274
275 #[tracing::instrument(
277 skip(self),
278 fields(tenant = %self.registration.tenant_id, provider = %self.registration.provider_id)
279 )]
280 pub async fn trigger_refresh(&self) -> Result<()> {
281 let now = Instant::now();
282 let action = {
283 let mut entry = self.entry.write().await;
284
285 match entry.state() {
286 CacheState::Empty => {
287 entry.begin_load();
288 RefreshTrigger::Blocking
289 },
290 CacheState::Loading | CacheState::Refreshing(_) => RefreshTrigger::None,
291 CacheState::Ready(_) =>
292 if entry.begin_refresh(now) {
293 RefreshTrigger::Background
294 } else {
295 RefreshTrigger::None
296 },
297 }
298 };
299
300 match action {
301 RefreshTrigger::Background => {
302 let manager = self.clone();
303
304 tokio::spawn(async move {
305 if let Err(err) = manager.refresh_blocking(true).await {
306 tracing::warn!(error = %err, "manual refresh failed");
307 }
308 });
309 },
310 RefreshTrigger::Blocking => {
311 self.refresh_blocking(true).await?;
312 },
313 RefreshTrigger::None => {},
314 }
315
316 Ok(())
317 }
318
319 #[tracing::instrument(
320 skip(self),
321 fields(tenant = %self.registration.tenant_id, provider = %self.registration.provider_id)
322 )]
323 async fn schedule_background_refresh(&self, now: Instant) {
324 let should_spawn = {
325 let mut entry = self.entry.write().await;
326
327 entry.begin_refresh(now)
328 };
329 if should_spawn {
330 let manager = self.clone();
331
332 tokio::spawn(async move {
333 if let Err(err) = manager.refresh_blocking(true).await {
334 tracing::debug!(error = %err, "background refresh failed");
335 }
336 });
337 }
338 }
339
340 #[tracing::instrument(
341 skip(self, force_revalidation),
342 fields(tenant = %self.registration.tenant_id, provider = %self.registration.provider_id, force_revalidation)
343 )]
344 async fn refresh_blocking(&self, force_revalidation: bool) -> Result<RefreshOutcome> {
345 let _guard = self.single_flight.lock().await;
346 let now = Instant::now();
347 let (existing, mode) = {
348 let mut entry = self.entry.write().await;
349 let snapshot = entry.snapshot();
350 let mode = if snapshot.is_some() {
351 entry.begin_refresh(now);
352
353 FetchMode::Refresh
354 } else {
355 entry.begin_load();
356
357 FetchMode::Initial
358 };
359
360 (snapshot, mode)
361 };
362
363 match self.prepare_request(existing.as_ref(), force_revalidation)? {
364 PreparedRequest::UseCached { jwks } =>
365 Ok(RefreshOutcome::Updated { jwks, from_cache: true }),
366 PreparedRequest::Send(request) =>
367 self.perform_fetch_with_retry(*request, existing, mode, force_revalidation).await,
368 }
369 }
370
371 fn prepare_request(
372 &self,
373 existing: Option<&CachePayload>,
374 force_revalidation: bool,
375 ) -> Result<PreparedRequest> {
376 let mut request = base_request(&self.registration)?;
377
378 if let Some(payload) = existing {
379 let mut send_conditional = force_revalidation;
380
381 match payload.policy.before_request(&request, SystemTime::now()) {
382 BeforeRequest::Fresh(_) if !force_revalidation => {
383 return Ok(PreparedRequest::UseCached { jwks: payload.jwks.clone() });
384 },
385 BeforeRequest::Stale { request: parts, matches } if matches => {
386 request = Request::from_parts(parts, ());
387 send_conditional = true;
388 },
389 _ => {},
390 }
391
392 if send_conditional
393 && let Some(etag) = &payload.etag
394 && let Ok(value) = HeaderValue::from_str(etag)
395 {
396 request.headers_mut().insert(IF_NONE_MATCH, value);
397 }
398 }
399
400 Ok(PreparedRequest::Send(Box::new(request)))
401 }
402
403 async fn perform_fetch_with_retry(
404 &self,
405 request: Request<()>,
406 existing: Option<CachePayload>,
407 mode: FetchMode,
408 force_revalidation: bool,
409 ) -> Result<RefreshOutcome> {
410 let mut executor = RetryExecutor::new(&self.registration.retry_policy);
411 let mut last_error: Option<Error> = None;
412 let mut last_backoff: Option<Duration> = None;
413 let request = request;
414
415 while let AttemptBudget::Granted { timeout } = executor.attempt_budget() {
416 let attempt_started = Instant::now();
417 let fetch = fetch_jwks(&self.client, &self.registration, &request, timeout).await;
418
419 match fetch {
420 Ok(fetch) => {
421 let now = Instant::now();
422 let payload = match (&fetch.jwks, existing.as_ref()) {
423 (Some(fresh_jwks), _) => {
424 let freshness =
425 evaluate_freshness(&self.registration, &fetch.exchange)?;
426
427 self.build_payload(
428 fresh_jwks.clone(),
429 freshness,
430 fetch.etag.clone(),
431 fetch.last_modified,
432 now,
433 Utc::now(),
434 )
435 },
436 (None, Some(previous)) => {
437 let revalidation = evaluate_revalidation(
438 &self.registration,
439 &previous.policy,
440 &fetch.exchange.request,
441 &fetch.exchange.response,
442 )?;
443 let updated_etag = extract_header(&revalidation.response, &ETAG)
444 .or_else(|| previous.etag.clone());
445
446 self.build_payload(
447 previous.jwks.clone(),
448 revalidation.freshness,
449 updated_etag,
450 extract_last_modified(&revalidation.response)
451 .or(previous.last_modified),
452 now,
453 Utc::now(),
454 )
455 },
456 (None, None) => {
457 return Err(Error::Cache(
458 "Received 304 status without a cached payload.".into(),
459 ));
460 },
461 };
462
463 let jwks = payload.jwks.clone();
464
465 self.commit_success(mode, payload).await;
466 self.observe_refresh_success(attempt_started.elapsed());
467
468 return Ok(RefreshOutcome::Updated { jwks, from_cache: false });
469 },
470 Err(err) => {
471 last_error = Some(err);
472
473 if !executor.can_retry() {
474 break;
475 }
476
477 if let Some(delay) = executor.next_backoff() {
478 last_backoff = Some(delay);
479
480 if !delay.is_zero() {
481 time::sleep(delay).await;
482 }
483 continue;
484 }
485
486 break;
487 },
488 }
489 }
490
491 let now = Instant::now();
492
493 match mode {
494 FetchMode::Initial => {
495 let mut entry = self.entry.write().await;
496
497 entry.invalidate();
498 },
499 FetchMode::Refresh => {
500 let mut entry = self.entry.write().await;
501
502 entry.refresh_failure(now, last_backoff);
503 },
504 }
505
506 self.observe_refresh_error();
507
508 if !force_revalidation
509 && let Some(payload) = existing
510 && payload.can_serve_stale(now)
511 {
512 return Ok(RefreshOutcome::Stale(payload.jwks));
513 }
514
515 Err(last_error.unwrap_or_else(|| Error::Cache("Refresh attempts exhausted.".into())))
516 }
517
518 async fn commit_success(&self, mode: FetchMode, payload: CachePayload) {
519 let mut entry = self.entry.write().await;
520
521 match mode {
522 FetchMode::Initial => entry.load_success(payload),
523 FetchMode::Refresh => entry.refresh_success(payload),
524 }
525 }
526
527 fn build_payload(
528 &self,
529 jwks: Arc<JwkSet>,
530 freshness: Freshness,
531 etag: Option<String>,
532 last_modified: Option<DateTime<Utc>>,
533 now: Instant,
534 refreshed_at: DateTime<Utc>,
535 ) -> CachePayload {
536 let ttl = freshness.ttl;
537 let expires_at = now + ttl;
538 let mut refresh_at = if self.registration.refresh_early >= ttl {
539 now
540 } else {
541 expires_at - self.registration.refresh_early
542 };
543
544 if !self.registration.prefetch_jitter.is_zero() {
545 let jitter = random_jitter(self.registration.prefetch_jitter);
546
547 if refresh_at > now + jitter {
548 refresh_at -= jitter;
549 }
550 }
551
552 let stale_deadline = if self.registration.stale_while_error.is_zero() {
553 None
554 } else {
555 Some(expires_at + self.registration.stale_while_error)
556 };
557
558 CachePayload {
559 jwks,
560 policy: freshness.policy,
561 etag,
562 last_modified,
563 last_refresh_at: refreshed_at,
564 expires_at,
565 next_refresh_at: refresh_at,
566 stale_deadline,
567 retry_backoff: None,
568 error_count: 0,
569 }
570 }
571
572 fn observe_hit(&self, stale: bool) {
573 let tenant = &self.registration.tenant_id;
574 let provider = &self.registration.provider_id;
575
576 metrics::record_resolve_hit(tenant, provider, stale);
577
578 self.metrics.record_hit(stale);
579 }
580
581 fn observe_miss(&self) {
582 let tenant = &self.registration.tenant_id;
583 let provider = &self.registration.provider_id;
584
585 metrics::record_resolve_miss(tenant, provider);
586
587 self.metrics.record_miss();
588 }
589
590 fn observe_refresh_success(&self, duration: Duration) {
591 let tenant = &self.registration.tenant_id;
592 let provider = &self.registration.provider_id;
593
594 metrics::record_refresh_success(tenant, provider, duration);
595
596 self.metrics.record_refresh_success(duration);
597 }
598
599 fn observe_refresh_error(&self) {
600 let tenant = &self.registration.tenant_id;
601 let provider = &self.registration.provider_id;
602
603 metrics::record_refresh_error(tenant, provider);
604
605 self.metrics.record_refresh_error();
606 }
607}
608
609#[derive(Clone, Debug)]
611pub struct CacheSnapshot {
612 pub captured_at: Instant,
614 pub captured_at_wallclock: DateTime<Utc>,
616 pub state: CacheState,
618}
619impl CacheSnapshot {
620 pub fn to_datetime(&self, instant: Instant) -> Option<DateTime<Utc>> {
622 if let Some(delta) = instant.checked_duration_since(self.captured_at) {
623 let chrono = TimeDelta::from_std(delta).ok()?;
624
625 self.captured_at_wallclock.checked_add_signed(chrono)
626 } else if let Some(delta) = self.captured_at.checked_duration_since(instant) {
627 let chrono = TimeDelta::from_std(delta).ok()?;
628
629 self.captured_at_wallclock.checked_sub_signed(chrono)
630 } else {
631 None
632 }
633 }
634}
635
636#[derive(Clone, Copy, Debug)]
637enum FetchMode {
638 Initial,
639 Refresh,
640}
641
642#[derive(Debug)]
643enum RefreshOutcome {
644 Updated { jwks: Arc<JwkSet>, from_cache: bool },
645 Stale(Arc<JwkSet>),
646}
647
648#[derive(Clone, Copy, Debug)]
649enum RefreshTrigger {
650 Background,
651 Blocking,
652 None,
653}
654
655#[derive(Debug)]
656enum PreparedRequest {
657 UseCached { jwks: Arc<JwkSet> },
658 Send(Box<Request<()>>),
659}
660
661fn random_jitter(max: Duration) -> Duration {
662 if max.is_zero() {
663 return Duration::ZERO;
664 }
665
666 let mut rng = rand::rng();
667 let jitter = rng.random_range(0.0..=max.as_secs_f64());
668
669 Duration::from_secs_f64(jitter)
670}
671
672fn extract_header(response: &Response<()>, name: &HeaderName) -> Option<String> {
673 response.headers().get(name).and_then(|value| value.to_str().ok()).map(|s| s.to_string())
674}
675
676fn extract_last_modified(response: &Response<()>) -> Option<DateTime<Utc>> {
677 response
678 .headers()
679 .get(LAST_MODIFIED)
680 .and_then(|value| value.to_str().ok())
681 .and_then(|raw| httpdate::parse_http_date(raw).ok())
682 .map(<DateTime<Utc>>::from)
683}