use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use arc_swap::ArcSwap;
use notify::{Config as NotifyConfig, Event, EventKind, PollWatcher, RecursiveMode, Watcher};
use serde::de::DeserializeOwned;
use tokio::sync::{broadcast, mpsc, watch as tokio_watch};
use tokio::task::JoinHandle;
use crate::error::{Error, Result};
use crate::format::Format;
use crate::loader::load_with_format;
const DEFAULT_DEBOUNCE: Duration = Duration::from_millis(200);
const DEFAULT_POLL_INTERVAL: Duration = Duration::from_millis(100);
type ChangeHook<T> = Arc<dyn Fn(Arc<T>) + Send + Sync + 'static>;
type ErrorHook = Arc<dyn Fn(Arc<Error>) + Send + Sync + 'static>;
pub struct Config<T> {
value: Arc<ArcSwap<T>>,
changes: tokio_watch::Receiver<()>,
errors: broadcast::Sender<Arc<Error>>,
hooks: Arc<Hooks<T>>,
_watch: Arc<WatchGuard>,
}
impl<T> Config<T> {
pub fn get(&self) -> Arc<T> {
self.value.load_full()
}
pub fn with<R>(&self, reader: impl FnOnce(&T) -> R) -> R {
let config = self.value.load();
reader(&config)
}
pub fn on_change(&self, hook: impl Fn(Arc<T>) + Send + Sync + 'static) {
self.hooks.add_change(hook);
}
pub fn on_error(&self, hook: impl Fn(Arc<Error>) + Send + Sync + 'static) {
self.hooks.add_error(hook);
}
pub async fn changed(&mut self) -> Result<Arc<T>> {
self.changes
.changed()
.await
.map_err(|_| Error::WatchClosed)?;
Ok(self.get())
}
pub fn subscribe(&self) -> Self {
self.clone()
}
pub fn errors(&self) -> broadcast::Receiver<Arc<Error>> {
self.errors.subscribe()
}
}
impl<T> Clone for Config<T> {
fn clone(&self) -> Self {
Self {
value: Arc::clone(&self.value),
changes: self.changes.clone(),
errors: self.errors.clone(),
hooks: Arc::clone(&self.hooks),
_watch: Arc::clone(&self._watch),
}
}
}
pub type WatchedConfig<T> = Config<T>;
#[derive(Debug, Clone, Copy)]
pub struct WatchOptions {
pub format: Option<Format>,
pub cooldown: Duration,
}
impl WatchOptions {
pub const fn new() -> Self {
Self {
format: None,
cooldown: DEFAULT_DEBOUNCE,
}
}
pub const fn with_format(mut self, format: Format) -> Self {
self.format = Some(format);
self
}
pub const fn with_cooldown(mut self, cooldown: Duration) -> Self {
self.cooldown = cooldown;
self
}
pub const fn with_debounce(mut self, cooldown: Duration) -> Self {
self.cooldown = cooldown;
self
}
}
impl Default for WatchOptions {
fn default() -> Self {
Self::new()
}
}
struct Hooks<T> {
on_change: std::sync::Mutex<Vec<ChangeHook<T>>>,
on_error: std::sync::Mutex<Vec<ErrorHook>>,
}
impl<T> Hooks<T> {
fn new() -> Self {
Self {
on_change: std::sync::Mutex::new(Vec::new()),
on_error: std::sync::Mutex::new(Vec::new()),
}
}
fn add_change(&self, hook: impl Fn(Arc<T>) + Send + Sync + 'static) {
if let Ok(mut hooks) = self.on_change.lock() {
hooks.push(Arc::new(hook));
}
}
fn add_error(&self, hook: impl Fn(Arc<Error>) + Send + Sync + 'static) {
if let Ok(mut hooks) = self.on_error.lock() {
hooks.push(Arc::new(hook));
}
}
fn call_change(&self, config: Arc<T>) {
let hooks = self
.on_change
.lock()
.map(|hooks| hooks.clone())
.unwrap_or_default();
for hook in hooks {
hook(Arc::clone(&config));
}
}
fn call_error(&self, error: Arc<Error>) {
let hooks = self
.on_error
.lock()
.map(|hooks| hooks.clone())
.unwrap_or_default();
for hook in hooks {
hook(Arc::clone(&error));
}
}
}
struct WatchGuard {
watcher: PollWatcher,
task: JoinHandle<()>,
}
impl Drop for WatchGuard {
fn drop(&mut self) {
let _ = &self.watcher;
self.task.abort();
}
}
pub async fn watch<T>(path: impl AsRef<Path>) -> Result<Config<T>>
where
T: DeserializeOwned + Send + Sync + 'static,
{
watch_with_options(path, WatchOptions::default()).await
}
pub async fn watch_with_format<T>(path: impl AsRef<Path>, format: Format) -> Result<Config<T>>
where
T: DeserializeOwned + Send + Sync + 'static,
{
watch_with_options(path, WatchOptions::default().with_format(format)).await
}
pub async fn watch_with_options<T>(
path: impl AsRef<Path>,
options: WatchOptions,
) -> Result<Config<T>>
where
T: DeserializeOwned + Send + Sync + 'static,
{
let path = absolute_path(path.as_ref())?;
let format = match options.format {
Some(format) => format,
None => Format::from_path(&path)?,
};
let parent = path
.parent()
.map(Path::to_path_buf)
.unwrap_or_else(|| PathBuf::from("."));
let initial = Arc::new(load_with_format(&path, format)?);
let value = Arc::new(ArcSwap::from(initial));
let (change_tx, change_rx) = tokio_watch::channel(());
let (error_tx, _) = broadcast::channel(16);
let (event_tx, mut event_rx) = mpsc::unbounded_channel();
let hooks = Arc::new(Hooks::new());
let mut watcher = PollWatcher::new(
move |event| {
let _ = event_tx.send(event);
},
NotifyConfig::default()
.with_poll_interval(DEFAULT_POLL_INTERVAL)
.with_compare_contents(true),
)
.map_err(|source| Error::Watch {
path: path.clone(),
source,
})?;
watcher
.watch(&parent, RecursiveMode::NonRecursive)
.map_err(|source| Error::Watch {
path: path.clone(),
source,
})?;
let task_path = path.clone();
let task_error_tx = error_tx.clone();
let task_hooks = Arc::clone(&hooks);
let task_value = Arc::clone(&value);
let task = tokio::spawn(async move {
while let Some(event) = event_rx.recv().await {
if should_reload(event, &task_path, &task_error_tx, &task_hooks) {
reload_config(
&task_path,
format,
&task_value,
&change_tx,
&task_error_tx,
&task_hooks,
);
suppress_cooldown_window(
&mut event_rx,
&task_path,
&task_error_tx,
&task_hooks,
options.cooldown,
)
.await;
}
}
});
Ok(Config {
value,
changes: change_rx,
errors: error_tx,
hooks,
_watch: Arc::new(WatchGuard { watcher, task }),
})
}
async fn suppress_cooldown_window<T>(
event_rx: &mut mpsc::UnboundedReceiver<notify::Result<Event>>,
path: &Path,
error_tx: &broadcast::Sender<Arc<Error>>,
hooks: &Hooks<T>,
cooldown: Duration,
) {
let window = tokio::time::sleep(cooldown);
tokio::pin!(window);
loop {
tokio::select! {
event = event_rx.recv() => {
let Some(event) = event else {
break;
};
let _ = should_reload(event, path, error_tx, hooks);
}
() = &mut window => {
break;
}
}
}
}
fn should_reload<T>(
event: notify::Result<Event>,
path: &Path,
error_tx: &broadcast::Sender<Arc<Error>>,
hooks: &Hooks<T>,
) -> bool {
match event {
Ok(event) => is_relevant_event(&event, path),
Err(source) => {
let error = Arc::new(Error::Watch {
path: path.to_path_buf(),
source,
});
let _ = error_tx.send(Arc::clone(&error));
hooks.call_error(error);
false
}
}
}
fn reload_config<T>(
path: &Path,
format: Format,
value: &Arc<ArcSwap<T>>,
change_tx: &tokio_watch::Sender<()>,
error_tx: &broadcast::Sender<Arc<Error>>,
hooks: &Hooks<T>,
) where
T: DeserializeOwned,
{
match load_with_format(path, format) {
Ok(config) => {
let config = Arc::new(config);
value.store(Arc::clone(&config));
let _ = change_tx.send(());
hooks.call_change(config);
}
Err(error) => {
let error = Arc::new(error);
let _ = error_tx.send(Arc::clone(&error));
hooks.call_error(error);
}
}
}
fn absolute_path(path: &Path) -> Result<PathBuf> {
if path.is_absolute() {
return Ok(path.to_path_buf());
}
std::env::current_dir()
.map(|cwd| cwd.join(path))
.map_err(|source| Error::ResolvePath {
path: path.to_path_buf(),
source,
})
}
fn is_relevant_event(event: &Event, path: &Path) -> bool {
if matches!(event.kind, EventKind::Access(_)) {
return false;
}
let target_name = path.file_name();
let target_parent = path.parent();
event.paths.iter().any(|changed| {
changed == path
|| changed
.parent()
.zip(target_parent)
.is_some_and(|(changed_parent, target_parent)| changed_parent == target_parent)
|| target_parent.is_some_and(|target_parent| changed == target_parent)
|| changed
.file_name()
.zip(target_name)
.is_some_and(|(changed, target)| changed == target)
})
}