mlxrs 0.1.0

Safe Rust bindings for Apple's MLX array framework, with LM, VLM, audio, and embeddings support
//! Fast scaled-dot-product attention.
//!
//! A 1:1 port of mlx's fused multi-head attention primitive: the python
//! `mx.fast.scaled_dot_product_attention` (`python/src/fast.cpp`) and swift
//! `MLXFast.scaledDotProductAttention` (`Source/MLX/MLXFast.swift`). Both
//! wrap the same mlx-c entry point `mlx_fast_scaled_dot_product_attention`,
//! which computes `O = softmax(Q @ K.T * scale, dim=-1) @ V` (with the
//! softmax in `float32` regardless of input precision — see mlx's docstring)
//! and dispatches to an optimized Metal kernel when the query sequence
//! length is 1, falling back to ordinary mlx ops otherwise.
//!
//! Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762),
//! [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and
//! [Multi-Query Attention](https://arxiv.org/abs/1911.02150). For
//! GQA / MQA the `k` / `v` inputs are **not** pre-tiled to match `q` — the
//! kernel handles the repeat internally; `n_q_heads` must be a multiple of
//! `n_kv_heads`.
//!
//! In the following the dimensions are given by:
//!
//! - `B`: batch size,
//! - `N_q`: number of query heads,
//! - `N_kv`: number of key / value heads,
//! - `T_q`: queries per example,
//! - `T_kv`: keys / values per example,
//! - `D`: per-head dimension.
//!
//! `q` is `[B, N_q, T_q, D]`, `k` is `[B, N_kv, T_kv, D]`, `v` is
//! `[B, N_kv, T_kv, D]`; the output is `[B, N_q, T_q, D]`.
//!
//! # Scope
//!
//! This is the **base** sdpa primitive only. Cache-aware quantized routing
//! (mlx-swift-lm's `attentionWithCacheUpdate`'s `QuantizedKVCacheProtocol`
//! branch dispatching to `quantizedScaledDotProductAttention`) is a
//! documented follow-up; this module always calls the dense (non-quantized)
//! kernel. The `sinks` argument (attention-sinks, mlx's `sinks` kw-only
//! arg) is also a follow-up — the mlx-c entry point accepts a NULL handle
//! for absent `sinks` and the dense path is the common case.

use std::ffi::CStr;

use crate::{
  array::Array,
  error::{Result, check},
  stream::default_stream,
};

/// mlx's `mask_mode` string for "no mask" / array mask (empty string) — mlx-c
/// rejects any other value than `""` / `"causal"` / `"array"` (`mlx/fast.cpp`
/// `[scaled_dot_product_attention] Invalid mask_mode`). The two non-empty
/// modes are exposed via the [`Mask`] enum.
const MASK_MODE_NONE_OR_ARRAY: &CStr = c"";

/// mlx's `mask_mode` string for the implicit causal lower-triangle mask
/// (mlx's `mask="causal"` kw-arg, swift's
/// `ScaledDotProductAttentionMaskMode.causal`). When this is in effect no
/// `mask` array is passed (mlx-c rejects a non-null `mask_arr` paired with
/// `mask_mode="causal"`).
const MASK_MODE_CAUSAL: &CStr = c"causal";

/// Attention-mask selector for [`scaled_dot_product_attention`], mirroring
/// mlx's `mask` argument and swift's `ScaledDotProductAttentionMaskMode`.
///
/// - [`None`](Mask::None): unmasked — full attention over all keys.
/// - [`Causal`](Mask::Causal): the implicit lower-triangular causal mask
///   (mlx's `mask="causal"`). The mask is generated by the kernel from the
///   key / query lengths — `T_kv - T_q` is treated as the offset, so this
///   path naturally supports incremental decode (`T_q = 1`,
///   `T_kv = cache_len + 1`).
/// - [`Array`](Mask::Array): an explicit mask array, additive (float) or
///   boolean (mlx promotes a bool mask to additive `0 / -inf` internally).
///   The mask must broadcast to `[B, N_q, T_q, T_kv]` and have rank `<= 4`
///   — mlx surfaces a recoverable error on shape / dtype mismatch.
#[derive(Debug, Clone, Copy)]
pub enum Mask<'a> {
  /// No mask — full (non-causal) attention.
  None,
  /// Implicit causal lower-triangular mask generated by the kernel from
  /// `T_kv - T_q`.
  Causal,
  /// Explicit additive (float) or boolean mask broadcast to
  /// `[B, N_q, T_q, T_kv]`.
  Array(&'a Array),
}

