use std::path::{Path, PathBuf};
use std::time::Duration;
use notify::{RecursiveMode, Watcher};
use tracing::debug;
use crate::cube::{ClosedAxis, ClosedAxisLabel};
use crate::error::ShikumiError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[non_exhaustive]
pub enum WatchEventClass {
Reload,
Removed,
Ignored,
}
impl WatchEventClass {
pub const ALL: &'static [Self] = &[Self::Reload, Self::Removed, Self::Ignored];
#[must_use]
pub fn classify(kind: ¬ify::EventKind) -> Self {
use notify::EventKind;
use notify::event::{DataChange, MetadataKind, ModifyKind};
match kind {
EventKind::Modify(
ModifyKind::Metadata(MetadataKind::WriteTime)
| ModifyKind::Data(DataChange::Content),
)
| EventKind::Create(_) => Self::Reload,
EventKind::Remove(_) => Self::Removed,
_ => Self::Ignored,
}
}
#[must_use]
pub const fn should_reload(self) -> bool {
matches!(self, Self::Reload)
}
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::Reload => "reload",
Self::Removed => "removed",
Self::Ignored => "ignored",
}
}
}
impl ClosedAxis for WatchEventClass {
const ALL: &'static [Self] = Self::ALL;
}
impl ClosedAxisLabel for WatchEventClass {
fn as_str(self) -> &'static str {
Self::as_str(self)
}
}
closed_axis_label_string_surface! {
type = WatchEventClass,
parse_error = "unknown watch event class",
expecting = "a canonical WatchEventClass lowercase label \
(`reload`, `removed`, `ignored`; case-insensitive)",
}
#[must_use]
pub fn symlink_target(path: &Path) -> Option<PathBuf> {
let metadata = std::fs::symlink_metadata(path).ok()?;
if metadata.file_type().is_symlink() {
std::fs::canonicalize(path).ok()
} else {
None
}
}
pub struct ConfigWatcher {
_watcher: Box<dyn Watcher + Send + Sync>,
}
impl ConfigWatcher {
pub fn watch<F>(path: &Path, on_change: F) -> Result<Self, ShikumiError>
where
F: Fn(notify::Event) + Send + 'static,
{
let handler = CallbackHandler(Box::new(on_change));
let setup = notify::Config::default().with_poll_interval(Duration::from_secs(3));
let symlink = symlink_target(path);
let mut watcher: Box<dyn Watcher + Send + Sync> = if let Some(ref target) = symlink {
let poll_setup = setup.with_follow_symlinks(true);
let mut w = notify::PollWatcher::new(handler, poll_setup)?;
debug!("watching symlink target {} for changes", target.display());
w.watch(target, RecursiveMode::NonRecursive)?;
Box::new(w)
} else {
Box::new(notify::RecommendedWatcher::new(handler, setup)?)
};
debug!("watching config file {} for changes", path.display());
watcher.watch(path, RecursiveMode::NonRecursive)?;
Ok(Self { _watcher: watcher })
}
pub fn rewatch<F>(path: &Path, on_change: F) -> Result<Self, ShikumiError>
where
F: Fn(notify::Event) + Send + 'static,
{
Self::watch(path, on_change)
}
}
struct CallbackHandler(Box<dyn Fn(notify::Event) + Send>);
impl notify::EventHandler for CallbackHandler {
fn handle_event(&mut self, event: notify::Result<notify::Event>) {
match event {
Ok(event) => (self.0)(event),
Err(err) => tracing::warn!("file watcher error: {err}"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::sync::{Arc, Mutex};
use std::thread;
use tempfile::TempDir;
use notify::EventKind;
use notify::event::{
AccessKind, CreateKind, DataChange, MetadataKind, ModifyKind, RemoveKind, RenameMode,
};
#[test]
fn classify_create_is_reload() {
for kind in [
EventKind::Create(CreateKind::File),
EventKind::Create(CreateKind::Any),
EventKind::Create(CreateKind::Other),
] {
assert_eq!(WatchEventClass::classify(&kind), WatchEventClass::Reload);
}
}
#[test]
fn classify_content_and_writetime_modify_is_reload() {
assert_eq!(
WatchEventClass::classify(&EventKind::Modify(ModifyKind::Data(DataChange::Content))),
WatchEventClass::Reload
);
assert_eq!(
WatchEventClass::classify(&EventKind::Modify(ModifyKind::Metadata(
MetadataKind::WriteTime
))),
WatchEventClass::Reload
);
}
#[test]
fn classify_remove_is_removed() {
for kind in [
EventKind::Remove(RemoveKind::File),
EventKind::Remove(RemoveKind::Any),
EventKind::Remove(RemoveKind::Other),
] {
assert_eq!(WatchEventClass::classify(&kind), WatchEventClass::Removed);
}
}
#[test]
fn classify_non_reload_modify_and_other_kinds_are_ignored() {
for kind in [
EventKind::Modify(ModifyKind::Data(DataChange::Any)),
EventKind::Modify(ModifyKind::Data(DataChange::Size)),
EventKind::Modify(ModifyKind::Metadata(MetadataKind::Permissions)),
EventKind::Modify(ModifyKind::Metadata(MetadataKind::Ownership)),
EventKind::Modify(ModifyKind::Name(RenameMode::Both)),
EventKind::Modify(ModifyKind::Any),
EventKind::Modify(ModifyKind::Other),
] {
assert_eq!(
WatchEventClass::classify(&kind),
WatchEventClass::Ignored,
"{kind:?} should be Ignored"
);
}
for kind in [
EventKind::Access(AccessKind::Any),
EventKind::Any,
EventKind::Other,
] {
assert_eq!(
WatchEventClass::classify(&kind),
WatchEventClass::Ignored,
"{kind:?} should be Ignored"
);
}
}
#[test]
fn should_reload_agrees_with_classify_reload() {
for class in WatchEventClass::ALL.iter().copied() {
assert_eq!(class.should_reload(), class == WatchEventClass::Reload);
}
}
#[test]
fn watch_event_class_all_covers_every_variant() {
assert_eq!(WatchEventClass::ALL.len(), 3);
let mut seen = WatchEventClass::ALL.to_vec();
seen.sort_by_key(|c| c.as_str());
seen.dedup();
assert_eq!(seen.len(), 3, "ALL must have no duplicates");
for kind in [
EventKind::Create(CreateKind::File),
EventKind::Modify(ModifyKind::Data(DataChange::Content)),
EventKind::Remove(RemoveKind::File),
EventKind::Access(AccessKind::Any),
EventKind::Any,
] {
assert!(WatchEventClass::ALL.contains(&WatchEventClass::classify(&kind)));
}
}
#[test]
fn watch_event_class_as_str_is_distinct_lowercase() {
assert_eq!(WatchEventClass::Reload.as_str(), "reload");
assert_eq!(WatchEventClass::Removed.as_str(), "removed");
assert_eq!(WatchEventClass::Ignored.as_str(), "ignored");
}
#[test]
fn watch_event_class_label_round_trips() {
use crate::ClosedAxisLabel;
for class in WatchEventClass::ALL.iter().copied() {
assert_eq!(
WatchEventClass::from_canonical_str(ClosedAxisLabel::as_str(class)),
Some(class)
);
assert_eq!(
WatchEventClass::from_canonical_str(&class.as_str().to_uppercase()),
Some(class)
);
}
assert_eq!(WatchEventClass::from_canonical_str("nonsense"), None);
assert_eq!(WatchEventClass::from_canonical_str(""), None);
}
#[test]
fn watch_event_class_ord_matches_all_declaration_order() {
use std::cmp::Ordering;
for window in WatchEventClass::ALL.windows(2) {
assert!(
window[0] < window[1],
"WatchEventClass::ALL must be strictly increasing under Ord, \
but {:?} >= {:?}",
window[0],
window[1],
);
}
for (i, &a) in WatchEventClass::ALL.iter().enumerate() {
for (j, &b) in WatchEventClass::ALL.iter().enumerate() {
let expected = i.cmp(&j);
assert_eq!(
a.cmp(&b),
expected,
"WatchEventClass::cmp must match ALL-index lex for ({a:?}, {b:?})",
);
assert_eq!(
a.partial_cmp(&b),
Some(expected),
"WatchEventClass::partial_cmp must agree with cmp for ({a:?}, {b:?})",
);
if i == j {
assert_eq!(a.cmp(&b), Ordering::Equal, "Ord must be reflexive on {a:?}",);
}
}
}
}
#[test]
fn watch_event_class_btreemap_emits_in_declaration_order() {
use std::collections::BTreeMap;
let mut counts: BTreeMap<WatchEventClass, u32> = BTreeMap::new();
counts.insert(WatchEventClass::Ignored, 3);
counts.insert(WatchEventClass::Reload, 1);
counts.insert(WatchEventClass::Removed, 2);
let observed: Vec<WatchEventClass> = counts.keys().copied().collect();
assert_eq!(
observed,
WatchEventClass::ALL.to_vec(),
"BTreeMap<WatchEventClass, _> must emit keys in ALL declaration order",
);
}
#[test]
fn watch_event_class_display_matches_as_str() {
for c in WatchEventClass::ALL.iter().copied() {
assert_eq!(
format!("{c}"),
c.as_str(),
"Display must agree with as_str for {c:?}",
);
}
}
#[test]
fn watch_event_class_from_str_round_trips_over_every_variant() {
for c in WatchEventClass::ALL {
let rendered = c.to_string();
let parsed: WatchEventClass = rendered
.parse()
.expect("FromStr must round-trip Display output");
assert_eq!(parsed, *c, "FromStr must round-trip {c:?}");
}
}
#[test]
fn watch_event_class_from_str_is_case_insensitive() {
assert_eq!(
"RELOAD".parse::<WatchEventClass>().unwrap(),
WatchEventClass::Reload,
);
assert_eq!(
"Removed".parse::<WatchEventClass>().unwrap(),
WatchEventClass::Removed,
);
assert_eq!(
"iGnOrEd".parse::<WatchEventClass>().unwrap(),
WatchEventClass::Ignored,
);
assert_eq!(
"rElOaD".parse::<WatchEventClass>().unwrap(),
WatchEventClass::Reload,
);
}
#[test]
fn watch_event_class_from_str_unknown_class_error_carries_label_verbatim() {
for bad in &["modify", "create", "rename", "", " reload"] {
let err = bad
.parse::<WatchEventClass>()
.expect_err("non-canonical label must reject");
let rendered = err.to_string();
assert!(
rendered.contains(bad),
"rendered error must contain the offending label verbatim: \
input={bad:?}, rendered={rendered:?}",
);
}
}
#[test]
fn watch_event_class_serde_yaml_round_trips_over_every_variant() {
for c in WatchEventClass::ALL {
let yaml = serde_yaml::to_string(c).expect("Serialize must succeed");
let parsed: WatchEventClass =
serde_yaml::from_str(&yaml).expect("Deserialize must accept Serialize output");
assert_eq!(parsed, *c, "serde_yaml round-trip must preserve {c:?}");
}
}
#[test]
fn watch_event_class_serde_json_round_trips_over_every_variant() {
for c in WatchEventClass::ALL {
let json = serde_json::to_string(c).expect("Serialize must succeed");
let parsed: WatchEventClass =
serde_json::from_str(&json).expect("Deserialize must accept Serialize output");
assert_eq!(parsed, *c, "serde_json round-trip must preserve {c:?}");
}
}
#[test]
fn watch_event_class_serde_yaml_is_case_insensitive() {
let cases: &[(&str, WatchEventClass)] = &[
("Reload", WatchEventClass::Reload),
("REMOVED", WatchEventClass::Removed),
("IgNoReD", WatchEventClass::Ignored),
("rElOaD", WatchEventClass::Reload),
];
for (input, expected) in cases {
let parsed: WatchEventClass =
serde_yaml::from_str(input).expect("case-insensitive Deserialize must succeed");
assert_eq!(
parsed, *expected,
"serde_yaml must parse case-insensitively for input {input:?}",
);
}
}
#[test]
fn watch_event_class_serde_yaml_unknown_class_error_carries_label_verbatim() {
for bad in &["modify", "create", "rename", "noop"] {
let err = serde_yaml::from_str::<WatchEventClass>(bad)
.expect_err("non-canonical label must reject");
let rendered = err.to_string();
assert!(
rendered.contains(bad),
"rendered serde error must contain the offending label verbatim: \
input={bad:?}, rendered={rendered:?}",
);
}
}
#[test]
fn watch_event_class_serde_yaml_emission_is_bare_scalar() {
assert_eq!(
serde_yaml::to_string(&WatchEventClass::Reload).unwrap(),
"reload\n",
);
assert_eq!(
serde_yaml::to_string(&WatchEventClass::Removed).unwrap(),
"removed\n",
);
assert_eq!(
serde_yaml::to_string(&WatchEventClass::Ignored).unwrap(),
"ignored\n",
);
}
#[test]
fn symlink_target_regular_file_returns_none() {
let dir = TempDir::new().unwrap();
let file = dir.path().join("regular.txt");
fs::write(&file, "hello").unwrap();
assert!(symlink_target(&file).is_none());
}
#[test]
fn symlink_target_nonexistent_returns_none() {
assert!(symlink_target(Path::new("/nonexistent/path")).is_none());
}
#[test]
fn symlink_target_resolves_symlink() {
let dir = TempDir::new().unwrap();
let target = dir.path().join("target.yaml");
fs::write(&target, "key: value").unwrap();
let link = dir.path().join("link.yaml");
std::os::unix::fs::symlink(&target, &link).unwrap();
let resolved = symlink_target(&link);
assert!(resolved.is_some());
assert_eq!(resolved.unwrap(), fs::canonicalize(&target).unwrap());
}
#[test]
fn watch_regular_file_detects_change() {
let dir = TempDir::new().unwrap();
let file = dir.path().join("config.yaml");
fs::write(&file, "key: old").unwrap();
let events = Arc::new(Mutex::new(Vec::new()));
let events_clone = events.clone();
let _watcher = ConfigWatcher::watch(&file, move |event| {
events_clone.lock().unwrap().push(event);
})
.unwrap();
thread::sleep(Duration::from_millis(100));
fs::write(&file, "key: new").unwrap();
thread::sleep(Duration::from_millis(500));
let captured = events.lock().unwrap();
assert!(
!captured.is_empty(),
"expected at least one file change event"
);
}
#[test]
fn watch_symlink_creates_poll_watcher() {
let dir = TempDir::new().unwrap();
let target = dir.path().join("target.yaml");
fs::write(&target, "key: value").unwrap();
let link = dir.path().join("link.yaml");
std::os::unix::fs::symlink(&target, &link).unwrap();
let _watcher = ConfigWatcher::watch(&link, |_event| {}).unwrap();
}
#[test]
fn watch_nonexistent_file_errors() {
let result = ConfigWatcher::watch(Path::new("/nonexistent/config.yaml"), |_| {});
assert!(result.is_err());
}
#[test]
fn symlink_target_broken_symlink_returns_none() {
let dir = TempDir::new().unwrap();
let target = dir.path().join("deleted_target.yaml");
let link = dir.path().join("broken_link.yaml");
fs::write(&target, "key: value").unwrap();
std::os::unix::fs::symlink(&target, &link).unwrap();
fs::remove_file(&target).unwrap();
let result = symlink_target(&link);
assert!(result.is_none(), "broken symlink should return None");
}
#[test]
fn symlink_target_directory_symlink() {
let dir = TempDir::new().unwrap();
let target_dir = dir.path().join("target_dir");
fs::create_dir_all(&target_dir).unwrap();
let link = dir.path().join("link_dir");
std::os::unix::fs::symlink(&target_dir, &link).unwrap();
let result = symlink_target(&link);
assert!(result.is_some());
assert_eq!(result.unwrap(), fs::canonicalize(&target_dir).unwrap());
}
#[test]
fn rewatch_creates_new_watcher() {
let dir = TempDir::new().unwrap();
let file = dir.path().join("rewatch.yaml");
fs::write(&file, "key: value").unwrap();
let events = Arc::new(Mutex::new(Vec::new()));
let events_clone = events.clone();
let _watcher = ConfigWatcher::rewatch(&file, move |event| {
events_clone.lock().unwrap().push(event);
})
.unwrap();
thread::sleep(Duration::from_millis(100));
fs::write(&file, "key: updated").unwrap();
thread::sleep(Duration::from_millis(500));
let captured = events.lock().unwrap();
assert!(!captured.is_empty(), "rewatch should detect file changes");
}
#[test]
fn watch_symlink_detects_target_change() {
let dir = TempDir::new().unwrap();
let target = dir.path().join("target.yaml");
fs::write(&target, "key: original").unwrap();
let link = dir.path().join("watched_link.yaml");
std::os::unix::fs::symlink(&target, &link).unwrap();
let events = Arc::new(Mutex::new(Vec::new()));
let events_clone = events.clone();
let _watcher = ConfigWatcher::watch(&link, move |event| {
events_clone.lock().unwrap().push(event);
})
.unwrap();
thread::sleep(Duration::from_millis(200));
fs::write(&target, "key: modified").unwrap();
thread::sleep(Duration::from_millis(4000));
let captured = events.lock().unwrap();
if !captured.is_empty() {
assert!(captured.iter().any(|e| !e.paths.is_empty()));
}
}
#[test]
fn watch_callback_receives_event_with_path() {
let dir = TempDir::new().unwrap();
let file = dir.path().join("pathcheck.yaml");
fs::write(&file, "key: value").unwrap();
let paths = Arc::new(Mutex::new(Vec::new()));
let paths_clone = paths.clone();
let _watcher = ConfigWatcher::watch(&file, move |event| {
for p in &event.paths {
paths_clone.lock().unwrap().push(p.clone());
}
})
.unwrap();
thread::sleep(Duration::from_millis(100));
fs::write(&file, "key: new_value").unwrap();
thread::sleep(Duration::from_millis(500));
let captured = paths.lock().unwrap();
if !captured.is_empty() {
assert!(
captured
.iter()
.any(|p| { p.display().to_string().contains("pathcheck") }),
"expected event path to reference the watched file"
);
}
}
#[test]
fn symlink_target_nested_symlink() {
let dir = TempDir::new().unwrap();
let target = dir.path().join("real.yaml");
fs::write(&target, "key: value").unwrap();
let link1 = dir.path().join("link1.yaml");
std::os::unix::fs::symlink(&target, &link1).unwrap();
let link2 = dir.path().join("link2.yaml");
std::os::unix::fs::symlink(&link1, &link2).unwrap();
let resolved = symlink_target(&link2);
assert!(resolved.is_some());
assert_eq!(resolved.unwrap(), fs::canonicalize(&target).unwrap());
}
#[test]
fn rewatch_nonexistent_file_errors() {
let result = ConfigWatcher::rewatch(Path::new("/nonexistent/rewatch.yaml"), |_| {});
assert!(result.is_err());
}
#[test]
fn symlink_target_returns_none_for_plain_directory() {
let dir = TempDir::new().unwrap();
assert!(symlink_target(dir.path()).is_none());
}
}