use super::super::nf4_cpu::{NF4_BLOCK_SIZE, NF4_LUT};
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl, PtxMemory};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
const NF4_BLOCK_SIZE_U32: u32 = NF4_BLOCK_SIZE as u32;
const NF4_BLOCK_DATA_BYTES: u32 = (NF4_BLOCK_SIZE / 2) as u32;
#[derive(Debug, Clone)]
pub struct FusedRmsNormNf4GemvKernel {
pub k: u32,
pub n: u32,
pub epsilon: f32,
}
impl FusedRmsNormNf4GemvKernel {
#[must_use]
pub fn new(k: u32, n: u32) -> Self {
Self {
k,
n,
epsilon: 1e-5,
}
}
#[must_use]
pub const fn with_epsilon(mut self, epsilon: f32) -> Self {
self.epsilon = epsilon;
self
}
#[must_use]
pub const fn num_blocks_per_col(&self) -> u32 {
self.k / NF4_BLOCK_SIZE_U32
}
}
impl Kernel for FusedRmsNormNf4GemvKernel {
fn name(&self) -> &str {
"fused_rmsnorm_nf4_gemv"
}
fn build_ptx(&self) -> PtxKernel {
let k = self.k;
let epsilon = self.epsilon;
let num_k_blocks = k / NF4_BLOCK_SIZE_U32;
let smem_size = (k * 4 + 32) as usize;
PtxKernel::new("fused_rmsnorm_nf4_gemv")
.param(PtxType::U64, "y_ptr") .param(PtxType::U64, "scales_ptr") .param(PtxType::U64, "data_ptr") .param(PtxType::U64, "x_ptr") .param(PtxType::U64, "gamma_ptr") .param(PtxType::U32, "k_dim")
.param(PtxType::U32, "n_dim")
.shared_memory(smem_size)
.build(move |ctx| {
let block_id = ctx.special_reg(PtxReg::CtaIdX);
let thread_id = ctx.special_reg(PtxReg::TidX);
let n_dim = ctx.load_param_u32("n_dim");
let oob = ctx.setp_ge_u32(block_id, n_dim);
ctx.branch_if(oob, "exit");
let k_dim = ctx.load_param_u32("k_dim");
let y_ptr = ctx.load_param_u64("y_ptr");
let scales_ptr = ctx.load_param_u64("scales_ptr");
let data_ptr = ctx.load_param_u64("data_ptr");
let x_ptr = ctx.load_param_u64("x_ptr");
let gamma_ptr = ctx.load_param_u64("gamma_ptr");
let four = ctx.mov_u32_imm(4);
let one = ctx.mov_u32_imm(1);
let sq_sum = ctx.mov_f32_imm(0.0);
let idx = ctx.mov_u32_imm(0);
ctx.label("load_loop");
let loop_idx = ctx.add_u32_reg(idx, thread_id);
let in_bounds = ctx.setp_lt_u32(loop_idx, k_dim);
ctx.branch_if_not(in_bounds, "load_loop_end");
let elem_offset = ctx.mul_wide_u32_reg(loop_idx, four);
let x_addr = ctx.add_u64(x_ptr, elem_offset);
let x_val = ctx.ld_global_f32(x_addr);
ctx.st_shared_f32(elem_offset, x_val);
ctx.fma_f32_inplace(sq_sum, x_val, x_val);
ctx.add_u32_inplace(idx, 256);
ctx.branch("load_loop");
ctx.label("load_loop_end");
let shfl16 = ctx.shfl_down_f32(sq_sum, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl16);
let shfl8 = ctx.shfl_down_f32(sq_sum, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl8);
let shfl4 = ctx.shfl_down_f32(sq_sum, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl4);
let shfl2 = ctx.shfl_down_f32(sq_sum, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl2);
let shfl1 = ctx.shfl_down_f32(sq_sum, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl1);
let lane_id = ctx.rem_u32(thread_id, 32);
let warp_id = ctx.div_u32(thread_id, 32);
let is_lane0 = ctx.setp_lt_u32(lane_id, one);
let k_bytes = ctx.mul_u32_reg(k_dim, four);
let warp_sum_offset = ctx.mul_wide_u32_reg(warp_id, four);
let k_bytes_64 = ctx.cvt_u64_u32(k_bytes);
let warp_sum_addr = ctx.add_u64(k_bytes_64, warp_sum_offset);
ctx.branch_if_not(is_lane0, "skip_warp_write");
ctx.st_shared_f32(warp_sum_addr, sq_sum);
ctx.label("skip_warp_write");
ctx.bar_sync(0);
let is_thread0 = ctx.setp_lt_u32(thread_id, one);
let total_sq = ctx.mov_f32_imm(0.0);
ctx.branch_if_not(is_thread0, "skip_rms_reduce");
for warp in 0..8u32 {
let w_off = ctx.mov_u64_imm((warp * 4) as u64);
let w_addr = ctx.add_u64(k_bytes_64, w_off);
let w_sum = ctx.ld_shared_f32(w_addr);
ctx.add_f32_inplace(total_sq, w_sum);
}
let k_f32 = ctx.cvt_f32_u32(k_dim);
let mean_sq = ctx.div_f32(total_sq, k_f32);
let eps = ctx.mov_f32_imm(epsilon);
let mean_sq_eps = ctx.add_f32(mean_sq, eps);
let rms_inv = ctx.rsqrt_f32(mean_sq_eps);
let zero_off = ctx.mov_u64_imm(0);
ctx.st_shared_f32(zero_off, rms_inv);
ctx.label("skip_rms_reduce");
ctx.bar_sync(1);
let rms_inv_off = ctx.mov_u64_imm(0);
let rms_inv_val = ctx.ld_shared_f32(rms_inv_off);
let idx2 = ctx.mov_u32_imm(0);
ctx.label("norm_loop");
let loop_idx2 = ctx.add_u32_reg(idx2, thread_id);
let in_bounds2 = ctx.setp_lt_u32(loop_idx2, k_dim);
ctx.branch_if_not(in_bounds2, "norm_loop_end");
let elem_off2 = ctx.mul_wide_u32_reg(loop_idx2, four);
let x_smem = ctx.ld_shared_f32(elem_off2);
let gamma_addr = ctx.add_u64(gamma_ptr, elem_off2);
let gamma_val = ctx.ld_global_f32(gamma_addr);
let normalized = ctx.mul_f32(x_smem, rms_inv_val);
let scaled = ctx.mul_f32(normalized, gamma_val);
ctx.st_shared_f32(elem_off2, scaled);
ctx.add_u32_inplace(idx2, 256);
ctx.branch("norm_loop");
ctx.label("norm_loop_end");
ctx.bar_sync(2);
let lut_regs: [_; 16] = std::array::from_fn(|i| ctx.mov_f32_imm(NF4_LUT[i]));
let acc = ctx.mov_f64_imm_zero();
let num_k_blocks_reg = ctx.mov_u32_imm(num_k_blocks);
let scale_row_elems = ctx.mul_u32_reg(block_id, num_k_blocks_reg);
let scale_row_offset = ctx.mul_wide_u32_reg(scale_row_elems, four);
let scale_row_base = ctx.add_u64(scales_ptr, scale_row_offset);
let thirty_two = ctx.mov_u32_imm(NF4_BLOCK_DATA_BYTES);
let data_row_elems = ctx.mul_u32_reg(block_id, num_k_blocks_reg);
let data_row_blocks = ctx.mul_u32_reg(data_row_elems, thirty_two);
let data_row_offset = ctx.cvt_u64_u32(data_row_blocks);
let data_row_base = ctx.add_u64(data_ptr, data_row_offset);
let elem_idx = ctx.mov_u32_imm(0);
ctx.label("nf4_loop");
let e_idx = ctx.add_u32_reg(elem_idx, thread_id);
let in_bounds3 = ctx.setp_lt_u32(e_idx, k_dim);
ctx.branch_if_not(in_bounds3, "nf4_loop_end");
let block_idx = ctx.div_u32(e_idx, NF4_BLOCK_SIZE_U32);
let elem_in_block = ctx.rem_u32(e_idx, NF4_BLOCK_SIZE_U32);
let scale_off = ctx.mul_wide_u32_reg(block_idx, four);
let scale_addr = ctx.add_u64(scale_row_base, scale_off);
let scale = ctx.ld_global_f32(scale_addr);
let byte_idx = ctx.div_u32(elem_in_block, 2);
let block_data_off = ctx.mul_u32_reg(block_idx, thirty_two);
let byte_off_in_data = ctx.add_u32_reg(block_data_off, byte_idx);
let byte_off_64 = ctx.cvt_u64_u32(byte_off_in_data);
let byte_addr = ctx.add_u64(data_row_base, byte_off_64);
let packed_byte = ctx.ld_global_u8(byte_addr);
let packed_u32 = ctx.cvt_u32_u8(packed_byte);
let is_high = ctx.rem_u32(elem_in_block, 2);
let shift_amt = ctx.mul_u32_reg(is_high, four);
let shifted = ctx.shr_u32(packed_u32, shift_amt);
let mask_4bit = ctx.mov_u32_imm(0xF);
let nibble = ctx.and_u32(shifted, mask_4bit);
let codebook_val =
super::super::nf4::nf4_register_lut_lookup(ctx, nibble, &lut_regs);
let weight = ctx.mul_f32(scale, codebook_val);
let x_smem_off = ctx.mul_wide_u32_reg(e_idx, four);
let x_norm = ctx.ld_shared_f32(x_smem_off);
ctx.fma_f64_acc_inplace(acc, x_norm, weight);
ctx.add_u32_inplace(elem_idx, 256);
ctx.branch("nf4_loop");
ctx.label("nf4_loop_end");
let acc_f32 = ctx.cvt_f32_f64_rn(acc);
let w_shfl16 = ctx.shfl_down_f32(acc_f32, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_f32, w_shfl16);
let w_shfl8 = ctx.shfl_down_f32(acc_f32, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_f32, w_shfl8);
let w_shfl4 = ctx.shfl_down_f32(acc_f32, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_f32, w_shfl4);
let w_shfl2 = ctx.shfl_down_f32(acc_f32, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_f32, w_shfl2);
let w_shfl1 = ctx.shfl_down_f32(acc_f32, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(acc_f32, w_shfl1);
let warp_acc_off = ctx.mul_wide_u32_reg(warp_id, four);
let warp_acc_addr = ctx.add_u64(k_bytes_64, warp_acc_off);
ctx.branch_if_not(is_lane0, "skip_warp_acc");
ctx.st_shared_f32(warp_acc_addr, acc_f32);
ctx.label("skip_warp_acc");
ctx.bar_sync(3);
ctx.branch_if_not(is_thread0, "exit");
let final_val = ctx.mov_f32_imm(0.0);
for warp in 0..8u32 {
let w_off = ctx.mov_u64_imm((warp * 4) as u64);
let w_addr = ctx.add_u64(k_bytes_64, w_off);
let w_acc = ctx.ld_shared_f32(w_addr);
ctx.add_f32_inplace(final_val, w_acc);
}
let y_off = ctx.mul_wide_u32(block_id, 4);
let y_addr = ctx.add_u64(y_ptr, y_off);
ctx.st_global_f32(y_addr, final_val);
ctx.label("exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fused_rmsnorm_nf4_gemv_name() {
let kernel = FusedRmsNormNf4GemvKernel::new(1536, 1536);
assert_eq!(kernel.name(), "fused_rmsnorm_nf4_gemv");
}
#[test]
fn test_fused_rmsnorm_nf4_gemv_ptx_emits() {
let kernel = FusedRmsNormNf4GemvKernel::new(1536, 1536);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("fused_rmsnorm_nf4_gemv"));
assert!(ptx.contains("scales_ptr"));
assert!(ptx.contains("gamma_ptr"));
assert!(ptx.contains("selp")); assert!(ptx.contains("rsqrt")); }
#[test]
fn test_fused_rmsnorm_nf4_gemv_num_blocks() {
let kernel = FusedRmsNormNf4GemvKernel::new(1536, 256);
assert_eq!(kernel.num_blocks_per_col(), 24); }
}