mlx-native 0.1.3

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
// quantized_matmul.metal — MSL shader for 4-bit, 6-bit, and 8-bit affine
// quantized matrix multiplication with on-the-fly dequantization.
//
// Computes: output[row][col] = sum_k(dequant(weight[col][k]) * input[row][k])
//
// Weight layout: (N, K) — each of N output columns is a row in the weight
// matrix, packed in quantized format.  Scales and biases are per-group
// (group_size consecutive values along K share one scale and one bias).
//
// Packing formats:
//   4-bit: 8 values per uint32, value i = (packed >> (4*i)) & 0xF
//   6-bit: 4 values per 3 bytes packed into uint32 (MLX triplet format)
//          val0 = packed & 0x3F, val1 = (packed>>6) & 0x3F, etc.
//   8-bit: 4 values per uint32, value i = (packed >> (8*i)) & 0xFF
//
// Dequantization: float_val = scale * quant_val + bias
//   where scale and bias are bf16, one per group of group_size values.
//
// Accumulation is done in f32 for numerical stability; output is written as f16.

#include <metal_stdlib>
using namespace metal;

// Parameters struct — must match the Rust-side QuantizedMatmulGpuParams.
struct QuantizedMatmulParams {
    uint M;           // number of input rows (tokens)
    uint K;           // inner dimension
    uint N;           // number of output columns
    uint group_size;  // values per scale/bias group
    uint bits;        // 4, 6, or 8
};

// Helper: read bf16 scales/biases.  Metal's `half` type is IEEE f16, which is
// a DIFFERENT format from bf16 (bfloat16).  MLX stores scales and biases as
// bf16, so we read them as raw uint16_t and reinterpret via as_type<bfloat>,
// then cast to float for dequantization arithmetic.

// ---- Dequantization helpers ----
// All dequant functions operate in bf16 arithmetic to match MLX's precision.
// MLX dequantizes weights to bf16 (the model's native dtype), and multiplies
// bf16 weights with bf16 inputs, accumulating products in f32.
// Using bf16 dequant ensures our weight values match MLX's exactly.

// 4-bit: Extract the i-th 4-bit value from a packed uint32.
inline bfloat dequant_4bit(uint packed, uint i, bfloat scale, bfloat bias) {
    uint val = (packed >> (4 * i)) & 0xFu;
    return bfloat(val) * scale + bias;
}

// 6-bit: Extract the i-th 6-bit value from a packed uint32 (4 values per uint32).
inline bfloat dequant_6bit(uint packed, uint i, bfloat scale, bfloat bias) {
    uint val = (packed >> (6 * i)) & 0x3Fu;
    return bfloat(val) * scale + bias;
}

// 8-bit: Extract the i-th 8-bit value from a packed uint32 (4 values per uint32).
inline bfloat dequant_8bit(uint packed, uint i, bfloat scale, bfloat bias) {
    uint val = (packed >> (8 * i)) & 0xFFu;
    return bfloat(val) * scale + bias;
}

