use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
use notify_debouncer_mini::{new_debouncer, DebouncedEvent, Debouncer};
use parking_lot::Mutex;
use tokio::sync::mpsc;
use super::ipc::Frame;
use super::state::DaemonState;
use super::Result;
#[derive(Clone, Debug, serde::Serialize, Default)]
pub struct WatcherStats {
pub events_received: u64,
pub events_routed: u64,
pub watched_path_count: usize,
}
#[derive(Clone, Debug)]
pub struct WatchedPath {
pub path: PathBuf,
pub shard_slug: String,
pub source_root: String,
}
pub struct FsWatcher {
inner: Mutex<FsWatcherInner>,
}
struct FsWatcherInner {
registered: HashMap<PathBuf, WatchedPath>,
stats: WatcherStats,
debouncer: Option<Debouncer<RecommendedWatcher>>,
}
impl FsWatcher {
pub fn new() -> Self {
Self {
inner: Mutex::new(FsWatcherInner {
registered: HashMap::new(),
stats: WatcherStats::default(),
debouncer: None,
}),
}
}
pub fn start(self: &Arc<Self>, state: Arc<DaemonState>) -> Result<()> {
let (tx, mut rx) = mpsc::unbounded_channel::<DebouncedResult>();
let debouncer = new_debouncer(
Duration::from_millis(150),
move |res: notify_debouncer_mini::DebounceEventResult| {
let _ = tx.send(res);
},
)
.map_err(|e| super::DaemonError::other(format!("fsnotify init: {e}")))?;
{
let mut g = self.inner.lock();
g.debouncer = Some(debouncer);
}
let me = Arc::clone(self);
tokio::spawn(async move {
while let Some(res) = rx.recv().await {
me.handle_debounced(res, &state);
}
});
tracing::info!("fsnotify watcher started");
Ok(())
}
pub fn watch_path(&self, mut wp: WatchedPath, recursive: bool) -> Result<()> {
let canonical = std::fs::canonicalize(&wp.path).unwrap_or(wp.path.clone());
wp.path = canonical;
let mut g = self.inner.lock();
let key = wp.path.clone();
let mode = if recursive {
RecursiveMode::Recursive
} else {
RecursiveMode::NonRecursive
};
if let Some(deb) = g.debouncer.as_mut() {
deb.watcher()
.watch(&wp.path, mode)
.map_err(|e| super::DaemonError::other(format!("fsnotify watch: {e}")))?;
}
g.registered.insert(key, wp);
g.stats.watched_path_count = g.registered.len();
Ok(())
}
pub fn unwatch_path(&self, path: &Path) -> Result<()> {
let canonical = std::fs::canonicalize(path).unwrap_or_else(|_| path.to_path_buf());
let mut g = self.inner.lock();
if let Some(deb) = g.debouncer.as_mut() {
let _ = deb.watcher().unwatch(&canonical);
}
g.registered.remove(&canonical);
g.registered.remove(path); g.stats.watched_path_count = g.registered.len();
Ok(())
}
pub fn stats(&self) -> WatcherStats {
let g = self.inner.lock();
g.stats.clone()
}
pub fn registered_paths(&self) -> Vec<WatchedPath> {
let g = self.inner.lock();
g.registered.values().cloned().collect()
}
fn handle_debounced(&self, res: DebouncedResult, state: &Arc<DaemonState>) {
let events = match res {
Ok(events) => events,
Err(e) => {
tracing::warn!(?e, "fsnotify error");
return;
}
};
for ev in events {
self.handle_one_event(&ev, state);
}
}
fn handle_one_event(&self, ev: &DebouncedEvent, state: &Arc<DaemonState>) {
let mut g = self.inner.lock();
g.stats.events_received += 1;
let path = &ev.path;
let best_match: Option<WatchedPath> = g
.registered
.values()
.filter(|wp| path.starts_with(&wp.path))
.max_by_key(|wp| wp.path.as_os_str().len())
.cloned();
let Some(wp) = best_match else {
tracing::debug!(path = %path.display(), "fsnotify event with no matching registration");
return;
};
g.stats.events_routed += 1;
drop(g);
tracing::info!(
path = %path.display(),
shard = %wp.shard_slug,
source_root = %wp.source_root,
"fsnotify event routed"
);
let payload = serde_json::json!({
"shard": wp.shard_slug,
"source_root": wp.source_root,
"trigger_path": path.display().to_string(),
"generation": null,
});
let frame = Frame::event("shard_updated", payload);
let _ = state.broadcast(frame, &[]);
}
}
type DebouncedResult = notify_debouncer_mini::DebounceEventResult;
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tempfile::TempDir;
#[tokio::test(flavor = "multi_thread")]
async fn watch_register_unregister() {
let tmp = TempDir::new().unwrap();
let watch_dir = tmp.path().join("zsh_funcs");
std::fs::create_dir_all(&watch_dir).unwrap();
let paths = super::super::paths::CachePaths::with_root(tmp.path().join("zshrs"));
paths.ensure_dirs().unwrap();
let state = super::super::state::DaemonState::new(paths).unwrap();
let watcher = Arc::new(FsWatcher::new());
watcher.start(Arc::clone(&state)).unwrap();
let wp = WatchedPath {
path: watch_dir.clone(),
shard_slug: "test".to_string(),
source_root: watch_dir.display().to_string(),
};
watcher.watch_path(wp.clone(), false).unwrap();
let stats = watcher.stats();
assert_eq!(stats.watched_path_count, 1);
let registered = watcher.registered_paths();
assert_eq!(registered.len(), 1);
assert_eq!(registered[0].shard_slug, "test");
watcher.unwatch_path(&watch_dir).unwrap();
assert_eq!(watcher.stats().watched_path_count, 0);
}
#[tokio::test(flavor = "multi_thread")]
async fn fsnotify_routes_event_to_shard() {
let tmp = TempDir::new().unwrap();
let watch_dir = tmp.path().join("zsh_funcs");
std::fs::create_dir_all(&watch_dir).unwrap();
let paths = super::super::paths::CachePaths::with_root(tmp.path().join("zshrs"));
paths.ensure_dirs().unwrap();
let state = super::super::state::DaemonState::new(paths).unwrap();
let watcher = Arc::new(FsWatcher::new());
watcher.start(Arc::clone(&state)).unwrap();
let wp = WatchedPath {
path: watch_dir.clone(),
shard_slug: "test".to_string(),
source_root: watch_dir.display().to_string(),
};
watcher.watch_path(wp, false).unwrap();
std::fs::write(watch_dir.join("_git"), b"# completion file").unwrap();
tokio::time::sleep(Duration::from_millis(400)).await;
let stats = watcher.stats();
assert!(
stats.events_received >= 1,
"events_received = {}",
stats.events_received
);
assert!(
stats.events_routed >= 1,
"events_routed = {}",
stats.events_routed
);
}
}