use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use tokio::task::JoinHandle;
use super::{Config, ConfigBuilder};
use crate::error::ConfigError;
pub struct ReloadableConfig {
builder: Arc<ConfigBuilder>,
tx: watch::Sender<Arc<Config>>,
rx: watch::Receiver<Arc<Config>>,
task: JoinHandle<()>,
}
impl ReloadableConfig {
pub async fn start(builder: ConfigBuilder, interval: Duration) -> Result<Self, ConfigError> {
let builder = Arc::new(builder.ensure_defaults()?);
let initial = Arc::new(builder.resolve().await?);
let (tx, rx) = watch::channel(initial);
let task = {
let builder = Arc::clone(&builder);
let tx = tx.clone();
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
ticker.tick().await; loop {
ticker.tick().await;
match builder.resolve().await {
Ok(next) => swap_if_changed(&tx, next),
Err(error) => {
tracing::warn!(%error, "config reload failed; keeping current values");
}
}
}
})
};
Ok(Self { builder, tx, rx, task })
}
#[must_use]
pub fn current(&self) -> Arc<Config> {
self.rx.borrow().clone()
}
#[must_use]
pub fn subscribe(&self) -> watch::Receiver<Arc<Config>> {
self.rx.clone()
}
pub async fn reload_now(&self) -> Result<(), ConfigError> {
let next = self.builder.resolve().await?;
swap_if_changed(&self.tx, next);
Ok(())
}
}
impl Drop for ReloadableConfig {
fn drop(&mut self) {
self.task.abort();
}
}
fn swap_if_changed(tx: &watch::Sender<Arc<Config>>, next: Config) {
let changed = tx.borrow().values() != next.values();
if changed {
tracing::info!("configuration changed; swapping in new values");
let _ = tx.send(Arc::new(next));
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Profile;
use crate::config::map::ConfigMap;
use crate::config::provider::{ConfigProvider, MemoryProvider, ProviderKind};
use async_trait::async_trait;
use serde_json::json;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Clone, Default)]
struct CountingProvider {
counter: Arc<AtomicU64>,
}
#[async_trait]
impl ConfigProvider for CountingProvider {
fn name(&self) -> String {
"counting".to_owned()
}
fn kind(&self) -> ProviderKind {
ProviderKind::Memory
}
async fn load(&self) -> Result<ConfigMap, ConfigError> {
let version = self.counter.fetch_add(1, Ordering::SeqCst);
Ok(ConfigMap::from_iter([("version".to_string(), json!(version))]))
}
}
#[tokio::test]
async fn reload_now_swaps_in_changed_values_and_notifies() {
let builder = ConfigBuilder::new(Profile::Test).with_provider(CountingProvider::default());
let config = ReloadableConfig::start(builder, Duration::from_secs(3600)).await.unwrap();
assert_eq!(config.current().get_raw("version"), Some(&json!(0)));
let sub = config.subscribe();
config.reload_now().await.unwrap();
assert_eq!(config.current().get_raw("version"), Some(&json!(1)));
assert!(sub.has_changed().unwrap(), "subscriber should see the swap");
}
#[tokio::test]
async fn reload_with_unchanged_values_does_not_notify() {
let builder =
ConfigBuilder::new(Profile::Test).with_provider(MemoryProvider::new().set("x", 1));
let config = ReloadableConfig::start(builder, Duration::from_secs(3600)).await.unwrap();
let sub = config.subscribe();
config.reload_now().await.unwrap();
assert!(!sub.has_changed().unwrap(), "no change → no notification");
}
#[tokio::test]
async fn periodic_refresh_picks_up_changes() {
let builder = ConfigBuilder::new(Profile::Test).with_provider(CountingProvider::default());
let config = ReloadableConfig::start(builder, Duration::from_millis(20)).await.unwrap();
let mut sub = config.subscribe();
tokio::time::timeout(Duration::from_secs(2), sub.changed()).await.unwrap().unwrap();
assert!(config.current().get_raw("version").unwrap().as_u64().unwrap() >= 1);
}
}