// Main quantized matmul kernel (f32 output).
//
// Each thread computes one element of the output: output[row][col].
// This is a simple, correct baseline — future optimization (tiling, SIMD groups)
// will come in Epic 6.
//
// Accumulation and output are both f32 to avoid f16 overflow (max ~65504) on
// projections with large intermediate values (e.g. attention output projections).
//
// Buffer layout:
//   buffer(0): input    — float32[M][K] (row-major)
//   buffer(1): weight   — packed uint32[N][packed_k] (row-major per output column)
//   buffer(2): scales   — bf16[N][num_groups_per_row] (one per group along K, stored as uint16)
//   buffer(3): biases   — bf16[N][num_groups_per_row] (stored as uint16)
//   buffer(4): output   — float32[M][N] (row-major)
//   buffer(5): params   — QuantizedMatmulParams
kernel void quantized_matmul(
    device const float*  input   [[buffer(0)]],
    device const uint*   weight  [[buffer(1)]],
    device const uint16_t* scales  [[buffer(2)]],
    device const uint16_t* biases  [[buffer(3)]],
    device float*        output  [[buffer(4)]],
    constant QuantizedMatmulParams& params [[buffer(5)]],
    uint2 tid [[thread_position_in_grid]]
) {
    uint row = tid.y;  // which input row (token)
    uint col = tid.x;  // which output column

    if (row >= params.M || col >= params.N) {
        return;
    }

    uint K = params.K;
    uint group_size = params.group_size;
    uint bits = params.bits;

    // Number of groups along K for one output column.
    uint num_groups = (K + group_size - 1) / group_size;

    // Scale/bias for this column: stored contiguously as scales[col * num_groups + g].
    uint sb_base = col * num_groups;

    float acc = 0.0f;

    // Determine packing parameters based on bit-width.
    // 4-bit: 8 values per uint32
    // 6-bit: 4 values per 3 bytes (MLX triplet format, read as uint8_t triples)
    // 8-bit: 4 values per uint32
    const device uint8_t* w_bytes = (const device uint8_t*)weight;

    if (bits == 6) {
        // 6-bit: 4 values per 3-byte triplet. Row stride = ceil(K/4) * 3 bytes.
        uint triplets_per_row = (K + 3) / 4;
        uint row_bytes = triplets_per_row * 3;
        const device uint8_t* w_row = w_bytes + col * row_bytes;

        for (uint k = 0; k < K; k++) {
            uint triplet_idx = k / 4;
            uint in_triplet = k % 4;
            // Read 3 bytes as uint32 (high byte is zero-padded)
            uint byte_off = triplet_idx * 3;
            uint packed = uint(w_row[byte_off])
                        | (uint(w_row[byte_off + 1]) << 8)
                        | (uint(w_row[byte_off + 2]) << 16);

            uint g = k / group_size;
            bfloat scale = as_type<bfloat>(scales[sb_base + g]);
            bfloat bias  = as_type<bfloat>(biases[sb_base + g]);
            bfloat w = dequant_6bit(packed, in_triplet, scale, bias);

            bfloat x = bfloat(input[row * K + k]);
            acc += float(w) * float(x);
        }
    } else {
        // 4-bit and 8-bit: uint32 packed
        uint values_per_pack = (bits == 4) ? 8u : 4u;
        uint packs_per_row = (K + values_per_pack - 1) / values_per_pack;
        uint w_base = col * packs_per_row;

        for (uint k = 0; k < K; k++) {
            uint pack_idx = k / values_per_pack;
            uint in_pack_idx = k % values_per_pack;
            uint packed = weight[w_base + pack_idx];

            uint g = k / group_size;
            bfloat scale = as_type<bfloat>(scales[sb_base + g]);
            bfloat bias  = as_type<bfloat>(biases[sb_base + g]);

            bfloat w;
            if (bits == 4) {
                w = dequant_4bit(packed, in_pack_idx, scale, bias);
            } else {
                w = dequant_8bit(packed, in_pack_idx, scale, bias);
            }

            bfloat x = bfloat(input[row * K + k]);
            acc += float(w) * float(x);
        }
    }

    output[row * params.N + col] = acc;
}

// ---------------------------------------------------------------------------
// SIMD-cooperative quantized matrix-vector multiply kernel.
//
// Matches MLX's qmv_fast accumulation pattern exactly:
//   - 2 simdgroups per threadgroup, 4 output rows per simdgroup (8 total)
//   - Each thread in a simdgroup processes values_per_thread elements
//   - simd_sum() reduction across the 32 threads in a simdgroup
//
// This ensures bit-identical f32 accumulation order to MLX, which is
// critical for matching inference results.
//
// Dispatch (from Rust):
//   threadgroup_size = (32, 2, 1)   — 32 threads * 2 simdgroups = 64 threads
//   threadgroups     = (M, ceil(N / 8), 1)
//
// Buffer layout matches the original kernel exactly.
// ---------------------------------------------------------------------------

// Constants matching MLX's qmv_fast_impl
constant constexpr int SIMD_SIZE_CONST = 32;

// -- load_vector for 4-bit: pre-divide inputs for packed dot product --
// MLX loads values_per_thread (16 for 4-bit) consecutive input values,
// divides them by the positional shift factors, and sums the raw values.
// This allows qdot to multiply directly with masked (unshifted) packed weights.
//
// Input is f32 on device; we truncate to bfloat then promote to float to
// match MLX's bf16 pipeline exactly.

