use std::{
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
time::Duration,
};
use arc_swap::ArcSwap;
use rustls::{ClientConfig, client::VerifierBuilderError};
use thiserror::Error;
use tokio::time::sleep;
use tokio_stream::{Stream, StreamExt};
#[cfg(feature = "tracing")]
use tracing::{debug, error, info};
#[derive(Debug, Error)]
pub enum ClientConfigStreamError {
#[error("stream provider error")]
StreamError(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("empty stream")]
EmptyStream,
#[error("could not build stream")]
StreamBuilderError(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("cert verifier builder error")]
VerifierBuilderError(#[from] VerifierBuilderError),
#[error("missing client certified key")]
MissingCertifiedKey,
#[error("missing root certificates")]
MissingRoots,
#[error("rustls error")]
RustlsError(#[from] rustls::Error),
}
pub trait ClientConfigStreamBuilder {
type ConfigStream: Stream<Item = Result<Arc<ClientConfig>, ClientConfigStreamError>>
+ Send
+ Sync
+ Unpin
+ 'static;
fn build(
&mut self,
) -> impl std::future::Future<Output = Result<Self::ConfigStream, ClientConfigStreamError>> + Send;
}
pub struct ClientConfigProvider {
inner: ArcSwap<ClientConfig>,
stream_healthy: AtomicBool,
}
impl ClientConfigProvider {
pub async fn start<B>(mut builder: B) -> Result<Arc<Self>, ClientConfigStreamError>
where
B: ClientConfigStreamBuilder + Send + 'static,
{
let mut stream = builder.build().await?;
let initial = stream
.next()
.await
.ok_or(ClientConfigStreamError::EmptyStream)??;
let this = Arc::new(Self {
inner: ArcSwap::from(initial),
stream_healthy: AtomicBool::new(true),
});
let ret = this.clone();
tokio::spawn(async move {
let initial_delay = Duration::from_millis(10);
let mut delay = initial_delay;
let max_delay = Duration::from_secs(10);
loop {
match stream.next().await {
Some(Ok(client_config)) => {
this.inner.store(client_config);
#[cfg(feature = "tracing")]
debug!("stored updated client config from stream");
}
Some(Err(_)) | None => {
this.stream_healthy.store(false, Ordering::Relaxed);
#[cfg(feature = "tracing")]
error!("config stream returned error or none, trying to build new stream");
match builder.build().await {
Ok(s) => {
this.stream_healthy.store(true, Ordering::Relaxed);
delay = initial_delay;
stream = s;
#[cfg(feature = "tracing")]
info!("reestablished client config stream");
}
Err(err) => {
#[cfg(feature = "tracing")]
error!(retry_in_ms = delay.as_millis(), error = %err, "failed to reestablish client config stream");
sleep(delay).await;
delay = (delay * 2).min(max_delay);
}
};
}
}
}
});
Ok(ret)
}
pub fn stream_healthy(&self) -> bool {
self.stream_healthy.load(Ordering::Relaxed)
}
pub fn get_config(&self) -> Arc<ClientConfig> {
self.inner.load_full()
}
}
#[cfg(test)]
mod tests {
use std::{
collections::VecDeque,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
};
use rustls::{ClientConfig, RootCertStore};
use thiserror::Error;
use tokio::sync::{Mutex, mpsc};
use tokio_stream::wrappers::ReceiverStream;
use crate::{ClientConfigProvider, ClientConfigStreamBuilder, ClientConfigStreamError};
#[derive(Error, Debug)]
struct MockError(&'static str);
impl std::fmt::Display for MockError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.0)
}
}
fn empty_client_config() -> Arc<ClientConfig> {
Arc::from(
ClientConfig::builder()
.with_root_certificates(RootCertStore::empty())
.with_no_client_auth(),
)
}
#[derive(Debug)]
struct MockClientConfigStreamBuilder {
streams:
Mutex<VecDeque<mpsc::Receiver<Result<Arc<ClientConfig>, ClientConfigStreamError>>>>,
builds: Arc<AtomicUsize>,
}
impl MockClientConfigStreamBuilder {
fn new(
streams: Vec<mpsc::Receiver<Result<Arc<ClientConfig>, ClientConfigStreamError>>>,
) -> Self {
let builds = Arc::from(AtomicUsize::new(0));
let streams = Mutex::from(VecDeque::from(streams));
Self { streams, builds }
}
}
impl ClientConfigStreamBuilder for MockClientConfigStreamBuilder {
type ConfigStream = ReceiverStream<Result<Arc<ClientConfig>, ClientConfigStreamError>>;
async fn build(&mut self) -> Result<Self::ConfigStream, ClientConfigStreamError> {
self.builds.fetch_add(1, Ordering::SeqCst);
let rx = self.streams.lock().await.pop_front().ok_or_else(|| {
ClientConfigStreamError::StreamBuilderError(MockError("mock stream error").into())
})?;
Ok(ReceiverStream::new(rx))
}
}
#[tokio::test]
async fn start_fails_given_initial_stream_build_failure() {
let builder = MockClientConfigStreamBuilder::new(vec![]);
let res = ClientConfigProvider::start(builder).await;
match res {
Err(ClientConfigStreamError::StreamBuilderError(_)) => { }
_ => panic!("expected ClientConfigStreamError::EmptyStream"),
}
}
#[tokio::test]
async fn start_fails_when_stream_is_empty() {
let (tx, rx) = mpsc::channel(1);
std::mem::drop(tx);
let builder = MockClientConfigStreamBuilder::new(vec![rx]);
let res = ClientConfigProvider::start(builder).await;
match res {
Err(ClientConfigStreamError::EmptyStream) => { }
_ => panic!("expected ClientConfigStreamError::EmptyStream"),
}
}
#[tokio::test]
async fn start_fails_when_first_result_is_err() {
let (tx, rx) = mpsc::channel(1);
let builder = MockClientConfigStreamBuilder::new(vec![rx]);
tx.send(Err(ClientConfigStreamError::StreamError(
MockError("fake error").into(),
)))
.await
.unwrap();
let res = ClientConfigProvider::start(builder).await;
match res {
Err(ClientConfigStreamError::StreamError(err)) => {
assert_eq!(err.to_string(), "fake error");
}
_ => panic!("expected ClientConfigStreamError::EmptyStream"),
}
}
#[tokio::test]
async fn start_and_initial_config_is_loaded() {
let (tx, rx) = mpsc::channel(1);
let builder = MockClientConfigStreamBuilder::new(vec![rx]);
let expected = empty_client_config();
tx.send(Ok(expected.clone())).await.unwrap();
let provider = ClientConfigProvider::start(builder).await.unwrap();
let got = provider.get_config();
assert!(Arc::ptr_eq(&got, &expected));
assert!(provider.stream_healthy());
}
#[tokio::test]
async fn single_stream_config_hot_swap() {
let (tx, rx) = mpsc::channel(1);
let builder = MockClientConfigStreamBuilder::new(vec![rx]);
let initial = empty_client_config();
tx.send(Ok(initial.clone())).await.unwrap();
let provider = ClientConfigProvider::start(builder).await.unwrap();
let got = provider.get_config();
assert!(Arc::ptr_eq(&got, &initial));
assert!(provider.stream_healthy());
for i in 0..10 {
let expected = empty_client_config();
tx.send(Ok(expected.clone())).await.unwrap();
tokio::task::yield_now().await;
let got = provider.get_config();
assert!(
Arc::ptr_eq(&got, &expected),
"config not updated on iter {i}"
);
assert!(provider.stream_healthy());
}
}
#[tokio::test]
async fn stream_failure_triggers_rebuild() {
let (tx1, rx1) = mpsc::channel(1);
let (tx2, rx2) = mpsc::channel(1);
let builder = MockClientConfigStreamBuilder::new(vec![rx1, rx2]);
let builds = &builder.builds.clone();
let initial = empty_client_config();
tx1.send(Ok(initial.clone())).await.unwrap();
let provider = ClientConfigProvider::start(builder).await.unwrap();
assert!(Arc::ptr_eq(&provider.get_config(), &initial));
assert!(provider.stream_healthy());
tx1.send(Err(ClientConfigStreamError::StreamError(
MockError("fake error").into(),
)))
.await
.unwrap();
tokio::task::yield_now().await;
assert_eq!(builds.load(Ordering::SeqCst), 2);
let new = empty_client_config();
tx2.send(Ok(new.clone())).await.unwrap();
tokio::task::yield_now().await;
assert!(provider.stream_healthy());
assert!(Arc::ptr_eq(&provider.get_config(), &new))
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn stream_rebuild_goes_into_backoff() {
let (tx, rx) = mpsc::channel(1);
let builder = MockClientConfigStreamBuilder::new(vec![rx]);
let builds = &builder.builds.clone();
let initial = empty_client_config();
tx.send(Ok(initial.clone())).await.unwrap();
let provider = ClientConfigProvider::start(builder).await.unwrap();
assert!(Arc::ptr_eq(&provider.get_config(), &initial));
assert!(provider.stream_healthy());
assert_eq!(builds.load(Ordering::SeqCst), 1);
tx.send(Err(ClientConfigStreamError::StreamError(
MockError("fake error").into(),
)))
.await
.unwrap();
tokio::task::yield_now().await;
assert_eq!(builds.load(Ordering::SeqCst), 2);
assert!(!provider.stream_healthy.load(Ordering::Relaxed));
}
}