#![allow(clippy::unwrap_used)]
use super::ptx_analysis::{PtxBugClass, PtxValidationResult};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct KernelPixelConfig {
pub test_degenerate_dims: bool,
pub test_boundaries: bool,
pub strict_ptx: bool,
pub timeout: Duration,
}
impl Default for KernelPixelConfig {
fn default() -> Self {
Self {
test_degenerate_dims: true,
test_boundaries: true,
strict_ptx: true,
timeout: Duration::from_secs(5),
}
}
}
#[derive(Debug, Clone)]
pub struct GpuPixelResult {
pub name: String,
pub passed: bool,
pub error: Option<String>,
pub duration: Duration,
pub bug_class: Option<PtxBugClass>,
}
impl GpuPixelResult {
#[must_use]
pub fn pass(name: &str, duration: Duration) -> Self {
Self {
name: name.to_string(),
passed: true,
error: None,
duration,
bug_class: None,
}
}
#[must_use]
pub fn fail(name: &str, error: &str, duration: Duration) -> Self {
Self {
name: name.to_string(),
passed: false,
error: Some(error.to_string()),
duration,
bug_class: None,
}
}
#[must_use]
pub fn fail_with_bug(name: &str, error: &str, bug: PtxBugClass, duration: Duration) -> Self {
Self {
name: name.to_string(),
passed: false,
error: Some(error.to_string()),
duration,
bug_class: Some(bug),
}
}
#[must_use]
pub fn from_ptx_validation(result: &PtxValidationResult) -> Self {
let start = Instant::now();
if result.is_valid() {
Self::pass("ptx_validation", start.elapsed())
} else {
let first_bug = result.bugs.first();
let error = first_bug
.map(|b| format!("{}: {}", b.class, b.message))
.unwrap_or_else(|| "Unknown PTX error".to_string());
let bug_class = first_bug.map(|b| b.class.clone());
Self {
name: "ptx_validation".to_string(),
passed: false,
error: Some(error),
duration: start.elapsed(),
bug_class,
}
}
}
}
#[derive(Debug, Clone)]
pub struct GpuPixelTest {
pub name: String,
pub description: String,
pub catches: PtxBugClass,
}
impl GpuPixelTest {
#[must_use]
pub fn new(name: &str, description: &str, catches: PtxBugClass) -> Self {
Self {
name: name.to_string(),
description: description.to_string(),
catches,
}
}
}
#[derive(Debug, Clone)]
pub struct GpuPixelTestSuite {
pub kernel_name: String,
pub results: Vec<GpuPixelResult>,
pub duration: Duration,
}
impl GpuPixelTestSuite {
#[must_use]
pub fn new(kernel_name: &str) -> Self {
Self {
kernel_name: kernel_name.to_string(),
results: Vec::new(),
duration: Duration::ZERO,
}
}
pub fn add_result(&mut self, result: GpuPixelResult) {
self.duration += result.duration;
self.results.push(result);
}
#[must_use]
pub fn all_passed(&self) -> bool {
self.results.iter().all(|r| r.passed)
}
#[must_use]
pub fn passed_count(&self) -> usize {
self.results.iter().filter(|r| r.passed).count()
}
#[must_use]
pub fn failed_count(&self) -> usize {
self.results.iter().filter(|r| !r.passed).count()
}
#[must_use]
pub fn failures(&self) -> Vec<&GpuPixelResult> {
self.results.iter().filter(|r| !r.passed).collect()
}
pub fn run_kernel_pixels(&mut self, ptx: &str, config: &KernelPixelConfig) {
let start = Instant::now();
self.add_result(self.pixel_shared_mem_addressing(ptx));
self.add_result(self.pixel_kernel_entry_exists(ptx));
self.add_result(self.pixel_loop_structure(ptx, config.strict_ptx));
if ptx.contains(".shared") {
self.add_result(self.pixel_barrier_sync(ptx));
}
self.duration = start.elapsed();
}
fn pixel_shared_mem_addressing(&self, ptx: &str) -> GpuPixelResult {
let start = Instant::now();
let regex = regex::Regex::new(r"[sl]t\.shared\.[^\[]+\[%rd\d+\]").unwrap();
if regex.is_match(ptx) {
GpuPixelResult::fail_with_bug(
"shared_mem_u32_addressing",
"Shared memory uses 64-bit addressing (should be 32-bit)",
PtxBugClass::SharedMemU64Addressing,
start.elapsed(),
)
} else {
GpuPixelResult::pass("shared_mem_u32_addressing", start.elapsed())
}
}
fn pixel_kernel_entry_exists(&self, ptx: &str) -> GpuPixelResult {
let start = Instant::now();
let regex = regex::Regex::new(r"\.visible\s+\.entry\s+\w+").unwrap();
if regex.is_match(ptx) {
GpuPixelResult::pass("kernel_entry_exists", start.elapsed())
} else {
GpuPixelResult::fail_with_bug(
"kernel_entry_exists",
"No kernel entry point found",
PtxBugClass::MissingEntryPoint,
start.elapsed(),
)
}
}
fn pixel_loop_structure(&self, ptx: &str, strict: bool) -> GpuPixelResult {
let start = Instant::now();
if !strict {
return GpuPixelResult::pass("loop_structure", start.elapsed());
}
let branch_regex = regex::Regex::new(r"^\s+bra\s+(\w*_end\w*);").unwrap();
for line in ptx.lines() {
if branch_regex.is_match(line) && !line.trim().starts_with('@') {
return GpuPixelResult::fail_with_bug(
"loop_structure",
"Unconditional branch to loop end (should branch to start)",
PtxBugClass::LoopBranchToEnd,
start.elapsed(),
);
}
}
GpuPixelResult::pass("loop_structure", start.elapsed())
}
fn pixel_barrier_sync(&self, ptx: &str) -> GpuPixelResult {
let start = Instant::now();
if ptx.contains("bar.sync") {
GpuPixelResult::pass("barrier_sync", start.elapsed())
} else {
GpuPixelResult::fail_with_bug(
"barrier_sync",
"Shared memory used but no bar.sync found",
PtxBugClass::MissingBarrierSync,
start.elapsed(),
)
}
}
#[must_use]
pub fn summary(&self) -> String {
let status = if self.all_passed() { "PASS" } else { "FAIL" };
format!(
"[{}] {} - {}/{} passed ({:?})",
status,
self.kernel_name,
self.passed_count(),
self.results.len(),
self.duration
)
}
}
pub fn standard_pixel_tests() -> Vec<GpuPixelTest> {
vec![
GpuPixelTest::new(
"shared_mem_u32_addressing",
"Verify shared memory uses 32-bit addressing",
PtxBugClass::SharedMemU64Addressing,
),
GpuPixelTest::new(
"loop_branch_to_start",
"Verify loop branches go to start label, not end",
PtxBugClass::LoopBranchToEnd,
),
GpuPixelTest::new(
"barrier_sync_present",
"Verify barrier sync exists when using shared memory",
PtxBugClass::MissingBarrierSync,
),
GpuPixelTest::new(
"kernel_entry_exists",
"Verify kernel has entry point",
PtxBugClass::MissingEntryPoint,
),
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_suite_all_passed() {
let mut suite = GpuPixelTestSuite::new("test_kernel");
suite.add_result(GpuPixelResult::pass("test1", Duration::from_millis(1)));
suite.add_result(GpuPixelResult::pass("test2", Duration::from_millis(2)));
assert!(suite.all_passed());
assert_eq!(suite.passed_count(), 2);
assert_eq!(suite.failed_count(), 0);
}
#[test]
fn test_suite_has_failure() {
let mut suite = GpuPixelTestSuite::new("test_kernel");
suite.add_result(GpuPixelResult::pass("test1", Duration::from_millis(1)));
suite.add_result(GpuPixelResult::fail(
"test2",
"error",
Duration::from_millis(2),
));
assert!(!suite.all_passed());
assert_eq!(suite.passed_count(), 1);
assert_eq!(suite.failed_count(), 1);
}
#[test]
fn test_pixel_shared_mem_u64_fails() {
let ptx = "st.shared.f32 [%rd5], %f0;";
let suite = GpuPixelTestSuite::new("test");
let result = suite.pixel_shared_mem_addressing(ptx);
assert!(!result.passed);
assert_eq!(result.bug_class, Some(PtxBugClass::SharedMemU64Addressing));
}
#[test]
fn test_pixel_shared_mem_u32_passes() {
let ptx = "st.shared.f32 [%r5], %f0;";
let suite = GpuPixelTestSuite::new("test");
let result = suite.pixel_shared_mem_addressing(ptx);
assert!(result.passed);
}
#[test]
fn test_standard_pixel_tests() {
let tests = standard_pixel_tests();
assert!(!tests.is_empty());
assert!(tests.iter().any(|t| t.name == "shared_mem_u32_addressing"));
}
#[test]
fn test_summary_format() {
let mut suite = GpuPixelTestSuite::new("gemm_tiled");
suite.add_result(GpuPixelResult::pass("test1", Duration::from_millis(1)));
let summary = suite.summary();
assert!(summary.contains("PASS"));
assert!(summary.contains("gemm_tiled"));
}
}