inline void load_vector_4bit(
    const device float* x_f32,
    thread float* x_thread,
    thread float& sum,
    int count  // values_per_thread, expected 16
) {
    float s = 0.0f;
    // 4-bit: pack_factor=8, values come in groups of 4 from uint16_t reads
    // Divisors: 1, 16, 256, 4096 (corresponding to nibble positions in uint16_t)
    for (int i = 0; i < count; i += 4) {
        // Truncate f32 -> bfloat -> float to match MLX's precision
        float v0 = float(bfloat(x_f32[i]));
        float v1 = float(bfloat(x_f32[i + 1]));
        float v2 = float(bfloat(x_f32[i + 2]));
        float v3 = float(bfloat(x_f32[i + 3]));
        s += v0 + v1 + v2 + v3;
        x_thread[i]     = v0;
        x_thread[i + 1] = v1 / 16.0f;
        x_thread[i + 2] = v2 / 256.0f;
        x_thread[i + 3] = v3 / 4096.0f;
    }
    sum = s;
}

// -- load_vector for 8-bit --
inline void load_vector_8bit(
    const device float* x_f32,
    thread float* x_thread,
    thread float& sum,
    int count  // values_per_thread, expected 8
) {
    float s = 0.0f;
    for (int i = 0; i < count; i++) {
        float v = float(bfloat(x_f32[i]));
        s += v;
        x_thread[i] = v;
    }
    sum = s;
}

// -- qdot for 4-bit: masked multiply-accumulate with packed uint16_t weights --
// Reads weight bytes as uint16_t*, multiplies masked values with pre-divided inputs.
// Result: scale * accum + sum * bias  (equivalent to proper dequant + dot)
inline float qdot_4bit(
    const device uint8_t* w,
    const thread float* x_thread,
    float scale,
    float bias,
    float sum,
    int values_per_thread  // 16
) {
    float accum = 0.0f;
    const device uint16_t* ws = (const device uint16_t*)w;
    for (int i = 0; i < (values_per_thread / 4); i++) {
        accum +=
            (x_thread[4 * i]     * float(ws[i] & 0x000fu) +
             x_thread[4 * i + 1] * float(ws[i] & 0x00f0u) +
             x_thread[4 * i + 2] * float(ws[i] & 0x0f00u) +
             x_thread[4 * i + 3] * float(ws[i] & 0xf000u));
    }
    return scale * accum + sum * bias;
}

// -- qdot for 8-bit --
inline float qdot_8bit(
    const device uint8_t* w,
    const thread float* x_thread,
    float scale,
    float bias,
    float sum,
    int values_per_thread  // 8
) {
    float accum = 0.0f;
    for (int i = 0; i < values_per_thread; i++) {
        accum += x_thread[i] * float(w[i]);
    }
    return scale * accum + sum * bias;
}

