use metaltile::{bench_kernel, kernel};
macro_rules! aura_encode_kernel {
($name:ident, $bits:literal, $levels:literal, $subop:literal) => {
#[bench_kernel(op="aura", subop=$subop, class=GenericEmpty, tol=0.0, kernel_mode=Reduction,)]
#[kernel]
pub fn $name<T>(
input: Tensor<T>,
rotation: Tensor<f32>,
boundaries: Tensor<f32>,
codebook: Tensor<f32>,
mut packed_out: Tensor<u32>,
mut norms_out: Tensor<f32>,
#[constexpr] dim: u32,
#[constexpr] packed_width: u32,
) {
let d = tid;
let row = tgid_x;
let val = load(input[row * dim + d]).cast::<f32>();
let sq = val * val;
let simd_norm_sq = simd_sum(sq);
threadgroup_alloc("shared_norm", 16);
let sg_id = d / 32u32;
let lane = d & 31u32;
if lane == 0u32 {
threadgroup_store("shared_norm", sg_id, simd_norm_sq);
}
threadgroup_barrier();
let mut total_norm_sq = 0.0f32;
let num_groups = (dim + 31u32) / 32u32;
for i in range(0u32, num_groups, 1u32) {
total_norm_sq = total_norm_sq + threadgroup_load("shared_norm", i);
}
let norm_val = sqrt(total_norm_sq);
let inv_norm = select(norm_val > 1.0e-8f32, 1.0f32 / norm_val, 0.0f32);
let unit_val = val * inv_norm;
threadgroup_alloc("shared_unit", 1024);
threadgroup_store("shared_unit", d, unit_val);
threadgroup_barrier();
let mut rotated = 0.0f32;
for j in range(0u32, dim, 1u32) {
rotated =
rotated + load(rotation[d * dim + j]) * threadgroup_load("shared_unit", j);
}
let mut idx = 0u32;
for b in range(0u32, $levels - 1u32, 1u32) {
idx = idx + (rotated > load(boundaries[b])).cast::<u32>();
}
let bit_offset = d * $bits;
let word_idx = bit_offset / 32u32;
let shift = bit_offset & 31u32;
let masked = idx & ((1u32 << $bits) - 1u32);
threadgroup_alloc("shared_packed", 128, "u32");
if d < packed_width {
threadgroup_store("shared_packed", d, 0u32);
}
threadgroup_barrier();
atomic_or_tg("shared_packed", word_idx, masked << shift);
let total_bits = shift + $bits;
if total_bits > 32u32 {
let spill_u = total_bits - 32u32;
atomic_or_tg("shared_packed", word_idx + 1u32, masked >> ($bits - spill_u));
}
threadgroup_barrier();
if d < packed_width {
store(packed_out[row * packed_width + d], threadgroup_load("shared_packed", d));
}
let centroid_val = load(codebook[idx]);
let recon_sq = centroid_val * centroid_val;
let simd_recon_sq = simd_sum(recon_sq);
if lane == 0u32 {
threadgroup_store("shared_norm", sg_id, simd_recon_sq);
}
threadgroup_barrier();
let mut total_recon_sq = 0.0f32;
for i in range(0u32, num_groups, 1u32) {
total_recon_sq = total_recon_sq + threadgroup_load("shared_norm", i);
}
let recon_norm = sqrt(total_recon_sq);
let corrected_norm = select(recon_norm > 1.0e-8f32, norm_val / recon_norm, norm_val);
if d == 0u32 {
store(norms_out[row], corrected_norm);
}
}
};
}
aura_encode_kernel!(aura_encode_int2, 2u32, 4u32, "encode_int2");
aura_encode_kernel!(aura_encode_int3, 3u32, 8u32, "encode_int3");
aura_encode_kernel!(aura_encode_int4, 4u32, 16u32, "encode_int4");
aura_encode_kernel!(aura_encode_int6, 6u32, 64u32, "encode_int6");
aura_encode_kernel!(aura_encode_int8, 8u32, 256u32, "encode_int8");