mlx-native 0.6.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
// flash_attn_prefill_d512 — mlx-native NSG=8 flash-attention prefill kernel
// ported from llama.cpp's `kernel_flash_attn_ext_impl` template.
//
// ## What this file is
//
// A SECOND, independent Metal kernel implementing scaled-dot-product attention
// prefill for head_dim=512 (DK=DV=512) using llama.cpp's per-simdgroup-Q-
// distributed decomposition.  It is the companion to, NOT a replacement for,
// `flash_attn_prefill.metal` — the candle-derived D=256 kernel there is
// retained as-is (correct, fast, well-tested).
//
// ## Why a separate file (vs extending `flash_attn_prefill.metal`)
//
// The candle-derived template in `flash_attn_prefill.metal` uses a
// per-warp-Q-stacking layout where each simdgroup owns `BQ / kNWarps`
// Q rows of work.  The `BQ >= kNWarps * kFragSize` static_assert
// (`flash_attn_prefill.metal:1250-1251`) forces a minimum BQ of 64 at NSG=8
// (`kNWarps=WM*WN=8`), which at BD=512 bf16 needs a Q threadgroup tile of
// `64 * 520 * 2 = 66,560 B` — over 2× the 32 KiB Apple Silicon threadgroup-
// memory budget.  Cannot host NSG=8 at D=512.
//
// llama.cpp gets NSG=8 at D=512 within 28 KiB by using a FUNDAMENTALLY
// DIFFERENT per-simdgroup decomposition:
//
//   - Q rows are DISTRIBUTED across simdgroups (NQ = Q/NSG rows per simdgroup),
//     not stacked.  The Q tile lives ONCE in threadgroup memory (`sq[Q × DK]`)
//     and each simdgroup walks its owned Q rows by indexing `j = jj*NSG + sgitg`.
//   - The output `O` accumulator lives in THREADGROUP MEMORY (`so[Q × PV]` in
//     `half`), not in registers.  Each KV chunk pays `load + store` of the
//     per-simdgroup Otile frags (NO = PV8/NSG = 8 frags at NSG=8) to trade
//     shmem traffic for far lower register pressure (16 vs 128 per-thread
//     registers), which drives occupancy up by 8× on M5 Max.
//   - K/V are loaded directly from device memory via `simdgroup_load(…,
//     NS10, 0, true)` for K and `simdgroup_load(…, NS20, 0, false)` for V,
//     with each simdgroup owning disjoint 8-column slices of the K×V tile.
//
// These three differences are architectural — they cannot be reached from
// the candle template by parameter tuning.  The new kernel is a direct
// port of `kernel_flash_attn_ext_impl` from llama.cpp's
// `ggml-metal.metal:5736-6375`, specialised for:
//
//   - DK = DV = 512 only
//   - bf16/f16 I/O (f32 excluded as in flash_attn_prefill.metal — 32 KiB TG
//     budget is the hard limit)
//   - Unquantized K/V cache (is_q=0 branch only — llama.cpp's quantized
//     branch at `:6066-6127, :6257-6322` is dropped)
//   - NQPSG=8, NCPSG=64 hard-coded (as llama.cpp does at the host thunk
//     defaults `ggml-metal.metal:6403-6404`)
//   - NSG selectable at pipeline-creation time via the int function constant
//     `nsg` (index 322, mirrors llama.cpp's `FC_flash_attn_ext_nsg` at
//     `ggml-metal.metal:5735` = `FC_FLASH_ATTN_EXT + 22`).  Supported
//     values: 4 and 8.  NSG=8 is the required configuration for the 32 KiB
//     budget headroom win (§2.3 in ADR-011-phase2-port-d512.md); NSG=4
//     is provided to mirror llama.cpp's dispatch flexibility.
//   - In-kernel causal masking + additive (bf16/f16) OR bool mask, matching
//     the D=256 kernel's API surface (same function-constant indices for
//     `align_Q, align_K, has_mask, do_causal`).
//
// ## Features excluded vs llama.cpp's kernel (intentional scope reduction)
//
// These branches exist in llama.cpp's kernel and are dropped here.  They are
// all zero-cost-when-not-used (dead-code eliminated via function constants
// at pipeline creation), so retaining them would be free at runtime but
// would add ~200 LOC of complexity.  If future Gemma / DeepSeek-style
// features land, they can be re-added in a follow-up ADR:
//
//   - Quantized K/V cache (is_q=1 path): llama.cpp's Gemma 4 runs bf16 K/V,
//     matching our pipeline.  Port `:6066-6127, :6257-6322` when needed.
//   - Sinks (`FC_flash_attn_ext_has_sinks`): attention-sinks for StreamingLLM.
//     llama.cpp `:5722, :6328-6346`.
//   - ALiBi-style bias (`FC_flash_attn_ext_has_bias`): not used by Gemma 4.
//     llama.cpp `:5723, :5896-5903, :6146-6150`.
//   - Softcap (`FC_flash_attn_ext_has_scap`): Gemma 2 used it; Gemma 4
//     doesn't.  llama.cpp `:5724, :6140-6142`.
//   - KV-pad tail handling (`FC_flash_attn_ext_has_kvpad`): zero-copy padding
//     of the trailing partial KV tile.  We handle kL % NCPSG via per-position
//     -inf mask-out in the last KV chunk instead (matches our D=256 kernel's
//     approach via `align_K=false` + `kL_rem`).  llama.cpp `:5725, :5914-5949`.
//   - Broadcast mask (`FC_flash_attn_ext_bc_mask`): mask broadcast across
//     Q-rows.  Not used by our dispatcher.  llama.cpp `:5727, :5969-5970`.
//   - Per-tile pre-pass skip (`blk`): llama.cpp's `flash_attn_ext_blk`
//     writes a `{0,1,2}` bitmap per KV chunk letting the prefill skip
//     all-masked or pass-through chunks.  Deferred to Phase 4.  We treat
//     every chunk as `blk_cur = 1` (full mask).  llama.cpp `:5775, :5951-6005`.
//
// ## Numerical regime — identical to the D=256 kernel
//
// The row-max `M` is initialised to the FINITE sentinel `-FLT_MAX/2`
// (llama.cpp convention, `ggml-metal.metal:5891`).  Masked scores arrive
// as `-inf` from the additive mask buffer; `simd_max(-FLT_MAX/2, -inf)`
// floors at `-FLT_MAX/2` so `M` stays finite.  Every `exp(score - M)`
// evaluates as `exp(-inf) = +0.0` (IEEE-754 exact), never NaN.  The ONE
// output-side guard is at the final `output / sum_score` store:
//
//     S == 0 ? 0 : 1/S
//
// mirroring `ggml-metal.metal:6358`.  Same regime as `flash_attn_prefill.metal`;
// see that file's preamble for the full rationale and
// ADR-011-phase2-port-sentinel.md §1-3 for the line-by-line trace.
//
// ## Exponential base
//
// llama.cpp uses natural-base `exp` throughout.  Our candle-derived D=256
// kernel uses `fast::exp2` with Q pre-scaled by `scale * log2(e)`.  This
// file MIRRORS the D=256 choice (exp2 + pre-scale) so the host-side
// `TransformScale` contract and `AttnParams::scale` semantics are identical
// across both kernels.  The mathematical result is unchanged.
//
// ## Shared-memory layout (dynamic, sized by host)
//
// The dispatcher passes `threadgroup half*` with size `smem_bytes` as
// set via `setThreadgroupMemoryLength` on the encoder:
//
//   Offset 0         : sq[Q × DK]   — query tile, as I/O-dtype T (bf16/f16)
//   Offset Q×DK×2    : so[Q × PV]   — output accumulator in half
//   Offset Q×(DK+2PV): ss[Q × SH]   — softmax/mask scratch (SH = 2*C)
//   Offset Q×T +     : sk/sv ... per-simdgroup dequant scratch (is_q=1 only;
//   sgitg*(4*16*KV)    dropped in this port)
//   Total: Q × (DK + 2*PV) × sizeof(half) = 8 × (512 + 1024) × 2 = 24,576 B
//   Padded to 16 B alignment: 24,576 B.  We set 28,672 B to match llama.cpp
//   exactly (§2.3 in ADR-011-phase2-port-d512.md).
//
// References (kernel body is a direct port; see inline citations):
//   - /opt/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal:5736-6375
//     (kernel_flash_attn_ext_impl — single template, all DKs, all NSGs)
//   - /opt/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal:6421-6427
//     (dispatch thunk; cases 4/8 → kernel_flash_attn_ext_impl<…, NSG>)
//   - /opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h:77, 93-94
//     (FC offsets, NQPSG=8, NCPSG=64 defines)
//   - /opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp:2750-2900
//     (host dispatch: ne00>=512 → nsg=8; FATTN_SMEM size formula; grid)
//   - ADR-011-phase2-port-d512.md  (port spec; this file follows §7 checklist)
//
// SPDX-License-Identifier: MIT

