#![cfg(feature = "gpu-pixels")]
use jugar_probar::gpu_pixels::{
run_kernel_pixels, validate_ptx, GpuRegressionSuite, KernelPixelConfig, PtxBugClass,
RegressionConfig,
};
use trueno_gpu::kernels::{AttentionKernel, GemmKernel, Kernel, LayerNormKernel, SoftmaxKernel};
#[cfg(feature = "gpu-pixels")]
mod tui_report {
use crossterm::{
event::{self, Event, KeyCode},
execute,
terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
};
use jugar_probar::gpu_pixels::GpuPixelTestSuite;
use ratatui::{
backend::CrosstermBackend,
layout::{Constraint, Direction, Layout},
style::{Color, Style},
text::{Line, Span},
widgets::{Block, Borders, List, ListItem, Paragraph},
Frame, Terminal,
};
use std::io;
pub struct GpuPixelTuiReport {
pub suites: Vec<GpuPixelTestSuite>,
}
impl GpuPixelTuiReport {
pub fn new() -> Self {
Self { suites: Vec::new() }
}
pub fn add_suite(&mut self, suite: GpuPixelTestSuite) {
self.suites.push(suite);
}
pub fn total_passed(&self) -> usize {
self.suites.iter().map(|s| s.passed_count()).sum()
}
pub fn total_tests(&self) -> usize {
self.suites.iter().map(|s| s.results.len()).sum()
}
pub fn all_passed(&self) -> bool {
self.suites.iter().all(|s| s.all_passed())
}
pub fn render_to_terminal(&self) -> io::Result<()> {
enable_raw_mode()?;
let mut stdout = io::stdout();
execute!(stdout, EnterAlternateScreen)?;
let backend = CrosstermBackend::new(stdout);
let mut terminal = Terminal::new(backend)?;
terminal.draw(|f| self.ui(f))?;
loop {
if let Event::Key(key) = event::read()? {
if key.code == KeyCode::Char('q') || key.code == KeyCode::Esc {
break;
}
}
}
disable_raw_mode()?;
execute!(terminal.backend_mut(), LeaveAlternateScreen)?;
Ok(())
}
fn ui(&self, f: &mut Frame<'_>) {
let chunks = Layout::default()
.direction(Direction::Vertical)
.margin(1)
.constraints([
Constraint::Length(3),
Constraint::Min(10),
Constraint::Length(5),
])
.split(f.area());
let total_passed: usize = self.suites.iter().map(|s| s.passed_count()).sum();
let total_tests: usize = self.suites.iter().map(|s| s.results.len()).sum();
let pass_rate = if total_tests > 0 {
(total_passed as f64 / total_tests as f64) * 100.0
} else {
0.0
};
let header = Paragraph::new(format!(
"GPU Pixel Tests: {}/{} passed ({:.1}%)",
total_passed, total_tests, pass_rate
))
.style(Style::default().fg(if pass_rate == 100.0 {
Color::Green
} else {
Color::Yellow
}))
.block(Block::default().borders(Borders::ALL).title("Summary"));
f.render_widget(header, chunks[0]);
let items: Vec<ListItem<'_>> = self
.suites
.iter()
.map(|suite| {
let status = if suite.all_passed() { "✓" } else { "✗" };
let color = if suite.all_passed() {
Color::Green
} else {
Color::Red
};
ListItem::new(Line::from(vec![
Span::styled(format!("[{}] ", status), Style::default().fg(color)),
Span::raw(&suite.kernel_name),
Span::styled(
format!(" ({}/{})", suite.passed_count(), suite.results.len()),
Style::default().fg(Color::Gray),
),
]))
})
.collect();
let list =
List::new(items).block(Block::default().borders(Borders::ALL).title("Kernels"));
f.render_widget(list, chunks[1]);
let legend = Paragraph::new(vec![
Line::from("Bug Classes Detected:"),
Line::from(
" shared_mem_u64 - Shared memory uses 64-bit addressing (should be 32-bit)",
),
Line::from(" loop_branch_end - Loop branches to END instead of START"),
Line::from(" missing_barrier - No bar.sync with shared memory"),
])
.style(Style::default().fg(Color::Cyan))
.block(
Block::default()
.borders(Borders::ALL)
.title("Legend (press 'q' to exit)"),
);
f.render_widget(legend, chunks[2]);
}
pub fn print_summary(&self) {
let total_passed: usize = self.suites.iter().map(|s| s.passed_count()).sum();
let total_tests: usize = self.suites.iter().map(|s| s.results.len()).sum();
let pass_rate = if total_tests > 0 {
(total_passed as f64 / total_tests as f64) * 100.0
} else {
0.0
};
let all_pass = total_passed == total_tests;
println!();
println!("┌─────────────────────────────────────────────────────────────────────┐");
println!("│ GPU PIXEL TEST REPORT │");
println!("│ trueno-gpu + probar │");
println!("├─────────────────────────────────────────────────────────────────────┤");
let bar_width = 40;
let filled = (pass_rate / 100.0 * bar_width as f64) as usize;
let bar: String = "█".repeat(filled) + &"░".repeat(bar_width - filled);
let status_color = if all_pass { "\x1b[32m" } else { "\x1b[33m" };
println!(
"│ Tests: {}{:>3}/{:<3}\x1b[0m ({:>5.1}%) [{}] │",
status_color, total_passed, total_tests, pass_rate, bar
);
println!("├─────────────────────────────────────────────────────────────────────┤");
println!("│ KERNELS │");
println!("├─────────────────────────────────────────────────────────────────────┤");
for suite in &self.suites {
let passed = suite.passed_count();
let total = suite.results.len();
let suite_pass = passed == total;
let status = if suite_pass { "✓" } else { "✗" };
let color = if suite_pass { "\x1b[32m" } else { "\x1b[31m" };
let mini_bar_width = 10;
let mini_filled = (passed as f64 / total as f64 * mini_bar_width as f64) as usize;
let mini_bar: String =
"▓".repeat(mini_filled) + &"░".repeat(mini_bar_width - mini_filled);
println!(
"│ {}[{}]\x1b[0m {:<32} {:>2}/{:<2} [{}] │",
color, status, suite.kernel_name, passed, total, mini_bar
);
for result in suite.failures() {
println!("│ └─ \x1b[31m{:<55}\x1b[0m │", result.name);
if let Some(err) = &result.error {
let truncated = if err.len() > 50 {
&err[..50]
} else {
err.as_str()
};
println!("│ └─ {:52} │", truncated);
}
}
}
println!("├─────────────────────────────────────────────────────────────────────┤");
println!("│ BUG CLASSES DETECTED │");
println!("├─────────────────────────────────────────────────────────────────────┤");
println!("│ \x1b[36mshared_mem_u64\x1b[0m - Shared memory 64-bit addressing (should be 32) │");
println!("│ \x1b[36mloop_branch_end\x1b[0m - Loop branches to END instead of START │");
println!("│ \x1b[36mmissing_barrier\x1b[0m - No bar.sync with shared memory │");
println!("│ \x1b[36mmissing_entry\x1b[0m - Kernel entry point missing │");
println!("├─────────────────────────────────────────────────────────────────────┤");
println!("│ STATISTICS │");
println!("├─────────────────────────────────────────────────────────────────────┤");
println!(
"│ Kernels Tested: {:>3} │",
self.suites.len()
);
println!(
"│ Pixel Tests: {:>3} │",
total_tests
);
println!(
"│ Bugs Found: {:>3} │",
total_tests - total_passed
);
println!("├─────────────────────────────────────────────────────────────────────┤");
if all_pass {
println!("│ \x1b[32m✓ ALL TESTS PASSED\x1b[0m │");
} else {
println!("│ \x1b[31m✗ {} TESTS FAILED\x1b[0m │", total_tests - total_passed);
}
println!("└─────────────────────────────────────────────────────────────────────┘");
println!();
}
}
}
#[cfg(feature = "gpu-pixels")]
use tui_report::GpuPixelTuiReport;
#[test]
#[cfg(feature = "gpu-pixels")]
fn pixel_gemm_tiled_shared_mem_addressing() {
let kernel = GemmKernel::tiled(32, 32, 128, 32);
let ptx = kernel.emit_ptx();
let result = validate_ptx(&ptx);
assert!(
!result.has_bug(&PtxBugClass::SharedMemU64Addressing),
"GEMM kernel uses u64 for shared memory (should be u32)"
);
}
#[test]
#[cfg(feature = "gpu-pixels")]
fn pixel_gemm_tensor_core_shared_mem_addressing() {
let kernel = GemmKernel::tensor_core(32, 32, 128);
let ptx = kernel.emit_ptx();
let result = validate_ptx(&ptx);
assert!(
!result.has_bug(&PtxBugClass::SharedMemU64Addressing),
"Tensor core GEMM uses u64 for shared memory"
);
}
#[test]
#[cfg(feature = "gpu-pixels")]
fn pixel_attention_shared_mem_addressing() {
let kernel = AttentionKernel::new(64, 64);
let ptx = kernel.emit_ptx();
let result = validate_ptx(&ptx);
assert!(
!result.has_bug(&PtxBugClass::SharedMemU64Addressing),
"Attention kernel uses u64 for shared memory"
);
}
#[test]
#[cfg(feature = "gpu-pixels")]
fn pixel_attention_causal_kernel_name() {
let kernel = AttentionKernel::new(64, 64).with_causal();
let ptx = kernel.emit_ptx();
assert!(
ptx.contains("flash_attention_causal"),
"Causal attention should have _causal suffix in kernel name"
);
}
#[test]
#[cfg(feature = "gpu-pixels")]
fn pixel_attention_barrier_sync() {
let kernel = AttentionKernel::new(64, 64);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains("bar.sync"),
"Attention kernel must have barrier synchronization"
);
}
#[test]
#[cfg(feature = "gpu-pixels")]
fn pixel_gemm_barrier_sync() {
let kernel = GemmKernel::tiled(32, 32, 64, 32);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains("bar.sync"),
"Tiled GEMM must have barrier synchronization"
);
}
#[test]
#[cfg(feature = "gpu-pixels")]
fn pixel_softmax_kernel_entry() {
let kernel = SoftmaxKernel::new(128);
let ptx = kernel.emit_ptx();
let result = validate_ptx(&ptx);
assert!(
!result.has_bug(&PtxBugClass::MissingEntryPoint),
"Softmax kernel must have entry point"
);
}
#[test]
#[cfg(feature = "gpu-pixels")]
fn pixel_layernorm_kernel_entry() {
let kernel = LayerNormKernel::new(256);
let ptx = kernel.emit_ptx();
let result = validate_ptx(&ptx);
assert!(
!result.has_bug(&PtxBugClass::MissingEntryPoint),
"LayerNorm kernel must have entry point"
);
}
#[test]
#[cfg(feature = "gpu-pixels")]
fn gpu_pixel_suite_all_kernels() {
let config = KernelPixelConfig::default();
let mut report = GpuPixelTuiReport::new();
let kernels: Vec<(&str, String)> = vec![
(
"gemm_tiled_32x32x64",
GemmKernel::tiled(32, 32, 64, 32).emit_ptx(),
),
(
"gemm_tiled_64x64x128",
GemmKernel::tiled(64, 64, 128, 32).emit_ptx(),
),
(
"gemm_tensor_core",
GemmKernel::tensor_core(32, 32, 64).emit_ptx(),
),
("attention_64x64", AttentionKernel::new(64, 64).emit_ptx()),
(
"attention_causal",
AttentionKernel::new(64, 64).with_causal().emit_ptx(),
),
("softmax_128", SoftmaxKernel::new(128).emit_ptx()),
("layernorm_256", LayerNormKernel::new(256).emit_ptx()),
];
for (name, ptx) in &kernels {
let suite = run_kernel_pixels(name, ptx, &config);
report.add_suite(suite);
}
report.print_summary();
use std::io::IsTerminal;
if std::io::stdout().is_terminal() {
if let Err(e) = report.render_to_terminal() {
eprintln!("TUI render error: {e}");
}
}
assert!(
report.all_passed(),
"Not all GPU pixel tests passed: {}/{}",
report.total_passed(),
report.total_tests()
);
}
#[test]
#[cfg(feature = "gpu-pixels")]
fn gpu_pixel_regression_detection() {
let config = RegressionConfig::default();
let mut suite = GpuRegressionSuite::new(config);
let gemm_ptx = GemmKernel::tiled(32, 32, 64, 32).emit_ptx();
suite.add_baseline("gemm_tiled", &gemm_ptx);
let result = suite.test_kernel("gemm_tiled", &gemm_ptx);
assert!(!result.is_regression, "Unexpected regression detected");
assert!(result.pixel_results.all_passed(), "Pixel tests failed");
}