use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn gemm_naive_always_valid(m in 16u32..512, n in 16u32..512, k in 16u32..512) {
let kernel = GemmKernel::naive(m, n, k);
let ptx = kernel.emit_ptx();
prop_assert!(ptx.contains(".version"), "Missing PTX version");
prop_assert!(ptx.contains(".target"), "Missing target");
prop_assert!(ptx.contains(".entry"), "Missing entry point");
prop_assert!(ptx.contains(".param"), "Missing parameters");
prop_assert!(ptx.contains("a_ptr"), "Missing A matrix pointer");
prop_assert!(ptx.contains("b_ptr"), "Missing B matrix pointer");
prop_assert!(ptx.contains("c_ptr"), "Missing C matrix pointer");
}
#[test]
fn gemm_tiled_uses_shared_memory(m in 32u32..256, n in 32u32..256, k in 32u32..256, tile in 8u32..32) {
let kernel = GemmKernel::tiled(m, n, k, tile);
let ptx_kernel = kernel.build_ptx();
prop_assert!(ptx_kernel.shared_memory_bytes() > 0, "Tiled GEMM should use shared memory");
}
#[test]
fn softmax_always_valid(seq_len in 64u32..8192) {
let kernel = SoftmaxKernel::new(seq_len);
let ptx = kernel.emit_ptx();
prop_assert!(ptx.contains(".version"), "Missing PTX version");
prop_assert!(ptx.contains(".entry"), "Missing entry point");
prop_assert!(ptx.contains("softmax"), "Missing softmax kernel name");
}
#[test]
fn layernorm_always_valid(hidden_size in 64u32..4096) {
let kernel = LayerNormKernel::new(hidden_size);
let ptx = kernel.emit_ptx();
prop_assert!(ptx.contains(".version"), "Missing PTX version");
prop_assert!(ptx.contains(".entry"), "Missing entry point");
}
#[test]
fn attention_always_valid(
seq_len in 64u32..2048,
head_dim in 32u32..128,
) {
let kernel = AttentionKernel::new(seq_len, head_dim);
let ptx = kernel.emit_ptx();
prop_assert!(ptx.contains(".version"), "Missing PTX version");
prop_assert!(ptx.contains(".entry"), "Missing entry point");
}
#[test]
fn kernel_names_deterministic(m in 16u32..512, n in 16u32..512, k in 16u32..512) {
let kernel1 = GemmKernel::naive(m, n, k);
let kernel2 = GemmKernel::naive(m, n, k);
prop_assert_eq!(kernel1.name(), kernel2.name(), "Kernel names should be deterministic");
}
#[test]
fn ptx_emission_consistent_structure(m in 16u32..256, n in 16u32..256, k in 16u32..256) {
let kernel = GemmKernel::naive(m, n, k);
let ptx = kernel.emit_ptx();
prop_assert!(ptx.contains(".version 8.0"), "Must have version 8.0");
prop_assert!(ptx.contains(".target sm_70"), "Must target sm_70 baseline");
prop_assert!(ptx.contains(".address_size 64"), "Must use 64-bit addresses");
prop_assert!(ptx.contains("ret;"), "Must have return statement");
}
}
#[test]
fn test_minimum_dimensions() {
let kernel = GemmKernel::naive(1, 1, 1);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry"), "Should handle 1x1x1");
}
#[test]
fn test_large_dimensions() {
let kernel = GemmKernel::naive(4096, 4096, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry"), "Should handle 4096x4096");
}
#[test]
fn test_non_power_of_two() {
let kernel = GemmKernel::naive(127, 255, 63);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry"), "Should handle non-power-of-two");
}
proptest! {
#[test]
fn batched_gemm_naive_always_valid(
batch in 1u32..16,
m in 16u32..256,
n in 16u32..256,
k in 16u32..256
) {
let kernel = BatchedGemmKernel::naive(batch, m, n, k);
let ptx = kernel.emit_ptx();
prop_assert!(ptx.contains(".version"), "Missing PTX version");
prop_assert!(ptx.contains(".entry"), "Missing entry point");
prop_assert!(ptx.contains(".param .u32 batch"), "Missing batch parameter");
prop_assert!(ptx.contains("%ctaid.z"), "Missing batch indexing via ctaid.z");
}
#[test]
fn batched_gemm_tiled_always_valid(
batch in 1u32..8,
m in 32u32..128,
n in 32u32..128,
k in 32u32..128,
tile in 8u32..17
) {
let kernel = BatchedGemmKernel::tiled(batch, m, n, k, tile);
let ptx = kernel.emit_ptx();
let ptx_kernel = kernel.build_ptx();
prop_assert!(ptx.contains(".entry"), "Missing entry point");
prop_assert!(ptx.contains("bar.sync"), "Missing barrier synchronization");
prop_assert!(ptx_kernel.shared_memory_bytes() > 0, "Should use shared memory");
}
#[test]
fn batched_4d_gemm_always_valid(
batch in 1u32..8,
heads in 1u32..16,
m in 32u32..128,
n in 32u32..128,
k in 16u32..64
) {
let kernel = Batched4DGemmKernel::new(batch, heads, m, n, k);
let ptx = kernel.emit_ptx();
prop_assert!(ptx.contains(".version"), "Missing PTX version");
prop_assert!(ptx.contains(".entry"), "Missing entry point");
prop_assert!(ptx.contains(".param .u32 batch"), "Missing batch parameter");
prop_assert!(ptx.contains(".param .u32 heads"), "Missing heads parameter");
prop_assert!(ptx.contains("%ctaid.z"), "Missing batch*heads indexing");
}
}
#[test]
fn test_batched_gemm_minimum_batch() {
let kernel = BatchedGemmKernel::naive(1, 32, 32, 32);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry"), "Should handle batch=1");
}
#[test]
fn test_batched_4d_gemm_attention_pattern() {
let kernel = Batched4DGemmKernel::new(2, 8, 512, 512, 64);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry"), "Should handle attention pattern");
assert!(ptx.contains("bar.sync"), "Should have barriers for tiled compute");
}