apr-cli 0.4.16

CLI tool for APR model inspection, debugging, and operations
//! PTX analysis and explanation command
//!
//! Bridges trueno-explain's PTX analysis into the apr CLI.
//! Runs PtxAnalyzer (register pressure, memory patterns, occupancy, roofline, muda)
//! and PtxBugAnalyzer (15+ bug detectors) on PTX source files or generated kernels.

use crate::error::Result;
use std::path::Path;

/// Run PTX analysis on a file or kernel name.
pub(crate) fn run(
    file: Option<&Path>,
    kernel: Option<&str>,
    strict: bool,
    bugs_only: bool,
    json: bool,
    verbose: bool,
) -> Result<()> {
    let ptx_source = if let Some(path) = file {
        std::fs::read_to_string(path).map_err(|e| {
            crate::error::CliError::Aprender(format!(
                "Failed to read PTX file '{}': {}",
                path.display(),
                e
            ))
        })?
    } else if let Some(name) = kernel {
        generate_kernel_ptx(name)?
    } else {
        return Err(crate::error::CliError::Aprender(
            "Provide a PTX file path or --kernel <name>".to_string(),
        ));
    };

    if json {
        run_json(&ptx_source, strict, bugs_only)
    } else {
        run_human(&ptx_source, strict, bugs_only, verbose)
    }
}

/// Human-readable PTX analysis output.
/// Print PTX analysis report (registers, memory, roofline, muda warnings).
fn print_ptx_analysis_report(
    report: &trueno_explain::analyzer::AnalysisReport,
    ptx: &str,
    verbose: bool,
) {
    println!("\x1b[1;36m=== PTX Analysis: {} ===\x1b[0m\n", report.name);

    println!("\x1b[1mRegisters:\x1b[0m");
    println!(
        "  f32: {}  f64: {}  b32: {}  b64: {}  pred: {}",
        report.registers.f32_regs,
        report.registers.f64_regs,
        report.registers.b32_regs,
        report.registers.b64_regs,
        report.registers.pred_regs,
    );
    println!(
        "  Total: {}  Estimated occupancy: {:.0}%\n",
        report.registers.total(),
        report.estimated_occupancy * 100.0
    );

    println!("\x1b[1mMemory:\x1b[0m");
    println!(
        "  Global loads: {}  Global stores: {}",
        report.memory.global_loads, report.memory.global_stores
    );
    println!(
        "  Shared loads: {}  Shared stores: {}",
        report.memory.shared_loads, report.memory.shared_stores
    );
    println!(
        "  Coalescing ratio: {:.1}%\n",
        report.memory.coalesced_ratio * 100.0
    );

    println!("\x1b[1mRoofline:\x1b[0m");
    println!("  Instructions: {}", report.instruction_count);
    println!(
        "  Arithmetic intensity: {:.2} FLOP/byte",
        report.roofline.arithmetic_intensity
    );
    println!(
        "  Bottleneck: {}\n",
        if report.roofline.memory_bound {
            "MEMORY-BOUND"
        } else {
            "COMPUTE-BOUND"
        }
    );

    if !report.warnings.is_empty() {
        println!("\x1b[1;33mMuda (Waste) Warnings:\x1b[0m");
        for w in &report.warnings {
            println!("  [{:?}] {}", w.muda_type, w.description);
            println!("    Impact: {}", w.impact);
            if let Some(suggestion) = &w.suggestion {
                println!("    Fix: {suggestion}");
            }
        }
        println!();
    }

    if verbose {
        println!("\x1b[1mPTX Source:\x1b[0m");
        for (i, line) in ptx.lines().enumerate() {
            println!("  {:4} | {line}", i + 1);
        }
        println!();
    }
}

