1use std::{cell::RefCell, collections::HashMap, mem};
7use jsonwebtoken::jwk::JwkSet;
9use rand::{Rng, SeedableRng, rngs::SmallRng};
10#[cfg(feature = "redis")] use redis::AsyncCommands;
11use serde::{Deserialize, Serialize};
12use tokio::sync::RwLock;
13use url::Url;
14use crate::{
16 _prelude::*,
17 cache::{
18 manager::{CacheManager, CacheSnapshot},
19 state::CacheState,
20 },
21 metrics::{ProviderMetrics, ProviderMetricsSnapshot},
22 security::{self, SpkiFingerprint},
23};
24
25thread_local! {
26 static SMALL_RNG: RefCell<SmallRng> = RefCell::new(SmallRng::from_rng(&mut rand::rng()));
27}
28
29pub const DEFAULT_REFRESH_EARLY: Duration = Duration::from_secs(30);
31pub const DEFAULT_STALE_WHILE_ERROR: Duration = Duration::from_secs(60);
33pub const MIN_TTL_FLOOR: Duration = Duration::from_secs(30);
35pub const DEFAULT_MAX_TTL: Duration = Duration::from_secs(60 * 60 * 24);
37pub const DEFAULT_MAX_RESPONSE_BYTES: u64 = 1_048_576;
39pub const DEFAULT_PREFETCH_JITTER: Duration = Duration::from_secs(5);
41pub const MAX_REDIRECTS: u8 = 10;
43
44#[derive(Clone, Debug, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
46#[serde(rename_all = "snake_case")]
47pub enum JitterStrategy {
48 None,
50 #[default]
52 Full,
53 Decorrelated,
55}
56
57#[derive(Clone, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
59#[serde(rename_all = "PascalCase")]
60pub enum ProviderState {
61 Empty,
63 Loading,
65 Ready,
67 Refreshing,
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize)]
73pub struct RetryPolicy {
74 pub max_retries: u32,
76 pub attempt_timeout: Duration,
78 pub initial_backoff: Duration,
80 pub max_backoff: Duration,
82 pub deadline: Duration,
84 #[serde(default)]
86 pub jitter: JitterStrategy,
87}
88impl RetryPolicy {
89 pub fn validate(&self) -> Result<()> {
91 if self.attempt_timeout < Duration::from_millis(100) {
92 return Err(Error::Validation {
93 field: "retry_policy.attempt_timeout",
94 reason: "Must be at least 100 ms.".into(),
95 });
96 }
97 if self.initial_backoff.is_zero() {
98 return Err(Error::Validation {
99 field: "retry_policy.initial_backoff",
100 reason: "Must be greater than zero.".into(),
101 });
102 }
103 if self.max_backoff < self.initial_backoff {
104 return Err(Error::Validation {
105 field: "retry_policy.max_backoff",
106 reason: "Must be greater than or equal to initial_backoff.".into(),
107 });
108 }
109 if self.deadline < self.attempt_timeout {
110 return Err(Error::Validation {
111 field: "retry_policy.deadline",
112 reason: "Must be greater than or equal to attempt_timeout.".into(),
113 });
114 }
115 Ok(())
116 }
117
118 pub fn compute_backoff(&self, attempt: u32) -> Duration {
120 self.default_backoff(attempt)
121 }
122
123 pub fn default_backoff(&self, attempt: u32) -> Duration {
125 let exponent = attempt.min(32);
126 let base = self.initial_backoff.mul_f64(2f64.powi(exponent as i32));
127 let bounded = base.min(self.max_backoff).max(self.initial_backoff);
128
129 self.apply_jitter(bounded, attempt)
130 }
131
132 fn apply_jitter(&self, bounded: Duration, attempt: u32) -> Duration {
133 match self.jitter {
134 JitterStrategy::None => bounded,
135 JitterStrategy::Full => {
136 let lower = bounded.mul_f64(0.8).max(self.initial_backoff);
137 let upper = bounded.min(self.max_backoff);
138
139 random_within(lower, upper)
140 },
141 JitterStrategy::Decorrelated => {
142 let prev = if attempt == 0 { self.initial_backoff } else { bounded };
143 let ceiling = self.max_backoff.min(prev.mul_f64(3.0));
144
145 random_within(self.initial_backoff, ceiling.max(self.initial_backoff))
146 },
147 }
148 }
149}
150impl Default for RetryPolicy {
151 fn default() -> Self {
152 Self {
153 max_retries: 2,
154 attempt_timeout: Duration::from_secs(3),
155 initial_backoff: Duration::from_millis(250),
156 max_backoff: Duration::from_secs(2),
157 deadline: Duration::from_secs(8),
158 jitter: JitterStrategy::Full,
159 }
160 }
161}
162
163#[derive(Clone, Debug, Serialize, Deserialize)]
165pub struct IdentityProviderRegistration {
166 pub tenant_id: String,
168 pub provider_id: String,
170 pub jwks_url: Url,
172 #[serde(default = "default_true")]
174 pub require_https: bool,
175 #[serde(default, deserialize_with = "crate::security::deserialize_allowed_domains")]
177 pub allowed_domains: Vec<String>,
178 #[serde(default = "default_refresh_early")]
180 pub refresh_early: Duration,
181 #[serde(default = "default_stale_while_error")]
183 pub stale_while_error: Duration,
184 #[serde(default = "default_min_ttl")]
186 pub min_ttl: Duration,
187 #[serde(default = "default_max_ttl")]
189 pub max_ttl: Duration,
190 #[serde(default = "default_max_response_bytes")]
192 pub max_response_bytes: u64,
193 #[serde(default)]
195 pub negative_cache_ttl: Duration,
196 #[serde(default = "default_max_redirects")]
198 pub max_redirects: u8,
199 #[serde(default)]
201 pub pinned_spki: Vec<SpkiFingerprint>,
202 #[serde(default = "default_prefetch_jitter")]
204 pub prefetch_jitter: Duration,
205 #[serde(default)]
207 pub retry_policy: RetryPolicy,
208}
209impl IdentityProviderRegistration {
210 pub fn new(
212 tenant_id: impl Into<String>,
213 provider_id: impl Into<String>,
214 jwks_url: impl AsRef<str>,
215 ) -> Result<Self> {
216 let jwks_url = Url::parse(jwks_url.as_ref())?;
217
218 Ok(Self {
219 tenant_id: tenant_id.into(),
220 provider_id: provider_id.into(),
221 jwks_url,
222 require_https: true,
223 allowed_domains: Vec::new(),
224 refresh_early: DEFAULT_REFRESH_EARLY,
225 stale_while_error: DEFAULT_STALE_WHILE_ERROR,
226 min_ttl: MIN_TTL_FLOOR,
227 max_ttl: DEFAULT_MAX_TTL,
228 max_response_bytes: DEFAULT_MAX_RESPONSE_BYTES,
229 negative_cache_ttl: Duration::ZERO,
230 max_redirects: 3,
231 pinned_spki: Vec::new(),
232 prefetch_jitter: DEFAULT_PREFETCH_JITTER,
233 retry_policy: RetryPolicy::default(),
234 })
235 }
236
237 pub fn normalize_allowed_domains(&mut self) {
239 let domains = mem::take(&mut self.allowed_domains);
240
241 self.allowed_domains = security::normalize_allowlist(domains);
242 }
243
244 pub fn with_require_https(mut self, require_https: bool) -> Self {
246 self.require_https = require_https;
247
248 self
249 }
250
251 pub fn validate(&self) -> Result<()> {
253 validate_tenant_id(&self.tenant_id)?;
254 validate_provider_id(&self.provider_id)?;
255
256 if self.require_https {
257 security::enforce_https(&self.jwks_url)?;
258 }
259
260 if let Some(host) = self.jwks_url.host_str() {
261 if !security::host_is_allowed(host, &self.allowed_domains) {
262 return Err(Error::Validation {
263 field: "jwks_url",
264 reason: "Host is not within the allowed_domains allowlist.".into(),
265 });
266 }
267 } else {
268 return Err(Error::Validation {
269 field: "jwks_url",
270 reason: "Must include a host component.".into(),
271 });
272 }
273
274 if self.refresh_early < Duration::from_secs(1) {
275 return Err(Error::Validation {
276 field: "refresh_early",
277 reason: "Must be at least 1 second.".into(),
278 });
279 }
280 if self.min_ttl < MIN_TTL_FLOOR {
281 return Err(Error::Validation {
282 field: "min_ttl",
283 reason: format!("Must be at least {:?}.", MIN_TTL_FLOOR),
284 });
285 }
286 if self.max_ttl < self.min_ttl {
287 return Err(Error::Validation {
288 field: "max_ttl",
289 reason: "Must be greater than or equal to min_ttl.".into(),
290 });
291 }
292 if self.refresh_early >= self.max_ttl {
293 return Err(Error::Validation {
294 field: "refresh_early",
295 reason: "Must be less than max_ttl.".into(),
296 });
297 }
298 if self.max_response_bytes == 0 {
299 return Err(Error::Validation {
300 field: "max_response_bytes",
301 reason: "Must be greater than zero.".into(),
302 });
303 }
304 if self.max_redirects > MAX_REDIRECTS {
305 return Err(Error::Validation {
306 field: "max_redirects",
307 reason: format!("Must be less than or equal to {}.", MAX_REDIRECTS),
308 });
309 }
310 if !self.negative_cache_ttl.is_zero() && self.negative_cache_ttl < Duration::from_secs(1) {
311 return Err(Error::Validation {
312 field: "negative_cache_ttl",
313 reason: "Must be zero or at least one second.".into(),
314 });
315 }
316
317 self.retry_policy.validate()?;
318
319 for domain in &self.allowed_domains {
320 if let Some(canonical) = security::canonicalize_dns_name(domain) {
321 if canonical != *domain {
322 return Err(Error::Validation {
323 field: "allowed_domains",
324 reason: "Entries must be canonical hostnames (lowercase, no trailing dot)."
325 .into(),
326 });
327 }
328 } else {
329 return Err(Error::Validation {
330 field: "allowed_domains",
331 reason: "Entries must be non-empty hostnames.".into(),
332 });
333 }
334 }
335
336 Ok(())
337 }
338}
339
340#[derive(Clone, Debug, Serialize, Deserialize)]
342pub struct PersistentSnapshot {
343 pub tenant_id: String,
345 pub provider_id: String,
347 pub jwks_json: String,
349 pub etag: Option<String>,
351 #[serde(default)]
353 pub last_modified: Option<DateTime<Utc>>,
354 pub expires_at: DateTime<Utc>,
356 pub persisted_at: DateTime<Utc>,
358}
359impl PersistentSnapshot {
360 pub fn validate(&self, registration: &IdentityProviderRegistration) -> Result<()> {
362 if self.jwks_json.len() as u64 > registration.max_response_bytes {
363 return Err(Error::Validation {
364 field: "jwks_json",
365 reason: format!(
366 "Snapshot exceeds max_response_bytes ({} bytes).",
367 registration.max_response_bytes
368 ),
369 });
370 }
371
372 if self.tenant_id != registration.tenant_id {
373 return Err(Error::Validation {
374 field: "tenant_id",
375 reason: "Snapshot tenant does not match registration.".into(),
376 });
377 }
378 if self.provider_id != registration.provider_id {
379 return Err(Error::Validation {
380 field: "provider_id",
381 reason: "Snapshot provider does not match registration.".into(),
382 });
383 }
384
385 if let Some(etag) = &self.etag
386 && !etag.is_ascii()
387 {
388 return Err(Error::Validation { field: "etag", reason: "ETag must be ASCII.".into() });
389 }
390
391 if self.expires_at < self.persisted_at {
392 return Err(Error::Validation {
393 field: "expires_at",
394 reason: "Cannot be earlier than persisted_at.".into(),
395 });
396 }
397
398 Ok(())
399 }
400}
401
402#[derive(Clone, Debug, PartialEq, Eq, Hash)]
404pub struct TenantProviderKey {
405 pub tenant_id: String,
406 pub provider_id: String,
407}
408impl TenantProviderKey {
409 pub fn new(tenant_id: impl Into<String>, provider_id: impl Into<String>) -> Self {
410 Self { tenant_id: tenant_id.into(), provider_id: provider_id.into() }
411 }
412}
413
414#[derive(Debug, Default)]
416pub struct RegistryBuilder {
417 config: RegistryConfig,
418}
419impl RegistryBuilder {
420 pub fn new() -> Self {
422 Self::default()
423 }
424
425 pub fn require_https(mut self, require_https: bool) -> Self {
427 self.config.require_https = require_https;
428
429 self
430 }
431
432 pub fn default_refresh_early(mut self, value: Duration) -> Self {
434 self.config.default_refresh_early = value;
435
436 self
437 }
438
439 pub fn default_stale_while_error(mut self, value: Duration) -> Self {
441 self.config.default_stale_while_error = value;
442
443 self
444 }
445
446 pub fn add_allowed_domain(mut self, domain: impl Into<String>) -> Self {
448 let raw = domain.into();
449
450 if let Some(domain) = security::canonicalize_dns_name(&raw)
451 && !self.config.allowed_domains.contains(&domain)
452 {
453 self.config.allowed_domains.push(domain);
454 }
455
456 self
457 }
458
459 pub fn allowed_domains<I, S>(mut self, domains: I) -> Self
461 where
462 I: IntoIterator<Item = S>,
463 S: Into<String>,
464 {
465 self.config.allowed_domains.clear();
466
467 for domain in domains {
468 self = self.add_allowed_domain(domain);
469 }
470
471 self
472 }
473
474 #[cfg(feature = "redis")]
475 pub fn with_redis_client(mut self, client: redis::Client) -> Self {
477 self.config.persistence = Some(RedisPersistence::new(client));
478
479 self
480 }
481
482 #[cfg(feature = "redis")]
483 pub fn redis_namespace(mut self, namespace: impl Into<String>) -> Self {
485 if let Some(persistence) = self.config.persistence.as_mut() {
486 persistence.namespace = Arc::from(namespace.into());
487 } else {
488 panic!("Redis client must be configured before setting namespace.");
489 }
490
491 self
492 }
493
494 pub fn build(self) -> Registry {
496 let mut config = self.config;
497
498 config.allowed_domains = security::normalize_allowlist(config.allowed_domains);
499
500 Registry {
501 inner: Arc::new(RwLock::new(RegistryState { providers: HashMap::new() })),
502 config: Arc::new(config),
503 }
504 }
505}
506
507#[derive(Clone, Debug)]
509pub struct Registry {
510 inner: Arc<RwLock<RegistryState>>,
511 config: Arc<RegistryConfig>,
512}
513impl Registry {
514 pub fn new() -> Self {
516 Self::builder().build()
517 }
518
519 pub fn builder() -> RegistryBuilder {
521 RegistryBuilder::new()
522 }
523
524 pub async fn register(&self, mut registration: IdentityProviderRegistration) -> Result<()> {
526 if self.config.require_https {
527 if !registration.require_https {
528 return Err(Error::Security(
529 "Registry requires HTTPS for all provider registrations.".into(),
530 ));
531 }
532 } else {
533 registration.require_https = false;
534 }
535
536 registration.normalize_allowed_domains();
537
538 if registration.refresh_early == DEFAULT_REFRESH_EARLY {
539 registration.refresh_early = self.config.default_refresh_early;
540 }
541 if registration.stale_while_error == DEFAULT_STALE_WHILE_ERROR {
542 registration.stale_while_error = self.config.default_stale_while_error;
543 }
544 if registration.allowed_domains.is_empty() && !self.config.allowed_domains.is_empty() {
545 registration.allowed_domains = self.config.allowed_domains.clone();
546 }
547
548 if let Some(host) = registration.jwks_url.host_str()
549 && !security::host_is_allowed(host, &self.config.allowed_domains)
550 {
551 return Err(Error::Security(format!(
552 "Host '{host}' is not in the registry allowlist."
553 )));
554 }
555
556 let key = TenantProviderKey::new(®istration.tenant_id, ®istration.provider_id);
557 let manager = CacheManager::new(registration.clone())?;
558 let metrics = manager.metrics();
559 let handle =
560 Arc::new(ProviderHandle { registration: Arc::new(registration), manager, metrics });
561
562 {
563 let mut state = self.inner.write().await;
564
565 state.providers.insert(key.clone(), handle.clone());
566 }
567
568 #[cfg(feature = "redis")]
569 if let Some(persistence) = &self.config.persistence {
570 if let Some(snapshot) = persistence.load(&key.tenant_id, &key.provider_id).await? {
571 handle.manager.restore_snapshot(snapshot).await?;
572 }
573 }
574
575 Ok(())
576 }
577
578 pub async fn resolve(
580 &self,
581 tenant_id: &str,
582 provider_id: &str,
583 kid: Option<&str>,
584 ) -> Result<Arc<JwkSet>> {
585 let key = TenantProviderKey::new(tenant_id, provider_id);
586 let handle = {
587 let state = self.inner.read().await;
588
589 state.providers.get(&key).cloned()
590 };
591 let handle = handle.ok_or_else(|| Error::NotRegistered {
592 tenant: tenant_id.to_string(),
593 provider: provider_id.to_string(),
594 })?;
595
596 handle.manager.resolve(kid).await
597 }
598
599 pub async fn refresh(&self, tenant_id: &str, provider_id: &str) -> Result<()> {
601 let key = TenantProviderKey::new(tenant_id, provider_id);
602 let handle = {
603 let state = self.inner.read().await;
604 state.providers.get(&key).cloned()
605 };
606 let handle = handle.ok_or_else(|| Error::NotRegistered {
607 tenant: tenant_id.to_string(),
608 provider: provider_id.to_string(),
609 })?;
610
611 handle.manager.trigger_refresh().await
612 }
613
614 pub async fn unregister(&self, tenant_id: &str, provider_id: &str) -> Result<bool> {
616 let key = TenantProviderKey::new(tenant_id, provider_id);
617 let mut state = self.inner.write().await;
618
619 Ok(state.providers.remove(&key).is_some())
620 }
621
622 pub async fn provider_status(
624 &self,
625 tenant_id: &str,
626 provider_id: &str,
627 ) -> Result<ProviderStatus> {
628 let key = TenantProviderKey::new(tenant_id, provider_id);
629 let handle = {
630 let state = self.inner.read().await;
631
632 state.providers.get(&key).cloned()
633 };
634 let handle = handle.ok_or_else(|| Error::NotRegistered {
635 tenant: tenant_id.to_string(),
636 provider: provider_id.to_string(),
637 })?;
638
639 Ok(handle.status().await)
640 }
641
642 pub async fn all_statuses(&self) -> Vec<ProviderStatus> {
644 let handles: Vec<Arc<ProviderHandle>> = {
645 let state = self.inner.read().await;
646 state.providers.values().cloned().collect()
647 };
648 let mut statuses = Vec::with_capacity(handles.len());
649
650 for handle in handles {
651 statuses.push(handle.status().await);
652 }
653
654 statuses
655 }
656
657 pub async fn persist_all(&self) -> Result<()> {
659 #[cfg(feature = "redis")]
660 {
661 if let Some(persistence) = &self.config.persistence {
662 let handles: Vec<Arc<ProviderHandle>> = {
663 let state = self.inner.read().await;
664
665 state.providers.values().cloned().collect()
666 };
667 let mut snapshots = Vec::new();
668
669 for handle in handles {
670 if let Some(snapshot) = handle.manager.persistent_snapshot().await? {
671 snapshots.push(snapshot);
672 }
673 }
674
675 persistence.persist(&snapshots).await?;
676 }
677 }
678
679 Ok(())
680 }
681
682 pub async fn restore_from_persistence(&self) -> Result<()> {
684 #[cfg(feature = "redis")]
685 {
686 if let Some(persistence) = &self.config.persistence {
687 let handles: Vec<Arc<ProviderHandle>> = {
688 let state = self.inner.read().await;
689
690 state.providers.values().cloned().collect()
691 };
692
693 for handle in handles {
694 if let Some(snapshot) = persistence
695 .load(&handle.registration.tenant_id, &handle.registration.provider_id)
696 .await?
697 {
698 handle.manager.restore_snapshot(snapshot).await?;
699 }
700 }
701 }
702 }
703
704 Ok(())
705 }
706}
707impl Default for Registry {
708 fn default() -> Self {
709 Self::new()
710 }
711}
712
713#[derive(Clone, Debug, Serialize, Deserialize)]
715pub struct ProviderStatus {
716 pub tenant_id: String,
718 pub provider_id: String,
720 pub state: ProviderState,
722 pub last_refresh: Option<DateTime<Utc>>,
724 pub next_refresh: Option<DateTime<Utc>>,
726 pub expires_at: Option<DateTime<Utc>>,
728 pub error_count: u32,
730 pub hit_rate: f64,
732 pub stale_serve_ratio: f64,
734 pub metrics: Vec<StatusMetric>,
736}
737impl ProviderStatus {
738 fn from_components(
739 registration: &IdentityProviderRegistration,
740 snapshot: CacheSnapshot,
741 metrics: ProviderMetricsSnapshot,
742 ) -> Self {
743 let mut last_refresh = None;
744 let mut next_refresh = None;
745 let mut expires_at = None;
746 let mut error_count = 0;
747 let state = match &snapshot.state {
748 CacheState::Empty => ProviderState::Empty,
749 CacheState::Loading => ProviderState::Loading,
750 CacheState::Ready(payload) => {
751 last_refresh = Some(payload.last_refresh_at);
752 next_refresh = snapshot.to_datetime(payload.next_refresh_at);
753 expires_at = snapshot.to_datetime(payload.expires_at);
754 error_count = payload.error_count;
755 ProviderState::Ready
756 },
757 CacheState::Refreshing(payload) => {
758 last_refresh = Some(payload.last_refresh_at);
759 next_refresh = snapshot.to_datetime(payload.next_refresh_at);
760 expires_at = snapshot.to_datetime(payload.expires_at);
761 error_count = payload.error_count;
762 ProviderState::Refreshing
763 },
764 };
765 let tenant = ®istration.tenant_id;
766 let provider = ®istration.provider_id;
767 let mut status_metrics = vec![
768 StatusMetric::new(
769 "jwks_cache_requests_total",
770 metrics.total_requests as f64,
771 tenant,
772 provider,
773 ),
774 StatusMetric::new("jwks_cache_hits_total", metrics.cache_hits as f64, tenant, provider),
775 StatusMetric::new(
776 "jwks_cache_stale_total",
777 metrics.stale_serves as f64,
778 tenant,
779 provider,
780 ),
781 StatusMetric::new(
782 "jwks_cache_refresh_errors_total",
783 metrics.refresh_errors as f64,
784 tenant,
785 provider,
786 ),
787 ];
788
789 if let Some(last_micros) = metrics.last_refresh_micros {
790 status_metrics.push(StatusMetric::new(
791 "jwks_cache_last_refresh_micros",
792 last_micros as f64,
793 tenant,
794 provider,
795 ));
796 }
797
798 Self {
799 tenant_id: tenant.clone(),
800 provider_id: provider.clone(),
801 state,
802 last_refresh,
803 next_refresh,
804 expires_at,
805 error_count,
806 hit_rate: metrics.hit_rate(),
807 stale_serve_ratio: metrics.stale_ratio(),
808 metrics: status_metrics,
809 }
810 }
811}
812
813#[derive(Clone, Debug, Serialize, Deserialize)]
815pub struct StatusMetric {
816 pub name: String,
818 pub value: f64,
820 #[serde(default)]
822 pub labels: HashMap<String, String>,
823}
824impl StatusMetric {
825 fn new(name: impl Into<String>, value: f64, tenant: &str, provider: &str) -> Self {
826 let mut labels = HashMap::with_capacity(2);
827
828 labels.insert("tenant".into(), tenant.into());
829 labels.insert("provider".into(), provider.into());
830
831 Self { name: name.into(), value, labels }
832 }
833}
834
835#[derive(Debug)]
836struct RegistryConfig {
837 require_https: bool,
838 default_refresh_early: Duration,
839 default_stale_while_error: Duration,
840 allowed_domains: Vec<String>,
841 #[cfg(feature = "redis")]
842 persistence: Option<RedisPersistence>,
843}
844impl Default for RegistryConfig {
845 fn default() -> Self {
846 Self {
847 require_https: true,
848 default_refresh_early: DEFAULT_REFRESH_EARLY,
849 default_stale_while_error: DEFAULT_STALE_WHILE_ERROR,
850 allowed_domains: Vec::new(),
851 #[cfg(feature = "redis")]
852 persistence: None,
853 }
854 }
855}
856
857#[derive(Debug)]
858struct ProviderHandle {
859 registration: Arc<IdentityProviderRegistration>,
860 manager: CacheManager,
861 metrics: Arc<ProviderMetrics>,
862}
863impl ProviderHandle {
864 async fn status(&self) -> ProviderStatus {
865 let snapshot = self.manager.snapshot().await;
866 let metrics = self.metrics.snapshot();
867
868 ProviderStatus::from_components(&self.registration, snapshot, metrics)
869 }
870}
871
872#[derive(Debug)]
873struct RegistryState {
874 providers: HashMap<TenantProviderKey, Arc<ProviderHandle>>,
876}
877
878#[cfg(feature = "redis")]
879#[derive(Clone, Debug)]
880struct RedisPersistence {
881 client: redis::Client,
882 namespace: Arc<str>,
883}
884#[cfg(feature = "redis")]
885impl RedisPersistence {
886 fn new(client: redis::Client) -> Self {
887 Self { client, namespace: Arc::from("jwks-cache") }
888 }
889
890 async fn persist(&self, snapshots: &[PersistentSnapshot]) -> Result<()> {
891 if snapshots.is_empty() {
892 return Ok(());
893 }
894
895 let mut conn = self.client.get_multiplexed_async_connection().await?;
896
897 for snapshot in snapshots {
898 let key = self.key(&snapshot.tenant_id, &snapshot.provider_id);
899 let payload = serde_json::to_string(snapshot)?;
900 let ttl = (snapshot.expires_at - Utc::now())
901 .to_std()
902 .unwrap_or_else(|_| Duration::from_secs(1));
903 let ttl_secs = ttl.as_secs().max(1);
904
905 conn.set_ex::<_, _, ()>(key, payload, ttl_secs).await?;
906 }
907
908 Ok(())
909 }
910
911 async fn load(&self, tenant: &str, provider: &str) -> Result<Option<PersistentSnapshot>> {
912 let mut conn = self.client.get_multiplexed_async_connection().await?;
913 let key = self.key(tenant, provider);
914 let value: Option<String> = conn.get(key).await?;
915
916 if let Some(json) = value {
917 let snapshot: PersistentSnapshot = serde_json::from_str(&json)?;
918
919 Ok(Some(snapshot))
920 } else {
921 Ok(None)
922 }
923 }
924
925 fn key(&self, tenant: &str, provider: &str) -> String {
926 format!("{}:{tenant}:{provider}", self.namespace)
927 }
928}
929
930fn random_within(min: Duration, max: Duration) -> Duration {
931 if max <= min {
932 return max;
933 }
934 SMALL_RNG.with(|cell| {
935 let mut rng = cell.borrow_mut();
936 let nanos = max.as_nanos() - min.as_nanos();
937 let jitter = rng.random_range(0..=nanos.min(u64::MAX as u128));
938
939 min + Duration::from_nanos(jitter as u64)
940 })
941}
942
943fn default_true() -> bool {
944 true
945}
946
947fn default_refresh_early() -> Duration {
948 DEFAULT_REFRESH_EARLY
949}
950
951fn default_stale_while_error() -> Duration {
952 DEFAULT_STALE_WHILE_ERROR
953}
954
955fn default_min_ttl() -> Duration {
956 MIN_TTL_FLOOR
957}
958
959fn default_max_ttl() -> Duration {
960 DEFAULT_MAX_TTL
961}
962
963fn default_max_response_bytes() -> u64 {
964 DEFAULT_MAX_RESPONSE_BYTES
965}
966
967fn default_max_redirects() -> u8 {
968 3
969}
970
971fn default_prefetch_jitter() -> Duration {
972 DEFAULT_PREFETCH_JITTER
973}
974
975fn validate_tenant_id(value: &str) -> Result<()> {
976 if value.is_empty() {
977 return Err(Error::Validation { field: "tenant_id", reason: "Must not be empty.".into() });
978 }
979 if value.len() > 64 {
980 return Err(Error::Validation {
981 field: "tenant_id",
982 reason: "Must be 64 characters or fewer.".into(),
983 });
984 }
985 if !value.as_bytes().iter().all(|b| b.is_ascii_alphanumeric() || *b == b'-') {
986 return Err(Error::Validation {
987 field: "tenant_id",
988 reason: "May only contain ASCII letters, numbers, and '-'.".into(),
989 });
990 }
991
992 Ok(())
993}
994
995fn validate_provider_id(value: &str) -> Result<()> {
996 if value.is_empty() {
997 return Err(Error::Validation {
998 field: "provider_id",
999 reason: "Must not be empty.".into(),
1000 });
1001 }
1002 if value.len() > 64 {
1003 return Err(Error::Validation {
1004 field: "provider_id",
1005 reason: "Must be 64 characters or fewer.".into(),
1006 });
1007 }
1008 if !value.as_bytes().iter().all(|b| b.is_ascii_alphanumeric() || matches!(b, b'-' | b'_')) {
1009 return Err(Error::Validation {
1010 field: "provider_id",
1011 reason: "May only contain ASCII letters, numbers, '-', or '_'.".into(),
1012 });
1013 }
1014
1015 Ok(())
1016}