mlx-native 0.7.1

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
// quantized_matmul_ggml.metal — GGML block-format quantized mat-vec kernels.
//
// Portions of this file are derived from candle-metal-kernels v0.10.2
// (https://github.com/huggingface/candle), Apache-2.0 licensed.
// Original source: llama.cpp's ggml-metal.metal, vendored in candle.
// Source: candle-metal-kernels/src/metal_src/quantized.metal
//
// Block struct definitions and dequantization formulas are byte-for-byte
// compatible with GGUF on-disk format. The kernel dispatch pattern is
// adapted to mlx-native's CommandEncoder API.
//
// Copyright the candle Authors and llama.cpp Authors.
// See LICENSE-APACHE-candle in this directory.

#include <metal_stdlib>
using namespace metal;

// ---- Constants ----

#define QK4_0 32
#define QK8_0 32
#define QK_K  256

#define N_DST       4   // each SIMD group works on 4 rows (Q4_0, Q6_K)
#define N_SIMDGROUP 2   // number of SIMD groups per threadgroup (Q4_0, Q6_K)
#define N_SIMDWIDTH 32  // Apple GPU SIMD width

// Q8_0 uses wider threadgroups: 4 simdgroups × 2 rows = 8 rows/tg.
// Matches llama.cpp N_SG_Q8_0=4, N_R0_Q8_0=2.
#define N_DST_Q8       2   // each SIMD group works on 2 rows
#define N_SIMDGROUP_Q8 4   // 4 SIMD groups per threadgroup (128 threads)

// Packed parameter struct — matches Rust-side GgmlMatvecGpuParams.
struct GgmlMatvecParams {
    int64_t ne00; // K: number of values per weight row (before quantization)
    int64_t ne01; // N: number of weight rows (output dim)
    int64_t ne02; // batch dim for weights
    int64_t ne10; // K: number of values per input row
    int64_t ne12; // batch dim for input
    int64_t ne0;  // output stride (= ne01)
    int64_t ne1;  // M: number of input rows
    uint    r2;   // ne12 / ne02
    uint    r3;   // ne13 / ne03 (always 1 for non-batched)
};

// ---- GGML block struct definitions ----
// Byte-for-byte compatible with GGUF on-disk format.

typedef struct {
    half    d;              // delta (scale)
    uint8_t qs[QK4_0 / 2]; // 32 nibbles packed into 16 bytes
} block_q4_0;
static_assert(sizeof(block_q4_0) == sizeof(half) + QK4_0 / 2, "wrong q4_0 block size");

typedef struct {
    half   d;          // delta (scale)
    int8_t qs[QK8_0];  // 32 signed 8-bit quants
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(half) + QK8_0, "wrong q8_0 block size");

typedef struct {
    uint8_t ql[QK_K/2];      // lower 4 bits of 6-bit values
    uint8_t qh[QK_K/4];      // upper 2 bits of 6-bit values
    int8_t  scales[QK_K/16]; // 8-bit sub-block scales
    half    d;                // super-block scale
} block_q6_K;
static_assert(sizeof(block_q6_K) == sizeof(half) + QK_K/16 + 3*QK_K/4, "wrong q6_K block size");

// Q4_K: 256 values per block, 144 bytes per block.
// Layout: [half d][half dmin][uint8_t scales[12]][uint8_t qs[128]]
//   d     : super-block scale for the 6-bit quantized sub-block scales
//   dmin  : super-block scale for the 6-bit quantized sub-block mins
//   scales: packed 6-bit (sub-scale, sub-min) pairs for 8 sub-blocks
//           (same K_SCALE_SIZE=12 byte layout shared with Q5_K, decoded
//            via the kmask1/kmask2/kmask3 machinery below).
//   qs    : 128 bytes of 4-bit quantized values, low nibble = first half
//           of pair, high nibble = second half of pair.
//
// Q4_K is structurally Q5_K minus the 32-byte qh "high-bit" array.
//
// Source: ggml-common.h block_q4_K (llama.cpp).
#define K_SCALE_SIZE 12
typedef struct {
    half    d;                    // super-block scale for quantized scales
    half    dmin;                 // super-block scale for quantized mins
    uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
    uint8_t qs[QK_K/2];           // quants, low 4 bits (128 bytes)
} block_q4_K;
static_assert(sizeof(block_q4_K) == 2*sizeof(half) + K_SCALE_SIZE + QK_K/2,
              "wrong q4_K block size");

