aptu_coder_core/
traversal.rs1use ignore::WalkBuilder;
9use std::collections::HashSet;
10use std::path::{Path, PathBuf};
11use std::process::Command;
12use std::sync::Arc;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::time::{Instant, SystemTime};
15use thiserror::Error;
16use tracing::instrument;
17
18pub const MAX_WALK_ENTRIES: usize = 50_000;
19
20#[derive(Debug, Clone)]
21pub struct WalkEntry {
22 pub path: PathBuf,
23 pub depth: usize,
25 pub is_dir: bool,
26 pub is_symlink: bool,
27 pub symlink_target: Option<PathBuf>,
28 pub mtime: Option<SystemTime>,
29 pub canonical_path: PathBuf,
30}
31
32#[derive(Debug, Error)]
33#[non_exhaustive]
34pub enum TraversalError {
35 #[error("IO error: {0}")]
36 Io(#[from] std::io::Error),
37 #[error("internal concurrency error: {0}")]
38 Internal(String),
39 #[error("git error: {0}")]
40 GitError(String),
41}
42
43#[instrument(skip_all, fields(path = %root.display(), max_depth))]
47pub fn walk_directory(
48 root: &Path,
49 max_depth: Option<u32>,
50) -> Result<Vec<WalkEntry>, TraversalError> {
51 let start = Instant::now();
52 let mut builder = WalkBuilder::new(root);
53 builder.hidden(true).standard_filters(true);
54
55 if let Some(depth) = max_depth
57 && depth > 0
58 {
59 builder.max_depth(Some(depth as usize));
60 }
61
62 let (sender, receiver) = std::sync::mpsc::channel::<WalkEntry>();
63 let entry_count = Arc::new(AtomicUsize::new(0));
64
65 builder.build_parallel().run(move || {
66 let sender = sender.clone();
67 let entry_count = entry_count.clone();
68 Box::new(move |result| match result {
69 Ok(entry) => {
70 let path = entry.path().to_path_buf();
71 let depth = entry.depth();
72 let is_dir = entry.file_type().is_some_and(|ft| ft.is_dir());
73 let is_symlink = entry.path_is_symlink();
74
75 let symlink_target = if is_symlink {
76 std::fs::read_link(&path).ok()
77 } else {
78 None
79 };
80
81 let mtime = entry.metadata().ok().and_then(|m| m.modified().ok());
82
83 let walk_entry = WalkEntry {
84 path: path.clone(),
85 depth,
86 is_dir,
87 is_symlink,
88 symlink_target,
89 mtime,
90 canonical_path: path.clone(),
91 };
92 sender.send(walk_entry).ok();
93 let count = entry_count.fetch_add(1, Ordering::Relaxed);
94 if count >= MAX_WALK_ENTRIES {
95 return ignore::WalkState::Quit;
96 }
97 ignore::WalkState::Continue
98 }
99 Err(e) => {
100 tracing::warn!(error = %e, "skipping unreadable entry");
101 ignore::WalkState::Continue
102 }
103 })
104 });
105
106 let mut entries: Vec<WalkEntry> = receiver.try_iter().collect();
107 entries.truncate(MAX_WALK_ENTRIES);
108 if entries.len() >= MAX_WALK_ENTRIES {
109 tracing::warn!(
110 "walk truncated at {} entries (MAX_WALK_ENTRIES={}); results are partial",
111 MAX_WALK_ENTRIES,
112 MAX_WALK_ENTRIES
113 );
114 }
115
116 let dir_count = entries.iter().filter(|e| e.is_dir).count();
117 let file_count = entries.iter().filter(|e| !e.is_dir).count();
118
119 tracing::debug!(
120 entries = entries.len(),
121 dirs = dir_count,
122 files = file_count,
123 duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
124 "walk complete"
125 );
126
127 entries.sort_by(|a, b| a.path.cmp(&b.path));
129 Ok(entries)
130}
131
132pub fn changed_files_from_git_ref(
140 dir: &Path,
141 git_ref: &str,
142) -> Result<HashSet<PathBuf>, TraversalError> {
143 if git_ref.is_empty() || git_ref.starts_with('-') {
147 return Err(TraversalError::GitError(
148 "invalid git_ref: must not be empty or start with '-'".to_string(),
149 ));
150 }
151 if git_ref.chars().any(|c| {
152 c.is_whitespace()
153 || matches!(
154 c,
155 '|' | '&'
156 | ';'
157 | '>'
158 | '<'
159 | '`'
160 | '$'
161 | '('
162 | ')'
163 | '{'
164 | '}'
165 | '['
166 | ']'
167 | '*'
168 | '?'
169 | '\\'
170 | '"'
171 | '\''
172 )
173 }) {
174 return Err(TraversalError::GitError(
175 "invalid git_ref: contains forbidden characters".to_string(),
176 ));
177 }
178
179 let root_out = Command::new("git")
182 .arg("-C")
183 .arg(dir)
184 .arg("rev-parse")
185 .arg("--show-toplevel")
186 .output()
187 .map_err(|e| {
188 if e.kind() == std::io::ErrorKind::NotFound {
189 TraversalError::GitError("git not found on PATH".to_string())
190 } else {
191 TraversalError::GitError(format!("failed to run git: {e}"))
192 }
193 })?;
194
195 if !root_out.status.success() {
196 let stderr = String::from_utf8_lossy(&root_out.stderr);
197 return Err(TraversalError::GitError(format!(
198 "not a git repository: {stderr}"
199 )));
200 }
201
202 let root_raw = PathBuf::from(String::from_utf8_lossy(&root_out.stdout).trim().to_string());
203 let root = std::fs::canonicalize(&root_raw).unwrap_or(root_raw);
206
207 let diff_out = Command::new("git")
209 .arg("-C")
210 .arg(dir)
211 .arg("diff")
212 .arg("--name-only")
213 .arg(git_ref)
214 .output()
215 .map_err(|e| TraversalError::GitError(format!("failed to run git diff: {e}")))?;
216
217 if !diff_out.status.success() {
218 let stderr = String::from_utf8_lossy(&diff_out.stderr);
219 return Err(TraversalError::GitError(format!(
220 "git diff failed: {stderr}"
221 )));
222 }
223
224 let changed: HashSet<PathBuf> = String::from_utf8_lossy(&diff_out.stdout)
225 .lines()
226 .filter(|l| !l.is_empty())
227 .map(|l| root.join(l))
228 .collect();
229
230 Ok(changed)
231}
232
233#[must_use]
239pub fn filter_entries_by_git_ref(
240 entries: Vec<WalkEntry>,
241 changed: &HashSet<PathBuf>,
242 root: &Path,
243) -> Vec<WalkEntry> {
244 let canonical_root = std::fs::canonicalize(root).unwrap_or_else(|_| root.to_path_buf());
245
246 let canonical_changed: HashSet<PathBuf> = changed
248 .iter()
249 .map(|p| std::fs::canonicalize(p).unwrap_or_else(|_| p.clone()))
250 .collect();
251
252 let mut ancestor_dirs: HashSet<PathBuf> = HashSet::new();
254 ancestor_dirs.insert(canonical_root.clone());
255 for p in &canonical_changed {
256 let mut cur = p.as_path();
257 while let Some(parent) = cur.parent() {
258 if !ancestor_dirs.insert(parent.to_path_buf()) {
259 break;
261 }
262 if parent == canonical_root {
263 break;
264 }
265 cur = parent;
266 }
267 }
268
269 entries
270 .into_iter()
271 .filter(|e| {
272 let canonical = std::fs::canonicalize(&e.path).unwrap_or_else(|_| e.path.clone());
273 if e.is_dir {
274 ancestor_dirs.contains(&canonical)
275 } else {
276 canonical_changed.contains(&canonical)
277 }
278 })
279 .collect()
280}
281
282#[must_use]
287pub fn subtree_counts_from_entries(root: &Path, entries: &[WalkEntry]) -> Vec<(PathBuf, usize)> {
288 let mut counts: Vec<(PathBuf, usize)> = Vec::new();
289 for entry in entries {
290 if entry.is_dir {
291 continue;
292 }
293 if entry.path.components().any(|c| {
295 let s = c.as_os_str().to_string_lossy();
296 crate::EXCLUDED_DIRS.contains(&s.as_ref())
297 }) {
298 continue;
299 }
300 let Ok(rel) = entry.path.strip_prefix(root) else {
301 continue;
302 };
303 if let Some(first) = rel.components().next() {
304 let key = root.join(first);
305 match counts.last_mut() {
306 Some(last) if last.0 == key => last.1 += 1,
307 _ => counts.push((key, 1)),
308 }
309 }
310 }
311 counts
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[test]
319 fn test_git_ref_injection_rejected() {
320 let tmp = tempfile::tempdir().unwrap();
321 let tmp_path = tmp.path();
322
323 let result = changed_files_from_git_ref(tmp_path, "--output=/tmp/evil");
325 assert!(result.is_err(), "should reject git_ref starting with '-'");
326
327 let result = changed_files_from_git_ref(tmp_path, "--");
329 assert!(result.is_err(), "should reject git_ref starting with '-'");
330
331 let result = changed_files_from_git_ref(tmp_path, "main branch");
333 assert!(result.is_err(), "should reject git_ref with spaces");
334
335 let result = changed_files_from_git_ref(tmp_path, "");
337 assert!(result.is_err(), "should reject empty git_ref");
338
339 let result = changed_files_from_git_ref(tmp_path, "HEAD~1");
342 if let Err(TraversalError::GitError(msg)) = result {
344 assert!(
345 !msg.contains("invalid git_ref"),
346 "HEAD~1 should pass validation"
347 );
348 }
349
350 let result = changed_files_from_git_ref(tmp_path, "main");
352 if let Err(TraversalError::GitError(msg)) = result {
353 assert!(
354 !msg.contains("invalid git_ref"),
355 "main should pass validation"
356 );
357 }
358
359 let result = changed_files_from_git_ref(tmp_path, "abc123");
361 if let Err(TraversalError::GitError(msg)) = result {
362 assert!(
363 !msg.contains("invalid git_ref"),
364 "abc123 should pass validation"
365 );
366 }
367 }
368}