mlx-native 0.6.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
// flash_attn_vec_tq_hb.metal — Native TQ SDPA for 5/6/8-bit byte-packed KV cache.
//
// Variant of flash_attn_vec_tq.metal that reads K/V from byte-packed (1 byte/element)
// higher-bit codebook indices instead of nibble-packed 4-bit indices.
//
// Bit-width is selected at compile time via template parameter CODEBOOK_BITS:
//   5  → 32 centroids  (Lloyd-Max N(0,1) optimal)
//   6  → 64 centroids
//   8  → 256 centroids
//
// Packed buffer layout: [num_kv_heads, capacity, head_dim] u8 (byte-packed)
//   One byte per element. For 5-bit only 5 LSBs are used (upper 3 zero).
//
// Dequant formula (same as tq_dequantize_hb_kv, which must match exactly):
//   D=256: scale_norm = norm * inv_sqrt(256)
//   D=512: scale_norm = norm / scale_factor_d512
//
// ADR-007 iter-24: measure Gate A/B/C at 5/6/8-bit to find shippable bit-width.

#include <metal_stdlib>
using namespace metal;

#define N_SIMDWIDTH 32
#define C           32   // KV positions per simdgroup iteration
#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1))

// Parameters — same layout as FlashAttnVecTqParams in flash_attn_vec_tq.metal.
struct FlashAttnVecTqHbParams {
    uint  n_heads;
    uint  n_kv_heads;
    uint  head_dim;
    uint  kv_seq_len;
    uint  kv_capacity;
    float scale;
    uint  mask_type;
    uint  sliding_window;
    float softcap;
    uint  nwg;
    uint  ring_start;
    float scale_factor_d512;  // for D=512 norm dequant
    uint  codebook_bits;      // 5, 6, or 8 (runtime selector)
};

// Reduce params — shared with flash_attn_vec.
struct FlashAttnVecReduceParamsHb {
    uint nrows;
};

// ---------------------------------------------------------------------------
// 5-bit codebook (32 centroids, byte-packed — same as hadamard_quantize_kv_fast.metal)
// ---------------------------------------------------------------------------
constant float CODEBOOK_HB_5BIT[32] = {
    -3.2606790f, -2.6910589f, -2.3176743f, -2.0286608f,
    -1.7871646f, -1.5761599f, -1.3862739f, -1.2117410f,
    -1.0487242f, -0.8945114f, -0.7470884f, -0.6048936f,
    -0.4666676f, -0.3313550f, -0.1980377f, -0.0658849f,
     0.0658849f,  0.1980377f,  0.3313550f,  0.4666676f,
     0.6048936f,  0.7470884f,  0.8945114f,  1.0487242f,
     1.2117410f,  1.3862739f,  1.5761599f,  1.7871646f,
     2.0286608f,  2.3176743f,  2.6910589f,  3.2606790f,
};

// ---------------------------------------------------------------------------
// 6-bit codebook (64 centroids)
// ---------------------------------------------------------------------------
constant float CODEBOOK_HB_6BIT[64] = {
    -3.6996161f, -3.1907215f, -2.8640626f, -2.6161277f,
    -2.4129324f, -2.2388464f, -2.0853192f, -1.9471373f,
    -1.8208742f, -1.7041502f, -1.5952401f, -1.4928497f,
    -1.3959804f, -1.3038428f, -1.2157998f, -1.1313277f,
    -1.0499889f, -0.9714118f, -0.8952766f, -0.8213046f,
    -0.7492492f, -0.6788902f, -0.6100285f, -0.5424819f,
    -0.4760822f, -0.4106724f, -0.3461048f, -0.2822386f,
    -0.2189392f, -0.1560761f, -0.0935225f, -0.0311537f,
     0.0311537f,  0.0935225f,  0.1560761f,  0.2189392f,
     0.2822386f,  0.3461048f,  0.4106724f,  0.4760822f,
     0.5424819f,  0.6100285f,  0.6788902f,  0.7492492f,
     0.8213046f,  0.8952766f,  0.9714118f,  1.0499889f,
     1.1313277f,  1.2157998f,  1.3038428f,  1.3959804f,
     1.4928497f,  1.5952401f,  1.7041502f,  1.8208742f,
     1.9471373f,  2.0853192f,  2.2388464f,  2.4129324f,
     2.6161277f,  2.8640626f,  3.1907215f,  3.6996161f,
};