// ADR-022 Phase 2 — Q5_K block (176 bytes).
// Layout: [half d][half dmin][uint8_t scales[12]][uint8_t qh[32]][uint8_t qs[128]]
// Adds a 32-byte qh "high-bit" array vs Q4_K. The high bit OR'd into each
// dequantized 4-bit nibble lifts the value range from [0,15] to [0,31].
typedef struct {
    half    d;
    half    dmin;
    uint8_t scales[K_SCALE_SIZE];
    uint8_t qh[QK_K/8];
    uint8_t qs[QK_K/2];
} block_q5_K;
static_assert(sizeof(block_q5_K) == 2*sizeof(half) + K_SCALE_SIZE + QK_K/8 + QK_K/2,
              "wrong q5_K block size");

// Q5_1 (ADR-022 Phase 1). 32 values per block, 24 bytes per block.
typedef struct {
    half    d;
    half    m;
    uint    qh;
    uint8_t qs[QK4_0 / 2];
} block_q5_1;
static_assert(sizeof(block_q5_1) == 2*sizeof(half) + 4 + QK4_0/2,
              "wrong q5_1 block size");

// IQ4_NL (ADR-022 Phase 1). 32 values per block, 18 bytes per block.
typedef struct {
    half    d;
    uint8_t qs[QK4_0 / 2];
} block_iq4_nl;
static_assert(sizeof(block_iq4_nl) == sizeof(half) + QK4_0/2,
              "wrong iq4_nl block size");

// Frozen IQ4_NL codebook (ggml-common.h:1109-1112). Lock-step with
// host-side `KVALUES_IQ4_NL` in src/gguf/mod.rs and the duplicate in
// quantized_matmul_id_ggml.metal.
constant int8_t kvalues_iq4nl[16] = {
    -127, -104, -83, -65, -49, -35, -22, -10,
    1, 13, 25, 38, 53, 69, 89, 113
};

// Q5_1 dot helper — see id_ggml.metal:135-179 for the formula derivation.
inline float block_q5_1_dot_y(
    device const block_q5_1 * qb,
    float sumy,
    thread float * yl,
    int il
) {
    float d = qb->d;
    float m = qb->m;
    float4 acc = 0.f;
    device const uint16_t * qs = ((device const uint16_t *)qb + 4 + il/2);
    const uint qh = qb->qh;
    for (int i = 0; i < 8; i += 2) {
        acc[0] += yl[i + 0]
                * (float)((qs[i / 2] & 0x000F) | (((qh >> (i + 0 + il      )) << 4 ) & 0x0010));
        acc[1] += yl[i + 1]
                * (float)((qs[i / 2] & 0x0F00) | (((qh >> (i + 1 + il      )) << 12) & 0x1000));
        acc[2] += yl[i + 8]
                * (float)((qs[i / 2] & 0x00F0) | (((qh >> (i + 0 + il + 16)) << 8 ) & 0x0100));
        acc[3] += yl[i + 9]
                * (float)((qs[i / 2] & 0xF000) | (((qh >> (i + 1 + il + 16)) << 16) & 0x10000));
    }
    return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
}

// IQ4_NL dot helper — see id_ggml.metal:181-211 for the codebook-lookup
// rationale (raw yl[], no pre-scale, non-linear).
inline float block_iq4_nl_dot_y(
    device const block_iq4_nl * qb,
    thread float * yl_raw,
    int il
) {
    float d = qb->d;
    float acc = 0.f;
    device const uint8_t * qs = qb->qs + il;
    for (int i = 0; i < 8; i++) {
        const uint8_t b = qs[i];
        acc += yl_raw[i]     * (float)kvalues_iq4nl[b & 0x0F];
        acc += yl_raw[i + 8] * (float)kvalues_iq4nl[(b >> 4) & 0x0F];
    }
    return d * acc;
}

// ---- Q4_0 mat-vec kernel ----
//
// Each SIMD group (32 threads) processes N_DST=4 rows.
// Two SIMD groups per threadgroup => 8 rows per threadgroup.
// Each thread processes half a Q4_0 block (16 nibbles).
//
// Dispatch: threadgroups=(ceil(N/8), M, B), threads_per_tg=(8, 8, 1)

