use metaltile::{bench_kernel, kernel};
macro_rules! dequant_gemv_pow2 {
($name:ident, $bits:literal, $subop:literal) => {
#[bench_kernel(op="dequant_gemv", subop=$subop, class=GenericEmpty, tol=0.0, kernel_mode=Reduction,)]
#[kernel]
pub fn $name<T>(
weight: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
input: Tensor<T>,
output: Tensor<T>,
#[constexpr] in_dim: u32,
#[constexpr] group_size: u32,
) {
let vals_per_pack = 32u32 / $bits;
let mask = (1u32 << $bits) - 1u32;
let row = program_id::<0>();
let n_packs_per_row = in_dim / vals_per_pack;
let n_groups = in_dim / group_size;
let packs_per_group = group_size / vals_per_pack;
let row_pack_off = row * n_packs_per_row;
let row_group_off = row * n_groups;
let mut acc = 0.0f32;
let p_iters = (n_packs_per_row + lsize - 1u32) / lsize;
for p_iter in range(0u32, p_iters, 1u32) {
let pack_idx = p_iter * lsize + tid;
if pack_idx < n_packs_per_row {
let g = pack_idx / packs_per_group;
let scale = load(scales[row_group_off + g]).cast::<f32>();
let bias = load(biases[row_group_off + g]).cast::<f32>();
let packed = load(weight[row_pack_off + pack_idx]);
let p_off = pack_idx * vals_per_pack;
for i in range(0u32, vals_per_pack, 1u32) {
let q = (packed >> (i * $bits)) & mask;
acc = acc
+ (q.cast::<f32>() * scale + bias)
* load(input[p_off + i]).cast::<f32>();
}
}
}
let total = reduce_sum(acc);
if tid == 0u32 {
store(output[row], total.cast::<T>());
}
}
};
}
macro_rules! dequant_gemv_odd {
($name:ident, $bits:literal, $subop:literal) => {
#[bench_kernel(op="dequant_gemv", subop=$subop, class=GenericEmpty, tol=0.0, kernel_mode=Reduction,)]
#[kernel]
pub fn $name<T>(
weight: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
input: Tensor<T>,
output: Tensor<T>,
#[constexpr] in_dim: u32,
#[constexpr] group_size: u32,
) {
let row = program_id::<0>();
let u32_per_row = in_dim * $bits / 32u32;
let n_groups = in_dim / group_size;
let row_u32_off = row * u32_per_row;
let row_group_off = row * n_groups;
let mut acc = 0.0f32;
let n_iters = (in_dim + lsize - 1u32) / lsize;
for _iter in range(0u32, n_iters, 1u32) {
let d = _iter * lsize + tid;
if d < in_dim {
let g = d / group_size;
let scale = load(scales[row_group_off + g]).cast::<f32>();
let bias = load(biases[row_group_off + g]).cast::<f32>();
let bit_off = d * $bits;
let word_idx = bit_off / 32u32;
let bit_in_w = bit_off & 31u32;
let bits_in_w0 = 32u32 - bit_in_w;
let lo_bits = select(bits_in_w0 >= $bits, $bits, bits_in_w0);
let spill = $bits - lo_bits;
let w0 = load(weight[row_u32_off + word_idx]);
let w1idx = select(spill > 0u32, word_idx + 1u32, word_idx);
let w1 = load(weight[row_u32_off + w1idx]);
let lo = (w0 >> bit_in_w) & ((1u32 << lo_bits) - 1u32);
let hi = (w1 & ((1u32 << spill) - 1u32)) << lo_bits;
let q = lo | hi;
acc = acc + (q.cast::<f32>() * scale + bias) * load(input[d]).cast::<f32>();
}
}
let total = reduce_sum(acc);
if tid == 0u32 {
store(output[row], total.cast::<T>());
}
}
};
}
dequant_gemv_pow2!(dequant_gemv_int4, 4u32, "int4");
dequant_gemv_pow2!(dequant_gemv_int8, 8u32, "int8");
dequant_gemv_odd!(dequant_gemv_int3, 3u32, "int3");
dequant_gemv_odd!(dequant_gemv_int5, 5u32, "int5");
dequant_gemv_odd!(dequant_gemv_int6, 6u32, "int6");
#[bench_kernel(
op="dequant_gemv",
subop="int4_fast",
class=GenericEmpty,
tol=0.0,
kernel_mode=Reduction,
)]
#[kernel]
pub fn dequant_gemv_int4_fast<T>(
weight: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
input: Tensor<T>,
output: Tensor<T>,
#[constexpr] in_dim: u32,
#[constexpr] group_size: u32,
) {
let tg = tgid_x;
let sg = simd_id;
let lane = simd_lane;
let base_row = tg * 8u32 + sg * 4u32;
let gs_per_row = in_dim / group_size;
let packs_per_row = in_dim / 8u32; let lane_x_off = lane * 16u32;
let lane_pack_off = lane * 2u32;
stack_alloc("accs", 4, "f32");
for _r in range(0u32, 4u32, 1u32) {
stack_store("accs", _r, 0.0f32);
}
let s_16 = 0.0625f32;
let s_256 = 0.00390625f32;
let s_4096 = 0.000244140625f32;
for _b in range(0u32, in_dim, 512u32) {
let xb = _b + lane_x_off;
let x0 = load(input[xb]).cast::<f32>();
let x1_raw = load(input[xb + 1u32]).cast::<f32>();
let x2_raw = load(input[xb + 2u32]).cast::<f32>();
let x3_raw = load(input[xb + 3u32]).cast::<f32>();
let x4 = load(input[xb + 4u32]).cast::<f32>();
let x5_raw = load(input[xb + 5u32]).cast::<f32>();
let x6_raw = load(input[xb + 6u32]).cast::<f32>();
let x7_raw = load(input[xb + 7u32]).cast::<f32>();
let x8 = load(input[xb + 8u32]).cast::<f32>();
let x9_raw = load(input[xb + 9u32]).cast::<f32>();
let x10_raw = load(input[xb + 10u32]).cast::<f32>();
let x11_raw = load(input[xb + 11u32]).cast::<f32>();
let x12 = load(input[xb + 12u32]).cast::<f32>();
let x13_raw = load(input[xb + 13u32]).cast::<f32>();
let x14_raw = load(input[xb + 14u32]).cast::<f32>();
let x15_raw = load(input[xb + 15u32]).cast::<f32>();
let xs = x0
+ x1_raw
+ x2_raw
+ x3_raw
+ x4
+ x5_raw
+ x6_raw
+ x7_raw
+ x8
+ x9_raw
+ x10_raw
+ x11_raw
+ x12
+ x13_raw
+ x14_raw
+ x15_raw;
let x1 = x1_raw * s_16;
let x2 = x2_raw * s_256;
let x3 = x3_raw * s_4096;
let x5 = x5_raw * s_16;
let x6 = x6_raw * s_256;
let x7 = x7_raw * s_4096;
let x9 = x9_raw * s_16;
let x10 = x10_raw * s_256;
let x11 = x11_raw * s_4096;
let x13 = x13_raw * s_16;
let x14 = x14_raw * s_256;
let x15 = x15_raw * s_4096;
let g = xb / group_size;
let pack_off = _b / 8u32 + lane_pack_off;
for _r in range(0u32, 4u32, 1u32) {
let row = base_row + _r;
let w_base = row * packs_per_row;
let sb_base = row * gs_per_row;
let p_lo = load(weight[w_base + pack_off]);
let p_hi_word = load(weight[w_base + pack_off + 1u32]);
let p_lo_hi = p_lo >> 16u32;
let p_hi_hi = p_hi_word >> 16u32;
let s = load(scales[sb_base + g]).cast::<f32>();
let bi = load(biases[sb_base + g]).cast::<f32>();
let qd = (p_lo & 15u32).cast::<f32>() * x0
+ (p_lo & 240u32).cast::<f32>() * x1
+ (p_lo & 3840u32).cast::<f32>() * x2
+ (p_lo & 61440u32).cast::<f32>() * x3
+ (p_lo_hi & 15u32).cast::<f32>() * x4
+ (p_lo_hi & 240u32).cast::<f32>() * x5
+ (p_lo_hi & 3840u32).cast::<f32>() * x6
+ (p_lo_hi & 61440u32).cast::<f32>() * x7
+ (p_hi_word & 15u32).cast::<f32>() * x8
+ (p_hi_word & 240u32).cast::<f32>() * x9
+ (p_hi_word & 3840u32).cast::<f32>() * x10
+ (p_hi_word & 61440u32).cast::<f32>() * x11
+ (p_hi_hi & 15u32).cast::<f32>() * x12
+ (p_hi_hi & 240u32).cast::<f32>() * x13
+ (p_hi_hi & 3840u32).cast::<f32>() * x14
+ (p_hi_hi & 61440u32).cast::<f32>() * x15;
let prev = stack_load("accs", _r);
stack_store("accs", _r, prev + s * qd + bi * xs);
}
}
for _r in range(0u32, 4u32, 1u32) {
let v = stack_load("accs", _r);
let r = simd_sum(v);
if lane == 0u32 {
store(output[base_row + _r], r.cast::<T>());
}
}
}
pub fn dequant_gemv_wants_indirect(kernel_name: &str) -> bool {
matches!(kernel_name, "dequant_gemv_int4_f16" | "dequant_gemv_int4_bf16")
}