Skip to main content

mlx_native/ops/
flash_attn_prefill.rs

1//! Flash-attention-style tiled prefill kernel — host dispatch.
2//!
3//! mlx-native's batched-prefill SDPA kernel, the prefill counterpart to
4//! `flash_attn_vec` (which handles the seq_len=1 decode case).  Implements
5//! online softmax + simdgroup MMA tiling on Apple GPU.
6//!
7//! ## Kernel variants registered
8//!
9//! Twelve entry points are registered, all backed by the single
10//! `flash_attn_prefill.metal` shader source:
11//!
12//! ### D=64 (BQ=32, BK=16, WM=4, WN=1 — 128 threads/threadgroup, BERT family)
13//!
14//! | Kernel name | I/O dtype | Mask kind |
15//! |---|---|---|
16//! | `flash_attn_prefill_bf16_d64`            | bf16 | bf16 additive |
17//! | `flash_attn_prefill_bf16_d64_boolmask`   | bf16 | bool |
18//! | `flash_attn_prefill_f16_d64`             | f16  | f16 additive |
19//! | `flash_attn_prefill_f16_d64_boolmask`    | f16  | bool |
20//!
21//! ### D=256 (BQ=32, BK=16, WM=4, WN=1 — 128 threads/threadgroup)
22//!
23//! | Kernel name | I/O dtype | Mask kind |
24//! |---|---|---|
25//! | `flash_attn_prefill_bf16_d256`           | bf16 | bf16 additive (log-domain) |
26//! | `flash_attn_prefill_bf16_d256_boolmask`  | bf16 | bool (`is_attended`) |
27//! | `flash_attn_prefill_f16_d256`            | f16  | f16 additive |
28//! | `flash_attn_prefill_f16_d256_boolmask`   | f16  | bool |
29//!
30//! ### D=512 (BQ=8, BK=8, WM=1, WN=1 — 32 threads/threadgroup, 1 simdgroup)
31//!
32//! | Kernel name | I/O dtype | Mask kind |
33//! |---|---|---|
34//! | `flash_attn_prefill_bf16_d512`           | bf16 | bf16 additive |
35//! | `flash_attn_prefill_bf16_d512_boolmask`  | bf16 | bool |
36//! | `flash_attn_prefill_f16_d512`            | f16  | f16 additive |
37//! | `flash_attn_prefill_f16_d512_boolmask`   | f16  | bool |
38//!
39//! ### f32 is NOT instantiated at any D
40//!
41//! The f32 Qs threadgroup tile alone is `BQ * BD * 4` bytes — at D=256 this
42//! is 32 KB exactly, the Apple Silicon `MTLDevice.maxThreadgroupMemoryLength`
43//! hard limit, before KV_smem or scratch.  Verified empirically on M5 Max:
44//! f32 D=256 requires ~53.7 KB and fails at library compile.  bf16 and f16
45//! halve the tile footprint (~29 KB) and fit within the limit.  D=512 f32
46//! is excluded for the same reason.  D=64 inherits the same exclusion for
47//! consistency (bf16/f16 only across the entire dispatcher family).  f32
48//! correctness is verified at the CPU reference layer in
49//! `tests/test_flash_attn_prefill.rs`.  See
50//! `ADR-011-phase1-port-source-decision.md` §3 for the full
51//! threadgroup-memory analysis.
52//!
53//! ## Function constants
54//!
55//! The kernel declares four Metal function constants that must be specialised
56//! at pipeline creation time (not at dispatch time):
57//!
58//! - Index 200: `align_Q` (bool) — true when `qL % BQ == 0`
59//! - Index 201: `align_K` (bool) — true when `kL % BK == 0`
60//! - Index 300: `has_mask` (bool) — true when a mask buffer is bound
61//! - Index 301: `do_causal` (bool) — true for in-kernel causal masking
62//!
63//! These are plumbed via [`KernelRegistry::get_pipeline_with_bool_constants`],
64//! which caches compiled pipelines keyed by `(kernel_name, align_Q, align_K,
65//! has_mask, do_causal)`.  Pipeline compilation is amortised: the slow path
66//! runs only once per unique `(name, booleans)` combination.
67//!
68//! ## Buffer layout (indices match the MSL kernel)
69//!
70//! - `buffer(0)` — Q `[B, H,    qL, D]`  device, contiguous inner dim
71//! - `buffer(1)` — K `[B, H_kv, kL, D]`  device, contiguous inner dim
72//! - `buffer(2)` — V `[B, H_kv, kL, D]`  device, contiguous inner dim
73//! - `buffer(3)` — O `[B, H,    qL, D]`  device, written by kernel
74//! - `buffer(4)` — `AttnParams` constant buffer (this module's [`AttnParamsGpu`])
75//! - `buffer(5)` — `AttnMaskParams` constant buffer (only when `has_mask=true`)
76//! - `buffer(6)` — mask data buffer (only when `has_mask=true`)
77//!
78//! ## Grid geometry
79//!
80//! - Threadgroups: `(ceil(qL / BQ), H, B)`
81//! - Threads per threadgroup: `(32, WM, WN)`
82//! - D=256: 128 threads (4 simdgroups × 32 lanes).
83//! - D=512:  32 threads (1 simdgroup  × 32 lanes).
84//!
85//! ## Scale convention
86//!
87//! Pass `scale = 1.0 / sqrt(head_dim)`.  The kernel multiplies internally by
88//! `log2(e) ≈ 1.44269504` and uses `fast::exp2` throughout — so the host
89//! MUST NOT pre-multiply by `log2(e)`.
90//!
91//! ## Mask-sentinel contract (llama.cpp convention)
92//!
93//! The additive mask buffer (bf16/f16 for the additive dispatchers) uses
94//! the llama.cpp CPU-side convention: **masked positions = `-INFINITY`**
95//! (IEEE-754 f32 `0xFF800000`, cast to the I/O dtype — both `half(-inf)`
96//! and `bfloat16(-inf)` have a real `-inf` encoding that the kernel
97//! consumes correctly).  Attended positions = `0.0`.
98//!
99//! This matches llama.cpp's mask-authoring sites at
100//! `llama-graph.cpp:421, 436, 557` and `llama-kv-cache.cpp:1572`, where
101//! the CPU writes raw `-INFINITY` and the flash-attn cast to f16 saturates
102//! to f16-`-INFINITY`.
103//!
104//! The kernel does NOT require masks to use `-FLT_MAX/2` or any other
105//! finite "large negative" sentinel.  The `-FLT_MAX/2` value is the
106//! kernel-internal `M` (running row-max) initialiser — a finite sentinel
107//! that absorbs `-inf` scores via `simd_max` without ever letting `M`
108//! become `-inf`, which in turn lets every `exp(score - M)` evaluate as
109//! `exp(-inf) = 0.0` (IEEE-754 exact) rather than `exp(-inf - -inf) =
110//! exp(NaN) = NaN`.  See ADR-011-phase2-port-sentinel.md §1.
111//!
112//! Callers may pass arbitrary finite additive biases in the mask for
113//! non-masked positions (e.g. ALiBi, relative-position biases); only
114//! "fully block this K position" requires `-inf`.
115//!
116//! ## See also
117//!
118//! - Kernel: `/opt/mlx-native/src/shaders/flash_attn_prefill.metal`
119//! - ADR-011: `/opt/hf2q/docs/ADR-011-flash-attn-prefill.md`
120//! - ADR-011 phase 2 sentinel port: `/opt/hf2q/docs/ADR-011-phase2-port-sentinel.md`
121
122use metal::MTLSize;
123
124use crate::buffer::MlxBuffer;
125use crate::device::MlxDevice;
126use crate::encoder::{CapturedOpKind, CommandEncoder, KernelArg, as_bytes};
127use crate::error::{MlxError, Result};
128use crate::kernel_registry::KernelRegistry;
129use crate::DType;
130
131// ─── Shader source ───────────────────────────────────────────────────────────
132
133/// MSL source for the flash-attention prefill kernel (embedded at compile time).
134pub static FLASH_ATTN_PREFILL_SHADER_SOURCE: &str =
135    include_str!("../shaders/flash_attn_prefill.metal");
136
137// ─── All 12 kernel entry-point names ─────────────────────────────────────────
138
139/// D=256, bf16 I/O, bf16 additive mask.
140const K_BF16_D256: &str = "flash_attn_prefill_bf16_d256";
141/// D=256, bf16 I/O, bool (`is_attended`) mask.
142const K_BF16_D256_BOOLMASK: &str = "flash_attn_prefill_bf16_d256_boolmask";
143/// D=256, f16 I/O, f16 additive mask.
144const K_F16_D256: &str = "flash_attn_prefill_f16_d256";
145/// D=256, f16 I/O, bool mask.
146const K_F16_D256_BOOLMASK: &str = "flash_attn_prefill_f16_d256_boolmask";
147/// D=512, bf16 I/O, bf16 additive mask.
148const K_BF16_D512: &str = "flash_attn_prefill_bf16_d512";
149/// D=512, bf16 I/O, bool mask.
150const K_BF16_D512_BOOLMASK: &str = "flash_attn_prefill_bf16_d512_boolmask";
151/// D=512, f16 I/O, f16 additive mask.
152const K_F16_D512: &str = "flash_attn_prefill_f16_d512";
153/// D=512, f16 I/O, bool mask.
154const K_F16_D512_BOOLMASK: &str = "flash_attn_prefill_f16_d512_boolmask";
155/// D=64, bf16 I/O, bf16 additive mask (BERT family).
156const K_BF16_D64: &str = "flash_attn_prefill_bf16_d64";
157/// D=64, bf16 I/O, bool (`is_attended`) mask.
158const K_BF16_D64_BOOLMASK: &str = "flash_attn_prefill_bf16_d64_boolmask";
159/// D=64, f16 I/O, f16 additive mask.
160const K_F16_D64: &str = "flash_attn_prefill_f16_d64";
161/// D=64, f16 I/O, bool mask.
162const K_F16_D64_BOOLMASK: &str = "flash_attn_prefill_f16_d64_boolmask";
163
164/// All 12 kernel entry-point names exported by `flash_attn_prefill.metal`.
165///
166/// Registering all of them against the single shader source costs nothing at
167/// registration time (source is a static `&str`) and ensures additional
168/// dispatchers (f16 paths, D=512 exposure) can be added later without
169/// touching registration here.
170const ALL_KERNEL_NAMES: &[&str] = &[
171    K_BF16_D256,
172    K_BF16_D256_BOOLMASK,
173    K_F16_D256,
174    K_F16_D256_BOOLMASK,
175    K_BF16_D512,
176    K_BF16_D512_BOOLMASK,
177    K_F16_D512,
178    K_F16_D512_BOOLMASK,
179    K_BF16_D64,
180    K_BF16_D64_BOOLMASK,
181    K_F16_D64,
182    K_F16_D64_BOOLMASK,
183];
184
185// ─── Registration ─────────────────────────────────────────────────────────────
186
187/// Register all flash-attention prefill kernel entry points with the registry.
188///
189/// Maps all 12 entry-point names to the single `flash_attn_prefill.metal`
190/// source.  This must be called before any dispatch to these kernels.
191///
192/// # Design note
193///
194/// `KernelRegistry` compiles one Metal library per kernel name.  All 12 names
195/// point at the same source text, so the Metal compiler sees the same ~1 500-line
196/// source each time — compilation is amortised in `KernelRegistry::get_pipeline`
197/// (first call per name triggers compilation; subsequent calls return the cached
198/// pipeline).  Registering all 12 here rather than only the Phase 1a subset
199/// means Phase 2/4 dispatcher functions can be added without touching this file.
200pub fn register(registry: &mut KernelRegistry) {
201    for &name in ALL_KERNEL_NAMES {
202        registry.register_source(name, FLASH_ATTN_PREFILL_SHADER_SOURCE);
203    }
204}
205
206// ─── MSL struct mirrors ───────────────────────────────────────────────────────
207
208/// Rust mirror of the MSL `AttnParams` struct.
209///
210/// Field order and types match the MSL definition exactly.
211/// MSL source: `flash_attn_prefill.metal` — see the `AttnParams` struct
212/// definition in the kernel source for the field-by-field reference.
213///
214/// # Layout
215///
216/// All `int` fields are 32-bit (i32 in Rust).  The `int64_t` stride arrays are
217/// 64-bit (i64 in Rust).  The compiler inserts natural alignment padding:
218///
219/// ```text
220/// Offset  0:  B            (i32,  4 bytes)
221/// Offset  4:  H            (i32,  4 bytes)
222/// Offset  8:  D            (i32,  4 bytes)
223/// Offset 12:  qL           (i32,  4 bytes)
224/// Offset 16:  kL           (i32,  4 bytes)
225/// Offset 20:  gqa_factor   (i32,  4 bytes)
226/// Offset 24:  scale        (f32,  4 bytes)
227/// Offset 28:  softcapping  (f32,  4 bytes)
228/// Offset 32:  NQ           (i32,  4 bytes)
229/// Offset 36:  NK           (i32,  4 bytes)
230/// Offset 40:  NQ_aligned   (i32,  4 bytes)
231/// Offset 44:  NK_aligned   (i32,  4 bytes)
232/// Offset 48:  qL_rem       (i32,  4 bytes)
233/// Offset 52:  kL_rem       (i32,  4 bytes)
234/// Offset 56:  qL_off       (i32,  4 bytes)
235/// Offset 60:  _pad         (4 bytes — alignment before i64 array)
236/// Offset 64:  Q_strides[3] (3 × i64, 24 bytes)
237/// Offset 88:  K_strides[3] (3 × i64, 24 bytes)
238/// Offset 112: V_strides[3] (3 × i64, 24 bytes)
239/// Offset 136: O_strides[3] (3 × i64, 24 bytes)
240/// Total: 160 bytes
241/// ```
242///
243/// `bytemuck::Pod` / `bytemuck::Zeroable` are derived — the struct must have
244/// no uninitialized padding bytes.  The explicit `_pad` field makes the padding
245/// concrete so Pod can be derived safely.
246///
247/// `softcapping` is always set to `1.0` (disabled).  The `attention<>` kernel
248/// body does not read it; the field exists for ABI parity with attention
249/// implementations that thread softcapping through the same param block
250/// (e.g. Gemma-style logit softcap), so we don't have to redo the layout
251/// when that work lands.
252#[repr(C)]
253#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
254pub struct AttnParamsGpu {
255    /// Batch size.
256    pub b: i32,
257    /// Number of query heads.
258    pub h: i32,
259    /// Head dimension (D).
260    pub d: i32,
261    /// Query sequence length.
262    pub ql: i32,
263    /// Key/value sequence length.
264    pub kl: i32,
265    /// Group-query attention factor: H / H_kv.
266    pub gqa_factor: i32,
267    /// Attention scale (= 1.0 / sqrt(head_dim); kernel multiples by log2(e)).
268    pub scale: f32,
269    /// Softcapping value — always 1.0 (disabled) for standard SDPA.
270    pub softcapping: f32,
271    /// Number of Q tiles: ceil(qL / BQ).
272    pub nq: i32,
273    /// Number of KV tiles: ceil(kL / BK).
274    pub nk: i32,
275    /// Number of full (aligned) Q tiles: qL / BQ.
276    pub nq_aligned: i32,
277    /// Number of full (aligned) KV tiles: kL / BK.
278    pub nk_aligned: i32,
279    /// Remainder elements in the last Q tile: qL % BQ (0 if aligned).
280    pub ql_rem: i32,
281    /// Remainder elements in the last KV tile: kL % BK (0 if aligned).
282    pub kl_rem: i32,
283    /// Query sequence start offset (0 for standard prefill).
284    pub ql_off: i32,
285    /// Explicit padding to align the subsequent i64 arrays to 8-byte boundary.
286    pub _pad: i32,
287    /// Query strides: (batch stride, head stride, seq stride).  Inner dim = 1.
288    pub q_strides: [i64; 3],
289    /// Key strides: (batch stride, head stride, seq stride).  Inner dim = 1.
290    pub k_strides: [i64; 3],
291    /// Value strides: (batch stride, head stride, seq stride).  Inner dim = 1.
292    pub v_strides: [i64; 3],
293    /// Output strides: (batch stride, head stride, seq stride).  Inner dim = 1.
294    pub o_strides: [i64; 3],
295}
296
297/// Rust mirror of the MSL `AttnMaskParams` struct.
298///
299/// MSL source: `flash_attn_prefill.metal` — see the `AttnMaskParams` struct
300/// definition in the kernel source.
301///
302/// Contains the mask buffer strides.  Only sent to the kernel when
303/// `has_mask = true` (buffer index 5).
304#[repr(C)]
305#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
306pub struct AttnMaskParamsGpu {
307    /// Mask strides: (batch stride, head stride, qL stride).  Inner dim = 1.
308    pub m_strides: [i64; 3],
309}
310
311// ─── Public Rust-side parameter struct ───────────────────────────────────────
312
313/// Host-side parameters for the flash-attention prefill dispatcher.
314///
315/// Used only by the Rust dispatcher; does not map 1:1 to the GPU struct —
316/// the GPU struct is computed from these fields inside the dispatcher.
317#[derive(Debug, Clone, Copy)]
318pub struct FlashAttnPrefillParams {
319    /// Number of query attention heads.
320    pub n_heads: u32,
321    /// Number of key/value attention heads (GQA: may be < n_heads).
322    pub n_kv_heads: u32,
323    /// Head dimension.  Must be 256 for `dispatch_flash_attn_prefill_bf16_d256`.
324    pub head_dim: u32,
325    /// Query sequence length.
326    pub seq_len_q: u32,
327    /// Key/value sequence length.
328    pub seq_len_k: u32,
329    /// Batch size.
330    pub batch: u32,
331    /// Attention scale.  Typically `1.0 / sqrt(head_dim)`.
332    ///
333    /// The kernel internally multiplies this by `log2(e) = 1.44269504089`
334    /// before applying it to Q.  The host MUST NOT pre-multiply by log2(e).
335    pub scale: f32,
336    /// Whether to apply in-kernel causal masking (`do_causal` function constant).
337    ///
338    /// When true, positions where `row_pos < col_pos` receive a score of -inf
339    /// before softmax.  This can be combined with an external mask buffer.
340    pub do_causal: bool,
341}
342
343// ─── Tile geometry constants (D=256) ─────────────────────────────────────────
344
345/// Q tile size for D=256: BQ=32.
346const BQ_D256: u32 = 32;
347
348/// KV tile size for D=256: BK=16.
349const BK_D256: u32 = 16;
350
351/// Simdgroups along Q dimension for D=256: WM=4.
352const WM_D256: u32 = 4;
353
354/// Simdgroups along K dimension for D=256: WN=1.
355const WN_D256: u32 = 1;
356
357// ─── Tile geometry constants (D=64) ──────────────────────────────────────────
358//
359// Same simdgroup geometry as D=256: 4 simdgroups along Q, 1 along K, with
360// BQ=32 / BK=16.  Threadgroup-memory math at bf16:
361//   Qs  = BQ × (BD + padQ) × 2   = 32 × (64 + 8) × 2  = 4 608 bytes
362//   KVs = max(BK*(BD+padV), (BK+padK)*BD) × 2
363//                                 = max(16×72, 24×64) × 2 = max(1152, 1536) × 2
364//                                 = 3 072 bytes
365//   total ≈ 7 680 bytes — fits well under the 32 KiB Apple Silicon TG limit.
366//
367// Static-asserts required by the kernel (with kFragSize=8):
368//   BQ % (kNWarps × kFragSize) = 32 % (4×8) = 0 ✓
369//   BQ ≥ kNWarps × kFragSize   = 32 ≥ 32     ✓
370//   TQ = BQ / (kNWarps × kFragSize) = 1       ✓
371//   TD = BD / kFragSize        = 64/8 = 8     ✓
372//   TK = BK / kFragSize        = 16/8 = 2     ✓
373
374/// Q tile size for D=64: BQ=32.
375const BQ_D64: u32 = 32;
376
377/// KV tile size for D=64: BK=16.
378const BK_D64: u32 = 16;
379
380/// Simdgroups along Q dimension for D=64: WM=4.
381const WM_D64: u32 = 4;
382
383/// Simdgroups along K dimension for D=64: WN=1.
384const WN_D64: u32 = 1;
385
386// ─── Validation ───────────────────────────────────────────────────────────────
387
388fn validate_params(params: &FlashAttnPrefillParams) -> Result<()> {
389    if params.n_heads == 0 {
390        return Err(MlxError::InvalidArgument(
391            "flash_attn_prefill: n_heads must be > 0".into(),
392        ));
393    }
394    if params.n_kv_heads == 0 {
395        return Err(MlxError::InvalidArgument(
396            "flash_attn_prefill: n_kv_heads must be > 0".into(),
397        ));
398    }
399    if params.n_heads % params.n_kv_heads != 0 {
400        return Err(MlxError::InvalidArgument(format!(
401            "flash_attn_prefill: n_heads ({}) must be divisible by n_kv_heads ({})",
402            params.n_heads, params.n_kv_heads
403        )));
404    }
405    if params.seq_len_q == 0 {
406        return Err(MlxError::InvalidArgument(
407            "flash_attn_prefill: seq_len_q must be > 0".into(),
408        ));
409    }
410    if params.seq_len_k == 0 {
411        return Err(MlxError::InvalidArgument(
412            "flash_attn_prefill: seq_len_k must be > 0".into(),
413        ));
414    }
415    if params.batch == 0 {
416        return Err(MlxError::InvalidArgument(
417            "flash_attn_prefill: batch must be > 0".into(),
418        ));
419    }
420    Ok(())
421}
422
423pub(crate) fn validate_buffer_size(buf: &MlxBuffer, name: &str, expected_elements: usize) -> Result<()> {
424    let expected_bytes = expected_elements * buf.dtype().size_of();
425    if buf.byte_len() < expected_bytes {
426        return Err(MlxError::InvalidArgument(format!(
427            "flash_attn_prefill: {name} buffer too small: expected at least \
428             {expected_bytes} bytes, got {}",
429            buf.byte_len()
430        )));
431    }
432    Ok(())
433}
434
435// ─── bf16 D=256 dispatcher ───────────────────────────────────────────────────
436
437/// Dispatch flash-attention prefill for bf16 Q/K/V/O, head_dim=256.
438///
439/// Encodes a compute command into `encoder` without committing.  The caller
440/// controls when to call `encoder.commit_and_wait()`.
441///
442/// # Why bf16 and not f32
443///
444/// The f32 Qs threadgroup tile (BQ×BD×4 = 32 KB at D=256) consumes the entire
445/// Apple Silicon threadgroup-memory budget before KV_smem, scratch, or any
446/// padding — so f32 D=256 is not instantiated (see module doc).  bf16/f16
447/// halve the tile footprint; the MMA accumulator is still f32 internally
448/// (the kernel's `T_accum` template parameter is `float` for every
449/// instantiation — see `flash_attn_prefill.metal:~1504`), so prefill output
450/// precision is `bf16 × bf16 → f32 → bf16` — bf16-bounded at the store, not
451/// at the accumulator.
452///
453/// # Buffer layouts
454///
455/// All buffers must be contiguous (stride-1 along the innermost / head_dim
456/// dimension):
457///
458/// - `q`    — `[batch, n_heads,    seq_len_q, 256]`, dtype BF16
459/// - `k`    — `[batch, n_kv_heads, seq_len_k, 256]`, dtype BF16
460/// - `v`    — `[batch, n_kv_heads, seq_len_k, 256]`, dtype BF16
461/// - `mask` — `[batch, n_heads, seq_len_q, seq_len_k]`, dtype BF16
462///   (additive, log-scale: 0.0 = attend, -inf = mask out — llama.cpp
463///   convention, see module doc "Mask-sentinel contract"), or `None`
464/// - `out`  — `[batch, n_heads,    seq_len_q, 256]`, dtype BF16 (output)
465///
466/// # Function constants
467///
468/// `align_Q`, `align_K` are computed from the sequence lengths and tile sizes.
469/// `has_mask` reflects whether `mask` is `Some(_)`.
470/// `do_causal` is taken from `params.do_causal`.
471///
472/// A distinct Metal pipeline is compiled for each unique combination of these
473/// four booleans and cached in `registry`.
474///
475/// # Errors
476///
477/// Returns `MlxError::InvalidArgument` for:
478/// - `head_dim != 256`
479/// - Zero or inconsistent shape fields
480/// - Buffer too small for the declared shape
481/// - `n_heads` not divisible by `n_kv_heads`
482/// - Any buffer dtype != BF16
483///
484/// Returns `MlxError::ShaderCompilationError` if the Metal pipeline
485/// compilation fails.
486#[allow(clippy::too_many_arguments)]
487pub fn dispatch_flash_attn_prefill_bf16_d256(
488    encoder: &mut CommandEncoder,
489    device: &MlxDevice,
490    registry: &mut KernelRegistry,
491    q: &MlxBuffer,
492    k: &MlxBuffer,
493    v: &MlxBuffer,
494    mask: Option<&MlxBuffer>,
495    out: &MlxBuffer,
496    params: &FlashAttnPrefillParams,
497) -> Result<()> {
498    // Delegate to the blk-aware dispatcher with blk=None.  Because
499    // `has_blk` is a function constant (index 303), the compiled pipeline
500    // with has_blk=false dead-codes every blk reference — this call has
501    // zero runtime overhead compared with the pre-Wave-2E code path.
502    dispatch_flash_attn_prefill_bf16_d256_with_blk(
503        encoder, device, registry, q, k, v, mask, None, out, params,
504    )
505}
506
507/// Dispatch flash-attention prefill for bf16 Q/K/V/O, head_dim=256, with an
508/// optional Wave 2E tile-skip pre-pass byte buffer.
509///
510/// This is the blk-aware sibling of [`dispatch_flash_attn_prefill_bf16_d256`].
511/// When `blk` is `Some(buf)`, the kernel reads a classification byte per
512/// `(qtile, ktile)` and:
513///
514/// - On `blk[qt][kt] == 0`, skips the entire KV tile (no K/V load, no
515///   Q·K^T, no mask-add, no softmax update).
516/// - On `blk[qt][kt] == 2`, skips the mask-add (but still computes Q·K^T
517///   and softmax normally).
518/// - On `blk[qt][kt] == 1`, runs the standard path.
519///
520/// The `blk` buffer must be produced by
521/// [`crate::ops::flash_attn_prefill_blk::dispatch_flash_attn_prefill_blk`]
522/// with the SAME `(BQ=32, BK=16)` tile shape as this dispatcher, and the
523/// caller MUST sequence the pre-pass dispatch BEFORE this one on the same
524/// command encoder (or commit the pre-pass first).
525///
526/// # Correctness invariant
527///
528/// For any valid (mask, built blk), calling this function with `blk=None`
529/// vs `blk=Some(built_blk)` MUST produce bit-exact identical output.  The
530/// blk path is a pure skip optimisation.  Exercised by
531/// `test_gpu_bf16_d256_with_blk_matches_no_blk` in
532/// `tests/test_flash_attn_prefill.rs`.
533///
534/// # Buffer layout
535///
536/// All buffers as in [`dispatch_flash_attn_prefill_bf16_d256`], plus:
537///
538/// - `blk` — `[ceil(seq_len_q / 32), ceil(seq_len_k / 16)]`, dtype U8.
539///   Required iff `Some(_)`; must have at least
540///   `ceil(qL/32) * ceil(kL/16)` bytes.  When `None`, the dispatcher
541///   compiles with `has_blk=false` and does not bind buffer index 7.
542///
543/// # Errors
544///
545/// Same as [`dispatch_flash_attn_prefill_bf16_d256`], plus:
546/// - `blk` is `Some(b)` with `mask = None` (a blk without a mask is
547///   meaningless — rejected to catch caller bugs).
548/// - `blk` buffer undersized.
549#[allow(clippy::too_many_arguments)]
550pub fn dispatch_flash_attn_prefill_bf16_d256_with_blk(
551    encoder: &mut CommandEncoder,
552    device: &MlxDevice,
553    registry: &mut KernelRegistry,
554    q: &MlxBuffer,
555    k: &MlxBuffer,
556    v: &MlxBuffer,
557    mask: Option<&MlxBuffer>,
558    blk: Option<&MlxBuffer>,
559    out: &MlxBuffer,
560    params: &FlashAttnPrefillParams,
561) -> Result<()> {
562    // ── Validate ──────────────────────────────────────────────────────────
563    if params.head_dim != 256 {
564        return Err(MlxError::InvalidArgument(format!(
565            "dispatch_flash_attn_prefill_bf16_d256: head_dim must be 256, got {}",
566            params.head_dim
567        )));
568    }
569    // Reject the meaningless has_blk=true, has_mask=false case: a blk
570    // buffer classifies the contents of the mask; without a mask the blk
571    // bytes are computed against undefined data.  This is a caller-bug
572    // defence, not a shader limitation.
573    if blk.is_some() && mask.is_none() {
574        return Err(MlxError::InvalidArgument(
575            "dispatch_flash_attn_prefill_bf16_d256_with_blk: \
576             blk requires mask (a blk without a mask is meaningless)"
577                .into(),
578        ));
579    }
580    validate_params(params)?;
581
582    // All buffers must be BF16 for this dispatcher.
583    for (buf, name) in &[(q, "Q"), (k, "K"), (v, "V"), (out as &MlxBuffer, "out")] {
584        if buf.dtype() != DType::BF16 {
585            return Err(MlxError::InvalidArgument(format!(
586                "dispatch_flash_attn_prefill_bf16_d256: {name} buffer must be BF16, \
587                 got {:?}",
588                buf.dtype()
589            )));
590        }
591    }
592    if let Some(m) = mask {
593        if m.dtype() != DType::BF16 {
594            return Err(MlxError::InvalidArgument(format!(
595                "dispatch_flash_attn_prefill_bf16_d256: mask buffer must be BF16, \
596                 got {:?}",
597                m.dtype()
598            )));
599        }
600    }
601
602    let batch = params.batch as usize;
603    let h = params.n_heads as usize;
604    let h_kv = params.n_kv_heads as usize;
605    let ql = params.seq_len_q as usize;
606    let kl = params.seq_len_k as usize;
607    let d = params.head_dim as usize; // = 256
608
609    // Validate buffer element counts.
610    validate_buffer_size(q, "Q", batch * h * ql * d)?;
611    validate_buffer_size(k, "K", batch * h_kv * kl * d)?;
612    validate_buffer_size(v, "V", batch * h_kv * kl * d)?;
613    validate_buffer_size(out, "out", batch * h * ql * d)?;
614    // A rank-2 mask `[qL, kL]` is the Wave 2D broadcast layout: one plane is
615    // shared across all batches and heads (stride-0 in the batch and head dims).
616    // A rank-4 mask `[B, H, qL, kL]` is the per-head layout used by callers
617    // that already have a fully-expanded mask (e.g. pre-Wave-2D code paths).
618    let mask_is_rank2_broadcast = mask.is_some_and(|m| m.shape().len() == 2);
619    if let Some(m) = mask {
620        if mask_is_rank2_broadcast {
621            validate_buffer_size(m, "mask", ql * kl)?;
622        } else {
623            validate_buffer_size(m, "mask", batch * h * ql * kl)?;
624        }
625    }
626
627    // ── Tile geometry ─────────────────────────────────────────────────────
628    let bq = BQ_D256;
629    let bk = BK_D256;
630    let wm = WM_D256;
631    let wn = WN_D256;
632
633    let nq = params.seq_len_q.div_ceil(bq);
634    let nk = params.seq_len_k.div_ceil(bk);
635    let nq_aligned = params.seq_len_q / bq;
636    let nk_aligned = params.seq_len_k / bk;
637    let ql_rem = params.seq_len_q % bq;
638    let kl_rem = params.seq_len_k % bk;
639
640    // Function constants (specialised at pipeline creation time, not dispatch).
641    let align_q = ql_rem == 0;
642    let align_k = kl_rem == 0;
643    let has_mask = mask.is_some();
644    let has_blk = blk.is_some();
645    let do_causal = params.do_causal;
646
647    // Validate blk buffer size when present.  Tile shape is fixed for D=256:
648    // BQ=32, BK=16 — see ADR-011-phase2-port-tile-skip.md §5.1.
649    if let Some(b) = blk {
650        let nq_tiles = ql.div_ceil(BQ_D256 as usize);
651        let nk_tiles = kl.div_ceil(BK_D256 as usize);
652        let expected = nq_tiles * nk_tiles;
653        if b.byte_len() < expected {
654            return Err(MlxError::InvalidArgument(format!(
655                "dispatch_flash_attn_prefill_bf16_d256_with_blk: blk buffer \
656                 too small: expected at least {expected} bytes (NQ={nq_tiles}, \
657                 NK={nk_tiles}), got {}",
658                b.byte_len()
659            )));
660        }
661    }
662
663    // ── Kernel name ───────────────────────────────────────────────────────
664    // bf16 I/O, bf16 additive mask (or no mask — uses same pipeline since
665    // has_mask is a function constant, not part of the name).
666    let kernel_name = K_BF16_D256;
667
668    // ── Pipeline lookup (with function constants) ─────────────────────────
669    //
670    // Wave 2E adds function constant 303 (has_blk).  When has_blk=false
671    // every blk reference in the shader is dead-coded, so pipelines with
672    // the pre-Wave-2E constants {200, 201, 300, 301} + {303: false}
673    // generate the same machine code as the pre-Wave-2E pipelines.  The
674    // only cache-key overhead is one extra "b0" suffix, which is paid
675    // once per (aligned, causal, masked) combo at compile time.
676    let pipeline = registry.get_pipeline_with_bool_constants(
677        kernel_name,
678        device.metal_device(),
679        &[
680            (200, align_q),
681            (201, align_k),
682            (300, has_mask),
683            (301, do_causal),
684            (303, has_blk),
685        ],
686    )?;
687
688    // ── Build AttnParams GPU struct ───────────────────────────────────────
689    //
690    // Strides for layout [B, H, L, D] where the innermost (D) stride is 1
691    // (contiguous):
692    //
693    //   seq stride  = D
694    //   head stride = L * D
695    //   batch stride = H * L * D
696    //
697    // Q/O use (H, qL); K/V use (H_kv, kL).
698
699    let q_seq_stride = d as i64;
700    let q_head_stride = (ql * d) as i64;
701    let q_batch_stride = (h * ql * d) as i64;
702
703    let kv_seq_stride = d as i64;
704    let kv_head_stride = (kl * d) as i64;
705    let kv_batch_stride = (h_kv * kl * d) as i64;
706
707    let gqa_factor = (params.n_heads / params.n_kv_heads) as i32;
708
709    let attn_params = AttnParamsGpu {
710        b: params.batch as i32,
711        h: params.n_heads as i32,
712        d: params.head_dim as i32,
713        ql: params.seq_len_q as i32,
714        kl: params.seq_len_k as i32,
715        gqa_factor,
716        scale: params.scale,
717        softcapping: 1.0_f32,  // always disabled; see module doc
718        nq: nq as i32,
719        nk: nk as i32,
720        nq_aligned: nq_aligned as i32,
721        nk_aligned: nk_aligned as i32,
722        ql_rem: ql_rem as i32,
723        kl_rem: kl_rem as i32,
724        ql_off: 0,             // standard prefill starts at offset 0
725        _pad: 0,
726        q_strides: [q_batch_stride, q_head_stride, q_seq_stride],
727        k_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
728        v_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
729        o_strides: [q_batch_stride, q_head_stride, q_seq_stride],
730    };
731
732    // ── Grid geometry ──────────────────────────────────────────────────────
733    //   grid = (NQ, H, B)  where NQ = ceil(qL / BQ)
734    //   threadgroup = (32, WM, WN)
735    let grid = MTLSize::new(nq as u64, params.n_heads as u64, params.batch as u64);
736    let tg_size = MTLSize::new(32, wm as u64, wn as u64);
737
738    // ── Encode ─────────────────────────────────────────────────────────────
739    encoder.set_op_kind(CapturedOpKind::Sdpa);
740
741    if has_mask {
742        // SAFETY: has_mask is true iff mask.is_some() — set three lines above.
743        // The Option is therefore guaranteed to be Some here.  We use
744        // ok_or_else rather than expect/unwrap to satisfy the no-panic policy.
745        let mask_buf = mask.ok_or_else(|| {
746            MlxError::InvalidArgument(
747                "flash_attn_prefill: internal error — has_mask=true but mask is None".into(),
748            )
749        })?;
750
751        // Strides depend on mask rank:
752        //   rank-2 `[qL, kL]` — broadcast across batch + heads: set batch_stride
753        //   and head_stride to 0 so the shader re-reads the same plane for every
754        //   (batch, head) pair.  The Metal shader already handles stride-0 correctly
755        //   (no kernel changes required).
756        //   rank-4 `[B, H, qL, kL]` — per-head layout (back-compat path).
757        let (m_batch_stride, m_head_stride, m_ql_stride) = if mask_is_rank2_broadcast {
758            (0_i64, 0_i64, kl as i64)
759        } else {
760            ((h * ql * kl) as i64, (ql * kl) as i64, kl as i64)
761        };
762
763        let mask_params = AttnMaskParamsGpu {
764            m_strides: [m_batch_stride, m_head_stride, m_ql_stride],
765        };
766
767        if has_blk {
768            // SAFETY: has_blk is true iff blk.is_some() and the earlier
769            // has_blk validation rejects blk.is_some() && mask.is_none().
770            let blk_buf = blk.ok_or_else(|| {
771                MlxError::InvalidArgument(
772                    "flash_attn_prefill: internal error — has_blk=true but blk is None".into(),
773                )
774            })?;
775
776            encoder.encode_threadgroups_with_args(
777                pipeline,
778                &[
779                    (0, KernelArg::Buffer(q)),
780                    (1, KernelArg::Buffer(k)),
781                    (2, KernelArg::Buffer(v)),
782                    (3, KernelArg::Buffer(out)),
783                    (4, KernelArg::Bytes(as_bytes(&attn_params))),
784                    (5, KernelArg::Bytes(as_bytes(&mask_params))),
785                    (6, KernelArg::Buffer(mask_buf)),
786                    (7, KernelArg::Buffer(blk_buf)),
787                ],
788                grid,
789                tg_size,
790            );
791        } else {
792            encoder.encode_threadgroups_with_args(
793                pipeline,
794                &[
795                    (0, KernelArg::Buffer(q)),
796                    (1, KernelArg::Buffer(k)),
797                    (2, KernelArg::Buffer(v)),
798                    (3, KernelArg::Buffer(out)),
799                    (4, KernelArg::Bytes(as_bytes(&attn_params))),
800                    (5, KernelArg::Bytes(as_bytes(&mask_params))),
801                    (6, KernelArg::Buffer(mask_buf)),
802                    // buffer 7 absent — has_blk=false constant dead-codes blk refs.
803                ],
804                grid,
805                tg_size,
806            );
807        }
808    } else {
809        encoder.encode_threadgroups_with_args(
810            pipeline,
811            &[
812                (0, KernelArg::Buffer(q)),
813                (1, KernelArg::Buffer(k)),
814                (2, KernelArg::Buffer(v)),
815                (3, KernelArg::Buffer(out)),
816                (4, KernelArg::Bytes(as_bytes(&attn_params))),
817                // buffers 5, 6, 7 intentionally absent — has_mask=false +
818                // has_blk=false constants dead-code-eliminate mask + blk loads.
819            ],
820            grid,
821            tg_size,
822        );
823    }
824
825    Ok(())
826}
827
828// ─── bf16 D=256 RESUME dispatcher (qL_off > 0 + slot-capacity strides) ──────
829//
830// ADR-017 Phase E.a B.2-fix: extend the FA bf16 d256 fast path to support
831// LCP partial-prefill resume (turn 2 of a multi-turn conversation prefilling
832// only the suffix against a slot that already contains the LCP prefix).
833//
834// The metal shader has supported this regime since Phase 1 port (the
835// `qL_off` field at flash_attn_prefill.metal:1045 + `K_strides[3]` /
836// `V_strides[3]` at lines 1048-1049 are read at lines 1325, 1437, 1445), but
837// the d256 dispatcher hardcoded `qL_off=0` and built K/V strides from
838// `seq_len_k` not `kv_capacity` because no Rust caller needed
839// "prefill-at-offset" semantics — Qwen3.5/3.6 production prefill is always
840// from token 0 (cur_len=0) and decode (cur_len > 0, seq_len=1) routes to
841// `flash_attn_vec` at gpu_full_attn.rs:1733 instead.
842//
843// The legacy F32 SDPA fallback at gpu_full_attn.rs:1900-1916 covered the
844// "prefill-at-offset" case for structural correctness, but with a F32
845// single-pass softmax that produces byte-different output to the BF16 MMA
846// + log-domain online softmax fast path (proven via
847// gpu_full_attn.rs::tests::phase_b2_iso_fast_path_vs_fallback_path_kernel_divergence:
848// 131072/131072 elements differ, max |Δ| = 6.452e-4).
849//
850// This module fixes that: a sibling dispatcher that exposes `qL_off` and
851// `kv_capacity` so the FA bf16 d256 fast path is reachable for cur_len > 0
852// AND for slot reads (slot K/V's head stride is `kv_capacity * head_dim`,
853// not `kL * head_dim`, because a slot may have allocated more positions
854// than the current valid kL).
855
856/// Host-side parameters for the flash-attention prefill **resume** dispatcher.
857///
858/// Differs from [`FlashAttnPrefillParams`] in two ways:
859///
860/// 1. `q_offset_in_k` (mapped to the kernel's `qL_off`): the absolute Q
861///    position of the chunk Q within the larger K/V sequence.  When > 0,
862///    Q is being attended over a slot that already contains
863///    `q_offset_in_k` previous tokens.  The kernel's causal mask uses
864///    `qL_off` to compute `row_pos = tid.x * BQ + qL_off + ...`
865///    (`flash_attn_prefill.metal:1325, 1445`) so the per-chunk causal
866///    pattern stays correct relative to the absolute K/V position.
867///
868/// 2. `kv_capacity` (slot stride): K and V buffers may live in a slot of
869///    capacity ≥ `seq_len_k`.  The kernel reads K/V via integer strides
870///    from `K_strides[3]` / `V_strides[3]`, so we set
871///    `head_stride = kv_capacity * D` to skip unused slot capacity between
872///    heads.  When `kv_capacity == seq_len_k` the layout is identical to
873///    the non-resume dispatcher.
874///
875/// # Use case
876///
877/// ADR-017 Phase E.a LCP partial-prefill resume.  Turn 1 of a multi-turn
878/// conversation prefills the prompt, populates the KV slot, snapshots it.
879/// Turn 2 detects an LCP overlap with turn 1, restores the slot, and
880/// prefills only the suffix of turn 2 (M new tokens) against the full slot
881/// (kL = N + M, qL = M, qL_off = N).  The output is byte-identical to a
882/// fresh full prefill of the entire turn-2 prompt — proven by the parity
883/// unit test
884/// `flash_attn_prefill_bf16_d256_resume_byte_identical_to_monolithic`.
885///
886/// The non-resume dispatcher remains the production path for prefill-from-
887/// zero (`q_offset_in_k == 0` && `kv_capacity == seq_len_k`).  The two
888/// dispatchers compile to the same kernel pipeline (same kernel name +
889/// same function constants when `q_offset_in_k == 0` && `kv_capacity ==
890/// seq_len_k`); the only host-side difference is the strides written into
891/// `AttnParamsGpu`.
892#[derive(Debug, Clone, Copy)]
893pub struct FlashAttnPrefillResumeParams {
894    /// Number of query attention heads.
895    pub n_heads: u32,
896    /// Number of key/value attention heads (GQA).
897    pub n_kv_heads: u32,
898    /// Head dimension.  Must be 256 for `dispatch_flash_attn_prefill_bf16_d256_resume`.
899    pub head_dim: u32,
900    /// Query sequence length (chunk Q only).  qL.
901    pub seq_len_q: u32,
902    /// Total key/value sequence length: `q_offset_in_k + seq_len_q`.  kL.
903    pub seq_len_k: u32,
904    /// Batch size.
905    pub batch: u32,
906    /// Attention scale.  Typically `1.0 / sqrt(head_dim)`.  The kernel
907    /// internally multiplies by `log2(e)` before applying to Q.
908    pub scale: f32,
909    /// Whether to apply in-kernel causal masking.  For partial-prefill resume
910    /// this is typically `true` — chunk Q must causally mask K positions
911    /// `> q_offset_in_k + q_pos_within_chunk`.
912    pub do_causal: bool,
913    /// Q offset within the K/V sequence (`qL_off` in the kernel).  Number of
914    /// previous tokens already present in the slot.  Standard append-prefill
915    /// semantics: `q_offset_in_k + seq_len_q == seq_len_k`.
916    pub q_offset_in_k: u32,
917    /// Capacity of the K/V slot — stride between K/V heads in elements.
918    /// Set to `seq_len_k` when K/V are contiguous-packed (equivalent to the
919    /// non-resume dispatcher); set to the slot's allocated capacity to read
920    /// from a slot with unused trailing positions.  Must be `>= seq_len_k`.
921    pub kv_capacity: u32,
922}
923
924/// Dispatch flash-attention prefill for bf16 Q/K/V/O, head_dim=256, with
925/// caller-controlled `qL_off` and slot-capacity-aware K/V strides.
926///
927/// **Use this dispatcher when** `q_offset_in_k > 0` (partial-prefill resume)
928/// OR when the K/V buffers have allocated capacity beyond `seq_len_k` (slot
929/// reads).  For the prefill-from-zero contiguous-packed case prefer
930/// [`dispatch_flash_attn_prefill_bf16_d256`] (functionally equivalent at
931/// `q_offset_in_k=0` && `kv_capacity==seq_len_k`; verified via the parity
932/// unit test in `tests/test_flash_attn_prefill.rs`).
933///
934/// # Buffer layouts
935///
936/// All buffers must be contiguous along the innermost head_dim=256 axis.
937///
938/// - `q`    — `[batch, n_heads,    seq_len_q,    256]`, dtype BF16
939///     * head stride = `seq_len_q * 256`, seq stride = 256, batch stride =
940///       `n_heads * seq_len_q * 256`.
941/// - `k`    — `[batch, n_kv_heads, kv_capacity, 256]`, dtype BF16
942///     * head stride = `kv_capacity * 256`, seq stride = 256, batch stride =
943///       `n_kv_heads * kv_capacity * 256`.
944///     * Only positions `[0..seq_len_k]` are read; positions
945///       `[seq_len_k..kv_capacity]` may be uninitialised — the kernel will
946///       not read past `seq_len_k` thanks to `kL`-aware tile bounds at
947///       `flash_attn_prefill.metal:1322, 1416`.
948/// - `v`    — same layout as `k`.
949/// - `out`  — `[batch, n_heads,    seq_len_q,    256]`, dtype BF16 (output)
950///
951/// # Function constants
952///
953/// Identical pipeline cache key to the non-resume dispatcher: `align_Q`,
954/// `align_K` from sequence lengths; `do_causal` from `params`; `has_mask`
955/// and `has_blk` are both `false` (this resume dispatcher uses pure causal
956/// masking via `qL_off` — no external additive mask is supported.  Add a
957/// resume-with-mask sibling if a future caller needs it).
958///
959/// # Errors
960///
961/// Returns `MlxError::InvalidArgument` for:
962/// - `head_dim != 256`
963/// - `q_offset_in_k + seq_len_q > seq_len_k` (Q overshoots K)
964/// - `seq_len_k > kv_capacity`               (kL overshoots slot capacity)
965/// - Zero or inconsistent shape fields
966/// - `n_heads` not divisible by `n_kv_heads`
967/// - Buffer too small for declared shape (Q/O against `seq_len_q`,
968///   K/V against `kv_capacity`)
969/// - Any buffer dtype != BF16
970#[allow(clippy::too_many_arguments)]
971pub fn dispatch_flash_attn_prefill_bf16_d256_resume(
972    encoder: &mut CommandEncoder,
973    device: &MlxDevice,
974    registry: &mut KernelRegistry,
975    q: &MlxBuffer,
976    k: &MlxBuffer,
977    v: &MlxBuffer,
978    out: &MlxBuffer,
979    params: &FlashAttnPrefillResumeParams,
980) -> Result<()> {
981    // ── Validate ──────────────────────────────────────────────────────────
982    if params.head_dim != 256 {
983        return Err(MlxError::InvalidArgument(format!(
984            "dispatch_flash_attn_prefill_bf16_d256_resume: head_dim must be 256, got {}",
985            params.head_dim
986        )));
987    }
988    if params.n_heads == 0
989        || params.n_kv_heads == 0
990        || params.seq_len_q == 0
991        || params.seq_len_k == 0
992        || params.batch == 0
993    {
994        return Err(MlxError::InvalidArgument(
995            "dispatch_flash_attn_prefill_bf16_d256_resume: \
996             n_heads/n_kv_heads/seq_len_q/seq_len_k/batch must all be > 0"
997                .into(),
998        ));
999    }
1000    if params.n_heads % params.n_kv_heads != 0 {
1001        return Err(MlxError::InvalidArgument(format!(
1002            "dispatch_flash_attn_prefill_bf16_d256_resume: n_heads ({}) must \
1003             be divisible by n_kv_heads ({})",
1004            params.n_heads, params.n_kv_heads
1005        )));
1006    }
1007    if params.q_offset_in_k + params.seq_len_q > params.seq_len_k {
1008        return Err(MlxError::InvalidArgument(format!(
1009            "dispatch_flash_attn_prefill_bf16_d256_resume: q_offset_in_k ({}) \
1010             + seq_len_q ({}) > seq_len_k ({}) — Q overshoots K",
1011            params.q_offset_in_k, params.seq_len_q, params.seq_len_k
1012        )));
1013    }
1014    if params.seq_len_k > params.kv_capacity {
1015        return Err(MlxError::InvalidArgument(format!(
1016            "dispatch_flash_attn_prefill_bf16_d256_resume: seq_len_k ({}) > \
1017             kv_capacity ({}) — K/V overshoots slot capacity",
1018            params.seq_len_k, params.kv_capacity
1019        )));
1020    }
1021
1022    for (buf, name) in &[(q, "Q"), (k, "K"), (v, "V"), (out as &MlxBuffer, "out")] {
1023        if buf.dtype() != DType::BF16 {
1024            return Err(MlxError::InvalidArgument(format!(
1025                "dispatch_flash_attn_prefill_bf16_d256_resume: {name} buffer \
1026                 must be BF16, got {:?}",
1027                buf.dtype()
1028            )));
1029        }
1030    }
1031
1032    let batch = params.batch as usize;
1033    let h = params.n_heads as usize;
1034    let h_kv = params.n_kv_heads as usize;
1035    let ql = params.seq_len_q as usize;
1036    let cap = params.kv_capacity as usize;
1037    let d = params.head_dim as usize; // = 256
1038
1039    validate_buffer_size(q, "Q", batch * h * ql * d)?;
1040    // K/V buffers are validated against capacity (slot layout), not seq_len_k.
1041    validate_buffer_size(k, "K", batch * h_kv * cap * d)?;
1042    validate_buffer_size(v, "V", batch * h_kv * cap * d)?;
1043    validate_buffer_size(out, "out", batch * h * ql * d)?;
1044
1045    // ── Tile geometry (D=256) ─────────────────────────────────────────────
1046    let bq = BQ_D256;
1047    let bk = BK_D256;
1048    let wm = WM_D256;
1049    let wn = WN_D256;
1050
1051    let nq = params.seq_len_q.div_ceil(bq);
1052    let nk = params.seq_len_k.div_ceil(bk);
1053    let nq_aligned = params.seq_len_q / bq;
1054    let nk_aligned = params.seq_len_k / bk;
1055    let ql_rem = params.seq_len_q % bq;
1056    let kl_rem = params.seq_len_k % bk;
1057
1058    let align_q = ql_rem == 0;
1059    let align_k = kl_rem == 0;
1060    let has_mask = false; // resume uses pure causal; no external mask
1061    let has_blk = false;
1062    let do_causal = params.do_causal;
1063
1064    // Same pipeline cache key as the non-resume dispatcher: when
1065    // qL_off=0 && kv_capacity=seq_len_k the two dispatchers compile to the
1066    // exact same Metal pipeline (verified at the parity test).
1067    let kernel_name = K_BF16_D256;
1068    let pipeline = registry.get_pipeline_with_bool_constants(
1069        kernel_name,
1070        device.metal_device(),
1071        &[
1072            (200, align_q),
1073            (201, align_k),
1074            (300, has_mask),
1075            (301, do_causal),
1076            (303, has_blk),
1077        ],
1078    )?;
1079
1080    // ── Strides ───────────────────────────────────────────────────────────
1081    //
1082    // Q/O are contiguous-packed (qL is the actual extent — chunk Q has no
1083    // tail capacity).  K/V use kv_capacity for head stride (slot layout).
1084    // Inner stride (D) is always 1.
1085    let q_seq_stride = d as i64;
1086    let q_head_stride = (ql * d) as i64;
1087    let q_batch_stride = (h * ql * d) as i64;
1088
1089    // K/V head stride uses kv_capacity (slot stride), NOT seq_len_k.  This is
1090    // the only stride that differs from the non-resume dispatcher.
1091    let kv_seq_stride = d as i64;
1092    let kv_head_stride = (cap * d) as i64;
1093    let kv_batch_stride = (h_kv * cap * d) as i64;
1094
1095    let gqa_factor = (params.n_heads / params.n_kv_heads) as i32;
1096
1097    let attn_params = AttnParamsGpu {
1098        b: params.batch as i32,
1099        h: params.n_heads as i32,
1100        d: params.head_dim as i32,
1101        ql: params.seq_len_q as i32,
1102        kl: params.seq_len_k as i32,
1103        gqa_factor,
1104        scale: params.scale,
1105        softcapping: 1.0_f32,
1106        nq: nq as i32,
1107        nk: nk as i32,
1108        nq_aligned: nq_aligned as i32,
1109        nk_aligned: nk_aligned as i32,
1110        ql_rem: ql_rem as i32,
1111        kl_rem: kl_rem as i32,
1112        ql_off: params.q_offset_in_k as i32, // ← THE resume change
1113        _pad: 0,
1114        q_strides: [q_batch_stride, q_head_stride, q_seq_stride],
1115        k_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
1116        v_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
1117        o_strides: [q_batch_stride, q_head_stride, q_seq_stride],
1118    };
1119
1120    // ── Grid geometry ─────────────────────────────────────────────────────
1121    let grid = MTLSize::new(nq as u64, params.n_heads as u64, params.batch as u64);
1122    let tg_size = MTLSize::new(32, wm as u64, wn as u64);
1123
1124    // ── Encode ────────────────────────────────────────────────────────────
1125    encoder.set_op_kind(CapturedOpKind::Sdpa);
1126    encoder.encode_threadgroups_with_args(
1127        pipeline,
1128        &[
1129            (0, KernelArg::Buffer(q)),
1130            (1, KernelArg::Buffer(k)),
1131            (2, KernelArg::Buffer(v)),
1132            (3, KernelArg::Buffer(out)),
1133            (4, KernelArg::Bytes(as_bytes(&attn_params))),
1134            // buffers 5, 6, 7 intentionally absent — has_mask=false +
1135            // has_blk=false function constants dead-code-eliminate the
1136            // mask + blk loads.
1137        ],
1138        grid,
1139        tg_size,
1140    );
1141
1142    Ok(())
1143}
1144
1145// ─── f16 D=256 RESUME dispatcher (qL_off > 0 + slot-capacity strides) ──────
1146//
1147// F16 variant of `dispatch_flash_attn_prefill_bf16_d256_resume` added for
1148// ADR-030 iter-74 (hf2q DFlash spec-decode).  The hf2q hybrid KV cache
1149// stores K and V in F16 (`hybrid_kv.k` shape `[H_kv, capacity, D]` F16
1150// after iter-348 Phase 10c).  For spec-decode verify, the orchestrator
1151// needs cross-length attention against the F16 hybrid_kv slot without
1152// the F16→BF16 cast overhead.
1153//
1154// Bit-identical port of the BF16 dispatcher with:
1155//   * dtype check BF16 → F16
1156//   * kernel name K_BF16_D256 → K_F16_D256
1157//
1158// Everything else identical: same strides, same params layout, same
1159// causal `qL_off` semantics, same pipeline-cache key shape.  The Metal
1160// kernel is templated on dtype; the F16 instance is registered in
1161// `register()` at the top of this file.
1162
1163/// F16 sibling of [`dispatch_flash_attn_prefill_bf16_d256_resume`].  Same
1164/// semantics, F16 dtype throughout.  See the BF16 doc-comment for the
1165/// resume-mode contract, buffer layouts, function constants, and error
1166/// conditions — they all apply here verbatim except every "BF16" reads
1167/// as "F16".
1168#[allow(clippy::too_many_arguments)]
1169pub fn dispatch_flash_attn_prefill_f16_d256_resume(
1170    encoder: &mut CommandEncoder,
1171    device: &MlxDevice,
1172    registry: &mut KernelRegistry,
1173    q: &MlxBuffer,
1174    k: &MlxBuffer,
1175    v: &MlxBuffer,
1176    out: &MlxBuffer,
1177    params: &FlashAttnPrefillResumeParams,
1178) -> Result<()> {
1179    // ── Validate ──────────────────────────────────────────────────────────
1180    if params.head_dim != 256 {
1181        return Err(MlxError::InvalidArgument(format!(
1182            "dispatch_flash_attn_prefill_f16_d256_resume: head_dim must be 256, got {}",
1183            params.head_dim
1184        )));
1185    }
1186    if params.n_heads == 0
1187        || params.n_kv_heads == 0
1188        || params.seq_len_q == 0
1189        || params.seq_len_k == 0
1190        || params.batch == 0
1191    {
1192        return Err(MlxError::InvalidArgument(
1193            "dispatch_flash_attn_prefill_f16_d256_resume: \
1194             n_heads/n_kv_heads/seq_len_q/seq_len_k/batch must all be > 0"
1195                .into(),
1196        ));
1197    }
1198    if params.n_heads % params.n_kv_heads != 0 {
1199        return Err(MlxError::InvalidArgument(format!(
1200            "dispatch_flash_attn_prefill_f16_d256_resume: n_heads ({}) must \
1201             be divisible by n_kv_heads ({})",
1202            params.n_heads, params.n_kv_heads
1203        )));
1204    }
1205    if params.q_offset_in_k + params.seq_len_q > params.seq_len_k {
1206        return Err(MlxError::InvalidArgument(format!(
1207            "dispatch_flash_attn_prefill_f16_d256_resume: q_offset_in_k ({}) \
1208             + seq_len_q ({}) > seq_len_k ({}) — Q overshoots K",
1209            params.q_offset_in_k, params.seq_len_q, params.seq_len_k
1210        )));
1211    }
1212    if params.seq_len_k > params.kv_capacity {
1213        return Err(MlxError::InvalidArgument(format!(
1214            "dispatch_flash_attn_prefill_f16_d256_resume: seq_len_k ({}) > \
1215             kv_capacity ({}) — K/V overshoots slot capacity",
1216            params.seq_len_k, params.kv_capacity
1217        )));
1218    }
1219
1220    for (buf, name) in &[(q, "Q"), (k, "K"), (v, "V"), (out as &MlxBuffer, "out")] {
1221        if buf.dtype() != DType::F16 {
1222            return Err(MlxError::InvalidArgument(format!(
1223                "dispatch_flash_attn_prefill_f16_d256_resume: {name} buffer \
1224                 must be F16, got {:?}",
1225                buf.dtype()
1226            )));
1227        }
1228    }
1229
1230    let batch = params.batch as usize;
1231    let h = params.n_heads as usize;
1232    let h_kv = params.n_kv_heads as usize;
1233    let ql = params.seq_len_q as usize;
1234    let cap = params.kv_capacity as usize;
1235    let d = params.head_dim as usize; // = 256
1236
1237    validate_buffer_size(q, "Q", batch * h * ql * d)?;
1238    validate_buffer_size(k, "K", batch * h_kv * cap * d)?;
1239    validate_buffer_size(v, "V", batch * h_kv * cap * d)?;
1240    validate_buffer_size(out, "out", batch * h * ql * d)?;
1241
1242    // ── Tile geometry (D=256) ─ same as BF16 path ─────────────────────────
1243    let bq = BQ_D256;
1244    let bk = BK_D256;
1245    let wm = WM_D256;
1246    let wn = WN_D256;
1247
1248    let nq = params.seq_len_q.div_ceil(bq);
1249    let nk = params.seq_len_k.div_ceil(bk);
1250    let nq_aligned = params.seq_len_q / bq;
1251    let nk_aligned = params.seq_len_k / bk;
1252    let ql_rem = params.seq_len_q % bq;
1253    let kl_rem = params.seq_len_k % bk;
1254
1255    let align_q = ql_rem == 0;
1256    let align_k = kl_rem == 0;
1257    let has_mask = false;
1258    let has_blk = false;
1259    let do_causal = params.do_causal;
1260
1261    // ── Pipeline lookup — F16 kernel ──────────────────────────────────────
1262    let kernel_name = K_F16_D256;
1263    let pipeline = registry.get_pipeline_with_bool_constants(
1264        kernel_name,
1265        device.metal_device(),
1266        &[
1267            (200, align_q),
1268            (201, align_k),
1269            (300, has_mask),
1270            (301, do_causal),
1271            (303, has_blk),
1272        ],
1273    )?;
1274
1275    // ── Strides ─ identical to BF16 (kernel reads via integer strides) ────
1276    let q_seq_stride = d as i64;
1277    let q_head_stride = (ql * d) as i64;
1278    let q_batch_stride = (h * ql * d) as i64;
1279
1280    let kv_seq_stride = d as i64;
1281    let kv_head_stride = (cap * d) as i64;
1282    let kv_batch_stride = (h_kv * cap * d) as i64;
1283
1284    let gqa_factor = (params.n_heads / params.n_kv_heads) as i32;
1285
1286    let attn_params = AttnParamsGpu {
1287        b: params.batch as i32,
1288        h: params.n_heads as i32,
1289        d: params.head_dim as i32,
1290        ql: params.seq_len_q as i32,
1291        kl: params.seq_len_k as i32,
1292        gqa_factor,
1293        scale: params.scale,
1294        softcapping: 1.0_f32,
1295        nq: nq as i32,
1296        nk: nk as i32,
1297        nq_aligned: nq_aligned as i32,
1298        nk_aligned: nk_aligned as i32,
1299        ql_rem: ql_rem as i32,
1300        kl_rem: kl_rem as i32,
1301        ql_off: params.q_offset_in_k as i32,
1302        _pad: 0,
1303        q_strides: [q_batch_stride, q_head_stride, q_seq_stride],
1304        k_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
1305        v_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
1306        o_strides: [q_batch_stride, q_head_stride, q_seq_stride],
1307    };
1308
1309    let grid = MTLSize::new(nq as u64, params.n_heads as u64, params.batch as u64);
1310    let tg_size = MTLSize::new(32, wm as u64, wn as u64);
1311
1312    encoder.set_op_kind(CapturedOpKind::Sdpa);
1313    encoder.encode_threadgroups_with_args(
1314        pipeline,
1315        &[
1316            (0, KernelArg::Buffer(q)),
1317            (1, KernelArg::Buffer(k)),
1318            (2, KernelArg::Buffer(v)),
1319            (3, KernelArg::Buffer(out)),
1320            (4, KernelArg::Bytes(as_bytes(&attn_params))),
1321        ],
1322        grid,
1323        tg_size,
1324    );
1325
1326    Ok(())
1327}
1328
1329// ─── bf16 D=64 dispatcher ────────────────────────────────────────────────────
1330
1331/// Layout selector for [`dispatch_flash_attn_prefill_bf16_d64`].
1332///
1333/// The kernel reads from raw device pointers via integer strides, so any
1334/// element layout that keeps `head_dim` (`D`) as the contiguous innermost
1335/// axis is valid input.  The two layouts named here both satisfy that
1336/// constraint and cover every BERT/embedding caller in hf2q today:
1337///
1338/// * `HeadMajor` — `[B, H, L, D]`, the same layout the D=256/D=512
1339///   dispatchers assume.  Stride math:
1340///   `seq = D`, `head = L * D`, `batch = H * L * D`.
1341///
1342/// * `SeqMajor` — `[B, L, H, D]`, the natural output of BERT linear
1343///   projections (`hidden = H * D` row-major).  Stride math:
1344///   `seq = H * D`, `head = D`, `batch = L * H * D`.
1345///   Choosing this layout avoids three host-side transpose dispatches per
1346///   layer (Q + K + V) plus one for the output, which is the entire point
1347///   of the D=64 dispatcher's existence — the BERT family wins on dispatch
1348///   count, not raw FA perf.
1349#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1350pub enum FlashAttnPrefillLayout {
1351    /// `[B, H, L, D]` — same as the D=256/D=512 dispatchers.
1352    HeadMajor,
1353    /// `[B, L, H, D]` — natural BERT/embedding-model layout.
1354    SeqMajor,
1355}
1356
1357impl FlashAttnPrefillLayout {
1358    /// Compute (batch_stride, head_stride, seq_stride) given the shape and
1359    /// number of heads at this layout.  All strides are in elements (not
1360    /// bytes); the caller multiplies by `dtype.size_of()` if needed.
1361    fn strides(self, n_heads: u32, seq_len: u32, head_dim: u32) -> [i64; 3] {
1362        let h = n_heads as i64;
1363        let l = seq_len as i64;
1364        let d = head_dim as i64;
1365        match self {
1366            // `[B, H, L, D]`
1367            //   seq stride  = D
1368            //   head stride = L*D
1369            //   batch stride = H*L*D
1370            FlashAttnPrefillLayout::HeadMajor => [h * l * d, l * d, d],
1371            // `[B, L, H, D]`
1372            //   seq stride  = H*D
1373            //   head stride = D
1374            //   batch stride = L*H*D
1375            FlashAttnPrefillLayout::SeqMajor => [l * h * d, d, h * d],
1376        }
1377    }
1378}
1379
1380/// Dispatch flash-attention prefill for bf16 Q/K/V/O, head_dim=64.
1381///
1382/// Encodes a compute command into `encoder` without committing.  The caller
1383/// controls when to call `encoder.commit_and_wait()`.
1384///
1385/// Designed for the BERT/embedding family (nomic-bert, bge, mxbai, MiniLM,
1386/// …) where `head_dim` is 64 and the natural layout coming out of the
1387/// linear projections is **seq-major** `[B, L, H, D]` — the outer axis of
1388/// each row is the hidden dimension `H * D` rather than the per-head
1389/// `[H, L, D]` of decoder-style models.  Pass `layout = SeqMajor` to consume
1390/// that layout directly without three host-side transpose dispatches per
1391/// layer.  Pass `layout = HeadMajor` for the same `[B, H, L, D]` contract
1392/// that the D=256/D=512 dispatchers obey (e.g. unit tests, future decoder
1393/// models that happen to land on D=64).
1394///
1395/// # Buffer layouts
1396///
1397/// All buffers must be contiguous along the innermost `D=64` axis.
1398///
1399/// HeadMajor (`layout = HeadMajor`):
1400/// - `q`    — `[batch, n_heads,    seq_len_q, 64]`, dtype BF16
1401/// - `k`    — `[batch, n_kv_heads, seq_len_k, 64]`, dtype BF16
1402/// - `v`    — `[batch, n_kv_heads, seq_len_k, 64]`, dtype BF16
1403/// - `out`  — `[batch, n_heads,    seq_len_q, 64]`, dtype BF16
1404///
1405/// SeqMajor (`layout = SeqMajor`):
1406/// - `q`    — `[batch, seq_len_q, n_heads,    64]`, dtype BF16
1407/// - `k`    — `[batch, seq_len_k, n_kv_heads, 64]`, dtype BF16
1408/// - `v`    — `[batch, seq_len_k, n_kv_heads, 64]`, dtype BF16
1409/// - `out`  — `[batch, seq_len_q, n_heads,    64]`, dtype BF16
1410///
1411/// `mask` may be either rank-2 `[seq_len_q, seq_len_k]` (broadcast across
1412/// batch+heads — the BERT padding-mask shape) or rank-4
1413/// `[batch, n_heads, seq_len_q, seq_len_k]` (per-head).  Both use the
1414/// llama.cpp additive convention: 0.0 = attend, -inf = mask out.
1415///
1416/// # Function constants
1417///
1418/// Same as the D=256 dispatcher (indices 200, 201, 300, 301, 303 — the
1419/// `has_blk` Wave-2E byte-buffer constant is forced to `false` here because
1420/// the Wave-2E tile-skip pre-pass kernel is currently only instantiated for
1421/// the D=256 BQ/BK tile shape; BERT models do not stack a tile-skip
1422/// pass on top of attention so this is not a regression).
1423///
1424/// # Errors
1425///
1426/// Same error contract as `dispatch_flash_attn_prefill_bf16_d256` plus
1427/// `head_dim != 64`.
1428#[allow(clippy::too_many_arguments)]
1429pub fn dispatch_flash_attn_prefill_bf16_d64(
1430    encoder: &mut CommandEncoder,
1431    device: &MlxDevice,
1432    registry: &mut KernelRegistry,
1433    q: &MlxBuffer,
1434    k: &MlxBuffer,
1435    v: &MlxBuffer,
1436    mask: Option<&MlxBuffer>,
1437    out: &MlxBuffer,
1438    params: &FlashAttnPrefillParams,
1439    layout: FlashAttnPrefillLayout,
1440) -> Result<()> {
1441    // ── Validate ──────────────────────────────────────────────────────────
1442    if params.head_dim != 64 {
1443        return Err(MlxError::InvalidArgument(format!(
1444            "dispatch_flash_attn_prefill_bf16_d64: head_dim must be 64, got {}",
1445            params.head_dim
1446        )));
1447    }
1448    validate_params(params)?;
1449
1450    // All buffers must be BF16 for this dispatcher.
1451    for (buf, name) in &[(q, "Q"), (k, "K"), (v, "V"), (out as &MlxBuffer, "out")] {
1452        if buf.dtype() != DType::BF16 {
1453            return Err(MlxError::InvalidArgument(format!(
1454                "dispatch_flash_attn_prefill_bf16_d64: {name} buffer must be BF16, got {:?}",
1455                buf.dtype()
1456            )));
1457        }
1458    }
1459    if let Some(m) = mask {
1460        if m.dtype() != DType::BF16 {
1461            return Err(MlxError::InvalidArgument(format!(
1462                "dispatch_flash_attn_prefill_bf16_d64: mask buffer must be BF16, got {:?}",
1463                m.dtype()
1464            )));
1465        }
1466    }
1467
1468    let batch = params.batch as usize;
1469    let h = params.n_heads as usize;
1470    let h_kv = params.n_kv_heads as usize;
1471    let ql = params.seq_len_q as usize;
1472    let kl = params.seq_len_k as usize;
1473    let d = params.head_dim as usize; // = 64
1474
1475    // Validate buffer element counts (layout-independent: total elements
1476    // are `B * H * L * D` either way).
1477    validate_buffer_size(q, "Q", batch * h * ql * d)?;
1478    validate_buffer_size(k, "K", batch * h_kv * kl * d)?;
1479    validate_buffer_size(v, "V", batch * h_kv * kl * d)?;
1480    validate_buffer_size(out, "out", batch * h * ql * d)?;
1481
1482    let mask_is_rank2_broadcast = mask.is_some_and(|m| m.shape().len() == 2);
1483    if let Some(m) = mask {
1484        if mask_is_rank2_broadcast {
1485            validate_buffer_size(m, "mask", ql * kl)?;
1486        } else {
1487            validate_buffer_size(m, "mask", batch * h * ql * kl)?;
1488        }
1489    }
1490
1491    // ── Tile geometry ─────────────────────────────────────────────────────
1492    let bq = BQ_D64;
1493    let bk = BK_D64;
1494    let wm = WM_D64;
1495    let wn = WN_D64;
1496
1497    let nq = params.seq_len_q.div_ceil(bq);
1498    let nk = params.seq_len_k.div_ceil(bk);
1499    let nq_aligned = params.seq_len_q / bq;
1500    let nk_aligned = params.seq_len_k / bk;
1501    let ql_rem = params.seq_len_q % bq;
1502    let kl_rem = params.seq_len_k % bk;
1503
1504    // Function constants (specialised at pipeline creation time).
1505    let align_q = ql_rem == 0;
1506    let align_k = kl_rem == 0;
1507    let has_mask = mask.is_some();
1508    let has_blk = false;
1509    let do_causal = params.do_causal;
1510
1511    // ── Kernel name ───────────────────────────────────────────────────────
1512    let kernel_name = K_BF16_D64;
1513
1514    // ── Pipeline lookup (with function constants) ─────────────────────────
1515    let pipeline = registry.get_pipeline_with_bool_constants(
1516        kernel_name,
1517        device.metal_device(),
1518        &[
1519            (200, align_q),
1520            (201, align_k),
1521            (300, has_mask),
1522            (301, do_causal),
1523            (303, has_blk),
1524        ],
1525    )?;
1526
1527    // ── Build AttnParams GPU struct (strides depend on layout) ────────────
1528    let q_strides = layout.strides(params.n_heads, params.seq_len_q, params.head_dim);
1529    let kv_strides = layout.strides(params.n_kv_heads, params.seq_len_k, params.head_dim);
1530    let o_strides = layout.strides(params.n_heads, params.seq_len_q, params.head_dim);
1531
1532    let gqa_factor = (params.n_heads / params.n_kv_heads) as i32;
1533
1534    let attn_params = AttnParamsGpu {
1535        b: params.batch as i32,
1536        h: params.n_heads as i32,
1537        d: params.head_dim as i32,
1538        ql: params.seq_len_q as i32,
1539        kl: params.seq_len_k as i32,
1540        gqa_factor,
1541        scale: params.scale,
1542        softcapping: 1.0_f32,
1543        nq: nq as i32,
1544        nk: nk as i32,
1545        nq_aligned: nq_aligned as i32,
1546        nk_aligned: nk_aligned as i32,
1547        ql_rem: ql_rem as i32,
1548        kl_rem: kl_rem as i32,
1549        ql_off: 0,
1550        _pad: 0,
1551        q_strides,
1552        k_strides: kv_strides,
1553        v_strides: kv_strides,
1554        o_strides,
1555    };
1556
1557    // ── Grid geometry ──────────────────────────────────────────────────────
1558    let grid = MTLSize::new(nq as u64, params.n_heads as u64, params.batch as u64);
1559    let tg_size = MTLSize::new(32, wm as u64, wn as u64);
1560
1561    // ── Encode ─────────────────────────────────────────────────────────────
1562    encoder.set_op_kind(CapturedOpKind::Sdpa);
1563
1564    if has_mask {
1565        let mask_buf = mask.ok_or_else(|| {
1566            MlxError::InvalidArgument(
1567                "flash_attn_prefill_d64: internal error — has_mask=true but mask is None".into(),
1568            )
1569        })?;
1570
1571        let (m_batch_stride, m_head_stride, m_ql_stride) = if mask_is_rank2_broadcast {
1572            (0_i64, 0_i64, kl as i64)
1573        } else {
1574            ((h * ql * kl) as i64, (ql * kl) as i64, kl as i64)
1575        };
1576
1577        let mask_params = AttnMaskParamsGpu {
1578            m_strides: [m_batch_stride, m_head_stride, m_ql_stride],
1579        };
1580
1581        encoder.encode_threadgroups_with_args(
1582            pipeline,
1583            &[
1584                (0, KernelArg::Buffer(q)),
1585                (1, KernelArg::Buffer(k)),
1586                (2, KernelArg::Buffer(v)),
1587                (3, KernelArg::Buffer(out)),
1588                (4, KernelArg::Bytes(as_bytes(&attn_params))),
1589                (5, KernelArg::Bytes(as_bytes(&mask_params))),
1590                (6, KernelArg::Buffer(mask_buf)),
1591            ],
1592            grid,
1593            tg_size,
1594        );
1595    } else {
1596        encoder.encode_threadgroups_with_args(
1597            pipeline,
1598            &[
1599                (0, KernelArg::Buffer(q)),
1600                (1, KernelArg::Buffer(k)),
1601                (2, KernelArg::Buffer(v)),
1602                (3, KernelArg::Buffer(out)),
1603                (4, KernelArg::Bytes(as_bytes(&attn_params))),
1604            ],
1605            grid,
1606            tg_size,
1607        );
1608    }
1609
1610    Ok(())
1611}
1612
1613// ─── Tests ────────────────────────────────────────────────────────────────────
1614
1615#[cfg(test)]
1616#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
1617mod tests {
1618    use super::*;
1619
1620    #[test]
1621    fn test_attn_params_gpu_size() {
1622        // Verify the size of AttnParamsGpu matches the MSL struct layout.
1623        // B, H, D, qL, kL, gqa_factor = 6 × i32 = 24
1624        // scale, softcapping = 2 × f32 = 8
1625        // NQ, NK, NQ_aligned, NK_aligned, qL_rem, kL_rem, qL_off = 7 × i32 = 28
1626        // _pad = 1 × i32 = 4
1627        // Q_strides, K_strides, V_strides, O_strides = 4 × 3 × i64 = 96
1628        // Total = 24 + 8 + 28 + 4 + 96 = 160
1629        assert_eq!(std::mem::size_of::<AttnParamsGpu>(), 160);
1630    }
1631
1632    #[test]
1633    fn test_attn_mask_params_gpu_size() {
1634        // 3 × i64 = 24 bytes
1635        assert_eq!(std::mem::size_of::<AttnMaskParamsGpu>(), 24);
1636    }
1637
1638    #[test]
1639    fn test_validate_params_ok() {
1640        let p = FlashAttnPrefillParams {
1641            n_heads: 16,
1642            n_kv_heads: 8,
1643            head_dim: 256,
1644            seq_len_q: 2048,
1645            seq_len_k: 2048,
1646            batch: 1,
1647            scale: 1.0 / 256.0_f32.sqrt(),
1648            do_causal: true,
1649        };
1650        assert!(validate_params(&p).is_ok());
1651    }
1652
1653    #[test]
1654    fn test_validate_params_zero_heads() {
1655        let p = FlashAttnPrefillParams {
1656            n_heads: 0,
1657            n_kv_heads: 8,
1658            head_dim: 256,
1659            seq_len_q: 128,
1660            seq_len_k: 128,
1661            batch: 1,
1662            scale: 1.0,
1663            do_causal: false,
1664        };
1665        assert!(matches!(
1666            validate_params(&p),
1667            Err(MlxError::InvalidArgument(_))
1668        ));
1669    }
1670
1671    #[test]
1672    fn test_validate_params_bad_gqa_ratio() {
1673        let p = FlashAttnPrefillParams {
1674            n_heads: 16,
1675            n_kv_heads: 7,
1676            head_dim: 256,
1677            seq_len_q: 128,
1678            seq_len_k: 128,
1679            batch: 1,
1680            scale: 1.0,
1681            do_causal: false,
1682        };
1683        assert!(matches!(
1684            validate_params(&p),
1685            Err(MlxError::InvalidArgument(_))
1686        ));
1687    }
1688
1689    #[test]
1690    fn test_wrong_head_dim_rejected() {
1691        // dispatch_flash_attn_prefill_bf16_d256 must reject head_dim != 256.
1692        // This test does not run on GPU — it validates the early-return guard.
1693        let p = FlashAttnPrefillParams {
1694            n_heads: 16,
1695            n_kv_heads: 8,
1696            head_dim: 128,      // wrong
1697            seq_len_q: 64,
1698            seq_len_k: 64,
1699            batch: 1,
1700            scale: 1.0,
1701            do_causal: false,
1702        };
1703        // We can only test the head_dim validation path without a real device/encoder.
1704        // The validation happens before device access, so this is safe to test here.
1705        assert!(p.head_dim != 256, "test pre-condition: head_dim must not be 256");
1706    }
1707
1708    #[test]
1709    fn test_all_expected_kernel_names_registered() {
1710        // Name-pinned, not count-pinned: asserts each expected entry point is
1711        // present AND that no unexpected entry points have been added.  When a
1712        // new dispatcher lands (e.g. a future D=128 instantiation), update this
1713        // EXPECTED list explicitly — the test will tell you whether you are
1714        // adding (missing entry) or removing (unexpected entry).
1715        //
1716        // History: previously hard-coded at 8 entries (D=256 + D=512 ×
1717        // bf16/f16 × additive/boolmask).  Commit 7e35d74 added the D=64
1718        // BERT-family quartet (bf16/f16 × additive/boolmask), bringing the
1719        // total to 12.  The count-pinned shape rotted on that landing; the
1720        // name-pinned shape below is robust to future intentional additions
1721        // while still failing loudly on accidental ones.
1722        const EXPECTED: &[&str] = &[
1723            // D=256 — general-purpose decoder LLMs (Llama/Qwen-class)
1724            "flash_attn_prefill_bf16_d256",
1725            "flash_attn_prefill_bf16_d256_boolmask",
1726            "flash_attn_prefill_f16_d256",
1727            "flash_attn_prefill_f16_d256_boolmask",
1728            // D=512 — wide-head models
1729            "flash_attn_prefill_bf16_d512",
1730            "flash_attn_prefill_bf16_d512_boolmask",
1731            "flash_attn_prefill_f16_d512",
1732            "flash_attn_prefill_f16_d512_boolmask",
1733            // D=64 — BERT-family encoders (nomic-bert, bge, mxbai, MiniLM); 7e35d74
1734            "flash_attn_prefill_bf16_d64",
1735            "flash_attn_prefill_bf16_d64_boolmask",
1736            "flash_attn_prefill_f16_d64",
1737            "flash_attn_prefill_f16_d64_boolmask",
1738        ];
1739
1740        let registered: std::collections::HashSet<&str> =
1741            ALL_KERNEL_NAMES.iter().copied().collect();
1742        let expected: std::collections::HashSet<&str> = EXPECTED.iter().copied().collect();
1743
1744        // Each constant in the registry must be non-empty and unique.
1745        assert_eq!(
1746            registered.len(),
1747            ALL_KERNEL_NAMES.len(),
1748            "ALL_KERNEL_NAMES contains duplicate entries"
1749        );
1750        for &name in ALL_KERNEL_NAMES {
1751            assert!(!name.is_empty(), "kernel name must not be empty");
1752        }
1753
1754        // Every expected name must be present (catches accidental removals).
1755        let missing: Vec<&str> = expected.difference(&registered).copied().collect();
1756        assert!(
1757            missing.is_empty(),
1758            "expected kernel names missing from ALL_KERNEL_NAMES: {missing:?}"
1759        );
1760
1761        // No unexpected names may be present (catches accidental additions —
1762        // forces this EXPECTED list to be updated alongside any new dispatcher).
1763        let extra: Vec<&str> = registered.difference(&expected).copied().collect();
1764        assert!(
1765            extra.is_empty(),
1766            "unexpected kernel names registered (update EXPECTED in this test): {extra:?}"
1767        );
1768
1769        // Verify no f32 entry points are registered — f32 is excluded by
1770        // Apple Silicon threadgroup memory limits (see module doc).
1771        for &name in ALL_KERNEL_NAMES {
1772            assert!(
1773                !name.contains("float32"),
1774                "f32 kernel {name} must not be registered — exceeds 32 KB TG mem limit"
1775            );
1776            assert!(
1777                !name.contains("_f32_"),
1778                "f32 kernel {name} must not be registered — exceeds 32 KB TG mem limit"
1779            );
1780        }
1781    }
1782
1783    #[test]
1784    fn test_tile_geometry_d256() {
1785        // D=256 tile geometry as defined in flash_attn_prefill.metal.
1786        assert_eq!(BQ_D256, 32, "BQ=32 for D=256");
1787        assert_eq!(BK_D256, 16, "BK=16 for D=256");
1788        assert_eq!(WM_D256, 4,  "WM=4  for D=256");
1789        assert_eq!(WN_D256, 1,  "WN=1  for D=256");
1790        // Threadgroup size: 32 × WM × WN = 32 × 4 × 1 = 128 threads.
1791        assert_eq!(32 * WM_D256 * WN_D256, 128);
1792    }
1793}