ptx-90-parser 0.1.0

Parse NVIDIA PTX 9.0 assembly into a structured AST and explore modules via a CLI.
Documentation
mod format;

use std::path::PathBuf;

use clap::Parser;
use format::module_to_tree;
use ptx_parser::{
    parse, FunctionBody, FunctionKernelDirective, FunctionStatement, ModuleDirective,
};
#[derive(Parser, Debug)]
#[command(name = "ptx-parser-bin", about = "Parse PTX files and emit a summary")]
struct Cli {
    /// Path to the PTX source file to parse.
    #[arg(short, long)]
    input: PathBuf,
    /// Optional path to write the summary; stdout when omitted.
    #[arg(short, long)]
    output: Option<PathBuf>,
}

fn main() {
    let cli = Cli::parse();

    let source = match std::fs::read_to_string(&cli.input) {
        Ok(contents) => contents,
        Err(err) => {
            eprintln!("Failed to read {}: {err}", cli.input.display());
            std::process::exit(1);
        }
    };

    let module = match parse(&source) {
        Ok(module) => module,
        Err(err) => {
            eprintln!("Failed to parse {}: {err}", cli.input.display());
            std::process::exit(1);
        }
    };

    let function_count = module
        .directives
        .iter()
        .filter(|directive| {
            matches!(
                directive,
                ModuleDirective::FunctionKernel(FunctionKernelDirective::Entry(_))
                    | ModuleDirective::FunctionKernel(FunctionKernelDirective::Func(_))
            )
        })
        .count();
    println!("functions {}", function_count);

    println!("Parsed {}", cli.input.display());
    println!("  directives: {}", module.directives.len());
    println!("  functions : {}", function_count);
    for directive in &module.directives {
        match directive {
            ModuleDirective::FunctionKernel(FunctionKernelDirective::Entry(function)) => {
                log_function_summary(function.name.as_str(), &function.body);
            }
            ModuleDirective::FunctionKernel(FunctionKernelDirective::Func(function)) => {
                log_function_summary(function.name.as_str(), &function.body);
            }
            _ => continue,
        };
    }

    if let Some(output_path) = cli.output {
        let ast_repr = module_to_tree(&module);
        if let Err(err) = std::fs::write(&output_path, ast_repr) {
            eprintln!("Failed to write {}: {err}", output_path.display());
            std::process::exit(1);
        }
    }
}

fn log_function_summary(name: &str, body: &FunctionBody) {
    let directive_count = body.entry_directives.len();
    let (statement_count, instruction_count) = body.statements.iter().fold(
        (0usize, 0usize),
        |(stmt_total, instr_total), item| match item {
            FunctionStatement::Directive(_) => (stmt_total + 1, instr_total),
            FunctionStatement::Instruction(_) => (stmt_total, instr_total + 1),
            FunctionStatement::Label(_) => (stmt_total, instr_total),
        },
    );
    println!(
        "    {}: {} entry directives, {} statement directives, {} instructions",
        name, directive_count, statement_count, instruction_count
    );
}