#![allow(clippy::similar_names)]
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct ReluKernel {
pub n: u32,
}
impl ReluKernel {
#[must_use]
pub const fn new(n: u32) -> Self {
Self { n }
}
}
impl Kernel for ReluKernel {
fn name(&self) -> &str {
"relu"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("relu")
.param(PtxType::U64, "input_ptr")
.param(PtxType::U64, "output_ptr")
.param(PtxType::U32, "n")
.build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid = ctx.special_reg(PtxReg::CtaIdX);
let ntid = ctx.special_reg(PtxReg::NtidX);
let gid = ctx.mad_lo_u32(ctaid, ntid, tid);
let n = ctx.load_param_u32("n");
let input_ptr = ctx.load_param_u64("input_ptr");
let output_ptr = ctx.load_param_u64("output_ptr");
let in_bounds = ctx.setp_lt_u32(gid, n);
ctx.branch_if_not(in_bounds, "exit");
let four = ctx.mov_u32_imm(4);
let offset = ctx.mul_wide_u32_reg(gid, four);
let in_addr = ctx.add_u64(input_ptr, offset);
let out_addr = ctx.add_u64(output_ptr, offset);
let x = ctx.ld_global_f32(in_addr);
let zero = ctx.mov_f32_imm(0.0);
let result = ctx.max_f32(x, zero);
ctx.st_global_f32(out_addr, result);
ctx.label("exit");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct SiluKernel {
pub n: u32,
}
impl SiluKernel {
#[must_use]
pub const fn new(n: u32) -> Self {
Self { n }
}
}
impl Kernel for SiluKernel {
fn name(&self) -> &str {
"silu"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("silu")
.param(PtxType::U64, "input_ptr")
.param(PtxType::U64, "output_ptr")
.param(PtxType::U32, "n")
.build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid = ctx.special_reg(PtxReg::CtaIdX);
let ntid = ctx.special_reg(PtxReg::NtidX);
let gid = ctx.mad_lo_u32(ctaid, ntid, tid);
let n = ctx.load_param_u32("n");
let input_ptr = ctx.load_param_u64("input_ptr");
let output_ptr = ctx.load_param_u64("output_ptr");
let in_bounds = ctx.setp_lt_u32(gid, n);
ctx.branch_if_not(in_bounds, "exit");
let four = ctx.mov_u32_imm(4);
let offset = ctx.mul_wide_u32_reg(gid, four);
let in_addr = ctx.add_u64(input_ptr, offset);
let out_addr = ctx.add_u64(output_ptr, offset);
let x = ctx.ld_global_f32(in_addr);
let zero = ctx.mov_f32_imm(0.0);
let neg_x = ctx.sub_f32(zero, x);
let log2_e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let scaled = ctx.mul_f32(neg_x, log2_e);
let exp_neg_x = ctx.ex2_f32(scaled);
let one = ctx.mov_f32_imm(1.0);
let denom = ctx.add_f32(one, exp_neg_x);
let sigmoid = ctx.div_f32(one, denom);
let result = ctx.mul_f32(x, sigmoid);
ctx.st_global_f32(out_addr, result);
ctx.label("exit");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct GeluKernel {
pub n: u32,
}
impl GeluKernel {
#[must_use]
pub const fn new(n: u32) -> Self {
Self { n }
}
}
impl Kernel for GeluKernel {
fn name(&self) -> &str {
"gelu"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("gelu")
.param(PtxType::U64, "input_ptr")
.param(PtxType::U64, "output_ptr")
.param(PtxType::U32, "n")
.build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid = ctx.special_reg(PtxReg::CtaIdX);
let ntid = ctx.special_reg(PtxReg::NtidX);
let gid = ctx.mad_lo_u32(ctaid, ntid, tid);
let n = ctx.load_param_u32("n");
let input_ptr = ctx.load_param_u64("input_ptr");
let output_ptr = ctx.load_param_u64("output_ptr");
let in_bounds = ctx.setp_lt_u32(gid, n);
ctx.branch_if_not(in_bounds, "exit");
let four = ctx.mov_u32_imm(4);
let offset = ctx.mul_wide_u32_reg(gid, four);
let in_addr = ctx.add_u64(input_ptr, offset);
let out_addr = ctx.add_u64(output_ptr, offset);
let x = ctx.ld_global_f32(in_addr);
let sqrt_2_pi = ctx.mov_f32_imm(0.797_884_6);
let c = ctx.mov_f32_imm(0.044_715);
let half = ctx.mov_f32_imm(0.5);
let one = ctx.mov_f32_imm(1.0);
let x2 = ctx.mul_f32(x, x);
let x3 = ctx.mul_f32(x2, x);
let cx3 = ctx.mul_f32(c, x3);
let inner = ctx.add_f32(x, cx3);
let scaled = ctx.mul_f32(sqrt_2_pi, inner);
let two = ctx.mov_f32_imm(2.0);
let zero = ctx.mov_f32_imm(0.0);
let two_x = ctx.mul_f32(two, scaled);
let neg_two_x = ctx.sub_f32(zero, two_x);
let log2_e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let scaled_exp = ctx.mul_f32(neg_two_x, log2_e);
let exp_neg = ctx.ex2_f32(scaled_exp);
let denom = ctx.add_f32(one, exp_neg);
let sigmoid = ctx.div_f32(one, denom);
let two_sigmoid = ctx.mul_f32(two, sigmoid);
let tanh = ctx.sub_f32(two_sigmoid, one);
let one_plus_tanh = ctx.add_f32(one, tanh);
let half_x = ctx.mul_f32(half, x);
let result = ctx.mul_f32(half_x, one_plus_tanh);
ctx.st_global_f32(out_addr, result);
ctx.label("exit");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct ElementwiseMulKernel {
pub n: u32,
}
impl ElementwiseMulKernel {
#[must_use]
pub const fn new(n: u32) -> Self {
Self { n }
}
}
impl Kernel for ElementwiseMulKernel {
fn name(&self) -> &str {
"elementwise_mul"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("elementwise_mul")
.param(PtxType::U64, "input1_ptr")
.param(PtxType::U64, "input2_ptr")
.param(PtxType::U64, "output_ptr")
.param(PtxType::U32, "n")
.build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid = ctx.special_reg(PtxReg::CtaIdX);
let ntid = ctx.special_reg(PtxReg::NtidX);
let gid = ctx.mad_lo_u32(ctaid, ntid, tid);
let n = ctx.load_param_u32("n");
let input1_ptr = ctx.load_param_u64("input1_ptr");
let input2_ptr = ctx.load_param_u64("input2_ptr");
let output_ptr = ctx.load_param_u64("output_ptr");
let in_bounds = ctx.setp_lt_u32(gid, n);
ctx.branch_if_not(in_bounds, "exit");
let four = ctx.mov_u32_imm(4);
let offset = ctx.mul_wide_u32_reg(gid, four);
let addr1 = ctx.add_u64(input1_ptr, offset);
let addr2 = ctx.add_u64(input2_ptr, offset);
let out_addr = ctx.add_u64(output_ptr, offset);
let val1 = ctx.ld_global_f32(addr1);
let val2 = ctx.ld_global_f32(addr2);
let result = ctx.mul_f32(val1, val2);
ctx.st_global_f32(out_addr, result);
ctx.label("exit");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct ScaleKernel {
pub n: u32,
}
impl ScaleKernel {
#[must_use]
pub const fn new(n: u32) -> Self {
Self { n }
}
}
impl Kernel for ScaleKernel {
fn name(&self) -> &str {
"scale"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new("scale")
.param(PtxType::U64, "input_ptr")
.param(PtxType::U64, "output_ptr")
.param(PtxType::F32, "scale")
.param(PtxType::U32, "n")
.build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid = ctx.special_reg(PtxReg::CtaIdX);
let ntid = ctx.special_reg(PtxReg::NtidX);
let gid = ctx.mad_lo_u32(ctaid, ntid, tid);
let n = ctx.load_param_u32("n");
let input_ptr = ctx.load_param_u64("input_ptr");
let output_ptr = ctx.load_param_u64("output_ptr");
let scale = ctx.load_param_f32("scale");
let in_bounds = ctx.setp_lt_u32(gid, n);
ctx.branch_if_not(in_bounds, "exit");
let four = ctx.mov_u32_imm(4);
let offset = ctx.mul_wide_u32_reg(gid, four);
let in_addr = ctx.add_u64(input_ptr, offset);
let out_addr = ctx.add_u64(output_ptr, offset);
let val = ctx.ld_global_f32(in_addr);
let result = ctx.mul_f32(val, scale);
ctx.st_global_f32(out_addr, result);
ctx.label("exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_relu_kernel_name() {
let kernel = ReluKernel::new(2048);
assert_eq!(kernel.name(), "relu");
}
#[test]
fn test_relu_ptx_generation() {
let kernel = ReluKernel::new(2048);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry relu"));
assert!(ptx.contains("max.f32"));
}
#[test]
fn test_relu_kernel_debug() {
let kernel = ReluKernel::new(1024);
let debug_str = format!("{:?}", kernel);
assert!(debug_str.contains("ReluKernel"));
assert!(debug_str.contains("1024"));
}
#[test]
fn test_relu_kernel_clone() {
let kernel = ReluKernel::new(512);
let cloned = kernel.clone();
assert_eq!(cloned.n, 512);
}
#[test]
fn test_relu_kernel_ptx_contains_bounds_check() {
let kernel = ReluKernel::new(100);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("setp.lt.u32"));
assert!(ptx.contains("@!"));
}
#[test]
fn test_relu_kernel_edge_case_n_zero() {
let kernel = ReluKernel::new(0);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry relu"));
}
#[test]
fn test_relu_kernel_edge_case_n_one() {
let kernel = ReluKernel::new(1);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry relu"));
assert!(ptx.contains("max.f32"));
}
#[test]
fn test_relu_kernel_large_n() {
let kernel = ReluKernel::new(u32::MAX);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry relu"));
}
#[test]
fn test_silu_kernel_name() {
let kernel = SiluKernel::new(2048);
assert_eq!(kernel.name(), "silu");
}
#[test]
fn test_silu_ptx_generation() {
let kernel = SiluKernel::new(2048);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry silu"));
assert!(ptx.contains("ex2.approx.f32"));
assert!(ptx.contains("div.rn.f32"));
assert!(ptx.contains("mul.f32"));
}
#[test]
fn test_silu_kernel_debug() {
let kernel = SiluKernel::new(4096);
let debug_str = format!("{:?}", kernel);
assert!(debug_str.contains("SiluKernel"));
assert!(debug_str.contains("4096"));
}
#[test]
fn test_silu_kernel_clone() {
let kernel = SiluKernel::new(256);
let cloned = kernel.clone();
assert_eq!(cloned.n, 256);
}
#[test]
fn test_silu_kernel_contains_log2e_constant() {
let kernel = SiluKernel::new(1000);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("ex2.approx.f32"));
}
#[test]
fn test_silu_kernel_ptx_structure() {
let kernel = SiluKernel::new(512);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".param .u64 input_ptr"));
assert!(ptx.contains(".param .u64 output_ptr"));
assert!(ptx.contains(".param .u32 n"));
assert!(ptx.contains("exit:"));
}
#[test]
fn test_gelu_kernel_name() {
let kernel = GeluKernel::new(2048);
assert_eq!(kernel.name(), "gelu");
}
#[test]
fn test_gelu_ptx_generation() {
let kernel = GeluKernel::new(2048);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry gelu"));
assert!(ptx.contains("ex2.approx.f32"));
assert!(ptx.contains("mul.f32"));
}
#[test]
fn test_gelu_kernel_debug() {
let kernel = GeluKernel::new(8192);
let debug_str = format!("{:?}", kernel);
assert!(debug_str.contains("GeluKernel"));
assert!(debug_str.contains("8192"));
}
#[test]
fn test_gelu_kernel_clone() {
let kernel = GeluKernel::new(128);
let cloned = kernel.clone();
assert_eq!(cloned.n, 128);
}
#[test]
fn test_gelu_kernel_ptx_contains_tanh_approximation() {
let kernel = GeluKernel::new(1000);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("div.rn.f32")); assert!(ptx.contains("sub.f32")); }
#[test]
fn test_gelu_kernel_edge_case_n_zero() {
let kernel = GeluKernel::new(0);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry gelu"));
}
#[test]
fn test_elementwise_mul_kernel_name() {
let kernel = ElementwiseMulKernel::new(2048);
assert_eq!(kernel.name(), "elementwise_mul");
}
#[test]
fn test_elementwise_mul_ptx_generation() {
let kernel = ElementwiseMulKernel::new(2048);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry elementwise_mul"));
assert!(ptx.contains(".param .u64 input1_ptr"));
assert!(ptx.contains(".param .u64 input2_ptr"));
assert!(ptx.contains(".param .u64 output_ptr"));
assert!(ptx.contains(".param .u32 n"));
assert!(ptx.contains("mul.f32"));
}
#[test]
fn test_elementwise_mul_kernel_debug() {
let kernel = ElementwiseMulKernel::new(1024);
let debug_str = format!("{:?}", kernel);
assert!(debug_str.contains("ElementwiseMulKernel"));
assert!(debug_str.contains("1024"));
}
#[test]
fn test_elementwise_mul_kernel_clone() {
let kernel = ElementwiseMulKernel::new(64);
let cloned = kernel.clone();
assert_eq!(cloned.n, 64);
}
#[test]
fn test_elementwise_mul_kernel_ptx_contains_bounds_check() {
let kernel = ElementwiseMulKernel::new(500);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("setp.lt.u32"));
}
#[test]
fn test_elementwise_mul_kernel_ptx_loads_two_inputs() {
let kernel = ElementwiseMulKernel::new(100);
let ptx = kernel.emit_ptx();
let load_count = ptx.matches("ld.global.f32").count();
assert_eq!(load_count, 2, "Should have exactly 2 global loads");
}
#[test]
fn test_elementwise_mul_kernel_edge_case_n_one() {
let kernel = ElementwiseMulKernel::new(1);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry elementwise_mul"));
assert!(ptx.contains("mul.f32"));
}
#[test]
fn test_elementwise_mul_kernel_large_n() {
let kernel = ElementwiseMulKernel::new(1_000_000);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry elementwise_mul"));
}
#[test]
fn test_scale_kernel_name() {
let kernel = ScaleKernel::new(2048);
assert_eq!(kernel.name(), "scale");
}
#[test]
fn test_scale_ptx_generation() {
let kernel = ScaleKernel::new(2048);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry scale"));
assert!(ptx.contains(".param .f32 scale"));
assert!(ptx.contains("mul.f32"));
}
#[test]
fn test_scale_kernel_debug() {
let kernel = ScaleKernel::new(512);
let debug_str = format!("{:?}", kernel);
assert!(debug_str.contains("ScaleKernel"));
assert!(debug_str.contains("512"));
}
#[test]
fn test_scale_kernel_clone() {
let kernel = ScaleKernel::new(32);
let cloned = kernel.clone();
assert_eq!(cloned.n, 32);
}
#[test]
fn test_scale_kernel_ptx_structure() {
let kernel = ScaleKernel::new(256);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".param .u64 input_ptr"));
assert!(ptx.contains(".param .u64 output_ptr"));
assert!(ptx.contains(".param .f32 scale"));
assert!(ptx.contains(".param .u32 n"));
}
#[test]
fn test_scale_kernel_edge_case_n_zero() {
let kernel = ScaleKernel::new(0);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry scale"));
}
#[test]
fn test_scale_kernel_ptx_uses_f32_scale_param() {
let kernel = ScaleKernel::new(100);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".param .f32 scale"));
assert!(ptx.contains("mul.f32"));
}
}