1use crate::error::{Error, Result};
4use std::path::{Path, PathBuf};
5use walkdir::WalkDir;
6
7#[derive(Debug, Clone, Default)]
9pub struct SourceSelector {
10 includes: Vec<SourcePath>,
11 excludes: Vec<String>,
12 watch_paths: Vec<PathBuf>,
13}
14
15#[derive(Debug, Clone)]
16enum SourcePath {
17 File(PathBuf),
18 Directory(PathBuf),
19 Glob(String),
20}
21
22impl SourceSelector {
23 pub fn new() -> Self {
25 Self::default()
26 }
27
28 pub fn add_directory<P: AsRef<Path>>(mut self, dir: P) -> Self {
30 self.includes
31 .push(SourcePath::Directory(dir.as_ref().to_path_buf()));
32 self
33 }
34
35 pub fn add_files<I, P>(mut self, files: I) -> Self
37 where
38 I: IntoIterator<Item = P>,
39 P: AsRef<Path>,
40 {
41 for file in files {
42 self.includes
43 .push(SourcePath::File(file.as_ref().to_path_buf()));
44 }
45 self
46 }
47
48 pub fn add_glob(mut self, pattern: &str) -> Self {
50 self.includes.push(SourcePath::Glob(pattern.to_string()));
51 self
52 }
53
54 pub fn exclude(mut self, patterns: &[&str]) -> Self {
61 for pattern in patterns {
62 self.excludes.push(pattern.to_string());
63 }
64 self
65 }
66
67 pub fn watch<I, P>(mut self, paths: I) -> Self
69 where
70 I: IntoIterator<Item = P>,
71 P: AsRef<Path>,
72 {
73 for path in paths {
74 self.watch_paths.push(path.as_ref().to_path_buf());
75 }
76 self
77 }
78
79 pub fn resolve(&self) -> Result<Vec<PathBuf>> {
81 let mut files = Vec::new();
82
83 if self.includes.is_empty() {
84 if let Ok(entries) = glob::glob("src/**/*.cu") {
85 for entry in entries.flatten() {
86 if !self.is_excluded(&entry) {
87 files.push(entry);
88 }
89 }
90 }
91 } else {
92 for source in &self.includes {
93 match source {
94 SourcePath::File(path) => {
95 if !path.exists() {
96 return Err(Error::SourcePathNotFound(path.clone()));
97 }
98 if !self.is_excluded(path) {
99 files.push(path.clone());
100 }
101 }
102 SourcePath::Directory(dir) => {
103 if !dir.exists() {
104 return Err(Error::SourcePathNotFound(dir.clone()));
105 }
106 self.collect_from_directory(dir, &mut files)?;
107 }
108 SourcePath::Glob(pattern) => {
109 if let Ok(entries) = glob::glob(pattern) {
110 for entry in entries.flatten() {
111 if entry.extension().is_some_and(|e| e == "cu")
112 && !self.is_excluded(&entry)
113 {
114 files.push(entry);
115 }
116 }
117 }
118 }
119 }
120 }
121 }
122
123 files.sort();
124 files.dedup();
125 Ok(files)
126 }
127
128 pub fn watch_paths(&self) -> &[PathBuf] {
130 &self.watch_paths
131 }
132
133 fn collect_from_directory(&self, dir: &Path, files: &mut Vec<PathBuf>) -> Result<()> {
134 for entry in WalkDir::new(dir).into_iter().filter_map(|e| e.ok()) {
135 let path = entry.path();
136 if path.is_file()
137 && path.extension().is_some_and(|e| e == "cu")
138 && !self.is_excluded(path)
139 {
140 files.push(path.to_path_buf());
141 }
142 }
143 Ok(())
144 }
145
146 fn is_excluded(&self, path: &Path) -> bool {
147 let filename = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
148 let path_str = path.to_string_lossy();
149
150 for pattern in &self.excludes {
151 if matches_exclusion_pattern(filename, &path_str, pattern) {
152 return true;
153 }
154 }
155 false
156 }
157}
158
159fn matches_exclusion_pattern(filename: &str, path_str: &str, pattern: &str) -> bool {
160 if pattern.contains('/') {
161 let pattern_parts: Vec<&str> = pattern.split('/').collect();
162 if pattern_parts.len() == 2 && pattern_parts[1] == "*" {
163 return path_str.contains(&format!("/{}/", pattern_parts[0]))
164 || path_str.contains(&format!("\\{}\\", pattern_parts[0]));
165 }
166 }
167
168 if pattern.contains('*') {
169 let parts: Vec<&str> = pattern.split('*').collect();
170 if parts.len() == 2 {
171 let (prefix, suffix) = (parts[0], parts[1]);
172 return filename.starts_with(prefix) && filename.ends_with(suffix);
173 }
174 if let Some(stripped) = pattern.strip_prefix('*') {
175 return filename.ends_with(stripped);
176 }
177 if let Some(stripped) = pattern.strip_suffix('*') {
178 return filename.starts_with(stripped);
179 }
180 }
181
182 filename == pattern
183}
184
185pub fn collect_headers<P: AsRef<Path>>(dirs: &[P]) -> Vec<PathBuf> {
187 let mut headers = Vec::new();
188
189 for dir in dirs {
190 if let Ok(entries) = glob::glob(&format!("{}/**/*.cuh", dir.as_ref().display())) {
191 for entry in entries.flatten() {
192 headers.push(entry);
193 }
194 }
195 }
196
197 headers.sort();
198 headers.dedup();
199 headers
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn test_exclusion_patterns() {
208 assert!(matches_exclusion_pattern(
209 "test_kernel.cu",
210 "src/test_kernel.cu",
211 "test_*.cu"
212 ));
213 assert!(matches_exclusion_pattern(
214 "kernel_test.cu",
215 "src/kernel_test.cu",
216 "*_test.cu"
217 ));
218 assert!(!matches_exclusion_pattern(
219 "kernel.cu",
220 "src/kernel.cu",
221 "*_test.cu"
222 ));
223 }
224}