mlx-native 0.9.0

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
// dense_mm_bf16_tensor.metal — Dense bf16×f32 → f32 tensor-API matmul.
//
// Port of llama.cpp's `kernel_mul_mm_bf16_f32` template instantiation
// (ggml/src/ggml-metal/ggml-metal.metal:10032) with the
// `GGML_METAL_HAS_TENSOR` branch active.  Tile geometry, shared memory
// layout, and matmul2d descriptor are identical to our existing
// `quantized_matmul_mm_tensor.metal`:
//   * sa (A tile): bfloat, [NR0=64][NK=32] row-major, 4 KB
//   * sb (B tile): bfloat, [NR1=32][NK=32] row-major, 4 KB
//   * sc (partial-tile write-back): float, reuses shmem base
//
// The kernel is used on hf2q's non-flash-attention prefill path for
// BOTH of the two attention mat-muls:
//   1. Q @ K^T -> scores  (K is src0 bf16 weight; Q is src1 f32 input)
//   2. scores @ V -> out  (V is src0 bf16 weight; scores is src1 f32
//      input after softmax)
//
// The non-tensor simdgroup MMA fallback is intentionally NOT included
// — mlx-native targets M3+ where tensor-ops is always available, and
// keeping the kernel single-path avoids the two-branch duplication
// llama.cpp carries for backward compatibility.  If a pre-M3 user ever
// runs this build, kernel compile will fail cleanly and the host-side
// dispatcher (dense_matmul_bf16_f32_tensor_mm) returns an error, and
// the caller must use a different attention path (flash-attn or simd
// MMA mat-mul).
//
// ne02 / r2 broadcast:  hf2q's grouped-query attention has nh heads
// attending but nkv shared KV heads.  The attention mat-muls iterate
// over nh in the z-axis (im = tgpig.z); the src0 head offset divides
// by r2 = nh/nkv so the same KV head is broadcast across all heads in
// its GQA group.  This matches llama.cpp's ggml_mul_mat r2/r3 contract.
//
// Portions of this file are derived from llama.cpp
// (https://github.com/ggml-org/llama.cpp), MIT licensed.
// Copyright the llama.cpp Authors.  See LICENSE-MIT-llamacpp.

#include <metal_stdlib>
#include <metal_tensor>
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>

using namespace metal;
using namespace mpp::tensor_ops;

// ---- Host-facing params struct ---------------------------------------
//
// Mirrors the ggml matmul args layout we use in quantized_matmul_mm_tensor.
// ne00 = contract dim K (shared between src0 and src1).
// ne0  = output N (= src0.ne[1], number of weight rows).
// ne1  = output M (= src1.ne[1], number of input rows).
// ne02 = src0 batch count (GQA: nkv).
// ne12 = src1 batch count (GQA: nh); r2 = ne12 / ne02.
// nb01 = src0 row stride (bytes) = ne00 * sizeof(bfloat).
// nb02 = src0 batch stride (bytes).
// nb11 = src1 row stride (bytes) = ne00 * sizeof(float).
// nb12 = src1 batch stride (bytes).

struct DenseMmBf16F32TensorParams {
    int32_t  ne00;
    int32_t  ne02;
    uint64_t nb01;
    uint64_t nb02;
    uint64_t nb03;
    int32_t  ne12;
    uint64_t nb10;
    uint64_t nb11;
    uint64_t nb12;
    uint64_t nb13;
    int32_t  ne0;
    int32_t  ne1;
    int16_t  r2;
    int16_t  r3;
};

// ---- Kernel --------------------------------------------------------
//
// Directly modelled on hf2q_mul_mm_tensor_impl from
// quantized_matmul_mm_tensor.metal.  The ONLY structural differences:
//   * src0 is bfloat instead of block_q; no dequantize is needed, so
//     the A-tile staging is a plain copy loop (bfloat -> bfloat in
//     shmem) rather than the 16-element dequantize + permuted store
//     from the quantized path.
//   * A-stage tile stride: src0 is laid out row-major bfloat, so each
//     thread loads 16 consecutive bfloats from src0 and stores them
//     into sa at the llama.cpp tile-row/tile-col positions matching
//     the tensor_ops matmul2d contract (same as the quantized path
//     after dequantize).

