ferrum-kernels 0.7.7

Unified compute kernels (CUDA/Metal/CPU) and model runner for Ferrum inference
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
#include <metal_stdlib>
#include <metal_simdgroup_matrix>
using namespace metal;

// ── Fused Flash Attention for f32 ──────────────────────────────────────
//
// Single kernel: QK^T * scale + causal_mask → online softmax → attn@V
// No intermediate buffers. All accumulation in registers/threadgroup memory.
//
// Layout:
//   Q: [batch, num_heads, q_len, head_dim]  (contiguous)
//   K: [batch, num_kv_heads, kv_len, head_dim]
//   V: [batch, num_kv_heads, kv_len, head_dim]
//   O: [batch, num_heads, q_len, head_dim]
//
// Grid: (q_len, num_heads, batch) — one threadgroup per query position per head
// Each threadgroup processes one row of the attention matrix.

struct FlashAttnParams {
    int batch;
    int num_heads;
    int num_kv_heads;
    int q_len;
    int kv_len;
    int head_dim;
    float scale;
    int causal;       // 0 or 1
    int pos_offset;
    int kv_seq_stride; // seq dimension stride for K/V (= kv_len for contiguous, = max_len for paged cache)
    int sliding_window; // 0 = full causal, >0 = attend only to last `w` KV positions (Mistral v0.1, Gemma)
};

// Block size for KV processing — process this many KV positions per iteration
constant int BLOCK_KV = 32;

kernel void flash_attn_f32(
    device const float* Q       [[buffer(0)]],
    device const float* K       [[buffer(1)]],
    device const float* V       [[buffer(2)]],
    device       float* O       [[buffer(3)]],
    constant FlashAttnParams& p [[buffer(4)]],
    uint3 tgpig [[threadgroup_position_in_grid]],   // (q_pos, head, batch)
    uint  tiisg [[thread_index_in_simdgroup]],       // 0..31
    uint  sgitg [[simdgroup_index_in_threadgroup]])  // simdgroup index
{
    const int qi    = tgpig.x;  // which query position
    const int hi    = tgpig.y;  // which head
    const int bi    = tgpig.z;  // which batch

    const int kv_hi = hi / (p.num_heads / p.num_kv_heads); // GQA: KV head index
    const int d     = p.head_dim;
    const int sk    = p.kv_len;

    // Causal upper bound and optional sliding-window lower bound.
    const int attend_end = p.causal ? min(p.pos_offset + qi + 1, sk) : sk;
    const int attend_start = (p.causal && p.sliding_window > 0)
        ? max(0, attend_end - p.sliding_window)
        : 0;
    const int attend_len = attend_end;

    // Pointers
    device const float* q_row = Q + ((bi * p.num_heads + hi) * p.q_len + qi) * d;
    const int kv_stride = (p.kv_seq_stride > 0) ? p.kv_seq_stride : sk;
    device const float* k_base = K + (bi * p.num_kv_heads + kv_hi) * kv_stride * d;
    device const float* v_base = V + (bi * p.num_kv_heads + kv_hi) * kv_stride * d;
    device       float* o_row = O + ((bi * p.num_heads + hi) * p.q_len + qi) * d;

    // Online softmax state (per thread, then reduced)
    float M = -INFINITY;  // running max
    float S = 0.0f;       // running sum of exp

    // Output accumulator in registers (one float per head_dim element)
    // Each thread handles d/32 elements (32 threads in simdgroup)
    // For head_dim=128: each thread handles 4 elements
    const int elems_per_thread = d / 32;

    // Local output accumulator
    float acc[4] = {0, 0, 0, 0}; // supports up to head_dim=128 (4 per thread)

    // Process KV in blocks, starting at `attend_start` so sliding-window
    // positions that are too old don't contribute.
    for (int kv_start = attend_start; kv_start < attend_len; kv_start += BLOCK_KV) {
        int kv_end = min(kv_start + BLOCK_KV, attend_len);

        for (int ki = kv_start; ki < kv_end; ++ki) {
            device const float* k_row = k_base + ki * d;
            device const float* v_row = v_base + ki * d;

            // Compute dot product Q[qi] · K[ki] using simd reduction.
            // Vectorized: each thread loads 4 floats per iter (float4
            // load = 16-byte coalesced access vs 4× 4-byte scalar loads).
            // For d=128, head_dim divisible by 4: each thread does d/(32*4)
            // = 1 float4 load and 1 metal::dot. For odd head_dims fall
            // back to scalar tail.
            float dot_acc = 0.0f;
            const int d4 = d & ~3; // round down to multiple of 4
            for (int j = tiisg * 4; j < d4; j += 32 * 4) {
                float4 q_v = *((device const float4 *)(q_row + j));
                float4 k_v = *((device const float4 *)(k_row + j));
                dot_acc += metal::dot(q_v, k_v);
            }
            // Tail (head_dim not multiple of 4) — rare, kept for safety.
            for (int j = d4 + tiisg; j < d; j += 32) {
                dot_acc += q_row[j] * k_row[j];
            }
            float dot_v = simd_sum(dot_acc) * p.scale;

            // Online softmax update
            float old_M = M;
            M = max(M, dot_v);
            float exp_diff = exp(old_M - M);
            float exp_val = exp(dot_v - M);

            // Rescale existing accumulator and sum
            S = S * exp_diff + exp_val;

            // Update output: O = O * exp_diff + exp_val * V[ki]
            for (int j = 0; j < elems_per_thread; ++j) {
                int idx = tiisg + j * 32;
                if (idx < d) {
                    acc[j] = acc[j] * exp_diff + exp_val * v_row[idx];
                }
            }
        }
    }

    // Final normalization: O = O / S
    float inv_S = (S > 0.0f) ? (1.0f / S) : 0.0f;
    for (int j = 0; j < elems_per_thread; ++j) {
        int idx = tiisg + j * 32;
        if (idx < d) {
            o_row[idx] = acc[j] * inv_S;
        }
    }
}

