1use anyhow::{Context, Result};
7use notify::{Config, Event, PollWatcher, RecursiveMode, Watcher};
8use parking_lot::Mutex;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::sync::Arc;
12use std::sync::mpsc::{Receiver, channel};
13use std::time::{Duration, Instant};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum ShaderType {
18 Background,
20 Cursor,
22}
23
24#[derive(Debug, Clone)]
26pub struct ShaderReloadEvent {
27 pub shader_type: ShaderType,
29 pub path: PathBuf,
31}
32
33pub struct ShaderWatcher {
35 _watcher: PollWatcher,
37 event_receiver: Receiver<ShaderReloadEvent>,
39 debounce_delay_ms: u64,
41}
42
43impl std::fmt::Debug for ShaderWatcher {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 f.debug_struct("ShaderWatcher")
46 .field("debounce_delay_ms", &self.debounce_delay_ms)
47 .finish_non_exhaustive()
48 }
49}
50
51impl ShaderWatcher {
52 pub fn new(
59 background_shader_path: Option<&Path>,
60 cursor_shader_path: Option<&Path>,
61 debounce_delay_ms: u64,
62 ) -> Result<Self> {
63 let (tx, rx) = channel();
64 let debounce_state: Arc<Mutex<HashMap<ShaderType, Instant>>> =
65 Arc::new(Mutex::new(HashMap::new()));
66
67 let mut filename_to_type: HashMap<std::ffi::OsString, (ShaderType, PathBuf)> =
71 HashMap::new();
72 let mut dirs_to_watch: HashMap<PathBuf, ()> = HashMap::new();
73
74 if let Some(path) = background_shader_path {
75 if !path.exists() {
76 anyhow::bail!("Background shader file not found: {}", path.display());
77 }
78 let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
79 if let Some(filename) = canonical.file_name() {
80 filename_to_type.insert(
81 filename.to_os_string(),
82 (ShaderType::Background, canonical.clone()),
83 );
84 if let Some(parent) = canonical.parent() {
85 dirs_to_watch.insert(parent.to_path_buf(), ());
86 }
87 }
88 log::info!(
89 "Shader hot reload: watching background shader at {}",
90 canonical.display()
91 );
92 }
93 if let Some(path) = cursor_shader_path {
94 if !path.exists() {
95 anyhow::bail!("Cursor shader file not found: {}", path.display());
96 }
97 let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
98 if let Some(filename) = canonical.file_name() {
99 filename_to_type.insert(
100 filename.to_os_string(),
101 (ShaderType::Cursor, canonical.clone()),
102 );
103 if let Some(parent) = canonical.parent() {
104 dirs_to_watch.insert(parent.to_path_buf(), ());
105 }
106 }
107 log::info!(
108 "Shader hot reload: watching cursor shader at {}",
109 canonical.display()
110 );
111 }
112
113 if filename_to_type.is_empty() {
114 anyhow::bail!("No shader paths provided for hot reload");
115 }
116
117 let filename_to_type = Arc::new(filename_to_type);
118 let debounce_delay = Duration::from_millis(debounce_delay_ms);
119 let debounce_state_clone = Arc::clone(&debounce_state);
120
121 let mut watcher = PollWatcher::new(
123 move |result: std::result::Result<Event, notify::Error>| {
124 if let Ok(event) = result {
125 log::debug!(
126 "File system event: {:?} for paths: {:?}",
127 event.kind,
128 event.paths
129 );
130
131 if !matches!(
133 event.kind,
134 notify::EventKind::Modify(_)
135 | notify::EventKind::Create(_)
136 | notify::EventKind::Remove(_)
137 ) {
138 log::trace!("Ignoring event kind: {:?}", event.kind);
139 return;
140 }
141
142 let filename_to_type = Arc::clone(&filename_to_type);
143 let debounce_state = Arc::clone(&debounce_state_clone);
144
145 for path in event.paths {
147 let Some(filename) = path.file_name() else {
149 log::trace!("Skipping path with no filename: {:?}", path);
150 continue;
151 };
152
153 let Some((shader_type, canonical_path)) =
154 filename_to_type.get(filename).cloned()
155 else {
156 log::trace!("Filename {:?} not in watch list", filename);
157 continue;
158 };
159
160 let should_send = {
162 let now = Instant::now();
163 let mut state = debounce_state.lock();
164 if let Some(last_event) = state.get(&shader_type) {
165 if now.duration_since(*last_event) < debounce_delay {
166 log::trace!("Debouncing shader reload for {:?}", shader_type);
167 false
168 } else {
169 state.insert(shader_type, now);
170 true
171 }
172 } else {
173 state.insert(shader_type, now);
174 true
175 }
176 };
177
178 if should_send {
179 let reload_event = ShaderReloadEvent {
180 shader_type,
181 path: canonical_path,
182 };
183 log::info!(
184 "Shader file changed: {:?} at {}",
185 shader_type,
186 reload_event.path.display()
187 );
188 if let Err(e) = tx.send(reload_event) {
189 log::error!("Failed to send shader reload event: {}", e);
190 }
191 }
192 }
193 }
194 },
195 Config::default().with_poll_interval(Duration::from_millis(100)),
196 )
197 .context("Failed to create file watcher")?;
198
199 for dir in dirs_to_watch.keys() {
201 watcher
202 .watch(dir, RecursiveMode::NonRecursive)
203 .with_context(|| format!("Failed to watch shader directory: {}", dir.display()))?;
204 log::debug!("Watching directory for shader changes: {}", dir.display());
205 }
206
207 Ok(Self {
208 _watcher: watcher,
209 event_receiver: rx,
210 debounce_delay_ms,
211 })
212 }
213
214 pub fn try_recv(&self) -> Option<ShaderReloadEvent> {
218 self.event_receiver.try_recv().ok()
219 }
220
221 #[allow(dead_code)]
223 pub fn debounce_delay_ms(&self) -> u64 {
224 self.debounce_delay_ms
225 }
226}
227
228pub struct ShaderWatcherBuilder {
230 background_shader_path: Option<PathBuf>,
231 cursor_shader_path: Option<PathBuf>,
232 debounce_delay_ms: u64,
233}
234
235impl ShaderWatcherBuilder {
236 pub fn new() -> Self {
238 Self {
239 background_shader_path: None,
240 cursor_shader_path: None,
241 debounce_delay_ms: 100,
242 }
243 }
244
245 #[allow(dead_code)]
247 pub fn background_shader(mut self, path: impl Into<PathBuf>) -> Self {
248 self.background_shader_path = Some(path.into());
249 self
250 }
251
252 #[allow(dead_code)]
254 pub fn cursor_shader(mut self, path: impl Into<PathBuf>) -> Self {
255 self.cursor_shader_path = Some(path.into());
256 self
257 }
258
259 #[allow(dead_code)]
261 pub fn debounce_delay_ms(mut self, delay_ms: u64) -> Self {
262 self.debounce_delay_ms = delay_ms;
263 self
264 }
265
266 #[allow(dead_code)]
268 pub fn build(self) -> Result<ShaderWatcher> {
269 ShaderWatcher::new(
270 self.background_shader_path.as_deref(),
271 self.cursor_shader_path.as_deref(),
272 self.debounce_delay_ms,
273 )
274 }
275}
276
277impl Default for ShaderWatcherBuilder {
278 fn default() -> Self {
279 Self::new()
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use std::fs;
287 use tempfile::TempDir;
288
289 #[test]
290 fn test_shader_type_equality() {
291 assert_eq!(ShaderType::Background, ShaderType::Background);
292 assert_eq!(ShaderType::Cursor, ShaderType::Cursor);
293 assert_ne!(ShaderType::Background, ShaderType::Cursor);
294 }
295
296 #[test]
297 fn test_shader_watcher_builder_default() {
298 let builder = ShaderWatcherBuilder::default();
299 assert!(builder.background_shader_path.is_none());
300 assert!(builder.cursor_shader_path.is_none());
301 assert_eq!(builder.debounce_delay_ms, 100);
302 }
303
304 #[test]
305 fn test_shader_watcher_builder_with_paths() {
306 let builder = ShaderWatcherBuilder::new()
307 .background_shader("/tmp/test.glsl")
308 .cursor_shader("/tmp/cursor.glsl")
309 .debounce_delay_ms(200);
310
311 assert_eq!(
312 builder.background_shader_path,
313 Some(PathBuf::from("/tmp/test.glsl"))
314 );
315 assert_eq!(
316 builder.cursor_shader_path,
317 Some(PathBuf::from("/tmp/cursor.glsl"))
318 );
319 assert_eq!(builder.debounce_delay_ms, 200);
320 }
321
322 #[test]
323 fn test_watcher_creation_with_valid_path() {
324 let temp_dir = TempDir::new().expect("Failed to create temp dir");
325 let shader_path = temp_dir.path().join("test.glsl");
326 fs::write(
327 &shader_path,
328 "void mainImage(out vec4 fragColor, in vec2 fragCoord) { fragColor = vec4(1.0); }",
329 )
330 .expect("Failed to write shader");
331
332 let result = ShaderWatcher::new(Some(&shader_path), None, 100);
333 assert!(result.is_ok());
334 }
335
336 #[test]
337 fn test_watcher_creation_no_paths_fails() {
338 let result = ShaderWatcher::new(None, None, 100);
339 assert!(result.is_err());
340 }
341
342 #[test]
343 fn test_try_recv_empty() {
344 let temp_dir = TempDir::new().expect("Failed to create temp dir");
345 let shader_path = temp_dir.path().join("test.glsl");
346 fs::write(
347 &shader_path,
348 "void mainImage(out vec4 fragColor, in vec2 fragCoord) { fragColor = vec4(1.0); }",
349 )
350 .expect("Failed to write shader");
351
352 let watcher =
353 ShaderWatcher::new(Some(&shader_path), None, 100).expect("Failed to create watcher");
354
355 assert!(watcher.try_recv().is_none());
357 }
358
359 #[test]
360 fn test_shader_reload_event_debug() {
361 let event = ShaderReloadEvent {
362 shader_type: ShaderType::Background,
363 path: PathBuf::from("/tmp/test.glsl"),
364 };
365 let debug_str = format!("{:?}", event);
366 assert!(debug_str.contains("Background"));
367 assert!(debug_str.contains("test.glsl"));
368 }
369
370 #[test]
371 fn test_file_change_triggers_event() {
372 let temp_dir = TempDir::new().expect("Failed to create temp dir");
373 let shader_path = temp_dir.path().join("test.glsl");
374 fs::write(
375 &shader_path,
376 "void mainImage(out vec4 fragColor, in vec2 fragCoord) { fragColor = vec4(1.0); }",
377 )
378 .expect("Failed to write shader");
379
380 let watcher =
381 ShaderWatcher::new(Some(&shader_path), None, 50).expect("Failed to create watcher");
382
383 std::thread::sleep(std::time::Duration::from_millis(100));
385
386 fs::write(
388 &shader_path,
389 "void mainImage(out vec4 fragColor, in vec2 fragCoord) { fragColor = vec4(0.5); }",
390 )
391 .expect("Failed to write shader");
392
393 std::thread::sleep(std::time::Duration::from_millis(200));
395
396 let event = watcher.try_recv();
398 if let Some(evt) = event {
400 assert_eq!(evt.shader_type, ShaderType::Background);
401 }
402 }
403}