#![allow(clippy::similar_names)]
#![allow(clippy::too_many_lines)]
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl, PtxMemory};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct IncrementalAttentionKernel {
pub max_seq_len: u32,
pub head_dim: u32,
pub num_heads: u32,
pub num_kv_heads: u32,
pub scale: f32,
pub indirect_seq_len: bool,
}
impl IncrementalAttentionKernel {
#[must_use]
pub fn new(max_seq_len: u32, head_dim: u32, num_heads: u32) -> Self {
Self::with_gqa(max_seq_len, head_dim, num_heads, num_heads)
}
#[must_use]
pub fn with_gqa(max_seq_len: u32, head_dim: u32, num_heads: u32, num_kv_heads: u32) -> Self {
Self {
max_seq_len,
head_dim,
num_heads,
num_kv_heads,
scale: 1.0 / (head_dim as f32).sqrt(),
indirect_seq_len: false,
}
}
#[must_use]
pub fn with_indirect_seq_len(mut self, indirect: bool) -> Self {
self.indirect_seq_len = indirect;
self
}
#[must_use]
pub fn is_gqa(&self) -> bool {
self.num_kv_heads != self.num_heads
}
}
impl Kernel for IncrementalAttentionKernel {
fn name(&self) -> &str {
if self.indirect_seq_len {
"incremental_attention_indirect"
} else {
"incremental_attention"
}
}
fn build_ptx(&self) -> PtxKernel {
let head_dim = self.head_dim;
let scale = self.scale;
let max_seq_len = self.max_seq_len;
let num_heads = self.num_heads;
let num_kv_heads = self.num_kv_heads;
let indirect = self.indirect_seq_len;
let kernel_name = if indirect {
"incremental_attention_indirect"
} else {
"incremental_attention"
};
let mut builder = PtxKernel::new(kernel_name)
.param(PtxType::U64, "q_ptr")
.param(PtxType::U64, "k_ptr")
.param(PtxType::U64, "v_ptr")
.param(PtxType::U64, "out_ptr");
builder = if indirect {
builder.param(PtxType::U64, "seq_len_ptr")
} else {
builder.param(PtxType::U32, "seq_len")
};
builder
.shared_memory(0) .build(move |ctx| {
let q_head_idx = ctx.special_reg(PtxReg::CtaIdX);
let lane_id = ctx.special_reg(PtxReg::TidX);
let seq_len = if indirect {
let seq_len_ptr = ctx.load_param_u64("seq_len_ptr");
ctx.ld_global_u32(seq_len_ptr)
} else {
ctx.load_param_u32("seq_len")
};
let q_ptr = ctx.load_param_u64("q_ptr");
let k_ptr = ctx.load_param_u64("k_ptr");
let v_ptr = ctx.load_param_u64("v_ptr");
let out_ptr = ctx.load_param_u64("out_ptr");
let four = ctx.mov_u32_imm(4);
let head_dim_u32 = ctx.mov_u32_imm(head_dim);
let q_head_off = ctx.mul_lo_u32(q_head_idx, head_dim_u32);
let q_head_off_bytes = ctx.mul_wide_u32_reg(q_head_off, four);
let q_head_ptr = ctx.add_u64(q_ptr, q_head_off_bytes);
let out_head_ptr = ctx.add_u64(out_ptr, q_head_off_bytes);
let kv_head_idx = ctx.mul_u32(q_head_idx, num_kv_heads);
let kv_head_idx = ctx.div_u32(kv_head_idx, num_heads);
let kv_stride = ctx.mov_u32_imm(max_seq_len * head_dim);
let kv_head_off = ctx.mul_lo_u32(kv_head_idx, kv_stride);
let kv_head_off_bytes = ctx.mul_wide_u32_reg(kv_head_off, four);
let k_head_ptr = ctx.add_u64(k_ptr, kv_head_off_bytes);
let v_head_ptr = ctx.add_u64(v_ptr, kv_head_off_bytes);
let q0_off_bytes = ctx.mul_wide_u32_reg(lane_id, four);
let q0_addr = ctx.add_u64(q_head_ptr, q0_off_bytes);
let in_bounds0 = ctx.setp_lt_u32(lane_id, head_dim_u32);
let q0 = ctx.ld_global_f32_predicated(q0_addr, in_bounds0, 0.0);
let lane_plus_32 = ctx.add_u32(lane_id, 32);
let q1_off_bytes = ctx.mul_wide_u32_reg(lane_plus_32, four);
let q1_addr = ctx.add_u64(q_head_ptr, q1_off_bytes);
let in_bounds1 = ctx.setp_lt_u32(lane_plus_32, head_dim_u32);
let q1 = ctx.ld_global_f32_predicated(q1_addr, in_bounds1, 0.0);
let lane_plus_64 = ctx.add_u32(lane_id, 64);
let q2_off_bytes = ctx.mul_wide_u32_reg(lane_plus_64, four);
let q2_addr = ctx.add_u64(q_head_ptr, q2_off_bytes);
let in_bounds2 = ctx.setp_lt_u32(lane_plus_64, head_dim_u32);
let q2 = ctx.ld_global_f32_predicated(q2_addr, in_bounds2, 0.0);
let lane_plus_96 = ctx.add_u32(lane_id, 96);
let q3_off_bytes = ctx.mul_wide_u32_reg(lane_plus_96, four);
let q3_addr = ctx.add_u64(q_head_ptr, q3_off_bytes);
let in_bounds3 = ctx.setp_lt_u32(lane_plus_96, head_dim_u32);
let q3 = ctx.ld_global_f32_predicated(q3_addr, in_bounds3, 0.0);
let out0 = ctx.mov_f32_imm(0.0);
let out1 = ctx.mov_f32_imm(0.0);
let out2 = ctx.mov_f32_imm(0.0);
let out3 = ctx.mov_f32_imm(0.0);
let max_score = ctx.mov_f32_imm(f32::NEG_INFINITY);
let sum_exp = ctx.mov_f32_imm(0.0);
let log2e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let scale_reg = ctx.mov_f32_imm(scale);
let pos = ctx.mov_u32_imm(0);
ctx.label("seq_loop");
let loop_cond = ctx.setp_lt_u32(pos, seq_len);
ctx.branch_if_not(loop_cond, "seq_loop_end");
let k_pos_off = ctx.mul_lo_u32(pos, head_dim_u32);
let k0_elem_off = ctx.add_u32_reg(k_pos_off, lane_id);
let k0_off_bytes = ctx.mul_wide_u32_reg(k0_elem_off, four);
let k0_addr = ctx.add_u64(k_head_ptr, k0_off_bytes);
let k0 = ctx.ld_global_f32_predicated(k0_addr, in_bounds0, 0.0);
let k1_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_32);
let k1_off_bytes = ctx.mul_wide_u32_reg(k1_elem_off, four);
let k1_addr = ctx.add_u64(k_head_ptr, k1_off_bytes);
let k1 = ctx.ld_global_f32_predicated(k1_addr, in_bounds1, 0.0);
let k2_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_64);
let k2_off_bytes = ctx.mul_wide_u32_reg(k2_elem_off, four);
let k2_addr = ctx.add_u64(k_head_ptr, k2_off_bytes);
let k2 = ctx.ld_global_f32_predicated(k2_addr, in_bounds2, 0.0);
let k3_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_96);
let k3_off_bytes = ctx.mul_wide_u32_reg(k3_elem_off, four);
let k3_addr = ctx.add_u64(k_head_ptr, k3_off_bytes);
let k3 = ctx.ld_global_f32_predicated(k3_addr, in_bounds3, 0.0);
let dot_partial = ctx.mul_f32(q0, k0);
let dot_partial = ctx.fma_f32(q1, k1, dot_partial);
let dot_partial = ctx.fma_f32(q2, k2, dot_partial);
let dot_partial = ctx.fma_f32(q3, k3, dot_partial);
let dot16 = ctx.shfl_down_f32(dot_partial, 16, 0xFFFF_FFFF);
let dot_partial = ctx.add_f32(dot_partial, dot16);
let dot8 = ctx.shfl_down_f32(dot_partial, 8, 0xFFFF_FFFF);
let dot_partial = ctx.add_f32(dot_partial, dot8);
let dot4 = ctx.shfl_down_f32(dot_partial, 4, 0xFFFF_FFFF);
let dot_partial = ctx.add_f32(dot_partial, dot4);
let dot2 = ctx.shfl_down_f32(dot_partial, 2, 0xFFFF_FFFF);
let dot_partial = ctx.add_f32(dot_partial, dot2);
let dot1 = ctx.shfl_down_f32(dot_partial, 1, 0xFFFF_FFFF);
let dot_reduced = ctx.add_f32(dot_partial, dot1);
let dot_broadcast = ctx.shfl_idx_f32(dot_reduced, 0, 0xFFFF_FFFF);
let score = ctx.mul_f32(dot_broadcast, scale_reg);
let new_max = ctx.max_f32(max_score, score);
let max_diff = ctx.sub_f32(max_score, new_max);
let max_diff_scaled = ctx.mul_f32(max_diff, log2e);
let correction = ctx.ex2_f32(max_diff_scaled);
let score_diff = ctx.sub_f32(score, new_max);
let score_diff_scaled = ctx.mul_f32(score_diff, log2e);
let exp_score = ctx.ex2_f32(score_diff_scaled);
let v0_elem_off = ctx.add_u32_reg(k_pos_off, lane_id);
let v0_off_bytes = ctx.mul_wide_u32_reg(v0_elem_off, four);
let v0_addr = ctx.add_u64(v_head_ptr, v0_off_bytes);
let v0 = ctx.ld_global_f32_predicated(v0_addr, in_bounds0, 0.0);
let v1_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_32);
let v1_off_bytes = ctx.mul_wide_u32_reg(v1_elem_off, four);
let v1_addr = ctx.add_u64(v_head_ptr, v1_off_bytes);
let v1 = ctx.ld_global_f32_predicated(v1_addr, in_bounds1, 0.0);
let v2_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_64);
let v2_off_bytes = ctx.mul_wide_u32_reg(v2_elem_off, four);
let v2_addr = ctx.add_u64(v_head_ptr, v2_off_bytes);
let v2 = ctx.ld_global_f32_predicated(v2_addr, in_bounds2, 0.0);
let v3_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_96);
let v3_off_bytes = ctx.mul_wide_u32_reg(v3_elem_off, four);
let v3_addr = ctx.add_u64(v_head_ptr, v3_off_bytes);
let v3 = ctx.ld_global_f32_predicated(v3_addr, in_bounds3, 0.0);
ctx.max_f32_inplace(max_score, score);
ctx.mul_f32_inplace(sum_exp, correction);
ctx.add_f32_inplace(sum_exp, exp_score);
ctx.mul_f32_inplace(out0, correction);
ctx.fma_f32_inplace(out0, exp_score, v0);
ctx.mul_f32_inplace(out1, correction);
ctx.fma_f32_inplace(out1, exp_score, v1);
ctx.mul_f32_inplace(out2, correction);
ctx.fma_f32_inplace(out2, exp_score, v2);
ctx.mul_f32_inplace(out3, correction);
ctx.fma_f32_inplace(out3, exp_score, v3);
ctx.add_u32_inplace(pos, 1);
ctx.branch("seq_loop");
ctx.label("seq_loop_end");
let one = ctx.mov_f32_imm(1.0);
let inv_sum = ctx.div_f32(one, sum_exp);
ctx.mul_f32_inplace(out0, inv_sum);
ctx.mul_f32_inplace(out1, inv_sum);
ctx.mul_f32_inplace(out2, inv_sum);
ctx.mul_f32_inplace(out3, inv_sum);
let out0_addr = ctx.add_u64(out_head_ptr, q0_off_bytes);
ctx.branch_if_not(in_bounds0, "skip_store0");
ctx.st_global_f32(out0_addr, out0);
ctx.label("skip_store0");
let out1_addr = ctx.add_u64(out_head_ptr, q1_off_bytes);
ctx.branch_if_not(in_bounds1, "skip_store1");
ctx.st_global_f32(out1_addr, out1);
ctx.label("skip_store1");
let out2_addr = ctx.add_u64(out_head_ptr, q2_off_bytes);
ctx.branch_if_not(in_bounds2, "skip_store2");
ctx.st_global_f32(out2_addr, out2);
ctx.label("skip_store2");
let out3_addr = ctx.add_u64(out_head_ptr, q3_off_bytes);
ctx.branch_if_not(in_bounds3, "skip_store3");
ctx.st_global_f32(out3_addr, out3);
ctx.label("skip_store3");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct MultiWarpIncrementalAttentionKernel {
pub max_seq_len: u32,
pub head_dim: u32,
pub num_heads: u32,
pub num_kv_heads: u32,
pub num_warps_per_head: u32,
pub scale: f32,
pub indirect_seq_len: bool,
}
impl MultiWarpIncrementalAttentionKernel {
#[must_use]
pub fn new(
max_seq_len: u32,
head_dim: u32,
num_heads: u32,
num_kv_heads: u32,
num_warps: u32,
) -> Self {
Self {
max_seq_len,
head_dim,
num_heads,
num_kv_heads,
num_warps_per_head: num_warps,
scale: 1.0 / (head_dim as f32).sqrt(),
indirect_seq_len: false,
}
}
#[must_use]
pub fn with_indirect_seq_len(mut self, indirect: bool) -> Self {
self.indirect_seq_len = indirect;
self
}
}
impl Kernel for MultiWarpIncrementalAttentionKernel {
fn name(&self) -> &str {
if self.indirect_seq_len {
"multi_warp_attention_indirect"
} else {
"multi_warp_attention"
}
}
fn build_ptx(&self) -> PtxKernel {
let head_dim = self.head_dim;
let scale = self.scale;
let max_seq_len = self.max_seq_len;
let num_heads = self.num_heads;
let num_kv_heads = self.num_kv_heads;
let num_warps = self.num_warps_per_head;
let indirect = self.indirect_seq_len;
let smem_size = (num_warps * 2 + 2 + num_warps * head_dim) * 4;
let kernel_name = if indirect {
"multi_warp_attention_indirect"
} else {
"multi_warp_attention"
};
let mut builder = PtxKernel::new(kernel_name)
.param(PtxType::U64, "q_ptr")
.param(PtxType::U64, "k_ptr")
.param(PtxType::U64, "v_ptr")
.param(PtxType::U64, "out_ptr");
builder = if indirect {
builder.param(PtxType::U64, "seq_len_ptr")
} else {
builder.param(PtxType::U32, "seq_len")
};
builder.shared_memory(smem_size as usize).build(move |ctx| {
let q_head_idx = ctx.special_reg(PtxReg::CtaIdX);
let tid = ctx.special_reg(PtxReg::TidX);
let warp_idx = ctx.div_u32(tid, 32);
let lane_id = ctx.rem_u32(tid, 32);
let seq_len = if indirect {
let seq_len_ptr = ctx.load_param_u64("seq_len_ptr");
ctx.ld_global_u32(seq_len_ptr)
} else {
ctx.load_param_u32("seq_len")
};
let q_ptr = ctx.load_param_u64("q_ptr");
let k_ptr = ctx.load_param_u64("k_ptr");
let v_ptr = ctx.load_param_u64("v_ptr");
let out_ptr = ctx.load_param_u64("out_ptr");
let four = ctx.mov_u32_imm(4);
let head_dim_u32 = ctx.mov_u32_imm(head_dim);
let num_warps_u32 = ctx.mov_u32_imm(num_warps);
let seq_plus_nw = ctx.add_u32(seq_len, num_warps - 1);
let chunk_size = ctx.div_u32(seq_plus_nw, num_warps);
let start_pos = ctx.mul_lo_u32(warp_idx, chunk_size);
let end_pos = ctx.add_u32_reg(start_pos, chunk_size);
let end_pos = ctx.min_u32(end_pos, seq_len);
let q_head_off = ctx.mul_lo_u32(q_head_idx, head_dim_u32);
let q_head_off_bytes = ctx.mul_wide_u32_reg(q_head_off, four);
let q_head_ptr = ctx.add_u64(q_ptr, q_head_off_bytes);
let out_head_ptr = ctx.add_u64(out_ptr, q_head_off_bytes);
let kv_head_idx = ctx.mul_u32(q_head_idx, num_kv_heads);
let kv_head_idx = ctx.div_u32(kv_head_idx, num_heads);
let kv_stride = ctx.mov_u32_imm(max_seq_len * head_dim);
let kv_head_off = ctx.mul_lo_u32(kv_head_idx, kv_stride);
let kv_head_off_bytes = ctx.mul_wide_u32_reg(kv_head_off, four);
let k_head_ptr = ctx.add_u64(k_ptr, kv_head_off_bytes);
let v_head_ptr = ctx.add_u64(v_ptr, kv_head_off_bytes);
let q0_off_bytes = ctx.mul_wide_u32_reg(lane_id, four);
let q0_addr = ctx.add_u64(q_head_ptr, q0_off_bytes);
let in_bounds0 = ctx.setp_lt_u32(lane_id, head_dim_u32);
let q0 = ctx.ld_global_f32_predicated(q0_addr, in_bounds0, 0.0);
let lane_plus_32 = ctx.add_u32(lane_id, 32);
let q1_off_bytes = ctx.mul_wide_u32_reg(lane_plus_32, four);
let q1_addr = ctx.add_u64(q_head_ptr, q1_off_bytes);
let in_bounds1 = ctx.setp_lt_u32(lane_plus_32, head_dim_u32);
let q1 = ctx.ld_global_f32_predicated(q1_addr, in_bounds1, 0.0);
let lane_plus_64 = ctx.add_u32(lane_id, 64);
let q2_off_bytes = ctx.mul_wide_u32_reg(lane_plus_64, four);
let q2_addr = ctx.add_u64(q_head_ptr, q2_off_bytes);
let in_bounds2 = ctx.setp_lt_u32(lane_plus_64, head_dim_u32);
let q2 = ctx.ld_global_f32_predicated(q2_addr, in_bounds2, 0.0);
let lane_plus_96 = ctx.add_u32(lane_id, 96);
let q3_off_bytes = ctx.mul_wide_u32_reg(lane_plus_96, four);
let q3_addr = ctx.add_u64(q_head_ptr, q3_off_bytes);
let in_bounds3 = ctx.setp_lt_u32(lane_plus_96, head_dim_u32);
let q3 = ctx.ld_global_f32_predicated(q3_addr, in_bounds3, 0.0);
let out0 = ctx.mov_f32_imm(0.0);
let out1 = ctx.mov_f32_imm(0.0);
let out2 = ctx.mov_f32_imm(0.0);
let out3 = ctx.mov_f32_imm(0.0);
let max_score = ctx.mov_f32_imm(f32::NEG_INFINITY);
let sum_exp = ctx.mov_f32_imm(0.0);
let log2e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let scale_reg = ctx.mov_f32_imm(scale);
let pos = ctx.add_u32(start_pos, 0);
ctx.label("chunk_loop");
let loop_cond = ctx.setp_lt_u32(pos, end_pos);
ctx.branch_if_not(loop_cond, "chunk_loop_end");
let k_pos_off = ctx.mul_lo_u32(pos, head_dim_u32);
let k0_elem_off = ctx.add_u32_reg(k_pos_off, lane_id);
let k0_off_bytes = ctx.mul_wide_u32_reg(k0_elem_off, four);
let k0_addr = ctx.add_u64(k_head_ptr, k0_off_bytes);
let k0 = ctx.ld_global_f32_predicated(k0_addr, in_bounds0, 0.0);
let k1_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_32);
let k1_off_bytes = ctx.mul_wide_u32_reg(k1_elem_off, four);
let k1_addr = ctx.add_u64(k_head_ptr, k1_off_bytes);
let k1 = ctx.ld_global_f32_predicated(k1_addr, in_bounds1, 0.0);
let k2_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_64);
let k2_off_bytes = ctx.mul_wide_u32_reg(k2_elem_off, four);
let k2_addr = ctx.add_u64(k_head_ptr, k2_off_bytes);
let k2 = ctx.ld_global_f32_predicated(k2_addr, in_bounds2, 0.0);
let k3_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_96);
let k3_off_bytes = ctx.mul_wide_u32_reg(k3_elem_off, four);
let k3_addr = ctx.add_u64(k_head_ptr, k3_off_bytes);
let k3 = ctx.ld_global_f32_predicated(k3_addr, in_bounds3, 0.0);
let dot = ctx.mul_f32(q0, k0);
let dot = ctx.fma_f32(q1, k1, dot);
let dot = ctx.fma_f32(q2, k2, dot);
let dot = ctx.fma_f32(q3, k3, dot);
let dot16 = ctx.shfl_down_f32(dot, 16, 0xFFFF_FFFF);
let dot = ctx.add_f32(dot, dot16);
let dot8 = ctx.shfl_down_f32(dot, 8, 0xFFFF_FFFF);
let dot = ctx.add_f32(dot, dot8);
let dot4 = ctx.shfl_down_f32(dot, 4, 0xFFFF_FFFF);
let dot = ctx.add_f32(dot, dot4);
let dot2 = ctx.shfl_down_f32(dot, 2, 0xFFFF_FFFF);
let dot = ctx.add_f32(dot, dot2);
let dot1 = ctx.shfl_down_f32(dot, 1, 0xFFFF_FFFF);
let dot = ctx.add_f32(dot, dot1);
let score = ctx.shfl_idx_f32(dot, 0, 0xFFFF_FFFF);
let score = ctx.mul_f32(score, scale_reg);
let new_max = ctx.max_f32(max_score, score);
let max_diff = ctx.sub_f32(max_score, new_max);
let max_diff_scaled = ctx.mul_f32(max_diff, log2e);
let correction = ctx.ex2_f32(max_diff_scaled);
let score_diff = ctx.sub_f32(score, new_max);
let score_diff_scaled = ctx.mul_f32(score_diff, log2e);
let exp_score = ctx.ex2_f32(score_diff_scaled);
let v0_elem_off = ctx.add_u32_reg(k_pos_off, lane_id);
let v0_off_bytes = ctx.mul_wide_u32_reg(v0_elem_off, four);
let v0_addr = ctx.add_u64(v_head_ptr, v0_off_bytes);
let v0 = ctx.ld_global_f32_predicated(v0_addr, in_bounds0, 0.0);
let v1_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_32);
let v1_off_bytes = ctx.mul_wide_u32_reg(v1_elem_off, four);
let v1_addr = ctx.add_u64(v_head_ptr, v1_off_bytes);
let v1 = ctx.ld_global_f32_predicated(v1_addr, in_bounds1, 0.0);
let v2_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_64);
let v2_off_bytes = ctx.mul_wide_u32_reg(v2_elem_off, four);
let v2_addr = ctx.add_u64(v_head_ptr, v2_off_bytes);
let v2 = ctx.ld_global_f32_predicated(v2_addr, in_bounds2, 0.0);
let v3_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_96);
let v3_off_bytes = ctx.mul_wide_u32_reg(v3_elem_off, four);
let v3_addr = ctx.add_u64(v_head_ptr, v3_off_bytes);
let v3 = ctx.ld_global_f32_predicated(v3_addr, in_bounds3, 0.0);
ctx.max_f32_inplace(max_score, score);
ctx.mul_f32_inplace(sum_exp, correction);
ctx.add_f32_inplace(sum_exp, exp_score);
ctx.mul_f32_inplace(out0, correction);
ctx.fma_f32_inplace(out0, exp_score, v0);
ctx.mul_f32_inplace(out1, correction);
ctx.fma_f32_inplace(out1, exp_score, v1);
ctx.mul_f32_inplace(out2, correction);
ctx.fma_f32_inplace(out2, exp_score, v2);
ctx.mul_f32_inplace(out3, correction);
ctx.fma_f32_inplace(out3, exp_score, v3);
ctx.add_u32_inplace(pos, 1);
ctx.branch("chunk_loop");
ctx.label("chunk_loop_end");
let warp_off = ctx.mul_u32(warp_idx, 4);
let warp_off_64 = ctx.cvt_u64_u32(warp_off);
let max_off_base = ctx.mov_u64_imm(0);
let max_addr = ctx.add_u64(max_off_base, warp_off_64);
let sum_off_base = ctx.mov_u64_imm((num_warps * 4) as u64);
let sum_addr = ctx.add_u64(sum_off_base, warp_off_64);
let zero_u32 = ctx.mov_u32_imm(0);
let is_lane0 = ctx.setp_eq_u32(lane_id, zero_u32);
ctx.branch_if_not(is_lane0, "skip_smem_write");
ctx.st_shared_f32(max_addr, max_score);
ctx.st_shared_f32(sum_addr, sum_exp);
ctx.label("skip_smem_write");
ctx.bar_sync(0);
let is_warp0 = ctx.setp_eq_u32(warp_idx, zero_u32);
let is_warp0_lane0 = ctx.and_pred(is_warp0, is_lane0);
ctx.branch_if_not(is_warp0_lane0, "skip_reduce");
let global_max = ctx.mov_f32_imm(f32::NEG_INFINITY);
let reduce_i = ctx.mov_u32_imm(0);
ctx.label("reduce_max_loop");
let reduce_cond = ctx.setp_lt_u32(reduce_i, num_warps_u32);
ctx.branch_if_not(reduce_cond, "reduce_max_done");
let i_off = ctx.mul_u32(reduce_i, 4);
let i_off_64 = ctx.cvt_u64_u32(i_off);
let max_i_addr = ctx.add_u64(max_off_base, i_off_64);
let max_i = ctx.ld_shared_f32(max_i_addr);
ctx.max_f32_inplace(global_max, max_i);
ctx.add_u32_inplace(reduce_i, 1);
ctx.branch("reduce_max_loop");
ctx.label("reduce_max_done");
let global_sum = ctx.mov_f32_imm(0.0);
let reduce_i = ctx.mov_u32_imm(0);
ctx.label("reduce_sum_loop");
let reduce_cond = ctx.setp_lt_u32(reduce_i, num_warps_u32);
ctx.branch_if_not(reduce_cond, "reduce_sum_done");
let i_off = ctx.mul_u32(reduce_i, 4);
let i_off_64 = ctx.cvt_u64_u32(i_off);
let max_i_addr = ctx.add_u64(max_off_base, i_off_64);
let max_i = ctx.ld_shared_f32(max_i_addr);
let sum_i_addr = ctx.add_u64(sum_off_base, i_off_64);
let sum_i = ctx.ld_shared_f32(sum_i_addr);
let diff = ctx.sub_f32(max_i, global_max);
let diff_scaled = ctx.mul_f32(diff, log2e);
let correction = ctx.ex2_f32(diff_scaled);
let corrected_sum = ctx.mul_f32(sum_i, correction);
ctx.add_f32_inplace(global_sum, corrected_sum);
ctx.add_u32_inplace(reduce_i, 1);
ctx.branch("reduce_sum_loop");
ctx.label("reduce_sum_done");
let global_max_off = ctx.mov_u64_imm((num_warps * 8) as u64);
let global_sum_off = ctx.mov_u64_imm((num_warps * 8 + 4) as u64);
ctx.st_shared_f32(global_max_off, global_max);
ctx.st_shared_f32(global_sum_off, global_sum);
ctx.label("skip_reduce");
ctx.bar_sync(1);
let global_max_off = ctx.mov_u64_imm((num_warps * 8) as u64);
let global_sum_off = ctx.mov_u64_imm((num_warps * 8 + 4) as u64);
let global_max = ctx.ld_shared_f32(global_max_off);
let global_sum = ctx.ld_shared_f32(global_sum_off);
let my_max = ctx.ld_shared_f32(max_addr);
let my_diff = ctx.sub_f32(my_max, global_max);
let my_diff_scaled = ctx.mul_f32(my_diff, log2e);
let my_correction = ctx.ex2_f32(my_diff_scaled);
let one = ctx.mov_f32_imm(1.0);
let inv_sum = ctx.div_f32(one, global_sum);
let final_scale = ctx.mul_f32(my_correction, inv_sum);
ctx.mul_f32_inplace(out0, final_scale);
ctx.mul_f32_inplace(out1, final_scale);
ctx.mul_f32_inplace(out2, final_scale);
ctx.mul_f32_inplace(out3, final_scale);
let out_area_base = ctx.mov_u32_imm(num_warps * 8 + 8);
let warp_out_offset = ctx.mul_u32(warp_idx, head_dim * 4);
let out_base = ctx.add_u32_reg(out_area_base, warp_out_offset);
let lane_off_0 = ctx.mul_u32(lane_id, 4);
let out0_smem_off = ctx.add_u32_reg(out_base, lane_off_0);
let out0_smem_addr = ctx.cvt_u64_u32(out0_smem_off);
ctx.branch_if_not(in_bounds0, "skip_store_out0");
ctx.st_shared_f32(out0_smem_addr, out0);
ctx.label("skip_store_out0");
let lane_off_1 = ctx.mul_u32(lane_plus_32, 4);
let out1_smem_off = ctx.add_u32_reg(out_base, lane_off_1);
let out1_smem_addr = ctx.cvt_u64_u32(out1_smem_off);
ctx.branch_if_not(in_bounds1, "skip_store_out1");
ctx.st_shared_f32(out1_smem_addr, out1);
ctx.label("skip_store_out1");
let lane_off_2 = ctx.mul_u32(lane_plus_64, 4);
let out2_smem_off = ctx.add_u32_reg(out_base, lane_off_2);
let out2_smem_addr = ctx.cvt_u64_u32(out2_smem_off);
ctx.branch_if_not(in_bounds2, "skip_store_out2");
ctx.st_shared_f32(out2_smem_addr, out2);
ctx.label("skip_store_out2");
let lane_off_3 = ctx.mul_u32(lane_plus_96, 4);
let out3_smem_off = ctx.add_u32_reg(out_base, lane_off_3);
let out3_smem_addr = ctx.cvt_u64_u32(out3_smem_off);
ctx.branch_if_not(in_bounds3, "skip_store_out3");
ctx.st_shared_f32(out3_smem_addr, out3);
ctx.label("skip_store_out3");
ctx.bar_sync(2);
ctx.branch_if_not(is_warp0, "skip_final_sum");
let final_out0 = ctx.mov_f32_imm(0.0);
let sum_w = ctx.mov_u32_imm(0);
ctx.label("sum_warps_loop0");
let sum_cond = ctx.setp_lt_u32(sum_w, num_warps_u32);
ctx.branch_if_not(sum_cond, "sum_warps_done0");
let w_out_offset = ctx.mul_u32(sum_w, head_dim * 4);
let w_out_base = ctx.add_u32_reg(out_area_base, w_out_offset);
let elem_off = ctx.mul_u32(lane_id, 4);
let elem_addr_off = ctx.add_u32_reg(w_out_base, elem_off);
let elem_addr = ctx.cvt_u64_u32(elem_addr_off);
let elem_val = ctx.ld_shared_f32(elem_addr);
ctx.add_f32_inplace(final_out0, elem_val);
ctx.add_u32_inplace(sum_w, 1);
ctx.branch("sum_warps_loop0");
ctx.label("sum_warps_done0");
let out0_addr = ctx.add_u64(out_head_ptr, q0_off_bytes);
ctx.branch_if_not(in_bounds0, "skip_final_store0");
ctx.st_global_f32(out0_addr, final_out0);
ctx.label("skip_final_store0");
let final_out1 = ctx.mov_f32_imm(0.0);
let sum_w = ctx.mov_u32_imm(0);
ctx.label("sum_warps_loop1");
let sum_cond = ctx.setp_lt_u32(sum_w, num_warps_u32);
ctx.branch_if_not(sum_cond, "sum_warps_done1");
let w_out_offset = ctx.mul_u32(sum_w, head_dim * 4);
let w_out_base = ctx.add_u32_reg(out_area_base, w_out_offset);
let elem_off = ctx.mul_u32(lane_plus_32, 4);
let elem_addr_off = ctx.add_u32_reg(w_out_base, elem_off);
let elem_addr = ctx.cvt_u64_u32(elem_addr_off);
let elem_val = ctx.ld_shared_f32(elem_addr);
ctx.add_f32_inplace(final_out1, elem_val);
ctx.add_u32_inplace(sum_w, 1);
ctx.branch("sum_warps_loop1");
ctx.label("sum_warps_done1");
let out1_addr = ctx.add_u64(out_head_ptr, q1_off_bytes);
ctx.branch_if_not(in_bounds1, "skip_final_store1");
ctx.st_global_f32(out1_addr, final_out1);
ctx.label("skip_final_store1");
let final_out2 = ctx.mov_f32_imm(0.0);
let sum_w = ctx.mov_u32_imm(0);
ctx.label("sum_warps_loop2");
let sum_cond = ctx.setp_lt_u32(sum_w, num_warps_u32);
ctx.branch_if_not(sum_cond, "sum_warps_done2");
let w_out_offset = ctx.mul_u32(sum_w, head_dim * 4);
let w_out_base = ctx.add_u32_reg(out_area_base, w_out_offset);
let elem_off = ctx.mul_u32(lane_plus_64, 4);
let elem_addr_off = ctx.add_u32_reg(w_out_base, elem_off);
let elem_addr = ctx.cvt_u64_u32(elem_addr_off);
let elem_val = ctx.ld_shared_f32(elem_addr);
ctx.add_f32_inplace(final_out2, elem_val);
ctx.add_u32_inplace(sum_w, 1);
ctx.branch("sum_warps_loop2");
ctx.label("sum_warps_done2");
let out2_addr = ctx.add_u64(out_head_ptr, q2_off_bytes);
ctx.branch_if_not(in_bounds2, "skip_final_store2");
ctx.st_global_f32(out2_addr, final_out2);
ctx.label("skip_final_store2");
let final_out3 = ctx.mov_f32_imm(0.0);
let sum_w = ctx.mov_u32_imm(0);
ctx.label("sum_warps_loop3");
let sum_cond = ctx.setp_lt_u32(sum_w, num_warps_u32);
ctx.branch_if_not(sum_cond, "sum_warps_done3");
let w_out_offset = ctx.mul_u32(sum_w, head_dim * 4);
let w_out_base = ctx.add_u32_reg(out_area_base, w_out_offset);
let elem_off = ctx.mul_u32(lane_plus_96, 4);
let elem_addr_off = ctx.add_u32_reg(w_out_base, elem_off);
let elem_addr = ctx.cvt_u64_u32(elem_addr_off);
let elem_val = ctx.ld_shared_f32(elem_addr);
ctx.add_f32_inplace(final_out3, elem_val);
ctx.add_u32_inplace(sum_w, 1);
ctx.branch("sum_warps_loop3");
ctx.label("sum_warps_done3");
let out3_addr = ctx.add_u64(out_head_ptr, q3_off_bytes);
ctx.branch_if_not(in_bounds3, "skip_final_store3");
ctx.st_global_f32(out3_addr, final_out3);
ctx.label("skip_final_store3");
ctx.label("skip_final_sum");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct BatchedIncrementalAttentionKernel {
pub max_seq_len: u32,
pub head_dim: u32,
pub num_heads: u32,
pub num_kv_heads: u32,
pub batch_size: u32,
pub scale: f32,
}
impl BatchedIncrementalAttentionKernel {
#[must_use]
pub fn new(
max_seq_len: u32,
head_dim: u32,
num_heads: u32,
num_kv_heads: u32,
batch_size: u32,
) -> Self {
Self {
max_seq_len,
head_dim,
num_heads,
num_kv_heads,
batch_size,
scale: 1.0 / (head_dim as f32).sqrt(),
}
}
}
impl Kernel for BatchedIncrementalAttentionKernel {
fn name(&self) -> &str {
"batched_incremental_attention"
}
fn build_ptx(&self) -> PtxKernel {
let head_dim = self.head_dim;
let scale = self.scale;
let max_seq_len = self.max_seq_len;
let num_heads = self.num_heads;
let num_kv_heads = self.num_kv_heads;
let _batch_size = self.batch_size;
PtxKernel::new("batched_incremental_attention")
.param(PtxType::U64, "q_ptr") .param(PtxType::U64, "k_ptrs_ptr") .param(PtxType::U64, "v_ptrs_ptr") .param(PtxType::U64, "out_ptr") .param(PtxType::U64, "seq_lens_ptr") .shared_memory(0)
.build(move |ctx| {
let head_idx = ctx.special_reg(PtxReg::CtaIdX);
let batch_idx = ctx.special_reg(PtxReg::CtaIdY);
let lane_id = ctx.special_reg(PtxReg::TidX);
let q_ptr = ctx.load_param_u64("q_ptr");
let k_ptrs_ptr = ctx.load_param_u64("k_ptrs_ptr");
let v_ptrs_ptr = ctx.load_param_u64("v_ptrs_ptr");
let out_ptr = ctx.load_param_u64("out_ptr");
let seq_lens_ptr = ctx.load_param_u64("seq_lens_ptr");
let four = ctx.mov_u32_imm(4);
let eight = ctx.mov_u32_imm(8);
let batch_idx_bytes = ctx.mul_wide_u32_reg(batch_idx, four);
let seq_len_addr = ctx.add_u64(seq_lens_ptr, batch_idx_bytes);
let seq_len = ctx.ld_global_u32(seq_len_addr);
let batch_ptr_off = ctx.mul_wide_u32_reg(batch_idx, eight);
let k_ptr_addr = ctx.add_u64(k_ptrs_ptr, batch_ptr_off);
let v_ptr_addr = ctx.add_u64(v_ptrs_ptr, batch_ptr_off);
let k_cache_ptr = ctx.ld_global_u64(k_ptr_addr);
let v_cache_ptr = ctx.ld_global_u64(v_ptr_addr);
let head_dim_u32 = ctx.mov_u32_imm(head_dim);
let num_heads_u32 = ctx.mov_u32_imm(num_heads);
let batch_head_stride = ctx.mul_lo_u32(num_heads_u32, head_dim_u32);
let batch_off = ctx.mul_lo_u32(batch_idx, batch_head_stride);
let head_off = ctx.mul_lo_u32(head_idx, head_dim_u32);
let q_head_off = ctx.add_u32_reg(batch_off, head_off);
let q_head_off_bytes = ctx.mul_wide_u32_reg(q_head_off, four);
let q_head_ptr = ctx.add_u64(q_ptr, q_head_off_bytes);
let out_head_ptr = ctx.add_u64(out_ptr, q_head_off_bytes);
let kv_head_idx = ctx.mul_u32(head_idx, num_kv_heads);
let kv_head_idx = ctx.div_u32(kv_head_idx, num_heads);
let kv_stride = ctx.mov_u32_imm(max_seq_len * head_dim);
let kv_head_off = ctx.mul_lo_u32(kv_head_idx, kv_stride);
let kv_head_off_bytes = ctx.mul_wide_u32_reg(kv_head_off, four);
let k_head_ptr = ctx.add_u64(k_cache_ptr, kv_head_off_bytes);
let v_head_ptr = ctx.add_u64(v_cache_ptr, kv_head_off_bytes);
let q0_off_bytes = ctx.mul_wide_u32_reg(lane_id, four);
let q0_addr = ctx.add_u64(q_head_ptr, q0_off_bytes);
let in_bounds0 = ctx.setp_lt_u32(lane_id, head_dim_u32);
let q0 = ctx.ld_global_f32_predicated(q0_addr, in_bounds0, 0.0);
let lane_plus_32 = ctx.add_u32(lane_id, 32);
let q1_off_bytes = ctx.mul_wide_u32_reg(lane_plus_32, four);
let q1_addr = ctx.add_u64(q_head_ptr, q1_off_bytes);
let in_bounds1 = ctx.setp_lt_u32(lane_plus_32, head_dim_u32);
let q1 = ctx.ld_global_f32_predicated(q1_addr, in_bounds1, 0.0);
let lane_plus_64 = ctx.add_u32(lane_id, 64);
let q2_off_bytes = ctx.mul_wide_u32_reg(lane_plus_64, four);
let q2_addr = ctx.add_u64(q_head_ptr, q2_off_bytes);
let in_bounds2 = ctx.setp_lt_u32(lane_plus_64, head_dim_u32);
let q2 = ctx.ld_global_f32_predicated(q2_addr, in_bounds2, 0.0);
let lane_plus_96 = ctx.add_u32(lane_id, 96);
let q3_off_bytes = ctx.mul_wide_u32_reg(lane_plus_96, four);
let q3_addr = ctx.add_u64(q_head_ptr, q3_off_bytes);
let in_bounds3 = ctx.setp_lt_u32(lane_plus_96, head_dim_u32);
let q3 = ctx.ld_global_f32_predicated(q3_addr, in_bounds3, 0.0);
let out0 = ctx.mov_f32_imm(0.0);
let out1 = ctx.mov_f32_imm(0.0);
let out2 = ctx.mov_f32_imm(0.0);
let out3 = ctx.mov_f32_imm(0.0);
let max_score = ctx.mov_f32_imm(f32::NEG_INFINITY);
let sum_exp = ctx.mov_f32_imm(0.0);
let log2e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let scale_reg = ctx.mov_f32_imm(scale);
let pos = ctx.mov_u32_imm(0);
ctx.label("batched_seq_loop");
let loop_cond = ctx.setp_lt_u32(pos, seq_len);
ctx.branch_if_not(loop_cond, "batched_seq_loop_end");
let k_pos_off = ctx.mul_lo_u32(pos, head_dim_u32);
let k0_elem_off = ctx.add_u32_reg(k_pos_off, lane_id);
let k0_off_bytes = ctx.mul_wide_u32_reg(k0_elem_off, four);
let k0_addr = ctx.add_u64(k_head_ptr, k0_off_bytes);
let k0 = ctx.ld_global_f32_predicated(k0_addr, in_bounds0, 0.0);
let k1_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_32);
let k1_off_bytes = ctx.mul_wide_u32_reg(k1_elem_off, four);
let k1_addr = ctx.add_u64(k_head_ptr, k1_off_bytes);
let k1 = ctx.ld_global_f32_predicated(k1_addr, in_bounds1, 0.0);
let k2_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_64);
let k2_off_bytes = ctx.mul_wide_u32_reg(k2_elem_off, four);
let k2_addr = ctx.add_u64(k_head_ptr, k2_off_bytes);
let k2 = ctx.ld_global_f32_predicated(k2_addr, in_bounds2, 0.0);
let k3_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_96);
let k3_off_bytes = ctx.mul_wide_u32_reg(k3_elem_off, four);
let k3_addr = ctx.add_u64(k_head_ptr, k3_off_bytes);
let k3 = ctx.ld_global_f32_predicated(k3_addr, in_bounds3, 0.0);
let dot = ctx.mul_f32(q0, k0);
ctx.fma_f32_inplace(dot, q1, k1);
ctx.fma_f32_inplace(dot, q2, k2);
ctx.fma_f32_inplace(dot, q3, k3);
for delta in [16, 8, 4, 2, 1] {
let other = ctx.shfl_down_f32(dot, delta, 0xFFFF_FFFF);
ctx.add_f32_inplace(dot, other);
}
let score = ctx.mul_f32(dot, scale_reg);
let old_max = max_score;
ctx.max_f32_inplace(max_score, score);
let score_minus_max = ctx.sub_f32(score, max_score);
let score_log2 = ctx.mul_f32(score_minus_max, log2e);
let exp_score = ctx.ex2_f32(score_log2);
let old_minus_new = ctx.sub_f32(old_max, max_score);
let log2_old = ctx.mul_f32(old_minus_new, log2e);
let correction = ctx.ex2_f32(log2_old);
ctx.mul_f32_inplace(sum_exp, correction);
ctx.add_f32_inplace(sum_exp, exp_score);
ctx.mul_f32_inplace(out0, correction);
ctx.mul_f32_inplace(out1, correction);
ctx.mul_f32_inplace(out2, correction);
ctx.mul_f32_inplace(out3, correction);
let v0_addr = ctx.add_u64(v_head_ptr, k0_off_bytes);
let v0 = ctx.ld_global_f32_predicated(v0_addr, in_bounds0, 0.0);
ctx.fma_f32_inplace(out0, exp_score, v0);
let v1_addr = ctx.add_u64(v_head_ptr, k1_off_bytes);
let v1 = ctx.ld_global_f32_predicated(v1_addr, in_bounds1, 0.0);
ctx.fma_f32_inplace(out1, exp_score, v1);
let v2_addr = ctx.add_u64(v_head_ptr, k2_off_bytes);
let v2 = ctx.ld_global_f32_predicated(v2_addr, in_bounds2, 0.0);
ctx.fma_f32_inplace(out2, exp_score, v2);
let v3_addr = ctx.add_u64(v_head_ptr, k3_off_bytes);
let v3 = ctx.ld_global_f32_predicated(v3_addr, in_bounds3, 0.0);
ctx.fma_f32_inplace(out3, exp_score, v3);
ctx.add_u32_inplace(pos, 1);
ctx.branch("batched_seq_loop");
ctx.label("batched_seq_loop_end");
let one = ctx.mov_f32_imm(1.0);
let inv_sum = ctx.div_f32(one, sum_exp);
ctx.mul_f32_inplace(out0, inv_sum);
ctx.mul_f32_inplace(out1, inv_sum);
ctx.mul_f32_inplace(out2, inv_sum);
ctx.mul_f32_inplace(out3, inv_sum);
let out0_addr = ctx.add_u64(out_head_ptr, q0_off_bytes);
ctx.branch_if_not(in_bounds0, "batched_skip_store0");
ctx.st_global_f32(out0_addr, out0);
ctx.label("batched_skip_store0");
let out1_addr = ctx.add_u64(out_head_ptr, q1_off_bytes);
ctx.branch_if_not(in_bounds1, "batched_skip_store1");
ctx.st_global_f32(out1_addr, out1);
ctx.label("batched_skip_store1");
let out2_addr = ctx.add_u64(out_head_ptr, q2_off_bytes);
ctx.branch_if_not(in_bounds2, "batched_skip_store2");
ctx.st_global_f32(out2_addr, out2);
ctx.label("batched_skip_store2");
let out3_addr = ctx.add_u64(out_head_ptr, q3_off_bytes);
ctx.branch_if_not(in_bounds3, "batched_skip_store3");
ctx.st_global_f32(out3_addr, out3);
ctx.label("batched_skip_store3");
ctx.ret();
})
}
}
pub const FLASH_DECODE_CHUNK_SIZE: u32 = 128;
#[derive(Debug, Clone)]
pub struct FlashDecodingChunkKernel {
pub max_seq_len: u32,
pub head_dim: u32,
pub num_heads: u32,
pub num_kv_heads: u32,
pub batch_size: u32,
pub chunk_size: u32,
pub scale: f32,
}
impl FlashDecodingChunkKernel {
#[must_use]
pub fn new(
max_seq_len: u32,
head_dim: u32,
num_heads: u32,
num_kv_heads: u32,
batch_size: u32,
) -> Self {
Self {
max_seq_len,
head_dim,
num_heads,
num_kv_heads,
batch_size,
chunk_size: FLASH_DECODE_CHUNK_SIZE,
scale: 1.0 / (head_dim as f32).sqrt(),
}
}
#[must_use]
pub fn num_chunks(&self, seq_len: u32) -> u32 {
(seq_len + self.chunk_size - 1) / self.chunk_size
}
#[must_use]
pub fn partials_size_per_head(&self, max_chunks: u32) -> u32 {
max_chunks * (self.head_dim + 2)
}
}
impl Kernel for FlashDecodingChunkKernel {
fn name(&self) -> &str {
"flash_decoding_chunk"
}
fn build_ptx(&self) -> PtxKernel {
let head_dim = self.head_dim;
let scale = self.scale;
let max_seq_len = self.max_seq_len;
let num_heads = self.num_heads;
let num_kv_heads = self.num_kv_heads;
let chunk_size = self.chunk_size;
let _batch_size = self.batch_size;
PtxKernel::new("flash_decoding_chunk")
.param(PtxType::U64, "q_ptr") .param(PtxType::U64, "k_ptrs_ptr") .param(PtxType::U64, "v_ptrs_ptr") .param(PtxType::U64, "partials_ptr") .param(PtxType::U64, "seq_lens_ptr") .param(PtxType::U32, "max_chunks") .shared_memory(0)
.build(move |ctx| {
let head_idx = ctx.special_reg(PtxReg::CtaIdX);
let batch_idx = ctx.special_reg(PtxReg::CtaIdY);
let chunk_idx = ctx.special_reg(PtxReg::CtaIdZ);
let lane_id = ctx.special_reg(PtxReg::TidX);
let q_ptr = ctx.load_param_u64("q_ptr");
let k_ptrs_ptr = ctx.load_param_u64("k_ptrs_ptr");
let v_ptrs_ptr = ctx.load_param_u64("v_ptrs_ptr");
let partials_ptr = ctx.load_param_u64("partials_ptr");
let seq_lens_ptr = ctx.load_param_u64("seq_lens_ptr");
let max_chunks_param = ctx.load_param_u32("max_chunks");
let four = ctx.mov_u32_imm(4);
let eight = ctx.mov_u32_imm(8);
let batch_idx_bytes = ctx.mul_wide_u32_reg(batch_idx, four);
let seq_len_addr = ctx.add_u64(seq_lens_ptr, batch_idx_bytes);
let seq_len = ctx.ld_global_u32(seq_len_addr);
let chunk_size_u32 = ctx.mov_u32_imm(chunk_size);
let chunk_start = ctx.mul_lo_u32(chunk_idx, chunk_size_u32);
let chunk_end_raw = ctx.add_u32(chunk_start, chunk_size); let chunk_end = ctx.min_u32(chunk_end_raw, seq_len);
let has_work = ctx.setp_lt_u32(chunk_start, seq_len);
ctx.branch_if_not(has_work, "flash_decode_chunk_empty");
let batch_ptr_off = ctx.mul_wide_u32_reg(batch_idx, eight);
let k_ptr_addr = ctx.add_u64(k_ptrs_ptr, batch_ptr_off);
let v_ptr_addr = ctx.add_u64(v_ptrs_ptr, batch_ptr_off);
let k_cache_ptr = ctx.ld_global_u64(k_ptr_addr);
let v_cache_ptr = ctx.ld_global_u64(v_ptr_addr);
let head_dim_u32 = ctx.mov_u32_imm(head_dim);
let num_heads_u32 = ctx.mov_u32_imm(num_heads);
let batch_head_stride = ctx.mul_lo_u32(num_heads_u32, head_dim_u32);
let batch_off = ctx.mul_lo_u32(batch_idx, batch_head_stride);
let head_off = ctx.mul_lo_u32(head_idx, head_dim_u32);
let q_head_off = ctx.add_u32_reg(batch_off, head_off);
let q_head_off_bytes = ctx.mul_wide_u32_reg(q_head_off, four);
let q_head_ptr = ctx.add_u64(q_ptr, q_head_off_bytes);
let kv_head_idx = ctx.mul_u32(head_idx, num_kv_heads);
let kv_head_idx = ctx.div_u32(kv_head_idx, num_heads);
let kv_stride = ctx.mov_u32_imm(max_seq_len * head_dim);
let kv_head_off = ctx.mul_lo_u32(kv_head_idx, kv_stride);
let kv_head_off_bytes = ctx.mul_wide_u32_reg(kv_head_off, four);
let k_head_ptr = ctx.add_u64(k_cache_ptr, kv_head_off_bytes);
let v_head_ptr = ctx.add_u64(v_cache_ptr, kv_head_off_bytes);
let q0_off_bytes = ctx.mul_wide_u32_reg(lane_id, four);
let q0_addr = ctx.add_u64(q_head_ptr, q0_off_bytes);
let in_bounds0 = ctx.setp_lt_u32(lane_id, head_dim_u32);
let q0 = ctx.ld_global_f32_predicated(q0_addr, in_bounds0, 0.0);
let lane_plus_32 = ctx.add_u32(lane_id, 32);
let q1_off_bytes = ctx.mul_wide_u32_reg(lane_plus_32, four);
let q1_addr = ctx.add_u64(q_head_ptr, q1_off_bytes);
let in_bounds1 = ctx.setp_lt_u32(lane_plus_32, head_dim_u32);
let q1 = ctx.ld_global_f32_predicated(q1_addr, in_bounds1, 0.0);
let lane_plus_64 = ctx.add_u32(lane_id, 64);
let q2_off_bytes = ctx.mul_wide_u32_reg(lane_plus_64, four);
let q2_addr = ctx.add_u64(q_head_ptr, q2_off_bytes);
let in_bounds2 = ctx.setp_lt_u32(lane_plus_64, head_dim_u32);
let q2 = ctx.ld_global_f32_predicated(q2_addr, in_bounds2, 0.0);
let lane_plus_96 = ctx.add_u32(lane_id, 96);
let q3_off_bytes = ctx.mul_wide_u32_reg(lane_plus_96, four);
let q3_addr = ctx.add_u64(q_head_ptr, q3_off_bytes);
let in_bounds3 = ctx.setp_lt_u32(lane_plus_96, head_dim_u32);
let q3 = ctx.ld_global_f32_predicated(q3_addr, in_bounds3, 0.0);
let out0 = ctx.mov_f32_imm(0.0);
let out1 = ctx.mov_f32_imm(0.0);
let out2 = ctx.mov_f32_imm(0.0);
let out3 = ctx.mov_f32_imm(0.0);
let max_score = ctx.mov_f32_imm(f32::NEG_INFINITY);
let sum_exp = ctx.mov_f32_imm(0.0);
let log2e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let scale_reg = ctx.mov_f32_imm(scale);
let pos = chunk_start;
ctx.label("flash_decode_chunk_loop");
let loop_cond = ctx.setp_lt_u32(pos, chunk_end);
ctx.branch_if_not(loop_cond, "flash_decode_chunk_loop_end");
let k_pos_off = ctx.mul_lo_u32(pos, head_dim_u32);
let k0_elem_off = ctx.add_u32_reg(k_pos_off, lane_id);
let k0_off_bytes = ctx.mul_wide_u32_reg(k0_elem_off, four);
let k0_addr = ctx.add_u64(k_head_ptr, k0_off_bytes);
let k0 = ctx.ld_global_f32_predicated(k0_addr, in_bounds0, 0.0);
let k1_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_32);
let k1_off_bytes = ctx.mul_wide_u32_reg(k1_elem_off, four);
let k1_addr = ctx.add_u64(k_head_ptr, k1_off_bytes);
let k1 = ctx.ld_global_f32_predicated(k1_addr, in_bounds1, 0.0);
let k2_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_64);
let k2_off_bytes = ctx.mul_wide_u32_reg(k2_elem_off, four);
let k2_addr = ctx.add_u64(k_head_ptr, k2_off_bytes);
let k2 = ctx.ld_global_f32_predicated(k2_addr, in_bounds2, 0.0);
let k3_elem_off = ctx.add_u32_reg(k_pos_off, lane_plus_96);
let k3_off_bytes = ctx.mul_wide_u32_reg(k3_elem_off, four);
let k3_addr = ctx.add_u64(k_head_ptr, k3_off_bytes);
let k3 = ctx.ld_global_f32_predicated(k3_addr, in_bounds3, 0.0);
let dot = ctx.mul_f32(q0, k0);
ctx.fma_f32_inplace(dot, q1, k1);
ctx.fma_f32_inplace(dot, q2, k2);
ctx.fma_f32_inplace(dot, q3, k3);
for delta in [16, 8, 4, 2, 1] {
let other = ctx.shfl_down_f32(dot, delta, 0xFFFF_FFFF);
ctx.add_f32_inplace(dot, other);
}
let score = ctx.mul_f32(dot, scale_reg);
let old_max = max_score;
ctx.max_f32_inplace(max_score, score);
let score_minus_max = ctx.sub_f32(score, max_score);
let score_log2 = ctx.mul_f32(score_minus_max, log2e);
let exp_score = ctx.ex2_f32(score_log2);
let old_minus_new = ctx.sub_f32(old_max, max_score);
let log2_old = ctx.mul_f32(old_minus_new, log2e);
let correction = ctx.ex2_f32(log2_old);
ctx.mul_f32_inplace(sum_exp, correction);
ctx.add_f32_inplace(sum_exp, exp_score);
ctx.mul_f32_inplace(out0, correction);
ctx.mul_f32_inplace(out1, correction);
ctx.mul_f32_inplace(out2, correction);
ctx.mul_f32_inplace(out3, correction);
let v0_addr = ctx.add_u64(v_head_ptr, k0_off_bytes);
let v0 = ctx.ld_global_f32_predicated(v0_addr, in_bounds0, 0.0);
ctx.fma_f32_inplace(out0, exp_score, v0);
let v1_addr = ctx.add_u64(v_head_ptr, k1_off_bytes);
let v1 = ctx.ld_global_f32_predicated(v1_addr, in_bounds1, 0.0);
ctx.fma_f32_inplace(out1, exp_score, v1);
let v2_addr = ctx.add_u64(v_head_ptr, k2_off_bytes);
let v2 = ctx.ld_global_f32_predicated(v2_addr, in_bounds2, 0.0);
ctx.fma_f32_inplace(out2, exp_score, v2);
let v3_addr = ctx.add_u64(v_head_ptr, k3_off_bytes);
let v3 = ctx.ld_global_f32_predicated(v3_addr, in_bounds3, 0.0);
ctx.fma_f32_inplace(out3, exp_score, v3);
ctx.add_u32_inplace(pos, 1);
ctx.branch("flash_decode_chunk_loop");
ctx.label("flash_decode_chunk_loop_end");
let head_dim_plus_2 = ctx.mov_u32_imm(head_dim + 2);
let partial_stride = ctx.mul_lo_u32(max_chunks_param, head_dim_plus_2);
let batch_partial_stride = ctx.mul_lo_u32(num_heads_u32, partial_stride);
let batch_partial_off = ctx.mul_lo_u32(batch_idx, batch_partial_stride);
let head_partial_off = ctx.mul_lo_u32(head_idx, partial_stride);
let chunk_partial_off = ctx.mul_lo_u32(chunk_idx, head_dim_plus_2);
let partial_off = ctx.add_u32_reg(batch_partial_off, head_partial_off);
let partial_off = ctx.add_u32_reg(partial_off, chunk_partial_off);
let partial_off_bytes = ctx.mul_wide_u32_reg(partial_off, four);
let partial_base = ctx.add_u64(partials_ptr, partial_off_bytes);
let out0_addr = ctx.add_u64(partial_base, q0_off_bytes);
ctx.branch_if_not(in_bounds0, "flash_decode_skip_out0");
ctx.st_global_f32(out0_addr, out0);
ctx.label("flash_decode_skip_out0");
let out1_addr = ctx.add_u64(partial_base, q1_off_bytes);
ctx.branch_if_not(in_bounds1, "flash_decode_skip_out1");
ctx.st_global_f32(out1_addr, out1);
ctx.label("flash_decode_skip_out1");
let out2_addr = ctx.add_u64(partial_base, q2_off_bytes);
ctx.branch_if_not(in_bounds2, "flash_decode_skip_out2");
ctx.st_global_f32(out2_addr, out2);
ctx.label("flash_decode_skip_out2");
let out3_addr = ctx.add_u64(partial_base, q3_off_bytes);
ctx.branch_if_not(in_bounds3, "flash_decode_skip_out3");
ctx.st_global_f32(out3_addr, out3);
ctx.label("flash_decode_skip_out3");
let zero_u32 = ctx.mov_u32_imm(0);
let is_lane0 = ctx.setp_eq_u32(lane_id, zero_u32);
ctx.branch_if_not(is_lane0, "flash_decode_skip_meta");
let max_off = ctx.mov_u32_imm(head_dim);
let max_off_bytes = ctx.mul_wide_u32_reg(max_off, four);
let max_addr = ctx.add_u64(partial_base, max_off_bytes);
ctx.st_global_f32(max_addr, max_score);
let sum_off = ctx.mov_u32_imm(head_dim + 1);
let sum_off_bytes = ctx.mul_wide_u32_reg(sum_off, four);
let sum_addr = ctx.add_u64(partial_base, sum_off_bytes);
ctx.st_global_f32(sum_addr, sum_exp);
ctx.label("flash_decode_skip_meta");
ctx.ret();
ctx.label("flash_decode_chunk_empty");
let head_dim_plus_2_e = ctx.mov_u32_imm(head_dim + 2);
let partial_stride_e = ctx.mul_lo_u32(max_chunks_param, head_dim_plus_2_e);
let batch_partial_stride_e = ctx.mul_lo_u32(num_heads_u32, partial_stride_e);
let batch_partial_off_e = ctx.mul_lo_u32(batch_idx, batch_partial_stride_e);
let head_partial_off_e = ctx.mul_lo_u32(head_idx, partial_stride_e);
let chunk_partial_off_e = ctx.mul_lo_u32(chunk_idx, head_dim_plus_2_e);
let partial_off_e = ctx.add_u32_reg(batch_partial_off_e, head_partial_off_e);
let partial_off_e = ctx.add_u32_reg(partial_off_e, chunk_partial_off_e);
let partial_off_bytes_e = ctx.mul_wide_u32_reg(partial_off_e, four);
let partial_base_e = ctx.add_u64(partials_ptr, partial_off_bytes_e);
let zero_u32_e = ctx.mov_u32_imm(0);
let is_lane0_e = ctx.setp_eq_u32(lane_id, zero_u32_e);
ctx.branch_if_not(is_lane0_e, "flash_decode_empty_done");
let neg_inf = ctx.mov_f32_imm(f32::NEG_INFINITY);
let max_off_e = ctx.mov_u32_imm(head_dim);
let max_off_bytes_e = ctx.mul_wide_u32_reg(max_off_e, four);
let max_addr_e = ctx.add_u64(partial_base_e, max_off_bytes_e);
ctx.st_global_f32(max_addr_e, neg_inf);
let zero = ctx.mov_f32_imm(0.0);
let sum_off_e = ctx.mov_u32_imm(head_dim + 1);
let sum_off_bytes_e = ctx.mul_wide_u32_reg(sum_off_e, four);
let sum_addr_e = ctx.add_u64(partial_base_e, sum_off_bytes_e);
ctx.st_global_f32(sum_addr_e, zero);
ctx.label("flash_decode_empty_done");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct FlashDecodingReduceKernel {
pub head_dim: u32,
pub num_heads: u32,
pub batch_size: u32,
pub chunk_size: u32,
}
impl FlashDecodingReduceKernel {
#[must_use]
pub fn new(head_dim: u32, num_heads: u32, batch_size: u32) -> Self {
Self {
head_dim,
num_heads,
batch_size,
chunk_size: FLASH_DECODE_CHUNK_SIZE,
}
}
}
impl Kernel for FlashDecodingReduceKernel {
fn name(&self) -> &str {
"flash_decoding_reduce"
}
fn build_ptx(&self) -> PtxKernel {
let head_dim = self.head_dim;
let num_heads = self.num_heads;
let chunk_size = self.chunk_size;
let _batch_size = self.batch_size;
PtxKernel::new("flash_decoding_reduce")
.param(PtxType::U64, "partials_ptr") .param(PtxType::U64, "output_ptr") .param(PtxType::U64, "seq_lens_ptr") .param(PtxType::U32, "max_chunks") .shared_memory(0)
.build(move |ctx| {
let head_idx = ctx.special_reg(PtxReg::CtaIdX);
let batch_idx = ctx.special_reg(PtxReg::CtaIdY);
let lane_id = ctx.special_reg(PtxReg::TidX);
let partials_ptr = ctx.load_param_u64("partials_ptr");
let output_ptr = ctx.load_param_u64("output_ptr");
let seq_lens_ptr = ctx.load_param_u64("seq_lens_ptr");
let max_chunks = ctx.load_param_u32("max_chunks");
let four = ctx.mov_u32_imm(4);
let batch_idx_bytes = ctx.mul_wide_u32_reg(batch_idx, four);
let seq_len_addr = ctx.add_u64(seq_lens_ptr, batch_idx_bytes);
let seq_len = ctx.ld_global_u32(seq_len_addr);
let seq_plus_chunk_m1 = ctx.add_u32(seq_len, chunk_size - 1);
let num_chunks = ctx.div_u32(seq_plus_chunk_m1, chunk_size);
let head_dim_u32 = ctx.mov_u32_imm(head_dim);
let num_heads_u32 = ctx.mov_u32_imm(num_heads);
let head_dim_plus_2 = ctx.mov_u32_imm(head_dim + 2);
let partial_stride = ctx.mul_lo_u32(max_chunks, head_dim_plus_2);
let batch_partial_stride = ctx.mul_lo_u32(num_heads_u32, partial_stride);
let batch_partial_off = ctx.mul_lo_u32(batch_idx, batch_partial_stride);
let head_partial_off = ctx.mul_lo_u32(head_idx, partial_stride);
let partial_base_off = ctx.add_u32_reg(batch_partial_off, head_partial_off);
let partial_base_off_bytes = ctx.mul_wide_u32_reg(partial_base_off, four);
let partial_base = ctx.add_u64(partials_ptr, partial_base_off_bytes);
let global_max = ctx.mov_f32_imm(f32::NEG_INFINITY);
let chunk_iter = ctx.mov_u32_imm(0);
ctx.label("reduce_max_loop");
let max_loop_cond = ctx.setp_lt_u32(chunk_iter, num_chunks);
ctx.branch_if_not(max_loop_cond, "reduce_max_loop_end");
let chunk_off = ctx.mul_lo_u32(chunk_iter, head_dim_plus_2);
let max_elem_off = ctx.add_u32(chunk_off, head_dim);
let max_elem_off_bytes = ctx.mul_wide_u32_reg(max_elem_off, four);
let chunk_max_addr = ctx.add_u64(partial_base, max_elem_off_bytes);
let zero_lane = ctx.mov_u32_imm(0);
let is_lane0 = ctx.setp_eq_u32(lane_id, zero_lane);
let chunk_max =
ctx.ld_global_f32_predicated(chunk_max_addr, is_lane0, f32::NEG_INFINITY);
let chunk_max = ctx.shfl_idx_f32(chunk_max, 0, 0xFFFF_FFFF);
ctx.max_f32_inplace(global_max, chunk_max);
ctx.add_u32_inplace(chunk_iter, 1);
ctx.branch("reduce_max_loop");
ctx.label("reduce_max_loop_end");
let log2e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let global_sum = ctx.mov_f32_imm(0.0);
let acc0 = ctx.mov_f32_imm(0.0);
let acc1 = ctx.mov_f32_imm(0.0);
let acc2 = ctx.mov_f32_imm(0.0);
let acc3 = ctx.mov_f32_imm(0.0);
let lane_plus_32 = ctx.add_u32(lane_id, 32);
let lane_plus_64 = ctx.add_u32(lane_id, 64);
let lane_plus_96 = ctx.add_u32(lane_id, 96);
let in_bounds0 = ctx.setp_lt_u32(lane_id, head_dim_u32);
let in_bounds1 = ctx.setp_lt_u32(lane_plus_32, head_dim_u32);
let in_bounds2 = ctx.setp_lt_u32(lane_plus_64, head_dim_u32);
let in_bounds3 = ctx.setp_lt_u32(lane_plus_96, head_dim_u32);
let chunk_iter2 = ctx.mov_u32_imm(0);
ctx.label("reduce_acc_loop");
let acc_loop_cond = ctx.setp_lt_u32(chunk_iter2, num_chunks);
ctx.branch_if_not(acc_loop_cond, "reduce_acc_loop_end");
let chunk_off2 = ctx.mul_lo_u32(chunk_iter2, head_dim_plus_2);
let max_elem_off2 = ctx.add_u32(chunk_off2, head_dim);
let sum_elem_off2 = ctx.add_u32(chunk_off2, head_dim);
let sum_elem_off2 = ctx.add_u32(sum_elem_off2, 1);
let max_off_bytes2 = ctx.mul_wide_u32_reg(max_elem_off2, four);
let sum_off_bytes2 = ctx.mul_wide_u32_reg(sum_elem_off2, four);
let chunk_max_addr2 = ctx.add_u64(partial_base, max_off_bytes2);
let chunk_sum_addr2 = ctx.add_u64(partial_base, sum_off_bytes2);
let chunk_max2 =
ctx.ld_global_f32_predicated(chunk_max_addr2, is_lane0, f32::NEG_INFINITY);
let chunk_sum2 = ctx.ld_global_f32_predicated(chunk_sum_addr2, is_lane0, 0.0);
let chunk_max2 = ctx.shfl_idx_f32(chunk_max2, 0, 0xFFFF_FFFF);
let chunk_sum2 = ctx.shfl_idx_f32(chunk_sum2, 0, 0xFFFF_FFFF);
let neg_inf_check = ctx.mov_f32_imm(-1e30);
let is_valid = ctx.setp_gt_f32(chunk_max2, neg_inf_check);
ctx.branch_if_not(is_valid, "reduce_skip_chunk");
let max_diff = ctx.sub_f32(chunk_max2, global_max);
let max_diff_log2 = ctx.mul_f32(max_diff, log2e);
let scale_factor = ctx.ex2_f32(max_diff_log2);
let scaled_sum = ctx.mul_f32(chunk_sum2, scale_factor);
ctx.add_f32_inplace(global_sum, scaled_sum);
let chunk_base_off_bytes = ctx.mul_wide_u32_reg(chunk_off2, four);
let chunk_base = ctx.add_u64(partial_base, chunk_base_off_bytes);
let out0_off_bytes = ctx.mul_wide_u32_reg(lane_id, four);
let out0_addr = ctx.add_u64(chunk_base, out0_off_bytes);
let out0 = ctx.ld_global_f32_predicated(out0_addr, in_bounds0, 0.0);
let scaled_out0 = ctx.mul_f32(out0, scale_factor);
ctx.add_f32_inplace(acc0, scaled_out0);
let out1_off_bytes = ctx.mul_wide_u32_reg(lane_plus_32, four);
let out1_addr = ctx.add_u64(chunk_base, out1_off_bytes);
let out1 = ctx.ld_global_f32_predicated(out1_addr, in_bounds1, 0.0);
let scaled_out1 = ctx.mul_f32(out1, scale_factor);
ctx.add_f32_inplace(acc1, scaled_out1);
let out2_off_bytes = ctx.mul_wide_u32_reg(lane_plus_64, four);
let out2_addr = ctx.add_u64(chunk_base, out2_off_bytes);
let out2 = ctx.ld_global_f32_predicated(out2_addr, in_bounds2, 0.0);
let scaled_out2 = ctx.mul_f32(out2, scale_factor);
ctx.add_f32_inplace(acc2, scaled_out2);
let out3_off_bytes = ctx.mul_wide_u32_reg(lane_plus_96, four);
let out3_addr = ctx.add_u64(chunk_base, out3_off_bytes);
let out3 = ctx.ld_global_f32_predicated(out3_addr, in_bounds3, 0.0);
let scaled_out3 = ctx.mul_f32(out3, scale_factor);
ctx.add_f32_inplace(acc3, scaled_out3);
ctx.label("reduce_skip_chunk");
ctx.add_u32_inplace(chunk_iter2, 1);
ctx.branch("reduce_acc_loop");
ctx.label("reduce_acc_loop_end");
let one = ctx.mov_f32_imm(1.0);
let inv_sum = ctx.div_f32(one, global_sum);
ctx.mul_f32_inplace(acc0, inv_sum);
ctx.mul_f32_inplace(acc1, inv_sum);
ctx.mul_f32_inplace(acc2, inv_sum);
ctx.mul_f32_inplace(acc3, inv_sum);
let batch_head_stride = ctx.mul_lo_u32(num_heads_u32, head_dim_u32);
let batch_off = ctx.mul_lo_u32(batch_idx, batch_head_stride);
let head_off = ctx.mul_lo_u32(head_idx, head_dim_u32);
let out_base_off = ctx.add_u32_reg(batch_off, head_off);
let out_base_off_bytes = ctx.mul_wide_u32_reg(out_base_off, four);
let out_base = ctx.add_u64(output_ptr, out_base_off_bytes);
let final_out0_addr = ctx.add_u64(out_base, out0_off_bytes);
ctx.branch_if_not(in_bounds0, "reduce_skip_store0");
ctx.st_global_f32(final_out0_addr, acc0);
ctx.label("reduce_skip_store0");
let final_out1_addr = ctx.add_u64(out_base, out1_off_bytes);
ctx.branch_if_not(in_bounds1, "reduce_skip_store1");
ctx.st_global_f32(final_out1_addr, acc1);
ctx.label("reduce_skip_store1");
let final_out2_addr = ctx.add_u64(out_base, out2_off_bytes);
ctx.branch_if_not(in_bounds2, "reduce_skip_store2");
ctx.st_global_f32(final_out2_addr, acc2);
ctx.label("reduce_skip_store2");
let final_out3_addr = ctx.add_u64(out_base, out3_off_bytes);
ctx.branch_if_not(in_bounds3, "reduce_skip_store3");
ctx.st_global_f32(final_out3_addr, acc3);
ctx.label("reduce_skip_store3");
ctx.ret();
})
}
}