use super::AttentionKernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl, PtxMemory};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
impl AttentionKernel {
pub(super) fn build_flash_attention(&self) -> PtxKernel {
let head_dim = self.head_dim;
let tile_q = self.tile_q;
let tile_kv = self.tile_kv;
let scale = self.scale;
let causal = self.causal;
let smem_size = (tile_q * head_dim + tile_kv * head_dim * 2) * 4;
let kernel_name = if causal { "flash_attention_causal" } else { "flash_attention" };
PtxKernel::new(kernel_name)
.param(PtxType::U64, "q_ptr")
.param(PtxType::U64, "k_ptr")
.param(PtxType::U64, "v_ptr")
.param(PtxType::U64, "o_ptr")
.param(PtxType::U32, "seq_len")
.param(PtxType::U32, "head_dim")
.param(PtxType::U32, "num_heads")
.shared_memory(smem_size as usize)
.build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid_x = ctx.special_reg(PtxReg::CtaIdX);
let ctaid_y = ctx.special_reg(PtxReg::CtaIdY);
let _ntid = ctx.special_reg(PtxReg::NtidX);
let seq_len_param = ctx.load_param_u32("seq_len");
let head_dim_param = ctx.load_param_u32("head_dim");
let num_heads = ctx.load_param_u32("num_heads");
let q_ptr = ctx.load_param_u64("q_ptr");
let k_ptr = ctx.load_param_u64("k_ptr");
let v_ptr = ctx.load_param_u64("v_ptr");
let o_ptr = ctx.load_param_u64("o_ptr");
let q_block = ctaid_x;
let head_idx = ctaid_y;
let head_valid = ctx.setp_lt_u32(head_idx, num_heads);
let head_stride = ctx.mul_u32_reg(seq_len_param, head_dim_param);
let head_offset = ctx.mul_wide_u32_reg(head_idx, head_stride);
let head_offset_bytes = ctx.mul_u64(head_offset, 4);
let tile_q_imm = ctx.mov_u32_imm(tile_q);
let q_row_start = ctx.mul_u32_reg(q_block, tile_q_imm);
let q_tile_offset = ctx.mul_wide_u32_reg(q_row_start, head_dim_param);
let q_tile_offset_bytes = ctx.mul_u64(q_tile_offset, 4);
let q_base = ctx.add_u64(q_ptr, head_offset_bytes);
let q_tile_base = ctx.add_u64(q_base, q_tile_offset_bytes);
let local_row = ctx.div_u32(tid, head_dim);
let local_col = ctx.rem_u32(tid, head_dim);
let tile_q_check = ctx.mov_u32_imm(tile_q);
let thread_valid = ctx.setp_lt_u32(local_row, tile_q_check);
let o_acc = ctx.mov_f32_imm(0.0);
let m_prev = ctx.mov_f32_imm(f32::NEG_INFINITY);
let l_prev = ctx.mov_f32_imm(0.0);
let tile_kv_imm = ctx.mov_u32_imm(tile_kv);
let num_kv_blocks = ctx.div_u32(seq_len_param, tile_kv);
let local_row_64 = ctx.cvt_u64_u32(local_row);
let local_col_64 = ctx.cvt_u64_u32(local_col);
let head_dim_64 = ctx.cvt_u64_u32(head_dim_param);
let q_elem_offset = ctx.mul_u64_reg(local_row_64, head_dim_64);
let q_elem_offset_full = ctx.add_u64(q_elem_offset, local_col_64);
let q_elem_offset_bytes = ctx.mul_u64(q_elem_offset_full, 4);
let k_smem_base = tile_q * head_dim * 4;
let v_smem_base = (tile_q * head_dim + tile_kv * head_dim) * 4;
let kv_block = ctx.mov_u32_imm(0);
ctx.label("kv_loop_start");
let kv_done = ctx.setp_ge_u32(kv_block, num_kv_blocks);
ctx.branch_if(kv_done, "kv_loop_end");
if causal {
let causal_skip = ctx.setp_lt_u32(q_block, kv_block);
ctx.branch_if(causal_skip, "kv_loop_end");
}
let kv_row_start = ctx.mul_u32_reg(kv_block, tile_kv_imm);
let kv_tile_offset = ctx.mul_wide_u32_reg(kv_row_start, head_dim_param);
let kv_tile_offset_bytes = ctx.mul_u64(kv_tile_offset, 4);
let k_base = ctx.add_u64(k_ptr, head_offset_bytes);
let k_tile_base = ctx.add_u64(k_base, kv_tile_offset_bytes);
let v_base = ctx.add_u64(v_ptr, head_offset_bytes);
let v_tile_base = ctx.add_u64(v_base, kv_tile_offset_bytes);
let q_addr = ctx.add_u64(q_tile_base, q_elem_offset_bytes);
let q_val = ctx.ld_global_f32(q_addr);
let q_smem_offset = ctx.mul_u32(tid, 4);
ctx.st_shared_f32(q_smem_offset, q_val);
let kv_total_reg = ctx.mov_u32_imm(tile_kv * head_dim);
let k_load_base = ctx.mov_u32_imm(0);
ctx.label("k_coop_load");
let k_elem_idx = ctx.add_u32_reg(k_load_base, tid);
let k_in_bounds = ctx.setp_lt_u32(k_elem_idx, kv_total_reg);
ctx.branch_if_not(k_in_bounds, "k_coop_load_end");
let k_offset = ctx.mul_wide_u32(k_elem_idx, 4);
let k_addr = ctx.add_u64(k_tile_base, k_offset);
let k_val = ctx.ld_global_f32(k_addr);
let k_smem_base_u32 = ctx.mov_u32_imm(k_smem_base);
let k_elem_bytes = ctx.mul_u32(k_elem_idx, 4);
let k_smem_off = ctx.add_u32_reg(k_smem_base_u32, k_elem_bytes);
ctx.st_shared_f32(k_smem_off, k_val);
ctx.add_u32_inplace(k_load_base, tile_q * head_dim);
ctx.branch("k_coop_load");
ctx.label("k_coop_load_end");
let v_load_base = ctx.mov_u32_imm(0);
ctx.label("v_coop_load");
let v_elem_idx = ctx.add_u32_reg(v_load_base, tid);
let v_in_bounds = ctx.setp_lt_u32(v_elem_idx, kv_total_reg);
ctx.branch_if_not(v_in_bounds, "v_coop_load_end");
let v_offset = ctx.mul_wide_u32(v_elem_idx, 4);
let v_addr = ctx.add_u64(v_tile_base, v_offset);
let v_val = ctx.ld_global_f32(v_addr);
let v_smem_base_u32 = ctx.mov_u32_imm(v_smem_base);
let v_elem_bytes = ctx.mul_u32(v_elem_idx, 4);
let v_smem_off = ctx.add_u32_reg(v_smem_base_u32, v_elem_bytes);
ctx.st_shared_f32(v_smem_off, v_val);
ctx.add_u32_inplace(v_load_base, tile_q * head_dim);
ctx.branch("v_coop_load");
ctx.label("v_coop_load_end");
ctx.bar_sync(0);
let k_row = ctx.mov_u32_imm(0);
let tile_kv_reg = ctx.mov_u32_imm(tile_kv);
ctx.label("k_row_loop_start");
let k_row_done = ctx.setp_ge_u32(k_row, tile_kv_reg);
ctx.branch_if(k_row_done, "k_row_loop_end");
if causal {
let q_global_row = ctx.add_u32_reg(q_row_start, local_row);
let k_global_row = ctx.add_u32_reg(kv_row_start, k_row);
let causal_mask = ctx.setp_lt_u32(q_global_row, k_global_row);
ctx.branch_if(causal_mask, "k_row_next");
}
let s_acc = ctx.mov_f32_imm(0.0);
let d_idx = ctx.mov_u32_imm(0);
let head_dim_u32 = ctx.mov_u32_imm(head_dim);
ctx.label("dot_loop_start");
let d_done = ctx.setp_ge_u32(d_idx, head_dim_param);
ctx.branch_if(d_done, "dot_loop_end");
let q_row_offset = ctx.mul_u32_reg(local_row, head_dim_u32);
let q_elem_smem = ctx.add_u32_reg(q_row_offset, d_idx);
let q_elem_smem_bytes = ctx.mul_u32(q_elem_smem, 4);
let q_dot_val = ctx.ld_shared_f32(q_elem_smem_bytes);
let k_row_offset = ctx.mul_u32_reg(k_row, head_dim_u32);
let k_elem_smem = ctx.add_u32_reg(k_row_offset, d_idx);
let k_elem_smem_bytes = ctx.mul_u32(k_elem_smem, 4);
let k_smem_base_loop = ctx.mov_u32_imm(k_smem_base);
let k_elem_smem_full = ctx.add_u32_reg(k_smem_base_loop, k_elem_smem_bytes);
let k_dot_val = ctx.ld_shared_f32(k_elem_smem_full);
ctx.fma_f32_inplace(s_acc, q_dot_val, k_dot_val);
ctx.add_u32_inplace(d_idx, 1);
ctx.branch("dot_loop_start");
ctx.label("dot_loop_end");
let scale_reg = ctx.mov_f32_imm(scale);
let s_scaled = ctx.mul_f32(s_acc, scale_reg);
let m_new = ctx.max_f32(m_prev, s_scaled);
let m_diff = ctx.sub_f32(m_prev, m_new);
let log2_e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
let m_diff_scaled = ctx.mul_f32(m_diff, log2_e);
let scale_factor = ctx.ex2_f32(m_diff_scaled);
let s_shifted = ctx.sub_f32(s_scaled, m_new);
let s_shifted_scaled = ctx.mul_f32(s_shifted, log2_e);
let p_val = ctx.ex2_f32(s_shifted_scaled);
let l_scaled = ctx.mul_f32(scale_factor, l_prev);
let l_new = ctx.add_f32(l_scaled, p_val);
let o_scaled = ctx.mul_f32(o_acc, scale_factor);
let o_weighted = ctx.mul_f32(o_scaled, l_prev);
let v_row_offset = ctx.mul_u32_reg(k_row, head_dim_u32);
let v_elem_idx = ctx.add_u32_reg(v_row_offset, local_col);
let v_elem_smem_bytes = ctx.mul_u32(v_elem_idx, 4);
let v_smem_base_loop = ctx.mov_u32_imm(v_smem_base);
let v_elem_smem_full = ctx.add_u32_reg(v_smem_base_loop, v_elem_smem_bytes);
let v_out_val = ctx.ld_shared_f32(v_elem_smem_full);
let pv = ctx.mul_f32(p_val, v_out_val);
let o_sum = ctx.add_f32(o_weighted, pv);
let o_new = ctx.div_f32(o_sum, l_new);
ctx.mov_f32_reg(m_prev, m_new);
ctx.mov_f32_reg(l_prev, l_new);
ctx.mov_f32_reg(o_acc, o_new);
ctx.label("k_row_next");
ctx.add_u32_inplace(k_row, 1);
ctx.branch("k_row_loop_start");
ctx.label("k_row_loop_end");
ctx.bar_sync(1);
ctx.add_u32_inplace(kv_block, 1);
ctx.branch("kv_loop_start");
ctx.label("kv_loop_end");
ctx.branch_if_not(head_valid, "exit");
ctx.branch_if_not(thread_valid, "exit");
let o_base = ctx.add_u64(o_ptr, head_offset_bytes);
let o_tile_offset = ctx.mul_wide_u32_reg(q_row_start, head_dim_param);
let o_tile_offset_bytes = ctx.mul_u64(o_tile_offset, 4);
let o_tile_base = ctx.add_u64(o_base, o_tile_offset_bytes);
let o_addr = ctx.add_u64(o_tile_base, q_elem_offset_bytes);
ctx.st_global_f32(o_addr, o_acc);
ctx.label("exit");
ctx.ret();
})
}
}