use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::io;
use std::marker::PhantomData;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use tor_rtcompat::Runtime;
use amplify::Getters;
use notify::{EventKind, Watcher};
use postage::watch;
use futures::{Stream, StreamExt as _};
pub type Result<T> = std::result::Result<T, FileWatcherBuildError>;
cfg_if::cfg_if! {
if #[cfg(any(target_os = "linux", target_os = "android", target_os = "windows"))] {
type NotifyWatcher = notify::RecommendedWatcher;
} else {
type NotifyWatcher = notify::PollWatcher;
}
}
#[derive(Getters)]
pub struct FileWatcher {
#[getter(skip)]
_watcher: NotifyWatcher,
watching_dirs: HashSet<PathBuf>,
}
impl FileWatcher {
pub fn builder<R: Runtime>(runtime: R) -> FileWatcherBuilder<R> {
FileWatcherBuilder::new(runtime)
}
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum Event {
FileChanged,
Rescan,
}
pub struct FileWatcherBuilder<R: Runtime> {
#[allow(dead_code)]
runtime: PhantomData<R>,
watching_dirs: HashMap<PathBuf, HashSet<DirEventFilter>>,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
enum DirEventFilter {
MatchesExtension(String),
MatchesPath(PathBuf),
}
impl DirEventFilter {
fn accepts_path(&self, path: &Path) -> bool {
match self {
DirEventFilter::MatchesExtension(ext) => path
.extension()
.and_then(|ext| ext.to_str())
.map(|e| e == ext.as_str())
.unwrap_or_default(),
DirEventFilter::MatchesPath(p) => p == path,
}
}
}
impl<R: Runtime> FileWatcherBuilder<R> {
pub fn new(_runtime: R) -> Self {
FileWatcherBuilder {
runtime: PhantomData,
watching_dirs: HashMap::new(),
}
}
pub fn watch_path<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
self.watch_just_parents(path.as_ref())?;
Ok(())
}
pub fn watch_dir<P: AsRef<Path>, S: AsRef<str>>(
&mut self,
path: P,
extension: S,
) -> Result<()> {
let path = self.watch_just_parents(path.as_ref())?;
self.watch_just_abs_dir(
&path,
DirEventFilter::MatchesExtension(extension.as_ref().into()),
);
Ok(())
}
fn watch_just_parents(&mut self, path: &Path) -> Result<PathBuf> {
let cwd = std::env::current_dir()
.map_err(|e| FileWatcherBuildError::CurrentDirectory(Arc::new(e)))?;
let path = cwd.join(path);
debug_assert!(path.is_absolute());
let watch_target = match path.parent() {
Some(parent) => parent,
None => path.as_ref(),
};
self.watch_just_abs_dir(watch_target, DirEventFilter::MatchesPath(path.clone()));
Ok(path)
}
fn watch_just_abs_dir(&mut self, watch_target: &Path, filter: DirEventFilter) {
match self.watching_dirs.entry(watch_target.to_path_buf()) {
Entry::Occupied(mut o) => {
let _: bool = o.get_mut().insert(filter);
}
Entry::Vacant(v) => {
let _ = v.insert(HashSet::from([filter]));
}
}
}
pub fn start_watching(self, tx: FileEventSender) -> Result<FileWatcher> {
let watching_dirs = self.watching_dirs.clone();
let event_sender = move |event: notify::Result<notify::Event>| {
let event = handle_event(event, &watching_dirs);
if let Some(event) = event {
*tx.0.lock().expect("poisoned").borrow_mut() = event;
}
};
cfg_if::cfg_if! {
if #[cfg(any(target_os = "linux", target_os = "android", target_os = "windows"))] {
let config = notify::Config::default();
} else {
#[cfg(not(any(test, feature = "testing")))]
const WATCHER_POLL_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
#[cfg(any(test, feature = "testing"))]
const WATCHER_POLL_INTERVAL: std::time::Duration = std::time::Duration::from_millis(10);
let config = notify::Config::default()
.with_poll_interval(WATCHER_POLL_INTERVAL);
#[cfg(any(test, feature = "testing"))]
let config = config.with_compare_contents(true);
}
}
let mut watcher = NotifyWatcher::new(event_sender, config).map_err(Arc::new)?;
let watching_dirs: HashSet<_> = self.watching_dirs.keys().cloned().collect();
for dir in &watching_dirs {
watcher
.watch(dir, notify::RecursiveMode::NonRecursive)
.map_err(Arc::new)?;
}
Ok(FileWatcher {
_watcher: watcher,
watching_dirs,
})
}
}
fn handle_event(
event: notify::Result<notify::Event>,
watching_dirs: &HashMap<PathBuf, HashSet<DirEventFilter>>,
) -> Option<Event> {
let watching = |f: &PathBuf| {
let parent = f.parent().unwrap_or_else(|| f.as_ref());
match watching_dirs
.iter()
.find_map(|(dir, filters)| (dir == parent).then_some(filters))
{
Some(filters) => {
filters.iter().any(|filter| filter.accepts_path(f.as_ref()))
}
None => false,
}
};
match event {
Ok(event) => {
if event.need_rescan() {
Some(Event::Rescan)
} else if ignore_event_kind(&event.kind) {
None
} else if event.paths.iter().any(watching) {
Some(Event::FileChanged)
} else {
None
}
}
Err(error) => {
if error.paths.iter().any(watching) {
Some(Event::FileChanged)
} else {
None
}
}
}
}
fn ignore_event_kind(kind: &EventKind) -> bool {
use EventKind::*;
matches!(kind, Access(_) | Any | Other)
}
#[derive(Clone)]
pub struct FileEventSender(Arc<Mutex<watch::Sender<Event>>>);
#[derive(Clone)]
pub struct FileEventReceiver(watch::Receiver<Event>);
impl Stream for FileEventReceiver {
type Item = Event;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.0.poll_next_unpin(cx)
}
}
impl FileEventReceiver {
pub fn try_recv(&mut self) -> Option<Event> {
use postage::prelude::Stream;
self.0.try_recv().ok()
}
}
pub fn channel() -> (FileEventSender, FileEventReceiver) {
let (tx, rx) = watch::channel_with(Event::Rescan);
(
FileEventSender(Arc::new(Mutex::new(tx))),
FileEventReceiver(rx),
)
}
#[derive(Debug, Clone, thiserror::Error)]
#[non_exhaustive]
pub enum FileWatcherBuildError {
#[error("Invalid current working directory")]
CurrentDirectory(#[source] Arc<io::Error>),
#[error("Problem creating Watcher")]
Notify(#[from] Arc<notify::Error>),
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_duration_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
use notify::event::{AccessKind, ModifyKind};
use test_temp_dir::{TestTempDir, test_temp_dir};
fn write_file(dir: &TestTempDir, name: &str, data: &[u8]) -> PathBuf {
let path = dir.as_path_untracked().join(name);
std::fs::write(&path, data).unwrap();
path
}
fn rescan_event() -> notify::Event {
let event = notify::Event::new(notify::EventKind::Any);
event.set_flag(notify::event::Flag::Rescan)
}
async fn assert_file_changed(rx: &mut FileEventReceiver) {
assert_eq!(rx.next().await, Some(Event::FileChanged));
while let Some(ev) = rx.try_recv() {
assert_eq!(ev, Event::FileChanged);
}
}
fn assert_ignored(event: ¬ify::Event, watching: &HashMap<PathBuf, HashSet<DirEventFilter>>) {
for kind in [EventKind::Access(AccessKind::Any), EventKind::Other] {
let ignored_event = event.clone().set_kind(kind);
assert_eq!(handle_event(Ok(ignored_event.clone()), watching), None);
let event = ignored_event.set_flag(notify::event::Flag::Rescan);
assert_eq!(handle_event(Ok(event), watching), Some(Event::Rescan));
}
}
#[test]
fn notify_event_handler() {
let mut event = notify::Event::new(notify::EventKind::Modify(ModifyKind::Any));
let mut watching_dirs = Default::default();
assert_eq!(handle_event(Ok(event.clone()), &watching_dirs), None);
assert_eq!(
handle_event(Ok(rescan_event()), &watching_dirs),
Some(Event::Rescan)
);
watching_dirs.insert(
"/foo/baz".into(),
HashSet::from([DirEventFilter::MatchesExtension("auth".into())]),
);
assert_eq!(handle_event(Ok(event.clone()), &watching_dirs), None);
assert_eq!(
handle_event(Ok(rescan_event()), &watching_dirs),
Some(Event::Rescan)
);
event = event.add_path("/foo/bar/alice.authh".into());
assert_eq!(handle_event(Ok(event.clone()), &watching_dirs), None);
event = event.add_path("/foo/bar/alice.auth".into());
assert_eq!(handle_event(Ok(event.clone()), &watching_dirs), None);
event = event.add_path("/foo/baz/bob.auth".into());
assert_eq!(
handle_event(Ok(event.clone()), &watching_dirs),
Some(Event::FileChanged)
);
assert_ignored(&event, &watching_dirs);
watching_dirs.insert(
"/foo/bar".into(),
HashSet::from([DirEventFilter::MatchesPath("/foo/bar/abc".into())]),
);
assert_eq!(
handle_event(Ok(event.clone()), &watching_dirs),
Some(Event::FileChanged)
);
assert_eq!(
handle_event(Ok(rescan_event()), &watching_dirs),
Some(Event::Rescan)
);
assert_ignored(&event, &watching_dirs);
let event = notify::Event::new(notify::EventKind::Modify(ModifyKind::Any))
.add_path("/a/b/c/d".into());
let watching_dirs = [(
"/a/b/c/".into(),
HashSet::from([DirEventFilter::MatchesPath("/a/b/c/d".into())]),
)]
.into_iter()
.collect();
assert_eq!(
handle_event(Ok(event), &watching_dirs),
Some(Event::FileChanged)
);
assert_eq!(
handle_event(Ok(rescan_event()), &watching_dirs),
Some(Event::Rescan)
);
let err = notify::Error::path_not_found();
assert_eq!(handle_event(Err(err), &watching_dirs), None);
let mut err = notify::Error::path_not_found();
err = err.add_path("/a/b/c/d".into());
assert_eq!(
handle_event(Err(err), &watching_dirs),
Some(Event::FileChanged)
);
}
#[test]
fn watch_dirs() {
tor_rtcompat::test_with_one_runtime!(|rt| async move {
let temp_dir = test_temp_dir!();
let (tx, mut rx) = channel();
let mut builder = FileWatcher::builder(rt.clone());
builder
.watch_dir(temp_dir.as_path_untracked(), "foo")
.unwrap();
let watcher = builder.start_watching(tx).unwrap();
assert_eq!(rx.try_recv(), Some(Event::Rescan));
assert_eq!(rx.try_recv(), None);
write_file(&temp_dir, "bar.foo", b"hello");
assert_eq!(rx.next().await, Some(Event::FileChanged));
drop(watcher);
while let Some(ev) = rx.next().await {
assert_eq!(ev.clone(), Event::FileChanged);
}
});
}
#[test]
fn watch_file_path() {
tor_rtcompat::test_with_one_runtime!(|rt| async move {
let temp_dir = test_temp_dir!();
let (tx, mut rx) = channel();
let path = write_file(&temp_dir, "hello.txt", b"hello");
let mut builder = FileWatcher::builder(rt.clone());
builder.watch_path(&path).unwrap();
let _watcher = builder.start_watching(tx).unwrap();
assert_eq!(rx.try_recv(), Some(Event::Rescan));
assert_eq!(rx.try_recv(), None);
let _: PathBuf = write_file(&temp_dir, "hello.txt", b"good-bye");
assert_file_changed(&mut rx).await;
std::fs::remove_file(&path).unwrap();
assert_file_changed(&mut rx).await;
let tmp_hello = write_file(&temp_dir, "hello.tmp", b"new hello");
std::fs::rename(&tmp_hello, &path).unwrap();
assert_file_changed(&mut rx).await;
});
}
#[test]
fn watch_dir_path() {
tor_rtcompat::test_with_one_runtime!(|rt| async move {
let temp_dir1 = tempfile::TempDir::new().unwrap();
let (tx, mut rx) = channel();
let mut builder = FileWatcher::builder(rt.clone());
builder.watch_path(temp_dir1.path()).unwrap();
let _watcher = builder.start_watching(tx).unwrap();
assert_eq!(rx.try_recv(), Some(Event::Rescan));
assert_eq!(rx.try_recv(), None);
std::fs::write(temp_dir1.path().join("hello.txt"), b"hello").unwrap();
assert_eq!(rx.try_recv(), None);
let temp_dir2 = tempfile::TempDir::new().unwrap();
std::fs::rename(&temp_dir1, &temp_dir2).unwrap();
assert_file_changed(&mut rx).await;
std::fs::rename(&temp_dir2, &temp_dir1).unwrap();
assert_file_changed(&mut rx).await;
});
}
}