#![allow(clippy::similar_names)]
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct LayerNormBackwardKernel {
pub num_rows: u32,
pub hidden_dim: u32,
}
impl LayerNormBackwardKernel {
#[must_use]
pub fn new(num_rows: u32, hidden_dim: u32) -> Self {
assert!(hidden_dim <= 32, "hidden_dim must be ≤ 32 for warp reduction");
Self { num_rows, hidden_dim }
}
}
impl Kernel for LayerNormBackwardKernel {
fn name(&self) -> &str {
"layer_norm_backward"
}
fn build_ptx(&self) -> PtxKernel {
let hidden_dim = self.hidden_dim;
PtxKernel::new("layer_norm_backward")
.param(PtxType::U64, "input_ptr")
.param(PtxType::U64, "gamma_ptr")
.param(PtxType::U64, "mean_ptr")
.param(PtxType::U64, "rstd_ptr")
.param(PtxType::U64, "grad_output_ptr")
.param(PtxType::U64, "grad_input_ptr")
.param(PtxType::U32, "num_rows")
.param(PtxType::U32, "hidden_dim")
.build(move |ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid = ctx.special_reg(PtxReg::CtaIdX);
let ntid = ctx.special_reg(PtxReg::NtidX);
let global_tid = ctx.mad_lo_u32(ctaid, ntid, tid);
let lane = ctx.and_u32_imm(global_tid, 31);
let warp_id = ctx.shr_u32_imm(global_tid, 5);
let num_rows_param = ctx.load_param_u32("num_rows");
let hidden_dim_param = ctx.load_param_u32("hidden_dim");
let input_ptr = ctx.load_param_u64("input_ptr");
let gamma_ptr = ctx.load_param_u64("gamma_ptr");
let mean_ptr = ctx.load_param_u64("mean_ptr");
let rstd_ptr = ctx.load_param_u64("rstd_ptr");
let grad_output_ptr = ctx.load_param_u64("grad_output_ptr");
let grad_input_ptr = ctx.load_param_u64("grad_input_ptr");
let valid_row = ctx.setp_lt_u32(warp_id, num_rows_param);
ctx.branch_if_not(valid_row, "exit");
let valid_lane = ctx.setp_lt_u32(lane, hidden_dim_param);
let row_elem_offset = ctx.mul_lo_u32(warp_id, hidden_dim_param);
let row_byte_offset = ctx.mul_wide_u32(row_elem_offset, 4);
let input_row_base = ctx.add_u64(input_ptr, row_byte_offset);
let grad_out_row_base = ctx.add_u64(grad_output_ptr, row_byte_offset);
let grad_in_row_base = ctx.add_u64(grad_input_ptr, row_byte_offset);
let lane_offset = ctx.mul_wide_u32(lane, 4);
let input_addr = ctx.add_u64(input_row_base, lane_offset);
let gamma_addr = ctx.add_u64(gamma_ptr, lane_offset);
let grad_out_addr = ctx.add_u64(grad_out_row_base, lane_offset);
let grad_in_addr = ctx.add_u64(grad_in_row_base, lane_offset);
let row_scalar_offset = ctx.mul_wide_u32(warp_id, 4);
let mean_addr = ctx.add_u64(mean_ptr, row_scalar_offset);
let rstd_addr = ctx.add_u64(rstd_ptr, row_scalar_offset);
let mean = ctx.ld_global_f32(mean_addr);
let rstd = ctx.ld_global_f32(rstd_addr);
let x_i = ctx.ld_global_f32_predicated(input_addr, valid_lane, 0.0);
let gamma_i = ctx.ld_global_f32_predicated(gamma_addr, valid_lane, 0.0);
let grad_y_i = ctx.ld_global_f32_predicated(grad_out_addr, valid_lane, 0.0);
let x_centered = ctx.sub_f32(x_i, mean);
let x_norm = ctx.mul_f32(x_centered, rstd);
let grad_y_gamma = ctx.mul_f32(grad_y_i, gamma_i);
let mut sum1 = grad_y_gamma;
let warp_mask = 0xFFFF_FFFFu32;
for offset in [16u32, 8, 4, 2, 1] {
if offset < hidden_dim {
let shuffled = ctx.shfl_down_f32(sum1, offset, warp_mask);
sum1 = ctx.add_f32(sum1, shuffled);
}
}
let total_sum1 = ctx.shfl_idx_f32(sum1, 0, warp_mask);
let x_norm_grad_gamma = ctx.mul_f32(x_norm, grad_y_gamma);
let mut sum2 = x_norm_grad_gamma;
for offset in [16u32, 8, 4, 2, 1] {
if offset < hidden_dim {
let shuffled = ctx.shfl_down_f32(sum2, offset, warp_mask);
sum2 = ctx.add_f32(sum2, shuffled);
}
}
let total_sum2 = ctx.shfl_idx_f32(sum2, 0, warp_mask);
let hidden_dim_f32 = ctx.cvt_f32_u32(hidden_dim_param);
let mean_grad_gamma = ctx.div_f32(total_sum1, hidden_dim_f32);
let mean_x_norm_grad_gamma = ctx.div_f32(total_sum2, hidden_dim_f32);
let correction1 = mean_grad_gamma;
let correction2 = ctx.mul_f32(x_norm, mean_x_norm_grad_gamma);
let total_correction = ctx.add_f32(correction1, correction2);
let adjusted_grad = ctx.sub_f32(grad_y_i, total_correction);
let gamma_rstd = ctx.mul_f32(gamma_i, rstd);
let grad_x_i = ctx.mul_f32(gamma_rstd, adjusted_grad);
ctx.branch_if_not(valid_lane, "exit");
ctx.st_global_f32(grad_in_addr, grad_x_i);
ctx.label("exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layer_norm_backward_name() {
let kernel = LayerNormBackwardKernel::new(64, 32);
assert_eq!(kernel.name(), "layer_norm_backward");
}
#[test]
fn test_layer_norm_backward_ptx_generation() {
let kernel = LayerNormBackwardKernel::new(64, 32);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry layer_norm_backward"));
assert!(ptx.contains(".param .u64 input_ptr"));
assert!(ptx.contains(".param .u64 gamma_ptr"));
assert!(ptx.contains(".param .u64 mean_ptr"));
assert!(ptx.contains(".param .u64 rstd_ptr"));
assert!(ptx.contains(".param .u64 grad_output_ptr"));
assert!(ptx.contains(".param .u64 grad_input_ptr"));
assert!(ptx.contains("shfl.sync.down"));
assert!(ptx.contains("sub.f32"));
}
#[test]
fn test_layer_norm_backward_small_hidden() {
let kernel = LayerNormBackwardKernel::new(128, 16);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry layer_norm_backward"));
assert!(ptx.contains("shfl.sync"));
}
#[test]
fn test_layer_norm_backward_barrier_safety() {
let kernel = LayerNormBackwardKernel::new(64, 32);
let result = kernel.analyze_barrier_safety();
assert!(
result.is_safe,
"LayerNorm backward should be barrier-safe: {:?}",
result.violations
);
}
#[test]
#[should_panic(expected = "hidden_dim must be ≤ 32")]
fn test_layer_norm_backward_hidden_dim_limit() {
let _ = LayerNormBackwardKernel::new(64, 64);
}
}