// ADR-009 Phase 3A: match llama.cpp's 4-accumulator layout exactly.
// Using 4 separate accumulators (one per nibble position) instead of 2
// paired accumulators ensures identical floating-point rounding to
// llama.cpp's block_q_n_dot_y for block_q4_0.
inline float block_q4_0_dot_y(
    device const block_q4_0 * qb,
    float sumy,
    thread float * yl,
    int il
) {
    float d = qb->d;
    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
    device const uint16_t * qs = ((device const uint16_t *)qb + 1 + il/2);
    for (int i = 0; i < 8; i += 2) {
        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
        acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
        acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
        acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
    }
    return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
}

kernel void kernel_mul_mv_q4_0_f32(
    device const  void  * src0   [[buffer(0)]],
    device const float  * src1   [[buffer(1)]],
    device       float  * dst    [[buffer(2)]],
    constant GgmlMatvecParams & p [[buffer(3)]],
    uint3 tgpig [[threadgroup_position_in_grid]],
    uint  tiisg [[thread_index_in_simdgroup]],
    uint  sgitg [[simdgroup_index_in_threadgroup]]
) {
    const int nr  = N_DST;
    const int nsg = N_SIMDGROUP;
    const int nw  = N_SIMDWIDTH;

    const int nb = p.ne00 / QK4_0;
    const int r0 = tgpig.x;
    const int r1 = tgpig.y;
    const int im = tgpig.z;

    const int first_row = (r0 * nsg + sgitg) * nr;

    const uint i12 = im % p.ne12;
    const uint i13 = im / p.ne12;

    const uint offset0 = first_row * nb + (i12/p.r2)*(nb*p.ne01) + (i13/p.r3)*(nb*p.ne01*p.ne02);

    device const block_q4_0 * x = (device const block_q4_0 *) src0 + offset0;
    device const float      * y = (device const float      *) src1 + r1*p.ne10 + im*p.ne00*p.ne1;

    float yl[16];
    float sumf[nr] = {0.f};

    const int ix = tiisg / 2;
    const int il = (tiisg % 2) * 8;

    device const float * yb = y + ix * QK4_0 + il;

    // ADR-009 Phase 3A: match llama.cpp's two-accumulator sumy pattern.
    // llama.cpp accumulates sumy[0] (first half) and sumy[1] (second half)
    // separately, then combines. This ensures identical FP rounding.
    for (int ib = ix; ib < nb; ib += nw/2) {
        float sumy[2] = { 0.f, 0.f };
        for (int i = 0; i < 8; i += 2) {
            sumy[0] += yb[i] + yb[i+1];
            yl[i+0] = yb[i+0];
            yl[i+1] = yb[i+1] / 256.f;
            sumy[1] += yb[i+16] + yb[i+17];
            yl[i+8] = yb[i+16] / 16.f;
            yl[i+9] = yb[i+17] / 4096.f;
        }

        for (int row = 0; row < nr; row++) {
            sumf[row] += block_q4_0_dot_y(x + ib + row*nb, sumy[0] + sumy[1], yl, il);
        }

        yb += QK4_0 * 16;
    }

    for (int row = 0; row < nr; ++row) {
        const float tot = simd_sum(sumf[row]);
        if (tiisg == 0 && first_row + row < p.ne01) {
            dst[im*p.ne0*p.ne1 + r1*p.ne0 + first_row + row] = tot;
        }
    }
}

// ---- Q5_1 mat-vec kernel (ADR-022 Phase 1) ----
//
// Same dispatch geometry as Q4_0; differs only in (a) block walked and
// (b) dot helper used. Q5_1 carries an `m` (min) term, contributing
// `m * sumy` to the dot product, plus the qh 5th-bit injection.