// ── Q-tiled flash attention with simdgroup_matmul (head_dim=128, f32) ────
//
// Mirrors llama.cpp's kernel_flash_attn_ext_impl shape:
//   Q_TILE = 8 query rows per threadgroup
//   NSG    = 4 simdgroups per threadgroup (128 threads)
//   NQ     = Q_TILE / NSG = 2 query rows per simdgroup
//   C      = 32 KV columns per inner chunk (4 simdgroups × 8 cols each)
//   DK=DV  = 128 head dimension
//
// Each threadgroup processes one (q_tile, head, batch) and walks the
// full KV range for that head. The 4 simdgroups cooperate:
//   • QK^T   — each simdgroup computes one [8,8] tile via simdgroup_matmul
//   • softmax — each simdgroup handles its NQ rows
//   • P @ V   — 16 output [8,8] tiles split across 4 simdgroups (NO=4 each)
//
// Restrictions (caller picks the legacy kernel when violated):
//   • head_dim == 128
//   • sliding_window == 0
//   • num_heads % num_kv_heads == 0  (any GQA ratio works)
//   • q_len divisible by 8 *or* the trailing tile is padded with zero queries
//
// Total threadgroup memory: 8*128 (sq) + 8*128 (so) + 8*32 (ss) = 2304 f32
// = 9.0 KB — well within Apple7's 32 KB per-threadgroup limit.

constant int Q_TILE_R = 8;
constant int FA_NSG   = 4;
constant int FA_NQ    = 2;        // Q_TILE_R / FA_NSG
constant int FA_C     = 32;
constant int FA_DK    = 128;
constant int FA_DK8   = 16;       // FA_DK / 8
constant int FA_DV    = 128;
constant int FA_DV4   = 32;       // FA_DV / 4
constant int FA_DV8   = 16;       // FA_DV / 8
constant int FA_NO    = 4;        // FA_DV8 / FA_NSG

