use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxType};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Activation {
#[default]
None,
ReLU,
GELU,
}
#[derive(Debug, Clone)]
pub struct BiasActivationKernel {
n: u32,
bias_size: u32,
activation: Activation,
}
impl BiasActivationKernel {
#[must_use]
pub fn new(n: u32, bias_size: u32) -> Self {
Self { n, bias_size, activation: Activation::None }
}
#[must_use]
pub fn with_relu(mut self) -> Self {
self.activation = Activation::ReLU;
self
}
#[must_use]
pub fn with_gelu(mut self) -> Self {
self.activation = Activation::GELU;
self
}
#[must_use]
pub fn with_activation(mut self, activation: Activation) -> Self {
self.activation = activation;
self
}
}
impl super::Kernel for BiasActivationKernel {
fn name(&self) -> &str {
"bias_activation"
}
fn build_ptx(&self) -> PtxKernel {
let activation = self.activation;
let bias_size = self.bias_size;
PtxKernel::new("bias_activation")
.param(PtxType::U64, "output")
.param(PtxType::U64, "bias")
.param(PtxType::U32, "n")
.build(|ctx| {
let ctaid_x = ctx.special_reg(crate::ptx::PtxReg::CtaIdX);
let ntid_x = ctx.special_reg(crate::ptx::PtxReg::NtidX);
let tid_x = ctx.special_reg(crate::ptx::PtxReg::TidX);
let global_id = ctx.mad_lo_u32(ctaid_x, ntid_x, tid_x);
let n_param = ctx.load_param_u32("n");
let out_of_bounds = ctx.setp_ge_u32(global_id, n_param);
ctx.branch_if(out_of_bounds, "exit");
let output_ptr = ctx.load_param_u64("output");
let offset = ctx.mul_wide_u32(global_id, 4); let addr = ctx.add_u64(output_ptr, offset);
let value = ctx.ld_global_f32(addr);
let bias_ptr = ctx.load_param_u64("bias");
let bias_idx = ctx.rem_u32(global_id, bias_size);
let bias_offset = ctx.mul_wide_u32(bias_idx, 4);
let bias_addr = ctx.add_u64(bias_ptr, bias_offset);
let bias_val = ctx.ld_global_f32(bias_addr);
let result = ctx.add_f32(value, bias_val);
let activated = match activation {
Activation::None => result,
Activation::ReLU => {
let zero = ctx.mov_f32_imm(0.0);
ctx.max_f32(result, zero)
}
Activation::GELU => {
let coeff = ctx.mov_f32_imm(1.702);
let scaled = ctx.mul_f32(result, coeff);
let zero = ctx.mov_f32_imm(0.0);
let neg_scaled = ctx.sub_f32(zero, scaled);
let log2_e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let scaled_for_ex2 = ctx.mul_f32(neg_scaled, log2_e);
let exp_val = ctx.ex2_f32(scaled_for_ex2);
let one = ctx.mov_f32_imm(1.0);
let denom = ctx.add_f32(one, exp_val);
let sigmoid = ctx.div_f32(one, denom);
ctx.mul_f32(result, sigmoid)
}
};
ctx.st_global_f32(addr, activated);
ctx.label("exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests;