use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use notify::{EventKind, RecursiveMode, Watcher, recommended_watcher};
use parking_lot::Mutex;
use tracing::{debug, info, warn};
use trusty_mpm_core::hook::{HookEvent, HookEventRecord};
use trusty_mpm_core::paths::FrameworkPaths;
use trusty_mpm_core::session::SessionId;
use crate::state::DaemonState;
pub struct FileWatcher {
state: Arc<DaemonState>,
roots: Mutex<HashMap<SessionId, PathBuf>>,
}
impl FileWatcher {
pub fn new(state: Arc<DaemonState>) -> Self {
Self {
state,
roots: Mutex::new(HashMap::new()),
}
}
pub fn watch_session(&self, session: SessionId, root: PathBuf) -> Option<PathBuf> {
self.roots.lock().insert(session, root)
}
#[allow(dead_code)] pub fn unwatch_session(&self, session: SessionId) -> Option<PathBuf> {
self.roots.lock().remove(&session)
}
pub fn watched_count(&self) -> usize {
self.roots.lock().len()
}
pub fn session_for_path(&self, path: &std::path::Path) -> Option<SessionId> {
let roots = self.roots.lock();
roots
.iter()
.filter(|(_, root)| path.starts_with(root))
.max_by_key(|(_, root)| root.as_os_str().len())
.map(|(session, _)| *session)
}
pub async fn spawn(self) {
for session in self.state.list_sessions() {
let root = PathBuf::from(&session.workdir);
if root.is_dir() {
self.watch_session(session.id, root);
}
}
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<PathBuf>();
let mut watcher = match recommended_watcher(move |res: notify::Result<notify::Event>| {
if let Ok(event) = res
&& matches!(
event.kind,
EventKind::Create(_) | EventKind::Modify(_) | EventKind::Remove(_)
)
{
for path in event.paths {
let _ = tx.send(path);
}
}
}) {
Ok(w) => w,
Err(e) => {
warn!("file watcher unavailable: {e}");
return;
}
};
let roots: Vec<PathBuf> = self.roots.lock().values().cloned().collect();
for root in &roots {
if let Err(e) = watcher.watch(root, RecursiveMode::Recursive) {
warn!("failed to watch {}: {e}", root.display());
} else {
debug!("watching {}", root.display());
}
}
let hooks = FrameworkPaths::default().hooks;
if hooks.is_dir() {
if let Err(e) = watcher.watch(&hooks, RecursiveMode::NonRecursive) {
warn!("failed to watch hooks dir {}: {e}", hooks.display());
} else {
debug!("watching framework hooks dir {}", hooks.display());
}
}
info!("file watcher started ({} root(s))", self.watched_count());
while let Some(path) = rx.recv().await {
if self.record_change(&path) {
debug!("recorded file change: {}", path.display());
}
}
}
fn is_optimizer_policy(path: &std::path::Path) -> bool {
path.file_name()
.is_some_and(|name| name == std::ffi::OsStr::new("optimizer.toml"))
}
pub fn record_change(&self, path: &std::path::Path) -> bool {
if Self::is_optimizer_policy(path) {
self.state.reload_optimizer_config();
debug!("reloaded optimizer config after {} changed", path.display());
return true;
}
let Some(session) = self.session_for_path(path) else {
return false;
};
let payload = serde_json::json!({ "path": path.to_string_lossy() });
self.state.push_hook_event(HookEventRecord::now(
session,
HookEvent::FileChanged,
payload,
));
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn register_and_unregister_roots() {
let watcher = FileWatcher::new(DaemonState::shared());
let s = SessionId::new();
assert_eq!(watcher.watched_count(), 0);
assert!(
watcher
.watch_session(s, PathBuf::from("/tmp/proj"))
.is_none()
);
assert_eq!(watcher.watched_count(), 1);
assert_eq!(watcher.unwatch_session(s), Some(PathBuf::from("/tmp/proj")));
assert_eq!(watcher.watched_count(), 0);
}
#[test]
fn attributes_path_to_longest_matching_root() {
let watcher = FileWatcher::new(DaemonState::shared());
let outer = SessionId::new();
let inner = SessionId::new();
watcher.watch_session(outer, PathBuf::from("/tmp/proj"));
watcher.watch_session(inner, PathBuf::from("/tmp/proj/sub"));
let hit = watcher.session_for_path(std::path::Path::new("/tmp/proj/sub/main.rs"));
assert_eq!(hit, Some(inner));
let hit = watcher.session_for_path(std::path::Path::new("/tmp/proj/README.md"));
assert_eq!(hit, Some(outer));
assert!(
watcher
.session_for_path(std::path::Path::new("/elsewhere/x"))
.is_none()
);
}
#[test]
fn detects_optimizer_toml_change() {
let state = DaemonState::shared();
let watcher = FileWatcher::new(state.clone());
assert!(FileWatcher::is_optimizer_policy(std::path::Path::new(
"/anywhere/hooks/optimizer.toml"
)));
assert!(!FileWatcher::is_optimizer_policy(std::path::Path::new(
"/anywhere/hooks/other.toml"
)));
assert!(watcher.record_change(std::path::Path::new(
"/x/.trusty-mpm/framework/hooks/optimizer.toml"
)));
assert_eq!(state.recent_hook_events().len(), 0);
}
#[test]
fn synthesises_file_changed_event() {
let state = DaemonState::shared();
let watcher = FileWatcher::new(state.clone());
let s = SessionId::new();
watcher.watch_session(s, PathBuf::from("/tmp/proj"));
assert!(watcher.record_change(std::path::Path::new("/tmp/proj/src/lib.rs")));
let events = state.hook_events_for(s);
assert_eq!(events.len(), 1);
assert_eq!(events[0].event, HookEvent::FileChanged);
assert!(!watcher.record_change(std::path::Path::new("/nowhere/x")));
assert_eq!(state.recent_hook_events().len(), 1);
}
}