Skip to main content

pitchfork_cli/
watch_files.rs

1use crate::Result;
2use crate::pitchfork_toml::WatchMode;
3use glob::glob;
4use itertools::Itertools;
5use miette::IntoDiagnostic;
6use notify::{Config, EventKind, PollWatcher, RecommendedWatcher, RecursiveMode};
7use notify_debouncer_full::{DebounceEventResult, Debouncer, FileIdMap, new_debouncer_opt};
8use std::collections::HashSet;
9use std::path::{Path, PathBuf};
10use std::time::Duration;
11
12pub struct WatchFiles {
13    pub rx: tokio::sync::mpsc::Receiver<Vec<PathBuf>>,
14    backend: WatchFilesBackend,
15}
16
17enum WatchFilesBackend {
18    Native(Debouncer<RecommendedWatcher, FileIdMap>),
19    Poll(Debouncer<PollWatcher, FileIdMap>),
20}
21
22impl WatchFiles {
23    pub fn new(duration: Duration, mode: WatchMode, poll_interval: Duration) -> Result<Self> {
24        let h = tokio::runtime::Handle::current();
25        let (tx, rx) = tokio::sync::mpsc::channel(1);
26        let make_callback = |tx: tokio::sync::mpsc::Sender<Vec<PathBuf>>,
27                             h: tokio::runtime::Handle| {
28            move |res: DebounceEventResult| {
29                let tx = tx.clone();
30                h.spawn(async move {
31                    if let Ok(ev) = res {
32                        let paths = ev
33                            .into_iter()
34                            .filter(|e| {
35                                matches!(
36                                    e.kind,
37                                    EventKind::Modify(_)
38                                        | EventKind::Create(_)
39                                        | EventKind::Remove(_)
40                                )
41                            })
42                            .flat_map(|e| e.paths.clone())
43                            .unique()
44                            .collect_vec();
45                        if !paths.is_empty() {
46                            // Ignore send errors - receiver may be dropped during shutdown
47                            let _ = tx.send(paths).await;
48                        }
49                    }
50                });
51            }
52        };
53
54        let backend = match mode {
55            WatchMode::Native => WatchFilesBackend::Native(
56                new_debouncer_opt(
57                    duration,
58                    None,
59                    make_callback(tx.clone(), h.clone()),
60                    FileIdMap::new(),
61                    Config::default(),
62                )
63                .into_diagnostic()?,
64            ),
65            WatchMode::Poll => WatchFilesBackend::Poll(
66                new_debouncer_opt(
67                    duration,
68                    None,
69                    make_callback(tx.clone(), h.clone()),
70                    FileIdMap::new(),
71                    Config::default().with_poll_interval(poll_interval),
72                )
73                .into_diagnostic()?,
74            ),
75            WatchMode::Auto => {
76                return Err(miette::miette!(
77                    "WatchMode::Auto must not be passed directly to WatchFiles::new; \
78                     the caller must resolve auto to native or poll"
79                ));
80            }
81        };
82
83        Ok(Self { backend, rx })
84    }
85
86    pub fn watch(&mut self, path: &Path, recursive_mode: RecursiveMode) -> Result<()> {
87        match &mut self.backend {
88            WatchFilesBackend::Native(debouncer) => {
89                debouncer.watch(path, recursive_mode).into_diagnostic()
90            }
91            WatchFilesBackend::Poll(debouncer) => {
92                debouncer.watch(path, recursive_mode).into_diagnostic()
93            }
94        }
95    }
96
97    pub fn unwatch(&mut self, path: &Path) -> Result<()> {
98        match &mut self.backend {
99            WatchFilesBackend::Native(debouncer) => debouncer.unwatch(path).into_diagnostic(),
100            WatchFilesBackend::Poll(debouncer) => debouncer.unwatch(path).into_diagnostic(),
101        }
102    }
103}
104
105/// Normalize a path by attempting to canonicalize it. If that fails, it attempts
106/// to resolve it as an absolute path. This helps ensure that different relative
107/// paths to the same directory are deduplicated.
108fn normalize_watch_path(path: &Path) -> PathBuf {
109    path.canonicalize().unwrap_or_else(|_| {
110        if path.is_absolute() {
111            path.to_path_buf()
112        } else {
113            crate::env::CWD.join(path)
114        }
115    })
116}
117
118/// Expand glob patterns to actual file paths.
119/// Patterns are resolved relative to base_dir.
120/// Returns unique directories that need to be watched.
121pub fn expand_watch_patterns(patterns: &[String], base_dir: &Path) -> Result<HashSet<PathBuf>> {
122    let mut dirs_to_watch = HashSet::new();
123
124    for pattern in patterns {
125        // Strip leading "./" from patterns to handle relative path prefixes
126        let normalized_pattern = pattern.strip_prefix("./").unwrap_or(pattern);
127
128        // Make the pattern absolute by joining with base_dir
129        let full_pattern = if Path::new(normalized_pattern).is_absolute() {
130            normalize_path_for_glob(normalized_pattern)
131        } else {
132            normalize_path_for_glob(&base_dir.join(normalized_pattern).to_string_lossy())
133        };
134
135        // Expand the glob pattern
136        match glob(&full_pattern) {
137            Ok(paths) => {
138                for entry in paths.flatten() {
139                    // Watch the parent directory of each matched file
140                    // This allows us to detect new files that match the pattern
141                    if let Some(parent) = entry.parent() {
142                        dirs_to_watch.insert(normalize_watch_path(parent));
143                    }
144                }
145            }
146            Err(e) => {
147                log::warn!("Invalid glob pattern '{pattern}': {e}");
148            }
149        }
150
151        // For patterns with wildcards, watch the base directory (before the wildcard)
152        // For non-wildcard patterns, watch the parent directory of the specific file
153        // This ensures we catch new files even if they don't exist at startup
154        if normalized_pattern.contains('*') {
155            // Find the first directory without wildcards
156            // Normalize to use forward slashes for cross-platform compatibility
157            let normalized_pattern_str = normalize_path_for_glob(normalized_pattern);
158            let parts: Vec<&str> = normalized_pattern_str.split('/').collect();
159            let mut base = base_dir.to_path_buf();
160            for part in parts {
161                if part.contains('*') {
162                    break;
163                }
164                base = base.join(part);
165            }
166            // Watch the base directory if it exists, otherwise fall back to base_dir
167            // This ensures we can detect when the directory is created
168            let dir_to_watch = if base.is_dir() {
169                base
170            } else {
171                base_dir.to_path_buf()
172            };
173            dirs_to_watch.insert(normalize_watch_path(&dir_to_watch));
174        } else {
175            // Non-wildcard pattern (specific file like "package.json")
176            // Always watch the parent directory, even if file doesn't exist yet
177            let full_path = if Path::new(normalized_pattern).is_absolute() {
178                PathBuf::from(normalized_pattern)
179            } else {
180                base_dir.join(normalized_pattern)
181            };
182            if let Some(parent) = full_path.parent() {
183                // Watch the parent if it exists (or base_dir as fallback)
184                let dir_to_watch = if parent.is_dir() {
185                    parent.to_path_buf()
186                } else {
187                    base_dir.to_path_buf()
188                };
189                dirs_to_watch.insert(normalize_watch_path(&dir_to_watch));
190            }
191        }
192    }
193
194    Ok(dirs_to_watch)
195}
196
197/// Normalize a path string to use forward slashes for glob pattern matching.
198/// This ensures consistent behavior across Windows and Unix platforms.
199fn normalize_path_for_glob(path: &str) -> String {
200    path.replace('\\', "/")
201}
202
203/// Check if a changed path matches any of the watch patterns.
204/// Uses globset which properly supports ** for recursive directory matching.
205pub fn path_matches_patterns(changed_path: &Path, patterns: &[String], base_dir: &Path) -> bool {
206    // Normalize the changed path to use forward slashes for consistent matching
207    let changed_path_str = normalize_path_for_glob(&changed_path.to_string_lossy());
208
209    for pattern in patterns {
210        // Strip leading "./" from patterns to handle relative path prefixes
211        let normalized_pattern = pattern.strip_prefix("./").unwrap_or(pattern);
212
213        // Build the full pattern and normalize to use forward slashes
214        let full_pattern = if Path::new(normalized_pattern).is_absolute() {
215            normalize_path_for_glob(normalized_pattern)
216        } else {
217            normalize_path_for_glob(&base_dir.join(normalized_pattern).to_string_lossy())
218        };
219
220        // Use globset which properly supports ** for recursive matching
221        let glob = globset::GlobBuilder::new(&full_pattern)
222            .case_insensitive(cfg!(target_os = "windows"))
223            .literal_separator(true) // * doesn't match /, use ** for recursive
224            .build();
225
226        if let Ok(glob) = glob {
227            let matcher = glob.compile_matcher();
228            if matcher.is_match(&changed_path_str) {
229                return true;
230            }
231        }
232    }
233    false
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use std::fs;
240    use tempfile::TempDir;
241
242    #[test]
243    fn test_normalize_watch_path_existing_directory() {
244        let temp_dir = TempDir::new().unwrap();
245        let dir_path = temp_dir.path().join("test_dir");
246        fs::create_dir(&dir_path).unwrap();
247
248        // Canonicalize should work for existing directories
249        let normalized = normalize_watch_path(&dir_path);
250        assert!(normalized.is_absolute());
251        assert!(normalized.exists());
252    }
253
254    #[test]
255    fn test_normalize_watch_path_nonexistent_path() {
256        let path = PathBuf::from("/nonexistent/path/to/dir");
257
258        // Should return the original path when canonicalization fails
259        let normalized = normalize_watch_path(&path);
260        assert_eq!(normalized, path);
261    }
262
263    #[test]
264    fn test_normalize_watch_path_deduplication() {
265        let temp_dir = TempDir::new().unwrap();
266        let dir_path = temp_dir.path().join("test_dir");
267        fs::create_dir(&dir_path).unwrap();
268
269        // Create a subdirectory to test path traversal
270        let subdir = dir_path.join("subdir");
271        fs::create_dir(&subdir).unwrap();
272
273        // Create two different relative paths pointing to the same directory
274        // One is direct, the other uses parent/child traversal
275        let path1 = subdir.clone();
276        let path2 = subdir.join("..").join("subdir");
277
278        let normalized1 = normalize_watch_path(&path1);
279        let normalized2 = normalize_watch_path(&path2);
280
281        // Both should canonicalize to the same path
282        assert_eq!(normalized1, normalized2);
283    }
284
285    #[test]
286    fn test_expand_watch_patterns_specific_file() {
287        let temp_dir = TempDir::new().unwrap();
288        let base_dir = temp_dir.path();
289
290        // Create a test file
291        let test_file = base_dir.join("package.json");
292        fs::write(&test_file, "{}").unwrap();
293
294        // Expand pattern for a specific file
295        let patterns = vec!["package.json".to_string()];
296        let dirs = expand_watch_patterns(&patterns, base_dir).unwrap();
297
298        // Should watch the parent directory
299        assert_eq!(dirs.len(), 1);
300        let dir = dirs.iter().next().unwrap();
301        assert!(dir.is_absolute());
302    }
303
304    #[test]
305    fn test_expand_watch_patterns_glob() {
306        let temp_dir = TempDir::new().unwrap();
307        let base_dir = temp_dir.path();
308        let subdir = base_dir.join("src");
309        fs::create_dir(&subdir).unwrap();
310
311        // Create test files in src directory
312        fs::write(subdir.join("file1.rs"), "").unwrap();
313        fs::write(subdir.join("file2.rs"), "").unwrap();
314
315        // Expand glob pattern
316        let patterns = vec!["src/**/*.rs".to_string()];
317        let dirs = expand_watch_patterns(&patterns, base_dir).unwrap();
318
319        // Should watch the src directory
320        assert!(!dirs.is_empty());
321        for dir in &dirs {
322            assert!(dir.is_absolute());
323        }
324    }
325
326    #[test]
327    fn test_expand_watch_patterns_nonexistent_file() {
328        let temp_dir = TempDir::new().unwrap();
329        let base_dir = temp_dir.path();
330
331        // Pattern for a file that doesn't exist yet
332        let patterns = vec!["config.toml".to_string()];
333        let dirs = expand_watch_patterns(&patterns, base_dir).unwrap();
334
335        // Should still watch the parent directory (base_dir in this case)
336        assert_eq!(dirs.len(), 1);
337    }
338
339    #[test]
340    fn test_path_matches_patterns_simple() {
341        let temp_dir = TempDir::new().unwrap();
342        let base_dir = temp_dir.path();
343
344        // Create test files
345        let test_txt = base_dir.join("test.txt");
346        let test_rs = base_dir.join("test.rs");
347        fs::write(&test_txt, "").unwrap();
348        fs::write(&test_rs, "").unwrap();
349
350        // Simple pattern match
351        assert!(path_matches_patterns(
352            &test_txt,
353            &["*.txt".to_string()],
354            base_dir
355        ));
356
357        // Non-matching pattern
358        assert!(!path_matches_patterns(
359            &test_rs,
360            &["*.txt".to_string()],
361            base_dir
362        ));
363    }
364
365    #[test]
366    fn test_path_matches_patterns_recursive_glob() {
367        let temp_dir = TempDir::new().unwrap();
368        let base_dir = temp_dir.path();
369        let src_dir = base_dir.join("src");
370        let deep_dir = src_dir.join("deep");
371        fs::create_dir_all(&deep_dir).unwrap();
372
373        // Create test files
374        let deep_file = deep_dir.join("file.rs");
375        let src_file = src_dir.join("file.rs");
376        fs::write(&deep_file, "").unwrap();
377        fs::write(&src_file, "").unwrap();
378
379        // ** pattern should match any depth
380        assert!(path_matches_patterns(
381            &deep_file,
382            &["src/**/*.rs".to_string()],
383            base_dir
384        ));
385
386        // Should also match top-level
387        assert!(path_matches_patterns(
388            &src_file,
389            &["src/**/*.rs".to_string()],
390            base_dir
391        ));
392    }
393
394    #[test]
395    fn test_path_matches_patterns_multiple_patterns() {
396        let temp_dir = TempDir::new().unwrap();
397        let base_dir = temp_dir.path();
398
399        // Create test files
400        let cargo_toml = base_dir.join("Cargo.toml");
401        let main_rs = base_dir.join("main.rs");
402        let readme_md = base_dir.join("README.md");
403        fs::write(&cargo_toml, "").unwrap();
404        fs::write(&main_rs, "").unwrap();
405        fs::write(&readme_md, "").unwrap();
406
407        // Multiple patterns - should match if any pattern matches
408        let patterns = vec!["*.rs".to_string(), "*.toml".to_string()];
409        assert!(path_matches_patterns(&cargo_toml, &patterns, base_dir));
410        assert!(path_matches_patterns(&main_rs, &patterns, base_dir));
411        assert!(!path_matches_patterns(&readme_md, &patterns, base_dir));
412    }
413
414    #[test]
415    fn test_path_matches_patterns_relative_prefix() {
416        let temp_dir = TempDir::new().unwrap();
417        let base_dir = temp_dir.path();
418
419        // Create a test file
420        let test_file = base_dir.join("config.json");
421        fs::write(&test_file, "{}").unwrap();
422
423        // Pattern with "./" prefix should match the file
424        assert!(path_matches_patterns(
425            &test_file,
426            &["./config.json".to_string()],
427            base_dir
428        ));
429
430        // Same pattern without prefix should also match
431        assert!(path_matches_patterns(
432            &test_file,
433            &["config.json".to_string()],
434            base_dir
435        ));
436    }
437
438    #[test]
439    fn test_expand_watch_patterns_relative_prefix() {
440        let temp_dir = TempDir::new().unwrap();
441        let base_dir = temp_dir.path();
442
443        // Create a test file
444        let test_file = base_dir.join("config.json");
445        fs::write(&test_file, "{}").unwrap();
446
447        // Pattern with "./" prefix should expand correctly
448        let patterns = vec!["./config.json".to_string()];
449        let dirs = expand_watch_patterns(&patterns, base_dir).unwrap();
450
451        // Should watch the parent directory
452        assert_eq!(dirs.len(), 1);
453        let dir = dirs.iter().next().unwrap();
454        assert!(dir.is_absolute());
455    }
456}