kernel void hf2q_dense_mm_bf16_f32_tensor(
        constant DenseMmBf16F32TensorParams & args,
        device const char * src0,        // bfloat [ne02, ne01, ne00]
        device const char * src1,        // float  [ne12, ne11, ne10]
        device       char * dst,         // float  [batch, ne1, ne0]
        threadgroup  char * shmem [[threadgroup(0)]],
        uint3  tgpig[[threadgroup_position_in_grid]],
        ushort tiitg[[thread_index_in_threadgroup]],
        ushort sgitg[[simdgroup_index_in_threadgroup]]) {

    threadgroup bfloat * sa = (threadgroup bfloat *)(shmem);
    threadgroup bfloat * sb = (threadgroup bfloat *)(shmem + 4096);
    threadgroup float  * sc = (threadgroup float  *)(shmem);  // partial-tile write-back reuses shmem base

    constexpr int NR0 = 64;
    constexpr int NR1 = 32;
    constexpr int NK  = 32;
    constexpr int NL0 = NK/16;   // 2 — threads per A-tile row-block
    constexpr int NL1 = NK/8;    // 4 — threads per B-tile row-block

    const int im = tgpig.z;
    const int r0 = tgpig.y * NR0;
    const int r1 = tgpig.x * NR1;

    const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
    const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1;

    const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1;
    const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1;

    const short il0 = (tiitg % NL0);

    const int i12 = im % args.ne12;
    const int i13 = im / args.ne12;

    const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;

    // src0: bfloat row-major, row = (r0 + lr0), batch offset from im->i12/r2.
    // Start at the first bfloat of this thread's row, offset by the K-tile
    // column group that this thread-index owns (16 bfloats per group, so
    // il0 * 16 values = il0 * 16 bfloats in).
    device const bfloat * x = (device const bfloat *)(src0 + args.nb01*(r0 + lr0) + offset0) + il0 * 16;

    const short iy = 8*(tiitg % NL1);

    device const float * y = (device const float *)(src1
        + args.nb13*i13
        + args.nb12*i12
        + args.nb11*(r1 + lr1)
        + args.nb10*iy);

    auto tA = tensor<threadgroup bfloat, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK,  NR0));
    auto tB = tensor<threadgroup bfloat, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));

    matmul2d<
        matmul2d_descriptor(NR1, NR0, NK, false, true, false,
            matmul2d_descriptor::mode::multiply_accumulate),
        execution_simdgroups<4>> mm;

    auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();

    for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
        // Full-tile fast path when the entire NK=32 K-block fits inside
        // ne00; gated slow path for the partial trailing tile when
        // ne00 is not a multiple of NK. Without the gate, the in-tile
        // unconditional 16-element / 8-element loads read past the end
        // of `x` / `y` buffers and the matmul accumulates garbage into
        // cT — visible at the consumer as random per-row drift (hf2q
        // ADR-005 iter 67 bisection on bge-small-en-v1.5: cosine
        // 0.99999 at K=32 → 0.75-0.92 at K=33-200, regardless of
        // softmax masking; smooth per-K cliff localized to this loop).
        const bool full_tile = (loop_k + NK <= args.ne00);

        // ---- Stage A tile (bfloat -> bfloat copy into sa [NR0][NK]).
        // No dequantize: A is already bfloat in device memory.  Same
        // destination layout as quantized_matmul_mm_tensor.metal:
        //   sa[NK*(8*sy + ly) + 8*sx + lx] = x[i]
        // with (sx, sy, lx, ly) derived from (tiitg, i, il0).
        {
            threadgroup_barrier(mem_flags::mem_threadgroup);

            if (full_tile) {
                for (short i = 0; i < 16; i++) {
                    const short sx = 2*il0 + i/8;
                    const short sy = (tiitg/NL0)/8;
                    const short lx = i%8;
                    const short ly = (tiitg/NL0)%8;
                    *(sa + NK*(8*sy + ly) + 8*sx + lx) = x[i];
                }
            } else {
                // Partial tile: gate per-element. This thread's x[i]
                // covers absolute K = loop_k + il0*16 + i (see line
                // 129 — x is initialized at offset il0*16 within the
                // src0 row).
                for (short i = 0; i < 16; i++) {
                    const short sx = 2*il0 + i/8;
                    const short sy = (tiitg/NL0)/8;
                    const short lx = i%8;
                    const short ly = (tiitg/NL0)%8;
                    const int abs_k = loop_k + il0*16 + i;
                    const bfloat v = (abs_k < args.ne00) ? x[i] : bfloat(0.0);
                    *(sa + NK*(8*sy + ly) + 8*sx + lx) = v;
                }
            }
        }

        // ---- Stage B tile (f32 -> bfloat per-element cast into sb) ----
        // Metal has `float2x4` but no `bfloat2x4` matrix type, so the
        // "single vector store" trick the quantized tensor kernel uses
        // (cast float2x4 -> bfloat2x4) is not available here.  We load
        // 8 f32 values as a float4×2 pair and store them as 8
        // individual bfloats.  The Metal compiler packs this into a
        // half8-equivalent (bfloat shares the 16-bit storage size so
        // the store lowers to a single 128-bit write).
        {
            const short sx = (tiitg%NL1);
            const short sy = (tiitg/NL1)/8;
            const short ly = (tiitg/NL1)%8;

            threadgroup bfloat * sb_ptr = sb + NK*(8*sy + ly) + 8*sx;

            if (full_tile) {
                float4 y_lo = *((device const float4 *) y);
                float4 y_hi = *((device const float4 *)(y + 4));

                sb_ptr[0] = bfloat(y_lo[0]);
                sb_ptr[1] = bfloat(y_lo[1]);
                sb_ptr[2] = bfloat(y_lo[2]);
                sb_ptr[3] = bfloat(y_lo[3]);
                sb_ptr[4] = bfloat(y_hi[0]);
                sb_ptr[5] = bfloat(y_hi[1]);
                sb_ptr[6] = bfloat(y_hi[2]);
                sb_ptr[7] = bfloat(y_hi[3]);
            } else {
                // Partial tile: y[i] for thread (tiitg%NL1) covers
                // absolute K = loop_k + iy + i (line 131 — iy is the
                // thread's K-base within the tile).
                for (short i = 0; i < 8; i++) {
                    const int abs_k = loop_k + iy + i;
                    sb_ptr[i] = (abs_k < args.ne00) ? bfloat(y[i]) : bfloat(0.0);
                }
            }
        }

        x += NK;
        y += NK;

        threadgroup_barrier(mem_flags::mem_threadgroup);

        auto sA = tA.slice(0, 0);
        auto sB = tB.slice(0, 0);
        mm.run(sB, sA, cT);
    }

    if (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1) {
        device float * C = (device float *) dst +
            r0 +
            r1 * args.ne0 + im*args.ne1*args.ne0;

        auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1));
        cT.store(tC);
    } else {
        threadgroup_barrier(mem_flags::mem_threadgroup);

        auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
        cT.store(tC);

        threadgroup_barrier(mem_flags::mem_threadgroup);

        if (sgitg == 0) {
            for (int j = tiitg; j < nr1; j += NR1) {
                device float  * D  = (device float  *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0;
                device float4 * D4 = (device float4 *) D;

                threadgroup float  * C  = sc + (j*NR0);
                threadgroup float4 * C4 = (threadgroup float4 *) C;

                int i = 0;
                for (; i < nr0/4; i++) {
                    *(D4 + i) = *(C4 + i);
                }

                i *= 4;
                for (; i < nr0; i++) {
                    *(D + i) = *(C + i);
                }
            }
        }
    }
}

