#![allow(clippy::similar_names)]
#![allow(clippy::too_many_lines)]
use super::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
mod dot;
mod fp16_tensor;
mod fused;
mod legacy;
mod q4k;
mod q5k;
mod q6k;
pub use dot::{PackedDp4aQ4KQ8Kernel, Q4KQ8DotKernel};
pub use fp16_tensor::{Fp16Q4KGemvKernel, TensorCoreQ4KGemmKernel};
pub use fused::{FusedGateUpQ4KGemvKernel, FusedRmsNormQ4KGemvKernel};
pub use legacy::{Q4_0GemvKernel, Q4_1GemvKernel, Q5_0GemvKernel, Q8_0GemvKernel};
pub use q4k::{
BatchedQ4KGemvKernel, ChunkedTiledQ4KGemvKernel, CoalescedQ4KGemvKernel, Dp4aQ4KGemvKernel,
Q4KGemvKernel, TiledQ4KGemvKernel, TrueDp4aQ4KGemvKernel, VectorizedQ4KGemvKernel,
};
pub use q5k::{Q5KGemvKernel, Q5KKernel};
pub use q6k::{BatchedQ6KGemvKernel, CoalescedQ6KGemvKernel, Q6KGemvKernel, Q6KKernel};
const Q4K_BLOCK_SIZE: u32 = 32;
pub(crate) const Q4K_SUPER_BLOCK_SIZE: u32 = 256;
pub(crate) const Q4K_SUPER_BLOCK_BYTES: u32 = 144;
const Q4K_BLOCK_BYTES: u32 = 18;
pub(crate) const Q5K_SUPER_BLOCK_SIZE: u32 = 256;
pub(crate) const Q5K_SUPER_BLOCK_BYTES: u32 = 176;
pub(crate) const Q6K_SUPER_BLOCK_SIZE: u32 = 256;
pub(crate) const Q6K_SUPER_BLOCK_BYTES: u32 = 210;
const Q8_0_BLOCK_SIZE: u32 = 32;
const Q8_0_BLOCK_BYTES: u32 = 34;
const Q5_0_BLOCK_SIZE: u32 = 32;
const Q5_0_BLOCK_BYTES: u32 = 22;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Q4KFormat {
Simplified,
GgmlSuperBlock,
}
#[derive(Debug, Clone)]
pub struct QuantizeKernel {
pub m: u32,
pub n: u32,
pub k: u32,
pub tile_size: u32,
pub block_size: u32,
pub format: Q4KFormat,
}
impl QuantizeKernel {
#[must_use]
pub fn new(m: u32, n: u32, k: u32) -> Self {
Self {
m,
n,
k,
tile_size: 32,
block_size: Q4K_BLOCK_SIZE,
format: Q4KFormat::Simplified,
}
}
#[must_use]
pub fn ggml(m: u32, n: u32, k: u32) -> Self {
Self {
m,
n,
k,
tile_size: 32,
block_size: Q4K_SUPER_BLOCK_SIZE,
format: Q4KFormat::GgmlSuperBlock,
}
}
#[must_use]
pub const fn with_tile_size(mut self, tile_size: u32) -> Self {
self.tile_size = tile_size;
self
}
#[must_use]
pub const fn num_blocks_per_row(&self) -> u32 {
self.k / self.block_size
}
#[must_use]
pub const fn num_super_blocks_per_row(&self) -> u32 {
self.k / Q4K_SUPER_BLOCK_SIZE
}
}
impl Kernel for QuantizeKernel {
fn name(&self) -> &str {
match self.format {
Q4KFormat::Simplified => "q4k_gemm_fused",
Q4KFormat::GgmlSuperBlock => "q4k_gemm_ggml",
}
}
fn build_ptx(&self) -> PtxKernel {
match self.format {
Q4KFormat::Simplified => self.build_fused_gemm_simplified(),
Q4KFormat::GgmlSuperBlock => self.build_fused_gemm_ggml(),
}
}
}
impl QuantizeKernel {
fn build_fused_gemm_simplified(&self) -> PtxKernel {
let tile_size = self.tile_size;
let block_size = self.block_size;
let smem_size = tile_size * tile_size * 4;
PtxKernel::new("q4k_gemm_fused")
.param(PtxType::U64, "a_ptr") .param(PtxType::U64, "b_quant_ptr") .param(PtxType::U64, "c_ptr") .param(PtxType::U32, "m") .param(PtxType::U32, "n") .param(PtxType::U32, "k") .shared_memory(smem_size as usize)
.build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid_x = ctx.special_reg(PtxReg::CtaIdX);
let ctaid_y = ctx.special_reg(PtxReg::CtaIdY);
let m_param = ctx.load_param_u32("m");
let n_param = ctx.load_param_u32("n");
let k_param = ctx.load_param_u32("k");
let a_ptr = ctx.load_param_u64("a_ptr");
let b_quant_ptr = ctx.load_param_u64("b_quant_ptr");
let c_ptr = ctx.load_param_u64("c_ptr");
let tile_size_reg = ctx.mov_u32_imm(tile_size);
let out_row = ctx.mul_u32_reg(ctaid_y, tile_size_reg);
let out_col = ctx.mul_u32_reg(ctaid_x, tile_size_reg);
let local_row = ctx.div_u32(tid, tile_size);
let local_col = ctx.rem_u32(tid, tile_size);
let global_row = ctx.add_u32_reg(out_row, local_row);
let global_col = ctx.add_u32_reg(out_col, local_col);
let row_oob = ctx.setp_ge_u32(global_row, m_param);
let col_oob = ctx.setp_ge_u32(global_col, n_param);
let one = ctx.mov_u32_imm(1);
let m_minus_1 = ctx.sub_u32_reg(m_param, one);
let n_minus_1 = ctx.sub_u32_reg(n_param, one);
let clamped_row = ctx.min_u32(global_row, m_minus_1);
let clamped_col = ctx.min_u32(global_col, n_minus_1);
let acc = ctx.mov_f32_imm(0.0);
let block_size_reg = ctx.mov_u32_imm(block_size);
let num_k_blocks = ctx.div_u32(k_param, block_size);
let k_block = ctx.mov_u32_imm(0);
ctx.label("k_block_loop");
let k_done = ctx.setp_ge_u32(k_block, num_k_blocks);
ctx.branch_if(k_done, "k_block_done");
let blocks_per_row = num_k_blocks;
let block_bytes = ctx.mov_u32_imm(Q4K_BLOCK_BYTES);
let row_offset = ctx.mul_u32_reg(clamped_col, blocks_per_row);
let block_offset = ctx.add_u32_reg(row_offset, k_block);
let byte_offset = ctx.mul_wide_u32_reg(block_offset, block_bytes);
let block_addr = ctx.add_u64(b_quant_ptr, byte_offset);
let scale_addr = block_addr;
let scale_f16 = ctx.ld_global_f16(scale_addr);
let scale = ctx.cvt_f32_f16(scale_f16);
let lane = ctx.rem_u32(tid, block_size);
let byte_idx = ctx.div_u32(lane, 2);
let nibble_idx = ctx.rem_u32(lane, 2);
let header_size = ctx.mov_u64_imm(2);
let data_addr = ctx.add_u64(block_addr, header_size);
let byte_idx_64 = ctx.cvt_u64_u32(byte_idx);
let packed_addr = ctx.add_u64(data_addr, byte_idx_64);
let packed = ctx.ld_global_u8(packed_addr);
let four = ctx.mov_u32_imm(4);
let shift = ctx.mul_u32_reg(nibble_idx, four);
let packed_32 = ctx.cvt_u32_u8(packed);
let fifteen = ctx.mov_u32_imm(0xF);
let shifted = ctx.shr_u32(packed_32, shift);
let quant = ctx.and_u32(shifted, fifteen);
let quant_f32 = ctx.cvt_f32_u32(quant);
let dequant = ctx.mul_f32(scale, quant_f32);
let k_offset_base = ctx.mul_u32_reg(k_block, block_size_reg);
let k_offset = ctx.add_u32_reg(k_offset_base, lane);
let a_row_offset = ctx.mul_wide_u32_reg(clamped_row, k_param);
let k_offset_64 = ctx.cvt_u64_u32(k_offset);
let a_elem_offset = ctx.add_u64(a_row_offset, k_offset_64);
let a_elem_offset_bytes = ctx.mul_u64(a_elem_offset, 4);
let a_addr = ctx.add_u64(a_ptr, a_elem_offset_bytes);
let a_val = ctx.ld_global_f32(a_addr);
let prod = ctx.mul_f32(a_val, dequant);
let shuffled_16 = ctx.shfl_down_f32(prod, 16, 0xFFFF_FFFF);
let prod_1 = ctx.add_f32(prod, shuffled_16);
let shuffled_8 = ctx.shfl_down_f32(prod_1, 8, 0xFFFF_FFFF);
let prod_2 = ctx.add_f32(prod_1, shuffled_8);
let shuffled_4 = ctx.shfl_down_f32(prod_2, 4, 0xFFFF_FFFF);
let prod_3 = ctx.add_f32(prod_2, shuffled_4);
let shuffled_2 = ctx.shfl_down_f32(prod_3, 2, 0xFFFF_FFFF);
let prod_4 = ctx.add_f32(prod_3, shuffled_2);
let shuffled_1 = ctx.shfl_down_f32(prod_4, 1, 0xFFFF_FFFF);
let block_sum = ctx.add_f32(prod_4, shuffled_1);
let broadcast_sum = ctx.shfl_idx_f32(block_sum, 0, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, broadcast_sum);
ctx.add_u32_inplace(k_block, 1);
ctx.branch("k_block_loop");
ctx.label("k_block_done");
ctx.branch_if(row_oob, "exit");
ctx.branch_if(col_oob, "exit");
let c_row_offset = ctx.mul_wide_u32_reg(global_row, n_param);
let global_col_64 = ctx.cvt_u64_u32(global_col);
let c_elem_offset = ctx.add_u64(c_row_offset, global_col_64);
let c_elem_offset_bytes = ctx.mul_u64(c_elem_offset, 4);
let c_addr = ctx.add_u64(c_ptr, c_elem_offset_bytes);
ctx.st_global_f32(c_addr, acc);
ctx.label("exit");
ctx.ret();
})
}
fn build_fused_gemm_ggml(&self) -> PtxKernel {
let tile_size = self.tile_size;
let smem_size = Q4K_SUPER_BLOCK_SIZE * 4;
PtxKernel::new("q4k_gemm_ggml")
.param(PtxType::U64, "a_ptr") .param(PtxType::U64, "b_quant_ptr") .param(PtxType::U64, "c_ptr") .param(PtxType::U32, "m") .param(PtxType::U32, "n") .param(PtxType::U32, "k") .shared_memory(smem_size as usize)
.build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid_x = ctx.special_reg(PtxReg::CtaIdX);
let ctaid_y = ctx.special_reg(PtxReg::CtaIdY);
let m_param = ctx.load_param_u32("m");
let n_param = ctx.load_param_u32("n");
let k_param = ctx.load_param_u32("k");
let a_ptr = ctx.load_param_u64("a_ptr");
let b_quant_ptr = ctx.load_param_u64("b_quant_ptr");
let c_ptr = ctx.load_param_u64("c_ptr");
let tile_size_reg = ctx.mov_u32_imm(tile_size);
let out_row = ctx.mul_u32_reg(ctaid_y, tile_size_reg);
let out_col = ctx.mul_u32_reg(ctaid_x, tile_size_reg);
let local_row = ctx.div_u32(tid, tile_size);
let local_col = ctx.rem_u32(tid, tile_size);
let global_row = ctx.add_u32_reg(out_row, local_row);
let global_col = ctx.add_u32_reg(out_col, local_col);
let row_oob = ctx.setp_ge_u32(global_row, m_param);
let col_oob = ctx.setp_ge_u32(global_col, n_param);
let one = ctx.mov_u32_imm(1);
let m_minus_1 = ctx.sub_u32_reg(m_param, one);
let n_minus_1 = ctx.sub_u32_reg(n_param, one);
let clamped_row = ctx.min_u32(global_row, m_minus_1);
let clamped_col = ctx.min_u32(global_col, n_minus_1);
let acc = ctx.mov_f32_imm(0.0);
let num_k_super_blocks = ctx.div_u32(k_param, Q4K_SUPER_BLOCK_SIZE);
let sb_idx = ctx.mov_u32_imm(0);
ctx.label("sb_loop");
let sb_done = ctx.setp_ge_u32(sb_idx, num_k_super_blocks);
ctx.branch_if(sb_done, "sb_loop_done");
let sb_per_row = num_k_super_blocks;
let row_sb_offset = ctx.mul_u32_reg(clamped_col, sb_per_row);
let total_sb_offset = ctx.add_u32_reg(row_sb_offset, sb_idx);
let byte_offset = ctx.mul_wide_u32(total_sb_offset, Q4K_SUPER_BLOCK_BYTES);
let sb_addr = ctx.add_u64(b_quant_ptr, byte_offset);
let d_f16 = ctx.ld_global_f16(sb_addr);
let d = ctx.cvt_f32_f16(d_f16);
let two = ctx.mov_u64_imm(2);
let dmin_addr = ctx.add_u64(sb_addr, two);
let dmin_f16 = ctx.ld_global_f16(dmin_addr);
let dmin = ctx.cvt_f32_f16(dmin_f16);
let sub_block_idx = ctx.mov_u32_imm(0);
let eight = ctx.mov_u32_imm(8);
let thirty_two = ctx.mov_u32_imm(32);
ctx.label("sub_block_loop");
let sub_done = ctx.setp_ge_u32(sub_block_idx, eight);
ctx.branch_if(sub_done, "sub_block_done");
let bit_offset = ctx.mul_u32(sub_block_idx, 12);
let byte_idx = ctx.div_u32(bit_offset, 8);
let bit_in_byte = ctx.rem_u32(bit_offset, 8);
let four = ctx.mov_u64_imm(4);
let scales_base = ctx.add_u64(sb_addr, four);
let byte_idx_64 = ctx.cvt_u64_u32(byte_idx);
let scales_addr = ctx.add_u64(scales_base, byte_idx_64);
let scale_b0 = ctx.ld_global_u8(scales_addr);
let one_64 = ctx.mov_u64_imm(1);
let scales_addr1 = ctx.add_u64(scales_addr, one_64);
let scale_b1 = ctx.ld_global_u8(scales_addr1);
let b0_32 = ctx.cvt_u32_u8(scale_b0);
let b1_32 = ctx.cvt_u32_u8(scale_b1);
let eight_shift = ctx.mov_u32_imm(8);
let b1_shifted = ctx.shl_u32(b1_32, eight_shift);
let combined = ctx.or_u32(b0_32, b1_shifted);
let bits_12 = ctx.shr_u32(combined, bit_in_byte);
let mask_6bit = ctx.mov_u32_imm(0x3F);
let scale_6bit = ctx.and_u32(bits_12, mask_6bit);
let six_shift = ctx.mov_u32_imm(6);
let min_shifted = ctx.shr_u32(bits_12, six_shift);
let min_6bit = ctx.and_u32(min_shifted, mask_6bit);
let scale_f32 = ctx.cvt_f32_u32(scale_6bit);
let min_f32 = ctx.cvt_f32_u32(min_6bit);
let inv_63 = ctx.mov_f32_imm(1.0 / 63.0);
let scale_norm = ctx.mul_f32(scale_f32, inv_63);
let min_norm = ctx.mul_f32(min_f32, inv_63);
let lane = ctx.rem_u32(tid, 32);
let sixteen = ctx.mov_u64_imm(16);
let qs_base = ctx.add_u64(sb_addr, sixteen);
let sub_block_offset = ctx.mul_u32(sub_block_idx, 16);
let sub_block_offset_64 = ctx.cvt_u64_u32(sub_block_offset);
let qs_sub_base = ctx.add_u64(qs_base, sub_block_offset_64);
let byte_in_sub = ctx.div_u32(lane, 2);
let nibble_idx = ctx.rem_u32(lane, 2);
let byte_in_sub_64 = ctx.cvt_u64_u32(byte_in_sub);
let qs_addr = ctx.add_u64(qs_sub_base, byte_in_sub_64);
let packed = ctx.ld_global_u8(qs_addr);
let shift_amt = ctx.mul_u32(nibble_idx, 4);
let packed_32 = ctx.cvt_u32_u8(packed);
let shifted = ctx.shr_u32(packed_32, shift_amt);
let mask_4bit = ctx.mov_u32_imm(0xF);
let quant = ctx.and_u32(shifted, mask_4bit);
let quant_f32 = ctx.cvt_f32_u32(quant);
let d_scale = ctx.mul_f32(d, scale_norm);
let scaled = ctx.mul_f32(d_scale, quant_f32);
let dmin_min = ctx.mul_f32(dmin, min_norm);
let dequant = ctx.sub_f32(scaled, dmin_min);
let two_fifty_six = ctx.mov_u32_imm(256);
let sb_k_offset = ctx.mul_u32_reg(sb_idx, two_fifty_six);
let sub_k_offset = ctx.mul_u32_reg(sub_block_idx, thirty_two);
let k_offset = ctx.add_u32_reg(sb_k_offset, sub_k_offset);
let k_offset_full = ctx.add_u32_reg(k_offset, lane);
let a_row_offset = ctx.mul_wide_u32_reg(clamped_row, k_param);
let k_offset_64 = ctx.cvt_u64_u32(k_offset_full);
let a_elem_offset = ctx.add_u64(a_row_offset, k_offset_64);
let a_elem_bytes = ctx.mul_u64(a_elem_offset, 4);
let a_addr = ctx.add_u64(a_ptr, a_elem_bytes);
let a_val = ctx.ld_global_f32(a_addr);
let prod = ctx.mul_f32(a_val, dequant);
let shuffled_16 = ctx.shfl_down_f32(prod, 16, 0xFFFF_FFFF);
let prod_1 = ctx.add_f32(prod, shuffled_16);
let shuffled_8 = ctx.shfl_down_f32(prod_1, 8, 0xFFFF_FFFF);
let prod_2 = ctx.add_f32(prod_1, shuffled_8);
let shuffled_4 = ctx.shfl_down_f32(prod_2, 4, 0xFFFF_FFFF);
let prod_3 = ctx.add_f32(prod_2, shuffled_4);
let shuffled_2 = ctx.shfl_down_f32(prod_3, 2, 0xFFFF_FFFF);
let prod_4 = ctx.add_f32(prod_3, shuffled_2);
let shuffled_1 = ctx.shfl_down_f32(prod_4, 1, 0xFFFF_FFFF);
let sub_block_sum = ctx.add_f32(prod_4, shuffled_1);
let broadcast_sum = ctx.shfl_idx_f32(sub_block_sum, 0, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc, broadcast_sum);
ctx.add_u32_inplace(sub_block_idx, 1);
ctx.branch("sub_block_loop");
ctx.label("sub_block_done");
ctx.add_u32_inplace(sb_idx, 1);
ctx.branch("sb_loop");
ctx.label("sb_loop_done");
ctx.branch_if(row_oob, "exit");
ctx.branch_if(col_oob, "exit");
let c_row_offset = ctx.mul_wide_u32_reg(global_row, n_param);
let global_col_64 = ctx.cvt_u64_u32(global_col);
let c_elem_offset = ctx.add_u64(c_row_offset, global_col_64);
let c_elem_bytes = ctx.mul_u64(c_elem_offset, 4);
let c_addr = ctx.add_u64(c_ptr, c_elem_bytes);
ctx.st_global_f32(c_addr, acc);
ctx.label("exit");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct Q8QuantizeKernel {
pub n: u32,
}
impl Q8QuantizeKernel {
#[must_use]
pub fn new(n: u32) -> Self {
Self { n }
}
#[must_use]
pub const fn num_blocks(&self) -> u32 {
(self.n + 31) / 32
}
}
impl Kernel for Q8QuantizeKernel {
fn name(&self) -> &str {
"q8_quantize"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("q8_quantize")
.param(PtxType::U64, "out_ptr") .param(PtxType::U64, "in_ptr") .param(PtxType::U32, "n_dim")
.build(|ctx| {
let block_id = ctx.special_reg(PtxReg::CtaIdX);
let thread_id = ctx.special_reg(PtxReg::TidX);
let lane_id = ctx.rem_u32(thread_id, 32);
let n_dim = ctx.load_param_u32("n_dim");
let num_blocks = ctx.add_u32(n_dim, 31);
let num_blocks = ctx.div_u32(num_blocks, 32);
let oob = ctx.setp_ge_u32(block_id, num_blocks);
ctx.branch_if(oob, "exit");
let out_ptr = ctx.load_param_u64("out_ptr");
let in_ptr = ctx.load_param_u64("in_ptr");
let block_start = ctx.mul_u32(block_id, 32);
let idx = ctx.add_u32_reg(block_start, lane_id);
let idx_64 = ctx.cvt_u64_u32(idx);
let idx_bytes = ctx.mul_u64(idx_64, 4);
let in_addr = ctx.add_u64(in_ptr, idx_bytes);
let val = ctx.ld_global_f32(in_addr);
let abs_val = ctx.abs_f32(val);
let max_abs = abs_val;
let tmp16 = ctx.shfl_down_f32(max_abs, 16, 0xFFFF_FFFF);
let max_abs = ctx.max_f32(max_abs, tmp16);
let tmp8 = ctx.shfl_down_f32(max_abs, 8, 0xFFFF_FFFF);
let max_abs = ctx.max_f32(max_abs, tmp8);
let tmp4 = ctx.shfl_down_f32(max_abs, 4, 0xFFFF_FFFF);
let max_abs = ctx.max_f32(max_abs, tmp4);
let tmp2 = ctx.shfl_down_f32(max_abs, 2, 0xFFFF_FFFF);
let max_abs = ctx.max_f32(max_abs, tmp2);
let tmp1 = ctx.shfl_down_f32(max_abs, 1, 0xFFFF_FFFF);
let max_abs = ctx.max_f32(max_abs, tmp1);
let max_abs = ctx.shfl_idx_f32(max_abs, 0, 0xFFFF_FFFF);
let inv_127 = ctx.mov_f32_imm(1.0 / 127.0);
let scale = ctx.mul_f32(max_abs, inv_127);
let eps = ctx.mov_f32_imm(1e-10);
let scale_eps = ctx.add_f32(scale, eps);
let inv_scale = ctx.rcp_f32(scale_eps);
let scaled = ctx.mul_f32(val, inv_scale);
let rounded = ctx.cvt_rni_s32_f32(scaled);
let min_val = ctx.mov_u32_imm(0xFFFF_FF81); let min_s32 = ctx.mov_s32_from_u32(min_val);
let max_val = ctx.mov_s32_imm(127);
let clamped = ctx.max_s32(rounded, min_s32);
let clamped = ctx.min_s32(clamped, max_val);
let q8_val = ctx.cvt_u8_s32(clamped);
let block_bytes = ctx.mov_u32_imm(36);
let block_offset = ctx.mul_wide_u32_reg(block_id, block_bytes);
let block_base = ctx.add_u64(out_ptr, block_offset);
let lane_64 = ctx.cvt_u64_u32(lane_id);
let qs_addr = ctx.add_u64(block_base, lane_64);
ctx.st_global_u8(qs_addr, q8_val);
let one = ctx.mov_u32_imm(1);
let is_lane0 = ctx.setp_lt_u32(lane_id, one);
ctx.branch_if_not(is_lane0, "exit");
let thirty_two_64 = ctx.mov_u64_imm(32);
let d_addr = ctx.add_u64(block_base, thirty_two_64);
let scale_f16 = ctx.cvt_f16_f32(scale);
ctx.st_global_f16(d_addr, scale_f16);
let sum = val;
let sum_tmp16 = ctx.shfl_down_f32(sum, 16, 0xFFFF_FFFF);
let sum = ctx.add_f32(sum, sum_tmp16);
let sum_tmp8 = ctx.shfl_down_f32(sum, 8, 0xFFFF_FFFF);
let sum = ctx.add_f32(sum, sum_tmp8);
let sum_tmp4 = ctx.shfl_down_f32(sum, 4, 0xFFFF_FFFF);
let sum = ctx.add_f32(sum, sum_tmp4);
let sum_tmp2 = ctx.shfl_down_f32(sum, 2, 0xFFFF_FFFF);
let sum = ctx.add_f32(sum, sum_tmp2);
let sum_tmp1 = ctx.shfl_down_f32(sum, 1, 0xFFFF_FFFF);
let sum = ctx.add_f32(sum, sum_tmp1);
let thirty_four_64 = ctx.mov_u64_imm(34);
let s_addr = ctx.add_u64(block_base, thirty_four_64);
let sum_f16 = ctx.cvt_f16_f32(sum);
ctx.st_global_f16(s_addr, sum_f16);
ctx.label("exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests;