use metaltile::{bench_kernel, kernel};
static QUANTIZED_SHAPES: &[(usize, usize)] =
&[(4096, 4096), (5120, 5120), (14336, 5120), (5120, 14336), (27648, 5120)];
#[bench_kernel(
op="quantized",
subop="qmv",
class=QuantizedMatVec,
shapes=&QUANTIZED_SHAPES,
group_size=64,
// tpg=64 = 2 simdgroups × 32 lanes. Kernel processes 8 output rows
// per TG (each simdgroup handles 4 rows independently, indexed by
// simd_id). Dispatcher grid is `m/8` TGs — matches MLX qmv_fast.
tpg=64,
tol=1e-3,
mlx="affine_qmv_fast_float16_t_gs_64_b_4_batch_0",
metal_file="quantized.metal",
dtypes=&[metaltile_core::dtype::DType::F32, metaltile_core::dtype::DType::F16, metaltile_core::dtype::DType::BF16],
)]
#[kernel]
pub fn mt_qmv<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] gs_per_row: u32,
) {
let tg = tgid_x;
let sg = simd_id;
let lane = simd_lane;
let row0 = tg * 8u32 + sg * 4u32;
let row1 = row0 + 1u32;
let row2 = row0 + 2u32;
let row3 = row0 + 3u32;
let packs_per_row = k / 8u32;
let w_base0 = row0 * packs_per_row;
let w_base1 = row1 * packs_per_row;
let w_base2 = row2 * packs_per_row;
let w_base3 = row3 * packs_per_row;
let sb_base0 = row0 * gs_per_row;
let sb_base1 = row1 * gs_per_row;
let sb_base2 = row2 * gs_per_row;
let sb_base3 = row3 * gs_per_row;
let mut acc0 = 0.0f32;
let mut acc1 = 0.0f32;
let mut acc2 = 0.0f32;
let mut acc3 = 0.0f32;
let lane_x_off = lane * 16u32;
let lane_pack_off = lane * 2u32;
for _b in range(0u32, k, 512u32) {
let xb = _b + lane_x_off;
let xi0 = xb;
let xi1 = xb + 1u32;
let xi2 = xb + 2u32;
let xi3 = xb + 3u32;
let xi4 = xb + 4u32;
let xi5 = xb + 5u32;
let xi6 = xb + 6u32;
let xi7 = xb + 7u32;
let xi8 = xb + 8u32;
let xi9 = xb + 9u32;
let xi10 = xb + 10u32;
let xi11 = xb + 11u32;
let xi12 = xb + 12u32;
let xi13 = xb + 13u32;
let xi14 = xb + 14u32;
let xi15 = xb + 15u32;
let s_16 = 0.0625f32;
let s_256 = 0.00390625f32;
let s_4096 = 0.000244140625f32;
let x0 = load(x[xi0]).cast::<f32>();
let x1_raw = load(x[xi1]).cast::<f32>();
let x2_raw = load(x[xi2]).cast::<f32>();
let x3_raw = load(x[xi3]).cast::<f32>();
let x4 = load(x[xi4]).cast::<f32>();
let x5_raw = load(x[xi5]).cast::<f32>();
let x6_raw = load(x[xi6]).cast::<f32>();
let x7_raw = load(x[xi7]).cast::<f32>();
let x8 = load(x[xi8]).cast::<f32>();
let x9_raw = load(x[xi9]).cast::<f32>();
let x10_raw = load(x[xi10]).cast::<f32>();
let x11_raw = load(x[xi11]).cast::<f32>();
let x12 = load(x[xi12]).cast::<f32>();
let x13_raw = load(x[xi13]).cast::<f32>();
let x14_raw = load(x[xi14]).cast::<f32>();
let x15_raw = load(x[xi15]).cast::<f32>();
let xs = x0
+ x1_raw
+ x2_raw
+ x3_raw
+ x4
+ x5_raw
+ x6_raw
+ x7_raw
+ x8
+ x9_raw
+ x10_raw
+ x11_raw
+ x12
+ x13_raw
+ x14_raw
+ x15_raw;
let x1 = x1_raw * s_16;
let x2 = x2_raw * s_256;
let x3 = x3_raw * s_4096;
let x5 = x5_raw * s_16;
let x6 = x6_raw * s_256;
let x7 = x7_raw * s_4096;
let x9 = x9_raw * s_16;
let x10 = x10_raw * s_256;
let x11 = x11_raw * s_4096;
let x13 = x13_raw * s_16;
let x14 = x14_raw * s_256;
let x15 = x15_raw * s_4096;
let g = xb / 64u32;
let pack_off = _b / 8u32 + lane_pack_off;
let p00 = load(w[w_base0 + pack_off]);
let p01 = load(w[w_base0 + pack_off + 1u32]);
let p00_hi = p00 >> 16u32;
let p01_hi = p01 >> 16u32;
let s0 = load(scales[sb_base0 + g]).cast::<f32>();
let bi0 = load(biases[sb_base0 + g]).cast::<f32>();
let q00 = (p00 & 15u32).cast::<f32>();
let q01 = (p00 & 240u32).cast::<f32>();
let q02 = (p00 & 3840u32).cast::<f32>();
let q03 = (p00 & 61440u32).cast::<f32>();
let q04 = (p00_hi & 15u32).cast::<f32>();
let q05 = (p00_hi & 240u32).cast::<f32>();
let q06 = (p00_hi & 3840u32).cast::<f32>();
let q07 = (p00_hi & 61440u32).cast::<f32>();
let q08 = (p01 & 15u32).cast::<f32>();
let q09 = (p01 & 240u32).cast::<f32>();
let q010 = (p01 & 3840u32).cast::<f32>();
let q011 = (p01 & 61440u32).cast::<f32>();
let q012 = (p01_hi & 15u32).cast::<f32>();
let q013 = (p01_hi & 240u32).cast::<f32>();
let q014 = (p01_hi & 3840u32).cast::<f32>();
let q015 = (p01_hi & 61440u32).cast::<f32>();
let qd0 = q00 * x0
+ q01 * x1
+ q02 * x2
+ q03 * x3
+ q04 * x4
+ q05 * x5
+ q06 * x6
+ q07 * x7
+ q08 * x8
+ q09 * x9
+ q010 * x10
+ q011 * x11
+ q012 * x12
+ q013 * x13
+ q014 * x14
+ q015 * x15;
acc0 = acc0 + s0 * qd0 + bi0 * xs;
let p10 = load(w[w_base1 + pack_off]);
let p11 = load(w[w_base1 + pack_off + 1u32]);
let p10_hi = p10 >> 16u32;
let p11_hi = p11 >> 16u32;
let s1 = load(scales[sb_base1 + g]).cast::<f32>();
let bi1 = load(biases[sb_base1 + g]).cast::<f32>();
let q10 = (p10 & 15u32).cast::<f32>();
let q11 = (p10 & 240u32).cast::<f32>();
let q12 = (p10 & 3840u32).cast::<f32>();
let q13 = (p10 & 61440u32).cast::<f32>();
let q14 = (p10_hi & 15u32).cast::<f32>();
let q15 = (p10_hi & 240u32).cast::<f32>();
let q16 = (p10_hi & 3840u32).cast::<f32>();
let q17 = (p10_hi & 61440u32).cast::<f32>();
let q18 = (p11 & 15u32).cast::<f32>();
let q19 = (p11 & 240u32).cast::<f32>();
let q110 = (p11 & 3840u32).cast::<f32>();
let q111 = (p11 & 61440u32).cast::<f32>();
let q112 = (p11_hi & 15u32).cast::<f32>();
let q113 = (p11_hi & 240u32).cast::<f32>();
let q114 = (p11_hi & 3840u32).cast::<f32>();
let q115 = (p11_hi & 61440u32).cast::<f32>();
let qd1 = q10 * x0
+ q11 * x1
+ q12 * x2
+ q13 * x3
+ q14 * x4
+ q15 * x5
+ q16 * x6
+ q17 * x7
+ q18 * x8
+ q19 * x9
+ q110 * x10
+ q111 * x11
+ q112 * x12
+ q113 * x13
+ q114 * x14
+ q115 * x15;
acc1 = acc1 + s1 * qd1 + bi1 * xs;
let p20 = load(w[w_base2 + pack_off]);
let p21 = load(w[w_base2 + pack_off + 1u32]);
let p20_hi = p20 >> 16u32;
let p21_hi = p21 >> 16u32;
let s2 = load(scales[sb_base2 + g]).cast::<f32>();
let bi2 = load(biases[sb_base2 + g]).cast::<f32>();
let q20 = (p20 & 15u32).cast::<f32>();
let q21 = (p20 & 240u32).cast::<f32>();
let q22 = (p20 & 3840u32).cast::<f32>();
let q23 = (p20 & 61440u32).cast::<f32>();
let q24 = (p20_hi & 15u32).cast::<f32>();
let q25 = (p20_hi & 240u32).cast::<f32>();
let q26 = (p20_hi & 3840u32).cast::<f32>();
let q27 = (p20_hi & 61440u32).cast::<f32>();
let q28 = (p21 & 15u32).cast::<f32>();
let q29 = (p21 & 240u32).cast::<f32>();
let q210 = (p21 & 3840u32).cast::<f32>();
let q211 = (p21 & 61440u32).cast::<f32>();
let q212 = (p21_hi & 15u32).cast::<f32>();
let q213 = (p21_hi & 240u32).cast::<f32>();
let q214 = (p21_hi & 3840u32).cast::<f32>();
let q215 = (p21_hi & 61440u32).cast::<f32>();
let qd2 = q20 * x0
+ q21 * x1
+ q22 * x2
+ q23 * x3
+ q24 * x4
+ q25 * x5
+ q26 * x6
+ q27 * x7
+ q28 * x8
+ q29 * x9
+ q210 * x10
+ q211 * x11
+ q212 * x12
+ q213 * x13
+ q214 * x14
+ q215 * x15;
acc2 = acc2 + s2 * qd2 + bi2 * xs;
let p30 = load(w[w_base3 + pack_off]);
let p31 = load(w[w_base3 + pack_off + 1u32]);
let p30_hi = p30 >> 16u32;
let p31_hi = p31 >> 16u32;
let s3 = load(scales[sb_base3 + g]).cast::<f32>();
let bi3 = load(biases[sb_base3 + g]).cast::<f32>();
let q30 = (p30 & 15u32).cast::<f32>();
let q31 = (p30 & 240u32).cast::<f32>();
let q32 = (p30 & 3840u32).cast::<f32>();
let q33 = (p30 & 61440u32).cast::<f32>();
let q34 = (p30_hi & 15u32).cast::<f32>();
let q35 = (p30_hi & 240u32).cast::<f32>();
let q36 = (p30_hi & 3840u32).cast::<f32>();
let q37 = (p30_hi & 61440u32).cast::<f32>();
let q38 = (p31 & 15u32).cast::<f32>();
let q39 = (p31 & 240u32).cast::<f32>();
let q310 = (p31 & 3840u32).cast::<f32>();
let q311 = (p31 & 61440u32).cast::<f32>();
let q312 = (p31_hi & 15u32).cast::<f32>();
let q313 = (p31_hi & 240u32).cast::<f32>();
let q314 = (p31_hi & 3840u32).cast::<f32>();
let q315 = (p31_hi & 61440u32).cast::<f32>();
let qd3 = q30 * x0
+ q31 * x1
+ q32 * x2
+ q33 * x3
+ q34 * x4
+ q35 * x5
+ q36 * x6
+ q37 * x7
+ q38 * x8
+ q39 * x9
+ q310 * x10
+ q311 * x11
+ q312 * x12
+ q313 * x13
+ q314 * x14
+ q315 * x15;
acc3 = acc3 + s3 * qd3 + bi3 * xs;
}
let r0 = simd_sum(acc0);
let r1 = simd_sum(acc1);
let r2 = simd_sum(acc2);
let r3 = simd_sum(acc3);
if lane == 0u32 {
store(out[row0], r0.cast::<T>());
store(out[row1], r1.cast::<T>());
store(out[row2], r2.cast::<T>());
store(out[row3], r3.cast::<T>());
}
}
#[bench_kernel(
op="quantized",
subop="qmm",
class=QuantizedMatMul,
shapes=&QUANTIZED_SHAPES,
// M=4 = canonical small-batch prefill token count (covers
// single-prompt prefill chunks + small batched serving). Larger
// M values exposed via the #[ignore] `mt_qmm_perf_bench_*` test.
m=4,
group_size=64,
tpg=64,
tol=1e-2,
mlx="affine_qmm_t_{tn}_gs_64_b_4_alN_true_batch_0",
metal_file="quantized.metal",
dtypes=&[metaltile_core::dtype::DType::F32, metaltile_core::dtype::DType::F16, metaltile_core::dtype::DType::BF16],
)]
#[kernel]
pub fn mt_qmm<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] gs_per_row: u32,
) {
let tg = tgid_x;
let m_row = tgid_y;
let sg = simd_id;
let lane = simd_lane;
let row0 = tg * 8u32 + sg * 4u32;
let row1 = row0 + 1u32;
let row2 = row0 + 2u32;
let row3 = row0 + 3u32;
let packs_per_row = k / 8u32;
let w_base0 = row0 * packs_per_row;
let w_base1 = row1 * packs_per_row;
let w_base2 = row2 * packs_per_row;
let w_base3 = row3 * packs_per_row;
let sb_base0 = row0 * gs_per_row;
let sb_base1 = row1 * gs_per_row;
let sb_base2 = row2 * gs_per_row;
let sb_base3 = row3 * gs_per_row;
let x_row_base = m_row * k;
let mut acc0 = 0.0f32;
let mut acc1 = 0.0f32;
let mut acc2 = 0.0f32;
let mut acc3 = 0.0f32;
let lane_x_off = lane * 16u32;
let lane_pack_off = lane * 2u32;
for _b in range(0u32, k, 512u32) {
let xb = x_row_base + _b + lane_x_off;
let xi0 = xb;
let xi1 = xb + 1u32;
let xi2 = xb + 2u32;
let xi3 = xb + 3u32;
let xi4 = xb + 4u32;
let xi5 = xb + 5u32;
let xi6 = xb + 6u32;
let xi7 = xb + 7u32;
let xi8 = xb + 8u32;
let xi9 = xb + 9u32;
let xi10 = xb + 10u32;
let xi11 = xb + 11u32;
let xi12 = xb + 12u32;
let xi13 = xb + 13u32;
let xi14 = xb + 14u32;
let xi15 = xb + 15u32;
let s_16 = 0.0625f32;
let s_256 = 0.00390625f32;
let s_4096 = 0.000244140625f32;
let x0 = load(x[xi0]).cast::<f32>();
let x1_raw = load(x[xi1]).cast::<f32>();
let x2_raw = load(x[xi2]).cast::<f32>();
let x3_raw = load(x[xi3]).cast::<f32>();
let x4 = load(x[xi4]).cast::<f32>();
let x5_raw = load(x[xi5]).cast::<f32>();
let x6_raw = load(x[xi6]).cast::<f32>();
let x7_raw = load(x[xi7]).cast::<f32>();
let x8 = load(x[xi8]).cast::<f32>();
let x9_raw = load(x[xi9]).cast::<f32>();
let x10_raw = load(x[xi10]).cast::<f32>();
let x11_raw = load(x[xi11]).cast::<f32>();
let x12 = load(x[xi12]).cast::<f32>();
let x13_raw = load(x[xi13]).cast::<f32>();
let x14_raw = load(x[xi14]).cast::<f32>();
let x15_raw = load(x[xi15]).cast::<f32>();
let xs = x0
+ x1_raw
+ x2_raw
+ x3_raw
+ x4
+ x5_raw
+ x6_raw
+ x7_raw
+ x8
+ x9_raw
+ x10_raw
+ x11_raw
+ x12
+ x13_raw
+ x14_raw
+ x15_raw;
let x1 = x1_raw * s_16;
let x2 = x2_raw * s_256;
let x3 = x3_raw * s_4096;
let x5 = x5_raw * s_16;
let x6 = x6_raw * s_256;
let x7 = x7_raw * s_4096;
let x9 = x9_raw * s_16;
let x10 = x10_raw * s_256;
let x11 = x11_raw * s_4096;
let x13 = x13_raw * s_16;
let x14 = x14_raw * s_256;
let x15 = x15_raw * s_4096;
let g = (_b + lane_x_off) / 64u32;
let pack_off = _b / 8u32 + lane_pack_off;
let p00 = load(w[w_base0 + pack_off]);
let p01 = load(w[w_base0 + pack_off + 1u32]);
let p00_hi = p00 >> 16u32;
let p01_hi = p01 >> 16u32;
let s0 = load(scales[sb_base0 + g]).cast::<f32>();
let bi0 = load(biases[sb_base0 + g]).cast::<f32>();
let q00 = (p00 & 15u32).cast::<f32>();
let q01 = (p00 & 240u32).cast::<f32>();
let q02 = (p00 & 3840u32).cast::<f32>();
let q03 = (p00 & 61440u32).cast::<f32>();
let q04 = (p00_hi & 15u32).cast::<f32>();
let q05 = (p00_hi & 240u32).cast::<f32>();
let q06 = (p00_hi & 3840u32).cast::<f32>();
let q07 = (p00_hi & 61440u32).cast::<f32>();
let q08 = (p01 & 15u32).cast::<f32>();
let q09 = (p01 & 240u32).cast::<f32>();
let q010 = (p01 & 3840u32).cast::<f32>();
let q011 = (p01 & 61440u32).cast::<f32>();
let q012 = (p01_hi & 15u32).cast::<f32>();
let q013 = (p01_hi & 240u32).cast::<f32>();
let q014 = (p01_hi & 3840u32).cast::<f32>();
let q015 = (p01_hi & 61440u32).cast::<f32>();
let qd0 = q00 * x0
+ q01 * x1
+ q02 * x2
+ q03 * x3
+ q04 * x4
+ q05 * x5
+ q06 * x6
+ q07 * x7
+ q08 * x8
+ q09 * x9
+ q010 * x10
+ q011 * x11
+ q012 * x12
+ q013 * x13
+ q014 * x14
+ q015 * x15;
acc0 = acc0 + s0 * qd0 + bi0 * xs;
let p10 = load(w[w_base1 + pack_off]);
let p11 = load(w[w_base1 + pack_off + 1u32]);
let p10_hi = p10 >> 16u32;
let p11_hi = p11 >> 16u32;
let s1 = load(scales[sb_base1 + g]).cast::<f32>();
let bi1 = load(biases[sb_base1 + g]).cast::<f32>();
let q10 = (p10 & 15u32).cast::<f32>();
let q11 = (p10 & 240u32).cast::<f32>();
let q12 = (p10 & 3840u32).cast::<f32>();
let q13 = (p10 & 61440u32).cast::<f32>();
let q14 = (p10_hi & 15u32).cast::<f32>();
let q15 = (p10_hi & 240u32).cast::<f32>();
let q16 = (p10_hi & 3840u32).cast::<f32>();
let q17 = (p10_hi & 61440u32).cast::<f32>();
let q18 = (p11 & 15u32).cast::<f32>();
let q19 = (p11 & 240u32).cast::<f32>();
let q110 = (p11 & 3840u32).cast::<f32>();
let q111 = (p11 & 61440u32).cast::<f32>();
let q112 = (p11_hi & 15u32).cast::<f32>();
let q113 = (p11_hi & 240u32).cast::<f32>();
let q114 = (p11_hi & 3840u32).cast::<f32>();
let q115 = (p11_hi & 61440u32).cast::<f32>();
let qd1 = q10 * x0
+ q11 * x1
+ q12 * x2
+ q13 * x3
+ q14 * x4
+ q15 * x5
+ q16 * x6
+ q17 * x7
+ q18 * x8
+ q19 * x9
+ q110 * x10
+ q111 * x11
+ q112 * x12
+ q113 * x13
+ q114 * x14
+ q115 * x15;
acc1 = acc1 + s1 * qd1 + bi1 * xs;
let p20 = load(w[w_base2 + pack_off]);
let p21 = load(w[w_base2 + pack_off + 1u32]);
let p20_hi = p20 >> 16u32;
let p21_hi = p21 >> 16u32;
let s2 = load(scales[sb_base2 + g]).cast::<f32>();
let bi2 = load(biases[sb_base2 + g]).cast::<f32>();
let q20 = (p20 & 15u32).cast::<f32>();
let q21 = (p20 & 240u32).cast::<f32>();
let q22 = (p20 & 3840u32).cast::<f32>();
let q23 = (p20 & 61440u32).cast::<f32>();
let q24 = (p20_hi & 15u32).cast::<f32>();
let q25 = (p20_hi & 240u32).cast::<f32>();
let q26 = (p20_hi & 3840u32).cast::<f32>();
let q27 = (p20_hi & 61440u32).cast::<f32>();
let q28 = (p21 & 15u32).cast::<f32>();
let q29 = (p21 & 240u32).cast::<f32>();
let q210 = (p21 & 3840u32).cast::<f32>();
let q211 = (p21 & 61440u32).cast::<f32>();
let q212 = (p21_hi & 15u32).cast::<f32>();
let q213 = (p21_hi & 240u32).cast::<f32>();
let q214 = (p21_hi & 3840u32).cast::<f32>();
let q215 = (p21_hi & 61440u32).cast::<f32>();
let qd2 = q20 * x0
+ q21 * x1
+ q22 * x2
+ q23 * x3
+ q24 * x4
+ q25 * x5
+ q26 * x6
+ q27 * x7
+ q28 * x8
+ q29 * x9
+ q210 * x10
+ q211 * x11
+ q212 * x12
+ q213 * x13
+ q214 * x14
+ q215 * x15;
acc2 = acc2 + s2 * qd2 + bi2 * xs;
let p30 = load(w[w_base3 + pack_off]);
let p31 = load(w[w_base3 + pack_off + 1u32]);
let p30_hi = p30 >> 16u32;
let p31_hi = p31 >> 16u32;
let s3 = load(scales[sb_base3 + g]).cast::<f32>();
let bi3 = load(biases[sb_base3 + g]).cast::<f32>();
let q30 = (p30 & 15u32).cast::<f32>();
let q31 = (p30 & 240u32).cast::<f32>();
let q32 = (p30 & 3840u32).cast::<f32>();
let q33 = (p30 & 61440u32).cast::<f32>();
let q34 = (p30_hi & 15u32).cast::<f32>();
let q35 = (p30_hi & 240u32).cast::<f32>();
let q36 = (p30_hi & 3840u32).cast::<f32>();
let q37 = (p30_hi & 61440u32).cast::<f32>();
let q38 = (p31 & 15u32).cast::<f32>();
let q39 = (p31 & 240u32).cast::<f32>();
let q310 = (p31 & 3840u32).cast::<f32>();
let q311 = (p31 & 61440u32).cast::<f32>();
let q312 = (p31_hi & 15u32).cast::<f32>();
let q313 = (p31_hi & 240u32).cast::<f32>();
let q314 = (p31_hi & 3840u32).cast::<f32>();
let q315 = (p31_hi & 61440u32).cast::<f32>();
let qd3 = q30 * x0
+ q31 * x1
+ q32 * x2
+ q33 * x3
+ q34 * x4
+ q35 * x5
+ q36 * x6
+ q37 * x7
+ q38 * x8
+ q39 * x9
+ q310 * x10
+ q311 * x11
+ q312 * x12
+ q313 * x13
+ q314 * x14
+ q315 * x15;
acc3 = acc3 + s3 * qd3 + bi3 * xs;
}
let r0 = simd_sum(acc0);
let r1 = simd_sum(acc1);
let r2 = simd_sum(acc2);
let r3 = simd_sum(acc3);
if lane == 0u32 {
store(out[m_row * n + row0], r0.cast::<T>());
store(out[m_row * n + row1], r1.cast::<T>());
store(out[m_row * n + row2], r2.cast::<T>());
store(out[m_row * n + row3], r3.cast::<T>());
}
}
#[bench_kernel(
op="quantized",
subop="qmm_bm2",
class=QuantizedMatMul,
shapes=&QUANTIZED_SHAPES,
// M=8 = larger-batch prefill where W-reuse matters most. M=2 / 4
// also benefit (W reload halved); M=1 should keep dispatching
// mt_qmm (v2) since the BM=2 tile would burn TG slots on unused
// outputs.
// M=8 is a representative mid-M cell. Clean median-of-5 head-to-head
// bm2/v2 (25 cells per M, both rigs): bm2 wins 350/350 across
// M ∈ {2,4,6,8,12,16,32}. Speedups grow with M: 1.09× at M=2
// → 1.24× M5 / 1.30× M2 at M=32. vs MLX `affine_qmm_t`, the M=8
// bench cell measures 1.7-2.5× M5 / 1.4-1.7× f16 M2 (3-run M5
// drift ≤3pt). Selector `mt_qmm_for` routes every even M ≥ 2
// to bm2. Neither kernel beats MLX at M ≥ 16 (MLX's BM=BN=32
// simdgroup-matrix tile dominates large-M); closing that gap is
// the BM=4/BM=8 follow-up.
m=8,
group_size=64,
tpg=64,
// bf16 round-trip on int4-quantized matmul: max_q=15 × group_size=64
// × bf16's 7-bit mantissa drifts ~7-8e-3 at large K (per
// crates/metaltile-std/src/mlx/binary.rs precedent — "bf16 drifts
// ~7.8e-3 on signed"). Tighter than 1e-2 trips the bench cosine
// check at production shapes (M=4096+, K=4096+) on Apple Paravirtual
// CI. tol=1e-2 keeps f32/f16 cells tight while passing bf16.
tol=1e-2,
mlx="affine_qmm_t_{tn}_gs_64_b_4_alN_true_batch_0",
metal_file="quantized.metal",
dtypes=&[metaltile_core::dtype::DType::F32, metaltile_core::dtype::DType::F16, metaltile_core::dtype::DType::BF16],
)]
#[kernel]
pub fn mt_qmm_bm2<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] gs_per_row: u32,
) {
let tg = tgid_x;
let m_tile = tgid_y;
let sg = simd_id;
let lane = simd_lane;
let row0 = tg * 8u32 + sg * 4u32;
let row1 = row0 + 1u32;
let row2 = row0 + 2u32;
let row3 = row0 + 3u32;
let packs_per_row = k / 8u32;
let w_base0 = row0 * packs_per_row;
let w_base1 = row1 * packs_per_row;
let w_base2 = row2 * packs_per_row;
let w_base3 = row3 * packs_per_row;
let sb_base0 = row0 * gs_per_row;
let sb_base1 = row1 * gs_per_row;
let sb_base2 = row2 * gs_per_row;
let sb_base3 = row3 * gs_per_row;
let m_row_a = m_tile * 2u32;
let m_row_b = m_row_a + 1u32;
let x_base_a = m_row_a * k;
let x_base_b = m_row_b * k;
let mut acc0_a = 0.0f32;
let mut acc0_b = 0.0f32;
let mut acc1_a = 0.0f32;
let mut acc1_b = 0.0f32;
let mut acc2_a = 0.0f32;
let mut acc2_b = 0.0f32;
let mut acc3_a = 0.0f32;
let mut acc3_b = 0.0f32;
let lane_x_off = lane * 16u32;
let lane_pack_off = lane * 2u32;
for _b in range(0u32, k, 512u32) {
let xb_a = x_base_a + _b + lane_x_off;
let s_16 = 0.0625f32;
let s_256 = 0.00390625f32;
let s_4096 = 0.000244140625f32;
let x0_a = load(x[xb_a]).cast::<f32>();
let x1_a_raw = load(x[xb_a + 1u32]).cast::<f32>();
let x2_a_raw = load(x[xb_a + 2u32]).cast::<f32>();
let x3_a_raw = load(x[xb_a + 3u32]).cast::<f32>();
let x4_a = load(x[xb_a + 4u32]).cast::<f32>();
let x5_a_raw = load(x[xb_a + 5u32]).cast::<f32>();
let x6_a_raw = load(x[xb_a + 6u32]).cast::<f32>();
let x7_a_raw = load(x[xb_a + 7u32]).cast::<f32>();
let x8_a = load(x[xb_a + 8u32]).cast::<f32>();
let x9_a_raw = load(x[xb_a + 9u32]).cast::<f32>();
let x10_a_raw = load(x[xb_a + 10u32]).cast::<f32>();
let x11_a_raw = load(x[xb_a + 11u32]).cast::<f32>();
let x12_a = load(x[xb_a + 12u32]).cast::<f32>();
let x13_a_raw = load(x[xb_a + 13u32]).cast::<f32>();
let x14_a_raw = load(x[xb_a + 14u32]).cast::<f32>();
let x15_a_raw = load(x[xb_a + 15u32]).cast::<f32>();
let xs_a = x0_a
+ x1_a_raw
+ x2_a_raw
+ x3_a_raw
+ x4_a
+ x5_a_raw
+ x6_a_raw
+ x7_a_raw
+ x8_a
+ x9_a_raw
+ x10_a_raw
+ x11_a_raw
+ x12_a
+ x13_a_raw
+ x14_a_raw
+ x15_a_raw;
let x1_a = x1_a_raw * s_16;
let x2_a = x2_a_raw * s_256;
let x3_a = x3_a_raw * s_4096;
let x5_a = x5_a_raw * s_16;
let x6_a = x6_a_raw * s_256;
let x7_a = x7_a_raw * s_4096;
let x9_a = x9_a_raw * s_16;
let x10_a = x10_a_raw * s_256;
let x11_a = x11_a_raw * s_4096;
let x13_a = x13_a_raw * s_16;
let x14_a = x14_a_raw * s_256;
let x15_a = x15_a_raw * s_4096;
let xb_b = x_base_b + _b + lane_x_off;
let x0_b = load(x[xb_b]).cast::<f32>();
let x1_b_raw = load(x[xb_b + 1u32]).cast::<f32>();
let x2_b_raw = load(x[xb_b + 2u32]).cast::<f32>();
let x3_b_raw = load(x[xb_b + 3u32]).cast::<f32>();
let x4_b = load(x[xb_b + 4u32]).cast::<f32>();
let x5_b_raw = load(x[xb_b + 5u32]).cast::<f32>();
let x6_b_raw = load(x[xb_b + 6u32]).cast::<f32>();
let x7_b_raw = load(x[xb_b + 7u32]).cast::<f32>();
let x8_b = load(x[xb_b + 8u32]).cast::<f32>();
let x9_b_raw = load(x[xb_b + 9u32]).cast::<f32>();
let x10_b_raw = load(x[xb_b + 10u32]).cast::<f32>();
let x11_b_raw = load(x[xb_b + 11u32]).cast::<f32>();
let x12_b = load(x[xb_b + 12u32]).cast::<f32>();
let x13_b_raw = load(x[xb_b + 13u32]).cast::<f32>();
let x14_b_raw = load(x[xb_b + 14u32]).cast::<f32>();
let x15_b_raw = load(x[xb_b + 15u32]).cast::<f32>();
let xs_b = x0_b
+ x1_b_raw
+ x2_b_raw
+ x3_b_raw
+ x4_b
+ x5_b_raw
+ x6_b_raw
+ x7_b_raw
+ x8_b
+ x9_b_raw
+ x10_b_raw
+ x11_b_raw
+ x12_b
+ x13_b_raw
+ x14_b_raw
+ x15_b_raw;
let x1_b = x1_b_raw * s_16;
let x2_b = x2_b_raw * s_256;
let x3_b = x3_b_raw * s_4096;
let x5_b = x5_b_raw * s_16;
let x6_b = x6_b_raw * s_256;
let x7_b = x7_b_raw * s_4096;
let x9_b = x9_b_raw * s_16;
let x10_b = x10_b_raw * s_256;
let x11_b = x11_b_raw * s_4096;
let x13_b = x13_b_raw * s_16;
let x14_b = x14_b_raw * s_256;
let x15_b = x15_b_raw * s_4096;
let g = (_b + lane_x_off) / 64u32;
let pack_off = _b / 8u32 + lane_pack_off;
let p00 = load(w[w_base0 + pack_off]);
let p01 = load(w[w_base0 + pack_off + 1u32]);
let p00_hi = p00 >> 16u32;
let p01_hi = p01 >> 16u32;
let s0 = load(scales[sb_base0 + g]).cast::<f32>();
let bi0 = load(biases[sb_base0 + g]).cast::<f32>();
let q00 = (p00 & 15u32).cast::<f32>();
let q01 = (p00 & 240u32).cast::<f32>();
let q02 = (p00 & 3840u32).cast::<f32>();
let q03 = (p00 & 61440u32).cast::<f32>();
let q04 = (p00_hi & 15u32).cast::<f32>();
let q05 = (p00_hi & 240u32).cast::<f32>();
let q06 = (p00_hi & 3840u32).cast::<f32>();
let q07 = (p00_hi & 61440u32).cast::<f32>();
let q08 = (p01 & 15u32).cast::<f32>();
let q09 = (p01 & 240u32).cast::<f32>();
let q010 = (p01 & 3840u32).cast::<f32>();
let q011 = (p01 & 61440u32).cast::<f32>();
let q012 = (p01_hi & 15u32).cast::<f32>();
let q013 = (p01_hi & 240u32).cast::<f32>();
let q014 = (p01_hi & 3840u32).cast::<f32>();
let q015 = (p01_hi & 61440u32).cast::<f32>();
let qd0_a = q00 * x0_a
+ q01 * x1_a
+ q02 * x2_a
+ q03 * x3_a
+ q04 * x4_a
+ q05 * x5_a
+ q06 * x6_a
+ q07 * x7_a
+ q08 * x8_a
+ q09 * x9_a
+ q010 * x10_a
+ q011 * x11_a
+ q012 * x12_a
+ q013 * x13_a
+ q014 * x14_a
+ q015 * x15_a;
let qd0_b = q00 * x0_b
+ q01 * x1_b
+ q02 * x2_b
+ q03 * x3_b
+ q04 * x4_b
+ q05 * x5_b
+ q06 * x6_b
+ q07 * x7_b
+ q08 * x8_b
+ q09 * x9_b
+ q010 * x10_b
+ q011 * x11_b
+ q012 * x12_b
+ q013 * x13_b
+ q014 * x14_b
+ q015 * x15_b;
acc0_a = acc0_a + s0 * qd0_a + bi0 * xs_a;
acc0_b = acc0_b + s0 * qd0_b + bi0 * xs_b;
let p10 = load(w[w_base1 + pack_off]);
let p11 = load(w[w_base1 + pack_off + 1u32]);
let p10_hi = p10 >> 16u32;
let p11_hi = p11 >> 16u32;
let s1 = load(scales[sb_base1 + g]).cast::<f32>();
let bi1 = load(biases[sb_base1 + g]).cast::<f32>();
let q10 = (p10 & 15u32).cast::<f32>();
let q11 = (p10 & 240u32).cast::<f32>();
let q12 = (p10 & 3840u32).cast::<f32>();
let q13 = (p10 & 61440u32).cast::<f32>();
let q14 = (p10_hi & 15u32).cast::<f32>();
let q15 = (p10_hi & 240u32).cast::<f32>();
let q16 = (p10_hi & 3840u32).cast::<f32>();
let q17 = (p10_hi & 61440u32).cast::<f32>();
let q18 = (p11 & 15u32).cast::<f32>();
let q19 = (p11 & 240u32).cast::<f32>();
let q110 = (p11 & 3840u32).cast::<f32>();
let q111 = (p11 & 61440u32).cast::<f32>();
let q112 = (p11_hi & 15u32).cast::<f32>();
let q113 = (p11_hi & 240u32).cast::<f32>();
let q114 = (p11_hi & 3840u32).cast::<f32>();
let q115 = (p11_hi & 61440u32).cast::<f32>();
let qd1_a = q10 * x0_a
+ q11 * x1_a
+ q12 * x2_a
+ q13 * x3_a
+ q14 * x4_a
+ q15 * x5_a
+ q16 * x6_a
+ q17 * x7_a
+ q18 * x8_a
+ q19 * x9_a
+ q110 * x10_a
+ q111 * x11_a
+ q112 * x12_a
+ q113 * x13_a
+ q114 * x14_a
+ q115 * x15_a;
let qd1_b = q10 * x0_b
+ q11 * x1_b
+ q12 * x2_b
+ q13 * x3_b
+ q14 * x4_b
+ q15 * x5_b
+ q16 * x6_b
+ q17 * x7_b
+ q18 * x8_b
+ q19 * x9_b
+ q110 * x10_b
+ q111 * x11_b
+ q112 * x12_b
+ q113 * x13_b
+ q114 * x14_b
+ q115 * x15_b;
acc1_a = acc1_a + s1 * qd1_a + bi1 * xs_a;
acc1_b = acc1_b + s1 * qd1_b + bi1 * xs_b;
let p20 = load(w[w_base2 + pack_off]);
let p21 = load(w[w_base2 + pack_off + 1u32]);
let p20_hi = p20 >> 16u32;
let p21_hi = p21 >> 16u32;
let s2 = load(scales[sb_base2 + g]).cast::<f32>();
let bi2 = load(biases[sb_base2 + g]).cast::<f32>();
let q20 = (p20 & 15u32).cast::<f32>();
let q21 = (p20 & 240u32).cast::<f32>();
let q22 = (p20 & 3840u32).cast::<f32>();
let q23 = (p20 & 61440u32).cast::<f32>();
let q24 = (p20_hi & 15u32).cast::<f32>();
let q25 = (p20_hi & 240u32).cast::<f32>();
let q26 = (p20_hi & 3840u32).cast::<f32>();
let q27 = (p20_hi & 61440u32).cast::<f32>();
let q28 = (p21 & 15u32).cast::<f32>();
let q29 = (p21 & 240u32).cast::<f32>();
let q210 = (p21 & 3840u32).cast::<f32>();
let q211 = (p21 & 61440u32).cast::<f32>();
let q212 = (p21_hi & 15u32).cast::<f32>();
let q213 = (p21_hi & 240u32).cast::<f32>();
let q214 = (p21_hi & 3840u32).cast::<f32>();
let q215 = (p21_hi & 61440u32).cast::<f32>();
let qd2_a = q20 * x0_a
+ q21 * x1_a
+ q22 * x2_a
+ q23 * x3_a
+ q24 * x4_a
+ q25 * x5_a
+ q26 * x6_a
+ q27 * x7_a
+ q28 * x8_a
+ q29 * x9_a
+ q210 * x10_a
+ q211 * x11_a
+ q212 * x12_a
+ q213 * x13_a
+ q214 * x14_a
+ q215 * x15_a;
let qd2_b = q20 * x0_b
+ q21 * x1_b
+ q22 * x2_b
+ q23 * x3_b
+ q24 * x4_b
+ q25 * x5_b
+ q26 * x6_b
+ q27 * x7_b
+ q28 * x8_b
+ q29 * x9_b
+ q210 * x10_b
+ q211 * x11_b
+ q212 * x12_b
+ q213 * x13_b
+ q214 * x14_b
+ q215 * x15_b;
acc2_a = acc2_a + s2 * qd2_a + bi2 * xs_a;
acc2_b = acc2_b + s2 * qd2_b + bi2 * xs_b;
let p30 = load(w[w_base3 + pack_off]);
let p31 = load(w[w_base3 + pack_off + 1u32]);
let p30_hi = p30 >> 16u32;
let p31_hi = p31 >> 16u32;
let s3 = load(scales[sb_base3 + g]).cast::<f32>();
let bi3 = load(biases[sb_base3 + g]).cast::<f32>();
let q30 = (p30 & 15u32).cast::<f32>();
let q31 = (p30 & 240u32).cast::<f32>();
let q32 = (p30 & 3840u32).cast::<f32>();
let q33 = (p30 & 61440u32).cast::<f32>();
let q34 = (p30_hi & 15u32).cast::<f32>();
let q35 = (p30_hi & 240u32).cast::<f32>();
let q36 = (p30_hi & 3840u32).cast::<f32>();
let q37 = (p30_hi & 61440u32).cast::<f32>();
let q38 = (p31 & 15u32).cast::<f32>();
let q39 = (p31 & 240u32).cast::<f32>();
let q310 = (p31 & 3840u32).cast::<f32>();
let q311 = (p31 & 61440u32).cast::<f32>();
let q312 = (p31_hi & 15u32).cast::<f32>();
let q313 = (p31_hi & 240u32).cast::<f32>();
let q314 = (p31_hi & 3840u32).cast::<f32>();
let q315 = (p31_hi & 61440u32).cast::<f32>();
let qd3_a = q30 * x0_a
+ q31 * x1_a
+ q32 * x2_a
+ q33 * x3_a
+ q34 * x4_a
+ q35 * x5_a
+ q36 * x6_a
+ q37 * x7_a
+ q38 * x8_a
+ q39 * x9_a
+ q310 * x10_a
+ q311 * x11_a
+ q312 * x12_a
+ q313 * x13_a
+ q314 * x14_a
+ q315 * x15_a;
let qd3_b = q30 * x0_b
+ q31 * x1_b
+ q32 * x2_b
+ q33 * x3_b
+ q34 * x4_b
+ q35 * x5_b
+ q36 * x6_b
+ q37 * x7_b
+ q38 * x8_b
+ q39 * x9_b
+ q310 * x10_b
+ q311 * x11_b
+ q312 * x12_b
+ q313 * x13_b
+ q314 * x14_b
+ q315 * x15_b;
acc3_a = acc3_a + s3 * qd3_a + bi3 * xs_a;
acc3_b = acc3_b + s3 * qd3_b + bi3 * xs_b;
}
let r0_a = simd_sum(acc0_a);
let r0_b = simd_sum(acc0_b);
let r1_a = simd_sum(acc1_a);
let r1_b = simd_sum(acc1_b);
let r2_a = simd_sum(acc2_a);
let r2_b = simd_sum(acc2_b);
let r3_a = simd_sum(acc3_a);
let r3_b = simd_sum(acc3_b);
if lane == 0u32 {
store(out[m_row_a * n + row0], r0_a.cast::<T>());
store(out[m_row_a * n + row1], r1_a.cast::<T>());
store(out[m_row_a * n + row2], r2_a.cast::<T>());
store(out[m_row_a * n + row3], r3_a.cast::<T>());
store(out[m_row_b * n + row0], r0_b.cast::<T>());
store(out[m_row_b * n + row1], r1_b.cast::<T>());
store(out[m_row_b * n + row2], r2_b.cast::<T>());
store(out[m_row_b * n + row3], r3_b.cast::<T>());
}
}
#[bench_kernel(
op="quantized",
subop="qmm_bm4",
class=QuantizedMatMul,
shapes=&QUANTIZED_SHAPES,
// M=16 is the cell where bm4's W-bw advantage compounds — bm2 at
// M=16 hits ~50% MT MLX; bm4 halves W bw so a 1.5-1.8× speedup
// over bm2 is plausible. Selector routes m % 4 == 0 to bm4.
m=8,
group_size=64,
tpg=64,
// bf16 round-trip on int4-quantized matmul: max_q=15 × group_size=64
// × bf16's 7-bit mantissa drifts ~7-8e-3 at large K (per
// crates/metaltile-std/src/mlx/binary.rs precedent — "bf16 drifts
// ~7.8e-3 on signed"). Tighter than 1e-2 trips the bench cosine
// check at production shapes (M=4096+, K=4096+) on Apple Paravirtual
// CI. tol=1e-2 keeps f32/f16 cells tight while passing bf16.
tol=1e-2,
mlx="affine_qmm_t_{tn}_gs_64_b_4_alN_true_batch_0",
metal_file="quantized.metal",
dtypes=&[metaltile_core::dtype::DType::F32, metaltile_core::dtype::DType::F16, metaltile_core::dtype::DType::BF16],
)]
#[kernel]
pub fn mt_qmm_bm4<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] gs_per_row: u32,
) {
let tg = tgid_x;
let m_tile = tgid_y;
let sg = simd_id;
let lane = simd_lane;
let row0 = tg * 8u32 + sg * 4u32;
let row1 = row0 + 1u32;
let row2 = row0 + 2u32;
let row3 = row0 + 3u32;
let packs_per_row = k / 8u32;
let w_base0 = row0 * packs_per_row;
let w_base1 = row1 * packs_per_row;
let w_base2 = row2 * packs_per_row;
let w_base3 = row3 * packs_per_row;
let sb_base0 = row0 * gs_per_row;
let sb_base1 = row1 * gs_per_row;
let sb_base2 = row2 * gs_per_row;
let sb_base3 = row3 * gs_per_row;
let m_row_a = m_tile * 4u32;
let m_row_b = m_row_a + 1u32;
let m_row_c = m_row_a + 2u32;
let m_row_d = m_row_a + 3u32;
let x_base_a = m_row_a * k;
let x_base_b = m_row_b * k;
let x_base_c = m_row_c * k;
let x_base_d = m_row_d * k;
let mut acc0_a = 0.0f32;
let mut acc0_b = 0.0f32;
let mut acc0_c = 0.0f32;
let mut acc0_d = 0.0f32;
let mut acc1_a = 0.0f32;
let mut acc1_b = 0.0f32;
let mut acc1_c = 0.0f32;
let mut acc1_d = 0.0f32;
let mut acc2_a = 0.0f32;
let mut acc2_b = 0.0f32;
let mut acc2_c = 0.0f32;
let mut acc2_d = 0.0f32;
let mut acc3_a = 0.0f32;
let mut acc3_b = 0.0f32;
let mut acc3_c = 0.0f32;
let mut acc3_d = 0.0f32;
let lane_x_off = lane * 8u32;
let lane_pack_off = lane;
for _b in range(0u32, k, 256u32) {
let s_16 = 0.0625f32;
let s_256 = 0.00390625f32;
let s_4096 = 0.000244140625f32;
let xb_a = x_base_a + _b + lane_x_off;
let x0_a = load(x[xb_a]).cast::<f32>();
let x1_a_raw = load(x[xb_a + 1u32]).cast::<f32>();
let x2_a_raw = load(x[xb_a + 2u32]).cast::<f32>();
let x3_a_raw = load(x[xb_a + 3u32]).cast::<f32>();
let x4_a = load(x[xb_a + 4u32]).cast::<f32>();
let x5_a_raw = load(x[xb_a + 5u32]).cast::<f32>();
let x6_a_raw = load(x[xb_a + 6u32]).cast::<f32>();
let x7_a_raw = load(x[xb_a + 7u32]).cast::<f32>();
let xs_a = x0_a + x1_a_raw + x2_a_raw + x3_a_raw + x4_a + x5_a_raw + x6_a_raw + x7_a_raw;
let x1_a = x1_a_raw * s_16;
let x2_a = x2_a_raw * s_256;
let x3_a = x3_a_raw * s_4096;
let x5_a = x5_a_raw * s_16;
let x6_a = x6_a_raw * s_256;
let x7_a = x7_a_raw * s_4096;
let xb_b = x_base_b + _b + lane_x_off;
let x0_b = load(x[xb_b]).cast::<f32>();
let x1_b_raw = load(x[xb_b + 1u32]).cast::<f32>();
let x2_b_raw = load(x[xb_b + 2u32]).cast::<f32>();
let x3_b_raw = load(x[xb_b + 3u32]).cast::<f32>();
let x4_b = load(x[xb_b + 4u32]).cast::<f32>();
let x5_b_raw = load(x[xb_b + 5u32]).cast::<f32>();
let x6_b_raw = load(x[xb_b + 6u32]).cast::<f32>();
let x7_b_raw = load(x[xb_b + 7u32]).cast::<f32>();
let xs_b = x0_b + x1_b_raw + x2_b_raw + x3_b_raw + x4_b + x5_b_raw + x6_b_raw + x7_b_raw;
let x1_b = x1_b_raw * s_16;
let x2_b = x2_b_raw * s_256;
let x3_b = x3_b_raw * s_4096;
let x5_b = x5_b_raw * s_16;
let x6_b = x6_b_raw * s_256;
let x7_b = x7_b_raw * s_4096;
let xb_c = x_base_c + _b + lane_x_off;
let x0_c = load(x[xb_c]).cast::<f32>();
let x1_c_raw = load(x[xb_c + 1u32]).cast::<f32>();
let x2_c_raw = load(x[xb_c + 2u32]).cast::<f32>();
let x3_c_raw = load(x[xb_c + 3u32]).cast::<f32>();
let x4_c = load(x[xb_c + 4u32]).cast::<f32>();
let x5_c_raw = load(x[xb_c + 5u32]).cast::<f32>();
let x6_c_raw = load(x[xb_c + 6u32]).cast::<f32>();
let x7_c_raw = load(x[xb_c + 7u32]).cast::<f32>();
let xs_c = x0_c + x1_c_raw + x2_c_raw + x3_c_raw + x4_c + x5_c_raw + x6_c_raw + x7_c_raw;
let x1_c = x1_c_raw * s_16;
let x2_c = x2_c_raw * s_256;
let x3_c = x3_c_raw * s_4096;
let x5_c = x5_c_raw * s_16;
let x6_c = x6_c_raw * s_256;
let x7_c = x7_c_raw * s_4096;
let xb_d = x_base_d + _b + lane_x_off;
let x0_d = load(x[xb_d]).cast::<f32>();
let x1_d_raw = load(x[xb_d + 1u32]).cast::<f32>();
let x2_d_raw = load(x[xb_d + 2u32]).cast::<f32>();
let x3_d_raw = load(x[xb_d + 3u32]).cast::<f32>();
let x4_d = load(x[xb_d + 4u32]).cast::<f32>();
let x5_d_raw = load(x[xb_d + 5u32]).cast::<f32>();
let x6_d_raw = load(x[xb_d + 6u32]).cast::<f32>();
let x7_d_raw = load(x[xb_d + 7u32]).cast::<f32>();
let xs_d = x0_d + x1_d_raw + x2_d_raw + x3_d_raw + x4_d + x5_d_raw + x6_d_raw + x7_d_raw;
let x1_d = x1_d_raw * s_16;
let x2_d = x2_d_raw * s_256;
let x3_d = x3_d_raw * s_4096;
let x5_d = x5_d_raw * s_16;
let x6_d = x6_d_raw * s_256;
let x7_d = x7_d_raw * s_4096;
let g = (_b + lane_x_off) / 64u32;
let pack_off = _b / 8u32 + lane_pack_off;
let p00 = load(w[w_base0 + pack_off]);
let p00_hi = p00 >> 16u32;
let s0 = load(scales[sb_base0 + g]).cast::<f32>();
let bi0 = load(biases[sb_base0 + g]).cast::<f32>();
let q00 = (p00 & 15u32).cast::<f32>();
let q01 = (p00 & 240u32).cast::<f32>();
let q02 = (p00 & 3840u32).cast::<f32>();
let q03 = (p00 & 61440u32).cast::<f32>();
let q04 = (p00_hi & 15u32).cast::<f32>();
let q05 = (p00_hi & 240u32).cast::<f32>();
let q06 = (p00_hi & 3840u32).cast::<f32>();
let q07 = (p00_hi & 61440u32).cast::<f32>();
let qd0_a = q00 * x0_a
+ q01 * x1_a
+ q02 * x2_a
+ q03 * x3_a
+ q04 * x4_a
+ q05 * x5_a
+ q06 * x6_a
+ q07 * x7_a;
let qd0_b = q00 * x0_b
+ q01 * x1_b
+ q02 * x2_b
+ q03 * x3_b
+ q04 * x4_b
+ q05 * x5_b
+ q06 * x6_b
+ q07 * x7_b;
let qd0_c = q00 * x0_c
+ q01 * x1_c
+ q02 * x2_c
+ q03 * x3_c
+ q04 * x4_c
+ q05 * x5_c
+ q06 * x6_c
+ q07 * x7_c;
let qd0_d = q00 * x0_d
+ q01 * x1_d
+ q02 * x2_d
+ q03 * x3_d
+ q04 * x4_d
+ q05 * x5_d
+ q06 * x6_d
+ q07 * x7_d;
acc0_a = acc0_a + s0 * qd0_a + bi0 * xs_a;
acc0_b = acc0_b + s0 * qd0_b + bi0 * xs_b;
acc0_c = acc0_c + s0 * qd0_c + bi0 * xs_c;
acc0_d = acc0_d + s0 * qd0_d + bi0 * xs_d;
let p10 = load(w[w_base1 + pack_off]);
let p10_hi = p10 >> 16u32;
let s1 = load(scales[sb_base1 + g]).cast::<f32>();
let bi1 = load(biases[sb_base1 + g]).cast::<f32>();
let q10 = (p10 & 15u32).cast::<f32>();
let q11 = (p10 & 240u32).cast::<f32>();
let q12 = (p10 & 3840u32).cast::<f32>();
let q13 = (p10 & 61440u32).cast::<f32>();
let q14 = (p10_hi & 15u32).cast::<f32>();
let q15 = (p10_hi & 240u32).cast::<f32>();
let q16 = (p10_hi & 3840u32).cast::<f32>();
let q17 = (p10_hi & 61440u32).cast::<f32>();
let qd1_a = q10 * x0_a
+ q11 * x1_a
+ q12 * x2_a
+ q13 * x3_a
+ q14 * x4_a
+ q15 * x5_a
+ q16 * x6_a
+ q17 * x7_a;
let qd1_b = q10 * x0_b
+ q11 * x1_b
+ q12 * x2_b
+ q13 * x3_b
+ q14 * x4_b
+ q15 * x5_b
+ q16 * x6_b
+ q17 * x7_b;
let qd1_c = q10 * x0_c
+ q11 * x1_c
+ q12 * x2_c
+ q13 * x3_c
+ q14 * x4_c
+ q15 * x5_c
+ q16 * x6_c
+ q17 * x7_c;
let qd1_d = q10 * x0_d
+ q11 * x1_d
+ q12 * x2_d
+ q13 * x3_d
+ q14 * x4_d
+ q15 * x5_d
+ q16 * x6_d
+ q17 * x7_d;
acc1_a = acc1_a + s1 * qd1_a + bi1 * xs_a;
acc1_b = acc1_b + s1 * qd1_b + bi1 * xs_b;
acc1_c = acc1_c + s1 * qd1_c + bi1 * xs_c;
acc1_d = acc1_d + s1 * qd1_d + bi1 * xs_d;
let p20 = load(w[w_base2 + pack_off]);
let p20_hi = p20 >> 16u32;
let s2 = load(scales[sb_base2 + g]).cast::<f32>();
let bi2 = load(biases[sb_base2 + g]).cast::<f32>();
let q20 = (p20 & 15u32).cast::<f32>();
let q21 = (p20 & 240u32).cast::<f32>();
let q22 = (p20 & 3840u32).cast::<f32>();
let q23 = (p20 & 61440u32).cast::<f32>();
let q24 = (p20_hi & 15u32).cast::<f32>();
let q25 = (p20_hi & 240u32).cast::<f32>();
let q26 = (p20_hi & 3840u32).cast::<f32>();
let q27 = (p20_hi & 61440u32).cast::<f32>();
let qd2_a = q20 * x0_a
+ q21 * x1_a
+ q22 * x2_a
+ q23 * x3_a
+ q24 * x4_a
+ q25 * x5_a
+ q26 * x6_a
+ q27 * x7_a;
let qd2_b = q20 * x0_b
+ q21 * x1_b
+ q22 * x2_b
+ q23 * x3_b
+ q24 * x4_b
+ q25 * x5_b
+ q26 * x6_b
+ q27 * x7_b;
let qd2_c = q20 * x0_c
+ q21 * x1_c
+ q22 * x2_c
+ q23 * x3_c
+ q24 * x4_c
+ q25 * x5_c
+ q26 * x6_c
+ q27 * x7_c;
let qd2_d = q20 * x0_d
+ q21 * x1_d
+ q22 * x2_d
+ q23 * x3_d
+ q24 * x4_d
+ q25 * x5_d
+ q26 * x6_d
+ q27 * x7_d;
acc2_a = acc2_a + s2 * qd2_a + bi2 * xs_a;
acc2_b = acc2_b + s2 * qd2_b + bi2 * xs_b;
acc2_c = acc2_c + s2 * qd2_c + bi2 * xs_c;
acc2_d = acc2_d + s2 * qd2_d + bi2 * xs_d;
let p30 = load(w[w_base3 + pack_off]);
let p30_hi = p30 >> 16u32;
let s3 = load(scales[sb_base3 + g]).cast::<f32>();
let bi3 = load(biases[sb_base3 + g]).cast::<f32>();
let q30 = (p30 & 15u32).cast::<f32>();
let q31 = (p30 & 240u32).cast::<f32>();
let q32 = (p30 & 3840u32).cast::<f32>();
let q33 = (p30 & 61440u32).cast::<f32>();
let q34 = (p30_hi & 15u32).cast::<f32>();
let q35 = (p30_hi & 240u32).cast::<f32>();
let q36 = (p30_hi & 3840u32).cast::<f32>();
let q37 = (p30_hi & 61440u32).cast::<f32>();
let qd3_a = q30 * x0_a
+ q31 * x1_a
+ q32 * x2_a
+ q33 * x3_a
+ q34 * x4_a
+ q35 * x5_a
+ q36 * x6_a
+ q37 * x7_a;
let qd3_b = q30 * x0_b
+ q31 * x1_b
+ q32 * x2_b
+ q33 * x3_b
+ q34 * x4_b
+ q35 * x5_b
+ q36 * x6_b
+ q37 * x7_b;
let qd3_c = q30 * x0_c
+ q31 * x1_c
+ q32 * x2_c
+ q33 * x3_c
+ q34 * x4_c
+ q35 * x5_c
+ q36 * x6_c
+ q37 * x7_c;
let qd3_d = q30 * x0_d
+ q31 * x1_d
+ q32 * x2_d
+ q33 * x3_d
+ q34 * x4_d
+ q35 * x5_d
+ q36 * x6_d
+ q37 * x7_d;
acc3_a = acc3_a + s3 * qd3_a + bi3 * xs_a;
acc3_b = acc3_b + s3 * qd3_b + bi3 * xs_b;
acc3_c = acc3_c + s3 * qd3_c + bi3 * xs_c;
acc3_d = acc3_d + s3 * qd3_d + bi3 * xs_d;
}
let r0_a = simd_sum(acc0_a);
let r0_b = simd_sum(acc0_b);
let r0_c = simd_sum(acc0_c);
let r0_d = simd_sum(acc0_d);
let r1_a = simd_sum(acc1_a);
let r1_b = simd_sum(acc1_b);
let r1_c = simd_sum(acc1_c);
let r1_d = simd_sum(acc1_d);
let r2_a = simd_sum(acc2_a);
let r2_b = simd_sum(acc2_b);
let r2_c = simd_sum(acc2_c);
let r2_d = simd_sum(acc2_d);
let r3_a = simd_sum(acc3_a);
let r3_b = simd_sum(acc3_b);
let r3_c = simd_sum(acc3_c);
let r3_d = simd_sum(acc3_d);
if lane == 0u32 {
store(out[m_row_a * n + row0], r0_a.cast::<T>());
store(out[m_row_a * n + row1], r1_a.cast::<T>());
store(out[m_row_a * n + row2], r2_a.cast::<T>());
store(out[m_row_a * n + row3], r3_a.cast::<T>());
store(out[m_row_b * n + row0], r0_b.cast::<T>());
store(out[m_row_b * n + row1], r1_b.cast::<T>());
store(out[m_row_b * n + row2], r2_b.cast::<T>());
store(out[m_row_b * n + row3], r3_b.cast::<T>());
store(out[m_row_c * n + row0], r0_c.cast::<T>());
store(out[m_row_c * n + row1], r1_c.cast::<T>());
store(out[m_row_c * n + row2], r2_c.cast::<T>());
store(out[m_row_c * n + row3], r3_c.cast::<T>());
store(out[m_row_d * n + row0], r0_d.cast::<T>());
store(out[m_row_d * n + row1], r1_d.cast::<T>());
store(out[m_row_d * n + row2], r2_d.cast::<T>());
store(out[m_row_d * n + row3], r3_d.cast::<T>());
}
}
#[bench_kernel(
op="quantized",
subop="qmv_int8_fast",
class=QuantizedMatVec,
shapes=&QUANTIZED_SHAPES,
group_size=64,
tpg=64,
// bits=8: drives `run_quantized_mat_vec`'s W pack-factor (4
// bytes/u32) + the bit-stream extract in the correctness oracle.
// Without this the runner defaults to bits=4 and the int8 kernel
// reads 2× the int4-sized W buffer.
bits=8,
tol=1e-3,
mlx="affine_qmv_fast_float16_t_gs_64_b_8_batch_0",
metal_file="quantized.metal",
dtypes=&[metaltile_core::dtype::DType::F32, metaltile_core::dtype::DType::F16, metaltile_core::dtype::DType::BF16],
)]
#[kernel]
pub fn mt_qmv_int8_fast<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] gs_per_row: u32,
) {
let tg = tgid_x;
let sg = simd_id;
let lane = simd_lane;
let row0 = tg * 8u32 + sg * 4u32;
let row1 = row0 + 1u32;
let row2 = row0 + 2u32;
let row3 = row0 + 3u32;
let packs_per_row = k / 4u32;
let w_base0 = row0 * packs_per_row;
let w_base1 = row1 * packs_per_row;
let w_base2 = row2 * packs_per_row;
let w_base3 = row3 * packs_per_row;
let sb_base0 = row0 * gs_per_row;
let sb_base1 = row1 * gs_per_row;
let sb_base2 = row2 * gs_per_row;
let sb_base3 = row3 * gs_per_row;
let mut acc0 = 0.0f32;
let mut acc1 = 0.0f32;
let mut acc2 = 0.0f32;
let mut acc3 = 0.0f32;
let lane_x_off = lane * 4u32;
let lane_pack_off = lane;
for _b in range(0u32, k, 128u32) {
let xb = _b + lane_x_off;
let x0 = load(x[xb]).cast::<f32>();
let x1 = load(x[xb + 1u32]).cast::<f32>();
let x2 = load(x[xb + 2u32]).cast::<f32>();
let x3 = load(x[xb + 3u32]).cast::<f32>();
let xs = x0 + x1 + x2 + x3;
let g = xb / 64u32;
let pack_off = _b / 4u32 + lane_pack_off;
let p0 = load(w[w_base0 + pack_off]);
let s0 = load(scales[sb_base0 + g]).cast::<f32>();
let bi0 = load(biases[sb_base0 + g]).cast::<f32>();
let q00 = (p0 & 255u32).cast::<f32>();
let q01 = ((p0 >> 8u32) & 255u32).cast::<f32>();
let q02 = ((p0 >> 16u32) & 255u32).cast::<f32>();
let q03 = ((p0 >> 24u32) & 255u32).cast::<f32>();
let qd0 = q00 * x0 + q01 * x1 + q02 * x2 + q03 * x3;
acc0 = acc0 + s0 * qd0 + bi0 * xs;
let p1 = load(w[w_base1 + pack_off]);
let s1 = load(scales[sb_base1 + g]).cast::<f32>();
let bi1 = load(biases[sb_base1 + g]).cast::<f32>();
let q10 = (p1 & 255u32).cast::<f32>();
let q11 = ((p1 >> 8u32) & 255u32).cast::<f32>();
let q12 = ((p1 >> 16u32) & 255u32).cast::<f32>();
let q13 = ((p1 >> 24u32) & 255u32).cast::<f32>();
let qd1 = q10 * x0 + q11 * x1 + q12 * x2 + q13 * x3;
acc1 = acc1 + s1 * qd1 + bi1 * xs;
let p2 = load(w[w_base2 + pack_off]);
let s2 = load(scales[sb_base2 + g]).cast::<f32>();
let bi2 = load(biases[sb_base2 + g]).cast::<f32>();
let q20 = (p2 & 255u32).cast::<f32>();
let q21 = ((p2 >> 8u32) & 255u32).cast::<f32>();
let q22 = ((p2 >> 16u32) & 255u32).cast::<f32>();
let q23 = ((p2 >> 24u32) & 255u32).cast::<f32>();
let qd2 = q20 * x0 + q21 * x1 + q22 * x2 + q23 * x3;
acc2 = acc2 + s2 * qd2 + bi2 * xs;
let p3 = load(w[w_base3 + pack_off]);
let s3 = load(scales[sb_base3 + g]).cast::<f32>();
let bi3 = load(biases[sb_base3 + g]).cast::<f32>();
let q30 = (p3 & 255u32).cast::<f32>();
let q31 = ((p3 >> 8u32) & 255u32).cast::<f32>();
let q32 = ((p3 >> 16u32) & 255u32).cast::<f32>();
let q33 = ((p3 >> 24u32) & 255u32).cast::<f32>();
let qd3 = q30 * x0 + q31 * x1 + q32 * x2 + q33 * x3;
acc3 = acc3 + s3 * qd3 + bi3 * xs;
}
let r0 = simd_sum(acc0);
let r1 = simd_sum(acc1);
let r2 = simd_sum(acc2);
let r3 = simd_sum(acc3);
if lane == 0u32 {
store(out[row0], r0.cast::<T>());
store(out[row1], r1.cast::<T>());
store(out[row2], r2.cast::<T>());
store(out[row3], r3.cast::<T>());
}
}
#[bench_kernel(
op="quantized",
subop="qmm_int8_fast",
class=QuantizedMatMul,
shapes=&QUANTIZED_SHAPES,
m=4,
group_size=64,
tpg=64,
bits=8,
tol=1e-2,
mlx="affine_qmm_fast_float16_t_gs_64_b_8_alN_true_batch_0",
metal_file="quantized.metal",
dtypes=&[metaltile_core::dtype::DType::F32, metaltile_core::dtype::DType::F16, metaltile_core::dtype::DType::BF16],
)]
#[kernel]
pub fn mt_qmm_int8_fast<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] gs_per_row: u32,
) {
let tg = tgid_x;
let m_row = tgid_y;
let sg = simd_id;
let lane = simd_lane;
let row0 = tg * 8u32 + sg * 4u32;
let row1 = row0 + 1u32;
let row2 = row0 + 2u32;
let row3 = row0 + 3u32;
let packs_per_row = k / 4u32;
let w_base0 = row0 * packs_per_row;
let w_base1 = row1 * packs_per_row;
let w_base2 = row2 * packs_per_row;
let w_base3 = row3 * packs_per_row;
let sb_base0 = row0 * gs_per_row;
let sb_base1 = row1 * gs_per_row;
let sb_base2 = row2 * gs_per_row;
let sb_base3 = row3 * gs_per_row;
let x_row_base = m_row * k;
let mut acc0 = 0.0f32;
let mut acc1 = 0.0f32;
let mut acc2 = 0.0f32;
let mut acc3 = 0.0f32;
let lane_x_off = lane * 4u32;
let lane_pack_off = lane;
for _b in range(0u32, k, 128u32) {
let xb = x_row_base + _b + lane_x_off;
let x0 = load(x[xb]).cast::<f32>();
let x1 = load(x[xb + 1u32]).cast::<f32>();
let x2 = load(x[xb + 2u32]).cast::<f32>();
let x3 = load(x[xb + 3u32]).cast::<f32>();
let xs = x0 + x1 + x2 + x3;
let g = (_b + lane_x_off) / 64u32;
let pack_off = _b / 4u32 + lane_pack_off;
let p0 = load(w[w_base0 + pack_off]);
let s0 = load(scales[sb_base0 + g]).cast::<f32>();
let bi0 = load(biases[sb_base0 + g]).cast::<f32>();
let q00 = (p0 & 255u32).cast::<f32>();
let q01 = ((p0 >> 8u32) & 255u32).cast::<f32>();
let q02 = ((p0 >> 16u32) & 255u32).cast::<f32>();
let q03 = ((p0 >> 24u32) & 255u32).cast::<f32>();
let qd0 = q00 * x0 + q01 * x1 + q02 * x2 + q03 * x3;
acc0 = acc0 + s0 * qd0 + bi0 * xs;
let p1 = load(w[w_base1 + pack_off]);
let s1 = load(scales[sb_base1 + g]).cast::<f32>();
let bi1 = load(biases[sb_base1 + g]).cast::<f32>();
let q10 = (p1 & 255u32).cast::<f32>();
let q11 = ((p1 >> 8u32) & 255u32).cast::<f32>();
let q12 = ((p1 >> 16u32) & 255u32).cast::<f32>();
let q13 = ((p1 >> 24u32) & 255u32).cast::<f32>();
let qd1 = q10 * x0 + q11 * x1 + q12 * x2 + q13 * x3;
acc1 = acc1 + s1 * qd1 + bi1 * xs;
let p2 = load(w[w_base2 + pack_off]);
let s2 = load(scales[sb_base2 + g]).cast::<f32>();
let bi2 = load(biases[sb_base2 + g]).cast::<f32>();
let q20 = (p2 & 255u32).cast::<f32>();
let q21 = ((p2 >> 8u32) & 255u32).cast::<f32>();
let q22 = ((p2 >> 16u32) & 255u32).cast::<f32>();
let q23 = ((p2 >> 24u32) & 255u32).cast::<f32>();
let qd2 = q20 * x0 + q21 * x1 + q22 * x2 + q23 * x3;
acc2 = acc2 + s2 * qd2 + bi2 * xs;
let p3 = load(w[w_base3 + pack_off]);
let s3 = load(scales[sb_base3 + g]).cast::<f32>();
let bi3 = load(biases[sb_base3 + g]).cast::<f32>();
let q30 = (p3 & 255u32).cast::<f32>();
let q31 = ((p3 >> 8u32) & 255u32).cast::<f32>();
let q32 = ((p3 >> 16u32) & 255u32).cast::<f32>();
let q33 = ((p3 >> 24u32) & 255u32).cast::<f32>();
let qd3 = q30 * x0 + q31 * x1 + q32 * x2 + q33 * x3;
acc3 = acc3 + s3 * qd3 + bi3 * xs;
}
let r0 = simd_sum(acc0);
let r1 = simd_sum(acc1);
let r2 = simd_sum(acc2);
let r3 = simd_sum(acc3);
if lane == 0u32 {
store(out[m_row * n + row0], r0.cast::<T>());
store(out[m_row * n + row1], r1.cast::<T>());
store(out[m_row * n + row2], r2.cast::<T>());
store(out[m_row * n + row3], r3.cast::<T>());
}
}
#[bench_kernel(
op="quantized",
subop="qmm_bm2_int8_fast",
class=QuantizedMatMul,
shapes=&QUANTIZED_SHAPES,
m=8,
group_size=64,
tpg=64,
bits=8,
tol=1e-2,
mlx="affine_qmm_fast_float16_t_gs_64_b_8_alN_true_batch_0",
metal_file="quantized.metal",
dtypes=&[metaltile_core::dtype::DType::F32, metaltile_core::dtype::DType::F16, metaltile_core::dtype::DType::BF16],
)]
#[kernel]
pub fn mt_qmm_bm2_int8_fast<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] gs_per_row: u32,
) {
let tg = tgid_x;
let m_tile = tgid_y;
let sg = simd_id;
let lane = simd_lane;
let row0 = tg * 8u32 + sg * 4u32;
let row1 = row0 + 1u32;
let row2 = row0 + 2u32;
let row3 = row0 + 3u32;
let packs_per_row = k / 4u32;
let w_base0 = row0 * packs_per_row;
let w_base1 = row1 * packs_per_row;
let w_base2 = row2 * packs_per_row;
let w_base3 = row3 * packs_per_row;
let sb_base0 = row0 * gs_per_row;
let sb_base1 = row1 * gs_per_row;
let sb_base2 = row2 * gs_per_row;
let sb_base3 = row3 * gs_per_row;
let m_row_a = m_tile * 2u32;
let m_row_b = m_row_a + 1u32;
let x_base_a = m_row_a * k;
let x_base_b = m_row_b * k;
let mut acc0_a = 0.0f32;
let mut acc0_b = 0.0f32;
let mut acc1_a = 0.0f32;
let mut acc1_b = 0.0f32;
let mut acc2_a = 0.0f32;
let mut acc2_b = 0.0f32;
let mut acc3_a = 0.0f32;
let mut acc3_b = 0.0f32;
let lane_x_off = lane * 4u32;
let lane_pack_off = lane;
for _b in range(0u32, k, 128u32) {
let xb_a = x_base_a + _b + lane_x_off;
let x0_a = load(x[xb_a]).cast::<f32>();
let x1_a = load(x[xb_a + 1u32]).cast::<f32>();
let x2_a = load(x[xb_a + 2u32]).cast::<f32>();
let x3_a = load(x[xb_a + 3u32]).cast::<f32>();
let xs_a = x0_a + x1_a + x2_a + x3_a;
let xb_b = x_base_b + _b + lane_x_off;
let x0_b = load(x[xb_b]).cast::<f32>();
let x1_b = load(x[xb_b + 1u32]).cast::<f32>();
let x2_b = load(x[xb_b + 2u32]).cast::<f32>();
let x3_b = load(x[xb_b + 3u32]).cast::<f32>();
let xs_b = x0_b + x1_b + x2_b + x3_b;
let g = (_b + lane_x_off) / 64u32;
let pack_off = _b / 4u32 + lane_pack_off;
let p0 = load(w[w_base0 + pack_off]);
let s0 = load(scales[sb_base0 + g]).cast::<f32>();
let bi0 = load(biases[sb_base0 + g]).cast::<f32>();
let q00 = (p0 & 255u32).cast::<f32>();
let q01 = ((p0 >> 8u32) & 255u32).cast::<f32>();
let q02 = ((p0 >> 16u32) & 255u32).cast::<f32>();
let q03 = ((p0 >> 24u32) & 255u32).cast::<f32>();
let qd0_a = q00 * x0_a + q01 * x1_a + q02 * x2_a + q03 * x3_a;
let qd0_b = q00 * x0_b + q01 * x1_b + q02 * x2_b + q03 * x3_b;
acc0_a = acc0_a + s0 * qd0_a + bi0 * xs_a;
acc0_b = acc0_b + s0 * qd0_b + bi0 * xs_b;
let p1 = load(w[w_base1 + pack_off]);
let s1 = load(scales[sb_base1 + g]).cast::<f32>();
let bi1 = load(biases[sb_base1 + g]).cast::<f32>();
let q10 = (p1 & 255u32).cast::<f32>();
let q11 = ((p1 >> 8u32) & 255u32).cast::<f32>();
let q12 = ((p1 >> 16u32) & 255u32).cast::<f32>();
let q13 = ((p1 >> 24u32) & 255u32).cast::<f32>();
let qd1_a = q10 * x0_a + q11 * x1_a + q12 * x2_a + q13 * x3_a;
let qd1_b = q10 * x0_b + q11 * x1_b + q12 * x2_b + q13 * x3_b;
acc1_a = acc1_a + s1 * qd1_a + bi1 * xs_a;
acc1_b = acc1_b + s1 * qd1_b + bi1 * xs_b;
let p2 = load(w[w_base2 + pack_off]);
let s2 = load(scales[sb_base2 + g]).cast::<f32>();
let bi2 = load(biases[sb_base2 + g]).cast::<f32>();
let q20 = (p2 & 255u32).cast::<f32>();
let q21 = ((p2 >> 8u32) & 255u32).cast::<f32>();
let q22 = ((p2 >> 16u32) & 255u32).cast::<f32>();
let q23 = ((p2 >> 24u32) & 255u32).cast::<f32>();
let qd2_a = q20 * x0_a + q21 * x1_a + q22 * x2_a + q23 * x3_a;
let qd2_b = q20 * x0_b + q21 * x1_b + q22 * x2_b + q23 * x3_b;
acc2_a = acc2_a + s2 * qd2_a + bi2 * xs_a;
acc2_b = acc2_b + s2 * qd2_b + bi2 * xs_b;
let p3 = load(w[w_base3 + pack_off]);
let s3 = load(scales[sb_base3 + g]).cast::<f32>();
let bi3 = load(biases[sb_base3 + g]).cast::<f32>();
let q30 = (p3 & 255u32).cast::<f32>();
let q31 = ((p3 >> 8u32) & 255u32).cast::<f32>();
let q32 = ((p3 >> 16u32) & 255u32).cast::<f32>();
let q33 = ((p3 >> 24u32) & 255u32).cast::<f32>();
let qd3_a = q30 * x0_a + q31 * x1_a + q32 * x2_a + q33 * x3_a;
let qd3_b = q30 * x0_b + q31 * x1_b + q32 * x2_b + q33 * x3_b;
acc3_a = acc3_a + s3 * qd3_a + bi3 * xs_a;
acc3_b = acc3_b + s3 * qd3_b + bi3 * xs_b;
}
let r0_a = simd_sum(acc0_a);
let r0_b = simd_sum(acc0_b);
let r1_a = simd_sum(acc1_a);
let r1_b = simd_sum(acc1_b);
let r2_a = simd_sum(acc2_a);
let r2_b = simd_sum(acc2_b);
let r3_a = simd_sum(acc3_a);
let r3_b = simd_sum(acc3_b);
if lane == 0u32 {
store(out[m_row_a * n + row0], r0_a.cast::<T>());
store(out[m_row_a * n + row1], r1_a.cast::<T>());
store(out[m_row_a * n + row2], r2_a.cast::<T>());
store(out[m_row_a * n + row3], r3_a.cast::<T>());
store(out[m_row_b * n + row0], r0_b.cast::<T>());
store(out[m_row_b * n + row1], r1_b.cast::<T>());
store(out[m_row_b * n + row2], r2_b.cast::<T>());
store(out[m_row_b * n + row3], r3_b.cast::<T>());
}
}
#[bench_kernel(
op="quantized",
subop="qmm_bm4_int8_fast",
class=QuantizedMatMul,
shapes=&QUANTIZED_SHAPES,
m=8,
group_size=64,
tpg=64,
bits=8,
tol=1e-2,
mlx="affine_qmm_fast_float16_t_gs_64_b_8_alN_true_batch_0",
metal_file="quantized.metal",
dtypes=&[metaltile_core::dtype::DType::F32, metaltile_core::dtype::DType::F16, metaltile_core::dtype::DType::BF16],
)]
#[kernel]
pub fn mt_qmm_bm4_int8_fast<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] gs_per_row: u32,
) {
let tg = tgid_x;
let m_tile = tgid_y;
let sg = simd_id;
let lane = simd_lane;
let row0 = tg * 8u32 + sg * 4u32;
let row1 = row0 + 1u32;
let row2 = row0 + 2u32;
let row3 = row0 + 3u32;
let packs_per_row = k / 4u32;
let w_base0 = row0 * packs_per_row;
let w_base1 = row1 * packs_per_row;
let w_base2 = row2 * packs_per_row;
let w_base3 = row3 * packs_per_row;
let sb_base0 = row0 * gs_per_row;
let sb_base1 = row1 * gs_per_row;
let sb_base2 = row2 * gs_per_row;
let sb_base3 = row3 * gs_per_row;
let m_row_a = m_tile * 4u32;
let m_row_b = m_row_a + 1u32;
let m_row_c = m_row_a + 2u32;
let m_row_d = m_row_a + 3u32;
let x_base_a = m_row_a * k;
let x_base_b = m_row_b * k;
let x_base_c = m_row_c * k;
let x_base_d = m_row_d * k;
let mut acc0_a = 0.0f32;
let mut acc0_b = 0.0f32;
let mut acc0_c = 0.0f32;
let mut acc0_d = 0.0f32;
let mut acc1_a = 0.0f32;
let mut acc1_b = 0.0f32;
let mut acc1_c = 0.0f32;
let mut acc1_d = 0.0f32;
let mut acc2_a = 0.0f32;
let mut acc2_b = 0.0f32;
let mut acc2_c = 0.0f32;
let mut acc2_d = 0.0f32;
let mut acc3_a = 0.0f32;
let mut acc3_b = 0.0f32;
let mut acc3_c = 0.0f32;
let mut acc3_d = 0.0f32;
let lane_x_off = lane * 4u32;
let lane_pack_off = lane;
for _b in range(0u32, k, 128u32) {
let xb_a = x_base_a + _b + lane_x_off;
let x0_a = load(x[xb_a]).cast::<f32>();
let x1_a = load(x[xb_a + 1u32]).cast::<f32>();
let x2_a = load(x[xb_a + 2u32]).cast::<f32>();
let x3_a = load(x[xb_a + 3u32]).cast::<f32>();
let xs_a = x0_a + x1_a + x2_a + x3_a;
let xb_b = x_base_b + _b + lane_x_off;
let x0_b = load(x[xb_b]).cast::<f32>();
let x1_b = load(x[xb_b + 1u32]).cast::<f32>();
let x2_b = load(x[xb_b + 2u32]).cast::<f32>();
let x3_b = load(x[xb_b + 3u32]).cast::<f32>();
let xs_b = x0_b + x1_b + x2_b + x3_b;
let xb_c = x_base_c + _b + lane_x_off;
let x0_c = load(x[xb_c]).cast::<f32>();
let x1_c = load(x[xb_c + 1u32]).cast::<f32>();
let x2_c = load(x[xb_c + 2u32]).cast::<f32>();
let x3_c = load(x[xb_c + 3u32]).cast::<f32>();
let xs_c = x0_c + x1_c + x2_c + x3_c;
let xb_d = x_base_d + _b + lane_x_off;
let x0_d = load(x[xb_d]).cast::<f32>();
let x1_d = load(x[xb_d + 1u32]).cast::<f32>();
let x2_d = load(x[xb_d + 2u32]).cast::<f32>();
let x3_d = load(x[xb_d + 3u32]).cast::<f32>();
let xs_d = x0_d + x1_d + x2_d + x3_d;
let g = (_b + lane_x_off) / 64u32;
let pack_off = _b / 4u32 + lane_pack_off;
let p0 = load(w[w_base0 + pack_off]);
let s0 = load(scales[sb_base0 + g]).cast::<f32>();
let bi0 = load(biases[sb_base0 + g]).cast::<f32>();
let q00 = (p0 & 255u32).cast::<f32>();
let q01 = ((p0 >> 8u32) & 255u32).cast::<f32>();
let q02 = ((p0 >> 16u32) & 255u32).cast::<f32>();
let q03 = ((p0 >> 24u32) & 255u32).cast::<f32>();
let qd0_a = q00 * x0_a + q01 * x1_a + q02 * x2_a + q03 * x3_a;
let qd0_b = q00 * x0_b + q01 * x1_b + q02 * x2_b + q03 * x3_b;
let qd0_c = q00 * x0_c + q01 * x1_c + q02 * x2_c + q03 * x3_c;
let qd0_d = q00 * x0_d + q01 * x1_d + q02 * x2_d + q03 * x3_d;
acc0_a = acc0_a + s0 * qd0_a + bi0 * xs_a;
acc0_b = acc0_b + s0 * qd0_b + bi0 * xs_b;
acc0_c = acc0_c + s0 * qd0_c + bi0 * xs_c;
acc0_d = acc0_d + s0 * qd0_d + bi0 * xs_d;
let p1 = load(w[w_base1 + pack_off]);
let s1 = load(scales[sb_base1 + g]).cast::<f32>();
let bi1 = load(biases[sb_base1 + g]).cast::<f32>();
let q10 = (p1 & 255u32).cast::<f32>();
let q11 = ((p1 >> 8u32) & 255u32).cast::<f32>();
let q12 = ((p1 >> 16u32) & 255u32).cast::<f32>();
let q13 = ((p1 >> 24u32) & 255u32).cast::<f32>();
let qd1_a = q10 * x0_a + q11 * x1_a + q12 * x2_a + q13 * x3_a;
let qd1_b = q10 * x0_b + q11 * x1_b + q12 * x2_b + q13 * x3_b;
let qd1_c = q10 * x0_c + q11 * x1_c + q12 * x2_c + q13 * x3_c;
let qd1_d = q10 * x0_d + q11 * x1_d + q12 * x2_d + q13 * x3_d;
acc1_a = acc1_a + s1 * qd1_a + bi1 * xs_a;
acc1_b = acc1_b + s1 * qd1_b + bi1 * xs_b;
acc1_c = acc1_c + s1 * qd1_c + bi1 * xs_c;
acc1_d = acc1_d + s1 * qd1_d + bi1 * xs_d;
let p2 = load(w[w_base2 + pack_off]);
let s2 = load(scales[sb_base2 + g]).cast::<f32>();
let bi2 = load(biases[sb_base2 + g]).cast::<f32>();
let q20 = (p2 & 255u32).cast::<f32>();
let q21 = ((p2 >> 8u32) & 255u32).cast::<f32>();
let q22 = ((p2 >> 16u32) & 255u32).cast::<f32>();
let q23 = ((p2 >> 24u32) & 255u32).cast::<f32>();
let qd2_a = q20 * x0_a + q21 * x1_a + q22 * x2_a + q23 * x3_a;
let qd2_b = q20 * x0_b + q21 * x1_b + q22 * x2_b + q23 * x3_b;
let qd2_c = q20 * x0_c + q21 * x1_c + q22 * x2_c + q23 * x3_c;
let qd2_d = q20 * x0_d + q21 * x1_d + q22 * x2_d + q23 * x3_d;
acc2_a = acc2_a + s2 * qd2_a + bi2 * xs_a;
acc2_b = acc2_b + s2 * qd2_b + bi2 * xs_b;
acc2_c = acc2_c + s2 * qd2_c + bi2 * xs_c;
acc2_d = acc2_d + s2 * qd2_d + bi2 * xs_d;
let p3 = load(w[w_base3 + pack_off]);
let s3 = load(scales[sb_base3 + g]).cast::<f32>();
let bi3 = load(biases[sb_base3 + g]).cast::<f32>();
let q30 = (p3 & 255u32).cast::<f32>();
let q31 = ((p3 >> 8u32) & 255u32).cast::<f32>();
let q32 = ((p3 >> 16u32) & 255u32).cast::<f32>();
let q33 = ((p3 >> 24u32) & 255u32).cast::<f32>();
let qd3_a = q30 * x0_a + q31 * x1_a + q32 * x2_a + q33 * x3_a;
let qd3_b = q30 * x0_b + q31 * x1_b + q32 * x2_b + q33 * x3_b;
let qd3_c = q30 * x0_c + q31 * x1_c + q32 * x2_c + q33 * x3_c;
let qd3_d = q30 * x0_d + q31 * x1_d + q32 * x2_d + q33 * x3_d;
acc3_a = acc3_a + s3 * qd3_a + bi3 * xs_a;
acc3_b = acc3_b + s3 * qd3_b + bi3 * xs_b;
acc3_c = acc3_c + s3 * qd3_c + bi3 * xs_c;
acc3_d = acc3_d + s3 * qd3_d + bi3 * xs_d;
}
let r0_a = simd_sum(acc0_a);
let r0_b = simd_sum(acc0_b);
let r0_c = simd_sum(acc0_c);
let r0_d = simd_sum(acc0_d);
let r1_a = simd_sum(acc1_a);
let r1_b = simd_sum(acc1_b);
let r1_c = simd_sum(acc1_c);
let r1_d = simd_sum(acc1_d);
let r2_a = simd_sum(acc2_a);
let r2_b = simd_sum(acc2_b);
let r2_c = simd_sum(acc2_c);
let r2_d = simd_sum(acc2_d);
let r3_a = simd_sum(acc3_a);
let r3_b = simd_sum(acc3_b);
let r3_c = simd_sum(acc3_c);
let r3_d = simd_sum(acc3_d);
if lane == 0u32 {
store(out[m_row_a * n + row0], r0_a.cast::<T>());
store(out[m_row_a * n + row1], r1_a.cast::<T>());
store(out[m_row_a * n + row2], r2_a.cast::<T>());
store(out[m_row_a * n + row3], r3_a.cast::<T>());
store(out[m_row_b * n + row0], r0_b.cast::<T>());
store(out[m_row_b * n + row1], r1_b.cast::<T>());
store(out[m_row_b * n + row2], r2_b.cast::<T>());
store(out[m_row_b * n + row3], r3_b.cast::<T>());
store(out[m_row_c * n + row0], r0_c.cast::<T>());
store(out[m_row_c * n + row1], r1_c.cast::<T>());
store(out[m_row_c * n + row2], r2_c.cast::<T>());
store(out[m_row_c * n + row3], r3_c.cast::<T>());
store(out[m_row_d * n + row0], r0_d.cast::<T>());
store(out[m_row_d * n + row1], r1_d.cast::<T>());
store(out[m_row_d * n + row2], r2_d.cast::<T>());
store(out[m_row_d * n + row3], r3_d.cast::<T>());
}
}
#[bench_kernel(
op="quantized",
subop="qmm_mma",
class=QuantizedMatMul,
shapes=&QUANTIZED_SHAPES,
// M=32 = the cell where the simdgroup-matrix MMA pays for itself.
// M < 32 leaves >= 50% of the 32×32 tile padded (wasted ALU); bm4
// keeps winning there. Selector routes M >= 32 && M %% 32 == 0 to mma.
m=32,
group_size=64,
tpg=128,
// bf16 round-trip on int4-quantized matmul: max_q=15 × group_size=64
// × bf16's 7-bit mantissa drifts ~7-8e-3 at large K (per
// crates/metaltile-std/src/mlx/binary.rs precedent — "bf16 drifts
// ~7.8e-3 on signed"). Tighter than 1e-2 trips the bench cosine
// check at production shapes (M=4096+, K=4096+) on Apple Paravirtual
// CI. tol=1e-2 keeps f32/f16 cells tight while passing bf16.
tol=1e-2,
mlx="affine_qmm_t_{tn}_gs_64_b_4_alN_true_batch_0",
metal_file="quantized.metal",
dtypes=&[metaltile_core::dtype::DType::F32, metaltile_core::dtype::DType::F16, metaltile_core::dtype::DType::BF16],
)]
#[kernel]
pub fn mt_qmm_mma<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] gs_per_row: u32,
) {
let n_tile = tgid_x;
let m_tile = tgid_y;
let lane = simd_lane;
let sg = simd_group_id();
let sm = sg / 2u32;
let sn = sg & 1u32;
let lane_in_tg = sg * 32u32 + lane;
let qid = lane / 4u32;
let fm = (qid & 4u32) + ((lane / 2u32) % 4u32);
let fn0 = (qid & 2u32) * 2u32 + (lane % 2u32) * 2u32;
let fn1 = fn0 + 1u32;
threadgroup_alloc("xs", 1152, T);
threadgroup_alloc("ws", 1152, T);
let c_f00 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f00, 0, 0.0f32);
simdgroup_elem_store(c_f00, 1, 0.0f32);
let c_f01 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f01, 0, 0.0f32);
simdgroup_elem_store(c_f01, 1, 0.0f32);
let c_f10 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f10, 0, 0.0f32);
simdgroup_elem_store(c_f10, 1, 0.0f32);
let c_f11 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f11, 0, 0.0f32);
simdgroup_elem_store(c_f11, 1, 0.0f32);
let a_f0 = simdgroup_alloc::<T, 8, 8>();
let a_f1 = simdgroup_alloc::<T, 8, 8>();
let b_f0 = simdgroup_alloc::<T, 8, 8>();
let b_f1 = simdgroup_alloc::<T, 8, 8>();
let w_row = lane_in_tg / 4u32;
let pack_in_row = lane_in_tg & 3u32;
let x_m_base = m_tile * 32u32; let w_n_base = n_tile * 32u32; let packs_per_row = k / 8u32;
let sb_base = (w_n_base + w_row) * gs_per_row;
let w_pack_row_base = (w_n_base + w_row) * packs_per_row;
let xs_ld_const = 36u32;
let ws_ld_const = 36u32;
let xs_ld = xs_ld_const;
let ws_ld = ws_ld_const;
let x_m_row = lane_in_tg / 4u32;
let x_k_quad = lane_in_tg & 3u32;
let x_k_base = x_k_quad * 8u32;
for kb in range(0u32, k, 32u32) {
let x_row_dev_base = (x_m_base + x_m_row) * k + kb + x_k_base;
let x_ws_base = x_m_row * xs_ld + x_k_base;
let xv0 = load(x[x_row_dev_base]).cast::<T>();
let xv1 = load(x[x_row_dev_base + 1u32]).cast::<T>();
let xv2 = load(x[x_row_dev_base + 2u32]).cast::<T>();
let xv3 = load(x[x_row_dev_base + 3u32]).cast::<T>();
let xv4 = load(x[x_row_dev_base + 4u32]).cast::<T>();
let xv5 = load(x[x_row_dev_base + 5u32]).cast::<T>();
let xv6 = load(x[x_row_dev_base + 6u32]).cast::<T>();
let xv7 = load(x[x_row_dev_base + 7u32]).cast::<T>();
threadgroup_store("xs", x_ws_base, xv0);
threadgroup_store("xs", x_ws_base + 1u32, xv1);
threadgroup_store("xs", x_ws_base + 2u32, xv2);
threadgroup_store("xs", x_ws_base + 3u32, xv3);
threadgroup_store("xs", x_ws_base + 4u32, xv4);
threadgroup_store("xs", x_ws_base + 5u32, xv5);
threadgroup_store("xs", x_ws_base + 6u32, xv6);
threadgroup_store("xs", x_ws_base + 7u32, xv7);
let pack_k_off = kb / 8u32 + pack_in_row;
let pack = load(w[w_pack_row_base + pack_k_off]);
let k_off = kb + pack_in_row * 8u32;
let g = k_off / 64u32; let s = load(scales[sb_base + g]).cast::<f32>();
let b = load(biases[sb_base + g]).cast::<f32>();
let s_16 = 0.0625f32;
let s_256 = 0.00390625f32;
let s_4096 = 0.000244140625f32;
let pack_hi = pack >> 16u32;
let q0 = (pack & 15u32).cast::<f32>();
let q1 = (pack & 240u32).cast::<f32>() * s_16;
let q2 = (pack & 3840u32).cast::<f32>() * s_256;
let q3 = (pack & 61440u32).cast::<f32>() * s_4096;
let q4 = (pack_hi & 15u32).cast::<f32>();
let q5 = (pack_hi & 240u32).cast::<f32>() * s_16;
let q6 = (pack_hi & 3840u32).cast::<f32>() * s_256;
let q7 = (pack_hi & 61440u32).cast::<f32>() * s_4096;
let ws_base = w_row * ws_ld + pack_in_row * 8u32;
threadgroup_store("ws", ws_base, (s * q0 + b).cast::<T>());
threadgroup_store("ws", ws_base + 1u32, (s * q1 + b).cast::<T>());
threadgroup_store("ws", ws_base + 2u32, (s * q2 + b).cast::<T>());
threadgroup_store("ws", ws_base + 3u32, (s * q3 + b).cast::<T>());
threadgroup_store("ws", ws_base + 4u32, (s * q4 + b).cast::<T>());
threadgroup_store("ws", ws_base + 5u32, (s * q5 + b).cast::<T>());
threadgroup_store("ws", ws_base + 6u32, (s * q6 + b).cast::<T>());
threadgroup_store("ws", ws_base + 7u32, (s * q7 + b).cast::<T>());
threadgroup_barrier();
let row_a0 = sm * 16u32 + fm; let row_a1 = sm * 16u32 + 8u32 + fm; let col_b0 = sn * 16u32; let col_b1 = sn * 16u32 + 8u32; simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 8u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 8u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 8u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 8u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 8u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 16u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 16u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 16u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 16u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 16u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 24u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 24u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 24u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 24u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 24u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
threadgroup_barrier();
}
let out_m_base = m_tile * 32u32 + sm * 16u32;
let out_n_base = n_tile * 32u32 + sn * 16u32;
store(out[(out_m_base + fm) * n + out_n_base + fn0], simdgroup_elem_load(c_f00, 0).cast::<T>());
store(out[(out_m_base + fm) * n + out_n_base + fn1], simdgroup_elem_load(c_f00, 1).cast::<T>());
store(
out[(out_m_base + fm) * n + out_n_base + 8u32 + fn0],
simdgroup_elem_load(c_f01, 0).cast::<T>(),
);
store(
out[(out_m_base + fm) * n + out_n_base + 8u32 + fn1],
simdgroup_elem_load(c_f01, 1).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + fn0],
simdgroup_elem_load(c_f10, 0).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + fn1],
simdgroup_elem_load(c_f10, 1).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + 8u32 + fn0],
simdgroup_elem_load(c_f11, 0).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + 8u32 + fn1],
simdgroup_elem_load(c_f11, 1).cast::<T>(),
);
}
#[bench_kernel(
op="quantized",
subop="qmm_mma_m16",
class=QuantizedMatMul,
shapes=&QUANTIZED_SHAPES,
// M=16 = the bm4 weak cell. bm4 wins moderate-N M=16 cells but
// loses wide-N on M2 (76-94% MT MLX). Half-height MMA targets this
// exact gap: zero padding waste at M=16 (vs MMA's 32×32 tile which
// would be 50% empty here), MMA-class ALU, N-amortized W reuse.
m=16,
group_size=64,
tpg=64,
// bf16 round-trip on int4-quantized matmul: max_q=15 × group_size=64
// × bf16's 7-bit mantissa drifts ~7-8e-3 at large K (per
// crates/metaltile-std/src/mlx/binary.rs precedent — "bf16 drifts
// ~7.8e-3 on signed"). Tighter than 1e-2 trips the bench cosine
// check at production shapes (M=4096+, K=4096+) on Apple Paravirtual
// CI. tol=1e-2 keeps f32/f16 cells tight while passing bf16.
tol=1e-2,
mlx="affine_qmm_t_{tn}_gs_64_b_4_alN_true_batch_0",
metal_file="quantized.metal",
dtypes=&[metaltile_core::dtype::DType::F32, metaltile_core::dtype::DType::F16, metaltile_core::dtype::DType::BF16],
)]
#[kernel]
pub fn mt_qmm_mma_m16<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] gs_per_row: u32,
) {
let n_tile = tgid_x;
let m_tile = tgid_y;
let lane = simd_lane;
let sg = simd_group_id();
let sm = 0u32;
let sn = sg & 1u32;
let lane_in_tg = sg * 32u32 + lane;
let qid = lane / 4u32;
let fm = (qid & 4u32) + ((lane / 2u32) % 4u32);
let fn0 = (qid & 2u32) * 2u32 + (lane % 2u32) * 2u32;
let fn1 = fn0 + 1u32;
threadgroup_alloc("xs", 576, T);
threadgroup_alloc("ws", 1152, T);
let c_f00 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f00, 0, 0.0f32);
simdgroup_elem_store(c_f00, 1, 0.0f32);
let c_f01 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f01, 0, 0.0f32);
simdgroup_elem_store(c_f01, 1, 0.0f32);
let c_f10 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f10, 0, 0.0f32);
simdgroup_elem_store(c_f10, 1, 0.0f32);
let c_f11 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f11, 0, 0.0f32);
simdgroup_elem_store(c_f11, 1, 0.0f32);
let a_f0 = simdgroup_alloc::<T, 8, 8>();
let a_f1 = simdgroup_alloc::<T, 8, 8>();
let b_f0 = simdgroup_alloc::<T, 8, 8>();
let b_f1 = simdgroup_alloc::<T, 8, 8>();
let x_m_base = m_tile * 16u32; let w_n_base = n_tile * 32u32; let packs_per_row = k / 8u32;
let xs_ld = 36u32;
let ws_ld = 36u32;
for kb in range(0u32, k, 32u32) {
let flat0 = lane_in_tg;
let flat1 = 64u32 + lane_in_tg;
let flat2 = 128u32 + lane_in_tg;
let flat3 = 192u32 + lane_in_tg;
let flat4 = 256u32 + lane_in_tg;
let flat5 = 320u32 + lane_in_tg;
let flat6 = 384u32 + lane_in_tg;
let flat7 = 448u32 + lane_in_tg;
let mr0 = flat0 / 32u32;
let mr1 = flat1 / 32u32;
let mr2 = flat2 / 32u32;
let mr3 = flat3 / 32u32;
let mr4 = flat4 / 32u32;
let mr5 = flat5 / 32u32;
let mr6 = flat6 / 32u32;
let mr7 = flat7 / 32u32;
let kc0 = flat0 & 31u32;
let kc1 = flat1 & 31u32;
let kc2 = flat2 & 31u32;
let kc3 = flat3 & 31u32;
let kc4 = flat4 & 31u32;
let kc5 = flat5 & 31u32;
let kc6 = flat6 & 31u32;
let kc7 = flat7 & 31u32;
threadgroup_store(
"xs",
mr0 * xs_ld + kc0,
load(x[(x_m_base + mr0) * k + kb + kc0]).cast::<T>(),
);
threadgroup_store(
"xs",
mr1 * xs_ld + kc1,
load(x[(x_m_base + mr1) * k + kb + kc1]).cast::<T>(),
);
threadgroup_store(
"xs",
mr2 * xs_ld + kc2,
load(x[(x_m_base + mr2) * k + kb + kc2]).cast::<T>(),
);
threadgroup_store(
"xs",
mr3 * xs_ld + kc3,
load(x[(x_m_base + mr3) * k + kb + kc3]).cast::<T>(),
);
threadgroup_store(
"xs",
mr4 * xs_ld + kc4,
load(x[(x_m_base + mr4) * k + kb + kc4]).cast::<T>(),
);
threadgroup_store(
"xs",
mr5 * xs_ld + kc5,
load(x[(x_m_base + mr5) * k + kb + kc5]).cast::<T>(),
);
threadgroup_store(
"xs",
mr6 * xs_ld + kc6,
load(x[(x_m_base + mr6) * k + kb + kc6]).cast::<T>(),
);
threadgroup_store(
"xs",
mr7 * xs_ld + kc7,
load(x[(x_m_base + mr7) * k + kb + kc7]).cast::<T>(),
);
let s_16 = 0.0625f32;
let s_256 = 0.00390625f32;
let s_4096 = 0.000244140625f32;
let pack_idx_0 = lane_in_tg;
let w_row_0 = pack_idx_0 / 4u32;
let pack_in_row_0 = pack_idx_0 & 3u32;
let pack_0 = load(w[(w_n_base + w_row_0) * packs_per_row + kb / 8u32 + pack_in_row_0]);
let k_off_0 = kb + pack_in_row_0 * 8u32;
let g_0 = k_off_0 / 64u32;
let sb_base_0 = (w_n_base + w_row_0) * gs_per_row;
let s_0 = load(scales[sb_base_0 + g_0]).cast::<f32>();
let b_0 = load(biases[sb_base_0 + g_0]).cast::<f32>();
let pack_hi_0 = pack_0 >> 16u32;
let q0_0 = (pack_0 & 15u32).cast::<f32>();
let q1_0 = (pack_0 & 240u32).cast::<f32>() * s_16;
let q2_0 = (pack_0 & 3840u32).cast::<f32>() * s_256;
let q3_0 = (pack_0 & 61440u32).cast::<f32>() * s_4096;
let q4_0 = (pack_hi_0 & 15u32).cast::<f32>();
let q5_0 = (pack_hi_0 & 240u32).cast::<f32>() * s_16;
let q6_0 = (pack_hi_0 & 3840u32).cast::<f32>() * s_256;
let q7_0 = (pack_hi_0 & 61440u32).cast::<f32>() * s_4096;
let ws_base_0 = w_row_0 * ws_ld + pack_in_row_0 * 8u32;
threadgroup_store("ws", ws_base_0, (s_0 * q0_0 + b_0).cast::<T>());
threadgroup_store("ws", ws_base_0 + 1u32, (s_0 * q1_0 + b_0).cast::<T>());
threadgroup_store("ws", ws_base_0 + 2u32, (s_0 * q2_0 + b_0).cast::<T>());
threadgroup_store("ws", ws_base_0 + 3u32, (s_0 * q3_0 + b_0).cast::<T>());
threadgroup_store("ws", ws_base_0 + 4u32, (s_0 * q4_0 + b_0).cast::<T>());
threadgroup_store("ws", ws_base_0 + 5u32, (s_0 * q5_0 + b_0).cast::<T>());
threadgroup_store("ws", ws_base_0 + 6u32, (s_0 * q6_0 + b_0).cast::<T>());
threadgroup_store("ws", ws_base_0 + 7u32, (s_0 * q7_0 + b_0).cast::<T>());
let pack_idx_1 = 64u32 + lane_in_tg;
let w_row_1 = pack_idx_1 / 4u32;
let pack_in_row_1 = pack_idx_1 & 3u32;
let pack_1 = load(w[(w_n_base + w_row_1) * packs_per_row + kb / 8u32 + pack_in_row_1]);
let k_off_1 = kb + pack_in_row_1 * 8u32;
let g_1 = k_off_1 / 64u32;
let sb_base_1 = (w_n_base + w_row_1) * gs_per_row;
let s_1 = load(scales[sb_base_1 + g_1]).cast::<f32>();
let b_1 = load(biases[sb_base_1 + g_1]).cast::<f32>();
let pack_hi_1 = pack_1 >> 16u32;
let q0_1 = (pack_1 & 15u32).cast::<f32>();
let q1_1 = (pack_1 & 240u32).cast::<f32>() * s_16;
let q2_1 = (pack_1 & 3840u32).cast::<f32>() * s_256;
let q3_1 = (pack_1 & 61440u32).cast::<f32>() * s_4096;
let q4_1 = (pack_hi_1 & 15u32).cast::<f32>();
let q5_1 = (pack_hi_1 & 240u32).cast::<f32>() * s_16;
let q6_1 = (pack_hi_1 & 3840u32).cast::<f32>() * s_256;
let q7_1 = (pack_hi_1 & 61440u32).cast::<f32>() * s_4096;
let ws_base_1 = w_row_1 * ws_ld + pack_in_row_1 * 8u32;
threadgroup_store("ws", ws_base_1, (s_1 * q0_1 + b_1).cast::<T>());
threadgroup_store("ws", ws_base_1 + 1u32, (s_1 * q1_1 + b_1).cast::<T>());
threadgroup_store("ws", ws_base_1 + 2u32, (s_1 * q2_1 + b_1).cast::<T>());
threadgroup_store("ws", ws_base_1 + 3u32, (s_1 * q3_1 + b_1).cast::<T>());
threadgroup_store("ws", ws_base_1 + 4u32, (s_1 * q4_1 + b_1).cast::<T>());
threadgroup_store("ws", ws_base_1 + 5u32, (s_1 * q5_1 + b_1).cast::<T>());
threadgroup_store("ws", ws_base_1 + 6u32, (s_1 * q6_1 + b_1).cast::<T>());
threadgroup_store("ws", ws_base_1 + 7u32, (s_1 * q7_1 + b_1).cast::<T>());
threadgroup_barrier();
let row_a0 = sm * 16u32 + fm; let row_a1 = sm * 16u32 + 8u32 + fm; let col_b0 = sn * 16u32; let col_b1 = sn * 16u32 + 8u32; simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + fn1));
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 8u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 8u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 8u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 8u32 + fn1));
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 8u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 16u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 16u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 16u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 16u32 + fn1));
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 16u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 24u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 24u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 24u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 24u32 + fn1));
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 24u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_matmul(a_f1, b_f1, c_f11);
threadgroup_barrier();
}
let out_m_base = m_tile * 16u32;
let out_n_base = n_tile * 32u32 + sn * 16u32;
store(out[(out_m_base + fm) * n + out_n_base + fn0], simdgroup_elem_load(c_f00, 0).cast::<T>());
store(out[(out_m_base + fm) * n + out_n_base + fn1], simdgroup_elem_load(c_f00, 1).cast::<T>());
store(
out[(out_m_base + fm) * n + out_n_base + 8u32 + fn0],
simdgroup_elem_load(c_f01, 0).cast::<T>(),
);
store(
out[(out_m_base + fm) * n + out_n_base + 8u32 + fn1],
simdgroup_elem_load(c_f01, 1).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + fn0],
simdgroup_elem_load(c_f10, 0).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + fn1],
simdgroup_elem_load(c_f10, 1).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + 8u32 + fn0],
simdgroup_elem_load(c_f11, 0).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + 8u32 + fn1],
simdgroup_elem_load(c_f11, 1).cast::<T>(),
);
}
#[bench_kernel(
op="quantized",
subop="qmm_mma_int8",
class=QuantizedMatMul,
shapes=&QUANTIZED_SHAPES,
m=32,
group_size=64,
tpg=128,
bits=8,
// int8 max_q=255 amplifies bf16 round-trip drift further than int4's 15.
// At production shapes (M=4096+, K=4096+) bf16 cosine drifts ~8-9e-3.
// tol=1e-2 keeps f32/f16 cells tight while passing bf16.
tol=1e-2,
mlx="affine_qmm_t_{tn}_gs_64_b_8_alN_true_batch_0",
metal_file="quantized.metal",
dtypes=&[metaltile_core::dtype::DType::F32, metaltile_core::dtype::DType::F16, metaltile_core::dtype::DType::BF16],
)]
#[kernel]
pub fn mt_qmm_mma_int8<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] gs_per_row: u32,
) {
let n_tile = tgid_x;
let m_tile = tgid_y;
let lane = simd_lane;
let sg = simd_group_id();
let sm = sg / 2u32;
let sn = sg & 1u32;
let lane_in_tg = sg * 32u32 + lane;
let qid = lane / 4u32;
let fm = (qid & 4u32) + ((lane / 2u32) % 4u32);
let fn0 = (qid & 2u32) * 2u32 + (lane % 2u32) * 2u32;
let fn1 = fn0 + 1u32;
threadgroup_alloc("xs", 1152, T);
threadgroup_alloc("ws", 1152, T);
let c_f00 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f00, 0, 0.0f32);
simdgroup_elem_store(c_f00, 1, 0.0f32);
let c_f01 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f01, 0, 0.0f32);
simdgroup_elem_store(c_f01, 1, 0.0f32);
let c_f10 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f10, 0, 0.0f32);
simdgroup_elem_store(c_f10, 1, 0.0f32);
let c_f11 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f11, 0, 0.0f32);
simdgroup_elem_store(c_f11, 1, 0.0f32);
let a_f0 = simdgroup_alloc::<T, 8, 8>();
let a_f1 = simdgroup_alloc::<T, 8, 8>();
let b_f0 = simdgroup_alloc::<T, 8, 8>();
let b_f1 = simdgroup_alloc::<T, 8, 8>();
let x_m_base = m_tile * 32u32; let w_n_base = n_tile * 32u32; let packs_per_row = k / 4u32;
let xs_ld_const = 36u32;
let ws_ld_const = 36u32;
let xs_ld = xs_ld_const;
let ws_ld = ws_ld_const;
let x_m_row = lane_in_tg / 4u32;
let x_k_quad = lane_in_tg & 3u32;
let x_k_base = x_k_quad * 8u32;
for kb in range(0u32, k, 32u32) {
let x_row_dev_base = (x_m_base + x_m_row) * k + kb + x_k_base;
let x_ws_base = x_m_row * xs_ld + x_k_base;
let xv0 = load(x[x_row_dev_base]).cast::<T>();
let xv1 = load(x[x_row_dev_base + 1u32]).cast::<T>();
let xv2 = load(x[x_row_dev_base + 2u32]).cast::<T>();
let xv3 = load(x[x_row_dev_base + 3u32]).cast::<T>();
let xv4 = load(x[x_row_dev_base + 4u32]).cast::<T>();
let xv5 = load(x[x_row_dev_base + 5u32]).cast::<T>();
let xv6 = load(x[x_row_dev_base + 6u32]).cast::<T>();
let xv7 = load(x[x_row_dev_base + 7u32]).cast::<T>();
threadgroup_store("xs", x_ws_base, xv0);
threadgroup_store("xs", x_ws_base + 1u32, xv1);
threadgroup_store("xs", x_ws_base + 2u32, xv2);
threadgroup_store("xs", x_ws_base + 3u32, xv3);
threadgroup_store("xs", x_ws_base + 4u32, xv4);
threadgroup_store("xs", x_ws_base + 5u32, xv5);
threadgroup_store("xs", x_ws_base + 6u32, xv6);
threadgroup_store("xs", x_ws_base + 7u32, xv7);
let pack_idx_0 = lane_in_tg;
let w_row_0 = pack_idx_0 / 8u32;
let pack_in_row_0 = pack_idx_0 & 7u32;
let pack_0 = load(w[(w_n_base + w_row_0) * packs_per_row + kb / 4u32 + pack_in_row_0]);
let k_off_0 = kb + pack_in_row_0 * 4u32;
let g_0 = k_off_0 / 64u32; let sb_base_0 = (w_n_base + w_row_0) * gs_per_row;
let s_0 = load(scales[sb_base_0 + g_0]).cast::<f32>();
let b_0 = load(biases[sb_base_0 + g_0]).cast::<f32>();
let q0_0 = (pack_0 & 255u32).cast::<f32>();
let q1_0 = ((pack_0 >> 8u32) & 255u32).cast::<f32>();
let q2_0 = ((pack_0 >> 16u32) & 255u32).cast::<f32>();
let q3_0 = ((pack_0 >> 24u32) & 255u32).cast::<f32>();
let ws_base_0 = w_row_0 * ws_ld + pack_in_row_0 * 4u32;
threadgroup_store("ws", ws_base_0, (s_0 * q0_0 + b_0).cast::<T>());
threadgroup_store("ws", ws_base_0 + 1u32, (s_0 * q1_0 + b_0).cast::<T>());
threadgroup_store("ws", ws_base_0 + 2u32, (s_0 * q2_0 + b_0).cast::<T>());
threadgroup_store("ws", ws_base_0 + 3u32, (s_0 * q3_0 + b_0).cast::<T>());
let pack_idx_1 = 128u32 + lane_in_tg;
let w_row_1 = pack_idx_1 / 8u32;
let pack_in_row_1 = pack_idx_1 & 7u32;
let pack_1 = load(w[(w_n_base + w_row_1) * packs_per_row + kb / 4u32 + pack_in_row_1]);
let k_off_1 = kb + pack_in_row_1 * 4u32;
let g_1 = k_off_1 / 64u32;
let sb_base_1 = (w_n_base + w_row_1) * gs_per_row;
let s_1 = load(scales[sb_base_1 + g_1]).cast::<f32>();
let b_1 = load(biases[sb_base_1 + g_1]).cast::<f32>();
let q0_1 = (pack_1 & 255u32).cast::<f32>();
let q1_1 = ((pack_1 >> 8u32) & 255u32).cast::<f32>();
let q2_1 = ((pack_1 >> 16u32) & 255u32).cast::<f32>();
let q3_1 = ((pack_1 >> 24u32) & 255u32).cast::<f32>();
let ws_base_1 = w_row_1 * ws_ld + pack_in_row_1 * 4u32;
threadgroup_store("ws", ws_base_1, (s_1 * q0_1 + b_1).cast::<T>());
threadgroup_store("ws", ws_base_1 + 1u32, (s_1 * q1_1 + b_1).cast::<T>());
threadgroup_store("ws", ws_base_1 + 2u32, (s_1 * q2_1 + b_1).cast::<T>());
threadgroup_store("ws", ws_base_1 + 3u32, (s_1 * q3_1 + b_1).cast::<T>());
threadgroup_barrier();
let row_a0 = sm * 16u32 + fm; let row_a1 = sm * 16u32 + 8u32 + fm; let col_b0 = sn * 16u32; let col_b1 = sn * 16u32 + 8u32; simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 8u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 8u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 8u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 8u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 8u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 16u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 16u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 16u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 16u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 16u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 24u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 24u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 24u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 24u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 24u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
threadgroup_barrier();
}
let out_m_base = m_tile * 32u32 + sm * 16u32;
let out_n_base = n_tile * 32u32 + sn * 16u32;
store(out[(out_m_base + fm) * n + out_n_base + fn0], simdgroup_elem_load(c_f00, 0).cast::<T>());
store(out[(out_m_base + fm) * n + out_n_base + fn1], simdgroup_elem_load(c_f00, 1).cast::<T>());
store(
out[(out_m_base + fm) * n + out_n_base + 8u32 + fn0],
simdgroup_elem_load(c_f01, 0).cast::<T>(),
);
store(
out[(out_m_base + fm) * n + out_n_base + 8u32 + fn1],
simdgroup_elem_load(c_f01, 1).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + fn0],
simdgroup_elem_load(c_f10, 0).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + fn1],
simdgroup_elem_load(c_f10, 1).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + 8u32 + fn0],
simdgroup_elem_load(c_f11, 0).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + 8u32 + fn1],
simdgroup_elem_load(c_f11, 1).cast::<T>(),
);
}
#[bench_kernel(
op="quantized",
subop="qmm_mma_m16_int8",
class=QuantizedMatMul,
shapes=&QUANTIZED_SHAPES,
m=16,
group_size=64,
tpg=64,
bits=8,
// Same bf16 tolerance rationale as mt_qmm_mma_int8.
tol=1e-2,
mlx="affine_qmm_t_{tn}_gs_64_b_8_alN_true_batch_0",
metal_file="quantized.metal",
dtypes=&[metaltile_core::dtype::DType::F32, metaltile_core::dtype::DType::F16, metaltile_core::dtype::DType::BF16],
)]
#[kernel]
pub fn mt_qmm_mma_m16_int8<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] gs_per_row: u32,
) {
let n_tile = tgid_x;
let m_tile = tgid_y;
let lane = simd_lane;
let sg = simd_group_id();
let sm = 0u32;
let sn = sg & 1u32;
let lane_in_tg = sg * 32u32 + lane;
let qid = lane / 4u32;
let fm = (qid & 4u32) + ((lane / 2u32) % 4u32);
let fn0 = (qid & 2u32) * 2u32 + (lane % 2u32) * 2u32;
let fn1 = fn0 + 1u32;
threadgroup_alloc("xs", 576, T);
threadgroup_alloc("ws", 1152, T);
let c_f00 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f00, 0, 0.0f32);
simdgroup_elem_store(c_f00, 1, 0.0f32);
let c_f01 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f01, 0, 0.0f32);
simdgroup_elem_store(c_f01, 1, 0.0f32);
let c_f10 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f10, 0, 0.0f32);
simdgroup_elem_store(c_f10, 1, 0.0f32);
let c_f11 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f11, 0, 0.0f32);
simdgroup_elem_store(c_f11, 1, 0.0f32);
let a_f0 = simdgroup_alloc::<T, 8, 8>();
let a_f1 = simdgroup_alloc::<T, 8, 8>();
let b_f0 = simdgroup_alloc::<T, 8, 8>();
let b_f1 = simdgroup_alloc::<T, 8, 8>();
let x_m_base = m_tile * 16u32; let w_n_base = n_tile * 32u32; let packs_per_row = k / 4u32;
let xs_ld = 36u32;
let ws_ld = 36u32;
for kb in range(0u32, k, 32u32) {
let flat0 = lane_in_tg;
let flat1 = 64u32 + lane_in_tg;
let flat2 = 128u32 + lane_in_tg;
let flat3 = 192u32 + lane_in_tg;
let flat4 = 256u32 + lane_in_tg;
let flat5 = 320u32 + lane_in_tg;
let flat6 = 384u32 + lane_in_tg;
let flat7 = 448u32 + lane_in_tg;
let mr0 = flat0 / 32u32;
let mr1 = flat1 / 32u32;
let mr2 = flat2 / 32u32;
let mr3 = flat3 / 32u32;
let mr4 = flat4 / 32u32;
let mr5 = flat5 / 32u32;
let mr6 = flat6 / 32u32;
let mr7 = flat7 / 32u32;
let kc0 = flat0 & 31u32;
let kc1 = flat1 & 31u32;
let kc2 = flat2 & 31u32;
let kc3 = flat3 & 31u32;
let kc4 = flat4 & 31u32;
let kc5 = flat5 & 31u32;
let kc6 = flat6 & 31u32;
let kc7 = flat7 & 31u32;
threadgroup_store(
"xs",
mr0 * xs_ld + kc0,
load(x[(x_m_base + mr0) * k + kb + kc0]).cast::<T>(),
);
threadgroup_store(
"xs",
mr1 * xs_ld + kc1,
load(x[(x_m_base + mr1) * k + kb + kc1]).cast::<T>(),
);
threadgroup_store(
"xs",
mr2 * xs_ld + kc2,
load(x[(x_m_base + mr2) * k + kb + kc2]).cast::<T>(),
);
threadgroup_store(
"xs",
mr3 * xs_ld + kc3,
load(x[(x_m_base + mr3) * k + kb + kc3]).cast::<T>(),
);
threadgroup_store(
"xs",
mr4 * xs_ld + kc4,
load(x[(x_m_base + mr4) * k + kb + kc4]).cast::<T>(),
);
threadgroup_store(
"xs",
mr5 * xs_ld + kc5,
load(x[(x_m_base + mr5) * k + kb + kc5]).cast::<T>(),
);
threadgroup_store(
"xs",
mr6 * xs_ld + kc6,
load(x[(x_m_base + mr6) * k + kb + kc6]).cast::<T>(),
);
threadgroup_store(
"xs",
mr7 * xs_ld + kc7,
load(x[(x_m_base + mr7) * k + kb + kc7]).cast::<T>(),
);
let pack_idx_0 = lane_in_tg;
let w_row_0 = pack_idx_0 / 8u32;
let pack_in_row_0 = pack_idx_0 & 7u32;
let pack_0 = load(w[(w_n_base + w_row_0) * packs_per_row + kb / 4u32 + pack_in_row_0]);
let k_off_0 = kb + pack_in_row_0 * 4u32;
let g_0 = k_off_0 / 64u32;
let sb_base_0 = (w_n_base + w_row_0) * gs_per_row;
let s_0 = load(scales[sb_base_0 + g_0]).cast::<f32>();
let b_0 = load(biases[sb_base_0 + g_0]).cast::<f32>();
let q0_0 = (pack_0 & 255u32).cast::<f32>();
let q1_0 = ((pack_0 >> 8u32) & 255u32).cast::<f32>();
let q2_0 = ((pack_0 >> 16u32) & 255u32).cast::<f32>();
let q3_0 = ((pack_0 >> 24u32) & 255u32).cast::<f32>();
let ws_base_0 = w_row_0 * ws_ld + pack_in_row_0 * 4u32;
threadgroup_store("ws", ws_base_0, (s_0 * q0_0 + b_0).cast::<T>());
threadgroup_store("ws", ws_base_0 + 1u32, (s_0 * q1_0 + b_0).cast::<T>());
threadgroup_store("ws", ws_base_0 + 2u32, (s_0 * q2_0 + b_0).cast::<T>());
threadgroup_store("ws", ws_base_0 + 3u32, (s_0 * q3_0 + b_0).cast::<T>());
let pack_idx_1 = 64u32 + lane_in_tg;
let w_row_1 = pack_idx_1 / 8u32;
let pack_in_row_1 = pack_idx_1 & 7u32;
let pack_1 = load(w[(w_n_base + w_row_1) * packs_per_row + kb / 4u32 + pack_in_row_1]);
let k_off_1 = kb + pack_in_row_1 * 4u32;
let g_1 = k_off_1 / 64u32;
let sb_base_1 = (w_n_base + w_row_1) * gs_per_row;
let s_1 = load(scales[sb_base_1 + g_1]).cast::<f32>();
let b_1 = load(biases[sb_base_1 + g_1]).cast::<f32>();
let q0_1 = (pack_1 & 255u32).cast::<f32>();
let q1_1 = ((pack_1 >> 8u32) & 255u32).cast::<f32>();
let q2_1 = ((pack_1 >> 16u32) & 255u32).cast::<f32>();
let q3_1 = ((pack_1 >> 24u32) & 255u32).cast::<f32>();
let ws_base_1 = w_row_1 * ws_ld + pack_in_row_1 * 4u32;
threadgroup_store("ws", ws_base_1, (s_1 * q0_1 + b_1).cast::<T>());
threadgroup_store("ws", ws_base_1 + 1u32, (s_1 * q1_1 + b_1).cast::<T>());
threadgroup_store("ws", ws_base_1 + 2u32, (s_1 * q2_1 + b_1).cast::<T>());
threadgroup_store("ws", ws_base_1 + 3u32, (s_1 * q3_1 + b_1).cast::<T>());
let pack_idx_2 = 128u32 + lane_in_tg;
let w_row_2 = pack_idx_2 / 8u32;
let pack_in_row_2 = pack_idx_2 & 7u32;
let pack_2 = load(w[(w_n_base + w_row_2) * packs_per_row + kb / 4u32 + pack_in_row_2]);
let k_off_2 = kb + pack_in_row_2 * 4u32;
let g_2 = k_off_2 / 64u32;
let sb_base_2 = (w_n_base + w_row_2) * gs_per_row;
let s_2 = load(scales[sb_base_2 + g_2]).cast::<f32>();
let b_2 = load(biases[sb_base_2 + g_2]).cast::<f32>();
let q0_2 = (pack_2 & 255u32).cast::<f32>();
let q1_2 = ((pack_2 >> 8u32) & 255u32).cast::<f32>();
let q2_2 = ((pack_2 >> 16u32) & 255u32).cast::<f32>();
let q3_2 = ((pack_2 >> 24u32) & 255u32).cast::<f32>();
let ws_base_2 = w_row_2 * ws_ld + pack_in_row_2 * 4u32;
threadgroup_store("ws", ws_base_2, (s_2 * q0_2 + b_2).cast::<T>());
threadgroup_store("ws", ws_base_2 + 1u32, (s_2 * q1_2 + b_2).cast::<T>());
threadgroup_store("ws", ws_base_2 + 2u32, (s_2 * q2_2 + b_2).cast::<T>());
threadgroup_store("ws", ws_base_2 + 3u32, (s_2 * q3_2 + b_2).cast::<T>());
let pack_idx_3 = 192u32 + lane_in_tg;
let w_row_3 = pack_idx_3 / 8u32;
let pack_in_row_3 = pack_idx_3 & 7u32;
let pack_3 = load(w[(w_n_base + w_row_3) * packs_per_row + kb / 4u32 + pack_in_row_3]);
let k_off_3 = kb + pack_in_row_3 * 4u32;
let g_3 = k_off_3 / 64u32;
let sb_base_3 = (w_n_base + w_row_3) * gs_per_row;
let s_3 = load(scales[sb_base_3 + g_3]).cast::<f32>();
let b_3 = load(biases[sb_base_3 + g_3]).cast::<f32>();
let q0_3 = (pack_3 & 255u32).cast::<f32>();
let q1_3 = ((pack_3 >> 8u32) & 255u32).cast::<f32>();
let q2_3 = ((pack_3 >> 16u32) & 255u32).cast::<f32>();
let q3_3 = ((pack_3 >> 24u32) & 255u32).cast::<f32>();
let ws_base_3 = w_row_3 * ws_ld + pack_in_row_3 * 4u32;
threadgroup_store("ws", ws_base_3, (s_3 * q0_3 + b_3).cast::<T>());
threadgroup_store("ws", ws_base_3 + 1u32, (s_3 * q1_3 + b_3).cast::<T>());
threadgroup_store("ws", ws_base_3 + 2u32, (s_3 * q2_3 + b_3).cast::<T>());
threadgroup_store("ws", ws_base_3 + 3u32, (s_3 * q3_3 + b_3).cast::<T>());
threadgroup_barrier();
let row_a0 = sm * 16u32 + fm; let row_a1 = sm * 16u32 + 8u32 + fm; let col_b0 = sn * 16u32; let col_b1 = sn * 16u32 + 8u32; simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + fn1));
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 8u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 8u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 8u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 8u32 + fn1));
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 8u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 16u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 16u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 16u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 16u32 + fn1));
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 16u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 24u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 24u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 24u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 24u32 + fn1));
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 24u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_matmul(a_f1, b_f1, c_f11);
threadgroup_barrier();
}
let out_m_base = m_tile * 16u32;
let out_n_base = n_tile * 32u32 + sn * 16u32;
store(out[(out_m_base + fm) * n + out_n_base + fn0], simdgroup_elem_load(c_f00, 0).cast::<T>());
store(out[(out_m_base + fm) * n + out_n_base + fn1], simdgroup_elem_load(c_f00, 1).cast::<T>());
store(
out[(out_m_base + fm) * n + out_n_base + 8u32 + fn0],
simdgroup_elem_load(c_f01, 0).cast::<T>(),
);
store(
out[(out_m_base + fm) * n + out_n_base + 8u32 + fn1],
simdgroup_elem_load(c_f01, 1).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + fn0],
simdgroup_elem_load(c_f10, 0).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + fn1],
simdgroup_elem_load(c_f10, 1).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + 8u32 + fn0],
simdgroup_elem_load(c_f11, 0).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + 8u32 + fn1],
simdgroup_elem_load(c_f11, 1).cast::<T>(),
);
}
#[bench_kernel(
op="affine",
subop="dequantize_int4",
class=AffineDequantize,
bits=4,
group_size=64,
n_groups=4096,
batch=1,
tpg=32,
// tol=1e-2 — bf16 round-trip error scales with max_q (= 15). At
// n_groups=4096 the worst-case absolute drift is ~3e-3.
tol=1e-2,
metal_file="quantized.metal",
)]
#[kernel]
pub fn mt_affine_dequantize_int4<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
mut out: Tensor<T>,
#[constexpr] group_size: u32,
) {
let pack_idx = program_id::<0>();
let pack_factor = 8u32;
let oindex = pack_idx * pack_factor;
let g_idx = oindex / group_size;
let scale = load(scales[g_idx]).cast::<f32>();
let bias = load(biases[g_idx]).cast::<f32>();
let val = load(w[pack_idx]);
let q0 = (val >> 0u32) & 15u32;
let q1 = (val >> 4u32) & 15u32;
let q2 = (val >> 8u32) & 15u32;
let q3 = (val >> 12u32) & 15u32;
let q4 = (val >> 16u32) & 15u32;
let q5 = (val >> 20u32) & 15u32;
let q6 = (val >> 24u32) & 15u32;
let q7 = (val >> 28u32) & 15u32;
store(out[oindex + 0u32], (scale * q0.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 1u32], (scale * q1.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 2u32], (scale * q2.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 3u32], (scale * q3.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 4u32], (scale * q4.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 5u32], (scale * q5.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 6u32], (scale * q6.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 7u32], (scale * q7.cast::<f32>() + bias).cast::<T>());
}
#[bench_kernel(
op="affine",
subop="quantize_int4",
class=AffineQuantize,
bits=4,
group_size=64,
n_groups=4096,
batch=1,
tpg=32,
tol=1e-1,
metal_file="quantized.metal",
)]
#[kernel]
pub fn mt_affine_quantize_int4<T>(
w: Tensor<T>,
mut out: Tensor<u32>,
mut scales: Tensor<T>,
mut biases: Tensor<T>,
#[constexpr] group_size: u32,
) {
let g_idx = tgid_x;
let lane = tid;
let in_base = g_idx * group_size;
let v0 = load(w[in_base + lane * 2u32]).cast::<f32>();
let v1 = load(w[in_base + lane * 2u32 + 1u32]).cast::<f32>();
let local_min = select(v0 < v1, v0, v1);
let local_max = select(v0 > v1, v0, v1);
let w_min = simd_min(local_min);
let w_max = simd_max(local_max);
let n_bins = 15.0f32;
let raw_scale = (w_max - w_min) / n_bins;
let eps = 1.0e-7f32;
let scale = select(raw_scale < eps, 1.0f32, raw_scale);
let inv_scale = 1.0f32 / scale;
let bias = w_min;
if lane == 0u32 {
store(scales[g_idx], scale.cast::<T>());
store(biases[g_idx], bias.cast::<T>());
}
let packs_per_group = group_size / 8u32;
if lane < packs_per_group {
let pack_in_base = in_base + lane * 8u32;
let mut acc = 0u32;
for k in range(0u32, 8u32, 1u32) {
let v = load(w[pack_in_base + k]).cast::<f32>();
let q_f = (v - bias) * inv_scale + 0.5f32;
let q_c = select(q_f > 15.0f32, 15.0f32, select(q_f < 0.0f32, 0.0f32, q_f));
let q = q_c.cast::<u32>();
acc = acc | (q << (k * 4u32));
}
store(out[g_idx * packs_per_group + lane], acc);
}
}
#[bench_kernel(
op="affine",
subop="dequantize_int8",
class=AffineDequantize,
bits=8,
group_size=64,
n_groups=4096,
batch=1,
tpg=32,
// tol=1e-1 — int8 max_q=255 amplifies bf16 round-trip drift; the
// worst case at n_groups=4096 is ~5e-2.
tol=1e-1,
metal_file="quantized.metal",
)]
#[kernel]
pub fn mt_affine_dequantize_int8<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
mut out: Tensor<T>,
#[constexpr] group_size: u32,
) {
let pack_idx = program_id::<0>();
let pack_factor = 4u32;
let oindex = pack_idx * pack_factor;
let g_idx = oindex / group_size;
let scale = load(scales[g_idx]).cast::<f32>();
let bias = load(biases[g_idx]).cast::<f32>();
let val = load(w[pack_idx]);
let q0 = (val >> 0u32) & 255u32;
let q1 = (val >> 8u32) & 255u32;
let q2 = (val >> 16u32) & 255u32;
let q3 = (val >> 24u32) & 255u32;
store(out[oindex + 0u32], (scale * q0.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 1u32], (scale * q1.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 2u32], (scale * q2.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 3u32], (scale * q3.cast::<f32>() + bias).cast::<T>());
}
#[bench_kernel(
op="affine",
subop="quantize_int8",
class=AffineQuantize,
bits=8,
group_size=64,
n_groups=4096,
batch=1,
tpg=32,
tol=1e-1,
metal_file="quantized.metal",
)]
#[kernel]
pub fn mt_affine_quantize_int8<T>(
w: Tensor<T>,
mut out: Tensor<u32>,
mut scales: Tensor<T>,
mut biases: Tensor<T>,
#[constexpr] group_size: u32,
) {
let g_idx = tgid_x;
let lane = tid;
let in_base = g_idx * group_size;
let v0 = load(w[in_base + lane * 2u32]).cast::<f32>();
let v1 = load(w[in_base + lane * 2u32 + 1u32]).cast::<f32>();
let local_min = select(v0 < v1, v0, v1);
let local_max = select(v0 > v1, v0, v1);
let w_min = simd_min(local_min);
let w_max = simd_max(local_max);
let n_bins = 255.0f32;
let raw_scale = (w_max - w_min) / n_bins;
let eps = 1.0e-7f32;
let scale = select(raw_scale < eps, 1.0f32, raw_scale);
let inv_scale = 1.0f32 / scale;
let bias = w_min;
if lane == 0u32 {
store(scales[g_idx], scale.cast::<T>());
store(biases[g_idx], bias.cast::<T>());
}
let packs_per_group = group_size / 4u32;
if lane < packs_per_group {
let pack_in_base = in_base + lane * 4u32;
let mut acc = 0u32;
for k in range(0u32, 4u32, 1u32) {
let v = load(w[pack_in_base + k]).cast::<f32>();
let q_f = (v - bias) * inv_scale + 0.5f32;
let q_c = select(q_f > 255.0f32, 255.0f32, select(q_f < 0.0f32, 0.0f32, q_f));
let q = q_c.cast::<u32>();
acc = acc | (q << (k * 8u32));
}
store(out[g_idx * packs_per_group + lane], acc);
}
}
#[bench_kernel(
op="affine",
subop="dequantize_int2",
class=AffineDequantize,
bits=2,
group_size=64,
n_groups=4096,
batch=1,
tpg=32,
// tol=5e-3 — int2 max_q=3; tightest of the dequant family, the
// worst-case bf16 round-trip drift at n_groups=4096 is ~1e-3.
tol=5e-3,
metal_file="quantized.metal",
)]
#[kernel]
pub fn mt_affine_dequantize_int2<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
mut out: Tensor<T>,
#[constexpr] group_size: u32,
) {
let pack_idx = program_id::<0>();
let pack_factor = 16u32;
let oindex = pack_idx * pack_factor;
let g_idx = oindex / group_size;
let scale = load(scales[g_idx]).cast::<f32>();
let bias = load(biases[g_idx]).cast::<f32>();
let val = load(w[pack_idx]);
for k in range(0u32, 16u32, 1u32) {
let q = (val >> (k * 2u32)) & 3u32;
store(out[oindex + k], (scale * q.cast::<f32>() + bias).cast::<T>());
}
}
#[bench_kernel(
op="affine",
subop="quantize_int2",
class=AffineQuantize,
bits=2,
group_size=64,
n_groups=4096,
batch=1,
tpg=32,
tol=1e-1,
metal_file="quantized.metal",
)]
#[kernel]
pub fn mt_affine_quantize_int2<T>(
w: Tensor<T>,
mut out: Tensor<u32>,
mut scales: Tensor<T>,
mut biases: Tensor<T>,
#[constexpr] group_size: u32,
) {
let g_idx = tgid_x;
let lane = tid;
let in_base = g_idx * group_size;
let v0 = load(w[in_base + lane * 2u32]).cast::<f32>();
let v1 = load(w[in_base + lane * 2u32 + 1u32]).cast::<f32>();
let local_min = select(v0 < v1, v0, v1);
let local_max = select(v0 > v1, v0, v1);
let w_min = simd_min(local_min);
let w_max = simd_max(local_max);
let n_bins = 3.0f32;
let raw_scale = (w_max - w_min) / n_bins;
let eps = 1.0e-7f32;
let scale = select(raw_scale < eps, 1.0f32, raw_scale);
let inv_scale = 1.0f32 / scale;
let bias = w_min;
if lane == 0u32 {
store(scales[g_idx], scale.cast::<T>());
store(biases[g_idx], bias.cast::<T>());
}
let packs_per_group = group_size / 16u32;
if lane < packs_per_group {
let pack_in_base = in_base + lane * 16u32;
let mut acc = 0u32;
for k in range(0u32, 16u32, 1u32) {
let v = load(w[pack_in_base + k]).cast::<f32>();
let q_f = (v - bias) * inv_scale + 0.5f32;
let q_c = select(q_f > 3.0f32, 3.0f32, select(q_f < 0.0f32, 0.0f32, q_f));
let q = q_c.cast::<u32>();
acc = acc | (q << (k * 2u32));
}
store(out[g_idx * packs_per_group + lane], acc);
}
}
#[bench_kernel(
op="affine",
subop="quantize_int3",
class=AffineQuantize,
bits=3,
group_size=32,
n_groups=4096,
batch=1,
tpg=32,
tol=1e-1,
metal_file="quantized.metal",
)]
#[kernel]
pub fn mt_affine_quantize_int3<T>(
w: Tensor<T>,
mut out: Tensor<u32>,
mut scales: Tensor<T>,
mut biases: Tensor<T>,
#[constexpr] group_size: u32,
) {
let g_idx = tgid_x;
let lane = tid;
let in_base = g_idx * group_size;
let v = load(w[in_base + lane]).cast::<f32>();
let w_min = simd_min(v);
let w_max = simd_max(v);
let n_bins = 7.0f32; let raw_scale = (w_max - w_min) / n_bins;
let eps = 1.0e-7f32;
let scale = select(raw_scale < eps, 1.0f32, raw_scale);
let inv_scale = 1.0f32 / scale;
let bias = w_min;
if lane == 0u32 {
store(scales[g_idx], scale.cast::<T>());
store(biases[g_idx], bias.cast::<T>());
let out_base = g_idx * 3u32;
let mut w0 = 0u32;
let mut w1 = 0u32;
let mut w2 = 0u32;
for i in range(0u32, group_size, 1u32) {
let vi = load(w[in_base + i]).cast::<f32>();
let q_f = (vi - bias) * inv_scale + 0.5f32;
let q_c = select(q_f > 7.0f32, 7.0f32, select(q_f < 0.0f32, 0.0f32, q_f));
let q = q_c.cast::<u32>() & 7u32;
let bit_pos = i * 3u32;
let word_idx = bit_pos / 32u32;
let bit_shift = bit_pos & 31u32;
let q_lo = q << bit_shift;
w0 = select(word_idx == 0u32, w0 | q_lo, w0);
w1 = select(word_idx == 1u32, w1 | q_lo, w1);
w2 = select(word_idx == 2u32, w2 | q_lo, w2);
let spills = bit_shift + 3u32 > 32u32;
if spills {
let bits_hi = (bit_shift + 3u32) - 32u32;
let q_hi = q >> (3u32 - bits_hi);
w1 = select(word_idx == 0u32, w1 | q_hi, w1);
w2 = select(word_idx == 1u32, w2 | q_hi, w2);
}
}
store(out[out_base + 0u32], w0);
store(out[out_base + 1u32], w1);
store(out[out_base + 2u32], w2);
}
}
#[bench_kernel(
op="affine",
subop="quantize_int5",
class=AffineQuantize,
bits=5,
group_size=32,
n_groups=4096,
batch=1,
tpg=32,
tol=1e-1,
metal_file="quantized.metal",
)]
#[kernel]
pub fn mt_affine_quantize_int5<T>(
w: Tensor<T>,
mut out: Tensor<u32>,
mut scales: Tensor<T>,
mut biases: Tensor<T>,
#[constexpr] group_size: u32,
) {
let g_idx = tgid_x;
let lane = tid;
let in_base = g_idx * group_size;
let v = load(w[in_base + lane]).cast::<f32>();
let w_min = simd_min(v);
let w_max = simd_max(v);
let n_bins = 31.0f32; let raw_scale = (w_max - w_min) / n_bins;
let eps = 1.0e-7f32;
let scale = select(raw_scale < eps, 1.0f32, raw_scale);
let inv_scale = 1.0f32 / scale;
let bias = w_min;
if lane == 0u32 {
store(scales[g_idx], scale.cast::<T>());
store(biases[g_idx], bias.cast::<T>());
let out_base = g_idx * 5u32;
let mut w0 = 0u32;
let mut w1 = 0u32;
let mut w2 = 0u32;
let mut w3 = 0u32;
let mut w4 = 0u32;
for i in range(0u32, group_size, 1u32) {
let vi = load(w[in_base + i]).cast::<f32>();
let q_f = (vi - bias) * inv_scale + 0.5f32;
let q_c = select(q_f > 31.0f32, 31.0f32, select(q_f < 0.0f32, 0.0f32, q_f));
let q = q_c.cast::<u32>() & 31u32;
let bit_pos = i * 5u32;
let word_idx = bit_pos / 32u32;
let bit_shift = bit_pos & 31u32;
let q_lo = q << bit_shift;
w0 = select(word_idx == 0u32, w0 | q_lo, w0);
w1 = select(word_idx == 1u32, w1 | q_lo, w1);
w2 = select(word_idx == 2u32, w2 | q_lo, w2);
w3 = select(word_idx == 3u32, w3 | q_lo, w3);
w4 = select(word_idx == 4u32, w4 | q_lo, w4);
let spills = bit_shift + 5u32 > 32u32;
if spills {
let bits_hi = (bit_shift + 5u32) - 32u32;
let q_hi = q >> (5u32 - bits_hi);
w1 = select(word_idx == 0u32, w1 | q_hi, w1);
w2 = select(word_idx == 1u32, w2 | q_hi, w2);
w3 = select(word_idx == 2u32, w3 | q_hi, w3);
w4 = select(word_idx == 3u32, w4 | q_hi, w4);
}
}
store(out[out_base + 0u32], w0);
store(out[out_base + 1u32], w1);
store(out[out_base + 2u32], w2);
store(out[out_base + 3u32], w3);
store(out[out_base + 4u32], w4);
}
}
#[bench_kernel(
op="affine",
subop="quantize_int6",
class=AffineQuantize,
bits=6,
group_size=32,
n_groups=4096,
batch=1,
tpg=32,
tol=1e-1,
metal_file="quantized.metal",
)]
#[kernel]
pub fn mt_affine_quantize_int6<T>(
w: Tensor<T>,
mut out: Tensor<u32>,
mut scales: Tensor<T>,
mut biases: Tensor<T>,
#[constexpr] group_size: u32,
) {
let g_idx = tgid_x;
let lane = tid;
let in_base = g_idx * group_size;
let v = load(w[in_base + lane]).cast::<f32>();
let w_min = simd_min(v);
let w_max = simd_max(v);
let n_bins = 63.0f32; let raw_scale = (w_max - w_min) / n_bins;
let eps = 1.0e-7f32;
let scale = select(raw_scale < eps, 1.0f32, raw_scale);
let inv_scale = 1.0f32 / scale;
let bias = w_min;
if lane == 0u32 {
store(scales[g_idx], scale.cast::<T>());
store(biases[g_idx], bias.cast::<T>());
let out_base = g_idx * 6u32;
let mut w0 = 0u32;
let mut w1 = 0u32;
let mut w2 = 0u32;
let mut w3 = 0u32;
let mut w4 = 0u32;
let mut w5 = 0u32;
for i in range(0u32, group_size, 1u32) {
let vi = load(w[in_base + i]).cast::<f32>();
let q_f = (vi - bias) * inv_scale + 0.5f32;
let q_c = select(q_f > 63.0f32, 63.0f32, select(q_f < 0.0f32, 0.0f32, q_f));
let q = q_c.cast::<u32>() & 63u32;
let bit_pos = i * 6u32;
let word_idx = bit_pos / 32u32;
let bit_shift = bit_pos & 31u32;
let q_lo = q << bit_shift;
w0 = select(word_idx == 0u32, w0 | q_lo, w0);
w1 = select(word_idx == 1u32, w1 | q_lo, w1);
w2 = select(word_idx == 2u32, w2 | q_lo, w2);
w3 = select(word_idx == 3u32, w3 | q_lo, w3);
w4 = select(word_idx == 4u32, w4 | q_lo, w4);
w5 = select(word_idx == 5u32, w5 | q_lo, w5);
let spills = bit_shift + 6u32 > 32u32;
if spills {
let bits_hi = (bit_shift + 6u32) - 32u32;
let q_hi = q >> (6u32 - bits_hi);
w1 = select(word_idx == 0u32, w1 | q_hi, w1);
w2 = select(word_idx == 1u32, w2 | q_hi, w2);
w3 = select(word_idx == 2u32, w3 | q_hi, w3);
w4 = select(word_idx == 3u32, w4 | q_hi, w4);
w5 = select(word_idx == 4u32, w5 | q_hi, w5);
}
}
store(out[out_base + 0u32], w0);
store(out[out_base + 1u32], w1);
store(out[out_base + 2u32], w2);
store(out[out_base + 3u32], w3);
store(out[out_base + 4u32], w4);
store(out[out_base + 5u32], w5);
}
}
#[bench_kernel(
op="affine",
subop="dequantize_int3",
class=AffineDequantize,
bits=3,
group_size=32,
n_groups=4096,
batch=1,
tpg=16,
// tol=5e-3 — int3 max_q=7; worst-case bf16 drift at n_groups=4096
// is ~1e-3.
tol=5e-3,
metal_file="quantized.metal",
)]
#[kernel]
pub fn mt_affine_dequantize_int3<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
mut out: Tensor<T>,
#[constexpr] group_size: u32,
) {
let pack_idx = program_id::<0>();
let pack_factor = 8u32;
let bytes_per_pack = 3u32;
let oindex = pack_idx * pack_factor;
let g_idx = oindex / group_size;
let scale = load(scales[g_idx]).cast::<f32>();
let bias = load(biases[g_idx]).cast::<f32>();
let byte_off = pack_idx * bytes_per_pack;
let u_idx0 = byte_off / 4u32;
let u0 = load(w[u_idx0]);
let u1 = load(w[u_idx0 + 1u32]);
let s0 = byte_off & 3u32;
let s1 = (byte_off + 1u32) & 3u32;
let s2 = (byte_off + 2u32) & 3u32;
let in0_0 = (byte_off + 0u32) / 4u32 == u_idx0;
let in0_1 = (byte_off + 1u32) / 4u32 == u_idx0;
let in0_2 = (byte_off + 2u32) / 4u32 == u_idx0;
let b0 = (select(in0_0, u0, u1) >> (s0 * 8u32)) & 255u32;
let b1 = (select(in0_1, u0, u1) >> (s1 * 8u32)) & 255u32;
let b2 = (select(in0_2, u0, u1) >> (s2 * 8u32)) & 255u32;
let q0 = b0 & 7u32;
let q1 = (b0 >> 3u32) & 7u32;
let q2 = ((b0 >> 6u32) & 3u32) | ((b1 & 1u32) << 2u32);
let q3 = (b1 >> 1u32) & 7u32;
let q4 = (b1 >> 4u32) & 7u32;
let q5 = ((b1 >> 7u32) & 1u32) | ((b2 & 3u32) << 1u32);
let q6 = (b2 >> 2u32) & 7u32;
let q7 = (b2 >> 5u32) & 7u32;
store(out[oindex + 0u32], (scale * q0.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 1u32], (scale * q1.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 2u32], (scale * q2.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 3u32], (scale * q3.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 4u32], (scale * q4.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 5u32], (scale * q5.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 6u32], (scale * q6.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 7u32], (scale * q7.cast::<f32>() + bias).cast::<T>());
}
#[bench_kernel(
op="affine",
subop="dequantize_int5",
class=AffineDequantize,
bits=5,
group_size=32,
n_groups=4096,
batch=1,
tpg=16,
tol=1e-2,
metal_file="quantized.metal",
)]
#[kernel]
pub fn mt_affine_dequantize_int5<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
mut out: Tensor<T>,
#[constexpr] group_size: u32,
) {
let pack_idx = program_id::<0>();
let pack_factor = 8u32;
let bytes_per_pack = 5u32;
let oindex = pack_idx * pack_factor;
let g_idx = oindex / group_size;
let scale = load(scales[g_idx]).cast::<f32>();
let bias = load(biases[g_idx]).cast::<f32>();
let byte_off = pack_idx * bytes_per_pack;
let u_idx0 = byte_off / 4u32;
let u0 = load(w[u_idx0]);
let u1 = load(w[u_idx0 + 1u32]);
let s0 = byte_off & 3u32;
let s1 = (byte_off + 1u32) & 3u32;
let s2 = (byte_off + 2u32) & 3u32;
let s3 = (byte_off + 3u32) & 3u32;
let s4 = (byte_off + 4u32) & 3u32;
let in0_0 = (byte_off + 0u32) / 4u32 == u_idx0;
let in0_1 = (byte_off + 1u32) / 4u32 == u_idx0;
let in0_2 = (byte_off + 2u32) / 4u32 == u_idx0;
let in0_3 = (byte_off + 3u32) / 4u32 == u_idx0;
let in0_4 = (byte_off + 4u32) / 4u32 == u_idx0;
let b0 = (select(in0_0, u0, u1) >> (s0 * 8u32)) & 255u32;
let b1 = (select(in0_1, u0, u1) >> (s1 * 8u32)) & 255u32;
let b2 = (select(in0_2, u0, u1) >> (s2 * 8u32)) & 255u32;
let b3 = (select(in0_3, u0, u1) >> (s3 * 8u32)) & 255u32;
let b4 = (select(in0_4, u0, u1) >> (s4 * 8u32)) & 255u32;
let q0 = b0 & 31u32;
let q1 = ((b0 >> 5u32) & 7u32) | ((b1 & 3u32) << 3u32);
let q2 = (b1 >> 2u32) & 31u32;
let q3 = ((b1 >> 7u32) & 1u32) | ((b2 & 15u32) << 1u32);
let q4 = ((b2 >> 4u32) & 15u32) | ((b3 & 1u32) << 4u32);
let q5 = (b3 >> 1u32) & 31u32;
let q6 = ((b3 >> 6u32) & 3u32) | ((b4 & 7u32) << 2u32);
let q7 = (b4 >> 3u32) & 31u32;
store(out[oindex + 0u32], (scale * q0.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 1u32], (scale * q1.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 2u32], (scale * q2.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 3u32], (scale * q3.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 4u32], (scale * q4.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 5u32], (scale * q5.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 6u32], (scale * q6.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 7u32], (scale * q7.cast::<f32>() + bias).cast::<T>());
}
#[bench_kernel(
op="affine",
subop="dequantize_int6",
class=AffineDequantize,
bits=6,
group_size=32,
n_groups=4096,
batch=1,
tpg=16,
// tol=5e-2 — int6 max_q=63; worst-case bf16 drift at n_groups=4096
// is ~1.3e-2.
tol=5e-2,
metal_file="quantized.metal",
)]
#[kernel]
pub fn mt_affine_dequantize_int6<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
mut out: Tensor<T>,
#[constexpr] group_size: u32,
) {
let pack_idx = program_id::<0>();
let pack_factor = 4u32;
let bytes_per_pack = 3u32;
let oindex = pack_idx * pack_factor;
let g_idx = oindex / group_size;
let scale = load(scales[g_idx]).cast::<f32>();
let bias = load(biases[g_idx]).cast::<f32>();
let byte_off = pack_idx * bytes_per_pack;
let u_idx0 = byte_off / 4u32;
let u0 = load(w[u_idx0]);
let u1 = load(w[u_idx0 + 1u32]);
let s0 = byte_off & 3u32;
let s1 = (byte_off + 1u32) & 3u32;
let s2 = (byte_off + 2u32) & 3u32;
let in0_0 = (byte_off + 0u32) / 4u32 == u_idx0;
let in0_1 = (byte_off + 1u32) / 4u32 == u_idx0;
let in0_2 = (byte_off + 2u32) / 4u32 == u_idx0;
let b0 = (select(in0_0, u0, u1) >> (s0 * 8u32)) & 255u32;
let b1 = (select(in0_1, u0, u1) >> (s1 * 8u32)) & 255u32;
let b2 = (select(in0_2, u0, u1) >> (s2 * 8u32)) & 255u32;
let q0 = b0 & 63u32;
let q1 = ((b0 >> 6u32) & 3u32) | ((b1 & 15u32) << 2u32);
let q2 = ((b1 >> 4u32) & 15u32) | ((b2 & 3u32) << 4u32);
let q3 = (b2 >> 2u32) & 63u32;
store(out[oindex + 0u32], (scale * q0.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 1u32], (scale * q1.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 2u32], (scale * q2.cast::<f32>() + bias).cast::<T>());
store(out[oindex + 3u32], (scale * q3.cast::<f32>() + bias).cast::<T>());
}
macro_rules! quantized_family_spec {
($name:ident, $subop:literal) => {
inventory::submit! {
crate::spec::BenchSpec {
op: "quantized",
subop: $subop,
kernel_name: stringify!($name),
kernel_ir: $name::kernel_ir_for,
dtypes: &[
metaltile_core::dtype::DType::F32,
metaltile_core::dtype::DType::F16,
metaltile_core::dtype::DType::BF16,
],
tol: 5e-2, mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: crate::spec::BenchDispatch::Generic,
kernel_mode: Some(metaltile_core::ir::KernelMode::Reduction),
}
}
};
}
macro_rules! qmv_pow2 {
($name:ident, $bits:literal, $subop:literal) => {
#[kernel]
pub fn $name<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] group_size: u32,
) {
let row = tgid_x;
let m_row = tgid_y;
let lane = simd_lane;
let groups_per_row = k / group_size;
let scale_row_base = row * groups_per_row;
let x_row_base = m_row * k;
let vals_per_pack = 32u32 / $bits;
let packs_per_row = k / vals_per_pack;
let mask = (1u32 << $bits) - 1u32;
let mut acc = 0.0f32;
let n_iters = (k + 31u32) / 32u32;
for _it in range(0u32, n_iters, 1u32) {
let d = _it * 32u32 + lane;
if d < k {
let g = d / group_size;
let scale = load(scales[scale_row_base + g]).cast::<f32>();
let bias = load(biases[scale_row_base + g]).cast::<f32>();
let pack = d / vals_per_pack;
let slot = d - pack * vals_per_pack;
let word = load(w[row * packs_per_row + pack]);
let q = (word >> (slot * $bits)) & mask;
let wv = q.cast::<f32>() * scale + bias;
acc = acc + wv * load(x[x_row_base + d]).cast::<f32>();
}
}
let total = simd_sum(acc);
if lane == 0u32 {
store(out[m_row * n + row], total.cast::<T>());
}
}
quantized_family_spec!($name, $subop);
};
}
macro_rules! qmv_odd {
($name:ident, $bits:literal, $subop:literal) => {
#[kernel]
pub fn $name<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] group_size: u32,
) {
let row = tgid_x;
let m_row = tgid_y;
let lane = simd_lane;
let groups_per_row = k / group_size;
let scale_row_base = row * groups_per_row;
let x_row_base = m_row * k;
let u32_per_row = k * $bits / 32u32;
let row_u32_off = row * u32_per_row;
let mut acc = 0.0f32;
let n_iters = (k + 31u32) / 32u32;
for _it in range(0u32, n_iters, 1u32) {
let d = _it * 32u32 + lane;
if d < k {
let g = d / group_size;
let scale = load(scales[scale_row_base + g]).cast::<f32>();
let bias = load(biases[scale_row_base + g]).cast::<f32>();
let bit_off = d * $bits;
let word_idx = bit_off / 32u32;
let bit_in_w = bit_off & 31u32;
let bits_in_w0 = 32u32 - bit_in_w;
let lo_bits = select(bits_in_w0 >= $bits, $bits, bits_in_w0);
let spill = $bits - lo_bits;
let w0 = load(w[row_u32_off + word_idx]);
let w1idx = select(spill > 0u32, word_idx + 1u32, word_idx);
let w1 = load(w[row_u32_off + w1idx]);
let lo = (w0 >> bit_in_w) & ((1u32 << lo_bits) - 1u32);
let hi = (w1 & ((1u32 << spill) - 1u32)) << lo_bits;
let q = lo | hi;
let wv = q.cast::<f32>() * scale + bias;
acc = acc + wv * load(x[x_row_base + d]).cast::<f32>();
}
}
let total = simd_sum(acc);
if lane == 0u32 {
store(out[m_row * n + row], total.cast::<T>());
}
}
quantized_family_spec!($name, $subop);
};
}
macro_rules! qvm_pow2 {
($name:ident, $bits:literal, $subop:literal) => {
#[kernel]
pub fn $name<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] group_size: u32,
) {
let col = tgid_x;
let m_row = tgid_y;
let lane = simd_lane;
let x_row_base = m_row * k;
let vals_per_pack = 32u32 / $bits;
let packs_per_row = n / vals_per_pack;
let mask = (1u32 << $bits) - 1u32;
let mut acc = 0.0f32;
let n_iters = (k + 31u32) / 32u32;
for _it in range(0u32, n_iters, 1u32) {
let d = _it * 32u32 + lane;
if d < k {
let g = d / group_size;
let scale = load(scales[g * n + col]).cast::<f32>();
let bias = load(biases[g * n + col]).cast::<f32>();
let pack = col / vals_per_pack;
let slot = col - pack * vals_per_pack;
let word = load(w[d * packs_per_row + pack]);
let q = (word >> (slot * $bits)) & mask;
let wv = q.cast::<f32>() * scale + bias;
acc = acc + wv * load(x[x_row_base + d]).cast::<f32>();
}
}
let total = simd_sum(acc);
if lane == 0u32 {
store(out[m_row * n + col], total.cast::<T>());
}
}
quantized_family_spec!($name, $subop);
};
}
macro_rules! qvm_odd {
($name:ident, $bits:literal, $subop:literal) => {
#[kernel]
pub fn $name<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] group_size: u32,
) {
let col = tgid_x;
let m_row = tgid_y;
let lane = simd_lane;
let x_row_base = m_row * k;
let u32_per_row = n * $bits / 32u32;
let mut acc = 0.0f32;
let n_iters = (k + 31u32) / 32u32;
for _it in range(0u32, n_iters, 1u32) {
let d = _it * 32u32 + lane;
if d < k {
let g = d / group_size;
let scale = load(scales[g * n + col]).cast::<f32>();
let bias = load(biases[g * n + col]).cast::<f32>();
let row_u32_off = d * u32_per_row;
let bit_off = col * $bits;
let word_idx = bit_off / 32u32;
let bit_in_w = bit_off & 31u32;
let bits_in_w0 = 32u32 - bit_in_w;
let lo_bits = select(bits_in_w0 >= $bits, $bits, bits_in_w0);
let spill = $bits - lo_bits;
let w0 = load(w[row_u32_off + word_idx]);
let w1idx = select(spill > 0u32, word_idx + 1u32, word_idx);
let w1 = load(w[row_u32_off + w1idx]);
let lo = (w0 >> bit_in_w) & ((1u32 << lo_bits) - 1u32);
let hi = (w1 & ((1u32 << spill) - 1u32)) << lo_bits;
let q = lo | hi;
let wv = q.cast::<f32>() * scale + bias;
acc = acc + wv * load(x[x_row_base + d]).cast::<f32>();
}
}
let total = simd_sum(acc);
if lane == 0u32 {
store(out[m_row * n + col], total.cast::<T>());
}
}
quantized_family_spec!($name, $subop);
};
}
qmv_pow2!(mt_qmv_b4, 4u32, "qmv_b4");
qmv_pow2!(mt_qmv_b8, 8u32, "qmv_b8");
qmv_odd!(mt_qmv_b3, 3u32, "qmv_b3");
qmv_odd!(mt_qmv_b5, 5u32, "qmv_b5");
qmv_odd!(mt_qmv_b6, 6u32, "qmv_b6");
qmv_pow2!(mt_qmm_b4, 4u32, "qmm_b4");
qmv_pow2!(mt_qmm_b8, 8u32, "qmm_b8");
qmv_odd!(mt_qmm_b3, 3u32, "qmm_b3");
qmv_odd!(mt_qmm_b5, 5u32, "qmm_b5");
qmv_odd!(mt_qmm_b6, 6u32, "qmm_b6");
qvm_pow2!(mt_qvm_b4, 4u32, "qvm_b4");
qvm_pow2!(mt_qvm_b8, 8u32, "qvm_b8");
qvm_odd!(mt_qvm_b3, 3u32, "qvm_b3");
qvm_odd!(mt_qvm_b5, 5u32, "qvm_b5");
qvm_odd!(mt_qvm_b6, 6u32, "qvm_b6");
#[kernel]
pub fn mt_qvm_int4_fast<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] gs_per_col: u32, ) {
let tg = tgid_x;
let sg = simd_id;
let lane = simd_lane;
let col0 = tg * 8u32 + sg * 4u32;
let col1 = col0 + 1u32;
let col2 = col0 + 2u32;
let col3 = col0 + 3u32;
let packs_per_krow = n / 8u32;
let col0_pack = col0 / 8u32;
let col1_pack = col1 / 8u32;
let col2_pack = col2 / 8u32;
let col3_pack = col3 / 8u32;
let col0_slot = col0 & 7u32;
let col1_slot = col1 & 7u32;
let col2_slot = col2 & 7u32;
let col3_slot = col3 & 7u32;
let mask4 = 15u32;
let group_size = k / gs_per_col;
let mut acc0 = 0.0f32;
let mut acc1 = 0.0f32;
let mut acc2 = 0.0f32;
let mut acc3 = 0.0f32;
let n_iters = (k + 31u32) / 32u32;
for _it in range(0u32, n_iters, 1u32) {
let d = _it * 32u32 + lane; if d < k {
let xd = load(x[d]).cast::<f32>();
let g = d / group_size;
let s0 = load(scales[g * n + col0]).cast::<f32>();
let bi0 = load(biases[g * n + col0]).cast::<f32>();
let s1 = load(scales[g * n + col1]).cast::<f32>();
let bi1 = load(biases[g * n + col1]).cast::<f32>();
let s2 = load(scales[g * n + col2]).cast::<f32>();
let bi2 = load(biases[g * n + col2]).cast::<f32>();
let s3 = load(scales[g * n + col3]).cast::<f32>();
let bi3 = load(biases[g * n + col3]).cast::<f32>();
let row_base = d * packs_per_krow;
let w0 = load(w[row_base + col0_pack]);
let q0 = ((w0 >> (col0_slot * 4u32)) & mask4).cast::<f32>();
let w1 = load(w[row_base + col1_pack]);
let q1 = ((w1 >> (col1_slot * 4u32)) & mask4).cast::<f32>();
let w2 = load(w[row_base + col2_pack]);
let q2 = ((w2 >> (col2_slot * 4u32)) & mask4).cast::<f32>();
let w3 = load(w[row_base + col3_pack]);
let q3 = ((w3 >> (col3_slot * 4u32)) & mask4).cast::<f32>();
acc0 = acc0 + (q0 * s0 + bi0) * xd;
acc1 = acc1 + (q1 * s1 + bi1) * xd;
acc2 = acc2 + (q2 * s2 + bi2) * xd;
acc3 = acc3 + (q3 * s3 + bi3) * xd;
}
}
let r0 = simd_sum(acc0);
let r1 = simd_sum(acc1);
let r2 = simd_sum(acc2);
let r3 = simd_sum(acc3);
if lane == 0u32 {
store(out[col0], r0.cast::<T>());
store(out[col1], r1.cast::<T>());
store(out[col2], r2.cast::<T>());
store(out[col3], r3.cast::<T>());
}
}
quantized_family_spec!(mt_qvm_int4_fast, "qvm_int4_fast");
macro_rules! qmm_mma_bitwidth {
($name:ident, $bits:literal, $subop:literal) => {
#[kernel]
#[allow(clippy::too_many_arguments)]
pub fn $name<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] gs_per_row: u32,
) {
let n_tile = tgid_x;
let m_tile = tgid_y;
let lane = simd_lane;
let sg = simd_group_id();
let sm = sg / 2u32;
let sn = sg & 1u32;
let lane_in_tg = sg * 32u32 + lane;
let qid = lane / 4u32;
let fm = (qid & 4u32) + ((lane / 2u32) % 4u32);
let fn0 = (qid & 2u32) * 2u32 + (lane % 2u32) * 2u32;
let fn1 = fn0 + 1u32;
threadgroup_alloc("xs", 1152, T);
threadgroup_alloc("ws", 1152, T);
let c_f00 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f00, 0, 0.0f32);
simdgroup_elem_store(c_f00, 1, 0.0f32);
let c_f01 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f01, 0, 0.0f32);
simdgroup_elem_store(c_f01, 1, 0.0f32);
let c_f10 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f10, 0, 0.0f32);
simdgroup_elem_store(c_f10, 1, 0.0f32);
let c_f11 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f11, 0, 0.0f32);
simdgroup_elem_store(c_f11, 1, 0.0f32);
let a_f0 = simdgroup_alloc::<T, 8, 8>();
let a_f1 = simdgroup_alloc::<T, 8, 8>();
let b_f0 = simdgroup_alloc::<T, 8, 8>();
let b_f1 = simdgroup_alloc::<T, 8, 8>();
let w_row = lane_in_tg / 4u32;
let pack_in_row = lane_in_tg & 3u32;
let x_m_row = lane_in_tg / 4u32;
let x_k_quad = lane_in_tg & 3u32;
let x_k_base = x_k_quad * 8u32;
let xs_ld = 36u32;
let ws_ld = 36u32;
let x_m_base = m_tile * 32u32;
let w_n_base = n_tile * 32u32;
let u32_per_row = k * $bits / 32u32;
let group_size = k / gs_per_row;
let sb_base = (w_n_base + w_row) * gs_per_row;
let w_row_base = (w_n_base + w_row) * u32_per_row;
for kb in range(0u32, k, 32u32) {
let x_row_dev_base = (x_m_base + x_m_row) * k + kb + x_k_base;
let x_ws_base = x_m_row * xs_ld + x_k_base;
let xv0 = load(x[x_row_dev_base]).cast::<T>();
let xv1 = load(x[x_row_dev_base + 1u32]).cast::<T>();
let xv2 = load(x[x_row_dev_base + 2u32]).cast::<T>();
let xv3 = load(x[x_row_dev_base + 3u32]).cast::<T>();
let xv4 = load(x[x_row_dev_base + 4u32]).cast::<T>();
let xv5 = load(x[x_row_dev_base + 5u32]).cast::<T>();
let xv6 = load(x[x_row_dev_base + 6u32]).cast::<T>();
let xv7 = load(x[x_row_dev_base + 7u32]).cast::<T>();
threadgroup_store("xs", x_ws_base, xv0);
threadgroup_store("xs", x_ws_base + 1u32, xv1);
threadgroup_store("xs", x_ws_base + 2u32, xv2);
threadgroup_store("xs", x_ws_base + 3u32, xv3);
threadgroup_store("xs", x_ws_base + 4u32, xv4);
threadgroup_store("xs", x_ws_base + 5u32, xv5);
threadgroup_store("xs", x_ws_base + 6u32, xv6);
threadgroup_store("xs", x_ws_base + 7u32, xv7);
let k0 = kb + pack_in_row * 8u32;
let g = k0 / group_size;
let s = load(scales[sb_base + g]).cast::<f32>();
let b = load(biases[sb_base + g]).cast::<f32>();
let ws_base = w_row * ws_ld + pack_in_row * 8u32;
for _ci in range(0u32, 8u32, 1u32) {
let bit_off = (k0 + _ci) * $bits;
let word_idx = bit_off / 32u32;
let bit_in_w = bit_off & 31u32;
let bits_in_w0 = 32u32 - bit_in_w;
let lo_bits = select(bits_in_w0 >= $bits, $bits, bits_in_w0);
let spill = $bits - lo_bits;
let w0 = load(w[w_row_base + word_idx]);
let w1idx = select(spill > 0u32, word_idx + 1u32, word_idx);
let w1 = load(w[w_row_base + w1idx]);
let lo = (w0 >> bit_in_w) & ((1u32 << lo_bits) - 1u32);
let hi = (w1 & ((1u32 << spill) - 1u32)) << lo_bits;
let q = (lo | hi).cast::<f32>();
threadgroup_store("ws", ws_base + _ci, (s * q + b).cast::<T>());
}
threadgroup_barrier();
let row_a0 = sm * 16u32 + fm;
let row_a1 = sm * 16u32 + 8u32 + fm;
let col_b0 = sn * 16u32;
let col_b1 = sn * 16u32 + 8u32;
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 8u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 8u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 8u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 8u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(
b_f0,
0,
threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 8u32 + fm),
);
simdgroup_elem_store(
b_f0,
1,
threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 8u32 + fm),
);
simdgroup_elem_store(
b_f1,
0,
threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 8u32 + fm),
);
simdgroup_elem_store(
b_f1,
1,
threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 8u32 + fm),
);
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 16u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 16u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 16u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 16u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(
b_f0,
0,
threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 16u32 + fm),
);
simdgroup_elem_store(
b_f0,
1,
threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 16u32 + fm),
);
simdgroup_elem_store(
b_f1,
0,
threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 16u32 + fm),
);
simdgroup_elem_store(
b_f1,
1,
threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 16u32 + fm),
);
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 24u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 24u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 24u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 24u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(
b_f0,
0,
threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 24u32 + fm),
);
simdgroup_elem_store(
b_f0,
1,
threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 24u32 + fm),
);
simdgroup_elem_store(
b_f1,
0,
threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 24u32 + fm),
);
simdgroup_elem_store(
b_f1,
1,
threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 24u32 + fm),
);
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
threadgroup_barrier();
}
let out_m_base = m_tile * 32u32 + sm * 16u32;
let out_n_base = n_tile * 32u32 + sn * 16u32;
store(
out[(out_m_base + fm) * n + out_n_base + fn0],
simdgroup_elem_load(c_f00, 0).cast::<T>(),
);
store(
out[(out_m_base + fm) * n + out_n_base + fn1],
simdgroup_elem_load(c_f00, 1).cast::<T>(),
);
store(
out[(out_m_base + fm) * n + out_n_base + 8u32 + fn0],
simdgroup_elem_load(c_f01, 0).cast::<T>(),
);
store(
out[(out_m_base + fm) * n + out_n_base + 8u32 + fn1],
simdgroup_elem_load(c_f01, 1).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + fn0],
simdgroup_elem_load(c_f10, 0).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + fn1],
simdgroup_elem_load(c_f10, 1).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + 8u32 + fn0],
simdgroup_elem_load(c_f11, 0).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + 8u32 + fn1],
simdgroup_elem_load(c_f11, 1).cast::<T>(),
);
}
inventory::submit! {
crate::spec::BenchSpec {
op: "quantized",
subop: $subop,
kernel_name: stringify!($name),
kernel_ir: $name::kernel_ir_for,
dtypes: &[
metaltile_core::dtype::DType::F32,
metaltile_core::dtype::DType::F16,
metaltile_core::dtype::DType::BF16,
],
tol: 5e-2,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: crate::spec::BenchDispatch::Generic,
kernel_mode: Some(metaltile_core::ir::KernelMode::Reduction),
}
}
};
}
qmm_mma_bitwidth!(mt_qmm_mma_b3, 3u32, "qmm_mma_b3");
qmm_mma_bitwidth!(mt_qmm_mma_b5, 5u32, "qmm_mma_b5");
qmm_mma_bitwidth!(mt_qmm_mma_b6, 6u32, "qmm_mma_b6");
pub fn mt_qmm_for(dtype: metaltile_core::dtype::DType, m: u32) -> metaltile_core::ir::Kernel {
use metaltile_core::ir::KernelMode;
let mut k = if m >= 32 && m.is_multiple_of(32) {
let mut kk = mt_qmm_mma::kernel_ir_for(dtype);
patch_qmm_mma_dtype_aware_skew(&mut kk, dtype);
kk
} else if m == 16 {
mt_qmm_mma_m16::kernel_ir_for(dtype)
} else if m >= 4 && m.is_multiple_of(4) {
mt_qmm_bm4::kernel_ir_for(dtype)
} else if m >= 2 && m.is_multiple_of(2) {
mt_qmm_bm2::kernel_ir_for(dtype)
} else {
mt_qmm::kernel_ir_for(dtype)
};
k.mode = KernelMode::Reduction;
k
}
pub fn patch_qmm_mma_dtype_aware_skew(
kernel: &mut metaltile_core::ir::Kernel,
dtype: metaltile_core::dtype::DType,
) {
use metaltile_core::dtype::DType;
let bytes = match dtype {
DType::F32 => 4,
DType::F16 | DType::BF16 => 2,
_ => return,
};
if bytes == 4 {
return;
}
let new_ld: i64 = 32 + (16 / bytes as i64);
let new_alloc: u32 = 32 * (new_ld as u32);
let target_names: [&str; 2] = ["xs_ld_const", "ws_ld_const"];
for (vid, name) in kernel.body.names.clone().iter() {
if !target_names.iter().any(|t| t == name) {
continue;
}
for (i, r) in kernel.body.results.iter().enumerate() {
if r.map(|v| v == *vid).unwrap_or(false)
&& let Some(value) = kernel.body.ops[i].as_const_mut()
{
*value = new_ld;
}
}
}
for op in kernel.body.ops.iter_mut() {
if let Some((name, size)) = op.as_threadgroup_alloc_mut()
&& (name == "xs" || name == "ws")
{
*size = new_alloc;
}
}
}
#[cfg(test)]
mod qmm_selector_tests {
use metaltile_core::dtype::DType;
use super::*;
#[test]
fn selector_picks_mma_at_m_multiple_of_32() {
for m in [32u32, 64, 96, 128] {
let k = mt_qmm_for(DType::F32, m);
assert_eq!(k.name, "mt_qmm_mma", "m={m}: multiple of 32 should route to mma");
}
}
#[test]
fn selector_picks_mma_m16_at_m_16() {
let k = mt_qmm_for(DType::F32, 16);
assert_eq!(k.name, "mt_qmm_mma_m16");
}
#[test]
fn selector_picks_bm4_at_m_8_12_20_24_28() {
for m in [4u32, 8, 12, 20, 24, 28, 36, 60] {
let k = mt_qmm_for(DType::F32, m);
assert_eq!(k.name, "mt_qmm_bm4", "m={m}: m%4==0 not mma should route to bm4");
}
}
#[test]
fn selector_picks_bm2_at_even_m_not_multiple_of_4() {
for m in [2u32, 6, 10, 14, 18, 22, 26, 30] {
let k = mt_qmm_for(DType::F32, m);
assert_eq!(k.name, "mt_qmm_bm2", "m={m}: even-not-mod-4 should route to bm2");
}
}
#[test]
fn selector_picks_v2_at_m_1() {
let k = mt_qmm_for(DType::F32, 1);
assert_eq!(k.name, "mt_qmm");
}
#[test]
fn selector_picks_v2_at_odd_m() {
for m in [3u32, 5, 7, 9, 15, 31] {
let k = mt_qmm_for(DType::F32, m);
assert_eq!(k.name, "mt_qmm", "m={m}: odd M should route to v2");
}
}
#[test]
fn selector_picks_bm4_across_dtypes_at_m_8() {
for dt in [DType::F32, DType::F16] {
let k = mt_qmm_for(dt, 8);
assert_eq!(k.name, "mt_qmm_bm4", "dt={dt:?}");
}
}
#[test]
fn selector_kernels_carry_reduction_mode() {
for m in [1u32, 4, 8, 16, 32] {
let k = mt_qmm_for(DType::F32, m);
assert_eq!(
k.mode,
metaltile_core::ir::KernelMode::Reduction,
"m={m}: missing Reduction mode",
);
}
}
}