// The main SIMD-cooperative kernel.
// Supports 4-bit and 8-bit (the two formats used by our model).
kernel void quantized_matmul_simd(
    device const float*    input   [[buffer(0)]],
    device const uint*     weight  [[buffer(1)]],
    device const uint16_t* scales  [[buffer(2)]],
    device const uint16_t* biases  [[buffer(3)]],
    device float*          output  [[buffer(4)]],
    constant QuantizedMatmulParams& params [[buffer(5)]],
    uint3 tid       [[threadgroup_position_in_grid]],
    uint  simd_gid  [[simdgroup_index_in_threadgroup]],
    uint  simd_lid  [[thread_index_in_simdgroup]]
) {
    // Constants matching MLX's qmv_fast_impl template parameters
    const uint num_simdgroups = 2;
    const uint results_per_simdgroup = 4;

    const uint K = params.K;
    const uint N = params.N;
    const uint bits = params.bits;
    const uint group_size = params.group_size;

    // Packing parameters — match MLX's qmv (NOT qmv_fast) for K=2816 compatibility
    // 4-bit: pack_factor=8, packs_per_thread=1, values_per_thread=8, bytes_per_pack=4
    // 8-bit: pack_factor=4, packs_per_thread=2, values_per_thread=8, bytes_per_pack=4
    uint pack_factor;
    uint packs_per_thread;
    uint values_per_thread;
    uint bytes_per_pack;

    if (bits == 4) {
        pack_factor = 8;
        packs_per_thread = 1;
        values_per_thread = 8;     // 8 * 1 = 8, block_size = 256
        bytes_per_pack = 4;
    } else if (bits == 8) {
        pack_factor = 4;
        packs_per_thread = 2;
        values_per_thread = 8;
        bytes_per_pack = 4;
    } else {
        // Unsupported in SIMD path — should not be called for 6-bit
        return;
    }

    const uint block_size = values_per_thread * SIMD_SIZE_CONST;
    const uint scale_step_per_thread = group_size / values_per_thread;

    // Weight bytes pointer
    const device uint8_t* ws = (const device uint8_t*)weight;

    // Compute row in input (tid.x in MLX's grid corresponds to M dimension)
    const uint row = tid.x;  // which input row
    if (row >= params.M) return;

    // Compute output column block
    const uint out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
                         simd_gid * results_per_simdgroup;

    // Bounds check: if all 4 results are out of range, return
    if (out_row >= N) return;

    // Weight row stride in bytes: each output column row has (K / pack_factor) packs,
    // each pack is bytes_per_pack bytes.
    const uint in_vec_size_w = K * bytes_per_pack / pack_factor;  // bytes per weight row
    const uint in_vec_size_g = K / group_size;  // number of scale/bias groups per row

    // Position pointers for this thread
    // ws points to: out_row's weight data + this thread's offset
    const device uint8_t* ws_ptr = ws + out_row * in_vec_size_w +
                                   simd_lid * packs_per_thread * bytes_per_pack;
    // scales/biases: out_row's group data + this thread's group offset
    const device uint16_t* sc_ptr = scales + out_row * in_vec_size_g +
                                    simd_lid / scale_step_per_thread;
    const device uint16_t* bi_ptr = biases + out_row * in_vec_size_g +
                                    simd_lid / scale_step_per_thread;

    // Input pointer: row * K + this thread's element offset
    // MLX reads input as type T (bfloat), but our input is f32.
    // To match MLX's accumulation exactly, we must truncate to bfloat.
    // We'll read f32 and cast to bfloat in the load function.
    const device float* x_f32 = input + row * K + simd_lid * values_per_thread;

    // Output pointer
    device float* y_ptr = output + row * N + out_row;

    // Per-thread accumulation registers
    float result[4] = {0.0f, 0.0f, 0.0f, 0.0f};

    // Main loop over K in blocks of block_size
    if (bits == 4) {
        float x_thread[8];  // qmv: 8 values per thread (block_size=256)
        for (uint k = 0; k < K; k += block_size) {
            // Load input values with pre-division for packed dot product
            // f32 -> bfloat truncation happens inside load_vector_4bit
            float x_sum;
            load_vector_4bit(x_f32, x_thread, x_sum, values_per_thread);

            // Dot product with each of the 4 output rows
            for (uint r = 0; r < results_per_simdgroup; r++) {
                const device uint8_t* wl = ws_ptr + r * in_vec_size_w;
                float s = float(as_type<bfloat>(sc_ptr[r * in_vec_size_g]));
                float b = float(as_type<bfloat>(bi_ptr[r * in_vec_size_g]));
                result[r] += qdot_4bit(wl, x_thread, s, b, x_sum, values_per_thread);
            }

            // Advance pointers by block_size
            ws_ptr += block_size * bytes_per_pack / pack_factor;
            sc_ptr += block_size / group_size;
            bi_ptr += block_size / group_size;
            x_f32  += block_size;
        }
    } else {
        // 8-bit path
        float x_thread[8];
        for (uint k = 0; k < K; k += block_size) {
            float x_sum;
            load_vector_8bit(x_f32, x_thread, x_sum, 8);

            for (uint r = 0; r < results_per_simdgroup; r++) {
                const device uint8_t* wl = ws_ptr + r * in_vec_size_w;
                float s = float(as_type<bfloat>(sc_ptr[r * in_vec_size_g]));
                float b = float(as_type<bfloat>(bi_ptr[r * in_vec_size_g]));
                result[r] += qdot_8bit(wl, x_thread, s, b, x_sum, 8);
            }

            ws_ptr += block_size * bytes_per_pack / pack_factor;
            sc_ptr += block_size / group_size;
            bi_ptr += block_size / group_size;
            x_f32  += block_size;
        }
    }

    // Reduce across simdgroup and write output
    for (uint r = 0; r < results_per_simdgroup; r++) {
        result[r] = simd_sum(result[r]);
        if (simd_lid == 0) {
            // Only write if this output column is in bounds
            if (out_row + r < N) {
                y_ptr[r] = result[r];
            }
        }
    }
}

