1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
use anyhow::{Result, bail};
use async_trait::async_trait;
use notify::{
Config, Event, EventKind, PollWatcher, RecursiveMode, Watcher,
event::{CreateKind, ModifyKind},
};
use once_cell::sync::Lazy;
use std::sync::Mutex;
use std::{
path::{Path, PathBuf},
sync::Arc,
};
use tokio::{
sync::mpsc::UnboundedReceiver,
task::JoinHandle,
time::{Duration, sleep},
};
use tracing::{error, warn};
/// A global list of every watcher JoinHandle that has been spawned.
/// We keep it behind a `Mutex<Vec<JoinHandle<()>>>` so we can push new handles
/// as they are created, and later abort() them all at once.
static GLOBAL_WATCHER_HANDLES: Lazy<Mutex<Vec<JoinHandle<()>>>> =
Lazy::new(|| Mutex::new(Vec::new()));
/// Trait for objects that can be hot-reloaded from disk.
#[async_trait]
pub trait WatchedType: Send + Sync + 'static {
fn is_relevant(&self, path: &Path) -> bool;
async fn on_create_or_modify(&self, path: &Path) -> Result<()>;
async fn on_remove(&self, path: &Path) -> Result<()>;
async fn reload(&self, path: &Path) -> Result<()> {
// Default: just call on_create_or_modify
self.on_create_or_modify(path).await
}
}
pub struct Watched(pub Box<dyn WatchedType>);
/// A small wrapper around all of the spawned watcher‐tasks for a single directory.
/// Calling `shutdown()` will abort every task in `handles`.
#[derive(Clone)]
pub struct DirectoryWatcher;
impl DirectoryWatcher {
/// Start watching `dir` for any files whose extension matches `exts` _or_ for
/// any `WatcherType::is_relevant(path)`‐true paths. Returns a `DirectoryWatcher`
/// containing every spawned task. If `_enable_retry` is true, failed reloads that
/// return Err will be retried a few times before giving up.
pub async fn new(
dir: PathBuf,
watcher_impl: Arc<dyn WatchedType>,
exts: &[&str],
enable_retry: bool,
) -> Result<DirectoryWatcher> {
// If the directory doesn’t exist, bail out immediately.
if !dir.exists() {
let msg = format!("Directory {} does not exist", dir.to_string_lossy());
warn!(%msg);
bail!(msg);
}
// 1) On startup: scan the existing entries in `dir` and “reload” each one.
// In other words, if file matches, call `try_reload(...)`.
for entry in std::fs::read_dir(&dir)? {
let entry = entry?;
let path = entry.path();
if watcher_impl.is_relevant(&path) || is_valid_extension(&path, exts) {
try_reload(&watcher_impl, &path, enable_retry).await;
}
}
// 2) Set up channels + notify‐watcher so we can receive Future<NotifyResult<Event>>s.
let (tx, mut rx): (_, UnboundedReceiver<notify::Result<Event>>) =
tokio::sync::mpsc::unbounded_channel();
// 3) Spawn a background task that runs the `notify` poll‐watcher.
// When events arrive, they get sent down `tx`. We collect that handle below.
let dir_clone = dir.clone();
let handle_watcher = tokio::spawn(async move {
let mut watcher = PollWatcher::new(
move |res| {
// we ignore send errors, since if `rx` has dropped, no one is listening.
let _ = tx.send(res);
},
Config::default().with_poll_interval(Duration::from_secs(2)),
)
.expect("failed to create PollWatcher");
watcher
.watch(&dir_clone, RecursiveMode::Recursive)
.expect("failed to watch directory");
// Keep the `watcher` alive inside this blocking context. This task
// will never exit on its own; it’s just the glue that pushes FS events
// into the `tx` channel.
futures::future::pending::<()>().await;
});
// 4) In parallel, spawn a second background task that sits on `rx.recv()`,
// examines each Event, and then, if it matches, calls either
// `on_create_or_modify` or `on_remove`.
let watcher_clone = watcher_impl.clone();
let exts_clone: Vec<String> = exts.iter().map(|s| s.to_string()).collect();
let handle_dispatch = tokio::spawn(async move {
while let Some(res) = rx.recv().await {
match res {
Ok(Event {
kind:
EventKind::Create(CreateKind::Any) | EventKind::Modify(ModifyKind::Data(_)),
paths,
..
}) => {
for path in paths {
if watcher_clone.is_relevant(&path)
|| path
.extension()
.and_then(|e| e.to_str())
.map_or(false, |e| exts_clone.contains(&e.to_string()))
{
let path_clone = path.clone();
let watcher_inner = watcher_clone.clone();
// Spawn a short‐lived task to handle this single path.
tokio::spawn(async move {
if let Err(e) =
watcher_inner.on_create_or_modify(&path_clone).await
{
warn!(?path_clone, ?e, "Failed to handle create/modify");
}
});
}
}
}
Ok(Event {
kind: EventKind::Remove(_),
paths,
..
}) => {
for path in paths {
if watcher_clone.is_relevant(&path)
|| path
.extension()
.and_then(|e| e.to_str())
.map_or(false, |e| exts_clone.contains(&e.to_string()))
{
let path_clone = path.clone();
let watcher_inner = watcher_clone.clone();
//tokio::spawn(async move {
if let Err(e) = watcher_inner.on_remove(&path_clone).await {
warn!(?path_clone, ?e, "Failed to handle removal");
}
// });
}
}
}
Err(e) => {
warn!(?e, "Watcher error");
}
_ => {}
}
}
});
let mut guard = GLOBAL_WATCHER_HANDLES.lock().unwrap();
guard.append(&mut vec![handle_dispatch, handle_watcher]);
// 5) Return a DirectoryWatcher holding both handles.
Ok(DirectoryWatcher {})
}
/// Abort all of the spawned watcher‐tasks. After this returns, no more events
/// will be dispatched.
pub fn shutdown(self) {
let mut guard = GLOBAL_WATCHER_HANDLES.lock().unwrap();
for handle in guard.drain(..) {
handle.abort();
}
}
}
fn is_valid_extension(path: &Path, extensions: &[&str]) -> bool {
path.extension()
.and_then(|ext| ext.to_str())
.map_or(false, |ext| extensions.iter().any(|&e| e == ext))
}
async fn try_reload(watched: &Arc<dyn WatchedType>, path: &Path, retry: bool) {
const MAX_RETRIES: usize = 10;
for attempt in 0..MAX_RETRIES {
let result = watched.reload(path).await;
match result {
Ok(_) => return,
Err(e) => {
if !retry || attempt == MAX_RETRIES - 1 {
error!("Failed to reload {:?}: {e:?}", path);
return;
} else {
warn!(
"Retrying reload {:?} (attempt {}): {e:?}",
path,
attempt + 1
);
sleep(Duration::from_millis(100)).await;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
use std::sync::atomic::{AtomicUsize, Ordering};
struct DummyWatcher {
created: Arc<AtomicUsize>,
removed: Arc<AtomicUsize>,
allowed_exts: HashSet<String>,
}
#[async_trait]
impl WatchedType for DummyWatcher {
fn is_relevant(&self, path: &Path) -> bool {
path.extension()
.and_then(|e| e.to_str())
.map(|e| self.allowed_exts.contains(e))
.unwrap_or(false)
}
async fn on_create_or_modify(&self, _path: &Path) -> Result<()> {
self.created.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn on_remove(&self, _path: &Path) -> Result<()> {
self.removed.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_dummy_watch() {
let watcher = DummyWatcher {
created: Arc::new(AtomicUsize::new(0)),
removed: Arc::new(AtomicUsize::new(0)),
allowed_exts: ["txt"].iter().map(|s| s.to_string()).collect(),
};
let dir = PathBuf::from("./tests/tmp_watcher");
let _ = std::fs::create_dir_all(&dir);
let path = dir.join("test_file.txt");
let _ = std::fs::write(&path, "hello");
//let handle = tokio::spawn(watch_dir(dir.clone(), Arc::new(watcher), &["txt"], false));
let watcher_arc: Arc<dyn WatchedType> = Arc::new(watcher);
let dir_watcher = DirectoryWatcher::new(dir.clone(), watcher_arc.clone(), &["txt"], false)
.await
.unwrap();
// Give time for watcher to process
sleep(Duration::from_millis(300)).await;
let _ = std::fs::remove_file(path);
sleep(Duration::from_millis(300)).await;
// We won't assert here; this is just to ensure no panic occurs
//let _ = handle.await;
dir_watcher.shutdown();
}
}