#include <metal_stdlib>
#include <metal_simdgroup>

using namespace metal;

// `FOR_UNROLL` mirrors llama.cpp's macro at ggml-metal.metal:26.
#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x)

// ──────────────────────────────────────────────────────────────────────────
// Common structs (same ABI as flash_attn_prefill.metal)
// ──────────────────────────────────────────────────────────────────────────
//
// Host mirror: /opt/mlx-native/src/ops/flash_attn_prefill.rs AttnParamsGpu.
// Field order / types / total 160 B byte-for-byte identical to the
// definition in flash_attn_prefill.metal:1025-1051.
struct AttnParams {
  int B;
  int H;
  int D;

  int qL;
  int kL;

  int gqa_factor;
  float scale;
  float softcapping;

  int NQ;
  int NK;

  int NQ_aligned;
  int NK_aligned;

  int qL_rem;
  int kL_rem;
  int qL_off;

  int64_t Q_strides[3];
  int64_t K_strides[3];
  int64_t V_strides[3];
  int64_t O_strides[3];
};

struct AttnMaskParams {
  int64_t M_strides[3];
};

// ──────────────────────────────────────────────────────────────────────────
// Function constants
// ──────────────────────────────────────────────────────────────────────────
//
// Indices 200/201/300/301 MATCH flash_attn_prefill.metal so dispatcher
// helpers are uniform across D=256 and D=512.
//
// Index 322 is the int function constant controlling NSG — mirrors
// llama.cpp's `FC_flash_attn_ext_nsg` (ggml-metal.metal:5735; value is
// `FC_FLASH_ATTN_EXT + 22` = 300 + 22 = 322 in their numbering scheme,
// adopted here to keep semantic parity).  Supported values: 4, 8.

constant bool align_Q  [[function_constant(200)]];
constant bool align_K  [[function_constant(201)]];

constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]];

// Wave 2E tile-skip pre-pass.  See flash_attn_prefill.metal's `has_blk`
// declaration for the full contract; same index 303 across both D=256
// and D=512 so the dispatcher can set one bool and feed either kernel.
constant bool has_blk  [[function_constant(303)]];

constant int  fc_nsg   [[function_constant(322)]];

// Provide sensible defaults when a function constant isn't set (avoids
// undefined-behaviour during preview compiles).  Actual values are always
// supplied by the dispatcher at pipeline-creation time.
constant bool align_Q_def  = is_function_constant_defined(align_Q)  ? align_Q  : true;
constant bool align_K_def  = is_function_constant_defined(align_K)  ? align_K  : true;
constant bool has_mask_def = is_function_constant_defined(has_mask) ? has_mask : false;
constant bool do_causal_def = is_function_constant_defined(do_causal) ? do_causal : false;
constant bool has_blk_def  = is_function_constant_defined(has_blk)  ? has_blk  : false;
constant int  nsg_def      = is_function_constant_defined(fc_nsg)   ? fc_nsg   : 8;

