use super::{Q4K_SUPER_BLOCK_BYTES, Q4K_SUPER_BLOCK_SIZE};
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct Q4KQ8DotKernel {
pub k: u32,
pub n: u32,
}
impl Q4KQ8DotKernel {
#[must_use]
pub fn new(k: u32, n: u32) -> Self {
Self { k, n }
}
}
impl Kernel for Q4KQ8DotKernel {
fn name(&self) -> &str {
"q4k_q8_dot"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("q4k_q8_dot")
.param(PtxType::U64, "y_ptr") .param(PtxType::U64, "w_ptr") .param(PtxType::U64, "x_ptr") .param(PtxType::U32, "k_dim")
.param(PtxType::U32, "n_dim")
.build(|ctx| {
let block_id = ctx.special_reg(PtxReg::CtaIdX);
let thread_id = ctx.special_reg(PtxReg::TidX);
let lane_id = ctx.rem_u32(thread_id, 32);
let n_dim = ctx.load_param_u32("n_dim");
let oob = ctx.setp_ge_u32(block_id, n_dim);
ctx.branch_if(oob, "exit");
let k_dim = ctx.load_param_u32("k_dim");
let y_ptr = ctx.load_param_u64("y_ptr");
let w_ptr = ctx.load_param_u64("w_ptr");
let x_ptr = ctx.load_param_u64("x_ptr");
let float_acc = ctx.mov_f32_imm(0.0);
let num_sb = ctx.add_u32(k_dim, Q4K_SUPER_BLOCK_SIZE - 1);
let num_sb = ctx.div_u32(num_sb, Q4K_SUPER_BLOCK_SIZE);
let sb_bytes = ctx.mov_u32_imm(Q4K_SUPER_BLOCK_BYTES);
let row_bytes = ctx.mul_u32_reg(num_sb, sb_bytes);
let row_offset = ctx.mul_wide_u32_reg(block_id, row_bytes);
let row_base = ctx.add_u64(w_ptr, row_offset);
let q8_block_bytes = ctx.mov_u32_imm(36);
let mask_4bit = ctx.mov_u32_imm(0x0F);
let four_shift = ctx.mov_u32_imm(4);
let sb_idx = ctx.mov_u32_imm(0);
ctx.label("sb_loop");
let sb_done = ctx.setp_ge_u32(sb_idx, num_sb);
ctx.branch_if(sb_done, "sb_loop_end");
let sb_offset = ctx.mul_wide_u32(sb_idx, Q4K_SUPER_BLOCK_BYTES);
let sb_addr = ctx.add_u64(row_base, sb_offset);
let d_f16 = ctx.ld_global_f16(sb_addr);
let d_w = ctx.cvt_f32_f16(d_f16);
let four_64 = ctx.mov_u64_imm(4);
let scales_addr = ctx.add_u64(sb_addr, four_64);
let scale_byte = ctx.ld_global_u8(scales_addr);
let scale_u32 = ctx.cvt_u32_u8(scale_byte);
let mask_6bit = ctx.mov_u32_imm(0x3F);
let scale0 = ctx.and_u32(scale_u32, mask_6bit);
let scale0_f = ctx.cvt_f32_u32(scale0);
let ds = ctx.mul_f32(d_w, scale0_f);
let sixteen_64 = ctx.mov_u64_imm(16);
let qs_base = ctx.add_u64(sb_addr, sixteen_64);
let int_acc = ctx.mov_u32_imm(0);
let q8_base_idx = ctx.mul_u32(sb_idx, 8);
let lane_64 = ctx.cvt_u64_u32(lane_id);
let zero_imm = ctx.mov_u32_imm(0);
let q8_idx0 = ctx.add_u32_reg(q8_base_idx, zero_imm);
let q8_offset0 = ctx.mul_wide_u32_reg(q8_idx0, q8_block_bytes);
let q8_addr0 = ctx.add_u64(x_ptr, q8_offset0);
let q8_val_addr0 = ctx.add_u64(q8_addr0, lane_64);
let q8_val0 = ctx.ld_global_u8(q8_val_addr0);
let q8_val0_s32 = ctx.cvt_s32_u8_sx(q8_val0);
let qs_addr0 = ctx.add_u64(qs_base, lane_64);
let packed0 = ctx.ld_global_u8(qs_addr0);
let packed0_u32 = ctx.cvt_u32_u8(packed0);
let w0 = ctx.and_u32(packed0_u32, mask_4bit);
let w0_s32 = ctx.cvt_s32_u32(w0);
let prod0 = ctx.mul_lo_s32(w0_s32, q8_val0_s32);
ctx.add_u32_reg_inplace(int_acc, prod0);
let one_imm = ctx.mov_u32_imm(1);
let q8_idx1 = ctx.add_u32_reg(q8_base_idx, one_imm);
let q8_offset1 = ctx.mul_wide_u32_reg(q8_idx1, q8_block_bytes);
let q8_addr1 = ctx.add_u64(x_ptr, q8_offset1);
let q8_val_addr1 = ctx.add_u64(q8_addr1, lane_64);
let q8_val1 = ctx.ld_global_u8(q8_val_addr1);
let q8_val1_s32 = ctx.cvt_s32_u8_sx(q8_val1);
let w1 = ctx.shr_u32(packed0_u32, four_shift);
let w1_s32 = ctx.cvt_s32_u32(w1);
let prod1 = ctx.mul_lo_s32(w1_s32, q8_val1_s32);
ctx.add_u32_reg_inplace(int_acc, prod1);
let two_imm = ctx.mov_u32_imm(2);
let q8_idx2 = ctx.add_u32_reg(q8_base_idx, two_imm);
let q8_offset2 = ctx.mul_wide_u32_reg(q8_idx2, q8_block_bytes);
let q8_addr2 = ctx.add_u64(x_ptr, q8_offset2);
let q8_val_addr2 = ctx.add_u64(q8_addr2, lane_64);
let q8_val2 = ctx.ld_global_u8(q8_val_addr2);
let q8_val2_s32 = ctx.cvt_s32_u8_sx(q8_val2);
let thirty_two_64 = ctx.mov_u64_imm(32);
let qs_addr2 = ctx.add_u64(qs_base, thirty_two_64);
let qs_addr2 = ctx.add_u64(qs_addr2, lane_64);
let packed2 = ctx.ld_global_u8(qs_addr2);
let packed2_u32 = ctx.cvt_u32_u8(packed2);
let w2 = ctx.and_u32(packed2_u32, mask_4bit);
let w2_s32 = ctx.cvt_s32_u32(w2);
let prod2 = ctx.mul_lo_s32(w2_s32, q8_val2_s32);
ctx.add_u32_reg_inplace(int_acc, prod2);
let three_imm = ctx.mov_u32_imm(3);
let q8_idx3 = ctx.add_u32_reg(q8_base_idx, three_imm);
let q8_offset3 = ctx.mul_wide_u32_reg(q8_idx3, q8_block_bytes);
let q8_addr3 = ctx.add_u64(x_ptr, q8_offset3);
let q8_val_addr3 = ctx.add_u64(q8_addr3, lane_64);
let q8_val3 = ctx.ld_global_u8(q8_val_addr3);
let q8_val3_s32 = ctx.cvt_s32_u8_sx(q8_val3);
let w3 = ctx.shr_u32(packed2_u32, four_shift);
let w3_s32 = ctx.cvt_s32_u32(w3);
let prod3 = ctx.mul_lo_s32(w3_s32, q8_val3_s32);
ctx.add_u32_reg_inplace(int_acc, prod3);
let four_imm = ctx.mov_u32_imm(4);
let q8_idx4 = ctx.add_u32_reg(q8_base_idx, four_imm);
let q8_offset4 = ctx.mul_wide_u32_reg(q8_idx4, q8_block_bytes);
let q8_addr4 = ctx.add_u64(x_ptr, q8_offset4);
let q8_val_addr4 = ctx.add_u64(q8_addr4, lane_64);
let q8_val4 = ctx.ld_global_u8(q8_val_addr4);
let q8_val4_s32 = ctx.cvt_s32_u8_sx(q8_val4);
let sixty_four_64 = ctx.mov_u64_imm(64);
let qs_addr4 = ctx.add_u64(qs_base, sixty_four_64);
let qs_addr4 = ctx.add_u64(qs_addr4, lane_64);
let packed4 = ctx.ld_global_u8(qs_addr4);
let packed4_u32 = ctx.cvt_u32_u8(packed4);
let w4 = ctx.and_u32(packed4_u32, mask_4bit);
let w4_s32 = ctx.cvt_s32_u32(w4);
let prod4 = ctx.mul_lo_s32(w4_s32, q8_val4_s32);
ctx.add_u32_reg_inplace(int_acc, prod4);
let five_imm = ctx.mov_u32_imm(5);
let q8_idx5 = ctx.add_u32_reg(q8_base_idx, five_imm);
let q8_offset5 = ctx.mul_wide_u32_reg(q8_idx5, q8_block_bytes);
let q8_addr5 = ctx.add_u64(x_ptr, q8_offset5);
let q8_val_addr5 = ctx.add_u64(q8_addr5, lane_64);
let q8_val5 = ctx.ld_global_u8(q8_val_addr5);
let q8_val5_s32 = ctx.cvt_s32_u8_sx(q8_val5);
let w5 = ctx.shr_u32(packed4_u32, four_shift);
let w5_s32 = ctx.cvt_s32_u32(w5);
let prod5 = ctx.mul_lo_s32(w5_s32, q8_val5_s32);
ctx.add_u32_reg_inplace(int_acc, prod5);
let six_imm = ctx.mov_u32_imm(6);
let q8_idx6 = ctx.add_u32_reg(q8_base_idx, six_imm);
let q8_offset6 = ctx.mul_wide_u32_reg(q8_idx6, q8_block_bytes);
let q8_addr6 = ctx.add_u64(x_ptr, q8_offset6);
let q8_val_addr6 = ctx.add_u64(q8_addr6, lane_64);
let q8_val6 = ctx.ld_global_u8(q8_val_addr6);
let q8_val6_s32 = ctx.cvt_s32_u8_sx(q8_val6);
let ninety_six_64 = ctx.mov_u64_imm(96);
let qs_addr6 = ctx.add_u64(qs_base, ninety_six_64);
let qs_addr6 = ctx.add_u64(qs_addr6, lane_64);
let packed6 = ctx.ld_global_u8(qs_addr6);
let packed6_u32 = ctx.cvt_u32_u8(packed6);
let w6 = ctx.and_u32(packed6_u32, mask_4bit);
let w6_s32 = ctx.cvt_s32_u32(w6);
let prod6 = ctx.mul_lo_s32(w6_s32, q8_val6_s32);
ctx.add_u32_reg_inplace(int_acc, prod6);
let seven_imm = ctx.mov_u32_imm(7);
let q8_idx7 = ctx.add_u32_reg(q8_base_idx, seven_imm);
let q8_offset7 = ctx.mul_wide_u32_reg(q8_idx7, q8_block_bytes);
let q8_addr7 = ctx.add_u64(x_ptr, q8_offset7);
let q8_val_addr7 = ctx.add_u64(q8_addr7, lane_64);
let q8_val7 = ctx.ld_global_u8(q8_val_addr7);
let q8_val7_s32 = ctx.cvt_s32_u8_sx(q8_val7);
let w7 = ctx.shr_u32(packed6_u32, four_shift);
let w7_s32 = ctx.cvt_s32_u32(w7);
let prod7 = ctx.mul_lo_s32(w7_s32, q8_val7_s32);
ctx.add_u32_reg_inplace(int_acc, prod7);
let thirty_two_64_q8 = ctx.mov_u64_imm(32);
let q8_d_addr = ctx.add_u64(q8_addr0, thirty_two_64_q8);
let q8_d_f16 = ctx.ld_global_f16(q8_d_addr);
let q8_d = ctx.cvt_f32_f16(q8_d_f16);
let int_acc_f = ctx.cvt_f32_s32(int_acc);
let combined_scale = ctx.mul_f32(ds, q8_d);
let scaled_result = ctx.mul_f32(int_acc_f, combined_scale);
ctx.add_f32_inplace(float_acc, scaled_result);
ctx.add_u32_inplace(sb_idx, 1);
ctx.branch("sb_loop");
ctx.label("sb_loop_end");
let tmp16 = ctx.shfl_down_f32(float_acc, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(float_acc, tmp16);
let tmp8 = ctx.shfl_down_f32(float_acc, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(float_acc, tmp8);
let tmp4 = ctx.shfl_down_f32(float_acc, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(float_acc, tmp4);
let tmp2 = ctx.shfl_down_f32(float_acc, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(float_acc, tmp2);
let tmp1 = ctx.shfl_down_f32(float_acc, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(float_acc, tmp1);
let one = ctx.mov_u32_imm(1);
let is_lane0 = ctx.setp_lt_u32(lane_id, one);
ctx.branch_if_not(is_lane0, "exit");
let y_offset = ctx.mul_wide_u32(block_id, 4);
let y_addr = ctx.add_u64(y_ptr, y_offset);
ctx.st_global_f32(y_addr, float_acc);
ctx.label("exit");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct PackedDp4aQ4KQ8Kernel {
pub k: u32,
pub n: u32,
}
impl PackedDp4aQ4KQ8Kernel {
#[must_use]
pub fn new(k: u32, n: u32) -> Self {
Self { k, n }
}
}
impl Kernel for PackedDp4aQ4KQ8Kernel {
fn name(&self) -> &str {
"packed_dp4a_q4k_q8"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("packed_dp4a_q4k_q8")
.param(PtxType::U64, "y_ptr") .param(PtxType::U64, "w_ptr") .param(PtxType::U64, "x_ptr") .param(PtxType::U32, "k_dim")
.param(PtxType::U32, "n_dim")
.build(|ctx| {
let block_id = ctx.special_reg(PtxReg::CtaIdX);
let thread_id = ctx.special_reg(PtxReg::TidX);
let lane_id = ctx.rem_u32(thread_id, 32);
let n_dim = ctx.load_param_u32("n_dim");
let oob = ctx.setp_ge_u32(block_id, n_dim);
ctx.branch_if(oob, "exit");
let k_dim = ctx.load_param_u32("k_dim");
let y_ptr = ctx.load_param_u64("y_ptr");
let w_ptr = ctx.load_param_u64("w_ptr");
let x_ptr = ctx.load_param_u64("x_ptr");
let float_acc = ctx.mov_f32_imm(0.0);
let num_sb = ctx.add_u32(k_dim, Q4K_SUPER_BLOCK_SIZE - 1);
let num_sb = ctx.div_u32(num_sb, Q4K_SUPER_BLOCK_SIZE);
let sb_bytes = ctx.mov_u32_imm(Q4K_SUPER_BLOCK_BYTES);
let row_bytes = ctx.mul_u32_reg(num_sb, sb_bytes);
let row_offset = ctx.mul_wide_u32_reg(block_id, row_bytes);
let row_base = ctx.add_u64(w_ptr, row_offset);
let mask_0f = ctx.mov_u32_imm(0x0F);
let sb_idx = ctx.mov_u32_imm(0);
ctx.label("sb_loop");
let sb_done = ctx.setp_ge_u32(sb_idx, num_sb);
ctx.branch_if(sb_done, "sb_loop_end");
let sb_offset = ctx.mul_wide_u32(sb_idx, Q4K_SUPER_BLOCK_BYTES);
let sb_addr = ctx.add_u64(row_base, sb_offset);
let d_f16 = ctx.ld_global_f16(sb_addr);
let d_w = ctx.cvt_f32_f16(d_f16);
let four_64 = ctx.mov_u64_imm(4);
let scales_addr = ctx.add_u64(sb_addr, four_64);
let scale_byte = ctx.ld_global_u8(scales_addr);
let scale_u32 = ctx.cvt_u32_u8(scale_byte);
let mask_6bit = ctx.mov_u32_imm(0x3F);
let scale0 = ctx.and_u32(scale_u32, mask_6bit);
let scale0_f = ctx.cvt_f32_u32(scale0);
let ds = ctx.mul_f32(d_w, scale0_f);
let sixteen_64 = ctx.mov_u64_imm(16);
let qs_base = ctx.add_u64(sb_addr, sixteen_64);
let dp4a_acc = ctx.mov_u32_imm(0);
let q8_base_idx = ctx.mul_u32(sb_idx, 8);
let lane_64 = ctx.cvt_u64_u32(lane_id);
let q8_block_bytes = ctx.mov_u32_imm(36);
let zero_imm = ctx.mov_u32_imm(0);
let q8_idx0 = ctx.add_u32_reg(q8_base_idx, zero_imm);
let q8_offset0 = ctx.mul_wide_u32_reg(q8_idx0, q8_block_bytes);
let q8_addr0 = ctx.add_u64(x_ptr, q8_offset0);
let q8_val_addr0 = ctx.add_u64(q8_addr0, lane_64);
let q8_val0 = ctx.ld_global_u8(q8_val_addr0);
let q8_val0_u32 = ctx.cvt_u32_u8(q8_val0);
let one_imm = ctx.mov_u32_imm(1);
let q8_idx1 = ctx.add_u32_reg(q8_base_idx, one_imm);
let q8_offset1 = ctx.mul_wide_u32_reg(q8_idx1, q8_block_bytes);
let q8_addr1 = ctx.add_u64(x_ptr, q8_offset1);
let q8_val_addr1 = ctx.add_u64(q8_addr1, lane_64);
let q8_val1 = ctx.ld_global_u8(q8_val_addr1);
let q8_val1_u32 = ctx.cvt_u32_u8(q8_val1);
let two_imm = ctx.mov_u32_imm(2);
let q8_idx2 = ctx.add_u32_reg(q8_base_idx, two_imm);
let q8_offset2 = ctx.mul_wide_u32_reg(q8_idx2, q8_block_bytes);
let q8_addr2 = ctx.add_u64(x_ptr, q8_offset2);
let q8_val_addr2 = ctx.add_u64(q8_addr2, lane_64);
let q8_val2 = ctx.ld_global_u8(q8_val_addr2);
let q8_val2_u32 = ctx.cvt_u32_u8(q8_val2);
let three_imm = ctx.mov_u32_imm(3);
let q8_idx3 = ctx.add_u32_reg(q8_base_idx, three_imm);
let q8_offset3 = ctx.mul_wide_u32_reg(q8_idx3, q8_block_bytes);
let q8_addr3 = ctx.add_u64(x_ptr, q8_offset3);
let q8_val_addr3 = ctx.add_u64(q8_addr3, lane_64);
let q8_val3 = ctx.ld_global_u8(q8_val_addr3);
let q8_val3_u32 = ctx.cvt_u32_u8(q8_val3);
let eight = ctx.mov_u32_imm(8);
let sixteen = ctx.mov_u32_imm(16);
let twenty_four = ctx.mov_u32_imm(24);
let q8_val1_shifted = ctx.shl_u32(q8_val1_u32, eight);
let q8_val2_shifted = ctx.shl_u32(q8_val2_u32, sixteen);
let q8_val3_shifted = ctx.shl_u32(q8_val3_u32, twenty_four);
let q8_packed_01 = ctx.or_u32(q8_val0_u32, q8_val1_shifted);
let q8_packed_23 = ctx.or_u32(q8_val2_shifted, q8_val3_shifted);
let q8_packed_0123 = ctx.or_u32(q8_packed_01, q8_packed_23);
let qs_addr0 = ctx.add_u64(qs_base, lane_64);
let packed_01 = ctx.ld_global_u8(qs_addr0);
let packed_01_u32 = ctx.cvt_u32_u8(packed_01);
let thirty_two_64 = ctx.mov_u64_imm(32);
let qs_addr2 = ctx.add_u64(qs_base, thirty_two_64);
let qs_addr2 = ctx.add_u64(qs_addr2, lane_64);
let packed_23 = ctx.ld_global_u8(qs_addr2);
let packed_23_u32 = ctx.cvt_u32_u8(packed_23);
let four = ctx.mov_u32_imm(4);
let nibble0 = ctx.and_u32(packed_01_u32, mask_0f);
let nibble1 = ctx.shr_u32(packed_01_u32, four);
let nibble1 = ctx.and_u32(nibble1, mask_0f);
let nibble2 = ctx.and_u32(packed_23_u32, mask_0f);
let nibble3 = ctx.shr_u32(packed_23_u32, four);
let nibble3 = ctx.and_u32(nibble3, mask_0f);
let nibble1_shifted = ctx.shl_u32(nibble1, eight);
let nibble2_shifted = ctx.shl_u32(nibble2, sixteen);
let nibble3_shifted = ctx.shl_u32(nibble3, twenty_four);
let w_packed_01 = ctx.or_u32(nibble0, nibble1_shifted);
let w_packed_23 = ctx.or_u32(nibble2_shifted, nibble3_shifted);
let w_packed_0123 = ctx.or_u32(w_packed_01, w_packed_23);
ctx.dp4a_u32_s32_inplace(dp4a_acc, w_packed_0123, q8_packed_0123);
let four_imm = ctx.mov_u32_imm(4);
let q8_idx4 = ctx.add_u32_reg(q8_base_idx, four_imm);
let q8_offset4 = ctx.mul_wide_u32_reg(q8_idx4, q8_block_bytes);
let q8_addr4 = ctx.add_u64(x_ptr, q8_offset4);
let q8_val_addr4 = ctx.add_u64(q8_addr4, lane_64);
let q8_val4 = ctx.ld_global_u8(q8_val_addr4);
let q8_val4_u32 = ctx.cvt_u32_u8(q8_val4);
let five_imm = ctx.mov_u32_imm(5);
let q8_idx5 = ctx.add_u32_reg(q8_base_idx, five_imm);
let q8_offset5 = ctx.mul_wide_u32_reg(q8_idx5, q8_block_bytes);
let q8_addr5 = ctx.add_u64(x_ptr, q8_offset5);
let q8_val_addr5 = ctx.add_u64(q8_addr5, lane_64);
let q8_val5 = ctx.ld_global_u8(q8_val_addr5);
let q8_val5_u32 = ctx.cvt_u32_u8(q8_val5);
let six_imm = ctx.mov_u32_imm(6);
let q8_idx6 = ctx.add_u32_reg(q8_base_idx, six_imm);
let q8_offset6 = ctx.mul_wide_u32_reg(q8_idx6, q8_block_bytes);
let q8_addr6 = ctx.add_u64(x_ptr, q8_offset6);
let q8_val_addr6 = ctx.add_u64(q8_addr6, lane_64);
let q8_val6 = ctx.ld_global_u8(q8_val_addr6);
let q8_val6_u32 = ctx.cvt_u32_u8(q8_val6);
let seven_imm = ctx.mov_u32_imm(7);
let q8_idx7 = ctx.add_u32_reg(q8_base_idx, seven_imm);
let q8_offset7 = ctx.mul_wide_u32_reg(q8_idx7, q8_block_bytes);
let q8_addr7 = ctx.add_u64(x_ptr, q8_offset7);
let q8_val_addr7 = ctx.add_u64(q8_addr7, lane_64);
let q8_val7 = ctx.ld_global_u8(q8_val_addr7);
let q8_val7_u32 = ctx.cvt_u32_u8(q8_val7);
let q8_val5_shifted = ctx.shl_u32(q8_val5_u32, eight);
let q8_val6_shifted = ctx.shl_u32(q8_val6_u32, sixteen);
let q8_val7_shifted = ctx.shl_u32(q8_val7_u32, twenty_four);
let q8_packed_45 = ctx.or_u32(q8_val4_u32, q8_val5_shifted);
let q8_packed_67 = ctx.or_u32(q8_val6_shifted, q8_val7_shifted);
let q8_packed_4567 = ctx.or_u32(q8_packed_45, q8_packed_67);
let sixty_four_64 = ctx.mov_u64_imm(64);
let qs_addr4 = ctx.add_u64(qs_base, sixty_four_64);
let qs_addr4 = ctx.add_u64(qs_addr4, lane_64);
let packed_45 = ctx.ld_global_u8(qs_addr4);
let packed_45_u32 = ctx.cvt_u32_u8(packed_45);
let ninety_six_64 = ctx.mov_u64_imm(96);
let qs_addr6 = ctx.add_u64(qs_base, ninety_six_64);
let qs_addr6 = ctx.add_u64(qs_addr6, lane_64);
let packed_67 = ctx.ld_global_u8(qs_addr6);
let packed_67_u32 = ctx.cvt_u32_u8(packed_67);
let nibble4 = ctx.and_u32(packed_45_u32, mask_0f);
let nibble5 = ctx.shr_u32(packed_45_u32, four);
let nibble5 = ctx.and_u32(nibble5, mask_0f);
let nibble6 = ctx.and_u32(packed_67_u32, mask_0f);
let nibble7 = ctx.shr_u32(packed_67_u32, four);
let nibble7 = ctx.and_u32(nibble7, mask_0f);
let nibble5_shifted = ctx.shl_u32(nibble5, eight);
let nibble6_shifted = ctx.shl_u32(nibble6, sixteen);
let nibble7_shifted = ctx.shl_u32(nibble7, twenty_four);
let w_packed_45 = ctx.or_u32(nibble4, nibble5_shifted);
let w_packed_67 = ctx.or_u32(nibble6_shifted, nibble7_shifted);
let w_packed_4567 = ctx.or_u32(w_packed_45, w_packed_67);
ctx.dp4a_u32_s32_inplace(dp4a_acc, w_packed_4567, q8_packed_4567);
let thirty_two_64_q8 = ctx.mov_u64_imm(32);
let q8_d_addr = ctx.add_u64(q8_addr0, thirty_two_64_q8);
let q8_d_f16 = ctx.ld_global_f16(q8_d_addr);
let q8_d = ctx.cvt_f32_f16(q8_d_f16);
let dp4a_acc_f = ctx.cvt_f32_s32(dp4a_acc);
let combined_scale = ctx.mul_f32(ds, q8_d);
let scaled_result = ctx.mul_f32(dp4a_acc_f, combined_scale);
ctx.add_f32_inplace(float_acc, scaled_result);
ctx.mov_u32_inplace(dp4a_acc, 0);
ctx.add_u32_inplace(sb_idx, 1);
ctx.branch("sb_loop");
ctx.label("sb_loop_end");
let tmp16 = ctx.shfl_down_f32(float_acc, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(float_acc, tmp16);
let tmp8 = ctx.shfl_down_f32(float_acc, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(float_acc, tmp8);
let tmp4 = ctx.shfl_down_f32(float_acc, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(float_acc, tmp4);
let tmp2 = ctx.shfl_down_f32(float_acc, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(float_acc, tmp2);
let tmp1 = ctx.shfl_down_f32(float_acc, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(float_acc, tmp1);
let one = ctx.mov_u32_imm(1);
let is_lane0 = ctx.setp_lt_u32(lane_id, one);
ctx.branch_if_not(is_lane0, "exit");
let y_offset = ctx.mul_wide_u32(block_id, 4);
let y_addr = ctx.add_u64(y_ptr, y_offset);
ctx.st_global_f32(y_addr, float_acc);
ctx.label("exit");
ctx.ret();
})
}
}