// ---------------------------------------------------------------------------
// 8-bit codebook (256 centroids, Lloyd-Max N(0,1), iter-24)
// Computed via Lloyd-Max iteration to convergence (tol=1e-12).
// Symmetry error: 3.41e-10. Range: [-5.0652659, +5.0652659].
// Must match CODEBOOK_8BIT in hadamard_quantize_kv_fast.metal exactly.
// ---------------------------------------------------------------------------
constant float CODEBOOK_HB_8BIT[256] = {
    -5.0652659f, -4.6836997f, -4.4467193f, -4.2715508f,
    -4.1311907f, -4.0132856f, -3.9111092f, -3.8205780f,
    -3.7390194f, -3.6645851f, -3.5959415f, -3.5320936f,
    -3.4722785f, -3.4158977f, -3.3624729f, -3.3116156f,
    -3.2630056f, -3.2163758f, -3.1715011f, -3.1281899f,
    -3.0862780f, -3.0456229f, -3.0061011f, -2.9676040f,
    -2.9300362f, -2.8933131f, -2.8573596f, -2.8221086f,
    -2.7874999f, -2.7534795f, -2.7199985f, -2.6870129f,
    -2.6544825f, -2.6223710f, -2.5906452f, -2.5592748f,
    -2.5282321f, -2.4974918f, -2.4670306f, -2.4368270f,
    -2.4068614f, -2.3771157f, -2.3475732f, -2.3182184f,
    -2.2890372f, -2.2600165f, -2.2311440f, -2.2024086f,
    -2.1737998f, -2.1453081f, -2.1169245f, -2.0886408f,
    -2.0604493f, -2.0323430f, -2.0043154f, -1.9763603f,
    -1.9484722f, -1.9206458f, -1.8928763f, -1.8651592f,
    -1.8374904f, -1.8098662f, -1.7822828f, -1.7547372f,
    -1.7272261f, -1.6997469f, -1.6722970f, -1.6448739f,
    -1.6174755f, -1.5900996f, -1.5627445f, -1.5354084f,
    -1.5080897f, -1.4807869f, -1.4534986f, -1.4262237f,
    -1.3989610f, -1.3717093f, -1.3444678f, -1.3172356f,
    -1.2900118f, -1.2627956f, -1.2355865f, -1.2083838f,
    -1.1811868f, -1.1539951f, -1.1268081f, -1.0996255f,
    -1.0724469f, -1.0452718f, -1.0180999f, -0.9909310f,
    -0.9637647f, -0.9366008f, -0.9094390f, -0.8822793f,
    -0.8551212f, -0.8279648f, -0.8008098f, -0.7736561f,
    -0.7465035f, -0.7193520f, -0.6922014f, -0.6650517f,
    -0.6379027f, -0.6107544f, -0.5836067f, -0.5564596f,
    -0.5293129f, -0.5021667f, -0.4750208f, -0.4478753f,
    -0.4207301f, -0.3935852f, -0.3664405f, -0.3392960f,
    -0.3121517f, -0.2850076f, -0.2578636f, -0.2307198f,
    -0.2035761f, -0.1764324f, -0.1492888f, -0.1221453f,
    -0.0950019f, -0.0678584f, -0.0407151f, -0.0135717f,
     0.0135717f,  0.0407151f,  0.0678584f,  0.0950019f,
     0.1221453f,  0.1492888f,  0.1764324f,  0.2035761f,
     0.2307198f,  0.2578636f,  0.2850076f,  0.3121517f,
     0.3392960f,  0.3664405f,  0.3935852f,  0.4207301f,
     0.4478753f,  0.4750208f,  0.5021667f,  0.5293129f,
     0.5564596f,  0.5836067f,  0.6107544f,  0.6379027f,
     0.6650517f,  0.6922014f,  0.7193520f,  0.7465035f,
     0.7736561f,  0.8008098f,  0.8279648f,  0.8551212f,
     0.8822793f,  0.9094390f,  0.9366008f,  0.9637647f,
     0.9909310f,  1.0180999f,  1.0452718f,  1.0724469f,
     1.0996255f,  1.1268081f,  1.1539951f,  1.1811868f,
     1.2083838f,  1.2355865f,  1.2627956f,  1.2900118f,
     1.3172356f,  1.3444678f,  1.3717093f,  1.3989610f,
     1.4262237f,  1.4534986f,  1.4807869f,  1.5080897f,
     1.5354084f,  1.5627445f,  1.5900996f,  1.6174755f,
     1.6448739f,  1.6722970f,  1.6997469f,  1.7272261f,
     1.7547372f,  1.7822828f,  1.8098662f,  1.8374904f,
     1.8651592f,  1.8928763f,  1.9206458f,  1.9484722f,
     1.9763603f,  2.0043154f,  2.0323430f,  2.0604493f,
     2.0886408f,  2.1169245f,  2.1453081f,  2.1737998f,
     2.2024086f,  2.2311440f,  2.2600165f,  2.2890372f,
     2.3182184f,  2.3475732f,  2.3771157f,  2.4068614f,
     2.4368270f,  2.4670306f,  2.4974918f,  2.5282321f,
     2.5592748f,  2.5906452f,  2.6223710f,  2.6544825f,
     2.6870129f,  2.7199985f,  2.7534795f,  2.7874999f,
     2.8221086f,  2.8573596f,  2.8933131f,  2.9300362f,
     2.9676040f,  3.0061011f,  3.0456229f,  3.0862780f,
     3.1281899f,  3.1715011f,  3.2163758f,  3.2630056f,
     3.3116156f,  3.3624729f,  3.4158977f,  3.4722785f,
     3.5320936f,  3.5959415f,  3.6645851f,  3.7390194f,
     3.8205780f,  3.9111092f,  4.0132856f,  4.1311907f,
     4.2715508f,  4.4467193f,  4.6836997f,  5.0652659f,
};