// ===========================================================================
// ADR-029 iter-80 H60: V2 LARGE-TILE variant — port of
// quantized_matmul_mm_tensor.metal:hf2q_mul_mm_tensor_v2_f16_impl with
// bfloat-typed A staging (instead of half-typed shadow).
//
// Motivation: V1 above uses NR0=64, NR1=32 tile. At pp8333 scores@V
// (M=qL=8333, N=hd=512, K=kL=8333) V1 dispatches 261 × 8 = 2088 TGs/head
// × 16 heads = 33,408 TGs per layer. V2 dispatches 8 × 66 = 528 TGs/head
// × 16 = 8,448 TGs per layer — 4× threadgroup-count reduction.
//
// V2 large-tile already landed in quantized_matmul_mm_tensor.metal
// (iter-23 H28-A) for Q-typed weight matmuls, but the DENSE bf16 path
// used by HF2Q_NO_FA's Q@K^T and scores@V was never ported. This is
// the missing lever (iter-79 audit).
//
// Geometry constants (mirror peer ggml-metal-impl.h):
//   SZ_SIMDGROUP        = 16
//   N_MM_BLOCK_X        = 4
//   N_MM_BLOCK_Y        = 2
//   N_MM_SIMD_GROUP_X   = 2
//   N_MM_SIMD_GROUP_Y   = 2
//   N_MM_NK             = 2
//   NRA = SZ * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y = 64  (M tile)
//   NRB = SZ * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X = 128 (N tile)
//   NUM_THREADS = 32 * 2 * 2 = 128
//
// hf2q axis convention (mirroring quantized V2):
//   ne0 = hf2q-N = peer-M (output cols = weight rows)
//   ne1 = hf2q-M = peer-N (output rows = input rows)
//   Dispatcher should set tgpig.y = ceil(hf2q-N / NRA), tgpig.x = ceil(hf2q-M / NRB).
//
// Shared memory: NRA × N_MM_NK_TOTAL × sizeof(bfloat) = 64 × 32 × 2 = 4096 B
// (half the V1 shmem budget; B is read direct from device memory).
//
// Coherence: byte-identical-class to V1 for the same input bytes (no
// dequantize, no rounding round-trip). Stage A is a plain bfloat->bfloat
// shmem copy; matmul2d accumulates in float; final store is f32.
// ===========================================================================

