use mlx_native::ops::flash_attn_vec::{self, FlashAttnVecParams};
use mlx_native::{DType, KernelRegistry, MlxDevice};
fn cpu_flash_attn(
q: &[f32],
k: &[f32],
v: &[f32],
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
kv_seq_len: usize,
kv_capacity: usize,
scale: f32,
mask_type: u32,
sliding_window: usize,
) -> Vec<f32> {
let heads_per_kv = num_heads / num_kv_heads;
let mut output = vec![0.0f32; num_heads * head_dim];
let abs_pos = kv_seq_len - 1;
let window_start = if mask_type == 2 && sliding_window > 0 {
abs_pos.saturating_sub(sliding_window - 1)
} else {
0
};
let causal_max_k = std::cmp::min(abs_pos + 1, kv_seq_len);
for h in 0..num_heads {
let kv_h = h / heads_per_kv;
let q_offset = h * head_dim;
let k_head_base = kv_h * kv_capacity * head_dim;
let mut scores = Vec::new();
for k_pos in window_start..causal_max_k {
let k_offset = k_head_base + k_pos * head_dim;
let mut dot = 0.0f64; for d in 0..head_dim {
dot += q[q_offset + d] as f64 * k[k_offset + d] as f64;
}
scores.push(dot as f32 * scale);
}
if scores.is_empty() {
continue;
}
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max_score).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum();
let weights: Vec<f32> = exp_scores.iter().map(|e| e / sum_exp).collect();
let v_head_base = kv_h * kv_capacity * head_dim;
let o_offset = h * head_dim;
for d in 0..head_dim {
let mut acc = 0.0f32;
for (i, &w) in weights.iter().enumerate() {
let k_pos = window_start + i;
acc += w * v[v_head_base + k_pos * head_dim + d];
}
output[o_offset + d] = acc;
}
}
output
}
fn pseudo_random(seed: u64) -> f32 {
let x = seed
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let bits = ((x >> 33) as u32) & 0x7FFFFF;
(bits as f32 / 0x7FFFFF as f32) * 2.0 - 1.0
}
fn fill_random(buf: &mut [f32], seed: u64) {
for (i, val) in buf.iter_mut().enumerate() {
*val = pseudo_random(seed + i as u64);
}
}
fn run_test(
num_heads: u32,
num_kv_heads: u32,
head_dim: u32,
kv_seq_len: u32,
kv_capacity: u32,
scale: f32,
mask_type: u32,
sliding_window: u32,
seed: u64,
label: &str,
epsilon: f32,
) {
let device = MlxDevice::new().expect("Metal device");
let mut registry = KernelRegistry::new();
let q_elems = num_heads as usize * head_dim as usize;
let kv_elems = num_kv_heads as usize * kv_capacity as usize * head_dim as usize;
let mut q_data = vec![0.0f32; q_elems];
let mut k_data = vec![0.0f32; kv_elems];
let mut v_data = vec![0.0f32; kv_elems];
fill_random(&mut q_data, seed);
fill_random(&mut k_data, seed + 10000);
fill_random(&mut v_data, seed + 20000);
let expected = cpu_flash_attn(
&q_data,
&k_data,
&v_data,
num_heads as usize,
num_kv_heads as usize,
head_dim as usize,
kv_seq_len as usize,
kv_capacity as usize,
scale,
mask_type,
sliding_window as usize,
);
let q_bytes = q_elems * 4;
let kv_bytes = kv_elems * 4;
let out_bytes = q_elems * 4;
let mut q_buf = device
.alloc_buffer(q_bytes, DType::F32, vec![q_elems])
.expect("alloc Q");
let mut k_buf = device
.alloc_buffer(kv_bytes, DType::F32, vec![kv_elems])
.expect("alloc K");
let mut v_buf = device
.alloc_buffer(kv_bytes, DType::F32, vec![kv_elems])
.expect("alloc V");
let output_buf = device
.alloc_buffer(out_bytes, DType::F32, vec![q_elems])
.expect("alloc output");
q_buf
.as_mut_slice::<f32>()
.expect("q slice")
.copy_from_slice(&q_data);
k_buf
.as_mut_slice::<f32>()
.expect("k slice")
.copy_from_slice(&k_data);
v_buf
.as_mut_slice::<f32>()
.expect("v slice")
.copy_from_slice(&v_data);
let tmp_bytes = flash_attn_vec::tmp_buffer_bytes(num_heads, head_dim);
let tmp_elems = tmp_bytes / 4;
let tmp_buf = device
.alloc_buffer(tmp_bytes, DType::F32, vec![tmp_elems])
.expect("alloc tmp");
let params = FlashAttnVecParams {
num_heads,
num_kv_heads,
head_dim,
kv_seq_len,
kv_capacity,
scale,
mask_type,
sliding_window,
softcap: 0.0,
};
let mut encoder = device.command_encoder().expect("encoder");
flash_attn_vec::flash_attn_vec(
&mut encoder,
&mut registry,
&device,
&q_buf,
&k_buf,
&v_buf,
&output_buf,
&tmp_buf,
¶ms,
)
.expect("flash_attn_vec dispatch");
encoder.commit_and_wait().expect("commit_and_wait");
let gpu_output: &[f32] = output_buf.as_slice::<f32>().expect("output slice");
let mut max_diff = 0.0f32;
for (i, (&gpu, &cpu)) in gpu_output.iter().zip(expected.iter()).enumerate() {
let diff = (gpu - cpu).abs();
if diff > max_diff {
max_diff = diff;
}
if diff > epsilon {
panic!(
"{label}: mismatch at index {i}: gpu={gpu:.6}, cpu={cpu:.6}, diff={diff:.6e}"
);
}
}
eprintln!("{label}: max_diff = {max_diff:.6e}");
}
#[test]
fn test_flash_attn_vec_dk256_basic() {
run_test(
4, 4, 256, 32, 64, 1.0 / (256.0_f32).sqrt(), 1, 0, 42, "dk256 basic",
1e-2,
);
}
#[test]
fn test_flash_attn_vec_dk256_gqa() {
run_test(
16, 8, 256, 48, 64, 1.0, 1, 0, 100, "dk256 GQA",
1e-2,
);
}
#[test]
fn test_flash_attn_vec_dk256_sliding_window() {
run_test(
4, 4, 256, 64, 128, 1.0, 2, 16, 777, "dk256 sliding window",
1e-2,
);
}
#[test]
fn test_flash_attn_vec_dk512_basic() {
run_test(
4, 4, 512, 32, 64, 1.0 / (512.0_f32).sqrt(), 1, 0, 1111, "dk512 basic",
1e-2,
);
}
#[test]
fn test_flash_attn_vec_dk512_gqa() {
run_test(
16, 8, 512, 48, 64, 1.0, 1, 0, 2222, "dk512 GQA",
1e-2,
);
}