use super::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct Q8QuantizeKernel {
pub n: u32,
}
impl Q8QuantizeKernel {
#[must_use]
pub fn new(n: u32) -> Self {
Self { n }
}
#[must_use]
pub const fn num_blocks(&self) -> u32 {
(self.n + 31) / 32
}
}
impl Kernel for Q8QuantizeKernel {
fn name(&self) -> &str {
"q8_quantize"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("q8_quantize")
.param(PtxType::U64, "out_ptr") .param(PtxType::U64, "in_ptr") .param(PtxType::U32, "n_dim")
.build(|ctx| {
let block_id = ctx.special_reg(PtxReg::CtaIdX);
let thread_id = ctx.special_reg(PtxReg::TidX);
let lane_id = ctx.rem_u32(thread_id, 32);
let n_dim = ctx.load_param_u32("n_dim");
let num_blocks = ctx.add_u32(n_dim, 31);
let num_blocks = ctx.div_u32(num_blocks, 32);
let oob = ctx.setp_ge_u32(block_id, num_blocks);
ctx.branch_if(oob, "exit");
let out_ptr = ctx.load_param_u64("out_ptr");
let in_ptr = ctx.load_param_u64("in_ptr");
let block_start = ctx.mul_u32(block_id, 32);
let idx = ctx.add_u32_reg(block_start, lane_id);
let idx_64 = ctx.cvt_u64_u32(idx);
let idx_bytes = ctx.mul_u64(idx_64, 4);
let in_addr = ctx.add_u64(in_ptr, idx_bytes);
let val = ctx.ld_global_f32(in_addr);
let abs_val = ctx.abs_f32(val);
let max_abs = abs_val;
let tmp16 = ctx.shfl_down_f32(max_abs, 16, 0xFFFF_FFFF);
let max_abs = ctx.max_f32(max_abs, tmp16);
let tmp8 = ctx.shfl_down_f32(max_abs, 8, 0xFFFF_FFFF);
let max_abs = ctx.max_f32(max_abs, tmp8);
let tmp4 = ctx.shfl_down_f32(max_abs, 4, 0xFFFF_FFFF);
let max_abs = ctx.max_f32(max_abs, tmp4);
let tmp2 = ctx.shfl_down_f32(max_abs, 2, 0xFFFF_FFFF);
let max_abs = ctx.max_f32(max_abs, tmp2);
let tmp1 = ctx.shfl_down_f32(max_abs, 1, 0xFFFF_FFFF);
let max_abs = ctx.max_f32(max_abs, tmp1);
let max_abs = ctx.shfl_idx_f32(max_abs, 0, 0xFFFF_FFFF);
let inv_127 = ctx.mov_f32_imm(1.0 / 127.0);
let scale = ctx.mul_f32(max_abs, inv_127);
let eps = ctx.mov_f32_imm(1e-10);
let scale_eps = ctx.add_f32(scale, eps);
let inv_scale = ctx.rcp_f32(scale_eps);
let scaled = ctx.mul_f32(val, inv_scale);
let rounded = ctx.cvt_rni_s32_f32(scaled);
let min_val = ctx.mov_u32_imm(0xFFFF_FF81); let min_s32 = ctx.mov_s32_from_u32(min_val);
let max_val = ctx.mov_s32_imm(127);
let clamped = ctx.max_s32(rounded, min_s32);
let clamped = ctx.min_s32(clamped, max_val);
let q8_val = ctx.cvt_u8_s32(clamped);
let block_bytes = ctx.mov_u32_imm(36);
let block_offset = ctx.mul_wide_u32_reg(block_id, block_bytes);
let block_base = ctx.add_u64(out_ptr, block_offset);
let lane_64 = ctx.cvt_u64_u32(lane_id);
let qs_addr = ctx.add_u64(block_base, lane_64);
ctx.st_global_u8(qs_addr, q8_val);
let one = ctx.mov_u32_imm(1);
let is_lane0 = ctx.setp_lt_u32(lane_id, one);
ctx.branch_if_not(is_lane0, "exit");
let thirty_two_64 = ctx.mov_u64_imm(32);
let d_addr = ctx.add_u64(block_base, thirty_two_64);
let scale_f16 = ctx.cvt_f16_f32(scale);
ctx.st_global_f16(d_addr, scale_f16);
let sum = val;
let sum_tmp16 = ctx.shfl_down_f32(sum, 16, 0xFFFF_FFFF);
let sum = ctx.add_f32(sum, sum_tmp16);
let sum_tmp8 = ctx.shfl_down_f32(sum, 8, 0xFFFF_FFFF);
let sum = ctx.add_f32(sum, sum_tmp8);
let sum_tmp4 = ctx.shfl_down_f32(sum, 4, 0xFFFF_FFFF);
let sum = ctx.add_f32(sum, sum_tmp4);
let sum_tmp2 = ctx.shfl_down_f32(sum, 2, 0xFFFF_FFFF);
let sum = ctx.add_f32(sum, sum_tmp2);
let sum_tmp1 = ctx.shfl_down_f32(sum, 1, 0xFFFF_FFFF);
let sum = ctx.add_f32(sum, sum_tmp1);
let thirty_four_64 = ctx.mov_u64_imm(34);
let s_addr = ctx.add_u64(block_base, thirty_four_64);
let sum_f16 = ctx.cvt_f16_f32(sum);
ctx.st_global_f16(s_addr, sum_f16);
ctx.label("exit");
ctx.ret();
})
}
}