kernel void flash_attn_q_tiled_f32(
    device const float* Q       [[buffer(0)]],
    device const float* K       [[buffer(1)]],
    device const float* V       [[buffer(2)]],
    device       float* O       [[buffer(3)]],
    constant FlashAttnParams& p [[buffer(4)]],
    uint3 tgpig [[threadgroup_position_in_grid]],
    uint  tiisg [[thread_index_in_simdgroup]],
    uint  sgitg [[simdgroup_index_in_threadgroup]])
{
    const int qtile = int(tgpig.x);
    const int hi    = int(tgpig.y);
    const int bi    = int(tgpig.z);
    const int iq1   = qtile * Q_TILE_R;

    if (iq1 >= p.q_len) return;

    const int kv_hi    = hi / (p.num_heads / p.num_kv_heads);
    const int kv_stride = (p.kv_seq_stride > 0) ? p.kv_seq_stride : p.kv_len;

    device const float* q_base = Q + ((bi * p.num_heads + hi) * p.q_len + iq1) * FA_DK;
    device const float* k_base = K + (bi * p.num_kv_heads + kv_hi) * kv_stride * FA_DK;
    device const float* v_base = V + (bi * p.num_kv_heads + kv_hi) * kv_stride * FA_DV;
    device       float* o_base = O + ((bi * p.num_heads + hi) * p.q_len + iq1) * FA_DV;

    // Threadgroup memory — laid out contiguously
    threadgroup float sq[Q_TILE_R * FA_DK];   // queries
    threadgroup float so[Q_TILE_R * FA_DV];   // running output (post rescale)
    threadgroup float ss[Q_TILE_R * FA_C];    // attention scores / probabilities

    // 1. Load Q tile into shared memory; pad rows beyond q_len with zero.
    for (int jj = 0; jj < FA_NQ; ++jj) {
        const int j = jj * FA_NSG + int(sgitg);
        const int q_pos = iq1 + j;
        if (q_pos < p.q_len) {
            device const float4* q_row4 = (device const float4 *)(q_base + j * FA_DK);
            threadgroup float4* sq4 = (threadgroup float4 *)(sq + j * FA_DK);
            for (int i = int(tiisg); i < FA_DK / 4; i += 32) {
                sq4[i] = q_row4[i];
            }
        } else {
            threadgroup float4* sq4 = (threadgroup float4 *)(sq + j * FA_DK);
            for (int i = int(tiisg); i < FA_DK / 4; i += 32) {
                sq4[i] = float4(0.0f);
            }
        }
    }

    // 2. Zero output accumulator.
    for (int jj = 0; jj < FA_NQ; ++jj) {
        const int j = jj * FA_NSG + int(sgitg);
        threadgroup float4* so4 = (threadgroup float4 *)(so + j * FA_DV);
        for (int i = int(tiisg); i < FA_DV / 4; i += 32) {
            so4[i] = float4(0.0f);
        }
    }

    // Per-simdgroup running max and sum (covers FA_NQ rows).
    float M[FA_NQ];
    float S[FA_NQ];
    for (int jj = 0; jj < FA_NQ; ++jj) {
        M[jj] = -INFINITY;
        S[jj] = 0.0f;
    }

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Upper bound for the chunk loop. Causal: can stop after the last row's
    // attend_end (= pos_offset + iq1 + Q_TILE_R since rows j increase).
    int attend_end_max = p.kv_len;
    if (p.causal) {
        attend_end_max = min(p.pos_offset + iq1 + Q_TILE_R, p.kv_len);
    }

    // 3. Walk KV in C=32 chunks.
    for (int ic = 0; ic < attend_end_max; ic += FA_C) {
        // ── 3a. QK^T: each simdgroup writes one [8,8] tile to ss. ──
        {
            device const float* pk = k_base + (ic + 8 * int(sgitg)) * FA_DK;
            threadgroup const float* pq = sq;
            threadgroup       float* ps = ss + 8 * int(sgitg);

            simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.0f);

            for (int i = 0; i < FA_DK8; ++i) {
                simdgroup_float8x8 mk;
                simdgroup_float8x8 mq;
                simdgroup_load(mk, pk + 8 * i, FA_DK, ulong2(0, 0), true);
                simdgroup_load(mq, pq + 8 * i, FA_DK);
                simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
            }

            simdgroup_store(mqk, ps, FA_C, ulong2(0, 0), false);
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);

        // ── 3b. Online softmax: each simdgroup handles its FA_NQ rows. ──
        for (int jj = 0; jj < FA_NQ; ++jj) {
            const int j = jj * FA_NSG + int(sgitg);
            const int q_pos = iq1 + j;
            const int k_pos = ic + int(tiisg);

            // Load this lane's score (C == warp size, so one element per lane).
            float s = ss[j * FA_C + int(tiisg)];
            s *= p.scale;

            // Mask: out-of-range Q row, out-of-range K column, or causal.
            bool keep = (q_pos < p.q_len) && (k_pos < p.kv_len);
            if (p.causal) {
                const int row_end = min(p.pos_offset + q_pos + 1, p.kv_len);
                keep = keep && (k_pos < row_end);
            }
            if (!keep) {
                s = -INFINITY;
            }

            const float old_M = M[jj];
            const float row_max = simd_max(s);
            const float new_M = max(old_M, row_max);

            // Guard against the "all -INF" case (e.g. early causal rows).
            const float ms = isfinite(old_M) ? exp(old_M - new_M) : 0.0f;
            const float vs = isfinite(s)     ? exp(s - new_M)     : 0.0f;

            S[jj] = ms * S[jj] + simd_sum(vs);

            // Persist post-softmax probability for the P@V stage.
            ss[j * FA_C + int(tiisg)] = vs;

            // Rescale this row's running output by ms (each lane = FA_DV/32 elems).
            threadgroup float4* so4 = (threadgroup float4 *)(so + j * FA_DV);
            for (int i = int(tiisg); i < FA_DV / 4; i += 32) {
                so4[i] = so4[i] * ms;
            }

            M[jj] = new_M;
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);

        // ── 3c. O += P @ V via simdgroup_matmul. ──
        // Each simdgroup owns FA_NO=4 output [8,8] tiles at column offsets
        // 8*sgitg, 8*(sgitg+NSG), 8*(sgitg+2*NSG), 8*(sgitg+3*NSG). Loading
        // the running O tiles into registers, accumulating, then storing
        // back avoids touching the other simdgroups' data.
        {
            simdgroup_float8x8 lo[FA_NO];
            {
                threadgroup const float* pso = so + 8 * int(sgitg);
                for (int ii = 0; ii < FA_NO; ++ii) {
                    simdgroup_load(lo[ii], pso, FA_DV, ulong2(0, 0), false);
                    pso += 8 * FA_NSG;
                }
            }

            device const float* pv = v_base + ic * FA_DV;
            for (int cc = 0; cc < FA_C / 8; ++cc) {
                simdgroup_float8x8 vs;
                simdgroup_load(vs, ss + 8 * cc, FA_C, ulong2(0, 0), false);

                for (int ii = 0; ii < FA_NO; ++ii) {
                    simdgroup_float8x8 mv;
                    simdgroup_load(mv,
                                   pv + 8 * int(sgitg) + 8 * FA_NSG * ii,
                                   FA_DV, ulong2(0, 0), false);
                    simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]);
                }

                pv += 8 * FA_DV;
            }

            {
                threadgroup float* pso = so + 8 * int(sgitg);
                for (int ii = 0; ii < FA_NO; ++ii) {
                    simdgroup_store(lo[ii], pso, FA_DV, ulong2(0, 0), false);
                    pso += 8 * FA_NSG;
                }
            }
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    // 4. Normalize and write O = so / S to global memory.
    for (int jj = 0; jj < FA_NQ; ++jj) {
        const int j = jj * FA_NSG + int(sgitg);
        const int q_pos = iq1 + j;
        if (q_pos >= p.q_len) continue;

        const float inv_S = (S[jj] > 0.0f) ? (1.0f / S[jj]) : 0.0f;
        device float4* o_row4 = (device float4 *)(o_base + j * FA_DV);
        threadgroup const float4* so4 = (threadgroup const float4 *)(so + j * FA_DV);
        for (int i = int(tiisg); i < FA_DV / 4; i += 32) {
            o_row4[i] = so4[i] * inv_S;
        }
    }
}

