Skip to main content

mlx_native/ops/
flash_attn_train.rs

1//! Flash-attention training forward kernel — host dispatch.
2//!
3//! FA-2 forward pass that emits BOTH the attention output `O` AND the
4//! per-row natural-log logsumexp `L` required by the Phase 2 backward.
5//!
6//! ## Algorithm
7//!
8//! Identical to [`super::flash_attn_prefill`] (online softmax, simdgroup MMA,
9//! same tile geometry, same causal / additive-mask handling, same GQA).
10//!
11//! The only addition is the `L_out [B, H_q, qL]` f32 buffer at `buffer(8)`.
12//! After the K-tile sweep each thread with `sn == 0` writes one f32:
13//!
14//! ```text
15//! L[b, h, i] = max_score_b2 * ln(2) + ln(sum_score_b2)
16//! ```
17//!
18//! where `max_score_b2` and `sum_score_b2` are the per-row base-2
19//! running max / unnormalized exp2 sum from the K-sweep (Q is pre-scaled
20//! by `scale * log2(e)` so all accumulators live in base-2 space).
21//!
22//! This equals the FA-2 paper Algorithm 1 logsumexp:
23//! `L_i = m_i + log( sum_j exp(s_ij - m_i) )` in natural-log units.
24//!
25//! ## Buffer layout
26//!
27//! | Index | Name     | Shape               | DType |
28//! |-------|----------|---------------------|-------|
29//! | 0     | Q        | `[B, H_q, qL, D]`   | BF16  |
30//! | 1     | K        | `[B, H_kv, kL, D]`  | BF16  |
31//! | 2     | V        | `[B, H_kv, kL, D]`  | BF16  |
32//! | 3     | O (out)  | `[B, H_q, qL, D]`   | BF16  |
33//! | 4     | params   | 160-byte ABI struct  | —     |
34//! | 5     | mask_params | 24-byte struct    | — (when has_mask) |
35//! | 6     | mask     | `[B, H_q, qL, kL]`  | BF16 or bool (when has_mask) |
36//! | 8     | L_out    | `[B, H_q, qL]`      | F32   |
37//!
38//! ## Function constants
39//!
40//! Same 4 constants as `flash_attn_prefill.metal`:
41//!
42//! | Index | Name      | Semantics |
43//! |-------|-----------|-----------|
44//! | 200   | align_Q   | `qL % BQ == 0` |
45//! | 201   | align_K   | `kL % BK == 0` |
46//! | 300   | has_mask  | additive/bool mask buffer bound |
47//! | 301   | do_causal | in-kernel causal masking |
48//!
49//! ## Kernel variants
50//!
51//! | Name | D | I/O dtype | Mask kind |
52//! |------|---|-----------|-----------|
53//! | `flash_attn_train_fwd_bf16_d64`          | 64  | bf16 | bf16 additive |
54//! | `flash_attn_train_fwd_bf16_d64_boolmask` | 64  | bf16 | bool |
55//! | `flash_attn_train_fwd_bf16_d256`          | 256 | bf16 | bf16 additive |
56//! | `flash_attn_train_fwd_bf16_d256_boolmask` | 256 | bf16 | bool |
57//!
58//! ## Scale convention
59//!
60//! Pass `scale = 1.0 / sqrt(head_dim)`.  The kernel multiplies internally by
61//! `log2(e)`.  Do NOT pre-multiply by `log2(e)` on the host.
62
63use metal::MTLSize;
64
65use crate::buffer::MlxBuffer;
66use crate::device::MlxDevice;
67use crate::dtypes::DType;
68use crate::encoder::{CapturedOpKind, CommandEncoder, KernelArg, as_bytes};
69use crate::error::{MlxError, Result};
70use crate::kernel_registry::KernelRegistry;
71use crate::ops::flash_attn_prefill::{AttnMaskParamsGpu, AttnParamsGpu};
72
73// ─── Shader source ───────────────────────────────────────────────────────────
74
75/// MSL source (embedded at compile time).
76pub static FLASH_ATTN_TRAIN_FWD_SHADER_SOURCE: &str =
77    include_str!("../shaders/flash_attn_train_fwd.metal");
78
79// ─── Kernel names ────────────────────────────────────────────────────────────
80
81const K_BF16_D64: &str = "flash_attn_train_fwd_bf16_d64";
82const K_BF16_D64_BOOLMASK: &str = "flash_attn_train_fwd_bf16_d64_boolmask";
83const K_BF16_D256: &str = "flash_attn_train_fwd_bf16_d256";
84const K_BF16_D256_BOOLMASK: &str = "flash_attn_train_fwd_bf16_d256_boolmask";
85
86const ALL_KERNEL_NAMES: &[&str] = &[
87    K_BF16_D64,
88    K_BF16_D64_BOOLMASK,
89    K_BF16_D256,
90    K_BF16_D256_BOOLMASK,
91];
92
93// ─── Registration ─────────────────────────────────────────────────────────────
94
95/// Register all 4 training-forward kernel entry points with the registry.
96///
97/// Must be called before any `dispatch_flash_attn_train_fwd_*` call.
98pub fn register(registry: &mut KernelRegistry) {
99    for &name in ALL_KERNEL_NAMES {
100        registry.register_source(name, FLASH_ATTN_TRAIN_FWD_SHADER_SOURCE);
101    }
102}
103
104// ─── Tile geometry ────────────────────────────────────────────────────────────
105
106// D=64 and D=256 share the same tile geometry.
107const BQ: u32 = 32;
108const BK: u32 = 16;
109const WM: u32 = 4;
110const WN: u32 = 1;
111
112// ─── Public parameter struct ──────────────────────────────────────────────────
113
114/// Host-side parameters for the flash-attention training forward dispatcher.
115///
116/// Mirrors [`crate::ops::flash_attn_prefill::FlashAttnPrefillParams`] but is
117/// kept separate to decouple the training API from the inference API.
118#[derive(Debug, Clone, Copy)]
119pub struct FlashAttnTrainParams {
120    /// Batch size.
121    pub batch: u32,
122    /// Number of query attention heads.
123    pub n_q_heads: u32,
124    /// Number of key/value attention heads.  Must divide `n_q_heads` evenly.
125    pub n_kv_heads: u32,
126    /// Head dimension.  Must be 64 (D=64 dispatcher) or 256 (D=256 dispatcher).
127    pub head_dim: u32,
128    /// Query sequence length.
129    pub q_seq_len: u32,
130    /// Key/value sequence length.
131    pub k_seq_len: u32,
132    /// Attention scale.  Typically `1.0 / sqrt(head_dim)`.
133    ///
134    /// The kernel multiplies by `log2(e) ≈ 1.44269504` internally.
135    /// Do NOT pre-multiply by `log2(e)` here.
136    pub scale: f32,
137    /// Apply causal masking in-kernel.
138    pub causal: bool,
139}
140
141// ─── Input validation ─────────────────────────────────────────────────────────
142
143fn validate_params(p: &FlashAttnTrainParams) -> Result<()> {
144    if p.n_q_heads == 0 {
145        return Err(MlxError::InvalidArgument(
146            "flash_attn_train: n_q_heads must be > 0".into(),
147        ));
148    }
149    if p.n_kv_heads == 0 {
150        return Err(MlxError::InvalidArgument(
151            "flash_attn_train: n_kv_heads must be > 0".into(),
152        ));
153    }
154    if p.n_q_heads % p.n_kv_heads != 0 {
155        return Err(MlxError::InvalidArgument(format!(
156            "flash_attn_train: n_q_heads ({}) must be divisible by n_kv_heads ({})",
157            p.n_q_heads, p.n_kv_heads
158        )));
159    }
160    if p.q_seq_len == 0 {
161        return Err(MlxError::InvalidArgument(
162            "flash_attn_train: q_seq_len must be > 0".into(),
163        ));
164    }
165    if p.k_seq_len == 0 {
166        return Err(MlxError::InvalidArgument(
167            "flash_attn_train: k_seq_len must be > 0".into(),
168        ));
169    }
170    if p.batch == 0 {
171        return Err(MlxError::InvalidArgument(
172            "flash_attn_train: batch must be > 0".into(),
173        ));
174    }
175    Ok(())
176}
177
178fn validate_buffer_size(buf: &MlxBuffer, name: &str, expected_elements: usize) -> Result<()> {
179    let expected_bytes = expected_elements * buf.dtype().size_of();
180    if buf.byte_len() < expected_bytes {
181        return Err(MlxError::InvalidArgument(format!(
182            "flash_attn_train: {name} buffer too small: expected at least \
183             {expected_bytes} bytes, got {}",
184            buf.byte_len()
185        )));
186    }
187    Ok(())
188}
189
190// ─── Shared dispatch core ─────────────────────────────────────────────────────
191
192/// Inner dispatch used by both the D=64 and D=256 public dispatchers.
193///
194/// `kernel_name` must be one of the 4 registered names.
195/// `head_dim_expected` is checked against `params.head_dim` before dispatch.
196#[allow(clippy::too_many_arguments)]
197fn dispatch_inner(
198    encoder: &mut CommandEncoder,
199    device: &MlxDevice,
200    registry: &mut KernelRegistry,
201    q_buf: &MlxBuffer,
202    k_buf: &MlxBuffer,
203    v_buf: &MlxBuffer,
204    mask: Option<&MlxBuffer>,
205    o_buf: &MlxBuffer,
206    l_buf: &MlxBuffer,
207    params: &FlashAttnTrainParams,
208    kernel_name: &str,
209    head_dim_expected: u32,
210) -> Result<()> {
211    // ── Validate head_dim ──────────────────────────────────────────────────
212    if params.head_dim != head_dim_expected {
213        return Err(MlxError::InvalidArgument(format!(
214            "flash_attn_train ({}): head_dim must be {head_dim_expected}, got {}",
215            kernel_name, params.head_dim
216        )));
217    }
218
219    validate_params(params)?;
220
221    // ── Dtype checks ───────────────────────────────────────────────────────
222    for (buf, name) in &[(q_buf, "Q"), (k_buf, "K"), (v_buf, "V"), (o_buf as &MlxBuffer, "O")] {
223        if buf.dtype() != DType::BF16 {
224            return Err(MlxError::InvalidArgument(format!(
225                "flash_attn_train ({kernel_name}): {name} buffer must be BF16, got {:?}",
226                buf.dtype()
227            )));
228        }
229    }
230    if l_buf.dtype() != DType::F32 {
231        return Err(MlxError::InvalidArgument(format!(
232            "flash_attn_train ({kernel_name}): L_out buffer must be F32, got {:?}",
233            l_buf.dtype()
234        )));
235    }
236    if let Some(m) = mask {
237        if m.dtype() != DType::BF16 {
238            return Err(MlxError::InvalidArgument(format!(
239                "flash_attn_train ({kernel_name}): mask buffer must be BF16, got {:?}",
240                m.dtype()
241            )));
242        }
243    }
244
245    // ── Shape arithmetic ───────────────────────────────────────────────────
246    let batch = params.batch as usize;
247    let h = params.n_q_heads as usize;
248    let h_kv = params.n_kv_heads as usize;
249    let ql = params.q_seq_len as usize;
250    let kl = params.k_seq_len as usize;
251    let d = params.head_dim as usize;
252
253    validate_buffer_size(q_buf, "Q", batch * h * ql * d)?;
254    validate_buffer_size(k_buf, "K", batch * h_kv * kl * d)?;
255    validate_buffer_size(v_buf, "V", batch * h_kv * kl * d)?;
256    validate_buffer_size(o_buf, "O", batch * h * ql * d)?;
257    validate_buffer_size(l_buf, "L_out", batch * h * ql)?;
258    if let Some(m) = mask {
259        validate_buffer_size(m, "mask", batch * h * ql * kl)?;
260    }
261
262    // ── Tile geometry ──────────────────────────────────────────────────────
263    let nq = params.q_seq_len.div_ceil(BQ);
264    let nk = params.k_seq_len.div_ceil(BK);
265    let nq_aligned = params.q_seq_len / BQ;
266    let nk_aligned = params.k_seq_len / BK;
267    let ql_rem = params.q_seq_len % BQ;
268    let kl_rem = params.k_seq_len % BK;
269
270    let align_q = ql_rem == 0;
271    let align_k = kl_rem == 0;
272    let has_mask = mask.is_some();
273    let do_causal = params.causal;
274
275    // ── Pipeline ───────────────────────────────────────────────────────────
276    let pipeline = registry.get_pipeline_with_bool_constants(
277        kernel_name,
278        device.metal_device(),
279        &[
280            (200, align_q),
281            (201, align_k),
282            (300, has_mask),
283            (301, do_causal),
284        ],
285    )?;
286
287    // ── AttnParamsGpu ──────────────────────────────────────────────────────
288    let q_seq_stride = d as i64;
289    let q_head_stride = (ql * d) as i64;
290    let q_batch_stride = (h * ql * d) as i64;
291
292    let kv_seq_stride = d as i64;
293    let kv_head_stride = (kl * d) as i64;
294    let kv_batch_stride = (h_kv * kl * d) as i64;
295
296    let gqa_factor = (params.n_q_heads / params.n_kv_heads) as i32;
297
298    let attn_params = AttnParamsGpu {
299        b: params.batch as i32,
300        h: params.n_q_heads as i32,
301        d: params.head_dim as i32,
302        ql: params.q_seq_len as i32,
303        kl: params.k_seq_len as i32,
304        gqa_factor,
305        scale: params.scale,
306        softcapping: 1.0_f32,
307        nq: nq as i32,
308        nk: nk as i32,
309        nq_aligned: nq_aligned as i32,
310        nk_aligned: nk_aligned as i32,
311        ql_rem: ql_rem as i32,
312        kl_rem: kl_rem as i32,
313        ql_off: 0,
314        _pad: 0,
315        q_strides: [q_batch_stride, q_head_stride, q_seq_stride],
316        k_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
317        v_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
318        o_strides: [q_batch_stride, q_head_stride, q_seq_stride],
319    };
320
321    // ── Grid ───────────────────────────────────────────────────────────────
322    //   grid = (ceil(qL / BQ), H_q, B)
323    //   tg   = (32, WM, WN)
324    let grid = MTLSize::new(nq as u64, params.n_q_heads as u64, params.batch as u64);
325    let tg_size = MTLSize::new(32, WM as u64, WN as u64);
326
327    // ── Encode ────────────────────────────────────────────────────────────
328    encoder.set_op_kind(CapturedOpKind::Sdpa);
329
330    if let Some(mask_buf) = mask {
331        // Rank-4 mask [B, H, qL, kL] — per-head layout.
332        let m_batch_stride = (h * ql * kl) as i64;
333        let m_head_stride = (ql * kl) as i64;
334        let m_ql_stride = kl as i64;
335
336        let mask_params = AttnMaskParamsGpu {
337            m_strides: [m_batch_stride, m_head_stride, m_ql_stride],
338        };
339
340        encoder.encode_threadgroups_with_args(
341            pipeline,
342            &[
343                (0, KernelArg::Buffer(q_buf)),
344                (1, KernelArg::Buffer(k_buf)),
345                (2, KernelArg::Buffer(v_buf)),
346                (3, KernelArg::Buffer(o_buf)),
347                (4, KernelArg::Bytes(as_bytes(&attn_params))),
348                (5, KernelArg::Bytes(as_bytes(&mask_params))),
349                (6, KernelArg::Buffer(mask_buf)),
350                // buffer(7) intentionally absent (blk not used in training fwd)
351                (8, KernelArg::Buffer(l_buf)),
352            ],
353            grid,
354            tg_size,
355        );
356    } else {
357        encoder.encode_threadgroups_with_args(
358            pipeline,
359            &[
360                (0, KernelArg::Buffer(q_buf)),
361                (1, KernelArg::Buffer(k_buf)),
362                (2, KernelArg::Buffer(v_buf)),
363                (3, KernelArg::Buffer(o_buf)),
364                (4, KernelArg::Bytes(as_bytes(&attn_params))),
365                // buffers 5, 6 absent — has_mask=false constant dead-codes them
366                (8, KernelArg::Buffer(l_buf)),
367            ],
368            grid,
369            tg_size,
370        );
371    }
372
373    Ok(())
374}
375
376// ─── Public dispatchers ───────────────────────────────────────────────────────
377
378/// Dispatch the FA-2 forward pass for bf16 Q/K/V/O, head_dim=64.
379///
380/// Encodes a compute command into `encoder` without committing.
381///
382/// # Buffer shapes
383///
384/// - `q_buf`  — `[batch, n_q_heads, q_seq_len, 64]`  BF16
385/// - `k_buf`  — `[batch, n_kv_heads, k_seq_len, 64]` BF16
386/// - `v_buf`  — `[batch, n_kv_heads, k_seq_len, 64]` BF16
387/// - `mask`   — `[batch, n_q_heads, q_seq_len, k_seq_len]` BF16, or `None`
388/// - `o_buf`  — `[batch, n_q_heads, q_seq_len, 64]`  BF16 (output)
389/// - `l_buf`  — `[batch, n_q_heads, q_seq_len]`      F32  (logsumexp output)
390///
391/// # Errors
392///
393/// Returns `MlxError::InvalidArgument` for wrong head_dim, wrong dtype,
394/// bad GQA ratio, or undersized buffer.
395#[allow(clippy::too_many_arguments)]
396pub fn dispatch_flash_attn_train_fwd_bf16_d64(
397    encoder: &mut CommandEncoder,
398    device: &MlxDevice,
399    registry: &mut KernelRegistry,
400    q_buf: &MlxBuffer,
401    k_buf: &MlxBuffer,
402    v_buf: &MlxBuffer,
403    mask: Option<&MlxBuffer>,
404    o_buf: &MlxBuffer,
405    l_buf: &MlxBuffer,
406    params: &FlashAttnTrainParams,
407) -> Result<()> {
408    dispatch_inner(
409        encoder, device, registry,
410        q_buf, k_buf, v_buf, mask, o_buf, l_buf,
411        params, K_BF16_D64, 64,
412    )
413}
414
415/// Dispatch the FA-2 forward pass for bf16 Q/K/V/O, head_dim=256.
416///
417/// Same semantics as [`dispatch_flash_attn_train_fwd_bf16_d64`] but for
418/// the production Qwen3.6-35B-A3B head dimension (D=256).
419///
420/// # Errors
421///
422/// Same as `dispatch_flash_attn_train_fwd_bf16_d64`.
423#[allow(clippy::too_many_arguments)]
424pub fn dispatch_flash_attn_train_fwd_bf16_d256(
425    encoder: &mut CommandEncoder,
426    device: &MlxDevice,
427    registry: &mut KernelRegistry,
428    q_buf: &MlxBuffer,
429    k_buf: &MlxBuffer,
430    v_buf: &MlxBuffer,
431    mask: Option<&MlxBuffer>,
432    o_buf: &MlxBuffer,
433    l_buf: &MlxBuffer,
434    params: &FlashAttnTrainParams,
435) -> Result<()> {
436    dispatch_inner(
437        encoder, device, registry,
438        q_buf, k_buf, v_buf, mask, o_buf, l_buf,
439        params, K_BF16_D256, 256,
440    )
441}
442
443// ─── Kernel-name coverage test (compile-time) ─────────────────────────────────
444
445/// Returns all 4 registered kernel names.
446///
447/// Exposed for integration tests (`tests/test_flash_attn_train.rs`).
448/// `#[cfg(test)]` cannot be used here because integration tests are a
449/// separate crate and `#[cfg(test)]` is not set for them.
450#[doc(hidden)]
451pub fn all_kernel_names_for_test() -> &'static [&'static str] {
452    ALL_KERNEL_NAMES
453}
454
455// ═══════════════════════════════════════════════════════════════════════════════
456// Phase 2 — FA-2 backward kernel (dQ, dK, dV)
457// ═══════════════════════════════════════════════════════════════════════════════
458//
459// Three-kernel chain per call:
460//   1. `flash_attn_train_bwd_compute_d_bf16` — computes D[b,h,i] = rowsum(O·dO)
461//   2. `flash_attn_train_bwd_bf16_d{64,256}` — writes f32 dQ, f32 dK_scratch,
462//      f32 dV_scratch via Q-tile-outer grid + atomic f32 adds for dK/dV.
463//   3. `f32_to_bf16_cast` — casts f32 dK_scratch and dV_scratch to bf16 output.
464//
465// The caller passes pre-allocated BF16 dQ/dK/dV buffers.  The dispatcher
466// allocates two intermediate f32 scratch buffers (dK_f32, dV_f32) internally;
467// dQ is written directly as f32 into a temporary and then cast.
468//
469// Buffer sizing:
470//   dK_f32  [B, H_kv, kL, D] f32
471//   dV_f32  [B, H_kv, kL, D] f32
472//   dQ_f32  [B, H_q,  qL, D] f32
473//   D_vec   [B, H_q,  qL]    f32
474//
475// After the backward kernel, each f32 result is cast to the caller-supplied
476// bf16 output buffer via `f32_to_bf16_cast`.
477
478/// MSL source for the backward kernels (embedded at compile time).
479pub static FLASH_ATTN_TRAIN_BWD_SHADER_SOURCE: &str =
480    include_str!("../shaders/flash_attn_train_bwd.metal");
481
482/// MSL source for the compute-D pre-pass kernel.
483pub static FLASH_ATTN_TRAIN_BWD_COMPUTE_D_SHADER_SOURCE: &str =
484    include_str!("../shaders/flash_attn_train_bwd_compute_d.metal");
485
486// Backward kernel names.
487const K_BWD_COMPUTE_D: &str = "flash_attn_train_bwd_compute_d_bf16";
488const K_BWD_D64: &str = "flash_attn_train_bwd_bf16_d64";
489const K_BWD_D256: &str = "flash_attn_train_bwd_bf16_d256";
490const K_F32_TO_BF16: &str = "f32_to_bf16_cast";
491
492const ALL_BWD_KERNEL_NAMES: &[&str] = &[
493    K_BWD_COMPUTE_D,
494    K_BWD_D64,
495    K_BWD_D256,
496    K_F32_TO_BF16,
497];
498
499/// Register all backward kernel entry points with the registry.
500///
501/// Must be called before any `dispatch_flash_attn_train_bwd_*` call.
502/// Safe to call alongside [`register`] (forward registration).
503pub fn register_bwd(registry: &mut KernelRegistry) {
504    registry.register_source(K_BWD_COMPUTE_D, FLASH_ATTN_TRAIN_BWD_COMPUTE_D_SHADER_SOURCE);
505    for &name in &[K_BWD_D64, K_BWD_D256, K_F32_TO_BF16] {
506        registry.register_source(name, FLASH_ATTN_TRAIN_BWD_SHADER_SOURCE);
507    }
508}
509
510/// Returns all 4 backward kernel names.  Exposed for integration tests.
511#[doc(hidden)]
512pub fn all_bwd_kernel_names_for_test() -> &'static [&'static str] {
513    ALL_BWD_KERNEL_NAMES
514}
515
516// ── compute-D pre-pass ────────────────────────────────────────────────────────
517
518/// Struct for the compute-D Metal kernel params (4 × u32 = 16 bytes).
519#[repr(C)]
520#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
521struct ComputeDParams {
522    batch: u32,
523    n_q_heads: u32,
524    q_seq_len: u32,
525    head_dim: u32,
526}
527
528/// Encode the compute-D pre-pass: `D[b,h,i] = rowsum(O[b,h,i,:] * dO[b,h,i,:])`.
529///
530/// `d_out_buf` must be `[B, H_q, qL]` f32, allocated and zero-initialized by the
531/// caller (the dispatcher allocates it internally).
532///
533/// Grid: `(qL, 1, B*H_q)`, tg_size = `(min(256, next_pow2(D)), 1, 1)`.
534fn dispatch_compute_d(
535    encoder: &mut CommandEncoder,
536    registry: &mut KernelRegistry,
537    device: &metal::DeviceRef,
538    o_buf: &MlxBuffer,
539    do_buf: &MlxBuffer,
540    d_out_buf: &MlxBuffer,
541    params: &FlashAttnTrainParams,
542) -> Result<()> {
543    let p = ComputeDParams {
544        batch: params.batch,
545        n_q_heads: params.n_q_heads,
546        q_seq_len: params.q_seq_len,
547        head_dim: params.head_dim,
548    };
549
550    let pipeline = registry.get_pipeline(K_BWD_COMPUTE_D, device)?;
551
552    let tg_x = std::cmp::min(256, params.head_dim.next_power_of_two()) as u64;
553    let grid = MTLSize::new(
554        params.q_seq_len as u64,
555        1,
556        (params.batch * params.n_q_heads) as u64,
557    );
558    let tg_size = MTLSize::new(tg_x, 1, 1);
559
560    encoder.encode_threadgroups_with_args(
561        pipeline,
562        &[
563            (0, KernelArg::Buffer(o_buf)),
564            (1, KernelArg::Buffer(do_buf)),
565            (2, KernelArg::Buffer(d_out_buf)),
566            (3, KernelArg::Bytes(as_bytes(&p))),
567        ],
568        grid,
569        tg_size,
570    );
571
572    Ok(())
573}
574
575// ── elementwise f32→bf16 cast ─────────────────────────────────────────────────
576
577/// Encode an elementwise f32→bf16 cast of `n_elems` elements.
578///
579/// `src` must be f32; `dst` must be bf16 with the same element count.
580/// The cast kernel's buffer(2) receives `n_elems` as a u32 OOB guard.
581fn dispatch_f32_to_bf16(
582    encoder: &mut CommandEncoder,
583    registry: &mut KernelRegistry,
584    device: &metal::DeviceRef,
585    src: &MlxBuffer,
586    dst: &MlxBuffer,
587    n_elems: usize,
588) -> Result<()> {
589    let pipeline = registry.get_pipeline(K_F32_TO_BF16, device)?;
590    let tg_x = std::cmp::min(256u64, n_elems as u64);
591    let n_groups = (n_elems as u64).div_ceil(tg_x);
592    let n_u32 = n_elems as u32;
593    encoder.encode_threadgroups_with_args(
594        pipeline,
595        &[
596            (0, KernelArg::Buffer(src)),
597            (1, KernelArg::Buffer(dst)),
598            (2, KernelArg::Bytes(as_bytes(&n_u32))),
599        ],
600        MTLSize::new(n_groups, 1, 1),
601        MTLSize::new(tg_x, 1, 1),
602    );
603    Ok(())
604}
605
606// ── backward inner dispatch ───────────────────────────────────────────────────
607
608/// Inner backward dispatch shared by the D=64 and D=256 public functions.
609///
610/// Runs the three-kernel chain:
611/// 1. compute_D pre-pass
612/// 2. FA-2 Algorithm 4 backward (Q-tile-outer, writes f32 dQ/dK/dV scratch)
613/// 3. f32 → bf16 cast for dQ, dK, dV
614#[allow(clippy::too_many_arguments)]
615fn dispatch_bwd_inner(
616    encoder: &mut CommandEncoder,
617    device: &MlxDevice,
618    registry: &mut KernelRegistry,
619    q_buf: &MlxBuffer,
620    k_buf: &MlxBuffer,
621    v_buf: &MlxBuffer,
622    o_buf: &MlxBuffer,
623    l_buf: &MlxBuffer,
624    do_buf: &MlxBuffer,
625    mask: Option<&MlxBuffer>,
626    dq_buf: &MlxBuffer,
627    dk_buf: &MlxBuffer,
628    dv_buf: &MlxBuffer,
629    params: &FlashAttnTrainParams,
630    bwd_kernel_name: &str,
631    head_dim_expected: u32,
632) -> Result<()> {
633    // ── head_dim check ────────────────────────────────────────────────────────
634    if params.head_dim != head_dim_expected {
635        return Err(MlxError::InvalidArgument(format!(
636            "flash_attn_train_bwd ({bwd_kernel_name}): head_dim must be \
637             {head_dim_expected}, got {}",
638            params.head_dim
639        )));
640    }
641
642    validate_params(params)?;
643
644    // ── dtype checks ──────────────────────────────────────────────────────────
645    for (buf, name) in &[
646        (q_buf, "Q"),
647        (k_buf, "K"),
648        (v_buf, "V"),
649        (o_buf, "O"),
650        (do_buf, "dO"),
651    ] {
652        if buf.dtype() != DType::BF16 {
653            return Err(MlxError::InvalidArgument(format!(
654                "flash_attn_train_bwd ({bwd_kernel_name}): {name} buffer must be BF16, \
655                 got {:?}",
656                buf.dtype()
657            )));
658        }
659    }
660    for (buf, name) in &[(l_buf, "L")] {
661        if buf.dtype() != DType::F32 {
662            return Err(MlxError::InvalidArgument(format!(
663                "flash_attn_train_bwd ({bwd_kernel_name}): {name} buffer must be F32, \
664                 got {:?}",
665                buf.dtype()
666            )));
667        }
668    }
669    for (buf, name) in &[
670        (dq_buf as &MlxBuffer, "dQ"),
671        (dk_buf as &MlxBuffer, "dK"),
672        (dv_buf as &MlxBuffer, "dV"),
673    ] {
674        if buf.dtype() != DType::BF16 {
675            return Err(MlxError::InvalidArgument(format!(
676                "flash_attn_train_bwd ({bwd_kernel_name}): {name} output buffer must be \
677                 BF16, got {:?}",
678                buf.dtype()
679            )));
680        }
681    }
682    if let Some(m) = mask {
683        if m.dtype() != DType::BF16 {
684            return Err(MlxError::InvalidArgument(format!(
685                "flash_attn_train_bwd ({bwd_kernel_name}): mask buffer must be BF16, \
686                 got {:?}",
687                m.dtype()
688            )));
689        }
690    }
691
692    // ── shape arithmetic ──────────────────────────────────────────────────────
693    let batch = params.batch as usize;
694    let h_q = params.n_q_heads as usize;
695    let h_kv = params.n_kv_heads as usize;
696    let ql = params.q_seq_len as usize;
697    let kl = params.k_seq_len as usize;
698    let d = params.head_dim as usize;
699
700    let q_elems = batch * h_q * ql * d;
701    let kv_elems = batch * h_kv * kl * d;
702    let l_elems = batch * h_q * ql;
703
704    validate_buffer_size(q_buf, "Q", q_elems)?;
705    validate_buffer_size(k_buf, "K", kv_elems)?;
706    validate_buffer_size(v_buf, "V", kv_elems)?;
707    validate_buffer_size(o_buf, "O", q_elems)?;
708    validate_buffer_size(l_buf, "L", l_elems)?;
709    validate_buffer_size(do_buf, "dO", q_elems)?;
710    validate_buffer_size(dq_buf, "dQ", q_elems)?;
711    validate_buffer_size(dk_buf, "dK", kv_elems)?;
712    validate_buffer_size(dv_buf, "dV", kv_elems)?;
713    if let Some(m) = mask {
714        validate_buffer_size(m, "mask", batch * h_q * ql * kl)?;
715    }
716
717    // ── allocate internal f32 scratch buffers ─────────────────────────────────
718    // alloc_buffer zero-initialises all bytes (ADR-015 iter61a fix in device.rs).
719    // dK_f32 and dV_f32 accumulate via f32 atomic adds in the backward kernel;
720    // dQ_f32 is written (not accumulated) by the backward kernel.
721    let d_vec_buf = device
722        .alloc_buffer(l_elems * 4, DType::F32, vec![l_elems])
723        .map_err(|e| MlxError::InvalidArgument(format!("flash_attn_train_bwd: alloc D_vec: {e}")))?;
724    let dq_f32_buf = device
725        .alloc_buffer(q_elems * 4, DType::F32, vec![q_elems])
726        .map_err(|e| MlxError::InvalidArgument(format!("flash_attn_train_bwd: alloc dQ_f32: {e}")))?;
727    let dk_f32_buf = device
728        .alloc_buffer(kv_elems * 4, DType::F32, vec![kv_elems])
729        .map_err(|e| MlxError::InvalidArgument(format!("flash_attn_train_bwd: alloc dK_f32: {e}")))?;
730    let dv_f32_buf = device
731        .alloc_buffer(kv_elems * 4, DType::F32, vec![kv_elems])
732        .map_err(|e| MlxError::InvalidArgument(format!("flash_attn_train_bwd: alloc dV_f32: {e}")))?;
733
734    // ── tile geometry (same as forward) ───────────────────────────────────────
735    let nq = params.q_seq_len.div_ceil(BQ);
736    let nk = params.k_seq_len.div_ceil(BK);
737    let nq_aligned = params.q_seq_len / BQ;
738    let nk_aligned = params.k_seq_len / BK;
739    let ql_rem = params.q_seq_len % BQ;
740    let kl_rem = params.k_seq_len % BK;
741
742    let align_q = ql_rem == 0;
743    let align_k = kl_rem == 0;
744    let has_mask = mask.is_some();
745    let do_causal = params.causal;
746
747    // ── AttnParamsGpu ─────────────────────────────────────────────────────────
748    let q_seq_stride = d as i64;
749    let q_head_stride = (ql * d) as i64;
750    let q_batch_stride = (h_q * ql * d) as i64;
751    let kv_seq_stride = d as i64;
752    let kv_head_stride = (kl * d) as i64;
753    let kv_batch_stride = (h_kv * kl * d) as i64;
754    let gqa_factor = (params.n_q_heads / params.n_kv_heads) as i32;
755
756    let attn_params = AttnParamsGpu {
757        b: params.batch as i32,
758        h: params.n_q_heads as i32,
759        d: params.head_dim as i32,
760        ql: params.q_seq_len as i32,
761        kl: params.k_seq_len as i32,
762        gqa_factor,
763        scale: params.scale,
764        softcapping: 1.0_f32,
765        nq: nq as i32,
766        nk: nk as i32,
767        nq_aligned: nq_aligned as i32,
768        nk_aligned: nk_aligned as i32,
769        ql_rem: ql_rem as i32,
770        kl_rem: kl_rem as i32,
771        ql_off: 0,
772        _pad: 0,
773        q_strides: [q_batch_stride, q_head_stride, q_seq_stride],
774        k_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
775        v_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
776        o_strides: [q_batch_stride, q_head_stride, q_seq_stride],
777    };
778
779    // ── Kernel 1: compute D ───────────────────────────────────────────────────
780    dispatch_compute_d(
781        encoder, registry, device.metal_device(),
782        o_buf, do_buf, &d_vec_buf, params,
783    )?;
784    encoder.memory_barrier();
785
786    // ── Kernel 2: FA-2 backward ───────────────────────────────────────────────
787    let bwd_pipeline = registry.get_pipeline_with_bool_constants(
788        bwd_kernel_name,
789        device.metal_device(),
790        &[
791            (200, align_q),
792            (201, align_k),
793            (300, has_mask),
794            (301, do_causal),
795        ],
796    )?;
797
798    // Grid: (ceil(qL/BQ), H_q, B), tg_size: (32, WM, WN)
799    let grid = MTLSize::new(nq as u64, params.n_q_heads as u64, params.batch as u64);
800    let tg_size = MTLSize::new(32, WM as u64, WN as u64);
801
802    encoder.set_op_kind(CapturedOpKind::Sdpa);
803
804    if let Some(mask_buf) = mask {
805        let m_batch_stride = (h_q * ql * kl) as i64;
806        let m_head_stride = (ql * kl) as i64;
807        let m_ql_stride = kl as i64;
808        let mask_params = AttnMaskParamsGpu {
809            m_strides: [m_batch_stride, m_head_stride, m_ql_stride],
810        };
811        encoder.encode_threadgroups_with_args(
812            bwd_pipeline,
813            &[
814                (0, KernelArg::Buffer(q_buf)),
815                (1, KernelArg::Buffer(k_buf)),
816                (2, KernelArg::Buffer(v_buf)),
817                // buffer(3) unused (O is not needed in backward computation)
818                (4, KernelArg::Buffer(l_buf)),
819                (5, KernelArg::Buffer(do_buf)),
820                (6, KernelArg::Buffer(&d_vec_buf)),
821                (7, KernelArg::Buffer(&dq_f32_buf)),
822                (8, KernelArg::Buffer(&dk_f32_buf)),
823                (9, KernelArg::Buffer(&dv_f32_buf)),
824                (10, KernelArg::Bytes(as_bytes(&attn_params))),
825                (11, KernelArg::Bytes(as_bytes(&mask_params))),
826                (12, KernelArg::Buffer(mask_buf)),
827            ],
828            grid,
829            tg_size,
830        );
831    } else {
832        encoder.encode_threadgroups_with_args(
833            bwd_pipeline,
834            &[
835                (0, KernelArg::Buffer(q_buf)),
836                (1, KernelArg::Buffer(k_buf)),
837                (2, KernelArg::Buffer(v_buf)),
838                // buffer(3) unused
839                (4, KernelArg::Buffer(l_buf)),
840                (5, KernelArg::Buffer(do_buf)),
841                (6, KernelArg::Buffer(&d_vec_buf)),
842                (7, KernelArg::Buffer(&dq_f32_buf)),
843                (8, KernelArg::Buffer(&dk_f32_buf)),
844                (9, KernelArg::Buffer(&dv_f32_buf)),
845                (10, KernelArg::Bytes(as_bytes(&attn_params))),
846            ],
847            grid,
848            tg_size,
849        );
850    }
851    encoder.memory_barrier();
852
853    // ── Kernel 3: f32 → bf16 cast for dQ, dK, dV ─────────────────────────────
854    dispatch_f32_to_bf16(encoder, registry, device.metal_device(), &dq_f32_buf, dq_buf, q_elems)?;
855    encoder.memory_barrier();
856    dispatch_f32_to_bf16(encoder, registry, device.metal_device(), &dk_f32_buf, dk_buf, kv_elems)?;
857    encoder.memory_barrier();
858    dispatch_f32_to_bf16(encoder, registry, device.metal_device(), &dv_f32_buf, dv_buf, kv_elems)?;
859
860    Ok(())
861}
862
863// ── Public backward dispatchers ───────────────────────────────────────────────
864
865/// Dispatch the FA-2 backward pass for bf16 I/O, head_dim=64.
866///
867/// Encodes a three-kernel sequence into `encoder`:
868/// 1. Compute D pre-pass (`D[b,h,i] = rowsum(O·dO)`).
869/// 2. FA-2 Algorithm 4 backward: writes f32 dQ/dK/dV.
870/// 3. Cast f32 dQ/dK/dV → bf16 output buffers.
871///
872/// # Buffer shapes
873///
874/// - `q_buf`  — `[batch, n_q_heads, q_seq_len, 64]`  BF16
875/// - `k_buf`  — `[batch, n_kv_heads, k_seq_len, 64]` BF16
876/// - `v_buf`  — `[batch, n_kv_heads, k_seq_len, 64]` BF16
877/// - `o_buf`  — `[batch, n_q_heads, q_seq_len, 64]`  BF16 (forward output)
878/// - `l_buf`  — `[batch, n_q_heads, q_seq_len]`      F32  (forward logsumexp)
879/// - `do_buf` — `[batch, n_q_heads, q_seq_len, 64]`  BF16 (upstream gradient)
880/// - `mask`   — `[batch, n_q_heads, q_seq_len, k_seq_len]` BF16 additive, or `None`
881/// - `dq_buf` — `[batch, n_q_heads, q_seq_len, 64]`  BF16 (output, zero-init by caller)
882/// - `dk_buf` — `[batch, n_kv_heads, k_seq_len, 64]` BF16 (output, zero-init by caller)
883/// - `dv_buf` — `[batch, n_kv_heads, k_seq_len, 64]` BF16 (output, zero-init by caller)
884///
885/// # Errors
886///
887/// Returns `MlxError::InvalidArgument` for wrong head_dim, dtype mismatch,
888/// GQA ratio error, or undersized buffer.
889#[allow(clippy::too_many_arguments)]
890pub fn dispatch_flash_attn_train_bwd_bf16_d64(
891    encoder: &mut CommandEncoder,
892    device: &MlxDevice,
893    registry: &mut KernelRegistry,
894    q_buf: &MlxBuffer,
895    k_buf: &MlxBuffer,
896    v_buf: &MlxBuffer,
897    o_buf: &MlxBuffer,
898    l_buf: &MlxBuffer,
899    do_buf: &MlxBuffer,
900    mask: Option<&MlxBuffer>,
901    dq_buf: &MlxBuffer,
902    dk_buf: &MlxBuffer,
903    dv_buf: &MlxBuffer,
904    params: &FlashAttnTrainParams,
905) -> Result<()> {
906    dispatch_bwd_inner(
907        encoder, device, registry,
908        q_buf, k_buf, v_buf, o_buf, l_buf, do_buf, mask,
909        dq_buf, dk_buf, dv_buf,
910        params, K_BWD_D64, 64,
911    )
912}
913
914/// Dispatch the FA-2 backward pass for bf16 I/O, head_dim=256.
915///
916/// Same semantics as [`dispatch_flash_attn_train_bwd_bf16_d64`] but for
917/// Qwen3.6-35B-A3B head dimension (D=256).
918///
919/// # Errors
920///
921/// Same as `dispatch_flash_attn_train_bwd_bf16_d64`.
922#[allow(clippy::too_many_arguments)]
923pub fn dispatch_flash_attn_train_bwd_bf16_d256(
924    encoder: &mut CommandEncoder,
925    device: &MlxDevice,
926    registry: &mut KernelRegistry,
927    q_buf: &MlxBuffer,
928    k_buf: &MlxBuffer,
929    v_buf: &MlxBuffer,
930    o_buf: &MlxBuffer,
931    l_buf: &MlxBuffer,
932    do_buf: &MlxBuffer,
933    mask: Option<&MlxBuffer>,
934    dq_buf: &MlxBuffer,
935    dk_buf: &MlxBuffer,
936    dv_buf: &MlxBuffer,
937    params: &FlashAttnTrainParams,
938) -> Result<()> {
939    dispatch_bwd_inner(
940        encoder, device, registry,
941        q_buf, k_buf, v_buf, o_buf, l_buf, do_buf, mask,
942        dq_buf, dk_buf, dv_buf,
943        params, K_BWD_D256, 256,
944    )
945}