#![allow(clippy::too_many_lines)]
use super::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
const Q4K_SUPER_BLOCK_SIZE: u32 = 256;
#[derive(Debug, Clone)]
pub struct TransformerBlockMegakernel {
pub hidden_size: u32,
pub intermediate_size: u32,
pub num_heads: u32,
pub head_dim: u32,
pub epsilon: f32,
}
impl TransformerBlockMegakernel {
#[must_use]
pub fn new(hidden_size: u32, intermediate_size: u32, num_heads: u32) -> Self {
Self {
hidden_size,
intermediate_size,
num_heads,
head_dim: hidden_size / num_heads,
epsilon: 1e-6,
}
}
#[must_use]
pub fn with_epsilon(mut self, epsilon: f32) -> Self {
self.epsilon = epsilon;
self
}
#[must_use]
pub fn shared_memory_bytes(&self) -> usize {
let norm_buffer = self.hidden_size as usize * 4;
let qkv_buffer = 3 * self.head_dim as usize * 4;
let attn_buffer = self.hidden_size as usize * 4;
norm_buffer + qkv_buffer + attn_buffer
}
#[must_use]
pub fn num_hidden_super_blocks(&self) -> u32 {
(self.hidden_size + Q4K_SUPER_BLOCK_SIZE - 1) / Q4K_SUPER_BLOCK_SIZE
}
}
impl Kernel for TransformerBlockMegakernel {
fn name(&self) -> &str {
"transformer_block_megakernel"
}
fn build_ptx(&self) -> PtxKernel {
let hidden_size = self.hidden_size;
let intermediate_size = self.intermediate_size;
let num_heads = self.num_heads;
let head_dim = self.head_dim;
let epsilon = self.epsilon;
let smem_bytes = self.shared_memory_bytes();
PtxKernel::new("transformer_block_megakernel")
.param(PtxType::U64, "input_ptr") .param(PtxType::U64, "output_ptr") .param(PtxType::U64, "q_proj_ptr") .param(PtxType::U64, "k_proj_ptr") .param(PtxType::U64, "v_proj_ptr") .param(PtxType::U64, "o_proj_ptr") .param(PtxType::U64, "gate_proj_ptr") .param(PtxType::U64, "up_proj_ptr") .param(PtxType::U64, "down_proj_ptr") .param(PtxType::U64, "attn_norm_ptr") .param(PtxType::U64, "ffn_norm_ptr") .param(PtxType::U64, "k_cache_ptr") .param(PtxType::U64, "v_cache_ptr") .param(PtxType::U32, "seq_pos") .shared_memory(smem_bytes)
.build(move |ctx| {
let thread_id = ctx.special_reg(PtxReg::TidX);
let warp_id = ctx.div_u32(thread_id, 32);
let lane_id = ctx.rem_u32(thread_id, 32);
let input_ptr = ctx.load_param_u64("input_ptr");
let output_ptr = ctx.load_param_u64("output_ptr");
let attn_norm_ptr = ctx.load_param_u64("attn_norm_ptr");
let ffn_norm_ptr = ctx.load_param_u64("ffn_norm_ptr");
let _seq_pos = ctx.load_param_u32("seq_pos");
let thread_sum = ctx.mov_f32_imm(0.0);
let num_per_thread = hidden_size / 256;
let num_per_thread_reg = ctx.mov_u32_imm(num_per_thread);
let i = ctx.mov_u32_imm(0);
ctx.label("norm_loop");
let loop_done = ctx.setp_ge_u32(i, num_per_thread_reg);
ctx.branch_if(loop_done, "norm_loop_end");
let stride = ctx.mov_u32_imm(256);
let base_idx = ctx.mul_u32_reg(i, stride);
let global_idx = ctx.add_u32_reg(base_idx, thread_id);
let global_idx_64 = ctx.cvt_u64_u32(global_idx);
let input_bytes = ctx.mul_u64(global_idx_64, 2); let input_addr = ctx.add_u64(input_ptr, input_bytes);
let val_f16 = ctx.ld_global_f16(input_addr);
let val = ctx.cvt_f32_f16(val_f16);
let sq = ctx.mul_f32(val, val);
ctx.add_f32_inplace(thread_sum, sq);
ctx.add_u32_inplace(i, 1);
ctx.branch("norm_loop");
ctx.label("norm_loop_end");
let tmp16 = ctx.shfl_down_f32(thread_sum, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(thread_sum, tmp16);
let tmp8 = ctx.shfl_down_f32(thread_sum, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(thread_sum, tmp8);
let tmp4 = ctx.shfl_down_f32(thread_sum, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(thread_sum, tmp4);
let tmp2 = ctx.shfl_down_f32(thread_sum, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(thread_sum, tmp2);
let tmp1 = ctx.shfl_down_f32(thread_sum, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(thread_sum, tmp1);
let warp_sum = ctx.shfl_idx_f32(thread_sum, 0, 0xFFFF_FFFF);
ctx.bar_sync(0);
let hidden_size_u32 = ctx.mov_u32_imm(hidden_size);
let hidden_size_float = ctx.cvt_f32_u32(hidden_size_u32);
let mean = ctx.div_f32(warp_sum, hidden_size_float);
let eps_reg = ctx.mov_f32_imm(epsilon);
let mean_eps = ctx.add_f32(mean, eps_reg);
let inv_rms = ctx.rsqrt_f32(mean_eps);
let j = ctx.mov_u32_imm(0);
ctx.label("store_norm_loop");
let store_done = ctx.setp_ge_u32(j, num_per_thread_reg);
ctx.branch_if(store_done, "store_norm_loop_end");
let store_stride = ctx.mov_u32_imm(256);
let store_base_idx = ctx.mul_u32_reg(j, store_stride);
let store_global_idx = ctx.add_u32_reg(store_base_idx, thread_id);
let store_global_idx_64 = ctx.cvt_u64_u32(store_global_idx);
let input_load_bytes = ctx.mul_u64(store_global_idx_64, 2);
let input_load_addr = ctx.add_u64(input_ptr, input_load_bytes);
let input_val_f16 = ctx.ld_global_f16(input_load_addr);
let input_val = ctx.cvt_f32_f16(input_val_f16);
let gamma_bytes = ctx.mul_u64(store_global_idx_64, 4);
let gamma_addr = ctx.add_u64(attn_norm_ptr, gamma_bytes);
let gamma = ctx.ld_global_f32(gamma_addr);
let normalized = ctx.mul_f32(input_val, inv_rms);
let scaled = ctx.mul_f32(normalized, gamma);
let scaled_f16 = ctx.cvt_f16_f32(scaled);
let output_bytes = ctx.mul_u64(store_global_idx_64, 2);
let output_addr = ctx.add_u64(output_ptr, output_bytes);
ctx.st_global_f16(output_addr, scaled_f16);
ctx.add_u32_inplace(j, 1);
ctx.branch("store_norm_loop");
ctx.label("store_norm_loop_end");
ctx.bar_sync(1);
let _ = warp_id;
let _ = lane_id;
let _ = intermediate_size;
let _ = num_heads;
let _ = head_dim;
let _ = ffn_norm_ptr;
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_megakernel_name() {
let kernel = TransformerBlockMegakernel::new(3584, 18944, 28);
assert_eq!(kernel.name(), "transformer_block_megakernel");
}
#[test]
fn test_megakernel_generates_ptx() {
let kernel = TransformerBlockMegakernel::new(3584, 18944, 28);
let ptx = kernel.emit_ptx();
assert!(!ptx.is_empty());
assert!(ptx.contains(".visible .entry transformer_block_megakernel"));
}
#[test]
fn test_megakernel_qwen3b_dimensions() {
let kernel = TransformerBlockMegakernel::new(3584, 18944, 28);
assert_eq!(kernel.hidden_size, 3584);
assert_eq!(kernel.intermediate_size, 18944);
assert_eq!(kernel.num_heads, 28);
assert_eq!(kernel.head_dim, 128); }
#[test]
fn test_megakernel_shared_memory() {
let kernel = TransformerBlockMegakernel::new(3584, 18944, 28);
let smem = kernel.shared_memory_bytes();
assert!(smem >= 3584 * 4);
}
#[test]
fn test_megakernel_has_barriers() {
let kernel = TransformerBlockMegakernel::new(3584, 18944, 28);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("bar.sync"));
}
#[test]
fn test_megakernel_has_shuffle_reduction() {
let kernel = TransformerBlockMegakernel::new(3584, 18944, 28);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("shfl"));
}
#[test]
fn test_megakernel_with_epsilon() {
let kernel = TransformerBlockMegakernel::new(3584, 18944, 28).with_epsilon(1e-5);
assert!((kernel.epsilon - 1e-5).abs() < 1e-10);
}
#[test]
fn test_megakernel_barrier_safety() {
use crate::ptx::optimize::barrier_safety;
let kernel = TransformerBlockMegakernel::new(3584, 18944, 28);
let ptx = kernel.emit_ptx();
let result = barrier_safety::analyze(&ptx);
assert!(result.is_safe, "Megakernel should be barrier-safe: {:?}", result.violations);
}
}