// ── SDPA vector decode (m=1, head_dim=128) — MLX-style wide threadgroup ──
//
// Ported from MLX's sdpa_vector kernel (the same kernel candle-metal-kernels
// uses; mistral.rs reaches it via candle's `ops::sdpa`). The legacy
// `flash_attn_f32` above uses 32 threads (one simdgroup) per
// (head, query) — for Llama-3.1-8B that's 32 active threads × 32 q-heads
// = 1024 active threads total, ~3% of M1 Max's ~32K-thread concurrent
// capacity. KV positions are walked sequentially within that single
// simdgroup, so most of the GPU sits idle during decode m=1.
//
// This kernel widens the threadgroup to 32 simdgroups × 32 threads =
// 1024 threads, one TG per (head, batch). The 32 simdgroups process
// distinct KV positions in parallel; each thread within a simdgroup
// owns elem_per_thread = head_dim/32 = 4 elements of Q/K/V/O. After
// the KV loop, simdgroups merge their partial (max, sumexp, output)
// via threadgroup memory using the same online-softmax rescaling
// trick the inner loop uses.
//
// Restrictions (caller picks legacy when violated):
//   • q_len == 1                    (decode hot path; prefill stays Q-tiled)
//   • head_dim == 128               (4 elements/thread × 32 threads)
//   • sliding_window == 0           (handled later if needed)
//   • num_heads % num_kv_heads == 0 (standard GQA)
//
// Threadgroup memory:
//   outputs[BN * head_dim] = 32 * 128 * 4 = 16 KB
//   max_scores[BN]         = 128 B
//   sum_exp_scores[BN]     = 128 B
//   total ≈ 16.25 KB — within Apple7's 32 KB cap.

