#![allow(clippy::similar_names)]
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct ReluBackwardKernel {
pub n: u32,
}
impl ReluBackwardKernel {
#[must_use]
pub const fn new(n: u32) -> Self {
Self { n }
}
}
impl Kernel for ReluBackwardKernel {
fn name(&self) -> &str {
"relu_backward"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("relu_backward")
.param(PtxType::U64, "input_ptr")
.param(PtxType::U64, "grad_output_ptr")
.param(PtxType::U64, "grad_input_ptr")
.param(PtxType::U32, "n")
.build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid = ctx.special_reg(PtxReg::CtaIdX);
let ntid = ctx.special_reg(PtxReg::NtidX);
let gid = ctx.mad_lo_u32(ctaid, ntid, tid);
let n = ctx.load_param_u32("n");
let input_ptr = ctx.load_param_u64("input_ptr");
let grad_output_ptr = ctx.load_param_u64("grad_output_ptr");
let grad_input_ptr = ctx.load_param_u64("grad_input_ptr");
let in_bounds = ctx.setp_lt_u32(gid, n);
ctx.branch_if_not(in_bounds, "exit");
let four = ctx.mov_u32_imm(4);
let offset = ctx.mul_wide_u32_reg(gid, four);
let in_addr = ctx.add_u64(input_ptr, offset);
let grad_out_addr = ctx.add_u64(grad_output_ptr, offset);
let grad_in_addr = ctx.add_u64(grad_input_ptr, offset);
let x = ctx.ld_global_f32(in_addr);
let grad_out = ctx.ld_global_f32(grad_out_addr);
let zero = ctx.mov_f32_imm(0.0);
let is_positive = ctx.setp_gt_f32(x, zero);
let grad_in = ctx.selp_f32(is_positive, grad_out, zero);
ctx.st_global_f32(grad_in_addr, grad_in);
ctx.label("exit");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct GeluBackwardKernel {
pub n: u32,
}
impl GeluBackwardKernel {
#[must_use]
pub const fn new(n: u32) -> Self {
Self { n }
}
}
impl Kernel for GeluBackwardKernel {
fn name(&self) -> &str {
"gelu_backward"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("gelu_backward")
.param(PtxType::U64, "input_ptr")
.param(PtxType::U64, "grad_output_ptr")
.param(PtxType::U64, "grad_input_ptr")
.param(PtxType::U32, "n")
.build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid = ctx.special_reg(PtxReg::CtaIdX);
let ntid = ctx.special_reg(PtxReg::NtidX);
let gid = ctx.mad_lo_u32(ctaid, ntid, tid);
let n = ctx.load_param_u32("n");
let input_ptr = ctx.load_param_u64("input_ptr");
let grad_output_ptr = ctx.load_param_u64("grad_output_ptr");
let grad_input_ptr = ctx.load_param_u64("grad_input_ptr");
let in_bounds = ctx.setp_lt_u32(gid, n);
ctx.branch_if_not(in_bounds, "exit");
let four = ctx.mov_u32_imm(4);
let offset = ctx.mul_wide_u32_reg(gid, four);
let in_addr = ctx.add_u64(input_ptr, offset);
let grad_out_addr = ctx.add_u64(grad_output_ptr, offset);
let grad_in_addr = ctx.add_u64(grad_input_ptr, offset);
let x = ctx.ld_global_f32(in_addr);
let grad_out = ctx.ld_global_f32(grad_out_addr);
let sqrt_2_pi = ctx.mov_f32_imm(0.797_884_6); let c = ctx.mov_f32_imm(0.044_715);
let c3 = ctx.mov_f32_imm(0.134_145); let half = ctx.mov_f32_imm(0.5);
let one = ctx.mov_f32_imm(1.0);
let two = ctx.mov_f32_imm(2.0);
let zero = ctx.mov_f32_imm(0.0);
let log2_e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let x2 = ctx.mul_f32(x, x);
let x3 = ctx.mul_f32(x2, x);
let cx3 = ctx.mul_f32(c, x3);
let inner = ctx.add_f32(x, cx3);
let u = ctx.mul_f32(sqrt_2_pi, inner);
let two_u = ctx.mul_f32(two, u);
let neg_two_u = ctx.sub_f32(zero, two_u);
let scaled_exp = ctx.mul_f32(neg_two_u, log2_e);
let exp_neg = ctx.ex2_f32(scaled_exp);
let denom = ctx.add_f32(one, exp_neg);
let sigmoid = ctx.div_f32(one, denom);
let two_sigmoid = ctx.mul_f32(two, sigmoid);
let tanh_u = ctx.sub_f32(two_sigmoid, one);
let c3x2 = ctx.mul_f32(c3, x2);
let du_inner = ctx.add_f32(one, c3x2);
let du_dx = ctx.mul_f32(sqrt_2_pi, du_inner);
let tanh_sq = ctx.mul_f32(tanh_u, tanh_u);
let sech_sq = ctx.sub_f32(one, tanh_sq);
let one_plus_tanh = ctx.add_f32(one, tanh_u);
let term1 = ctx.mul_f32(half, one_plus_tanh);
let x_sech_sq = ctx.mul_f32(x, sech_sq);
let x_sech_sq_du = ctx.mul_f32(x_sech_sq, du_dx);
let term2 = ctx.mul_f32(half, x_sech_sq_du);
let gelu_prime = ctx.add_f32(term1, term2);
let grad_in = ctx.mul_f32(grad_out, gelu_prime);
ctx.st_global_f32(grad_in_addr, grad_in);
ctx.label("exit");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct SiluBackwardKernel {
pub n: u32,
}
impl SiluBackwardKernel {
#[must_use]
pub const fn new(n: u32) -> Self {
Self { n }
}
}
impl Kernel for SiluBackwardKernel {
fn name(&self) -> &str {
"silu_backward"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("silu_backward")
.param(PtxType::U64, "input_ptr")
.param(PtxType::U64, "grad_output_ptr")
.param(PtxType::U64, "grad_input_ptr")
.param(PtxType::U32, "n")
.build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid = ctx.special_reg(PtxReg::CtaIdX);
let ntid = ctx.special_reg(PtxReg::NtidX);
let gid = ctx.mad_lo_u32(ctaid, ntid, tid);
let n = ctx.load_param_u32("n");
let input_ptr = ctx.load_param_u64("input_ptr");
let grad_output_ptr = ctx.load_param_u64("grad_output_ptr");
let grad_input_ptr = ctx.load_param_u64("grad_input_ptr");
let in_bounds = ctx.setp_lt_u32(gid, n);
ctx.branch_if_not(in_bounds, "exit");
let four = ctx.mov_u32_imm(4);
let offset = ctx.mul_wide_u32_reg(gid, four);
let in_addr = ctx.add_u64(input_ptr, offset);
let grad_out_addr = ctx.add_u64(grad_output_ptr, offset);
let grad_in_addr = ctx.add_u64(grad_input_ptr, offset);
let x = ctx.ld_global_f32(in_addr);
let grad_out = ctx.ld_global_f32(grad_out_addr);
let one = ctx.mov_f32_imm(1.0);
let zero = ctx.mov_f32_imm(0.0);
let log2_e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let neg_x = ctx.sub_f32(zero, x);
let scaled = ctx.mul_f32(neg_x, log2_e);
let exp_neg_x = ctx.ex2_f32(scaled);
let denom = ctx.add_f32(one, exp_neg_x);
let sigmoid_x = ctx.div_f32(one, denom);
let y = ctx.mul_f32(x, sigmoid_x);
let one_plus_x = ctx.add_f32(one, x);
let one_plus_x_minus_y = ctx.sub_f32(one_plus_x, y);
let silu_prime = ctx.mul_f32(sigmoid_x, one_plus_x_minus_y);
let grad_in = ctx.mul_f32(grad_out, silu_prime);
ctx.st_global_f32(grad_in_addr, grad_in);
ctx.label("exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_relu_backward_name() {
let kernel = ReluBackwardKernel::new(2048);
assert_eq!(kernel.name(), "relu_backward");
}
#[test]
fn test_relu_backward_ptx_generation() {
let kernel = ReluBackwardKernel::new(2048);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry relu_backward"));
assert!(ptx.contains(".param .u64 input_ptr"));
assert!(ptx.contains(".param .u64 grad_output_ptr"));
assert!(ptx.contains(".param .u64 grad_input_ptr"));
assert!(ptx.contains("setp.gt.f32"));
assert!(ptx.contains("selp.f32"));
}
#[test]
fn test_gelu_backward_name() {
let kernel = GeluBackwardKernel::new(2048);
assert_eq!(kernel.name(), "gelu_backward");
}
#[test]
fn test_gelu_backward_ptx_generation() {
let kernel = GeluBackwardKernel::new(2048);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry gelu_backward"));
assert!(ptx.contains(".param .u64 input_ptr"));
assert!(ptx.contains(".param .u64 grad_output_ptr"));
assert!(ptx.contains(".param .u64 grad_input_ptr"));
assert!(ptx.contains("ex2.approx.f32"));
assert!(ptx.contains("mul.f32"));
}
#[test]
fn test_silu_backward_name() {
let kernel = SiluBackwardKernel::new(2048);
assert_eq!(kernel.name(), "silu_backward");
}
#[test]
fn test_silu_backward_ptx_generation() {
let kernel = SiluBackwardKernel::new(2048);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry silu_backward"));
assert!(ptx.contains(".param .u64 input_ptr"));
assert!(ptx.contains(".param .u64 grad_output_ptr"));
assert!(ptx.contains(".param .u64 grad_input_ptr"));
assert!(ptx.contains("ex2.approx.f32"));
assert!(ptx.contains("div.rn.f32"));
}
#[test]
fn test_relu_backward_barrier_safety() {
let kernel = ReluBackwardKernel::new(1024);
let result = kernel.analyze_barrier_safety();
assert!(
result.is_safe,
"ReLU backward should be barrier-safe: {:?}",
result.violations
);
}
#[test]
fn test_gelu_backward_barrier_safety() {
let kernel = GeluBackwardKernel::new(1024);
let result = kernel.analyze_barrier_safety();
assert!(
result.is_safe,
"GELU backward should be barrier-safe: {:?}",
result.violations
);
}
#[test]
fn test_silu_backward_barrier_safety() {
let kernel = SiluBackwardKernel::new(1024);
let result = kernel.analyze_barrier_safety();
assert!(
result.is_safe,
"SiLU backward should be barrier-safe: {:?}",
result.violations
);
}
}