use super::builder::{ReconnectConfig, ResourceLimits};
use super::errors::{MetricsErrorKind, X509SourceError};
use super::limits::{select_svid, validate_context};
use super::metrics::MetricsRecorder;
use super::supervisor::initial_sync_with_retry;
use crate::bundle::BundleSource;
use crate::prelude::warn;
use crate::svid::SvidSource;
use crate::workload_api::x509_context::X509Context;
use crate::x509_source::types::{ClientFactory, SvidPicker};
use crate::{TrustDomain, X509Bundle, X509BundleSet, X509SourceBuilder, X509Svid};
use arc_swap::ArcSwap;
use std::fmt::Debug;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{watch, Mutex};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
#[cfg(test)]
use crate::WorkloadApiError;
#[derive(Clone, Debug)]
pub struct X509SourceUpdates {
rx: watch::Receiver<u64>,
}
impl X509SourceUpdates {
pub async fn changed(&mut self) -> Result<u64, X509SourceError> {
self.rx
.changed()
.await
.map_err(|watch::error::RecvError { .. }| X509SourceError::Closed)?;
Ok(*self.rx.borrow())
}
pub fn last(&self) -> u64 {
*self.rx.borrow()
}
pub async fn wait_for<F>(&mut self, mut f: F) -> Result<u64, X509SourceError>
where
F: FnMut(&u64) -> bool,
{
let current = self.last();
if f(¤t) {
return Ok(current);
}
loop {
let seq = self.changed().await?;
if f(&seq) {
return Ok(seq);
}
}
}
}
#[derive(Clone, Debug)]
pub struct X509Source {
inner: Arc<Inner>,
}
pub(super) struct Inner {
x509_context: ArcSwap<X509Context>,
svid_picker: Option<Box<dyn SvidPicker>>,
limits: ResourceLimits,
reconnect: ReconnectConfig,
make_client: ClientFactory,
metrics: Option<Arc<dyn MetricsRecorder>>,
closed: AtomicBool,
cancel: CancellationToken,
shutdown_timeout: Option<Duration>,
update_seq: AtomicU64,
update_tx: watch::Sender<u64>,
update_rx: watch::Receiver<u64>,
supervisor: Mutex<Option<JoinHandle<()>>>,
}
impl Inner {
pub(super) const fn reconnect(&self) -> ReconnectConfig {
self.reconnect
}
pub(super) fn metrics(&self) -> Option<&dyn MetricsRecorder> {
self.metrics.as_deref()
}
pub(super) fn make_client(&self) -> &ClientFactory {
&self.make_client
}
}
impl Drop for Inner {
fn drop(&mut self) {
self.cancel.cancel();
}
}
impl Debug for Inner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("X509Source")
.field("x509_context", &"<ArcSwap<X509Context>>")
.field(
"svid_picker",
&self.svid_picker.as_ref().map(|_| "<SvidPicker>"),
)
.field("reconnect", &self.reconnect)
.field("limits", &self.limits)
.field("make_client", &"<ClientFactory>")
.field(
"metrics",
&self.metrics.as_ref().map(|_| "<MetricsRecorder>"),
)
.field("shutdown_timeout", &self.shutdown_timeout)
.field("closed", &self.closed.load(Ordering::Relaxed))
.field("cancel", &self.cancel)
.field("update_seq", &self.update_seq)
.field("update_tx", &"<watch::Sender<u64>>")
.field("update_rx", &"<watch::Receiver<u64>>")
.field("supervisor", &"<Mutex<Option<JoinHandle<()>>>>")
.finish()
}
}
impl X509Source {
pub async fn new() -> Result<Self, X509SourceError> {
X509SourceBuilder::new().build().await
}
pub fn builder() -> X509SourceBuilder {
X509SourceBuilder::new()
}
pub fn updated(&self) -> X509SourceUpdates {
X509SourceUpdates {
rx: self.inner.update_rx.clone(),
}
}
pub fn is_healthy(&self) -> bool {
if self.inner.closed.load(Ordering::Acquire) || self.inner.cancel.is_cancelled() {
return false;
}
let ctx = self.inner.x509_context.load();
select_svid(&ctx, self.inner.svid_picker.as_deref()).is_some()
}
pub fn x509_context(&self) -> Result<Arc<X509Context>, X509SourceError> {
self.assert_open()?;
Ok(self.inner.x509_context.load_full())
}
pub fn svid(&self) -> Result<Arc<X509Svid>, X509SourceError> {
self.assert_open()?;
let ctx = self.inner.x509_context.load();
select_svid(&ctx, self.inner.svid_picker.as_deref()).ok_or_else(|| {
self.inner.record_error(MetricsErrorKind::NoSuitableSvid);
X509SourceError::NoSuitableSvid
})
}
pub fn try_svid(&self) -> Option<Arc<X509Svid>> {
self.svid().ok()
}
pub fn bundle_set(&self) -> Result<Arc<X509BundleSet>, X509SourceError> {
self.assert_open()?;
Ok(Arc::clone(self.inner.x509_context.load().bundle_set()))
}
pub fn try_bundle_for_trust_domain(&self, td: &TrustDomain) -> Option<Arc<X509Bundle>> {
self.bundle_for_trust_domain(td).ok().flatten()
}
pub async fn shutdown(&self) {
if self.inner.closed.swap(true, Ordering::AcqRel) {
return;
}
self.inner.cancel.cancel();
if let Some(handle) = self.inner.supervisor.lock().await.take() {
if let Err(e) = handle.await {
warn!("Error joining supervisor task during shutdown: error={e}");
self.inner
.record_error(MetricsErrorKind::SupervisorJoinFailed);
}
}
}
pub async fn shutdown_with_timeout(&self, timeout: Duration) -> Result<(), X509SourceError> {
if self.inner.closed.swap(true, Ordering::AcqRel) {
return Ok(());
}
self.inner.cancel.cancel();
let Some(mut handle) = self.inner.supervisor.lock().await.take() else {
return Ok(());
};
match tokio::time::timeout(timeout, &mut handle).await {
Ok(Ok(())) => Ok(()),
Ok(Err(e)) => {
warn!("Error joining supervisor task during shutdown: error={e}");
self.inner
.record_error(MetricsErrorKind::SupervisorJoinFailed);
Ok(())
}
Err(_) => {
warn!("Shutdown timeout exceeded; aborting supervisor task");
handle.abort();
let _unused: Result<_, _> = handle.await;
Err(X509SourceError::ShutdownTimeout)
}
}
}
pub async fn shutdown_configured(&self) -> Result<(), X509SourceError> {
if let Some(timeout) = self.inner.shutdown_timeout {
self.shutdown_with_timeout(timeout).await
} else {
self.shutdown().await;
Ok(())
}
}
}
impl X509Source {
pub(super) async fn build_with(
make_client: ClientFactory,
svid_picker: Option<Box<dyn SvidPicker>>,
reconnect: ReconnectConfig,
limits: ResourceLimits,
metrics: Option<Arc<dyn MetricsRecorder>>,
shutdown_timeout: Option<Duration>,
) -> Result<Self, X509SourceError> {
let reconnect = super::builder::normalize_reconnect(reconnect);
let (update_tx, update_rx) = watch::channel(0u64);
let cancel = CancellationToken::new();
let initial_ctx = initial_sync_with_retry(
&make_client,
svid_picker.as_deref(),
&cancel,
reconnect,
limits,
metrics.as_deref(),
)
.await?;
let inner = Arc::new(Inner {
x509_context: ArcSwap::from(initial_ctx),
svid_picker,
reconnect,
make_client,
limits,
metrics,
shutdown_timeout,
closed: AtomicBool::new(false),
cancel,
update_seq: AtomicU64::new(0),
update_tx,
update_rx,
supervisor: Mutex::new(None),
});
let task_inner = Arc::clone(&inner);
let token = task_inner.cancel.clone();
let handle = tokio::spawn(async move {
task_inner.run_update_supervisor(token).await;
});
*inner.supervisor.lock().await = Some(handle);
Ok(Self { inner })
}
#[cfg(test)]
pub(super) fn new_for_test(
initial_ctx: Arc<X509Context>,
reconnect: ReconnectConfig,
limits: ResourceLimits,
metrics: Option<Arc<dyn MetricsRecorder>>,
svid_picker: Option<Box<dyn SvidPicker>>,
) -> Self {
let reconnect = super::builder::normalize_reconnect(reconnect);
let (update_tx, update_rx) = watch::channel(0u64);
let cancel = CancellationToken::new();
let make_client: ClientFactory =
Arc::new(|| Box::pin(async move { Err(WorkloadApiError::EmptyResponse) }));
let inner = Inner {
x509_context: ArcSwap::from(initial_ctx),
svid_picker,
reconnect,
make_client,
limits,
metrics,
shutdown_timeout: None,
closed: AtomicBool::new(false),
cancel,
update_seq: AtomicU64::new(0),
update_tx,
update_rx,
supervisor: Mutex::new(None),
};
Self {
inner: Arc::new(inner),
}
}
fn assert_open(&self) -> Result<(), X509SourceError> {
if self.inner.closed.load(Ordering::Acquire) || self.inner.cancel.is_cancelled() {
return Err(X509SourceError::Closed);
}
Ok(())
}
}
impl Inner {
pub(super) fn record_error(&self, kind: MetricsErrorKind) {
if let Some(metrics) = self.metrics.as_deref() {
metrics.record_error(kind);
}
}
pub(super) fn record_update(&self) {
if let Some(metrics) = self.metrics.as_deref() {
metrics.record_update();
}
}
pub(super) fn apply_update(&self, new_ctx: Arc<X509Context>) -> Result<(), X509SourceError> {
match self.validate_and_select(&new_ctx) {
Ok(()) => {
self.x509_context.store(new_ctx);
self.record_update();
self.notify_update();
Ok(())
}
Err(e) => {
self.record_error(MetricsErrorKind::UpdateRejected);
Err(e)
}
}
}
pub(super) fn notify_update(&self) {
let next = self.update_seq.fetch_add(1, Ordering::Relaxed) + 1;
let _unused: Result<_, _> = self.update_tx.send(next);
}
pub(super) fn validate_and_select(&self, ctx: &X509Context) -> Result<(), X509SourceError> {
validate_context(
ctx,
self.svid_picker.as_deref(),
self.limits,
self.metrics.as_deref(),
)
}
}
impl SvidSource for X509Source {
type Item = X509Svid;
type Error = X509SourceError;
fn svid(&self) -> Result<Arc<Self::Item>, Self::Error> {
Self::svid(self)
}
}
impl BundleSource for X509Source {
type Item = X509Bundle;
type Error = X509SourceError;
fn bundle_for_trust_domain(
&self,
trust_domain: &TrustDomain,
) -> Result<Option<Arc<Self::Item>>, Self::Error> {
self.assert_open()?;
let ctx = self.inner.x509_context.load();
Ok(ctx.bundle_set().get(trust_domain))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::sync::Mutex;
#[tokio::test]
async fn test_wait_for_immediate_satisfaction() {
let (tx, rx) = watch::channel(5u64);
let mut updates = X509SourceUpdates { rx };
let result = updates.wait_for(|&seq| seq > 3).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 5);
let _unused: Result<_, _> = tx.send(10);
let result = updates.wait_for(|&seq| seq > 8).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 10);
}
#[tokio::test]
async fn test_wait_for_waits_when_not_satisfied() {
let (tx, rx) = watch::channel(1u64);
let mut updates = X509SourceUpdates { rx };
let tx_clone = tx.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
let _unused: Result<_, _> = tx_clone.send(5);
});
let result = tokio::time::timeout(Duration::from_secs(1), updates.wait_for(|&seq| seq > 3))
.await
.expect("Should complete within timeout");
assert!(result.is_ok());
assert_eq!(result.unwrap(), 5);
}
#[tokio::test]
async fn test_updated_only_notifies_on_rotations_after_initial_sync() {
let (tx, rx) = watch::channel(0u64);
let mut updates = X509SourceUpdates { rx: rx.clone() };
let tx_clone = tx.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
let _unused: Result<_, _> = tx_clone.send(1);
});
let result = tokio::time::timeout(Duration::from_secs(1), updates.changed())
.await
.expect("Should complete within timeout");
assert!(result.is_ok());
assert_eq!(result.unwrap(), 1);
assert_eq!(updates.last(), 1);
}
#[tokio::test]
async fn test_updated_initial_sequence_is_zero() {
let (_tx, rx) = watch::channel(0u64);
let updates = X509SourceUpdates { rx };
assert_eq!(updates.last(), 0);
}
struct TestMetricsRecorder {
counts: Arc<Mutex<HashMap<MetricsErrorKind, u64>>>,
}
impl TestMetricsRecorder {
fn new() -> Self {
Self {
counts: Arc::new(Mutex::new(HashMap::new())),
}
}
fn count(&self, kind: MetricsErrorKind) -> u64 {
*self.counts.lock().unwrap().get(&kind).unwrap_or(&0)
}
}
impl MetricsRecorder for TestMetricsRecorder {
fn record_update(&self) {}
fn record_reconnect(&self) {}
fn record_error(&self, kind: MetricsErrorKind) {
*self.counts.lock().unwrap().entry(kind).or_insert(0) += 1;
}
}
#[test]
fn test_metrics_recorded_exactly_once_per_rejected_update() {
use super::super::builder::{ReconnectConfig, ResourceLimits};
use crate::workload_api::x509_context::X509Context;
use crate::{TrustDomain, X509Bundle, X509BundleSet, X509Svid};
use std::sync::Arc;
let cert_bytes = include_bytes!("../../tests/testdata/svid/x509/1-svid-chain.der");
let key_bytes = include_bytes!("../../tests/testdata/svid/x509/1-key.der");
let svid = Arc::new(X509Svid::parse_from_der(cert_bytes, key_bytes).unwrap());
let metrics = Arc::new(TestMetricsRecorder::new());
let limits = ResourceLimits {
max_svids: Some(100),
max_bundles: Some(0), max_bundle_der_bytes: Some(1000),
};
let trust_domain = TrustDomain::new("example.org").unwrap();
let bundle = X509Bundle::new(trust_domain);
let mut bundle_set = X509BundleSet::new();
bundle_set.add_bundle(bundle);
let ctx = X509Context::new([svid], Arc::new(bundle_set));
let source = {
let metrics = Arc::clone(&metrics);
X509Source::new_for_test(
Arc::new(X509Context::new([], Arc::new(X509BundleSet::new()))),
ReconnectConfig::default(),
limits,
Some(metrics),
None,
)
};
let result = source.inner.apply_update(Arc::new(ctx));
assert!(matches!(
result,
Err(X509SourceError::ResourceLimitExceeded {
kind: super::super::errors::LimitKind::MaxBundles,
..
})
));
assert_eq!(metrics.count(MetricsErrorKind::LimitMaxBundles), 1);
assert_eq!(metrics.count(MetricsErrorKind::UpdateRejected), 1);
assert_eq!(metrics.count(MetricsErrorKind::LimitMaxSvids), 0);
assert_eq!(metrics.count(MetricsErrorKind::LimitMaxBundleDerBytes), 0);
}
#[test]
fn test_new_with_normalizes_reconnect_config() {
use super::super::builder::{ReconnectConfig, ResourceLimits};
use crate::workload_api::x509_context::X509Context;
use crate::{X509BundleSet, X509Svid};
use std::sync::Arc;
use std::time::Duration;
let cert_bytes = include_bytes!("../../tests/testdata/svid/x509/1-svid-chain.der");
let key_bytes = include_bytes!("../../tests/testdata/svid/x509/1-key.der");
let svid = Arc::new(X509Svid::parse_from_der(cert_bytes, key_bytes).unwrap());
let ctx = X509Context::new([svid], Arc::new(X509BundleSet::new()));
let inverted_reconnect = ReconnectConfig {
min_backoff: Duration::from_secs(10),
max_backoff: Duration::from_secs(1),
};
let source = X509Source::new_for_test(
Arc::new(ctx),
inverted_reconnect,
ResourceLimits::default(),
None,
None,
);
assert_eq!(source.inner.reconnect.min_backoff, Duration::from_secs(1));
assert_eq!(source.inner.reconnect.max_backoff, Duration::from_secs(10));
}
#[test]
fn test_initial_sync_validation_records_correct_metrics() {
use super::super::builder::ResourceLimits;
use super::super::limits::validate_context;
use crate::workload_api::x509_context::X509Context;
use crate::{TrustDomain, X509Bundle, X509BundleSet, X509Svid};
use std::sync::Arc;
let cert_bytes = include_bytes!("../../tests/testdata/svid/x509/1-svid-chain.der");
let key_bytes = include_bytes!("../../tests/testdata/svid/x509/1-key.der");
let svid = Arc::new(X509Svid::parse_from_der(cert_bytes, key_bytes).unwrap());
let metrics = Arc::new(TestMetricsRecorder::new());
let limits = ResourceLimits {
max_svids: Some(100),
max_bundles: Some(0), max_bundle_der_bytes: Some(1000),
};
let trust_domain = TrustDomain::new("example.org").unwrap();
let bundle = X509Bundle::new(trust_domain);
let mut bundle_set = X509BundleSet::new();
bundle_set.add_bundle(bundle);
let ctx = X509Context::new([svid], Arc::new(bundle_set));
let result = validate_context(
&ctx,
None, limits,
Some(metrics.as_ref()),
);
assert!(matches!(
result,
Err(X509SourceError::ResourceLimitExceeded {
kind: super::super::errors::LimitKind::MaxBundles,
..
})
));
assert_eq!(metrics.count(MetricsErrorKind::LimitMaxBundles), 1);
assert_eq!(metrics.count(MetricsErrorKind::UpdateRejected), 0);
assert_eq!(metrics.count(MetricsErrorKind::LimitMaxSvids), 0);
assert_eq!(metrics.count(MetricsErrorKind::LimitMaxBundleDerBytes), 0);
}
#[test]
fn test_resource_limits_unlimited() {
use super::super::builder::ResourceLimits;
let unlimited = ResourceLimits::unlimited();
assert_eq!(unlimited.max_svids, None);
assert_eq!(unlimited.max_bundles, None);
assert_eq!(unlimited.max_bundle_der_bytes, None);
}
#[test]
fn test_resource_limits_default_limits() {
use super::super::builder::ResourceLimits;
let limits = ResourceLimits::default_limits();
assert_eq!(limits.max_svids, Some(100));
assert_eq!(limits.max_bundles, Some(200));
assert_eq!(limits.max_bundle_der_bytes, Some(4 * 1024 * 1024)); }
#[test]
fn test_resource_limits_mixed() {
use super::super::builder::ResourceLimits;
let mixed = ResourceLimits {
max_svids: Some(50),
max_bundles: None, max_bundle_der_bytes: Some(1024 * 1024), };
assert_eq!(mixed.max_svids, Some(50));
assert_eq!(mixed.max_bundles, None);
assert_eq!(mixed.max_bundle_der_bytes, Some(1024 * 1024));
}
}