kernel void kernel_mul_mv_q5_1_f32(
    device const  void  * src0   [[buffer(0)]],
    device const float  * src1   [[buffer(1)]],
    device       float  * dst    [[buffer(2)]],
    constant GgmlMatvecParams & p [[buffer(3)]],
    uint3 tgpig [[threadgroup_position_in_grid]],
    uint  tiisg [[thread_index_in_simdgroup]],
    uint  sgitg [[simdgroup_index_in_threadgroup]]
) {
    const int nr  = N_DST;
    const int nsg = N_SIMDGROUP;
    const int nw  = N_SIMDWIDTH;

    const int nb = p.ne00 / QK4_0;
    const int r0 = tgpig.x;
    const int r1 = tgpig.y;
    const int im = tgpig.z;

    const int first_row = (r0 * nsg + sgitg) * nr;

    const uint i12 = im % p.ne12;
    const uint i13 = im / p.ne12;

    const uint offset0 = first_row * nb + (i12/p.r2)*(nb*p.ne01) + (i13/p.r3)*(nb*p.ne01*p.ne02);

    device const block_q5_1 * x = (device const block_q5_1 *) src0 + offset0;
    device const float      * y = (device const float      *) src1 + r1*p.ne10 + im*p.ne00*p.ne1;

    float yl[16];
    float sumf[nr] = {0.f};

    const int ix = tiisg / 2;
    const int il = (tiisg % 2) * 8;

    device const float * yb = y + ix * QK4_0 + il;

    for (int ib = ix; ib < nb; ib += nw/2) {
        float sumy = 0.f;
        for (int i = 0; i < 8; i += 2) {
            sumy += yb[i] + yb[i+1];
            yl[i+0] = yb[i+0];
            yl[i+1] = yb[i+1] / 256.f;
            sumy += yb[i+16] + yb[i+17];
            yl[i+8] = yb[i+16] / 16.f;
            yl[i+9] = yb[i+17] / 4096.f;
        }

        for (int row = 0; row < nr; row++) {
            sumf[row] += block_q5_1_dot_y(x + ib + row*nb, sumy, yl, il);
        }

        yb += QK4_0 * 16;
    }

    for (int row = 0; row < nr; ++row) {
        const float tot = simd_sum(sumf[row]);
        if (tiisg == 0 && first_row + row < p.ne01) {
            dst[im*p.ne0*p.ne1 + r1*p.ne0 + first_row + row] = tot;
        }
    }
}

// ---- IQ4_NL mat-vec kernel (ADR-022 Phase 1) ----
//
// IQ4_NL's codebook lookup is non-linear; uses raw yl[] (no pre-scale).

kernel void kernel_mul_mv_iq4_nl_f32(
    device const  void  * src0   [[buffer(0)]],
    device const float  * src1   [[buffer(1)]],
    device       float  * dst    [[buffer(2)]],
    constant GgmlMatvecParams & p [[buffer(3)]],
    uint3 tgpig [[threadgroup_position_in_grid]],
    uint  tiisg [[thread_index_in_simdgroup]],
    uint  sgitg [[simdgroup_index_in_threadgroup]]
) {
    const int nr  = N_DST;
    const int nsg = N_SIMDGROUP;
    const int nw  = N_SIMDWIDTH;

    const int nb = p.ne00 / QK4_0;
    const int r0 = tgpig.x;
    const int r1 = tgpig.y;
    const int im = tgpig.z;

    const int first_row = (r0 * nsg + sgitg) * nr;

    const uint i12 = im % p.ne12;
    const uint i13 = im / p.ne12;

    const uint offset0 = first_row * nb + (i12/p.r2)*(nb*p.ne01) + (i13/p.r3)*(nb*p.ne01*p.ne02);

    device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + offset0;
    device const float        * y = (device const float        *) src1 + r1*p.ne10 + im*p.ne00*p.ne1;

    float yl_raw[16];
    float sumf[nr] = {0.f};

    const int ix = tiisg / 2;
    const int il = (tiisg % 2) * 8;

    device const float * yb = y + ix * QK4_0 + il;

    for (int ib = ix; ib < nb; ib += nw/2) {
        for (int i = 0; i < 8; i++) {
            yl_raw[i]     = yb[i];
            yl_raw[i + 8] = yb[i + 16];
        }

        for (int row = 0; row < nr; row++) {
            sumf[row] += block_iq4_nl_dot_y(x + ib + row*nb, yl_raw, il);
        }

        yb += QK4_0 * 16;
    }

    for (int row = 0; row < nr; ++row) {
        const float tot = simd_sum(sumf[row]);
        if (tiisg == 0 && first_row + row < p.ne01) {
            dst[im*p.ne0*p.ne1 + r1*p.ne0 + first_row + row] = tot;
        }
    }
}

