use std::ffi::CStr;
use crate::{
array::Array,
error::{Result, check},
stream::default_stream,
};
const MASK_MODE_NONE_OR_ARRAY: &CStr = c"";
const MASK_MODE_CAUSAL: &CStr = c"causal";
#[derive(Debug, Clone, Copy)]
pub enum Mask<'a> {
None,
Causal,
Array(&'a Array),
}
impl Mask<'_> {
fn mode(self) -> &'static CStr {
match self {
Mask::None | Mask::Array(_) => MASK_MODE_NONE_OR_ARRAY,
Mask::Causal => MASK_MODE_CAUSAL,
}
}
}
pub fn scaled_dot_product_attention(
q: &Array,
k: &Array,
v: &Array,
scale: f32,
mask: Mask<'_>,
) -> Result<Array> {
let mask_mode = mask.mode();
let null_array = Array(unsafe { mlxrs_sys::mlx_array_new() });
let mask_arr_ctx = match mask {
Mask::Array(arr) => arr.0,
Mask::None | Mask::Causal => null_array.0,
};
let sinks_ctx = null_array.0;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fast_scaled_dot_product_attention(
&mut out.0,
q.0,
k.0,
v.0,
scale,
mask_mode.as_ptr(),
mask_arr_ctx,
sinks_ctx,
default_stream(),
)
})?;
Ok(out)
}
#[cfg(test)]
#[allow(clippy::excessive_precision)]
mod tests {
use super::*;
const TOL: f32 = 1e-5;
fn assert_close(got: &[f32], want: &[f32]) {
assert_eq!(got.len(), want.len(), "length mismatch");
for (i, (g, w)) in got.iter().zip(want).enumerate() {
assert!(
(g - w).abs() <= TOL,
"index {i}: got {g}, want {w} (|Δ|={})",
(g - w).abs()
);
}
}
fn q_2x4() -> Array {
Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], &(1, 1, 2, 4)).unwrap()
}
fn k_2x4() -> Array {
Array::from_slice::<f32>(&[1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0], &(1, 1, 2, 4)).unwrap()
}
fn v_2x4() -> Array {
Array::from_slice::<f32>(
&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0],
&(1, 1, 2, 4),
)
.unwrap()
}
#[test]
fn unmasked_matches_hand_softmax() {
let q = q_2x4();
let k = k_2x4();
let v = v_2x4();
let mut out = scaled_dot_product_attention(&q, &k, &v, 0.5, Mask::None).unwrap();
assert_close(
&out.to_vec::<f32>().unwrap(),
&[
39.2423431, 49.2423431, 59.2423431, 69.2423431, 39.2423431, 49.2423431, 59.2423431, 69.2423431, ],
);
}
#[test]
fn causal_mask_blocks_future_keys() {
let q = q_2x4();
let k = k_2x4();
let v = v_2x4();
let mut out = scaled_dot_product_attention(&q, &k, &v, 0.5, Mask::Causal).unwrap();
assert_close(
&out.to_vec::<f32>().unwrap(),
&[
10.0, 20.0, 30.0, 40.0, 39.2423431, 49.2423431, 59.2423431, 69.2423431, ],
);
}
#[test]
fn causal_mask_decode_step_attends_to_full_history() {
let q = Array::from_slice::<f32>(&[4.0, 5.0, 6.0, 7.0], &(1, 1, 1, 4)).unwrap();
let k = k_2x4();
let v = v_2x4();
let mut out = scaled_dot_product_attention(&q, &k, &v, 0.5, Mask::Causal).unwrap();
assert_close(
&out.to_vec::<f32>().unwrap(),
&[39.2423431, 49.2423431, 59.2423431, 69.2423431],
);
}
#[test]
fn array_mask_additive_matches_causal_when_lower_triangular() {
let q = q_2x4();
let k = k_2x4();
let v = v_2x4();
let neg_inf = f32::NEG_INFINITY;
let mask = Array::from_slice::<f32>(&[0.0, neg_inf, 0.0, 0.0], &(2, 2)).unwrap();
let mut out = scaled_dot_product_attention(&q, &k, &v, 0.5, Mask::Array(&mask)).unwrap();
assert_close(
&out.to_vec::<f32>().unwrap(),
&[
10.0, 20.0, 30.0, 40.0, 39.2423431, 49.2423431, 59.2423431, 69.2423431, ],
);
}
#[test]
fn array_mask_bool_matches_additive() {
let q = q_2x4();
let k = k_2x4();
let v = v_2x4();
let mask = Array::from_slice::<bool>(&[true, false, true, true], &(2, 2)).unwrap();
let mut out = scaled_dot_product_attention(&q, &k, &v, 0.5, Mask::Array(&mask)).unwrap();
assert_close(
&out.to_vec::<f32>().unwrap(),
&[
10.0, 20.0, 30.0, 40.0, 39.2423431, 49.2423431, 59.2423431, 69.2423431, ],
);
}
#[test]
fn array_mask_broadcast_zero_matches_unmasked() {
let q = q_2x4();
let k = k_2x4();
let v = v_2x4();
let mask = Array::from_slice::<f32>(&[0.0], &(1, 1)).unwrap();
let mut via_mask = scaled_dot_product_attention(&q, &k, &v, 0.5, Mask::Array(&mask)).unwrap();
let mut via_none = scaled_dot_product_attention(&q, &k, &v, 0.5, Mask::None).unwrap();
assert_close(
&via_mask.to_vec::<f32>().unwrap(),
&via_none.to_vec::<f32>().unwrap(),
);
}
#[test]
fn gqa_kv_repeated_across_query_heads() {
let q = Array::from_slice::<f32>(
&[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, ],
&(1, 2, 2, 4),
)
.unwrap();
let k = k_2x4(); let v = v_2x4();
let mut out = scaled_dot_product_attention(&q, &k, &v, 0.5, Mask::None).unwrap();
let golden = [
39.2423431, 49.2423431, 59.2423431, 69.2423431, 39.2423431, 49.2423431, 59.2423431, 69.2423431, ];
assert_close(&out.to_vec::<f32>().unwrap(), &[golden, golden].concat());
}
#[test]
fn mismatched_head_dim_errors() {
let q = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0], &(1, 1, 1, 4)).unwrap();
let k = Array::from_slice::<f32>(&[0.0, 1.0], &(1, 1, 1, 2)).unwrap();
let v = Array::from_slice::<f32>(&[0.0, 1.0], &(1, 1, 1, 2)).unwrap();
let err = scaled_dot_product_attention(&q, &k, &v, 1.0, Mask::None);
assert!(err.is_err(), "mismatched head dim must error");
}
}