#![allow(clippy::similar_names, unused_assignments, unused_mut)]
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
use super::Kernel;
#[derive(Debug, Clone)]
pub struct FusedQKVKernel {
pub hidden_size: usize,
pub kv_dim: usize,
}
impl FusedQKVKernel {
pub fn new(hidden_size: usize, kv_dim: usize) -> Self {
Self {
hidden_size,
kv_dim,
}
}
}
impl Kernel for FusedQKVKernel {
fn name(&self) -> &str {
"fused_qkv_gemv"
}
fn build_ptx(&self) -> PtxKernel {
let hidden = self.hidden_size as u32;
let kv = self.kv_dim as u32;
PtxKernel::new("fused_qkv_gemv")
.param(PtxType::U64, "x_ptr")
.param(PtxType::U64, "wq_ptr")
.param(PtxType::U64, "wk_ptr")
.param(PtxType::U64, "wv_ptr")
.param(PtxType::U64, "out_q_ptr")
.param(PtxType::U64, "out_k_ptr")
.param(PtxType::U64, "out_v_ptr")
.build(move |ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let row = ctx.special_reg(PtxReg::CtaIdX);
let lane = ctx.and_u32_imm(tid, 31);
let hidden_size = ctx.mov_u32_imm(hidden);
let kv_dim_val = ctx.mov_u32_imm(kv);
let mut acc_q = ctx.mov_f32_imm(0.0);
let mut acc_k = ctx.mov_f32_imm(0.0);
let mut acc_v = ctx.mov_f32_imm(0.0);
let x_ptr = ctx.load_param_u64("x_ptr");
let wq_ptr = ctx.load_param_u64("wq_ptr");
let wk_ptr = ctx.load_param_u64("wk_ptr");
let wv_ptr = ctx.load_param_u64("wv_ptr");
let mut k = lane;
ctx.label("loop_start");
let pred_exit = ctx.setp_ge_u32(k, hidden_size);
ctx.branch_if(pred_exit, "loop_end");
let offset_k = ctx.mul_wide_u32(k, 4);
let x_addr = ctx.add_u64(x_ptr, offset_k);
let x_val = ctx.ld_global_f32(x_addr);
let row_offset = ctx.mul_u32_reg(row, hidden_size);
let weight_idx = ctx.add_u32_reg(row_offset, k);
let weight_byte_offset = ctx.mul_wide_u32(weight_idx, 4);
let wq_addr = ctx.add_u64(wq_ptr, weight_byte_offset);
let wq_val = ctx.ld_global_f32(wq_addr);
acc_q = ctx.fma_f32(x_val, wq_val, acc_q);
let wk_addr = ctx.add_u64(wk_ptr, weight_byte_offset);
let wk_val = ctx.ld_global_f32(wk_addr);
acc_k = ctx.fma_f32(x_val, wk_val, acc_k);
let wv_addr = ctx.add_u64(wv_ptr, weight_byte_offset);
let wv_val = ctx.ld_global_f32(wv_addr);
acc_v = ctx.fma_f32(x_val, wv_val, acc_v);
ctx.add_u32_inplace(k, 32);
ctx.branch("loop_start");
ctx.label("loop_end");
let shfl_q_16 = ctx.shfl_down_f32(acc_q, 16, 0xFFFF_FFFF);
acc_q = ctx.add_f32(acc_q, shfl_q_16);
let shfl_q_8 = ctx.shfl_down_f32(acc_q, 8, 0xFFFF_FFFF);
acc_q = ctx.add_f32(acc_q, shfl_q_8);
let shfl_q_4 = ctx.shfl_down_f32(acc_q, 4, 0xFFFF_FFFF);
acc_q = ctx.add_f32(acc_q, shfl_q_4);
let shfl_q_2 = ctx.shfl_down_f32(acc_q, 2, 0xFFFF_FFFF);
acc_q = ctx.add_f32(acc_q, shfl_q_2);
let shfl_q_1 = ctx.shfl_down_f32(acc_q, 1, 0xFFFF_FFFF);
acc_q = ctx.add_f32(acc_q, shfl_q_1);
let shfl_k_16 = ctx.shfl_down_f32(acc_k, 16, 0xFFFF_FFFF);
acc_k = ctx.add_f32(acc_k, shfl_k_16);
let shfl_k_8 = ctx.shfl_down_f32(acc_k, 8, 0xFFFF_FFFF);
acc_k = ctx.add_f32(acc_k, shfl_k_8);
let shfl_k_4 = ctx.shfl_down_f32(acc_k, 4, 0xFFFF_FFFF);
acc_k = ctx.add_f32(acc_k, shfl_k_4);
let shfl_k_2 = ctx.shfl_down_f32(acc_k, 2, 0xFFFF_FFFF);
acc_k = ctx.add_f32(acc_k, shfl_k_2);
let shfl_k_1 = ctx.shfl_down_f32(acc_k, 1, 0xFFFF_FFFF);
acc_k = ctx.add_f32(acc_k, shfl_k_1);
let shfl_v_16 = ctx.shfl_down_f32(acc_v, 16, 0xFFFF_FFFF);
acc_v = ctx.add_f32(acc_v, shfl_v_16);
let shfl_v_8 = ctx.shfl_down_f32(acc_v, 8, 0xFFFF_FFFF);
acc_v = ctx.add_f32(acc_v, shfl_v_8);
let shfl_v_4 = ctx.shfl_down_f32(acc_v, 4, 0xFFFF_FFFF);
acc_v = ctx.add_f32(acc_v, shfl_v_4);
let shfl_v_2 = ctx.shfl_down_f32(acc_v, 2, 0xFFFF_FFFF);
acc_v = ctx.add_f32(acc_v, shfl_v_2);
let shfl_v_1 = ctx.shfl_down_f32(acc_v, 1, 0xFFFF_FFFF);
acc_v = ctx.add_f32(acc_v, shfl_v_1);
let zero = ctx.mov_u32_imm(0);
let is_lane0 = ctx.setp_eq_u32(lane, zero);
ctx.branch_if_not(is_lane0, "done");
let out_q_ptr = ctx.load_param_u64("out_q_ptr");
let out_k_ptr = ctx.load_param_u64("out_k_ptr");
let out_v_ptr = ctx.load_param_u64("out_v_ptr");
let row_byte_offset = ctx.mul_wide_u32(row, 4);
let out_q_addr = ctx.add_u64(out_q_ptr, row_byte_offset);
ctx.st_global_f32(out_q_addr, acc_q);
let pred_kv = ctx.setp_lt_u32(row, kv_dim_val);
ctx.branch_if_not(pred_kv, "done");
let out_k_addr = ctx.add_u64(out_k_ptr, row_byte_offset);
ctx.st_global_f32(out_k_addr, acc_k);
let out_v_addr = ctx.add_u64(out_v_ptr, row_byte_offset);
ctx.st_global_f32(out_v_addr, acc_v);
ctx.label("done");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct FusedGateUpKernel {
pub hidden_size: usize,
pub intermediate_size: usize,
}
impl FusedGateUpKernel {
pub fn new(hidden_size: usize, intermediate_size: usize) -> Self {
Self {
hidden_size,
intermediate_size,
}
}
}
impl Kernel for FusedGateUpKernel {
fn name(&self) -> &str {
"fused_gate_up_swiglu"
}
fn build_ptx(&self) -> PtxKernel {
let hidden = self.hidden_size as u32;
PtxKernel::new("fused_gate_up_swiglu")
.param(PtxType::U64, "x_ptr")
.param(PtxType::U64, "wg_ptr")
.param(PtxType::U64, "wu_ptr")
.param(PtxType::U64, "out_ptr")
.build(move |ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let row = ctx.special_reg(PtxReg::CtaIdX);
let lane = ctx.and_u32_imm(tid, 31);
let hidden_size = ctx.mov_u32_imm(hidden);
let mut acc_gate = ctx.mov_f32_imm(0.0);
let mut acc_up = ctx.mov_f32_imm(0.0);
let x_ptr = ctx.load_param_u64("x_ptr");
let wg_ptr = ctx.load_param_u64("wg_ptr");
let wu_ptr = ctx.load_param_u64("wu_ptr");
let mut k = lane;
ctx.label("loop_start");
let pred_exit = ctx.setp_ge_u32(k, hidden_size);
ctx.branch_if(pred_exit, "loop_end");
let offset_k = ctx.mul_wide_u32(k, 4);
let x_addr = ctx.add_u64(x_ptr, offset_k);
let x_val = ctx.ld_global_f32(x_addr);
let row_offset = ctx.mul_u32_reg(row, hidden_size);
let weight_idx = ctx.add_u32_reg(row_offset, k);
let weight_byte_offset = ctx.mul_wide_u32(weight_idx, 4);
let wg_addr = ctx.add_u64(wg_ptr, weight_byte_offset);
let wg_val = ctx.ld_global_f32(wg_addr);
acc_gate = ctx.fma_f32(x_val, wg_val, acc_gate);
let wu_addr = ctx.add_u64(wu_ptr, weight_byte_offset);
let wu_val = ctx.ld_global_f32(wu_addr);
acc_up = ctx.fma_f32(x_val, wu_val, acc_up);
ctx.add_u32_inplace(k, 32);
ctx.branch("loop_start");
ctx.label("loop_end");
let shfl_g_16 = ctx.shfl_down_f32(acc_gate, 16, 0xFFFF_FFFF);
acc_gate = ctx.add_f32(acc_gate, shfl_g_16);
let shfl_g_8 = ctx.shfl_down_f32(acc_gate, 8, 0xFFFF_FFFF);
acc_gate = ctx.add_f32(acc_gate, shfl_g_8);
let shfl_g_4 = ctx.shfl_down_f32(acc_gate, 4, 0xFFFF_FFFF);
acc_gate = ctx.add_f32(acc_gate, shfl_g_4);
let shfl_g_2 = ctx.shfl_down_f32(acc_gate, 2, 0xFFFF_FFFF);
acc_gate = ctx.add_f32(acc_gate, shfl_g_2);
let shfl_g_1 = ctx.shfl_down_f32(acc_gate, 1, 0xFFFF_FFFF);
acc_gate = ctx.add_f32(acc_gate, shfl_g_1);
let shfl_u_16 = ctx.shfl_down_f32(acc_up, 16, 0xFFFF_FFFF);
acc_up = ctx.add_f32(acc_up, shfl_u_16);
let shfl_u_8 = ctx.shfl_down_f32(acc_up, 8, 0xFFFF_FFFF);
acc_up = ctx.add_f32(acc_up, shfl_u_8);
let shfl_u_4 = ctx.shfl_down_f32(acc_up, 4, 0xFFFF_FFFF);
acc_up = ctx.add_f32(acc_up, shfl_u_4);
let shfl_u_2 = ctx.shfl_down_f32(acc_up, 2, 0xFFFF_FFFF);
acc_up = ctx.add_f32(acc_up, shfl_u_2);
let shfl_u_1 = ctx.shfl_down_f32(acc_up, 1, 0xFFFF_FFFF);
acc_up = ctx.add_f32(acc_up, shfl_u_1);
let zero = ctx.mov_u32_imm(0);
let is_lane0 = ctx.setp_eq_u32(lane, zero);
ctx.branch_if_not(is_lane0, "done");
let neg_gate = ctx.neg_f32(acc_gate);
let log2_e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let scaled = ctx.mul_f32(neg_gate, log2_e);
let exp_val = ctx.ex2_f32(scaled);
let one = ctx.mov_f32_imm(1.0);
let one_plus_exp = ctx.add_f32(one, exp_val);
let sigmoid = ctx.rcp_f32(one_plus_exp);
let silu = ctx.mul_f32(acc_gate, sigmoid);
let output = ctx.mul_f32(silu, acc_up);
let out_ptr = ctx.load_param_u64("out_ptr");
let row_byte_offset = ctx.mul_wide_u32(row, 4);
let out_addr = ctx.add_u64(out_ptr, row_byte_offset);
ctx.st_global_f32(out_addr, output);
ctx.label("done");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct FusedGemmBiasGeluKernel {
pub m: u32,
pub n: u32,
pub k: u32,
}
impl FusedGemmBiasGeluKernel {
#[must_use]
pub fn new(m: u32, n: u32, k: u32) -> Self {
Self { m, n, k }
}
}
impl Kernel for FusedGemmBiasGeluKernel {
fn name(&self) -> &str {
"fused_gemm_bias_gelu"
}
fn build_ptx(&self) -> PtxKernel {
let k_val = self.k;
let n_val = self.n;
PtxKernel::new("fused_gemm_bias_gelu")
.param(PtxType::U64, "a_ptr") .param(PtxType::U64, "b_ptr") .param(PtxType::U64, "bias_ptr") .param(PtxType::U64, "c_ptr") .param(PtxType::U32, "m")
.param(PtxType::U32, "n")
.param(PtxType::U32, "k")
.build(move |ctx| {
let ctaid_y = ctx.special_reg(PtxReg::CtaIdY);
let ntid_y = ctx.special_reg(PtxReg::NtidY);
let tid_y = ctx.special_reg(PtxReg::TidY);
let ctaid_x = ctx.special_reg(PtxReg::CtaIdX);
let ntid_x = ctx.special_reg(PtxReg::NtidX);
let tid_x = ctx.special_reg(PtxReg::TidX);
let row = ctx.mad_lo_u32(ctaid_y, ntid_y, tid_y);
let col = ctx.mad_lo_u32(ctaid_x, ntid_x, tid_x);
let m_param = ctx.load_param_u32("m");
let n_param = ctx.load_param_u32("n");
let k_param = ctx.load_param_u32("k");
let pred_m = ctx.setp_ge_u32(row, m_param);
ctx.branch_if(pred_m, "exit");
let pred_n = ctx.setp_ge_u32(col, n_param);
ctx.branch_if(pred_n, "exit");
let a_ptr = ctx.load_param_u64("a_ptr");
let b_ptr = ctx.load_param_u64("b_ptr");
let bias_ptr = ctx.load_param_u64("bias_ptr");
let c_ptr = ctx.load_param_u64("c_ptr");
let acc = ctx.mov_f32_imm(0.0);
let row_offset = ctx.mul_wide_u32(row, k_val * 4);
let a_row_ptr = ctx.add_u64(a_ptr, row_offset);
let col_offset = ctx.mul_wide_u32(col, 4);
let b_col_base = ctx.add_u64(b_ptr, col_offset);
let i = ctx.mov_u32_imm(0);
ctx.label("loop_k");
let pred_k = ctx.setp_ge_u32(i, k_param);
ctx.branch_if(pred_k, "loop_end");
let i_offset = ctx.mul_wide_u32(i, 4);
let a_addr = ctx.add_u64(a_row_ptr, i_offset);
let a_val = ctx.ld_global_f32(a_addr);
let b_row_offset = ctx.mul_wide_u32(i, n_val * 4);
let b_addr = ctx.add_u64(b_col_base, b_row_offset);
let b_val = ctx.ld_global_f32(b_addr);
ctx.fma_f32_inplace(acc, a_val, b_val);
ctx.add_u32_inplace(i, 1);
ctx.branch("loop_k");
ctx.label("loop_end");
let bias_offset = ctx.mul_wide_u32(col, 4);
let bias_addr = ctx.add_u64(bias_ptr, bias_offset);
let bias_val = ctx.ld_global_f32(bias_addr);
let acc_biased = ctx.add_f32(acc, bias_val);
let x = acc_biased;
let sqrt_2_pi = ctx.mov_f32_imm(0.797_884_6); let c = ctx.mov_f32_imm(0.044_715);
let half = ctx.mov_f32_imm(0.5);
let one = ctx.mov_f32_imm(1.0);
let two = ctx.mov_f32_imm(2.0);
let zero = ctx.mov_f32_imm(0.0);
let log2_e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let x2 = ctx.mul_f32(x, x);
let x3 = ctx.mul_f32(x2, x);
let cx3 = ctx.mul_f32(c, x3);
let inner = ctx.add_f32(x, cx3);
let scaled = ctx.mul_f32(sqrt_2_pi, inner);
let two_x = ctx.mul_f32(two, scaled);
let neg_two_x = ctx.sub_f32(zero, two_x);
let scaled_exp = ctx.mul_f32(neg_two_x, log2_e);
let exp_neg = ctx.ex2_f32(scaled_exp);
let denom = ctx.add_f32(one, exp_neg);
let sigmoid = ctx.div_f32(one, denom);
let two_sigmoid = ctx.mul_f32(two, sigmoid);
let tanh = ctx.sub_f32(two_sigmoid, one);
let one_plus_tanh = ctx.add_f32(one, tanh);
let half_x = ctx.mul_f32(half, x);
let result = ctx.mul_f32(half_x, one_plus_tanh);
let c_row_offset = ctx.mul_wide_u32(row, n_val * 4);
let c_row_ptr = ctx.add_u64(c_ptr, c_row_offset);
let c_col_offset = ctx.mul_wide_u32(col, 4);
let c_addr = ctx.add_u64(c_row_ptr, c_col_offset);
ctx.st_global_f32(c_addr, result);
ctx.label("exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fused_gemm_bias_gelu_kernel_builds() {
let kernel = FusedGemmBiasGeluKernel::new(1500, 1536, 384);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("fused_gemm_bias_gelu"));
assert!(ptx.contains(".entry"));
assert!(ptx.contains("0F3F4C422A"), "Missing sqrt(2/π) constant");
assert!(ptx.contains("0F3D372713"), "Missing 0.044715 constant");
}
#[test]
fn test_fused_qkv_kernel_builds() {
let kernel = FusedQKVKernel::new(3584, 512);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("fused_qkv_gemv"));
assert!(ptx.contains(".entry"));
}
#[test]
fn test_fused_gate_up_kernel_builds() {
let kernel = FusedGateUpKernel::new(3584, 18944);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("fused_gate_up_swiglu"));
assert!(ptx.contains(".entry"));
}
#[test]
fn test_fused_qkv_kernel_name() {
let kernel = FusedQKVKernel::new(1024, 256);
assert_eq!(kernel.name(), "fused_qkv_gemv");
}
#[test]
fn test_fused_gate_up_kernel_name() {
let kernel = FusedGateUpKernel::new(1024, 4096);
assert_eq!(kernel.name(), "fused_gate_up_swiglu");
}
#[test]
fn test_fused_qkv_kernel_clone() {
let kernel = FusedQKVKernel::new(1024, 256);
let cloned = kernel.clone();
assert_eq!(cloned.hidden_size, kernel.hidden_size);
assert_eq!(cloned.kv_dim, kernel.kv_dim);
}
#[test]
fn test_fused_gate_up_kernel_clone() {
let kernel = FusedGateUpKernel::new(1024, 4096);
let cloned = kernel.clone();
assert_eq!(cloned.hidden_size, kernel.hidden_size);
assert_eq!(cloned.intermediate_size, kernel.intermediate_size);
}
#[test]
fn test_fused_qkv_kernel_debug() {
let kernel = FusedQKVKernel::new(1024, 256);
let debug = format!("{:?}", kernel);
assert!(debug.contains("FusedQKVKernel"));
assert!(debug.contains("1024"));
}
#[test]
fn test_fused_gate_up_kernel_debug() {
let kernel = FusedGateUpKernel::new(1024, 4096);
let debug = format!("{:?}", kernel);
assert!(debug.contains("FusedGateUpKernel"));
assert!(debug.contains("4096"));
}
}