/// Print PTX bug analysis report.
fn print_ptx_bug_report(bug_report: &trueno_explain::PtxBugReport) {
    let color = if bug_report.bugs.is_empty() {
        "32"
    } else {
        "31"
    };
    let name_suffix = bug_report
        .kernel_name
        .as_ref()
        .map_or(String::new(), |n| format!(": {n}"));
    println!("\x1b[1;{color}m=== PTX Bug Analysis{name_suffix} ===\x1b[0m");
    println!("  Lines analyzed: {}", bug_report.lines_analyzed);
    println!("  Strict mode: {}", bug_report.strict_mode);
    println!("  Bugs found: {}\n", bug_report.bugs.len());

    if bug_report.bugs.is_empty() {
        println!("  \x1b[32mNo bugs detected.\x1b[0m");
        return;
    }

    for bug in &bug_report.bugs {
        let severity_color = match bug.class.severity() {
            trueno_explain::BugSeverity::Critical => "31",
            trueno_explain::BugSeverity::High => "33",
            trueno_explain::BugSeverity::Medium => "35",
            trueno_explain::BugSeverity::FalsePositive => "36",
        };
        println!(
            "  \x1b[{severity_color}m[{:?}]\x1b[0m Line {}: {}",
            bug.class, bug.line, bug.message
        );
        if !bug.instruction.is_empty() {
            println!("    Instruction: {}", bug.instruction);
        }
        if let Some(fix) = &bug.fix {
            println!("    Fix: {fix}");
        }
        println!();
    }
}

fn run_human(ptx: &str, strict: bool, bugs_only: bool, verbose: bool) -> Result<()> {
    use trueno_explain::analyzer::Analyzer;

    if !bugs_only {
        let analyzer = trueno_explain::PtxAnalyzer::new();
        match analyzer.analyze(ptx) {
            Ok(report) => print_ptx_analysis_report(&report, ptx, verbose),
            Err(e) => eprintln!("PTX analysis error: {e}"),
        }
    }

    let bug_analyzer = if strict {
        trueno_explain::PtxBugAnalyzer::strict()
    } else {
        trueno_explain::PtxBugAnalyzer::with_performance_whitelist()
    };

    let bug_report = bug_analyzer.analyze(ptx);
    if bugs_only || !bug_report.bugs.is_empty() {
        print_ptx_bug_report(&bug_report);
    }

    Ok(())
}

/// JSON output for PTX analysis.
// serde_json::json!() macro uses infallible unwrap internally
#[allow(clippy::disallowed_methods)]
fn run_json(ptx: &str, strict: bool, bugs_only: bool) -> Result<()> {
    use trueno_explain::analyzer::Analyzer;

    let mut output = serde_json::Map::new();

    if !bugs_only {
        let analyzer = trueno_explain::PtxAnalyzer::new();
        if let Ok(report) = analyzer.analyze(ptx) {
            if let Ok(report_json) = serde_json::to_value(&report) {
                output.insert("analysis".to_string(), report_json);
            }
        }
    }

    let bug_analyzer = if strict {
        trueno_explain::PtxBugAnalyzer::strict()
    } else {
        trueno_explain::PtxBugAnalyzer::with_performance_whitelist()
    };

    let bug_report = bug_analyzer.analyze(ptx);
    let bugs_json = serde_json::json!({
        "kernel_name": bug_report.kernel_name,
        "lines_analyzed": bug_report.lines_analyzed,
        "strict_mode": bug_report.strict_mode,
        "bug_count": bug_report.bugs.len(),
        "bugs": bug_report.bugs.iter().map(|b| {
            serde_json::json!({
                "class": format!("{:?}", b.class),
                "severity": format!("{:?}", b.class.severity()),
                "line": b.line,
                "message": b.message,
                "instruction": b.instruction,
                "fix": b.fix,
            })
        }).collect::<Vec<_>>(),
    });
    output.insert("bugs".to_string(), bugs_json);

    println!(
        "{}",
        serde_json::to_string_pretty(&serde_json::Value::Object(output)).unwrap_or_default()
    );

    Ok(())
}

/// Generate PTX for a named kernel from trueno-gpu via realizar.
///
/// Uses Qwen2 7B dimensions by default (hidden=3584, intermediate=18944,
/// heads=28, head_dim=128). These can be overridden with a model file.
#[cfg(feature = "inference")]
fn generate_kernel_ptx(name: &str) -> Result<String> {
    use realizar::ptx_parity::{generate_named_kernel_ptx, KernelDimensions};

    // Default: Qwen2.5-Coder-7B-Instruct dimensions
    let dims = KernelDimensions {
        hidden_dim: 3584,
        intermediate_dim: 18944,
        num_heads: 28,
        head_dim: 128,
        rope_theta: 1_000_000.0,
        epsilon: 1e-6,
    };

    let (_label, ptx) = generate_named_kernel_ptx(name, &dims).map_err(|e| {
        crate::error::CliError::Aprender(format!(
            "{e}\nHint: Run with DP4A_Q4K=1 to dump failing PTX to /tmp/failing_ptx.txt"
        ))
    })?;
    Ok(ptx)
}

