use super::super::super::*;
#[test]
fn test_coalesced_q4k_gemv_kernel_name() {
let kernel = CoalescedQ4KGemvKernel::new(3584, 4096);
assert_eq!(kernel.name(), "coalesced_q4k_gemv");
}
#[test]
fn test_coalesced_q4k_gemv_config() {
let kernel = CoalescedQ4KGemvKernel::new(3584, 4096);
assert_eq!(kernel.k, 3584);
assert_eq!(kernel.n, 4096);
}
#[test]
fn test_coalesced_q4k_gemv_ptx_generation() {
let kernel = CoalescedQ4KGemvKernel::new(3584, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".visible .entry coalesced_q4k_gemv"));
assert!(ptx.contains("ld.global"));
}
#[test]
fn test_dp4a_q4k_gemv_kernel_name() {
let kernel = Dp4aQ4KGemvKernel::new(3584, 4096);
assert_eq!(kernel.name(), "dp4a_q4k_gemv");
}
#[test]
fn test_dp4a_q4k_gemv_config() {
let kernel = Dp4aQ4KGemvKernel::new(3584, 4096);
assert_eq!(kernel.k, 3584);
assert_eq!(kernel.n, 4096);
}
#[test]
fn test_dp4a_q4k_gemv_ptx_generation() {
let kernel = Dp4aQ4KGemvKernel::new(3584, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".visible .entry dp4a_q4k_gemv"));
assert!(ptx.contains("dp4a"));
}
#[test]
fn test_true_dp4a_q4k_gemv_kernel_name() {
let kernel = TrueDp4aQ4KGemvKernel::new(3584, 4096);
assert_eq!(kernel.name(), "true_dp4a_q4k_gemv");
}
#[test]
fn test_true_dp4a_q4k_gemv_config() {
let kernel = TrueDp4aQ4KGemvKernel::new(3584, 4096);
assert_eq!(kernel.k, 3584);
assert_eq!(kernel.n, 4096);
}
#[test]
fn test_true_dp4a_q4k_gemv_ptx_generation() {
let kernel = TrueDp4aQ4KGemvKernel::new(3584, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".visible .entry true_dp4a_q4k_gemv"));
assert!(ptx.contains("dp4a"));
}
#[test]
fn test_mwv_q4k_gemv_kernel_name() {
let kernel = MultiWarpVectorizedQ4KGemvKernel::new(3584, 4096);
assert_eq!(kernel.name(), "mwv_q4k_gemv");
}
#[test]
fn test_mwv_q4k_gemv_config() {
let kernel = MultiWarpVectorizedQ4KGemvKernel::new(3584, 4096);
assert_eq!(kernel.k, 3584);
assert_eq!(kernel.n, 4096);
assert_eq!(kernel.num_warps, 4);
}
#[test]
fn test_mwv_q4k_gemv_ptx_generation() {
let kernel = MultiWarpVectorizedQ4KGemvKernel::new(3584, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".visible .entry mwv_q4k_gemv"));
assert!(ptx.contains("ld.global"));
assert!(ptx.contains("shfl"));
}
#[test]
fn test_mwv_q4k_gemv_shared_memory() {
let kernel = MultiWarpVectorizedQ4KGemvKernel::new(3584, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".shared"));
}
#[test]
fn test_mwv_dp4a_q4k_gemv_kernel_name() {
let kernel = MwvDp4aQ4KGemvKernel::new(3584, 4096);
assert_eq!(kernel.name(), "mwv_dp4a_q4k_gemv");
}
#[test]
fn test_mwv_dp4a_q4k_gemv_config() {
let kernel = MwvDp4aQ4KGemvKernel::new(3584, 4096);
assert_eq!(kernel.k, 3584);
assert_eq!(kernel.n, 4096);
assert_eq!(kernel.num_warps, 3);
}
#[test]
fn test_mwv_dp4a_q4k_gemv_ptx_generation() {
let kernel = MwvDp4aQ4KGemvKernel::new(3584, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".visible .entry mwv_dp4a_q4k_gemv"));
assert!(ptx.contains("dp4a"));
assert!(ptx.contains("ld.global"));
assert!(
ptx.contains(".maxnreg 255"),
"Must emit .maxnreg 255 for max ILP"
);
}
#[test]
fn test_mwv_dp4a_q4k_gemv_shared_memory() {
let kernel = MwvDp4aQ4KGemvKernel::new(3584, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".shared"));
}
#[test]
fn test_wide_q4k_gemv_kernel_name() {
let kernel = WideQ4KGemvKernel::new(3584, 4096);
assert_eq!(kernel.name(), "wide_q4k_gemv");
}
#[test]
fn test_wide_q4k_gemv_config() {
let kernel = WideQ4KGemvKernel::new(3584, 4096);
assert_eq!(kernel.k, 3584);
assert_eq!(kernel.n, 4096);
assert_eq!(kernel.num_warps, 8);
}
#[test]
fn test_wide_q4k_gemv_ptx_generation() {
let kernel = WideQ4KGemvKernel::new(3584, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".visible .entry wide_q4k_gemv"));
assert!(ptx.contains("ld.global"));
assert!(ptx.contains("shfl"));
}
#[test]
fn test_wide_q4k_gemv_shared_memory() {
let kernel = WideQ4KGemvKernel::new(3584, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".shared"));
}
#[test]
fn test_q8_quantize_kernel_name() {
let kernel = Q8QuantizeKernel::new(3584);
assert_eq!(kernel.name(), "q8_quantize");
}
#[test]
fn test_q8_quantize_config() {
let kernel = Q8QuantizeKernel::new(3584);
assert_eq!(kernel.n, 3584);
}
#[test]
fn test_q8_quantize_ptx_generation() {
let kernel = Q8QuantizeKernel::new(3584);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".visible .entry q8_quantize"));
assert!(ptx.contains("ld.global"));
assert!(ptx.contains("st.global"));
}
#[test]
fn test_q8_quantize_separate_signed_unsigned_registers() {
let kernel = Q8QuantizeKernel::new(3584);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains(".reg .u32 %r<"),
"Missing .u32 register declaration in PTX"
);
assert!(
ptx.contains(".reg .s32 %ri<"),
"Missing .s32 register declaration in PTX"
);
let r_decl_count = ptx
.lines()
.filter(|l| {
let trimmed = l.trim();
trimmed.starts_with(".reg") && trimmed.contains(" %r<")
})
.count();
assert_eq!(
r_decl_count, 1,
"Expected exactly 1 %r register declaration, got {}. PTX:\n{}",
r_decl_count, ptx
);
}
#[test]
fn test_q4k_q8_dot_kernel_name() {
let kernel = Q4KQ8DotKernel::new(3584, 4096);
assert_eq!(kernel.name(), "q4k_q8_dot");
}
#[test]
fn test_q4k_q8_dot_config() {
let kernel = Q4KQ8DotKernel::new(3584, 4096);
assert_eq!(kernel.k, 3584);
assert_eq!(kernel.n, 4096);
}
#[test]
fn test_q4k_q8_dot_ptx_generation() {
let kernel = Q4KQ8DotKernel::new(3584, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".visible .entry q4k_q8_dot"));
assert!(ptx.contains("ld.global")); assert!(ptx.contains("shfl")); }
#[test]
fn test_packed_dp4a_q4k_q8_kernel_name() {
let kernel = PackedDp4aQ4KQ8Kernel::new(3584, 4096);
assert_eq!(kernel.name(), "packed_dp4a_q4k_q8");
}
#[test]
fn test_packed_dp4a_q4k_q8_config() {
let kernel = PackedDp4aQ4KQ8Kernel::new(3584, 4096);
assert_eq!(kernel.k, 3584);
assert_eq!(kernel.n, 4096);
}
#[test]
fn test_packed_dp4a_q4k_q8_ptx_generation() {
let kernel = PackedDp4aQ4KQ8Kernel::new(3584, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".visible .entry packed_dp4a_q4k_q8"));
assert!(ptx.contains("dp4a"));
}
#[test]
fn test_packed_dp4a_q4k_q8_has_warp_shuffle() {
let kernel = PackedDp4aQ4KQ8Kernel::new(3584, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("shfl"));
}
#[test]
fn test_batched_q4k_gemv_kernel_name() {
let kernel = BatchedQ4KGemvKernel::new(4096, 32000, 4);
assert_eq!(kernel.name(), "batched_q4k_gemv_warp_reduce");
}
#[test]
fn test_batched_q4k_gemv_config() {
let kernel = BatchedQ4KGemvKernel::new(4096, 32000, 4);
assert_eq!(kernel.k, 4096);
assert_eq!(kernel.n, 32000);
assert_eq!(kernel.m, 4);
}
#[test]
fn test_batched_q4k_gemv_ptx_generation() {
let kernel = BatchedQ4KGemvKernel::new(4096, 4096, 4);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".visible .entry batched_q4k_gemv"));
assert!(ptx.contains("ld.global"));
}
#[test]
fn test_batched_q4k_gemv_has_warp_shuffle() {
let kernel = BatchedQ4KGemvKernel::new(4096, 4096, 2);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("shfl"));
}
#[test]
fn test_batched_q4k_gemv_num_super_blocks() {
let kernel = BatchedQ4KGemvKernel::new(4096, 4096, 4);
assert_eq!(kernel.num_super_blocks_per_row(), 16); }