// ---------------------------------------------------------------------------
// load_vector helpers for bf16 input (used by bf16 variants below).
//
// These mirror load_vector_4bit / load_vector_8bit but read bfloat instead of
// float32.  The values are promoted to float for accumulation — identical
// arithmetic to the f32 path because both paths truncate to bfloat first.
// ---------------------------------------------------------------------------

inline void load_vector_4bit_bf16(
    const device bfloat* x_bf16,
    thread float* x_thread,
    thread float& sum,
    int count  // values_per_thread, expected 16 (same as f32 variant)
) {
    float s = 0.0f;
    for (int i = 0; i < count; i += 4) {
        float v0 = float(x_bf16[i]);
        float v1 = float(x_bf16[i + 1]);
        float v2 = float(x_bf16[i + 2]);
        float v3 = float(x_bf16[i + 3]);
        s += v0 + v1 + v2 + v3;
        x_thread[i]     = v0;
        x_thread[i + 1] = v1 / 16.0f;
        x_thread[i + 2] = v2 / 256.0f;
        x_thread[i + 3] = v3 / 4096.0f;
    }
    sum = s;
}

inline void load_vector_8bit_bf16(
    const device bfloat* x_bf16,
    thread float* x_thread,
    thread float& sum,
    int count  // values_per_thread, expected 8
) {
    float s = 0.0f;
    for (int i = 0; i < count; i++) {
        float v = float(x_bf16[i]);
        s += v;
        x_thread[i] = v;
    }
    sum = s;
}

// ---------------------------------------------------------------------------
// quantized_matmul_simd_bf16
//
// Identical computation to quantized_matmul_simd but with bf16 input and
// bf16 output.  Accumulation remains in f32 for numerical stability.
//
// Buffer layout:
//   buffer(0): input    — bfloat[M][K]
//   buffer(1): weight   — packed uint32[N][packed_k]
//   buffer(2): scales   — uint16_t[N][num_groups]  (bf16 bits)
//   buffer(3): biases   — uint16_t[N][num_groups]  (bf16 bits)
//   buffer(4): output   — bfloat[M][N]
//   buffer(5): params   — QuantizedMatmulParams
// ---------------------------------------------------------------------------
kernel void quantized_matmul_simd_bf16(
    device const bfloat*   input   [[buffer(0)]],
    device const uint*     weight  [[buffer(1)]],
    device const uint16_t* scales  [[buffer(2)]],
    device const uint16_t* biases  [[buffer(3)]],
    device bfloat*         output  [[buffer(4)]],
    constant QuantizedMatmulParams& params [[buffer(5)]],
    uint3 tid       [[threadgroup_position_in_grid]],
    uint  simd_gid  [[simdgroup_index_in_threadgroup]],
    uint  simd_lid  [[thread_index_in_simdgroup]]
) {
    const uint num_simdgroups = 2;
    const uint results_per_simdgroup = 4;

    const uint K = params.K;
    const uint N = params.N;
    const uint bits = params.bits;
    const uint group_size = params.group_size;

    uint pack_factor;
    uint packs_per_thread;
    uint values_per_thread;
    uint bytes_per_pack;

    // IMPORTANT: values_per_thread and packs_per_thread MUST match the f32
    // variant exactly.  A previous bug used values_per_thread=8 for 4-bit
    // which halved the block_size and corrupted all pointer arithmetic.
    if (bits == 4) {
        pack_factor = 8;
        packs_per_thread = 2;       // was incorrectly 1
        values_per_thread = 16;     // was incorrectly 8
        bytes_per_pack = 4;
    } else if (bits == 8) {
        pack_factor = 4;
        packs_per_thread = 2;
        values_per_thread = 8;
        bytes_per_pack = 4;
    } else {
        return;
    }

    const uint block_size = values_per_thread * SIMD_SIZE_CONST;
    const uint scale_step_per_thread = group_size / values_per_thread;

    const device uint8_t* ws = (const device uint8_t*)weight;

    const uint row = tid.x;
    if (row >= params.M) return;

    const uint out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
                         simd_gid * results_per_simdgroup;
    if (out_row >= N) return;

    const uint in_vec_size_w = K * bytes_per_pack / pack_factor;
    const uint in_vec_size_g = K / group_size;

    const device uint8_t* ws_ptr = ws + out_row * in_vec_size_w +
                                   simd_lid * packs_per_thread * bytes_per_pack;
    const device uint16_t* sc_ptr = scales + out_row * in_vec_size_g +
                                    simd_lid / scale_step_per_thread;
    const device uint16_t* bi_ptr = biases + out_row * in_vec_size_g +
                                    simd_lid / scale_step_per_thread;

    // Input is bf16 — read directly without cast.
    const device bfloat* x_bf16 = input + row * K + simd_lid * values_per_thread;

    // Output is bf16.
    device bfloat* y_ptr = output + row * N + out_row;

    float result[4] = {0.0f, 0.0f, 0.0f, 0.0f};

    if (bits == 4) {
        float x_thread[16];  // was incorrectly [8]
        for (uint k = 0; k < K; k += block_size) {
            float x_sum;
            load_vector_4bit_bf16(x_bf16, x_thread, x_sum, values_per_thread);

            for (uint r = 0; r < results_per_simdgroup; r++) {
                const device uint8_t* wl = ws_ptr + r * in_vec_size_w;
                float s = float(as_type<bfloat>(sc_ptr[r * in_vec_size_g]));
                float b = float(as_type<bfloat>(bi_ptr[r * in_vec_size_g]));
                result[r] += qdot_4bit(wl, x_thread, s, b, x_sum, values_per_thread);
            }

            ws_ptr += block_size * bytes_per_pack / pack_factor;
            sc_ptr += block_size / group_size;
            bi_ptr += block_size / group_size;
            x_bf16 += block_size;
        }
    } else {
        float x_thread[8];
        for (uint k = 0; k < K; k += block_size) {
            float x_sum;
            load_vector_8bit_bf16(x_bf16, x_thread, x_sum, 8);

            for (uint r = 0; r < results_per_simdgroup; r++) {
                const device uint8_t* wl = ws_ptr + r * in_vec_size_w;
                float s = float(as_type<bfloat>(sc_ptr[r * in_vec_size_g]));
                float b = float(as_type<bfloat>(bi_ptr[r * in_vec_size_g]));
                result[r] += qdot_8bit(wl, x_thread, s, b, x_sum, 8);
            }

            ws_ptr += block_size * bytes_per_pack / pack_factor;
            sc_ptr += block_size / group_size;
            bi_ptr += block_size / group_size;
            x_bf16 += block_size;
        }
    }

    // Reduce and write bf16 output.
    for (uint r = 0; r < results_per_simdgroup; r++) {
        result[r] = simd_sum(result[r]);
        if (simd_lid == 0) {
            if (out_row + r < N) {
                y_ptr[r] = bfloat(result[r]);
            }
        }
    }
}

