Skip to main content

jwks_cache/
registry.rs

1//! Tenant/provider registry and configuration validation.
2//!
3//! The registry owns tenant registrations, cache metadata, and optional persistence wiring.
4
5// std
6use std::{cell::RefCell, collections::HashMap, mem};
7// crates.io
8use jsonwebtoken::jwk::JwkSet;
9use rand::{Rng, SeedableRng, rngs::SmallRng};
10#[cfg(feature = "redis")] use redis::AsyncCommands;
11use serde::{Deserialize, Serialize};
12use tokio::sync::RwLock;
13use url::Url;
14// self
15#[cfg(feature = "metrics")] use crate::metrics::{ProviderMetrics, ProviderMetricsSnapshot};
16use crate::{
17	_prelude::*,
18	cache::{
19		manager::{CacheManager, CacheSnapshot},
20		state::CacheState,
21	},
22	security::{self, SpkiFingerprint},
23};
24
25thread_local! {
26	static SMALL_RNG: RefCell<SmallRng> = RefCell::new(SmallRng::from_rng(&mut rand::rng()));
27}
28
29/// Default refresh lead time before TTL expiry.
30pub const DEFAULT_REFRESH_EARLY: Duration = Duration::from_secs(30);
31/// Default stale-while-error window.
32pub const DEFAULT_STALE_WHILE_ERROR: Duration = Duration::from_secs(60);
33/// Minimum accepted TTL for upstream responses.
34pub const MIN_TTL_FLOOR: Duration = Duration::from_secs(30);
35/// Default maximum TTL clamp.
36pub const DEFAULT_MAX_TTL: Duration = Duration::from_secs(60 * 60 * 24);
37/// Default size guard (1 MiB).
38pub const DEFAULT_MAX_RESPONSE_BYTES: u64 = 1_048_576;
39/// Default prefetch jitter.
40pub const DEFAULT_PREFETCH_JITTER: Duration = Duration::from_secs(5);
41/// Maximum redirect depth.
42pub const MAX_REDIRECTS: u8 = 10;
43
44/// Supported jitter strategies for retry policies.
45#[derive(Clone, Debug, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
46#[serde(rename_all = "snake_case")]
47pub enum JitterStrategy {
48	/// No jitter; deterministic backoff schedule.
49	None,
50	/// Full jitter; randomize delay between 0 and current backoff.
51	#[default]
52	Full,
53	/// Decorrelated jitter per AWS architecture guidance.
54	Decorrelated,
55}
56
57/// Public representation of provider lifecycle state.
58#[derive(Clone, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
59#[serde(rename_all = "PascalCase")]
60pub enum ProviderState {
61	/// No JWKS payload has been cached yet.
62	Empty,
63	/// Initial fetch operation is currently running.
64	Loading,
65	/// Fresh JWKS payload is available for requests.
66	Ready,
67	/// Cache is serving while a refresh is in progress.
68	Refreshing,
69}
70
71/// Retry configuration for HTTP fetch operations.
72#[derive(Clone, Debug, Serialize, Deserialize)]
73pub struct RetryPolicy {
74	/// Maximum number of retry attempts to perform after the initial request.
75	pub max_retries: u32,
76	/// Timeout applied to each individual HTTP attempt.
77	pub attempt_timeout: Duration,
78	/// Initial delay before retrying after a failure.
79	pub initial_backoff: Duration,
80	/// Upper bound applied to exponential backoff growth.
81	pub max_backoff: Duration,
82	/// Overall deadline that bounds the entire retry sequence.
83	pub deadline: Duration,
84	/// Strategy used to randomize the computed backoff.
85	#[serde(default)]
86	pub jitter: JitterStrategy,
87}
88impl RetryPolicy {
89	/// Validate invariants for retry configuration.
90	pub fn validate(&self) -> Result<()> {
91		if self.attempt_timeout < Duration::from_millis(100) {
92			return Err(Error::Validation {
93				field: "retry_policy.attempt_timeout",
94				reason: "Must be at least 100 ms.".into(),
95			});
96		}
97		if self.initial_backoff.is_zero() {
98			return Err(Error::Validation {
99				field: "retry_policy.initial_backoff",
100				reason: "Must be greater than zero.".into(),
101			});
102		}
103		if self.max_backoff < self.initial_backoff {
104			return Err(Error::Validation {
105				field: "retry_policy.max_backoff",
106				reason: "Must be greater than or equal to initial_backoff.".into(),
107			});
108		}
109		if self.deadline < self.attempt_timeout {
110			return Err(Error::Validation {
111				field: "retry_policy.deadline",
112				reason: "Must be greater than or equal to attempt_timeout.".into(),
113			});
114		}
115		Ok(())
116	}
117
118	/// Compute backoff for a retry attempt using the selected jitter strategy.
119	pub fn compute_backoff(&self, attempt: u32) -> Duration {
120		self.default_backoff(attempt)
121	}
122
123	/// Default exponential backoff with jitter following the AWS architecture guidance.
124	pub fn default_backoff(&self, attempt: u32) -> Duration {
125		let exponent = attempt.min(32);
126		let base = self.initial_backoff.mul_f64(2f64.powi(exponent as i32));
127		let bounded = base.min(self.max_backoff).max(self.initial_backoff);
128
129		self.apply_jitter(bounded, attempt)
130	}
131
132	fn apply_jitter(&self, bounded: Duration, attempt: u32) -> Duration {
133		match self.jitter {
134			JitterStrategy::None => bounded,
135			JitterStrategy::Full => {
136				let lower = bounded.mul_f64(0.8).max(self.initial_backoff);
137				let upper = bounded.min(self.max_backoff);
138
139				random_within(lower, upper)
140			},
141			JitterStrategy::Decorrelated => {
142				let prev = if attempt == 0 { self.initial_backoff } else { bounded };
143				let ceiling = self.max_backoff.min(prev.mul_f64(3.0));
144
145				random_within(self.initial_backoff, ceiling.max(self.initial_backoff))
146			},
147		}
148	}
149}
150impl Default for RetryPolicy {
151	fn default() -> Self {
152		Self {
153			max_retries: 2,
154			attempt_timeout: Duration::from_secs(3),
155			initial_backoff: Duration::from_millis(250),
156			max_backoff: Duration::from_secs(2),
157			deadline: Duration::from_secs(8),
158			jitter: JitterStrategy::Full,
159		}
160	}
161}
162
163/// Registration describing how to fetch and maintain JWKS for a provider.
164#[derive(Clone, Debug, Serialize, Deserialize)]
165pub struct IdentityProviderRegistration {
166	/// Tenant identifier used for metrics, caching, and persistence scope.
167	pub tenant_id: String,
168	/// Provider identifier unique within the tenant.
169	pub provider_id: String,
170	/// URL of the JWKS endpoint to fetch signing keys from.
171	pub jwks_url: Url,
172	/// Whether HTTPS is required for JWKS retrieval.
173	#[serde(default = "default_true")]
174	pub require_https: bool,
175	/// Optional allowlist of domains permitted for redirects.
176	#[serde(default, deserialize_with = "crate::security::deserialize_allowed_domains")]
177	pub allowed_domains: Vec<String>,
178	/// Lead time before expiry to trigger proactive refresh.
179	#[serde(default = "default_refresh_early")]
180	pub refresh_early: Duration,
181	/// Duration to continue serving stale data when refresh fails.
182	#[serde(default = "default_stale_while_error")]
183	pub stale_while_error: Duration,
184	/// Minimum TTL applied to upstream responses.
185	#[serde(default = "default_min_ttl")]
186	pub min_ttl: Duration,
187	/// Maximum TTL applied to upstream responses.
188	#[serde(default = "default_max_ttl")]
189	pub max_ttl: Duration,
190	/// Maximum size allowed for JWKS payloads in bytes.
191	#[serde(default = "default_max_response_bytes")]
192	pub max_response_bytes: u64,
193	/// TTL applied when persisting negative cache outcomes.
194	#[serde(default)]
195	pub negative_cache_ttl: Duration,
196	/// Maximum number of redirects to follow during fetch.
197	#[serde(default = "default_max_redirects")]
198	pub max_redirects: u8,
199	/// Optional SPKI fingerprints used for TLS pinning.
200	#[serde(default)]
201	pub pinned_spki: Vec<SpkiFingerprint>,
202	/// Random jitter applied when scheduling proactive refreshes.
203	#[serde(default = "default_prefetch_jitter")]
204	pub prefetch_jitter: Duration,
205	/// Retry policy configuration for JWKS fetch attempts.
206	#[serde(default)]
207	pub retry_policy: RetryPolicy,
208}
209impl IdentityProviderRegistration {
210	/// Construct a new registration with default cache settings.
211	pub fn new(
212		tenant_id: impl Into<String>,
213		provider_id: impl Into<String>,
214		jwks_url: impl AsRef<str>,
215	) -> Result<Self> {
216		let jwks_url = Url::parse(jwks_url.as_ref())?;
217
218		Ok(Self {
219			tenant_id: tenant_id.into(),
220			provider_id: provider_id.into(),
221			jwks_url,
222			require_https: true,
223			allowed_domains: Vec::new(),
224			refresh_early: DEFAULT_REFRESH_EARLY,
225			stale_while_error: DEFAULT_STALE_WHILE_ERROR,
226			min_ttl: MIN_TTL_FLOOR,
227			max_ttl: DEFAULT_MAX_TTL,
228			max_response_bytes: DEFAULT_MAX_RESPONSE_BYTES,
229			negative_cache_ttl: Duration::ZERO,
230			max_redirects: 3,
231			pinned_spki: Vec::new(),
232			prefetch_jitter: DEFAULT_PREFETCH_JITTER,
233			retry_policy: RetryPolicy::default(),
234		})
235	}
236
237	/// Canonicalise the domain allowlist in-place.
238	pub fn normalize_allowed_domains(&mut self) {
239		let domains = mem::take(&mut self.allowed_domains);
240
241		self.allowed_domains = security::normalize_allowlist(domains);
242	}
243
244	/// Set HTTPS requirement to the desired value.
245	pub fn with_require_https(mut self, require_https: bool) -> Self {
246		self.require_https = require_https;
247
248		self
249	}
250
251	/// Validate the registration against the documented constraints.
252	pub fn validate(&self) -> Result<()> {
253		validate_tenant_id(&self.tenant_id)?;
254		validate_provider_id(&self.provider_id)?;
255
256		if self.require_https {
257			security::enforce_https(&self.jwks_url)?;
258		}
259
260		if let Some(host) = self.jwks_url.host_str() {
261			if !security::host_is_allowed(host, &self.allowed_domains) {
262				return Err(Error::Validation {
263					field: "jwks_url",
264					reason: "Host is not within the allowed_domains allowlist.".into(),
265				});
266			}
267		} else {
268			return Err(Error::Validation {
269				field: "jwks_url",
270				reason: "Must include a host component.".into(),
271			});
272		}
273
274		if self.refresh_early < Duration::from_secs(1) {
275			return Err(Error::Validation {
276				field: "refresh_early",
277				reason: "Must be at least 1 second.".into(),
278			});
279		}
280		if self.min_ttl < MIN_TTL_FLOOR {
281			return Err(Error::Validation {
282				field: "min_ttl",
283				reason: format!("Must be at least {:?}.", MIN_TTL_FLOOR),
284			});
285		}
286		if self.max_ttl < self.min_ttl {
287			return Err(Error::Validation {
288				field: "max_ttl",
289				reason: "Must be greater than or equal to min_ttl.".into(),
290			});
291		}
292		if self.refresh_early >= self.max_ttl {
293			return Err(Error::Validation {
294				field: "refresh_early",
295				reason: "Must be less than max_ttl.".into(),
296			});
297		}
298		if self.max_response_bytes == 0 {
299			return Err(Error::Validation {
300				field: "max_response_bytes",
301				reason: "Must be greater than zero.".into(),
302			});
303		}
304		if self.max_redirects > MAX_REDIRECTS {
305			return Err(Error::Validation {
306				field: "max_redirects",
307				reason: format!("Must be less than or equal to {}.", MAX_REDIRECTS),
308			});
309		}
310		if !self.negative_cache_ttl.is_zero() && self.negative_cache_ttl < Duration::from_secs(1) {
311			return Err(Error::Validation {
312				field: "negative_cache_ttl",
313				reason: "Must be zero or at least one second.".into(),
314			});
315		}
316
317		self.retry_policy.validate()?;
318
319		for domain in &self.allowed_domains {
320			if let Some(canonical) = security::canonicalize_dns_name(domain) {
321				if canonical != *domain {
322					return Err(Error::Validation {
323						field: "allowed_domains",
324						reason: "Entries must be canonical hostnames (lowercase, no trailing dot)."
325							.into(),
326					});
327				}
328			} else {
329				return Err(Error::Validation {
330					field: "allowed_domains",
331					reason: "Entries must be non-empty hostnames.".into(),
332				});
333			}
334		}
335
336		Ok(())
337	}
338}
339
340/// Snapshot of cache payload persisted to external storage.
341#[derive(Clone, Debug, Serialize, Deserialize)]
342pub struct PersistentSnapshot {
343	/// Tenant identifier associated with the snapshot.
344	pub tenant_id: String,
345	/// Provider identifier within the tenant scope.
346	pub provider_id: String,
347	/// Serialized JWKS payload captured from the cache.
348	pub jwks_json: String,
349	/// Entity tag returned by the JWKS endpoint, if present.
350	pub etag: Option<String>,
351	/// Last-Modified timestamp advertised by the JWKS endpoint.
352	#[serde(default)]
353	pub last_modified: Option<DateTime<Utc>>,
354	/// UTC timestamp when the cached payload expires.
355	pub expires_at: DateTime<Utc>,
356	/// UTC timestamp when the snapshot was persisted.
357	pub persisted_at: DateTime<Utc>,
358}
359impl PersistentSnapshot {
360	/// Validate snapshot metadata aligns with registration expectations.
361	pub fn validate(&self, registration: &IdentityProviderRegistration) -> Result<()> {
362		if self.jwks_json.len() as u64 > registration.max_response_bytes {
363			return Err(Error::Validation {
364				field: "jwks_json",
365				reason: format!(
366					"Snapshot exceeds max_response_bytes ({} bytes).",
367					registration.max_response_bytes
368				),
369			});
370		}
371
372		if self.tenant_id != registration.tenant_id {
373			return Err(Error::Validation {
374				field: "tenant_id",
375				reason: "Snapshot tenant does not match registration.".into(),
376			});
377		}
378		if self.provider_id != registration.provider_id {
379			return Err(Error::Validation {
380				field: "provider_id",
381				reason: "Snapshot provider does not match registration.".into(),
382			});
383		}
384
385		if let Some(etag) = &self.etag
386			&& !etag.is_ascii()
387		{
388			return Err(Error::Validation { field: "etag", reason: "ETag must be ASCII.".into() });
389		}
390
391		if self.expires_at < self.persisted_at {
392			return Err(Error::Validation {
393				field: "expires_at",
394				reason: "Cannot be earlier than persisted_at.".into(),
395			});
396		}
397
398		Ok(())
399	}
400}
401
402/// Internal key mapping tenants and providers.
403#[derive(Clone, Debug, PartialEq, Eq, Hash)]
404pub struct TenantProviderKey {
405	pub tenant_id: String,
406	pub provider_id: String,
407}
408impl TenantProviderKey {
409	pub fn new(tenant_id: impl Into<String>, provider_id: impl Into<String>) -> Self {
410		Self { tenant_id: tenant_id.into(), provider_id: provider_id.into() }
411	}
412}
413
414/// Builder for [`Registry`] enabling multi-tenant configuration.
415#[derive(Debug, Default)]
416pub struct RegistryBuilder {
417	config: RegistryConfig,
418}
419impl RegistryBuilder {
420	/// Create a builder with default configuration.
421	pub fn new() -> Self {
422		Self::default()
423	}
424
425	/// Enforce HTTPS for registrations (enabled by default).
426	pub fn require_https(mut self, require_https: bool) -> Self {
427		self.config.require_https = require_https;
428
429		self
430	}
431
432	/// Override the default refresh-early offset applied to registrations.
433	pub fn default_refresh_early(mut self, value: Duration) -> Self {
434		self.config.default_refresh_early = value;
435
436		self
437	}
438
439	/// Override the default stale-while-error window applied to registrations.
440	pub fn default_stale_while_error(mut self, value: Duration) -> Self {
441		self.config.default_stale_while_error = value;
442
443		self
444	}
445
446	/// Add an entry to the global domain allowlist.
447	pub fn add_allowed_domain(mut self, domain: impl Into<String>) -> Self {
448		let raw = domain.into();
449
450		if let Some(domain) = security::canonicalize_dns_name(&raw)
451			&& !self.config.allowed_domains.contains(&domain)
452		{
453			self.config.allowed_domains.push(domain);
454		}
455
456		self
457	}
458
459	/// Replace the global domain allowlist.
460	pub fn allowed_domains<I, S>(mut self, domains: I) -> Self
461	where
462		I: IntoIterator<Item = S>,
463		S: Into<String>,
464	{
465		self.config.allowed_domains.clear();
466
467		for domain in domains {
468			self = self.add_allowed_domain(domain);
469		}
470
471		self
472	}
473
474	#[cfg(feature = "redis")]
475	/// Configure Redis-backed persistence for snapshots.
476	pub fn with_redis_client(mut self, client: redis::Client) -> Self {
477		self.config.persistence = Some(RedisPersistence::new(client));
478
479		self
480	}
481
482	#[cfg(feature = "redis")]
483	/// Adjust the Redis key namespace (defaults to `jwks-cache`).
484	pub fn redis_namespace(mut self, namespace: impl Into<String>) -> Self {
485		if let Some(persistence) = self.config.persistence.as_mut() {
486			persistence.namespace = Arc::from(namespace.into());
487		} else {
488			panic!("Redis client must be configured before setting namespace.");
489		}
490
491		self
492	}
493
494	/// Finalise the configuration and construct a [`Registry`].
495	pub fn build(self) -> Registry {
496		let mut config = self.config;
497
498		config.allowed_domains = security::normalize_allowlist(config.allowed_domains);
499
500		Registry {
501			inner: Arc::new(RwLock::new(RegistryState { providers: HashMap::new() })),
502			config: Arc::new(config),
503		}
504	}
505}
506
507/// Registry state container.
508#[derive(Clone, Debug)]
509pub struct Registry {
510	inner: Arc<RwLock<RegistryState>>,
511	config: Arc<RegistryConfig>,
512}
513impl Registry {
514	/// Create a new registry instance with defaults.
515	pub fn new() -> Self {
516		Self::builder().build()
517	}
518
519	/// Create a [`RegistryBuilder`] for advanced configuration.
520	pub fn builder() -> RegistryBuilder {
521		RegistryBuilder::new()
522	}
523
524	/// Register or update a provider configuration.
525	pub async fn register(&self, mut registration: IdentityProviderRegistration) -> Result<()> {
526		if self.config.require_https {
527			if !registration.require_https {
528				return Err(Error::Security(
529					"Registry requires HTTPS for all provider registrations.".into(),
530				));
531			}
532		} else {
533			registration.require_https = false;
534		}
535
536		registration.normalize_allowed_domains();
537
538		if registration.refresh_early == DEFAULT_REFRESH_EARLY {
539			registration.refresh_early = self.config.default_refresh_early;
540		}
541		if registration.stale_while_error == DEFAULT_STALE_WHILE_ERROR {
542			registration.stale_while_error = self.config.default_stale_while_error;
543		}
544		if registration.allowed_domains.is_empty() && !self.config.allowed_domains.is_empty() {
545			registration.allowed_domains = self.config.allowed_domains.clone();
546		}
547
548		if let Some(host) = registration.jwks_url.host_str()
549			&& !security::host_is_allowed(host, &self.config.allowed_domains)
550		{
551			return Err(Error::Security(format!(
552				"Host '{host}' is not in the registry allowlist."
553			)));
554		}
555
556		let key = TenantProviderKey::new(&registration.tenant_id, &registration.provider_id);
557		let manager = CacheManager::new(registration.clone())?;
558		#[cfg(feature = "metrics")]
559		let metrics = manager.metrics();
560		let handle = Arc::new(ProviderHandle {
561			registration: Arc::new(registration),
562			manager,
563			#[cfg(feature = "metrics")]
564			metrics,
565		});
566
567		{
568			let mut state = self.inner.write().await;
569
570			state.providers.insert(key.clone(), handle.clone());
571		}
572
573		#[cfg(feature = "redis")]
574		if let Some(persistence) = &self.config.persistence
575			&& let Some(snapshot) = persistence.load(&key.tenant_id, &key.provider_id).await?
576		{
577			handle.manager.restore_snapshot(snapshot).await?;
578		}
579
580		Ok(())
581	}
582
583	/// Resolve JWKS for a tenant/provider pair.
584	pub async fn resolve(
585		&self,
586		tenant_id: &str,
587		provider_id: &str,
588		kid: Option<&str>,
589	) -> Result<Arc<JwkSet>> {
590		let key = TenantProviderKey::new(tenant_id, provider_id);
591		let handle = {
592			let state = self.inner.read().await;
593
594			state.providers.get(&key).cloned()
595		};
596		let handle = handle.ok_or_else(|| Error::NotRegistered {
597			tenant: tenant_id.to_string(),
598			provider: provider_id.to_string(),
599		})?;
600
601		handle.manager.resolve(kid).await
602	}
603
604	/// Trigger a manual refresh for a registered provider.
605	pub async fn refresh(&self, tenant_id: &str, provider_id: &str) -> Result<()> {
606		let key = TenantProviderKey::new(tenant_id, provider_id);
607		let handle = {
608			let state = self.inner.read().await;
609			state.providers.get(&key).cloned()
610		};
611		let handle = handle.ok_or_else(|| Error::NotRegistered {
612			tenant: tenant_id.to_string(),
613			provider: provider_id.to_string(),
614		})?;
615
616		handle.manager.trigger_refresh().await
617	}
618
619	/// Remove a provider registration if present.
620	pub async fn unregister(&self, tenant_id: &str, provider_id: &str) -> Result<bool> {
621		let key = TenantProviderKey::new(tenant_id, provider_id);
622		let mut state = self.inner.write().await;
623
624		Ok(state.providers.remove(&key).is_some())
625	}
626
627	/// Fetch status information for a specific provider.
628	pub async fn provider_status(
629		&self,
630		tenant_id: &str,
631		provider_id: &str,
632	) -> Result<ProviderStatus> {
633		let key = TenantProviderKey::new(tenant_id, provider_id);
634		let handle = {
635			let state = self.inner.read().await;
636
637			state.providers.get(&key).cloned()
638		};
639		let handle = handle.ok_or_else(|| Error::NotRegistered {
640			tenant: tenant_id.to_string(),
641			provider: provider_id.to_string(),
642		})?;
643
644		Ok(handle.status().await)
645	}
646
647	/// Fetch status for every registered provider.
648	pub async fn all_statuses(&self) -> Vec<ProviderStatus> {
649		let handles: Vec<Arc<ProviderHandle>> = {
650			let state = self.inner.read().await;
651			state.providers.values().cloned().collect()
652		};
653		let mut statuses = Vec::with_capacity(handles.len());
654
655		for handle in handles {
656			statuses.push(handle.status().await);
657		}
658
659		statuses
660	}
661
662	/// Persist snapshots for every provider when persistence is configured.
663	pub async fn persist_all(&self) -> Result<()> {
664		#[cfg(feature = "redis")]
665		{
666			if let Some(persistence) = &self.config.persistence {
667				let handles: Vec<Arc<ProviderHandle>> = {
668					let state = self.inner.read().await;
669
670					state.providers.values().cloned().collect()
671				};
672				let mut snapshots = Vec::new();
673
674				for handle in handles {
675					if let Some(snapshot) = handle.manager.persistent_snapshot().await? {
676						snapshots.push(snapshot);
677					}
678				}
679
680				persistence.persist(&snapshots).await?;
681			}
682		}
683
684		Ok(())
685	}
686
687	/// Restore cached entries from persistence for all active registrations.
688	pub async fn restore_from_persistence(&self) -> Result<()> {
689		#[cfg(feature = "redis")]
690		{
691			if let Some(persistence) = &self.config.persistence {
692				let handles: Vec<Arc<ProviderHandle>> = {
693					let state = self.inner.read().await;
694
695					state.providers.values().cloned().collect()
696				};
697
698				for handle in handles {
699					if let Some(snapshot) = persistence
700						.load(&handle.registration.tenant_id, &handle.registration.provider_id)
701						.await?
702					{
703						handle.manager.restore_snapshot(snapshot).await?;
704					}
705				}
706			}
707		}
708
709		Ok(())
710	}
711}
712impl Default for Registry {
713	fn default() -> Self {
714		Self::new()
715	}
716}
717
718/// Status projection for a provider, aligned with the OpenAPI contract.
719#[derive(Clone, Debug, Serialize, Deserialize)]
720pub struct ProviderStatus {
721	/// Tenant identifier that owns the provider.
722	pub tenant_id: String,
723	/// Provider identifier unique within the tenant.
724	pub provider_id: String,
725	/// Lifecycle state currently reported for the provider.
726	pub state: ProviderState,
727	/// Timestamp of the most recent successful refresh.
728	pub last_refresh: Option<DateTime<Utc>>,
729	/// Scheduled timestamp for the next refresh attempt.
730	pub next_refresh: Option<DateTime<Utc>>,
731	/// Expiration timestamp for the active payload, if available.
732	pub expires_at: Option<DateTime<Utc>>,
733	/// Consecutive error count observed during refresh attempts.
734	pub error_count: u32,
735	/// Ratio of cache hits to total requests.
736	#[cfg(feature = "metrics")]
737	pub hit_rate: f64,
738	/// Ratio of served responses that were stale.
739	#[cfg(feature = "metrics")]
740	pub stale_serve_ratio: f64,
741	/// Metrics emitted to describe provider performance.
742	#[cfg(feature = "metrics")]
743	pub metrics: Vec<StatusMetric>,
744}
745impl ProviderStatus {
746	#[cfg(feature = "metrics")]
747	fn from_components(
748		registration: &IdentityProviderRegistration,
749		snapshot: CacheSnapshot,
750		metrics: ProviderMetricsSnapshot,
751	) -> Self {
752		let mut last_refresh = None;
753		let mut next_refresh = None;
754		let mut expires_at = None;
755		let mut error_count = 0;
756		let state = match &snapshot.state {
757			CacheState::Empty => ProviderState::Empty,
758			CacheState::Loading => ProviderState::Loading,
759			CacheState::Ready(payload) => {
760				last_refresh = Some(payload.last_refresh_at);
761				next_refresh = snapshot.to_datetime(payload.next_refresh_at);
762				expires_at = snapshot.to_datetime(payload.expires_at);
763				error_count = payload.error_count;
764				ProviderState::Ready
765			},
766			CacheState::Refreshing(payload) => {
767				last_refresh = Some(payload.last_refresh_at);
768				next_refresh = snapshot.to_datetime(payload.next_refresh_at);
769				expires_at = snapshot.to_datetime(payload.expires_at);
770				error_count = payload.error_count;
771				ProviderState::Refreshing
772			},
773		};
774		let tenant = &registration.tenant_id;
775		let provider = &registration.provider_id;
776		let mut status_metrics = vec![
777			StatusMetric::new(
778				"jwks_cache_requests_total",
779				metrics.total_requests as f64,
780				tenant,
781				provider,
782			),
783			StatusMetric::new("jwks_cache_hits_total", metrics.cache_hits as f64, tenant, provider),
784			StatusMetric::new(
785				"jwks_cache_stale_total",
786				metrics.stale_serves as f64,
787				tenant,
788				provider,
789			),
790			StatusMetric::new(
791				"jwks_cache_refresh_errors_total",
792				metrics.refresh_errors as f64,
793				tenant,
794				provider,
795			),
796		];
797
798		if let Some(last_micros) = metrics.last_refresh_micros {
799			status_metrics.push(StatusMetric::new(
800				"jwks_cache_last_refresh_micros",
801				last_micros as f64,
802				tenant,
803				provider,
804			));
805		}
806
807		Self {
808			tenant_id: tenant.clone(),
809			provider_id: provider.clone(),
810			state,
811			last_refresh,
812			next_refresh,
813			expires_at,
814			error_count,
815			hit_rate: metrics.hit_rate(),
816			stale_serve_ratio: metrics.stale_ratio(),
817			metrics: status_metrics,
818		}
819	}
820
821	#[cfg(not(feature = "metrics"))]
822	fn from_components(
823		registration: &IdentityProviderRegistration,
824		snapshot: CacheSnapshot,
825	) -> Self {
826		let mut last_refresh = None;
827		let mut next_refresh = None;
828		let mut expires_at = None;
829		let mut error_count = 0;
830		let state = match &snapshot.state {
831			CacheState::Empty => ProviderState::Empty,
832			CacheState::Loading => ProviderState::Loading,
833			CacheState::Ready(payload) => {
834				last_refresh = Some(payload.last_refresh_at);
835				next_refresh = snapshot.to_datetime(payload.next_refresh_at);
836				expires_at = snapshot.to_datetime(payload.expires_at);
837				error_count = payload.error_count;
838				ProviderState::Ready
839			},
840			CacheState::Refreshing(payload) => {
841				last_refresh = Some(payload.last_refresh_at);
842				next_refresh = snapshot.to_datetime(payload.next_refresh_at);
843				expires_at = snapshot.to_datetime(payload.expires_at);
844				error_count = payload.error_count;
845				ProviderState::Refreshing
846			},
847		};
848
849		Self {
850			tenant_id: registration.tenant_id.clone(),
851			provider_id: registration.provider_id.clone(),
852			state,
853			last_refresh,
854			next_refresh,
855			expires_at,
856			error_count,
857		}
858	}
859}
860
861/// Metric sample used in provider status responses.
862#[cfg(feature = "metrics")]
863#[derive(Clone, Debug, Serialize, Deserialize)]
864pub struct StatusMetric {
865	/// Metric name following the monitoring schema.
866	pub name: String,
867	/// Numeric value captured for the metric.
868	pub value: f64,
869	/// Additional labels enriching the metric sample.
870	#[serde(default)]
871	pub labels: HashMap<String, String>,
872}
873#[cfg(feature = "metrics")]
874impl StatusMetric {
875	fn new(name: impl Into<String>, value: f64, tenant: &str, provider: &str) -> Self {
876		let mut labels = HashMap::with_capacity(2);
877
878		labels.insert("tenant".into(), tenant.into());
879		labels.insert("provider".into(), provider.into());
880
881		Self { name: name.into(), value, labels }
882	}
883}
884
885#[derive(Debug)]
886struct RegistryConfig {
887	require_https: bool,
888	default_refresh_early: Duration,
889	default_stale_while_error: Duration,
890	allowed_domains: Vec<String>,
891	#[cfg(feature = "redis")]
892	persistence: Option<RedisPersistence>,
893}
894impl Default for RegistryConfig {
895	fn default() -> Self {
896		Self {
897			require_https: true,
898			default_refresh_early: DEFAULT_REFRESH_EARLY,
899			default_stale_while_error: DEFAULT_STALE_WHILE_ERROR,
900			allowed_domains: Vec::new(),
901			#[cfg(feature = "redis")]
902			persistence: None,
903		}
904	}
905}
906
907#[derive(Debug)]
908struct ProviderHandle {
909	registration: Arc<IdentityProviderRegistration>,
910	manager: CacheManager,
911	#[cfg(feature = "metrics")]
912	metrics: Arc<ProviderMetrics>,
913}
914impl ProviderHandle {
915	async fn status(&self) -> ProviderStatus {
916		let snapshot = self.manager.snapshot().await;
917		#[cfg(feature = "metrics")]
918		let status = {
919			let metrics = self.metrics.snapshot();
920
921			ProviderStatus::from_components(&self.registration, snapshot, metrics)
922		};
923		#[cfg(not(feature = "metrics"))]
924		let status = ProviderStatus::from_components(&self.registration, snapshot);
925
926		status
927	}
928}
929
930#[derive(Debug)]
931struct RegistryState {
932	// TODO: Consider replacing the RwLock<HashMap> with DashMap if contention becomes measurable.
933	providers: HashMap<TenantProviderKey, Arc<ProviderHandle>>,
934}
935
936#[cfg(feature = "redis")]
937#[derive(Clone, Debug)]
938struct RedisPersistence {
939	client: redis::Client,
940	namespace: Arc<str>,
941}
942#[cfg(feature = "redis")]
943impl RedisPersistence {
944	fn new(client: redis::Client) -> Self {
945		Self { client, namespace: Arc::from("jwks-cache") }
946	}
947
948	async fn persist(&self, snapshots: &[PersistentSnapshot]) -> Result<()> {
949		if snapshots.is_empty() {
950			return Ok(());
951		}
952
953		let mut conn = self.client.get_multiplexed_async_connection().await?;
954
955		for snapshot in snapshots {
956			let key = self.key(&snapshot.tenant_id, &snapshot.provider_id);
957			let payload = serde_json::to_string(snapshot)?;
958			let ttl = (snapshot.expires_at - Utc::now())
959				.to_std()
960				.unwrap_or_else(|_| Duration::from_secs(1));
961			let ttl_secs = ttl.as_secs().max(1);
962
963			conn.set_ex::<_, _, ()>(key, payload, ttl_secs).await?;
964		}
965
966		Ok(())
967	}
968
969	async fn load(&self, tenant: &str, provider: &str) -> Result<Option<PersistentSnapshot>> {
970		let mut conn = self.client.get_multiplexed_async_connection().await?;
971		let key = self.key(tenant, provider);
972		let value: Option<String> = conn.get(key).await?;
973
974		if let Some(json) = value {
975			let snapshot: PersistentSnapshot = serde_json::from_str(&json)?;
976
977			Ok(Some(snapshot))
978		} else {
979			Ok(None)
980		}
981	}
982
983	fn key(&self, tenant: &str, provider: &str) -> String {
984		format!("{}:{tenant}:{provider}", self.namespace)
985	}
986}
987
988fn random_within(min: Duration, max: Duration) -> Duration {
989	if max <= min {
990		return max;
991	}
992	SMALL_RNG.with(|cell| {
993		let mut rng = cell.borrow_mut();
994		let nanos = max.as_nanos() - min.as_nanos();
995		let jitter = rng.random_range(0..=nanos.min(u64::MAX as u128));
996
997		min + Duration::from_nanos(jitter as u64)
998	})
999}
1000
1001fn default_true() -> bool {
1002	true
1003}
1004
1005fn default_refresh_early() -> Duration {
1006	DEFAULT_REFRESH_EARLY
1007}
1008
1009fn default_stale_while_error() -> Duration {
1010	DEFAULT_STALE_WHILE_ERROR
1011}
1012
1013fn default_min_ttl() -> Duration {
1014	MIN_TTL_FLOOR
1015}
1016
1017fn default_max_ttl() -> Duration {
1018	DEFAULT_MAX_TTL
1019}
1020
1021fn default_max_response_bytes() -> u64 {
1022	DEFAULT_MAX_RESPONSE_BYTES
1023}
1024
1025fn default_max_redirects() -> u8 {
1026	3
1027}
1028
1029fn default_prefetch_jitter() -> Duration {
1030	DEFAULT_PREFETCH_JITTER
1031}
1032
1033fn validate_tenant_id(value: &str) -> Result<()> {
1034	if value.is_empty() {
1035		return Err(Error::Validation { field: "tenant_id", reason: "Must not be empty.".into() });
1036	}
1037	if value.len() > 64 {
1038		return Err(Error::Validation {
1039			field: "tenant_id",
1040			reason: "Must be 64 characters or fewer.".into(),
1041		});
1042	}
1043	if !value.as_bytes().iter().all(|b| b.is_ascii_alphanumeric() || *b == b'-') {
1044		return Err(Error::Validation {
1045			field: "tenant_id",
1046			reason: "May only contain ASCII letters, numbers, and '-'.".into(),
1047		});
1048	}
1049
1050	Ok(())
1051}
1052
1053fn validate_provider_id(value: &str) -> Result<()> {
1054	if value.is_empty() {
1055		return Err(Error::Validation {
1056			field: "provider_id",
1057			reason: "Must not be empty.".into(),
1058		});
1059	}
1060	if value.len() > 64 {
1061		return Err(Error::Validation {
1062			field: "provider_id",
1063			reason: "Must be 64 characters or fewer.".into(),
1064		});
1065	}
1066	if !value.as_bytes().iter().all(|b| b.is_ascii_alphanumeric() || matches!(b, b'-' | b'_')) {
1067		return Err(Error::Validation {
1068			field: "provider_id",
1069			reason: "May only contain ASCII letters, numbers, '-', or '_'.".into(),
1070		});
1071	}
1072
1073	Ok(())
1074}