mlx-native 0.3.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
// Flash attention vector kernel for TurboQuant-compressed KV cache (ADR-007 Phase 1.3).
//
// Fork of flash_attn_vec.metal that reads K and V from nibble-packed indices
// + per-position norms, instead of F16/F32 buffers.
//
// The kernel operates in the Hadamard-rotated domain:
//   1. Q is rotated (FWHT) via standalone dispatch before this kernel
//   2. K/V are dequantized inline from nibble indices + scalar codebook
//   3. Dot products are computed in the rotated domain (orthogonal invariance)
//   4. Output stays rotated; caller applies inverse FWHT via standalone dispatch
//
// Dequant: value = CODEBOOK_4BIT[nibble_idx] * inv_sqrt(head_dim) * norm
//   The 16-element codebook fits in registers — zero main-memory bandwidth for dequant.
//
// Packed KV layout: [num_kv_heads, capacity, head_dim/2] u8
//   Low nibble = even coordinate index, high nibble = odd coordinate index
//
// Norms layout: [num_kv_heads, capacity] f32

#include <metal_stdlib>
using namespace metal;

#define N_SIMDWIDTH 32
#define C           32   // KV positions per simdgroup iteration

// Pad x up to next multiple of n (n must be power of 2).
#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1))

// Parameters — same layout as FlashAttnVecParams.
struct FlashAttnVecTqParams {
    uint  n_heads;
    uint  n_kv_heads;
    uint  head_dim;
    uint  kv_seq_len;
    uint  kv_capacity;
    float scale;
    uint  mask_type;
    uint  sliding_window;
    float softcap;
    uint  nwg;
    uint  ring_start;  // ADR-009 Track 2: physical slot of oldest entry in ring buffer
};

// Reduce params — shared with flash_attn_vec.
struct FlashAttnVecReduceParams {
    uint nrows;
};

