mlx_native/ops/flash_attn_prefill.rs
1//! Flash-attention-style tiled prefill kernel — host dispatch.
2//!
3//! mlx-native's batched-prefill SDPA kernel, the prefill counterpart to
4//! `flash_attn_vec` (which handles the seq_len=1 decode case). Implements
5//! online softmax + simdgroup MMA tiling on Apple GPU.
6//!
7//! ## Kernel variants registered
8//!
9//! Twelve entry points are registered, all backed by the single
10//! `flash_attn_prefill.metal` shader source:
11//!
12//! ### D=64 (BQ=32, BK=16, WM=4, WN=1 — 128 threads/threadgroup, BERT family)
13//!
14//! | Kernel name | I/O dtype | Mask kind |
15//! |---|---|---|
16//! | `flash_attn_prefill_bf16_d64` | bf16 | bf16 additive |
17//! | `flash_attn_prefill_bf16_d64_boolmask` | bf16 | bool |
18//! | `flash_attn_prefill_f16_d64` | f16 | f16 additive |
19//! | `flash_attn_prefill_f16_d64_boolmask` | f16 | bool |
20//!
21//! ### D=256 (BQ=32, BK=16, WM=4, WN=1 — 128 threads/threadgroup)
22//!
23//! | Kernel name | I/O dtype | Mask kind |
24//! |---|---|---|
25//! | `flash_attn_prefill_bf16_d256` | bf16 | bf16 additive (log-domain) |
26//! | `flash_attn_prefill_bf16_d256_boolmask` | bf16 | bool (`is_attended`) |
27//! | `flash_attn_prefill_f16_d256` | f16 | f16 additive |
28//! | `flash_attn_prefill_f16_d256_boolmask` | f16 | bool |
29//!
30//! ### D=512 (BQ=8, BK=8, WM=1, WN=1 — 32 threads/threadgroup, 1 simdgroup)
31//!
32//! | Kernel name | I/O dtype | Mask kind |
33//! |---|---|---|
34//! | `flash_attn_prefill_bf16_d512` | bf16 | bf16 additive |
35//! | `flash_attn_prefill_bf16_d512_boolmask` | bf16 | bool |
36//! | `flash_attn_prefill_f16_d512` | f16 | f16 additive |
37//! | `flash_attn_prefill_f16_d512_boolmask` | f16 | bool |
38//!
39//! ### f32 is NOT instantiated at any D
40//!
41//! The f32 Qs threadgroup tile alone is `BQ * BD * 4` bytes — at D=256 this
42//! is 32 KB exactly, the Apple Silicon `MTLDevice.maxThreadgroupMemoryLength`
43//! hard limit, before KV_smem or scratch. Verified empirically on M5 Max:
44//! f32 D=256 requires ~53.7 KB and fails at library compile. bf16 and f16
45//! halve the tile footprint (~29 KB) and fit within the limit. D=512 f32
46//! is excluded for the same reason. D=64 inherits the same exclusion for
47//! consistency (bf16/f16 only across the entire dispatcher family). f32
48//! correctness is verified at the CPU reference layer in
49//! `tests/test_flash_attn_prefill.rs`. See
50//! `ADR-011-phase1-port-source-decision.md` §3 for the full
51//! threadgroup-memory analysis.
52//!
53//! ## Function constants
54//!
55//! The kernel declares four Metal function constants that must be specialised
56//! at pipeline creation time (not at dispatch time):
57//!
58//! - Index 200: `align_Q` (bool) — true when `qL % BQ == 0`
59//! - Index 201: `align_K` (bool) — true when `kL % BK == 0`
60//! - Index 300: `has_mask` (bool) — true when a mask buffer is bound
61//! - Index 301: `do_causal` (bool) — true for in-kernel causal masking
62//!
63//! These are plumbed via [`KernelRegistry::get_pipeline_with_bool_constants`],
64//! which caches compiled pipelines keyed by `(kernel_name, align_Q, align_K,
65//! has_mask, do_causal)`. Pipeline compilation is amortised: the slow path
66//! runs only once per unique `(name, booleans)` combination.
67//!
68//! ## Buffer layout (indices match the MSL kernel)
69//!
70//! - `buffer(0)` — Q `[B, H, qL, D]` device, contiguous inner dim
71//! - `buffer(1)` — K `[B, H_kv, kL, D]` device, contiguous inner dim
72//! - `buffer(2)` — V `[B, H_kv, kL, D]` device, contiguous inner dim
73//! - `buffer(3)` — O `[B, H, qL, D]` device, written by kernel
74//! - `buffer(4)` — `AttnParams` constant buffer (this module's [`AttnParamsGpu`])
75//! - `buffer(5)` — `AttnMaskParams` constant buffer (only when `has_mask=true`)
76//! - `buffer(6)` — mask data buffer (only when `has_mask=true`)
77//!
78//! ## Grid geometry
79//!
80//! - Threadgroups: `(ceil(qL / BQ), H, B)`
81//! - Threads per threadgroup: `(32, WM, WN)`
82//! - D=256: 128 threads (4 simdgroups × 32 lanes).
83//! - D=512: 32 threads (1 simdgroup × 32 lanes).
84//!
85//! ## Scale convention
86//!
87//! Pass `scale = 1.0 / sqrt(head_dim)`. The kernel multiplies internally by
88//! `log2(e) ≈ 1.44269504` and uses `fast::exp2` throughout — so the host
89//! MUST NOT pre-multiply by `log2(e)`.
90//!
91//! ## Mask-sentinel contract (llama.cpp convention)
92//!
93//! The additive mask buffer (bf16/f16 for the additive dispatchers) uses
94//! the llama.cpp CPU-side convention: **masked positions = `-INFINITY`**
95//! (IEEE-754 f32 `0xFF800000`, cast to the I/O dtype — both `half(-inf)`
96//! and `bfloat16(-inf)` have a real `-inf` encoding that the kernel
97//! consumes correctly). Attended positions = `0.0`.
98//!
99//! This matches llama.cpp's mask-authoring sites at
100//! `llama-graph.cpp:421, 436, 557` and `llama-kv-cache.cpp:1572`, where
101//! the CPU writes raw `-INFINITY` and the flash-attn cast to f16 saturates
102//! to f16-`-INFINITY`.
103//!
104//! The kernel does NOT require masks to use `-FLT_MAX/2` or any other
105//! finite "large negative" sentinel. The `-FLT_MAX/2` value is the
106//! kernel-internal `M` (running row-max) initialiser — a finite sentinel
107//! that absorbs `-inf` scores via `simd_max` without ever letting `M`
108//! become `-inf`, which in turn lets every `exp(score - M)` evaluate as
109//! `exp(-inf) = 0.0` (IEEE-754 exact) rather than `exp(-inf - -inf) =
110//! exp(NaN) = NaN`. See ADR-011-phase2-port-sentinel.md §1.
111//!
112//! Callers may pass arbitrary finite additive biases in the mask for
113//! non-masked positions (e.g. ALiBi, relative-position biases); only
114//! "fully block this K position" requires `-inf`.
115//!
116//! ## See also
117//!
118//! - Kernel: `/opt/mlx-native/src/shaders/flash_attn_prefill.metal`
119//! - ADR-011: `/opt/hf2q/docs/ADR-011-flash-attn-prefill.md`
120//! - ADR-011 phase 2 sentinel port: `/opt/hf2q/docs/ADR-011-phase2-port-sentinel.md`
121
122use metal::MTLSize;
123
124use crate::buffer::MlxBuffer;
125use crate::device::MlxDevice;
126use crate::encoder::{CapturedOpKind, CommandEncoder, KernelArg, as_bytes};
127use crate::error::{MlxError, Result};
128use crate::kernel_registry::KernelRegistry;
129use crate::DType;
130
131// ─── Shader source ───────────────────────────────────────────────────────────
132
133/// MSL source for the flash-attention prefill kernel (embedded at compile time).
134pub static FLASH_ATTN_PREFILL_SHADER_SOURCE: &str =
135 include_str!("../shaders/flash_attn_prefill.metal");
136
137// ─── All 12 kernel entry-point names ─────────────────────────────────────────
138
139/// D=256, bf16 I/O, bf16 additive mask.
140const K_BF16_D256: &str = "flash_attn_prefill_bf16_d256";
141/// D=256, bf16 I/O, bool (`is_attended`) mask.
142const K_BF16_D256_BOOLMASK: &str = "flash_attn_prefill_bf16_d256_boolmask";
143/// D=256, f16 I/O, f16 additive mask.
144const K_F16_D256: &str = "flash_attn_prefill_f16_d256";
145/// D=256, f16 I/O, bool mask.
146const K_F16_D256_BOOLMASK: &str = "flash_attn_prefill_f16_d256_boolmask";
147/// D=512, bf16 I/O, bf16 additive mask.
148const K_BF16_D512: &str = "flash_attn_prefill_bf16_d512";
149/// D=512, bf16 I/O, bool mask.
150const K_BF16_D512_BOOLMASK: &str = "flash_attn_prefill_bf16_d512_boolmask";
151/// D=512, f16 I/O, f16 additive mask.
152const K_F16_D512: &str = "flash_attn_prefill_f16_d512";
153/// D=512, f16 I/O, bool mask.
154const K_F16_D512_BOOLMASK: &str = "flash_attn_prefill_f16_d512_boolmask";
155/// D=64, bf16 I/O, bf16 additive mask (BERT family).
156const K_BF16_D64: &str = "flash_attn_prefill_bf16_d64";
157/// D=64, bf16 I/O, bool (`is_attended`) mask.
158const K_BF16_D64_BOOLMASK: &str = "flash_attn_prefill_bf16_d64_boolmask";
159/// D=64, f16 I/O, f16 additive mask.
160const K_F16_D64: &str = "flash_attn_prefill_f16_d64";
161/// D=64, f16 I/O, bool mask.
162const K_F16_D64_BOOLMASK: &str = "flash_attn_prefill_f16_d64_boolmask";
163
164/// All 12 kernel entry-point names exported by `flash_attn_prefill.metal`.
165///
166/// Registering all of them against the single shader source costs nothing at
167/// registration time (source is a static `&str`) and ensures additional
168/// dispatchers (f16 paths, D=512 exposure) can be added later without
169/// touching registration here.
170const ALL_KERNEL_NAMES: &[&str] = &[
171 K_BF16_D256,
172 K_BF16_D256_BOOLMASK,
173 K_F16_D256,
174 K_F16_D256_BOOLMASK,
175 K_BF16_D512,
176 K_BF16_D512_BOOLMASK,
177 K_F16_D512,
178 K_F16_D512_BOOLMASK,
179 K_BF16_D64,
180 K_BF16_D64_BOOLMASK,
181 K_F16_D64,
182 K_F16_D64_BOOLMASK,
183];
184
185// ─── Registration ─────────────────────────────────────────────────────────────
186
187/// Register all flash-attention prefill kernel entry points with the registry.
188///
189/// Maps all 12 entry-point names to the single `flash_attn_prefill.metal`
190/// source. This must be called before any dispatch to these kernels.
191///
192/// # Design note
193///
194/// `KernelRegistry` compiles one Metal library per kernel name. All 12 names
195/// point at the same source text, so the Metal compiler sees the same ~1 500-line
196/// source each time — compilation is amortised in `KernelRegistry::get_pipeline`
197/// (first call per name triggers compilation; subsequent calls return the cached
198/// pipeline). Registering all 12 here rather than only the Phase 1a subset
199/// means Phase 2/4 dispatcher functions can be added without touching this file.
200pub fn register(registry: &mut KernelRegistry) {
201 for &name in ALL_KERNEL_NAMES {
202 registry.register_source(name, FLASH_ATTN_PREFILL_SHADER_SOURCE);
203 }
204}
205
206// ─── MSL struct mirrors ───────────────────────────────────────────────────────
207
208/// Rust mirror of the MSL `AttnParams` struct.
209///
210/// Field order and types match the MSL definition exactly.
211/// MSL source: `flash_attn_prefill.metal` — see the `AttnParams` struct
212/// definition in the kernel source for the field-by-field reference.
213///
214/// # Layout
215///
216/// All `int` fields are 32-bit (i32 in Rust). The `int64_t` stride arrays are
217/// 64-bit (i64 in Rust). The compiler inserts natural alignment padding:
218///
219/// ```text
220/// Offset 0: B (i32, 4 bytes)
221/// Offset 4: H (i32, 4 bytes)
222/// Offset 8: D (i32, 4 bytes)
223/// Offset 12: qL (i32, 4 bytes)
224/// Offset 16: kL (i32, 4 bytes)
225/// Offset 20: gqa_factor (i32, 4 bytes)
226/// Offset 24: scale (f32, 4 bytes)
227/// Offset 28: softcapping (f32, 4 bytes)
228/// Offset 32: NQ (i32, 4 bytes)
229/// Offset 36: NK (i32, 4 bytes)
230/// Offset 40: NQ_aligned (i32, 4 bytes)
231/// Offset 44: NK_aligned (i32, 4 bytes)
232/// Offset 48: qL_rem (i32, 4 bytes)
233/// Offset 52: kL_rem (i32, 4 bytes)
234/// Offset 56: qL_off (i32, 4 bytes)
235/// Offset 60: _pad (4 bytes — alignment before i64 array)
236/// Offset 64: Q_strides[3] (3 × i64, 24 bytes)
237/// Offset 88: K_strides[3] (3 × i64, 24 bytes)
238/// Offset 112: V_strides[3] (3 × i64, 24 bytes)
239/// Offset 136: O_strides[3] (3 × i64, 24 bytes)
240/// Total: 160 bytes
241/// ```
242///
243/// `bytemuck::Pod` / `bytemuck::Zeroable` are derived — the struct must have
244/// no uninitialized padding bytes. The explicit `_pad` field makes the padding
245/// concrete so Pod can be derived safely.
246///
247/// `softcapping` is always set to `1.0` (disabled). The `attention<>` kernel
248/// body does not read it; the field exists for ABI parity with attention
249/// implementations that thread softcapping through the same param block
250/// (e.g. Gemma-style logit softcap), so we don't have to redo the layout
251/// when that work lands.
252#[repr(C)]
253#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
254pub struct AttnParamsGpu {
255 /// Batch size.
256 pub b: i32,
257 /// Number of query heads.
258 pub h: i32,
259 /// Head dimension (D).
260 pub d: i32,
261 /// Query sequence length.
262 pub ql: i32,
263 /// Key/value sequence length.
264 pub kl: i32,
265 /// Group-query attention factor: H / H_kv.
266 pub gqa_factor: i32,
267 /// Attention scale (= 1.0 / sqrt(head_dim); kernel multiples by log2(e)).
268 pub scale: f32,
269 /// Softcapping value — always 1.0 (disabled) for standard SDPA.
270 pub softcapping: f32,
271 /// Number of Q tiles: ceil(qL / BQ).
272 pub nq: i32,
273 /// Number of KV tiles: ceil(kL / BK).
274 pub nk: i32,
275 /// Number of full (aligned) Q tiles: qL / BQ.
276 pub nq_aligned: i32,
277 /// Number of full (aligned) KV tiles: kL / BK.
278 pub nk_aligned: i32,
279 /// Remainder elements in the last Q tile: qL % BQ (0 if aligned).
280 pub ql_rem: i32,
281 /// Remainder elements in the last KV tile: kL % BK (0 if aligned).
282 pub kl_rem: i32,
283 /// Query sequence start offset (0 for standard prefill).
284 pub ql_off: i32,
285 /// Explicit padding to align the subsequent i64 arrays to 8-byte boundary.
286 pub _pad: i32,
287 /// Query strides: (batch stride, head stride, seq stride). Inner dim = 1.
288 pub q_strides: [i64; 3],
289 /// Key strides: (batch stride, head stride, seq stride). Inner dim = 1.
290 pub k_strides: [i64; 3],
291 /// Value strides: (batch stride, head stride, seq stride). Inner dim = 1.
292 pub v_strides: [i64; 3],
293 /// Output strides: (batch stride, head stride, seq stride). Inner dim = 1.
294 pub o_strides: [i64; 3],
295}
296
297/// Rust mirror of the MSL `AttnMaskParams` struct.
298///
299/// MSL source: `flash_attn_prefill.metal` — see the `AttnMaskParams` struct
300/// definition in the kernel source.
301///
302/// Contains the mask buffer strides. Only sent to the kernel when
303/// `has_mask = true` (buffer index 5).
304#[repr(C)]
305#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
306pub struct AttnMaskParamsGpu {
307 /// Mask strides: (batch stride, head stride, qL stride). Inner dim = 1.
308 pub m_strides: [i64; 3],
309}
310
311// ─── Public Rust-side parameter struct ───────────────────────────────────────
312
313/// Host-side parameters for the flash-attention prefill dispatcher.
314///
315/// Used only by the Rust dispatcher; does not map 1:1 to the GPU struct —
316/// the GPU struct is computed from these fields inside the dispatcher.
317#[derive(Debug, Clone, Copy)]
318pub struct FlashAttnPrefillParams {
319 /// Number of query attention heads.
320 pub n_heads: u32,
321 /// Number of key/value attention heads (GQA: may be < n_heads).
322 pub n_kv_heads: u32,
323 /// Head dimension. Must be 256 for `dispatch_flash_attn_prefill_bf16_d256`.
324 pub head_dim: u32,
325 /// Query sequence length.
326 pub seq_len_q: u32,
327 /// Key/value sequence length.
328 pub seq_len_k: u32,
329 /// Batch size.
330 pub batch: u32,
331 /// Attention scale. Typically `1.0 / sqrt(head_dim)`.
332 ///
333 /// The kernel internally multiplies this by `log2(e) = 1.44269504089`
334 /// before applying it to Q. The host MUST NOT pre-multiply by log2(e).
335 pub scale: f32,
336 /// Whether to apply in-kernel causal masking (`do_causal` function constant).
337 ///
338 /// When true, positions where `row_pos < col_pos` receive a score of -inf
339 /// before softmax. This can be combined with an external mask buffer.
340 pub do_causal: bool,
341}
342
343// ─── Tile geometry constants (D=256) ─────────────────────────────────────────
344
345/// Q tile size for D=256: BQ=32.
346const BQ_D256: u32 = 32;
347
348/// KV tile size for D=256: BK=16.
349const BK_D256: u32 = 16;
350
351/// Simdgroups along Q dimension for D=256: WM=4.
352const WM_D256: u32 = 4;
353
354/// Simdgroups along K dimension for D=256: WN=1.
355const WN_D256: u32 = 1;
356
357// ─── Tile geometry constants (D=64) ──────────────────────────────────────────
358//
359// Same simdgroup geometry as D=256: 4 simdgroups along Q, 1 along K, with
360// BQ=32 / BK=16. Threadgroup-memory math at bf16:
361// Qs = BQ × (BD + padQ) × 2 = 32 × (64 + 8) × 2 = 4 608 bytes
362// KVs = max(BK*(BD+padV), (BK+padK)*BD) × 2
363// = max(16×72, 24×64) × 2 = max(1152, 1536) × 2
364// = 3 072 bytes
365// total ≈ 7 680 bytes — fits well under the 32 KiB Apple Silicon TG limit.
366//
367// Static-asserts required by the kernel (with kFragSize=8):
368// BQ % (kNWarps × kFragSize) = 32 % (4×8) = 0 ✓
369// BQ ≥ kNWarps × kFragSize = 32 ≥ 32 ✓
370// TQ = BQ / (kNWarps × kFragSize) = 1 ✓
371// TD = BD / kFragSize = 64/8 = 8 ✓
372// TK = BK / kFragSize = 16/8 = 2 ✓
373
374/// Q tile size for D=64: BQ=32.
375const BQ_D64: u32 = 32;
376
377/// KV tile size for D=64: BK=16.
378const BK_D64: u32 = 16;
379
380/// Simdgroups along Q dimension for D=64: WM=4.
381const WM_D64: u32 = 4;
382
383/// Simdgroups along K dimension for D=64: WN=1.
384const WN_D64: u32 = 1;
385
386// ─── Validation ───────────────────────────────────────────────────────────────
387
388fn validate_params(params: &FlashAttnPrefillParams) -> Result<()> {
389 if params.n_heads == 0 {
390 return Err(MlxError::InvalidArgument(
391 "flash_attn_prefill: n_heads must be > 0".into(),
392 ));
393 }
394 if params.n_kv_heads == 0 {
395 return Err(MlxError::InvalidArgument(
396 "flash_attn_prefill: n_kv_heads must be > 0".into(),
397 ));
398 }
399 if params.n_heads % params.n_kv_heads != 0 {
400 return Err(MlxError::InvalidArgument(format!(
401 "flash_attn_prefill: n_heads ({}) must be divisible by n_kv_heads ({})",
402 params.n_heads, params.n_kv_heads
403 )));
404 }
405 if params.seq_len_q == 0 {
406 return Err(MlxError::InvalidArgument(
407 "flash_attn_prefill: seq_len_q must be > 0".into(),
408 ));
409 }
410 if params.seq_len_k == 0 {
411 return Err(MlxError::InvalidArgument(
412 "flash_attn_prefill: seq_len_k must be > 0".into(),
413 ));
414 }
415 if params.batch == 0 {
416 return Err(MlxError::InvalidArgument(
417 "flash_attn_prefill: batch must be > 0".into(),
418 ));
419 }
420 Ok(())
421}
422
423pub(crate) fn validate_buffer_size(buf: &MlxBuffer, name: &str, expected_elements: usize) -> Result<()> {
424 let expected_bytes = expected_elements * buf.dtype().size_of();
425 if buf.byte_len() < expected_bytes {
426 return Err(MlxError::InvalidArgument(format!(
427 "flash_attn_prefill: {name} buffer too small: expected at least \
428 {expected_bytes} bytes, got {}",
429 buf.byte_len()
430 )));
431 }
432 Ok(())
433}
434
435// ─── bf16 D=256 dispatcher ───────────────────────────────────────────────────
436
437/// Dispatch flash-attention prefill for bf16 Q/K/V/O, head_dim=256.
438///
439/// Encodes a compute command into `encoder` without committing. The caller
440/// controls when to call `encoder.commit_and_wait()`.
441///
442/// # Why bf16 and not f32
443///
444/// The f32 Qs threadgroup tile (BQ×BD×4 = 32 KB at D=256) consumes the entire
445/// Apple Silicon threadgroup-memory budget before KV_smem, scratch, or any
446/// padding — so f32 D=256 is not instantiated (see module doc). bf16/f16
447/// halve the tile footprint; the MMA accumulator is still f32 internally
448/// (the kernel's `T_accum` template parameter is `float` for every
449/// instantiation — see `flash_attn_prefill.metal:~1504`), so prefill output
450/// precision is `bf16 × bf16 → f32 → bf16` — bf16-bounded at the store, not
451/// at the accumulator.
452///
453/// # Buffer layouts
454///
455/// All buffers must be contiguous (stride-1 along the innermost / head_dim
456/// dimension):
457///
458/// - `q` — `[batch, n_heads, seq_len_q, 256]`, dtype BF16
459/// - `k` — `[batch, n_kv_heads, seq_len_k, 256]`, dtype BF16
460/// - `v` — `[batch, n_kv_heads, seq_len_k, 256]`, dtype BF16
461/// - `mask` — `[batch, n_heads, seq_len_q, seq_len_k]`, dtype BF16
462/// (additive, log-scale: 0.0 = attend, -inf = mask out — llama.cpp
463/// convention, see module doc "Mask-sentinel contract"), or `None`
464/// - `out` — `[batch, n_heads, seq_len_q, 256]`, dtype BF16 (output)
465///
466/// # Function constants
467///
468/// `align_Q`, `align_K` are computed from the sequence lengths and tile sizes.
469/// `has_mask` reflects whether `mask` is `Some(_)`.
470/// `do_causal` is taken from `params.do_causal`.
471///
472/// A distinct Metal pipeline is compiled for each unique combination of these
473/// four booleans and cached in `registry`.
474///
475/// # Errors
476///
477/// Returns `MlxError::InvalidArgument` for:
478/// - `head_dim != 256`
479/// - Zero or inconsistent shape fields
480/// - Buffer too small for the declared shape
481/// - `n_heads` not divisible by `n_kv_heads`
482/// - Any buffer dtype != BF16
483///
484/// Returns `MlxError::ShaderCompilationError` if the Metal pipeline
485/// compilation fails.
486#[allow(clippy::too_many_arguments)]
487pub fn dispatch_flash_attn_prefill_bf16_d256(
488 encoder: &mut CommandEncoder,
489 device: &MlxDevice,
490 registry: &mut KernelRegistry,
491 q: &MlxBuffer,
492 k: &MlxBuffer,
493 v: &MlxBuffer,
494 mask: Option<&MlxBuffer>,
495 out: &MlxBuffer,
496 params: &FlashAttnPrefillParams,
497) -> Result<()> {
498 // Delegate to the blk-aware dispatcher with blk=None. Because
499 // `has_blk` is a function constant (index 303), the compiled pipeline
500 // with has_blk=false dead-codes every blk reference — this call has
501 // zero runtime overhead compared with the pre-Wave-2E code path.
502 dispatch_flash_attn_prefill_bf16_d256_with_blk(
503 encoder, device, registry, q, k, v, mask, None, out, params,
504 )
505}
506
507/// Dispatch flash-attention prefill for bf16 Q/K/V/O, head_dim=256, with an
508/// optional Wave 2E tile-skip pre-pass byte buffer.
509///
510/// This is the blk-aware sibling of [`dispatch_flash_attn_prefill_bf16_d256`].
511/// When `blk` is `Some(buf)`, the kernel reads a classification byte per
512/// `(qtile, ktile)` and:
513///
514/// - On `blk[qt][kt] == 0`, skips the entire KV tile (no K/V load, no
515/// Q·K^T, no mask-add, no softmax update).
516/// - On `blk[qt][kt] == 2`, skips the mask-add (but still computes Q·K^T
517/// and softmax normally).
518/// - On `blk[qt][kt] == 1`, runs the standard path.
519///
520/// The `blk` buffer must be produced by
521/// [`crate::ops::flash_attn_prefill_blk::dispatch_flash_attn_prefill_blk`]
522/// with the SAME `(BQ=32, BK=16)` tile shape as this dispatcher, and the
523/// caller MUST sequence the pre-pass dispatch BEFORE this one on the same
524/// command encoder (or commit the pre-pass first).
525///
526/// # Correctness invariant
527///
528/// For any valid (mask, built blk), calling this function with `blk=None`
529/// vs `blk=Some(built_blk)` MUST produce bit-exact identical output. The
530/// blk path is a pure skip optimisation. Exercised by
531/// `test_gpu_bf16_d256_with_blk_matches_no_blk` in
532/// `tests/test_flash_attn_prefill.rs`.
533///
534/// # Buffer layout
535///
536/// All buffers as in [`dispatch_flash_attn_prefill_bf16_d256`], plus:
537///
538/// - `blk` — `[ceil(seq_len_q / 32), ceil(seq_len_k / 16)]`, dtype U8.
539/// Required iff `Some(_)`; must have at least
540/// `ceil(qL/32) * ceil(kL/16)` bytes. When `None`, the dispatcher
541/// compiles with `has_blk=false` and does not bind buffer index 7.
542///
543/// # Errors
544///
545/// Same as [`dispatch_flash_attn_prefill_bf16_d256`], plus:
546/// - `blk` is `Some(b)` with `mask = None` (a blk without a mask is
547/// meaningless — rejected to catch caller bugs).
548/// - `blk` buffer undersized.
549#[allow(clippy::too_many_arguments)]
550pub fn dispatch_flash_attn_prefill_bf16_d256_with_blk(
551 encoder: &mut CommandEncoder,
552 device: &MlxDevice,
553 registry: &mut KernelRegistry,
554 q: &MlxBuffer,
555 k: &MlxBuffer,
556 v: &MlxBuffer,
557 mask: Option<&MlxBuffer>,
558 blk: Option<&MlxBuffer>,
559 out: &MlxBuffer,
560 params: &FlashAttnPrefillParams,
561) -> Result<()> {
562 // ── Validate ──────────────────────────────────────────────────────────
563 if params.head_dim != 256 {
564 return Err(MlxError::InvalidArgument(format!(
565 "dispatch_flash_attn_prefill_bf16_d256: head_dim must be 256, got {}",
566 params.head_dim
567 )));
568 }
569 // Reject the meaningless has_blk=true, has_mask=false case: a blk
570 // buffer classifies the contents of the mask; without a mask the blk
571 // bytes are computed against undefined data. This is a caller-bug
572 // defence, not a shader limitation.
573 if blk.is_some() && mask.is_none() {
574 return Err(MlxError::InvalidArgument(
575 "dispatch_flash_attn_prefill_bf16_d256_with_blk: \
576 blk requires mask (a blk without a mask is meaningless)"
577 .into(),
578 ));
579 }
580 validate_params(params)?;
581
582 // All buffers must be BF16 for this dispatcher.
583 for (buf, name) in &[(q, "Q"), (k, "K"), (v, "V"), (out as &MlxBuffer, "out")] {
584 if buf.dtype() != DType::BF16 {
585 return Err(MlxError::InvalidArgument(format!(
586 "dispatch_flash_attn_prefill_bf16_d256: {name} buffer must be BF16, \
587 got {:?}",
588 buf.dtype()
589 )));
590 }
591 }
592 if let Some(m) = mask {
593 if m.dtype() != DType::BF16 {
594 return Err(MlxError::InvalidArgument(format!(
595 "dispatch_flash_attn_prefill_bf16_d256: mask buffer must be BF16, \
596 got {:?}",
597 m.dtype()
598 )));
599 }
600 }
601
602 let batch = params.batch as usize;
603 let h = params.n_heads as usize;
604 let h_kv = params.n_kv_heads as usize;
605 let ql = params.seq_len_q as usize;
606 let kl = params.seq_len_k as usize;
607 let d = params.head_dim as usize; // = 256
608
609 // Validate buffer element counts.
610 validate_buffer_size(q, "Q", batch * h * ql * d)?;
611 validate_buffer_size(k, "K", batch * h_kv * kl * d)?;
612 validate_buffer_size(v, "V", batch * h_kv * kl * d)?;
613 validate_buffer_size(out, "out", batch * h * ql * d)?;
614 // A rank-2 mask `[qL, kL]` is the Wave 2D broadcast layout: one plane is
615 // shared across all batches and heads (stride-0 in the batch and head dims).
616 // A rank-4 mask `[B, H, qL, kL]` is the per-head layout used by callers
617 // that already have a fully-expanded mask (e.g. pre-Wave-2D code paths).
618 let mask_is_rank2_broadcast = mask.is_some_and(|m| m.shape().len() == 2);
619 if let Some(m) = mask {
620 if mask_is_rank2_broadcast {
621 validate_buffer_size(m, "mask", ql * kl)?;
622 } else {
623 validate_buffer_size(m, "mask", batch * h * ql * kl)?;
624 }
625 }
626
627 // ── Tile geometry ─────────────────────────────────────────────────────
628 let bq = BQ_D256;
629 let bk = BK_D256;
630 let wm = WM_D256;
631 let wn = WN_D256;
632
633 let nq = params.seq_len_q.div_ceil(bq);
634 let nk = params.seq_len_k.div_ceil(bk);
635 let nq_aligned = params.seq_len_q / bq;
636 let nk_aligned = params.seq_len_k / bk;
637 let ql_rem = params.seq_len_q % bq;
638 let kl_rem = params.seq_len_k % bk;
639
640 // Function constants (specialised at pipeline creation time, not dispatch).
641 let align_q = ql_rem == 0;
642 let align_k = kl_rem == 0;
643 let has_mask = mask.is_some();
644 let has_blk = blk.is_some();
645 let do_causal = params.do_causal;
646
647 // Validate blk buffer size when present. Tile shape is fixed for D=256:
648 // BQ=32, BK=16 — see ADR-011-phase2-port-tile-skip.md §5.1.
649 if let Some(b) = blk {
650 let nq_tiles = ql.div_ceil(BQ_D256 as usize);
651 let nk_tiles = kl.div_ceil(BK_D256 as usize);
652 let expected = nq_tiles * nk_tiles;
653 if b.byte_len() < expected {
654 return Err(MlxError::InvalidArgument(format!(
655 "dispatch_flash_attn_prefill_bf16_d256_with_blk: blk buffer \
656 too small: expected at least {expected} bytes (NQ={nq_tiles}, \
657 NK={nk_tiles}), got {}",
658 b.byte_len()
659 )));
660 }
661 }
662
663 // ── Kernel name ───────────────────────────────────────────────────────
664 // bf16 I/O, bf16 additive mask (or no mask — uses same pipeline since
665 // has_mask is a function constant, not part of the name).
666 let kernel_name = K_BF16_D256;
667
668 // ── Pipeline lookup (with function constants) ─────────────────────────
669 //
670 // Wave 2E adds function constant 303 (has_blk). When has_blk=false
671 // every blk reference in the shader is dead-coded, so pipelines with
672 // the pre-Wave-2E constants {200, 201, 300, 301} + {303: false}
673 // generate the same machine code as the pre-Wave-2E pipelines. The
674 // only cache-key overhead is one extra "b0" suffix, which is paid
675 // once per (aligned, causal, masked) combo at compile time.
676 let pipeline = registry.get_pipeline_with_bool_constants(
677 kernel_name,
678 device.metal_device(),
679 &[
680 (200, align_q),
681 (201, align_k),
682 (300, has_mask),
683 (301, do_causal),
684 (303, has_blk),
685 ],
686 )?;
687
688 // ── Build AttnParams GPU struct ───────────────────────────────────────
689 //
690 // Strides for layout [B, H, L, D] where the innermost (D) stride is 1
691 // (contiguous):
692 //
693 // seq stride = D
694 // head stride = L * D
695 // batch stride = H * L * D
696 //
697 // Q/O use (H, qL); K/V use (H_kv, kL).
698
699 let q_seq_stride = d as i64;
700 let q_head_stride = (ql * d) as i64;
701 let q_batch_stride = (h * ql * d) as i64;
702
703 let kv_seq_stride = d as i64;
704 let kv_head_stride = (kl * d) as i64;
705 let kv_batch_stride = (h_kv * kl * d) as i64;
706
707 let gqa_factor = (params.n_heads / params.n_kv_heads) as i32;
708
709 let attn_params = AttnParamsGpu {
710 b: params.batch as i32,
711 h: params.n_heads as i32,
712 d: params.head_dim as i32,
713 ql: params.seq_len_q as i32,
714 kl: params.seq_len_k as i32,
715 gqa_factor,
716 scale: params.scale,
717 softcapping: 1.0_f32, // always disabled; see module doc
718 nq: nq as i32,
719 nk: nk as i32,
720 nq_aligned: nq_aligned as i32,
721 nk_aligned: nk_aligned as i32,
722 ql_rem: ql_rem as i32,
723 kl_rem: kl_rem as i32,
724 ql_off: 0, // standard prefill starts at offset 0
725 _pad: 0,
726 q_strides: [q_batch_stride, q_head_stride, q_seq_stride],
727 k_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
728 v_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
729 o_strides: [q_batch_stride, q_head_stride, q_seq_stride],
730 };
731
732 // ── Grid geometry ──────────────────────────────────────────────────────
733 // grid = (NQ, H, B) where NQ = ceil(qL / BQ)
734 // threadgroup = (32, WM, WN)
735 let grid = MTLSize::new(nq as u64, params.n_heads as u64, params.batch as u64);
736 let tg_size = MTLSize::new(32, wm as u64, wn as u64);
737
738 // ── Encode ─────────────────────────────────────────────────────────────
739 encoder.set_op_kind(CapturedOpKind::Sdpa);
740
741 if has_mask {
742 // SAFETY: has_mask is true iff mask.is_some() — set three lines above.
743 // The Option is therefore guaranteed to be Some here. We use
744 // ok_or_else rather than expect/unwrap to satisfy the no-panic policy.
745 let mask_buf = mask.ok_or_else(|| {
746 MlxError::InvalidArgument(
747 "flash_attn_prefill: internal error — has_mask=true but mask is None".into(),
748 )
749 })?;
750
751 // Strides depend on mask rank:
752 // rank-2 `[qL, kL]` — broadcast across batch + heads: set batch_stride
753 // and head_stride to 0 so the shader re-reads the same plane for every
754 // (batch, head) pair. The Metal shader already handles stride-0 correctly
755 // (no kernel changes required).
756 // rank-4 `[B, H, qL, kL]` — per-head layout (back-compat path).
757 let (m_batch_stride, m_head_stride, m_ql_stride) = if mask_is_rank2_broadcast {
758 (0_i64, 0_i64, kl as i64)
759 } else {
760 ((h * ql * kl) as i64, (ql * kl) as i64, kl as i64)
761 };
762
763 let mask_params = AttnMaskParamsGpu {
764 m_strides: [m_batch_stride, m_head_stride, m_ql_stride],
765 };
766
767 if has_blk {
768 // SAFETY: has_blk is true iff blk.is_some() and the earlier
769 // has_blk validation rejects blk.is_some() && mask.is_none().
770 let blk_buf = blk.ok_or_else(|| {
771 MlxError::InvalidArgument(
772 "flash_attn_prefill: internal error — has_blk=true but blk is None".into(),
773 )
774 })?;
775
776 encoder.encode_threadgroups_with_args(
777 pipeline,
778 &[
779 (0, KernelArg::Buffer(q)),
780 (1, KernelArg::Buffer(k)),
781 (2, KernelArg::Buffer(v)),
782 (3, KernelArg::Buffer(out)),
783 (4, KernelArg::Bytes(as_bytes(&attn_params))),
784 (5, KernelArg::Bytes(as_bytes(&mask_params))),
785 (6, KernelArg::Buffer(mask_buf)),
786 (7, KernelArg::Buffer(blk_buf)),
787 ],
788 grid,
789 tg_size,
790 );
791 } else {
792 encoder.encode_threadgroups_with_args(
793 pipeline,
794 &[
795 (0, KernelArg::Buffer(q)),
796 (1, KernelArg::Buffer(k)),
797 (2, KernelArg::Buffer(v)),
798 (3, KernelArg::Buffer(out)),
799 (4, KernelArg::Bytes(as_bytes(&attn_params))),
800 (5, KernelArg::Bytes(as_bytes(&mask_params))),
801 (6, KernelArg::Buffer(mask_buf)),
802 // buffer 7 absent — has_blk=false constant dead-codes blk refs.
803 ],
804 grid,
805 tg_size,
806 );
807 }
808 } else {
809 encoder.encode_threadgroups_with_args(
810 pipeline,
811 &[
812 (0, KernelArg::Buffer(q)),
813 (1, KernelArg::Buffer(k)),
814 (2, KernelArg::Buffer(v)),
815 (3, KernelArg::Buffer(out)),
816 (4, KernelArg::Bytes(as_bytes(&attn_params))),
817 // buffers 5, 6, 7 intentionally absent — has_mask=false +
818 // has_blk=false constants dead-code-eliminate mask + blk loads.
819 ],
820 grid,
821 tg_size,
822 );
823 }
824
825 Ok(())
826}
827
828// ─── bf16 D=256 RESUME dispatcher (qL_off > 0 + slot-capacity strides) ──────
829//
830// ADR-017 Phase E.a B.2-fix: extend the FA bf16 d256 fast path to support
831// LCP partial-prefill resume (turn 2 of a multi-turn conversation prefilling
832// only the suffix against a slot that already contains the LCP prefix).
833//
834// The metal shader has supported this regime since Phase 1 port (the
835// `qL_off` field at flash_attn_prefill.metal:1045 + `K_strides[3]` /
836// `V_strides[3]` at lines 1048-1049 are read at lines 1325, 1437, 1445), but
837// the d256 dispatcher hardcoded `qL_off=0` and built K/V strides from
838// `seq_len_k` not `kv_capacity` because no Rust caller needed
839// "prefill-at-offset" semantics — Qwen3.5/3.6 production prefill is always
840// from token 0 (cur_len=0) and decode (cur_len > 0, seq_len=1) routes to
841// `flash_attn_vec` at gpu_full_attn.rs:1733 instead.
842//
843// The legacy F32 SDPA fallback at gpu_full_attn.rs:1900-1916 covered the
844// "prefill-at-offset" case for structural correctness, but with a F32
845// single-pass softmax that produces byte-different output to the BF16 MMA
846// + log-domain online softmax fast path (proven via
847// gpu_full_attn.rs::tests::phase_b2_iso_fast_path_vs_fallback_path_kernel_divergence:
848// 131072/131072 elements differ, max |Δ| = 6.452e-4).
849//
850// This module fixes that: a sibling dispatcher that exposes `qL_off` and
851// `kv_capacity` so the FA bf16 d256 fast path is reachable for cur_len > 0
852// AND for slot reads (slot K/V's head stride is `kv_capacity * head_dim`,
853// not `kL * head_dim`, because a slot may have allocated more positions
854// than the current valid kL).
855
856/// Host-side parameters for the flash-attention prefill **resume** dispatcher.
857///
858/// Differs from [`FlashAttnPrefillParams`] in two ways:
859///
860/// 1. `q_offset_in_k` (mapped to the kernel's `qL_off`): the absolute Q
861/// position of the chunk Q within the larger K/V sequence. When > 0,
862/// Q is being attended over a slot that already contains
863/// `q_offset_in_k` previous tokens. The kernel's causal mask uses
864/// `qL_off` to compute `row_pos = tid.x * BQ + qL_off + ...`
865/// (`flash_attn_prefill.metal:1325, 1445`) so the per-chunk causal
866/// pattern stays correct relative to the absolute K/V position.
867///
868/// 2. `kv_capacity` (slot stride): K and V buffers may live in a slot of
869/// capacity ≥ `seq_len_k`. The kernel reads K/V via integer strides
870/// from `K_strides[3]` / `V_strides[3]`, so we set
871/// `head_stride = kv_capacity * D` to skip unused slot capacity between
872/// heads. When `kv_capacity == seq_len_k` the layout is identical to
873/// the non-resume dispatcher.
874///
875/// # Use case
876///
877/// ADR-017 Phase E.a LCP partial-prefill resume. Turn 1 of a multi-turn
878/// conversation prefills the prompt, populates the KV slot, snapshots it.
879/// Turn 2 detects an LCP overlap with turn 1, restores the slot, and
880/// prefills only the suffix of turn 2 (M new tokens) against the full slot
881/// (kL = N + M, qL = M, qL_off = N). The output is byte-identical to a
882/// fresh full prefill of the entire turn-2 prompt — proven by the parity
883/// unit test
884/// `flash_attn_prefill_bf16_d256_resume_byte_identical_to_monolithic`.
885///
886/// The non-resume dispatcher remains the production path for prefill-from-
887/// zero (`q_offset_in_k == 0` && `kv_capacity == seq_len_k`). The two
888/// dispatchers compile to the same kernel pipeline (same kernel name +
889/// same function constants when `q_offset_in_k == 0` && `kv_capacity ==
890/// seq_len_k`); the only host-side difference is the strides written into
891/// `AttnParamsGpu`.
892#[derive(Debug, Clone, Copy)]
893pub struct FlashAttnPrefillResumeParams {
894 /// Number of query attention heads.
895 pub n_heads: u32,
896 /// Number of key/value attention heads (GQA).
897 pub n_kv_heads: u32,
898 /// Head dimension. Must be 256 for `dispatch_flash_attn_prefill_bf16_d256_resume`.
899 pub head_dim: u32,
900 /// Query sequence length (chunk Q only). qL.
901 pub seq_len_q: u32,
902 /// Total key/value sequence length: `q_offset_in_k + seq_len_q`. kL.
903 pub seq_len_k: u32,
904 /// Batch size.
905 pub batch: u32,
906 /// Attention scale. Typically `1.0 / sqrt(head_dim)`. The kernel
907 /// internally multiplies by `log2(e)` before applying to Q.
908 pub scale: f32,
909 /// Whether to apply in-kernel causal masking. For partial-prefill resume
910 /// this is typically `true` — chunk Q must causally mask K positions
911 /// `> q_offset_in_k + q_pos_within_chunk`.
912 pub do_causal: bool,
913 /// Q offset within the K/V sequence (`qL_off` in the kernel). Number of
914 /// previous tokens already present in the slot. Standard append-prefill
915 /// semantics: `q_offset_in_k + seq_len_q == seq_len_k`.
916 pub q_offset_in_k: u32,
917 /// Capacity of the K/V slot — stride between K/V heads in elements.
918 /// Set to `seq_len_k` when K/V are contiguous-packed (equivalent to the
919 /// non-resume dispatcher); set to the slot's allocated capacity to read
920 /// from a slot with unused trailing positions. Must be `>= seq_len_k`.
921 pub kv_capacity: u32,
922}
923
924/// Dispatch flash-attention prefill for bf16 Q/K/V/O, head_dim=256, with
925/// caller-controlled `qL_off` and slot-capacity-aware K/V strides.
926///
927/// **Use this dispatcher when** `q_offset_in_k > 0` (partial-prefill resume)
928/// OR when the K/V buffers have allocated capacity beyond `seq_len_k` (slot
929/// reads). For the prefill-from-zero contiguous-packed case prefer
930/// [`dispatch_flash_attn_prefill_bf16_d256`] (functionally equivalent at
931/// `q_offset_in_k=0` && `kv_capacity==seq_len_k`; verified via the parity
932/// unit test in `tests/test_flash_attn_prefill.rs`).
933///
934/// # Buffer layouts
935///
936/// All buffers must be contiguous along the innermost head_dim=256 axis.
937///
938/// - `q` — `[batch, n_heads, seq_len_q, 256]`, dtype BF16
939/// * head stride = `seq_len_q * 256`, seq stride = 256, batch stride =
940/// `n_heads * seq_len_q * 256`.
941/// - `k` — `[batch, n_kv_heads, kv_capacity, 256]`, dtype BF16
942/// * head stride = `kv_capacity * 256`, seq stride = 256, batch stride =
943/// `n_kv_heads * kv_capacity * 256`.
944/// * Only positions `[0..seq_len_k]` are read; positions
945/// `[seq_len_k..kv_capacity]` may be uninitialised — the kernel will
946/// not read past `seq_len_k` thanks to `kL`-aware tile bounds at
947/// `flash_attn_prefill.metal:1322, 1416`.
948/// - `v` — same layout as `k`.
949/// - `out` — `[batch, n_heads, seq_len_q, 256]`, dtype BF16 (output)
950///
951/// # Function constants
952///
953/// Identical pipeline cache key to the non-resume dispatcher: `align_Q`,
954/// `align_K` from sequence lengths; `do_causal` from `params`; `has_mask`
955/// and `has_blk` are both `false` (this resume dispatcher uses pure causal
956/// masking via `qL_off` — no external additive mask is supported. Add a
957/// resume-with-mask sibling if a future caller needs it).
958///
959/// # Errors
960///
961/// Returns `MlxError::InvalidArgument` for:
962/// - `head_dim != 256`
963/// - `q_offset_in_k + seq_len_q > seq_len_k` (Q overshoots K)
964/// - `seq_len_k > kv_capacity` (kL overshoots slot capacity)
965/// - Zero or inconsistent shape fields
966/// - `n_heads` not divisible by `n_kv_heads`
967/// - Buffer too small for declared shape (Q/O against `seq_len_q`,
968/// K/V against `kv_capacity`)
969/// - Any buffer dtype != BF16
970#[allow(clippy::too_many_arguments)]
971pub fn dispatch_flash_attn_prefill_bf16_d256_resume(
972 encoder: &mut CommandEncoder,
973 device: &MlxDevice,
974 registry: &mut KernelRegistry,
975 q: &MlxBuffer,
976 k: &MlxBuffer,
977 v: &MlxBuffer,
978 out: &MlxBuffer,
979 params: &FlashAttnPrefillResumeParams,
980) -> Result<()> {
981 // ── Validate ──────────────────────────────────────────────────────────
982 if params.head_dim != 256 {
983 return Err(MlxError::InvalidArgument(format!(
984 "dispatch_flash_attn_prefill_bf16_d256_resume: head_dim must be 256, got {}",
985 params.head_dim
986 )));
987 }
988 if params.n_heads == 0
989 || params.n_kv_heads == 0
990 || params.seq_len_q == 0
991 || params.seq_len_k == 0
992 || params.batch == 0
993 {
994 return Err(MlxError::InvalidArgument(
995 "dispatch_flash_attn_prefill_bf16_d256_resume: \
996 n_heads/n_kv_heads/seq_len_q/seq_len_k/batch must all be > 0"
997 .into(),
998 ));
999 }
1000 if params.n_heads % params.n_kv_heads != 0 {
1001 return Err(MlxError::InvalidArgument(format!(
1002 "dispatch_flash_attn_prefill_bf16_d256_resume: n_heads ({}) must \
1003 be divisible by n_kv_heads ({})",
1004 params.n_heads, params.n_kv_heads
1005 )));
1006 }
1007 if params.q_offset_in_k + params.seq_len_q > params.seq_len_k {
1008 return Err(MlxError::InvalidArgument(format!(
1009 "dispatch_flash_attn_prefill_bf16_d256_resume: q_offset_in_k ({}) \
1010 + seq_len_q ({}) > seq_len_k ({}) — Q overshoots K",
1011 params.q_offset_in_k, params.seq_len_q, params.seq_len_k
1012 )));
1013 }
1014 if params.seq_len_k > params.kv_capacity {
1015 return Err(MlxError::InvalidArgument(format!(
1016 "dispatch_flash_attn_prefill_bf16_d256_resume: seq_len_k ({}) > \
1017 kv_capacity ({}) — K/V overshoots slot capacity",
1018 params.seq_len_k, params.kv_capacity
1019 )));
1020 }
1021
1022 for (buf, name) in &[(q, "Q"), (k, "K"), (v, "V"), (out as &MlxBuffer, "out")] {
1023 if buf.dtype() != DType::BF16 {
1024 return Err(MlxError::InvalidArgument(format!(
1025 "dispatch_flash_attn_prefill_bf16_d256_resume: {name} buffer \
1026 must be BF16, got {:?}",
1027 buf.dtype()
1028 )));
1029 }
1030 }
1031
1032 let batch = params.batch as usize;
1033 let h = params.n_heads as usize;
1034 let h_kv = params.n_kv_heads as usize;
1035 let ql = params.seq_len_q as usize;
1036 let cap = params.kv_capacity as usize;
1037 let d = params.head_dim as usize; // = 256
1038
1039 validate_buffer_size(q, "Q", batch * h * ql * d)?;
1040 // K/V buffers are validated against capacity (slot layout), not seq_len_k.
1041 validate_buffer_size(k, "K", batch * h_kv * cap * d)?;
1042 validate_buffer_size(v, "V", batch * h_kv * cap * d)?;
1043 validate_buffer_size(out, "out", batch * h * ql * d)?;
1044
1045 // ── Tile geometry (D=256) ─────────────────────────────────────────────
1046 let bq = BQ_D256;
1047 let bk = BK_D256;
1048 let wm = WM_D256;
1049 let wn = WN_D256;
1050
1051 let nq = params.seq_len_q.div_ceil(bq);
1052 let nk = params.seq_len_k.div_ceil(bk);
1053 let nq_aligned = params.seq_len_q / bq;
1054 let nk_aligned = params.seq_len_k / bk;
1055 let ql_rem = params.seq_len_q % bq;
1056 let kl_rem = params.seq_len_k % bk;
1057
1058 let align_q = ql_rem == 0;
1059 let align_k = kl_rem == 0;
1060 let has_mask = false; // resume uses pure causal; no external mask
1061 let has_blk = false;
1062 let do_causal = params.do_causal;
1063
1064 // Same pipeline cache key as the non-resume dispatcher: when
1065 // qL_off=0 && kv_capacity=seq_len_k the two dispatchers compile to the
1066 // exact same Metal pipeline (verified at the parity test).
1067 let kernel_name = K_BF16_D256;
1068 let pipeline = registry.get_pipeline_with_bool_constants(
1069 kernel_name,
1070 device.metal_device(),
1071 &[
1072 (200, align_q),
1073 (201, align_k),
1074 (300, has_mask),
1075 (301, do_causal),
1076 (303, has_blk),
1077 ],
1078 )?;
1079
1080 // ── Strides ───────────────────────────────────────────────────────────
1081 //
1082 // Q/O are contiguous-packed (qL is the actual extent — chunk Q has no
1083 // tail capacity). K/V use kv_capacity for head stride (slot layout).
1084 // Inner stride (D) is always 1.
1085 let q_seq_stride = d as i64;
1086 let q_head_stride = (ql * d) as i64;
1087 let q_batch_stride = (h * ql * d) as i64;
1088
1089 // K/V head stride uses kv_capacity (slot stride), NOT seq_len_k. This is
1090 // the only stride that differs from the non-resume dispatcher.
1091 let kv_seq_stride = d as i64;
1092 let kv_head_stride = (cap * d) as i64;
1093 let kv_batch_stride = (h_kv * cap * d) as i64;
1094
1095 let gqa_factor = (params.n_heads / params.n_kv_heads) as i32;
1096
1097 let attn_params = AttnParamsGpu {
1098 b: params.batch as i32,
1099 h: params.n_heads as i32,
1100 d: params.head_dim as i32,
1101 ql: params.seq_len_q as i32,
1102 kl: params.seq_len_k as i32,
1103 gqa_factor,
1104 scale: params.scale,
1105 softcapping: 1.0_f32,
1106 nq: nq as i32,
1107 nk: nk as i32,
1108 nq_aligned: nq_aligned as i32,
1109 nk_aligned: nk_aligned as i32,
1110 ql_rem: ql_rem as i32,
1111 kl_rem: kl_rem as i32,
1112 ql_off: params.q_offset_in_k as i32, // ← THE resume change
1113 _pad: 0,
1114 q_strides: [q_batch_stride, q_head_stride, q_seq_stride],
1115 k_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
1116 v_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
1117 o_strides: [q_batch_stride, q_head_stride, q_seq_stride],
1118 };
1119
1120 // ── Grid geometry ─────────────────────────────────────────────────────
1121 let grid = MTLSize::new(nq as u64, params.n_heads as u64, params.batch as u64);
1122 let tg_size = MTLSize::new(32, wm as u64, wn as u64);
1123
1124 // ── Encode ────────────────────────────────────────────────────────────
1125 encoder.set_op_kind(CapturedOpKind::Sdpa);
1126 encoder.encode_threadgroups_with_args(
1127 pipeline,
1128 &[
1129 (0, KernelArg::Buffer(q)),
1130 (1, KernelArg::Buffer(k)),
1131 (2, KernelArg::Buffer(v)),
1132 (3, KernelArg::Buffer(out)),
1133 (4, KernelArg::Bytes(as_bytes(&attn_params))),
1134 // buffers 5, 6, 7 intentionally absent — has_mask=false +
1135 // has_blk=false function constants dead-code-eliminate the
1136 // mask + blk loads.
1137 ],
1138 grid,
1139 tg_size,
1140 );
1141
1142 Ok(())
1143}
1144
1145// ─── f16 D=256 RESUME dispatcher (qL_off > 0 + slot-capacity strides) ──────
1146//
1147// F16 variant of `dispatch_flash_attn_prefill_bf16_d256_resume` added for
1148// ADR-030 iter-74 (hf2q DFlash spec-decode). The hf2q hybrid KV cache
1149// stores K and V in F16 (`hybrid_kv.k` shape `[H_kv, capacity, D]` F16
1150// after iter-348 Phase 10c). For spec-decode verify, the orchestrator
1151// needs cross-length attention against the F16 hybrid_kv slot without
1152// the F16→BF16 cast overhead.
1153//
1154// Bit-identical port of the BF16 dispatcher with:
1155// * dtype check BF16 → F16
1156// * kernel name K_BF16_D256 → K_F16_D256
1157//
1158// Everything else identical: same strides, same params layout, same
1159// causal `qL_off` semantics, same pipeline-cache key shape. The Metal
1160// kernel is templated on dtype; the F16 instance is registered in
1161// `register()` at the top of this file.
1162
1163/// F16 sibling of [`dispatch_flash_attn_prefill_bf16_d256_resume`]. Same
1164/// semantics, F16 dtype throughout. See the BF16 doc-comment for the
1165/// resume-mode contract, buffer layouts, function constants, and error
1166/// conditions — they all apply here verbatim except every "BF16" reads
1167/// as "F16".
1168#[allow(clippy::too_many_arguments)]
1169pub fn dispatch_flash_attn_prefill_f16_d256_resume(
1170 encoder: &mut CommandEncoder,
1171 device: &MlxDevice,
1172 registry: &mut KernelRegistry,
1173 q: &MlxBuffer,
1174 k: &MlxBuffer,
1175 v: &MlxBuffer,
1176 out: &MlxBuffer,
1177 params: &FlashAttnPrefillResumeParams,
1178) -> Result<()> {
1179 // ── Validate ──────────────────────────────────────────────────────────
1180 if params.head_dim != 256 {
1181 return Err(MlxError::InvalidArgument(format!(
1182 "dispatch_flash_attn_prefill_f16_d256_resume: head_dim must be 256, got {}",
1183 params.head_dim
1184 )));
1185 }
1186 if params.n_heads == 0
1187 || params.n_kv_heads == 0
1188 || params.seq_len_q == 0
1189 || params.seq_len_k == 0
1190 || params.batch == 0
1191 {
1192 return Err(MlxError::InvalidArgument(
1193 "dispatch_flash_attn_prefill_f16_d256_resume: \
1194 n_heads/n_kv_heads/seq_len_q/seq_len_k/batch must all be > 0"
1195 .into(),
1196 ));
1197 }
1198 if params.n_heads % params.n_kv_heads != 0 {
1199 return Err(MlxError::InvalidArgument(format!(
1200 "dispatch_flash_attn_prefill_f16_d256_resume: n_heads ({}) must \
1201 be divisible by n_kv_heads ({})",
1202 params.n_heads, params.n_kv_heads
1203 )));
1204 }
1205 if params.q_offset_in_k + params.seq_len_q > params.seq_len_k {
1206 return Err(MlxError::InvalidArgument(format!(
1207 "dispatch_flash_attn_prefill_f16_d256_resume: q_offset_in_k ({}) \
1208 + seq_len_q ({}) > seq_len_k ({}) — Q overshoots K",
1209 params.q_offset_in_k, params.seq_len_q, params.seq_len_k
1210 )));
1211 }
1212 if params.seq_len_k > params.kv_capacity {
1213 return Err(MlxError::InvalidArgument(format!(
1214 "dispatch_flash_attn_prefill_f16_d256_resume: seq_len_k ({}) > \
1215 kv_capacity ({}) — K/V overshoots slot capacity",
1216 params.seq_len_k, params.kv_capacity
1217 )));
1218 }
1219
1220 for (buf, name) in &[(q, "Q"), (k, "K"), (v, "V"), (out as &MlxBuffer, "out")] {
1221 if buf.dtype() != DType::F16 {
1222 return Err(MlxError::InvalidArgument(format!(
1223 "dispatch_flash_attn_prefill_f16_d256_resume: {name} buffer \
1224 must be F16, got {:?}",
1225 buf.dtype()
1226 )));
1227 }
1228 }
1229
1230 let batch = params.batch as usize;
1231 let h = params.n_heads as usize;
1232 let h_kv = params.n_kv_heads as usize;
1233 let ql = params.seq_len_q as usize;
1234 let cap = params.kv_capacity as usize;
1235 let d = params.head_dim as usize; // = 256
1236
1237 validate_buffer_size(q, "Q", batch * h * ql * d)?;
1238 validate_buffer_size(k, "K", batch * h_kv * cap * d)?;
1239 validate_buffer_size(v, "V", batch * h_kv * cap * d)?;
1240 validate_buffer_size(out, "out", batch * h * ql * d)?;
1241
1242 // ── Tile geometry (D=256) ─ same as BF16 path ─────────────────────────
1243 let bq = BQ_D256;
1244 let bk = BK_D256;
1245 let wm = WM_D256;
1246 let wn = WN_D256;
1247
1248 let nq = params.seq_len_q.div_ceil(bq);
1249 let nk = params.seq_len_k.div_ceil(bk);
1250 let nq_aligned = params.seq_len_q / bq;
1251 let nk_aligned = params.seq_len_k / bk;
1252 let ql_rem = params.seq_len_q % bq;
1253 let kl_rem = params.seq_len_k % bk;
1254
1255 let align_q = ql_rem == 0;
1256 let align_k = kl_rem == 0;
1257 let has_mask = false;
1258 let has_blk = false;
1259 let do_causal = params.do_causal;
1260
1261 // ── Pipeline lookup — F16 kernel ──────────────────────────────────────
1262 let kernel_name = K_F16_D256;
1263 let pipeline = registry.get_pipeline_with_bool_constants(
1264 kernel_name,
1265 device.metal_device(),
1266 &[
1267 (200, align_q),
1268 (201, align_k),
1269 (300, has_mask),
1270 (301, do_causal),
1271 (303, has_blk),
1272 ],
1273 )?;
1274
1275 // ── Strides ─ identical to BF16 (kernel reads via integer strides) ────
1276 let q_seq_stride = d as i64;
1277 let q_head_stride = (ql * d) as i64;
1278 let q_batch_stride = (h * ql * d) as i64;
1279
1280 let kv_seq_stride = d as i64;
1281 let kv_head_stride = (cap * d) as i64;
1282 let kv_batch_stride = (h_kv * cap * d) as i64;
1283
1284 let gqa_factor = (params.n_heads / params.n_kv_heads) as i32;
1285
1286 let attn_params = AttnParamsGpu {
1287 b: params.batch as i32,
1288 h: params.n_heads as i32,
1289 d: params.head_dim as i32,
1290 ql: params.seq_len_q as i32,
1291 kl: params.seq_len_k as i32,
1292 gqa_factor,
1293 scale: params.scale,
1294 softcapping: 1.0_f32,
1295 nq: nq as i32,
1296 nk: nk as i32,
1297 nq_aligned: nq_aligned as i32,
1298 nk_aligned: nk_aligned as i32,
1299 ql_rem: ql_rem as i32,
1300 kl_rem: kl_rem as i32,
1301 ql_off: params.q_offset_in_k as i32,
1302 _pad: 0,
1303 q_strides: [q_batch_stride, q_head_stride, q_seq_stride],
1304 k_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
1305 v_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
1306 o_strides: [q_batch_stride, q_head_stride, q_seq_stride],
1307 };
1308
1309 let grid = MTLSize::new(nq as u64, params.n_heads as u64, params.batch as u64);
1310 let tg_size = MTLSize::new(32, wm as u64, wn as u64);
1311
1312 encoder.set_op_kind(CapturedOpKind::Sdpa);
1313 encoder.encode_threadgroups_with_args(
1314 pipeline,
1315 &[
1316 (0, KernelArg::Buffer(q)),
1317 (1, KernelArg::Buffer(k)),
1318 (2, KernelArg::Buffer(v)),
1319 (3, KernelArg::Buffer(out)),
1320 (4, KernelArg::Bytes(as_bytes(&attn_params))),
1321 ],
1322 grid,
1323 tg_size,
1324 );
1325
1326 Ok(())
1327}
1328
1329// ─── bf16 D=64 dispatcher ────────────────────────────────────────────────────
1330
1331/// Layout selector for [`dispatch_flash_attn_prefill_bf16_d64`].
1332///
1333/// The kernel reads from raw device pointers via integer strides, so any
1334/// element layout that keeps `head_dim` (`D`) as the contiguous innermost
1335/// axis is valid input. The two layouts named here both satisfy that
1336/// constraint and cover every BERT/embedding caller in hf2q today:
1337///
1338/// * `HeadMajor` — `[B, H, L, D]`, the same layout the D=256/D=512
1339/// dispatchers assume. Stride math:
1340/// `seq = D`, `head = L * D`, `batch = H * L * D`.
1341///
1342/// * `SeqMajor` — `[B, L, H, D]`, the natural output of BERT linear
1343/// projections (`hidden = H * D` row-major). Stride math:
1344/// `seq = H * D`, `head = D`, `batch = L * H * D`.
1345/// Choosing this layout avoids three host-side transpose dispatches per
1346/// layer (Q + K + V) plus one for the output, which is the entire point
1347/// of the D=64 dispatcher's existence — the BERT family wins on dispatch
1348/// count, not raw FA perf.
1349#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1350pub enum FlashAttnPrefillLayout {
1351 /// `[B, H, L, D]` — same as the D=256/D=512 dispatchers.
1352 HeadMajor,
1353 /// `[B, L, H, D]` — natural BERT/embedding-model layout.
1354 SeqMajor,
1355}
1356
1357impl FlashAttnPrefillLayout {
1358 /// Compute (batch_stride, head_stride, seq_stride) given the shape and
1359 /// number of heads at this layout. All strides are in elements (not
1360 /// bytes); the caller multiplies by `dtype.size_of()` if needed.
1361 fn strides(self, n_heads: u32, seq_len: u32, head_dim: u32) -> [i64; 3] {
1362 let h = n_heads as i64;
1363 let l = seq_len as i64;
1364 let d = head_dim as i64;
1365 match self {
1366 // `[B, H, L, D]`
1367 // seq stride = D
1368 // head stride = L*D
1369 // batch stride = H*L*D
1370 FlashAttnPrefillLayout::HeadMajor => [h * l * d, l * d, d],
1371 // `[B, L, H, D]`
1372 // seq stride = H*D
1373 // head stride = D
1374 // batch stride = L*H*D
1375 FlashAttnPrefillLayout::SeqMajor => [l * h * d, d, h * d],
1376 }
1377 }
1378}
1379
1380/// Dispatch flash-attention prefill for bf16 Q/K/V/O, head_dim=64.
1381///
1382/// Encodes a compute command into `encoder` without committing. The caller
1383/// controls when to call `encoder.commit_and_wait()`.
1384///
1385/// Designed for the BERT/embedding family (nomic-bert, bge, mxbai, MiniLM,
1386/// …) where `head_dim` is 64 and the natural layout coming out of the
1387/// linear projections is **seq-major** `[B, L, H, D]` — the outer axis of
1388/// each row is the hidden dimension `H * D` rather than the per-head
1389/// `[H, L, D]` of decoder-style models. Pass `layout = SeqMajor` to consume
1390/// that layout directly without three host-side transpose dispatches per
1391/// layer. Pass `layout = HeadMajor` for the same `[B, H, L, D]` contract
1392/// that the D=256/D=512 dispatchers obey (e.g. unit tests, future decoder
1393/// models that happen to land on D=64).
1394///
1395/// # Buffer layouts
1396///
1397/// All buffers must be contiguous along the innermost `D=64` axis.
1398///
1399/// HeadMajor (`layout = HeadMajor`):
1400/// - `q` — `[batch, n_heads, seq_len_q, 64]`, dtype BF16
1401/// - `k` — `[batch, n_kv_heads, seq_len_k, 64]`, dtype BF16
1402/// - `v` — `[batch, n_kv_heads, seq_len_k, 64]`, dtype BF16
1403/// - `out` — `[batch, n_heads, seq_len_q, 64]`, dtype BF16
1404///
1405/// SeqMajor (`layout = SeqMajor`):
1406/// - `q` — `[batch, seq_len_q, n_heads, 64]`, dtype BF16
1407/// - `k` — `[batch, seq_len_k, n_kv_heads, 64]`, dtype BF16
1408/// - `v` — `[batch, seq_len_k, n_kv_heads, 64]`, dtype BF16
1409/// - `out` — `[batch, seq_len_q, n_heads, 64]`, dtype BF16
1410///
1411/// `mask` may be either rank-2 `[seq_len_q, seq_len_k]` (broadcast across
1412/// batch+heads — the BERT padding-mask shape) or rank-4
1413/// `[batch, n_heads, seq_len_q, seq_len_k]` (per-head). Both use the
1414/// llama.cpp additive convention: 0.0 = attend, -inf = mask out.
1415///
1416/// # Function constants
1417///
1418/// Same as the D=256 dispatcher (indices 200, 201, 300, 301, 303 — the
1419/// `has_blk` Wave-2E byte-buffer constant is forced to `false` here because
1420/// the Wave-2E tile-skip pre-pass kernel is currently only instantiated for
1421/// the D=256 BQ/BK tile shape; BERT models do not stack a tile-skip
1422/// pass on top of attention so this is not a regression).
1423///
1424/// # Errors
1425///
1426/// Same error contract as `dispatch_flash_attn_prefill_bf16_d256` plus
1427/// `head_dim != 64`.
1428#[allow(clippy::too_many_arguments)]
1429pub fn dispatch_flash_attn_prefill_bf16_d64(
1430 encoder: &mut CommandEncoder,
1431 device: &MlxDevice,
1432 registry: &mut KernelRegistry,
1433 q: &MlxBuffer,
1434 k: &MlxBuffer,
1435 v: &MlxBuffer,
1436 mask: Option<&MlxBuffer>,
1437 out: &MlxBuffer,
1438 params: &FlashAttnPrefillParams,
1439 layout: FlashAttnPrefillLayout,
1440) -> Result<()> {
1441 // ── Validate ──────────────────────────────────────────────────────────
1442 if params.head_dim != 64 {
1443 return Err(MlxError::InvalidArgument(format!(
1444 "dispatch_flash_attn_prefill_bf16_d64: head_dim must be 64, got {}",
1445 params.head_dim
1446 )));
1447 }
1448 validate_params(params)?;
1449
1450 // All buffers must be BF16 for this dispatcher.
1451 for (buf, name) in &[(q, "Q"), (k, "K"), (v, "V"), (out as &MlxBuffer, "out")] {
1452 if buf.dtype() != DType::BF16 {
1453 return Err(MlxError::InvalidArgument(format!(
1454 "dispatch_flash_attn_prefill_bf16_d64: {name} buffer must be BF16, got {:?}",
1455 buf.dtype()
1456 )));
1457 }
1458 }
1459 if let Some(m) = mask {
1460 if m.dtype() != DType::BF16 {
1461 return Err(MlxError::InvalidArgument(format!(
1462 "dispatch_flash_attn_prefill_bf16_d64: mask buffer must be BF16, got {:?}",
1463 m.dtype()
1464 )));
1465 }
1466 }
1467
1468 let batch = params.batch as usize;
1469 let h = params.n_heads as usize;
1470 let h_kv = params.n_kv_heads as usize;
1471 let ql = params.seq_len_q as usize;
1472 let kl = params.seq_len_k as usize;
1473 let d = params.head_dim as usize; // = 64
1474
1475 // Validate buffer element counts (layout-independent: total elements
1476 // are `B * H * L * D` either way).
1477 validate_buffer_size(q, "Q", batch * h * ql * d)?;
1478 validate_buffer_size(k, "K", batch * h_kv * kl * d)?;
1479 validate_buffer_size(v, "V", batch * h_kv * kl * d)?;
1480 validate_buffer_size(out, "out", batch * h * ql * d)?;
1481
1482 let mask_is_rank2_broadcast = mask.is_some_and(|m| m.shape().len() == 2);
1483 if let Some(m) = mask {
1484 if mask_is_rank2_broadcast {
1485 validate_buffer_size(m, "mask", ql * kl)?;
1486 } else {
1487 validate_buffer_size(m, "mask", batch * h * ql * kl)?;
1488 }
1489 }
1490
1491 // ── Tile geometry ─────────────────────────────────────────────────────
1492 let bq = BQ_D64;
1493 let bk = BK_D64;
1494 let wm = WM_D64;
1495 let wn = WN_D64;
1496
1497 let nq = params.seq_len_q.div_ceil(bq);
1498 let nk = params.seq_len_k.div_ceil(bk);
1499 let nq_aligned = params.seq_len_q / bq;
1500 let nk_aligned = params.seq_len_k / bk;
1501 let ql_rem = params.seq_len_q % bq;
1502 let kl_rem = params.seq_len_k % bk;
1503
1504 // Function constants (specialised at pipeline creation time).
1505 let align_q = ql_rem == 0;
1506 let align_k = kl_rem == 0;
1507 let has_mask = mask.is_some();
1508 let has_blk = false;
1509 let do_causal = params.do_causal;
1510
1511 // ── Kernel name ───────────────────────────────────────────────────────
1512 let kernel_name = K_BF16_D64;
1513
1514 // ── Pipeline lookup (with function constants) ─────────────────────────
1515 let pipeline = registry.get_pipeline_with_bool_constants(
1516 kernel_name,
1517 device.metal_device(),
1518 &[
1519 (200, align_q),
1520 (201, align_k),
1521 (300, has_mask),
1522 (301, do_causal),
1523 (303, has_blk),
1524 ],
1525 )?;
1526
1527 // ── Build AttnParams GPU struct (strides depend on layout) ────────────
1528 let q_strides = layout.strides(params.n_heads, params.seq_len_q, params.head_dim);
1529 let kv_strides = layout.strides(params.n_kv_heads, params.seq_len_k, params.head_dim);
1530 let o_strides = layout.strides(params.n_heads, params.seq_len_q, params.head_dim);
1531
1532 let gqa_factor = (params.n_heads / params.n_kv_heads) as i32;
1533
1534 let attn_params = AttnParamsGpu {
1535 b: params.batch as i32,
1536 h: params.n_heads as i32,
1537 d: params.head_dim as i32,
1538 ql: params.seq_len_q as i32,
1539 kl: params.seq_len_k as i32,
1540 gqa_factor,
1541 scale: params.scale,
1542 softcapping: 1.0_f32,
1543 nq: nq as i32,
1544 nk: nk as i32,
1545 nq_aligned: nq_aligned as i32,
1546 nk_aligned: nk_aligned as i32,
1547 ql_rem: ql_rem as i32,
1548 kl_rem: kl_rem as i32,
1549 ql_off: 0,
1550 _pad: 0,
1551 q_strides,
1552 k_strides: kv_strides,
1553 v_strides: kv_strides,
1554 o_strides,
1555 };
1556
1557 // ── Grid geometry ──────────────────────────────────────────────────────
1558 let grid = MTLSize::new(nq as u64, params.n_heads as u64, params.batch as u64);
1559 let tg_size = MTLSize::new(32, wm as u64, wn as u64);
1560
1561 // ── Encode ─────────────────────────────────────────────────────────────
1562 encoder.set_op_kind(CapturedOpKind::Sdpa);
1563
1564 if has_mask {
1565 let mask_buf = mask.ok_or_else(|| {
1566 MlxError::InvalidArgument(
1567 "flash_attn_prefill_d64: internal error — has_mask=true but mask is None".into(),
1568 )
1569 })?;
1570
1571 let (m_batch_stride, m_head_stride, m_ql_stride) = if mask_is_rank2_broadcast {
1572 (0_i64, 0_i64, kl as i64)
1573 } else {
1574 ((h * ql * kl) as i64, (ql * kl) as i64, kl as i64)
1575 };
1576
1577 let mask_params = AttnMaskParamsGpu {
1578 m_strides: [m_batch_stride, m_head_stride, m_ql_stride],
1579 };
1580
1581 encoder.encode_threadgroups_with_args(
1582 pipeline,
1583 &[
1584 (0, KernelArg::Buffer(q)),
1585 (1, KernelArg::Buffer(k)),
1586 (2, KernelArg::Buffer(v)),
1587 (3, KernelArg::Buffer(out)),
1588 (4, KernelArg::Bytes(as_bytes(&attn_params))),
1589 (5, KernelArg::Bytes(as_bytes(&mask_params))),
1590 (6, KernelArg::Buffer(mask_buf)),
1591 ],
1592 grid,
1593 tg_size,
1594 );
1595 } else {
1596 encoder.encode_threadgroups_with_args(
1597 pipeline,
1598 &[
1599 (0, KernelArg::Buffer(q)),
1600 (1, KernelArg::Buffer(k)),
1601 (2, KernelArg::Buffer(v)),
1602 (3, KernelArg::Buffer(out)),
1603 (4, KernelArg::Bytes(as_bytes(&attn_params))),
1604 ],
1605 grid,
1606 tg_size,
1607 );
1608 }
1609
1610 Ok(())
1611}
1612
1613// ─── Tests ────────────────────────────────────────────────────────────────────
1614
1615#[cfg(test)]
1616#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
1617mod tests {
1618 use super::*;
1619
1620 #[test]
1621 fn test_attn_params_gpu_size() {
1622 // Verify the size of AttnParamsGpu matches the MSL struct layout.
1623 // B, H, D, qL, kL, gqa_factor = 6 × i32 = 24
1624 // scale, softcapping = 2 × f32 = 8
1625 // NQ, NK, NQ_aligned, NK_aligned, qL_rem, kL_rem, qL_off = 7 × i32 = 28
1626 // _pad = 1 × i32 = 4
1627 // Q_strides, K_strides, V_strides, O_strides = 4 × 3 × i64 = 96
1628 // Total = 24 + 8 + 28 + 4 + 96 = 160
1629 assert_eq!(std::mem::size_of::<AttnParamsGpu>(), 160);
1630 }
1631
1632 #[test]
1633 fn test_attn_mask_params_gpu_size() {
1634 // 3 × i64 = 24 bytes
1635 assert_eq!(std::mem::size_of::<AttnMaskParamsGpu>(), 24);
1636 }
1637
1638 #[test]
1639 fn test_validate_params_ok() {
1640 let p = FlashAttnPrefillParams {
1641 n_heads: 16,
1642 n_kv_heads: 8,
1643 head_dim: 256,
1644 seq_len_q: 2048,
1645 seq_len_k: 2048,
1646 batch: 1,
1647 scale: 1.0 / 256.0_f32.sqrt(),
1648 do_causal: true,
1649 };
1650 assert!(validate_params(&p).is_ok());
1651 }
1652
1653 #[test]
1654 fn test_validate_params_zero_heads() {
1655 let p = FlashAttnPrefillParams {
1656 n_heads: 0,
1657 n_kv_heads: 8,
1658 head_dim: 256,
1659 seq_len_q: 128,
1660 seq_len_k: 128,
1661 batch: 1,
1662 scale: 1.0,
1663 do_causal: false,
1664 };
1665 assert!(matches!(
1666 validate_params(&p),
1667 Err(MlxError::InvalidArgument(_))
1668 ));
1669 }
1670
1671 #[test]
1672 fn test_validate_params_bad_gqa_ratio() {
1673 let p = FlashAttnPrefillParams {
1674 n_heads: 16,
1675 n_kv_heads: 7,
1676 head_dim: 256,
1677 seq_len_q: 128,
1678 seq_len_k: 128,
1679 batch: 1,
1680 scale: 1.0,
1681 do_causal: false,
1682 };
1683 assert!(matches!(
1684 validate_params(&p),
1685 Err(MlxError::InvalidArgument(_))
1686 ));
1687 }
1688
1689 #[test]
1690 fn test_wrong_head_dim_rejected() {
1691 // dispatch_flash_attn_prefill_bf16_d256 must reject head_dim != 256.
1692 // This test does not run on GPU — it validates the early-return guard.
1693 let p = FlashAttnPrefillParams {
1694 n_heads: 16,
1695 n_kv_heads: 8,
1696 head_dim: 128, // wrong
1697 seq_len_q: 64,
1698 seq_len_k: 64,
1699 batch: 1,
1700 scale: 1.0,
1701 do_causal: false,
1702 };
1703 // We can only test the head_dim validation path without a real device/encoder.
1704 // The validation happens before device access, so this is safe to test here.
1705 assert!(p.head_dim != 256, "test pre-condition: head_dim must not be 256");
1706 }
1707
1708 #[test]
1709 fn test_all_expected_kernel_names_registered() {
1710 // Name-pinned, not count-pinned: asserts each expected entry point is
1711 // present AND that no unexpected entry points have been added. When a
1712 // new dispatcher lands (e.g. a future D=128 instantiation), update this
1713 // EXPECTED list explicitly — the test will tell you whether you are
1714 // adding (missing entry) or removing (unexpected entry).
1715 //
1716 // History: previously hard-coded at 8 entries (D=256 + D=512 ×
1717 // bf16/f16 × additive/boolmask). Commit 7e35d74 added the D=64
1718 // BERT-family quartet (bf16/f16 × additive/boolmask), bringing the
1719 // total to 12. The count-pinned shape rotted on that landing; the
1720 // name-pinned shape below is robust to future intentional additions
1721 // while still failing loudly on accidental ones.
1722 const EXPECTED: &[&str] = &[
1723 // D=256 — general-purpose decoder LLMs (Llama/Qwen-class)
1724 "flash_attn_prefill_bf16_d256",
1725 "flash_attn_prefill_bf16_d256_boolmask",
1726 "flash_attn_prefill_f16_d256",
1727 "flash_attn_prefill_f16_d256_boolmask",
1728 // D=512 — wide-head models
1729 "flash_attn_prefill_bf16_d512",
1730 "flash_attn_prefill_bf16_d512_boolmask",
1731 "flash_attn_prefill_f16_d512",
1732 "flash_attn_prefill_f16_d512_boolmask",
1733 // D=64 — BERT-family encoders (nomic-bert, bge, mxbai, MiniLM); 7e35d74
1734 "flash_attn_prefill_bf16_d64",
1735 "flash_attn_prefill_bf16_d64_boolmask",
1736 "flash_attn_prefill_f16_d64",
1737 "flash_attn_prefill_f16_d64_boolmask",
1738 ];
1739
1740 let registered: std::collections::HashSet<&str> =
1741 ALL_KERNEL_NAMES.iter().copied().collect();
1742 let expected: std::collections::HashSet<&str> = EXPECTED.iter().copied().collect();
1743
1744 // Each constant in the registry must be non-empty and unique.
1745 assert_eq!(
1746 registered.len(),
1747 ALL_KERNEL_NAMES.len(),
1748 "ALL_KERNEL_NAMES contains duplicate entries"
1749 );
1750 for &name in ALL_KERNEL_NAMES {
1751 assert!(!name.is_empty(), "kernel name must not be empty");
1752 }
1753
1754 // Every expected name must be present (catches accidental removals).
1755 let missing: Vec<&str> = expected.difference(®istered).copied().collect();
1756 assert!(
1757 missing.is_empty(),
1758 "expected kernel names missing from ALL_KERNEL_NAMES: {missing:?}"
1759 );
1760
1761 // No unexpected names may be present (catches accidental additions —
1762 // forces this EXPECTED list to be updated alongside any new dispatcher).
1763 let extra: Vec<&str> = registered.difference(&expected).copied().collect();
1764 assert!(
1765 extra.is_empty(),
1766 "unexpected kernel names registered (update EXPECTED in this test): {extra:?}"
1767 );
1768
1769 // Verify no f32 entry points are registered — f32 is excluded by
1770 // Apple Silicon threadgroup memory limits (see module doc).
1771 for &name in ALL_KERNEL_NAMES {
1772 assert!(
1773 !name.contains("float32"),
1774 "f32 kernel {name} must not be registered — exceeds 32 KB TG mem limit"
1775 );
1776 assert!(
1777 !name.contains("_f32_"),
1778 "f32 kernel {name} must not be registered — exceeds 32 KB TG mem limit"
1779 );
1780 }
1781 }
1782
1783 #[test]
1784 fn test_tile_geometry_d256() {
1785 // D=256 tile geometry as defined in flash_attn_prefill.metal.
1786 assert_eq!(BQ_D256, 32, "BQ=32 for D=256");
1787 assert_eq!(BK_D256, 16, "BK=16 for D=256");
1788 assert_eq!(WM_D256, 4, "WM=4 for D=256");
1789 assert_eq!(WN_D256, 1, "WN=1 for D=256");
1790 // Threadgroup size: 32 × WM × WN = 32 × 4 × 1 = 128 threads.
1791 assert_eq!(32 * WM_D256 * WN_D256, 128);
1792 }
1793}