use std::fs;
use std::io::{self, Write};
use std::path::{Path, PathBuf};
use clap::{Parser, Subcommand, ValueEnum};
use ptx_parser::pretty_print::{TreeDisplay, TreeFormatter, print_compact_module};
use ptx_parser::{parse_ptx, run_with_large_stack};
#[derive(Parser)]
#[command(name = "ptx-parser", about = "Utilities for parsing PTX assembly")]
struct Cli {
#[command(subcommand)]
command: Command,
}
#[derive(Subcommand)]
enum Command {
ParseFile {
input_file: PathBuf,
},
PrintAst {
input_file: PathBuf,
#[arg(long)]
compact: bool,
},
OutputAst {
input_file: PathBuf,
#[arg(long, value_enum, default_value_t = AstOutputFormat::Json)]
format: AstOutputFormat,
#[arg(long)]
output: Option<PathBuf>,
},
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, ValueEnum)]
enum AstOutputFormat {
Json,
Tree,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let cli = Cli::parse();
match cli.command {
Command::ParseFile { input_file } => parse_file(&input_file)?,
Command::PrintAst {
input_file,
compact,
} => print_ast(&input_file, compact)?,
Command::OutputAst {
input_file,
format,
output,
} => output_ast(&input_file, format, output.as_deref())?,
}
Ok(())
}
fn parse_file(path: &Path) -> Result<(), Box<dyn std::error::Error>> {
let source = fs::read_to_string(path)?;
let module = parse_with_large_stack(source)?;
println!(
"✓ {}: Successfully parsed PTX module with {} directives",
path.display(),
module.directives.len()
);
Ok(())
}
fn output_ast(
path: &Path,
format: AstOutputFormat,
output: Option<&Path>,
) -> Result<(), Box<dyn std::error::Error>> {
let source = fs::read_to_string(path)?;
let module = parse_with_large_stack(source.clone())?;
match (format, output) {
(AstOutputFormat::Json, Some(out_path)) => {
let file = fs::File::create(out_path)?;
serde_json::to_writer_pretty(io::BufWriter::new(file), &module)?;
}
(AstOutputFormat::Json, None) => {
let mut stdout = io::stdout().lock();
serde_json::to_writer_pretty(&mut stdout, &module)?;
stdout.write_all(b"\n")?;
}
(AstOutputFormat::Tree, Some(out_path)) => {
let mut file = fs::File::create(out_path)?;
let mut formatter = TreeFormatter::new();
module
.tree_display(&mut formatter, &source)
.expect("Failed to display tree");
writeln!(file, "{}", formatter.finish())?;
}
(AstOutputFormat::Tree, None) => {
print_ast_with_large_stack(module, source)?;
}
}
Ok(())
}
fn print_ast(path: &Path, compact: bool) -> Result<(), Box<dyn std::error::Error>> {
let source = fs::read_to_string(path)?;
let module = parse_with_large_stack(source.clone())?;
if compact {
print_compact_module(&module);
} else {
print_ast_with_large_stack(module, source)?;
}
Ok(())
}
fn print_ast_with_large_stack(
module: ptx_parser::r#type::Module,
source: String,
) -> Result<(), Box<dyn std::error::Error>> {
run_with_large_stack(move || {
let mut formatter = TreeFormatter::new();
module
.tree_display(&mut formatter, &source)
.expect("Failed to display tree");
println!("{}", formatter.finish());
});
Ok(())
}
fn parse_with_large_stack(
source: String,
) -> Result<ptx_parser::r#type::Module, ptx_parser::PtxParseError> {
run_with_large_stack(move || parse_ptx(&source))
}