[[host_name("hf2q_dense_mm_bf16_f32_tensor_v2")]]
kernel void hf2q_dense_mm_bf16_f32_tensor_v2(
        constant DenseMmBf16F32TensorParams & args,
        device const char * srcA,        // bfloat [ne02, M_peer, K]    nb01 = K*sizeof(bfloat)
        device const char * srcB,        // float  [ne12, N_peer, K]    nb11 = K*sizeof(float)
        device       char * dst,         // float  [batch, M_peer, N_peer]
        threadgroup  char * shmem [[threadgroup(0)]],
        uint3  tgpig [[threadgroup_position_in_grid]],
        ushort tiitg [[thread_index_in_threadgroup]],
        ushort sgitg [[simdgroup_index_in_threadgroup]]) {
    (void) sgitg;

    // Peer's ggml convention: A(M_peer, K) × B(K, N_peer) -> C(M_peer, N_peer)
    // with column-major-over-M_peer C layout. hf2q SWAPS M/N at the
    // dispatcher so ne0 = hf2q-N (peer M_peer), ne1 = hf2q-M (peer N_peer).
    const int K      = args.ne00;
    const int M_peer = args.ne0;   // hf2q ne0 = peer M_peer
    const int N_peer = args.ne1;   // hf2q ne1 = peer N_peer

    const int im = tgpig.z;
    const int i12 = im % args.ne12;
    const int i13 = im / args.ne12;
    const uint64_t offset0 = (i12 / args.r2) * args.nb02 + (i13 / args.r3) * args.nb03;

    // Tile constants (peer's ggml-metal-impl.h).
    constexpr int SZ_SIMDGROUP        = 16;
    constexpr int N_MM_BLOCK_X        = 4;
    constexpr int N_MM_BLOCK_Y        = 2;
    constexpr int N_MM_SIMD_GROUP_X   = 2;
    constexpr int N_MM_SIMD_GROUP_Y   = 2;
    constexpr int N_MM_NK             = 2;
    constexpr int N_MM_NK_TOTAL       = SZ_SIMDGROUP * N_MM_NK;          // 32
    constexpr int N_SIMDWIDTH         = 32;

    constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y; // 64 = M tile
    constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X; // 128 = N tile

    const int ra = tgpig.y * NRA;   // M_peer offset
    const int rb = tgpig.x * NRB;   // N_peer offset

    threadgroup bfloat * sa = (threadgroup bfloat *)(shmem);

    constexpr int A_WORK_ITEMS = NRA * N_MM_NK;                              // 128
    constexpr int NUM_THREADS  = N_SIMDWIDTH * N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y; // 128

    auto tA = tensor(sa, dextents<int32_t, 2>(N_MM_NK_TOTAL, NRA));

    // B is F32, read directly from device memory (peer ggml-metal.metal:9358-9360).
    device float * ptrB = (device float *)(srcB + args.nb12 * i12 + args.nb13 * i13);
    const int strideB = (int)(args.nb11 / sizeof(float));
    auto tB = tensor(ptrB, dextents<int32_t, 2>(K, N_peer), array<int, 2>({1, strideB}));

    matmul2d<
        matmul2d_descriptor(NRB, NRA, N_MM_NK_TOTAL, false, true, true,
            matmul2d_descriptor::mode::multiply_accumulate),
        execution_simdgroups<N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y>> mm;

    auto cT = mm.get_destination_cooperative_tensor<decltype(tB), decltype(tA), float>();

    for (int loop_k = 0; loop_k < K; loop_k += N_MM_NK_TOTAL) {
        // PHASE 1: stage A tile — bfloat → bfloat direct copy (no dequantize).
        for (int work = tiitg; work < A_WORK_ITEMS; work += NUM_THREADS) {
            const int row     = work / N_MM_NK;
            const int k_chunk = work % N_MM_NK;
            const int k_pos   = loop_k + k_chunk * 16;
            const short k_base = k_chunk * 16;

            if (ra + row < M_peer) {
                device const bfloat * row_ptr =
                    (device const bfloat *)(srcA + args.nb01 * (ra + row) + offset0);
                #pragma clang loop unroll(full)
                for (short i = 0; i < 16; i++) {
                    sa[row * N_MM_NK_TOTAL + (k_base + i)] =
                        (k_pos + i < K) ? row_ptr[k_pos + i] : (bfloat)0;
                }
            } else {
                #pragma clang loop unroll(full)
                for (short i = 0; i < 16; i++) {
                    sa[row * N_MM_NK_TOTAL + (k_base + i)] = (bfloat)0;
                }
            }
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);

        // PHASE 2: tensor matmul.
        auto mA = tA.slice(0, 0);
        auto mB = tB.slice(loop_k, rb);
        mm.run(mB, mA, cT);

        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    // Store cooperative tensor to device output.
    // Output (peer convention): [M_peer, N_peer], column-major-over-M_peer.
    // V1 stores `dst + r0 + r1*ne0 + im*ne1*ne0`; V2's tD with stride
    // (1, M_peer) writes the same byte layout (column-major-over-M_peer).
    device float * dstBatch = (device float *)dst +
        im * (uint64_t)M_peer * (uint64_t)N_peer;
    auto tD = tensor(dstBatch, dextents<int32_t, 2>(M_peer, N_peer),
                     array<int, 2>({1, M_peer}));
    cT.store(tD.slice(ra, rb));
}