use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct PrefillAttentionKernel {
pub head_dim: u32,
pub heads_per_kv: u32,
}
impl PrefillAttentionKernel {
#[must_use]
pub fn new(head_dim: u32, heads_per_kv: u32) -> Self {
assert!(head_dim % 32 == 0, "head_dim must be multiple of 32");
assert!(heads_per_kv > 0, "heads_per_kv must be > 0");
Self {
head_dim,
heads_per_kv,
}
}
}
impl Kernel for PrefillAttentionKernel {
fn name(&self) -> &str {
"fused_prefill_attention_causal"
}
fn build_ptx(&self) -> PtxKernel {
let head_dim = self.head_dim;
let heads_per_kv = self.heads_per_kv;
let epl = head_dim / 32; let scale = 1.0 / (head_dim as f32).sqrt();
PtxKernel::new("fused_prefill_attention_causal")
.param(PtxType::U64, "q_ptr") .param(PtxType::U64, "k_ptr") .param(PtxType::U64, "v_ptr") .param(PtxType::U64, "o_ptr") .param(PtxType::U32, "m_param") .param(PtxType::U32, "q_stride") .param(PtxType::U32, "kv_stride") .param(PtxType::U32, "num_q_heads") .build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let head_idx = ctx.special_reg(PtxReg::CtaIdX);
let q_ptr = ctx.load_param_u64("q_ptr");
let k_ptr = ctx.load_param_u64("k_ptr");
let v_ptr = ctx.load_param_u64("v_ptr");
let o_ptr = ctx.load_param_u64("o_ptr");
let m_param = ctx.load_param_u32("m_param");
let q_stride = ctx.load_param_u32("q_stride");
let kv_stride = ctx.load_param_u32("kv_stride");
let num_q_heads = ctx.load_param_u32("num_q_heads");
let head_valid = ctx.setp_lt_u32(head_idx, num_q_heads);
ctx.branch_if_not(head_valid, "exit");
let kv_head = ctx.div_u32(head_idx, heads_per_kv);
let head_dim_reg = ctx.mov_u32_imm(head_dim);
let q_head_elem_off = ctx.mul_u32_reg(head_idx, head_dim_reg);
let q_head_byte_off = ctx.mul_wide_u32(q_head_elem_off, 4);
let kv_head_elem_off = ctx.mul_u32_reg(kv_head, head_dim_reg);
let kv_head_byte_off = ctx.mul_wide_u32(kv_head_elem_off, 4);
let lane_base = ctx.mul_u32(tid, epl);
let scale_reg = ctx.mov_f32_imm(scale);
let log2e_reg = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let zero_f32 = ctx.mov_f32_imm(0.0);
let neg_inf = ctx.mov_f32_imm(f32::NEG_INFINITY);
let lane_byte_off = ctx.mul_wide_u32(lane_base, 4);
let mut q_regs = Vec::with_capacity(epl as usize);
let mut o_regs = Vec::with_capacity(epl as usize);
for _ in 0..epl {
q_regs.push(ctx.mov_f32_imm(0.0));
o_regs.push(ctx.mov_f32_imm(0.0));
}
let m_i = ctx.mov_f32_imm(f32::NEG_INFINITY);
let l_i = ctx.mov_f32_imm(0.0);
let row_i = ctx.mov_u32_imm(0);
ctx.label("row_start");
let row_done = ctx.setp_ge_u32(row_i, m_param);
ctx.branch_if(row_done, "row_end");
ctx.mov_f32_reg(m_i, neg_inf);
ctx.mov_f32_reg(l_i, zero_f32);
for o in &o_regs {
ctx.mov_f32_reg(*o, zero_f32);
}
let q_row_off = ctx.mul_wide_u32_reg(row_i, q_stride);
let q_row_off_bytes = ctx.mul_u64(q_row_off, 4);
let q_row_base = ctx.add_u64(q_ptr, q_row_off_bytes);
let q_row_base = ctx.add_u64(q_row_base, q_head_byte_off);
let q_lane_base = ctx.add_u64(q_row_base, lane_byte_off);
for e in 0..epl as usize {
let addr = if e == 0 {
q_lane_base
} else {
let off = ctx.mov_u64_imm((e * 4) as u64);
ctx.add_u64(q_lane_base, off)
};
let val = ctx.ld_global_f32(addr);
ctx.mov_f32_reg(q_regs[e], val);
}
let j_limit = ctx.add_u32(row_i, 1);
let col_j = ctx.mov_u32_imm(0);
ctx.label("col_start");
let col_done = ctx.setp_ge_u32(col_j, j_limit);
ctx.branch_if(col_done, "col_end");
let k_row_off = ctx.mul_wide_u32_reg(col_j, kv_stride);
let k_row_off_bytes = ctx.mul_u64(k_row_off, 4);
let k_row_base = ctx.add_u64(k_ptr, k_row_off_bytes);
let k_row_base = ctx.add_u64(k_row_base, kv_head_byte_off);
let k_lane_base = ctx.add_u64(k_row_base, lane_byte_off);
let dot_partial = ctx.mov_f32_imm(0.0);
for e in 0..epl as usize {
let k_addr = if e == 0 {
k_lane_base
} else {
let off = ctx.mov_u64_imm((e * 4) as u64);
ctx.add_u64(k_lane_base, off)
};
let k_val = ctx.ld_global_f32(k_addr);
ctx.fma_f32_inplace(dot_partial, q_regs[e], k_val);
}
let s16 = ctx.shfl_down_f32(dot_partial, 16, 31);
ctx.add_f32_inplace(dot_partial, s16);
let s8 = ctx.shfl_down_f32(dot_partial, 8, 31);
ctx.add_f32_inplace(dot_partial, s8);
let s4 = ctx.shfl_down_f32(dot_partial, 4, 31);
ctx.add_f32_inplace(dot_partial, s4);
let s2 = ctx.shfl_down_f32(dot_partial, 2, 31);
ctx.add_f32_inplace(dot_partial, s2);
let s1 = ctx.shfl_down_f32(dot_partial, 1, 31);
ctx.add_f32_inplace(dot_partial, s1);
let score = ctx.shfl_idx_f32(dot_partial, 0, 31);
let s_scaled = ctx.mul_f32(score, scale_reg);
let m_new = ctx.max_f32(m_i, s_scaled);
let m_diff = ctx.sub_f32(m_i, m_new);
let m_diff_log2 = ctx.mul_f32(m_diff, log2e_reg);
let correction = ctx.ex2_f32(m_diff_log2);
let s_diff = ctx.sub_f32(s_scaled, m_new);
let s_diff_log2 = ctx.mul_f32(s_diff, log2e_reg);
let p_val = ctx.ex2_f32(s_diff_log2);
ctx.mul_f32_inplace(l_i, correction);
ctx.add_f32_inplace(l_i, p_val);
ctx.mov_f32_reg(m_i, m_new);
let v_row_off = ctx.mul_wide_u32_reg(col_j, kv_stride);
let v_row_off_bytes = ctx.mul_u64(v_row_off, 4);
let v_row_base = ctx.add_u64(v_ptr, v_row_off_bytes);
let v_row_base = ctx.add_u64(v_row_base, kv_head_byte_off);
let v_lane_base = ctx.add_u64(v_row_base, lane_byte_off);
for e in 0..epl as usize {
ctx.mul_f32_inplace(o_regs[e], correction);
let v_addr = if e == 0 {
v_lane_base
} else {
let off = ctx.mov_u64_imm((e * 4) as u64);
ctx.add_u64(v_lane_base, off)
};
let v_val = ctx.ld_global_f32(v_addr);
ctx.fma_f32_inplace(o_regs[e], p_val, v_val);
}
ctx.add_u32_inplace(col_j, 1);
ctx.branch("col_start");
ctx.label("col_end");
let o_row_off = ctx.mul_wide_u32_reg(row_i, q_stride);
let o_row_off_bytes = ctx.mul_u64(o_row_off, 4);
let o_row_base = ctx.add_u64(o_ptr, o_row_off_bytes);
let o_row_base = ctx.add_u64(o_row_base, q_head_byte_off);
let o_lane_base = ctx.add_u64(o_row_base, lane_byte_off);
for e in 0..epl as usize {
ctx.div_f32_inplace(o_regs[e], l_i);
let o_addr = if e == 0 {
o_lane_base
} else {
let off = ctx.mov_u64_imm((e * 4) as u64);
ctx.add_u64(o_lane_base, off)
};
ctx.st_global_f32(o_addr, o_regs[e]);
}
ctx.add_u32_inplace(row_i, 1);
ctx.branch("row_start");
ctx.label("row_end");
ctx.label("exit");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prefill_attention_kernel_name() {
let kernel = PrefillAttentionKernel::new(64, 6);
assert_eq!(kernel.name(), "fused_prefill_attention_causal");
}
#[test]
fn test_prefill_attention_ptx_generation() {
let kernel = PrefillAttentionKernel::new(64, 6);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("fused_prefill_attention_causal"));
assert!(ptx.contains(".param .u64 q_ptr"));
assert!(ptx.contains(".param .u32 m_param"));
assert!(ptx.contains(".param .u32 num_q_heads"));
}
#[test]
fn test_prefill_attention_head_dim_128() {
let kernel = PrefillAttentionKernel::new(128, 6);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("fused_prefill_attention_causal"));
}
#[test]
fn test_prefill_attention_mha() {
let kernel = PrefillAttentionKernel::new(64, 1);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("fused_prefill_attention_causal"));
}
#[test]
#[should_panic(expected = "head_dim must be multiple of 32")]
fn test_prefill_attention_invalid_head_dim() {
let _kernel = PrefillAttentionKernel::new(48, 6);
}
}