// ---------------------------------------------------------------------------
// Inline dequant: look up byte index in the selected codebook, scale by norm.
// CODEBOOK_BITS is a runtime value from params (not compile-time template),
// so we use if-else. The Metal compiler will constant-fold if the value is
// known constant per-dispatch via a push-constant variant, but runtime is fine
// for correctness.
//
// packed_base: pointer to start of this position's byte-packed data [head_dim bytes]
// coord:       coordinate index (0..head_dim-1)
// scale_norm:  pre-multiplied scale (norm * inv_sqrt_dk for D=256, norm/sf for D=512)
// cbits:       codebook_bits field from params (5, 6, or 8)
// ---------------------------------------------------------------------------
inline float dequant_hb_single(
    device const uint8_t *packed_pos,
    uint coord,
    float scale_norm,
    uint cbits
) {
    uint idx = (uint)packed_pos[coord];
    float centroid;
    if (cbits == 5u) {
        centroid = CODEBOOK_HB_5BIT[idx & 0x1Fu];
    } else if (cbits == 6u) {
        centroid = CODEBOOK_HB_6BIT[idx & 0x3Fu];
    } else {
        centroid = CODEBOOK_HB_8BIT[idx];  // 8-bit: full byte
    }
    return centroid * scale_norm;
}

// Reconstruct float4 from 4 consecutive byte-packed elements.
// coord_base must be a multiple of 4.
inline float4 dequant_hb_float4(
    device const uint8_t *packed_pos,
    uint coord_base,
    float scale_norm,
    uint cbits
) {
    return float4(
        dequant_hb_single(packed_pos, coord_base + 0, scale_norm, cbits),
        dequant_hb_single(packed_pos, coord_base + 1, scale_norm, cbits),
        dequant_hb_single(packed_pos, coord_base + 2, scale_norm, cbits),
        dequant_hb_single(packed_pos, coord_base + 3, scale_norm, cbits)
    );
}

