use std::collections::HashSet;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use notify::{EventKind, RecommendedWatcher, RecursiveMode, Watcher};
pub struct ShaderWatcher {
_watcher: RecommendedWatcher,
watched_paths: HashSet<PathBuf>,
changed: Arc<Mutex<HashSet<PathBuf>>>,
}
impl ShaderWatcher {
pub fn new() -> Self {
match Self::try_new() {
Ok(w) => w,
Err(e) => panic!("ShaderWatcher: notify watcher could not be initialised: {e}"),
}
}
pub fn try_new() -> Result<Self, ShaderWatchError> {
let changed: Arc<Mutex<HashSet<PathBuf>>> = Arc::new(Mutex::new(HashSet::new()));
let changed_cb = Arc::clone(&changed);
let watcher =
notify::recommended_watcher(move |event: Result<notify::Event, notify::Error>| {
if let Ok(ev) = event {
let is_modify = matches!(ev.kind, EventKind::Modify(_) | EventKind::Create(_));
if is_modify {
if let Ok(mut guard) = changed_cb.lock() {
for path in ev.paths {
guard.insert(path);
}
}
}
}
})
.map_err(|e| ShaderWatchError::Notify(e.to_string()))?;
Ok(ShaderWatcher {
_watcher: watcher,
watched_paths: HashSet::new(),
changed,
})
}
pub fn watch(&mut self, path: &Path) -> Result<(), ShaderWatchError> {
self._watcher
.watch(path, RecursiveMode::NonRecursive)
.map_err(|e| ShaderWatchError::Notify(e.to_string()))?;
let canonical = path
.canonicalize()
.map_err(|e| ShaderWatchError::Io(e.to_string()))?;
self.watched_paths.insert(canonical);
Ok(())
}
pub fn unwatch(&mut self, path: &Path) -> bool {
let _ = self._watcher.unwatch(path);
if let Ok(canonical) = path.canonicalize() {
self.watched_paths.remove(&canonical)
} else {
false
}
}
pub fn drain_changed(&self) -> Vec<PathBuf> {
match self.changed.lock() {
Ok(mut guard) => guard.drain().collect(),
Err(_) => Vec::new(),
}
}
pub fn watched_count(&self) -> usize {
self.watched_paths.len()
}
pub fn is_empty(&self) -> bool {
self.watched_paths.is_empty()
}
}
impl Default for ShaderWatcher {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub enum ShaderWatchError {
Notify(String),
Io(String),
}
impl std::fmt::Display for ShaderWatchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ShaderWatchError::Notify(s) => write!(f, "notify watcher error: {s}"),
ShaderWatchError::Io(s) => write!(f, "I/O error in shader watcher: {s}"),
}
}
}
impl std::error::Error for ShaderWatchError {}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write as _;
#[test]
fn watcher_new_does_not_panic() {
let _w = ShaderWatcher::new();
}
#[test]
fn watcher_default_is_empty() {
let w = ShaderWatcher::default();
assert!(w.is_empty());
assert_eq!(w.watched_count(), 0);
}
#[test]
fn drain_changed_empty_initially() {
let w = ShaderWatcher::new();
let changed = w.drain_changed();
assert!(changed.is_empty(), "no paths watched yet");
}
#[test]
fn watch_nonexistent_returns_error() {
let mut w = ShaderWatcher::new();
let result = w.watch(Path::new("/nonexistent/path/shader.wgsl"));
assert!(
result.is_err(),
"watching a non-existent path must return Err"
);
}
#[test]
fn watch_existing_file_detected() {
let mut w = ShaderWatcher::new();
let dir = std::env::temp_dir();
let pid = std::process::id();
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos())
.unwrap_or(0);
let shader_path = dir.join(format!("oxiui_hot_reload_test_{pid}_{nanos}.wgsl"));
{
let mut f = std::fs::File::create(&shader_path).expect("create test file");
writeln!(f, "@compute @workgroup_size(1) fn noop() {{}}").expect("write");
}
w.watch(&shader_path).expect("watch should succeed");
assert_eq!(w.watched_count(), 1);
{
let mut f = std::fs::OpenOptions::new()
.write(true)
.truncate(true)
.open(&shader_path)
.expect("open for write");
writeln!(f, "@compute @workgroup_size(64) fn noop() {{}}").expect("write");
}
std::thread::sleep(std::time::Duration::from_millis(200));
let changed = w.drain_changed();
let _ = changed;
std::thread::sleep(std::time::Duration::from_millis(50));
let _second = w.drain_changed();
let _ = std::fs::remove_file(&shader_path);
}
#[test]
fn unwatch_not_registered_returns_false() {
let mut w = ShaderWatcher::new();
let result = w.unwatch(Path::new("/some/path.wgsl"));
assert!(!result, "unregistered path must return false");
}
#[test]
fn shader_watch_error_display() {
let e = ShaderWatchError::Notify("backend init failed".into());
assert!(e.to_string().contains("notify"), "{e}");
let e2 = ShaderWatchError::Io("permission denied".into());
assert!(e2.to_string().contains("I/O"), "{e2}");
}
#[test]
fn shader_watch_error_is_std_error() {
fn assert_error<E: std::error::Error>(_: &E) {}
assert_error(&ShaderWatchError::Notify("x".into()));
assert_error(&ShaderWatchError::Io("x".into()));
}
}