use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use ignore::WalkBuilder;
use rayon::prelude::*;
use tracing::{debug, warn};
use crate::ast::extractor::AstExtractor;
use crate::ast::types::{ClassSummary, CodeStructure, FunctionSummary, ModuleInfo};
use crate::error::{validate_path_containment, Result, BrrrError};
use crate::lang::LanguageRegistry;
pub fn code_structure(
path: &str,
lang_filter: Option<&str>,
max_results: usize,
no_ignore: bool,
) -> Result<CodeStructure> {
let input_path = Path::new(path);
let registry = LanguageRegistry::global();
if !input_path.exists() {
return Err(BrrrError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Path not found: {}", path),
)));
}
let root_path = input_path.canonicalize()?;
let mut walker_builder = WalkBuilder::new(&root_path);
if no_ignore {
walker_builder
.git_ignore(false)
.git_global(false)
.git_exclude(false)
.ignore(false);
} else {
walker_builder.add_custom_ignore_filename(".brrrignore");
}
walker_builder.hidden(true);
let resolved_lang_name: Option<&str> = lang_filter.map(|name| {
registry
.get_by_name(name)
.map(|lang| lang.name())
.unwrap_or(name) });
let files: Vec<_> = walker_builder
.build()
.filter_map(|entry| entry.ok())
.filter(|e| e.path().is_file())
.filter(|e| {
if let Some(target_name) = resolved_lang_name {
registry
.detect_language(e.path())
.is_some_and(|l| l.name() == target_name)
} else {
registry.detect_language(e.path()).is_some()
}
})
.collect();
let file_count = files.len();
debug!("Found {} source files to analyze", file_count);
let functions_found = AtomicUsize::new(0);
let classes_found = AtomicUsize::new(0);
let files_failed_counter = AtomicUsize::new(0);
let files_skipped_counter = AtomicUsize::new(0);
let early_termination_threshold = if max_results > 0 {
max_results.saturating_mul(2)
} else {
usize::MAX
};
let mut results: Vec<(String, ModuleInfo)> = files
.par_iter()
.filter_map(|entry| {
if max_results > 0 {
let funcs = functions_found.load(Ordering::Relaxed);
let cls = classes_found.load(Ordering::Relaxed);
if funcs >= early_termination_threshold && cls >= early_termination_threshold {
files_skipped_counter.fetch_add(1, Ordering::Relaxed);
return None;
}
}
let file_path = entry.path();
if let Err(e) = validate_path_containment(&root_path, file_path) {
match &e {
BrrrError::PathTraversal { target, base } => {
warn!(
file = %file_path.display(),
target = %target,
base = %base,
"Skipping file due to path traversal detection (security)"
);
}
BrrrError::Io(_) => {
debug!("Skipping unresolvable path: {}", file_path.display());
}
_ => {
warn!(
file = %file_path.display(),
error = %e,
"Skipping file due to path validation error"
);
}
}
files_skipped_counter.fetch_add(1, Ordering::Relaxed);
return None;
}
match AstExtractor::extract_file(file_path) {
Ok(module) => {
functions_found.fetch_add(module.functions.len(), Ordering::Relaxed);
classes_found.fetch_add(module.classes.len(), Ordering::Relaxed);
let rel_path = file_path
.strip_prefix(&root_path)
.unwrap_or(file_path)
.display()
.to_string();
Some((rel_path, module))
}
Err(e) => {
warn!(
file = %file_path.display(),
error = %e,
"Failed to extract AST from file"
);
files_failed_counter.fetch_add(1, Ordering::Relaxed);
None
}
}
})
.collect();
results.sort_by(|a, b| a.0.cmp(&b.0));
let files_processed = results.len();
let files_failed = files_failed_counter.load(Ordering::Relaxed);
let files_skipped = files_skipped_counter.load(Ordering::Relaxed);
if files_failed > 0 || files_skipped > 0 {
debug!(
"Extracted AST from {} files ({} failed, {} skipped) out of {} total",
files_processed, files_failed, files_skipped, file_count
);
} else {
debug!(
"Successfully extracted AST from all {} files",
files_processed
);
}
let mut functions = Vec::new();
let mut classes = Vec::new();
for (rel_path, module) in results {
for func in module.functions {
functions.push(FunctionSummary {
name: func.name.clone(),
file: rel_path.clone(),
line: func.line_number,
signature: func.signature(),
});
}
for class in module.classes {
classes.push(ClassSummary {
name: class.name,
file: rel_path.clone(),
line: class.line_number,
method_count: class.methods.len(),
});
}
}
functions.sort_by(|a, b| (&a.file, a.line).cmp(&(&b.file, b.line)));
classes.sort_by(|a, b| (&a.file, a.line).cmp(&(&b.file, b.line)));
if max_results > 0 {
functions.truncate(max_results);
classes.truncate(max_results);
}
Ok(CodeStructure {
path: path.to_string(),
functions,
classes,
files_processed,
files_failed,
files_skipped,
total_files: file_count,
})
}
pub fn extract_file(file_path: &str, base_path: Option<&str>) -> Result<ModuleInfo> {
if file_path.contains('\0') {
return Err(BrrrError::PathTraversal {
target: "<contains null byte>".to_string(),
base: base_path.unwrap_or("<single file extraction>").to_string(),
});
}
let path = Path::new(file_path);
if let Some(base) = base_path {
let base_path_obj = Path::new(base);
validate_path_containment(base_path_obj, path)?;
} else {
let mut depth: i32 = 0;
for component in path.components() {
match component {
std::path::Component::ParentDir => {
depth -= 1;
if depth < -10 {
return Err(BrrrError::PathTraversal {
target: file_path.to_string(),
base: "<single file extraction>".to_string(),
});
}
}
std::path::Component::Normal(_) => {
depth += 1;
}
_ => {}
}
}
}
AstExtractor::extract_file(path)
}
#[inline]
#[allow(dead_code)]
pub fn extract_file_unchecked(file_path: &str) -> Result<ModuleInfo> {
extract_file(file_path, None)
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
fn create_test_project() -> TempDir {
let dir = TempDir::new().expect("Failed to create temp dir");
let py_content = r#"
def hello(name: str) -> str:
"""Say hello to someone."""
return f"Hello, {name}!"
class Greeter:
"""A greeting class."""
def greet(self, name: str) -> str:
return hello(name)
"#;
fs::write(dir.path().join("main.py"), py_content).unwrap();
fs::write(dir.path().join("utils.py"), "def helper(): pass").unwrap();
let sub_dir = dir.path().join("lib");
fs::create_dir(&sub_dir).unwrap();
fs::write(sub_dir.join("core.py"), "def core_func(): pass").unwrap();
let ts_content = r#"
function greet(name: string): string {
return "Hello, " + name;
}
class Service {
process(): void {}
}
"#;
fs::write(dir.path().join("app.ts"), ts_content).unwrap();
dir
}
#[test]
fn test_code_structure_all_languages() {
let dir = create_test_project();
let result = code_structure(dir.path().to_str().unwrap(), None, 100, false).unwrap();
assert!(result.total_files >= 3);
assert!(!result.functions.is_empty());
assert!(!result.classes.is_empty());
}
#[test]
fn test_code_structure_python_only() {
let dir = create_test_project();
let result = code_structure(dir.path().to_str().unwrap(), Some("python"), 100, false).unwrap();
assert!(result.total_files >= 3);
for func in &result.functions {
assert!(
func.file.ends_with(".py"),
"Expected .py file: {}",
func.file
);
}
for class in &result.classes {
assert!(
class.file.ends_with(".py"),
"Expected .py file: {}",
class.file
);
}
}
#[test]
fn test_code_structure_typescript_only() {
let dir = create_test_project();
let result = code_structure(dir.path().to_str().unwrap(), Some("typescript"), 100, false).unwrap();
assert!(result.total_files >= 1);
for func in &result.functions {
assert!(
func.file.ends_with(".ts"),
"Expected .ts file: {}",
func.file
);
}
}
#[test]
fn test_code_structure_max_results() {
let dir = create_test_project();
let result = code_structure(dir.path().to_str().unwrap(), None, 2, false).unwrap();
assert!(result.functions.len() <= 2);
assert!(result.classes.len() <= 2);
}
#[test]
fn test_code_structure_unlimited_results() {
let dir = create_test_project();
let result = code_structure(dir.path().to_str().unwrap(), None, 0, false).unwrap();
assert!(result.functions.len() > 2);
}
#[test]
fn test_code_structure_relative_paths() {
let dir = create_test_project();
let result = code_structure(dir.path().to_str().unwrap(), Some("python"), 100, false).unwrap();
for func in &result.functions {
assert!(
!func.file.starts_with('/') && !func.file.contains("tmp"),
"Expected relative path, got: {}",
func.file
);
}
}
#[test]
fn test_code_structure_nested_directory() {
let dir = create_test_project();
let result = code_structure(dir.path().to_str().unwrap(), Some("python"), 100, false).unwrap();
let has_nested = result.functions.iter().any(|f| f.file.contains("lib"));
assert!(has_nested, "Should find functions in nested directories");
}
#[test]
fn test_code_structure_nonexistent_path() {
let result = code_structure("/nonexistent/path/that/does/not/exist", None, 100, false);
assert!(result.is_err());
}
#[test]
fn test_code_structure_empty_directory() {
let dir = TempDir::new().unwrap();
let result = code_structure(dir.path().to_str().unwrap(), None, 100, false).unwrap();
assert_eq!(result.total_files, 0);
assert!(result.functions.is_empty());
assert!(result.classes.is_empty());
}
#[test]
fn test_extract_file_python() {
let dir = create_test_project();
let file_path = dir.path().join("main.py");
let result = extract_file(file_path.to_str().unwrap(), None).unwrap();
assert_eq!(result.language, "python");
assert!(!result.functions.is_empty());
assert!(!result.classes.is_empty());
let hello = result.functions.iter().find(|f| f.name == "hello");
assert!(hello.is_some());
let hello = hello.unwrap();
assert_eq!(hello.return_type, Some("str".to_string()));
}
#[test]
fn test_extract_file_with_base_path() {
let dir = create_test_project();
let file_path = dir.path().join("main.py");
let base_path = dir.path();
let result = extract_file(
file_path.to_str().unwrap(),
Some(base_path.to_str().unwrap()),
);
assert!(result.is_ok());
}
#[test]
fn test_extract_file_base_path_escape_rejected() {
let dir = create_test_project();
let sub_dir = dir.path().join("lib");
let file_path = dir.path().join("main.py");
let restricted_base = sub_dir.to_str().unwrap();
let result = extract_file(file_path.to_str().unwrap(), Some(restricted_base));
assert!(result.is_err());
match result.unwrap_err() {
BrrrError::PathTraversal { .. } => {}
e => panic!("Expected PathTraversal error, got: {:?}", e),
}
}
#[test]
fn test_extract_file_nonexistent() {
let result = extract_file("/nonexistent/file.py", None);
assert!(result.is_err());
}
#[test]
fn test_function_summary_has_signature() {
let dir = create_test_project();
let result = code_structure(dir.path().to_str().unwrap(), Some("python"), 100, false).unwrap();
for func in &result.functions {
assert!(!func.signature.is_empty());
if func.file.ends_with(".py") {
assert!(
func.signature.contains("def"),
"Python sig: {}",
func.signature
);
}
}
}
#[test]
fn test_class_summary_has_method_count() {
let dir = create_test_project();
let result = code_structure(dir.path().to_str().unwrap(), Some("python"), 100, false).unwrap();
let greeter = result.classes.iter().find(|c| c.name == "Greeter");
assert!(greeter.is_some());
assert!(greeter.unwrap().method_count >= 1);
}
#[test]
fn test_extract_file_rejects_null_bytes() {
let result = extract_file("/tmp/test\0.py", None);
assert!(result.is_err());
match result.unwrap_err() {
BrrrError::PathTraversal { target, .. } => {
assert!(target.contains("null byte"), "Should mention null byte");
}
_ => panic!("Expected PathTraversal error"),
}
}
#[test]
fn test_extract_file_rejects_null_bytes_with_base_path() {
let result = extract_file("/tmp/test\0.py", Some("/tmp"));
assert!(result.is_err());
match result.unwrap_err() {
BrrrError::PathTraversal { target, base } => {
assert!(target.contains("null byte"), "Should mention null byte");
assert_eq!(base, "/tmp", "Should include base path in error");
}
_ => panic!("Expected PathTraversal error"),
}
}
#[test]
fn test_extract_file_rejects_excessive_traversal() {
let malicious_path = "../../../../../../../../../../../../../../../etc/passwd";
let result = extract_file(malicious_path, None);
assert!(result.is_err());
match result.unwrap_err() {
BrrrError::PathTraversal { .. } => {}
BrrrError::Io(_) => {} e => panic!("Expected PathTraversal or Io error, got: {:?}", e),
}
}
#[test]
fn test_extract_file_unchecked() {
let dir = create_test_project();
let file_path = dir.path().join("main.py");
let result = extract_file_unchecked(file_path.to_str().unwrap()).unwrap();
assert_eq!(result.language, "python");
assert!(!result.functions.is_empty());
}
#[cfg(unix)]
#[test]
fn test_code_structure_excludes_symlinks_outside_root() {
use std::os::unix::fs::symlink;
let dir = TempDir::new().expect("Failed to create temp dir");
let py_content = "def safe_func(): pass";
fs::write(dir.path().join("safe.py"), py_content).unwrap();
let _ = symlink("/etc/passwd", dir.path().join("escape.py"));
let result = code_structure(dir.path().to_str().unwrap(), Some("python"), 100, false).unwrap();
let has_safe = result.functions.iter().any(|f| f.name == "safe_func");
assert!(has_safe, "Should find safe_func");
let has_escape = result.functions.iter().any(|f| f.file.contains("escape"));
assert!(!has_escape, "Should NOT have analyzed escape.py symlink");
}
#[cfg(unix)]
#[test]
fn test_code_structure_allows_symlinks_inside_root() {
use std::os::unix::fs::symlink;
let dir = TempDir::new().expect("Failed to create temp dir");
let sub_dir = dir.path().join("lib");
fs::create_dir(&sub_dir).unwrap();
fs::write(sub_dir.join("utils.py"), "def util_func(): pass").unwrap();
let _ = symlink(sub_dir.join("utils.py"), dir.path().join("link_to_utils.py"));
let result = code_structure(dir.path().to_str().unwrap(), Some("python"), 100, false).unwrap();
let has_util = result.functions.iter().any(|f| f.name == "util_func");
assert!(has_util, "Should find util_func");
}
#[test]
fn test_code_structure_deterministic_output() {
let dir = create_test_project();
let path = dir.path().to_str().unwrap();
let results: Vec<_> = (0..5)
.map(|_| code_structure(path, None, 100, false).unwrap())
.collect();
let first = &results[0];
for (i, result) in results.iter().enumerate().skip(1) {
assert_eq!(
first.functions.len(),
result.functions.len(),
"Function count mismatch on iteration {i}"
);
for (j, (f1, f2)) in first.functions.iter().zip(&result.functions).enumerate() {
assert_eq!(
(f1.file.as_str(), &f1.name, f1.line),
(f2.file.as_str(), &f2.name, f2.line),
"Function mismatch at index {j} on iteration {i}"
);
}
assert_eq!(
first.classes.len(),
result.classes.len(),
"Class count mismatch on iteration {i}"
);
for (j, (c1, c2)) in first.classes.iter().zip(&result.classes).enumerate() {
assert_eq!(
(c1.file.as_str(), &c1.name, c1.line),
(c2.file.as_str(), &c2.name, c2.line),
"Class mismatch at index {j} on iteration {i}"
);
}
}
}
#[test]
fn test_code_structure_sorted_by_file_and_line() {
let dir = create_test_project();
let result = code_structure(dir.path().to_str().unwrap(), None, 100, false).unwrap();
for window in result.functions.windows(2) {
let cmp = (&window[0].file, window[0].line).cmp(&(&window[1].file, window[1].line));
assert!(
cmp != std::cmp::Ordering::Greater,
"Functions not sorted: {:?} should come before {:?}",
(&window[0].file, window[0].line),
(&window[1].file, window[1].line)
);
}
for window in result.classes.windows(2) {
let cmp = (&window[0].file, window[0].line).cmp(&(&window[1].file, window[1].line));
assert!(
cmp != std::cmp::Ordering::Greater,
"Classes not sorted: {:?} should come before {:?}",
(&window[0].file, window[0].line),
(&window[1].file, window[1].line)
);
}
}
#[test]
fn test_file_count_invariant() {
let dir = create_test_project();
let result = code_structure(dir.path().to_str().unwrap(), None, 100, false).unwrap();
assert_eq!(
result.total_files,
result.files_processed + result.files_failed + result.files_skipped,
"Invariant violated: total_files ({}) != processed ({}) + failed ({}) + skipped ({})",
result.total_files,
result.files_processed,
result.files_failed,
result.files_skipped
);
assert_eq!(
result.files_failed, 0,
"Valid Python/TypeScript files should not have parse failures"
);
assert_eq!(
result.files_skipped, 0,
"Without max_results limit, no files should be skipped"
);
}
#[test]
fn test_file_count_with_invalid_syntax() {
let dir = TempDir::new().unwrap();
fs::write(dir.path().join("valid.py"), "def good(): pass").unwrap();
fs::write(dir.path().join("invalid.py"), "def bad( missing colon").unwrap();
let result = code_structure(dir.path().to_str().unwrap(), Some("python"), 100, false).unwrap();
assert_eq!(result.total_files, 2, "Should find 2 Python files");
assert_eq!(
result.total_files,
result.files_processed + result.files_failed + result.files_skipped,
"Invariant violated with invalid syntax file"
);
assert!(
result.files_processed >= 1,
"Should successfully parse at least the valid file"
);
}
#[test]
fn test_files_processed_reflects_success() {
let dir = create_test_project();
let result = code_structure(dir.path().to_str().unwrap(), Some("python"), 100, false).unwrap();
let unique_files: std::collections::HashSet<&str> = result
.functions
.iter()
.map(|f| f.file.as_str())
.chain(result.classes.iter().map(|c| c.file.as_str()))
.collect();
assert!(
result.files_processed >= unique_files.len(),
"files_processed ({}) should be >= unique files with content ({})",
result.files_processed,
unique_files.len()
);
}
#[test]
fn test_javascript_alias_resolves_to_typescript() {
let dir = TempDir::new().unwrap();
let js_content = r#"
function greet(name) {
return "Hello, " + name;
}
const helper = () => {
console.log("helper");
};
"#;
fs::write(dir.path().join("app.js"), js_content).unwrap();
let ts_content = r#"
function tsFunc(x: number): number {
return x * 2;
}
"#;
fs::write(dir.path().join("lib.ts"), ts_content).unwrap();
let js_result =
code_structure(dir.path().to_str().unwrap(), Some("javascript"), 100, false).unwrap();
assert!(
js_result.total_files >= 1,
"javascript alias should find .js files, got {} files",
js_result.total_files
);
assert!(
js_result
.functions
.iter()
.any(|f| f.file.ends_with(".js")),
"Should find functions in .js files"
);
let ts_result =
code_structure(dir.path().to_str().unwrap(), Some("typescript"), 100, false).unwrap();
assert_eq!(
js_result.total_files, ts_result.total_files,
"javascript and typescript should find same files (same parser)"
);
}
#[test]
fn test_unknown_language_finds_nothing() {
let dir = TempDir::new().unwrap();
fs::write(dir.path().join("app.js"), "function test() {}").unwrap();
fs::write(dir.path().join("main.py"), "def test(): pass").unwrap();
let result =
code_structure(dir.path().to_str().unwrap(), Some("brainfuck"), 100, false).unwrap();
assert_eq!(
result.total_files, 0,
"Unknown language should find no files"
);
}
}