use crate::kernels::quantize::{Kernel, Q6K_SUPER_BLOCK_BYTES, Q6K_SUPER_BLOCK_SIZE};
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl, PtxMemory};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
pub struct Dp4aQ6KGemvKernel {
pub k: u32,
pub n: u32,
pub num_warps: u32,
}
impl Dp4aQ6KGemvKernel {
#[must_use]
pub fn new(k: u32, n: u32) -> Self {
Self { k, n, num_warps: 3 }
}
#[must_use]
pub fn with_warps(k: u32, n: u32, num_warps: u32) -> Self {
debug_assert!(
matches!(num_warps, 1 | 2 | 3 | 4 | 6 | 8),
"num_warps should be in {{1,2,3,4,6,8}}, got {num_warps}"
);
Self { k, n, num_warps }
}
}
impl Kernel for Dp4aQ6KGemvKernel {
fn name(&self) -> &str {
"dp4a_q6k_gemv"
}
fn build_ptx(&self) -> PtxKernel {
let num_warps = self.num_warps;
let smem_size = (num_warps * 4) as usize;
PtxKernel::new("dp4a_q6k_gemv")
.param(PtxType::U64, "y_ptr")
.param(PtxType::U64, "w_ptr")
.param(PtxType::U64, "q8_ptr")
.param(PtxType::U32, "k_dim")
.param(PtxType::U32, "n_dim")
.shared_memory(smem_size)
.max_regs(255)
.build(move |ctx| {
let block_id = ctx.special_reg(PtxReg::CtaIdX);
let thread_id = ctx.special_reg(PtxReg::TidX);
let lane_id = ctx.rem_u32(thread_id, 32);
let warp_id = ctx.div_u32(thread_id, 32);
let grid_dim = ctx.special_reg(PtxReg::NctaIdX);
let n_dim = ctx.load_param_u32("n_dim");
let k_dim = ctx.load_param_u32("k_dim");
let y_ptr = ctx.load_param_u64("y_ptr");
let w_ptr = ctx.load_param_u64("w_ptr");
let q8_ptr = ctx.load_param_u64("q8_ptr");
let k_rounded = ctx.add_u32(k_dim, Q6K_SUPER_BLOCK_SIZE - 1);
let num_super_blocks = ctx.div_u32(k_rounded, Q6K_SUPER_BLOCK_SIZE);
let sb_bytes_c = ctx.mov_u32_imm(Q6K_SUPER_BLOCK_BYTES);
let row_bytes = ctx.mul_u32_reg(num_super_blocks, sb_bytes_c);
let row_idx = ctx.mov_u32_imm(0);
ctx.add_u32_reg_inplace(row_idx, block_id);
ctx.label("dp4a_q6k_row_loop");
let row_oob = ctx.setp_ge_u32(row_idx, n_dim);
ctx.branch_if(row_oob, "dp4a_q6k_exit");
let row_offset = ctx.mul_wide_u32_reg(row_idx, row_bytes);
let row_base = ctx.add_u64(w_ptr, row_offset);
let next_row = ctx.add_u32_reg(row_idx, grid_dim);
let next_offset = ctx.mul_wide_u32_reg(next_row, row_bytes);
let next_base = ctx.add_u64(w_ptr, next_offset);
ctx.prefetch_global_l2(next_base);
let acc = ctx.mov_f32_imm(0.0);
let one_c = ctx.mov_u32_imm(1);
let two_c = ctx.mov_u32_imm(2);
let four_c = ctx.mov_u32_imm(4);
let sixteen_c = ctx.mov_u32_imm(16);
let thirty_two_c = ctx.mov_u32_imm(32);
let sub_block_local = ctx.shr_u32(lane_id, two_c);
let three_c = ctx.mov_u32_imm(3);
let lane_mod4 = ctx.and_u32(lane_id, three_c);
let pos = ctx.shl_u32(lane_mod4, two_c);
let group_of_32 = ctx.shr_u32(sub_block_local, one_c);
let group_is_odd = ctx.and_u32(sub_block_local, one_c);
let gio_x16 = ctx.mul_u32_reg(group_is_odd, sixteen_c);
let pos_in_group = ctx.add_u32_reg(gio_x16, pos);
let group_low_bit = ctx.and_u32(group_of_32, one_c);
let ql_group_offset = ctx.mul_u32_reg(group_low_bit, thirty_two_c);
let ql_base_in_half = ctx.add_u32_reg(ql_group_offset, pos_in_group);
let group_div2 = ctx.shr_u32(group_of_32, one_c);
let nibble_shift = ctx.shl_u32(group_div2, two_c);
let qh_shift = ctx.shl_u32(group_of_32, one_c);
let mask_0f = ctx.mov_u32_imm(0x0F0F_0F0F);
let mask_03 = ctx.mov_u32_imm(0x0303_0303);
let ones_packed = ctx.mov_u32_imm(0x0101_0101);
let five_c = ctx.mov_u32_imm(5);
let bit0 = group_is_odd;
let bit1_val = ctx.and_u32(group_of_32, one_c);
let bit2_val = ctx.shr_u32(sub_block_local, two_c);
let zero_c = ctx.mov_u32_imm(0);
let p_bit0 = ctx.setp_ne_u32(bit0, zero_c);
let p_bit1 = ctx.setp_ne_u32(bit1_val, zero_c);
let p_bit2 = ctx.setp_ne_u32(bit2_val, zero_c);
let sb_idx = ctx.mov_u32_imm(0);
ctx.add_u32_reg_inplace(sb_idx, warp_id);
let nw_reg = ctx.mov_u32_imm(num_warps);
ctx.label("dp4a_q6k_sb_loop");
let sb_done = ctx.setp_ge_u32(sb_idx, num_super_blocks);
ctx.branch_if(sb_done, "dp4a_q6k_sb_end");
let sb_off = ctx.mul_wide_u32(sb_idx, Q6K_SUPER_BLOCK_BYTES);
let sb_addr = ctx.add_u64(row_base, sb_off);
let d_offset = ctx.mov_u64_imm(208);
let d_addr = ctx.add_u64(sb_addr, d_offset);
let d_f16 = ctx.ld_global_f16(d_addr);
let d = ctx.cvt_f32_f16(d_f16);
let scales_base_offset = ctx.mov_u64_imm(192);
let scales_base = ctx.add_u64(sb_addr, scales_base_offset);
let lane_mod_16 = ctx.rem_u32(lane_id, 16);
let lane_offset_64 = ctx.cvt_u64_u32(lane_mod_16);
let scale_addr = ctx.add_u64(scales_base, lane_offset_64);
let my_scale_byte = ctx.mov_u32_imm(0);
let is_low_lane = ctx.setp_lt_u32(lane_id, sixteen_c);
ctx.branch_if_not(is_low_lane, "dp4a_q6k_skip_scale");
let scale_u8 = ctx.ld_global_u8(scale_addr);
let scale_u32 = ctx.cvt_u32_u8(scale_u8);
ctx.mov_u32_reg(my_scale_byte, scale_u32);
ctx.label("dp4a_q6k_skip_scale");
let mut scale_regs = Vec::with_capacity(16);
for i in 0..16u32 {
scale_regs.push(ctx.shfl_idx_u32(my_scale_byte, i, 0xFFFF_FFFF));
}
let seven_c = ctx.mov_u32_imm(7);
let twofiftysix_f32 = ctx.mov_f32_imm(256.0);
let mut ds = Vec::with_capacity(16);
for &sr in &scale_regs {
let sign_bit = ctx.shr_u32(sr, seven_c);
let raw_f32 = ctx.cvt_f32_u32(sr);
let sign_f32 = ctx.cvt_f32_u32(sign_bit);
let correction = ctx.mul_f32(sign_f32, twofiftysix_f32);
let signed_f32 = ctx.sub_f32(raw_f32, correction);
ds.push(ctx.mul_f32(d, signed_f32));
}
let eight_c = ctx.mov_u32_imm(8);
let sb_q8_blocks = ctx.mul_u32_reg(sb_idx, eight_c);
let thirty_six_c = ctx.mov_u32_imm(36);
let q8_sb_offset = ctx.mul_wide_u32_reg(sb_q8_blocks, thirty_six_c);
let q8_sb_base = ctx.add_u64(q8_ptr, q8_sb_offset);
for n_idx in 0..2u32 {
let ql_full_offset =
if n_idx == 0 { ql_base_in_half } else { ctx.add_u32(ql_base_in_half, 64) };
let ql_off_64 = ctx.cvt_u64_u32(ql_full_offset);
let ql_addr = ctx.add_u64(sb_addr, ql_off_64);
let ql_int32 = ctx.ld_global_u32_unaligned(ql_addr);
let ql_shifted = ctx.shr_u32(ql_int32, nibble_shift);
let ql_nibs = ctx.and_u32(ql_shifted, mask_0f);
let qh_base_val = 128 + 32 * n_idx;
let qh_full_offset = ctx.add_u32(pos_in_group, qh_base_val);
let qh_off_64 = ctx.cvt_u64_u32(qh_full_offset);
let qh_addr = ctx.add_u64(sb_addr, qh_off_64);
let qh_int32 = ctx.ld_global_u32_unaligned(qh_addr);
let qh_shifted = ctx.shr_u32(qh_int32, qh_shift);
let qh_2bits = ctx.and_u32(qh_shifted, mask_03);
let qh_up = ctx.shl_u32(qh_2bits, four_c);
let combined = ctx.or_u32(ql_nibs, qh_up);
let sbl_div2 = ctx.shr_u32(sub_block_local, one_c);
let q8_block_idx = ctx.add_u32(sbl_div2, n_idx * 4);
let q8_block_off = ctx.mul_wide_u32_reg(q8_block_idx, thirty_six_c);
let q8_block_addr = ctx.add_u64(q8_sb_base, q8_block_off);
let pig_64 = ctx.cvt_u64_u32(pos_in_group);
let q8_qs_addr = ctx.add_u64(q8_block_addr, pig_64);
let q8_int32 = ctx.ld_global_u32(q8_qs_addr);
let dot_acc = ctx.mov_u32_imm(0);
ctx.dp4a_u32_s32_inplace(dot_acc, combined, q8_int32);
let sum_acc = ctx.mov_u32_imm(0);
ctx.dp4a_u32_s32_inplace(sum_acc, ones_packed, q8_int32);
let sum_x32 = ctx.shl_u32(sum_acc, five_c);
let int_result = ctx.sub_u32(dot_acc, sum_x32);
let result_f32 = ctx.cvt_f32_s32(int_result);
let q8_d_off = ctx.mov_u64_imm(32);
let q8_d_addr = ctx.add_u64(q8_block_addr, q8_d_off);
let q8_d_f16 = ctx.ld_global_f16(q8_d_addr);
let q8_d = ctx.cvt_f32_f16(q8_d_f16);
let base = n_idx as usize * 8;
let t_01 = ctx.selp_f32(p_bit0, ds[base + 1], ds[base]);
let t_23 = ctx.selp_f32(p_bit0, ds[base + 3], ds[base + 2]);
let t_45 = ctx.selp_f32(p_bit0, ds[base + 5], ds[base + 4]);
let t_67 = ctx.selp_f32(p_bit0, ds[base + 7], ds[base + 6]);
let t_03 = ctx.selp_f32(p_bit1, t_23, t_01);
let t_47 = ctx.selp_f32(p_bit1, t_67, t_45);
let ds_selected = ctx.selp_f32(p_bit2, t_47, t_03);
let scale_product = ctx.mul_f32(ds_selected, q8_d);
ctx.fma_f32_inplace(acc, scale_product, result_f32);
}
ctx.add_u32_reg_inplace(sb_idx, nw_reg);
ctx.branch("dp4a_q6k_sb_loop");
ctx.label("dp4a_q6k_sb_end");
let t16 = ctx.shfl_down_f32(acc, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, t16);
let t8 = ctx.shfl_down_f32(acc, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, t8);
let t4_r = ctx.shfl_down_f32(acc, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, t4_r);
let t2 = ctx.shfl_down_f32(acc, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, t2);
let t1 = ctx.shfl_down_f32(acc, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, t1);
let is_l0 = ctx.setp_eq_u32(lane_id, zero_c);
ctx.branch_if_not(is_l0, "dp4a_q6k_skip_sm");
let wo = ctx.mul_u32_reg(warp_id, four_c);
let sa = ctx.cvt_u64_u32(wo);
ctx.st_shared_f32(sa, acc);
ctx.label("dp4a_q6k_skip_sm");
ctx.bar_sync(0);
let is_t0 = ctx.setp_eq_u32(thread_id, zero_c);
ctx.branch_if_not(is_t0, "dp4a_q6k_skip_store");
let fs = ctx.mov_f32_imm(0.0);
for w in 0..num_warps {
let wo = ctx.mov_u64_imm(u64::from(w * 4));
let pv = ctx.ld_shared_f32(wo);
ctx.add_f32_inplace(fs, pv);
}
let yo = ctx.mul_wide_u32(row_idx, 4);
let ya = ctx.add_u64(y_ptr, yo);
ctx.st_global_f32(ya, fs);
ctx.label("dp4a_q6k_skip_store");
ctx.add_u32_reg_inplace(row_idx, grid_dim);
ctx.bar_sync(0);
ctx.branch("dp4a_q6k_row_loop");
ctx.label("dp4a_q6k_exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dp4a_q6k_builds_qwen25() {
let kernel = Dp4aQ6KGemvKernel::new(1536, 1536);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".visible .entry dp4a_q6k_gemv"));
assert!(ptx.contains("dp4a.u32.s32"), "Must use dp4a instructions");
assert!(ptx.contains("bar.sync"), "Must have barrier for cross-warp safety");
assert!(ptx.contains("prefetch.global.L2"), "Must prefetch next row data");
assert!(ptx.contains("bfi.b32"), "Must use bfi.b32 for unaligned byte packing");
}
#[test]
fn test_dp4a_q6k_builds_lm_head() {
let kernel = Dp4aQ6KGemvKernel::new(1536, 151936);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".visible .entry dp4a_q6k_gemv"));
}
#[test]
fn test_dp4a_q6k_parameters() {
let kernel = Dp4aQ6KGemvKernel::new(256, 64);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("y_ptr"));
assert!(ptx.contains("w_ptr"));
assert!(ptx.contains("q8_ptr"), "Must have Q8_1 activation pointer");
assert!(ptx.contains("k_dim"));
assert!(ptx.contains("n_dim"));
}
#[test]
fn test_dp4a_q6k_shared_memory() {
for warps in [1, 2, 3, 4, 6, 8] {
let kernel = Dp4aQ6KGemvKernel::with_warps(256, 64, warps);
let ptx_kernel = kernel.build_ptx();
assert_eq!(
ptx_kernel.shared_memory_bytes(),
(warps * 4) as usize,
"Shared memory must be {warps} warps × 4 bytes"
);
}
}
#[test]
fn test_dp4a_q6k_warp_variants() {
for warps in [1, 2, 3, 4, 6, 8] {
let kernel = Dp4aQ6KGemvKernel::with_warps(1536, 1536, warps);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".visible .entry"), "Must produce valid PTX for {warps} warps");
}
}
#[test]
fn test_dp4a_q6k_name() {
let k = Dp4aQ6KGemvKernel::new(1536, 1536);
assert_eq!(k.name(), "dp4a_q6k_gemv");
}
}