use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
use super::super::Kernel;
#[derive(Debug, Clone)]
pub struct FusedGemmBiasGeluKernel {
pub m: u32,
pub n: u32,
pub k: u32,
}
impl FusedGemmBiasGeluKernel {
#[must_use]
pub fn new(m: u32, n: u32, k: u32) -> Self {
Self { m, n, k }
}
}
impl Kernel for FusedGemmBiasGeluKernel {
fn name(&self) -> &str {
"fused_gemm_bias_gelu"
}
fn build_ptx(&self) -> PtxKernel {
let k_val = self.k;
let n_val = self.n;
PtxKernel::new("fused_gemm_bias_gelu")
.param(PtxType::U64, "a_ptr") .param(PtxType::U64, "b_ptr") .param(PtxType::U64, "bias_ptr") .param(PtxType::U64, "c_ptr") .param(PtxType::U32, "m")
.param(PtxType::U32, "n")
.param(PtxType::U32, "k")
.build(move |ctx| {
let ctaid_y = ctx.special_reg(PtxReg::CtaIdY);
let ntid_y = ctx.special_reg(PtxReg::NtidY);
let tid_y = ctx.special_reg(PtxReg::TidY);
let ctaid_x = ctx.special_reg(PtxReg::CtaIdX);
let ntid_x = ctx.special_reg(PtxReg::NtidX);
let tid_x = ctx.special_reg(PtxReg::TidX);
let row = ctx.mad_lo_u32(ctaid_y, ntid_y, tid_y);
let col = ctx.mad_lo_u32(ctaid_x, ntid_x, tid_x);
let m_param = ctx.load_param_u32("m");
let n_param = ctx.load_param_u32("n");
let k_param = ctx.load_param_u32("k");
let pred_m = ctx.setp_ge_u32(row, m_param);
ctx.branch_if(pred_m, "exit");
let pred_n = ctx.setp_ge_u32(col, n_param);
ctx.branch_if(pred_n, "exit");
let a_ptr = ctx.load_param_u64("a_ptr");
let b_ptr = ctx.load_param_u64("b_ptr");
let bias_ptr = ctx.load_param_u64("bias_ptr");
let c_ptr = ctx.load_param_u64("c_ptr");
let acc = ctx.mov_f32_imm(0.0);
let row_offset = ctx.mul_wide_u32(row, k_val * 4);
let a_row_ptr = ctx.add_u64(a_ptr, row_offset);
let col_offset = ctx.mul_wide_u32(col, 4);
let b_col_base = ctx.add_u64(b_ptr, col_offset);
let i = ctx.mov_u32_imm(0);
ctx.label("loop_k");
let pred_k = ctx.setp_ge_u32(i, k_param);
ctx.branch_if(pred_k, "loop_end");
let i_offset = ctx.mul_wide_u32(i, 4);
let a_addr = ctx.add_u64(a_row_ptr, i_offset);
let a_val = ctx.ld_global_f32(a_addr);
let b_row_offset = ctx.mul_wide_u32(i, n_val * 4);
let b_addr = ctx.add_u64(b_col_base, b_row_offset);
let b_val = ctx.ld_global_f32(b_addr);
ctx.fma_f32_inplace(acc, a_val, b_val);
ctx.add_u32_inplace(i, 1);
ctx.branch("loop_k");
ctx.label("loop_end");
let bias_offset = ctx.mul_wide_u32(col, 4);
let bias_addr = ctx.add_u64(bias_ptr, bias_offset);
let bias_val = ctx.ld_global_f32(bias_addr);
let acc_biased = ctx.add_f32(acc, bias_val);
let x = acc_biased;
let sqrt_2_pi = ctx.mov_f32_imm(0.797_884_6); let c = ctx.mov_f32_imm(0.044_715);
let half = ctx.mov_f32_imm(0.5);
let one = ctx.mov_f32_imm(1.0);
let two = ctx.mov_f32_imm(2.0);
let zero = ctx.mov_f32_imm(0.0);
let log2_e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let x2 = ctx.mul_f32(x, x);
let x3 = ctx.mul_f32(x2, x);
let cx3 = ctx.mul_f32(c, x3);
let inner = ctx.add_f32(x, cx3);
let scaled = ctx.mul_f32(sqrt_2_pi, inner);
let two_x = ctx.mul_f32(two, scaled);
let neg_two_x = ctx.sub_f32(zero, two_x);
let scaled_exp = ctx.mul_f32(neg_two_x, log2_e);
let exp_neg = ctx.ex2_f32(scaled_exp);
let denom = ctx.add_f32(one, exp_neg);
let sigmoid = ctx.div_f32(one, denom);
let two_sigmoid = ctx.mul_f32(two, sigmoid);
let tanh = ctx.sub_f32(two_sigmoid, one);
let one_plus_tanh = ctx.add_f32(one, tanh);
let half_x = ctx.mul_f32(half, x);
let result = ctx.mul_f32(half_x, one_plus_tanh);
let c_row_offset = ctx.mul_wide_u32(row, n_val * 4);
let c_row_ptr = ctx.add_u64(c_ptr, c_row_offset);
let c_col_offset = ctx.mul_wide_u32(col, 4);
let c_addr = ctx.add_u64(c_row_ptr, c_col_offset);
ctx.st_global_f32(c_addr, result);
ctx.label("exit");
ctx.ret();
})
}
}