#![allow(clippy::similar_names)]
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct PerHeadRmsNormKernel {
pub head_dim: u32,
pub num_heads: u32,
pub epsilon: f32,
}
impl PerHeadRmsNormKernel {
#[must_use]
pub fn new(head_dim: u32, num_heads: u32) -> Self {
Self { head_dim, num_heads, epsilon: 1e-6 }
}
#[must_use]
pub const fn with_epsilon(mut self, epsilon: f32) -> Self {
self.epsilon = epsilon;
self
}
}
impl Kernel for PerHeadRmsNormKernel {
fn name(&self) -> &str {
"per_head_rmsnorm"
}
fn build_ptx(&self) -> PtxKernel {
let head_dim = self.head_dim;
let epsilon = self.epsilon;
PtxKernel::new("per_head_rmsnorm")
.param(PtxType::U64, "input_ptr") .param(PtxType::U64, "output_ptr") .param(PtxType::U64, "gamma_ptr") .shared_memory(0) .build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let head_idx = ctx.special_reg(PtxReg::CtaIdX);
let input_ptr = ctx.load_param_u64("input_ptr");
let output_ptr = ctx.load_param_u64("output_ptr");
let gamma_ptr = ctx.load_param_u64("gamma_ptr");
let head_dim_u32 = ctx.mov_u32_imm(head_dim);
let four = ctx.mov_u32_imm(4);
let head_elem_offset = ctx.mul_u32_reg(head_idx, head_dim_u32);
let head_byte_offset = ctx.mul_wide_u32_reg(head_elem_offset, four);
let head_input_base = ctx.add_u64(input_ptr, head_byte_offset);
let head_output_base = ctx.add_u64(output_ptr, head_byte_offset);
let sq_sum = ctx.mov_f32_imm(0.0);
let idx = ctx.mov_u32_imm(0);
ctx.label("sum_loop");
let loop_idx = ctx.add_u32_reg(idx, tid);
let in_bounds = ctx.setp_lt_u32(loop_idx, head_dim_u32);
ctx.branch_if_not(in_bounds, "sum_loop_end");
let elem_offset = ctx.mul_wide_u32_reg(loop_idx, four);
let elem_addr = ctx.add_u64(head_input_base, elem_offset);
let val = ctx.ld_global_f32(elem_addr);
ctx.fma_f32_inplace(sq_sum, val, val);
ctx.add_u32_inplace(idx, 32);
ctx.branch("sum_loop");
ctx.label("sum_loop_end");
let shfl16 = ctx.shfl_down_f32(sq_sum, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl16);
let shfl8 = ctx.shfl_down_f32(sq_sum, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl8);
let shfl4 = ctx.shfl_down_f32(sq_sum, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl4);
let shfl2 = ctx.shfl_down_f32(sq_sum, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl2);
let shfl1 = ctx.shfl_down_f32(sq_sum, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl1);
let total_sq_sum = ctx.shfl_idx_f32(sq_sum, 0, 0xFFFF_FFFF);
let head_dim_f32 = ctx.cvt_f32_u32(head_dim_u32);
let mean_sq = ctx.div_f32(total_sq_sum, head_dim_f32);
let eps = ctx.mov_f32_imm(epsilon);
let mean_sq_eps = ctx.add_f32(mean_sq, eps);
let rms_inv = ctx.rsqrt_f32(mean_sq_eps);
let idx2 = ctx.mov_u32_imm(0);
ctx.label("norm_loop");
let loop_idx2 = ctx.add_u32_reg(idx2, tid);
let in_bounds2 = ctx.setp_lt_u32(loop_idx2, head_dim_u32);
ctx.branch_if_not(in_bounds2, "exit");
let elem_offset2 = ctx.mul_wide_u32_reg(loop_idx2, four);
let in_addr = ctx.add_u64(head_input_base, elem_offset2);
let gamma_addr = ctx.add_u64(gamma_ptr, elem_offset2);
let out_addr = ctx.add_u64(head_output_base, elem_offset2);
let inp = ctx.ld_global_f32(in_addr);
let gamma = ctx.ld_global_f32(gamma_addr);
let normalized = ctx.mul_f32(inp, rms_inv);
let result = ctx.mul_f32(normalized, gamma);
ctx.st_global_f32(out_addr, result);
ctx.add_u32_inplace(idx2, 32);
ctx.branch("norm_loop");
ctx.label("exit");
ctx.ret();
})
}
}