use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, watch};
use tokio::task::JoinHandle;
use super::{Config, ConfigBuilder};
use crate::error::ConfigError;
pub struct ReloadableConfig {
builder: Arc<ConfigBuilder>,
tx: watch::Sender<Arc<Config>>,
rx: watch::Receiver<Arc<Config>>,
task: JoinHandle<()>,
}
impl ReloadableConfig {
pub async fn start(builder: ConfigBuilder, interval: Duration) -> Result<Self, ConfigError> {
let builder = Arc::new(builder.ensure_defaults()?);
let initial = Arc::new(builder.resolve().await?);
let (tx, rx) = watch::channel(initial);
let task = {
let builder = Arc::clone(&builder);
let tx = tx.clone();
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
ticker.tick().await; loop {
ticker.tick().await;
match builder.resolve().await {
Ok(next) => swap_if_changed(&tx, next),
Err(error) => {
tracing::warn!(%error, "config reload failed; keeping current values");
}
}
}
})
};
Ok(Self { builder, tx, rx, task })
}
#[must_use]
pub fn current(&self) -> Arc<Config> {
self.rx.borrow().clone()
}
#[must_use]
pub fn subscribe(&self) -> watch::Receiver<Arc<Config>> {
self.rx.clone()
}
pub async fn reload_now(&self) -> Result<(), ConfigError> {
let next = self.builder.resolve().await?;
swap_if_changed(&self.tx, next);
Ok(())
}
pub async fn start_with_refresh(
builder: ConfigBuilder,
interval: Duration,
) -> Result<(Self, RefreshTrigger), ConfigError> {
let builder = Arc::new(builder.ensure_defaults()?);
let initial = Arc::new(builder.resolve().await?);
let (tx, rx) = watch::channel(initial);
let (trigger_tx, mut trigger_rx) = mpsc::channel::<()>(1);
let task = {
let builder = Arc::clone(&builder);
let tx = tx.clone();
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
ticker.tick().await; let mut triggers_open = true;
loop {
tokio::select! {
_ = ticker.tick() => {}
signal = trigger_rx.recv(), if triggers_open => {
if signal.is_none() {
triggers_open = false;
continue;
}
}
}
match builder.resolve().await {
Ok(next) => swap_if_changed(&tx, next),
Err(error) => {
tracing::warn!(%error, "config reload failed; keeping current values");
}
}
}
})
};
Ok((Self { builder, tx, rx, task }, RefreshTrigger(trigger_tx)))
}
}
#[derive(Clone)]
pub struct RefreshTrigger(mpsc::Sender<()>);
impl RefreshTrigger {
pub fn refresh(&self) {
let _ = self.0.try_send(());
}
}
impl Drop for ReloadableConfig {
fn drop(&mut self) {
self.task.abort();
}
}
fn swap_if_changed(tx: &watch::Sender<Arc<Config>>, next: Config) {
let changed = tx.borrow().values() != next.values();
if changed {
tracing::info!("configuration changed; swapping in new values");
let _ = tx.send(Arc::new(next));
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Profile;
use crate::config::map::ConfigMap;
use crate::config::provider::{ConfigProvider, MemoryProvider, ProviderKind};
use async_trait::async_trait;
use serde_json::json;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Clone, Default)]
struct CountingProvider {
counter: Arc<AtomicU64>,
}
#[async_trait]
impl ConfigProvider for CountingProvider {
fn name(&self) -> String {
"counting".to_owned()
}
fn kind(&self) -> ProviderKind {
ProviderKind::Memory
}
async fn load(&self) -> Result<ConfigMap, ConfigError> {
let version = self.counter.fetch_add(1, Ordering::SeqCst);
Ok(ConfigMap::from_iter([("version".to_string(), json!(version))]))
}
}
#[tokio::test]
async fn reload_now_swaps_in_changed_values_and_notifies() {
let builder = ConfigBuilder::new(Profile::Test).with_provider(CountingProvider::default());
let config = ReloadableConfig::start(builder, Duration::from_secs(3600)).await.unwrap();
assert_eq!(config.current().get_raw("version"), Some(&json!(0)));
let sub = config.subscribe();
config.reload_now().await.unwrap();
assert_eq!(config.current().get_raw("version"), Some(&json!(1)));
assert!(sub.has_changed().unwrap(), "subscriber should see the swap");
}
#[tokio::test]
async fn reload_with_unchanged_values_does_not_notify() {
let builder =
ConfigBuilder::new(Profile::Test).with_provider(MemoryProvider::new().set("x", 1));
let config = ReloadableConfig::start(builder, Duration::from_secs(3600)).await.unwrap();
let sub = config.subscribe();
config.reload_now().await.unwrap();
assert!(!sub.has_changed().unwrap(), "no change → no notification");
}
#[tokio::test]
async fn periodic_refresh_picks_up_changes() {
let builder = ConfigBuilder::new(Profile::Test).with_provider(CountingProvider::default());
let config = ReloadableConfig::start(builder, Duration::from_millis(20)).await.unwrap();
let mut sub = config.subscribe();
tokio::time::timeout(Duration::from_secs(2), sub.changed()).await.unwrap().unwrap();
assert!(config.current().get_raw("version").unwrap().as_u64().unwrap() >= 1);
}
#[tokio::test]
async fn refresh_trigger_forces_a_reload() {
let builder = ConfigBuilder::new(Profile::Test).with_provider(CountingProvider::default());
let (config, trigger) =
ReloadableConfig::start_with_refresh(builder, Duration::from_secs(3600)).await.unwrap();
assert_eq!(config.current().get_raw("version"), Some(&json!(0)));
let mut sub = config.subscribe();
trigger.refresh();
tokio::time::timeout(Duration::from_secs(2), sub.changed()).await.unwrap().unwrap();
assert_eq!(config.current().get_raw("version"), Some(&json!(1)));
}
#[tokio::test]
async fn periodic_refresh_survives_dropping_the_trigger() {
let builder = ConfigBuilder::new(Profile::Test).with_provider(CountingProvider::default());
let (config, trigger) =
ReloadableConfig::start_with_refresh(builder, Duration::from_millis(20)).await.unwrap();
drop(trigger);
let mut sub = config.subscribe();
tokio::time::timeout(Duration::from_secs(2), sub.changed()).await.unwrap().unwrap();
assert!(config.current().get_raw("version").unwrap().as_u64().unwrap() >= 1);
}
}