use super::*;
#[test]
fn test_barrier_safety_gemm_naive() {
let kernel = GemmKernel::naive(64, 64, 64);
let result = kernel.analyze_barrier_safety();
assert!(result.is_safe, "GEMM naive should be barrier-safe: {:?}", result.violations);
}
#[test]
fn test_barrier_safety_gemm_tiled() {
let kernel = GemmKernel::tiled(64, 64, 64, 16);
let result = kernel.analyze_barrier_safety();
assert!(result.is_safe, "GEMM tiled should be barrier-safe: {:?}", result.violations);
}
#[test]
fn test_barrier_safety_gemm_tensor_core() {
let kernel = GemmKernel::tensor_core(64, 64, 64);
let result = kernel.analyze_barrier_safety();
assert!(result.is_safe, "GEMM tensor core should be barrier-safe: {:?}", result.violations);
}
#[test]
fn test_barrier_safety_gemm_wmma() {
let kernel = GemmKernel::wmma_fp16(64, 64, 64);
let result = kernel.analyze_barrier_safety();
assert!(result.is_safe, "GEMM WMMA should be barrier-safe: {:?}", result.violations);
}
#[test]
fn test_barrier_safety_attention() {
let kernel = AttentionKernel::new(64, 32);
let result = kernel.analyze_barrier_safety();
assert!(result.is_safe, "Attention should be barrier-safe: {:?}", result.violations);
}
#[test]
fn test_barrier_safety_attention_tensor_core() {
let kernel = AttentionKernel::tensor_core(64, 32);
let result = kernel.analyze_barrier_safety();
assert!(result.is_safe, "TC Attention should be barrier-safe: {:?}", result.violations);
}
#[test]
fn test_barrier_safety_softmax() {
let kernel = SoftmaxKernel::new(1024);
let result = kernel.analyze_barrier_safety();
assert!(result.is_safe, "Softmax should be barrier-safe: {:?}", result.violations);
}
#[test]
fn test_barrier_safety_layernorm() {
let kernel = LayerNormKernel::new(512);
let result = kernel.analyze_barrier_safety();
assert!(result.is_safe, "LayerNorm should be barrier-safe: {:?}", result.violations);
}
#[test]
fn test_validate_barrier_safety_ok() {
let kernel = GemmKernel::naive(32, 32, 32);
assert!(kernel.validate_barrier_safety().is_ok(), "Safe kernel should pass validation");
}
#[test]
fn test_emit_ptx_validated_works() {
let kernel = GemmKernel::naive(32, 32, 32);
let ptx = kernel.emit_ptx_validated(); assert!(ptx.contains(".entry"));
}
#[test]
fn test_barrier_safety_mwv_q6k() {
let kernel = MultiWarpQ6KGemvKernel::new(1536, 1536);
let result = kernel.analyze_barrier_safety();
assert!(result.is_safe, "MWV Q6K should be barrier-safe: {:?}", result.violations);
}
#[test]
fn test_barrier_safety_mwv_q6k_warp_variants() {
for warps in [1, 2, 3, 4, 6, 8] {
let kernel = MultiWarpQ6KGemvKernel::with_warps(1536, 1536, warps);
let result = kernel.analyze_barrier_safety();
assert!(
result.is_safe,
"MWV Q6K ({warps} warps) should be barrier-safe: {:?}",
result.violations
);
}
}
#[test]
fn test_barrier_safety_boundary_conditions() {
let test_cases = [
GemmKernel::tensor_core(17, 17, 17),
GemmKernel::tensor_core(33, 33, 33),
GemmKernel::tensor_core(100, 100, 100),
];
for kernel in test_cases {
let result = kernel.analyze_barrier_safety();
assert!(
result.is_safe,
"Boundary case {} should be barrier-safe: {:?}",
kernel.name(),
result.violations
);
}
}