// ---------------------------------------------------------------------------
// Main kernel: native HB (higher-bit) TQ flash attention vector.
//
// Same structure as flash_attn_vec_tq_impl but reads from byte-packed K/V.
// 5/6/8-bit controlled by params.codebook_bits at runtime.
//
// Norms layout:
//   D=256: [num_kv_heads, capacity]    f32 — 1 norm per position
//   D=512: [num_kv_heads, capacity, 2] f32 — 2 per-block norms per position
// ---------------------------------------------------------------------------
template<short DK, short DV>
kernel void flash_attn_vec_tq_hb_impl(
    constant FlashAttnVecTqHbParams  &params      [[buffer(0)]],
    device const float               *Q           [[buffer(1)]],
    device const uint8_t             *K_packed    [[buffer(2)]],  // byte-packed
    device const float               *K_norms     [[buffer(3)]],
    device const uint8_t             *V_packed    [[buffer(4)]],  // byte-packed
    device const float               *V_norms     [[buffer(5)]],
    device       float               *dst         [[buffer(6)]],
    threadgroup  half                *shmem       [[threadgroup(0)]],
    uint3  tgpig [[threadgroup_position_in_grid]],
    ushort tiisg [[thread_index_in_simdgroup]],
    ushort sgitg [[simdgroup_index_in_threadgroup]])
{
    constexpr short DK4 = DK / 4;
    constexpr short DV4 = DV / 4;
    constexpr short NW  = N_SIMDWIDTH;
    constexpr short NL  = NW;
    constexpr short PK  = PAD2(DK, 128);
    constexpr short PK4 = PK / 4;
    constexpr short PV  = PAD2(DV, 128);
    constexpr short PV4 = PV / 4;
    constexpr short SH  = 4 * C;  // 128 halfs = 64 floats

    static_assert(DK % 32 == 0, "DK must be divisible by 32");
    static_assert(DV % 32 == 0, "DV must be divisible by 32");
    static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
    static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");

    const uint NWG = params.nwg;
    const ushort iwg = tgpig[2] % NWG;
    const ushort iq2 = tgpig[1];  // head index
    const ushort iq1 = tgpig[0];  // query index (0 for decode)

    // GQA: map query head to KV head.
    const uint heads_per_kv = params.n_heads / params.n_kv_heads;
    const uint kv_head = iq2 / heads_per_kv;

    // Shared memory layout (same as flash_attn_vec_tq.metal):
    //   [0, PK)                     — Q as half4 (pre-rotated by caller)
    //   [PK, PK + SH)               — scratch for attention scores
    //   [PK + SH, PK + SH + 2*PV)   — output accumulator as float4
    threadgroup half4  *sq4 = (threadgroup half4  *)(shmem);
    threadgroup float  *ss  = (threadgroup float  *)(shmem + PK);
    threadgroup float4 *so4 = (threadgroup float4 *)(shmem + PK + SH);

    // Load pre-rotated Q into shared memory as half4.
    for (ushort i = tiisg; i < PK4; i += NW) {
        if (i < DK4) {
            float4 qval = *((device const float4 *)(Q + iq2 * DK + i * 4));
            sq4[i] = half4(qval);
        } else {
            sq4[i] = half4(0.0h);
        }
    }

    // Zero output accumulator.
    so4 += tiisg;
    for (short i = 0; i < DV4 / NL; ++i) {
        so4[i * NL] = float4(0.0f);
    }

    // Zero scratch buffer.
    for (ushort i = tiisg; i < SH / 4; i += NW) {
        ((threadgroup float *)(shmem + PK))[i] = 0.0f;
    }

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Online softmax state.
    float S = 0.0f;
    float M = -FLT_MAX / 2;

    const ushort tx = tiisg;
    const uint kv_seq_len = params.kv_seq_len;
    const uint kv_capacity = params.kv_capacity;
    const uint ring_start = params.ring_start;
    const uint cbits = params.codebook_bits;
    const float sf_d512 = params.scale_factor_d512;
    const bool is_d512 = (DK > 256);

    uint window_start_logical = 0;
    if (params.mask_type == 2 && params.sliding_window > 0 && kv_seq_len > params.sliding_window) {
        window_start_logical = kv_seq_len - params.sliding_window;
    }

    threadgroup const half4 *pq4 = sq4 + tx;

    // Main loop over KV cache in chunks of C=32.
    for (uint ic0 = iwg; ; ic0 += NWG) {
        uint ic = ic0 * C;
        if (ic >= kv_seq_len) break;

        // Compute mask for this chunk.
        {
            uint k_pos = ic + tx;
            float mask_val = 0.0f;
            if (k_pos >= kv_seq_len) {
                mask_val = -65504.0f;
            } else {
                uint logical_idx = (k_pos - ring_start + kv_capacity) % kv_capacity;
                if (logical_idx >= kv_seq_len || logical_idx < window_start_logical) {
                    mask_val = -65504.0f;
                }
            }
            ss[tx] = mask_val;
        }

        if (simd_max(ss[tiisg]) <= -65504.0f) continue;

        // ---- Q * K^T ----
        {
            float mqk[C];
            const float inv_sqrt_dk = rsqrt(float(DK));

            for (short cc = 0; cc < C; ++cc) {
                uint kv_pos = ic + cc;
                if (kv_pos >= kv_seq_len) {
                    mqk[cc] = 0.0f;
                    continue;
                }

                // Dequant scale for K.
                float k_sn;
                if (is_d512) {
                    // D=512: per-block norms; block 0 = coords 0..255, block 1 = 256..511
                    // For K*Q^T we need both blocks. The dot product spans all DK coords.
                    // We compute the block-0 portion and block-1 portion separately,
                    // each with their own scale_norm.
                    // norm_base points to: [kv_head, kv_pos, 0..2] f32
                    device const float *knorm = K_norms + (kv_head * kv_capacity + kv_pos) * 2u;
                    // k_sn unused in this branch — handled in the inner loop below
                    (void)k_sn;
                    (void)inv_sqrt_dk;

                    device const uint8_t *k_base =
                        K_packed + (kv_head * kv_capacity + kv_pos) * DK;

                    float partial = 0.0f;
                    // Block 0: coords 0..255
                    // Each thread tx covers elements (tx + ii*NL)*4 .. (tx + ii*NL)*4+3
                    // for ii in [0..(DK/2)/4/NL). This mirrors the D=256 striding pattern.
                    {
                        float sn0 = knorm[0] / sf_d512;
                        for (short ii = 0; ii < (DK/2) / 4 / NL; ++ii) {
                            uint coord = (uint)(tx + ii * NL) * 4u;
                            float4 k_val = dequant_hb_float4(k_base, coord, sn0, cbits);
                            partial += dot(k_val, float4(pq4[ii * NL]));
                        }
                    }
                    // Block 1: coords 256..511
                    {
                        float sn1 = knorm[1] / sf_d512;
                        const uint blk1_start = DK / 2;
                        for (short ii = 0; ii < (DK/2) / 4 / NL; ++ii) {
                            uint coord = blk1_start + (uint)(tx + ii * NL) * 4u;
                            float4 k_val = dequant_hb_float4(k_base, coord, sn1, cbits);
                            partial += dot(k_val, float4(pq4[(DK4/2/NL + ii) * NL]));
                        }
                    }
                    mqk[cc] = simd_sum(partial);
                } else {
                    // D=256: single norm per position.
                    float k_norm_val = K_norms[kv_head * kv_capacity + kv_pos];
                    k_sn = k_norm_val * inv_sqrt_dk;

                    device const uint8_t *k_base =
                        K_packed + (kv_head * kv_capacity + kv_pos) * DK + tx * 4u;

                    float partial = 0.0f;
                    for (short ii = 0; ii < DK4 / NL; ++ii) {
                        float4 k_val = dequant_hb_float4(k_base, (uint)(ii * NL) * 4u, k_sn, cbits);
                        partial += dot(k_val, float4(pq4[ii * NL]));
                    }
                    mqk[cc] = simd_sum(partial);
                }
            }

            ss[tx] = fma(mqk[tx], params.scale, ss[tx]);
        }

        simdgroup_barrier(mem_flags::mem_threadgroup);

        // ---- Online softmax ----
        {
            const float m_old = M;
            const float s_new = ss[tiisg];
            M = simd_max(max(M, s_new));
            const float ms = exp(m_old - M);
            const float vs = exp(s_new - M);
            S = S * ms + simd_sum(vs);
            ss[tiisg] = vs;
            for (short ii = 0; ii < DV4 / NL; ++ii) {
                so4[ii * NL] *= ms;
            }
        }

        simdgroup_barrier(mem_flags::mem_threadgroup);

        // ---- O = O + softmax_weights * V ----
        {
            float4 lo[DV4 / NL];
            for (short ii = 0; ii < DV4 / NL; ++ii) lo[ii] = float4(0.0f);

            const float inv_sqrt_dv = rsqrt(float(DV));

            for (short cc = 0; cc < C; ++cc) {
                uint kv_pos = ic + cc;
                if (kv_pos >= kv_seq_len) continue;

                if (is_d512) {
                    device const float *vnorm = V_norms + (kv_head * kv_capacity + kv_pos) * 2u;
                    device const uint8_t *v_base =
                        V_packed + (kv_head * kv_capacity + kv_pos) * DV;
                    float w = ss[cc];

                    // Block 0: coords 0..255
                    // Same striding pattern as D=256 and K D=512 above.
                    float sn0 = vnorm[0] / sf_d512 * w;
                    for (short ii = 0; ii < (DV/2) / 4 / NL; ++ii) {
                        uint coord = (uint)(tx + ii * NL) * 4u;
                        lo[ii] += dequant_hb_float4(v_base, coord, sn0, cbits);
                    }
                    // Block 1: coords 256..511
                    float sn1 = vnorm[1] / sf_d512 * w;
                    for (short ii = 0; ii < (DV/2) / 4 / NL; ++ii) {
                        uint coord = (uint)(DV/2) + (uint)(tx + ii * NL) * 4u;
                        lo[DV4/2/NL + ii] += dequant_hb_float4(v_base, coord, sn1, cbits);
                    }
                } else {
                    float v_norm_val = V_norms[kv_head * kv_capacity + kv_pos];
                    float v_sw = v_norm_val * inv_sqrt_dv * ss[cc];
                    device const uint8_t *v_base =
                        V_packed + (kv_head * kv_capacity + kv_pos) * DV + tx * 4u;

                    for (short ii = 0; ii < DV4 / NL; ++ii) {
                        lo[ii] += dequant_hb_float4(v_base, (uint)(ii * NL) * 4u, v_sw, cbits);
                    }
                }
            }

            for (short ii = 0; ii < DV4 / NL; ++ii) {
                so4[ii * NL] += lo[ii];
            }
        }
    }

    // Store M and S for the reduce kernel.
    if (tiisg == 0) {
        ss[0] = S;
        ss[1] = M;
    }

    so4 -= tiisg;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // ---- Write output ----
    if (sgitg == 0) {
        const int64_t nrows = params.n_heads;
        const int64_t rid = iq2 + (int64_t)iq1 * params.n_heads;
        const uint NWG_val = params.nwg;
        const float inv_S = (NWG_val == 1) ? ((S == 0.0f) ? 0.0f : 1.0f / S) : 1.0f;

        device float4 *dst4 = (device float4 *)dst;
        for (ushort i = tiisg; i < DV4; i += NW) {
            dst4[rid * DV4 * NWG_val + NWG_val * i + iwg] = so4[i] * inv_S;
        }

        if (NWG_val > 1 && tiisg == 0) {
            device float *dst1 = (device float *)dst + nrows * DV * NWG_val;
            dst1[rid * (2 * NWG_val) + 2 * iwg + 0] = S;
            dst1[rid * (2 * NWG_val) + 2 * iwg + 1] = M;
        }
    }
}

// --------------------------------------------------------------------------
// Kernel instantiations
// --------------------------------------------------------------------------

typedef decltype(flash_attn_vec_tq_hb_impl<256, 256>) flash_attn_vec_tq_hb_t;

template [[host_name("flash_attn_vec_tq_hb_dk256")]]
kernel flash_attn_vec_tq_hb_t flash_attn_vec_tq_hb_impl<256, 256>;

template [[host_name("flash_attn_vec_tq_hb_dk512")]]
kernel flash_attn_vec_tq_hb_t flash_attn_vec_tq_hb_impl<512, 512>;