use super::super::*;
#[test]
fn test_q5k_kernel_name() {
let kernel = Q5KKernel::new(1024, 1024, 4096);
assert_eq!(kernel.name(), "q5k_gemm_ggml");
}
#[test]
fn test_q5k_kernel_config() {
let kernel = Q5KKernel::new(1024, 1024, 4096);
assert_eq!(kernel.m, 1024);
assert_eq!(kernel.n, 1024);
assert_eq!(kernel.k, 4096);
assert_eq!(kernel.tile_size, 32);
}
#[test]
fn test_q5k_super_block_constants() {
assert_eq!(Q5K_SUPER_BLOCK_SIZE, 256, "Q5_K super-block should have 256 values");
assert_eq!(Q5K_SUPER_BLOCK_BYTES, 176, "Q5_K super-block should be 176 bytes (2+2+12+128+32)");
}
#[test]
fn test_q5k_num_super_blocks() {
let kernel = Q5KKernel::new(1024, 1024, 4096);
assert_eq!(kernel.num_super_blocks_per_row(), 16); }
#[test]
fn test_q5k_ptx_generation() {
let kernel = Q5KKernel::new(1024, 1024, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("q5k_gemm_ggml"), "Should contain Q5_K kernel name");
assert!(ptx.contains(".param .u64 a_ptr"));
assert!(ptx.contains(".param .u64 b_quant_ptr"));
assert!(ptx.contains(".param .u64 c_ptr"));
assert!(ptx.contains(".param .u32 m"));
assert!(ptx.contains(".param .u32 n"));
assert!(ptx.contains(".param .u32 k"));
}
#[test]
fn test_q5k_with_tile_size() {
let kernel = Q5KKernel::new(1024, 1024, 4096).with_tile_size(64);
assert_eq!(kernel.tile_size, 64);
assert_eq!(kernel.m, 1024);
assert_eq!(kernel.n, 1024);
assert_eq!(kernel.k, 4096);
}
#[test]
fn test_q5k_with_tile_size_affects_ptx() {
let kernel_32 = Q5KKernel::new(1024, 1024, 4096);
let kernel_64 = Q5KKernel::new(1024, 1024, 4096).with_tile_size(64);
let ptx_32 = kernel_32.emit_ptx();
let ptx_64 = kernel_64.emit_ptx();
assert!(ptx_32.contains("q5k_gemm_ggml"));
assert!(ptx_64.contains("q5k_gemm_ggml"));
}
#[test]
fn test_q5k_ptx_contains_nested_loops() {
let kernel = Q5KKernel::new(1024, 1024, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("sb_loop"), "Should have super-block loop");
assert!(ptx.contains("sub_block_loop"), "Should have sub-block loop");
}
#[test]
fn test_q5k_ptx_contains_high_bit_load() {
let kernel = Q5KKernel::new(1024, 1024, 4096);
let ptx = kernel.emit_ptx();
let load_count = ptx.matches("ld.global.u8").count();
assert!(
load_count >= 4, "Q5_K should have multiple u8 loads for scales, ql, and qh. Found {}",
load_count
);
}
#[test]
fn test_q5k_both_loops_branch_back() {
let kernel = Q5KKernel::new(1024, 1024, 4096);
let ptx = kernel.emit_ptx();
let sb_loop_count = ptx.matches("sb_loop").count();
let sub_block_loop_count = ptx.matches("sub_block_loop").count();
assert!(
sb_loop_count >= 2,
"sb_loop should appear at least twice (label + branch back), found {}",
sb_loop_count
);
assert!(
sub_block_loop_count >= 2,
"sub_block_loop should appear at least twice (label + branch back), found {}",
sub_block_loop_count
);
}
#[test]
fn test_q6k_kernel_name() {
let kernel = Q6KKernel::new(1024, 1024, 4096);
assert_eq!(kernel.name(), "q6k_gemm_ggml");
}
#[test]
fn test_q6k_kernel_config() {
let kernel = Q6KKernel::new(1024, 1024, 4096);
assert_eq!(kernel.m, 1024);
assert_eq!(kernel.n, 1024);
assert_eq!(kernel.k, 4096);
assert_eq!(kernel.tile_size, 32);
}
#[test]
fn test_q6k_super_block_constants() {
assert_eq!(Q6K_SUPER_BLOCK_SIZE, 256, "Q6_K super-block should have 256 values");
assert_eq!(Q6K_SUPER_BLOCK_BYTES, 210, "Q6_K super-block should be 210 bytes (128+64+16+2)");
}
#[test]
fn test_q6k_num_super_blocks() {
let kernel = Q6KKernel::new(1024, 1024, 4096);
assert_eq!(kernel.num_super_blocks_per_row(), 16); }
#[test]
fn test_q6k_ptx_generation() {
let kernel = Q6KKernel::new(1024, 1024, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("q6k_gemm_ggml"), "Should contain Q6_K kernel name");
assert!(ptx.contains(".param .u64 a_ptr"));
assert!(ptx.contains(".param .u64 b_quant_ptr"));
assert!(ptx.contains(".param .u64 c_ptr"));
assert!(ptx.contains(".param .u32 m"));
assert!(ptx.contains(".param .u32 n"));
assert!(ptx.contains(".param .u32 k"));
}
#[test]
fn test_q6k_with_tile_size() {
let kernel = Q6KKernel::new(1024, 1024, 4096).with_tile_size(64);
assert_eq!(kernel.tile_size, 64);
assert_eq!(kernel.m, 1024);
assert_eq!(kernel.n, 1024);
assert_eq!(kernel.k, 4096);
}
#[test]
fn test_q6k_with_tile_size_affects_ptx() {
let kernel_32 = Q6KKernel::new(1024, 1024, 4096);
let kernel_64 = Q6KKernel::new(1024, 1024, 4096).with_tile_size(64);
let ptx_32 = kernel_32.emit_ptx();
let ptx_64 = kernel_64.emit_ptx();
assert!(ptx_32.contains("q6k_gemm_ggml"));
assert!(ptx_64.contains("q6k_gemm_ggml"));
}
#[test]
fn test_q6k_ptx_contains_nested_loops() {
let kernel = Q6KKernel::new(1024, 1024, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("sb_loop"), "Should have super-block loop");
assert!(ptx.contains("sub_block_loop"), "Should have sub-block loop");
}
#[test]
fn test_q6k_ptx_contains_2bit_high_extraction() {
let kernel = Q6KKernel::new(1024, 1024, 4096);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("and"), "Should have AND for bit masking");
}
#[test]
fn test_q6k_ptx_contains_signed_offset() {
let kernel = Q6KKernel::new(1024, 1024, 4096);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains("sub.f32") || ptx.contains("sub.rn.f32"),
"Should have subtraction for signed offset"
);
}
#[test]
fn test_q6k_both_loops_branch_back() {
let kernel = Q6KKernel::new(1024, 1024, 4096);
let ptx = kernel.emit_ptx();
let sb_loop_count = ptx.matches("sb_loop").count();
let sub_block_loop_count = ptx.matches("sub_block_loop").count();
assert!(
sb_loop_count >= 2,
"sb_loop should appear at least twice (label + branch back), found {}",
sb_loop_count
);
assert!(
sub_block_loop_count >= 2,
"sub_block_loop should appear at least twice (label + branch back), found {}",
sub_block_loop_count
);
}
#[test]
fn test_all_quant_kernels_different() {
let q4k = QuantizeKernel::ggml(1024, 1024, 4096);
let q5k = Q5KKernel::new(1024, 1024, 4096);
let q6k = Q6KKernel::new(1024, 1024, 4096);
let ptx_q4k = q4k.emit_ptx();
let ptx_q5k = q5k.emit_ptx();
let ptx_q6k = q6k.emit_ptx();
assert_ne!(ptx_q4k, ptx_q5k, "Q4_K and Q5_K should produce different PTX");
assert_ne!(ptx_q4k, ptx_q6k, "Q4_K and Q6_K should produce different PTX");
assert_ne!(ptx_q5k, ptx_q6k, "Q5_K and Q6_K should produce different PTX");
}
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(32))]
#[test]
fn prop_q5k_valid_ptx_for_any_size(
m in 32u32..512,
n in 32u32..512,
k_factor in 1u32..8
) {
let k = k_factor * 256;
let kernel = Q5KKernel::new(m, n, k);
let ptx = kernel.emit_ptx();
prop_assert!(!ptx.is_empty());
prop_assert!(ptx.contains("q5k_gemm_ggml"));
prop_assert!(ptx.contains(".entry"));
prop_assert!(ptx.contains("ret;"));
prop_assert!(ptx.contains("sb_loop"));
prop_assert!(ptx.contains("sub_block_loop"));
}
#[test]
fn prop_q6k_valid_ptx_for_any_size(
m in 32u32..512,
n in 32u32..512,
k_factor in 1u32..8
) {
let k = k_factor * 256;
let kernel = Q6KKernel::new(m, n, k);
let ptx = kernel.emit_ptx();
prop_assert!(!ptx.is_empty());
prop_assert!(ptx.contains("q6k_gemm_ggml"));
prop_assert!(ptx.contains(".entry"));
prop_assert!(ptx.contains("ret;"));
prop_assert!(ptx.contains("sub.f32") || ptx.contains("sub.rn.f32"));
}
#[test]
fn prop_q5k_super_blocks_correct(k_factor in 1u32..16) {
let k = k_factor * 256;
let kernel = Q5KKernel::new(64, 64, k);
prop_assert_eq!(kernel.num_super_blocks_per_row(), k_factor);
}
#[test]
fn prop_q6k_super_blocks_correct(k_factor in 1u32..16) {
let k = k_factor * 256;
let kernel = Q6KKernel::new(64, 64, k);
prop_assert_eq!(kernel.num_super_blocks_per_row(), k_factor);
}
#[test]
fn prop_q5k_q6k_matvec_n1(m in 32u32..512, k_factor in 1u32..8) {
let k = k_factor * 256;
let q5k = Q5KKernel::new(m, 1, k);
let ptx_q5k = q5k.emit_ptx();
prop_assert!(ptx_q5k.contains("q5k_gemm_ggml"));
prop_assert!(ptx_q5k.contains(".entry"));
let q6k = Q6KKernel::new(m, 1, k);
let ptx_q6k = q6k.emit_ptx();
prop_assert!(ptx_q6k.contains("q6k_gemm_ggml"));
prop_assert!(ptx_q6k.contains(".entry"));
}
#[test]
fn prop_all_quant_kernels_distinct(
m in 64u32..256,
n in 64u32..256,
k_factor in 1u32..4
) {
let k = k_factor * 256;
let q4k = QuantizeKernel::ggml(m, n, k);
let q5k = Q5KKernel::new(m, n, k);
let q6k = Q6KKernel::new(m, n, k);
let ptx_q4k = q4k.emit_ptx();
let ptx_q5k = q5k.emit_ptx();
let ptx_q6k = q6k.emit_ptx();
prop_assert!(ptx_q4k != ptx_q5k);
prop_assert!(ptx_q4k != ptx_q6k);
prop_assert!(ptx_q5k != ptx_q6k);
}
}