use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="dequant_gemv_expert_indexed",
subop="int4",
class=GenericEmpty,
tol=0.0,
kernel_mode=Reduction,
)]
#[kernel]
pub fn dequant_gemv_int4_expert_indexed<T>(
weights_stacked: Tensor<u32>,
scales_stacked: Tensor<T>,
biases_stacked: Tensor<T>,
input: Tensor<T>,
expert_index: Tensor<u32>,
output: Tensor<T>,
#[constexpr] in_dim: u32,
#[constexpr] out_dim: u32,
#[constexpr] group_size: u32,
) {
let vals_per_pack = 8u32;
let mask = 0xFu32;
let row = program_id::<0>();
let n_packs_per_row = in_dim / vals_per_pack;
let n_groups = in_dim / group_size;
let packs_per_group = group_size / vals_per_pack;
let expert = load(expert_index[0u32]);
let weight_expert_off = expert * out_dim * n_packs_per_row;
let scale_expert_off = expert * out_dim * n_groups;
let row_pack_off = weight_expert_off + row * n_packs_per_row;
let row_group_off = scale_expert_off + row * n_groups;
let mut acc = 0.0f32;
let p_iters = (n_packs_per_row + lsize - 1u32) / lsize;
for p_iter in range(0u32, p_iters, 1u32) {
let pack_idx = p_iter * lsize + tid;
if pack_idx < n_packs_per_row {
let g = pack_idx / packs_per_group;
let scale = load(scales_stacked[row_group_off + g]).cast::<f32>();
let bias = load(biases_stacked[row_group_off + g]).cast::<f32>();
let packed = load(weights_stacked[row_pack_off + pack_idx]);
let p_off = pack_idx * vals_per_pack;
for i in range(0u32, vals_per_pack, 1u32) {
let q = (packed >> (i * 4u32)) & mask;
acc = acc + (q.cast::<f32>() * scale + bias) * load(input[p_off + i]).cast::<f32>();
}
}
}
let total = reduce_sum(acc);
if tid == 0u32 {
store(output[row], total.cast::<T>());
}
}