use metaltile::{bench_kernel, kernel};
macro_rules! aura_dequant_rotated_clean {
($name:ident, $bits:literal, $subop:literal) => {
#[bench_kernel(op="aura", subop=$subop, class=GenericEmpty, tol=0.0, kernel_mode=Grid3D,)]
#[kernel]
pub fn $name<T>(
packed: Tensor<u32>,
norms: Tensor<f32>,
codebook: Tensor<f32>,
mut out: Tensor<T>,
#[constexpr] dim: u32,
#[constexpr] packed_width: u32,
#[constexpr] tokens: u32,
) {
let w = program_id::<0>();
let t = program_id::<1>();
let bh = program_id::<2>();
let mask = (1u32 << $bits) - 1u32;
let dims_per_word = 32u32 / $bits;
let base = (bh * tokens + t) * packed_width;
let word = load(packed[base + w]);
let norm_val = load(norms[bh * tokens + t]);
let d_base = w * dims_per_word;
let out_row_base = (bh * tokens + t) * dim + d_base;
for k in range(0u32, dims_per_word, 1u32) {
let d = d_base + k;
if d < dim {
let val = (word >> (k * $bits)) & mask;
let centroid = load(codebook[val]);
let result = centroid * norm_val;
store(out[out_row_base + k], result.cast::<T>());
}
}
}
};
}
macro_rules! aura_dequant_rotated_odd {
($name:ident, $bits:literal, $subop:literal) => {
#[bench_kernel(op="aura", subop=$subop, class=GenericEmpty, tol=0.0, kernel_mode=Grid3D,)]
#[kernel]
pub fn $name<T>(
packed: Tensor<u32>,
norms: Tensor<f32>,
codebook: Tensor<f32>,
mut out: Tensor<T>,
#[constexpr] dim: u32,
#[constexpr] packed_width: u32,
#[constexpr] tokens: u32,
) {
let w = program_id::<0>();
let t = program_id::<1>();
let bh = program_id::<2>();
let mask = (1u32 << $bits) - 1u32;
let dims_per_word = (32u32 + $bits - 1u32) / $bits;
let base = (bh * tokens + t) * packed_width;
let norm_val = load(norms[bh * tokens + t]);
let d_base = w * dims_per_word;
for k in range(0u32, dims_per_word, 1u32) {
let d = d_base + k;
if d < dim {
let bit_offset = d * $bits;
let word_idx = bit_offset / 32u32;
let bit_in_w = bit_offset & 31u32;
let bits_in_w0 = 32u32 - bit_in_w;
let lo_bits = select(bits_in_w0 >= $bits, $bits, bits_in_w0);
let spill = $bits - lo_bits;
let w0 = load(packed[base + word_idx]);
let w1_idx = select(spill > 0u32, word_idx + 1u32, word_idx);
let w1 = load(packed[base + w1_idx]);
let lo = (w0 >> bit_in_w) & ((1u32 << lo_bits) - 1u32);
let hi = (w1 & ((1u32 << spill) - 1u32)) << lo_bits;
let val = (lo | hi) & mask;
let centroid = load(codebook[val]);
let result = centroid * norm_val;
store(out[(bh * tokens + t) * dim + d], result.cast::<T>());
}
}
}
};
}
aura_dequant_rotated_clean!(aura_dequant_rotated_int2, 2u32, "dequant_rotated_int2");
aura_dequant_rotated_clean!(aura_dequant_rotated_int4, 4u32, "dequant_rotated_int4");
aura_dequant_rotated_clean!(aura_dequant_rotated_int8, 8u32, "dequant_rotated_int8");
aura_dequant_rotated_odd!(aura_dequant_rotated_int3, 3u32, "dequant_rotated_int3");
aura_dequant_rotated_odd!(aura_dequant_rotated_int6, 6u32, "dequant_rotated_int6");