1use std::{cell::RefCell, collections::HashMap, mem};
7use jsonwebtoken::jwk::JwkSet;
9use rand::{Rng, SeedableRng, rngs::SmallRng};
10#[cfg(feature = "redis")] use redis::AsyncCommands;
11use serde::{Deserialize, Serialize};
12use tokio::sync::RwLock;
13use url::Url;
14#[cfg(feature = "metrics")] use crate::metrics::{ProviderMetrics, ProviderMetricsSnapshot};
16use crate::{
17 _prelude::*,
18 cache::{
19 manager::{CacheManager, CacheSnapshot},
20 state::CacheState,
21 },
22 security::{self, SpkiFingerprint},
23};
24
25thread_local! {
26 static SMALL_RNG: RefCell<SmallRng> = RefCell::new(SmallRng::from_rng(&mut rand::rng()));
27}
28
29pub const DEFAULT_REFRESH_EARLY: Duration = Duration::from_secs(30);
31pub const DEFAULT_STALE_WHILE_ERROR: Duration = Duration::from_secs(60);
33pub const MIN_TTL_FLOOR: Duration = Duration::from_secs(30);
35pub const DEFAULT_MAX_TTL: Duration = Duration::from_secs(60 * 60 * 24);
37pub const DEFAULT_MAX_RESPONSE_BYTES: u64 = 1_048_576;
39pub const DEFAULT_PREFETCH_JITTER: Duration = Duration::from_secs(5);
41pub const MAX_REDIRECTS: u8 = 10;
43
44#[derive(Clone, Debug, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
46#[serde(rename_all = "snake_case")]
47pub enum JitterStrategy {
48 None,
50 #[default]
52 Full,
53 Decorrelated,
55}
56
57#[derive(Clone, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
59#[serde(rename_all = "PascalCase")]
60pub enum ProviderState {
61 Empty,
63 Loading,
65 Ready,
67 Refreshing,
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize)]
73pub struct RetryPolicy {
74 pub max_retries: u32,
76 pub attempt_timeout: Duration,
78 pub initial_backoff: Duration,
80 pub max_backoff: Duration,
82 pub deadline: Duration,
84 #[serde(default)]
86 pub jitter: JitterStrategy,
87}
88impl RetryPolicy {
89 pub fn validate(&self) -> Result<()> {
91 if self.attempt_timeout < Duration::from_millis(100) {
92 return Err(Error::Validation {
93 field: "retry_policy.attempt_timeout",
94 reason: "Must be at least 100 ms.".into(),
95 });
96 }
97 if self.initial_backoff.is_zero() {
98 return Err(Error::Validation {
99 field: "retry_policy.initial_backoff",
100 reason: "Must be greater than zero.".into(),
101 });
102 }
103 if self.max_backoff < self.initial_backoff {
104 return Err(Error::Validation {
105 field: "retry_policy.max_backoff",
106 reason: "Must be greater than or equal to initial_backoff.".into(),
107 });
108 }
109 if self.deadline < self.attempt_timeout {
110 return Err(Error::Validation {
111 field: "retry_policy.deadline",
112 reason: "Must be greater than or equal to attempt_timeout.".into(),
113 });
114 }
115 Ok(())
116 }
117
118 pub fn compute_backoff(&self, attempt: u32) -> Duration {
120 self.default_backoff(attempt)
121 }
122
123 pub fn default_backoff(&self, attempt: u32) -> Duration {
125 let exponent = attempt.min(32);
126 let base = self.initial_backoff.mul_f64(2f64.powi(exponent as i32));
127 let bounded = base.min(self.max_backoff).max(self.initial_backoff);
128
129 self.apply_jitter(bounded, attempt)
130 }
131
132 fn apply_jitter(&self, bounded: Duration, attempt: u32) -> Duration {
133 match self.jitter {
134 JitterStrategy::None => bounded,
135 JitterStrategy::Full => {
136 let lower = bounded.mul_f64(0.8).max(self.initial_backoff);
137 let upper = bounded.min(self.max_backoff);
138
139 random_within(lower, upper)
140 },
141 JitterStrategy::Decorrelated => {
142 let prev = if attempt == 0 { self.initial_backoff } else { bounded };
143 let ceiling = self.max_backoff.min(prev.mul_f64(3.0));
144
145 random_within(self.initial_backoff, ceiling.max(self.initial_backoff))
146 },
147 }
148 }
149}
150impl Default for RetryPolicy {
151 fn default() -> Self {
152 Self {
153 max_retries: 2,
154 attempt_timeout: Duration::from_secs(3),
155 initial_backoff: Duration::from_millis(250),
156 max_backoff: Duration::from_secs(2),
157 deadline: Duration::from_secs(8),
158 jitter: JitterStrategy::Full,
159 }
160 }
161}
162
163#[derive(Clone, Debug, Serialize, Deserialize)]
165pub struct IdentityProviderRegistration {
166 pub tenant_id: String,
168 pub provider_id: String,
170 pub jwks_url: Url,
172 #[serde(default = "default_true")]
174 pub require_https: bool,
175 #[serde(default, deserialize_with = "crate::security::deserialize_allowed_domains")]
177 pub allowed_domains: Vec<String>,
178 #[serde(default = "default_refresh_early")]
180 pub refresh_early: Duration,
181 #[serde(default = "default_stale_while_error")]
183 pub stale_while_error: Duration,
184 #[serde(default = "default_min_ttl")]
186 pub min_ttl: Duration,
187 #[serde(default = "default_max_ttl")]
189 pub max_ttl: Duration,
190 #[serde(default = "default_max_response_bytes")]
192 pub max_response_bytes: u64,
193 #[serde(default)]
195 pub negative_cache_ttl: Duration,
196 #[serde(default = "default_max_redirects")]
198 pub max_redirects: u8,
199 #[serde(default)]
201 pub pinned_spki: Vec<SpkiFingerprint>,
202 #[serde(default = "default_prefetch_jitter")]
204 pub prefetch_jitter: Duration,
205 #[serde(default)]
207 pub retry_policy: RetryPolicy,
208}
209impl IdentityProviderRegistration {
210 pub fn new(
212 tenant_id: impl Into<String>,
213 provider_id: impl Into<String>,
214 jwks_url: impl AsRef<str>,
215 ) -> Result<Self> {
216 let jwks_url = Url::parse(jwks_url.as_ref())?;
217
218 Ok(Self {
219 tenant_id: tenant_id.into(),
220 provider_id: provider_id.into(),
221 jwks_url,
222 require_https: true,
223 allowed_domains: Vec::new(),
224 refresh_early: DEFAULT_REFRESH_EARLY,
225 stale_while_error: DEFAULT_STALE_WHILE_ERROR,
226 min_ttl: MIN_TTL_FLOOR,
227 max_ttl: DEFAULT_MAX_TTL,
228 max_response_bytes: DEFAULT_MAX_RESPONSE_BYTES,
229 negative_cache_ttl: Duration::ZERO,
230 max_redirects: 3,
231 pinned_spki: Vec::new(),
232 prefetch_jitter: DEFAULT_PREFETCH_JITTER,
233 retry_policy: RetryPolicy::default(),
234 })
235 }
236
237 pub fn normalize_allowed_domains(&mut self) {
239 let domains = mem::take(&mut self.allowed_domains);
240
241 self.allowed_domains = security::normalize_allowlist(domains);
242 }
243
244 pub fn with_require_https(mut self, require_https: bool) -> Self {
246 self.require_https = require_https;
247
248 self
249 }
250
251 pub fn validate(&self) -> Result<()> {
253 validate_tenant_id(&self.tenant_id)?;
254 validate_provider_id(&self.provider_id)?;
255
256 if self.require_https {
257 security::enforce_https(&self.jwks_url)?;
258 }
259
260 if let Some(host) = self.jwks_url.host_str() {
261 if !security::host_is_allowed(host, &self.allowed_domains) {
262 return Err(Error::Validation {
263 field: "jwks_url",
264 reason: "Host is not within the allowed_domains allowlist.".into(),
265 });
266 }
267 } else {
268 return Err(Error::Validation {
269 field: "jwks_url",
270 reason: "Must include a host component.".into(),
271 });
272 }
273
274 if self.refresh_early < Duration::from_secs(1) {
275 return Err(Error::Validation {
276 field: "refresh_early",
277 reason: "Must be at least 1 second.".into(),
278 });
279 }
280 if self.min_ttl < MIN_TTL_FLOOR {
281 return Err(Error::Validation {
282 field: "min_ttl",
283 reason: format!("Must be at least {:?}.", MIN_TTL_FLOOR),
284 });
285 }
286 if self.max_ttl < self.min_ttl {
287 return Err(Error::Validation {
288 field: "max_ttl",
289 reason: "Must be greater than or equal to min_ttl.".into(),
290 });
291 }
292 if self.refresh_early >= self.max_ttl {
293 return Err(Error::Validation {
294 field: "refresh_early",
295 reason: "Must be less than max_ttl.".into(),
296 });
297 }
298 if self.max_response_bytes == 0 {
299 return Err(Error::Validation {
300 field: "max_response_bytes",
301 reason: "Must be greater than zero.".into(),
302 });
303 }
304 if self.max_redirects > MAX_REDIRECTS {
305 return Err(Error::Validation {
306 field: "max_redirects",
307 reason: format!("Must be less than or equal to {}.", MAX_REDIRECTS),
308 });
309 }
310 if !self.negative_cache_ttl.is_zero() && self.negative_cache_ttl < Duration::from_secs(1) {
311 return Err(Error::Validation {
312 field: "negative_cache_ttl",
313 reason: "Must be zero or at least one second.".into(),
314 });
315 }
316
317 self.retry_policy.validate()?;
318
319 for domain in &self.allowed_domains {
320 if let Some(canonical) = security::canonicalize_dns_name(domain) {
321 if canonical != *domain {
322 return Err(Error::Validation {
323 field: "allowed_domains",
324 reason: "Entries must be canonical hostnames (lowercase, no trailing dot)."
325 .into(),
326 });
327 }
328 } else {
329 return Err(Error::Validation {
330 field: "allowed_domains",
331 reason: "Entries must be non-empty hostnames.".into(),
332 });
333 }
334 }
335
336 Ok(())
337 }
338}
339
340#[derive(Clone, Debug, Serialize, Deserialize)]
342pub struct PersistentSnapshot {
343 pub tenant_id: String,
345 pub provider_id: String,
347 pub jwks_json: String,
349 pub etag: Option<String>,
351 #[serde(default)]
353 pub last_modified: Option<DateTime<Utc>>,
354 pub expires_at: DateTime<Utc>,
356 pub persisted_at: DateTime<Utc>,
358}
359impl PersistentSnapshot {
360 pub fn validate(&self, registration: &IdentityProviderRegistration) -> Result<()> {
362 if self.jwks_json.len() as u64 > registration.max_response_bytes {
363 return Err(Error::Validation {
364 field: "jwks_json",
365 reason: format!(
366 "Snapshot exceeds max_response_bytes ({} bytes).",
367 registration.max_response_bytes
368 ),
369 });
370 }
371
372 if self.tenant_id != registration.tenant_id {
373 return Err(Error::Validation {
374 field: "tenant_id",
375 reason: "Snapshot tenant does not match registration.".into(),
376 });
377 }
378 if self.provider_id != registration.provider_id {
379 return Err(Error::Validation {
380 field: "provider_id",
381 reason: "Snapshot provider does not match registration.".into(),
382 });
383 }
384
385 if let Some(etag) = &self.etag
386 && !etag.is_ascii()
387 {
388 return Err(Error::Validation { field: "etag", reason: "ETag must be ASCII.".into() });
389 }
390
391 if self.expires_at < self.persisted_at {
392 return Err(Error::Validation {
393 field: "expires_at",
394 reason: "Cannot be earlier than persisted_at.".into(),
395 });
396 }
397
398 Ok(())
399 }
400}
401
402#[derive(Clone, Debug, PartialEq, Eq, Hash)]
404pub struct TenantProviderKey {
405 pub tenant_id: String,
406 pub provider_id: String,
407}
408impl TenantProviderKey {
409 pub fn new(tenant_id: impl Into<String>, provider_id: impl Into<String>) -> Self {
410 Self { tenant_id: tenant_id.into(), provider_id: provider_id.into() }
411 }
412}
413
414#[derive(Debug, Default)]
416pub struct RegistryBuilder {
417 config: RegistryConfig,
418}
419impl RegistryBuilder {
420 pub fn new() -> Self {
422 Self::default()
423 }
424
425 pub fn require_https(mut self, require_https: bool) -> Self {
427 self.config.require_https = require_https;
428
429 self
430 }
431
432 pub fn default_refresh_early(mut self, value: Duration) -> Self {
434 self.config.default_refresh_early = value;
435
436 self
437 }
438
439 pub fn default_stale_while_error(mut self, value: Duration) -> Self {
441 self.config.default_stale_while_error = value;
442
443 self
444 }
445
446 pub fn add_allowed_domain(mut self, domain: impl Into<String>) -> Self {
448 let raw = domain.into();
449
450 if let Some(domain) = security::canonicalize_dns_name(&raw)
451 && !self.config.allowed_domains.contains(&domain)
452 {
453 self.config.allowed_domains.push(domain);
454 }
455
456 self
457 }
458
459 pub fn allowed_domains<I, S>(mut self, domains: I) -> Self
461 where
462 I: IntoIterator<Item = S>,
463 S: Into<String>,
464 {
465 self.config.allowed_domains.clear();
466
467 for domain in domains {
468 self = self.add_allowed_domain(domain);
469 }
470
471 self
472 }
473
474 #[cfg(feature = "redis")]
475 pub fn with_redis_client(mut self, client: redis::Client) -> Self {
477 self.config.persistence = Some(RedisPersistence::new(client));
478
479 self
480 }
481
482 #[cfg(feature = "redis")]
483 pub fn redis_namespace(mut self, namespace: impl Into<String>) -> Self {
485 if let Some(persistence) = self.config.persistence.as_mut() {
486 persistence.namespace = Arc::from(namespace.into());
487 } else {
488 panic!("Redis client must be configured before setting namespace.");
489 }
490
491 self
492 }
493
494 pub fn build(self) -> Registry {
496 let mut config = self.config;
497
498 config.allowed_domains = security::normalize_allowlist(config.allowed_domains);
499
500 Registry {
501 inner: Arc::new(RwLock::new(RegistryState { providers: HashMap::new() })),
502 config: Arc::new(config),
503 }
504 }
505}
506
507#[derive(Clone, Debug)]
509pub struct Registry {
510 inner: Arc<RwLock<RegistryState>>,
511 config: Arc<RegistryConfig>,
512}
513impl Registry {
514 pub fn new() -> Self {
516 Self::builder().build()
517 }
518
519 pub fn builder() -> RegistryBuilder {
521 RegistryBuilder::new()
522 }
523
524 pub async fn register(&self, mut registration: IdentityProviderRegistration) -> Result<()> {
526 if self.config.require_https {
527 if !registration.require_https {
528 return Err(Error::Security(
529 "Registry requires HTTPS for all provider registrations.".into(),
530 ));
531 }
532 } else {
533 registration.require_https = false;
534 }
535
536 registration.normalize_allowed_domains();
537
538 if registration.refresh_early == DEFAULT_REFRESH_EARLY {
539 registration.refresh_early = self.config.default_refresh_early;
540 }
541 if registration.stale_while_error == DEFAULT_STALE_WHILE_ERROR {
542 registration.stale_while_error = self.config.default_stale_while_error;
543 }
544 if registration.allowed_domains.is_empty() && !self.config.allowed_domains.is_empty() {
545 registration.allowed_domains = self.config.allowed_domains.clone();
546 }
547
548 if let Some(host) = registration.jwks_url.host_str()
549 && !security::host_is_allowed(host, &self.config.allowed_domains)
550 {
551 return Err(Error::Security(format!(
552 "Host '{host}' is not in the registry allowlist."
553 )));
554 }
555
556 let key = TenantProviderKey::new(®istration.tenant_id, ®istration.provider_id);
557 let manager = CacheManager::new(registration.clone())?;
558 #[cfg(feature = "metrics")]
559 let metrics = manager.metrics();
560 let handle = Arc::new(ProviderHandle {
561 registration: Arc::new(registration),
562 manager,
563 #[cfg(feature = "metrics")]
564 metrics,
565 });
566
567 {
568 let mut state = self.inner.write().await;
569
570 state.providers.insert(key.clone(), handle.clone());
571 }
572
573 #[cfg(feature = "redis")]
574 if let Some(persistence) = &self.config.persistence
575 && let Some(snapshot) = persistence.load(&key.tenant_id, &key.provider_id).await?
576 {
577 handle.manager.restore_snapshot(snapshot).await?;
578 }
579
580 Ok(())
581 }
582
583 pub async fn resolve(
585 &self,
586 tenant_id: &str,
587 provider_id: &str,
588 kid: Option<&str>,
589 ) -> Result<Arc<JwkSet>> {
590 let key = TenantProviderKey::new(tenant_id, provider_id);
591 let handle = {
592 let state = self.inner.read().await;
593
594 state.providers.get(&key).cloned()
595 };
596 let handle = handle.ok_or_else(|| Error::NotRegistered {
597 tenant: tenant_id.to_string(),
598 provider: provider_id.to_string(),
599 })?;
600
601 handle.manager.resolve(kid).await
602 }
603
604 pub async fn refresh(&self, tenant_id: &str, provider_id: &str) -> Result<()> {
606 let key = TenantProviderKey::new(tenant_id, provider_id);
607 let handle = {
608 let state = self.inner.read().await;
609 state.providers.get(&key).cloned()
610 };
611 let handle = handle.ok_or_else(|| Error::NotRegistered {
612 tenant: tenant_id.to_string(),
613 provider: provider_id.to_string(),
614 })?;
615
616 handle.manager.trigger_refresh().await
617 }
618
619 pub async fn unregister(&self, tenant_id: &str, provider_id: &str) -> Result<bool> {
621 let key = TenantProviderKey::new(tenant_id, provider_id);
622 let mut state = self.inner.write().await;
623
624 Ok(state.providers.remove(&key).is_some())
625 }
626
627 pub async fn provider_status(
629 &self,
630 tenant_id: &str,
631 provider_id: &str,
632 ) -> Result<ProviderStatus> {
633 let key = TenantProviderKey::new(tenant_id, provider_id);
634 let handle = {
635 let state = self.inner.read().await;
636
637 state.providers.get(&key).cloned()
638 };
639 let handle = handle.ok_or_else(|| Error::NotRegistered {
640 tenant: tenant_id.to_string(),
641 provider: provider_id.to_string(),
642 })?;
643
644 Ok(handle.status().await)
645 }
646
647 pub async fn all_statuses(&self) -> Vec<ProviderStatus> {
649 let handles: Vec<Arc<ProviderHandle>> = {
650 let state = self.inner.read().await;
651 state.providers.values().cloned().collect()
652 };
653 let mut statuses = Vec::with_capacity(handles.len());
654
655 for handle in handles {
656 statuses.push(handle.status().await);
657 }
658
659 statuses
660 }
661
662 pub async fn persist_all(&self) -> Result<()> {
664 #[cfg(feature = "redis")]
665 {
666 if let Some(persistence) = &self.config.persistence {
667 let handles: Vec<Arc<ProviderHandle>> = {
668 let state = self.inner.read().await;
669
670 state.providers.values().cloned().collect()
671 };
672 let mut snapshots = Vec::new();
673
674 for handle in handles {
675 if let Some(snapshot) = handle.manager.persistent_snapshot().await? {
676 snapshots.push(snapshot);
677 }
678 }
679
680 persistence.persist(&snapshots).await?;
681 }
682 }
683
684 Ok(())
685 }
686
687 pub async fn restore_from_persistence(&self) -> Result<()> {
689 #[cfg(feature = "redis")]
690 {
691 if let Some(persistence) = &self.config.persistence {
692 let handles: Vec<Arc<ProviderHandle>> = {
693 let state = self.inner.read().await;
694
695 state.providers.values().cloned().collect()
696 };
697
698 for handle in handles {
699 if let Some(snapshot) = persistence
700 .load(&handle.registration.tenant_id, &handle.registration.provider_id)
701 .await?
702 {
703 handle.manager.restore_snapshot(snapshot).await?;
704 }
705 }
706 }
707 }
708
709 Ok(())
710 }
711}
712impl Default for Registry {
713 fn default() -> Self {
714 Self::new()
715 }
716}
717
718#[derive(Clone, Debug, Serialize, Deserialize)]
720pub struct ProviderStatus {
721 pub tenant_id: String,
723 pub provider_id: String,
725 pub state: ProviderState,
727 pub last_refresh: Option<DateTime<Utc>>,
729 pub next_refresh: Option<DateTime<Utc>>,
731 pub expires_at: Option<DateTime<Utc>>,
733 pub error_count: u32,
735 #[cfg(feature = "metrics")]
737 pub hit_rate: f64,
738 #[cfg(feature = "metrics")]
740 pub stale_serve_ratio: f64,
741 #[cfg(feature = "metrics")]
743 pub metrics: Vec<StatusMetric>,
744}
745impl ProviderStatus {
746 #[cfg(feature = "metrics")]
747 fn from_components(
748 registration: &IdentityProviderRegistration,
749 snapshot: CacheSnapshot,
750 metrics: ProviderMetricsSnapshot,
751 ) -> Self {
752 let mut last_refresh = None;
753 let mut next_refresh = None;
754 let mut expires_at = None;
755 let mut error_count = 0;
756 let state = match &snapshot.state {
757 CacheState::Empty => ProviderState::Empty,
758 CacheState::Loading => ProviderState::Loading,
759 CacheState::Ready(payload) => {
760 last_refresh = Some(payload.last_refresh_at);
761 next_refresh = snapshot.to_datetime(payload.next_refresh_at);
762 expires_at = snapshot.to_datetime(payload.expires_at);
763 error_count = payload.error_count;
764 ProviderState::Ready
765 },
766 CacheState::Refreshing(payload) => {
767 last_refresh = Some(payload.last_refresh_at);
768 next_refresh = snapshot.to_datetime(payload.next_refresh_at);
769 expires_at = snapshot.to_datetime(payload.expires_at);
770 error_count = payload.error_count;
771 ProviderState::Refreshing
772 },
773 };
774 let tenant = ®istration.tenant_id;
775 let provider = ®istration.provider_id;
776 let mut status_metrics = vec![
777 StatusMetric::new(
778 "jwks_cache_requests_total",
779 metrics.total_requests as f64,
780 tenant,
781 provider,
782 ),
783 StatusMetric::new("jwks_cache_hits_total", metrics.cache_hits as f64, tenant, provider),
784 StatusMetric::new(
785 "jwks_cache_stale_total",
786 metrics.stale_serves as f64,
787 tenant,
788 provider,
789 ),
790 StatusMetric::new(
791 "jwks_cache_refresh_errors_total",
792 metrics.refresh_errors as f64,
793 tenant,
794 provider,
795 ),
796 ];
797
798 if let Some(last_micros) = metrics.last_refresh_micros {
799 status_metrics.push(StatusMetric::new(
800 "jwks_cache_last_refresh_micros",
801 last_micros as f64,
802 tenant,
803 provider,
804 ));
805 }
806
807 Self {
808 tenant_id: tenant.clone(),
809 provider_id: provider.clone(),
810 state,
811 last_refresh,
812 next_refresh,
813 expires_at,
814 error_count,
815 hit_rate: metrics.hit_rate(),
816 stale_serve_ratio: metrics.stale_ratio(),
817 metrics: status_metrics,
818 }
819 }
820
821 #[cfg(not(feature = "metrics"))]
822 fn from_components(
823 registration: &IdentityProviderRegistration,
824 snapshot: CacheSnapshot,
825 ) -> Self {
826 let mut last_refresh = None;
827 let mut next_refresh = None;
828 let mut expires_at = None;
829 let mut error_count = 0;
830 let state = match &snapshot.state {
831 CacheState::Empty => ProviderState::Empty,
832 CacheState::Loading => ProviderState::Loading,
833 CacheState::Ready(payload) => {
834 last_refresh = Some(payload.last_refresh_at);
835 next_refresh = snapshot.to_datetime(payload.next_refresh_at);
836 expires_at = snapshot.to_datetime(payload.expires_at);
837 error_count = payload.error_count;
838 ProviderState::Ready
839 },
840 CacheState::Refreshing(payload) => {
841 last_refresh = Some(payload.last_refresh_at);
842 next_refresh = snapshot.to_datetime(payload.next_refresh_at);
843 expires_at = snapshot.to_datetime(payload.expires_at);
844 error_count = payload.error_count;
845 ProviderState::Refreshing
846 },
847 };
848
849 Self {
850 tenant_id: registration.tenant_id.clone(),
851 provider_id: registration.provider_id.clone(),
852 state,
853 last_refresh,
854 next_refresh,
855 expires_at,
856 error_count,
857 }
858 }
859}
860
861#[cfg(feature = "metrics")]
863#[derive(Clone, Debug, Serialize, Deserialize)]
864pub struct StatusMetric {
865 pub name: String,
867 pub value: f64,
869 #[serde(default)]
871 pub labels: HashMap<String, String>,
872}
873#[cfg(feature = "metrics")]
874impl StatusMetric {
875 fn new(name: impl Into<String>, value: f64, tenant: &str, provider: &str) -> Self {
876 let mut labels = HashMap::with_capacity(2);
877
878 labels.insert("tenant".into(), tenant.into());
879 labels.insert("provider".into(), provider.into());
880
881 Self { name: name.into(), value, labels }
882 }
883}
884
885#[derive(Debug)]
886struct RegistryConfig {
887 require_https: bool,
888 default_refresh_early: Duration,
889 default_stale_while_error: Duration,
890 allowed_domains: Vec<String>,
891 #[cfg(feature = "redis")]
892 persistence: Option<RedisPersistence>,
893}
894impl Default for RegistryConfig {
895 fn default() -> Self {
896 Self {
897 require_https: true,
898 default_refresh_early: DEFAULT_REFRESH_EARLY,
899 default_stale_while_error: DEFAULT_STALE_WHILE_ERROR,
900 allowed_domains: Vec::new(),
901 #[cfg(feature = "redis")]
902 persistence: None,
903 }
904 }
905}
906
907#[derive(Debug)]
908struct ProviderHandle {
909 registration: Arc<IdentityProviderRegistration>,
910 manager: CacheManager,
911 #[cfg(feature = "metrics")]
912 metrics: Arc<ProviderMetrics>,
913}
914impl ProviderHandle {
915 async fn status(&self) -> ProviderStatus {
916 let snapshot = self.manager.snapshot().await;
917 #[cfg(feature = "metrics")]
918 let status = {
919 let metrics = self.metrics.snapshot();
920
921 ProviderStatus::from_components(&self.registration, snapshot, metrics)
922 };
923 #[cfg(not(feature = "metrics"))]
924 let status = ProviderStatus::from_components(&self.registration, snapshot);
925
926 status
927 }
928}
929
930#[derive(Debug)]
931struct RegistryState {
932 providers: HashMap<TenantProviderKey, Arc<ProviderHandle>>,
934}
935
936#[cfg(feature = "redis")]
937#[derive(Clone, Debug)]
938struct RedisPersistence {
939 client: redis::Client,
940 namespace: Arc<str>,
941}
942#[cfg(feature = "redis")]
943impl RedisPersistence {
944 fn new(client: redis::Client) -> Self {
945 Self { client, namespace: Arc::from("jwks-cache") }
946 }
947
948 async fn persist(&self, snapshots: &[PersistentSnapshot]) -> Result<()> {
949 if snapshots.is_empty() {
950 return Ok(());
951 }
952
953 let mut conn = self.client.get_multiplexed_async_connection().await?;
954
955 for snapshot in snapshots {
956 let key = self.key(&snapshot.tenant_id, &snapshot.provider_id);
957 let payload = serde_json::to_string(snapshot)?;
958 let ttl = (snapshot.expires_at - Utc::now())
959 .to_std()
960 .unwrap_or_else(|_| Duration::from_secs(1));
961 let ttl_secs = ttl.as_secs().max(1);
962
963 conn.set_ex::<_, _, ()>(key, payload, ttl_secs).await?;
964 }
965
966 Ok(())
967 }
968
969 async fn load(&self, tenant: &str, provider: &str) -> Result<Option<PersistentSnapshot>> {
970 let mut conn = self.client.get_multiplexed_async_connection().await?;
971 let key = self.key(tenant, provider);
972 let value: Option<String> = conn.get(key).await?;
973
974 if let Some(json) = value {
975 let snapshot: PersistentSnapshot = serde_json::from_str(&json)?;
976
977 Ok(Some(snapshot))
978 } else {
979 Ok(None)
980 }
981 }
982
983 fn key(&self, tenant: &str, provider: &str) -> String {
984 format!("{}:{tenant}:{provider}", self.namespace)
985 }
986}
987
988fn random_within(min: Duration, max: Duration) -> Duration {
989 if max <= min {
990 return max;
991 }
992 SMALL_RNG.with(|cell| {
993 let mut rng = cell.borrow_mut();
994 let nanos = max.as_nanos() - min.as_nanos();
995 let jitter = rng.random_range(0..=nanos.min(u64::MAX as u128));
996
997 min + Duration::from_nanos(jitter as u64)
998 })
999}
1000
1001fn default_true() -> bool {
1002 true
1003}
1004
1005fn default_refresh_early() -> Duration {
1006 DEFAULT_REFRESH_EARLY
1007}
1008
1009fn default_stale_while_error() -> Duration {
1010 DEFAULT_STALE_WHILE_ERROR
1011}
1012
1013fn default_min_ttl() -> Duration {
1014 MIN_TTL_FLOOR
1015}
1016
1017fn default_max_ttl() -> Duration {
1018 DEFAULT_MAX_TTL
1019}
1020
1021fn default_max_response_bytes() -> u64 {
1022 DEFAULT_MAX_RESPONSE_BYTES
1023}
1024
1025fn default_max_redirects() -> u8 {
1026 3
1027}
1028
1029fn default_prefetch_jitter() -> Duration {
1030 DEFAULT_PREFETCH_JITTER
1031}
1032
1033fn validate_tenant_id(value: &str) -> Result<()> {
1034 if value.is_empty() {
1035 return Err(Error::Validation { field: "tenant_id", reason: "Must not be empty.".into() });
1036 }
1037 if value.len() > 64 {
1038 return Err(Error::Validation {
1039 field: "tenant_id",
1040 reason: "Must be 64 characters or fewer.".into(),
1041 });
1042 }
1043 if !value.as_bytes().iter().all(|b| b.is_ascii_alphanumeric() || *b == b'-') {
1044 return Err(Error::Validation {
1045 field: "tenant_id",
1046 reason: "May only contain ASCII letters, numbers, and '-'.".into(),
1047 });
1048 }
1049
1050 Ok(())
1051}
1052
1053fn validate_provider_id(value: &str) -> Result<()> {
1054 if value.is_empty() {
1055 return Err(Error::Validation {
1056 field: "provider_id",
1057 reason: "Must not be empty.".into(),
1058 });
1059 }
1060 if value.len() > 64 {
1061 return Err(Error::Validation {
1062 field: "provider_id",
1063 reason: "Must be 64 characters or fewer.".into(),
1064 });
1065 }
1066 if !value.as_bytes().iter().all(|b| b.is_ascii_alphanumeric() || matches!(b, b'-' | b'_')) {
1067 return Err(Error::Validation {
1068 field: "provider_id",
1069 reason: "May only contain ASCII letters, numbers, '-', or '_'.".into(),
1070 });
1071 }
1072
1073 Ok(())
1074}