use super::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl, PtxMemory};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct PersistentDecoderKernel {
pub hidden_size: u32,
pub num_layers: u32,
pub max_seq_len: u32,
pub block_size: u32,
}
impl PersistentDecoderKernel {
#[must_use]
pub fn new(hidden_size: u32, num_layers: u32, max_seq_len: u32) -> Self {
Self {
hidden_size,
num_layers,
max_seq_len,
block_size: 256,
}
}
#[must_use]
pub fn with_block_size(mut self, block_size: u32) -> Self {
self.block_size = block_size;
self
}
#[must_use]
pub fn shared_memory_bytes(&self) -> usize {
4 + (self.hidden_size as usize * 4)
}
}
impl Kernel for PersistentDecoderKernel {
fn name(&self) -> &str {
"persistent_decoder"
}
fn build_ptx(&self) -> PtxKernel {
let hidden_size = self.hidden_size;
let _num_layers = self.num_layers;
let max_seq_len = self.max_seq_len;
let block_size = self.block_size;
let smem_bytes = self.shared_memory_bytes();
PtxKernel::new("persistent_decoder")
.param(PtxType::U64, "work_queue_ptr") .param(PtxType::U64, "work_counter_ptr") .param(PtxType::U64, "input_ptr") .param(PtxType::U64, "output_ptr") .param(PtxType::U32, "num_tokens") .param(PtxType::U32, "stop_flag_ptr") .shared_memory(smem_bytes)
.build(move |ctx| {
let thread_id = ctx.special_reg(PtxReg::TidX);
let block_id = ctx.special_reg(PtxReg::CtaIdX);
let num_blocks = ctx.special_reg(PtxReg::NctaIdX);
let _work_counter_ptr = ctx.load_param_u64("work_counter_ptr");
let input_ptr = ctx.load_param_u64("input_ptr");
let output_ptr = ctx.load_param_u64("output_ptr");
let num_tokens = ctx.load_param_u32("num_tokens");
let smem_base = ctx.shared_base_addr();
let iteration = ctx.mov_u32_imm(0);
ctx.label("work_loop");
let iter_offset = ctx.mul_u32_reg(iteration, num_blocks);
let token_idx = ctx.add_u32_reg(block_id, iter_offset);
let work_done = ctx.setp_ge_u32(token_idx, num_tokens);
ctx.branch_if(work_done, "exit");
let zero = ctx.mov_u32_imm(0);
let is_leader = ctx.setp_eq_u32(thread_id, zero);
ctx.branch_if_not(is_leader, "skip_store");
ctx.st_shared_u32(smem_base, token_idx);
ctx.label("skip_store");
ctx.bar_sync(0);
let current_token = ctx.ld_shared_u32(smem_base);
let token_offset = ctx.mul_u32(current_token, hidden_size);
let token_offset_64 = ctx.cvt_u64_u32(token_offset);
let token_bytes = ctx.mul_u64(token_offset_64, 2); let input_addr = ctx.add_u64(input_ptr, token_bytes);
let elements_per_thread = hidden_size / block_size;
let elements_per_thread_reg = ctx.mov_u32_imm(elements_per_thread);
let thread_sum = ctx.mov_f32_imm(0.0);
let i = ctx.mov_u32_imm(0);
ctx.label("sum_loop");
let sum_done = ctx.setp_ge_u32(i, elements_per_thread_reg);
ctx.branch_if(sum_done, "sum_loop_end");
let stride = ctx.mov_u32_imm(block_size);
let elem_base = ctx.mul_u32_reg(i, stride);
let elem_idx = ctx.add_u32_reg(elem_base, thread_id);
let elem_idx_64 = ctx.cvt_u64_u32(elem_idx);
let elem_bytes = ctx.mul_u64(elem_idx_64, 2);
let elem_addr = ctx.add_u64(input_addr, elem_bytes);
let val_f16 = ctx.ld_global_f16(elem_addr);
let val = ctx.cvt_f32_f16(val_f16);
let sq = ctx.mul_f32(val, val);
ctx.add_f32_inplace(thread_sum, sq);
ctx.add_u32_inplace(i, 1);
ctx.branch("sum_loop");
ctx.label("sum_loop_end");
let tmp16 = ctx.shfl_down_f32(thread_sum, 16, 0xFFFF_FFFF);
ctx.add_f32_inplace(thread_sum, tmp16);
let tmp8 = ctx.shfl_down_f32(thread_sum, 8, 0xFFFF_FFFF);
ctx.add_f32_inplace(thread_sum, tmp8);
let tmp4 = ctx.shfl_down_f32(thread_sum, 4, 0xFFFF_FFFF);
ctx.add_f32_inplace(thread_sum, tmp4);
let tmp2 = ctx.shfl_down_f32(thread_sum, 2, 0xFFFF_FFFF);
ctx.add_f32_inplace(thread_sum, tmp2);
let tmp1 = ctx.shfl_down_f32(thread_sum, 1, 0xFFFF_FFFF);
ctx.add_f32_inplace(thread_sum, tmp1);
let warp_sum = ctx.shfl_idx_f32(thread_sum, 0, 0xFFFF_FFFF);
ctx.bar_sync(1);
let hidden_size_u32 = ctx.mov_u32_imm(hidden_size);
let hidden_size_float = ctx.cvt_f32_u32(hidden_size_u32);
let mean = ctx.div_f32(warp_sum, hidden_size_float);
let eps = ctx.mov_f32_imm(1e-6);
let mean_eps = ctx.add_f32(mean, eps);
let inv_rms = ctx.rsqrt_f32(mean_eps);
let j = ctx.mov_u32_imm(0);
ctx.label("norm_loop");
let norm_done = ctx.setp_ge_u32(j, elements_per_thread_reg);
ctx.branch_if(norm_done, "norm_loop_end");
let norm_stride = ctx.mov_u32_imm(block_size);
let norm_base = ctx.mul_u32_reg(j, norm_stride);
let norm_idx = ctx.add_u32_reg(norm_base, thread_id);
let norm_idx_64 = ctx.cvt_u64_u32(norm_idx);
let norm_bytes = ctx.mul_u64(norm_idx_64, 2);
let norm_in_addr = ctx.add_u64(input_addr, norm_bytes);
let in_val_f16 = ctx.ld_global_f16(norm_in_addr);
let in_val = ctx.cvt_f32_f16(in_val_f16);
let normed = ctx.mul_f32(in_val, inv_rms);
let out_val_f16 = ctx.cvt_f16_f32(normed);
let output_addr_elem = ctx.add_u64(output_ptr, token_bytes);
let output_final = ctx.add_u64(output_addr_elem, norm_bytes);
ctx.st_global_f16(output_final, out_val_f16);
ctx.add_u32_inplace(j, 1);
ctx.branch("norm_loop");
ctx.label("norm_loop_end");
ctx.bar_sync(2);
ctx.add_u32_inplace(iteration, 1);
ctx.branch("work_loop");
ctx.label("exit");
let _ = current_token;
let _ = max_seq_len;
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_persistent_decoder_name() {
let kernel = PersistentDecoderKernel::new(3584, 28, 2048);
assert_eq!(kernel.name(), "persistent_decoder");
}
#[test]
fn test_persistent_decoder_generates_ptx() {
let kernel = PersistentDecoderKernel::new(3584, 28, 2048);
let ptx = kernel.emit_ptx();
assert!(!ptx.is_empty());
assert!(ptx.contains(".visible .entry persistent_decoder"));
}
#[test]
fn test_persistent_decoder_has_work_loop() {
let kernel = PersistentDecoderKernel::new(3584, 28, 2048);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("work_loop"));
}
#[test]
fn test_persistent_decoder_has_block_distribution() {
let kernel = PersistentDecoderKernel::new(3584, 28, 2048);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("%ctaid"));
assert!(ptx.contains("%nctaid"));
}
#[test]
fn test_persistent_decoder_has_barriers() {
let kernel = PersistentDecoderKernel::new(3584, 28, 2048);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("bar.sync"));
}
#[test]
fn test_persistent_decoder_qwen3b_config() {
let kernel = PersistentDecoderKernel::new(3584, 28, 2048);
assert_eq!(kernel.hidden_size, 3584);
assert_eq!(kernel.num_layers, 28);
assert_eq!(kernel.max_seq_len, 2048);
}
#[test]
fn test_persistent_decoder_shared_memory() {
let kernel = PersistentDecoderKernel::new(3584, 28, 2048);
let smem = kernel.shared_memory_bytes();
assert_eq!(smem, 4 + 3584 * 4);
}
#[test]
fn test_persistent_decoder_barrier_structure() {
let kernel = PersistentDecoderKernel::new(3584, 28, 2048);
let ptx = kernel.emit_ptx();
let barrier_count = ptx.matches("bar.sync").count();
assert!(
barrier_count >= 2,
"Expected at least 2 barriers for work loop sync, found: {}",
barrier_count
);
}
}