mlx_native/ops/flash_attn_prefill.rs
1//! Flash-attention-style tiled prefill kernel — host dispatch.
2//!
3//! mlx-native's batched-prefill SDPA kernel, the prefill counterpart to
4//! `flash_attn_vec` (which handles the seq_len=1 decode case). Implements
5//! online softmax + simdgroup MMA tiling on Apple GPU.
6//!
7//! ## Kernel variants registered
8//!
9//! Twelve entry points are registered, all backed by the single
10//! `flash_attn_prefill.metal` shader source:
11//!
12//! ### D=64 (BQ=32, BK=16, WM=4, WN=1 — 128 threads/threadgroup, BERT family)
13//!
14//! | Kernel name | I/O dtype | Mask kind |
15//! |---|---|---|
16//! | `flash_attn_prefill_bf16_d64` | bf16 | bf16 additive |
17//! | `flash_attn_prefill_bf16_d64_boolmask` | bf16 | bool |
18//! | `flash_attn_prefill_f16_d64` | f16 | f16 additive |
19//! | `flash_attn_prefill_f16_d64_boolmask` | f16 | bool |
20//!
21//! ### D=256 (BQ=32, BK=16, WM=4, WN=1 — 128 threads/threadgroup)
22//!
23//! | Kernel name | I/O dtype | Mask kind |
24//! |---|---|---|
25//! | `flash_attn_prefill_bf16_d256` | bf16 | bf16 additive (log-domain) |
26//! | `flash_attn_prefill_bf16_d256_boolmask` | bf16 | bool (`is_attended`) |
27//! | `flash_attn_prefill_f16_d256` | f16 | f16 additive |
28//! | `flash_attn_prefill_f16_d256_boolmask` | f16 | bool |
29//!
30//! ### D=512 (BQ=8, BK=8, WM=1, WN=1 — 32 threads/threadgroup, 1 simdgroup)
31//!
32//! | Kernel name | I/O dtype | Mask kind |
33//! |---|---|---|
34//! | `flash_attn_prefill_bf16_d512` | bf16 | bf16 additive |
35//! | `flash_attn_prefill_bf16_d512_boolmask` | bf16 | bool |
36//! | `flash_attn_prefill_f16_d512` | f16 | f16 additive |
37//! | `flash_attn_prefill_f16_d512_boolmask` | f16 | bool |
38//!
39//! ### f32 is NOT instantiated at any D
40//!
41//! The f32 Qs threadgroup tile alone is `BQ * BD * 4` bytes — at D=256 this
42//! is 32 KB exactly, the Apple Silicon `MTLDevice.maxThreadgroupMemoryLength`
43//! hard limit, before KV_smem or scratch. Verified empirically on M5 Max:
44//! f32 D=256 requires ~53.7 KB and fails at library compile. bf16 and f16
45//! halve the tile footprint (~29 KB) and fit within the limit. D=512 f32
46//! is excluded for the same reason. D=64 inherits the same exclusion for
47//! consistency (bf16/f16 only across the entire dispatcher family). f32
48//! correctness is verified at the CPU reference layer in
49//! `tests/test_flash_attn_prefill.rs`. See
50//! `ADR-011-phase1-port-source-decision.md` §3 for the full
51//! threadgroup-memory analysis.
52//!
53//! ## Function constants
54//!
55//! The kernel declares four Metal function constants that must be specialised
56//! at pipeline creation time (not at dispatch time):
57//!
58//! - Index 200: `align_Q` (bool) — true when `qL % BQ == 0`
59//! - Index 201: `align_K` (bool) — true when `kL % BK == 0`
60//! - Index 300: `has_mask` (bool) — true when a mask buffer is bound
61//! - Index 301: `do_causal` (bool) — true for in-kernel causal masking
62//!
63//! These are plumbed via [`KernelRegistry::get_pipeline_with_bool_constants`],
64//! which caches compiled pipelines keyed by `(kernel_name, align_Q, align_K,
65//! has_mask, do_causal)`. Pipeline compilation is amortised: the slow path
66//! runs only once per unique `(name, booleans)` combination.
67//!
68//! ## Buffer layout (indices match the MSL kernel)
69//!
70//! - `buffer(0)` — Q `[B, H, qL, D]` device, contiguous inner dim
71//! - `buffer(1)` — K `[B, H_kv, kL, D]` device, contiguous inner dim
72//! - `buffer(2)` — V `[B, H_kv, kL, D]` device, contiguous inner dim
73//! - `buffer(3)` — O `[B, H, qL, D]` device, written by kernel
74//! - `buffer(4)` — `AttnParams` constant buffer (this module's [`AttnParamsGpu`])
75//! - `buffer(5)` — `AttnMaskParams` constant buffer (only when `has_mask=true`)
76//! - `buffer(6)` — mask data buffer (only when `has_mask=true`)
77//!
78//! ## Grid geometry
79//!
80//! - Threadgroups: `(ceil(qL / BQ), H, B)`
81//! - Threads per threadgroup: `(32, WM, WN)`
82//! - D=256: 128 threads (4 simdgroups × 32 lanes).
83//! - D=512: 32 threads (1 simdgroup × 32 lanes).
84//!
85//! ## Scale convention
86//!
87//! Pass `scale = 1.0 / sqrt(head_dim)`. The kernel multiplies internally by
88//! `log2(e) ≈ 1.44269504` and uses `fast::exp2` throughout — so the host
89//! MUST NOT pre-multiply by `log2(e)`.
90//!
91//! ## Mask-sentinel contract (llama.cpp convention)
92//!
93//! The additive mask buffer (bf16/f16 for the additive dispatchers) uses
94//! the llama.cpp CPU-side convention: **masked positions = `-INFINITY`**
95//! (IEEE-754 f32 `0xFF800000`, cast to the I/O dtype — both `half(-inf)`
96//! and `bfloat16(-inf)` have a real `-inf` encoding that the kernel
97//! consumes correctly). Attended positions = `0.0`.
98//!
99//! This matches llama.cpp's mask-authoring sites at
100//! `llama-graph.cpp:421, 436, 557` and `llama-kv-cache.cpp:1572`, where
101//! the CPU writes raw `-INFINITY` and the flash-attn cast to f16 saturates
102//! to f16-`-INFINITY`.
103//!
104//! The kernel does NOT require masks to use `-FLT_MAX/2` or any other
105//! finite "large negative" sentinel. The `-FLT_MAX/2` value is the
106//! kernel-internal `M` (running row-max) initialiser — a finite sentinel
107//! that absorbs `-inf` scores via `simd_max` without ever letting `M`
108//! become `-inf`, which in turn lets every `exp(score - M)` evaluate as
109//! `exp(-inf) = 0.0` (IEEE-754 exact) rather than `exp(-inf - -inf) =
110//! exp(NaN) = NaN`. See ADR-011-phase2-port-sentinel.md §1.
111//!
112//! Callers may pass arbitrary finite additive biases in the mask for
113//! non-masked positions (e.g. ALiBi, relative-position biases); only
114//! "fully block this K position" requires `-inf`.
115//!
116//! ## See also
117//!
118//! - Kernel: `/opt/mlx-native/src/shaders/flash_attn_prefill.metal`
119//! - ADR-011: `/opt/hf2q/docs/ADR-011-flash-attn-prefill.md`
120//! - ADR-011 phase 2 sentinel port: `/opt/hf2q/docs/ADR-011-phase2-port-sentinel.md`
121
122use metal::MTLSize;
123
124use crate::buffer::MlxBuffer;
125use crate::device::MlxDevice;
126use crate::encoder::{CapturedOpKind, CommandEncoder, KernelArg, as_bytes};
127use crate::error::{MlxError, Result};
128use crate::kernel_registry::KernelRegistry;
129use crate::DType;
130
131// ─── Shader source ───────────────────────────────────────────────────────────
132
133/// MSL source for the flash-attention prefill kernel (embedded at compile time).
134pub static FLASH_ATTN_PREFILL_SHADER_SOURCE: &str =
135 include_str!("../shaders/flash_attn_prefill.metal");
136
137// ─── All 12 kernel entry-point names ─────────────────────────────────────────
138
139/// D=256, bf16 I/O, bf16 additive mask.
140const K_BF16_D256: &str = "flash_attn_prefill_bf16_d256";
141/// D=256, bf16 I/O, bool (`is_attended`) mask.
142const K_BF16_D256_BOOLMASK: &str = "flash_attn_prefill_bf16_d256_boolmask";
143/// D=256, f16 I/O, f16 additive mask.
144const K_F16_D256: &str = "flash_attn_prefill_f16_d256";
145/// D=256, f16 I/O, bool mask.
146const K_F16_D256_BOOLMASK: &str = "flash_attn_prefill_f16_d256_boolmask";
147/// D=512, bf16 I/O, bf16 additive mask.
148const K_BF16_D512: &str = "flash_attn_prefill_bf16_d512";
149/// D=512, bf16 I/O, bool mask.
150const K_BF16_D512_BOOLMASK: &str = "flash_attn_prefill_bf16_d512_boolmask";
151/// D=512, f16 I/O, f16 additive mask.
152const K_F16_D512: &str = "flash_attn_prefill_f16_d512";
153/// D=512, f16 I/O, bool mask.
154const K_F16_D512_BOOLMASK: &str = "flash_attn_prefill_f16_d512_boolmask";
155/// D=64, bf16 I/O, bf16 additive mask (BERT family).
156const K_BF16_D64: &str = "flash_attn_prefill_bf16_d64";
157/// D=64, bf16 I/O, bool (`is_attended`) mask.
158const K_BF16_D64_BOOLMASK: &str = "flash_attn_prefill_bf16_d64_boolmask";
159/// D=64, f16 I/O, f16 additive mask.
160const K_F16_D64: &str = "flash_attn_prefill_f16_d64";
161/// D=64, f16 I/O, bool mask.
162const K_F16_D64_BOOLMASK: &str = "flash_attn_prefill_f16_d64_boolmask";
163
164/// All 12 kernel entry-point names exported by `flash_attn_prefill.metal`.
165///
166/// Registering all of them against the single shader source costs nothing at
167/// registration time (source is a static `&str`) and ensures additional
168/// dispatchers (f16 paths, D=512 exposure) can be added later without
169/// touching registration here.
170const ALL_KERNEL_NAMES: &[&str] = &[
171 K_BF16_D256,
172 K_BF16_D256_BOOLMASK,
173 K_F16_D256,
174 K_F16_D256_BOOLMASK,
175 K_BF16_D512,
176 K_BF16_D512_BOOLMASK,
177 K_F16_D512,
178 K_F16_D512_BOOLMASK,
179 K_BF16_D64,
180 K_BF16_D64_BOOLMASK,
181 K_F16_D64,
182 K_F16_D64_BOOLMASK,
183];
184
185// ─── Registration ─────────────────────────────────────────────────────────────
186
187/// Register all flash-attention prefill kernel entry points with the registry.
188///
189/// Maps all 12 entry-point names to the single `flash_attn_prefill.metal`
190/// source. This must be called before any dispatch to these kernels.
191///
192/// # Design note
193///
194/// `KernelRegistry` compiles one Metal library per kernel name. All 12 names
195/// point at the same source text, so the Metal compiler sees the same ~1 500-line
196/// source each time — compilation is amortised in `KernelRegistry::get_pipeline`
197/// (first call per name triggers compilation; subsequent calls return the cached
198/// pipeline). Registering all 12 here rather than only the Phase 1a subset
199/// means Phase 2/4 dispatcher functions can be added without touching this file.
200pub fn register(registry: &mut KernelRegistry) {
201 for &name in ALL_KERNEL_NAMES {
202 registry.register_source(name, FLASH_ATTN_PREFILL_SHADER_SOURCE);
203 }
204}
205
206// ─── MSL struct mirrors ───────────────────────────────────────────────────────
207
208/// Rust mirror of the MSL `AttnParams` struct.
209///
210/// Field order and types match the MSL definition exactly.
211/// MSL source: `flash_attn_prefill.metal` — see the `AttnParams` struct
212/// definition in the kernel source for the field-by-field reference.
213///
214/// # Layout
215///
216/// All `int` fields are 32-bit (i32 in Rust). The `int64_t` stride arrays are
217/// 64-bit (i64 in Rust). The compiler inserts natural alignment padding:
218///
219/// ```text
220/// Offset 0: B (i32, 4 bytes)
221/// Offset 4: H (i32, 4 bytes)
222/// Offset 8: D (i32, 4 bytes)
223/// Offset 12: qL (i32, 4 bytes)
224/// Offset 16: kL (i32, 4 bytes)
225/// Offset 20: gqa_factor (i32, 4 bytes)
226/// Offset 24: scale (f32, 4 bytes)
227/// Offset 28: softcapping (f32, 4 bytes)
228/// Offset 32: NQ (i32, 4 bytes)
229/// Offset 36: NK (i32, 4 bytes)
230/// Offset 40: NQ_aligned (i32, 4 bytes)
231/// Offset 44: NK_aligned (i32, 4 bytes)
232/// Offset 48: qL_rem (i32, 4 bytes)
233/// Offset 52: kL_rem (i32, 4 bytes)
234/// Offset 56: qL_off (i32, 4 bytes)
235/// Offset 60: _pad (4 bytes — alignment before i64 array)
236/// Offset 64: Q_strides[3] (3 × i64, 24 bytes)
237/// Offset 88: K_strides[3] (3 × i64, 24 bytes)
238/// Offset 112: V_strides[3] (3 × i64, 24 bytes)
239/// Offset 136: O_strides[3] (3 × i64, 24 bytes)
240/// Total: 160 bytes
241/// ```
242///
243/// `bytemuck::Pod` / `bytemuck::Zeroable` are derived — the struct must have
244/// no uninitialized padding bytes. The explicit `_pad` field makes the padding
245/// concrete so Pod can be derived safely.
246///
247/// `softcapping` is always set to `1.0` (disabled). The `attention<>` kernel
248/// body does not read it; the field exists for ABI parity with attention
249/// implementations that thread softcapping through the same param block
250/// (e.g. Gemma-style logit softcap), so we don't have to redo the layout
251/// when that work lands.
252#[repr(C)]
253#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
254pub struct AttnParamsGpu {
255 /// Batch size.
256 pub b: i32,
257 /// Number of query heads.
258 pub h: i32,
259 /// Head dimension (D).
260 pub d: i32,
261 /// Query sequence length.
262 pub ql: i32,
263 /// Key/value sequence length.
264 pub kl: i32,
265 /// Group-query attention factor: H / H_kv.
266 pub gqa_factor: i32,
267 /// Attention scale (= 1.0 / sqrt(head_dim); kernel multiples by log2(e)).
268 pub scale: f32,
269 /// Softcapping value — always 1.0 (disabled) for standard SDPA.
270 pub softcapping: f32,
271 /// Number of Q tiles: ceil(qL / BQ).
272 pub nq: i32,
273 /// Number of KV tiles: ceil(kL / BK).
274 pub nk: i32,
275 /// Number of full (aligned) Q tiles: qL / BQ.
276 pub nq_aligned: i32,
277 /// Number of full (aligned) KV tiles: kL / BK.
278 pub nk_aligned: i32,
279 /// Remainder elements in the last Q tile: qL % BQ (0 if aligned).
280 pub ql_rem: i32,
281 /// Remainder elements in the last KV tile: kL % BK (0 if aligned).
282 pub kl_rem: i32,
283 /// Query sequence start offset (0 for standard prefill).
284 pub ql_off: i32,
285 /// Explicit padding to align the subsequent i64 arrays to 8-byte boundary.
286 pub _pad: i32,
287 /// Query strides: (batch stride, head stride, seq stride). Inner dim = 1.
288 pub q_strides: [i64; 3],
289 /// Key strides: (batch stride, head stride, seq stride). Inner dim = 1.
290 pub k_strides: [i64; 3],
291 /// Value strides: (batch stride, head stride, seq stride). Inner dim = 1.
292 pub v_strides: [i64; 3],
293 /// Output strides: (batch stride, head stride, seq stride). Inner dim = 1.
294 pub o_strides: [i64; 3],
295}
296
297/// Rust mirror of the MSL `AttnMaskParams` struct.
298///
299/// MSL source: `flash_attn_prefill.metal` — see the `AttnMaskParams` struct
300/// definition in the kernel source.
301///
302/// Contains the mask buffer strides. Only sent to the kernel when
303/// `has_mask = true` (buffer index 5).
304#[repr(C)]
305#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
306pub struct AttnMaskParamsGpu {
307 /// Mask strides: (batch stride, head stride, qL stride). Inner dim = 1.
308 pub m_strides: [i64; 3],
309}
310
311// ─── Public Rust-side parameter struct ───────────────────────────────────────
312
313/// Host-side parameters for the flash-attention prefill dispatcher.
314///
315/// Used only by the Rust dispatcher; does not map 1:1 to the GPU struct —
316/// the GPU struct is computed from these fields inside the dispatcher.
317#[derive(Debug, Clone, Copy)]
318pub struct FlashAttnPrefillParams {
319 /// Number of query attention heads.
320 pub n_heads: u32,
321 /// Number of key/value attention heads (GQA: may be < n_heads).
322 pub n_kv_heads: u32,
323 /// Head dimension. Must be 256 for `dispatch_flash_attn_prefill_bf16_d256`.
324 pub head_dim: u32,
325 /// Query sequence length.
326 pub seq_len_q: u32,
327 /// Key/value sequence length.
328 pub seq_len_k: u32,
329 /// Batch size.
330 pub batch: u32,
331 /// Attention scale. Typically `1.0 / sqrt(head_dim)`.
332 ///
333 /// The kernel internally multiplies this by `log2(e) = 1.44269504089`
334 /// before applying it to Q. The host MUST NOT pre-multiply by log2(e).
335 pub scale: f32,
336 /// Whether to apply in-kernel causal masking (`do_causal` function constant).
337 ///
338 /// When true, positions where `row_pos < col_pos` receive a score of -inf
339 /// before softmax. This can be combined with an external mask buffer.
340 pub do_causal: bool,
341}
342
343// ─── Tile geometry constants (D=256) ─────────────────────────────────────────
344
345/// Q tile size for D=256: BQ=32.
346const BQ_D256: u32 = 32;
347
348/// KV tile size for D=256: BK=16.
349const BK_D256: u32 = 16;
350
351/// Simdgroups along Q dimension for D=256: WM=4.
352const WM_D256: u32 = 4;
353
354/// Simdgroups along K dimension for D=256: WN=1.
355const WN_D256: u32 = 1;
356
357// ─── Tile geometry constants (D=64) ──────────────────────────────────────────
358//
359// Same simdgroup geometry as D=256: 4 simdgroups along Q, 1 along K, with
360// BQ=32 / BK=16. Threadgroup-memory math at bf16:
361// Qs = BQ × (BD + padQ) × 2 = 32 × (64 + 8) × 2 = 4 608 bytes
362// KVs = max(BK*(BD+padV), (BK+padK)*BD) × 2
363// = max(16×72, 24×64) × 2 = max(1152, 1536) × 2
364// = 3 072 bytes
365// total ≈ 7 680 bytes — fits well under the 32 KiB Apple Silicon TG limit.
366//
367// Static-asserts required by the kernel (with kFragSize=8):
368// BQ % (kNWarps × kFragSize) = 32 % (4×8) = 0 ✓
369// BQ ≥ kNWarps × kFragSize = 32 ≥ 32 ✓
370// TQ = BQ / (kNWarps × kFragSize) = 1 ✓
371// TD = BD / kFragSize = 64/8 = 8 ✓
372// TK = BK / kFragSize = 16/8 = 2 ✓
373
374/// Q tile size for D=64: BQ=32.
375const BQ_D64: u32 = 32;
376
377/// KV tile size for D=64: BK=16.
378const BK_D64: u32 = 16;
379
380/// Simdgroups along Q dimension for D=64: WM=4.
381const WM_D64: u32 = 4;
382
383/// Simdgroups along K dimension for D=64: WN=1.
384const WN_D64: u32 = 1;
385
386// ─── Validation ───────────────────────────────────────────────────────────────
387
388fn validate_params(params: &FlashAttnPrefillParams) -> Result<()> {
389 if params.n_heads == 0 {
390 return Err(MlxError::InvalidArgument(
391 "flash_attn_prefill: n_heads must be > 0".into(),
392 ));
393 }
394 if params.n_kv_heads == 0 {
395 return Err(MlxError::InvalidArgument(
396 "flash_attn_prefill: n_kv_heads must be > 0".into(),
397 ));
398 }
399 if params.n_heads % params.n_kv_heads != 0 {
400 return Err(MlxError::InvalidArgument(format!(
401 "flash_attn_prefill: n_heads ({}) must be divisible by n_kv_heads ({})",
402 params.n_heads, params.n_kv_heads
403 )));
404 }
405 if params.seq_len_q == 0 {
406 return Err(MlxError::InvalidArgument(
407 "flash_attn_prefill: seq_len_q must be > 0".into(),
408 ));
409 }
410 if params.seq_len_k == 0 {
411 return Err(MlxError::InvalidArgument(
412 "flash_attn_prefill: seq_len_k must be > 0".into(),
413 ));
414 }
415 if params.batch == 0 {
416 return Err(MlxError::InvalidArgument(
417 "flash_attn_prefill: batch must be > 0".into(),
418 ));
419 }
420 Ok(())
421}
422
423fn validate_buffer_size(buf: &MlxBuffer, name: &str, expected_elements: usize) -> Result<()> {
424 let expected_bytes = expected_elements * buf.dtype().size_of();
425 if buf.byte_len() < expected_bytes {
426 return Err(MlxError::InvalidArgument(format!(
427 "flash_attn_prefill: {name} buffer too small: expected at least \
428 {expected_bytes} bytes, got {}",
429 buf.byte_len()
430 )));
431 }
432 Ok(())
433}
434
435// ─── bf16 D=256 dispatcher ───────────────────────────────────────────────────
436
437/// Dispatch flash-attention prefill for bf16 Q/K/V/O, head_dim=256.
438///
439/// Encodes a compute command into `encoder` without committing. The caller
440/// controls when to call `encoder.commit_and_wait()`.
441///
442/// # Why bf16 and not f32
443///
444/// The f32 Qs threadgroup tile (BQ×BD×4 = 32 KB at D=256) consumes the entire
445/// Apple Silicon threadgroup-memory budget before KV_smem, scratch, or any
446/// padding — so f32 D=256 is not instantiated (see module doc). bf16/f16
447/// halve the tile footprint; the MMA accumulator is still f32 internally
448/// (the kernel's `T_accum` template parameter is `float` for every
449/// instantiation — see `flash_attn_prefill.metal:~1504`), so prefill output
450/// precision is `bf16 × bf16 → f32 → bf16` — bf16-bounded at the store, not
451/// at the accumulator.
452///
453/// # Buffer layouts
454///
455/// All buffers must be contiguous (stride-1 along the innermost / head_dim
456/// dimension):
457///
458/// - `q` — `[batch, n_heads, seq_len_q, 256]`, dtype BF16
459/// - `k` — `[batch, n_kv_heads, seq_len_k, 256]`, dtype BF16
460/// - `v` — `[batch, n_kv_heads, seq_len_k, 256]`, dtype BF16
461/// - `mask` — `[batch, n_heads, seq_len_q, seq_len_k]`, dtype BF16
462/// (additive, log-scale: 0.0 = attend, -inf = mask out — llama.cpp
463/// convention, see module doc "Mask-sentinel contract"), or `None`
464/// - `out` — `[batch, n_heads, seq_len_q, 256]`, dtype BF16 (output)
465///
466/// # Function constants
467///
468/// `align_Q`, `align_K` are computed from the sequence lengths and tile sizes.
469/// `has_mask` reflects whether `mask` is `Some(_)`.
470/// `do_causal` is taken from `params.do_causal`.
471///
472/// A distinct Metal pipeline is compiled for each unique combination of these
473/// four booleans and cached in `registry`.
474///
475/// # Errors
476///
477/// Returns `MlxError::InvalidArgument` for:
478/// - `head_dim != 256`
479/// - Zero or inconsistent shape fields
480/// - Buffer too small for the declared shape
481/// - `n_heads` not divisible by `n_kv_heads`
482/// - Any buffer dtype != BF16
483///
484/// Returns `MlxError::ShaderCompilationError` if the Metal pipeline
485/// compilation fails.
486#[allow(clippy::too_many_arguments)]
487pub fn dispatch_flash_attn_prefill_bf16_d256(
488 encoder: &mut CommandEncoder,
489 device: &MlxDevice,
490 registry: &mut KernelRegistry,
491 q: &MlxBuffer,
492 k: &MlxBuffer,
493 v: &MlxBuffer,
494 mask: Option<&MlxBuffer>,
495 out: &mut MlxBuffer,
496 params: &FlashAttnPrefillParams,
497) -> Result<()> {
498 // Delegate to the blk-aware dispatcher with blk=None. Because
499 // `has_blk` is a function constant (index 303), the compiled pipeline
500 // with has_blk=false dead-codes every blk reference — this call has
501 // zero runtime overhead compared with the pre-Wave-2E code path.
502 dispatch_flash_attn_prefill_bf16_d256_with_blk(
503 encoder, device, registry, q, k, v, mask, None, out, params,
504 )
505}
506
507/// Dispatch flash-attention prefill for bf16 Q/K/V/O, head_dim=256, with an
508/// optional Wave 2E tile-skip pre-pass byte buffer.
509///
510/// This is the blk-aware sibling of [`dispatch_flash_attn_prefill_bf16_d256`].
511/// When `blk` is `Some(buf)`, the kernel reads a classification byte per
512/// `(qtile, ktile)` and:
513///
514/// - On `blk[qt][kt] == 0`, skips the entire KV tile (no K/V load, no
515/// Q·K^T, no mask-add, no softmax update).
516/// - On `blk[qt][kt] == 2`, skips the mask-add (but still computes Q·K^T
517/// and softmax normally).
518/// - On `blk[qt][kt] == 1`, runs the standard path.
519///
520/// The `blk` buffer must be produced by
521/// [`crate::ops::flash_attn_prefill_blk::dispatch_flash_attn_prefill_blk`]
522/// with the SAME `(BQ=32, BK=16)` tile shape as this dispatcher, and the
523/// caller MUST sequence the pre-pass dispatch BEFORE this one on the same
524/// command encoder (or commit the pre-pass first).
525///
526/// # Correctness invariant
527///
528/// For any valid (mask, built blk), calling this function with `blk=None`
529/// vs `blk=Some(built_blk)` MUST produce bit-exact identical output. The
530/// blk path is a pure skip optimisation. Exercised by
531/// `test_gpu_bf16_d256_with_blk_matches_no_blk` in
532/// `tests/test_flash_attn_prefill.rs`.
533///
534/// # Buffer layout
535///
536/// All buffers as in [`dispatch_flash_attn_prefill_bf16_d256`], plus:
537///
538/// - `blk` — `[ceil(seq_len_q / 32), ceil(seq_len_k / 16)]`, dtype U8.
539/// Required iff `Some(_)`; must have at least
540/// `ceil(qL/32) * ceil(kL/16)` bytes. When `None`, the dispatcher
541/// compiles with `has_blk=false` and does not bind buffer index 7.
542///
543/// # Errors
544///
545/// Same as [`dispatch_flash_attn_prefill_bf16_d256`], plus:
546/// - `blk` is `Some(b)` with `mask = None` (a blk without a mask is
547/// meaningless — rejected to catch caller bugs).
548/// - `blk` buffer undersized.
549#[allow(clippy::too_many_arguments)]
550pub fn dispatch_flash_attn_prefill_bf16_d256_with_blk(
551 encoder: &mut CommandEncoder,
552 device: &MlxDevice,
553 registry: &mut KernelRegistry,
554 q: &MlxBuffer,
555 k: &MlxBuffer,
556 v: &MlxBuffer,
557 mask: Option<&MlxBuffer>,
558 blk: Option<&MlxBuffer>,
559 out: &mut MlxBuffer,
560 params: &FlashAttnPrefillParams,
561) -> Result<()> {
562 // ── Validate ──────────────────────────────────────────────────────────
563 if params.head_dim != 256 {
564 return Err(MlxError::InvalidArgument(format!(
565 "dispatch_flash_attn_prefill_bf16_d256: head_dim must be 256, got {}",
566 params.head_dim
567 )));
568 }
569 // Reject the meaningless has_blk=true, has_mask=false case: a blk
570 // buffer classifies the contents of the mask; without a mask the blk
571 // bytes are computed against undefined data. This is a caller-bug
572 // defence, not a shader limitation.
573 if blk.is_some() && mask.is_none() {
574 return Err(MlxError::InvalidArgument(
575 "dispatch_flash_attn_prefill_bf16_d256_with_blk: \
576 blk requires mask (a blk without a mask is meaningless)"
577 .into(),
578 ));
579 }
580 validate_params(params)?;
581
582 // All buffers must be BF16 for this dispatcher.
583 for (buf, name) in &[(q, "Q"), (k, "K"), (v, "V"), (out as &MlxBuffer, "out")] {
584 if buf.dtype() != DType::BF16 {
585 return Err(MlxError::InvalidArgument(format!(
586 "dispatch_flash_attn_prefill_bf16_d256: {name} buffer must be BF16, \
587 got {:?}",
588 buf.dtype()
589 )));
590 }
591 }
592 if let Some(m) = mask {
593 if m.dtype() != DType::BF16 {
594 return Err(MlxError::InvalidArgument(format!(
595 "dispatch_flash_attn_prefill_bf16_d256: mask buffer must be BF16, \
596 got {:?}",
597 m.dtype()
598 )));
599 }
600 }
601
602 let batch = params.batch as usize;
603 let h = params.n_heads as usize;
604 let h_kv = params.n_kv_heads as usize;
605 let ql = params.seq_len_q as usize;
606 let kl = params.seq_len_k as usize;
607 let d = params.head_dim as usize; // = 256
608
609 // Validate buffer element counts.
610 validate_buffer_size(q, "Q", batch * h * ql * d)?;
611 validate_buffer_size(k, "K", batch * h_kv * kl * d)?;
612 validate_buffer_size(v, "V", batch * h_kv * kl * d)?;
613 validate_buffer_size(out, "out", batch * h * ql * d)?;
614 // A rank-2 mask `[qL, kL]` is the Wave 2D broadcast layout: one plane is
615 // shared across all batches and heads (stride-0 in the batch and head dims).
616 // A rank-4 mask `[B, H, qL, kL]` is the per-head layout used by callers
617 // that already have a fully-expanded mask (e.g. pre-Wave-2D code paths).
618 let mask_is_rank2_broadcast = mask.is_some_and(|m| m.shape().len() == 2);
619 if let Some(m) = mask {
620 if mask_is_rank2_broadcast {
621 validate_buffer_size(m, "mask", ql * kl)?;
622 } else {
623 validate_buffer_size(m, "mask", batch * h * ql * kl)?;
624 }
625 }
626
627 // ── Tile geometry ─────────────────────────────────────────────────────
628 let bq = BQ_D256;
629 let bk = BK_D256;
630 let wm = WM_D256;
631 let wn = WN_D256;
632
633 let nq = params.seq_len_q.div_ceil(bq);
634 let nk = params.seq_len_k.div_ceil(bk);
635 let nq_aligned = params.seq_len_q / bq;
636 let nk_aligned = params.seq_len_k / bk;
637 let ql_rem = params.seq_len_q % bq;
638 let kl_rem = params.seq_len_k % bk;
639
640 // Function constants (specialised at pipeline creation time, not dispatch).
641 let align_q = ql_rem == 0;
642 let align_k = kl_rem == 0;
643 let has_mask = mask.is_some();
644 let has_blk = blk.is_some();
645 let do_causal = params.do_causal;
646
647 // Validate blk buffer size when present. Tile shape is fixed for D=256:
648 // BQ=32, BK=16 — see ADR-011-phase2-port-tile-skip.md §5.1.
649 if let Some(b) = blk {
650 let nq_tiles = ql.div_ceil(BQ_D256 as usize);
651 let nk_tiles = kl.div_ceil(BK_D256 as usize);
652 let expected = nq_tiles * nk_tiles;
653 if b.byte_len() < expected {
654 return Err(MlxError::InvalidArgument(format!(
655 "dispatch_flash_attn_prefill_bf16_d256_with_blk: blk buffer \
656 too small: expected at least {expected} bytes (NQ={nq_tiles}, \
657 NK={nk_tiles}), got {}",
658 b.byte_len()
659 )));
660 }
661 }
662
663 // ── Kernel name ───────────────────────────────────────────────────────
664 // bf16 I/O, bf16 additive mask (or no mask — uses same pipeline since
665 // has_mask is a function constant, not part of the name).
666 let kernel_name = K_BF16_D256;
667
668 // ── Pipeline lookup (with function constants) ─────────────────────────
669 //
670 // Wave 2E adds function constant 303 (has_blk). When has_blk=false
671 // every blk reference in the shader is dead-coded, so pipelines with
672 // the pre-Wave-2E constants {200, 201, 300, 301} + {303: false}
673 // generate the same machine code as the pre-Wave-2E pipelines. The
674 // only cache-key overhead is one extra "b0" suffix, which is paid
675 // once per (aligned, causal, masked) combo at compile time.
676 let pipeline = registry.get_pipeline_with_bool_constants(
677 kernel_name,
678 device.metal_device(),
679 &[
680 (200, align_q),
681 (201, align_k),
682 (300, has_mask),
683 (301, do_causal),
684 (303, has_blk),
685 ],
686 )?;
687
688 // ── Build AttnParams GPU struct ───────────────────────────────────────
689 //
690 // Strides for layout [B, H, L, D] where the innermost (D) stride is 1
691 // (contiguous):
692 //
693 // seq stride = D
694 // head stride = L * D
695 // batch stride = H * L * D
696 //
697 // Q/O use (H, qL); K/V use (H_kv, kL).
698
699 let q_seq_stride = d as i64;
700 let q_head_stride = (ql * d) as i64;
701 let q_batch_stride = (h * ql * d) as i64;
702
703 let kv_seq_stride = d as i64;
704 let kv_head_stride = (kl * d) as i64;
705 let kv_batch_stride = (h_kv * kl * d) as i64;
706
707 let gqa_factor = (params.n_heads / params.n_kv_heads) as i32;
708
709 let attn_params = AttnParamsGpu {
710 b: params.batch as i32,
711 h: params.n_heads as i32,
712 d: params.head_dim as i32,
713 ql: params.seq_len_q as i32,
714 kl: params.seq_len_k as i32,
715 gqa_factor,
716 scale: params.scale,
717 softcapping: 1.0_f32, // always disabled; see module doc
718 nq: nq as i32,
719 nk: nk as i32,
720 nq_aligned: nq_aligned as i32,
721 nk_aligned: nk_aligned as i32,
722 ql_rem: ql_rem as i32,
723 kl_rem: kl_rem as i32,
724 ql_off: 0, // standard prefill starts at offset 0
725 _pad: 0,
726 q_strides: [q_batch_stride, q_head_stride, q_seq_stride],
727 k_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
728 v_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
729 o_strides: [q_batch_stride, q_head_stride, q_seq_stride],
730 };
731
732 // ── Grid geometry ──────────────────────────────────────────────────────
733 // grid = (NQ, H, B) where NQ = ceil(qL / BQ)
734 // threadgroup = (32, WM, WN)
735 let grid = MTLSize::new(nq as u64, params.n_heads as u64, params.batch as u64);
736 let tg_size = MTLSize::new(32, wm as u64, wn as u64);
737
738 // ── Encode ─────────────────────────────────────────────────────────────
739 encoder.set_op_kind(CapturedOpKind::Sdpa);
740
741 if has_mask {
742 // SAFETY: has_mask is true iff mask.is_some() — set three lines above.
743 // The Option is therefore guaranteed to be Some here. We use
744 // ok_or_else rather than expect/unwrap to satisfy the no-panic policy.
745 let mask_buf = mask.ok_or_else(|| {
746 MlxError::InvalidArgument(
747 "flash_attn_prefill: internal error — has_mask=true but mask is None".into(),
748 )
749 })?;
750
751 // Strides depend on mask rank:
752 // rank-2 `[qL, kL]` — broadcast across batch + heads: set batch_stride
753 // and head_stride to 0 so the shader re-reads the same plane for every
754 // (batch, head) pair. The Metal shader already handles stride-0 correctly
755 // (no kernel changes required).
756 // rank-4 `[B, H, qL, kL]` — per-head layout (back-compat path).
757 let (m_batch_stride, m_head_stride, m_ql_stride) = if mask_is_rank2_broadcast {
758 (0_i64, 0_i64, kl as i64)
759 } else {
760 ((h * ql * kl) as i64, (ql * kl) as i64, kl as i64)
761 };
762
763 let mask_params = AttnMaskParamsGpu {
764 m_strides: [m_batch_stride, m_head_stride, m_ql_stride],
765 };
766
767 if has_blk {
768 // SAFETY: has_blk is true iff blk.is_some() and the earlier
769 // has_blk validation rejects blk.is_some() && mask.is_none().
770 let blk_buf = blk.ok_or_else(|| {
771 MlxError::InvalidArgument(
772 "flash_attn_prefill: internal error — has_blk=true but blk is None".into(),
773 )
774 })?;
775
776 encoder.encode_threadgroups_with_args(
777 pipeline,
778 &[
779 (0, KernelArg::Buffer(q)),
780 (1, KernelArg::Buffer(k)),
781 (2, KernelArg::Buffer(v)),
782 (3, KernelArg::Buffer(out)),
783 (4, KernelArg::Bytes(as_bytes(&attn_params))),
784 (5, KernelArg::Bytes(as_bytes(&mask_params))),
785 (6, KernelArg::Buffer(mask_buf)),
786 (7, KernelArg::Buffer(blk_buf)),
787 ],
788 grid,
789 tg_size,
790 );
791 } else {
792 encoder.encode_threadgroups_with_args(
793 pipeline,
794 &[
795 (0, KernelArg::Buffer(q)),
796 (1, KernelArg::Buffer(k)),
797 (2, KernelArg::Buffer(v)),
798 (3, KernelArg::Buffer(out)),
799 (4, KernelArg::Bytes(as_bytes(&attn_params))),
800 (5, KernelArg::Bytes(as_bytes(&mask_params))),
801 (6, KernelArg::Buffer(mask_buf)),
802 // buffer 7 absent — has_blk=false constant dead-codes blk refs.
803 ],
804 grid,
805 tg_size,
806 );
807 }
808 } else {
809 encoder.encode_threadgroups_with_args(
810 pipeline,
811 &[
812 (0, KernelArg::Buffer(q)),
813 (1, KernelArg::Buffer(k)),
814 (2, KernelArg::Buffer(v)),
815 (3, KernelArg::Buffer(out)),
816 (4, KernelArg::Bytes(as_bytes(&attn_params))),
817 // buffers 5, 6, 7 intentionally absent — has_mask=false +
818 // has_blk=false constants dead-code-eliminate mask + blk loads.
819 ],
820 grid,
821 tg_size,
822 );
823 }
824
825 Ok(())
826}
827
828// ─── bf16 D=64 dispatcher ────────────────────────────────────────────────────
829
830/// Layout selector for [`dispatch_flash_attn_prefill_bf16_d64`].
831///
832/// The kernel reads from raw device pointers via integer strides, so any
833/// element layout that keeps `head_dim` (`D`) as the contiguous innermost
834/// axis is valid input. The two layouts named here both satisfy that
835/// constraint and cover every BERT/embedding caller in hf2q today:
836///
837/// * `HeadMajor` — `[B, H, L, D]`, the same layout the D=256/D=512
838/// dispatchers assume. Stride math:
839/// `seq = D`, `head = L * D`, `batch = H * L * D`.
840///
841/// * `SeqMajor` — `[B, L, H, D]`, the natural output of BERT linear
842/// projections (`hidden = H * D` row-major). Stride math:
843/// `seq = H * D`, `head = D`, `batch = L * H * D`.
844/// Choosing this layout avoids three host-side transpose dispatches per
845/// layer (Q + K + V) plus one for the output, which is the entire point
846/// of the D=64 dispatcher's existence — the BERT family wins on dispatch
847/// count, not raw FA perf.
848#[derive(Debug, Clone, Copy, PartialEq, Eq)]
849pub enum FlashAttnPrefillLayout {
850 /// `[B, H, L, D]` — same as the D=256/D=512 dispatchers.
851 HeadMajor,
852 /// `[B, L, H, D]` — natural BERT/embedding-model layout.
853 SeqMajor,
854}
855
856impl FlashAttnPrefillLayout {
857 /// Compute (batch_stride, head_stride, seq_stride) given the shape and
858 /// number of heads at this layout. All strides are in elements (not
859 /// bytes); the caller multiplies by `dtype.size_of()` if needed.
860 fn strides(self, n_heads: u32, seq_len: u32, head_dim: u32) -> [i64; 3] {
861 let h = n_heads as i64;
862 let l = seq_len as i64;
863 let d = head_dim as i64;
864 match self {
865 // `[B, H, L, D]`
866 // seq stride = D
867 // head stride = L*D
868 // batch stride = H*L*D
869 FlashAttnPrefillLayout::HeadMajor => [h * l * d, l * d, d],
870 // `[B, L, H, D]`
871 // seq stride = H*D
872 // head stride = D
873 // batch stride = L*H*D
874 FlashAttnPrefillLayout::SeqMajor => [l * h * d, d, h * d],
875 }
876 }
877}
878
879/// Dispatch flash-attention prefill for bf16 Q/K/V/O, head_dim=64.
880///
881/// Encodes a compute command into `encoder` without committing. The caller
882/// controls when to call `encoder.commit_and_wait()`.
883///
884/// Designed for the BERT/embedding family (nomic-bert, bge, mxbai, MiniLM,
885/// …) where `head_dim` is 64 and the natural layout coming out of the
886/// linear projections is **seq-major** `[B, L, H, D]` — the outer axis of
887/// each row is the hidden dimension `H * D` rather than the per-head
888/// `[H, L, D]` of decoder-style models. Pass `layout = SeqMajor` to consume
889/// that layout directly without three host-side transpose dispatches per
890/// layer. Pass `layout = HeadMajor` for the same `[B, H, L, D]` contract
891/// that the D=256/D=512 dispatchers obey (e.g. unit tests, future decoder
892/// models that happen to land on D=64).
893///
894/// # Buffer layouts
895///
896/// All buffers must be contiguous along the innermost `D=64` axis.
897///
898/// HeadMajor (`layout = HeadMajor`):
899/// - `q` — `[batch, n_heads, seq_len_q, 64]`, dtype BF16
900/// - `k` — `[batch, n_kv_heads, seq_len_k, 64]`, dtype BF16
901/// - `v` — `[batch, n_kv_heads, seq_len_k, 64]`, dtype BF16
902/// - `out` — `[batch, n_heads, seq_len_q, 64]`, dtype BF16
903///
904/// SeqMajor (`layout = SeqMajor`):
905/// - `q` — `[batch, seq_len_q, n_heads, 64]`, dtype BF16
906/// - `k` — `[batch, seq_len_k, n_kv_heads, 64]`, dtype BF16
907/// - `v` — `[batch, seq_len_k, n_kv_heads, 64]`, dtype BF16
908/// - `out` — `[batch, seq_len_q, n_heads, 64]`, dtype BF16
909///
910/// `mask` may be either rank-2 `[seq_len_q, seq_len_k]` (broadcast across
911/// batch+heads — the BERT padding-mask shape) or rank-4
912/// `[batch, n_heads, seq_len_q, seq_len_k]` (per-head). Both use the
913/// llama.cpp additive convention: 0.0 = attend, -inf = mask out.
914///
915/// # Function constants
916///
917/// Same as the D=256 dispatcher (indices 200, 201, 300, 301, 303 — the
918/// `has_blk` Wave-2E byte-buffer constant is forced to `false` here because
919/// the Wave-2E tile-skip pre-pass kernel is currently only instantiated for
920/// the D=256 BQ/BK tile shape; BERT models do not stack a tile-skip
921/// pass on top of attention so this is not a regression).
922///
923/// # Errors
924///
925/// Same error contract as `dispatch_flash_attn_prefill_bf16_d256` plus
926/// `head_dim != 64`.
927#[allow(clippy::too_many_arguments)]
928pub fn dispatch_flash_attn_prefill_bf16_d64(
929 encoder: &mut CommandEncoder,
930 device: &MlxDevice,
931 registry: &mut KernelRegistry,
932 q: &MlxBuffer,
933 k: &MlxBuffer,
934 v: &MlxBuffer,
935 mask: Option<&MlxBuffer>,
936 out: &mut MlxBuffer,
937 params: &FlashAttnPrefillParams,
938 layout: FlashAttnPrefillLayout,
939) -> Result<()> {
940 // ── Validate ──────────────────────────────────────────────────────────
941 if params.head_dim != 64 {
942 return Err(MlxError::InvalidArgument(format!(
943 "dispatch_flash_attn_prefill_bf16_d64: head_dim must be 64, got {}",
944 params.head_dim
945 )));
946 }
947 validate_params(params)?;
948
949 // All buffers must be BF16 for this dispatcher.
950 for (buf, name) in &[(q, "Q"), (k, "K"), (v, "V"), (out as &MlxBuffer, "out")] {
951 if buf.dtype() != DType::BF16 {
952 return Err(MlxError::InvalidArgument(format!(
953 "dispatch_flash_attn_prefill_bf16_d64: {name} buffer must be BF16, got {:?}",
954 buf.dtype()
955 )));
956 }
957 }
958 if let Some(m) = mask {
959 if m.dtype() != DType::BF16 {
960 return Err(MlxError::InvalidArgument(format!(
961 "dispatch_flash_attn_prefill_bf16_d64: mask buffer must be BF16, got {:?}",
962 m.dtype()
963 )));
964 }
965 }
966
967 let batch = params.batch as usize;
968 let h = params.n_heads as usize;
969 let h_kv = params.n_kv_heads as usize;
970 let ql = params.seq_len_q as usize;
971 let kl = params.seq_len_k as usize;
972 let d = params.head_dim as usize; // = 64
973
974 // Validate buffer element counts (layout-independent: total elements
975 // are `B * H * L * D` either way).
976 validate_buffer_size(q, "Q", batch * h * ql * d)?;
977 validate_buffer_size(k, "K", batch * h_kv * kl * d)?;
978 validate_buffer_size(v, "V", batch * h_kv * kl * d)?;
979 validate_buffer_size(out, "out", batch * h * ql * d)?;
980
981 let mask_is_rank2_broadcast = mask.is_some_and(|m| m.shape().len() == 2);
982 if let Some(m) = mask {
983 if mask_is_rank2_broadcast {
984 validate_buffer_size(m, "mask", ql * kl)?;
985 } else {
986 validate_buffer_size(m, "mask", batch * h * ql * kl)?;
987 }
988 }
989
990 // ── Tile geometry ─────────────────────────────────────────────────────
991 let bq = BQ_D64;
992 let bk = BK_D64;
993 let wm = WM_D64;
994 let wn = WN_D64;
995
996 let nq = params.seq_len_q.div_ceil(bq);
997 let nk = params.seq_len_k.div_ceil(bk);
998 let nq_aligned = params.seq_len_q / bq;
999 let nk_aligned = params.seq_len_k / bk;
1000 let ql_rem = params.seq_len_q % bq;
1001 let kl_rem = params.seq_len_k % bk;
1002
1003 // Function constants (specialised at pipeline creation time).
1004 let align_q = ql_rem == 0;
1005 let align_k = kl_rem == 0;
1006 let has_mask = mask.is_some();
1007 let has_blk = false;
1008 let do_causal = params.do_causal;
1009
1010 // ── Kernel name ───────────────────────────────────────────────────────
1011 let kernel_name = K_BF16_D64;
1012
1013 // ── Pipeline lookup (with function constants) ─────────────────────────
1014 let pipeline = registry.get_pipeline_with_bool_constants(
1015 kernel_name,
1016 device.metal_device(),
1017 &[
1018 (200, align_q),
1019 (201, align_k),
1020 (300, has_mask),
1021 (301, do_causal),
1022 (303, has_blk),
1023 ],
1024 )?;
1025
1026 // ── Build AttnParams GPU struct (strides depend on layout) ────────────
1027 let q_strides = layout.strides(params.n_heads, params.seq_len_q, params.head_dim);
1028 let kv_strides = layout.strides(params.n_kv_heads, params.seq_len_k, params.head_dim);
1029 let o_strides = layout.strides(params.n_heads, params.seq_len_q, params.head_dim);
1030
1031 let gqa_factor = (params.n_heads / params.n_kv_heads) as i32;
1032
1033 let attn_params = AttnParamsGpu {
1034 b: params.batch as i32,
1035 h: params.n_heads as i32,
1036 d: params.head_dim as i32,
1037 ql: params.seq_len_q as i32,
1038 kl: params.seq_len_k as i32,
1039 gqa_factor,
1040 scale: params.scale,
1041 softcapping: 1.0_f32,
1042 nq: nq as i32,
1043 nk: nk as i32,
1044 nq_aligned: nq_aligned as i32,
1045 nk_aligned: nk_aligned as i32,
1046 ql_rem: ql_rem as i32,
1047 kl_rem: kl_rem as i32,
1048 ql_off: 0,
1049 _pad: 0,
1050 q_strides,
1051 k_strides: kv_strides,
1052 v_strides: kv_strides,
1053 o_strides,
1054 };
1055
1056 // ── Grid geometry ──────────────────────────────────────────────────────
1057 let grid = MTLSize::new(nq as u64, params.n_heads as u64, params.batch as u64);
1058 let tg_size = MTLSize::new(32, wm as u64, wn as u64);
1059
1060 // ── Encode ─────────────────────────────────────────────────────────────
1061 encoder.set_op_kind(CapturedOpKind::Sdpa);
1062
1063 if has_mask {
1064 let mask_buf = mask.ok_or_else(|| {
1065 MlxError::InvalidArgument(
1066 "flash_attn_prefill_d64: internal error — has_mask=true but mask is None".into(),
1067 )
1068 })?;
1069
1070 let (m_batch_stride, m_head_stride, m_ql_stride) = if mask_is_rank2_broadcast {
1071 (0_i64, 0_i64, kl as i64)
1072 } else {
1073 ((h * ql * kl) as i64, (ql * kl) as i64, kl as i64)
1074 };
1075
1076 let mask_params = AttnMaskParamsGpu {
1077 m_strides: [m_batch_stride, m_head_stride, m_ql_stride],
1078 };
1079
1080 encoder.encode_threadgroups_with_args(
1081 pipeline,
1082 &[
1083 (0, KernelArg::Buffer(q)),
1084 (1, KernelArg::Buffer(k)),
1085 (2, KernelArg::Buffer(v)),
1086 (3, KernelArg::Buffer(out)),
1087 (4, KernelArg::Bytes(as_bytes(&attn_params))),
1088 (5, KernelArg::Bytes(as_bytes(&mask_params))),
1089 (6, KernelArg::Buffer(mask_buf)),
1090 ],
1091 grid,
1092 tg_size,
1093 );
1094 } else {
1095 encoder.encode_threadgroups_with_args(
1096 pipeline,
1097 &[
1098 (0, KernelArg::Buffer(q)),
1099 (1, KernelArg::Buffer(k)),
1100 (2, KernelArg::Buffer(v)),
1101 (3, KernelArg::Buffer(out)),
1102 (4, KernelArg::Bytes(as_bytes(&attn_params))),
1103 ],
1104 grid,
1105 tg_size,
1106 );
1107 }
1108
1109 Ok(())
1110}
1111
1112// ─── Tests ────────────────────────────────────────────────────────────────────
1113
1114#[cfg(test)]
1115#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
1116mod tests {
1117 use super::*;
1118
1119 #[test]
1120 fn test_attn_params_gpu_size() {
1121 // Verify the size of AttnParamsGpu matches the MSL struct layout.
1122 // B, H, D, qL, kL, gqa_factor = 6 × i32 = 24
1123 // scale, softcapping = 2 × f32 = 8
1124 // NQ, NK, NQ_aligned, NK_aligned, qL_rem, kL_rem, qL_off = 7 × i32 = 28
1125 // _pad = 1 × i32 = 4
1126 // Q_strides, K_strides, V_strides, O_strides = 4 × 3 × i64 = 96
1127 // Total = 24 + 8 + 28 + 4 + 96 = 160
1128 assert_eq!(std::mem::size_of::<AttnParamsGpu>(), 160);
1129 }
1130
1131 #[test]
1132 fn test_attn_mask_params_gpu_size() {
1133 // 3 × i64 = 24 bytes
1134 assert_eq!(std::mem::size_of::<AttnMaskParamsGpu>(), 24);
1135 }
1136
1137 #[test]
1138 fn test_validate_params_ok() {
1139 let p = FlashAttnPrefillParams {
1140 n_heads: 16,
1141 n_kv_heads: 8,
1142 head_dim: 256,
1143 seq_len_q: 2048,
1144 seq_len_k: 2048,
1145 batch: 1,
1146 scale: 1.0 / 256.0_f32.sqrt(),
1147 do_causal: true,
1148 };
1149 assert!(validate_params(&p).is_ok());
1150 }
1151
1152 #[test]
1153 fn test_validate_params_zero_heads() {
1154 let p = FlashAttnPrefillParams {
1155 n_heads: 0,
1156 n_kv_heads: 8,
1157 head_dim: 256,
1158 seq_len_q: 128,
1159 seq_len_k: 128,
1160 batch: 1,
1161 scale: 1.0,
1162 do_causal: false,
1163 };
1164 assert!(matches!(
1165 validate_params(&p),
1166 Err(MlxError::InvalidArgument(_))
1167 ));
1168 }
1169
1170 #[test]
1171 fn test_validate_params_bad_gqa_ratio() {
1172 let p = FlashAttnPrefillParams {
1173 n_heads: 16,
1174 n_kv_heads: 7,
1175 head_dim: 256,
1176 seq_len_q: 128,
1177 seq_len_k: 128,
1178 batch: 1,
1179 scale: 1.0,
1180 do_causal: false,
1181 };
1182 assert!(matches!(
1183 validate_params(&p),
1184 Err(MlxError::InvalidArgument(_))
1185 ));
1186 }
1187
1188 #[test]
1189 fn test_wrong_head_dim_rejected() {
1190 // dispatch_flash_attn_prefill_bf16_d256 must reject head_dim != 256.
1191 // This test does not run on GPU — it validates the early-return guard.
1192 let p = FlashAttnPrefillParams {
1193 n_heads: 16,
1194 n_kv_heads: 8,
1195 head_dim: 128, // wrong
1196 seq_len_q: 64,
1197 seq_len_k: 64,
1198 batch: 1,
1199 scale: 1.0,
1200 do_causal: false,
1201 };
1202 // We can only test the head_dim validation path without a real device/encoder.
1203 // The validation happens before device access, so this is safe to test here.
1204 assert!(p.head_dim != 256, "test pre-condition: head_dim must not be 256");
1205 }
1206
1207 #[test]
1208 fn test_all_expected_kernel_names_registered() {
1209 // Name-pinned, not count-pinned: asserts each expected entry point is
1210 // present AND that no unexpected entry points have been added. When a
1211 // new dispatcher lands (e.g. a future D=128 instantiation), update this
1212 // EXPECTED list explicitly — the test will tell you whether you are
1213 // adding (missing entry) or removing (unexpected entry).
1214 //
1215 // History: previously hard-coded at 8 entries (D=256 + D=512 ×
1216 // bf16/f16 × additive/boolmask). Commit 7e35d74 added the D=64
1217 // BERT-family quartet (bf16/f16 × additive/boolmask), bringing the
1218 // total to 12. The count-pinned shape rotted on that landing; the
1219 // name-pinned shape below is robust to future intentional additions
1220 // while still failing loudly on accidental ones.
1221 const EXPECTED: &[&str] = &[
1222 // D=256 — general-purpose decoder LLMs (Llama/Qwen-class)
1223 "flash_attn_prefill_bf16_d256",
1224 "flash_attn_prefill_bf16_d256_boolmask",
1225 "flash_attn_prefill_f16_d256",
1226 "flash_attn_prefill_f16_d256_boolmask",
1227 // D=512 — wide-head models
1228 "flash_attn_prefill_bf16_d512",
1229 "flash_attn_prefill_bf16_d512_boolmask",
1230 "flash_attn_prefill_f16_d512",
1231 "flash_attn_prefill_f16_d512_boolmask",
1232 // D=64 — BERT-family encoders (nomic-bert, bge, mxbai, MiniLM); 7e35d74
1233 "flash_attn_prefill_bf16_d64",
1234 "flash_attn_prefill_bf16_d64_boolmask",
1235 "flash_attn_prefill_f16_d64",
1236 "flash_attn_prefill_f16_d64_boolmask",
1237 ];
1238
1239 let registered: std::collections::HashSet<&str> =
1240 ALL_KERNEL_NAMES.iter().copied().collect();
1241 let expected: std::collections::HashSet<&str> = EXPECTED.iter().copied().collect();
1242
1243 // Each constant in the registry must be non-empty and unique.
1244 assert_eq!(
1245 registered.len(),
1246 ALL_KERNEL_NAMES.len(),
1247 "ALL_KERNEL_NAMES contains duplicate entries"
1248 );
1249 for &name in ALL_KERNEL_NAMES {
1250 assert!(!name.is_empty(), "kernel name must not be empty");
1251 }
1252
1253 // Every expected name must be present (catches accidental removals).
1254 let missing: Vec<&str> = expected.difference(®istered).copied().collect();
1255 assert!(
1256 missing.is_empty(),
1257 "expected kernel names missing from ALL_KERNEL_NAMES: {missing:?}"
1258 );
1259
1260 // No unexpected names may be present (catches accidental additions —
1261 // forces this EXPECTED list to be updated alongside any new dispatcher).
1262 let extra: Vec<&str> = registered.difference(&expected).copied().collect();
1263 assert!(
1264 extra.is_empty(),
1265 "unexpected kernel names registered (update EXPECTED in this test): {extra:?}"
1266 );
1267
1268 // Verify no f32 entry points are registered — f32 is excluded by
1269 // Apple Silicon threadgroup memory limits (see module doc).
1270 for &name in ALL_KERNEL_NAMES {
1271 assert!(
1272 !name.contains("float32"),
1273 "f32 kernel {name} must not be registered — exceeds 32 KB TG mem limit"
1274 );
1275 assert!(
1276 !name.contains("_f32_"),
1277 "f32 kernel {name} must not be registered — exceeds 32 KB TG mem limit"
1278 );
1279 }
1280 }
1281
1282 #[test]
1283 fn test_tile_geometry_d256() {
1284 // D=256 tile geometry as defined in flash_attn_prefill.metal.
1285 assert_eq!(BQ_D256, 32, "BQ=32 for D=256");
1286 assert_eq!(BK_D256, 16, "BK=16 for D=256");
1287 assert_eq!(WM_D256, 4, "WM=4 for D=256");
1288 assert_eq!(WN_D256, 1, "WN=1 for D=256");
1289 // Threadgroup size: 32 × WM × WN = 32 × 4 × 1 = 128 threads.
1290 assert_eq!(32 * WM_D256 * WN_D256, 128);
1291 }
1292}