#![allow(clippy::similar_names)]
#![allow(clippy::too_many_lines)]
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct IncrementalAttentionKernel {
pub max_seq_len: u32,
pub head_dim: u32,
pub num_heads: u32,
pub num_kv_heads: u32,
pub scale: f32,
pub indirect_seq_len: bool,
}
impl IncrementalAttentionKernel {
#[must_use]
pub fn new(max_seq_len: u32, head_dim: u32, num_heads: u32) -> Self {
Self::with_gqa(max_seq_len, head_dim, num_heads, num_heads)
}
#[must_use]
pub fn with_gqa(max_seq_len: u32, head_dim: u32, num_heads: u32, num_kv_heads: u32) -> Self {
Self {
max_seq_len,
head_dim,
num_heads,
num_kv_heads,
scale: 1.0 / (head_dim as f32).sqrt(),
indirect_seq_len: false,
}
}
#[must_use]
pub fn with_indirect_seq_len(mut self, indirect: bool) -> Self {
self.indirect_seq_len = indirect;
self
}
#[must_use]
pub fn is_gqa(&self) -> bool {
self.num_kv_heads != self.num_heads
}
}
impl Kernel for IncrementalAttentionKernel {
fn name(&self) -> &str {
if self.indirect_seq_len {
"incremental_attention_indirect"
} else {
"incremental_attention"
}
}
fn build_ptx(&self) -> PtxKernel {
let head_dim = self.head_dim;
let scale = self.scale;
let max_seq_len = self.max_seq_len;
let num_heads = self.num_heads;
let num_kv_heads = self.num_kv_heads;
let indirect = self.indirect_seq_len;
let kernel_name = if indirect {
"incremental_attention_indirect"
} else {
"incremental_attention"
};
let mut builder = PtxKernel::new(kernel_name)
.param(PtxType::U64, "q_ptr")
.param(PtxType::U64, "k_ptr")
.param(PtxType::U64, "v_ptr")
.param(PtxType::U64, "out_ptr");
builder = if indirect {
builder.param(PtxType::U64, "seq_len_ptr")
} else {
builder.param(PtxType::U32, "seq_len")
};
builder
.shared_memory(0) .build(move |ctx| {
let q_head_idx = ctx.special_reg(PtxReg::CtaIdX);
let lane_id = ctx.special_reg(PtxReg::TidX);
let seq_len = if indirect {
let seq_len_ptr = ctx.load_param_u64("seq_len_ptr");
ctx.ld_global_u32(seq_len_ptr)
} else {
ctx.load_param_u32("seq_len")
};
let q_ptr = ctx.load_param_u64("q_ptr");
let k_ptr = ctx.load_param_u64("k_ptr");
let v_ptr = ctx.load_param_u64("v_ptr");
let out_ptr = ctx.load_param_u64("out_ptr");
let four = ctx.mov_u32_imm(4);
let head_dim_u32 = ctx.mov_u32_imm(head_dim);
let q_head_off = ctx.mul_lo_u32(q_head_idx, head_dim_u32);
let q_head_off_bytes = ctx.mul_wide_u32_reg(q_head_off, four);
let q_head_ptr = ctx.add_u64(q_ptr, q_head_off_bytes);
let out_head_ptr = ctx.add_u64(out_ptr, q_head_off_bytes);
let kv_head_idx = ctx.mul_u32(q_head_idx, num_kv_heads);
let kv_head_idx = ctx.div_u32(kv_head_idx, num_heads);
let kv_stride = ctx.mov_u32_imm(max_seq_len * head_dim);
let kv_head_off = ctx.mul_lo_u32(kv_head_idx, kv_stride);
let kv_head_off_bytes = ctx.mul_wide_u32_reg(kv_head_off, four);
let k_head_ptr = ctx.add_u64(k_ptr, kv_head_off_bytes);
let v_head_ptr = ctx.add_u64(v_ptr, kv_head_off_bytes);
let q0_off_bytes = ctx.mul_wide_u32_reg(lane_id, four);
let q0_addr = ctx.add_u64(q_head_ptr, q0_off_bytes);
let in_bounds0 = ctx.setp_lt_u32(lane_id, head_dim_u32);
let q0 = ctx.ld_global_f32_predicated(q0_addr, in_bounds0, 0.0);
let lane_plus_32 = ctx.add_u32(lane_id, 32);
let q1_off_bytes = ctx.mul_wide_u32_reg(lane_plus_32, four);
let q1_addr = ctx.add_u64(q_head_ptr, q1_off_bytes);
let in_bounds1 = ctx.setp_lt_u32(lane_plus_32, head_dim_u32);
let q1 = ctx.ld_global_f32_predicated(q1_addr, in_bounds1, 0.0);
let lane_plus_64 = ctx.add_u32(lane_id, 64);
let q2_off_bytes = ctx.mul_wide_u32_reg(lane_plus_64, four);
let q2_addr = ctx.add_u64(q_head_ptr, q2_off_bytes);
let in_bounds2 = ctx.setp_lt_u32(lane_plus_64, head_dim_u32);
let q2 = ctx.ld_global_f32_predicated(q2_addr, in_bounds2, 0.0);
let lane_plus_96 = ctx.add_u32(lane_id, 96);
let q3_off_bytes = ctx.mul_wide_u32_reg(lane_plus_96, four);
let q3_addr = ctx.add_u64(q_head_ptr, q3_off_bytes);
let in_bounds3 = ctx.setp_lt_u32(lane_plus_96, head_dim_u32);
let q3 = ctx.ld_global_f32_predicated(q3_addr, in_bounds3, 0.0);
let out0 = ctx.mov_f32_imm(0.0);
let out1 = ctx.mov_f32_imm(0.0);
let out2 = ctx.mov_f32_imm(0.0);
let out3 = ctx.mov_f32_imm(0.0);
let max_score = ctx.mov_f32_imm(f32::NEG_INFINITY);
let sum_exp = ctx.mov_f32_imm(0.0);
let log2e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let scale_reg = ctx.mov_f32_imm(scale);
let pos = ctx.mov_u32_imm(0);
ctx.label("seq_loop");
let loop_cond = ctx.setp_lt_u32(pos, seq_len);
ctx.branch_if_not(loop_cond, "seq_loop_end");
let k_pos_off = ctx.mul_lo_u32(pos, head_dim_u32);
let k0_elem_off = ctx.add_u32_reg(k_pos_off, lane_id);
let k0_off_bytes = ctx.mul_wide_u32_reg(k0_elem_off, four);
let k0_addr = ctx.add_u64(k_head_ptr, k0_off_bytes);
let k0 = ctx.ld_global_f32_predicated(k0_addr, in_bounds0, 0.0);
let k1_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_32);
let k1_off_bytes = ctx.mul_wide_u32_reg(k1_elem_off, four);
let k1_addr = ctx.add_u64(k_head_ptr, k1_off_bytes);
let k1 = ctx.ld_global_f32_predicated(k1_addr, in_bounds1, 0.0);
let k2_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_64);
let k2_off_bytes = ctx.mul_wide_u32_reg(k2_elem_off, four);
let k2_addr = ctx.add_u64(k_head_ptr, k2_off_bytes);
let k2 = ctx.ld_global_f32_predicated(k2_addr, in_bounds2, 0.0);
let k3_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_96);
let k3_off_bytes = ctx.mul_wide_u32_reg(k3_elem_off, four);
let k3_addr = ctx.add_u64(k_head_ptr, k3_off_bytes);
let k3 = ctx.ld_global_f32_predicated(k3_addr, in_bounds3, 0.0);
let dot_partial = ctx.mul_f32(q0, k0);
let dot_partial = ctx.fma_f32(q1, k1, dot_partial);
let dot_partial = ctx.fma_f32(q2, k2, dot_partial);
let dot_partial = ctx.fma_f32(q3, k3, dot_partial);
let dot16 = ctx.shfl_down_f32(dot_partial, 16, 0xFFFF_FFFF);
let dot_partial = ctx.add_f32(dot_partial, dot16);
let dot8 = ctx.shfl_down_f32(dot_partial, 8, 0xFFFF_FFFF);
let dot_partial = ctx.add_f32(dot_partial, dot8);
let dot4 = ctx.shfl_down_f32(dot_partial, 4, 0xFFFF_FFFF);
let dot_partial = ctx.add_f32(dot_partial, dot4);
let dot2 = ctx.shfl_down_f32(dot_partial, 2, 0xFFFF_FFFF);
let dot_partial = ctx.add_f32(dot_partial, dot2);
let dot1 = ctx.shfl_down_f32(dot_partial, 1, 0xFFFF_FFFF);
let dot_reduced = ctx.add_f32(dot_partial, dot1);
let dot_broadcast = ctx.shfl_idx_f32(dot_reduced, 0, 0xFFFF_FFFF);
let score = ctx.mul_f32(dot_broadcast, scale_reg);
let new_max = ctx.max_f32(max_score, score);
let max_diff = ctx.sub_f32(max_score, new_max);
let max_diff_scaled = ctx.mul_f32(max_diff, log2e);
let correction = ctx.ex2_f32(max_diff_scaled);
let score_diff = ctx.sub_f32(score, new_max);
let score_diff_scaled = ctx.mul_f32(score_diff, log2e);
let exp_score = ctx.ex2_f32(score_diff_scaled);
let v0_elem_off = ctx.add_u32_reg(k_pos_off, lane_id);
let v0_off_bytes = ctx.mul_wide_u32_reg(v0_elem_off, four);
let v0_addr = ctx.add_u64(v_head_ptr, v0_off_bytes);
let v0 = ctx.ld_global_f32_predicated(v0_addr, in_bounds0, 0.0);
let v1_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_32);
let v1_off_bytes = ctx.mul_wide_u32_reg(v1_elem_off, four);
let v1_addr = ctx.add_u64(v_head_ptr, v1_off_bytes);
let v1 = ctx.ld_global_f32_predicated(v1_addr, in_bounds1, 0.0);
let v2_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_64);
let v2_off_bytes = ctx.mul_wide_u32_reg(v2_elem_off, four);
let v2_addr = ctx.add_u64(v_head_ptr, v2_off_bytes);
let v2 = ctx.ld_global_f32_predicated(v2_addr, in_bounds2, 0.0);
let v3_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_96);
let v3_off_bytes = ctx.mul_wide_u32_reg(v3_elem_off, four);
let v3_addr = ctx.add_u64(v_head_ptr, v3_off_bytes);
let v3 = ctx.ld_global_f32_predicated(v3_addr, in_bounds3, 0.0);
ctx.max_f32_inplace(max_score, score);
ctx.mul_f32_inplace(sum_exp, correction);
ctx.add_f32_inplace(sum_exp, exp_score);
ctx.mul_f32_inplace(out0, correction);
ctx.fma_f32_inplace(out0, exp_score, v0);
ctx.mul_f32_inplace(out1, correction);
ctx.fma_f32_inplace(out1, exp_score, v1);
ctx.mul_f32_inplace(out2, correction);
ctx.fma_f32_inplace(out2, exp_score, v2);
ctx.mul_f32_inplace(out3, correction);
ctx.fma_f32_inplace(out3, exp_score, v3);
ctx.add_u32_inplace(pos, 1);
ctx.branch("seq_loop");
ctx.label("seq_loop_end");
let one = ctx.mov_f32_imm(1.0);
let inv_sum = ctx.div_f32(one, sum_exp);
ctx.mul_f32_inplace(out0, inv_sum);
ctx.mul_f32_inplace(out1, inv_sum);
ctx.mul_f32_inplace(out2, inv_sum);
ctx.mul_f32_inplace(out3, inv_sum);
let out0_addr = ctx.add_u64(out_head_ptr, q0_off_bytes);
ctx.branch_if_not(in_bounds0, "skip_store0");
ctx.st_global_f32(out0_addr, out0);
ctx.label("skip_store0");
let out1_addr = ctx.add_u64(out_head_ptr, q1_off_bytes);
ctx.branch_if_not(in_bounds1, "skip_store1");
ctx.st_global_f32(out1_addr, out1);
ctx.label("skip_store1");
let out2_addr = ctx.add_u64(out_head_ptr, q2_off_bytes);
ctx.branch_if_not(in_bounds2, "skip_store2");
ctx.st_global_f32(out2_addr, out2);
ctx.label("skip_store2");
let out3_addr = ctx.add_u64(out_head_ptr, q3_off_bytes);
ctx.branch_if_not(in_bounds3, "skip_store3");
ctx.st_global_f32(out3_addr, out3);
ctx.label("skip_store3");
ctx.ret();
})
}
}