use super::super::super::*;
use proptest::prelude::*;
#[test]
fn test_coalesced_q6k_gemv_kernel_name() {
let kernel = CoalescedQ6KGemvKernel::new(4096, 32000);
assert_eq!(kernel.name(), "coalesced_q6k_gemv");
}
#[test]
fn test_coalesced_q6k_gemv_config() {
let kernel = CoalescedQ6KGemvKernel::new(4096, 32000);
assert_eq!(kernel.k, 4096);
assert_eq!(kernel.n, 32000);
}
#[test]
fn test_coalesced_q6k_gemv_ptx_generation() {
let kernel = CoalescedQ6KGemvKernel::new(4096, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".visible .entry coalesced_q6k_gemv"));
assert!(ptx.contains("ld.global"));
}
#[test]
fn test_coalesced_q6k_gemv_num_super_blocks() {
let kernel = CoalescedQ6KGemvKernel::new(4096, 4096);
assert_eq!(kernel.num_super_blocks_per_row(), 16);
let kernel2 = CoalescedQ6KGemvKernel::new(512, 4096);
assert_eq!(kernel2.num_super_blocks_per_row(), 2); }
#[test]
fn test_coalesced_q6k_gemv_has_warp_shuffle() {
let kernel = CoalescedQ6KGemvKernel::new(4096, 4096);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains("shfl.sync"),
"Should use warp shuffle for scale broadcast"
);
}
#[test]
fn test_coalesced_q6k_gemv_has_loops() {
let kernel = CoalescedQ6KGemvKernel::new(4096, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("sb_loop"), "Should have super-block loop");
}
#[test]
fn test_coalesced_q6k_gemv_barrier_safety() {
use crate::ptx::optimize::barrier_safety;
let kernel = CoalescedQ6KGemvKernel::new(4096, 4096);
let ptx = kernel.emit_ptx();
let result = barrier_safety::analyze(&ptx);
assert!(
result.is_safe,
"Coalesced Q6K GEMV should pass barrier safety: {:?}",
result.violations
);
}
#[test]
fn test_coalesced_q6k_gemv_vs_basic_different() {
let coalesced = CoalescedQ6KGemvKernel::new(4096, 4096);
let basic = Q6KGemvKernel::new(4096, 4096);
let ptx_coalesced = coalesced.emit_ptx();
let ptx_basic = basic.emit_ptx();
assert_ne!(
ptx_coalesced, ptx_basic,
"Coalesced and basic kernels should produce different PTX"
);
assert!(
ptx_coalesced.contains("coalesced"),
"Coalesced PTX should contain 'coalesced' in entry name"
);
}
proptest! {
#[test]
fn prop_coalesced_q6k_generates_valid_ptx(
k_factor in 1u32..8,
n in 32u32..4096
) {
let k = k_factor * 256;
let kernel = CoalescedQ6KGemvKernel::new(k, n);
let ptx = kernel.emit_ptx();
prop_assert!(!ptx.is_empty());
prop_assert!(ptx.contains(".visible .entry coalesced_q6k_gemv"));
}
#[test]
fn prop_coalesced_q6k_super_blocks_correct(k_factor in 1u32..16) {
let k = k_factor * 256;
let kernel = CoalescedQ6KGemvKernel::new(k, 64);
prop_assert_eq!(kernel.num_super_blocks_per_row(), k_factor);
}
}
#[test]
fn test_batched_q6k_gemv_kernel_name() {
let kernel = BatchedQ6KGemvKernel::new(4096, 32000, 8);
assert_eq!(kernel.name(), "batched_q6k_gemv_warp_reduce");
}
#[test]
fn test_batched_q6k_gemv_config() {
let kernel = BatchedQ6KGemvKernel::new(4096, 32000, 8);
assert_eq!(kernel.k, 4096);
assert_eq!(kernel.n, 32000);
assert_eq!(kernel.m, 8);
}
#[test]
fn test_batched_q6k_gemv_num_super_blocks() {
let kernel = BatchedQ6KGemvKernel::new(4096, 4096, 4);
assert_eq!(kernel.num_super_blocks_per_row(), 16);
let kernel2 = BatchedQ6KGemvKernel::new(512, 4096, 4);
assert_eq!(kernel2.num_super_blocks_per_row(), 2); }
#[test]
fn test_batched_q6k_gemv_ptx_generation() {
let kernel = BatchedQ6KGemvKernel::new(4096, 4096, 4);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".visible .entry batched_q6k_gemv_warp_reduce"),
"PTX should contain entry point"
);
assert!(ptx.contains("ld.global"), "Should have global memory loads");
}
#[test]
fn test_batched_q6k_gemv_has_m_param() {
let kernel = BatchedQ6KGemvKernel::new(4096, 4096, 4);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".param .u32 m_dim"),
"Should have m_dim parameter for batch size"
);
}
#[test]
fn test_batched_q6k_gemv_has_warp_shuffle() {
let kernel = BatchedQ6KGemvKernel::new(4096, 4096, 4);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains("shfl.sync.down"),
"Should use warp shuffle down for reduction"
);
}
#[test]
fn test_batched_q6k_gemv_has_loops() {
let kernel = BatchedQ6KGemvKernel::new(4096, 4096, 4);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("sb_loop"), "Should have super-block loop");
}
#[test]
fn test_batched_q6k_gemv_barrier_safety() {
use crate::ptx::optimize::barrier_safety;
let kernel = BatchedQ6KGemvKernel::new(4096, 4096, 4);
let ptx = kernel.emit_ptx();
let result = barrier_safety::analyze(&ptx);
assert!(
result.is_safe,
"Batched Q6K GEMV should pass barrier safety: {:?}",
result.violations
);
}
#[test]
fn test_batched_q6k_gemv_different_batch_sizes() {
for m in [1, 2, 4, 8, 16, 32] {
let kernel = BatchedQ6KGemvKernel::new(4096, 4096, m);
let ptx = kernel.emit_ptx();
assert!(!ptx.is_empty(), "PTX should not be empty for m={m}");
assert!(
ptx.contains(".visible .entry batched_q6k_gemv_warp_reduce"),
"PTX should have entry point for m={m}"
);
}
}
#[test]
fn test_batched_q6k_gemv_outputs_multiple_results() {
let kernel = BatchedQ6KGemvKernel::new(4096, 4096, 4);
let ptx = kernel.emit_ptx();
let store_count = ptx.matches("st.global.f32").count();
assert!(
store_count >= 4,
"Should have at least 4 st.global.f32 for batch size 4, found {store_count}"
);
}
#[test]
fn test_batched_q6k_gemv_vs_non_batched_different() {
let batched = BatchedQ6KGemvKernel::new(4096, 4096, 4);
let non_batched = Q6KGemvKernel::new(4096, 4096);
let ptx_batched = batched.emit_ptx();
let ptx_non_batched = non_batched.emit_ptx();
assert_ne!(
ptx_batched, ptx_non_batched,
"Batched and non-batched kernels should produce different PTX"
);
assert!(
ptx_batched.contains("batched"),
"Batched PTX should contain 'batched' in entry name"
);
}
proptest! {
#[test]
fn prop_batched_q6k_generates_valid_ptx(
k_factor in 1u32..8,
n in 32u32..4096,
m in 1u32..16
) {
let k = k_factor * 256;
let kernel = BatchedQ6KGemvKernel::new(k, n, m);
let ptx = kernel.emit_ptx();
prop_assert!(!ptx.is_empty());
prop_assert!(ptx.contains(".visible .entry batched_q6k_gemv_warp_reduce"));
prop_assert!(ptx.contains(".param .u32 m_dim"));
}
#[test]
fn prop_batched_q6k_super_blocks_correct(k_factor in 1u32..16) {
let k = k_factor * 256;
let kernel = BatchedQ6KGemvKernel::new(k, 64, 4);
prop_assert_eq!(kernel.num_super_blocks_per_row(), k_factor);
}
}