use super::*;
#[test]
fn test_parity_114_tiled_gemm_no_early_exit_before_barrier() {
let kernel = GemmKernel::tiled(4, 8, 64, 32);
let ptx = kernel.emit_ptx();
let bar_sync_pos = ptx.find("bar.sync").expect("bar.sync required");
let tile_loop_end_pos = ptx.find("tile_loop_end:").expect("tile_loop_end required");
let early_exit = ptx.lines().any(|line| {
if line.contains("@%p") && line.contains("bra exit") {
let pos = ptx.find(line).unwrap_or(0);
pos < tile_loop_end_pos
} else {
false
}
});
assert!(!early_exit, "PARITY-114 violation");
assert!(bar_sync_pos < tile_loop_end_pos, "bar.sync must be in loop");
}
#[test]
fn test_parity_114_ntiles_computation() {
let kernel = GemmKernel::tiled(4, 8, 64, 32);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(", 2;"),
"PTX should have n_tiles=2 for k=64, tile_size=32"
);
assert!(ptx.contains(", 32;"), "PTX should have tile_size=32");
}
#[test]
fn test_parity_114_tensor_core_no_early_exit_before_barrier() {
let kernel = GemmKernel::tensor_core(16, 16, 16);
let ptx = kernel.emit_ptx();
let bar_sync_pos = ptx.find("bar.sync").expect("PTX should have bar.sync");
let k_tile_end_pos = ptx.find("k_tile_end:").expect("PTX should have k_tile_end");
assert!(
bar_sync_pos < k_tile_end_pos,
"bar.sync should be inside k_tile_loop (before k_tile_end)"
);
}
#[test]
fn test_parity_114_wmma_no_early_exit_before_barrier() {
let kernel = GemmKernel::wmma_fp16(16, 16, 16);
let ptx = kernel.emit_ptx();
let bar_sync_pos = ptx.find("bar.sync").expect("PTX should have bar.sync");
let k_tile_end_pos = ptx.find("k_tile_end:").expect("PTX should have k_tile_end");
assert!(
bar_sync_pos < k_tile_end_pos,
"bar.sync should be inside k_tile_loop (before k_tile_end)"
);
assert!(ptx.contains("wmma.mma"), "WMMA kernel should have wmma.mma");
assert!(
ptx.contains("wmma.load"),
"WMMA kernel should have wmma.load"
);
}
#[test]
fn test_boundary_conditions_tensor_core() {
let boundary_cases = [
(17, 17, 17), (31, 31, 31), (33, 33, 33), (100, 100, 100), (1, 16, 16), (16, 1, 16), ];
for (m, n, k) in boundary_cases {
let kernel = GemmKernel::tensor_core(m, n, k);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry"));
assert!(ptx.contains("bar.sync"));
let bar_sync_pos = ptx.find("bar.sync").expect("test");
let k_tile_end_pos = ptx.find("k_tile_end:").expect("test");
assert!(bar_sync_pos < k_tile_end_pos);
}
}
#[test]
fn test_boundary_conditions_tiled_gemm() {
let boundary_cases = [
(17, 17, 17, 16),
(65, 65, 65, 32),
(100, 100, 100, 32),
(1, 32, 32, 16),
];
for (m, n, k, tile) in boundary_cases {
let kernel = GemmKernel::tiled(m, n, k, tile);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry"));
assert!(ptx.contains("bar.sync"));
}
}
#[test]
fn test_boundary_conditions_wmma() {
let boundary_cases = [(17, 17, 17), (32, 33, 34), (100, 100, 100)];
for (m, n, k) in boundary_cases {
let kernel = GemmKernel::wmma_fp16(m, n, k);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry"));
assert!(ptx.contains("bar.sync"));
assert!(ptx.contains("wmma.mma"));
}
}