use crate::kernels::quantize::{Kernel, Q4K_SUPER_BLOCK_BYTES, Q4K_SUPER_BLOCK_SIZE};
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl, PtxMemory, PtxSync};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
pub struct FusedFp32Q4KGemvKernel {
pub k: u32,
pub n: u32,
pub m: u32,
pub num_warps: u32,
}
impl FusedFp32Q4KGemvKernel {
pub fn new(k: u32, n: u32, m: u32) -> Self {
Self {
k,
n,
m,
num_warps: 3,
}
}
}
impl Kernel for FusedFp32Q4KGemvKernel {
fn name(&self) -> &str {
"fused_fp32_q4k_gemv"
}
fn build_ptx(&self) -> PtxKernel {
let num_warps = self.num_warps;
let num_half_warps = num_warps * 2;
let m = self.m;
let smem_size = (num_half_warps * m * 4) as usize;
PtxKernel::new("fused_fp32_q4k_gemv")
.param(PtxType::U64, "y_ptr") .param(PtxType::U64, "w_ptr") .param(PtxType::U64, "x_ptr") .param(PtxType::U32, "k_dim")
.param(PtxType::U32, "n_dim")
.param(PtxType::U32, "m_dim")
.shared_memory(smem_size)
.max_regs(255)
.build(move |ctx| {
let block_id = ctx.special_reg(PtxReg::CtaIdX);
let thread_id = ctx.special_reg(PtxReg::TidX);
let lane_id = ctx.rem_u32(thread_id, 32);
let warp_id = ctx.div_u32(thread_id, 32);
let grid_dim = ctx.special_reg(PtxReg::NctaIdX);
let n_dim = ctx.load_param_u32("n_dim");
let k_dim = ctx.load_param_u32("k_dim");
let y_ptr = ctx.load_param_u64("y_ptr");
let w_ptr = ctx.load_param_u64("w_ptr");
let x_ptr = ctx.load_param_u64("x_ptr");
let k_rounded = ctx.add_u32(k_dim, Q4K_SUPER_BLOCK_SIZE - 1);
let num_sb = ctx.div_u32(k_rounded, Q4K_SUPER_BLOCK_SIZE);
let sb_bytes_reg = ctx.mov_u32_imm(Q4K_SUPER_BLOCK_BYTES);
let row_bytes = ctx.mul_u32_reg(num_sb, sb_bytes_reg);
let half_lane = ctx.and_u32_imm(lane_id, 15);
let half_warp_in_warp = ctx.shr_u32_imm(lane_id, 4);
let warp_x2 = ctx.shl_u32_imm(warp_id, 1);
let half_warp_id = ctx.add_u32_reg(warp_x2, half_warp_in_warp);
let num_hw = ctx.mov_u32_imm(num_half_warps);
let bq8_group = ctx.shr_u32_imm(half_lane, 2);
let lane_in_group = ctx.and_u32_imm(half_lane, 3);
let bq8_offset = ctx.shl_u32_imm(bq8_group, 1);
let t1 = ctx.shl_u32_imm(bq8_offset, 4);
let t2 = ctx.shl_u32_imm(lane_in_group, 2);
let q4_local = ctx.add_u32_reg(t1, t2);
let q4_off = ctx.add_u32(q4_local, 16);
let q4_off_64 = ctx.cvt_u64_u32(q4_off);
let c_256_u32 = ctx.mov_u32_imm(256);
let c_4_u64 = ctx.mov_u64_imm(4);
let c_16_u64 = ctx.mov_u64_imm(16);
let c_4_stride = ctx.mov_u32_imm(4);
let x_vec_stride = ctx.mul_u32_reg(k_dim, c_4_stride);
let c_2_64 = ctx.mov_u64_imm(2);
let c_4_64 = ctx.mov_u64_imm(4);
let c_8_64 = ctx.mov_u64_imm(8);
let c_16_64 = ctx.mov_u64_imm(16);
let _c_32_64 = ctx.mov_u64_imm(32);
let ci_mod2 = ctx.and_u32_imm(bq8_group, 1);
let c_16_u32 = ctx.mov_u32_imm(16);
let byte_shift = ctx.mul_u32_reg(ci_mod2, c_16_u32);
let c_8_u32 = ctx.mov_u32_imm(8);
let byte_shift_hi = ctx.add_u32_reg(byte_shift, c_8_u32);
let c_2_u32 = ctx.mov_u32_imm(2);
let p_hi = ctx.setp_ge_u32(bq8_group, c_2_u32);
let c_mask_6bit = ctx.mov_u32_imm(0x3F3F_3F3F);
let c_mask_4bit = ctx.mov_u32_imm(0x0F0F_0F0F);
let c_mask_2bit = ctx.mov_u32_imm(0x0303_0303);
let f32_zero = ctx.mov_f32_imm(0.0);
let mut accs = Vec::with_capacity(m as usize);
for _ in 0..m {
accs.push(ctx.mov_f32_imm(0.0));
}
let row_idx = ctx.mov_u32_imm(0);
ctx.add_u32_reg_inplace(row_idx, block_id);
ctx.label("fused_row_loop");
let row_oob = ctx.setp_ge_u32(row_idx, n_dim);
ctx.branch_if(row_oob, "fused_exit");
let row_off = ctx.mul_wide_u32_reg(row_idx, row_bytes);
let row_base = ctx.add_u64(w_ptr, row_off);
for acc in &accs {
ctx.mov_f32_reg(*acc, f32_zero);
}
let sb_idx = ctx.mov_u32_imm(0);
ctx.add_u32_reg_inplace(sb_idx, half_warp_id);
ctx.label("fused_sb_loop");
let sb_done = ctx.setp_ge_u32(sb_idx, num_sb);
ctx.branch_if(sb_done, "fused_sb_end");
let sb_off = ctx.mul_wide_u32(sb_idx, Q4K_SUPER_BLOCK_BYTES);
let sb_addr = ctx.add_u64(row_base, sb_off);
let d_f16 = ctx.ld_global_f16(sb_addr);
let d = ctx.cvt_f32_f16(d_f16);
let dmin_addr = ctx.add_u64(sb_addr, c_2_64);
let dmin_f16 = ctx.ld_global_f16(dmin_addr);
let dmin = ctx.cvt_f32_f16(dmin_f16);
let sc_base = ctx.add_u64(sb_addr, c_4_64);
let sc03 = ctx.ld_global_u32(sc_base);
let sc47_addr = ctx.add_u64(sc_base, c_4_64);
let sc47 = ctx.ld_global_u32(sc47_addr);
let sc811_addr = ctx.add_u64(sc_base, c_8_64);
let sc811 = ctx.ld_global_u32(sc811_addr);
let sc_lo4 = ctx.and_u32(sc03, c_mask_6bit);
let mn_lo4 = ctx.and_u32(sc47, c_mask_6bit);
let sc_hi_low = ctx.and_u32(sc811, c_mask_4bit);
let t = ctx.shr_u32_imm(sc03, 6);
let t = ctx.and_u32(t, c_mask_2bit);
let sc_hi_top = ctx.shl_u32_imm(t, 4);
let sc_hi4 = ctx.or_u32(sc_hi_low, sc_hi_top);
let mn_hi_raw = ctx.shr_u32_imm(sc811, 4);
let mn_hi_low = ctx.and_u32(mn_hi_raw, c_mask_4bit);
let t = ctx.shr_u32_imm(sc47, 6);
let t = ctx.and_u32(t, c_mask_2bit);
let mn_hi_top = ctx.shl_u32_imm(t, 4);
let mn_hi4 = ctx.or_u32(mn_hi_low, mn_hi_top);
let sc_src = ctx.selp_u32(p_hi, sc_hi4, sc_lo4);
let mn_src = ctx.selp_u32(p_hi, mn_hi4, mn_lo4);
let sc0 = ctx.bfe_u32_reg(sc_src, byte_shift, 8);
let sc1 = ctx.bfe_u32_reg(sc_src, byte_shift_hi, 8);
let mn0 = ctx.bfe_u32_reg(mn_src, byte_shift, 8);
let mn1 = ctx.bfe_u32_reg(mn_src, byte_shift_hi, 8);
let q4_addr = ctx.add_u64(sb_addr, q4_off_64);
let v0 = ctx.ld_global_u32(q4_addr);
let v1_addr = ctx.add_u64(q4_addr, c_16_64);
let v1 = ctx.ld_global_u32(v1_addr);
let v0_lo = ctx.and_u32(v0, c_mask_4bit);
let v1_lo = ctx.and_u32(v1, c_mask_4bit);
let v0_hi = ctx.shr_u32_imm(v0, 4);
let v0_hi = ctx.and_u32(v0_hi, c_mask_4bit);
let v1_hi = ctx.shr_u32_imm(v1, 4);
let v1_hi = ctx.and_u32(v1_hi, c_mask_4bit);
let sb_k_base = ctx.mul_u32_reg(sb_idx, c_256_u32);
let bq8_k_off = ctx.shl_u32_imm(bq8_offset, 5); let lig_k_off = ctx.shl_u32_imm(lane_in_group, 2); let k_off_base = ctx.add_u32_reg(sb_k_base, bq8_k_off);
let k_off_lo = ctx.add_u32_reg(k_off_base, lig_k_off);
let k_off_hi = ctx.add_u32(k_off_lo, 128);
for mi in 0..m {
let x_m_off = if mi == 0 {
ctx.mov_u64_imm(0)
} else {
let mi_reg = ctx.mov_u32_imm(mi);
ctx.mul_wide_u32_reg(mi_reg, x_vec_stride)
};
let x_m_base = ctx.add_u64(x_ptr, x_m_off);
let lo_byte_off = ctx.mul_wide_u32(k_off_lo, 4); let x_lo_addr = ctx.add_u64(x_m_base, lo_byte_off);
let x0 = ctx.ld_global_f32(x_lo_addr);
let x1_addr = ctx.add_u64(x_lo_addr, c_4_u64);
let x1 = ctx.ld_global_f32(x1_addr);
let x2_addr = ctx.add_u64(x_lo_addr, c_8_64);
let x2 = ctx.ld_global_f32(x2_addr);
let x3_addr = ctx.add_u64(x_lo_addr, c_16_u64);
let x3 = ctx.ld_global_f32(x3_addr);
let nib0 = ctx.and_u32_imm(v0_lo, 0x0F);
let nib1_sh = ctx.shr_u32_imm(v0_lo, 8);
let nib1 = ctx.and_u32_imm(nib1_sh, 0x0F);
let nib2_sh = ctx.shr_u32_imm(v0_lo, 16);
let nib2 = ctx.and_u32_imm(nib2_sh, 0x0F);
let nib3_sh = ctx.shr_u32_imm(v0_lo, 24);
let nib3 = ctx.and_u32_imm(nib3_sh, 0x0F);
let nib0_f = ctx.cvt_f32_u32(nib0);
let nib1_f = ctx.cvt_f32_u32(nib1);
let nib2_f = ctx.cvt_f32_u32(nib2);
let nib3_f = ctx.cvt_f32_u32(nib3);
let wx_lo = ctx.mul_f32(nib0_f, x0);
ctx.fma_f32_inplace(wx_lo, nib1_f, x1);
ctx.fma_f32_inplace(wx_lo, nib2_f, x2);
ctx.fma_f32_inplace(wx_lo, nib3_f, x3);
let sx_lo = ctx.add_f32(x0, x1);
let t = ctx.add_f32(x2, x3);
ctx.add_f32_inplace(sx_lo, t);
let _x_lo2_addr = ctx.add_u64(x_lo_addr, c_16_u64);
let v1_k_off = ctx.add_u32(k_off_lo, 16); let v1_byte_off = ctx.mul_wide_u32(v1_k_off, 4);
let x_v1_addr = ctx.add_u64(x_m_base, v1_byte_off);
let x4 = ctx.ld_global_f32(x_v1_addr);
let x5_addr = ctx.add_u64(x_v1_addr, c_4_u64);
let x5 = ctx.ld_global_f32(x5_addr);
let x6_addr = ctx.add_u64(x_v1_addr, c_8_64);
let x6 = ctx.ld_global_f32(x6_addr);
let x7_addr = ctx.add_u64(x_v1_addr, c_16_u64);
let x7 = ctx.ld_global_f32(x7_addr);
let v1_nib0 = ctx.and_u32_imm(v1_lo, 0x0F);
let v1_nib1_sh = ctx.shr_u32_imm(v1_lo, 8);
let v1_nib1 = ctx.and_u32_imm(v1_nib1_sh, 0x0F);
let v1_nib2_sh = ctx.shr_u32_imm(v1_lo, 16);
let v1_nib2 = ctx.and_u32_imm(v1_nib2_sh, 0x0F);
let v1_nib3_sh = ctx.shr_u32_imm(v1_lo, 24);
let v1_nib3 = ctx.and_u32_imm(v1_nib3_sh, 0x0F);
let v1_nib0_f = ctx.cvt_f32_u32(v1_nib0);
let v1_nib1_f = ctx.cvt_f32_u32(v1_nib1);
let v1_nib2_f = ctx.cvt_f32_u32(v1_nib2);
let v1_nib3_f = ctx.cvt_f32_u32(v1_nib3);
let wx_lo2 = ctx.mul_f32(v1_nib0_f, x4);
ctx.fma_f32_inplace(wx_lo2, v1_nib1_f, x5);
ctx.fma_f32_inplace(wx_lo2, v1_nib2_f, x6);
ctx.fma_f32_inplace(wx_lo2, v1_nib3_f, x7);
ctx.add_f32_inplace(wx_lo, wx_lo2);
let sx_lo2 = ctx.add_f32(x4, x5);
let t = ctx.add_f32(x6, x7);
ctx.add_f32_inplace(sx_lo2, t);
ctx.add_f32_inplace(sx_lo, sx_lo2);
let sc0_f = ctx.cvt_f32_u32(sc0);
let mn0_f = ctx.cvt_f32_u32(mn0);
let d_sc0 = ctx.mul_f32(d, sc0_f);
let dmin_mn0 = ctx.mul_f32(dmin, mn0_f);
let contrib_lo = ctx.mul_f32(d_sc0, wx_lo);
let neg_min_lo = ctx.mul_f32(dmin_mn0, sx_lo);
let result_lo = ctx.sub_f32(contrib_lo, neg_min_lo);
ctx.add_f32_inplace(accs[mi as usize], result_lo);
let hi_byte_off = ctx.mul_wide_u32(k_off_hi, 4);
let x_hi_addr = ctx.add_u64(x_m_base, hi_byte_off);
let xh0 = ctx.ld_global_f32(x_hi_addr);
let xh1_addr = ctx.add_u64(x_hi_addr, c_4_u64);
let xh1 = ctx.ld_global_f32(xh1_addr);
let xh2_addr = ctx.add_u64(x_hi_addr, c_8_64);
let xh2 = ctx.ld_global_f32(xh2_addr);
let xh3_addr = ctx.add_u64(x_hi_addr, c_16_u64);
let xh3 = ctx.ld_global_f32(xh3_addr);
let hnib0 = ctx.and_u32_imm(v0_hi, 0x0F);
let hnib1_sh = ctx.shr_u32_imm(v0_hi, 8);
let hnib1 = ctx.and_u32_imm(hnib1_sh, 0x0F);
let hnib2_sh = ctx.shr_u32_imm(v0_hi, 16);
let hnib2 = ctx.and_u32_imm(hnib2_sh, 0x0F);
let hnib3_sh = ctx.shr_u32_imm(v0_hi, 24);
let hnib3 = ctx.and_u32_imm(hnib3_sh, 0x0F);
let hnib0_f = ctx.cvt_f32_u32(hnib0);
let hnib1_f = ctx.cvt_f32_u32(hnib1);
let hnib2_f = ctx.cvt_f32_u32(hnib2);
let hnib3_f = ctx.cvt_f32_u32(hnib3);
let wx_hi = ctx.mul_f32(hnib0_f, xh0);
ctx.fma_f32_inplace(wx_hi, hnib1_f, xh1);
ctx.fma_f32_inplace(wx_hi, hnib2_f, xh2);
ctx.fma_f32_inplace(wx_hi, hnib3_f, xh3);
let sx_hi = ctx.add_f32(xh0, xh1);
let t = ctx.add_f32(xh2, xh3);
ctx.add_f32_inplace(sx_hi, t);
let v1_hi_k_off = ctx.add_u32(k_off_hi, 16);
let v1_hi_byte_off = ctx.mul_wide_u32(v1_hi_k_off, 4);
let x_v1hi_addr = ctx.add_u64(x_m_base, v1_hi_byte_off);
let xh4 = ctx.ld_global_f32(x_v1hi_addr);
let xh5_addr = ctx.add_u64(x_v1hi_addr, c_4_u64);
let xh5 = ctx.ld_global_f32(xh5_addr);
let xh6_addr = ctx.add_u64(x_v1hi_addr, c_8_64);
let xh6 = ctx.ld_global_f32(xh6_addr);
let xh7_addr = ctx.add_u64(x_v1hi_addr, c_16_u64);
let xh7 = ctx.ld_global_f32(xh7_addr);
let v1_hnib0 = ctx.and_u32_imm(v1_hi, 0x0F);
let v1_hnib1_sh = ctx.shr_u32_imm(v1_hi, 8);
let v1_hnib1 = ctx.and_u32_imm(v1_hnib1_sh, 0x0F);
let v1_hnib2_sh = ctx.shr_u32_imm(v1_hi, 16);
let v1_hnib2 = ctx.and_u32_imm(v1_hnib2_sh, 0x0F);
let v1_hnib3_sh = ctx.shr_u32_imm(v1_hi, 24);
let v1_hnib3 = ctx.and_u32_imm(v1_hnib3_sh, 0x0F);
let v1_hnib0_f = ctx.cvt_f32_u32(v1_hnib0);
let v1_hnib1_f = ctx.cvt_f32_u32(v1_hnib1);
let v1_hnib2_f = ctx.cvt_f32_u32(v1_hnib2);
let v1_hnib3_f = ctx.cvt_f32_u32(v1_hnib3);
let wx_hi2 = ctx.mul_f32(v1_hnib0_f, xh4);
ctx.fma_f32_inplace(wx_hi2, v1_hnib1_f, xh5);
ctx.fma_f32_inplace(wx_hi2, v1_hnib2_f, xh6);
ctx.fma_f32_inplace(wx_hi2, v1_hnib3_f, xh7);
ctx.add_f32_inplace(wx_hi, wx_hi2);
let sx_hi2 = ctx.add_f32(xh4, xh5);
let t = ctx.add_f32(xh6, xh7);
ctx.add_f32_inplace(sx_hi2, t);
ctx.add_f32_inplace(sx_hi, sx_hi2);
let sc1_f = ctx.cvt_f32_u32(sc1);
let mn1_f = ctx.cvt_f32_u32(mn1);
let d_sc1 = ctx.mul_f32(d, sc1_f);
let dmin_mn1 = ctx.mul_f32(dmin, mn1_f);
let contrib_hi = ctx.mul_f32(d_sc1, wx_hi);
let neg_min_hi = ctx.mul_f32(dmin_mn1, sx_hi);
let result_hi = ctx.sub_f32(contrib_hi, neg_min_hi);
ctx.add_f32_inplace(accs[mi as usize], result_hi);
}
ctx.add_u32_reg_inplace(sb_idx, num_hw);
ctx.branch("fused_sb_loop");
ctx.label("fused_sb_end");
for acc in &accs {
let t = ctx.shfl_down_f32(*acc, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(*acc, t);
let t = ctx.shfl_down_f32(*acc, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(*acc, t);
let t = ctx.shfl_down_f32(*acc, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(*acc, t);
let t = ctx.shfl_down_f32(*acc, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(*acc, t);
}
let c_zero_u32 = ctx.mov_u32_imm(0);
let is_lane0 = ctx.setp_eq_u32(half_lane, c_zero_u32);
let smem_base = ctx.mov_u64_imm(0);
ctx.branch_if_not(is_lane0, "fused_skip_smem_write");
let c_m_u32 = ctx.mov_u32_imm(m);
for (mi, acc) in accs.iter().enumerate() {
let mi_reg = ctx.mov_u32_imm(mi as u32);
let hw_m = ctx.mul_u32_reg(half_warp_id, c_m_u32);
let smem_idx = ctx.add_u32_reg(hw_m, mi_reg);
let smem_off = ctx.mul_wide_u32(smem_idx, 4);
let smem_addr = ctx.add_u64(smem_base, smem_off);
ctx.st_shared_f32(smem_addr, *acc);
}
ctx.label("fused_skip_smem_write");
ctx.bar_sync(0);
let c_zero2 = ctx.mov_u32_imm(0);
let is_first_hw = ctx.setp_eq_u32(half_warp_id, c_zero2);
ctx.branch_if_not(is_first_hw, "fused_skip_output");
let c_zero3 = ctx.mov_u32_imm(0);
let is_lane0_first = ctx.setp_eq_u32(half_lane, c_zero3);
ctx.branch_if_not(is_lane0_first, "fused_skip_output");
let c_m_red = ctx.mov_u32_imm(m);
for mi in 0..m {
let total = ctx.mov_f32_imm(0.0);
for hw in 0..num_half_warps {
let mi_reg = ctx.mov_u32_imm(mi);
let hw_reg = ctx.mov_u32_imm(hw);
let hw_m = ctx.mul_u32_reg(hw_reg, c_m_red);
let idx = ctx.add_u32_reg(hw_m, mi_reg);
let off = ctx.mul_wide_u32(idx, 4);
let addr = ctx.add_u64(smem_base, off);
let val = ctx.ld_shared_f32(addr);
ctx.add_f32_inplace(total, val);
}
let mi_reg = ctx.mov_u32_imm(mi);
let row_m = ctx.mul_u32_reg(row_idx, c_m_red);
let out_idx = ctx.add_u32_reg(row_m, mi_reg);
let out_off = ctx.mul_wide_u32(out_idx, 4);
let out_addr = ctx.add_u64(y_ptr, out_off);
ctx.st_global_f32(out_addr, total);
}
ctx.label("fused_skip_output");
ctx.bar_sync(1);
ctx.add_u32_reg_inplace(row_idx, grid_dim);
ctx.branch("fused_row_loop");
ctx.label("fused_exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fused_kernel_builds_ptx() {
let kernel = FusedFp32Q4KGemvKernel::new(1536, 1536, 4);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("fused_fp32_q4k_gemv"));
assert!(ptx.contains(".entry"));
assert!(ptx.is_ascii(), "PTX must be pure ASCII");
}
#[test]
fn test_fused_kernel_m1() {
let kernel = FusedFp32Q4KGemvKernel::new(1536, 1536, 1);
let ptx = kernel.emit_ptx();
assert!(ptx.is_ascii());
assert!(ptx.contains(".entry"));
}
}