// ──────────────────────────────────────────────────────────────────────────
// Kernel template
// ──────────────────────────────────────────────────────────────────────────
//
// Faithful port of `kernel_flash_attn_ext_impl` from
// /opt/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal:5736-6375 for the
// (DK=DV=512, bf16/f16 K/V, is_q=0) slice, with NSG as a template parameter
// (compile-time, specialised from the int function constant via the outer
// wrapper below).
//
// T    — I/O scalar type: bfloat or half.
// MaskT — Mask scalar type: same as T (additive) or `bool` (is_attended).
// NSG  — Simdgroups per threadgroup: 4 or 8.
//
// See llama.cpp kernel_flash_attn_ext_impl template parameters at
// ggml-metal.metal:5738-5766 for the full type-plumbing model; we collapse
// the 24 type/arity parameters there to {T, MaskT, NSG} because we fix
// DK=DV=512, q_t=k_t=v_t=kd_t=vd_t=T, q8x8_t=k8x8_t=v8x8_t=simdgroup_T8x8,
// qk_t=s_t=float, o_t=T (to match our D=256 kernel's output precision
// contract — bf16/f16 final store, f32 internal accum).
// NOTE: `flash_attn_prefill_d512_impl` is a HELPER function, NOT a Metal
// kernel.  It does not carry `[[kernel]]` — a kernel cannot be called from
// another function in MSL, but a plain function with threadgroup/grid/
// lane-id arguments forwarded from the caller is legal.  This matches
// llama.cpp's `kernel_flash_attn_ext_impl` at ggml-metal.metal:5767
// (declared without `[[kernel]]`; called from the outer `kernel void
// kernel_flash_attn_ext` thunk at :6405).
template <typename T, typename MaskT, short NSG>
void flash_attn_prefill_d512_impl(
    const device char* q_base,
    const device char* k_base,
    const device char* v_base,
    device       char* o_base,
    const constant AttnParams& args,
    const constant AttnMaskParams& mask_args,
    const device MaskT* mask_base,
    const device char*  blk_base,
    threadgroup  half* shmem_f16,
    uint3  tgpig,
    ushort tiisg,
    ushort sgitg) {

  // ── Tile/workload shape constants ──────────────────────────────────────
  //
  // Mirrors ggml-metal.metal:5795-5814.
  //   Q   — queries per threadgroup (NQPSG),     = 8   (llama.cpp-impl.h:93)
  //   C   — cache items per threadgroup (NCPSG), = 64  (llama.cpp-impl.h:94)
  //   DK  — K head size,                         = 512 (fixed in this kernel)
  //   DV  — V head size,                         = 512 (fixed in this kernel)
  //   NQ  — Q rows per simdgroup,                = Q / NSG
  //   NC  — Q·K^T 8×8 frags per simdgroup,       = (C/8) / NSG
  //   NO  — output 8×8 frags per simdgroup,      = (DV/8) / NSG
  //   SH  — softmax scratch halves per Q row,    = 2 × C
  //   NW  — simd width,                          = 32
  constexpr short Q  = 8;
  constexpr short C  = 64;
  constexpr short DK = 512;
  constexpr short DV = 512;

  constexpr short DK4  = DK / 4;
  constexpr short DK8  = DK / 8;
  // PV matches llama.cpp's PAD2(DV, 64) (ggml-metal.metal:5804); for DV=512
  // PV=512 so there is no actual padding.
  constexpr short PV   = DV;
  constexpr short PV4  = PV / 4;
  constexpr short PV8  = PV / 8;

  constexpr short NW  = 32;
  constexpr short NQ  = Q / NSG;                 // Q rows per simdgroup
  constexpr short SH  = 2 * C;                   // softmax+mask scratch
  // `T_STRIDE` is the llama.cpp name `T` (ggml-metal.metal:5814).  We rename
  // to avoid shadowing the template type parameter `T` below.  Halves per
  // Q-major region: `DK + 2*PV` (sq takes DK halves, so takes 2*PV halves,
  // ss takes 2*SH halves = 2*(2*C) = 2*SH < 2*PV for our geometry).
  constexpr short T_STRIDE = DK + 2 * PV;

  // Layout assertions (copied from llama.cpp ggml-metal.metal:6018, :6182).
  static_assert((C / 8) % NSG == 0, "NSG must divide C/8");
  static_assert(PV8     % NSG == 0, "NSG must divide PV8");
  static_assert(Q       % NSG == 0, "NSG must divide Q");

  // ── Threadgroup memory regions ─────────────────────────────────────────
  //
  // Layout mirrors ggml-metal.metal:5816-5828 STRUCTURALLY, but the `so`
  // accumulator here is stored in f32 (NOT half as llama.cpp's bf16-I/O
  // specialisation at FA_TYPES_BF).  Rationale: at runtime llama.cpp's
  // default KV-cache dtype is f16 (common/common.h:344), which dispatches
  // its *f16* FA_TYPES instantiation — and FA_TYPES uses `float` for o_t
  // (see ggml-metal.metal:6441, last line of FA_TYPES macro).  FA_TYPES_BF
  // is the bf16 I/O variant which llama.cpp defines for bf16 K/V caches
  // but its default cache types route Gemma 4 inference through the F16
  // kernel at runtime.  Matching llama.cpp's *actual inference behaviour*
  // (f32 O accumulator) — not its literal bf16 template instantiation —
  // gives byte-identical output; the half-O variant loses ~10 bits of
  // accumulator precision per KV chunk which compounds across Gemma 4's
  // 5 global layers (observed pre-fix: 1026-byte common prefix with
  // half O; post-fix: full 3094+ with f32 O).
  //
  // For (DK=DV=512, Q=8, C=64, NSG=8, bf16 I/O, f32 so, f32 ss):
  //   sq offset  0 bytes , size 8×512×2 =  8192 bytes
  //   so offset  8192     , size 8×512×4 = 16384 bytes (f32)
  //   ss offset  24576    , size 8×128×4 =  4096 bytes (f32)
  //   Total:    28672 B
  //   Exactly equals the dispatcher's 28672 B budget — we've re-used the
  //   is_q=1 dequant scratch reservation (which llama.cpp paid for but
  //   we don't need since is_q=0) for the widened O accumulator.
  //
  // Shared-memory offsets in halves (base ptr is `shmem_f16 + half index`):
  //   sq: [0, Q*DK)               = [0, 4096)
  //   so: [Q*DK, Q*DK + Q*PV*2)   = [4096, 12288)   — so floats live at
  //                                  half-index 4096..12287 as pairs
  //   ss: [Q*T_STRIDE, ...)       = [12288, 14336)  — same as before
  threadgroup T*     sq  = (threadgroup T*)     (shmem_f16 + 0 * T_STRIDE);
  threadgroup float* so  = (threadgroup float*) (shmem_f16 + 0 * T_STRIDE + Q * DK);
  // Note: ss is float-typed (s_t == float in llama.cpp FA_TYPES_BF).  Row j
  // starts at `Q*T_STRIDE halves + 2*j*SH halves` = `Q*T_STRIDE halves +
  // j*(2*SH) halves`; in float units (ss's native) row j starts at
  // `j*SH floats`.  Row stride in halves = 2*SH = 256 (for SH=128).
  //
  // The ss region is laid out as Q contiguous rows of SH floats each
  // (= 2*SH halves each).  llama.cpp's `sm2` is a half2 alias over the
  // SECOND HALF of each ss row (scores live in cols [0, C) floats; mask
  // lives in cols [C, 2*C) floats of the same row; C floats = 2*C halves
  // = C half2s).  Rather than cast `sm2` with a `+2*C halves` base offset
  // (which would require its own cast-then-index), we hold `sm2` at ss-base
  // and add `C` to the index inside the row — arithmetically identical to
  // llama.cpp ggml-metal.metal:5830 `sm2 = shmem_f16 + Q*T + 2*C` with
  // `sm2[j*SH + tiisg]` re-expressed as `sm2_at_ss_base[j*SH + C + tiisg]`.
  threadgroup float*  ss  = (threadgroup float*) (shmem_f16 + Q * T_STRIDE);
  threadgroup float2* ss2 = (threadgroup float2*)(shmem_f16 + Q * T_STRIDE);
  threadgroup half2*  sm2 = (threadgroup half2*) (shmem_f16 + Q * T_STRIDE);

  // ── Per-threadgroup (batch, head, qL-tile) indices ────────────────────
  //
  // llama.cpp uses tgpig.x for qL-tile, .y for head, .z for batch
  // (ggml-metal.metal:5781-5783).  Our host grid has the same order:
  // grid = (NQ_tiles, H, B).
  const ushort iq3 = tgpig.z;                    // batch
  const ushort iq2 = tgpig.y;                    // head
  const ushort iq1 = tgpig.x * Q;                // first Q row index (this tile)

  // ── Base pointer offsetting (applies {batch, head, qL-tile} stride) ──
  //
  // Our Rust dispatcher lays out Q / K / V / O as [B, H, L, D] contiguous
  // with element strides (B, H, L, D=1).  llama.cpp uses byte strides
  // (args.nb01, .nb02, .nb03 etc.); we substitute our i64 element strides.
  // See ggml-metal.metal:5849-5856.

  const device T* q_typed = (const device T*)(q_base);
  const device T* k_typed = (const device T*)(k_base);
  const device T* v_typed = (const device T*)(v_base);
  device       T* o_typed = (device       T*)(o_base);

  // Q / O are per-query-head (H heads).  K / V are per-KV-head (H / gqa).
  const ulong kv_head = iq2 / args.gqa_factor;

  const device T* q_head = q_typed + (ulong)iq3 * (ulong)args.Q_strides[0]
                                    + (ulong)iq2 * (ulong)args.Q_strides[1];
  const device T* k_head = k_typed + (ulong)iq3 * (ulong)args.K_strides[0]
                                    +       kv_head * (ulong)args.K_strides[1];
  const device T* v_head = v_typed + (ulong)iq3 * (ulong)args.V_strides[0]
                                    +       kv_head * (ulong)args.V_strides[1];
  device       T* o_head = o_typed + (ulong)iq3 * (ulong)args.O_strides[0]
                                    + (ulong)iq2 * (ulong)args.O_strides[1];

  // K / V element-stride between consecutive KV items.  Layout is contiguous
  // in D, stride-D between items, stride-(kL*D) between heads — matching
  // llama.cpp's `args.nb11 / sizeof(k_t)` == NS10 at runtime (we pass this
  // via args.K_strides[2]).  Similarly for V (args.V_strides[2] = NS20).
  const int NS10 = int(args.K_strides[2]);
  const int NS20 = int(args.V_strides[2]);

  // ── Per-query mask pointers ────────────────────────────────────────────
  //
  // llama.cpp loads the additive mask for each Q row once per KV chunk
  // (ggml-metal.metal:5833-5839).  Our mask layout is [B, H, qL, kL] and
  // mask_args.M_strides = (batch, head, qL-row) with inner dim = 1
  // (kL stride is 1 — element-contiguous in kL).
  //
  // We hold one device-pointer per jj index into the [NQ] loop.  For our
  // additive/bool mask variants, MaskT differs (bf16/f16 vs bool); both
  // read at (iq1 + j)*kL + ic + tiisg — llama.cpp's `pm2[jj] += NW` pattern
  // is equivalent to advancing by 32 halves = 64 bytes per step, which is
  // exactly the simdgroup-wide mask-read slab used below.
  //
  // Unused when has_mask=false — dead-code-eliminated.
  device const MaskT* pm[NQ];
  if (has_mask) {
    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
      const short j = jj * NSG + sgitg;
      // mask_args.M_strides: (batch, head, qL-row).  Additive mask uses
      // element indices; bool mask is byte-sized but kL-contiguous so the
      // same offset formula works.
      pm[jj] = mask_base + (ulong)iq3 * (ulong)mask_args.M_strides[0]
                         + (ulong)iq2 * (ulong)mask_args.M_strides[1]
                         + (ulong)(iq1 + j) * (ulong)mask_args.M_strides[2];
    }
  }

  // ── Load Q tile and zero O / SS ────────────────────────────────────────
  //
  // Direct port of ggml-metal.metal:5858-5884.  Each simdgroup loads its
  // NQ owned rows of the Q tile from device memory into `sq`, then zeros
  // its owned slots in `so` and `ss`.
  //
  // Q is stored UNSCALED, mirroring llama.cpp's ggml-metal.metal:5862-5870
  // (`sq4[…] = (q4_t) q4[i]` — pure type cast, no scale).  Scale is applied
  // AFTER the Q·K^T matmul, inside the online softmax step
  // (`float2 s2 = ss2[…] * args.scale` at :6138).  Keeping scale out of the
  // Q-tile preserves llama.cpp's bf16 rounding behaviour bit-for-bit:
  // pre-scaling Q by `scale * log2(e)` on load would round `Q*α` to bf16
  // once per element, introducing systematic per-element bias that
  // accumulates across the 512-wide dot product into measurable drift on
  // Gemma 4 global-layer outputs (observed: byte-1026 divergence from
  // llama.cpp on sourdough_gate).  Doing the multiply post-matmul keeps
  // the bf16 round on Q identical to llama.cpp's.
  //
  // sq uses element layout [Q][DK]; each row is DK elements of dtype T.
  FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
    const short j = jj * NSG + sgitg;
    const bool j_in_range = (int)(iq1 + j) < args.qL;

    const device T* q_row =
        q_head + (ulong)(iq1 + j) * (ulong)args.Q_strides[2];

    // Each lane loads ceil(DK / NW) elements; DK=512 / 32 = 16 per lane.
    for (short i = tiisg; i < DK; i += NW) {
      T val = j_in_range ? q_row[i] : T(0);
      sq[j * DK + i] = val;
    }
  }

  FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
    const short j = jj * NSG + sgitg;

    // Zero per-row O.  so is f32 here (see threadgroup-region comment above
    // for why we widen the llama.cpp FA_TYPES_BF half accumulator to f32).
    for (short i = tiisg; i < DV; i += NW) {
      so[j * PV + i] = 0.0f;
    }

    // Zero per-row SS (softmax / mask scratch).
    for (short i = tiisg; i < SH; i += NW) {
      ss[j * SH + i] = 0.0f;
    }
  }

  threadgroup_barrier(mem_flags::mem_threadgroup);

  // ── Per-simdgroup softmax state ────────────────────────────────────────
  //
  // Each simdgroup owns NQ scalar (M, S) pairs, one per Q row assigned to
  // it.  M initialised to the llama.cpp finite sentinel -FLT_MAX/2
  // (ggml-metal.metal:5891); S initialised to 0 (:5888).
  float S[NQ];
  float M[NQ];
  FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
    S[jj] = 0.0f;
    M[jj] = -FLT_MAX / 2.0f;
  }

  // ── Resolve KV-chunk upper bound (causal truncation) ───────────────────
  //
  // llama.cpp's in-kernel causal mask runs on a per-element basis inside
  // the softmax step (:6132-6151 via the mask logic).  Our candle-derived
  // D=256 kernel hoists the causal cut-off to the outer loop via `kb_lim`
  // (flash_attn_prefill.metal:1313-1316) to skip whole tiles after the
  // diagonal.  We do the same here for consistency with our D=256 behaviour
  // and dispatcher contract (AttnParams::qL_off).
  int kL_chunks = (args.kL + C - 1) / C;   // total KV chunks (ceil)
  if (do_causal_def) {
    // Each Q row `iq1 + j` (j in [0, Q)) can attend to K positions up to
    // `args.qL_off + iq1 + j` inclusive.  For the threadgroup, the max
    // Q row is `iq1 + Q - 1`, so kL-limit in elements is
    // `args.qL_off + iq1 + Q`.  Divide ceil by C for chunks.
    int q_max = (int)iq1 + Q + args.qL_off;
    kL_chunks = (q_max + C - 1) / C;
    if (kL_chunks < 0) kL_chunks = 0;
  }

  // Clamp to valid range.
  if (kL_chunks * C > args.kL + C) {
    kL_chunks = (args.kL + C - 1) / C;
  }

  // ── Wave 2E tile-skip pre-pass row base (D=512) ────────────────────────
  //
  // The blk buffer shape is [NQ, NK] where NQ = ceil(qL/Q), NK = ceil(kL/C).
  // Layout matches the D=256 kernel: one 2D plane broadcast across batch
  // and heads.  Each threadgroup owns one Q-tile (tgpig.x); the row base
  // is `blk + tgpig.x * NK`.
  //
  // Port of llama.cpp ggml-metal.metal:5841-5846, adapted to our 2D mask
  // + single-plane blk layout.  See /opt/hf2q/docs/ADR-011-phase2-port-tile-skip.md §6.
  const device char* blk_row = nullptr;
  if (has_blk_def) {
    const int NK_blk = (args.kL + C - 1) / C;
    blk_row = blk_base + int(tgpig.x) * NK_blk;
  }

  // ── KV-cache sweep ─────────────────────────────────────────────────────
  //
  // Direct port of the ic0 loop at ggml-metal.metal:5907-6326.  Per chunk:
  //   1. Optionally load the additive mask slab (`sm2`).
  //   2. Compute Q·K^T into `ss` via NSG simdgroups × NC 8×8 frags each.
  //   3. Online softmax: update M, S per Q row; rescale existing `so`.
  //   4. Accumulate O += P·V into the threadgroup-half `so` via NSG
  //      simdgroups × NO output frags each.
  //
  // Simplifications vs llama.cpp (see file preamble):
  //   - is_q=0 only (direct simdgroup_load from K/V device memory).
  //   - No sinks / bias / softcap / kvpad / bc_mask.
  //   - blk_cur handled via Wave 2E has_blk pre-pass when enabled.
  for (int ic0 = 0; ic0 < kL_chunks; ++ic0) {
    const int ic = ic0 * C;

    // ── Wave 2E tile-skip branch (D=512) ────────────────────────────────
    //
    // Same three-way classification as the D=256 kernel; see that file's
    // comment for the numerics argument.  When has_blk is false blk_cur is
    // forced to 1 and the entire branch is dead-coded by the compiler.
    //
    // The `continue` here skips all simdgroup work: mask load, Q·K^T,
    // softmax update, and the P·V accumulate.  Because every simdgroup
    // in the threadgroup takes the same branch (blk_cur is a single byte
    // read, identical across lanes and simdgroups), the cross-simdgroup
    // threadgroup_barrier at the end of the previous iteration has
    // already synchronised state; no barrier is needed before `continue`.
    char blk_cur = 1;
    if (has_blk_def) {
      blk_cur = blk_row[ic0];
      if (blk_cur == 0) {
        continue;
      }
    }

    // ── Mask slab load (ggml-metal.metal:5954-5981, blk_cur==1 case) ─
    //
    // Each simdgroup loads its NQ owned rows × C mask columns into `sm2`.
    // For our dispatcher: MaskT ∈ {T, bool}.  The additive path reads T
    // (bf16/f16) and writes half-promoted into sm2.  The bool path reads
    // 1 byte per element and converts to 0.0 / -FLT_MAX/2 (the finite
    // sentinel; see the apply step below).
    //
    // Wave 2E: skip the mask load when blk_cur == 2 (all-attended tile —
    // the mask-add is a no-op and the sm2 region is not consumed below
    // because the `has_mask_def && blk_cur != 2` gate in the softmax step
    // matches this elision).  Port of llama.cpp ggml-metal.metal:6145.
    if (has_mask_def && blk_cur != 2) {
      constexpr bool is_bool_mask = is_same_v<MaskT, bool>;

      FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
        const short j = jj * NSG + sgitg;
        // llama.cpp uses 2 halves per lane (half2) to cover 64 elements per
        // simdgroup (NW=32 × 2 = 64).  For C=64 that's exactly the mask
        // slab width.  We match that pattern: each lane writes 2 halves
        // into sm2[j*SH + lane].
        if (!is_bool_mask) {
          // Additive mask in I/O dtype — cast to half2 via two reads.
          // Mask stride for kL dim is 1 (contiguous).
          const int col0 = ic + 2 * (int)tiisg;
          const int col1 = col0 + 1;

          half v0 = (col0 < args.kL) ? (half)(float)(pm[jj][col0])
                                     : (half)(-FLT_MAX / 2.0f);
          half v1 = (col1 < args.kL) ? (half)(float)(pm[jj][col1])
                                     : (half)(-FLT_MAX / 2.0f);
          // Write into row j's mask sub-region: base ss + row-j-offset
          // + (C float2-slots past the score region) + tiisg.
          // See sm2 layout note at declaration.
          sm2[j * SH + C + tiisg] = half2(v0, v1);
        } else {
          // Boolean mask — false → -FLT_MAX/2 (finite sentinel),
          //                true  → 0.0 (additive identity).
          const int col0 = ic + 2 * (int)tiisg;
          const int col1 = col0 + 1;

          bool b0 = (col0 < args.kL) ? bool(pm[jj][col0]) : false;
          bool b1 = (col1 < args.kL) ? bool(pm[jj][col1]) : false;

          half v0 = b0 ? (half)0.0h : (half)(-FLT_MAX / 2.0f);
          half v1 = b1 ? (half)0.0h : (half)(-FLT_MAX / 2.0f);
          sm2[j * SH + C + tiisg] = half2(v0, v1);
        }
      }
    }

    // ── Q·K^T matmul (ggml-metal.metal:6009-6065, is_q=0 branch) ────────
    //
    // Each simdgroup owns NC = (C/8)/NSG of the 8×8 score frags along the
    // KV-column dimension.  For NSG=8, NC=1 (one 8×8 frag per simdgroup).
    // For NSG=4, NC=2.
    //
    // K is read directly from device memory via simdgroup_load with
    // transpose=true (K columns are contiguous, scores need K^T).  NS10
    // is the stride between consecutive KV items in elements (= D for our
    // contiguous layout).
    //
    // We write the accumulated frag into `ss + 8*cc + 8*NSG*sgitg` — each
    // simdgroup owns a disjoint 8-wide column of `ss` (ss layout:
    // [Q rows × SH columns], each row is SH halves wide = 128 halves for
    // Q=8, C=64).
    {
      constexpr short NC = (C / 8) / NSG;

      // Device pointer to the first K element for this chunk's K start,
      // offset to this simdgroup's starting KV-column group.
      const device T* pk = k_head + (ulong)ic * (ulong)NS10
                                  + (ulong)sgitg * (ulong)(8 * NS10);

      // Pointer into ss at this simdgroup's column origin (in halves/floats).
      threadgroup float* ps = ss + 8 * sgitg;

      FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
        // mqk accumulator — initialised to 0, unique per simdgroup.
        simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.0f);

        // DK=512 → DK8=64, DK8/2=32 matmul iterations.  Mirrors
        // ggml-metal.metal:6040-6058 (DK%16==0 path).
        #pragma unroll(4)
        for (short i = 0; i < DK8 / 2; ++i) {
          simdgroup_barrier(mem_flags::mem_none);

          simdgroup_matrix<T, 8, 8> mq[2];
          simdgroup_matrix<T, 8, 8> mk[2];

          // Load 2 × (8×8) Q frags from sq at this simdgroup's Q-row
          // owned range.  llama.cpp:6048-6049 — `pq = sq` is the shared
          // Q tile; simdgroup_load at offset `8*i` along the DK axis.
          //
          // Because sq is shared across simdgroups (all simdgroups see the
          // same Q rows), we read from the full sq buffer.  Each 8×8 Q
          // frag spans 8 Q rows × 8 DK columns, and we load ALL 8 Q rows
          // per frag (NOT just NQ) because the output score frag mqk
          // occupies the same 8 Q rows × 8 KV-col slot in ss.  This means
          // each simdgroup redundantly computes scores for all 8 Q rows
          // but owns only its NC columns — matching llama.cpp's exact
          // pattern (single `pq` pointer, not per-simdgroup).
          simdgroup_load(mq[0], sq + 0 * 8 + 16 * i, DK);
          simdgroup_load(mq[1], sq + 1 * 8 + 16 * i, DK);

          simdgroup_load(mk[0], pk + 0 * 8 + 16 * i, NS10, 0, true);
          simdgroup_load(mk[1], pk + 1 * 8 + 16 * i, NS10, 0, true);

          simdgroup_barrier(mem_flags::mem_none);

          simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk);
          simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk);
        }

        // Store the 8×8 score frag into ss at this simdgroup's column
        // origin.  SH = 128 is the row stride of ss in floats.
        simdgroup_store(mqk, ps, SH, 0, false);

        // Advance pk and ps to the next simdgroup's KV-column group.
        // llama.cpp:6063-6064.
        pk += 8 * (NSG * NS10);
        ps += 8 * NSG;
      }
    }

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // ── Online softmax + rescale O (ggml-metal.metal:6131-6174) ──────────
    //
    // Per simdgroup, per owned Q row jj:
    //   1. Read score vector ss2[j*SH/2 + tiisg] (2 floats per lane,
    //      covering 64 columns = full C).  Score already includes Q's
    //      log2(e) pre-scale; apply the additive mask if present.
    //   2. Update M, S per Flash-Attention online softmax formula.
    //   3. Store P = exp(score - M) back into ss2 (replaces score).
    //   4. Rescale the existing `so` accumulator by `ms = exp(M_old - M_new)`.
    //
    // mqk is already ss; simdgroup_store wrote 8 rows × 8 cols per frag.
    // Each simdgroup participated in all Q rows (see note in QK matmul),
    // but each simdgroup owns a different NC slice of the KV columns.
    // At this barrier, ss holds scores for all 8 Q rows × full C columns.
    //
    // Softmax idiom: line-by-line port of llama.cpp ggml-metal.metal:6131-
    // 6174, including the scale-after-matmul step at :6138.  The scale is
    // applied HERE (f32 multiply on the MMA output), not baked into Q on
    // load — this keeps the Q bf16 rounding behaviour bit-identical to
    // llama.cpp's.  We use `exp` (natural) like llama.cpp rather than
    // `exp2`: the score space is natural-log because Q was stored
    // unscaled, so natural exp is the correct inverse.  Metal's `exp`
    // compiles to a single instruction on Apple Silicon (same latency as
    // `exp2`); choosing natural exp here means the sourdough_gate's bf16
    // output stream can match llama.cpp's flash_attn_ext output byte-for-
    // byte.
    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
      const short j = jj * NSG + sgitg;

      const float m_old = M[jj];

      // Read 2 floats (one float2) from ss2 covering columns
      // [2*tiisg, 2*tiisg+1].  Apply `args.scale` post-matmul, matching
      // llama.cpp :6138.
      float2 s2 = ss2[j * (SH / 2) + tiisg] * args.scale;

      // Apply additive mask in NATURAL units — matches llama.cpp
      // ggml-metal.metal:6149 `s2 += s2_t(sm2[j*SH + tiisg])` (no
      // log2(e) factor because natural-exp is used below).  Mask cells
      // are already in natural-log space (bf16 -inf for masked,
      // bf16 0.0 for attended).
      //
      // Under the finite-sentinel regime, masked positions come in as
      // bf16::NEG_INFINITY, which in float promotion is -inf.  `s2 +=
      // -inf` → -inf; finite M absorbs it via simd_max without NaN.
      if (has_mask_def && blk_cur != 2) {
        half2 m2 = sm2[j * SH + C + tiisg];
        s2 = s2 + float2((float)m2.x, (float)m2.y);
      }

      // Causal masking — the qL-relative absolute position of row
      // (iq1 + j) inside the full qL is `args.qL_off + iq1 + j`, and the
      // KV columns at this iteration are ic + {2*tiisg, 2*tiisg+1}.
      // Positions where col > q_abs are masked out.
      if (do_causal_def) {
        int q_abs = args.qL_off + (int)iq1 + (int)j;
        int c0 = ic + 2 * (int)tiisg;
        int c1 = c0 + 1;
        if (c0 > q_abs) s2.x = -INFINITY;
        if (c1 > q_abs) s2.y = -INFINITY;
      }

      // Align_K=false trailing-chunk guard — mask out columns beyond kL.
      if (!align_K_def && ic + C > args.kL) {
        int c0 = ic + 2 * (int)tiisg;
        int c1 = c0 + 1;
        if (c0 >= args.kL) s2.x = -INFINITY;
        if (c1 >= args.kL) s2.y = -INFINITY;
      }

      // Row max over this simdgroup's lanes covering all C columns.
      // Matches llama.cpp :6153.
      float my_max = max(s2.x, s2.y);
      float new_max = simd_max(max(m_old, my_max));

      M[jj] = new_max;

      // Natural exp — direct port of llama.cpp :6155-6156.  Unguarded:
      // new_max is always finite under the sentinel regime, so
      // exp(-inf - finite) = exp(-inf) = +0.0 (IEEE-754 exact).
      float2 vs2;
      vs2.x = exp(s2.x - new_max);
      vs2.y = exp(s2.y - new_max);

      // Rescale factor — matches llama.cpp :6155
      // `const float ms = exp(m - M[jj]);`.  Finite by construction.
      const float ms = exp(m_old - new_max);

      // Update S.
      S[jj] = S[jj] * ms + simd_sum(vs2.x + vs2.y);

      // Store P back into ss2 for consumption by the P·V matmul below.
      ss2[j * (SH / 2) + tiisg] = vs2;

      // Rescale O for this row by `ms`.  Only this simdgroup touches its
      // owned Q row, so no cross-simdgroup synchronisation needed here
      // (barrier at end of outer loop covers the so write-back).
      // DV = 512; 512 / NW = 16 iterations per lane.  so is f32, so no
      // precision loss on rescale (unlike llama.cpp's half so4[i] *= ms
      // which rounds through half at every chunk).
      FOR_UNROLL (short ii = 0; ii < DV / NW; ++ii) {
        const short i = ii * NW + tiisg;
        so[j * PV + i] = so[j * PV + i] * ms;
      }
    }

    // NOTE: we intentionally do NOT correct s2.x / s2.y above for the
    // "score was reverted to base-2 by the pre-scaled Q" vs "log2(e) factor
    // on mask" mismatch — they are already consistent.  The score `s2` is
    // already in base-2 (Q was pre-scaled), and the mask-addition uses
    // log2(e) to promote the natural-log mask into the base-2 space.  This
    // matches our D=256 kernel's contract bit-for-bit.

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // ── O += P · V (ggml-metal.metal:6179-6256, is_q=0 branch) ───────────
    //
    // Each simdgroup owns NO = PV8/NSG output frags along the DV axis.
    // Load per-simdgroup output frags from `so` into registers `lo[NO]`,
    // accumulate against (P × V) across all C/8 = 8 KV-cache frags, then
    // store back to `so`.
    //
    // For DV=512 > 64 we take llama.cpp's "wide-DV" branch at :6220-6245
    // which uses NO/2 inner iterations processing 2 × 8 = 16 DV-frags per
    // s-frag.  NO=8 at NSG=8, so NO/2=4 inner iterations.
    {
      constexpr short NO = PV8 / NSG;

      // We use f32 O accumulator (simdgroup_float8x8) rather than the
      // llama.cpp FA_TYPES_BF half (simdgroup_half8x8) design.  See the
      // threadgroup-region comment at the top of the kernel for the
      // rationale — llama.cpp's runtime F16-KV-cache path uses FA_TYPES
      // which has `float` for o_t, and matching THAT behaviour (not the
      // bf16 FA_TYPES_BF variant's half o_t) gives byte-identical output
      // to llama.cpp at runtime.  The cost is 8 KB more threadgroup
      // memory for so (still within the 28 KB budget because we've
      // re-used the is_q=1 dequant scratch reservation).
      //
      // Matches D=256 kernel's `MMATile<AccumType, TQ, TD, MMAFrag_acc_t>
      // Otile` at /opt/mlx-native/src/shaders/flash_attn_prefill.metal:1277
      // where `AccumType = float`.
      simdgroup_matrix<float, 8, 8> lo[NO];

      // Load this simdgroup's NO output frags from `so` (f32).
      {
        threadgroup float* sot = so + 8 * sgitg;
        FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
          simdgroup_load(lo[ii], sot, PV, 0, false);
          sot += 8 * NSG;
        }
      }

      // Walk all C/8 = 8 KV-col frag positions.  V is read directly from
      // device memory with per-simdgroup offset along the DV axis.
      //
      // For DV > 64 (our case): load 2 score frags per cc iteration
      // (vs[2]), run 4 8×8 matmuls per ii (2 V-frag loads × 2 score-frag
      // multiplies).  This is the :6220-6245 branch in llama.cpp.
      {
        const device T* pv = v_head + (ulong)ic * (ulong)NS20
                                    + (ulong)sgitg * 8;  // per-simdgroup DV offset

        constexpr short NC = (C / 8) / 2;  // 4

        FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
          simdgroup_matrix<float, 8, 8> vs[2];
          simdgroup_load(vs[0], ss + 16 * cc + 0, SH, 0, false);
          simdgroup_load(vs[1], ss + 16 * cc + 8, SH, 0, false);

          FOR_UNROLL (short ii = 0; ii < NO / 2; ++ii) {
            simdgroup_matrix<T, 8, 8> mv[4];

            simdgroup_load(mv[0], pv + 0 * NSG + 16 * ii * NSG + 0 * 8 * NS20, NS20, 0, false);
            simdgroup_load(mv[1], pv + 8 * NSG + 16 * ii * NSG + 0 * 8 * NS20, NS20, 0, false);
            simdgroup_load(mv[2], pv + 0 * NSG + 16 * ii * NSG + 1 * 8 * NS20, NS20, 0, false);
            simdgroup_load(mv[3], pv + 8 * NSG + 16 * ii * NSG + 1 * 8 * NS20, NS20, 0, false);

            simdgroup_multiply_accumulate(lo[2 * ii + 0], vs[0], mv[0], lo[2 * ii + 0]);
            simdgroup_multiply_accumulate(lo[2 * ii + 1], vs[0], mv[1], lo[2 * ii + 1]);
            simdgroup_multiply_accumulate(lo[2 * ii + 0], vs[1], mv[2], lo[2 * ii + 0]);
            simdgroup_multiply_accumulate(lo[2 * ii + 1], vs[1], mv[3], lo[2 * ii + 1]);
          }

          pv += 2 * 8 * NS20;
        }
      }

      // Store this simdgroup's output frags back to `so` (f32).
      {
        threadgroup float* sot = so + 8 * sgitg;
        FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
          simdgroup_store(lo[ii], sot, PV, 0, false);
          sot += 8 * NSG;
        }
      }
    }

    threadgroup_barrier(mem_flags::mem_threadgroup);
  }

  // ── Final output write (ggml-metal.metal:6349-6371) ────────────────────
  //
  // Each simdgroup writes its NQ owned Q rows to device O, dividing by
  // S[jj] with the finite-sentinel guard: if S==0 (all masked), scale=0
  // so the output row is all-zero — matches llama.cpp:6358 and our D=256
  // DivOp.  Otherwise, scale = 1/S[jj].
  FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
    const short j = jj * NSG + sgitg;

    if ((int)(iq1 + j) >= args.qL) {
      continue;  // past qL tail (align_Q=false case)
    }

    device T* o_row = o_head + (ulong)(iq1 + j) * (ulong)args.O_strides[2];

    const float scale = (S[jj] == 0.0f) ? 0.0f : 1.0f / S[jj];

    // so is f32 — no pre-promotion needed.  Cast to T at the final store
    // is the ONLY bf16 round per element (llama.cpp writes f32 out here
    // but our dispatcher contract is bf16; this matches sdpa_bf16's
    // single-bf16-round output convention).  DV = 512; 512 / NW = 16
    // writes per lane.
    FOR_UNROLL (short ii = 0; ii < DV / NW; ++ii) {
      const short i = ii * NW + tiisg;
      o_row[i] = (T)(so[j * PV + i] * scale);
    }
  }
}

