mlx-native 0.4.7

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
//! SWA / causal attention-mask builder for the flash_attn_prefill kernels.
//!
//! Ported from llama.cpp's `llm_graph_input_attn_no_cache::set_input`
//! (`/opt/llama.cpp/src/llama-graph.cpp:380-444`) and the `is_masked_swa`
//! predicate at `/opt/llama.cpp/src/llama-hparams.h:316-328`.  See
//! `docs/ADR-011-phase2-port-swa-mask.md` for the full port spec and
//! `docs/ADR-011-phase2-wave2d-swa-mask-verification.md` for verification
//! notes.
//!
//! ## What this module does
//!
//! Given `(seq_len_q, seq_len_k, window_size, causal, q_abs_offset)`, it
//! dispatches a small Metal kernel that fills a device-resident bf16 buffer
//! of shape `[seq_len_q, seq_len_k]` with the additive attention mask
//! consumed by [`crate::ops::flash_attn_prefill`] and
//! [`crate::ops::flash_attn_prefill_d512`].  Cells for attended positions
//! are written as `bf16(0.0)` (bit pattern `0x0000`); cells for masked
//! positions are written as `bf16(-inf)` (bit pattern `0xFF80`).
//!
//! ## Sentinel choice
//!
//! Masked = `-INFINITY` (post-Wave-2A convention).  Attended = `0.0`.  These
//! match llama.cpp's CPU-side convention at `llama-graph.cpp:421, 436`.
//! bf16 has the same 8-bit exponent as f32 so `-inf` is an exact
//! representable value — no cast precision loss.
//!
//! ## Why GPU fill (not CPU fill)
//!
//! llama.cpp builds the mask on the CPU then relies on ggml's implicit
//! host→device upload.  We build on-GPU instead because:
//!
//! 1. **Unified memory** on Apple Silicon means there is no meaningful
//!    "upload" — CPU and GPU see the same `StorageModeShared` buffer.  The
//!    CPU vs GPU distinction collapses to "who writes the cells".
//! 2. **Dispatcher locality**: the rest of the mlx-native prefill path is
//!    GPU-native, including the kernel that reads the mask.  Staying
//!    on-device avoids a separate CPU fill path that would need its own
//!    validation and cache-coherence discipline.
//! 3. **Bandwidth-bound fill is essentially free**: at seq_len=2048 the
//!    mask is 8 MiB which writes in ~30 µs at Apple Silicon's sustained
//!    280 GB/s — negligible compared with the ~200 µs of a single prefill
//!    attention dispatch.
//!
//! This is a documented deviation from llama.cpp (see ADR-011 phase 2
//! §6.1).  The mask values are byte-identical.
//!
//! ## Broadcast semantics
//!
//! The mask this module writes has **logical shape `[qL, kL]`** with no
//! batch or head dimension.  It is broadcast across batch and heads at the
//! flash_attn_prefill call-site by passing `m_strides = [0, 0, kL]` —
//! see `AttnMaskParamsGpu` in [`crate::ops::flash_attn_prefill`].  This
//! mirrors llama.cpp's `ggml_new_tensor_4d(ctx, F32, n_tokens, n_tokens, 1, 1)`
//! layout where `ne[2] = ne[3] = 1` broadcast the single (qL, kL) plane
//! across heads and batch.
//!
//! ## Statelessness & caching
//!
//! The dispatcher is stateless — it allocates + fills the buffer and
//! returns it.  **The caller holds the buffer alive across layers** so the
//! global and sliding masks can be built once per prefill and reused 25×
//! (sliding layers) + 5× (global layers) for Gemma 4.
//!
//! ## Scope
//!
//! Phase 2 (this wave) supports only Gemma 4's requirements:
//! - Causal masking (toggle).
//! - Standard SWA (single `window_size` value, exclusive upper bound).
//! - Arbitrary `q_abs_offset` (for future prefill-with-existing-cache work;
//!   always 0 in the Phase 2 hf2q call-site).
//!
//! `LLAMA_SWA_TYPE_CHUNKED` and `LLAMA_SWA_TYPE_SYMMETRIC` are not ported —
//! Gemma 4 does not need them.  See ADR-011 phase 2 §1.4 for the algorithm
//! spec when a future model lands that does.

use metal::MTLSize;

use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::encoder::{CommandEncoder, KernelArg, as_bytes};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use crate::DType;

/// MSL source for the SWA-mask-fill kernel (embedded at compile time).
pub static FLASH_ATTN_PREFILL_MASK_SHADER_SOURCE: &str =
    include_str!("../shaders/flash_attn_prefill_mask.metal");

/// Kernel entry point for the bf16 mask-fill kernel.
pub const K_FILL_BF16: &str = "flash_attn_prefill_mask_fill_bf16";

