use crate::{FileWatcher, Result, WatchError};
use arc_swap::ArcSwap;
use serde::de::DeserializeOwned;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tracing::{error, info, warn};
pub trait HotConfig: DeserializeOwned + Send + Sync + 'static {
fn validate(&self) -> std::result::Result<(), String> {
Ok(())
}
}
pub struct ConfigWatcher<T: HotConfig> {
config: Arc<ArcSwap<T>>,
path: PathBuf,
_watcher: FileWatcher,
last_reload: Arc<Mutex<Instant>>,
debounce_duration: Duration,
}
impl<T: HotConfig> ConfigWatcher<T> {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
Self::with_debounce(path, Duration::from_millis(500))
}
pub fn with_debounce<P: AsRef<Path>>(path: P, debounce: Duration) -> Result<Self> {
let path = path.as_ref().to_path_buf();
let initial_config = Self::load_config(&path)?;
initial_config
.validate()
.map_err(|e| WatchError::ParseFailed(format!("Initial config validation failed: {}", e)))?;
info!("Initial config loaded and validated from: {:?}", path);
let config = Arc::new(ArcSwap::new(Arc::new(initial_config)));
let config_clone = config.clone();
let path_clone = path.clone();
let last_reload = Arc::new(Mutex::new(Instant::now()));
let last_reload_clone = last_reload.clone();
let debounce_clone = debounce;
let watcher = FileWatcher::new(&path, move |_| {
{
let mut last = last_reload_clone.lock().unwrap();
let now = Instant::now();
if now.duration_since(*last) < debounce_clone {
return; }
*last = now;
}
info!("Config file changed, reloading: {:?}", path_clone);
match Self::load_config(&path_clone) {
Ok(new_config) => {
if let Err(e) = new_config.validate() {
error!("Config validation failed, keeping previous config: {}", e);
return;
}
config_clone.store(Arc::new(new_config));
info!("Config reloaded and validated successfully");
}
Err(e) => {
error!("Failed to reload config (keeping previous): {}", e);
}
}
})?;
Ok(Self {
config,
path,
_watcher: watcher,
last_reload,
debounce_duration: debounce,
})
}
pub fn get(&self) -> Arc<T> {
self.config.load_full()
}
pub fn reload(&self) -> Result<()> {
let new_config = Self::load_config(&self.path)?;
new_config
.validate()
.map_err(|e| WatchError::ParseFailed(e))?;
self.config.store(Arc::new(new_config));
info!("Config manually reloaded");
Ok(())
}
fn load_config(path: &Path) -> Result<T> {
let content = fs::read_to_string(path)?;
let extension = path.extension().and_then(|e| e.to_str());
match extension {
#[cfg(feature = "json")]
Some("json") => serde_json::from_str(&content)
.map_err(|e| WatchError::ParseFailed(e.to_string())),
#[cfg(feature = "toml")]
Some("toml") => toml::from_str(&content)
.map_err(|e| WatchError::ParseFailed(e.to_string())),
#[cfg(feature = "yaml")]
Some("yaml") | Some("yml") => serde_yaml::from_str(&content)
.map_err(|e| WatchError::ParseFailed(e.to_string())),
_ => Err(WatchError::ParseFailed(format!(
"Unsupported file extension: {:?}",
extension
))),
}
}
pub fn path(&self) -> &Path {
&self.path
}
pub fn debounce_duration(&self) -> Duration {
self.debounce_duration
}
pub fn time_since_last_reload(&self) -> Duration {
let last = self.last_reload.lock().unwrap();
Instant::now().duration_since(*last)
}
}
pub fn watch<T: HotConfig, P: AsRef<Path>>(path: P) -> Result<ConfigWatcher<T>> {
ConfigWatcher::new(path)
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use tempfile::Builder;
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
struct TestConfig {
value: i32,
name: String,
}
impl HotConfig for TestConfig {}
#[test]
#[cfg(feature = "json")]
fn test_config_watcher_json() {
let temp = Builder::new()
.suffix(".json")
.tempfile()
.unwrap();
let config = TestConfig {
value: 42,
name: "test".to_string(),
};
std::fs::write(temp.path(), serde_json::to_string(&config).unwrap()).unwrap();
let watcher = ConfigWatcher::<TestConfig>::new(temp.path()).unwrap();
let loaded = watcher.get();
assert_eq!(loaded.value, 42);
assert_eq!(loaded.name, "test");
}
}