// ──────────────────────────────────────────────────────────────────────────
// Outer kernel wrapper — dispatches to the NSG-specialised impl via the int
// function constant `fc_nsg`.  Mirrors llama.cpp's `kernel_flash_attn_ext`
// at ggml-metal.metal:6405-6430 (switch on FC_flash_attn_ext_nsg).
// ──────────────────────────────────────────────────────────────────────────

template <typename T, typename MaskT>
[[kernel, max_total_threads_per_threadgroup(32 * 8)]]
void flash_attn_prefill_d512(
    const device char* q_base    [[buffer(0)]],
    const device char* k_base    [[buffer(1)]],
    const device char* v_base    [[buffer(2)]],
    device       char* o_base    [[buffer(3)]],
    const constant AttnParams& args [[buffer(4)]],
    const constant AttnMaskParams& mask_args [[buffer(5), function_constant(has_mask)]],
    const device MaskT* mask_base [[buffer(6), function_constant(has_mask)]],
    const device char*  blk_base  [[buffer(7), function_constant(has_blk)]],
    threadgroup  half* shmem_f16 [[threadgroup(0)]],
    uint3  tgpig [[threadgroup_position_in_grid]],
    ushort tiisg [[thread_index_in_simdgroup]],
    ushort sgitg [[simdgroup_index_in_threadgroup]]) {

  // ── NSG dispatch ──────────────────────────────────────────────────────
  //
  // The switch is resolved at pipeline-creation time via function-constant
  // specialisation, so only the selected branch is emitted in the compiled
  // pipeline.  Cases 1 and 2 are omitted (unused by our dispatcher; matches
  // llama.cpp's commented-out cases at :6423-6424).
  switch (nsg_def) {
    case 4:
      flash_attn_prefill_d512_impl<T, MaskT, 4>(
          q_base, k_base, v_base, o_base, args,
          mask_args, mask_base, blk_base, shmem_f16, tgpig, tiisg, sgitg);
      break;
    case 8:
      flash_attn_prefill_d512_impl<T, MaskT, 8>(
          q_base, k_base, v_base, o_base, args,
          mask_args, mask_base, blk_base, shmem_f16, tgpig, tiisg, sgitg);
      break;
    default:
      // No-op on unsupported NSG.  The dispatcher only ever sets 4 or 8.
      break;
  }
}

// ──────────────────────────────────────────────────────────────────────────
// Host-visible instantiations
// ──────────────────────────────────────────────────────────────────────────
//
// Four entry points — (bf16/f16) × (additive-T mask / bool mask).  NSG is
// NOT in the entry-point name; it is specialised via the int function
// constant at pipeline-creation time.  Matches llama.cpp's design
// (ggml-metal.metal:6510-6511 — one entry per (dtype, DK, DV), NSG chosen
// at FC-time).

#define instantiate_d512(name, T, MaskT) \
  template [[host_name(name)]] [[kernel]] \
  decltype(flash_attn_prefill_d512<T, MaskT>) \
  flash_attn_prefill_d512<T, MaskT>;

instantiate_d512("flash_attn_prefill_llamacpp_bf16_d512",          bfloat, bfloat)
instantiate_d512("flash_attn_prefill_llamacpp_bf16_d512_boolmask", bfloat, bool)
instantiate_d512("flash_attn_prefill_llamacpp_f16_d512",           half,   half)
instantiate_d512("flash_attn_prefill_llamacpp_f16_d512_boolmask",  half,   bool)