impl Mask<'_> {
  /// The `mask_mode` C-string mlx-c expects: `""` for the absent / explicit
  /// `array` case, `"causal"` for the implicit causal mask. The swift wrapper
  /// uses the same split (`MLXFast.swift` `ScaledDotProductAttentionMaskMode.mode`).
  fn mode(self) -> &'static CStr {
    match self {
      Mask::None | Mask::Array(_) => MASK_MODE_NONE_OR_ARRAY,
      Mask::Causal => MASK_MODE_CAUSAL,
    }
  }
}

/// Fast multi-head attention: `O = softmax(Q @ K.T * scale, dim=-1) @ V`.
/// Free-fn mirror of python `mx.fast.scaled_dot_product_attention(q, k, v, *,
/// scale, mask)` (`python/src/fast.cpp`) and swift
/// `MLXFast.scaledDotProductAttention(queries:keys:values:scale:mask:)`
/// (`Source/MLX/MLXFast.swift`).
///
/// - `q`: queries, shape `[B, N_q, T_q, D]`.
/// - `k`: keys, shape `[B, N_kv, T_kv, D]` — **not** pre-tiled to `N_q`
///   for GQA / MQA (the kernel handles the repeat).
/// - `v`: values, shape `[B, N_kv, T_kv, D]`.
/// - `scale`: query scale (typically `1 / sqrt(D)`).
/// - `mask`: attention mask — see [`Mask`].
///
/// Returns the attention output, shape `[B, N_q, T_q, D]`. Does **not**
/// evaluate; like every `mlxrs` op it appends to the lazy graph (eval is an
/// explicit `&mut` step on the result).
///
/// mlx validates the input ranks (all rank-4), the matching batch dim, the
/// matching last dim of `q` / `k`, the matching `N_kv` of `k` / `v`,
/// `N_q % N_kv == 0`, and (for `mask = Array`) the mask rank / dtype
/// promotion — any violation surfaces here as a recoverable error.
///
/// See [mlx docs](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.fast.scaled_dot_product_attention.html).
pub fn scaled_dot_product_attention(
  q: &Array,
  k: &Array,
  v: &Array,
  scale: f32,
  mask: Mask<'_>,
) -> Result<Array> {
  let mask_mode = mask.mode();
  // SAFETY: `mlx_array_new()` returns a fresh empty handle (NULL ctx) per
  // the mlx-c convention. It is wrapped in the RAII newtype so it is freed
  // on drop; a NULL-ctx `mlx_array` is the absent-optional `mask_arr` /
  // `sinks` value mlx-c accepts (its `(mask_arr.ctx ? ... : std::nullopt)`
  // dispatch), and the guard keeps it alive across the FFI call below.
  let null_array = Array(unsafe { mlxrs_sys::mlx_array_new() });
  // Borrow the explicit array's handle when present; otherwise reuse the
  // null guard. Swift's wrapper takes the same approach
  // (`mask?.ctx ?? MLXArray.mlxNone.ctx`).
  let mask_arr_ctx = match mask {
    Mask::Array(arr) => arr.0,
    Mask::None | Mask::Causal => null_array.0,
  };
  // `sinks` is unconditionally absent — this base wrapper does not expose
  // the attention-sinks path (see module docs).
  let sinks_ctx = null_array.0;
  // SAFETY: `mlx_array_new()` yields a fresh empty out-param handle (NULL
  // ctx); it is wrapped in the RAII newtype FIRST so an early return /
  // panic frees it, then populated by the following call.
  let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
  // SAFETY: all `mlx_*` handle args are valid borrowed handles, live for
  // the call and not retained by mlx past it — `q.0` / `k.0` / `v.0` are
  // the input arrays, `mask_arr_ctx` is either an explicit borrowed array
  // (kept alive by `mask`) or the NULL-ctx placeholder (kept alive by
  // `null_array`); `sinks_ctx` is always the NULL placeholder; `mask_mode`
  // is a `'static` C-string literal; the out-param was freshly allocated
  // above and is written by this call; the backend rc (incl. mlx's shape /
  // dtype validation) is surfaced via `check()`.
  check(unsafe {
    mlxrs_sys::mlx_fast_scaled_dot_product_attention(
      &mut out.0,
      q.0,
      k.0,
      v.0,
      scale,
      mask_mode.as_ptr(),
      mask_arr_ctx,
      sinks_ctx,
      default_stream(),
    )
  })?;
  Ok(out)
}

#[cfg(test)]
#[allow(clippy::excessive_precision)]
mod tests {
  use super::*;

  /// Golden values were derived by hand-computing `softmax(Q @ K.T * scale)
  /// @ V` in f64 and rounded to 7 digits. The fused mlx kernel performs
  /// the softmax in `float32` per mlx's docstring, so a `1e-5` abs
  /// tolerance absorbs the f32 rounding gap between the reference and the
  /// kernel.
  const TOL: f32 = 1e-5;

