use std::path::{Path, PathBuf};
use std::time::Duration;
use notify::{RecursiveMode, Watcher};
use tracing::debug;
use crate::cube::{ClosedAxis, ClosedAxisLabel};
use crate::error::ShikumiError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum WatchEventClass {
Reload,
Removed,
Ignored,
}
impl WatchEventClass {
pub const ALL: &'static [Self] = &[Self::Reload, Self::Removed, Self::Ignored];
#[must_use]
pub fn classify(kind: ¬ify::EventKind) -> Self {
use notify::EventKind;
use notify::event::{DataChange, MetadataKind, ModifyKind};
match kind {
EventKind::Modify(
ModifyKind::Metadata(MetadataKind::WriteTime)
| ModifyKind::Data(DataChange::Content),
)
| EventKind::Create(_) => Self::Reload,
EventKind::Remove(_) => Self::Removed,
_ => Self::Ignored,
}
}
#[must_use]
pub const fn should_reload(self) -> bool {
matches!(self, Self::Reload)
}
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::Reload => "reload",
Self::Removed => "removed",
Self::Ignored => "ignored",
}
}
}
impl ClosedAxis for WatchEventClass {
const ALL: &'static [Self] = Self::ALL;
}
impl ClosedAxisLabel for WatchEventClass {
fn as_str(self) -> &'static str {
Self::as_str(self)
}
}
#[must_use]
pub fn symlink_target(path: &Path) -> Option<PathBuf> {
let metadata = std::fs::symlink_metadata(path).ok()?;
if metadata.file_type().is_symlink() {
std::fs::canonicalize(path).ok()
} else {
None
}
}
pub struct ConfigWatcher {
_watcher: Box<dyn Watcher + Send + Sync>,
}
impl ConfigWatcher {
pub fn watch<F>(path: &Path, on_change: F) -> Result<Self, ShikumiError>
where
F: Fn(notify::Event) + Send + 'static,
{
let handler = CallbackHandler(Box::new(on_change));
let setup = notify::Config::default().with_poll_interval(Duration::from_secs(3));
let symlink = symlink_target(path);
let mut watcher: Box<dyn Watcher + Send + Sync> = if let Some(ref target) = symlink {
let poll_setup = setup.with_follow_symlinks(true);
let mut w = notify::PollWatcher::new(handler, poll_setup)?;
debug!("watching symlink target {} for changes", target.display());
w.watch(target, RecursiveMode::NonRecursive)?;
Box::new(w)
} else {
Box::new(notify::RecommendedWatcher::new(handler, setup)?)
};
debug!("watching config file {} for changes", path.display());
watcher.watch(path, RecursiveMode::NonRecursive)?;
Ok(Self { _watcher: watcher })
}
pub fn rewatch<F>(path: &Path, on_change: F) -> Result<Self, ShikumiError>
where
F: Fn(notify::Event) + Send + 'static,
{
Self::watch(path, on_change)
}
}
struct CallbackHandler(Box<dyn Fn(notify::Event) + Send>);
impl notify::EventHandler for CallbackHandler {
fn handle_event(&mut self, event: notify::Result<notify::Event>) {
match event {
Ok(event) => (self.0)(event),
Err(err) => tracing::warn!("file watcher error: {err}"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::sync::{Arc, Mutex};
use std::thread;
use tempfile::TempDir;
use notify::EventKind;
use notify::event::{
AccessKind, CreateKind, DataChange, MetadataKind, ModifyKind, RemoveKind, RenameMode,
};
#[test]
fn classify_create_is_reload() {
for kind in [
EventKind::Create(CreateKind::File),
EventKind::Create(CreateKind::Any),
EventKind::Create(CreateKind::Other),
] {
assert_eq!(WatchEventClass::classify(&kind), WatchEventClass::Reload);
}
}
#[test]
fn classify_content_and_writetime_modify_is_reload() {
assert_eq!(
WatchEventClass::classify(&EventKind::Modify(ModifyKind::Data(DataChange::Content))),
WatchEventClass::Reload
);
assert_eq!(
WatchEventClass::classify(&EventKind::Modify(ModifyKind::Metadata(
MetadataKind::WriteTime
))),
WatchEventClass::Reload
);
}
#[test]
fn classify_remove_is_removed() {
for kind in [
EventKind::Remove(RemoveKind::File),
EventKind::Remove(RemoveKind::Any),
EventKind::Remove(RemoveKind::Other),
] {
assert_eq!(WatchEventClass::classify(&kind), WatchEventClass::Removed);
}
}
#[test]
fn classify_non_reload_modify_and_other_kinds_are_ignored() {
for kind in [
EventKind::Modify(ModifyKind::Data(DataChange::Any)),
EventKind::Modify(ModifyKind::Data(DataChange::Size)),
EventKind::Modify(ModifyKind::Metadata(MetadataKind::Permissions)),
EventKind::Modify(ModifyKind::Metadata(MetadataKind::Ownership)),
EventKind::Modify(ModifyKind::Name(RenameMode::Both)),
EventKind::Modify(ModifyKind::Any),
EventKind::Modify(ModifyKind::Other),
] {
assert_eq!(
WatchEventClass::classify(&kind),
WatchEventClass::Ignored,
"{kind:?} should be Ignored"
);
}
for kind in [
EventKind::Access(AccessKind::Any),
EventKind::Any,
EventKind::Other,
] {
assert_eq!(
WatchEventClass::classify(&kind),
WatchEventClass::Ignored,
"{kind:?} should be Ignored"
);
}
}
#[test]
fn should_reload_agrees_with_classify_reload() {
for class in WatchEventClass::ALL.iter().copied() {
assert_eq!(class.should_reload(), class == WatchEventClass::Reload);
}
}
#[test]
fn watch_event_class_all_covers_every_variant() {
assert_eq!(WatchEventClass::ALL.len(), 3);
let mut seen = WatchEventClass::ALL.to_vec();
seen.sort_by_key(|c| c.as_str());
seen.dedup();
assert_eq!(seen.len(), 3, "ALL must have no duplicates");
for kind in [
EventKind::Create(CreateKind::File),
EventKind::Modify(ModifyKind::Data(DataChange::Content)),
EventKind::Remove(RemoveKind::File),
EventKind::Access(AccessKind::Any),
EventKind::Any,
] {
assert!(WatchEventClass::ALL.contains(&WatchEventClass::classify(&kind)));
}
}
#[test]
fn watch_event_class_as_str_is_distinct_lowercase() {
assert_eq!(WatchEventClass::Reload.as_str(), "reload");
assert_eq!(WatchEventClass::Removed.as_str(), "removed");
assert_eq!(WatchEventClass::Ignored.as_str(), "ignored");
}
#[test]
fn watch_event_class_label_round_trips() {
use crate::ClosedAxisLabel;
for class in WatchEventClass::ALL.iter().copied() {
assert_eq!(
WatchEventClass::from_canonical_str(ClosedAxisLabel::as_str(class)),
Some(class)
);
assert_eq!(
WatchEventClass::from_canonical_str(&class.as_str().to_uppercase()),
Some(class)
);
}
assert_eq!(WatchEventClass::from_canonical_str("nonsense"), None);
assert_eq!(WatchEventClass::from_canonical_str(""), None);
}
#[test]
fn symlink_target_regular_file_returns_none() {
let dir = TempDir::new().unwrap();
let file = dir.path().join("regular.txt");
fs::write(&file, "hello").unwrap();
assert!(symlink_target(&file).is_none());
}
#[test]
fn symlink_target_nonexistent_returns_none() {
assert!(symlink_target(Path::new("/nonexistent/path")).is_none());
}
#[test]
fn symlink_target_resolves_symlink() {
let dir = TempDir::new().unwrap();
let target = dir.path().join("target.yaml");
fs::write(&target, "key: value").unwrap();
let link = dir.path().join("link.yaml");
std::os::unix::fs::symlink(&target, &link).unwrap();
let resolved = symlink_target(&link);
assert!(resolved.is_some());
assert_eq!(resolved.unwrap(), fs::canonicalize(&target).unwrap());
}
#[test]
fn watch_regular_file_detects_change() {
let dir = TempDir::new().unwrap();
let file = dir.path().join("config.yaml");
fs::write(&file, "key: old").unwrap();
let events = Arc::new(Mutex::new(Vec::new()));
let events_clone = events.clone();
let _watcher = ConfigWatcher::watch(&file, move |event| {
events_clone.lock().unwrap().push(event);
})
.unwrap();
thread::sleep(Duration::from_millis(100));
fs::write(&file, "key: new").unwrap();
thread::sleep(Duration::from_millis(500));
let captured = events.lock().unwrap();
assert!(
!captured.is_empty(),
"expected at least one file change event"
);
}
#[test]
fn watch_symlink_creates_poll_watcher() {
let dir = TempDir::new().unwrap();
let target = dir.path().join("target.yaml");
fs::write(&target, "key: value").unwrap();
let link = dir.path().join("link.yaml");
std::os::unix::fs::symlink(&target, &link).unwrap();
let _watcher = ConfigWatcher::watch(&link, |_event| {}).unwrap();
}
#[test]
fn watch_nonexistent_file_errors() {
let result = ConfigWatcher::watch(Path::new("/nonexistent/config.yaml"), |_| {});
assert!(result.is_err());
}
#[test]
fn symlink_target_broken_symlink_returns_none() {
let dir = TempDir::new().unwrap();
let target = dir.path().join("deleted_target.yaml");
let link = dir.path().join("broken_link.yaml");
fs::write(&target, "key: value").unwrap();
std::os::unix::fs::symlink(&target, &link).unwrap();
fs::remove_file(&target).unwrap();
let result = symlink_target(&link);
assert!(result.is_none(), "broken symlink should return None");
}
#[test]
fn symlink_target_directory_symlink() {
let dir = TempDir::new().unwrap();
let target_dir = dir.path().join("target_dir");
fs::create_dir_all(&target_dir).unwrap();
let link = dir.path().join("link_dir");
std::os::unix::fs::symlink(&target_dir, &link).unwrap();
let result = symlink_target(&link);
assert!(result.is_some());
assert_eq!(result.unwrap(), fs::canonicalize(&target_dir).unwrap());
}
#[test]
fn rewatch_creates_new_watcher() {
let dir = TempDir::new().unwrap();
let file = dir.path().join("rewatch.yaml");
fs::write(&file, "key: value").unwrap();
let events = Arc::new(Mutex::new(Vec::new()));
let events_clone = events.clone();
let _watcher = ConfigWatcher::rewatch(&file, move |event| {
events_clone.lock().unwrap().push(event);
})
.unwrap();
thread::sleep(Duration::from_millis(100));
fs::write(&file, "key: updated").unwrap();
thread::sleep(Duration::from_millis(500));
let captured = events.lock().unwrap();
assert!(!captured.is_empty(), "rewatch should detect file changes");
}
#[test]
fn watch_symlink_detects_target_change() {
let dir = TempDir::new().unwrap();
let target = dir.path().join("target.yaml");
fs::write(&target, "key: original").unwrap();
let link = dir.path().join("watched_link.yaml");
std::os::unix::fs::symlink(&target, &link).unwrap();
let events = Arc::new(Mutex::new(Vec::new()));
let events_clone = events.clone();
let _watcher = ConfigWatcher::watch(&link, move |event| {
events_clone.lock().unwrap().push(event);
})
.unwrap();
thread::sleep(Duration::from_millis(200));
fs::write(&target, "key: modified").unwrap();
thread::sleep(Duration::from_millis(4000));
let captured = events.lock().unwrap();
if !captured.is_empty() {
assert!(captured.iter().any(|e| !e.paths.is_empty()));
}
}
#[test]
fn watch_callback_receives_event_with_path() {
let dir = TempDir::new().unwrap();
let file = dir.path().join("pathcheck.yaml");
fs::write(&file, "key: value").unwrap();
let paths = Arc::new(Mutex::new(Vec::new()));
let paths_clone = paths.clone();
let _watcher = ConfigWatcher::watch(&file, move |event| {
for p in &event.paths {
paths_clone.lock().unwrap().push(p.clone());
}
})
.unwrap();
thread::sleep(Duration::from_millis(100));
fs::write(&file, "key: new_value").unwrap();
thread::sleep(Duration::from_millis(500));
let captured = paths.lock().unwrap();
if !captured.is_empty() {
assert!(
captured
.iter()
.any(|p| { p.display().to_string().contains("pathcheck") }),
"expected event path to reference the watched file"
);
}
}
#[test]
fn symlink_target_nested_symlink() {
let dir = TempDir::new().unwrap();
let target = dir.path().join("real.yaml");
fs::write(&target, "key: value").unwrap();
let link1 = dir.path().join("link1.yaml");
std::os::unix::fs::symlink(&target, &link1).unwrap();
let link2 = dir.path().join("link2.yaml");
std::os::unix::fs::symlink(&link1, &link2).unwrap();
let resolved = symlink_target(&link2);
assert!(resolved.is_some());
assert_eq!(resolved.unwrap(), fs::canonicalize(&target).unwrap());
}
#[test]
fn rewatch_nonexistent_file_errors() {
let result = ConfigWatcher::rewatch(Path::new("/nonexistent/rewatch.yaml"), |_| {});
assert!(result.is_err());
}
#[test]
fn symlink_target_returns_none_for_plain_directory() {
let dir = TempDir::new().unwrap();
assert!(symlink_target(dir.path()).is_none());
}
}