// ---------------------------------------------------------------------------
// In-place FWHT on float data in shared memory.
//
// x: pointer to float array of length D in threadgroup memory.
// D: dimension (must be power of 2, known at compile time).
// NW: number of threads in the simdgroup (32).
// tx: thread index (0..31).
//
// Each butterfly stage requires a threadgroup barrier.
// For D=256: 8 stages. For D=512: 9 stages.
// ---------------------------------------------------------------------------
template<short D>
inline void fwht_shared(threadgroup float *x, ushort tx, ushort NW) {
    // Normalization factor: 1/sqrt(D)
    // Applied once at the end to avoid repeated multiplications.

    short h = 1;
    while (h < D) {
        short step = h * 2;
        // Each thread processes D/NW pairs per stage
        for (short idx = tx; idx < D / 2; idx += NW) {
            // Map flat index to butterfly pair
            short block = idx / h;
            short offset = idx % h;
            short j = block * step + offset;

            float a = x[j];
            float b = x[j + h];
            x[j]     = a + b;
            x[j + h] = a - b;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
        h *= 2;
    }

    // Normalize: multiply all elements by 1/sqrt(D)
    float inv_sqrt_d = rsqrt(float(D));
    for (short i = tx; i < D; i += NW) {
        x[i] *= inv_sqrt_d;
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);
}


// ---------------------------------------------------------------------------
// 4-bit Lloyd-Max codebook for N(0,1): 16 reconstruction levels.
// Matches CODEBOOK_4BIT in turboquant.rs exactly.
// Fits in registers — zero main-memory bandwidth for dequant.
// ---------------------------------------------------------------------------
constant float CODEBOOK_4BIT[16] = {
    -2.7325896f, -2.0690172f, -1.6180464f, -1.2562312f,
    -0.9423405f, -0.6567591f, -0.3880483f, -0.1283950f,
     0.1283950f,  0.3880483f,  0.6567591f,  0.9423405f,
     1.2562312f,  1.6180464f,  2.0690172f,  2.7325896f,
};

// ---------------------------------------------------------------------------
// Reconstruct float4 from 2 packed bytes (4 nibble indices) using inline
// scalar dequant. No centroid table lookup — just register-resident codebook.
//
// packed_base: pointer to start of this position's packed data [head_dim/2 bytes]
// coord_offset: starting coordinate index (must be multiple of 4)
// scale_norm: pre-multiplied (1/sqrt(head_dim)) * norm
// ---------------------------------------------------------------------------
// Reconstruct float4 from 2 adjacent packed bytes (4 nibble indices).
//
// packed_base: pointer to position's packed data [head_dim/2 bytes]
// byte_offset: byte offset into packed_base (= coord_offset / 2, always even)
// scale_norm: pre-multiplied (1/sqrt(head_dim)) * norm
inline float4 dequant_tq_float4(
    device const uint8_t *packed_base,
    uint byte_offset,
    float scale_norm
) {
    // Single 16-bit load for 2 adjacent bytes = 4 nibbles.
    // This is cheaper than two separate byte loads.
    ushort packed = *((device const ushort *)(packed_base + byte_offset));

    // Extract 4 nibble indices from the 16-bit value.
    uint idx0 = packed & 0xFu;
    uint idx1 = (packed >> 4u) & 0xFu;
    uint idx2 = (packed >> 8u) & 0xFu;
    uint idx3 = (packed >> 12u) & 0xFu;

    return float4(
        CODEBOOK_4BIT[idx0] * scale_norm,
        CODEBOOK_4BIT[idx1] * scale_norm,
        CODEBOOK_4BIT[idx2] * scale_norm,
        CODEBOOK_4BIT[idx3] * scale_norm
    );
}


// ---------------------------------------------------------------------------
// Main TQ flash attention vector kernel.
//
// Same structure as flash_attn_vec_impl but:
//   - Q is FWHT-rotated IN-KERNEL in shared memory before dot products
//   - K/V dequant is inline from nibble-packed buffers + register codebook
//   - Accumulated output is inverse-FWHT-rotated IN-KERNEL before writing
//   - Both FWHT transforms fused to eliminate 2 extra dispatches + barriers
// ---------------------------------------------------------------------------
template<short DK, short DV>
kernel void flash_attn_vec_tq_impl(
    constant FlashAttnVecTqParams   &params      [[buffer(0)]],
    device const float              *Q           [[buffer(1)]],
    device const uint8_t            *K_packed    [[buffer(2)]],
    device const float              *K_norms     [[buffer(3)]],
    device const uint8_t            *V_packed    [[buffer(4)]],
    device const float              *V_norms     [[buffer(5)]],
    device       float              *dst         [[buffer(6)]],
    threadgroup  half               *shmem       [[threadgroup(0)]],
    uint3  tgpig [[threadgroup_position_in_grid]],
    ushort tiisg [[thread_index_in_simdgroup]],
    ushort sgitg [[simdgroup_index_in_threadgroup]])
{
    // Compile-time constants.
    constexpr short DK4 = DK / 4;
    constexpr short DV4 = DV / 4;
    constexpr short NW  = N_SIMDWIDTH;
    constexpr short NL  = NW;
    constexpr short PK  = PAD2(DK, 128);
    constexpr short PK4 = PK / 4;
    constexpr short PV  = PAD2(DV, 128);
    constexpr short PV4 = PV / 4;
    constexpr short SH  = 4 * C;  // 128 halfs = 64 floats

    static_assert(DK % 32 == 0, "DK must be divisible by 32");
    static_assert(DV % 32 == 0, "DV must be divisible by 32");
    static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
    static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");

    const uint NWG = params.nwg;

    const ushort iwg = tgpig[2] % NWG;
    const ushort iq2 = tgpig[1];  // head index
    const ushort iq1 = tgpig[0];  // query index (0 for decode)

    // GQA: map query head to KV head.
    const uint heads_per_kv = params.n_heads / params.n_kv_heads;
    const uint kv_head = iq2 / heads_per_kv;

    // Shared memory layout:
    //   [0, PK)                     — Q as half4 (pre-rotated by caller)
    //   [PK, PK + SH)               — scratch for attention scores
    //   [PK + SH, PK + SH + 2*PV)   — output accumulator as float4
    //
    // FWHT is NOT done in this kernel. With NWG=32, doing FWHT per-workgroup
    // would repeat it 32× per head. Instead:
    //   - Caller pre-rotates Q via a standalone FWHT dispatch (1× per head)
    //   - Partials are written in the rotated domain
    //   - Caller applies inverse FWHT after the reduce kernel (1× per head)

    // Pointers.
    threadgroup half4  *sq4 = (threadgroup half4  *)(shmem);
    threadgroup float  *ss  = (threadgroup float  *)(shmem + PK);
    threadgroup float4 *so4 = (threadgroup float4 *)(shmem + PK + SH);

    // Load PRE-ROTATED Q directly into shared memory as half4.
    {
        for (ushort i = tiisg; i < PK4; i += NW) {
            if (i < DK4) {
                float4 qval = *((device const float4 *)(Q + iq2 * DK + i * 4));
                sq4[i] = half4(qval);
            } else {
                sq4[i] = half4(0.0h);
            }
        }
    }

    // Zero the output accumulator.
    // Each thread owns its SIMD lane.
    so4 += tiisg;
    for (short i = 0; i < DV4 / NL; ++i) {
        so4[i * NL] = float4(0.0f);
    }

    // Zero scratch buffer.
    for (ushort i = tiisg; i < SH / 4; i += NW) {
        ((threadgroup float *)(shmem + PK))[i] = 0.0f;
    }

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Online softmax state.
    float S = 0.0f;
    float M = -FLT_MAX / 2;

    const ushort tx = tiisg;

    // Masking bounds (ADR-009 Track 2: ring-buffer-aware chronology).
    //
    // ring_start = physical slot of the oldest entry in the ring buffer.
    // Before wrap: ring_start = 0 (physical == logical).
    // After wrap: ring_start = write_pos % capacity.
    //
    // logical_idx(physical_slot) = (physical_slot - ring_start + capacity) % capacity
    // This maps physical slots to chronological order: 0 = oldest, kv_seq_len-1 = newest.
    const uint kv_seq_len = params.kv_seq_len;
    const uint kv_capacity = params.kv_capacity;
    const uint ring_start = params.ring_start;

    // For sliding window: oldest visible logical index.
    uint window_start_logical = 0;
    if (params.mask_type == 2 && params.sliding_window > 0 && kv_seq_len > params.sliding_window) {
        window_start_logical = kv_seq_len - params.sliding_window;
    }

    // Reference to Q in shared memory for dot products.
    threadgroup const half4 *pq4 = sq4 + tx;

    // Main loop over KV cache in chunks of C=32.
    // Iterate over all physical slots (up to kv_seq_len).
    for (uint ic0 = iwg; ; ic0 += NWG) {
        uint ic = ic0 * C;
        if (ic >= kv_seq_len) {
            break;
        }

        // Compute implicit mask for this chunk using ring-buffer chronology.
        {
            uint k_pos = ic + tx;  // physical slot index
            float mask_val = 0.0f;
            if (k_pos >= kv_seq_len) {
                // Beyond valid range
                mask_val = -65504.0f;
            } else {
                // Map physical slot to logical (chronological) index
                uint logical_idx = (k_pos - ring_start + kv_capacity) % kv_capacity;
                if (logical_idx >= kv_seq_len || logical_idx < window_start_logical) {
                    mask_val = -65504.0f;
                }
            }
            ss[tx] = mask_val;
        }

        // Skip all-masked chunks.
        if (simd_max(ss[tiisg]) <= -65504.0f) {
            continue;
        }

        // ---- Q * K^T (in rotated domain) ----
        {
            float mqk[C];

            // Pre-compute inv_sqrt(DK) once — used for all K dequant in this chunk.
            const float inv_sqrt_dk = rsqrt(float(DK));

            for (short cc = 0; cc < C; ++cc) {
                uint kv_pos = ic + cc;
                if (kv_pos >= kv_seq_len) {
                    mqk[cc] = 0.0f;
                    continue;
                }

                float k_sn = K_norms[kv_head * params.kv_capacity + kv_pos] * inv_sqrt_dk;

                // Base pointer to packed K for this position.
                // Precompute the position offset once, then stride by 2 bytes per ii.
                device const uint8_t *k_base =
                    K_packed + (kv_head * params.kv_capacity + kv_pos) * (DK / 2) + tx * 2u;

                float partial = 0.0f;
                for (short ii = 0; ii < DK4 / NL; ++ii) {
                    float4 k_val = dequant_tq_float4(k_base, (uint)(ii * NL) * 2u, k_sn);
                    partial += dot(k_val, float4(pq4[ii * NL]));
                }
                mqk[cc] = simd_sum(partial);
            }

            // Combine with mask and scale.
            ss[tx] = fma(mqk[tx], params.scale, ss[tx]);
        }

        simdgroup_barrier(mem_flags::mem_threadgroup);

        // ---- Online softmax ----
        {
            const float m_old = M;
            const float s_new = ss[tiisg];

            M = simd_max(max(M, s_new));

            const float ms = exp(m_old - M);
            const float vs = exp(s_new - M);

            S = S * ms + simd_sum(vs);

            ss[tiisg] = vs;

            // Rescale previous output accumulation.
            for (short ii = 0; ii < DV4 / NL; ++ii) {
                so4[ii * NL] *= ms;
            }
        }

        simdgroup_barrier(mem_flags::mem_threadgroup);

        // ---- O = O + softmax_weights * V (in rotated domain) ----
        {
            float4 lo[DV4 / NL];
            for (short ii = 0; ii < DV4 / NL; ++ii) {
                lo[ii] = float4(0.0f);
            }

            // Pre-compute inv_sqrt(DV) once — used for all V dequant in this chunk.
            const float inv_sqrt_dv = rsqrt(float(DV));

            for (short cc = 0; cc < C; ++cc) {
                uint kv_pos = ic + cc;
                if (kv_pos >= kv_seq_len) continue;

                // Fold weight into scale_norm: dequant returns pre-weighted values.
                float v_sw = V_norms[kv_head * params.kv_capacity + kv_pos] * inv_sqrt_dv * ss[cc];
                device const uint8_t *v_base =
                    V_packed + (kv_head * params.kv_capacity + kv_pos) * (DV / 2) + tx * 2u;

                for (short ii = 0; ii < DV4 / NL; ++ii) {
                    lo[ii] += dequant_tq_float4(v_base, (uint)(ii * NL) * 2u, v_sw);
                }
            }

            for (short ii = 0; ii < DV4 / NL; ++ii) {
                so4[ii * NL] += lo[ii];
            }
        }
    }

    // Store M and S for the reduce kernel.
    if (tiisg == 0) {
        ss[0] = S;
        ss[1] = M;
    }

    // Remove per-thread offset before cross-simdgroup access.
    so4 -= tiisg;

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // ---- Write output (rotated domain — caller handles inverse FWHT) ----
    //
    // so4[0..DV4) holds the accumulated partial in the rotated domain.
    // NO in-kernel FWHT — with NWG=32 that would repeat it 32× per head.
    // The caller applies inverse FWHT once after the reduce kernel.
    //
    // NWG==1: normalize by 1/S and write directly to dst.
    // NWG>1:  write raw (unnormalized) partials to tmp for the reduce kernel.
    if (sgitg == 0) {
        const int64_t nrows = params.n_heads;
        const int64_t rid = iq2 + (int64_t)iq1 * params.n_heads;
        const uint NWG_val = params.nwg;

        // When NWG==1: normalize by 1/S directly. Otherwise write raw for reduce.
        const float inv_S = (NWG_val == 1) ? ((S == 0.0f) ? 0.0f : 1.0f / S) : 1.0f;

        device float4 *dst4 = (device float4 *)dst;
        for (ushort i = tiisg; i < DV4; i += NW) {
            dst4[rid * DV4 * NWG_val + NWG_val * i + iwg] = so4[i] * inv_S;
        }

        // Store S and M for the reduce kernel (NWG > 1 only).
        if (NWG_val > 1 && tiisg == 0) {
            device float *dst1 = (device float *)dst + nrows * DV * NWG_val;
            dst1[rid * (2 * NWG_val) + 2 * iwg + 0] = S;
            dst1[rid * (2 * NWG_val) + 2 * iwg + 1] = M;
        }
    }
}


// --------------------------------------------------------------------------
// Kernel instantiations — TQ variants
// --------------------------------------------------------------------------

typedef decltype(flash_attn_vec_tq_impl<256, 256>) flash_attn_vec_tq_t;

template [[host_name("flash_attn_vec_tq_dk256")]]
kernel flash_attn_vec_tq_t flash_attn_vec_tq_impl<256, 256>;

template [[host_name("flash_attn_vec_tq_dk512")]]
kernel flash_attn_vec_tq_t flash_attn_vec_tq_impl<512, 512>;