use std::collections::HashSet;
use std::fs;
use std::io::Read;
use std::path::{Path, PathBuf};
use crate::types::Language;
use crate::TldrError;
#[derive(Debug, Clone)]
pub struct WalkOptions {
pub lang: Option<Language>,
pub exclude: Vec<String>,
pub include_hidden: bool,
pub gitignore: bool,
pub max_files: usize,
}
impl Default for WalkOptions {
fn default() -> Self {
Self {
lang: None,
exclude: Vec::new(),
include_hidden: false,
gitignore: true,
max_files: 0,
}
}
}
pub fn walk_source_files(
path: &Path,
options: &WalkOptions,
) -> Result<(Vec<PathBuf>, Vec<String>), TldrError> {
if !path.exists() {
return Err(TldrError::PathNotFound(path.to_path_buf()));
}
if path.is_file() {
return Ok((vec![path.to_path_buf()], vec![]));
}
let mut files = Vec::new();
let mut warnings = Vec::new();
let mut had_entries = false;
let mut builder = ignore::WalkBuilder::new(path);
builder.follow_links(false); builder.hidden(!options.include_hidden);
if options.gitignore {
builder.git_ignore(true);
builder.git_global(true);
} else {
builder.git_ignore(false);
builder.git_global(false);
}
for entry in builder.build() {
let entry = match entry {
Ok(e) => e,
Err(e) => {
warnings.push(format!("Walk error: {}", e));
continue;
}
};
let entry_path = entry.path();
if entry_path.is_dir() {
continue;
}
had_entries = true;
if options.max_files > 0 && files.len() >= options.max_files {
warnings.push(format!(
"Stopped after {} files (max_files limit)",
options.max_files
));
break;
}
let relative_path = entry_path.strip_prefix(path).unwrap_or(entry_path);
if should_skip_path(relative_path) {
continue;
}
if should_exclude(relative_path, &options.exclude) {
continue;
}
let lang = match Language::from_path(entry_path) {
Some(l) => l,
None => continue,
};
if let Some(filter_lang) = options.lang {
if lang != filter_lang {
continue;
}
}
files.push(entry_path.to_path_buf());
}
if files.is_empty() && had_entries {
warnings.push(format!(
"No supported source files found in {}",
path.display()
));
}
Ok((files, warnings))
}
pub fn should_exclude(path: &Path, patterns: &[String]) -> bool {
let path_str = path.to_string_lossy();
for pattern in patterns {
if let Ok(glob) = glob::Pattern::new(pattern) {
if glob.matches(&path_str) {
return true;
}
}
}
false
}
pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024;
pub const DEFAULT_MAX_FILE_SIZE_MB: usize = 10;
const SKIP_DIRS: &[&str] = &[
"node_modules",
".git",
".svn",
".hg",
"__pycache__",
".pytest_cache",
".mypy_cache",
".tox",
".venv",
"venv",
".env",
"target",
"build",
"dist",
".idea",
".vscode",
".next",
".nuxt",
"coverage",
".coverage",
];
const BINARY_EXTENSIONS: &[&str] = &[
"png", "jpg", "jpeg", "gif", "bmp", "ico", "webp", "svg", "tiff", "psd",
"mp3", "mp4", "avi", "mkv", "mov", "wav", "flac", "ogg", "webm", "zip", "tar", "gz", "bz2", "xz", "7z", "rar", "exe", "dll", "so", "dylib", "a", "o", "obj", "class", "pyc", "pyo",
"pdf", "doc", "docx", "xls", "xlsx", "ppt", "pptx", "db", "sqlite", "sqlite3", "ttf", "otf", "woff", "woff2", "eot", "lock", "bin", "dat", "pak",
];
pub fn check_file_size(path: &Path, max_mb: usize) -> Result<(), TldrError> {
let metadata = fs::metadata(path)?;
let size_bytes = metadata.len() as usize;
let max_bytes = max_mb * 1024 * 1024;
if size_bytes > max_bytes {
let size_mb = size_bytes / (1024 * 1024);
return Err(TldrError::FileTooLarge {
path: path.to_path_buf(),
size_mb,
max_mb,
});
}
Ok(())
}
pub fn get_file_size(path: &Path) -> Result<usize, TldrError> {
let metadata = fs::metadata(path)?;
Ok(metadata.len() as usize)
}
pub fn is_binary_file(path: &Path) -> bool {
if let Some(ext) = path.extension() {
if let Some(ext_str) = ext.to_str() {
if BINARY_EXTENSIONS.contains(&ext_str.to_lowercase().as_str()) {
return true;
}
}
}
match fs::File::open(path) {
Ok(mut file) => {
let mut buffer = [0u8; 8192];
match file.read(&mut buffer) {
Ok(bytes_read) => buffer[..bytes_read].contains(&0),
Err(_) => false, }
}
Err(_) => false, }
}
pub fn has_binary_extension(path: &Path) -> bool {
if let Some(ext) = path.extension() {
if let Some(ext_str) = ext.to_str() {
return BINARY_EXTENSIONS.contains(&ext_str.to_lowercase().as_str());
}
}
false
}
pub fn should_skip_path(path: &Path) -> bool {
for component in path.components() {
if let std::path::Component::Normal(name) = component {
if let Some(name_str) = name.to_str() {
if name_str.starts_with('.') && name_str.len() > 1 {
if !matches!(name_str, ".github" | ".claude") {
return true;
}
}
if SKIP_DIRS.contains(&name_str) {
return true;
}
}
}
}
false
}
pub fn skip_directories() -> HashSet<&'static str> {
SKIP_DIRS.iter().copied().collect()
}
pub fn resolve_symlink_safely(
path: &Path,
project_root: Option<&Path>,
) -> Result<PathBuf, TldrError> {
let mut visited_links = HashSet::new();
let mut current = path.to_path_buf();
const MAX_DEPTH: usize = 40;
for _ in 0..MAX_DEPTH {
let metadata = match fs::symlink_metadata(¤t) {
Ok(m) => m,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
return Err(TldrError::PathNotFound(current));
}
Err(e) => return Err(TldrError::IoError(e)),
};
if metadata.file_type().is_symlink() {
let link_abs = if current.is_absolute() {
current.clone()
} else {
std::env::current_dir()
.map(|cwd| cwd.join(¤t))
.unwrap_or_else(|_| current.clone())
};
if visited_links.contains(&link_abs) {
return Err(TldrError::SymlinkCycle(path.to_path_buf()));
}
visited_links.insert(link_abs);
let target = fs::read_link(¤t)?;
current = if target.is_relative() {
current.parent().map(|p| p.join(&target)).unwrap_or(target)
} else {
target
};
} else {
let canonical = match current.canonicalize() {
Ok(c) => c,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
return Err(TldrError::PathNotFound(current));
}
Err(e) => return Err(TldrError::IoError(e)),
};
if let Some(root) = project_root {
let root_canonical = root.canonicalize().unwrap_or_else(|_| root.to_path_buf());
if !canonical.starts_with(&root_canonical) {
return Err(TldrError::PathTraversal(path.to_path_buf()));
}
}
return Ok(canonical);
}
}
Err(TldrError::SymlinkCycle(path.to_path_buf()))
}
pub fn is_symlink(path: &Path) -> bool {
fs::symlink_metadata(path)
.map(|m| m.file_type().is_symlink())
.unwrap_or(false)
}
pub fn is_path_within_project(path: &Path, project_root: &Path) -> bool {
let path_canonical = match path.canonicalize() {
Ok(p) => p,
Err(_) => return false,
};
let root_canonical = match project_root.canonicalize() {
Ok(p) => p,
Err(_) => return false,
};
path_canonical.starts_with(&root_canonical)
}
pub fn contains_path_traversal(path: &Path) -> bool {
for component in path.components() {
if let std::path::Component::ParentDir = component {
return true;
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::{tempdir, NamedTempFile};
#[test]
fn test_check_file_size_within_limit() {
let mut file = NamedTempFile::new().unwrap();
write!(file, "small content").unwrap();
assert!(check_file_size(file.path(), 10).is_ok());
}
#[test]
fn test_check_file_size_exceeds_limit() {
let mut file = NamedTempFile::new().unwrap();
let data = vec![b'x'; 2 * 1024 * 1024];
file.write_all(&data).unwrap();
let result = check_file_size(file.path(), 1); assert!(matches!(result, Err(TldrError::FileTooLarge { .. })));
}
#[test]
fn test_get_file_size() {
let mut file = NamedTempFile::new().unwrap();
write!(file, "hello world").unwrap();
let size = get_file_size(file.path()).unwrap();
assert_eq!(size, 11);
}
#[test]
fn test_is_binary_file_by_content() {
let mut file = NamedTempFile::new().unwrap();
file.write_all(&[0x00, 0x01, 0x02, 0x00]).unwrap();
assert!(is_binary_file(file.path()));
}
#[test]
fn test_is_binary_file_text_content() {
let mut file = NamedTempFile::new().unwrap();
write!(file, "def foo():\n pass\n").unwrap();
assert!(!is_binary_file(file.path()));
}
#[test]
fn test_has_binary_extension() {
assert!(has_binary_extension(Path::new("image.png")));
assert!(has_binary_extension(Path::new("archive.zip")));
assert!(has_binary_extension(Path::new("binary.exe")));
assert!(!has_binary_extension(Path::new("code.py")));
assert!(!has_binary_extension(Path::new("script.rs")));
}
#[test]
fn test_should_skip_path_node_modules() {
assert!(should_skip_path(Path::new("node_modules/package/index.js")));
assert!(should_skip_path(Path::new(
"project/node_modules/lodash/index.js"
)));
}
#[test]
fn test_should_skip_path_git() {
assert!(should_skip_path(Path::new(".git/objects/abc")));
assert!(should_skip_path(Path::new("repo/.git/HEAD")));
}
#[test]
fn test_should_skip_path_pycache() {
assert!(should_skip_path(Path::new("__pycache__/module.pyc")));
}
#[test]
fn test_should_skip_path_hidden() {
assert!(should_skip_path(Path::new(".hidden/file")));
assert!(should_skip_path(Path::new("dir/.hidden_file")));
}
#[test]
fn test_should_not_skip_regular_path() {
assert!(!should_skip_path(Path::new("src/main.rs")));
assert!(!should_skip_path(Path::new("lib/utils/helper.py")));
}
#[test]
fn test_should_not_skip_github() {
assert!(!should_skip_path(Path::new(".github/workflows/ci.yml")));
}
#[test]
fn test_resolve_symlink_regular_file() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("regular_file.txt");
fs::write(&file_path, "content").unwrap();
let resolved = resolve_symlink_safely(&file_path, None).unwrap();
assert_eq!(resolved, file_path.canonicalize().unwrap());
}
#[cfg(unix)]
#[test]
fn test_resolve_symlink_valid_link() {
let dir = tempdir().unwrap();
let target = dir.path().join("target.txt");
let link = dir.path().join("link.txt");
fs::write(&target, "content").unwrap();
std::os::unix::fs::symlink(&target, &link).unwrap();
let resolved = resolve_symlink_safely(&link, None).unwrap();
assert_eq!(resolved, target.canonicalize().unwrap());
}
#[cfg(unix)]
#[test]
fn test_resolve_symlink_outside_project() {
let project_dir = tempdir().unwrap();
let outside_dir = tempdir().unwrap();
let outside_file = outside_dir.path().join("outside.txt");
let link = project_dir.path().join("link.txt");
fs::write(&outside_file, "content").unwrap();
std::os::unix::fs::symlink(&outside_file, &link).unwrap();
let result = resolve_symlink_safely(&link, Some(project_dir.path()));
assert!(matches!(result, Err(TldrError::PathTraversal(_))));
}
#[test]
fn test_is_symlink_regular_file() {
let file = NamedTempFile::new().unwrap();
assert!(!is_symlink(file.path()));
}
#[test]
fn test_is_path_within_project_valid() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("src/main.rs");
fs::create_dir_all(file_path.parent().unwrap()).unwrap();
fs::write(&file_path, "fn main() {}").unwrap();
assert!(is_path_within_project(&file_path, dir.path()));
}
#[test]
fn test_contains_path_traversal() {
assert!(contains_path_traversal(Path::new("../outside")));
assert!(contains_path_traversal(Path::new("dir/../other")));
assert!(!contains_path_traversal(Path::new("dir/subdir/file")));
}
#[test]
fn test_skip_directories() {
let dirs = skip_directories();
assert!(dirs.contains("node_modules"));
assert!(dirs.contains(".git"));
assert!(dirs.contains("__pycache__"));
assert!(!dirs.contains("src"));
}
#[test]
fn test_walk_source_files_single_file() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("main.py");
fs::write(&file_path, "def main(): pass").unwrap();
let options = WalkOptions::default();
let (files, warnings) = walk_source_files(&file_path, &options).unwrap();
assert_eq!(
files.len(),
1,
"Single file should return vec with one entry"
);
assert_eq!(files[0], file_path);
assert!(warnings.is_empty(), "No warnings for single file");
}
#[test]
fn test_walk_source_files_directory_returns_source_files() {
let dir = tempdir().unwrap();
fs::write(dir.path().join("main.py"), "def main(): pass").unwrap();
fs::write(dir.path().join("lib.rs"), "fn main() {}").unwrap();
fs::write(dir.path().join("app.js"), "function app() {}").unwrap();
let options = WalkOptions::default();
let (files, _warnings) = walk_source_files(dir.path(), &options).unwrap();
assert!(
files.len() >= 3,
"Directory walk should find at least 3 source files, found {}",
files.len()
);
}
#[test]
fn test_walk_source_files_language_filter() {
let dir = tempdir().unwrap();
fs::write(dir.path().join("main.py"), "def main(): pass").unwrap();
fs::write(dir.path().join("lib.rs"), "fn main() {}").unwrap();
fs::write(dir.path().join("app.js"), "function app() {}").unwrap();
let options = WalkOptions {
lang: Some(Language::Python),
..WalkOptions::default()
};
let (files, _warnings) = walk_source_files(dir.path(), &options).unwrap();
assert_eq!(
files.len(),
1,
"Language filter should return only Python files"
);
assert!(
files[0].extension().unwrap() == "py",
"Filtered file should be .py"
);
}
#[test]
fn test_walk_source_files_empty_directory() {
let dir = tempdir().unwrap();
let options = WalkOptions::default();
let (files, _warnings) = walk_source_files(dir.path(), &options).unwrap();
assert!(files.is_empty(), "Empty directory should return empty vec");
}
#[test]
fn test_walk_source_files_skips_non_source_files() {
let dir = tempdir().unwrap();
fs::write(dir.path().join("readme.md"), "# README").unwrap();
fs::write(dir.path().join("notes.txt"), "some notes").unwrap();
fs::write(dir.path().join("Cargo.lock"), "lock file").unwrap();
fs::write(dir.path().join("main.py"), "def main(): pass").unwrap();
let options = WalkOptions::default();
let (files, _warnings) = walk_source_files(dir.path(), &options).unwrap();
assert_eq!(
files.len(),
1,
"Should only return source files, not .md/.txt/.lock. Found: {:?}",
files
);
}
#[test]
fn test_walk_source_files_respects_gitignore() {
let dir = tempdir().unwrap();
fs::write(dir.path().join(".gitignore"), "*.log\n").unwrap();
fs::write(dir.path().join("main.py"), "def main(): pass").unwrap();
fs::write(dir.path().join("debug.log"), "log data").unwrap();
std::process::Command::new("git")
.args(["init"])
.current_dir(dir.path())
.output()
.ok();
let options = WalkOptions {
gitignore: true,
..WalkOptions::default()
};
let (files, _warnings) = walk_source_files(dir.path(), &options).unwrap();
assert!(!files.is_empty(), "Should find at least the .py file");
for f in &files {
assert_ne!(
f.extension().and_then(|e| e.to_str()),
Some("log"),
"Should not include gitignored files"
);
}
}
#[test]
fn test_walk_source_files_nonexistent_path() {
let options = WalkOptions::default();
let result = walk_source_files(Path::new("/nonexistent/path/xyz"), &options);
assert!(result.is_err(), "Nonexistent path should return error");
}
#[test]
fn test_should_exclude_matching_pattern() {
let patterns = vec!["*.test.py".to_string()];
assert!(should_exclude(Path::new("test_foo.test.py"), &patterns));
}
#[test]
fn test_should_exclude_no_match() {
let patterns = vec!["*.test.py".to_string()];
assert!(!should_exclude(Path::new("main.py"), &patterns));
}
#[test]
fn test_should_exclude_empty_patterns() {
let patterns: Vec<String> = vec![];
assert!(!should_exclude(Path::new("anything.py"), &patterns));
}
}