Skip to main content

par_term/
shader_watcher.rs

1//! Shader hot reload watcher
2//!
3//! Watches custom shader files for changes and triggers automatic reloading.
4//! Uses debouncing to avoid multiple reloads during rapid saves from editors.
5
6use anyhow::{Context, Result};
7use notify::{Config, Event, PollWatcher, RecursiveMode, Watcher};
8use parking_lot::Mutex;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::sync::Arc;
12use std::sync::mpsc::{Receiver, channel};
13use std::time::{Duration, Instant};
14
15/// Type of shader being watched
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum ShaderType {
18    /// Background/custom shader
19    Background,
20    /// Cursor effect shader
21    Cursor,
22}
23
24/// Event indicating a shader file has changed and needs reloading
25#[derive(Debug, Clone)]
26pub struct ShaderReloadEvent {
27    /// Type of shader that changed
28    pub shader_type: ShaderType,
29    /// Path to the shader file
30    pub path: PathBuf,
31}
32
33/// Manages file watching for shader hot reload
34pub struct ShaderWatcher {
35    /// The file system watcher
36    _watcher: PollWatcher,
37    /// Receiver for file change events
38    event_receiver: Receiver<ShaderReloadEvent>,
39    /// Debounce delay in milliseconds
40    debounce_delay_ms: u64,
41}
42
43impl std::fmt::Debug for ShaderWatcher {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        f.debug_struct("ShaderWatcher")
46            .field("debounce_delay_ms", &self.debounce_delay_ms)
47            .finish_non_exhaustive()
48    }
49}
50
51impl ShaderWatcher {
52    /// Create a new shader watcher
53    ///
54    /// # Arguments
55    /// * `background_shader_path` - Optional path to background shader file
56    /// * `cursor_shader_path` - Optional path to cursor shader file
57    /// * `debounce_delay_ms` - Debounce delay in milliseconds
58    pub fn new(
59        background_shader_path: Option<&Path>,
60        cursor_shader_path: Option<&Path>,
61        debounce_delay_ms: u64,
62    ) -> Result<Self> {
63        let (tx, rx) = channel();
64        let debounce_state: Arc<Mutex<HashMap<ShaderType, Instant>>> =
65            Arc::new(Mutex::new(HashMap::new()));
66
67        // Build mapping of filenames to shader types and track directories to watch
68        // We watch parent directories because many editors use atomic saves (write temp + rename)
69        // which breaks direct file watching
70        let mut filename_to_type: HashMap<std::ffi::OsString, (ShaderType, PathBuf)> =
71            HashMap::new();
72        let mut dirs_to_watch: HashMap<PathBuf, ()> = HashMap::new();
73
74        if let Some(path) = background_shader_path {
75            if !path.exists() {
76                anyhow::bail!("Background shader file not found: {}", path.display());
77            }
78            let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
79            if let Some(filename) = canonical.file_name() {
80                filename_to_type.insert(
81                    filename.to_os_string(),
82                    (ShaderType::Background, canonical.clone()),
83                );
84                if let Some(parent) = canonical.parent() {
85                    dirs_to_watch.insert(parent.to_path_buf(), ());
86                }
87            }
88            log::info!(
89                "Shader hot reload: watching background shader at {}",
90                canonical.display()
91            );
92        }
93        if let Some(path) = cursor_shader_path {
94            if !path.exists() {
95                anyhow::bail!("Cursor shader file not found: {}", path.display());
96            }
97            let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
98            if let Some(filename) = canonical.file_name() {
99                filename_to_type.insert(
100                    filename.to_os_string(),
101                    (ShaderType::Cursor, canonical.clone()),
102                );
103                if let Some(parent) = canonical.parent() {
104                    dirs_to_watch.insert(parent.to_path_buf(), ());
105                }
106            }
107            log::info!(
108                "Shader hot reload: watching cursor shader at {}",
109                canonical.display()
110            );
111        }
112
113        if filename_to_type.is_empty() {
114            anyhow::bail!("No shader paths provided for hot reload");
115        }
116
117        let filename_to_type = Arc::new(filename_to_type);
118        let debounce_delay = Duration::from_millis(debounce_delay_ms);
119        let debounce_state_clone = Arc::clone(&debounce_state);
120
121        // Create the watcher with event handler
122        let mut watcher = PollWatcher::new(
123            move |result: std::result::Result<Event, notify::Error>| {
124                if let Ok(event) = result {
125                    log::debug!(
126                        "File system event: {:?} for paths: {:?}",
127                        event.kind,
128                        event.paths
129                    );
130
131                    // Process modify, create, and rename events (for atomic saves)
132                    if !matches!(
133                        event.kind,
134                        notify::EventKind::Modify(_)
135                            | notify::EventKind::Create(_)
136                            | notify::EventKind::Remove(_)
137                    ) {
138                        log::trace!("Ignoring event kind: {:?}", event.kind);
139                        return;
140                    }
141
142                    let filename_to_type = Arc::clone(&filename_to_type);
143                    let debounce_state = Arc::clone(&debounce_state_clone);
144
145                    // Process each path in the event
146                    for path in event.paths {
147                        // Match by filename (handles atomic saves where path changes)
148                        let Some(filename) = path.file_name() else {
149                            log::trace!("Skipping path with no filename: {:?}", path);
150                            continue;
151                        };
152
153                        let Some((shader_type, canonical_path)) =
154                            filename_to_type.get(filename).cloned()
155                        else {
156                            log::trace!("Filename {:?} not in watch list", filename);
157                            continue;
158                        };
159
160                        // Check debounce using parking_lot Mutex (sync-safe)
161                        let should_send = {
162                            let now = Instant::now();
163                            let mut state = debounce_state.lock();
164                            if let Some(last_event) = state.get(&shader_type) {
165                                if now.duration_since(*last_event) < debounce_delay {
166                                    log::trace!("Debouncing shader reload for {:?}", shader_type);
167                                    false
168                                } else {
169                                    state.insert(shader_type, now);
170                                    true
171                                }
172                            } else {
173                                state.insert(shader_type, now);
174                                true
175                            }
176                        };
177
178                        if should_send {
179                            let reload_event = ShaderReloadEvent {
180                                shader_type,
181                                path: canonical_path,
182                            };
183                            log::info!(
184                                "Shader file changed: {:?} at {}",
185                                shader_type,
186                                reload_event.path.display()
187                            );
188                            if let Err(e) = tx.send(reload_event) {
189                                log::error!("Failed to send shader reload event: {}", e);
190                            }
191                        }
192                    }
193                }
194            },
195            Config::default().with_poll_interval(Duration::from_millis(100)),
196        )
197        .context("Failed to create file watcher")?;
198
199        // Watch parent directories (handles atomic saves from editors like vim, VSCode)
200        for dir in dirs_to_watch.keys() {
201            watcher
202                .watch(dir, RecursiveMode::NonRecursive)
203                .with_context(|| format!("Failed to watch shader directory: {}", dir.display()))?;
204            log::debug!("Watching directory for shader changes: {}", dir.display());
205        }
206
207        Ok(Self {
208            _watcher: watcher,
209            event_receiver: rx,
210            debounce_delay_ms,
211        })
212    }
213
214    /// Check for pending shader reload events (non-blocking)
215    ///
216    /// Returns the next reload event if one is available, or None if no events are pending.
217    pub fn try_recv(&self) -> Option<ShaderReloadEvent> {
218        self.event_receiver.try_recv().ok()
219    }
220
221    /// Get the debounce delay in milliseconds
222    #[allow(dead_code)]
223    pub fn debounce_delay_ms(&self) -> u64 {
224        self.debounce_delay_ms
225    }
226}
227
228/// Builder for creating ShaderWatcher with configuration options
229pub struct ShaderWatcherBuilder {
230    background_shader_path: Option<PathBuf>,
231    cursor_shader_path: Option<PathBuf>,
232    debounce_delay_ms: u64,
233}
234
235impl ShaderWatcherBuilder {
236    /// Create a new builder with default settings
237    pub fn new() -> Self {
238        Self {
239            background_shader_path: None,
240            cursor_shader_path: None,
241            debounce_delay_ms: 100,
242        }
243    }
244
245    /// Set the background shader path
246    #[allow(dead_code)]
247    pub fn background_shader(mut self, path: impl Into<PathBuf>) -> Self {
248        self.background_shader_path = Some(path.into());
249        self
250    }
251
252    /// Set the cursor shader path
253    #[allow(dead_code)]
254    pub fn cursor_shader(mut self, path: impl Into<PathBuf>) -> Self {
255        self.cursor_shader_path = Some(path.into());
256        self
257    }
258
259    /// Set the debounce delay in milliseconds
260    #[allow(dead_code)]
261    pub fn debounce_delay_ms(mut self, delay_ms: u64) -> Self {
262        self.debounce_delay_ms = delay_ms;
263        self
264    }
265
266    /// Build the ShaderWatcher
267    #[allow(dead_code)]
268    pub fn build(self) -> Result<ShaderWatcher> {
269        ShaderWatcher::new(
270            self.background_shader_path.as_deref(),
271            self.cursor_shader_path.as_deref(),
272            self.debounce_delay_ms,
273        )
274    }
275}
276
277impl Default for ShaderWatcherBuilder {
278    fn default() -> Self {
279        Self::new()
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use std::fs;
287    use tempfile::TempDir;
288
289    #[test]
290    fn test_shader_type_equality() {
291        assert_eq!(ShaderType::Background, ShaderType::Background);
292        assert_eq!(ShaderType::Cursor, ShaderType::Cursor);
293        assert_ne!(ShaderType::Background, ShaderType::Cursor);
294    }
295
296    #[test]
297    fn test_shader_watcher_builder_default() {
298        let builder = ShaderWatcherBuilder::default();
299        assert!(builder.background_shader_path.is_none());
300        assert!(builder.cursor_shader_path.is_none());
301        assert_eq!(builder.debounce_delay_ms, 100);
302    }
303
304    #[test]
305    fn test_shader_watcher_builder_with_paths() {
306        let builder = ShaderWatcherBuilder::new()
307            .background_shader("/tmp/test.glsl")
308            .cursor_shader("/tmp/cursor.glsl")
309            .debounce_delay_ms(200);
310
311        assert_eq!(
312            builder.background_shader_path,
313            Some(PathBuf::from("/tmp/test.glsl"))
314        );
315        assert_eq!(
316            builder.cursor_shader_path,
317            Some(PathBuf::from("/tmp/cursor.glsl"))
318        );
319        assert_eq!(builder.debounce_delay_ms, 200);
320    }
321
322    #[test]
323    fn test_watcher_creation_with_valid_path() {
324        let temp_dir = TempDir::new().expect("Failed to create temp dir");
325        let shader_path = temp_dir.path().join("test.glsl");
326        fs::write(
327            &shader_path,
328            "void mainImage(out vec4 fragColor, in vec2 fragCoord) { fragColor = vec4(1.0); }",
329        )
330        .expect("Failed to write shader");
331
332        let result = ShaderWatcher::new(Some(&shader_path), None, 100);
333        assert!(result.is_ok());
334    }
335
336    #[test]
337    fn test_watcher_creation_no_paths_fails() {
338        let result = ShaderWatcher::new(None, None, 100);
339        assert!(result.is_err());
340    }
341
342    #[test]
343    fn test_try_recv_empty() {
344        let temp_dir = TempDir::new().expect("Failed to create temp dir");
345        let shader_path = temp_dir.path().join("test.glsl");
346        fs::write(
347            &shader_path,
348            "void mainImage(out vec4 fragColor, in vec2 fragCoord) { fragColor = vec4(1.0); }",
349        )
350        .expect("Failed to write shader");
351
352        let watcher =
353            ShaderWatcher::new(Some(&shader_path), None, 100).expect("Failed to create watcher");
354
355        // Should return None immediately with no events
356        assert!(watcher.try_recv().is_none());
357    }
358
359    #[test]
360    fn test_shader_reload_event_debug() {
361        let event = ShaderReloadEvent {
362            shader_type: ShaderType::Background,
363            path: PathBuf::from("/tmp/test.glsl"),
364        };
365        let debug_str = format!("{:?}", event);
366        assert!(debug_str.contains("Background"));
367        assert!(debug_str.contains("test.glsl"));
368    }
369
370    #[test]
371    fn test_file_change_triggers_event() {
372        let temp_dir = TempDir::new().expect("Failed to create temp dir");
373        let shader_path = temp_dir.path().join("test.glsl");
374        fs::write(
375            &shader_path,
376            "void mainImage(out vec4 fragColor, in vec2 fragCoord) { fragColor = vec4(1.0); }",
377        )
378        .expect("Failed to write shader");
379
380        let watcher =
381            ShaderWatcher::new(Some(&shader_path), None, 50).expect("Failed to create watcher");
382
383        // Give the watcher time to set up
384        std::thread::sleep(std::time::Duration::from_millis(100));
385
386        // Modify the file
387        fs::write(
388            &shader_path,
389            "void mainImage(out vec4 fragColor, in vec2 fragCoord) { fragColor = vec4(0.5); }",
390        )
391        .expect("Failed to write shader");
392
393        // Wait for the event to be detected
394        std::thread::sleep(std::time::Duration::from_millis(200));
395
396        // Check for the reload event
397        let event = watcher.try_recv();
398        // Note: This may not always trigger on all platforms, so we don't assert
399        if let Some(evt) = event {
400            assert_eq!(evt.shader_type, ShaderType::Background);
401        }
402    }
403}