use crate::kernels::quantize::{Kernel, Q4K_SUPER_BLOCK_BYTES, Q4K_SUPER_BLOCK_SIZE};
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl, PtxMemory, PtxSync};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
pub struct BatchedHwDp4aQ4KGemvKernel {
pub k: u32,
pub n: u32,
pub m: u32,
pub num_warps: u32,
}
impl BatchedHwDp4aQ4KGemvKernel {
pub fn new(k: u32, n: u32, m: u32) -> Self {
Self {
k,
n,
m,
num_warps: 3,
}
}
}
impl Kernel for BatchedHwDp4aQ4KGemvKernel {
fn name(&self) -> &str {
"batched_hw_dp4a_q4k_gemv"
}
fn build_ptx(&self) -> PtxKernel {
let num_warps = self.num_warps;
let num_half_warps = num_warps * 2;
let m = self.m;
let smem_size = (num_half_warps * m * 4) as usize;
PtxKernel::new("batched_hw_dp4a_q4k_gemv")
.param(PtxType::U64, "y_ptr")
.param(PtxType::U64, "w_ptr")
.param(PtxType::U64, "q8_ptr")
.param(PtxType::U32, "k_dim")
.param(PtxType::U32, "n_dim")
.param(PtxType::U32, "m_dim")
.shared_memory(smem_size)
.max_regs(255)
.build(move |ctx| {
let block_id = ctx.special_reg(PtxReg::CtaIdX);
let thread_id = ctx.special_reg(PtxReg::TidX);
let lane_id = ctx.rem_u32(thread_id, 32);
let warp_id = ctx.div_u32(thread_id, 32);
let grid_dim = ctx.special_reg(PtxReg::NctaIdX);
let n_dim = ctx.load_param_u32("n_dim");
let k_dim = ctx.load_param_u32("k_dim");
let y_ptr = ctx.load_param_u64("y_ptr");
let w_ptr = ctx.load_param_u64("w_ptr");
let q8_ptr = ctx.load_param_u64("q8_ptr");
let k_rounded = ctx.add_u32(k_dim, Q4K_SUPER_BLOCK_SIZE - 1);
let num_sb = ctx.div_u32(k_rounded, Q4K_SUPER_BLOCK_SIZE);
let sb_bytes_reg = ctx.mov_u32_imm(Q4K_SUPER_BLOCK_BYTES);
let row_bytes = ctx.mul_u32_reg(num_sb, sb_bytes_reg);
let half_lane = ctx.and_u32_imm(lane_id, 15);
let half_warp_in_warp = ctx.shr_u32_imm(lane_id, 4);
let warp_x2 = ctx.shl_u32_imm(warp_id, 1);
let half_warp_id = ctx.add_u32_reg(warp_x2, half_warp_in_warp);
let num_hw = ctx.mov_u32_imm(num_half_warps);
let bq8_group = ctx.shr_u32_imm(half_lane, 2);
let lane_in_group = ctx.and_u32_imm(half_lane, 3);
let bq8_offset = ctx.shl_u32_imm(bq8_group, 1);
let t1 = ctx.shl_u32_imm(bq8_offset, 4);
let t2 = ctx.shl_u32_imm(lane_in_group, 2);
let q4_local = ctx.add_u32_reg(t1, t2);
let q4_off = ctx.add_u32(q4_local, 16);
let q4_off_64 = ctx.cvt_u64_u32(q4_off);
let c_36_u32 = ctx.mov_u32_imm(36);
let bq8_bytes = ctx.mul_u32_reg(bq8_offset, c_36_u32);
let bq8_bytes_64 = ctx.cvt_u64_u32(bq8_bytes);
let lig_x4 = ctx.shl_u32_imm(lane_in_group, 2);
let lig_x4_64 = ctx.cvt_u64_u32(lig_x4);
let c_2_64 = ctx.mov_u64_imm(2);
let c_4_64 = ctx.mov_u64_imm(4);
let c_8_64 = ctx.mov_u64_imm(8);
let c_16_64 = ctx.mov_u64_imm(16);
let c_32_64 = ctx.mov_u64_imm(32);
let c_36_64 = ctx.mov_u64_imm(36);
let c_288 = ctx.mov_u32_imm(288);
let ci_mod2 = ctx.and_u32_imm(bq8_group, 1);
let c_16_u32 = ctx.mov_u32_imm(16);
let byte_shift = ctx.mul_u32_reg(ci_mod2, c_16_u32);
let c_8_u32 = ctx.mov_u32_imm(8);
let byte_shift_hi = ctx.add_u32_reg(byte_shift, c_8_u32);
let c_2_u32 = ctx.mov_u32_imm(2);
let p_hi = ctx.setp_ge_u32(bq8_group, c_2_u32);
let c_ones = ctx.mov_u32_imm(0x0101_0101);
let c_mask_6bit = ctx.mov_u32_imm(0x3F3F_3F3F);
let c_mask_4bit = ctx.mov_u32_imm(0x0F0F_0F0F);
let c_mask_2bit = ctx.mov_u32_imm(0x0303_0303);
let q8_vec_stride = ctx.mul_u32_reg(num_sb, c_288);
let _q8_vec_stride_64 = ctx.cvt_u64_u32(q8_vec_stride);
let f32_zero = ctx.mov_f32_imm(0.0);
let mut accs = Vec::with_capacity(m as usize);
for _ in 0..m {
accs.push(ctx.mov_f32_imm(0.0));
}
let row_idx = ctx.mov_u32_imm(0);
ctx.add_u32_reg_inplace(row_idx, block_id);
ctx.label("bhw_row_loop");
let row_oob = ctx.setp_ge_u32(row_idx, n_dim);
ctx.branch_if(row_oob, "bhw_exit");
let row_off = ctx.mul_wide_u32_reg(row_idx, row_bytes);
let row_base = ctx.add_u64(w_ptr, row_off);
for acc in &accs {
ctx.mov_f32_reg(*acc, f32_zero);
}
let sb_idx = ctx.mov_u32_imm(0);
ctx.add_u32_reg_inplace(sb_idx, half_warp_id);
ctx.label("bhw_sb_loop");
let sb_done = ctx.setp_ge_u32(sb_idx, num_sb);
ctx.branch_if(sb_done, "bhw_sb_end");
let sb_off = ctx.mul_wide_u32(sb_idx, Q4K_SUPER_BLOCK_BYTES);
let sb_addr = ctx.add_u64(row_base, sb_off);
let d_f16 = ctx.ld_global_f16(sb_addr);
let d = ctx.cvt_f32_f16(d_f16);
let dmin_addr = ctx.add_u64(sb_addr, c_2_64);
let dmin_f16 = ctx.ld_global_f16(dmin_addr);
let dmin = ctx.cvt_f32_f16(dmin_f16);
let neg_dmin = ctx.neg_f32(dmin);
let sc_base = ctx.add_u64(sb_addr, c_4_64);
let sc03 = ctx.ld_global_u32(sc_base);
let sc47_addr = ctx.add_u64(sc_base, c_4_64);
let sc47 = ctx.ld_global_u32(sc47_addr);
let sc811_addr = ctx.add_u64(sc_base, c_8_64);
let sc811 = ctx.ld_global_u32(sc811_addr);
let sc_lo4 = ctx.and_u32(sc03, c_mask_6bit);
let mn_lo4 = ctx.and_u32(sc47, c_mask_6bit);
let sc_hi_low = ctx.and_u32(sc811, c_mask_4bit);
let t = ctx.shr_u32_imm(sc03, 6);
let t = ctx.and_u32(t, c_mask_2bit);
let sc_hi_top = ctx.shl_u32_imm(t, 4);
let sc_hi4 = ctx.or_u32(sc_hi_low, sc_hi_top);
let mn_hi_raw = ctx.shr_u32_imm(sc811, 4);
let mn_hi_low = ctx.and_u32(mn_hi_raw, c_mask_4bit);
let t = ctx.shr_u32_imm(sc47, 6);
let t = ctx.and_u32(t, c_mask_2bit);
let mn_hi_top = ctx.shl_u32_imm(t, 4);
let mn_hi4 = ctx.or_u32(mn_hi_low, mn_hi_top);
let sc_src = ctx.selp_u32(p_hi, sc_hi4, sc_lo4);
let mn_src = ctx.selp_u32(p_hi, mn_hi4, mn_lo4);
let sc0 = ctx.bfe_u32_reg(sc_src, byte_shift, 8);
let sc1 = ctx.bfe_u32_reg(sc_src, byte_shift_hi, 8);
let mn0 = ctx.bfe_u32_reg(mn_src, byte_shift, 8);
let mn1 = ctx.bfe_u32_reg(mn_src, byte_shift_hi, 8);
let q4_addr = ctx.add_u64(sb_addr, q4_off_64);
let v0 = ctx.ld_global_u32(q4_addr);
let v1_addr = ctx.add_u64(q4_addr, c_16_64);
let v1 = ctx.ld_global_u32(v1_addr);
let v0_lo = ctx.and_u32(v0, c_mask_4bit);
let v1_lo = ctx.and_u32(v1, c_mask_4bit);
let v0_hi = ctx.shr_u32_imm(v0, 4);
let v0_hi = ctx.and_u32(v0_hi, c_mask_4bit);
let v1_hi = ctx.shr_u32_imm(v1, 4);
let v1_hi = ctx.and_u32(v1_hi, c_mask_4bit);
let q8_sb_off_base = ctx.mul_wide_u32_reg(sb_idx, c_288);
for mi in 0..m {
let q8_m_off = if mi == 0 {
ctx.mov_u64_imm(0)
} else {
let mi_reg = ctx.mov_u32_imm(mi);
ctx.mul_wide_u32_reg(mi_reg, q8_vec_stride)
};
let q8_m_base = ctx.add_u64(q8_ptr, q8_m_off);
let q8_sb_base = ctx.add_u64(q8_m_base, q8_sb_off_base);
let q8_blk = ctx.add_u64(q8_sb_base, bq8_bytes_64);
let q8_data = ctx.add_u64(q8_blk, lig_x4_64);
let u0_lo = ctx.ld_global_u32(q8_data);
let u1_lo_addr = ctx.add_u64(q8_data, c_16_64);
let u1_lo = ctx.ld_global_u32(u1_lo_addr);
let dot0 = ctx.mov_u32_imm(0);
ctx.dp4a_u32_s32_inplace(dot0, v0_lo, u0_lo);
ctx.dp4a_u32_s32_inplace(dot0, v1_lo, u1_lo);
let sum0 = ctx.mov_u32_imm(0);
ctx.dp4a_u32_s32_inplace(sum0, c_ones, u0_lo);
ctx.dp4a_u32_s32_inplace(sum0, c_ones, u1_lo);
let q8_d0_addr = ctx.add_u64(q8_blk, c_32_64);
let q8_d0_f16 = ctx.ld_global_f16(q8_d0_addr);
let q8_d0 = ctx.cvt_f32_f16(q8_d0_f16);
let sdot0 = ctx.mul_lo_s32(sc0, dot0);
let msum0 = ctx.mul_lo_s32(mn0, sum0);
let sdot0_f = ctx.cvt_f32_s32(sdot0);
let msum0_f = ctx.cvt_f32_s32(msum0);
let t1 = ctx.mul_f32(d, sdot0_f);
let t3 = ctx.fma_f32(neg_dmin, msum0_f, t1);
let q8_d0_t3 = ctx.mul_f32(q8_d0, t3);
ctx.add_f32_inplace(accs[mi as usize], q8_d0_t3);
let q8_blk_hi = ctx.add_u64(q8_blk, c_36_64);
let q8_data_hi = ctx.add_u64(q8_blk_hi, lig_x4_64);
let u0_hi = ctx.ld_global_u32(q8_data_hi);
let u1_hi_addr = ctx.add_u64(q8_data_hi, c_16_64);
let u1_hi = ctx.ld_global_u32(u1_hi_addr);
let dot1 = ctx.mov_u32_imm(0);
ctx.dp4a_u32_s32_inplace(dot1, v0_hi, u0_hi);
ctx.dp4a_u32_s32_inplace(dot1, v1_hi, u1_hi);
let sum1 = ctx.mov_u32_imm(0);
ctx.dp4a_u32_s32_inplace(sum1, c_ones, u0_hi);
ctx.dp4a_u32_s32_inplace(sum1, c_ones, u1_hi);
let q8_d1_addr = ctx.add_u64(q8_blk_hi, c_32_64);
let q8_d1_f16 = ctx.ld_global_f16(q8_d1_addr);
let q8_d1 = ctx.cvt_f32_f16(q8_d1_f16);
let sdot1 = ctx.mul_lo_s32(sc1, dot1);
let msum1 = ctx.mul_lo_s32(mn1, sum1);
let sdot1_f = ctx.cvt_f32_s32(sdot1);
let msum1_f = ctx.cvt_f32_s32(msum1);
let t1 = ctx.mul_f32(d, sdot1_f);
let t3 = ctx.fma_f32(neg_dmin, msum1_f, t1);
let q8_d1_t3 = ctx.mul_f32(q8_d1, t3);
ctx.add_f32_inplace(accs[mi as usize], q8_d1_t3);
}
ctx.add_u32_reg_inplace(sb_idx, num_hw);
ctx.branch("bhw_sb_loop");
ctx.label("bhw_sb_end");
for acc in &accs {
let t = ctx.shfl_down_f32(*acc, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(*acc, t);
let t = ctx.shfl_down_f32(*acc, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(*acc, t);
let t = ctx.shfl_down_f32(*acc, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(*acc, t);
let t = ctx.shfl_down_f32(*acc, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(*acc, t);
}
let z = ctx.mov_u32_imm(0);
let is_hl0 = ctx.setp_eq_u32(half_lane, z);
ctx.branch_if_not(is_hl0, "bhw_skip_sm");
for (mi, acc) in accs.iter().enumerate() {
let hw_m = ctx.mul_u32(half_warp_id, m);
let idx = ctx.add_u32(hw_m, mi as u32);
let sm_off = ctx.shl_u32_imm(idx, 2);
let sm_addr = ctx.cvt_u64_u32(sm_off);
ctx.st_shared_f32(sm_addr, *acc);
}
ctx.label("bhw_skip_sm");
ctx.bar_sync(0);
let is_warp0 = ctx.setp_eq_u32(warp_id, z);
ctx.branch_if_not(is_warp0, "bhw_skip_store");
let in_range = ctx.setp_lt_u32_imm(lane_id, num_half_warps);
let zero_f = ctx.mov_f32_imm(0.0);
let is_l0 = ctx.setp_eq_u32(lane_id, z);
for mi in 0..m {
let mi_off = ctx.mov_u32_imm(mi);
let hw_m = ctx.mul_u32(lane_id, m);
let idx = ctx.add_u32_reg(hw_m, mi_off);
let sm_off = ctx.shl_u32_imm(idx, 2);
let sm_addr = ctx.cvt_u64_u32(sm_off);
let loaded = ctx.ld_shared_f32(sm_addr);
let partial = ctx.selp_f32(in_range, loaded, zero_f);
let t = ctx.shfl_down_f32(partial, 4, 0xFFFF_FFFF);
let partial = ctx.add_f32(partial, t);
let t = ctx.shfl_down_f32(partial, 2, 0xFFFF_FFFF);
let partial = ctx.add_f32(partial, t);
let t = ctx.shfl_down_f32(partial, 1, 0xFFFF_FFFF);
let result = ctx.add_f32(partial, t);
let skip_label = format!("bhw_skip_mi{mi}");
ctx.branch_if_not(is_l0, &skip_label);
let mi_reg = ctx.mov_u32_imm(mi);
let y_mi_base = ctx.mul_u32_reg(mi_reg, n_dim);
let y_idx = ctx.add_u32_reg(y_mi_base, row_idx);
let y_off = ctx.mul_wide_u32(y_idx, 4);
let y_addr = ctx.add_u64(y_ptr, y_off);
ctx.st_global_f32(y_addr, result);
ctx.label(&skip_label);
}
ctx.label("bhw_skip_store");
ctx.add_u32_reg_inplace(row_idx, grid_dim);
ctx.bar_sync(0);
ctx.branch("bhw_row_loop");
ctx.label("bhw_exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ptx_emits_batched_hw_dp4a() {
let k = BatchedHwDp4aQ4KGemvKernel::new(1536, 256, 4);
let ptx = k.emit_ptx();
assert!(
ptx.contains("batched_hw_dp4a_q4k_gemv"),
"kernel name present"
);
assert!(ptx.contains("dp4a.u32.s32"), "DP4A instruction present");
assert!(ptx.contains("and.b32"), "half_lane = lane_id & 15");
}
#[test]
fn test_batched_m2() {
let k = BatchedHwDp4aQ4KGemvKernel::new(1536, 256, 2);
let ptx = k.emit_ptx();
assert!(!ptx.is_empty());
}
#[test]
fn test_batched_m8() {
let k = BatchedHwDp4aQ4KGemvKernel::new(1536, 256, 8);
let ptx = k.emit_ptx();
assert!(!ptx.is_empty());
}
}