/// Register the SWA-mask-fill shader source with the given kernel registry.
///
/// Must be called before any dispatch of `build_sdpa_mask_bf16`.
pub fn register(registry: &mut KernelRegistry) {
    registry.register_source(K_FILL_BF16, FLASH_ATTN_PREFILL_MASK_SHADER_SOURCE);
}

/// Host-side parameters for the SWA-mask builder.
///
/// Mirrors llama.cpp's `(n_tokens, n_kv, n_swa, swa_type, causal_attn)`
/// inputs to `llm_graph_input_attn_no_cache::set_input` simplified for
/// the batch=1, single-sequence case.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SdpaMaskParams {
    /// Query sequence length (rows of the mask).
    pub seq_len_q: u32,
    /// Key sequence length (cols of the mask).  For in-place batched prefill
    /// this equals `seq_len_q`.
    pub seq_len_k: u32,
    /// Sliding window size.  `None` means "no window" — only causal masking
    /// is applied (global / dense layer behaviour).  `Some(n)` means
    /// `LLAMA_SWA_TYPE_STANDARD` with `n_swa = n`: attended iff
    /// `q_abs - k_pos < n`.
    pub window_size: Option<u32>,
    /// When true, mask future keys (`k_pos > q_abs`).  When false, no causal
    /// gating is applied — every non-SWA position is attended.  For typical
    /// LLM prefill this is always true.
    pub causal: bool,
    /// Absolute offset of the first query row in the global sequence.  For a
    /// pure in-place prefill this is 0.  For a prefill that continues an
    /// existing KV cache this is the number of keys already present — each
    /// mask row `q_row` corresponds to absolute position `q_row + q_abs_offset`.
    pub q_abs_offset: u32,
}

/// Shader-side parameter struct.  Mirrors
/// `FlashAttnPrefillMaskParams` in `flash_attn_prefill_mask.metal`.
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct MaskFillParamsGpu {
    seq_len_k: u32,
    q_abs_offset: u32,
    /// Sliding window size; `-1` means "disabled" (global / causal-only).
    /// Signed so the shader-side `int` type can encode "no window" without
    /// a separate bool.  Host serialisation ensures a non-negative value
    /// is never written when `window_size == None`.
    n_swa: i32,
    /// 1 if causal masking is applied, 0 otherwise.  `uint` rather than
    /// `bool` so the struct layout is bytemuck-safe.
    causal: u32,
}