#[cfg(not(feature = "inference"))]
fn generate_kernel_ptx(_name: &str) -> Result<String> {
    Err(crate::error::CliError::ValidationFailed(
        "Kernel PTX generation requires the 'inference' feature (realizar)".to_string(),
    ))
}

#[cfg(test)]
mod tests {
    use super::*;

    const SAMPLE_PTX: &str = r#"
.version 8.0
.target sm_89
.address_size 64

.visible .entry vector_add(
    .param .u64 a_ptr,
    .param .u64 b_ptr,
    .param .u64 c_ptr,
    .param .u32 n
) {
    .reg .f32 %f<8>;
    .reg .u32 %r<6>;
    .reg .u64 %rd<8>;
    .reg .pred %p<2>;

    mov.u32 %r0, %tid.x;
    mov.u32 %r1, %ctaid.x;
    mov.u32 %r2, %ntid.x;
    mad.lo.u32 %r3, %r1, %r2, %r0;
    ld.param.u32 %r4, [n];
    setp.ge.u32 %p0, %r3, %r4;
    @%p0 bra exit;

    ld.param.u64 %rd0, [a_ptr];
    ld.param.u64 %rd1, [b_ptr];
    ld.param.u64 %rd2, [c_ptr];
    mul.wide.u32 %rd3, %r3, 4;
    add.u64 %rd4, %rd0, %rd3;
    add.u64 %rd5, %rd1, %rd3;
    add.u64 %rd6, %rd2, %rd3;
    ld.global.f32 %f0, [%rd4];
    ld.global.f32 %f1, [%rd5];
    add.f32 %f2, %f0, %f1;
    st.global.f32 [%rd6], %f2;
exit:
    ret;
}
"#;

    #[test]
    fn test_ptx_explain_inline() {
        let dir = tempfile::tempdir().expect("create temp dir");
        let ptx_path = dir.path().join("test.ptx");
        std::fs::write(&ptx_path, SAMPLE_PTX).expect("write PTX");

        let result = run(Some(ptx_path.as_path()), None, false, false, false, false);
        assert!(result.is_ok());
    }

    #[test]
    fn test_ptx_explain_json() {
        let dir = tempfile::tempdir().expect("create temp dir");
        let ptx_path = dir.path().join("test.ptx");
        std::fs::write(&ptx_path, SAMPLE_PTX).expect("write PTX");

        let result = run(Some(ptx_path.as_path()), None, false, false, true, false);
        assert!(result.is_ok());
    }

    #[test]
    fn test_ptx_explain_strict() {
        let dir = tempfile::tempdir().expect("create temp dir");
        let ptx_path = dir.path().join("test.ptx");
        std::fs::write(&ptx_path, SAMPLE_PTX).expect("write PTX");

        let result = run(Some(ptx_path.as_path()), None, true, false, false, false);
        assert!(result.is_ok());
    }

    #[test]
    fn test_ptx_explain_bugs_only() {
        let dir = tempfile::tempdir().expect("create temp dir");
        let ptx_path = dir.path().join("test.ptx");
        std::fs::write(&ptx_path, SAMPLE_PTX).expect("write PTX");

        let result = run(Some(ptx_path.as_path()), None, false, true, false, false);
        assert!(result.is_ok());
    }

    #[test]
    fn test_ptx_explain_no_args_errors() {
        let result = run(None, None, false, false, false, false);
        assert!(result.is_err());
    }

    #[test]
    fn test_ptx_explain_unknown_kernel() {
        let result = run(None, Some("NonexistentKernel"), false, false, false, false);
        assert!(result.is_err());
    }

    #[test]
    fn test_ptx_explain_missing_file() {
        let result = run(
            Some(Path::new("/tmp/nonexistent_ptx_file.ptx")),
            None,
            false,
            false,
            false,
            false,
        );
        assert!(result.is_err());
    }
}