#![allow(clippy::similar_names, unused_assignments, unused_mut)]
mod gemm_bias_gelu;
pub use gemm_bias_gelu::FusedGemmBiasGeluKernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
use super::Kernel;
#[derive(Debug, Clone)]
pub struct FusedQKVKernel {
pub hidden_size: usize,
pub kv_dim: usize,
}
impl FusedQKVKernel {
pub fn new(hidden_size: usize, kv_dim: usize) -> Self {
Self { hidden_size, kv_dim }
}
}
impl Kernel for FusedQKVKernel {
fn name(&self) -> &str {
"fused_qkv_gemv"
}
fn build_ptx(&self) -> PtxKernel {
let hidden = self.hidden_size as u32;
let kv = self.kv_dim as u32;
PtxKernel::new("fused_qkv_gemv")
.param(PtxType::U64, "x_ptr")
.param(PtxType::U64, "wq_ptr")
.param(PtxType::U64, "wk_ptr")
.param(PtxType::U64, "wv_ptr")
.param(PtxType::U64, "out_q_ptr")
.param(PtxType::U64, "out_k_ptr")
.param(PtxType::U64, "out_v_ptr")
.build(move |ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let row = ctx.special_reg(PtxReg::CtaIdX);
let lane = ctx.and_u32_imm(tid, 31);
let hidden_size = ctx.mov_u32_imm(hidden);
let kv_dim_val = ctx.mov_u32_imm(kv);
let mut acc_q = ctx.mov_f32_imm(0.0);
let mut acc_k = ctx.mov_f32_imm(0.0);
let mut acc_v = ctx.mov_f32_imm(0.0);
let x_ptr = ctx.load_param_u64("x_ptr");
let wq_ptr = ctx.load_param_u64("wq_ptr");
let wk_ptr = ctx.load_param_u64("wk_ptr");
let wv_ptr = ctx.load_param_u64("wv_ptr");
let mut k = lane;
ctx.label("loop_start");
let pred_exit = ctx.setp_ge_u32(k, hidden_size);
ctx.branch_if(pred_exit, "loop_end");
let offset_k = ctx.mul_wide_u32(k, 4);
let x_addr = ctx.add_u64(x_ptr, offset_k);
let x_val = ctx.ld_global_f32(x_addr);
let row_offset = ctx.mul_u32_reg(row, hidden_size);
let weight_idx = ctx.add_u32_reg(row_offset, k);
let weight_byte_offset = ctx.mul_wide_u32(weight_idx, 4);
let wq_addr = ctx.add_u64(wq_ptr, weight_byte_offset);
let wq_val = ctx.ld_global_f32(wq_addr);
acc_q = ctx.fma_f32(x_val, wq_val, acc_q);
let wk_addr = ctx.add_u64(wk_ptr, weight_byte_offset);
let wk_val = ctx.ld_global_f32(wk_addr);
acc_k = ctx.fma_f32(x_val, wk_val, acc_k);
let wv_addr = ctx.add_u64(wv_ptr, weight_byte_offset);
let wv_val = ctx.ld_global_f32(wv_addr);
acc_v = ctx.fma_f32(x_val, wv_val, acc_v);
ctx.add_u32_inplace(k, 32);
ctx.branch("loop_start");
ctx.label("loop_end");
let shfl_q_16 = ctx.shfl_down_f32(acc_q, 16, 0xFFFF_FFFF);
acc_q = ctx.add_f32(acc_q, shfl_q_16);
let shfl_q_8 = ctx.shfl_down_f32(acc_q, 8, 0xFFFF_FFFF);
acc_q = ctx.add_f32(acc_q, shfl_q_8);
let shfl_q_4 = ctx.shfl_down_f32(acc_q, 4, 0xFFFF_FFFF);
acc_q = ctx.add_f32(acc_q, shfl_q_4);
let shfl_q_2 = ctx.shfl_down_f32(acc_q, 2, 0xFFFF_FFFF);
acc_q = ctx.add_f32(acc_q, shfl_q_2);
let shfl_q_1 = ctx.shfl_down_f32(acc_q, 1, 0xFFFF_FFFF);
acc_q = ctx.add_f32(acc_q, shfl_q_1);
let shfl_k_16 = ctx.shfl_down_f32(acc_k, 16, 0xFFFF_FFFF);
acc_k = ctx.add_f32(acc_k, shfl_k_16);
let shfl_k_8 = ctx.shfl_down_f32(acc_k, 8, 0xFFFF_FFFF);
acc_k = ctx.add_f32(acc_k, shfl_k_8);
let shfl_k_4 = ctx.shfl_down_f32(acc_k, 4, 0xFFFF_FFFF);
acc_k = ctx.add_f32(acc_k, shfl_k_4);
let shfl_k_2 = ctx.shfl_down_f32(acc_k, 2, 0xFFFF_FFFF);
acc_k = ctx.add_f32(acc_k, shfl_k_2);
let shfl_k_1 = ctx.shfl_down_f32(acc_k, 1, 0xFFFF_FFFF);
acc_k = ctx.add_f32(acc_k, shfl_k_1);
let shfl_v_16 = ctx.shfl_down_f32(acc_v, 16, 0xFFFF_FFFF);
acc_v = ctx.add_f32(acc_v, shfl_v_16);
let shfl_v_8 = ctx.shfl_down_f32(acc_v, 8, 0xFFFF_FFFF);
acc_v = ctx.add_f32(acc_v, shfl_v_8);
let shfl_v_4 = ctx.shfl_down_f32(acc_v, 4, 0xFFFF_FFFF);
acc_v = ctx.add_f32(acc_v, shfl_v_4);
let shfl_v_2 = ctx.shfl_down_f32(acc_v, 2, 0xFFFF_FFFF);
acc_v = ctx.add_f32(acc_v, shfl_v_2);
let shfl_v_1 = ctx.shfl_down_f32(acc_v, 1, 0xFFFF_FFFF);
acc_v = ctx.add_f32(acc_v, shfl_v_1);
let zero = ctx.mov_u32_imm(0);
let is_lane0 = ctx.setp_eq_u32(lane, zero);
ctx.branch_if_not(is_lane0, "done");
let out_q_ptr = ctx.load_param_u64("out_q_ptr");
let out_k_ptr = ctx.load_param_u64("out_k_ptr");
let out_v_ptr = ctx.load_param_u64("out_v_ptr");
let row_byte_offset = ctx.mul_wide_u32(row, 4);
let out_q_addr = ctx.add_u64(out_q_ptr, row_byte_offset);
ctx.st_global_f32(out_q_addr, acc_q);
let pred_kv = ctx.setp_lt_u32(row, kv_dim_val);
ctx.branch_if_not(pred_kv, "done");
let out_k_addr = ctx.add_u64(out_k_ptr, row_byte_offset);
ctx.st_global_f32(out_k_addr, acc_k);
let out_v_addr = ctx.add_u64(out_v_ptr, row_byte_offset);
ctx.st_global_f32(out_v_addr, acc_v);
ctx.label("done");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct FusedGateUpKernel {
pub hidden_size: usize,
pub intermediate_size: usize,
}
impl FusedGateUpKernel {
pub fn new(hidden_size: usize, intermediate_size: usize) -> Self {
Self { hidden_size, intermediate_size }
}
}
impl Kernel for FusedGateUpKernel {
fn name(&self) -> &str {
"fused_gate_up_swiglu"
}
fn build_ptx(&self) -> PtxKernel {
let hidden = self.hidden_size as u32;
PtxKernel::new("fused_gate_up_swiglu")
.param(PtxType::U64, "x_ptr")
.param(PtxType::U64, "wg_ptr")
.param(PtxType::U64, "wu_ptr")
.param(PtxType::U64, "out_ptr")
.build(move |ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let row = ctx.special_reg(PtxReg::CtaIdX);
let lane = ctx.and_u32_imm(tid, 31);
let hidden_size = ctx.mov_u32_imm(hidden);
let mut acc_gate = ctx.mov_f32_imm(0.0);
let mut acc_up = ctx.mov_f32_imm(0.0);
let x_ptr = ctx.load_param_u64("x_ptr");
let wg_ptr = ctx.load_param_u64("wg_ptr");
let wu_ptr = ctx.load_param_u64("wu_ptr");
let mut k = lane;
ctx.label("loop_start");
let pred_exit = ctx.setp_ge_u32(k, hidden_size);
ctx.branch_if(pred_exit, "loop_end");
let offset_k = ctx.mul_wide_u32(k, 4);
let x_addr = ctx.add_u64(x_ptr, offset_k);
let x_val = ctx.ld_global_f32(x_addr);
let row_offset = ctx.mul_u32_reg(row, hidden_size);
let weight_idx = ctx.add_u32_reg(row_offset, k);
let weight_byte_offset = ctx.mul_wide_u32(weight_idx, 4);
let wg_addr = ctx.add_u64(wg_ptr, weight_byte_offset);
let wg_val = ctx.ld_global_f32(wg_addr);
acc_gate = ctx.fma_f32(x_val, wg_val, acc_gate);
let wu_addr = ctx.add_u64(wu_ptr, weight_byte_offset);
let wu_val = ctx.ld_global_f32(wu_addr);
acc_up = ctx.fma_f32(x_val, wu_val, acc_up);
ctx.add_u32_inplace(k, 32);
ctx.branch("loop_start");
ctx.label("loop_end");
let shfl_g_16 = ctx.shfl_down_f32(acc_gate, 16, 0xFFFF_FFFF);
acc_gate = ctx.add_f32(acc_gate, shfl_g_16);
let shfl_g_8 = ctx.shfl_down_f32(acc_gate, 8, 0xFFFF_FFFF);
acc_gate = ctx.add_f32(acc_gate, shfl_g_8);
let shfl_g_4 = ctx.shfl_down_f32(acc_gate, 4, 0xFFFF_FFFF);
acc_gate = ctx.add_f32(acc_gate, shfl_g_4);
let shfl_g_2 = ctx.shfl_down_f32(acc_gate, 2, 0xFFFF_FFFF);
acc_gate = ctx.add_f32(acc_gate, shfl_g_2);
let shfl_g_1 = ctx.shfl_down_f32(acc_gate, 1, 0xFFFF_FFFF);
acc_gate = ctx.add_f32(acc_gate, shfl_g_1);
let shfl_u_16 = ctx.shfl_down_f32(acc_up, 16, 0xFFFF_FFFF);
acc_up = ctx.add_f32(acc_up, shfl_u_16);
let shfl_u_8 = ctx.shfl_down_f32(acc_up, 8, 0xFFFF_FFFF);
acc_up = ctx.add_f32(acc_up, shfl_u_8);
let shfl_u_4 = ctx.shfl_down_f32(acc_up, 4, 0xFFFF_FFFF);
acc_up = ctx.add_f32(acc_up, shfl_u_4);
let shfl_u_2 = ctx.shfl_down_f32(acc_up, 2, 0xFFFF_FFFF);
acc_up = ctx.add_f32(acc_up, shfl_u_2);
let shfl_u_1 = ctx.shfl_down_f32(acc_up, 1, 0xFFFF_FFFF);
acc_up = ctx.add_f32(acc_up, shfl_u_1);
let zero = ctx.mov_u32_imm(0);
let is_lane0 = ctx.setp_eq_u32(lane, zero);
ctx.branch_if_not(is_lane0, "done");
let neg_gate = ctx.neg_f32(acc_gate);
let log2_e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let scaled = ctx.mul_f32(neg_gate, log2_e);
let exp_val = ctx.ex2_f32(scaled);
let one = ctx.mov_f32_imm(1.0);
let one_plus_exp = ctx.add_f32(one, exp_val);
let sigmoid = ctx.rcp_f32(one_plus_exp);
let silu = ctx.mul_f32(acc_gate, sigmoid);
let output = ctx.mul_f32(silu, acc_up);
let out_ptr = ctx.load_param_u64("out_ptr");
let row_byte_offset = ctx.mul_wide_u32(row, 4);
let out_addr = ctx.add_u64(out_ptr, row_byte_offset);
ctx.st_global_f32(out_addr, output);
ctx.label("done");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests;