#![allow(clippy::similar_names)]
#![allow(clippy::too_many_lines)]
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl, PtxMemory};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct RmsNormKernel {
pub hidden_size: u32,
pub epsilon: f32,
}
impl RmsNormKernel {
#[must_use]
pub fn new(hidden_size: u32) -> Self {
Self { hidden_size, epsilon: 1e-5 }
}
#[must_use]
pub const fn with_epsilon(mut self, epsilon: f32) -> Self {
self.epsilon = epsilon;
self
}
}
impl Kernel for RmsNormKernel {
fn name(&self) -> &str {
"rmsnorm"
}
fn build_ptx(&self) -> PtxKernel {
let hidden_size = self.hidden_size;
let epsilon = self.epsilon;
PtxKernel::new("rmsnorm")
.param(PtxType::U64, "input_ptr") .param(PtxType::U64, "output_ptr") .param(PtxType::U64, "gamma_ptr") .shared_memory(0) .build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let input_ptr = ctx.load_param_u64("input_ptr");
let output_ptr = ctx.load_param_u64("output_ptr");
let gamma_ptr = ctx.load_param_u64("gamma_ptr");
let hidden_u32 = ctx.mov_u32_imm(hidden_size);
let four = ctx.mov_u32_imm(4);
let sq_sum = ctx.mov_f32_imm(0.0);
let idx = ctx.mov_u32_imm(0);
ctx.label("sum_loop");
let loop_idx = ctx.add_u32_reg(idx, tid);
let in_bounds = ctx.setp_lt_u32(loop_idx, hidden_u32);
ctx.branch_if_not(in_bounds, "sum_loop_end");
let elem_offset = ctx.mul_wide_u32_reg(loop_idx, four);
let elem_addr = ctx.add_u64(input_ptr, elem_offset);
let val = ctx.ld_global_f32(elem_addr);
ctx.fma_f32_inplace(sq_sum, val, val);
ctx.add_u32_inplace(idx, 32);
ctx.branch("sum_loop");
ctx.label("sum_loop_end");
let shfl16 = ctx.shfl_down_f32(sq_sum, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl16);
let shfl8 = ctx.shfl_down_f32(sq_sum, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl8);
let shfl4 = ctx.shfl_down_f32(sq_sum, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl4);
let shfl2 = ctx.shfl_down_f32(sq_sum, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl2);
let shfl1 = ctx.shfl_down_f32(sq_sum, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl1);
let total_sq_sum = ctx.shfl_idx_f32(sq_sum, 0, 0xFFFF_FFFF);
let hidden_f32 = ctx.cvt_f32_u32(hidden_u32);
let mean_sq = ctx.div_f32(total_sq_sum, hidden_f32);
let eps = ctx.mov_f32_imm(epsilon);
let mean_sq_eps = ctx.add_f32(mean_sq, eps);
let rms_inv = ctx.rsqrt_f32(mean_sq_eps);
let idx2 = ctx.mov_u32_imm(0);
ctx.label("norm_loop");
let loop_idx2 = ctx.add_u32_reg(idx2, tid);
let in_bounds2 = ctx.setp_lt_u32(loop_idx2, hidden_u32);
ctx.branch_if_not(in_bounds2, "exit");
let elem_offset2 = ctx.mul_wide_u32_reg(loop_idx2, four);
let in_addr = ctx.add_u64(input_ptr, elem_offset2);
let gamma_addr = ctx.add_u64(gamma_ptr, elem_offset2);
let out_addr = ctx.add_u64(output_ptr, elem_offset2);
let inp = ctx.ld_global_f32(in_addr);
let gamma = ctx.ld_global_f32(gamma_addr);
let normalized = ctx.mul_f32(inp, rms_inv);
let result = ctx.mul_f32(normalized, gamma);
ctx.st_global_f32(out_addr, result);
ctx.add_u32_inplace(idx2, 32);
ctx.branch("norm_loop");
ctx.label("exit");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct VectorizedRmsNormKernel {
pub hidden_size: u32,
pub epsilon: f32,
}
impl VectorizedRmsNormKernel {
#[must_use]
pub fn new(hidden_size: u32) -> Self {
Self { hidden_size, epsilon: 1e-5 }
}
#[must_use]
pub const fn with_epsilon(mut self, epsilon: f32) -> Self {
self.epsilon = epsilon;
self
}
}
impl Kernel for VectorizedRmsNormKernel {
fn name(&self) -> &str {
"rmsnorm_vectorized"
}
fn build_ptx(&self) -> PtxKernel {
let hidden_size = self.hidden_size;
let epsilon = self.epsilon;
PtxKernel::new("rmsnorm_vectorized")
.param(PtxType::U64, "input_ptr")
.param(PtxType::U64, "output_ptr")
.param(PtxType::U64, "gamma_ptr")
.shared_memory(8 * 4) .build(move |ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let warp_id = ctx.div_u32(tid, 32);
let lane_id = ctx.rem_u32(tid, 32);
let input_ptr = ctx.load_param_u64("input_ptr");
let output_ptr = ctx.load_param_u64("output_ptr");
let gamma_ptr = ctx.load_param_u64("gamma_ptr");
let hidden_u32 = ctx.mov_u32_imm(hidden_size);
let four = ctx.mov_u32_imm(4);
let _thread_count = ctx.mov_u32_imm(256);
let sq_sum = ctx.mov_f32_imm(0.0);
let idx = ctx.mov_u32_imm(0);
ctx.label("sum_loop");
let loop_idx = ctx.add_u32_reg(idx, tid);
let in_bounds = ctx.setp_lt_u32(loop_idx, hidden_u32);
ctx.branch_if_not(in_bounds, "sum_loop_end");
let elem_offset = ctx.mul_wide_u32_reg(loop_idx, four);
let elem_addr = ctx.add_u64(input_ptr, elem_offset);
let val = ctx.ld_global_f32(elem_addr);
ctx.fma_f32_inplace(sq_sum, val, val);
ctx.add_u32_inplace(idx, 256);
ctx.branch("sum_loop");
ctx.label("sum_loop_end");
let shfl16 = ctx.shfl_down_f32(sq_sum, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl16);
let shfl8 = ctx.shfl_down_f32(sq_sum, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl8);
let shfl4 = ctx.shfl_down_f32(sq_sum, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl4);
let shfl2 = ctx.shfl_down_f32(sq_sum, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl2);
let shfl1 = ctx.shfl_down_f32(sq_sum, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(sq_sum, shfl1);
let zero = ctx.mov_u32_imm(0);
let eight = ctx.mov_u32_imm(8);
let thirty_two = ctx.mov_u32_imm(32);
let lane_zero = ctx.setp_eq_u32(lane_id, zero);
let warp_smem_off = ctx.mul_u32(warp_id, 4);
ctx.branch_if_not(lane_zero, "skip_smem_write");
ctx.st_shared_f32(warp_smem_off, sq_sum);
ctx.label("skip_smem_write");
ctx.bar_sync(0);
let is_first_warp = ctx.setp_lt_u32(tid, thirty_two);
ctx.branch_if_not(is_first_warp, "skip_final_reduce");
let lane_valid = ctx.setp_lt_u32(lane_id, eight);
let lane_smem_off = ctx.mul_u32(lane_id, 4);
let warp_partial = ctx.mov_f32_imm(0.0);
ctx.branch_if_not(lane_valid, "skip_warp_load");
let loaded_val = ctx.ld_shared_f32(lane_smem_off);
let _zero_f32 = ctx.mov_f32_imm(0.0);
ctx.add_f32_inplace(warp_partial, loaded_val);
ctx.label("skip_warp_load");
let red4 = ctx.shfl_down_f32(warp_partial, 4, 0xFFFF_FFFF);
let partial = ctx.add_f32(warp_partial, red4);
let red2 = ctx.shfl_down_f32(partial, 2, 0xFFFF_FFFF);
let partial = ctx.add_f32(partial, red2);
let red1 = ctx.shfl_down_f32(partial, 1, 0xFFFF_FFFF);
let final_sum = ctx.add_f32(partial, red1);
let smem_zero = ctx.mov_u32_imm(0);
ctx.branch_if_not(lane_zero, "skip_final_write");
ctx.st_shared_f32(smem_zero, final_sum);
ctx.label("skip_final_write");
ctx.label("skip_final_reduce");
ctx.bar_sync(1);
let smem_read_zero = ctx.mov_u32_imm(0);
let total = ctx.ld_shared_f32(smem_read_zero);
let hidden_f32 = ctx.cvt_f32_u32(hidden_u32);
let mean_sq = ctx.div_f32(total, hidden_f32);
let eps = ctx.mov_f32_imm(epsilon);
let mean_sq_eps = ctx.add_f32(mean_sq, eps);
let rms_inv = ctx.rsqrt_f32(mean_sq_eps);
let idx2 = ctx.mov_u32_imm(0);
ctx.label("norm_loop");
let loop_idx2 = ctx.add_u32_reg(idx2, tid);
let in_bounds2 = ctx.setp_lt_u32(loop_idx2, hidden_u32);
ctx.branch_if_not(in_bounds2, "exit");
let elem_offset2 = ctx.mul_wide_u32_reg(loop_idx2, four);
let in_addr = ctx.add_u64(input_ptr, elem_offset2);
let gamma_addr = ctx.add_u64(gamma_ptr, elem_offset2);
let out_addr = ctx.add_u64(output_ptr, elem_offset2);
let inp = ctx.ld_global_f32(in_addr);
let gamma = ctx.ld_global_f32(gamma_addr);
let normalized = ctx.mul_f32(inp, rms_inv);
let result = ctx.mul_f32(normalized, gamma);
ctx.st_global_f32(out_addr, result);
ctx.add_u32_inplace(idx2, 256);
ctx.branch("norm_loop");
ctx.label("exit");
ctx.ret();
})
}
}