// ---- Q8_0 mat-vec kernel ----
//
// This is the stock candle kernel geometry and reduction path used by the
// old passing TQ stack. Dispatch: threadgroups=(ceil(N/8), M, B),
// threads_per_tg=(8, 8, 1). No threadgroup shared memory.

#define NB_Q8_0 8

kernel void kernel_mul_mv_q8_0_f32(
    device const  void  * src0   [[buffer(0)]],
    device const float  * src1   [[buffer(1)]],
    device       float  * dst    [[buffer(2)]],
    constant GgmlMatvecParams & p [[buffer(3)]],
    uint3 tgpig [[threadgroup_position_in_grid]],
    uint  tiisg [[thread_index_in_simdgroup]],
    uint  sgitg [[simdgroup_index_in_threadgroup]]
) {
    const int nr  = N_DST;
    const int nsg = N_SIMDGROUP;
    const int nw  = N_SIMDWIDTH;

    const int nb = p.ne00 / QK8_0;
    const int r0 = tgpig.x;
    const int r1 = tgpig.y;
    const int im = tgpig.z;

    const int first_row = (r0 * nsg + sgitg) * nr;

    const uint i12 = im % p.ne12;
    const uint i13 = im / p.ne12;

    const uint offset0 = first_row * nb + (i12 / p.r2) * (nb * p.ne01) + (i13 / p.r3) * (nb * p.ne01 * p.ne02);

    device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
    device const float      * y = (device const float      *) src1 + r1 * p.ne10 + im * p.ne00 * p.ne1;

    float yl[NB_Q8_0];
    float sumf[nr] = {0.f};

    const int ix = tiisg / 4;
    const int il = tiisg % 4;

    device const float * yb = y + ix * QK8_0 + NB_Q8_0 * il;

    for (int ib = ix; ib < nb; ib += nw / 4) {
        for (int i = 0; i < NB_Q8_0; ++i) {
            yl[i] = yb[i];
        }

        for (int row = 0; row < nr; row++) {
            device const int8_t * qs = x[ib + row * nb].qs + NB_Q8_0 * il;
            float sumq = 0.f;
            for (int iq = 0; iq < NB_Q8_0; ++iq) {
                sumq += qs[iq] * yl[iq];
            }
            sumf[row] += sumq * x[ib + row * nb].d;
        }

        yb += NB_Q8_0 * nw;
    }

    for (int row = 0; row < nr; ++row) {
        const float tot = simd_sum(sumf[row]);
        if (tiisg == 0 && first_row + row < p.ne01) {
            dst[r1 * p.ne0 + im * p.ne0 * p.ne1 + first_row + row] = tot;
        }
    }
}

// ---- Q6_K mat-vec kernel ----
//
// Dispatch: threadgroups=(ceil(N/2), M, B), threads_per_tg=(2, 32, 1)
// Each threadgroup handles 2 rows (one per SIMD group).