// ---------------------------------------------------------------------------
// quantized_matmul_simd_bf16_expert
//
// Same as quantized_matmul_simd_bf16 but accepts byte-offset parameters that
// allow indexing into a 3D packed weight tensor [n_experts, rows, packed_cols]
// without CPU memcpy.  The offsets are in bytes.
//
// Buffer layout:
//   buffer(0): input              — bfloat[M][K]
//   buffer(1): packed_weights     — packed uint32 for ALL experts (3D layout)
//   buffer(2): scales             — uint16_t scales for ALL experts
//   buffer(3): biases             — uint16_t biases for ALL experts
//   buffer(4): output             — bfloat[M][N]
//   buffer(5): params             — QuantizedMatmulParams
//   buffer(6): expert_offset      — byte offset into packed_weights for this expert
//   buffer(7): scales_offset      — byte offset into scales for this expert
//   buffer(8): biases_offset      — byte offset into biases for this expert
// ---------------------------------------------------------------------------
kernel void quantized_matmul_simd_bf16_expert(
    device const bfloat*   input             [[buffer(0)]],
    device const uint8_t*  packed_weights    [[buffer(1)]],
    device const uint8_t*  scales_raw        [[buffer(2)]],
    device const uint8_t*  biases_raw        [[buffer(3)]],
    device bfloat*         output            [[buffer(4)]],
    constant QuantizedMatmulParams& params   [[buffer(5)]],
    constant uint& expert_offset             [[buffer(6)]],
    constant uint& scales_offset             [[buffer(7)]],
    constant uint& biases_offset             [[buffer(8)]],
    uint3 tid       [[threadgroup_position_in_grid]],
    uint  simd_gid  [[simdgroup_index_in_threadgroup]],
    uint  simd_lid  [[thread_index_in_simdgroup]]
) {
    const uint num_simdgroups = 2;
    const uint results_per_simdgroup = 4;

    const uint K = params.K;
    const uint N = params.N;
    const uint bits = params.bits;
    const uint group_size = params.group_size;

    uint pack_factor;
    uint packs_per_thread;
    uint values_per_thread;
    uint bytes_per_pack;

    // ADR-002 fix: 4-bit params must match non-expert bf16 kernel (qmv_fast).
    // Previously had packs_per_thread=1, values_per_thread=8 which halved
    // block_size and corrupted all pointer arithmetic for 4-bit MoE layers.
    if (bits == 4) {
        pack_factor = 8;
        packs_per_thread = 2;       // qmv_fast: 2 (was incorrectly 1)
        values_per_thread = 16;     // qmv_fast: 8*2=16 (was incorrectly 8)
        bytes_per_pack = 4;
    } else if (bits == 8) {
        pack_factor = 4;
        packs_per_thread = 2;
        values_per_thread = 8;
        bytes_per_pack = 4;
    } else {
        return;
    }

    const uint block_size = values_per_thread * SIMD_SIZE_CONST;
    const uint scale_step_per_thread = group_size / values_per_thread;

    // Apply byte offsets to reach this expert's slice in the 3D buffers.
    const device uint8_t* ws = packed_weights + expert_offset;
    const device uint16_t* scales = (const device uint16_t*)(scales_raw + scales_offset);
    const device uint16_t* biases = (const device uint16_t*)(biases_raw + biases_offset);

    const uint row = tid.x;
    if (row >= params.M) return;

    const uint out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
                         simd_gid * results_per_simdgroup;
    if (out_row >= N) return;

    const uint in_vec_size_w = K * bytes_per_pack / pack_factor;
    const uint in_vec_size_g = K / group_size;

    const device uint8_t* ws_ptr = ws + out_row * in_vec_size_w +
                                   simd_lid * packs_per_thread * bytes_per_pack;
    const device uint16_t* sc_ptr = scales + out_row * in_vec_size_g +
                                    simd_lid / scale_step_per_thread;
    const device uint16_t* bi_ptr = biases + out_row * in_vec_size_g +
                                    simd_lid / scale_step_per_thread;

    const device bfloat* x_bf16 = input + row * K + simd_lid * values_per_thread;
    device bfloat* y_ptr = output + row * N + out_row;

    float result[4] = {0.0f, 0.0f, 0.0f, 0.0f};

    if (bits == 4) {
        float x_thread[16];  // ADR-002 fix: was [8], must be [16] for qmv_fast
        for (uint k = 0; k < K; k += block_size) {
            float x_sum;
            load_vector_4bit_bf16(x_bf16, x_thread, x_sum, values_per_thread);

            for (uint r = 0; r < results_per_simdgroup; r++) {
                const device uint8_t* wl = ws_ptr + r * in_vec_size_w;
                float s = float(as_type<bfloat>(sc_ptr[r * in_vec_size_g]));
                float b = float(as_type<bfloat>(bi_ptr[r * in_vec_size_g]));
                result[r] += qdot_4bit(wl, x_thread, s, b, x_sum, values_per_thread);
            }

            ws_ptr += block_size * bytes_per_pack / pack_factor;
            sc_ptr += block_size / group_size;
            bi_ptr += block_size / group_size;
            x_bf16 += block_size;
        }
    } else {
        float x_thread[8];
        for (uint k = 0; k < K; k += block_size) {
            float x_sum;
            load_vector_8bit_bf16(x_bf16, x_thread, x_sum, 8);

            for (uint r = 0; r < results_per_simdgroup; r++) {
                const device uint8_t* wl = ws_ptr + r * in_vec_size_w;
                float s = float(as_type<bfloat>(sc_ptr[r * in_vec_size_g]));
                float b = float(as_type<bfloat>(bi_ptr[r * in_vec_size_g]));
                result[r] += qdot_8bit(wl, x_thread, s, b, x_sum, 8);
            }

            ws_ptr += block_size * bytes_per_pack / pack_factor;
            sc_ptr += block_size / group_size;
            bi_ptr += block_size / group_size;
            x_bf16 += block_size;
        }
    }

    for (uint r = 0; r < results_per_simdgroup; r++) {
        result[r] = simd_sum(result[r]);
        if (simd_lid == 0) {
            if (out_row + r < N) {
                y_ptr[r] = bfloat(result[r]);
            }
        }
    }
}