use std::io;
use std::path::{Path, PathBuf};
use clap::{ArgGroup, Parser};
use rusty_ast::{JsonVisitor, TextVisitor, parse_rust_file, parse_rust_source};
use syn::visit::Visit;
use walkdir::WalkDir;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
#[command(group(ArgGroup::new("input").required(true).args(["file", "code", "directory"])))]
struct Cli {
#[arg(short, long, value_name = "FILE")]
file: Option<PathBuf>,
#[arg(short, long, value_name = "CODE")]
code: Option<String>,
#[arg(short = 'd', long, value_name = "DIRECTORY")]
directory: Option<PathBuf>,
#[arg(short = 'o', long, value_enum, default_value = "text")]
format: OutputFormat,
#[arg(short = 'r', long)]
recursive: bool,
}
#[derive(clap::ValueEnum, Clone)]
enum OutputFormat {
Text,
Json,
}
fn process_directory(directory: &Path, format: &OutputFormat, recursive: bool) -> io::Result<()> {
let mut processed_files = 0;
let walker = if recursive {
WalkDir::new(directory)
} else {
WalkDir::new(directory).max_depth(1)
};
for entry in walker.into_iter().filter_map(|e| e.ok()) {
let path = entry.path();
if path.is_file() && path.extension().map_or(false, |ext| ext == "rs") {
println!("\n--- Processing file: {} ---", path.display());
match parse_rust_file(path) {
Ok(ast) => {
processed_files += 1;
match format {
OutputFormat::Text => {
println!("AST for Rust code in {}:", path.display());
let mut visitor = TextVisitor::new();
visitor.visit_file(&ast);
}
OutputFormat::Json => {
let mut visitor = JsonVisitor::new();
visitor.visit_file(&ast);
println!("{}", visitor.to_json());
}
}
}
Err(e) => {
eprintln!("Error parsing file {}: {}", path.display(), e);
}
}
}
}
if processed_files == 0 {
println!("No Rust files found in the specified directory.");
} else {
println!("\nProcessed {} Rust files.", processed_files);
}
Ok(())
}
fn main() -> io::Result<()> {
let cli = Cli::parse();
if let Some(directory) = cli.directory {
process_directory(&directory, &cli.format, cli.recursive)?;
return Ok(());
}
let ast = if let Some(file_path) = cli.file {
parse_rust_file(file_path)?
} else if let Some(code) = cli.code {
parse_rust_source(&code).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
} else {
unreachable!("clap should require one of the arguments");
};
match cli.format {
OutputFormat::Text => {
println!("AST for Rust code:");
let mut visitor = TextVisitor::new();
visitor.visit_file(&ast);
}
OutputFormat::Json => {
let mut visitor = JsonVisitor::new();
visitor.visit_file(&ast);
println!("{}", visitor.to_json());
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::io::Write;
use tempfile::TempDir;
#[test]
fn test_directory_processing() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("test.rs");
let mut file = fs::File::create(&file_path).unwrap();
file.write_all(b"fn test() { println!(\"Hello\"); }")
.unwrap();
process_directory(temp_dir.path(), &OutputFormat::Text, false).unwrap();
let nested_dir = temp_dir.path().join("nested");
fs::create_dir(&nested_dir).unwrap();
let nested_file_path = nested_dir.join("nested_test.rs");
let mut nested_file = fs::File::create(nested_file_path).unwrap();
nested_file
.write_all(b"fn nested_test() { return 42; }")
.unwrap();
process_directory(temp_dir.path(), &OutputFormat::Text, true).unwrap();
}
}