use super::builder::{JwtSourceBuilder, ReconnectConfig, ResourceLimits};
use super::errors::{JwtSourceError, MetricsErrorKind};
use super::limits::validate_bundle_set;
use super::metrics::MetricsRecorder;
use super::supervisor::initial_sync_with_retry;
use super::types::ClientFactory;
use crate::bundle::BundleSource;
use crate::prelude::warn;
use crate::workload_api::WorkloadApiClient;
use crate::{JwtBundle, JwtBundleSet, JwtSvid, SpiffeId, TrustDomain};
use arc_swap::ArcSwap;
use std::fmt::Debug;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{watch, Mutex};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
#[cfg(test)]
use crate::WorkloadApiError;
#[derive(Clone, Debug)]
pub struct JwtSourceUpdates {
rx: watch::Receiver<u64>,
}
impl JwtSourceUpdates {
pub async fn changed(&mut self) -> Result<u64, JwtSourceError> {
self.rx
.changed()
.await
.map_err(|watch::error::RecvError { .. }| JwtSourceError::Closed)?;
Ok(*self.rx.borrow())
}
pub fn last(&self) -> u64 {
*self.rx.borrow()
}
pub async fn wait_for<F>(&mut self, mut f: F) -> Result<u64, JwtSourceError>
where
F: FnMut(&u64) -> bool,
{
let current = self.last();
if f(¤t) {
return Ok(current);
}
loop {
let seq = self.changed().await?;
if f(&seq) {
return Ok(seq);
}
}
}
}
#[derive(Clone, Debug)]
pub struct JwtSource {
inner: Arc<Inner>,
}
pub(super) struct Inner {
bundle_set: ArcSwap<JwtBundleSet>,
limits: ResourceLimits,
cached_client: ArcSwap<Option<Arc<WorkloadApiClient>>>,
client_creation_mutex: Mutex<()>,
reconnect: ReconnectConfig,
make_client: ClientFactory,
metrics: Option<Arc<dyn MetricsRecorder>>,
closed: AtomicBool,
cancel: CancellationToken,
shutdown_timeout: Option<Duration>,
update_seq: AtomicU64,
update_tx: watch::Sender<u64>,
update_rx: watch::Receiver<u64>,
supervisor: Mutex<Option<JoinHandle<()>>>,
}
impl Inner {
pub(super) const fn reconnect(&self) -> ReconnectConfig {
self.reconnect
}
pub(super) fn metrics(&self) -> Option<&dyn MetricsRecorder> {
self.metrics.as_deref()
}
pub(super) fn make_client(&self) -> &ClientFactory {
&self.make_client
}
pub(super) async fn get_or_recreate_client(
&self,
) -> Result<Arc<WorkloadApiClient>, JwtSourceError> {
let cached = self.cached_client.load();
if let Some(client) = cached.as_ref() {
return Ok(Arc::clone(client));
}
let _guard = self.client_creation_mutex.lock().await;
let cached = self.cached_client.load();
if let Some(client) = cached.as_ref() {
return Ok(Arc::clone(client));
}
self.recreate_client_inner().await
}
pub(super) async fn recreate_client(&self) -> Result<Arc<WorkloadApiClient>, JwtSourceError> {
let _guard = self.client_creation_mutex.lock().await;
self.recreate_client_inner().await
}
async fn recreate_client_inner(&self) -> Result<Arc<WorkloadApiClient>, JwtSourceError> {
let client = (self.make_client)().await.map_err(JwtSourceError::Source)?;
let client_arc = Arc::new(client);
self.cached_client
.store(Arc::new(Some(Arc::clone(&client_arc))));
Ok(client_arc)
}
}
impl Debug for Inner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtSource")
.field("bundle_set", &"<ArcSwap<JwtBundleSet>>")
.field(
"cached_client",
&"<ArcSwap<Option<Arc<WorkloadApiClient>>>>",
)
.field("client_creation_mutex", &"<Mutex<()>>")
.field("reconnect", &self.reconnect)
.field("limits", &self.limits)
.field("make_client", &"<ClientFactory>")
.field(
"metrics",
&self.metrics.as_ref().map(|_| "<MetricsRecorder>"),
)
.field("shutdown_timeout", &self.shutdown_timeout)
.field("closed", &self.closed.load(Ordering::Relaxed))
.field("cancel", &self.cancel)
.field("update_seq", &self.update_seq)
.field("update_tx", &"<watch::Sender<u64>>")
.field("update_rx", &"<watch::Receiver<u64>>")
.field("supervisor", &"<Mutex<Option<JoinHandle<()>>>>")
.finish()
}
}
impl JwtSource {
pub async fn new() -> Result<Self, JwtSourceError> {
JwtSourceBuilder::new().build().await
}
pub fn builder() -> JwtSourceBuilder {
JwtSourceBuilder::new()
}
pub fn updated(&self) -> JwtSourceUpdates {
JwtSourceUpdates {
rx: self.inner.update_rx.clone(),
}
}
pub fn is_healthy(&self) -> bool {
if self.inner.closed.load(Ordering::Acquire) || self.inner.cancel.is_cancelled() {
return false;
}
let bundle_set = self.inner.bundle_set.load();
!bundle_set.is_empty()
}
pub fn bundle_set(&self) -> Result<Arc<JwtBundleSet>, JwtSourceError> {
self.assert_open()?;
Ok(self.inner.bundle_set.load_full())
}
pub fn try_bundle_for_trust_domain(&self, td: &TrustDomain) -> Option<Arc<JwtBundle>> {
self.bundle_for_trust_domain(td).ok().flatten()
}
pub async fn get_jwt_svid<I>(&self, audience: I) -> Result<JwtSvid, JwtSourceError>
where
I: IntoIterator,
I::Item: AsRef<str>,
{
self.get_jwt_svid_with_id(audience, None).await
}
pub async fn get_jwt_svid_with_id<I>(
&self,
audience: I,
spiffe_id: Option<&SpiffeId>,
) -> Result<JwtSvid, JwtSourceError>
where
I: IntoIterator,
I::Item: AsRef<str>,
{
self.assert_open()?;
let audience_vec: Vec<String> = audience
.into_iter()
.map(|a| a.as_ref().to_string())
.collect();
let client = self.inner.get_or_recreate_client().await?;
match client.fetch_jwt_svid(&audience_vec, spiffe_id).await {
Ok(svid) => Ok(svid),
Err(_e) => {
self.assert_open()?; let new_client = self.inner.recreate_client().await?;
new_client
.fetch_jwt_svid(&audience_vec, spiffe_id)
.await
.map_err(JwtSourceError::FetchJwtSvid)
}
}
}
pub async fn shutdown(&self) {
if self.inner.closed.swap(true, Ordering::AcqRel) {
return;
}
self.inner.cancel.cancel();
if let Some(handle) = self.inner.supervisor.lock().await.take() {
if let Err(e) = handle.await {
warn!("Error joining supervisor task during shutdown: error={e}");
self.inner
.record_error(MetricsErrorKind::SupervisorJoinFailed);
}
}
}
pub async fn shutdown_with_timeout(&self, timeout: Duration) -> Result<(), JwtSourceError> {
if self.inner.closed.swap(true, Ordering::AcqRel) {
return Ok(());
}
self.inner.cancel.cancel();
let Some(mut handle) = self.inner.supervisor.lock().await.take() else {
return Ok(());
};
match tokio::time::timeout(timeout, &mut handle).await {
Ok(Ok(())) => Ok(()),
Ok(Err(e)) => {
warn!("Error joining supervisor task during shutdown: error={e}");
self.inner
.record_error(MetricsErrorKind::SupervisorJoinFailed);
Ok(())
}
Err(_) => {
warn!("Shutdown timeout exceeded; aborting supervisor task");
handle.abort();
let _unused: Result<_, _> = handle.await;
Err(JwtSourceError::ShutdownTimeout)
}
}
}
pub async fn shutdown_configured(&self) -> Result<(), JwtSourceError> {
if let Some(timeout) = self.inner.shutdown_timeout {
self.shutdown_with_timeout(timeout).await
} else {
self.shutdown().await;
Ok(())
}
}
}
impl JwtSource {
pub(super) async fn build_with(
make_client: ClientFactory,
reconnect: ReconnectConfig,
limits: ResourceLimits,
metrics: Option<Arc<dyn MetricsRecorder>>,
shutdown_timeout: Option<Duration>,
) -> Result<Self, JwtSourceError> {
let reconnect = super::builder::normalize_reconnect(reconnect);
let (update_tx, update_rx) = watch::channel(0u64);
let cancel = CancellationToken::new();
let initial_bundle_set =
initial_sync_with_retry(&make_client, &cancel, reconnect, limits, metrics.as_deref())
.await?;
let initial_client = match make_client().await {
Ok(c) => c,
Err(_first_err) => {
tokio::time::sleep(reconnect.min_backoff).await;
make_client().await.map_err(JwtSourceError::Source)?
}
};
let initial_client_arc = Arc::new(initial_client);
let inner = Arc::new(Inner {
bundle_set: ArcSwap::from(initial_bundle_set),
cached_client: ArcSwap::from(Arc::new(Some(initial_client_arc))),
client_creation_mutex: Mutex::new(()),
reconnect,
make_client,
limits,
metrics,
shutdown_timeout,
closed: AtomicBool::new(false),
cancel,
update_seq: AtomicU64::new(0),
update_tx,
update_rx,
supervisor: Mutex::new(None),
});
let task_inner = Arc::clone(&inner);
let token = task_inner.cancel.clone();
let handle = tokio::spawn(async move {
task_inner.run_update_supervisor(token).await;
});
*inner.supervisor.lock().await = Some(handle);
Ok(Self { inner })
}
#[cfg(test)]
pub(super) fn new_for_test(
initial_bundle_set: Arc<JwtBundleSet>,
reconnect: ReconnectConfig,
limits: ResourceLimits,
metrics: Option<Arc<dyn MetricsRecorder>>,
) -> Self {
let reconnect = super::builder::normalize_reconnect(reconnect);
let (update_tx, update_rx) = watch::channel(0u64);
let cancel = CancellationToken::new();
let make_client: ClientFactory =
Arc::new(|| Box::pin(async move { Err(WorkloadApiError::EmptyResponse) }));
let inner = Inner {
bundle_set: ArcSwap::from(initial_bundle_set),
cached_client: ArcSwap::from(Arc::new(None)),
client_creation_mutex: Mutex::new(()),
reconnect,
make_client,
limits,
metrics,
shutdown_timeout: None,
closed: AtomicBool::new(false),
cancel,
update_seq: AtomicU64::new(0),
update_tx,
update_rx,
supervisor: Mutex::new(None),
};
Self {
inner: Arc::new(inner),
}
}
fn assert_open(&self) -> Result<(), JwtSourceError> {
if self.inner.closed.load(Ordering::Acquire) || self.inner.cancel.is_cancelled() {
return Err(JwtSourceError::Closed);
}
Ok(())
}
}
impl Inner {
pub(super) fn record_error(&self, kind: MetricsErrorKind) {
if let Some(metrics) = self.metrics.as_deref() {
metrics.record_error(kind);
}
}
pub(super) fn record_update(&self) {
if let Some(metrics) = self.metrics.as_deref() {
metrics.record_update();
}
}
pub(super) fn apply_update(
&self,
new_bundle_set: Arc<JwtBundleSet>,
) -> Result<(), JwtSourceError> {
match self.validate_bundle_set(&new_bundle_set) {
Ok(()) => {
self.bundle_set.store(new_bundle_set);
self.record_update();
self.notify_update();
Ok(())
}
Err(e) => {
self.record_error(MetricsErrorKind::UpdateRejected);
Err(e)
}
}
}
pub(super) fn notify_update(&self) {
let next = self.update_seq.fetch_add(1, Ordering::Relaxed) + 1;
let _unused: Result<_, _> = self.update_tx.send(next);
}
pub(super) fn validate_bundle_set(
&self,
bundle_set: &JwtBundleSet,
) -> Result<(), JwtSourceError> {
validate_bundle_set(bundle_set, self.limits, self.metrics.as_deref())
}
}
impl Drop for JwtSource {
fn drop(&mut self) {
self.inner.cancel.cancel();
}
}
impl BundleSource for JwtSource {
type Item = JwtBundle;
type Error = JwtSourceError;
fn bundle_for_trust_domain(
&self,
trust_domain: &TrustDomain,
) -> Result<Option<Arc<Self::Item>>, Self::Error> {
self.assert_open()?;
let bundle_set = self.inner.bundle_set.load();
Ok(bundle_set.get(trust_domain))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bundle::jwt::JwtAuthority;
use std::collections::HashMap;
use std::sync::Mutex;
fn jwk_with_kid(kid: &str) -> JwtAuthority {
let json = format!(
r#"{{
"kty": "oct",
"kid": "{kid}",
"k": "AyM1SysPpbyDfgZld3umj1qzKObwVMkoqQ-EstJQLr_T-1qS0gZH75aKtMN3Yj0iPS4hcgUuTwjAzZr1Z9CAow"
}}"#
);
JwtAuthority::from_jwk_json(json.as_bytes()).expect("valid JWK JSON")
}
fn create_test_bundle_set() -> Arc<JwtBundleSet> {
let trust_domain = TrustDomain::new("example.org").unwrap();
let mut bundle = JwtBundle::new(trust_domain);
bundle.add_jwt_authority(jwk_with_kid("kid-1"));
let mut bundle_set = JwtBundleSet::new();
bundle_set.add_bundle(bundle);
Arc::new(bundle_set)
}
#[tokio::test]
async fn test_wait_for_immediate_satisfaction() {
let (tx, rx) = watch::channel(5u64);
let mut updates = JwtSourceUpdates { rx };
let result = updates.wait_for(|&seq| seq > 3).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 5);
let _unused: Result<_, _> = tx.send(10);
let result = updates.wait_for(|&seq| seq > 8).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 10);
}
#[tokio::test]
async fn test_wait_for_waits_when_not_satisfied() {
let (tx, rx) = watch::channel(1u64);
let mut updates = JwtSourceUpdates { rx };
let tx_clone = tx.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
let _unused: Result<_, _> = tx_clone.send(5);
});
let result = tokio::time::timeout(Duration::from_secs(1), updates.wait_for(|&seq| seq > 3))
.await
.expect("Should complete within timeout");
assert!(result.is_ok());
assert_eq!(result.unwrap(), 5);
}
#[tokio::test]
async fn test_updated_only_notifies_on_rotations_after_initial_sync() {
let (tx, rx) = watch::channel(0u64);
let mut updates = JwtSourceUpdates { rx: rx.clone() };
let tx_clone = tx.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
let _unused: Result<_, _> = tx_clone.send(1);
});
let result = tokio::time::timeout(Duration::from_secs(1), updates.changed())
.await
.expect("Should complete within timeout");
assert!(result.is_ok());
assert_eq!(result.unwrap(), 1);
assert_eq!(updates.last(), 1);
}
#[tokio::test]
async fn test_updated_initial_sequence_is_zero() {
let (_tx, rx) = watch::channel(0u64);
let updates = JwtSourceUpdates { rx };
assert_eq!(updates.last(), 0);
}
struct TestMetricsRecorder {
counts: Arc<Mutex<HashMap<MetricsErrorKind, u64>>>,
}
impl TestMetricsRecorder {
fn new() -> Self {
Self {
counts: Arc::new(Mutex::new(HashMap::new())),
}
}
fn count(&self, kind: MetricsErrorKind) -> u64 {
*self.counts.lock().unwrap().get(&kind).unwrap_or(&0)
}
}
impl MetricsRecorder for TestMetricsRecorder {
fn record_update(&self) {}
fn record_reconnect(&self) {}
fn record_error(&self, kind: MetricsErrorKind) {
*self.counts.lock().unwrap().entry(kind).or_insert(0) += 1;
}
}
#[test]
fn test_metrics_recorded_exactly_once_per_rejected_update() {
use super::super::builder::ResourceLimits;
use crate::bundle::jwt::JwtBundle;
let metrics = Arc::new(TestMetricsRecorder::new());
let limits = ResourceLimits {
max_bundles: Some(0), max_bundle_jwks_bytes: Some(1000),
};
let trust_domain = TrustDomain::new("example.org").unwrap();
let mut bundle = JwtBundle::new(trust_domain);
bundle.add_jwt_authority(jwk_with_kid("kid-1"));
let mut bundle_set = JwtBundleSet::new();
bundle_set.add_bundle(bundle);
let source = {
let metrics = Arc::clone(&metrics);
JwtSource::new_for_test(
Arc::new(JwtBundleSet::new()),
ReconnectConfig::default(),
limits,
Some(metrics),
)
};
let result = source.inner.apply_update(Arc::new(bundle_set));
assert!(matches!(
result,
Err(JwtSourceError::ResourceLimitExceeded {
kind: super::super::errors::LimitKind::MaxBundles,
..
})
));
assert_eq!(metrics.count(MetricsErrorKind::LimitMaxBundles), 1);
assert_eq!(metrics.count(MetricsErrorKind::UpdateRejected), 1);
assert_eq!(metrics.count(MetricsErrorKind::LimitMaxBundleJwksBytes), 0);
}
#[test]
fn test_new_with_normalizes_reconnect_config() {
use super::super::builder::ResourceLimits;
use std::time::Duration;
let initial_bundle_set = create_test_bundle_set();
let inverted_reconnect = ReconnectConfig {
min_backoff: Duration::from_secs(10),
max_backoff: Duration::from_secs(1),
};
let source = JwtSource::new_for_test(
initial_bundle_set,
inverted_reconnect,
ResourceLimits::default(),
None,
);
assert_eq!(source.inner.reconnect.min_backoff, Duration::from_secs(1));
assert_eq!(source.inner.reconnect.max_backoff, Duration::from_secs(10));
}
#[test]
fn test_initial_sync_validation_records_correct_metrics() {
use super::super::builder::ResourceLimits;
use super::super::limits::validate_bundle_set;
let metrics = Arc::new(TestMetricsRecorder::new());
let limits = ResourceLimits {
max_bundles: Some(0), max_bundle_jwks_bytes: Some(1000),
};
let trust_domain = TrustDomain::new("example.org").unwrap();
let mut bundle = JwtBundle::new(trust_domain);
bundle.add_jwt_authority(jwk_with_kid("kid-1"));
let mut bundle_set = JwtBundleSet::new();
bundle_set.add_bundle(bundle);
let result = validate_bundle_set(&bundle_set, limits, Some(metrics.as_ref()));
assert!(matches!(
result,
Err(JwtSourceError::ResourceLimitExceeded {
kind: super::super::errors::LimitKind::MaxBundles,
..
})
));
assert_eq!(metrics.count(MetricsErrorKind::LimitMaxBundles), 1);
assert_eq!(metrics.count(MetricsErrorKind::UpdateRejected), 0);
assert_eq!(metrics.count(MetricsErrorKind::LimitMaxBundleJwksBytes), 0);
}
#[test]
fn test_resource_limits_unlimited() {
use super::super::builder::ResourceLimits;
let unlimited = ResourceLimits::unlimited();
assert_eq!(unlimited.max_bundles, None);
assert_eq!(unlimited.max_bundle_jwks_bytes, None);
}
#[test]
fn test_resource_limits_default_limits() {
use super::super::builder::ResourceLimits;
let limits = ResourceLimits::default_limits();
assert_eq!(limits.max_bundles, Some(200));
assert_eq!(limits.max_bundle_jwks_bytes, Some(4 * 1024 * 1024)); }
#[test]
fn test_resource_limits_mixed() {
use super::super::builder::ResourceLimits;
let mixed = ResourceLimits {
max_bundles: Some(50),
max_bundle_jwks_bytes: None, };
assert_eq!(mixed.max_bundles, Some(50));
assert_eq!(mixed.max_bundle_jwks_bytes, None);
}
}