use crate::kernels::quantize::{Kernel, Q6K_SUPER_BLOCK_BYTES, Q6K_SUPER_BLOCK_SIZE};
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct Q6KDequantKernel {
pub k: u32,
pub n: u32,
}
impl Q6KDequantKernel {
#[must_use]
pub fn new(k: u32, n: u32) -> Self {
Self { k, n }
}
#[must_use]
pub const fn num_super_blocks_per_row(&self) -> u32 {
(self.k + Q6K_SUPER_BLOCK_SIZE - 1) / Q6K_SUPER_BLOCK_SIZE
}
}
impl Kernel for Q6KDequantKernel {
fn name(&self) -> &str {
"q6k_dequant_to_f32"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("q6k_dequant_to_f32")
.param(PtxType::U64, "out_ptr")
.param(PtxType::U64, "w_ptr")
.param(PtxType::U32, "k_dim")
.param(PtxType::U32, "n_dim")
.build(|ctx| {
let row_id = ctx.special_reg(PtxReg::CtaIdX);
let sb_idx = ctx.special_reg(PtxReg::CtaIdY);
let thread_id = ctx.special_reg(PtxReg::TidX);
let n_dim = ctx.load_param_u32("n_dim");
let k_dim = ctx.load_param_u32("k_dim");
let oob = ctx.setp_ge_u32(row_id, n_dim);
ctx.branch_if(oob, "exit");
let out_ptr = ctx.load_param_u64("out_ptr");
let w_ptr = ctx.load_param_u64("w_ptr");
let k_rounded = ctx.add_u32(k_dim, Q6K_SUPER_BLOCK_SIZE - 1);
let num_sb = ctx.div_u32(k_rounded, Q6K_SUPER_BLOCK_SIZE);
let sb_oob = ctx.setp_ge_u32(sb_idx, num_sb);
ctx.branch_if(sb_oob, "exit");
let sb_bytes = ctx.mov_u32_imm(Q6K_SUPER_BLOCK_BYTES);
let row_bytes = ctx.mul_u32_reg(num_sb, sb_bytes);
let row_offset = ctx.mul_wide_u32_reg(row_id, row_bytes);
let row_base = ctx.add_u64(w_ptr, row_offset);
let sb_offset = ctx.mul_wide_u32(sb_idx, Q6K_SUPER_BLOCK_BYTES);
let sb_addr = ctx.add_u64(row_base, sb_offset);
let d_offset = ctx.mov_u64_imm(208);
let d_addr = ctx.add_u64(sb_addr, d_offset);
let d_f16 = ctx.ld_global_f16(d_addr);
let d = ctx.cvt_f32_f16(d_f16);
let scales_offset = ctx.mov_u64_imm(192);
let scales_base = ctx.add_u64(sb_addr, scales_offset);
let mut scale_f32s = Vec::with_capacity(16);
for i in 0..16u64 {
let s_off = ctx.mov_u64_imm(i);
let s_addr = ctx.add_u64(scales_base, s_off);
let s_u8 = ctx.ld_global_u8(s_addr);
let s_i32 = ctx.cvt_s32_s8(s_u8);
let s_f32 = ctx.cvt_f32_s32(s_i32);
let ds = ctx.mul_f32(d, s_f32);
scale_f32s.push(ds);
}
let ql_base = sb_addr;
let qh_offset = ctx.mov_u64_imm(128);
let qh_base = ctx.add_u64(sb_addr, qh_offset);
let sb_k_base = ctx.mul_u32(sb_idx, Q6K_SUPER_BLOCK_SIZE);
let row_k = ctx.mul_u32_reg(row_id, k_dim);
let out_k_base = ctx.add_u32_reg(row_k, sb_k_base);
let out_k_base_64 = ctx.cvt_u64_u32(out_k_base);
let out_k_bytes = ctx.mul_u64(out_k_base_64, 4);
let out_base = ctx.add_u64(out_ptr, out_k_bytes);
let mask_0f = ctx.mov_u32_imm(0x0F);
let four_u32 = ctx.mov_u32_imm(4);
let const_32_f = ctx.mov_f32_imm(32.0);
let sixteen = ctx.mov_u32_imm(16);
for step in 0..8u32 {
let offset = step * 32;
let offset_reg = ctx.mov_u32_imm(offset);
let val_idx = ctx.add_u32_reg(thread_id, offset_reg);
let global_k = ctx.add_u32_reg(sb_k_base, val_idx);
let out_of_bounds = ctx.setp_ge_u32(global_k, k_dim);
let skip_label = format!("skip_store_{step}");
ctx.branch_if(out_of_bounds, &skip_label);
let half = step / 4; let group = step % 4;
let ql_byte_offset = 64 * half + 32 * (group / 2);
let ql_off_reg = ctx.mov_u32_imm(ql_byte_offset);
let ql_idx = ctx.add_u32_reg(ql_off_reg, thread_id);
let ql_idx_64 = ctx.cvt_u64_u32(ql_idx);
let ql_addr = ctx.add_u64(ql_base, ql_idx_64);
let ql_byte = ctx.ld_global_u8(ql_addr);
let ql_u32 = ctx.cvt_u32_u8(ql_byte);
let ql_nibble = if group % 2 == 0 {
ctx.and_u32(ql_u32, mask_0f)
} else {
ctx.shr_u32(ql_u32, four_u32)
};
let qh_byte_offset = 32 * half;
let qh_off_reg = ctx.mov_u32_imm(qh_byte_offset);
let qh_idx = ctx.add_u32_reg(qh_off_reg, thread_id);
let qh_idx_64 = ctx.cvt_u64_u32(qh_idx);
let qh_addr = ctx.add_u64(qh_base, qh_idx_64);
let qh_byte = ctx.ld_global_u8(qh_addr);
let qh_u32 = ctx.cvt_u32_u8(qh_byte);
let qh_shift = ctx.mov_u32_imm(2 * group);
let qh_shifted = ctx.shr_u32(qh_u32, qh_shift);
let mask_03 = ctx.mov_u32_imm(0x03);
let qh_2bits = ctx.and_u32(qh_shifted, mask_03);
let qh_hi = ctx.shl_u32(qh_2bits, four_u32);
let q6 = ctx.or_u32(ql_nibble, qh_hi);
let q6_f32 = ctx.cvt_f32_u32(q6);
let q6_centered = ctx.sub_f32(q6_f32, const_32_f);
let sb_base = (8 * half + 2 * group) as usize;
let ds_lo = scale_f32s[sb_base]; let ds_hi = scale_f32s[sb_base + 1]; let is_hi = ctx.setp_ge_u32(thread_id, sixteen);
let ds = ctx.selp_f32(is_hi, ds_hi, ds_lo);
let dequant = ctx.mul_f32(ds, q6_centered);
let val_idx_64 = ctx.cvt_u64_u32(val_idx);
let val_bytes = ctx.mul_u64(val_idx_64, 4);
let out_addr = ctx.add_u64(out_base, val_bytes);
ctx.st_global_f32(out_addr, dequant);
ctx.label(&skip_label);
}
ctx.label("exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_q6k_dequant_kernel_emits_ptx() {
let kernel = Q6KDequantKernel::new(1536, 256);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("q6k_dequant_to_f32"));
assert!(ptx.contains(".entry"));
}
#[test]
fn test_q6k_dequant_kernel_name() {
let kernel = Q6KDequantKernel::new(256, 16);
assert_eq!(kernel.name(), "q6k_dequant_to_f32");
}
#[test]
fn test_num_super_blocks_per_row() {
let kernel = Q6KDequantKernel::new(1536, 256);
assert_eq!(kernel.num_super_blocks_per_row(), 6);
let kernel = Q6KDequantKernel::new(4096, 1536);
assert_eq!(kernel.num_super_blocks_per_row(), 16);
}
}