use crate::kernels::quantize::{Kernel, Q4K_SUPER_BLOCK_BYTES, Q4K_SUPER_BLOCK_SIZE};
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct TrueDp4aQ4KGemvKernel {
pub k: u32,
pub n: u32,
}
impl TrueDp4aQ4KGemvKernel {
#[must_use]
pub fn new(k: u32, n: u32) -> Self {
Self { k, n }
}
#[must_use]
pub const fn num_super_blocks_per_row(&self) -> u32 {
(self.k + Q4K_SUPER_BLOCK_SIZE - 1) / Q4K_SUPER_BLOCK_SIZE
}
}
impl Kernel for TrueDp4aQ4KGemvKernel {
fn name(&self) -> &str {
"true_dp4a_q4k_gemv"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("true_dp4a_q4k_gemv")
.param(PtxType::U64, "y_ptr")
.param(PtxType::U64, "w_ptr")
.param(PtxType::U64, "x_ptr")
.param(PtxType::U32, "k_dim")
.param(PtxType::U32, "n_dim")
.build(|ctx| {
let block_id = ctx.special_reg(PtxReg::CtaIdX);
let thread_id = ctx.special_reg(PtxReg::TidX);
let lane_id = ctx.rem_u32(thread_id, 32);
let n_dim = ctx.load_param_u32("n_dim");
let oob = ctx.setp_ge_u32(block_id, n_dim);
ctx.branch_if(oob, "exit");
let k_dim = ctx.load_param_u32("k_dim");
let y_ptr = ctx.load_param_u64("y_ptr");
let w_ptr = ctx.load_param_u64("w_ptr");
let x_ptr = ctx.load_param_u64("x_ptr");
let _int_acc = ctx.mov_u32_imm(0);
let float_acc = ctx.mov_f32_imm(0.0);
let k_rounded = ctx.add_u32(k_dim, Q4K_SUPER_BLOCK_SIZE - 1);
let num_super_blocks = ctx.div_u32(k_rounded, Q4K_SUPER_BLOCK_SIZE);
let sb_bytes = ctx.mov_u32_imm(Q4K_SUPER_BLOCK_BYTES);
let row_bytes = ctx.mul_u32_reg(num_super_blocks, sb_bytes);
let row_offset = ctx.mul_wide_u32_reg(block_id, row_bytes);
let row_base = ctx.add_u64(w_ptr, row_offset);
let sb_idx = ctx.mov_u32_imm(0);
ctx.label("sb_loop");
let sb_done = ctx.setp_ge_u32(sb_idx, num_super_blocks);
ctx.branch_if(sb_done, "sb_loop_end");
let sb_offset = ctx.mul_wide_u32(sb_idx, Q4K_SUPER_BLOCK_BYTES);
let sb_addr = ctx.add_u64(row_base, sb_offset);
let d_f16 = ctx.ld_global_f16(sb_addr);
let d = ctx.cvt_f32_f16(d_f16);
let two = ctx.mov_u64_imm(2);
let dmin_addr = ctx.add_u64(sb_addr, two);
let dmin_f16 = ctx.ld_global_f16(dmin_addr);
let dmin = ctx.cvt_f32_f16(dmin_f16);
let four_64 = ctx.mov_u64_imm(4);
let scales_base = ctx.add_u64(sb_addr, four_64);
let one = ctx.mov_u32_imm(1);
let is_lane0 = ctx.setp_lt_u32(lane_id, one);
let scales_0_3 = ctx.mov_u32_imm(0);
let scales_4_7 = ctx.mov_u32_imm(0);
let scales_8_11 = ctx.mov_u32_imm(0);
ctx.branch_if_not(is_lane0, "skip_scale_load_true");
ctx.ld_global_u32_into(scales_0_3, scales_base);
let four_64b = ctx.mov_u64_imm(4);
let scales_4_addr = ctx.add_u64(scales_base, four_64b);
ctx.ld_global_u32_into(scales_4_7, scales_4_addr);
let eight_64 = ctx.mov_u64_imm(8);
let scales_8_addr = ctx.add_u64(scales_base, eight_64);
ctx.ld_global_u32_into(scales_8_11, scales_8_addr);
ctx.label("skip_scale_load_true");
let scales_0_3_bcast = ctx.shfl_idx_u32(scales_0_3, 0, 0xFFFF_FFFF);
let scales_4_7_bcast = ctx.shfl_idx_u32(scales_4_7, 0, 0xFFFF_FFFF);
let _scales_8_11_bcast = ctx.shfl_idx_u32(scales_8_11, 0, 0xFFFF_FFFF);
let mask_6bit = ctx.mov_u32_imm(0x3F);
let mask_4bit = ctx.mov_u32_imm(0x0F);
let four_shift = ctx.mov_u32_imm(4);
let scale0 = ctx.and_u32(scales_0_3_bcast, mask_6bit);
let min0 = ctx.and_u32(scales_4_7_bcast, mask_6bit);
let scale0_f = ctx.cvt_f32_u32(scale0);
let min0_f = ctx.cvt_f32_u32(min0);
let inv_256 = ctx.mov_f32_imm(1.0 / 256.0);
let ds0 = ctx.mul_f32(d, scale0_f);
let _ds0_scaled = ctx.mul_f32(ds0, inv_256);
let dm0 = ctx.mul_f32(dmin, min0_f);
let sixteen_64 = ctx.mov_u64_imm(16);
let qs_base = ctx.add_u64(sb_addr, sixteen_64);
let qs_offset_64 = ctx.cvt_u64_u32(lane_id);
let qs_addr = ctx.add_u64(qs_base, qs_offset_64);
let packed_byte = ctx.ld_global_u8(qs_addr);
let packed = ctx.cvt_u32_u8(packed_byte);
let nibble0 = ctx.and_u32(packed, mask_4bit);
let nibble0_expanded = ctx.shl_u32(nibble0, four_shift);
let nibble1 = ctx.shr_u32(packed, four_shift);
let nibble1_expanded = ctx.shl_u32(nibble1, four_shift);
let eight_shift = ctx.mov_u32_imm(8);
let nibble1_shifted = ctx.shl_u32(nibble1_expanded, eight_shift);
let weights_lo = ctx.or_u32(nibble0_expanded, nibble1_shifted);
let one_64 = ctx.mov_u64_imm(1);
let qs_addr_hi = ctx.add_u64(qs_addr, one_64);
let packed_byte_hi = ctx.ld_global_u8(qs_addr_hi);
let packed_hi = ctx.cvt_u32_u8(packed_byte_hi);
let nibble2 = ctx.and_u32(packed_hi, mask_4bit);
let nibble2_expanded = ctx.shl_u32(nibble2, four_shift);
let nibble3 = ctx.shr_u32(packed_hi, four_shift);
let nibble3_expanded = ctx.shl_u32(nibble3, four_shift);
let sixteen_shift = ctx.mov_u32_imm(16);
let twenty_four_shift = ctx.mov_u32_imm(24);
let nibble2_shifted = ctx.shl_u32(nibble2_expanded, sixteen_shift);
let nibble3_shifted = ctx.shl_u32(nibble3_expanded, twenty_four_shift);
let weights_mid = ctx.or_u32(weights_lo, nibble2_shifted);
let _weights_packed = ctx.or_u32(weights_mid, nibble3_shifted);
let sb_k_base = ctx.mul_u32(sb_idx, Q4K_SUPER_BLOCK_SIZE);
let x_idx0 = ctx.add_u32_reg(sb_k_base, lane_id);
let x_idx0_64 = ctx.cvt_u64_u32(x_idx0);
let x_bytes0 = ctx.mul_u64(x_idx0_64, 4);
let x_addr0 = ctx.add_u64(x_ptr, x_bytes0);
let x_val0 = ctx.ld_global_f32(x_addr0);
let nibble0_f = ctx.cvt_f32_u32(nibble0);
let nibble1_f = ctx.cvt_f32_u32(nibble1);
let scaled0 = ctx.mul_f32(ds0, nibble0_f);
let dequant0 = ctx.sub_f32(scaled0, dm0);
ctx.fma_f32_inplace(float_acc, x_val0, dequant0);
let thirty_two = ctx.mov_u32_imm(32);
let x_idx1 = ctx.add_u32_reg(x_idx0, thirty_two);
let x_idx1_64 = ctx.cvt_u64_u32(x_idx1);
let x_bytes1 = ctx.mul_u64(x_idx1_64, 4);
let x_addr1 = ctx.add_u64(x_ptr, x_bytes1);
let x_val1 = ctx.ld_global_f32(x_addr1);
let scaled1 = ctx.mul_f32(ds0, nibble1_f);
let dequant1 = ctx.sub_f32(scaled1, dm0);
ctx.fma_f32_inplace(float_acc, x_val1, dequant1);
let sixty_four = ctx.mov_u32_imm(64);
let x_idx2 = ctx.add_u32_reg(x_idx0, sixty_four);
let x_idx2_64 = ctx.cvt_u64_u32(x_idx2);
let x_bytes2 = ctx.mul_u64(x_idx2_64, 4);
let x_addr2 = ctx.add_u64(x_ptr, x_bytes2);
let x_val2 = ctx.ld_global_f32(x_addr2);
let qs_offset2 = ctx.add_u32_reg(lane_id, thirty_two);
let qs_offset2_64 = ctx.cvt_u64_u32(qs_offset2);
let qs_addr2 = ctx.add_u64(qs_base, qs_offset2_64);
let packed_byte2 = ctx.ld_global_u8(qs_addr2);
let packed2 = ctx.cvt_u32_u8(packed_byte2);
let nibble2_val = ctx.and_u32(packed2, mask_4bit);
let nibble2_f_val = ctx.cvt_f32_u32(nibble2_val);
let scaled2 = ctx.mul_f32(ds0, nibble2_f_val);
let dequant2 = ctx.sub_f32(scaled2, dm0);
ctx.fma_f32_inplace(float_acc, x_val2, dequant2);
let ninety_six = ctx.mov_u32_imm(96);
let x_idx3 = ctx.add_u32_reg(x_idx0, ninety_six);
let x_idx3_64 = ctx.cvt_u64_u32(x_idx3);
let x_bytes3 = ctx.mul_u64(x_idx3_64, 4);
let x_addr3 = ctx.add_u64(x_ptr, x_bytes3);
let x_val3 = ctx.ld_global_f32(x_addr3);
let nibble3_val = ctx.shr_u32(packed2, four_shift);
let nibble3_f_val = ctx.cvt_f32_u32(nibble3_val);
let scaled3 = ctx.mul_f32(ds0, nibble3_f_val);
let dequant3 = ctx.sub_f32(scaled3, dm0);
ctx.fma_f32_inplace(float_acc, x_val3, dequant3);
let one_twenty_eight = ctx.mov_u32_imm(128);
let x_idx4 = ctx.add_u32_reg(x_idx0, one_twenty_eight);
let x_idx4_64 = ctx.cvt_u64_u32(x_idx4);
let x_bytes4 = ctx.mul_u64(x_idx4_64, 4);
let x_addr4 = ctx.add_u64(x_ptr, x_bytes4);
let x_val4 = ctx.ld_global_f32(x_addr4);
let qs_offset4 = ctx.add_u32_reg(lane_id, sixty_four);
let qs_offset4_64 = ctx.cvt_u64_u32(qs_offset4);
let qs_addr4 = ctx.add_u64(qs_base, qs_offset4_64);
let packed_byte4 = ctx.ld_global_u8(qs_addr4);
let packed4 = ctx.cvt_u32_u8(packed_byte4);
let nibble4_val = ctx.and_u32(packed4, mask_4bit);
let nibble4_f_val = ctx.cvt_f32_u32(nibble4_val);
let scaled4 = ctx.mul_f32(ds0, nibble4_f_val);
let dequant4 = ctx.sub_f32(scaled4, dm0);
ctx.fma_f32_inplace(float_acc, x_val4, dequant4);
let one_sixty = ctx.mov_u32_imm(160);
let x_idx5 = ctx.add_u32_reg(x_idx0, one_sixty);
let x_idx5_64 = ctx.cvt_u64_u32(x_idx5);
let x_bytes5 = ctx.mul_u64(x_idx5_64, 4);
let x_addr5 = ctx.add_u64(x_ptr, x_bytes5);
let x_val5 = ctx.ld_global_f32(x_addr5);
let nibble5_val = ctx.shr_u32(packed4, four_shift);
let nibble5_f_val = ctx.cvt_f32_u32(nibble5_val);
let scaled5 = ctx.mul_f32(ds0, nibble5_f_val);
let dequant5 = ctx.sub_f32(scaled5, dm0);
ctx.fma_f32_inplace(float_acc, x_val5, dequant5);
let one_ninety_two = ctx.mov_u32_imm(192);
let x_idx6 = ctx.add_u32_reg(x_idx0, one_ninety_two);
let x_idx6_64 = ctx.cvt_u64_u32(x_idx6);
let x_bytes6 = ctx.mul_u64(x_idx6_64, 4);
let x_addr6 = ctx.add_u64(x_ptr, x_bytes6);
let x_val6 = ctx.ld_global_f32(x_addr6);
let qs_offset6 = ctx.add_u32_reg(lane_id, ninety_six);
let qs_offset6_64 = ctx.cvt_u64_u32(qs_offset6);
let qs_addr6 = ctx.add_u64(qs_base, qs_offset6_64);
let packed_byte6 = ctx.ld_global_u8(qs_addr6);
let packed6 = ctx.cvt_u32_u8(packed_byte6);
let nibble6_val = ctx.and_u32(packed6, mask_4bit);
let nibble6_f_val = ctx.cvt_f32_u32(nibble6_val);
let scaled6 = ctx.mul_f32(ds0, nibble6_f_val);
let dequant6 = ctx.sub_f32(scaled6, dm0);
ctx.fma_f32_inplace(float_acc, x_val6, dequant6);
let two_twenty_four = ctx.mov_u32_imm(224);
let x_idx7 = ctx.add_u32_reg(x_idx0, two_twenty_four);
let x_idx7_64 = ctx.cvt_u64_u32(x_idx7);
let x_bytes7 = ctx.mul_u64(x_idx7_64, 4);
let x_addr7 = ctx.add_u64(x_ptr, x_bytes7);
let x_val7 = ctx.ld_global_f32(x_addr7);
let nibble7_val = ctx.shr_u32(packed6, four_shift);
let nibble7_f_val = ctx.cvt_f32_u32(nibble7_val);
let scaled7 = ctx.mul_f32(ds0, nibble7_f_val);
let dequant7 = ctx.sub_f32(scaled7, dm0);
ctx.fma_f32_inplace(float_acc, x_val7, dequant7);
ctx.add_u32_inplace(sb_idx, 1);
ctx.branch("sb_loop");
ctx.label("sb_loop_end");
let tmp16 = ctx.shfl_down_f32(float_acc, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(float_acc, tmp16);
let tmp8 = ctx.shfl_down_f32(float_acc, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(float_acc, tmp8);
let tmp4 = ctx.shfl_down_f32(float_acc, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(float_acc, tmp4);
let tmp2 = ctx.shfl_down_f32(float_acc, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(float_acc, tmp2);
let tmp1 = ctx.shfl_down_f32(float_acc, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(float_acc, tmp1);
let one_u32 = ctx.mov_u32_imm(1);
let is_thread0 = ctx.setp_lt_u32(lane_id, one_u32);
ctx.branch_if_not(is_thread0, "exit");
let y_offset = ctx.mul_wide_u32(block_id, 4);
let y_addr = ctx.add_u64(y_ptr, y_offset);
ctx.st_global_f32(y_addr, float_acc);
ctx.label("exit");
ctx.ret();
})
}
}