Skip to main content

mlx_native/ops/
flash_attn_prefill.rs

1//! Flash-attention-style tiled prefill kernel — host dispatch.
2//!
3//! mlx-native's batched-prefill SDPA kernel, the prefill counterpart to
4//! `flash_attn_vec` (which handles the seq_len=1 decode case).  Implements
5//! online softmax + simdgroup MMA tiling on Apple GPU.
6//!
7//! ## Kernel variants registered
8//!
9//! Twelve entry points are registered, all backed by the single
10//! `flash_attn_prefill.metal` shader source:
11//!
12//! ### D=64 (BQ=32, BK=16, WM=4, WN=1 — 128 threads/threadgroup, BERT family)
13//!
14//! | Kernel name | I/O dtype | Mask kind |
15//! |---|---|---|
16//! | `flash_attn_prefill_bf16_d64`            | bf16 | bf16 additive |
17//! | `flash_attn_prefill_bf16_d64_boolmask`   | bf16 | bool |
18//! | `flash_attn_prefill_f16_d64`             | f16  | f16 additive |
19//! | `flash_attn_prefill_f16_d64_boolmask`    | f16  | bool |
20//!
21//! ### D=256 (BQ=32, BK=16, WM=4, WN=1 — 128 threads/threadgroup)
22//!
23//! | Kernel name | I/O dtype | Mask kind |
24//! |---|---|---|
25//! | `flash_attn_prefill_bf16_d256`           | bf16 | bf16 additive (log-domain) |
26//! | `flash_attn_prefill_bf16_d256_boolmask`  | bf16 | bool (`is_attended`) |
27//! | `flash_attn_prefill_f16_d256`            | f16  | f16 additive |
28//! | `flash_attn_prefill_f16_d256_boolmask`   | f16  | bool |
29//!
30//! ### D=512 (BQ=8, BK=8, WM=1, WN=1 — 32 threads/threadgroup, 1 simdgroup)
31//!
32//! | Kernel name | I/O dtype | Mask kind |
33//! |---|---|---|
34//! | `flash_attn_prefill_bf16_d512`           | bf16 | bf16 additive |
35//! | `flash_attn_prefill_bf16_d512_boolmask`  | bf16 | bool |
36//! | `flash_attn_prefill_f16_d512`            | f16  | f16 additive |
37//! | `flash_attn_prefill_f16_d512_boolmask`   | f16  | bool |
38//!
39//! ### f32 is NOT instantiated at any D
40//!
41//! The f32 Qs threadgroup tile alone is `BQ * BD * 4` bytes — at D=256 this
42//! is 32 KB exactly, the Apple Silicon `MTLDevice.maxThreadgroupMemoryLength`
43//! hard limit, before KV_smem or scratch.  Verified empirically on M5 Max:
44//! f32 D=256 requires ~53.7 KB and fails at library compile.  bf16 and f16
45//! halve the tile footprint (~29 KB) and fit within the limit.  D=512 f32
46//! is excluded for the same reason.  D=64 inherits the same exclusion for
47//! consistency (bf16/f16 only across the entire dispatcher family).  f32
48//! correctness is verified at the CPU reference layer in
49//! `tests/test_flash_attn_prefill.rs`.  See
50//! `ADR-011-phase1-port-source-decision.md` §3 for the full
51//! threadgroup-memory analysis.
52//!
53//! ## Function constants
54//!
55//! The kernel declares four Metal function constants that must be specialised
56//! at pipeline creation time (not at dispatch time):
57//!
58//! - Index 200: `align_Q` (bool) — true when `qL % BQ == 0`
59//! - Index 201: `align_K` (bool) — true when `kL % BK == 0`
60//! - Index 300: `has_mask` (bool) — true when a mask buffer is bound
61//! - Index 301: `do_causal` (bool) — true for in-kernel causal masking
62//!
63//! These are plumbed via [`KernelRegistry::get_pipeline_with_bool_constants`],
64//! which caches compiled pipelines keyed by `(kernel_name, align_Q, align_K,
65//! has_mask, do_causal)`.  Pipeline compilation is amortised: the slow path
66//! runs only once per unique `(name, booleans)` combination.
67//!
68//! ## Buffer layout (indices match the MSL kernel)
69//!
70//! - `buffer(0)` — Q `[B, H,    qL, D]`  device, contiguous inner dim
71//! - `buffer(1)` — K `[B, H_kv, kL, D]`  device, contiguous inner dim
72//! - `buffer(2)` — V `[B, H_kv, kL, D]`  device, contiguous inner dim
73//! - `buffer(3)` — O `[B, H,    qL, D]`  device, written by kernel
74//! - `buffer(4)` — `AttnParams` constant buffer (this module's [`AttnParamsGpu`])
75//! - `buffer(5)` — `AttnMaskParams` constant buffer (only when `has_mask=true`)
76//! - `buffer(6)` — mask data buffer (only when `has_mask=true`)
77//!
78//! ## Grid geometry
79//!
80//! - Threadgroups: `(ceil(qL / BQ), H, B)`
81//! - Threads per threadgroup: `(32, WM, WN)`
82//! - D=256: 128 threads (4 simdgroups × 32 lanes).
83//! - D=512:  32 threads (1 simdgroup  × 32 lanes).
84//!
85//! ## Scale convention
86//!
87//! Pass `scale = 1.0 / sqrt(head_dim)`.  The kernel multiplies internally by
88//! `log2(e) ≈ 1.44269504` and uses `fast::exp2` throughout — so the host
89//! MUST NOT pre-multiply by `log2(e)`.
90//!
91//! ## Mask-sentinel contract (llama.cpp convention)
92//!
93//! The additive mask buffer (bf16/f16 for the additive dispatchers) uses
94//! the llama.cpp CPU-side convention: **masked positions = `-INFINITY`**
95//! (IEEE-754 f32 `0xFF800000`, cast to the I/O dtype — both `half(-inf)`
96//! and `bfloat16(-inf)` have a real `-inf` encoding that the kernel
97//! consumes correctly).  Attended positions = `0.0`.
98//!
99//! This matches llama.cpp's mask-authoring sites at
100//! `llama-graph.cpp:421, 436, 557` and `llama-kv-cache.cpp:1572`, where
101//! the CPU writes raw `-INFINITY` and the flash-attn cast to f16 saturates
102//! to f16-`-INFINITY`.
103//!
104//! The kernel does NOT require masks to use `-FLT_MAX/2` or any other
105//! finite "large negative" sentinel.  The `-FLT_MAX/2` value is the
106//! kernel-internal `M` (running row-max) initialiser — a finite sentinel
107//! that absorbs `-inf` scores via `simd_max` without ever letting `M`
108//! become `-inf`, which in turn lets every `exp(score - M)` evaluate as
109//! `exp(-inf) = 0.0` (IEEE-754 exact) rather than `exp(-inf - -inf) =
110//! exp(NaN) = NaN`.  See ADR-011-phase2-port-sentinel.md §1.
111//!
112//! Callers may pass arbitrary finite additive biases in the mask for
113//! non-masked positions (e.g. ALiBi, relative-position biases); only
114//! "fully block this K position" requires `-inf`.
115//!
116//! ## See also
117//!
118//! - Kernel: `/opt/mlx-native/src/shaders/flash_attn_prefill.metal`
119//! - ADR-011: `/opt/hf2q/docs/ADR-011-flash-attn-prefill.md`
120//! - ADR-011 phase 2 sentinel port: `/opt/hf2q/docs/ADR-011-phase2-port-sentinel.md`
121
122use metal::MTLSize;
123
124use crate::buffer::MlxBuffer;
125use crate::device::MlxDevice;
126use crate::encoder::{CapturedOpKind, CommandEncoder, KernelArg, as_bytes};
127use crate::error::{MlxError, Result};
128use crate::kernel_registry::KernelRegistry;
129use crate::DType;
130
131// ─── Shader source ───────────────────────────────────────────────────────────
132
133/// MSL source for the flash-attention prefill kernel (embedded at compile time).
134pub static FLASH_ATTN_PREFILL_SHADER_SOURCE: &str =
135    include_str!("../shaders/flash_attn_prefill.metal");
136
137// ─── All 12 kernel entry-point names ─────────────────────────────────────────
138
139/// D=256, bf16 I/O, bf16 additive mask.
140const K_BF16_D256: &str = "flash_attn_prefill_bf16_d256";
141/// D=256, bf16 I/O, bool (`is_attended`) mask.
142const K_BF16_D256_BOOLMASK: &str = "flash_attn_prefill_bf16_d256_boolmask";
143/// D=256, f16 I/O, f16 additive mask.
144const K_F16_D256: &str = "flash_attn_prefill_f16_d256";
145/// D=256, f16 I/O, bool mask.
146const K_F16_D256_BOOLMASK: &str = "flash_attn_prefill_f16_d256_boolmask";
147/// D=512, bf16 I/O, bf16 additive mask.
148const K_BF16_D512: &str = "flash_attn_prefill_bf16_d512";
149/// D=512, bf16 I/O, bool mask.
150const K_BF16_D512_BOOLMASK: &str = "flash_attn_prefill_bf16_d512_boolmask";
151/// D=512, f16 I/O, f16 additive mask.
152const K_F16_D512: &str = "flash_attn_prefill_f16_d512";
153/// D=512, f16 I/O, bool mask.
154const K_F16_D512_BOOLMASK: &str = "flash_attn_prefill_f16_d512_boolmask";
155/// D=64, bf16 I/O, bf16 additive mask (BERT family).
156const K_BF16_D64: &str = "flash_attn_prefill_bf16_d64";
157/// D=64, bf16 I/O, bool (`is_attended`) mask.
158const K_BF16_D64_BOOLMASK: &str = "flash_attn_prefill_bf16_d64_boolmask";
159/// D=64, f16 I/O, f16 additive mask.
160const K_F16_D64: &str = "flash_attn_prefill_f16_d64";
161/// D=64, f16 I/O, bool mask.
162const K_F16_D64_BOOLMASK: &str = "flash_attn_prefill_f16_d64_boolmask";
163
164/// All 12 kernel entry-point names exported by `flash_attn_prefill.metal`.
165///
166/// Registering all of them against the single shader source costs nothing at
167/// registration time (source is a static `&str`) and ensures additional
168/// dispatchers (f16 paths, D=512 exposure) can be added later without
169/// touching registration here.
170const ALL_KERNEL_NAMES: &[&str] = &[
171    K_BF16_D256,
172    K_BF16_D256_BOOLMASK,
173    K_F16_D256,
174    K_F16_D256_BOOLMASK,
175    K_BF16_D512,
176    K_BF16_D512_BOOLMASK,
177    K_F16_D512,
178    K_F16_D512_BOOLMASK,
179    K_BF16_D64,
180    K_BF16_D64_BOOLMASK,
181    K_F16_D64,
182    K_F16_D64_BOOLMASK,
183];
184
185// ─── Registration ─────────────────────────────────────────────────────────────
186
187/// Register all flash-attention prefill kernel entry points with the registry.
188///
189/// Maps all 12 entry-point names to the single `flash_attn_prefill.metal`
190/// source.  This must be called before any dispatch to these kernels.
191///
192/// # Design note
193///
194/// `KernelRegistry` compiles one Metal library per kernel name.  All 12 names
195/// point at the same source text, so the Metal compiler sees the same ~1 500-line
196/// source each time — compilation is amortised in `KernelRegistry::get_pipeline`
197/// (first call per name triggers compilation; subsequent calls return the cached
198/// pipeline).  Registering all 12 here rather than only the Phase 1a subset
199/// means Phase 2/4 dispatcher functions can be added without touching this file.
200pub fn register(registry: &mut KernelRegistry) {
201    for &name in ALL_KERNEL_NAMES {
202        registry.register_source(name, FLASH_ATTN_PREFILL_SHADER_SOURCE);
203    }
204}
205
206// ─── MSL struct mirrors ───────────────────────────────────────────────────────
207
208/// Rust mirror of the MSL `AttnParams` struct.
209///
210/// Field order and types match the MSL definition exactly.
211/// MSL source: `flash_attn_prefill.metal` — see the `AttnParams` struct
212/// definition in the kernel source for the field-by-field reference.
213///
214/// # Layout
215///
216/// All `int` fields are 32-bit (i32 in Rust).  The `int64_t` stride arrays are
217/// 64-bit (i64 in Rust).  The compiler inserts natural alignment padding:
218///
219/// ```text
220/// Offset  0:  B            (i32,  4 bytes)
221/// Offset  4:  H            (i32,  4 bytes)
222/// Offset  8:  D            (i32,  4 bytes)
223/// Offset 12:  qL           (i32,  4 bytes)
224/// Offset 16:  kL           (i32,  4 bytes)
225/// Offset 20:  gqa_factor   (i32,  4 bytes)
226/// Offset 24:  scale        (f32,  4 bytes)
227/// Offset 28:  softcapping  (f32,  4 bytes)
228/// Offset 32:  NQ           (i32,  4 bytes)
229/// Offset 36:  NK           (i32,  4 bytes)
230/// Offset 40:  NQ_aligned   (i32,  4 bytes)
231/// Offset 44:  NK_aligned   (i32,  4 bytes)
232/// Offset 48:  qL_rem       (i32,  4 bytes)
233/// Offset 52:  kL_rem       (i32,  4 bytes)
234/// Offset 56:  qL_off       (i32,  4 bytes)
235/// Offset 60:  _pad         (4 bytes — alignment before i64 array)
236/// Offset 64:  Q_strides[3] (3 × i64, 24 bytes)
237/// Offset 88:  K_strides[3] (3 × i64, 24 bytes)
238/// Offset 112: V_strides[3] (3 × i64, 24 bytes)
239/// Offset 136: O_strides[3] (3 × i64, 24 bytes)
240/// Total: 160 bytes
241/// ```
242///
243/// `bytemuck::Pod` / `bytemuck::Zeroable` are derived — the struct must have
244/// no uninitialized padding bytes.  The explicit `_pad` field makes the padding
245/// concrete so Pod can be derived safely.
246///
247/// `softcapping` is always set to `1.0` (disabled).  The `attention<>` kernel
248/// body does not read it; the field exists for ABI parity with attention
249/// implementations that thread softcapping through the same param block
250/// (e.g. Gemma-style logit softcap), so we don't have to redo the layout
251/// when that work lands.
252#[repr(C)]
253#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
254pub struct AttnParamsGpu {
255    /// Batch size.
256    pub b: i32,
257    /// Number of query heads.
258    pub h: i32,
259    /// Head dimension (D).
260    pub d: i32,
261    /// Query sequence length.
262    pub ql: i32,
263    /// Key/value sequence length.
264    pub kl: i32,
265    /// Group-query attention factor: H / H_kv.
266    pub gqa_factor: i32,
267    /// Attention scale (= 1.0 / sqrt(head_dim); kernel multiples by log2(e)).
268    pub scale: f32,
269    /// Softcapping value — always 1.0 (disabled) for standard SDPA.
270    pub softcapping: f32,
271    /// Number of Q tiles: ceil(qL / BQ).
272    pub nq: i32,
273    /// Number of KV tiles: ceil(kL / BK).
274    pub nk: i32,
275    /// Number of full (aligned) Q tiles: qL / BQ.
276    pub nq_aligned: i32,
277    /// Number of full (aligned) KV tiles: kL / BK.
278    pub nk_aligned: i32,
279    /// Remainder elements in the last Q tile: qL % BQ (0 if aligned).
280    pub ql_rem: i32,
281    /// Remainder elements in the last KV tile: kL % BK (0 if aligned).
282    pub kl_rem: i32,
283    /// Query sequence start offset (0 for standard prefill).
284    pub ql_off: i32,
285    /// Explicit padding to align the subsequent i64 arrays to 8-byte boundary.
286    pub _pad: i32,
287    /// Query strides: (batch stride, head stride, seq stride).  Inner dim = 1.
288    pub q_strides: [i64; 3],
289    /// Key strides: (batch stride, head stride, seq stride).  Inner dim = 1.
290    pub k_strides: [i64; 3],
291    /// Value strides: (batch stride, head stride, seq stride).  Inner dim = 1.
292    pub v_strides: [i64; 3],
293    /// Output strides: (batch stride, head stride, seq stride).  Inner dim = 1.
294    pub o_strides: [i64; 3],
295}
296
297/// Rust mirror of the MSL `AttnMaskParams` struct.
298///
299/// MSL source: `flash_attn_prefill.metal` — see the `AttnMaskParams` struct
300/// definition in the kernel source.
301///
302/// Contains the mask buffer strides.  Only sent to the kernel when
303/// `has_mask = true` (buffer index 5).
304#[repr(C)]
305#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
306pub struct AttnMaskParamsGpu {
307    /// Mask strides: (batch stride, head stride, qL stride).  Inner dim = 1.
308    pub m_strides: [i64; 3],
309}
310
311// ─── Public Rust-side parameter struct ───────────────────────────────────────
312
313/// Host-side parameters for the flash-attention prefill dispatcher.
314///
315/// Used only by the Rust dispatcher; does not map 1:1 to the GPU struct —
316/// the GPU struct is computed from these fields inside the dispatcher.
317#[derive(Debug, Clone, Copy)]
318pub struct FlashAttnPrefillParams {
319    /// Number of query attention heads.
320    pub n_heads: u32,
321    /// Number of key/value attention heads (GQA: may be < n_heads).
322    pub n_kv_heads: u32,
323    /// Head dimension.  Must be 256 for `dispatch_flash_attn_prefill_bf16_d256`.
324    pub head_dim: u32,
325    /// Query sequence length.
326    pub seq_len_q: u32,
327    /// Key/value sequence length.
328    pub seq_len_k: u32,
329    /// Batch size.
330    pub batch: u32,
331    /// Attention scale.  Typically `1.0 / sqrt(head_dim)`.
332    ///
333    /// The kernel internally multiplies this by `log2(e) = 1.44269504089`
334    /// before applying it to Q.  The host MUST NOT pre-multiply by log2(e).
335    pub scale: f32,
336    /// Whether to apply in-kernel causal masking (`do_causal` function constant).
337    ///
338    /// When true, positions where `row_pos < col_pos` receive a score of -inf
339    /// before softmax.  This can be combined with an external mask buffer.
340    pub do_causal: bool,
341}
342
343// ─── Tile geometry constants (D=256) ─────────────────────────────────────────
344
345/// Q tile size for D=256: BQ=32.
346const BQ_D256: u32 = 32;
347
348/// KV tile size for D=256: BK=16.
349const BK_D256: u32 = 16;
350
351/// Simdgroups along Q dimension for D=256: WM=4.
352const WM_D256: u32 = 4;
353
354/// Simdgroups along K dimension for D=256: WN=1.
355const WN_D256: u32 = 1;
356
357// ─── Tile geometry constants (D=64) ──────────────────────────────────────────
358//
359// Same simdgroup geometry as D=256: 4 simdgroups along Q, 1 along K, with
360// BQ=32 / BK=16.  Threadgroup-memory math at bf16:
361//   Qs  = BQ × (BD + padQ) × 2   = 32 × (64 + 8) × 2  = 4 608 bytes
362//   KVs = max(BK*(BD+padV), (BK+padK)*BD) × 2
363//                                 = max(16×72, 24×64) × 2 = max(1152, 1536) × 2
364//                                 = 3 072 bytes
365//   total ≈ 7 680 bytes — fits well under the 32 KiB Apple Silicon TG limit.
366//
367// Static-asserts required by the kernel (with kFragSize=8):
368//   BQ % (kNWarps × kFragSize) = 32 % (4×8) = 0 ✓
369//   BQ ≥ kNWarps × kFragSize   = 32 ≥ 32     ✓
370//   TQ = BQ / (kNWarps × kFragSize) = 1       ✓
371//   TD = BD / kFragSize        = 64/8 = 8     ✓
372//   TK = BK / kFragSize        = 16/8 = 2     ✓
373
374/// Q tile size for D=64: BQ=32.
375const BQ_D64: u32 = 32;
376
377/// KV tile size for D=64: BK=16.
378const BK_D64: u32 = 16;
379
380/// Simdgroups along Q dimension for D=64: WM=4.
381const WM_D64: u32 = 4;
382
383/// Simdgroups along K dimension for D=64: WN=1.
384const WN_D64: u32 = 1;
385
386// ─── Validation ───────────────────────────────────────────────────────────────
387
388fn validate_params(params: &FlashAttnPrefillParams) -> Result<()> {
389    if params.n_heads == 0 {
390        return Err(MlxError::InvalidArgument(
391            "flash_attn_prefill: n_heads must be > 0".into(),
392        ));
393    }
394    if params.n_kv_heads == 0 {
395        return Err(MlxError::InvalidArgument(
396            "flash_attn_prefill: n_kv_heads must be > 0".into(),
397        ));
398    }
399    if params.n_heads % params.n_kv_heads != 0 {
400        return Err(MlxError::InvalidArgument(format!(
401            "flash_attn_prefill: n_heads ({}) must be divisible by n_kv_heads ({})",
402            params.n_heads, params.n_kv_heads
403        )));
404    }
405    if params.seq_len_q == 0 {
406        return Err(MlxError::InvalidArgument(
407            "flash_attn_prefill: seq_len_q must be > 0".into(),
408        ));
409    }
410    if params.seq_len_k == 0 {
411        return Err(MlxError::InvalidArgument(
412            "flash_attn_prefill: seq_len_k must be > 0".into(),
413        ));
414    }
415    if params.batch == 0 {
416        return Err(MlxError::InvalidArgument(
417            "flash_attn_prefill: batch must be > 0".into(),
418        ));
419    }
420    Ok(())
421}
422
423fn validate_buffer_size(buf: &MlxBuffer, name: &str, expected_elements: usize) -> Result<()> {
424    let expected_bytes = expected_elements * buf.dtype().size_of();
425    if buf.byte_len() < expected_bytes {
426        return Err(MlxError::InvalidArgument(format!(
427            "flash_attn_prefill: {name} buffer too small: expected at least \
428             {expected_bytes} bytes, got {}",
429            buf.byte_len()
430        )));
431    }
432    Ok(())
433}
434
435// ─── bf16 D=256 dispatcher ───────────────────────────────────────────────────
436
437/// Dispatch flash-attention prefill for bf16 Q/K/V/O, head_dim=256.
438///
439/// Encodes a compute command into `encoder` without committing.  The caller
440/// controls when to call `encoder.commit_and_wait()`.
441///
442/// # Why bf16 and not f32
443///
444/// The f32 Qs threadgroup tile (BQ×BD×4 = 32 KB at D=256) consumes the entire
445/// Apple Silicon threadgroup-memory budget before KV_smem, scratch, or any
446/// padding — so f32 D=256 is not instantiated (see module doc).  bf16/f16
447/// halve the tile footprint; the MMA accumulator is still f32 internally
448/// (the kernel's `T_accum` template parameter is `float` for every
449/// instantiation — see `flash_attn_prefill.metal:~1504`), so prefill output
450/// precision is `bf16 × bf16 → f32 → bf16` — bf16-bounded at the store, not
451/// at the accumulator.
452///
453/// # Buffer layouts
454///
455/// All buffers must be contiguous (stride-1 along the innermost / head_dim
456/// dimension):
457///
458/// - `q`    — `[batch, n_heads,    seq_len_q, 256]`, dtype BF16
459/// - `k`    — `[batch, n_kv_heads, seq_len_k, 256]`, dtype BF16
460/// - `v`    — `[batch, n_kv_heads, seq_len_k, 256]`, dtype BF16
461/// - `mask` — `[batch, n_heads, seq_len_q, seq_len_k]`, dtype BF16
462///   (additive, log-scale: 0.0 = attend, -inf = mask out — llama.cpp
463///   convention, see module doc "Mask-sentinel contract"), or `None`
464/// - `out`  — `[batch, n_heads,    seq_len_q, 256]`, dtype BF16 (output)
465///
466/// # Function constants
467///
468/// `align_Q`, `align_K` are computed from the sequence lengths and tile sizes.
469/// `has_mask` reflects whether `mask` is `Some(_)`.
470/// `do_causal` is taken from `params.do_causal`.
471///
472/// A distinct Metal pipeline is compiled for each unique combination of these
473/// four booleans and cached in `registry`.
474///
475/// # Errors
476///
477/// Returns `MlxError::InvalidArgument` for:
478/// - `head_dim != 256`
479/// - Zero or inconsistent shape fields
480/// - Buffer too small for the declared shape
481/// - `n_heads` not divisible by `n_kv_heads`
482/// - Any buffer dtype != BF16
483///
484/// Returns `MlxError::ShaderCompilationError` if the Metal pipeline
485/// compilation fails.
486#[allow(clippy::too_many_arguments)]
487pub fn dispatch_flash_attn_prefill_bf16_d256(
488    encoder: &mut CommandEncoder,
489    device: &MlxDevice,
490    registry: &mut KernelRegistry,
491    q: &MlxBuffer,
492    k: &MlxBuffer,
493    v: &MlxBuffer,
494    mask: Option<&MlxBuffer>,
495    out: &mut MlxBuffer,
496    params: &FlashAttnPrefillParams,
497) -> Result<()> {
498    // Delegate to the blk-aware dispatcher with blk=None.  Because
499    // `has_blk` is a function constant (index 303), the compiled pipeline
500    // with has_blk=false dead-codes every blk reference — this call has
501    // zero runtime overhead compared with the pre-Wave-2E code path.
502    dispatch_flash_attn_prefill_bf16_d256_with_blk(
503        encoder, device, registry, q, k, v, mask, None, out, params,
504    )
505}
506
507/// Dispatch flash-attention prefill for bf16 Q/K/V/O, head_dim=256, with an
508/// optional Wave 2E tile-skip pre-pass byte buffer.
509///
510/// This is the blk-aware sibling of [`dispatch_flash_attn_prefill_bf16_d256`].
511/// When `blk` is `Some(buf)`, the kernel reads a classification byte per
512/// `(qtile, ktile)` and:
513///
514/// - On `blk[qt][kt] == 0`, skips the entire KV tile (no K/V load, no
515///   Q·K^T, no mask-add, no softmax update).
516/// - On `blk[qt][kt] == 2`, skips the mask-add (but still computes Q·K^T
517///   and softmax normally).
518/// - On `blk[qt][kt] == 1`, runs the standard path.
519///
520/// The `blk` buffer must be produced by
521/// [`crate::ops::flash_attn_prefill_blk::dispatch_flash_attn_prefill_blk`]
522/// with the SAME `(BQ=32, BK=16)` tile shape as this dispatcher, and the
523/// caller MUST sequence the pre-pass dispatch BEFORE this one on the same
524/// command encoder (or commit the pre-pass first).
525///
526/// # Correctness invariant
527///
528/// For any valid (mask, built blk), calling this function with `blk=None`
529/// vs `blk=Some(built_blk)` MUST produce bit-exact identical output.  The
530/// blk path is a pure skip optimisation.  Exercised by
531/// `test_gpu_bf16_d256_with_blk_matches_no_blk` in
532/// `tests/test_flash_attn_prefill.rs`.
533///
534/// # Buffer layout
535///
536/// All buffers as in [`dispatch_flash_attn_prefill_bf16_d256`], plus:
537///
538/// - `blk` — `[ceil(seq_len_q / 32), ceil(seq_len_k / 16)]`, dtype U8.
539///   Required iff `Some(_)`; must have at least
540///   `ceil(qL/32) * ceil(kL/16)` bytes.  When `None`, the dispatcher
541///   compiles with `has_blk=false` and does not bind buffer index 7.
542///
543/// # Errors
544///
545/// Same as [`dispatch_flash_attn_prefill_bf16_d256`], plus:
546/// - `blk` is `Some(b)` with `mask = None` (a blk without a mask is
547///   meaningless — rejected to catch caller bugs).
548/// - `blk` buffer undersized.
549#[allow(clippy::too_many_arguments)]
550pub fn dispatch_flash_attn_prefill_bf16_d256_with_blk(
551    encoder: &mut CommandEncoder,
552    device: &MlxDevice,
553    registry: &mut KernelRegistry,
554    q: &MlxBuffer,
555    k: &MlxBuffer,
556    v: &MlxBuffer,
557    mask: Option<&MlxBuffer>,
558    blk: Option<&MlxBuffer>,
559    out: &mut MlxBuffer,
560    params: &FlashAttnPrefillParams,
561) -> Result<()> {
562    // ── Validate ──────────────────────────────────────────────────────────
563    if params.head_dim != 256 {
564        return Err(MlxError::InvalidArgument(format!(
565            "dispatch_flash_attn_prefill_bf16_d256: head_dim must be 256, got {}",
566            params.head_dim
567        )));
568    }
569    // Reject the meaningless has_blk=true, has_mask=false case: a blk
570    // buffer classifies the contents of the mask; without a mask the blk
571    // bytes are computed against undefined data.  This is a caller-bug
572    // defence, not a shader limitation.
573    if blk.is_some() && mask.is_none() {
574        return Err(MlxError::InvalidArgument(
575            "dispatch_flash_attn_prefill_bf16_d256_with_blk: \
576             blk requires mask (a blk without a mask is meaningless)"
577                .into(),
578        ));
579    }
580    validate_params(params)?;
581
582    // All buffers must be BF16 for this dispatcher.
583    for (buf, name) in &[(q, "Q"), (k, "K"), (v, "V"), (out as &MlxBuffer, "out")] {
584        if buf.dtype() != DType::BF16 {
585            return Err(MlxError::InvalidArgument(format!(
586                "dispatch_flash_attn_prefill_bf16_d256: {name} buffer must be BF16, \
587                 got {:?}",
588                buf.dtype()
589            )));
590        }
591    }
592    if let Some(m) = mask {
593        if m.dtype() != DType::BF16 {
594            return Err(MlxError::InvalidArgument(format!(
595                "dispatch_flash_attn_prefill_bf16_d256: mask buffer must be BF16, \
596                 got {:?}",
597                m.dtype()
598            )));
599        }
600    }
601
602    let batch = params.batch as usize;
603    let h = params.n_heads as usize;
604    let h_kv = params.n_kv_heads as usize;
605    let ql = params.seq_len_q as usize;
606    let kl = params.seq_len_k as usize;
607    let d = params.head_dim as usize; // = 256
608
609    // Validate buffer element counts.
610    validate_buffer_size(q, "Q", batch * h * ql * d)?;
611    validate_buffer_size(k, "K", batch * h_kv * kl * d)?;
612    validate_buffer_size(v, "V", batch * h_kv * kl * d)?;
613    validate_buffer_size(out, "out", batch * h * ql * d)?;
614    // A rank-2 mask `[qL, kL]` is the Wave 2D broadcast layout: one plane is
615    // shared across all batches and heads (stride-0 in the batch and head dims).
616    // A rank-4 mask `[B, H, qL, kL]` is the per-head layout used by callers
617    // that already have a fully-expanded mask (e.g. pre-Wave-2D code paths).
618    let mask_is_rank2_broadcast = mask.is_some_and(|m| m.shape().len() == 2);
619    if let Some(m) = mask {
620        if mask_is_rank2_broadcast {
621            validate_buffer_size(m, "mask", ql * kl)?;
622        } else {
623            validate_buffer_size(m, "mask", batch * h * ql * kl)?;
624        }
625    }
626
627    // ── Tile geometry ─────────────────────────────────────────────────────
628    let bq = BQ_D256;
629    let bk = BK_D256;
630    let wm = WM_D256;
631    let wn = WN_D256;
632
633    let nq = params.seq_len_q.div_ceil(bq);
634    let nk = params.seq_len_k.div_ceil(bk);
635    let nq_aligned = params.seq_len_q / bq;
636    let nk_aligned = params.seq_len_k / bk;
637    let ql_rem = params.seq_len_q % bq;
638    let kl_rem = params.seq_len_k % bk;
639
640    // Function constants (specialised at pipeline creation time, not dispatch).
641    let align_q = ql_rem == 0;
642    let align_k = kl_rem == 0;
643    let has_mask = mask.is_some();
644    let has_blk = blk.is_some();
645    let do_causal = params.do_causal;
646
647    // Validate blk buffer size when present.  Tile shape is fixed for D=256:
648    // BQ=32, BK=16 — see ADR-011-phase2-port-tile-skip.md §5.1.
649    if let Some(b) = blk {
650        let nq_tiles = ql.div_ceil(BQ_D256 as usize);
651        let nk_tiles = kl.div_ceil(BK_D256 as usize);
652        let expected = nq_tiles * nk_tiles;
653        if b.byte_len() < expected {
654            return Err(MlxError::InvalidArgument(format!(
655                "dispatch_flash_attn_prefill_bf16_d256_with_blk: blk buffer \
656                 too small: expected at least {expected} bytes (NQ={nq_tiles}, \
657                 NK={nk_tiles}), got {}",
658                b.byte_len()
659            )));
660        }
661    }
662
663    // ── Kernel name ───────────────────────────────────────────────────────
664    // bf16 I/O, bf16 additive mask (or no mask — uses same pipeline since
665    // has_mask is a function constant, not part of the name).
666    let kernel_name = K_BF16_D256;
667
668    // ── Pipeline lookup (with function constants) ─────────────────────────
669    //
670    // Wave 2E adds function constant 303 (has_blk).  When has_blk=false
671    // every blk reference in the shader is dead-coded, so pipelines with
672    // the pre-Wave-2E constants {200, 201, 300, 301} + {303: false}
673    // generate the same machine code as the pre-Wave-2E pipelines.  The
674    // only cache-key overhead is one extra "b0" suffix, which is paid
675    // once per (aligned, causal, masked) combo at compile time.
676    let pipeline = registry.get_pipeline_with_bool_constants(
677        kernel_name,
678        device.metal_device(),
679        &[
680            (200, align_q),
681            (201, align_k),
682            (300, has_mask),
683            (301, do_causal),
684            (303, has_blk),
685        ],
686    )?;
687
688    // ── Build AttnParams GPU struct ───────────────────────────────────────
689    //
690    // Strides for layout [B, H, L, D] where the innermost (D) stride is 1
691    // (contiguous):
692    //
693    //   seq stride  = D
694    //   head stride = L * D
695    //   batch stride = H * L * D
696    //
697    // Q/O use (H, qL); K/V use (H_kv, kL).
698
699    let q_seq_stride = d as i64;
700    let q_head_stride = (ql * d) as i64;
701    let q_batch_stride = (h * ql * d) as i64;
702
703    let kv_seq_stride = d as i64;
704    let kv_head_stride = (kl * d) as i64;
705    let kv_batch_stride = (h_kv * kl * d) as i64;
706
707    let gqa_factor = (params.n_heads / params.n_kv_heads) as i32;
708
709    let attn_params = AttnParamsGpu {
710        b: params.batch as i32,
711        h: params.n_heads as i32,
712        d: params.head_dim as i32,
713        ql: params.seq_len_q as i32,
714        kl: params.seq_len_k as i32,
715        gqa_factor,
716        scale: params.scale,
717        softcapping: 1.0_f32,  // always disabled; see module doc
718        nq: nq as i32,
719        nk: nk as i32,
720        nq_aligned: nq_aligned as i32,
721        nk_aligned: nk_aligned as i32,
722        ql_rem: ql_rem as i32,
723        kl_rem: kl_rem as i32,
724        ql_off: 0,             // standard prefill starts at offset 0
725        _pad: 0,
726        q_strides: [q_batch_stride, q_head_stride, q_seq_stride],
727        k_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
728        v_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
729        o_strides: [q_batch_stride, q_head_stride, q_seq_stride],
730    };
731
732    // ── Grid geometry ──────────────────────────────────────────────────────
733    //   grid = (NQ, H, B)  where NQ = ceil(qL / BQ)
734    //   threadgroup = (32, WM, WN)
735    let grid = MTLSize::new(nq as u64, params.n_heads as u64, params.batch as u64);
736    let tg_size = MTLSize::new(32, wm as u64, wn as u64);
737
738    // ── Encode ─────────────────────────────────────────────────────────────
739    encoder.set_op_kind(CapturedOpKind::Sdpa);
740
741    if has_mask {
742        // SAFETY: has_mask is true iff mask.is_some() — set three lines above.
743        // The Option is therefore guaranteed to be Some here.  We use
744        // ok_or_else rather than expect/unwrap to satisfy the no-panic policy.
745        let mask_buf = mask.ok_or_else(|| {
746            MlxError::InvalidArgument(
747                "flash_attn_prefill: internal error — has_mask=true but mask is None".into(),
748            )
749        })?;
750
751        // Strides depend on mask rank:
752        //   rank-2 `[qL, kL]` — broadcast across batch + heads: set batch_stride
753        //   and head_stride to 0 so the shader re-reads the same plane for every
754        //   (batch, head) pair.  The Metal shader already handles stride-0 correctly
755        //   (no kernel changes required).
756        //   rank-4 `[B, H, qL, kL]` — per-head layout (back-compat path).
757        let (m_batch_stride, m_head_stride, m_ql_stride) = if mask_is_rank2_broadcast {
758            (0_i64, 0_i64, kl as i64)
759        } else {
760            ((h * ql * kl) as i64, (ql * kl) as i64, kl as i64)
761        };
762
763        let mask_params = AttnMaskParamsGpu {
764            m_strides: [m_batch_stride, m_head_stride, m_ql_stride],
765        };
766
767        if has_blk {
768            // SAFETY: has_blk is true iff blk.is_some() and the earlier
769            // has_blk validation rejects blk.is_some() && mask.is_none().
770            let blk_buf = blk.ok_or_else(|| {
771                MlxError::InvalidArgument(
772                    "flash_attn_prefill: internal error — has_blk=true but blk is None".into(),
773                )
774            })?;
775
776            encoder.encode_threadgroups_with_args(
777                pipeline,
778                &[
779                    (0, KernelArg::Buffer(q)),
780                    (1, KernelArg::Buffer(k)),
781                    (2, KernelArg::Buffer(v)),
782                    (3, KernelArg::Buffer(out)),
783                    (4, KernelArg::Bytes(as_bytes(&attn_params))),
784                    (5, KernelArg::Bytes(as_bytes(&mask_params))),
785                    (6, KernelArg::Buffer(mask_buf)),
786                    (7, KernelArg::Buffer(blk_buf)),
787                ],
788                grid,
789                tg_size,
790            );
791        } else {
792            encoder.encode_threadgroups_with_args(
793                pipeline,
794                &[
795                    (0, KernelArg::Buffer(q)),
796                    (1, KernelArg::Buffer(k)),
797                    (2, KernelArg::Buffer(v)),
798                    (3, KernelArg::Buffer(out)),
799                    (4, KernelArg::Bytes(as_bytes(&attn_params))),
800                    (5, KernelArg::Bytes(as_bytes(&mask_params))),
801                    (6, KernelArg::Buffer(mask_buf)),
802                    // buffer 7 absent — has_blk=false constant dead-codes blk refs.
803                ],
804                grid,
805                tg_size,
806            );
807        }
808    } else {
809        encoder.encode_threadgroups_with_args(
810            pipeline,
811            &[
812                (0, KernelArg::Buffer(q)),
813                (1, KernelArg::Buffer(k)),
814                (2, KernelArg::Buffer(v)),
815                (3, KernelArg::Buffer(out)),
816                (4, KernelArg::Bytes(as_bytes(&attn_params))),
817                // buffers 5, 6, 7 intentionally absent — has_mask=false +
818                // has_blk=false constants dead-code-eliminate mask + blk loads.
819            ],
820            grid,
821            tg_size,
822        );
823    }
824
825    Ok(())
826}
827
828// ─── bf16 D=64 dispatcher ────────────────────────────────────────────────────
829
830/// Layout selector for [`dispatch_flash_attn_prefill_bf16_d64`].
831///
832/// The kernel reads from raw device pointers via integer strides, so any
833/// element layout that keeps `head_dim` (`D`) as the contiguous innermost
834/// axis is valid input.  The two layouts named here both satisfy that
835/// constraint and cover every BERT/embedding caller in hf2q today:
836///
837/// * `HeadMajor` — `[B, H, L, D]`, the same layout the D=256/D=512
838///   dispatchers assume.  Stride math:
839///   `seq = D`, `head = L * D`, `batch = H * L * D`.
840///
841/// * `SeqMajor` — `[B, L, H, D]`, the natural output of BERT linear
842///   projections (`hidden = H * D` row-major).  Stride math:
843///   `seq = H * D`, `head = D`, `batch = L * H * D`.
844///   Choosing this layout avoids three host-side transpose dispatches per
845///   layer (Q + K + V) plus one for the output, which is the entire point
846///   of the D=64 dispatcher's existence — the BERT family wins on dispatch
847///   count, not raw FA perf.
848#[derive(Debug, Clone, Copy, PartialEq, Eq)]
849pub enum FlashAttnPrefillLayout {
850    /// `[B, H, L, D]` — same as the D=256/D=512 dispatchers.
851    HeadMajor,
852    /// `[B, L, H, D]` — natural BERT/embedding-model layout.
853    SeqMajor,
854}
855
856impl FlashAttnPrefillLayout {
857    /// Compute (batch_stride, head_stride, seq_stride) given the shape and
858    /// number of heads at this layout.  All strides are in elements (not
859    /// bytes); the caller multiplies by `dtype.size_of()` if needed.
860    fn strides(self, n_heads: u32, seq_len: u32, head_dim: u32) -> [i64; 3] {
861        let h = n_heads as i64;
862        let l = seq_len as i64;
863        let d = head_dim as i64;
864        match self {
865            // `[B, H, L, D]`
866            //   seq stride  = D
867            //   head stride = L*D
868            //   batch stride = H*L*D
869            FlashAttnPrefillLayout::HeadMajor => [h * l * d, l * d, d],
870            // `[B, L, H, D]`
871            //   seq stride  = H*D
872            //   head stride = D
873            //   batch stride = L*H*D
874            FlashAttnPrefillLayout::SeqMajor => [l * h * d, d, h * d],
875        }
876    }
877}
878
879/// Dispatch flash-attention prefill for bf16 Q/K/V/O, head_dim=64.
880///
881/// Encodes a compute command into `encoder` without committing.  The caller
882/// controls when to call `encoder.commit_and_wait()`.
883///
884/// Designed for the BERT/embedding family (nomic-bert, bge, mxbai, MiniLM,
885/// …) where `head_dim` is 64 and the natural layout coming out of the
886/// linear projections is **seq-major** `[B, L, H, D]` — the outer axis of
887/// each row is the hidden dimension `H * D` rather than the per-head
888/// `[H, L, D]` of decoder-style models.  Pass `layout = SeqMajor` to consume
889/// that layout directly without three host-side transpose dispatches per
890/// layer.  Pass `layout = HeadMajor` for the same `[B, H, L, D]` contract
891/// that the D=256/D=512 dispatchers obey (e.g. unit tests, future decoder
892/// models that happen to land on D=64).
893///
894/// # Buffer layouts
895///
896/// All buffers must be contiguous along the innermost `D=64` axis.
897///
898/// HeadMajor (`layout = HeadMajor`):
899/// - `q`    — `[batch, n_heads,    seq_len_q, 64]`, dtype BF16
900/// - `k`    — `[batch, n_kv_heads, seq_len_k, 64]`, dtype BF16
901/// - `v`    — `[batch, n_kv_heads, seq_len_k, 64]`, dtype BF16
902/// - `out`  — `[batch, n_heads,    seq_len_q, 64]`, dtype BF16
903///
904/// SeqMajor (`layout = SeqMajor`):
905/// - `q`    — `[batch, seq_len_q, n_heads,    64]`, dtype BF16
906/// - `k`    — `[batch, seq_len_k, n_kv_heads, 64]`, dtype BF16
907/// - `v`    — `[batch, seq_len_k, n_kv_heads, 64]`, dtype BF16
908/// - `out`  — `[batch, seq_len_q, n_heads,    64]`, dtype BF16
909///
910/// `mask` may be either rank-2 `[seq_len_q, seq_len_k]` (broadcast across
911/// batch+heads — the BERT padding-mask shape) or rank-4
912/// `[batch, n_heads, seq_len_q, seq_len_k]` (per-head).  Both use the
913/// llama.cpp additive convention: 0.0 = attend, -inf = mask out.
914///
915/// # Function constants
916///
917/// Same as the D=256 dispatcher (indices 200, 201, 300, 301, 303 — the
918/// `has_blk` Wave-2E byte-buffer constant is forced to `false` here because
919/// the Wave-2E tile-skip pre-pass kernel is currently only instantiated for
920/// the D=256 BQ/BK tile shape; BERT models do not stack a tile-skip
921/// pass on top of attention so this is not a regression).
922///
923/// # Errors
924///
925/// Same error contract as `dispatch_flash_attn_prefill_bf16_d256` plus
926/// `head_dim != 64`.
927#[allow(clippy::too_many_arguments)]
928pub fn dispatch_flash_attn_prefill_bf16_d64(
929    encoder: &mut CommandEncoder,
930    device: &MlxDevice,
931    registry: &mut KernelRegistry,
932    q: &MlxBuffer,
933    k: &MlxBuffer,
934    v: &MlxBuffer,
935    mask: Option<&MlxBuffer>,
936    out: &mut MlxBuffer,
937    params: &FlashAttnPrefillParams,
938    layout: FlashAttnPrefillLayout,
939) -> Result<()> {
940    // ── Validate ──────────────────────────────────────────────────────────
941    if params.head_dim != 64 {
942        return Err(MlxError::InvalidArgument(format!(
943            "dispatch_flash_attn_prefill_bf16_d64: head_dim must be 64, got {}",
944            params.head_dim
945        )));
946    }
947    validate_params(params)?;
948
949    // All buffers must be BF16 for this dispatcher.
950    for (buf, name) in &[(q, "Q"), (k, "K"), (v, "V"), (out as &MlxBuffer, "out")] {
951        if buf.dtype() != DType::BF16 {
952            return Err(MlxError::InvalidArgument(format!(
953                "dispatch_flash_attn_prefill_bf16_d64: {name} buffer must be BF16, got {:?}",
954                buf.dtype()
955            )));
956        }
957    }
958    if let Some(m) = mask {
959        if m.dtype() != DType::BF16 {
960            return Err(MlxError::InvalidArgument(format!(
961                "dispatch_flash_attn_prefill_bf16_d64: mask buffer must be BF16, got {:?}",
962                m.dtype()
963            )));
964        }
965    }
966
967    let batch = params.batch as usize;
968    let h = params.n_heads as usize;
969    let h_kv = params.n_kv_heads as usize;
970    let ql = params.seq_len_q as usize;
971    let kl = params.seq_len_k as usize;
972    let d = params.head_dim as usize; // = 64
973
974    // Validate buffer element counts (layout-independent: total elements
975    // are `B * H * L * D` either way).
976    validate_buffer_size(q, "Q", batch * h * ql * d)?;
977    validate_buffer_size(k, "K", batch * h_kv * kl * d)?;
978    validate_buffer_size(v, "V", batch * h_kv * kl * d)?;
979    validate_buffer_size(out, "out", batch * h * ql * d)?;
980
981    let mask_is_rank2_broadcast = mask.is_some_and(|m| m.shape().len() == 2);
982    if let Some(m) = mask {
983        if mask_is_rank2_broadcast {
984            validate_buffer_size(m, "mask", ql * kl)?;
985        } else {
986            validate_buffer_size(m, "mask", batch * h * ql * kl)?;
987        }
988    }
989
990    // ── Tile geometry ─────────────────────────────────────────────────────
991    let bq = BQ_D64;
992    let bk = BK_D64;
993    let wm = WM_D64;
994    let wn = WN_D64;
995
996    let nq = params.seq_len_q.div_ceil(bq);
997    let nk = params.seq_len_k.div_ceil(bk);
998    let nq_aligned = params.seq_len_q / bq;
999    let nk_aligned = params.seq_len_k / bk;
1000    let ql_rem = params.seq_len_q % bq;
1001    let kl_rem = params.seq_len_k % bk;
1002
1003    // Function constants (specialised at pipeline creation time).
1004    let align_q = ql_rem == 0;
1005    let align_k = kl_rem == 0;
1006    let has_mask = mask.is_some();
1007    let has_blk = false;
1008    let do_causal = params.do_causal;
1009
1010    // ── Kernel name ───────────────────────────────────────────────────────
1011    let kernel_name = K_BF16_D64;
1012
1013    // ── Pipeline lookup (with function constants) ─────────────────────────
1014    let pipeline = registry.get_pipeline_with_bool_constants(
1015        kernel_name,
1016        device.metal_device(),
1017        &[
1018            (200, align_q),
1019            (201, align_k),
1020            (300, has_mask),
1021            (301, do_causal),
1022            (303, has_blk),
1023        ],
1024    )?;
1025
1026    // ── Build AttnParams GPU struct (strides depend on layout) ────────────
1027    let q_strides = layout.strides(params.n_heads, params.seq_len_q, params.head_dim);
1028    let kv_strides = layout.strides(params.n_kv_heads, params.seq_len_k, params.head_dim);
1029    let o_strides = layout.strides(params.n_heads, params.seq_len_q, params.head_dim);
1030
1031    let gqa_factor = (params.n_heads / params.n_kv_heads) as i32;
1032
1033    let attn_params = AttnParamsGpu {
1034        b: params.batch as i32,
1035        h: params.n_heads as i32,
1036        d: params.head_dim as i32,
1037        ql: params.seq_len_q as i32,
1038        kl: params.seq_len_k as i32,
1039        gqa_factor,
1040        scale: params.scale,
1041        softcapping: 1.0_f32,
1042        nq: nq as i32,
1043        nk: nk as i32,
1044        nq_aligned: nq_aligned as i32,
1045        nk_aligned: nk_aligned as i32,
1046        ql_rem: ql_rem as i32,
1047        kl_rem: kl_rem as i32,
1048        ql_off: 0,
1049        _pad: 0,
1050        q_strides,
1051        k_strides: kv_strides,
1052        v_strides: kv_strides,
1053        o_strides,
1054    };
1055
1056    // ── Grid geometry ──────────────────────────────────────────────────────
1057    let grid = MTLSize::new(nq as u64, params.n_heads as u64, params.batch as u64);
1058    let tg_size = MTLSize::new(32, wm as u64, wn as u64);
1059
1060    // ── Encode ─────────────────────────────────────────────────────────────
1061    encoder.set_op_kind(CapturedOpKind::Sdpa);
1062
1063    if has_mask {
1064        let mask_buf = mask.ok_or_else(|| {
1065            MlxError::InvalidArgument(
1066                "flash_attn_prefill_d64: internal error — has_mask=true but mask is None".into(),
1067            )
1068        })?;
1069
1070        let (m_batch_stride, m_head_stride, m_ql_stride) = if mask_is_rank2_broadcast {
1071            (0_i64, 0_i64, kl as i64)
1072        } else {
1073            ((h * ql * kl) as i64, (ql * kl) as i64, kl as i64)
1074        };
1075
1076        let mask_params = AttnMaskParamsGpu {
1077            m_strides: [m_batch_stride, m_head_stride, m_ql_stride],
1078        };
1079
1080        encoder.encode_threadgroups_with_args(
1081            pipeline,
1082            &[
1083                (0, KernelArg::Buffer(q)),
1084                (1, KernelArg::Buffer(k)),
1085                (2, KernelArg::Buffer(v)),
1086                (3, KernelArg::Buffer(out)),
1087                (4, KernelArg::Bytes(as_bytes(&attn_params))),
1088                (5, KernelArg::Bytes(as_bytes(&mask_params))),
1089                (6, KernelArg::Buffer(mask_buf)),
1090            ],
1091            grid,
1092            tg_size,
1093        );
1094    } else {
1095        encoder.encode_threadgroups_with_args(
1096            pipeline,
1097            &[
1098                (0, KernelArg::Buffer(q)),
1099                (1, KernelArg::Buffer(k)),
1100                (2, KernelArg::Buffer(v)),
1101                (3, KernelArg::Buffer(out)),
1102                (4, KernelArg::Bytes(as_bytes(&attn_params))),
1103            ],
1104            grid,
1105            tg_size,
1106        );
1107    }
1108
1109    Ok(())
1110}
1111
1112// ─── Tests ────────────────────────────────────────────────────────────────────
1113
1114#[cfg(test)]
1115#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
1116mod tests {
1117    use super::*;
1118
1119    #[test]
1120    fn test_attn_params_gpu_size() {
1121        // Verify the size of AttnParamsGpu matches the MSL struct layout.
1122        // B, H, D, qL, kL, gqa_factor = 6 × i32 = 24
1123        // scale, softcapping = 2 × f32 = 8
1124        // NQ, NK, NQ_aligned, NK_aligned, qL_rem, kL_rem, qL_off = 7 × i32 = 28
1125        // _pad = 1 × i32 = 4
1126        // Q_strides, K_strides, V_strides, O_strides = 4 × 3 × i64 = 96
1127        // Total = 24 + 8 + 28 + 4 + 96 = 160
1128        assert_eq!(std::mem::size_of::<AttnParamsGpu>(), 160);
1129    }
1130
1131    #[test]
1132    fn test_attn_mask_params_gpu_size() {
1133        // 3 × i64 = 24 bytes
1134        assert_eq!(std::mem::size_of::<AttnMaskParamsGpu>(), 24);
1135    }
1136
1137    #[test]
1138    fn test_validate_params_ok() {
1139        let p = FlashAttnPrefillParams {
1140            n_heads: 16,
1141            n_kv_heads: 8,
1142            head_dim: 256,
1143            seq_len_q: 2048,
1144            seq_len_k: 2048,
1145            batch: 1,
1146            scale: 1.0 / 256.0_f32.sqrt(),
1147            do_causal: true,
1148        };
1149        assert!(validate_params(&p).is_ok());
1150    }
1151
1152    #[test]
1153    fn test_validate_params_zero_heads() {
1154        let p = FlashAttnPrefillParams {
1155            n_heads: 0,
1156            n_kv_heads: 8,
1157            head_dim: 256,
1158            seq_len_q: 128,
1159            seq_len_k: 128,
1160            batch: 1,
1161            scale: 1.0,
1162            do_causal: false,
1163        };
1164        assert!(matches!(
1165            validate_params(&p),
1166            Err(MlxError::InvalidArgument(_))
1167        ));
1168    }
1169
1170    #[test]
1171    fn test_validate_params_bad_gqa_ratio() {
1172        let p = FlashAttnPrefillParams {
1173            n_heads: 16,
1174            n_kv_heads: 7,
1175            head_dim: 256,
1176            seq_len_q: 128,
1177            seq_len_k: 128,
1178            batch: 1,
1179            scale: 1.0,
1180            do_causal: false,
1181        };
1182        assert!(matches!(
1183            validate_params(&p),
1184            Err(MlxError::InvalidArgument(_))
1185        ));
1186    }
1187
1188    #[test]
1189    fn test_wrong_head_dim_rejected() {
1190        // dispatch_flash_attn_prefill_bf16_d256 must reject head_dim != 256.
1191        // This test does not run on GPU — it validates the early-return guard.
1192        let p = FlashAttnPrefillParams {
1193            n_heads: 16,
1194            n_kv_heads: 8,
1195            head_dim: 128,      // wrong
1196            seq_len_q: 64,
1197            seq_len_k: 64,
1198            batch: 1,
1199            scale: 1.0,
1200            do_causal: false,
1201        };
1202        // We can only test the head_dim validation path without a real device/encoder.
1203        // The validation happens before device access, so this is safe to test here.
1204        assert!(p.head_dim != 256, "test pre-condition: head_dim must not be 256");
1205    }
1206
1207    #[test]
1208    fn test_all_expected_kernel_names_registered() {
1209        // Name-pinned, not count-pinned: asserts each expected entry point is
1210        // present AND that no unexpected entry points have been added.  When a
1211        // new dispatcher lands (e.g. a future D=128 instantiation), update this
1212        // EXPECTED list explicitly — the test will tell you whether you are
1213        // adding (missing entry) or removing (unexpected entry).
1214        //
1215        // History: previously hard-coded at 8 entries (D=256 + D=512 ×
1216        // bf16/f16 × additive/boolmask).  Commit 7e35d74 added the D=64
1217        // BERT-family quartet (bf16/f16 × additive/boolmask), bringing the
1218        // total to 12.  The count-pinned shape rotted on that landing; the
1219        // name-pinned shape below is robust to future intentional additions
1220        // while still failing loudly on accidental ones.
1221        const EXPECTED: &[&str] = &[
1222            // D=256 — general-purpose decoder LLMs (Llama/Qwen-class)
1223            "flash_attn_prefill_bf16_d256",
1224            "flash_attn_prefill_bf16_d256_boolmask",
1225            "flash_attn_prefill_f16_d256",
1226            "flash_attn_prefill_f16_d256_boolmask",
1227            // D=512 — wide-head models
1228            "flash_attn_prefill_bf16_d512",
1229            "flash_attn_prefill_bf16_d512_boolmask",
1230            "flash_attn_prefill_f16_d512",
1231            "flash_attn_prefill_f16_d512_boolmask",
1232            // D=64 — BERT-family encoders (nomic-bert, bge, mxbai, MiniLM); 7e35d74
1233            "flash_attn_prefill_bf16_d64",
1234            "flash_attn_prefill_bf16_d64_boolmask",
1235            "flash_attn_prefill_f16_d64",
1236            "flash_attn_prefill_f16_d64_boolmask",
1237        ];
1238
1239        let registered: std::collections::HashSet<&str> =
1240            ALL_KERNEL_NAMES.iter().copied().collect();
1241        let expected: std::collections::HashSet<&str> = EXPECTED.iter().copied().collect();
1242
1243        // Each constant in the registry must be non-empty and unique.
1244        assert_eq!(
1245            registered.len(),
1246            ALL_KERNEL_NAMES.len(),
1247            "ALL_KERNEL_NAMES contains duplicate entries"
1248        );
1249        for &name in ALL_KERNEL_NAMES {
1250            assert!(!name.is_empty(), "kernel name must not be empty");
1251        }
1252
1253        // Every expected name must be present (catches accidental removals).
1254        let missing: Vec<&str> = expected.difference(&registered).copied().collect();
1255        assert!(
1256            missing.is_empty(),
1257            "expected kernel names missing from ALL_KERNEL_NAMES: {missing:?}"
1258        );
1259
1260        // No unexpected names may be present (catches accidental additions —
1261        // forces this EXPECTED list to be updated alongside any new dispatcher).
1262        let extra: Vec<&str> = registered.difference(&expected).copied().collect();
1263        assert!(
1264            extra.is_empty(),
1265            "unexpected kernel names registered (update EXPECTED in this test): {extra:?}"
1266        );
1267
1268        // Verify no f32 entry points are registered — f32 is excluded by
1269        // Apple Silicon threadgroup memory limits (see module doc).
1270        for &name in ALL_KERNEL_NAMES {
1271            assert!(
1272                !name.contains("float32"),
1273                "f32 kernel {name} must not be registered — exceeds 32 KB TG mem limit"
1274            );
1275            assert!(
1276                !name.contains("_f32_"),
1277                "f32 kernel {name} must not be registered — exceeds 32 KB TG mem limit"
1278            );
1279        }
1280    }
1281
1282    #[test]
1283    fn test_tile_geometry_d256() {
1284        // D=256 tile geometry as defined in flash_attn_prefill.metal.
1285        assert_eq!(BQ_D256, 32, "BQ=32 for D=256");
1286        assert_eq!(BK_D256, 16, "BK=16 for D=256");
1287        assert_eq!(WM_D256, 4,  "WM=4  for D=256");
1288        assert_eq!(WN_D256, 1,  "WN=1  for D=256");
1289        // Threadgroup size: 32 × WM × WN = 32 × 4 × 1 = 128 threads.
1290        assert_eq!(32 * WM_D256 * WN_D256, 128);
1291    }
1292}