kernel void kernel_mul_mv_q6_K_f32(
    device const  void  * src0   [[buffer(0)]],
    device const float  * src1   [[buffer(1)]],
    device       float  * dst    [[buffer(2)]],
    constant GgmlMatvecParams & p [[buffer(3)]],
    uint3 tgpig [[threadgroup_position_in_grid]],
    uint  tiisg [[thread_index_in_simdgroup]],
    uint  sgitg [[simdgroup_index_in_threadgroup]]
) {
    const uint8_t kmask1 = 0x03;
    const uint8_t kmask2 = 0x0C;
    const uint8_t kmask3 = 0x30;
    const uint8_t kmask4 = 0xC0;

    const int nb = p.ne00 / QK_K;

    const int64_t r0 = tgpig.x;
    const int64_t r1 = tgpig.y;
    const int     im = tgpig.z;

    const int row = 2 * r0 + sgitg;

    const uint i12 = im % p.ne12;
    const uint i13 = im / p.ne12;

    const uint offset0 = (i12/p.r2)*(nb*p.ne01) + (i13/p.r3)*(nb*p.ne01*p.ne02);

    device const block_q6_K * x  = (device const block_q6_K *) src0 + row * nb + offset0;
    device const float      * yy = (device const float      *) src1 + r1*p.ne10 + im*p.ne00*p.ne1;

    float sumf = 0;

    const int tid  = tiisg / 2;
    const int ix   = tiisg % 2;
    const int ip   = tid / 8;
    const int il   = tid % 8;
    const int n    = 4;
    const int l0   = n * il;
    const int is   = 8*ip + l0/16;

    const int y_offset   = 128*ip + l0;
    const int q_offset_l = 64*ip + l0;
    const int q_offset_h = 32*ip + l0;

    for (int i = ix; i < nb; i += 2) {
        device const uint8_t * q1 = x[i].ql + q_offset_l;
        device const uint8_t * q2 = q1 + 32;
        device const uint8_t * qh = x[i].qh + q_offset_h;
        device const int8_t  * sc = x[i].scales + is;

        device const float * y = yy + i * QK_K + y_offset;

        const float dall = x[i].d;

        float4 sums = {0.f, 0.f, 0.f, 0.f};
        for (int l = 0; l < n; ++l) {
            sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
            sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
            sums[2] += y[l+64] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
            sums[3] += y[l+96] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
        }

        sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
    }

    const float tot = simd_sum(sumf);
    if (tiisg == 0) {
        dst[r1*p.ne0 + im*p.ne0*p.ne1 + row] = tot;
    }
}

// ---- Q4_K mat-vec kernel ----
//
// ADR-013 P7 — port of llama.cpp `kernel_mul_mv_q4_K_f32_impl`
// (ggml-metal.metal:7715-7821). Algorithm: for each weight row, decode
// the 8 sub-block (scale, min) 6-bit pairs from the packed 12-byte
// `scales` array, dequant and dot-product against the input vector.
//
// Geometry (mirrors Q5_K mv_id pattern):
//   NSG        = 2 simdgroups per threadgroup
//   nr0_per_sg = 1 row per simdgroup
//   rows/tg    = 2  (one per simdgroup; row = 2*r0 + sgitg)
// Dispatch:    threadgroups=(ceil(N/2), M, B), threads_per_tg=(2, 32, 1)
//
// Scale-decode is identical to Q5_K's: same kmask1=0x3f3f, kmask2=0x0f0f,
// kmask3=0xc0c0, same `sc16[]` packing. Q4_K differs from Q5_K only by
// the absence of the `qh` (high-bit) accumulators — the inner loop
// reduces to (q1[l] & 0x0F) and (q1[l] & 0xF0) >> 4 paired with the
// pre-summed yl/yh/sumy.

