use realizar::ptx_parity::{KernelDimensions, PtxParityReport};
fn main() {
println!("═══════════════════════════════════════════════════════════════════");
println!(" GH-219: PTX Parity Validation — Poka-Yoke for GPU Kernels");
println!("═══════════════════════════════════════════════════════════════════\n");
demo_validate_1_5b();
demo_validate_7b();
demo_dispatch_strategies();
println!("\n═══════════════════════════════════════════════════════════════════");
println!(" Demo Complete");
println!("═══════════════════════════════════════════════════════════════════");
}
fn demo_validate_1_5b() {
println!("┌─────────────────────────────────────────────────────────────────┐");
println!("│ Demo 1: Qwen2.5-Coder-1.5B (Q4K) — 6 Kernel Pairs │");
println!("└─────────────────────────────────────────────────────────────────┘\n");
let dims = KernelDimensions {
hidden_dim: 1536,
intermediate_dim: 8960,
num_heads: 12,
head_dim: 128,
rope_theta: 1_000_000.0,
epsilon: 1e-6,
};
let report = realizar::ptx_parity::validate_all_kernel_pairs(&dims);
print_report(&report, &dims);
}
fn demo_validate_7b() {
println!("┌─────────────────────────────────────────────────────────────────┐");
println!("│ Demo 2: Qwen2.5-Coder-7B (Q4K) — 6 Kernel Pairs │");
println!("└─────────────────────────────────────────────────────────────────┘\n");
let dims = KernelDimensions {
hidden_dim: 3584,
intermediate_dim: 18944,
num_heads: 28,
head_dim: 128,
rope_theta: 1_000_000.0,
epsilon: 1e-6,
};
let report = realizar::ptx_parity::validate_all_kernel_pairs(&dims);
print_report(&report, &dims);
}
fn demo_dispatch_strategies() {
println!("┌─────────────────────────────────────────────────────────────────┐");
println!("│ Demo 3: Batch Dispatch Strategies │");
println!("└─────────────────────────────────────────────────────────────────┘\n");
println!(" Two strategies for extending single-vector kernels to batched:");
println!();
println!(" 1. grid_y (ctaid.y) — Elementwise kernels");
println!(" Each batch element gets a separate grid Y index.");
println!(" Used by: RmsNorm, ResidualAdd, RoPE, SwiGLU");
println!(" PTX check: presence of %ctaid.y register");
println!();
println!(" 2. register_unroll (m_dim) — Quantized GEMV kernels");
println!(" Batch dimension folded into the M (output rows) dimension.");
println!(" Each warp processes one output row across all batch elements.");
println!(" Used by: Q4K GEMV, Q6K GEMV");
println!(" PTX check: m_dim parameter in kernel signature");
println!();
println!(" Why two strategies?");
println!(" - Elementwise ops are embarrassingly parallel per-element.");
println!(" grid_y maps naturally to independent vectors.");
println!(" - GEMV is memory-bandwidth-bound. Unrolling across batch");
println!(" elements within the same warp improves memory coalescing");
println!(" and shared memory reuse across the batch.");
println!();
}
fn print_report(report: &PtxParityReport, dims: &KernelDimensions) {
println!(" Model dimensions:");
println!(" hidden_dim: {}", dims.hidden_dim);
println!(" intermediate_dim: {}", dims.intermediate_dim);
println!(" num_heads: {}", dims.num_heads);
println!(" head_dim: {}", dims.head_dim);
println!(" rope_theta: {}", dims.rope_theta);
println!(" epsilon: {}\n", dims.epsilon);
if report.total == 0 {
println!(" (No CUDA feature — PTX validation requires --features cuda)\n");
return;
}
println!(" ┌──────────────────────────────────┬──────────┬──────────────────┐");
println!(" │ Kernel Pair │ Status │ Dispatch │");
println!(" ├──────────────────────────────────┼──────────┼──────────────────┤");
for result in &report.results {
let status = if result.passed {
"\x1b[32mPASS\x1b[0m"
} else {
"\x1b[31mFAIL\x1b[0m"
};
println!(
" │ {:<32} │ {} │ {:<16} │",
result.name, status, result.dispatch_strategy
);
for violation in &result.violations {
println!(" │ \x1b[31m{:<72}\x1b[0m │", truncate(violation, 72));
}
}
println!(" └──────────────────────────────────┴──────────┴──────────────────┘");
println!();
if report.all_passed() {
println!(" \x1b[32m{}\x1b[0m", report.summary());
} else {
println!(" \x1b[31m{}\x1b[0m", report.summary());
}
println!();
}
fn truncate(s: &str, max: usize) -> String {
if s.len() <= max {
s.to_string()
} else {
format!("{}...", &s[..max - 3])
}
}