jwks_cache/
registry.rs

1//! Tenant/provider registry and configuration validation.
2//!
3//! The registry owns tenant registrations, cache metadata, and optional persistence wiring.
4
5// std
6use std::{cell::RefCell, collections::HashMap, mem};
7// crates.io
8use jsonwebtoken::jwk::JwkSet;
9use rand::{Rng, SeedableRng, rngs::SmallRng};
10#[cfg(feature = "redis")] use redis::AsyncCommands;
11use serde::{Deserialize, Serialize};
12use tokio::sync::RwLock;
13use url::Url;
14// self
15use crate::{
16	_prelude::*,
17	cache::{
18		manager::{CacheManager, CacheSnapshot},
19		state::CacheState,
20	},
21	metrics::{ProviderMetrics, ProviderMetricsSnapshot},
22	security::{self, SpkiFingerprint},
23};
24
25thread_local! {
26	static SMALL_RNG: RefCell<SmallRng> = RefCell::new(SmallRng::from_rng(&mut rand::rng()));
27}
28
29/// Default refresh lead time before TTL expiry.
30pub const DEFAULT_REFRESH_EARLY: Duration = Duration::from_secs(30);
31/// Default stale-while-error window.
32pub const DEFAULT_STALE_WHILE_ERROR: Duration = Duration::from_secs(60);
33/// Minimum accepted TTL for upstream responses.
34pub const MIN_TTL_FLOOR: Duration = Duration::from_secs(30);
35/// Default maximum TTL clamp.
36pub const DEFAULT_MAX_TTL: Duration = Duration::from_secs(60 * 60 * 24);
37/// Default size guard (1 MiB).
38pub const DEFAULT_MAX_RESPONSE_BYTES: u64 = 1_048_576;
39/// Default prefetch jitter.
40pub const DEFAULT_PREFETCH_JITTER: Duration = Duration::from_secs(5);
41/// Maximum redirect depth.
42pub const MAX_REDIRECTS: u8 = 10;
43
44/// Supported jitter strategies for retry policies.
45#[derive(Clone, Debug, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
46#[serde(rename_all = "snake_case")]
47pub enum JitterStrategy {
48	/// No jitter; deterministic backoff schedule.
49	None,
50	/// Full jitter; randomize delay between 0 and current backoff.
51	#[default]
52	Full,
53	/// Decorrelated jitter per AWS architecture guidance.
54	Decorrelated,
55}
56
57/// Public representation of provider lifecycle state.
58#[derive(Clone, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
59#[serde(rename_all = "PascalCase")]
60pub enum ProviderState {
61	/// No JWKS payload has been cached yet.
62	Empty,
63	/// Initial fetch operation is currently running.
64	Loading,
65	/// Fresh JWKS payload is available for requests.
66	Ready,
67	/// Cache is serving while a refresh is in progress.
68	Refreshing,
69}
70
71/// Retry configuration for HTTP fetch operations.
72#[derive(Clone, Debug, Serialize, Deserialize)]
73pub struct RetryPolicy {
74	/// Maximum number of retry attempts to perform after the initial request.
75	pub max_retries: u32,
76	/// Timeout applied to each individual HTTP attempt.
77	pub attempt_timeout: Duration,
78	/// Initial delay before retrying after a failure.
79	pub initial_backoff: Duration,
80	/// Upper bound applied to exponential backoff growth.
81	pub max_backoff: Duration,
82	/// Overall deadline that bounds the entire retry sequence.
83	pub deadline: Duration,
84	/// Strategy used to randomize the computed backoff.
85	#[serde(default)]
86	pub jitter: JitterStrategy,
87}
88impl RetryPolicy {
89	/// Validate invariants for retry configuration.
90	pub fn validate(&self) -> Result<()> {
91		if self.attempt_timeout < Duration::from_millis(100) {
92			return Err(Error::Validation {
93				field: "retry_policy.attempt_timeout",
94				reason: "Must be at least 100 ms.".into(),
95			});
96		}
97		if self.initial_backoff.is_zero() {
98			return Err(Error::Validation {
99				field: "retry_policy.initial_backoff",
100				reason: "Must be greater than zero.".into(),
101			});
102		}
103		if self.max_backoff < self.initial_backoff {
104			return Err(Error::Validation {
105				field: "retry_policy.max_backoff",
106				reason: "Must be greater than or equal to initial_backoff.".into(),
107			});
108		}
109		if self.deadline < self.attempt_timeout {
110			return Err(Error::Validation {
111				field: "retry_policy.deadline",
112				reason: "Must be greater than or equal to attempt_timeout.".into(),
113			});
114		}
115		Ok(())
116	}
117
118	/// Compute backoff for a retry attempt using the selected jitter strategy.
119	pub fn compute_backoff(&self, attempt: u32) -> Duration {
120		self.default_backoff(attempt)
121	}
122
123	/// Default exponential backoff with jitter following the AWS architecture guidance.
124	pub fn default_backoff(&self, attempt: u32) -> Duration {
125		let exponent = attempt.min(32);
126		let base = self.initial_backoff.mul_f64(2f64.powi(exponent as i32));
127		let bounded = base.min(self.max_backoff).max(self.initial_backoff);
128
129		self.apply_jitter(bounded, attempt)
130	}
131
132	fn apply_jitter(&self, bounded: Duration, attempt: u32) -> Duration {
133		match self.jitter {
134			JitterStrategy::None => bounded,
135			JitterStrategy::Full => {
136				let lower = bounded.mul_f64(0.8).max(self.initial_backoff);
137				let upper = bounded.min(self.max_backoff);
138
139				random_within(lower, upper)
140			},
141			JitterStrategy::Decorrelated => {
142				let prev = if attempt == 0 { self.initial_backoff } else { bounded };
143				let ceiling = self.max_backoff.min(prev.mul_f64(3.0));
144
145				random_within(self.initial_backoff, ceiling.max(self.initial_backoff))
146			},
147		}
148	}
149}
150impl Default for RetryPolicy {
151	fn default() -> Self {
152		Self {
153			max_retries: 2,
154			attempt_timeout: Duration::from_secs(3),
155			initial_backoff: Duration::from_millis(250),
156			max_backoff: Duration::from_secs(2),
157			deadline: Duration::from_secs(8),
158			jitter: JitterStrategy::Full,
159		}
160	}
161}
162
163/// Registration describing how to fetch and maintain JWKS for a provider.
164#[derive(Clone, Debug, Serialize, Deserialize)]
165pub struct IdentityProviderRegistration {
166	/// Tenant identifier used for metrics, caching, and persistence scope.
167	pub tenant_id: String,
168	/// Provider identifier unique within the tenant.
169	pub provider_id: String,
170	/// URL of the JWKS endpoint to fetch signing keys from.
171	pub jwks_url: Url,
172	/// Whether HTTPS is required for JWKS retrieval.
173	#[serde(default = "default_true")]
174	pub require_https: bool,
175	/// Optional allowlist of domains permitted for redirects.
176	#[serde(default, deserialize_with = "crate::security::deserialize_allowed_domains")]
177	pub allowed_domains: Vec<String>,
178	/// Lead time before expiry to trigger proactive refresh.
179	#[serde(default = "default_refresh_early")]
180	pub refresh_early: Duration,
181	/// Duration to continue serving stale data when refresh fails.
182	#[serde(default = "default_stale_while_error")]
183	pub stale_while_error: Duration,
184	/// Minimum TTL applied to upstream responses.
185	#[serde(default = "default_min_ttl")]
186	pub min_ttl: Duration,
187	/// Maximum TTL applied to upstream responses.
188	#[serde(default = "default_max_ttl")]
189	pub max_ttl: Duration,
190	/// Maximum size allowed for JWKS payloads in bytes.
191	#[serde(default = "default_max_response_bytes")]
192	pub max_response_bytes: u64,
193	/// TTL applied when persisting negative cache outcomes.
194	#[serde(default)]
195	pub negative_cache_ttl: Duration,
196	/// Maximum number of redirects to follow during fetch.
197	#[serde(default = "default_max_redirects")]
198	pub max_redirects: u8,
199	/// Optional SPKI fingerprints used for TLS pinning.
200	#[serde(default)]
201	pub pinned_spki: Vec<SpkiFingerprint>,
202	/// Random jitter applied when scheduling proactive refreshes.
203	#[serde(default = "default_prefetch_jitter")]
204	pub prefetch_jitter: Duration,
205	/// Retry policy configuration for JWKS fetch attempts.
206	#[serde(default)]
207	pub retry_policy: RetryPolicy,
208}
209impl IdentityProviderRegistration {
210	/// Construct a new registration with default cache settings.
211	pub fn new(
212		tenant_id: impl Into<String>,
213		provider_id: impl Into<String>,
214		jwks_url: impl AsRef<str>,
215	) -> Result<Self> {
216		let jwks_url = Url::parse(jwks_url.as_ref())?;
217
218		Ok(Self {
219			tenant_id: tenant_id.into(),
220			provider_id: provider_id.into(),
221			jwks_url,
222			require_https: true,
223			allowed_domains: Vec::new(),
224			refresh_early: DEFAULT_REFRESH_EARLY,
225			stale_while_error: DEFAULT_STALE_WHILE_ERROR,
226			min_ttl: MIN_TTL_FLOOR,
227			max_ttl: DEFAULT_MAX_TTL,
228			max_response_bytes: DEFAULT_MAX_RESPONSE_BYTES,
229			negative_cache_ttl: Duration::ZERO,
230			max_redirects: 3,
231			pinned_spki: Vec::new(),
232			prefetch_jitter: DEFAULT_PREFETCH_JITTER,
233			retry_policy: RetryPolicy::default(),
234		})
235	}
236
237	/// Canonicalise the domain allowlist in-place.
238	pub fn normalize_allowed_domains(&mut self) {
239		let domains = mem::take(&mut self.allowed_domains);
240
241		self.allowed_domains = security::normalize_allowlist(domains);
242	}
243
244	/// Set HTTPS requirement to the desired value.
245	pub fn with_require_https(mut self, require_https: bool) -> Self {
246		self.require_https = require_https;
247
248		self
249	}
250
251	/// Validate the registration against the documented constraints.
252	pub fn validate(&self) -> Result<()> {
253		validate_tenant_id(&self.tenant_id)?;
254		validate_provider_id(&self.provider_id)?;
255
256		if self.require_https {
257			security::enforce_https(&self.jwks_url)?;
258		}
259
260		if let Some(host) = self.jwks_url.host_str() {
261			if !security::host_is_allowed(host, &self.allowed_domains) {
262				return Err(Error::Validation {
263					field: "jwks_url",
264					reason: "Host is not within the allowed_domains allowlist.".into(),
265				});
266			}
267		} else {
268			return Err(Error::Validation {
269				field: "jwks_url",
270				reason: "Must include a host component.".into(),
271			});
272		}
273
274		if self.refresh_early < Duration::from_secs(1) {
275			return Err(Error::Validation {
276				field: "refresh_early",
277				reason: "Must be at least 1 second.".into(),
278			});
279		}
280		if self.min_ttl < MIN_TTL_FLOOR {
281			return Err(Error::Validation {
282				field: "min_ttl",
283				reason: format!("Must be at least {:?}.", MIN_TTL_FLOOR),
284			});
285		}
286		if self.max_ttl < self.min_ttl {
287			return Err(Error::Validation {
288				field: "max_ttl",
289				reason: "Must be greater than or equal to min_ttl.".into(),
290			});
291		}
292		if self.refresh_early >= self.max_ttl {
293			return Err(Error::Validation {
294				field: "refresh_early",
295				reason: "Must be less than max_ttl.".into(),
296			});
297		}
298		if self.max_response_bytes == 0 {
299			return Err(Error::Validation {
300				field: "max_response_bytes",
301				reason: "Must be greater than zero.".into(),
302			});
303		}
304		if self.max_redirects > MAX_REDIRECTS {
305			return Err(Error::Validation {
306				field: "max_redirects",
307				reason: format!("Must be less than or equal to {}.", MAX_REDIRECTS),
308			});
309		}
310		if !self.negative_cache_ttl.is_zero() && self.negative_cache_ttl < Duration::from_secs(1) {
311			return Err(Error::Validation {
312				field: "negative_cache_ttl",
313				reason: "Must be zero or at least one second.".into(),
314			});
315		}
316
317		self.retry_policy.validate()?;
318
319		for domain in &self.allowed_domains {
320			if let Some(canonical) = security::canonicalize_dns_name(domain) {
321				if canonical != *domain {
322					return Err(Error::Validation {
323						field: "allowed_domains",
324						reason: "Entries must be canonical hostnames (lowercase, no trailing dot)."
325							.into(),
326					});
327				}
328			} else {
329				return Err(Error::Validation {
330					field: "allowed_domains",
331					reason: "Entries must be non-empty hostnames.".into(),
332				});
333			}
334		}
335
336		Ok(())
337	}
338}
339
340/// Snapshot of cache payload persisted to external storage.
341#[derive(Clone, Debug, Serialize, Deserialize)]
342pub struct PersistentSnapshot {
343	/// Tenant identifier associated with the snapshot.
344	pub tenant_id: String,
345	/// Provider identifier within the tenant scope.
346	pub provider_id: String,
347	/// Serialized JWKS payload captured from the cache.
348	pub jwks_json: String,
349	/// Entity tag returned by the JWKS endpoint, if present.
350	pub etag: Option<String>,
351	/// Last-Modified timestamp advertised by the JWKS endpoint.
352	#[serde(default)]
353	pub last_modified: Option<DateTime<Utc>>,
354	/// UTC timestamp when the cached payload expires.
355	pub expires_at: DateTime<Utc>,
356	/// UTC timestamp when the snapshot was persisted.
357	pub persisted_at: DateTime<Utc>,
358}
359impl PersistentSnapshot {
360	/// Validate snapshot metadata aligns with registration expectations.
361	pub fn validate(&self, registration: &IdentityProviderRegistration) -> Result<()> {
362		if self.jwks_json.len() as u64 > registration.max_response_bytes {
363			return Err(Error::Validation {
364				field: "jwks_json",
365				reason: format!(
366					"Snapshot exceeds max_response_bytes ({} bytes).",
367					registration.max_response_bytes
368				),
369			});
370		}
371
372		if self.tenant_id != registration.tenant_id {
373			return Err(Error::Validation {
374				field: "tenant_id",
375				reason: "Snapshot tenant does not match registration.".into(),
376			});
377		}
378		if self.provider_id != registration.provider_id {
379			return Err(Error::Validation {
380				field: "provider_id",
381				reason: "Snapshot provider does not match registration.".into(),
382			});
383		}
384
385		if let Some(etag) = &self.etag
386			&& !etag.is_ascii()
387		{
388			return Err(Error::Validation { field: "etag", reason: "ETag must be ASCII.".into() });
389		}
390
391		if self.expires_at < self.persisted_at {
392			return Err(Error::Validation {
393				field: "expires_at",
394				reason: "Cannot be earlier than persisted_at.".into(),
395			});
396		}
397
398		Ok(())
399	}
400}
401
402/// Internal key mapping tenants and providers.
403#[derive(Clone, Debug, PartialEq, Eq, Hash)]
404pub struct TenantProviderKey {
405	pub tenant_id: String,
406	pub provider_id: String,
407}
408impl TenantProviderKey {
409	pub fn new(tenant_id: impl Into<String>, provider_id: impl Into<String>) -> Self {
410		Self { tenant_id: tenant_id.into(), provider_id: provider_id.into() }
411	}
412}
413
414/// Builder for [`Registry`] enabling multi-tenant configuration.
415#[derive(Debug, Default)]
416pub struct RegistryBuilder {
417	config: RegistryConfig,
418}
419impl RegistryBuilder {
420	/// Create a builder with default configuration.
421	pub fn new() -> Self {
422		Self::default()
423	}
424
425	/// Enforce HTTPS for registrations (enabled by default).
426	pub fn require_https(mut self, require_https: bool) -> Self {
427		self.config.require_https = require_https;
428
429		self
430	}
431
432	/// Override the default refresh-early offset applied to registrations.
433	pub fn default_refresh_early(mut self, value: Duration) -> Self {
434		self.config.default_refresh_early = value;
435
436		self
437	}
438
439	/// Override the default stale-while-error window applied to registrations.
440	pub fn default_stale_while_error(mut self, value: Duration) -> Self {
441		self.config.default_stale_while_error = value;
442
443		self
444	}
445
446	/// Add an entry to the global domain allowlist.
447	pub fn add_allowed_domain(mut self, domain: impl Into<String>) -> Self {
448		let raw = domain.into();
449
450		if let Some(domain) = security::canonicalize_dns_name(&raw)
451			&& !self.config.allowed_domains.contains(&domain)
452		{
453			self.config.allowed_domains.push(domain);
454		}
455
456		self
457	}
458
459	/// Replace the global domain allowlist.
460	pub fn allowed_domains<I, S>(mut self, domains: I) -> Self
461	where
462		I: IntoIterator<Item = S>,
463		S: Into<String>,
464	{
465		self.config.allowed_domains.clear();
466
467		for domain in domains {
468			self = self.add_allowed_domain(domain);
469		}
470
471		self
472	}
473
474	#[cfg(feature = "redis")]
475	/// Configure Redis-backed persistence for snapshots.
476	pub fn with_redis_client(mut self, client: redis::Client) -> Self {
477		self.config.persistence = Some(RedisPersistence::new(client));
478
479		self
480	}
481
482	#[cfg(feature = "redis")]
483	/// Adjust the Redis key namespace (defaults to `jwks-cache`).
484	pub fn redis_namespace(mut self, namespace: impl Into<String>) -> Self {
485		if let Some(persistence) = self.config.persistence.as_mut() {
486			persistence.namespace = Arc::from(namespace.into());
487		} else {
488			panic!("Redis client must be configured before setting namespace.");
489		}
490
491		self
492	}
493
494	/// Finalise the configuration and construct a [`Registry`].
495	pub fn build(self) -> Registry {
496		let mut config = self.config;
497
498		config.allowed_domains = security::normalize_allowlist(config.allowed_domains);
499
500		Registry {
501			inner: Arc::new(RwLock::new(RegistryState { providers: HashMap::new() })),
502			config: Arc::new(config),
503		}
504	}
505}
506
507/// Registry state container.
508#[derive(Clone, Debug)]
509pub struct Registry {
510	inner: Arc<RwLock<RegistryState>>,
511	config: Arc<RegistryConfig>,
512}
513impl Registry {
514	/// Create a new registry instance with defaults.
515	pub fn new() -> Self {
516		Self::builder().build()
517	}
518
519	/// Create a [`RegistryBuilder`] for advanced configuration.
520	pub fn builder() -> RegistryBuilder {
521		RegistryBuilder::new()
522	}
523
524	/// Register or update a provider configuration.
525	pub async fn register(&self, mut registration: IdentityProviderRegistration) -> Result<()> {
526		if self.config.require_https {
527			if !registration.require_https {
528				return Err(Error::Security(
529					"Registry requires HTTPS for all provider registrations.".into(),
530				));
531			}
532		} else {
533			registration.require_https = false;
534		}
535
536		registration.normalize_allowed_domains();
537
538		if registration.refresh_early == DEFAULT_REFRESH_EARLY {
539			registration.refresh_early = self.config.default_refresh_early;
540		}
541		if registration.stale_while_error == DEFAULT_STALE_WHILE_ERROR {
542			registration.stale_while_error = self.config.default_stale_while_error;
543		}
544		if registration.allowed_domains.is_empty() && !self.config.allowed_domains.is_empty() {
545			registration.allowed_domains = self.config.allowed_domains.clone();
546		}
547
548		if let Some(host) = registration.jwks_url.host_str()
549			&& !security::host_is_allowed(host, &self.config.allowed_domains)
550		{
551			return Err(Error::Security(format!(
552				"Host '{host}' is not in the registry allowlist."
553			)));
554		}
555
556		let key = TenantProviderKey::new(&registration.tenant_id, &registration.provider_id);
557		let manager = CacheManager::new(registration.clone())?;
558		let metrics = manager.metrics();
559		let handle =
560			Arc::new(ProviderHandle { registration: Arc::new(registration), manager, metrics });
561
562		{
563			let mut state = self.inner.write().await;
564
565			state.providers.insert(key.clone(), handle.clone());
566		}
567
568		#[cfg(feature = "redis")]
569		if let Some(persistence) = &self.config.persistence {
570			if let Some(snapshot) = persistence.load(&key.tenant_id, &key.provider_id).await? {
571				handle.manager.restore_snapshot(snapshot).await?;
572			}
573		}
574
575		Ok(())
576	}
577
578	/// Resolve JWKS for a tenant/provider pair.
579	pub async fn resolve(
580		&self,
581		tenant_id: &str,
582		provider_id: &str,
583		kid: Option<&str>,
584	) -> Result<Arc<JwkSet>> {
585		let key = TenantProviderKey::new(tenant_id, provider_id);
586		let handle = {
587			let state = self.inner.read().await;
588
589			state.providers.get(&key).cloned()
590		};
591		let handle = handle.ok_or_else(|| Error::NotRegistered {
592			tenant: tenant_id.to_string(),
593			provider: provider_id.to_string(),
594		})?;
595
596		handle.manager.resolve(kid).await
597	}
598
599	/// Trigger a manual refresh for a registered provider.
600	pub async fn refresh(&self, tenant_id: &str, provider_id: &str) -> Result<()> {
601		let key = TenantProviderKey::new(tenant_id, provider_id);
602		let handle = {
603			let state = self.inner.read().await;
604			state.providers.get(&key).cloned()
605		};
606		let handle = handle.ok_or_else(|| Error::NotRegistered {
607			tenant: tenant_id.to_string(),
608			provider: provider_id.to_string(),
609		})?;
610
611		handle.manager.trigger_refresh().await
612	}
613
614	/// Remove a provider registration if present.
615	pub async fn unregister(&self, tenant_id: &str, provider_id: &str) -> Result<bool> {
616		let key = TenantProviderKey::new(tenant_id, provider_id);
617		let mut state = self.inner.write().await;
618
619		Ok(state.providers.remove(&key).is_some())
620	}
621
622	/// Fetch status information for a specific provider.
623	pub async fn provider_status(
624		&self,
625		tenant_id: &str,
626		provider_id: &str,
627	) -> Result<ProviderStatus> {
628		let key = TenantProviderKey::new(tenant_id, provider_id);
629		let handle = {
630			let state = self.inner.read().await;
631
632			state.providers.get(&key).cloned()
633		};
634		let handle = handle.ok_or_else(|| Error::NotRegistered {
635			tenant: tenant_id.to_string(),
636			provider: provider_id.to_string(),
637		})?;
638
639		Ok(handle.status().await)
640	}
641
642	/// Fetch status for every registered provider.
643	pub async fn all_statuses(&self) -> Vec<ProviderStatus> {
644		let handles: Vec<Arc<ProviderHandle>> = {
645			let state = self.inner.read().await;
646			state.providers.values().cloned().collect()
647		};
648		let mut statuses = Vec::with_capacity(handles.len());
649
650		for handle in handles {
651			statuses.push(handle.status().await);
652		}
653
654		statuses
655	}
656
657	/// Persist snapshots for every provider when persistence is configured.
658	pub async fn persist_all(&self) -> Result<()> {
659		#[cfg(feature = "redis")]
660		{
661			if let Some(persistence) = &self.config.persistence {
662				let handles: Vec<Arc<ProviderHandle>> = {
663					let state = self.inner.read().await;
664
665					state.providers.values().cloned().collect()
666				};
667				let mut snapshots = Vec::new();
668
669				for handle in handles {
670					if let Some(snapshot) = handle.manager.persistent_snapshot().await? {
671						snapshots.push(snapshot);
672					}
673				}
674
675				persistence.persist(&snapshots).await?;
676			}
677		}
678
679		Ok(())
680	}
681
682	/// Restore cached entries from persistence for all active registrations.
683	pub async fn restore_from_persistence(&self) -> Result<()> {
684		#[cfg(feature = "redis")]
685		{
686			if let Some(persistence) = &self.config.persistence {
687				let handles: Vec<Arc<ProviderHandle>> = {
688					let state = self.inner.read().await;
689
690					state.providers.values().cloned().collect()
691				};
692
693				for handle in handles {
694					if let Some(snapshot) = persistence
695						.load(&handle.registration.tenant_id, &handle.registration.provider_id)
696						.await?
697					{
698						handle.manager.restore_snapshot(snapshot).await?;
699					}
700				}
701			}
702		}
703
704		Ok(())
705	}
706}
707impl Default for Registry {
708	fn default() -> Self {
709		Self::new()
710	}
711}
712
713/// Status projection for a provider, aligned with the OpenAPI contract.
714#[derive(Clone, Debug, Serialize, Deserialize)]
715pub struct ProviderStatus {
716	/// Tenant identifier that owns the provider.
717	pub tenant_id: String,
718	/// Provider identifier unique within the tenant.
719	pub provider_id: String,
720	/// Lifecycle state currently reported for the provider.
721	pub state: ProviderState,
722	/// Timestamp of the most recent successful refresh.
723	pub last_refresh: Option<DateTime<Utc>>,
724	/// Scheduled timestamp for the next refresh attempt.
725	pub next_refresh: Option<DateTime<Utc>>,
726	/// Expiration timestamp for the active payload, if available.
727	pub expires_at: Option<DateTime<Utc>>,
728	/// Consecutive error count observed during refresh attempts.
729	pub error_count: u32,
730	/// Ratio of cache hits to total requests.
731	pub hit_rate: f64,
732	/// Ratio of served responses that were stale.
733	pub stale_serve_ratio: f64,
734	/// Metrics emitted to describe provider performance.
735	pub metrics: Vec<StatusMetric>,
736}
737impl ProviderStatus {
738	fn from_components(
739		registration: &IdentityProviderRegistration,
740		snapshot: CacheSnapshot,
741		metrics: ProviderMetricsSnapshot,
742	) -> Self {
743		let mut last_refresh = None;
744		let mut next_refresh = None;
745		let mut expires_at = None;
746		let mut error_count = 0;
747		let state = match &snapshot.state {
748			CacheState::Empty => ProviderState::Empty,
749			CacheState::Loading => ProviderState::Loading,
750			CacheState::Ready(payload) => {
751				last_refresh = Some(payload.last_refresh_at);
752				next_refresh = snapshot.to_datetime(payload.next_refresh_at);
753				expires_at = snapshot.to_datetime(payload.expires_at);
754				error_count = payload.error_count;
755				ProviderState::Ready
756			},
757			CacheState::Refreshing(payload) => {
758				last_refresh = Some(payload.last_refresh_at);
759				next_refresh = snapshot.to_datetime(payload.next_refresh_at);
760				expires_at = snapshot.to_datetime(payload.expires_at);
761				error_count = payload.error_count;
762				ProviderState::Refreshing
763			},
764		};
765		let tenant = &registration.tenant_id;
766		let provider = &registration.provider_id;
767		let mut status_metrics = vec![
768			StatusMetric::new(
769				"jwks_cache_requests_total",
770				metrics.total_requests as f64,
771				tenant,
772				provider,
773			),
774			StatusMetric::new("jwks_cache_hits_total", metrics.cache_hits as f64, tenant, provider),
775			StatusMetric::new(
776				"jwks_cache_stale_total",
777				metrics.stale_serves as f64,
778				tenant,
779				provider,
780			),
781			StatusMetric::new(
782				"jwks_cache_refresh_errors_total",
783				metrics.refresh_errors as f64,
784				tenant,
785				provider,
786			),
787		];
788
789		if let Some(last_micros) = metrics.last_refresh_micros {
790			status_metrics.push(StatusMetric::new(
791				"jwks_cache_last_refresh_micros",
792				last_micros as f64,
793				tenant,
794				provider,
795			));
796		}
797
798		Self {
799			tenant_id: tenant.clone(),
800			provider_id: provider.clone(),
801			state,
802			last_refresh,
803			next_refresh,
804			expires_at,
805			error_count,
806			hit_rate: metrics.hit_rate(),
807			stale_serve_ratio: metrics.stale_ratio(),
808			metrics: status_metrics,
809		}
810	}
811}
812
813/// Metric sample used in provider status responses.
814#[derive(Clone, Debug, Serialize, Deserialize)]
815pub struct StatusMetric {
816	/// Metric name following the monitoring schema.
817	pub name: String,
818	/// Numeric value captured for the metric.
819	pub value: f64,
820	/// Additional labels enriching the metric sample.
821	#[serde(default)]
822	pub labels: HashMap<String, String>,
823}
824impl StatusMetric {
825	fn new(name: impl Into<String>, value: f64, tenant: &str, provider: &str) -> Self {
826		let mut labels = HashMap::with_capacity(2);
827
828		labels.insert("tenant".into(), tenant.into());
829		labels.insert("provider".into(), provider.into());
830
831		Self { name: name.into(), value, labels }
832	}
833}
834
835#[derive(Debug)]
836struct RegistryConfig {
837	require_https: bool,
838	default_refresh_early: Duration,
839	default_stale_while_error: Duration,
840	allowed_domains: Vec<String>,
841	#[cfg(feature = "redis")]
842	persistence: Option<RedisPersistence>,
843}
844impl Default for RegistryConfig {
845	fn default() -> Self {
846		Self {
847			require_https: true,
848			default_refresh_early: DEFAULT_REFRESH_EARLY,
849			default_stale_while_error: DEFAULT_STALE_WHILE_ERROR,
850			allowed_domains: Vec::new(),
851			#[cfg(feature = "redis")]
852			persistence: None,
853		}
854	}
855}
856
857#[derive(Debug)]
858struct ProviderHandle {
859	registration: Arc<IdentityProviderRegistration>,
860	manager: CacheManager,
861	metrics: Arc<ProviderMetrics>,
862}
863impl ProviderHandle {
864	async fn status(&self) -> ProviderStatus {
865		let snapshot = self.manager.snapshot().await;
866		let metrics = self.metrics.snapshot();
867
868		ProviderStatus::from_components(&self.registration, snapshot, metrics)
869	}
870}
871
872#[derive(Debug)]
873struct RegistryState {
874	// TODO: Consider replacing the RwLock<HashMap> with DashMap if contention becomes measurable.
875	providers: HashMap<TenantProviderKey, Arc<ProviderHandle>>,
876}
877
878#[cfg(feature = "redis")]
879#[derive(Clone, Debug)]
880struct RedisPersistence {
881	client: redis::Client,
882	namespace: Arc<str>,
883}
884#[cfg(feature = "redis")]
885impl RedisPersistence {
886	fn new(client: redis::Client) -> Self {
887		Self { client, namespace: Arc::from("jwks-cache") }
888	}
889
890	async fn persist(&self, snapshots: &[PersistentSnapshot]) -> Result<()> {
891		if snapshots.is_empty() {
892			return Ok(());
893		}
894
895		let mut conn = self.client.get_multiplexed_async_connection().await?;
896
897		for snapshot in snapshots {
898			let key = self.key(&snapshot.tenant_id, &snapshot.provider_id);
899			let payload = serde_json::to_string(snapshot)?;
900			let ttl = (snapshot.expires_at - Utc::now())
901				.to_std()
902				.unwrap_or_else(|_| Duration::from_secs(1));
903			let ttl_secs = ttl.as_secs().max(1);
904
905			conn.set_ex::<_, _, ()>(key, payload, ttl_secs).await?;
906		}
907
908		Ok(())
909	}
910
911	async fn load(&self, tenant: &str, provider: &str) -> Result<Option<PersistentSnapshot>> {
912		let mut conn = self.client.get_multiplexed_async_connection().await?;
913		let key = self.key(tenant, provider);
914		let value: Option<String> = conn.get(key).await?;
915
916		if let Some(json) = value {
917			let snapshot: PersistentSnapshot = serde_json::from_str(&json)?;
918
919			Ok(Some(snapshot))
920		} else {
921			Ok(None)
922		}
923	}
924
925	fn key(&self, tenant: &str, provider: &str) -> String {
926		format!("{}:{tenant}:{provider}", self.namespace)
927	}
928}
929
930fn random_within(min: Duration, max: Duration) -> Duration {
931	if max <= min {
932		return max;
933	}
934	SMALL_RNG.with(|cell| {
935		let mut rng = cell.borrow_mut();
936		let nanos = max.as_nanos() - min.as_nanos();
937		let jitter = rng.random_range(0..=nanos.min(u64::MAX as u128));
938
939		min + Duration::from_nanos(jitter as u64)
940	})
941}
942
943fn default_true() -> bool {
944	true
945}
946
947fn default_refresh_early() -> Duration {
948	DEFAULT_REFRESH_EARLY
949}
950
951fn default_stale_while_error() -> Duration {
952	DEFAULT_STALE_WHILE_ERROR
953}
954
955fn default_min_ttl() -> Duration {
956	MIN_TTL_FLOOR
957}
958
959fn default_max_ttl() -> Duration {
960	DEFAULT_MAX_TTL
961}
962
963fn default_max_response_bytes() -> u64 {
964	DEFAULT_MAX_RESPONSE_BYTES
965}
966
967fn default_max_redirects() -> u8 {
968	3
969}
970
971fn default_prefetch_jitter() -> Duration {
972	DEFAULT_PREFETCH_JITTER
973}
974
975fn validate_tenant_id(value: &str) -> Result<()> {
976	if value.is_empty() {
977		return Err(Error::Validation { field: "tenant_id", reason: "Must not be empty.".into() });
978	}
979	if value.len() > 64 {
980		return Err(Error::Validation {
981			field: "tenant_id",
982			reason: "Must be 64 characters or fewer.".into(),
983		});
984	}
985	if !value.as_bytes().iter().all(|b| b.is_ascii_alphanumeric() || *b == b'-') {
986		return Err(Error::Validation {
987			field: "tenant_id",
988			reason: "May only contain ASCII letters, numbers, and '-'.".into(),
989		});
990	}
991
992	Ok(())
993}
994
995fn validate_provider_id(value: &str) -> Result<()> {
996	if value.is_empty() {
997		return Err(Error::Validation {
998			field: "provider_id",
999			reason: "Must not be empty.".into(),
1000		});
1001	}
1002	if value.len() > 64 {
1003		return Err(Error::Validation {
1004			field: "provider_id",
1005			reason: "Must be 64 characters or fewer.".into(),
1006		});
1007	}
1008	if !value.as_bytes().iter().all(|b| b.is_ascii_alphanumeric() || matches!(b, b'-' | b'_')) {
1009		return Err(Error::Validation {
1010			field: "provider_id",
1011			reason: "May only contain ASCII letters, numbers, '-', or '_'.".into(),
1012		});
1013	}
1014
1015	Ok(())
1016}