kernel void kernel_mul_mv_q4_K_f32(
    device const  void  * src0   [[buffer(0)]],
    device const float  * src1   [[buffer(1)]],
    device       float  * dst    [[buffer(2)]],
    constant GgmlMatvecParams & p [[buffer(3)]],
    uint3 tgpig [[threadgroup_position_in_grid]],
    uint  tiisg [[thread_index_in_simdgroup]],
    uint  sgitg [[simdgroup_index_in_threadgroup]]
) {
    const int nb = p.ne00 / QK_K;

    const int64_t r0 = tgpig.x;
    const int64_t r1 = tgpig.y;
    const int     im = tgpig.z;

    const int row = 2 * (int)r0 + (int)sgitg;

    const uint i12 = im % p.ne12;
    const uint i13 = im / p.ne12;

    const uint offset0 = (i12/p.r2)*(nb*p.ne01) + (i13/p.r3)*(nb*p.ne01*p.ne02);

    device const block_q4_K * x  = (device const block_q4_K *) src0 + row * nb + offset0;
    device const float      * yy = (device const float      *) src1 + r1*p.ne10 + im*p.ne00*p.ne1;

    float sumf = 0.f;

    const uint16_t kmask1 = 0x3f3f;
    const uint16_t kmask2 = 0x0f0f;
    const uint16_t kmask3 = 0xc0c0;

    // tiisg ∈ [0, 31].  Same partitioning as Q5_K mv_id:
    //   tid = tiisg/4 (0..7)
    //   ix  = tiisg%4 (0..3)  → block stride = 4
    //   iq  = tid/4    (0..1) → which half of the super-block (low/high)
    //   ir  = tid%4    (0..3) → which 8-element slice within iq's half
    const int tid = tiisg / 4;
    const int ix  = tiisg % 4;
    const int iq  = tid / 4;
    const int ir  = tid % 4;
    const int n   = 8;

    const int l0       = n * ir;
    const int q_offset = 32 * iq + l0;
    const int y_offset = 64 * iq + l0;

    uint16_t sc16[4];
    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;

    device const float * y1 = yy + ix * QK_K + y_offset;

    for (int i = ix; i < nb; i += 4) {
        device const uint8_t  * q1 = x[i].qs + q_offset;
        device const uint8_t  * q2 = q1 + 64;
        device const half     * dh = &x[i].d;
        // Read packed 6-bit scales/mins as 6 uint16_ts; iq selects
        // which half of the super-block we're decoding.
        device const uint16_t * a  = (device const uint16_t *)x[i].scales + iq;

        device const float * y2 = y1 + 128;
        float yl[16], yh[16];
        float4 sumy = {0.f, 0.f, 0.f, 0.f};
        for (int l = 0; l < n; ++l) {
            yl[l+0] = y1[l +  0]; sumy[0] += yl[l+0];
            yl[l+8] = y1[l + 32]; sumy[1] += yl[l+8];
            yh[l+0] = y2[l +  0]; sumy[2] += yh[l+0];
            yh[l+8] = y2[l + 32]; sumy[3] += yh[l+8];
        }

        sc16[0] = a[0] & kmask1;
        sc16[1] = a[2] & kmask1;
        sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
        sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);

        float4 acc1 = {0.f, 0.f, 0.f, 0.f};
        for (int l = 0; l < n; ++l) {
            // Low/high nibble pairs from q1 (first 32 vals) and q2 (third 32 vals).
            // No qh: Q4_K has no high-bit array, so the Q5_K formula's
            // acc2 (high-bit) accumulators collapse to zero; only the
            // raw nibble dot-products contribute.
            acc1[0] += yl[l+0] * (float)(q1[l] & 0x0F);
            acc1[1] += yl[l+8] * (float)(q1[l] & 0xF0);
            acc1[2] += yh[l+0] * (float)(q2[l] & 0x0F);
            acc1[3] += yh[l+8] * (float)(q2[l] & 0xF0);
        }

        const float dall = (float)dh[0];
        const float dmin = (float)dh[1];
        sumf += dall * ((float)sc8[0] * (acc1[0]        ) +
                        (float)sc8[1] * (acc1[1] / 16.f ) +
                        (float)sc8[4] * (acc1[2]        ) +
                        (float)sc8[5] * (acc1[3] / 16.f )) -
               dmin * (sumy[0] * (float)sc8[2] + sumy[1] * (float)sc8[3] +
                       sumy[2] * (float)sc8[6] + sumy[3] * (float)sc8[7]);

        y1 += 4 * QK_K;
    }

    const float tot = simd_sum(sumf);
    if (tiisg == 0 && row < (int)p.ne01) {
        dst[r1*p.ne0 + im*p.ne0*p.ne1 + row] = tot;
    }
}

// ---- Q5_K dense mat-vec kernel (ADR-022 Phase 2) ----
//
// Port of llama.cpp `kernel_mul_mv_q5_K_f32_impl` (ggml-metal.metal:7837).
// Body is `kernel_mul_mv_q4_K_f32` (above) plus the Q5_K mv_id qh/acc2
// high-bit accumulation block — the only structural delta between Q4_K
// and Q5_K. The geometry, scale-decode (kmask1/2/3 + sc16 packing), and
// final dall/dmin reduction are byte-identical.