  fn assert_close(got: &[f32], want: &[f32]) {
    assert_eq!(got.len(), want.len(), "length mismatch");
    for (i, (g, w)) in got.iter().zip(want).enumerate() {
      assert!(
        (g - w).abs() <= TOL,
        "index {i}: got {g}, want {w} (|Δ|={})",
        (g - w).abs()
      );
    }
  }

  /// `[B=1, N=1, T=2, D=4]` queries — two query positions, head_dim 4.
  fn q_2x4() -> Array {
    Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], &(1, 1, 2, 4)).unwrap()
  }

  /// `[B=1, N=1, T=2, D=4]` keys — two key positions, head_dim 4.
  fn k_2x4() -> Array {
    Array::from_slice::<f32>(&[1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0], &(1, 1, 2, 4)).unwrap()
  }

  /// `[B=1, N=1, T=2, D=4]` values — two value positions, head_dim 4.
  fn v_2x4() -> Array {
    Array::from_slice::<f32>(
      &[10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0],
      &(1, 1, 2, 4),
    )
    .unwrap()
  }

  /// Unmasked attention with `scale = 1/sqrt(4) = 0.5`.
  ///
  /// Scores `Q @ K.T * 0.5` (Q rows · K rows · 0.5):
  /// - row 0: `[(0+0+2+0)*0.5, (0+1+0+3)*0.5] = [1, 2]`
  /// - row 1: `[(4+0+6+0)*0.5, (0+5+0+7)*0.5] = [5, 6]`
  ///
  /// softmax row 0: `[exp(1)/(exp(1)+exp(2)), exp(2)/(exp(1)+exp(2))]`
  /// `= [0.2689414, 0.7310586]`.
  /// softmax row 1: same shape (scores differ by a constant) ->
  /// `[0.2689414, 0.7310586]`.
  ///
  /// `out[0] = 0.2689414 * v_row0 + 0.7310586 * v_row1`
  /// `out[1]` = same weights, so same per-column mix:
  /// columns `[10,20,30,40]` & `[50,60,70,80]` mix to
  /// `[39.2423431, 49.2423431, 59.2423431, 69.2423431]`.
  #[test]
  fn unmasked_matches_hand_softmax() {
    let q = q_2x4();
    let k = k_2x4();
    let v = v_2x4();
    let mut out = scaled_dot_product_attention(&q, &k, &v, 0.5, Mask::None).unwrap();
    assert_close(
      &out.to_vec::<f32>().unwrap(),
      &[
        39.2423431, 49.2423431, 59.2423431, 69.2423431, // query 0
        39.2423431, 49.2423431, 59.2423431, 69.2423431, // query 1
      ],
    );
  }

  /// Causal mask with `scale = 1/sqrt(4) = 0.5`, square `T_q = T_kv = 2`:
  /// query 0 only attends to key 0; query 1 attends to keys 0 and 1.
  ///
  /// - row 0: softmax over `[1, -inf]` ⇒ `[1.0, 0.0]` ⇒ `out[0] = v_row0`
  ///   `= [10, 20, 30, 40]`.
  /// - row 1: unmasked (same as the unmasked-test row 1) ⇒
  ///   `[39.2423431, 49.2423431, 59.2423431, 69.2423431]`.
  #[test]
  fn causal_mask_blocks_future_keys() {
    let q = q_2x4();
    let k = k_2x4();
    let v = v_2x4();
    let mut out = scaled_dot_product_attention(&q, &k, &v, 0.5, Mask::Causal).unwrap();
    assert_close(
      &out.to_vec::<f32>().unwrap(),
      &[
        10.0, 20.0, 30.0, 40.0, // query 0: only key 0
        39.2423431, 49.2423431, 59.2423431, 69.2423431, // query 1: both keys
      ],
    );
  }

  /// Causal mask with `T_q = 1, T_kv = 2` — the incremental-decode shape.
  /// mlx's causal mask uses `offset = T_kv - T_q = 1`, so the single query
  /// is treated as position 1 and attends to BOTH keys (matches unmasked
  /// query 1 from the unmasked test).
  #[test]
  fn causal_mask_decode_step_attends_to_full_history() {
    let q = Array::from_slice::<f32>(&[4.0, 5.0, 6.0, 7.0], &(1, 1, 1, 4)).unwrap();
    let k = k_2x4();
    let v = v_2x4();
    let mut out = scaled_dot_product_attention(&q, &k, &v, 0.5, Mask::Causal).unwrap();
    assert_close(
      &out.to_vec::<f32>().unwrap(),
      &[39.2423431, 49.2423431, 59.2423431, 69.2423431],
    );
  }

  /// Explicit additive float mask. Manually masks out key 1 for query 0 by
  /// adding `-inf` to that score — must match the causal-mask result for
  /// query 0 (only key 0) and the unmasked result for query 1.
  #[test]
  fn array_mask_additive_matches_causal_when_lower_triangular() {
    let q = q_2x4();
    let k = k_2x4();
    let v = v_2x4();
    let neg_inf = f32::NEG_INFINITY;
    // Mask shape `[T_q=2, T_kv=2]` — broadcasts to `[B, N_q, T_q, T_kv]`.
    let mask = Array::from_slice::<f32>(&[0.0, neg_inf, 0.0, 0.0], &(2, 2)).unwrap();
    let mut out = scaled_dot_product_attention(&q, &k, &v, 0.5, Mask::Array(&mask)).unwrap();
    assert_close(
      &out.to_vec::<f32>().unwrap(),
      &[
        10.0, 20.0, 30.0, 40.0, // query 0: only key 0 (manual mask)
        39.2423431, 49.2423431, 59.2423431, 69.2423431, // query 1: both keys
      ],
    );
  }

  /// Explicit boolean mask — mlx promotes a bool mask to additive `0 /
  /// -inf` internally (`true` = keep, `false` = mask). Same expected
  /// result as the additive-`-inf` version above.
  #[test]
  fn array_mask_bool_matches_additive() {
    let q = q_2x4();
    let k = k_2x4();
    let v = v_2x4();
    let mask = Array::from_slice::<bool>(&[true, false, true, true], &(2, 2)).unwrap();
    let mut out = scaled_dot_product_attention(&q, &k, &v, 0.5, Mask::Array(&mask)).unwrap();
    assert_close(
      &out.to_vec::<f32>().unwrap(),
      &[
        10.0, 20.0, 30.0, 40.0, // query 0: only key 0
        39.2423431, 49.2423431, 59.2423431, 69.2423431, // query 1: both keys
      ],
    );
  }

  /// Broadcast mask: a `[1, 1]` mask of `0.0` is broadcast across the full
  /// `[T_q, T_kv]` plane (and across `[B, N_q]` by mlx's broadcaster). Net
  /// effect: an all-zero additive mask, identical to no mask.
  #[test]
  fn array_mask_broadcast_zero_matches_unmasked() {
    let q = q_2x4();
    let k = k_2x4();
    let v = v_2x4();
    let mask = Array::from_slice::<f32>(&[0.0], &(1, 1)).unwrap();
    let mut via_mask = scaled_dot_product_attention(&q, &k, &v, 0.5, Mask::Array(&mask)).unwrap();
    let mut via_none = scaled_dot_product_attention(&q, &k, &v, 0.5, Mask::None).unwrap();
    assert_close(
      &via_mask.to_vec::<f32>().unwrap(),
      &via_none.to_vec::<f32>().unwrap(),
    );
  }

  /// Grouped Query Attention: `N_q = 2`, `N_kv = 1` — the single KV head
  /// is repeated across both query heads. Both query heads share the same
  /// queries here, so the two output heads must match each other AND match
  /// the `N_q = N_kv = 1` reference from `unmasked_matches_hand_softmax`.
  #[test]
  fn gqa_kv_repeated_across_query_heads() {
    // q: [B=1, N_q=2, T=2, D=4] — same per-head queries as `q_2x4()`.
    let q = Array::from_slice::<f32>(
      &[
        0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, // head 0
        0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, // head 1
      ],
      &(1, 2, 2, 4),
    )
    .unwrap();
    let k = k_2x4(); // [1, 1, 2, 4] — kernel repeats this across N_q=2
    let v = v_2x4();
    let mut out = scaled_dot_product_attention(&q, &k, &v, 0.5, Mask::None).unwrap();
    let golden = [
      39.2423431, 49.2423431, 59.2423431, 69.2423431, // q 0
      39.2423431, 49.2423431, 59.2423431, 69.2423431, // q 1
    ];
    // Output is `[1, N_q=2, T=2, D=4]` — both heads must match the golden.
    assert_close(&out.to_vec::<f32>().unwrap(), &[golden, golden].concat());
  }

  /// Validation surfaces from mlx: shape mismatch on `D` (q vs k last dim)
  /// must produce a recoverable error, not a panic.
  #[test]
  fn mismatched_head_dim_errors() {
    let q = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0], &(1, 1, 1, 4)).unwrap();
    let k = Array::from_slice::<f32>(&[0.0, 1.0], &(1, 1, 1, 2)).unwrap();
    let v = Array::from_slice::<f32>(&[0.0, 1.0], &(1, 1, 1, 2)).unwrap();
    let err = scaled_dot_product_attention(&q, &k, &v, 1.0, Mask::None);
    assert!(err.is_err(), "mismatched head dim must error");
  }
}