use anyhow::{anyhow, Context as _, Result};
use crossbeam_channel::{bounded, select, unbounded, Receiver, Sender};
use notify::{
event::ModifyKind, recommended_watcher, Event, EventKind, RecommendedWatcher, RecursiveMode,
Watcher as _,
};
use std::{
collections::{hash_map::Entry, HashMap},
path::{Path, PathBuf},
thread::{self, JoinHandle},
time::{Duration, Instant},
};
use tracing::{debug, error, instrument, warn};
use crate::test_hooks;
fn canonicalize_path(path: &Path) -> PathBuf {
if let Ok(canonical) = path.canonicalize() {
return canonical;
}
for ancestor in path.ancestors().skip(1) {
if let Ok(canonical_ancestor) = ancestor.canonicalize() {
if let Ok(remaining) = path.strip_prefix(ancestor) {
return canonical_ancestor.join(remaining);
}
}
}
path.to_path_buf()
}
pub fn best_effort_watch<'a>(
watcher: &mut RecommendedWatcher,
path: &'a Path,
) -> Result<(&'a Path, Option<&'a Path>)> {
let mut watched_path = Err(anyhow!("empty path"));
for watch_path in path.ancestors() {
match watcher.watch(watch_path, RecursiveMode::NonRecursive) {
Ok(_) => {
watched_path = Ok(watch_path);
break;
}
Err(err) => watched_path = Err(err.into()),
}
}
let watched_path = watched_path.context("adding notify watch for config file")?;
let remaining_path = path
.strip_prefix(watched_path)
.expect("watched_path was obtained as an ancestor of path, yet it is not a prefix");
let immediate_child = remaining_path.iter().next();
debug!("Actually watching {}, ic {:?}", watched_path.display(), &immediate_child);
Ok((watched_path, immediate_child.map(Path::new)))
}
pub struct ConfigWatcher {
tx: Sender<Command>,
#[allow(unused)]
worker: JoinHandle<()>,
#[cfg(test)]
debug_rx: Receiver<()>,
}
impl ConfigWatcher {
#[instrument(skip_all)]
pub fn new(handler: impl FnMut() + Send + 'static) -> Result<Self> {
Self::with_debounce(handler, Duration::from_millis(100))
}
#[instrument(skip_all)]
pub fn with_debounce(
handler: impl FnMut() + Send + 'static,
reload_debounce: Duration,
) -> Result<Self> {
let (notify_tx, notify_rx) = unbounded();
let (req_tx, req_rx) = unbounded();
#[cfg(test)]
let (debug_tx, debug_rx) = unbounded();
let watcher = recommended_watcher(notify_tx).context("create notify watcher")?;
let mut inner = ConfigWatcherInner {
reload_debounce,
reload_deadline: None,
handler,
watcher,
notify_rx,
req_rx,
#[cfg(test)]
debug_tx,
paths: Default::default(),
};
let worker = thread::Builder::new()
.name("config-reload".to_string())
.spawn(move || {
if let Err(err) = inner.run() {
error!("config reload thread returned error: {:?}", err);
}
})
.context("create config reload thread")?;
Ok(Self {
tx: req_tx,
worker,
#[cfg(test)]
debug_rx,
})
}
#[instrument(skip_all)]
pub fn watch(&self, path: impl AsRef<Path>) -> Result<()> {
let (tx, rx) = bounded(1);
self.tx
.send(Command::AddWatch(path.as_ref().to_owned(), tx))
.context("sending AddWatch to ConfigWatcherInner")?;
rx.recv()?
}
#[cfg(test)]
fn worker_ready(&self) {
self.debug_rx.recv().unwrap();
debug!("worker ready");
}
}
impl Drop for ConfigWatcher {
fn drop(&mut self) {
if let Err(err) = self.tx.send(Command::Shutdown) {
warn!("Config watcher thread already died: {:?}", err);
}
}
}
enum Command {
AddWatch(PathBuf, Sender<Result<()>>),
Shutdown,
}
struct ConfigWatcherInner<Handler> {
reload_debounce: Duration,
reload_deadline: Option<Instant>,
handler: Handler,
watcher: RecommendedWatcher,
notify_rx: Receiver<Result<notify::Event, notify::Error>>,
req_rx: Receiver<Command>,
paths: HashMap<PathBuf, (PathBuf, PathBuf)>,
#[cfg(test)]
debug_tx: Sender<()>,
}
enum Outcome {
Event(notify::Result<notify::Event>),
AddWatch(PathBuf, Sender<Result<()>>),
Timeout,
Shutdown,
}
impl From<Command> for Outcome {
fn from(value: Command) -> Self {
match value {
Command::AddWatch(path, sender) => Self::AddWatch(path, sender),
Command::Shutdown => Self::Shutdown,
}
}
}
impl From<notify::Result<notify::Event>> for Outcome {
fn from(value: notify::Result<notify::Event>) -> Self {
Self::Event(value)
}
}
impl<Handler> ConfigWatcherInner<Handler> {
fn select(&self) -> Outcome {
debug!("now {:?} select with ddl {:?}", Instant::now(), &self.reload_deadline);
let timeout = self
.reload_deadline
.map(crossbeam_channel::at)
.unwrap_or_else(crossbeam_channel::never);
#[cfg(test)]
{
if let Ok(res) = self.notify_rx.try_recv() {
return Outcome::from(res);
}
if let Ok(res) = self.req_rx.try_recv() {
return Outcome::from(res);
}
if timeout.try_recv().is_ok() {
return Outcome::Timeout;
}
if self.reload_deadline.is_none() {
self.debug_tx.send(()).unwrap();
}
}
select! {
recv(self.notify_rx) -> res => res.map(Outcome::from).unwrap_or(Outcome::Shutdown),
recv(self.req_rx) -> res => res.map(Outcome::from).unwrap_or(Outcome::Shutdown),
recv(timeout) -> _ => Outcome::Timeout,
}
}
fn trigger_reload(&mut self) {
self.reload_deadline =
self.reload_deadline.or_else(|| Some(Instant::now() + self.reload_debounce));
debug!("defer config reloading to {:?}!", &self.reload_deadline);
}
fn add_watch_by_command(&mut self, path: PathBuf) -> Result<()> {
let canonical_path = canonicalize_path(&path);
match self.paths.entry(canonical_path) {
Entry::Occupied(e) => Err(anyhow!("{} is already being watched", e.key().display())),
e @ Entry::Vacant(_) => {
let reload = watch_and_add(&mut self.watcher, e)?;
if reload {
self.trigger_reload();
}
Ok(())
}
}
}
fn rewatch(&mut self, rewatch: ReWatch) -> bool {
let rewatch_paths = match rewatch {
ReWatch::Some(rewatch_paths) => rewatch_paths,
ReWatch::All => {
self.paths.drain().map(|(path, (watched_path, _))| (path, watched_path)).collect()
}
};
rewatch_paths.into_iter().any(|(path, watched_path)| {
if let Err(err) = self.watcher.unwatch(&watched_path) {
error!("error unwatch {:?}", err);
} else {
debug!("unwatched {}", watched_path.display());
}
watch_and_add(&mut self.watcher, self.paths.entry(path))
.map_err(|err| error!("Failed to add watch: {:?}", err))
.unwrap_or(true)
})
}
}
impl<Handler> ConfigWatcherInner<Handler>
where
Handler: FnMut(),
{
#[instrument(skip_all)]
fn run(&mut self) -> Result<()> {
loop {
match self.select() {
Outcome::Event(res) => {
debug!("event: {:?}", res);
let (rewatch, mut reload) = match res {
Err(error) => {
error!("Error: {error:?}");
(ReWatch::All, false)
}
Ok(event) => handle_event(event, &self.paths),
};
debug!("rewatch = {rewatch:?}, reload = {reload}");
reload |= self.rewatch(rewatch);
if reload {
test_hooks::emit("daemon-config-watcher-file-change");
self.trigger_reload();
}
}
Outcome::AddWatch(path, sender) => {
debug!("addwatch: {:?}", path);
let _ = sender.send(self.add_watch_by_command(path));
}
Outcome::Timeout => {
debug!("timeout");
self.reload_deadline = None;
(self.handler)();
}
Outcome::Shutdown => {
debug!("stopping config watcher thread");
break;
}
}
}
Ok(())
}
}
#[derive(Debug, PartialEq, Eq)]
enum ReWatch {
Some(Vec<(PathBuf, PathBuf)>),
All,
}
fn handle_event(event: Event, paths: &HashMap<PathBuf, (PathBuf, PathBuf)>) -> (ReWatch, bool) {
if event.need_rescan() {
debug!("need rescan");
return (ReWatch::All, true);
}
let is_original = event.paths.iter().any(|p| paths.contains_key(p));
match event.kind {
EventKind::Remove(_) | EventKind::Create(_) | EventKind::Modify(ModifyKind::Name(_)) => {
debug!("create/remove: {:?}", event);
let rewatch = paths
.iter()
.filter(|(_, (watched_path, immediate_child_path))| {
event.paths.iter().any(|p| p == watched_path || p == immediate_child_path)
})
.map(|(path, (watched_path, _))| (path.to_owned(), watched_path.to_owned()))
.collect();
(ReWatch::Some(rewatch), is_original)
}
EventKind::Modify(_) => {
debug!("modify: {:?}", event);
(ReWatch::Some(vec![]), is_original)
}
_ => {
debug!("ignore {:?}", event);
(ReWatch::Some(vec![]), false)
}
}
}
fn watch_and_add(
watcher: &mut RecommendedWatcher,
entry: Entry<PathBuf, (PathBuf, PathBuf)>,
) -> Result<bool> {
let best_effort_watch_owned = |watcher: &mut RecommendedWatcher, path: &Path| {
best_effort_watch(watcher, path).map(|(w, ic)| {
let watched = w.canonicalize().unwrap_or_else(|_| w.to_path_buf());
let immediate = watched.join(ic.unwrap_or_else(|| Path::new("")));
(watched, immediate)
})
};
match best_effort_watch_owned(watcher, entry.key()) {
Ok((watched_path, immediate_child_path)) => {
let reload = &watched_path == entry.key();
match entry {
Entry::Occupied(mut entry) => {
entry.insert((watched_path, immediate_child_path));
}
Entry::Vacant(entry) => {
entry.insert((watched_path, immediate_child_path));
}
}
if reload {
debug!("Force reload since now watching on target file");
}
Ok(reload)
}
Err(err) => {
let context_msg = format!("best_effort_watch on {}", entry.key().display());
if let Entry::Occupied(entry) = entry {
entry.remove();
}
Err(err).context(context_msg)
}
}
}
#[cfg(test)]
#[rustfmt::skip::attributes(test_case)]
mod test {
use super::*;
use ntest::timeout;
use std::fs;
use tempfile::TempDir;
mod watch {
use super::*;
use std::fs;
#[test]
#[timeout(30000)]
fn all_non_existing() {
let mut watcher = recommended_watcher(|_| {}).unwrap();
let (watched_path, immediate_child) =
best_effort_watch(&mut watcher, Path::new("/non_existing/subdir")).unwrap();
assert_eq!(watched_path, Path::new("/"));
assert_eq!(immediate_child, Some(Path::new("non_existing")));
}
#[test]
#[timeout(30000)]
fn non_existing_parent() {
let tmpdir = tempfile::tempdir().unwrap();
let target_path = tmpdir.path().join(Path::new("sub1/sub2/c.txt"));
let parent_path = target_path.parent().unwrap().parent().unwrap();
fs::create_dir_all(parent_path).unwrap();
let mut watcher = recommended_watcher(|_| {}).unwrap();
let (watched_path, immediate_child) =
best_effort_watch(&mut watcher, &target_path).unwrap();
assert_eq!(watched_path, parent_path);
assert_eq!(immediate_child, Some(Path::new("sub2")));
}
#[test]
#[timeout(30000)]
fn existing_file() {
let tmpdir = tempfile::tempdir().unwrap();
let target_path = tmpdir.path().join(Path::new("sub1/sub2/c.txt"));
let parent_path = target_path.parent().unwrap();
fs::create_dir_all(parent_path).unwrap();
fs::write(&target_path, "test").unwrap();
let mut watcher = recommended_watcher(|_| {}).unwrap();
let (watched_path, immediate_child) =
best_effort_watch(&mut watcher, &target_path).unwrap();
assert_eq!(watched_path, target_path);
assert_eq!(immediate_child, None);
}
}
mod handle_event {
use super::*;
use assert_matches::assert_matches;
use notify::{
event::{CreateKind, ModifyKind, RemoveKind, RenameMode},
Event, EventKind,
};
use ntest::test_case;
fn paths_entry(target: &str, watched: &str) -> (PathBuf, (PathBuf, PathBuf)) {
let target = PathBuf::from(target);
let base = PathBuf::from(watched);
let immediate =
base.join(target.strip_prefix(&base).unwrap().iter().next().unwrap_or_default());
(target, (base, immediate))
}
fn event_from_spec(base: &str, evt: &str) -> notify::Event {
let base = Path::new(base);
let (evt, path) = evt.split_once(' ').unwrap_or((evt, ""));
match evt {
"create" => {
Event::new(EventKind::Create(CreateKind::Any)).add_path(base.join(path))
}
"mv" => {
let (src, dst) = path.split_once(' ').unwrap();
Event::new(EventKind::Modify(ModifyKind::Name(RenameMode::Both)))
.add_path(base.join(src))
.add_path(base.join(dst))
}
"mvselfother" => Event::new(EventKind::Modify(ModifyKind::Name(RenameMode::Both)))
.add_path(base.to_owned())
.add_path(PathBuf::from("/some/other/path")),
"modify" => {
Event::new(EventKind::Modify(ModifyKind::Any)).add_path(base.join(path))
}
"modifyself" => {
Event::new(EventKind::Modify(ModifyKind::Any)).add_path(base.to_owned())
}
"rm" => Event::new(EventKind::Remove(RemoveKind::Any)).add_path(base.join(path)),
"rmself" => {
Event::new(EventKind::Remove(RemoveKind::Any)).add_path(base.to_owned())
}
_ => panic!("malformatted event spec"),
}
}
#[test]
#[timeout(30000)]
fn need_rescan() {
let event = notify::Event::default().set_flag(notify::event::Flag::Rescan);
let paths = Default::default();
let (rewatch, reload) = handle_event(event, &paths);
assert_eq!(rewatch, ReWatch::All);
assert!(reload);
}
const TARGET: &str = "/base/sub/config.toml";
#[test_case(TARGET, "/base", "create sub", true, false, name = "base_create_sub")]
#[test_case(TARGET, "/base", "create other", false, false, name = "base_create_other")]
#[test_case(TARGET, "/base", "mv other sub", true, false, name = "base_other_to_sub")]
#[test_case(TARGET, "/base", "mv other another", false, false, name = "base_other_to_another")]
#[test_case(TARGET, "/base", "mv sub other", true, false, name = "base_sub_to_other")]
#[test_case(TARGET, "/base", "rm sub", true, false, name = "base_rm_sub")]
#[test_case(TARGET, "/base", "rm other", false, false, name = "base_rm_other")]
#[test_case(TARGET, "/base", "modify other.toml", false, false, name = "base_modify_other")]
#[test_case(TARGET, "/base/sub", "create config.toml", true, true, name = "sub_create_cfg")]
#[test_case(TARGET, "/base/sub", "mv other.toml config.toml", true, true, name = "sub_other_to_cfg")]
#[test_case(TARGET, "/base/sub", "mv other.toml another.toml", false, false, name = "sub_other_to_another")]
#[test_case(TARGET, "/base/sub", "modify config.toml", false, true, name = "sub_modify_cfg")]
#[test_case(TARGET, "/base/sub", "modify other.toml", false, false, name = "sub_modify_other")]
#[test_case(TARGET, "/base/sub", "rmself", true, false, name = "sub_rm_self")]
#[test_case(TARGET, "/base/sub/config.toml", "rmself", true, true, name = "cfg_rm_self")]
#[test_case(TARGET, "/base/sub/config.toml", "mvselfother", true, true, name = "cfg_self_to_other")]
#[test_case(TARGET, "/base/sub/config.toml", "modifyself", false, true, name = "cfg_modify_self")]
#[timeout(30000)]
fn single_path(
target: &str,
watched: &str,
evt: &str,
expected_rewatch: bool,
expected_reload: bool,
) {
let paths = HashMap::from([paths_entry(target, watched)]);
let event = event_from_spec(watched, evt);
let (rewatch, reload) = handle_event(event, &paths);
let expected_rewatch = if expected_rewatch {
ReWatch::Some(vec![(PathBuf::from(target), PathBuf::from(watched))])
} else {
ReWatch::Some(vec![])
};
assert_eq!(rewatch, expected_rewatch);
assert_eq!(reload, expected_reload);
}
#[test]
#[timeout(30000)]
fn both_paths_are_updated() {
let paths = HashMap::from([
paths_entry("/base/sub/config.toml", "/base"),
paths_entry("/base/other/another.toml", "/base"),
]);
let event = event_from_spec("/base", "rm /base");
let (rewatch, reload) = handle_event(event, &paths);
assert_matches!(rewatch, ReWatch::Some(p) if p.len() == 2);
assert!(!reload);
}
}
const DEBOUNCE_TIME: Duration = Duration::from_millis(50);
struct WatcherState {
#[allow(dead_code)]
tmpdir: TempDir,
base_path: PathBuf,
target_path: PathBuf,
rx: Receiver<()>,
watcher: ConfigWatcher,
}
fn setup(base: &str, target: &str) -> Result<WatcherState> {
let tmpdir = tempfile::tempdir()?;
let base_path = tmpdir.path().join(base);
let target_path = base_path.join(target);
assert!(target_path.strip_prefix(&base_path).is_ok());
fs::create_dir_all(&base_path)?;
let (tx, rx) = unbounded();
let watcher = ConfigWatcher::with_debounce(move || tx.send(()).unwrap(), DEBOUNCE_TIME)?;
watcher.watch(&target_path)?;
Ok(WatcherState { tmpdir, base_path, target_path, rx, watcher })
}
fn drop_watcher(watcher: ConfigWatcher) {
thread::sleep(DEBOUNCE_TIME * 2);
watcher.worker_ready();
}
#[test]
#[timeout(30000)]
#[cfg_attr(target_os = "macos", ignore)]
fn debounce() {
let state = setup("base", "sub/config.toml").unwrap();
fs::create_dir_all(state.target_path.parent().unwrap()).unwrap();
state.watcher.worker_ready();
fs::write(&state.target_path, "test").unwrap();
fs::write(&state.target_path, "another").unwrap();
drop_watcher(state.watcher);
let reloads: Vec<_> = state.rx.into_iter().collect();
assert_eq!(reloads.len(), 1);
}
#[test]
#[timeout(30000)]
fn writes_larger_than_debounce() {
let state = setup("base", "sub/config.toml").unwrap();
fs::create_dir_all(state.target_path.parent().unwrap()).unwrap();
state.watcher.worker_ready();
fs::write(&state.target_path, "test").unwrap();
thread::sleep(DEBOUNCE_TIME * 2);
state.watcher.worker_ready();
fs::write(&state.target_path, "another").unwrap();
drop_watcher(state.watcher);
let reloads: Vec<_> = state.rx.into_iter().collect();
assert_eq!(reloads.len(), 2);
}
#[test]
#[timeout(30000)]
fn move_multiple_levels_in_place() {
let state = setup("base", "sub/config.toml").unwrap();
fs::create_dir_all(state.base_path.join("other")).unwrap();
fs::write(state.base_path.join("other/config.toml"), "test").unwrap();
state.watcher.worker_ready();
fs::rename(state.base_path.join("other"), state.base_path.join("sub")).unwrap();
drop_watcher(state.watcher);
let reloads: Vec<_> = state.rx.into_iter().collect();
assert_eq!(reloads.len(), 1, "expected 1 reload, got {}", reloads.len());
}
#[test]
#[timeout(30000)]
#[cfg_attr(target_os = "macos", ignore)]
fn symlink_path_is_canonicalized() {
use std::os::unix::fs::symlink;
let tmpdir = tempfile::tempdir().unwrap();
let real_dir = tmpdir.path().join("real");
fs::create_dir_all(&real_dir).unwrap();
let link_dir = tmpdir.path().join("link");
symlink(&real_dir, &link_dir).unwrap();
let symlinked_target = link_dir.join("config.toml");
let (tx, rx) = unbounded();
let watcher =
ConfigWatcher::with_debounce(move || tx.send(()).unwrap(), DEBOUNCE_TIME).unwrap();
watcher.watch(&symlinked_target).unwrap();
watcher.worker_ready();
fs::write(&symlinked_target, "test content").unwrap();
drop_watcher(watcher);
let reloads: Vec<_> = rx.into_iter().collect();
assert_eq!(reloads.len(), 1, "expected 1 reload, got {}", reloads.len());
}
}