constant int SDPA_BN = 32;        // simdgroups per threadgroup
constant int SDPA_BD = 32;        // simdgroup width (Apple GPU)
constant int SDPA_D  = 128;       // head_dim
constant int SDPA_EPT = SDPA_D / SDPA_BD; // = 4 elements per thread

kernel void flash_attn_decode_f32(
    device const float* Q       [[buffer(0)]],
    device const float* K       [[buffer(1)]],
    device const float* V       [[buffer(2)]],
    device       float* O       [[buffer(3)]],
    constant FlashAttnParams& p [[buffer(4)]],
    uint3  tgpig [[threadgroup_position_in_grid]],   // (1, head, batch)
    uint   sgitg [[simdgroup_index_in_threadgroup]], // 0..31 — which KV-stripe
    uint   tiisg [[thread_index_in_simdgroup]])      // 0..31 — which D slice
{
    const int hi    = int(tgpig.y);
    const int bi    = int(tgpig.z);
    const int kv_hi = hi / (p.num_heads / p.num_kv_heads);
    const int d     = p.head_dim;
    const int sk    = p.kv_len;

    // Causal upper bound for q_len=1: attend to positions [0, pos_offset+1)
    // (the new token can see itself once it's been written into the cache).
    const int attend_end = p.causal ? min(p.pos_offset + 1, sk) : sk;

    // Pointers — offset to the per-thread element slice. Each thread owns
    // SDPA_EPT contiguous elements at offset `tiisg * SDPA_EPT`.
    const int kv_stride = (p.kv_seq_stride > 0) ? p.kv_seq_stride : sk;
    device const float* q_row = Q + ((bi * p.num_heads    + hi   ) * p.q_len) * d
                                  + tiisg * SDPA_EPT;
    device const float* k_base = K + (bi * p.num_kv_heads + kv_hi) * kv_stride * d;
    device const float* v_base = V + (bi * p.num_kv_heads + kv_hi) * kv_stride * d;
    device       float* o_row = O + ((bi * p.num_heads    + hi   ) * p.q_len) * d;

    // Per-thread Q (pre-scaled), running output, scratch K/V.
    float q[SDPA_EPT];
    float o_acc[SDPA_EPT];
    float k_v[SDPA_EPT];
    for (int i = 0; i < SDPA_EPT; ++i) {
        q[i]     = p.scale * q_row[i];
        o_acc[i] = 0.0f;
    }

    float max_score = -INFINITY;
    float sum_exp   = 0.0f;

    // KV loop — each simdgroup walks positions { sgitg, sgitg+BN, sgitg+2*BN, ... }.
    for (int ki = int(sgitg); ki < attend_end; ki += SDPA_BN) {
        device const float* k_row = k_base + ki * d + tiisg * SDPA_EPT;
        device const float* v_row = v_base + ki * d + tiisg * SDPA_EPT;

        // Read this thread's slice of K and compute per-thread partial dot.
        float dot_acc = 0.0f;
        for (int j = 0; j < SDPA_EPT; ++j) {
            k_v[j]   = k_row[j];
            dot_acc += q[j] * k_v[j];
        }
        // Reduce across the 32 threads of this simdgroup — `score` is the
        // full Q·K dot product for KV position `ki`.
        const float score = simd_sum(dot_acc);

        // Online softmax update (Tri Dao). All 32 threads in this simdgroup
        // see the same `score` after simd_sum so they update identically.
        const float new_max = max(max_score, score);
        const float factor  = exp(max_score - new_max);
        const float exp_sc  = exp(score      - new_max);

        max_score = new_max;
        sum_exp   = sum_exp * factor + exp_sc;

        // Read this thread's slice of V and fold into accumulator.
        for (int j = 0; j < SDPA_EPT; ++j) {
            o_acc[j] = o_acc[j] * factor + exp_sc * v_row[j];
        }
    }

    // ── Cross-simdgroup combine ───────────────────────────────────────
    // Each simdgroup has its own partial (max_score, sum_exp, o_acc) over
    // the KV positions it walked. Merge across the 32 simdgroups using
    // threadgroup memory + the same online-softmax rescaling trick.

    threadgroup float outputs[SDPA_BN * SDPA_D];   // [BN][D]
    threadgroup float max_scores[SDPA_BN];
    threadgroup float sum_exp_scores[SDPA_BN];

    // Stash this simdgroup's partial output (one row of D floats per SG).
    threadgroup float* my_row = outputs + sgitg * SDPA_D + tiisg * SDPA_EPT;
    for (int j = 0; j < SDPA_EPT; ++j) {
        my_row[j] = o_acc[j];
    }
    // The leader of each simdgroup publishes its scalar (max, sum_exp).
    if (tiisg == 0) {
        max_scores[sgitg]     = max_score;
        sum_exp_scores[sgitg] = sum_exp;
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Reduce the 32 (max_score, sum_exp) pairs to a single (M*, S*) using
    // online softmax. Simdgroup 0 does the reduction across all 32 lanes
    // by reading lane `tiisg` from threadgroup memory.
    if (sgitg == 0) {
        const float local_max = max_scores[tiisg];
        const float global_max = simd_max(local_max);
        const float local_factor = exp(local_max - global_max);
        const float local_sum_scaled = sum_exp_scores[tiisg] * local_factor;
        const float global_sum = simd_sum(local_sum_scaled);

        // Store the per-simdgroup rescale factor and the broadcast totals.
        max_scores[tiisg]     = local_factor;          // reused as factor
        sum_exp_scores[tiisg] = global_sum;            // every lane reads the same
        if (tiisg == 0) {
            // Sentinel slot to keep things simple: stash global_max in [0]
            // — no consumer reads it currently (we only need factor + sum)
            // but it's useful for debug. (no-op write, kept as comment)
        }
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Final write. Each thread owns SDPA_EPT output positions across D.
    // For each of those positions we need to combine the 32 simdgroup
    // partials. The element is at column `tiisg * SDPA_EPT + j` in the
    // [BN][D] outputs grid; thread `tiisg` of simdgroup `sgitg` walks the
    // 32 simdgroups.
    //
    // But we want exactly ONE thread to write each output element. Use
    // simdgroup 0 to do the writes — its 32 threads cover all D output
    // slots since SDPA_BD * SDPA_EPT = 128 = D.
    if (sgitg == 0) {
        const float inv_S = (sum_exp_scores[0] > 0.0f) ? (1.0f / sum_exp_scores[0]) : 0.0f;

        for (int j = 0; j < SDPA_EPT; ++j) {
            const int col = tiisg * SDPA_EPT + j;   // 0..127
            float total = 0.0f;
            for (int s = 0; s < SDPA_BN; ++s) {
                total += outputs[s * SDPA_D + col] * max_scores[s];
            }
            o_row[col] = total * inv_S;
        }
    }
}

// ── Paged-KV variant of flash_attn_decode_f32 ────────────────────────
//
// Same online-softmax + cross-simdgroup combine as the contiguous-KV
// variant above; the only change is HOW each simdgroup addresses K/V.
//
// Memory model (vLLM-style, simplified for f32-only):
//   k_cache, v_cache : [num_blocks, num_kv_heads, BLOCK_SIZE, head_dim]
//                      one shared pool; both cache buffers share the
//                      block_size grid.
//   block_tables     : [max_num_seqs, max_num_blocks_per_seq] u32
//                      block_tables[seq][i] = physical block index for
//                      the i-th logical block of sequence `seq`.
//   context_lens     : [max_num_seqs] u32 — true sequence length;
//                      simdgroups skip past it.
//
// The kernel takes two structs:
//   FlashAttnParams (already defined for contiguous variant) covers
//                   num_heads / num_kv_heads / head_dim / scale.
//   PagedAttnParams covers block_size / max_num_blocks_per_seq /
//                   kv_block_stride / kv_head_stride.
//
// `bi` (= tgpig.z) indexes into context_lens / block_tables directly
// — one TG per (head, sequence) just like the contiguous variant.
// pos_offset / kv_len from FlashAttnParams are NOT used here; the
// per-sequence context_len comes from the buffer.
//
// Why a separate kernel rather than runtime branching: the inner KV
// loop is the hot path (32 SGs × N positions). Apple's Metal compiler
// generates measurably tighter code when the block-table indirection
// is statically present rather than gated by a runtime flag. Cost:
// ~80 extra MSL lines, no extra dispatch overhead — both variants can
// live behind one `B::flash_attention` API.

struct PagedAttnParams {
    int num_heads;
    int num_kv_heads;
    int head_dim;
    float scale;
    int block_size;              // KV positions per physical block (16 typical)
    int max_num_blocks_per_seq;  // block_tables row stride
    int kv_block_stride;         // floats between consecutive blocks (= num_kv_heads * block_size * head_dim)
    int kv_head_stride;          // floats between consecutive kv heads within a block (= block_size * head_dim)
    // q_len > 1 support (causal prefill into a paged cache):
    int q_len;                   // 1 for decode, >1 for prefill batch
    int q_head_stride;           // floats between consecutive q heads
                                 //   q_len=1 token-major: head_dim
                                 //   q_len>1 head-major:  q_len * head_dim
    int o_head_stride;           // same shape as q_head_stride for the O buffer
};

kernel void flash_attn_decode_paged_f32(
    device const float*    Q              [[buffer(0)]],   // q_len=1: [num_seqs, num_heads, head_dim]
                                                            // q_len>1 single seq head-major: [num_heads, q_len, head_dim]
    device const float*    K_cache        [[buffer(1)]],   // [num_blocks, num_kv_heads, block_size, head_dim]
    device const float*    V_cache        [[buffer(2)]],   // same layout as K_cache
    device       float*    O              [[buffer(3)]],   // matches Q layout
    device const uint32_t* block_tables   [[buffer(4)]],
    device const uint32_t* context_lens   [[buffer(5)]],   // FINAL kv_len after this batch's writes
    constant PagedAttnParams& p           [[buffer(6)]],
    uint3  tgpig [[threadgroup_position_in_grid]],   // (q_token_idx, head, seq)
    uint   sgitg [[simdgroup_index_in_threadgroup]], // 0..31 — which KV-stripe
    uint   tiisg [[thread_index_in_simdgroup]])      // 0..31 — which D slice
{
    const int qi    = int(tgpig.x);    // 0..q_len-1
    const int hi    = int(tgpig.y);
    const int bi    = int(tgpig.z);
    const int kv_hi = hi / (p.num_heads / p.num_kv_heads);
    const int d     = p.head_dim;
    const int bs    = p.block_size;
    // Causal: token at q_token_idx=qi sees KV positions
    //   [0, context_len - (q_len - 1 - qi))
    // For q_len=1 this collapses to attend_end = context_len.
    const int context_len = int(context_lens[bi]);
    const int attend_end  = context_len - (p.q_len - 1 - qi);

    // Pointers — Q/O strides honour `q_head_stride` / `o_head_stride`
    // so callers can pick token-major (q_len=1) or head-major (q_len>1)
    // layouts without repacking. `bi` indexes into a per-seq slab via
    // num_heads * head_stride floats.
    device const float* q_row = Q + bi * p.num_heads * p.q_head_stride
                                  + hi * p.q_head_stride
                                  + qi * d
                                  + tiisg * SDPA_EPT;
    device       float* o_row = O + bi * p.num_heads * p.o_head_stride
                                  + hi * p.o_head_stride
                                  + qi * d;
    device const uint32_t* my_block_table = block_tables + bi * p.max_num_blocks_per_seq;

    // Per-thread Q (pre-scaled), running output, scratch K/V slice.
    float q[SDPA_EPT];
    float o_acc[SDPA_EPT];
    for (int i = 0; i < SDPA_EPT; ++i) {
        q[i]     = p.scale * q_row[i];
        o_acc[i] = 0.0f;
    }

    float max_score = -INFINITY;
    float sum_exp   = 0.0f;

    // KV loop — each simdgroup walks KV positions { sgitg, sgitg+BN, ... }.
    // For each position: resolve logical → physical block via block_tables.
    // `attend_end` (not `context_len`) bounds the loop so that q_token i
    // in a q_len > 1 prefill only sees positions ≤ i (causal).
    for (int ki = int(sgitg); ki < attend_end; ki += SDPA_BN) {
        const int logical_block = ki / bs;
        const int slot_in_block = ki % bs;
        const uint32_t physical_block = my_block_table[logical_block];

        // Pointer to this position's K/V slice. The cache layout is
        //   cache[physical_block][kv_hi][slot_in_block][d]
        // = base + physical_block*kv_block_stride
        //        + kv_hi*kv_head_stride
        //        + slot_in_block*d
        //        + tiisg*SDPA_EPT
        const int slice_off = int(physical_block) * p.kv_block_stride
                             + kv_hi * p.kv_head_stride
                             + slot_in_block * d
                             + int(tiisg) * SDPA_EPT;
        device const float* k_row = K_cache + slice_off;
        device const float* v_row = V_cache + slice_off;

        // Same dot + online softmax body as flash_attn_decode_f32.
        float dot_acc = 0.0f;
        float k_v[SDPA_EPT];
        for (int j = 0; j < SDPA_EPT; ++j) {
            k_v[j]   = k_row[j];
            dot_acc += q[j] * k_v[j];
        }
        const float score = simd_sum(dot_acc);

        const float new_max = max(max_score, score);
        const float factor  = exp(max_score - new_max);
        const float exp_sc  = exp(score      - new_max);

        max_score = new_max;
        sum_exp   = sum_exp * factor + exp_sc;

        for (int j = 0; j < SDPA_EPT; ++j) {
            o_acc[j] = o_acc[j] * factor + exp_sc * v_row[j];
        }
    }

    // Cross-simdgroup combine — identical structure to the contiguous
    // variant. Could be factored, but the 60 lines below are the
    // hotpath tail and we deliberately don't introduce a function-call
    // boundary on M1 Max where it can prevent register coalescing.
    threadgroup float outputs[SDPA_BN * SDPA_D];
    threadgroup float max_scores[SDPA_BN];
    threadgroup float sum_exp_scores[SDPA_BN];

    threadgroup float* my_out = outputs + sgitg * SDPA_D + tiisg * SDPA_EPT;
    for (int j = 0; j < SDPA_EPT; ++j) {
        my_out[j] = o_acc[j];
    }
    if (tiisg == 0) {
        max_scores[sgitg]     = max_score;
        sum_exp_scores[sgitg] = sum_exp;
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    if (sgitg == 0) {
        const float local_max = max_scores[tiisg];
        const float global_max = simd_max(local_max);
        const float local_factor = exp(local_max - global_max);
        const float local_sum_scaled = sum_exp_scores[tiisg] * local_factor;
        const float global_sum = simd_sum(local_sum_scaled);

        max_scores[tiisg]     = local_factor;
        sum_exp_scores[tiisg] = global_sum;
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    if (sgitg == 0) {
        const float gs = sum_exp_scores[0];
        const float inv_S = (gs > 0.0f) ? (1.0f / gs) : 0.0f;
        for (int j = 0; j < SDPA_EPT; ++j) {
            const int col = tiisg * SDPA_EPT + j;
            float total = 0.0f;
            for (int s = 0; s < SDPA_BN; ++s) {
                total += outputs[s * SDPA_D + col] * max_scores[s];
            }
            o_row[col] = total * inv_S;
        }
    }
}