use metaltile::{bench_kernel, kernel};
use crate::bench_types::DType;
const _: DType = DType::F32;
macro_rules! aura_score_kernel {
($name:ident, $bits:literal, $subop:literal) => {
#[bench_kernel(op="aura", subop=$subop, class=GenericEmpty, tol=0.0, kernel_mode=Reduction,)]
#[kernel]
pub fn $name<T>(
q_rot: Tensor<T>,
packed: Tensor<u32>,
norms: Tensor<T>,
codebook: Tensor<T>,
mut scores: Tensor<T>,
#[constexpr] dim: u32,
#[constexpr] packed_width: u32,
#[constexpr] tokens: u32,
#[constexpr] repeat_count: u32,
) {
let lane = tid;
let q_idx = tgid_x;
let k_idx = tgid_y;
let kv_idx = q_idx / repeat_count;
let mask = (1u32 << $bits) - 1u32;
let q_off = q_idx * dim;
let packed_row = (kv_idx * tokens + k_idx) * packed_width;
let norm_val = load(norms[kv_idx * tokens + k_idx]).cast::<f32>();
let mut acc = 0.0f32;
let iters = (dim + 31u32) / 32u32;
for it in range(0u32, iters, 1u32) {
let d = it * 32u32 + lane;
if d < dim {
let bit_offset = d * $bits;
let word_idx = bit_offset / 32u32;
let shift = bit_offset & 31u32;
let bits_in_w0 = 32u32 - shift;
let lo_bits = select(bits_in_w0 >= $bits, $bits, bits_in_w0);
let spill = $bits - lo_bits;
let w0 = load(packed[packed_row + word_idx]);
let w1_idx = select(spill > 0u32, word_idx + 1u32, word_idx);
let w1 = load(packed[packed_row + w1_idx]);
let lo = (w0 >> shift) & ((1u32 << lo_bits) - 1u32);
let hi = (w1 & ((1u32 << spill) - 1u32)) << lo_bits;
let value = (lo | hi) & mask;
let centroid = load(codebook[value]).cast::<f32>();
let qv = load(q_rot[q_off + d]).cast::<f32>();
acc = acc + qv * centroid;
}
}
let total = simd_sum(acc);
if lane == 0u32 {
store(scores[q_idx * tokens + k_idx], (total * norm_val).cast::<T>());
}
}
};
}
aura_score_kernel!(aura_score_int2, 2u32, "score_int2");
aura_score_kernel!(aura_score_int3, 3u32, "score_int3");
aura_score_kernel!(aura_score_int4, 4u32, "score_int4");
aura_score_kernel!(aura_score_int6, 6u32, "score_int6");
aura_score_kernel!(aura_score_int8, 8u32, "score_int8");