use crate::kernels::quantize::{Kernel, Q4K_SUPER_BLOCK_BYTES, Q4K_SUPER_BLOCK_SIZE};
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl, PtxMemory};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
pub struct HalfWarpDp4aQ4KGemvKernel {
pub k: u32,
pub n: u32,
pub num_warps: u32,
}
impl HalfWarpDp4aQ4KGemvKernel {
pub fn new(k: u32, n: u32) -> Self {
Self { k, n, num_warps: 3 }
}
}
impl Kernel for HalfWarpDp4aQ4KGemvKernel {
fn name(&self) -> &str {
"hw_dp4a_q4k_gemv"
}
fn build_ptx(&self) -> PtxKernel {
let num_warps = self.num_warps;
let num_half_warps = num_warps * 2;
let smem_size = (num_half_warps * 4) as usize;
PtxKernel::new("hw_dp4a_q4k_gemv")
.param(PtxType::U64, "y_ptr")
.param(PtxType::U64, "w_ptr")
.param(PtxType::U64, "q8_ptr")
.param(PtxType::U32, "k_dim")
.param(PtxType::U32, "n_dim")
.shared_memory(smem_size)
.max_regs(255)
.build(move |ctx| {
let block_id = ctx.special_reg(PtxReg::CtaIdX);
let thread_id = ctx.special_reg(PtxReg::TidX);
let lane_id = ctx.rem_u32(thread_id, 32);
let warp_id = ctx.div_u32(thread_id, 32);
let grid_dim = ctx.special_reg(PtxReg::NctaIdX);
let n_dim = ctx.load_param_u32("n_dim");
let k_dim = ctx.load_param_u32("k_dim");
let y_ptr = ctx.load_param_u64("y_ptr");
let w_ptr = ctx.load_param_u64("w_ptr");
let q8_ptr = ctx.load_param_u64("q8_ptr");
let k_rounded = ctx.add_u32(k_dim, Q4K_SUPER_BLOCK_SIZE - 1);
let num_sb = ctx.div_u32(k_rounded, Q4K_SUPER_BLOCK_SIZE);
let sb_bytes_reg = ctx.mov_u32_imm(Q4K_SUPER_BLOCK_BYTES);
let row_bytes = ctx.mul_u32_reg(num_sb, sb_bytes_reg);
let half_lane = ctx.and_u32_imm(lane_id, 15);
let half_warp_in_warp = ctx.shr_u32_imm(lane_id, 4);
let warp_x2 = ctx.shl_u32_imm(warp_id, 1);
let half_warp_id = ctx.add_u32_reg(warp_x2, half_warp_in_warp);
let num_hw = ctx.mov_u32_imm(num_half_warps);
let bq8_group = ctx.shr_u32_imm(half_lane, 2);
let lane_in_group = ctx.and_u32_imm(half_lane, 3);
let bq8_offset = ctx.shl_u32_imm(bq8_group, 1);
let t1 = ctx.shl_u32_imm(bq8_offset, 4);
let t2 = ctx.shl_u32_imm(lane_in_group, 2);
let q4_local = ctx.add_u32_reg(t1, t2);
let q4_off = ctx.add_u32(q4_local, 16);
let q4_off_64 = ctx.cvt_u64_u32(q4_off);
let c_36_u32 = ctx.mov_u32_imm(36);
let bq8_bytes = ctx.mul_u32_reg(bq8_offset, c_36_u32);
let bq8_bytes_64 = ctx.cvt_u64_u32(bq8_bytes);
let lig_x4 = ctx.shl_u32_imm(lane_in_group, 2);
let lig_x4_64 = ctx.cvt_u64_u32(lig_x4);
let c_2_64 = ctx.mov_u64_imm(2);
let c_4_64 = ctx.mov_u64_imm(4);
let c_8_64 = ctx.mov_u64_imm(8);
let c_16_64 = ctx.mov_u64_imm(16);
let c_32_64 = ctx.mov_u64_imm(32);
let c_36_64 = ctx.mov_u64_imm(36);
let c_288 = ctx.mov_u32_imm(288);
let ci_mod2 = ctx.and_u32_imm(bq8_group, 1);
let c_16_u32 = ctx.mov_u32_imm(16);
let byte_shift = ctx.mul_u32_reg(ci_mod2, c_16_u32);
let c_8_u32 = ctx.mov_u32_imm(8);
let byte_shift_hi = ctx.add_u32_reg(byte_shift, c_8_u32);
let c_2_u32 = ctx.mov_u32_imm(2);
let p_hi = ctx.setp_ge_u32(bq8_group, c_2_u32);
let c_ones = ctx.mov_u32_imm(0x0101_0101);
let row_idx = ctx.mov_u32_imm(0);
ctx.add_u32_reg_inplace(row_idx, block_id);
ctx.label("hw_row_loop");
let row_oob = ctx.setp_ge_u32(row_idx, n_dim);
ctx.branch_if(row_oob, "hw_exit");
let row_off = ctx.mul_wide_u32_reg(row_idx, row_bytes);
let row_base = ctx.add_u64(w_ptr, row_off);
let acc = ctx.mov_f32_imm(0.0);
let sb_idx = ctx.mov_u32_imm(0);
ctx.add_u32_reg_inplace(sb_idx, half_warp_id);
ctx.label("hw_sb_loop");
let sb_done = ctx.setp_ge_u32(sb_idx, num_sb);
ctx.branch_if(sb_done, "hw_sb_end");
let sb_off = ctx.mul_wide_u32(sb_idx, Q4K_SUPER_BLOCK_BYTES);
let sb_addr = ctx.add_u64(row_base, sb_off);
let d_f16 = ctx.ld_global_f16(sb_addr);
let d = ctx.cvt_f32_f16(d_f16);
let dmin_addr = ctx.add_u64(sb_addr, c_2_64);
let dmin_f16 = ctx.ld_global_f16(dmin_addr);
let dmin = ctx.cvt_f32_f16(dmin_f16);
let neg_dmin = ctx.neg_f32(dmin);
let sc_base = ctx.add_u64(sb_addr, c_4_64);
let sc03 = ctx.ld_global_u32(sc_base);
let sc47_addr = ctx.add_u64(sc_base, c_4_64);
let sc47 = ctx.ld_global_u32(sc47_addr);
let sc811_addr = ctx.add_u64(sc_base, c_8_64);
let sc811 = ctx.ld_global_u32(sc811_addr);
let sc_lo4 = ctx.and_u32_imm(sc03, 0x3F3F_3F3F);
let mn_lo4 = ctx.and_u32_imm(sc47, 0x3F3F_3F3F);
let sc_hi_low = ctx.and_u32_imm(sc811, 0x0F0F_0F0F);
let t = ctx.shr_u32_imm(sc03, 6);
let t = ctx.and_u32_imm(t, 0x0303_0303);
let sc_hi_top = ctx.shl_u32_imm(t, 4);
let sc_hi4 = ctx.or_u32(sc_hi_low, sc_hi_top);
let mn_hi_raw = ctx.shr_u32_imm(sc811, 4);
let mn_hi_low = ctx.and_u32_imm(mn_hi_raw, 0x0F0F_0F0F);
let t = ctx.shr_u32_imm(sc47, 6);
let t = ctx.and_u32_imm(t, 0x0303_0303);
let mn_hi_top = ctx.shl_u32_imm(t, 4);
let mn_hi4 = ctx.or_u32(mn_hi_low, mn_hi_top);
let sc_src = ctx.selp_u32(p_hi, sc_hi4, sc_lo4);
let mn_src = ctx.selp_u32(p_hi, mn_hi4, mn_lo4);
let t = ctx.shr_u32(sc_src, byte_shift);
let sc0 = ctx.and_u32_imm(t, 0xFF);
let t = ctx.shr_u32(sc_src, byte_shift_hi);
let sc1 = ctx.and_u32_imm(t, 0xFF);
let t = ctx.shr_u32(mn_src, byte_shift);
let mn0 = ctx.and_u32_imm(t, 0xFF);
let t = ctx.shr_u32(mn_src, byte_shift_hi);
let mn1 = ctx.and_u32_imm(t, 0xFF);
let q4_addr = ctx.add_u64(sb_addr, q4_off_64);
let v0 = ctx.ld_global_u32(q4_addr);
let v1_addr = ctx.add_u64(q4_addr, c_16_64);
let v1 = ctx.ld_global_u32(v1_addr);
let q8_sb_off = ctx.mul_wide_u32_reg(sb_idx, c_288);
let q8_sb_base = ctx.add_u64(q8_ptr, q8_sb_off);
let q8_blk = ctx.add_u64(q8_sb_base, bq8_bytes_64);
let q8_data = ctx.add_u64(q8_blk, lig_x4_64);
let v0_lo = ctx.and_u32_imm(v0, 0x0F0F_0F0F);
let v1_lo = ctx.and_u32_imm(v1, 0x0F0F_0F0F);
let u0_lo = ctx.ld_global_u32(q8_data);
let u1_lo_addr = ctx.add_u64(q8_data, c_16_64);
let u1_lo = ctx.ld_global_u32(u1_lo_addr);
let dot0 = ctx.mov_u32_imm(0);
ctx.dp4a_u32_s32_inplace(dot0, v0_lo, u0_lo);
ctx.dp4a_u32_s32_inplace(dot0, v1_lo, u1_lo);
let sum0 = ctx.mov_u32_imm(0);
ctx.dp4a_u32_s32_inplace(sum0, c_ones, u0_lo);
ctx.dp4a_u32_s32_inplace(sum0, c_ones, u1_lo);
let q8_d0_addr = ctx.add_u64(q8_blk, c_32_64);
let q8_d0_f16 = ctx.ld_global_f16(q8_d0_addr);
let q8_d0 = ctx.cvt_f32_f16(q8_d0_f16);
let sdot0 = ctx.mul_lo_s32(sc0, dot0);
let msum0 = ctx.mul_lo_s32(mn0, sum0);
let sdot0_f = ctx.cvt_f32_s32(sdot0);
let msum0_f = ctx.cvt_f32_s32(msum0);
let t1 = ctx.mul_f32(d, sdot0_f);
let t3 = ctx.fma_f32(neg_dmin, msum0_f, t1); let q8_d0_t3 = ctx.mul_f32(q8_d0, t3);
ctx.add_f32_inplace(acc, q8_d0_t3);
let v0_hi = ctx.shr_u32_imm(v0, 4);
let v0_hi = ctx.and_u32_imm(v0_hi, 0x0F0F_0F0F);
let v1_hi = ctx.shr_u32_imm(v1, 4);
let v1_hi = ctx.and_u32_imm(v1_hi, 0x0F0F_0F0F);
let q8_blk_hi = ctx.add_u64(q8_blk, c_36_64);
let q8_data_hi = ctx.add_u64(q8_blk_hi, lig_x4_64);
let u0_hi = ctx.ld_global_u32(q8_data_hi);
let u1_hi_addr = ctx.add_u64(q8_data_hi, c_16_64);
let u1_hi = ctx.ld_global_u32(u1_hi_addr);
let dot1 = ctx.mov_u32_imm(0);
ctx.dp4a_u32_s32_inplace(dot1, v0_hi, u0_hi);
ctx.dp4a_u32_s32_inplace(dot1, v1_hi, u1_hi);
let sum1 = ctx.mov_u32_imm(0);
ctx.dp4a_u32_s32_inplace(sum1, c_ones, u0_hi);
ctx.dp4a_u32_s32_inplace(sum1, c_ones, u1_hi);
let q8_d1_addr = ctx.add_u64(q8_blk_hi, c_32_64);
let q8_d1_f16 = ctx.ld_global_f16(q8_d1_addr);
let q8_d1 = ctx.cvt_f32_f16(q8_d1_f16);
let sdot1 = ctx.mul_lo_s32(sc1, dot1);
let msum1 = ctx.mul_lo_s32(mn1, sum1);
let sdot1_f = ctx.cvt_f32_s32(sdot1);
let msum1_f = ctx.cvt_f32_s32(msum1);
let t1 = ctx.mul_f32(d, sdot1_f);
let t3 = ctx.fma_f32(neg_dmin, msum1_f, t1);
let q8_d1_t3 = ctx.mul_f32(q8_d1, t3);
ctx.add_f32_inplace(acc, q8_d1_t3);
ctx.add_u32_reg_inplace(sb_idx, num_hw);
ctx.branch("hw_sb_loop");
ctx.label("hw_sb_end");
let t = ctx.shfl_down_f32(acc, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, t);
let t = ctx.shfl_down_f32(acc, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, t);
let t = ctx.shfl_down_f32(acc, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, t);
let t = ctx.shfl_down_f32(acc, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, t);
let z = ctx.mov_u32_imm(0);
let is_hl0 = ctx.setp_eq_u32(half_lane, z);
ctx.branch_if_not(is_hl0, "hw_skip_sm");
let sm_off = ctx.shl_u32_imm(half_warp_id, 2);
let sm_addr = ctx.cvt_u64_u32(sm_off);
ctx.st_shared_f32(sm_addr, acc);
ctx.label("hw_skip_sm");
ctx.bar_sync(0);
let is_t0 = ctx.setp_eq_u32(thread_id, z);
ctx.branch_if_not(is_t0, "hw_skip_store");
let result = ctx.mov_f32_imm(0.0);
for hw in 0..num_half_warps {
let off = ctx.mov_u64_imm(u64::from(hw * 4));
let val = ctx.ld_shared_f32(off);
ctx.add_f32_inplace(result, val);
}
let y_off = ctx.mul_wide_u32(row_idx, 4);
let y_addr = ctx.add_u64(y_ptr, y_off);
ctx.st_global_f32(y_addr, result);
ctx.label("hw_skip_store");
ctx.add_u32_reg_inplace(row_idx, grid_dim);
ctx.bar_sync(0);
ctx.branch("hw_row_loop");
ctx.label("hw_exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ptx_emits_hw_identity() {
let k = HalfWarpDp4aQ4KGemvKernel::new(1536, 256);
let ptx = k.emit_ptx();
assert!(ptx.contains("hw_dp4a_q4k_gemv"), "kernel name");
assert!(ptx.contains("dp4a.u32.s32"), "DP4A instruction");
assert!(ptx.contains("and.b32"), "half_lane = lane_id & 15");
}
#[test]
fn test_value_coverage_contract() {
assert_eq!(4 * 4 * 2 * 4 * 2, Q4K_SUPER_BLOCK_SIZE);
assert_eq!(16 * 16, Q4K_SUPER_BLOCK_SIZE as usize);
}
#[test]
fn test_instruction_density() {
let k = HalfWarpDp4aQ4KGemvKernel::new(1536, 256);
let ptx = k.emit_ptx();
let sb_loop_start = ptx.find("hw_sb_loop:").expect("sb_loop label");
let sb_loop_end = ptx.find("hw_sb_end:").expect("sb_end label");
let inner = &ptx[sb_loop_start..sb_loop_end];
let insn_count = inner.matches(';').count();
assert!(
insn_count <= 120,
"C3 violated: inner loop has {} instructions (limit 120)",
insn_count
);
let hw_per_value = insn_count as f64 / 16.0;
let mwv_per_value = 99.0 / 8.0;
assert!(
hw_per_value < mwv_per_value,
"C3: HW {:.1} insn/val should be < MWV {:.1} insn/val",
hw_per_value,
mwv_per_value
);
eprintln!(
"[C3] Inner loop: {} instructions for 16 values ({:.1} insn/val vs MWV {:.1})",
insn_count, hw_per_value, mwv_per_value
);
}
#[test]
fn test_reduction_correctness() {
let mut vals: Vec<f32> = (1..=16).map(|x| x as f32).collect();
vals.extend((100..=115).map(|x| x as f32));
assert_eq!(vals.len(), 32);
for delta in [8, 4, 2, 1] {
let old = vals.clone();
for i in 0..32 {
let src = i + delta;
if src < 32 {
vals[i] += old[src];
}
}
}
let hw0_sum: f32 = (1..=16).map(|x| x as f32).sum();
let hw1_sum: f32 = (100..=115).map(|x| x as f32).sum();
assert!((vals[0] - hw0_sum).abs() < 0.01, "Lane 0: got {}, expected {}", vals[0], hw0_sum);
assert!(
(vals[16] - hw1_sum).abs() < 0.01,
"Lane 16: got {}, expected {}",
vals[16],
hw1_sum
);
}
#[test]
fn dump_ptx() {
let k = HalfWarpDp4aQ4KGemvKernel::new(1536, 256);
let ptx = k.emit_ptx();
eprintln!("{ptx}");
}
}