kernel void kernel_mul_mv_q5_K_f32(
    device const  void  * src0   [[buffer(0)]],
    device const float  * src1   [[buffer(1)]],
    device       float  * dst    [[buffer(2)]],
    constant GgmlMatvecParams & p [[buffer(3)]],
    uint3 tgpig [[threadgroup_position_in_grid]],
    uint  tiisg [[thread_index_in_simdgroup]],
    uint  sgitg [[simdgroup_index_in_threadgroup]]
) {
    const int nb = p.ne00 / QK_K;

    const int64_t r0 = tgpig.x;
    const int64_t r1 = tgpig.y;
    const int     im = tgpig.z;

    const int row = 2 * (int)r0 + (int)sgitg;

    const uint i12 = im % p.ne12;
    const uint i13 = im / p.ne12;

    const uint offset0 = (i12/p.r2)*(nb*p.ne01) + (i13/p.r3)*(nb*p.ne01*p.ne02);

    device const block_q5_K * x  = (device const block_q5_K *) src0 + row * nb + offset0;
    device const float      * yy = (device const float      *) src1 + r1*p.ne10 + im*p.ne00*p.ne1;

    float sumf = 0.f;

    const uint16_t kmask1 = 0x3f3f;
    const uint16_t kmask2 = 0x0f0f;
    const uint16_t kmask3 = 0xc0c0;

    const int tid = tiisg / 4;
    const int ix  = tiisg % 4;
    const int iq  = tid / 4;
    const int ir  = tid % 4;
    const int n   = 8;

    const int l0       = n * ir;
    const int q_offset = 32 * iq + l0;
    const int y_offset = 64 * iq + l0;

    const uint8_t hm1 = 1u << (2 * iq);
    const uint8_t hm2 = hm1 << 1;
    const uint8_t hm3 = hm1 << 4;
    const uint8_t hm4 = hm2 << 4;

    uint16_t sc16[4];
    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;

    device const float * y1 = yy + ix * QK_K + y_offset;

    for (int i = ix; i < nb; i += 4) {
        device const uint8_t  * q1 = x[i].qs + q_offset;
        device const uint8_t  * q2 = q1 + 64;
        device const uint8_t  * qh = x[i].qh + l0;
        device const half     * dh = &x[i].d;
        device const uint16_t * a  = (device const uint16_t *)x[i].scales + iq;

        device const float * y2 = y1 + 128;
        float yl[16], yh[16];
        float4 sumy = {0.f, 0.f, 0.f, 0.f};
        for (int l = 0; l < n; ++l) {
            yl[l+0] = y1[l +  0]; sumy[0] += yl[l+0];
            yl[l+8] = y1[l + 32]; sumy[1] += yl[l+8];
            yh[l+0] = y2[l +  0]; sumy[2] += yh[l+0];
            yh[l+8] = y2[l + 32]; sumy[3] += yh[l+8];
        }

        sc16[0] = a[0] & kmask1;
        sc16[1] = a[2] & kmask1;
        sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
        sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);

        float4 acc1 = {0.f, 0.f, 0.f, 0.f};
        float4 acc2 = {0.f, 0.f, 0.f, 0.f};
        for (int l = 0; l < n; ++l) {
            uint8_t h = qh[l];
            acc1[0] += yl[l+0] * (float)(q1[l] & 0x0F);
            acc1[1] += yl[l+8] * (float)(q1[l] & 0xF0);
            acc1[2] += yh[l+0] * (float)(q2[l] & 0x0F);
            acc1[3] += yh[l+8] * (float)(q2[l] & 0xF0);
            acc2[0] += (h & hm1) ? yl[l+0] : 0.f;
            acc2[1] += (h & hm2) ? yl[l+8] : 0.f;
            acc2[2] += (h & hm3) ? yh[l+0] : 0.f;
            acc2[3] += (h & hm4) ? yh[l+8] : 0.f;
        }

        const float dall = (float)dh[0];
        const float dmin = (float)dh[1];
        sumf += dall * ((float)sc8[0] * (acc1[0]        + 16.f * acc2[0]) +
                        (float)sc8[1] * (acc1[1] / 16.f + 16.f * acc2[1]) +
                        (float)sc8[4] * (acc1[2]        + 16.f * acc2[2]) +
                        (float)sc8[5] * (acc1[3] / 16.f + 16.f * acc2[3])) -
               dmin * (sumy[0] * (float)sc8[2] + sumy[1] * (float)sc8[3] +
                       sumy[2] * (float)sc8[6] + sumy[3] * (float)sc8[7]);

        y1 += 4 * QK_K;
    }

    const float tot = simd_sum(sumf);
    if (tiisg == 0 && row < (int)p.ne01) {
        dst[r1*p.ne0 + im*p.ne0*p.ne1 + row] = tot;
    }
}