use std::collections::HashMap;
use std::ffi::OsStr;
use std::path::{Path, PathBuf};
use std::sync::mpsc::{Receiver, TryRecvError};
use std::sync::{Arc, Mutex, MutexGuard};
use notify::{
Config as NotifyConfig, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher,
};
use super::error::{SessionError, SessionResult};
use crate::config::buffers::watch_event_queue_capacity;
const INDEX_FILE_NAME: &str = ".sqry-index";
type Callback = Arc<dyn Fn() + Send + Sync + 'static>;
struct WatcherState {
watcher: RecommendedWatcher,
rx: Receiver<notify::Result<Event>>,
callbacks: Arc<Mutex<HashMap<PathBuf, Callback>>>,
}
impl WatcherState {
fn lock_callbacks(&self) -> MutexGuard<'_, HashMap<PathBuf, Callback>> {
self.callbacks
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
}
pub struct FileWatcher {
state: Option<WatcherState>,
}
impl FileWatcher {
pub fn new() -> SessionResult<Self> {
let capacity = watch_event_queue_capacity();
let (tx, rx) = std::sync::mpsc::sync_channel(capacity);
let watcher = RecommendedWatcher::new(
move |event| {
let _ = tx.send(event);
},
NotifyConfig::default(),
)
.map_err(SessionError::WatcherInit)?;
Ok(Self {
state: Some(WatcherState {
watcher,
rx,
callbacks: Arc::new(Mutex::new(HashMap::new())),
}),
})
}
#[must_use]
pub fn disabled() -> Self {
Self { state: None }
}
pub fn watch<F>(&mut self, path: PathBuf, on_change: F) -> SessionResult<()>
where
F: Fn() + Send + Sync + 'static,
{
let Some(state) = &mut self.state else {
return Ok(());
};
if state.lock_callbacks().contains_key(&path) {
return Ok(());
}
state
.watcher
.watch(&path, RecursiveMode::NonRecursive)
.map_err(|source| SessionError::WatchIndex {
path: path.clone(),
source,
})?;
state.lock_callbacks().insert(path, Arc::new(on_change));
Ok(())
}
pub fn unwatch(&mut self, path: &Path) -> SessionResult<()> {
let Some(state) = &mut self.state else {
return Ok(());
};
if state.lock_callbacks().remove(path).is_some() {
state
.watcher
.unwatch(path)
.map_err(|source| SessionError::UnwatchIndex {
path: path.to_path_buf(),
source,
})?;
}
Ok(())
}
pub fn process_events(&mut self) -> SessionResult<()> {
let Some(state) = &mut self.state else {
return Ok(());
};
loop {
match state.rx.try_recv() {
Ok(Ok(event)) => Self::handle_event(state, &event),
Ok(Err(err)) => {
log::warn!("file watcher error: {err}");
}
Err(TryRecvError::Empty | TryRecvError::Disconnected) => break,
}
}
Ok(())
}
pub fn wait_and_process(&mut self, duration: std::time::Duration) -> SessionResult<()> {
let Some(state) = &mut self.state else {
return Ok(());
};
let deadline = std::time::Instant::now() + duration;
while std::time::Instant::now() < deadline {
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
let poll_interval = std::time::Duration::from_millis(10).min(remaining);
match state.rx.recv_timeout(poll_interval) {
Ok(Ok(event)) => Self::handle_event(state, &event),
Ok(Err(err)) => {
log::warn!("file watcher error: {err}");
}
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
break;
}
}
}
Ok(())
}
fn handle_event(state: &WatcherState, event: &Event) {
use EventKind::{Any, Create, Modify, Remove};
let relevant = matches!(event.kind, Modify(_) | Create(_) | Remove(_) | Any);
if !relevant {
return;
}
let mut callbacks_to_run: Vec<Callback> = Vec::new();
{
let callbacks = state.lock_callbacks();
for path in &event.paths {
if path
.file_name()
.is_some_and(|name| name == OsStr::new(INDEX_FILE_NAME))
&& let Some(parent) = path.parent()
&& let Some(callback) = callbacks.get(parent)
{
callbacks_to_run.push(Arc::clone(callback));
}
}
}
for callback in callbacks_to_run {
callback();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use tempfile::tempdir;
fn event_timeout() -> Duration {
let base = if cfg!(target_os = "macos") {
Duration::from_secs(3)
} else {
Duration::from_secs(1) };
if std::env::var("CI").is_ok() {
base * 2
} else {
base
}
}
#[test]
#[cfg_attr(target_os = "macos", ignore = "FSEvents timing flaky in CI")]
fn detects_changes_to_index_file() {
let temp = tempdir().unwrap();
let workspace = temp.path();
let index_path = workspace.join(".sqry-index");
std::fs::write(&index_path, b"initial").unwrap();
let mut watcher = FileWatcher::new().unwrap();
let triggered = Arc::new(AtomicBool::new(false));
let flag = Arc::clone(&triggered);
watcher
.watch(workspace.to_path_buf(), move || {
flag.store(true, Ordering::SeqCst);
})
.unwrap();
std::fs::write(&index_path, b"modified").unwrap();
watcher.wait_and_process(event_timeout()).unwrap();
assert!(triggered.load(Ordering::SeqCst));
}
#[test]
fn disabled_watcher_is_noop() {
let temp = tempdir().unwrap();
let workspace = temp.path();
std::fs::write(workspace.join(".sqry-index"), b"data").unwrap();
let mut watcher = FileWatcher::disabled();
watcher
.watch(workspace.to_path_buf(), || {
panic!("disabled watcher should not invoke callback");
})
.unwrap();
watcher.process_events().unwrap();
}
}