/// Allocate and fill a bf16 SWA attention mask on the GPU.
///
/// Returns a fresh [`MlxBuffer`] of shape `[seq_len_q, seq_len_k]`, dtype
/// BF16 (byte length `seq_len_q * seq_len_k * 2`).  The caller owns the
/// returned buffer and is responsible for keeping it alive for as many
/// layers as need it.
///
/// The buffer has **no batch or head dimension**; callers consuming it as
/// the mask argument to `flash_attn_prefill` must set
/// `m_strides = [0, 0, seq_len_k]` in `AttnMaskParamsGpu` to broadcast
/// the single plane across batch and heads.
///
/// # Errors
///
/// - `MlxError::InvalidArgument` if `seq_len_q == 0` or `seq_len_k == 0`.
/// - `MlxError::InvalidArgument` if `window_size == Some(0)` (llama.cpp
///   treats n_swa=0 as UB upstream; we reject it cleanly here).
/// - `MlxError::BufferAllocationError` if Metal buffer allocation fails.
/// - `MlxError::ShaderCompilationError` if the mask-fill kernel fails to
///   compile (shouldn't happen on supported Apple Silicon).
pub fn build_sdpa_mask_bf16(
    device: &MlxDevice,
    registry: &mut KernelRegistry,
    encoder: &mut CommandEncoder,
    params: &SdpaMaskParams,
) -> Result<MlxBuffer> {
    // ── Validate ──────────────────────────────────────────────────────────
    if params.seq_len_q == 0 {
        return Err(MlxError::InvalidArgument(
            "build_sdpa_mask_bf16: seq_len_q must be > 0".into(),
        ));
    }
    if params.seq_len_k == 0 {
        return Err(MlxError::InvalidArgument(
            "build_sdpa_mask_bf16: seq_len_k must be > 0".into(),
        ));
    }
    if let Some(0) = params.window_size {
        return Err(MlxError::InvalidArgument(
            "build_sdpa_mask_bf16: window_size=Some(0) is not allowed \
             (llama.cpp treats n_swa=0 as undefined; pass None for \
             no-window / causal-only)".into(),
        ));
    }

    // Saturation check so `seq_len_q * seq_len_k * 2` (byte_len, below) never
    // overflows usize on 32-bit targets.  At u32::MAX^2 * 2 we would overflow
    // a u64 so this is a defence against absurd inputs, not a realistic
    // condition.
    let total_elems = (params.seq_len_q as u64)
        .checked_mul(params.seq_len_k as u64)
        .ok_or_else(|| {
            MlxError::InvalidArgument(format!(
                "build_sdpa_mask_bf16: seq_len_q ({}) * seq_len_k ({}) overflows u64",
                params.seq_len_q, params.seq_len_k
            ))
        })?;
    let byte_len = (total_elems as usize)
        .checked_mul(2)
        .ok_or_else(|| {
            MlxError::InvalidArgument(format!(
                "build_sdpa_mask_bf16: mask size ({} elems × 2 B) overflows usize",
                total_elems
            ))
        })?;

    // ── Allocate output ───────────────────────────────────────────────────
    let mask = device.alloc_buffer(
        byte_len,
        DType::BF16,
        vec![params.seq_len_q as usize, params.seq_len_k as usize],
    )?;

    // ── Build shader params ──────────────────────────────────────────────
    let fill_params = MaskFillParamsGpu {
        seq_len_k: params.seq_len_k,
        q_abs_offset: params.q_abs_offset,
        n_swa: match params.window_size {
            None => -1,
            // i32::MAX saturation: prevents signedness wraparound if a caller
            // ever passes window_size > i32::MAX (2.1 Gi tokens).  At that
            // point the model architecture is nonsense but we still produce
            // a defined output (mask is effectively causal-only).
            Some(w) => w.min(i32::MAX as u32) as i32,
        },
        causal: if params.causal { 1 } else { 0 },
    };

    // ── Pipeline lookup ───────────────────────────────────────────────────
    let pipeline = registry.get_pipeline(K_FILL_BF16, device.metal_device())?;

    // ── Grid geometry ─────────────────────────────────────────────────────
    //
    // One threadgroup per q row; threads within the threadgroup stride over
    // kL.  tg_size = min(256, kL.next_power_of_two()) mirrors softmax.metal's
    // allocation pattern and ensures a full simdgroup (32 threads) is always
    // scheduled.  Upper bound 256 so we don't exceed the 1024-thread-per-TG
    // Metal limit.
    let tg_x = {
        let want = params.seq_len_k.next_power_of_two().max(32);
        want.min(256)
    };
    let threadgroups = MTLSize::new(params.seq_len_q as u64, 1, 1);
    let tg_size = MTLSize::new(tg_x as u64, 1, 1);

    // ── Encode ────────────────────────────────────────────────────────────
    encoder.encode_threadgroups_with_args(
        pipeline,
        &[
            (0, KernelArg::Buffer(&mask)),
            (1, KernelArg::Bytes(as_bytes(&fill_params))),
        ],
        threadgroups,
        tg_size,
    );

    Ok(mask)
}

// ─── Tests ────────────────────────────────────────────────────────────────────

#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
    use super::*;

    #[test]
    fn test_mask_fill_params_gpu_size() {
        // 4 × u32/i32 = 16 bytes.  No padding (all 4-byte aligned).
        assert_eq!(std::mem::size_of::<MaskFillParamsGpu>(), 16);
    }

    #[test]
    fn test_mask_fill_params_encoding_global() {
        let p = MaskFillParamsGpu {
            seq_len_k: 2048,
            q_abs_offset: 0,
            n_swa: -1,
            causal: 1,
        };
        assert_eq!(p.n_swa, -1, "global mask encodes n_swa=-1");
        assert_eq!(p.causal, 1);
    }

    #[test]
    fn test_mask_fill_params_encoding_sliding() {
        let p = MaskFillParamsGpu {
            seq_len_k: 2048,
            q_abs_offset: 0,
            n_swa: 1024,
            causal: 1,
        };
        assert_eq!(p.n_swa, 1024, "sliding mask encodes n_swa>0");
    }

    #[test]
    fn test_reject_zero_seq_len_q() {
        // We can't easily construct an MlxDevice without Metal (CI sanity),
        // but the validation guard fires before any GPU access — check the
        // error variant path through a lightweight probe.  We check the
        // param struct is well-formed for the "good" path; the explicit
        // `seq_len_q == 0` early-return is exercised in the integration
        // test file where a real device is available.
        let p = SdpaMaskParams {
            seq_len_q: 0,
            seq_len_k: 8,
            window_size: None,
            causal: true,
            q_abs_offset: 0,
        };
        // The field check alone is sufficient at the unit level; GPU-side
        // dispatch is covered in tests/test_flash_attn_prefill.rs § 7.
        assert_eq!(p.seq_len_q, 0);
    }

    #[test]
    fn test_register_adds_kernel_name() {
        let mut registry = KernelRegistry::new();
        register(&mut registry);
        // The registry stores the source under K_FILL_BF16; we can't
        // directly inspect the internal map, but the registration must not
        // panic and the kernel name constant is stable.
        assert_eq!(K_FILL_BF16, "flash_attn_prefill_mask_fill_bf16");
    }
}