use metaltile::{bench_kernel, kernel};
use crate::bench_types::DType;
const _: DType = DType::F32;
macro_rules! aura_flash_p1_kernel {
(
$name:ident,
$key_bits:literal,
$value_bits:literal,
$key_levels:literal,
$value_levels:literal,
$dims_per_lane:literal,
$causal:literal,
$subop:literal
) => {
#[bench_kernel(op="aura", subop=$subop, class=GenericEmpty, tol=0.0, kernel_mode=Grid3D,)]
#[kernel]
pub fn $name<T>(
q_rot: Tensor<T>,
key_packed: Tensor<u32>,
key_norms: Tensor<T>,
key_codebook: Tensor<T>,
val_packed: Tensor<u32>,
val_norms: Tensor<T>,
val_codebook: Tensor<T>,
mut o_partials: Tensor<T>,
mut m_partials: Tensor<T>,
mut l_partials: Tensor<T>,
#[constexpr] dim: u32,
#[constexpr] key_packed_width: u32,
#[constexpr] value_packed_width: u32,
#[constexpr] tokens: u32,
#[constexpr] repeat_count: u32,
#[constexpr] num_blocks: u32,
#[constexpr] block_size: u32,
#[constexpr] q_position: u32,
) {
let lane = program_id::<0>();
let q_idx = program_id::<1>();
let block_idx = program_id::<2>();
let kv_idx = q_idx / repeat_count;
let key_mask = (1u32 << $key_bits) - 1u32;
let val_mask = (1u32 << $value_bits) - 1u32;
let raw_end = block_idx * block_size + block_size;
let clamped_end = select(raw_end > tokens, tokens, raw_end);
let causal_end = select($causal == 1u32, q_position + 1u32, clamped_end);
let t_end = select(causal_end < clamped_end, causal_end, clamped_end);
let t_start = block_idx * block_size;
stack_alloc("key_cb", $key_levels, "f32");
for i in range(0u32, $key_levels, 1u32) {
stack_store("key_cb", i, load(key_codebook[i]).cast::<f32>());
}
stack_alloc("val_cb", $value_levels, "f32");
for i in range(0u32, $value_levels, 1u32) {
stack_store("val_cb", i, load(val_codebook[i]).cast::<f32>());
}
stack_alloc("q_vals", $dims_per_lane, "f32");
for i in range(0u32, $dims_per_lane, 1u32) {
let d = lane + i * 32u32;
let v = select(d < dim, load(q_rot[q_idx * dim + d]).cast::<f32>(), 0.0f32);
stack_store("q_vals", i, v);
}
let mut m_acc = neg_infinity();
let mut l_acc = 0.0f32;
stack_alloc("o", $dims_per_lane, "f32");
for i in range(0u32, $dims_per_lane, 1u32) {
stack_store("o", i, 0.0f32);
}
for t in range(t_start, t_end, 1u32) {
let k_packed_row = (kv_idx * tokens + t) * key_packed_width;
let k_norm = load(key_norms[kv_idx * tokens + t]).cast::<f32>();
let mut dot_partial = 0.0f32;
for i in range(0u32, $dims_per_lane, 1u32) {
let d = lane + i * 32u32;
if d < dim {
let bit_offset = d * $key_bits;
let word_idx = bit_offset / 32u32;
let shift = bit_offset & 31u32;
let bits_in_w0 = 32u32 - shift;
let lo_bits = select(bits_in_w0 >= $key_bits, $key_bits, bits_in_w0);
let spill = $key_bits - lo_bits;
let w0 = load(key_packed[k_packed_row + word_idx]);
let w1_idx = select(spill > 0u32, word_idx + 1u32, word_idx);
let w1 = load(key_packed[k_packed_row + w1_idx]);
let lo = (w0 >> shift) & ((1u32 << lo_bits) - 1u32);
let hi = (w1 & ((1u32 << spill) - 1u32)) << lo_bits;
let value = (lo | hi) & key_mask;
let centroid = stack_load("key_cb", value);
let qv = stack_load("q_vals", i);
dot_partial = dot_partial + qv * centroid;
}
}
let score = simd_sum(dot_partial) * k_norm;
let new_m = select(m_acc > score, m_acc, score);
let exp_diff = exp(m_acc - new_m);
let exp_score = exp(score - new_m);
let v_packed_row = (kv_idx * tokens + t) * value_packed_width;
let v_norm = load(val_norms[kv_idx * tokens + t]).cast::<f32>();
for i in range(0u32, $dims_per_lane, 1u32) {
let d = lane + i * 32u32;
if d < dim {
let bit_offset = d * $value_bits;
let word_idx = bit_offset / 32u32;
let shift = bit_offset & 31u32;
let bits_in_w0 = 32u32 - shift;
let lo_bits = select(bits_in_w0 >= $value_bits, $value_bits, bits_in_w0);
let spill = $value_bits - lo_bits;
let w0 = load(val_packed[v_packed_row + word_idx]);
let w1_idx = select(spill > 0u32, word_idx + 1u32, word_idx);
let w1 = load(val_packed[v_packed_row + w1_idx]);
let lo = (w0 >> shift) & ((1u32 << lo_bits) - 1u32);
let hi = (w1 & ((1u32 << spill) - 1u32)) << lo_bits;
let value = (lo | hi) & val_mask;
let prev = stack_load("o", i);
let centroid = stack_load("val_cb", value);
let upd = prev * exp_diff + exp_score * centroid * v_norm;
stack_store("o", i, upd);
}
}
l_acc = l_acc * exp_diff + exp_score;
m_acc = new_m;
}
let partial_base = (q_idx * num_blocks + block_idx) * dim;
for i in range(0u32, $dims_per_lane, 1u32) {
let d = lane + i * 32u32;
if d < dim {
store(o_partials[partial_base + d], stack_load("o", i).cast::<T>());
}
}
if lane == 0u32 {
let ml_idx = q_idx * num_blocks + block_idx;
store(m_partials[ml_idx], m_acc.cast::<T>());
store(l_partials[ml_idx], l_acc.cast::<T>());
}
}
};
}
aura_flash_p1_kernel!(
aura_flash_p1_kb4_vb2_d128,
4u32,
2u32,
16u32,
4u32,
4u32,
0u32,
"flash_p1_kb4_vb2_d128"
);
aura_flash_p1_kernel!(
aura_flash_p1_kb4_vb4_d128,
4u32,
4u32,
16u32,
16u32,
4u32,
0u32,
"flash_p1_kb4_vb4_d128"
);
aura_flash_p1_kernel!(
aura_flash_p1_kb4_vb2_d64,
4u32,
2u32,
16u32,
4u32,
2u32,
0u32,
"flash_p1_kb4_vb2_d64"
);
aura_flash_p1_kernel!(
aura_flash_p1_kb4_vb4_d64,
4u32,
4u32,
16u32,
16u32,
2u32,
0u32,
"flash_p1_kb4_vb4_d64"
);
aura_flash_p1_kernel!(
aura_flash_p1_causal_kb4_vb2_d128,
4u32,
2u32,
16u32,
4u32,
4u32,
1u32,
"flash_p1_causal_kb4_vb2_d128"
);
aura_flash_p1_kernel!(
aura_flash_p1_causal_kb4_vb2_d64,
4u32,
2u32,
16u32,
4u32,
2u32,
1u32,
"flash_p1_causal_kb4_vb2_d64"
);