Skip to main content

baracuda_forge/
source.rs

1//! Source file selection and filtering.
2
3use crate::error::{Error, Result};
4use std::path::{Path, PathBuf};
5use walkdir::WalkDir;
6
7/// Source file selection configuration.
8#[derive(Debug, Clone, Default)]
9pub struct SourceSelector {
10    includes: Vec<SourcePath>,
11    excludes: Vec<String>,
12    watch_paths: Vec<PathBuf>,
13}
14
15#[derive(Debug, Clone)]
16enum SourcePath {
17    File(PathBuf),
18    Directory(PathBuf),
19    Glob(String),
20}
21
22impl SourceSelector {
23    /// Create a new empty source selector.
24    pub fn new() -> Self {
25        Self::default()
26    }
27
28    /// Add a directory to search for `.cu` files (recursive).
29    pub fn add_directory<P: AsRef<Path>>(mut self, dir: P) -> Self {
30        self.includes
31            .push(SourcePath::Directory(dir.as_ref().to_path_buf()));
32        self
33    }
34
35    /// Add specific files.
36    pub fn add_files<I, P>(mut self, files: I) -> Self
37    where
38        I: IntoIterator<Item = P>,
39        P: AsRef<Path>,
40    {
41        for file in files {
42            self.includes
43                .push(SourcePath::File(file.as_ref().to_path_buf()));
44        }
45        self
46    }
47
48    /// Add files matching a glob pattern.
49    pub fn add_glob(mut self, pattern: &str) -> Self {
50        self.includes.push(SourcePath::Glob(pattern.to_string()));
51        self
52    }
53
54    /// Exclude files matching patterns.
55    ///
56    /// Patterns can be:
57    /// - `"*_test.cu"` — files ending with `_test.cu`.
58    /// - `"deprecated/*"` — files in a directory.
59    /// - `"test_*.cu"` — files starting with `test_`.
60    pub fn exclude(mut self, patterns: &[&str]) -> Self {
61        for pattern in patterns {
62            self.excludes.push(pattern.to_string());
63        }
64        self
65    }
66
67    /// Add paths to watch for changes (headers, etc.).
68    pub fn watch<I, P>(mut self, paths: I) -> Self
69    where
70        I: IntoIterator<Item = P>,
71        P: AsRef<Path>,
72    {
73        for path in paths {
74            self.watch_paths.push(path.as_ref().to_path_buf());
75        }
76        self
77    }
78
79    /// Resolve all sources to a list of kernel files.
80    pub fn resolve(&self) -> Result<Vec<PathBuf>> {
81        let mut files = Vec::new();
82
83        if self.includes.is_empty() {
84            if let Ok(entries) = glob::glob("src/**/*.cu") {
85                for entry in entries.flatten() {
86                    if !self.is_excluded(&entry) {
87                        files.push(entry);
88                    }
89                }
90            }
91        } else {
92            for source in &self.includes {
93                match source {
94                    SourcePath::File(path) => {
95                        if !path.exists() {
96                            return Err(Error::SourcePathNotFound(path.clone()));
97                        }
98                        if !self.is_excluded(path) {
99                            files.push(path.clone());
100                        }
101                    }
102                    SourcePath::Directory(dir) => {
103                        if !dir.exists() {
104                            return Err(Error::SourcePathNotFound(dir.clone()));
105                        }
106                        self.collect_from_directory(dir, &mut files)?;
107                    }
108                    SourcePath::Glob(pattern) => {
109                        if let Ok(entries) = glob::glob(pattern) {
110                            for entry in entries.flatten() {
111                                if entry.extension().is_some_and(|e| e == "cu")
112                                    && !self.is_excluded(&entry)
113                                {
114                                    files.push(entry);
115                                }
116                            }
117                        }
118                    }
119                }
120            }
121        }
122
123        files.sort();
124        files.dedup();
125        Ok(files)
126    }
127
128    /// Get watch paths.
129    pub fn watch_paths(&self) -> &[PathBuf] {
130        &self.watch_paths
131    }
132
133    fn collect_from_directory(&self, dir: &Path, files: &mut Vec<PathBuf>) -> Result<()> {
134        for entry in WalkDir::new(dir).into_iter().filter_map(|e| e.ok()) {
135            let path = entry.path();
136            if path.is_file()
137                && path.extension().is_some_and(|e| e == "cu")
138                && !self.is_excluded(path)
139            {
140                files.push(path.to_path_buf());
141            }
142        }
143        Ok(())
144    }
145
146    fn is_excluded(&self, path: &Path) -> bool {
147        let filename = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
148        let path_str = path.to_string_lossy();
149
150        for pattern in &self.excludes {
151            if matches_exclusion_pattern(filename, &path_str, pattern) {
152                return true;
153            }
154        }
155        false
156    }
157}
158
159fn matches_exclusion_pattern(filename: &str, path_str: &str, pattern: &str) -> bool {
160    if pattern.contains('/') {
161        let pattern_parts: Vec<&str> = pattern.split('/').collect();
162        if pattern_parts.len() == 2 && pattern_parts[1] == "*" {
163            return path_str.contains(&format!("/{}/", pattern_parts[0]))
164                || path_str.contains(&format!("\\{}\\", pattern_parts[0]));
165        }
166    }
167
168    if pattern.contains('*') {
169        let parts: Vec<&str> = pattern.split('*').collect();
170        if parts.len() == 2 {
171            let (prefix, suffix) = (parts[0], parts[1]);
172            return filename.starts_with(prefix) && filename.ends_with(suffix);
173        }
174        if let Some(stripped) = pattern.strip_prefix('*') {
175            return filename.ends_with(stripped);
176        }
177        if let Some(stripped) = pattern.strip_suffix('*') {
178            return filename.starts_with(stripped);
179        }
180    }
181
182    filename == pattern
183}
184
185/// Collect header files (`.cuh`) from directories.
186pub fn collect_headers<P: AsRef<Path>>(dirs: &[P]) -> Vec<PathBuf> {
187    let mut headers = Vec::new();
188
189    for dir in dirs {
190        if let Ok(entries) = glob::glob(&format!("{}/**/*.cuh", dir.as_ref().display())) {
191            for entry in entries.flatten() {
192                headers.push(entry);
193            }
194        }
195    }
196
197    headers.sort();
198    headers.dedup();
199    headers
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_exclusion_patterns() {
208        assert!(matches_exclusion_pattern(
209            "test_kernel.cu",
210            "src/test_kernel.cu",
211            "test_*.cu"
212        ));
213        assert!(matches_exclusion_pattern(
214            "kernel_test.cu",
215            "src/kernel_test.cu",
216            "*_test.cu"
217        ));
218        assert!(!matches_exclusion_pattern(
219            "kernel.cu",
220            "src/kernel.cu",
221            "*_test.cu"
222        ));
223    }
224}