#[cfg(target_os = "macos")]
#[allow(clippy::too_many_arguments)]
mod flash_attn_train_tests {
use half::bf16;
use mlx_native::ops::flash_attn_train::{
self as flash_attn_train,
FlashAttnTrainParams,
dispatch_flash_attn_train_fwd_bf16_d64,
dispatch_flash_attn_train_fwd_bf16_d256,
dispatch_flash_attn_train_bwd_bf16_d64,
dispatch_flash_attn_train_bwd_bf16_d256,
};
use mlx_native::{DType, KernelRegistry, MlxDevice};
fn sdpa_reference_with_logsumexp(
q: &[f32],
k: &[f32],
v: &[f32],
mask: Option<&[f32]>,
batch: usize,
n_heads: usize,
n_kv_heads: usize,
ql: usize,
kl: usize,
head_dim: usize,
scale: f32,
do_causal: bool,
) -> (Vec<f32>, Vec<f32>) {
const LOG2E: f32 = std::f32::consts::LOG2_E; const LN2: f32 = std::f32::consts::LN_2;
let q_scale = scale * LOG2E; let heads_per_kv = n_heads / n_kv_heads;
let mut out = vec![0.0f32; batch * n_heads * ql * head_dim];
let mut lse = vec![0.0f32; batch * n_heads * ql];
for b in 0..batch {
for h in 0..n_heads {
let kv_h = h / heads_per_kv;
for q_pos in 0..ql {
let q_base = b * n_heads * ql * head_dim + h * ql * head_dim + q_pos * head_dim;
let kv_base = b * n_kv_heads * kl * head_dim + kv_h * kl * head_dim;
let mut scores = vec![0.0f32; kl];
for k_pos in 0..kl {
let mut dot = 0.0f32;
for d in 0..head_dim {
dot += q[q_base + d] * k[kv_base + k_pos * head_dim + d];
}
scores[k_pos] = dot * q_scale;
}
if let Some(m) = mask {
let m_base = b * n_heads * ql * kl + h * ql * kl + q_pos * kl;
for k_pos in 0..kl {
scores[k_pos] += LOG2E * m[m_base + k_pos];
}
}
if do_causal {
let q_abs = kl.saturating_sub(ql) + q_pos;
for (k_pos, score) in scores.iter_mut().enumerate() {
if k_pos > q_abs {
*score = f32::NEG_INFINITY;
}
}
}
let mut max_b2 = f32::MIN / 2.0; for &s in &scores {
if s > max_b2 { max_b2 = s; }
}
let exp_scores: Vec<f32> = scores.iter()
.map(|&s| f32::exp2(s - max_b2))
.collect();
let sum_exp: f32 = exp_scores.iter().sum();
let lse_val = max_b2 * LN2 + sum_exp.ln();
lse[b * n_heads * ql + h * ql + q_pos] = lse_val;
let safe_sum = if sum_exp == 0.0 { 1.0 } else { sum_exp };
let o_base = b * n_heads * ql * head_dim + h * ql * head_dim + q_pos * head_dim;
for d in 0..head_dim {
let mut acc = 0.0f32;
for k_pos in 0..kl {
acc += (exp_scores[k_pos] / safe_sum)
* v[kv_base + k_pos * head_dim + d];
}
out[o_base + d] = acc;
}
}
}
}
(out, lse)
}
const SEED_VAL: u64 = 0x5452_4149_4E30_3031;
fn pseudo_random_f32(seed: u64, n: usize) -> Vec<f32> {
let mut state = seed;
(0..n)
.map(|_| {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
((state >> 33) as f32) / (u32::MAX as f32) - 0.5
})
.collect()
}
fn f32_to_bf16(xs: &[f32]) -> Vec<bf16> {
xs.iter().map(|&x| bf16::from_f32(x)).collect()
}
fn bf16_to_f32(xs: &[bf16]) -> Vec<f32> {
xs.iter().map(|&x| x.to_f32()).collect()
}
fn alloc_bf16(device: &MlxDevice, elems: usize, name: &str) -> mlx_native::MlxBuffer {
device
.alloc_buffer(elems * 2, DType::BF16, vec![elems])
.unwrap_or_else(|e| panic!("alloc_bf16({name}, {elems}): {e:?}"))
}
fn alloc_f32(device: &MlxDevice, elems: usize, name: &str) -> mlx_native::MlxBuffer {
device
.alloc_buffer(elems * 4, DType::F32, vec![elems])
.unwrap_or_else(|e| panic!("alloc_f32({name}, {elems}): {e:?}"))
}
fn fill_bf16_buf(buf: &mlx_native::MlxBuffer, data: &[bf16]) {
let ptr = buf.contents_ptr() as *mut bf16;
assert!(!ptr.is_null(), "contents_ptr is null");
unsafe { std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, data.len()); }
}
fn read_bf16_buf(buf: &mlx_native::MlxBuffer, elems: usize) -> Vec<bf16> {
let ptr = buf.contents_ptr() as *const bf16;
assert!(!ptr.is_null(), "contents_ptr is null");
unsafe { std::slice::from_raw_parts(ptr, elems).to_vec() }
}
fn read_f32_buf(buf: &mlx_native::MlxBuffer, elems: usize) -> Vec<f32> {
let ptr = buf.contents_ptr() as *const f32;
assert!(!ptr.is_null(), "contents_ptr is null");
unsafe { std::slice::from_raw_parts(ptr, elems).to_vec() }
}
fn assert_close(actual: &[f32], expected: &[f32], atol: f32, label: &str) {
assert_eq!(actual.len(), expected.len(), "{label}: length mismatch");
let mut max_diff = 0.0f32;
let mut worst_idx = 0usize;
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
let diff = (a - e).abs();
if diff > max_diff { max_diff = diff; worst_idx = i; }
}
assert!(
max_diff <= atol,
"{label}: max_abs_error={max_diff:.4e} at index {worst_idx} \
(actual={:.6}, expected={:.6}) exceeds atol={atol:.4e}",
actual[worst_idx], expected[worst_idx]
);
eprintln!("{label}: PASS max_abs_err={max_diff:.4e} (atol={atol:.4e})");
}
#[allow(clippy::too_many_arguments)]
fn run_train_fwd(
device: &MlxDevice,
registry: &mut KernelRegistry,
q_bf: &[bf16],
k_bf: &[bf16],
v_bf: &[bf16],
mask_bf: Option<&[bf16]>,
batch: usize,
n_q_heads: usize,
n_kv_heads: usize,
ql: usize,
kl: usize,
head_dim: usize,
scale: f32,
causal: bool,
) -> (Vec<f32>, Vec<f32>) {
let q_elems = batch * n_q_heads * ql * head_dim;
let kv_elems = batch * n_kv_heads * kl * head_dim;
let l_elems = batch * n_q_heads * ql;
let q_buf = alloc_bf16(device, q_elems, "Q");
let k_buf = alloc_bf16(device, kv_elems, "K");
let v_buf = alloc_bf16(device, kv_elems, "V");
let mut o_buf = alloc_bf16(device, q_elems, "O");
let mut l_buf = alloc_f32(device, l_elems, "L");
fill_bf16_buf(&q_buf, q_bf);
fill_bf16_buf(&k_buf, k_bf);
fill_bf16_buf(&v_buf, v_bf);
let mask_buf = mask_bf.map(|m| {
let mask_elems = batch * n_q_heads * ql * kl;
assert_eq!(m.len(), mask_elems, "mask length mismatch");
let buf = alloc_bf16(device, mask_elems, "mask");
fill_bf16_buf(&buf, m);
buf
});
let params = FlashAttnTrainParams {
batch: batch as u32,
n_q_heads: n_q_heads as u32,
n_kv_heads: n_kv_heads as u32,
head_dim: head_dim as u32,
q_seq_len: ql as u32,
k_seq_len: kl as u32,
scale,
causal,
};
let mut encoder = device.command_encoder().expect("encoder");
if head_dim == 64 {
dispatch_flash_attn_train_fwd_bf16_d64(
&mut encoder, device, registry,
&q_buf, &k_buf, &v_buf, mask_buf.as_ref(), &mut o_buf, &mut l_buf,
¶ms,
).expect("dispatch d64");
} else {
dispatch_flash_attn_train_fwd_bf16_d256(
&mut encoder, device, registry,
&q_buf, &k_buf, &v_buf, mask_buf.as_ref(), &mut o_buf, &mut l_buf,
¶ms,
).expect("dispatch d256");
}
encoder.commit_and_wait().expect("commit_and_wait");
let o_bf16 = read_bf16_buf(&o_buf, q_elems);
let o_f32 = bf16_to_f32(&o_bf16);
let l_f32 = read_f32_buf(&l_buf, l_elems);
(o_f32, l_f32)
}
#[allow(clippy::too_many_arguments)]
fn oracle_for_bf16(
q_bf: &[bf16],
k_bf: &[bf16],
v_bf: &[bf16],
mask_bf: Option<&[bf16]>,
batch: usize,
n_heads: usize,
n_kv_heads: usize,
ql: usize,
kl: usize,
head_dim: usize,
scale: f32,
causal: bool,
) -> (Vec<f32>, Vec<f32>) {
let q = bf16_to_f32(q_bf);
let k = bf16_to_f32(k_bf);
let v = bf16_to_f32(v_bf);
let mask = mask_bf.map(bf16_to_f32);
let (o_f32, l_f32) = sdpa_reference_with_logsumexp(
&q, &k, &v, mask.as_deref(),
batch, n_heads, n_kv_heads, ql, kl, head_dim, scale, causal,
);
let o_bf16 = f32_to_bf16(&o_f32);
let o_rt = bf16_to_f32(&o_bf16);
(o_rt, l_f32)
}
fn setup() -> (MlxDevice, KernelRegistry) {
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
flash_attn_train::register(&mut registry);
(device, registry)
}
#[test]
fn test_kernel_names_and_library_compiles() {
let (device, mut registry) = setup();
for &name in flash_attn_train::all_kernel_names_for_test() {
let result = registry.get_pipeline_with_bool_constants(
name,
device.metal_device(),
&[(200, true), (201, true), (300, false), (301, false)],
);
match result {
Ok(_) => eprintln!("test_kernel_names_and_library_compiles: {name} OK"),
Err(e) => panic!("Pipeline compilation failed for {name}: {e:?}"),
}
}
}
#[test]
fn test_forward_o_parity_d64_no_mask() {
let (device, mut registry) = setup();
let batch = 1; let h = 1; let kv_h = 1; let ql = 32; let kl = 32; let d = 64;
let scale = 1.0 / (d as f32).sqrt();
let q = pseudo_random_f32(SEED_VAL, batch * h * ql * d);
let k = pseudo_random_f32(SEED_VAL + 1, batch * kv_h * kl * d);
let v = pseudo_random_f32(SEED_VAL + 2, batch * kv_h * kl * d);
let q_bf = f32_to_bf16(&q);
let k_bf = f32_to_bf16(&k);
let v_bf = f32_to_bf16(&v);
let (gpu_o, _) = run_train_fwd(
&device, &mut registry, &q_bf, &k_bf, &v_bf, None,
batch, h, kv_h, ql, kl, d, scale, false,
);
let (ref_o, _) = oracle_for_bf16(
&q_bf, &k_bf, &v_bf, None,
batch, h, kv_h, ql, kl, d, scale, false,
);
assert_close(&gpu_o, &ref_o, 5e-3, "forward_o_parity_d64_no_mask");
}
#[test]
fn test_forward_l_parity_d64_no_mask() {
let (device, mut registry) = setup();
let batch = 1; let h = 1; let kv_h = 1; let ql = 32; let kl = 32; let d = 64;
let scale = 1.0 / (d as f32).sqrt();
let q = pseudo_random_f32(SEED_VAL + 10, batch * h * ql * d);
let k = pseudo_random_f32(SEED_VAL + 11, batch * kv_h * kl * d);
let v = pseudo_random_f32(SEED_VAL + 12, batch * kv_h * kl * d);
let q_bf = f32_to_bf16(&q);
let k_bf = f32_to_bf16(&k);
let v_bf = f32_to_bf16(&v);
let (_, gpu_l) = run_train_fwd(
&device, &mut registry, &q_bf, &k_bf, &v_bf, None,
batch, h, kv_h, ql, kl, d, scale, false,
);
let (_, ref_l) = oracle_for_bf16(
&q_bf, &k_bf, &v_bf, None,
batch, h, kv_h, ql, kl, d, scale, false,
);
assert_close(&gpu_l, &ref_l, 3e-3, "forward_l_parity_d64_no_mask");
}
#[test]
fn test_forward_o_l_parity_d256_no_mask() {
let (device, mut registry) = setup();
let batch = 1; let h = 4; let kv_h = 4; let ql = 128; let kl = 128; let d = 256;
let scale = 1.0 / (d as f32).sqrt();
let q = pseudo_random_f32(SEED_VAL + 20, batch * h * ql * d);
let k = pseudo_random_f32(SEED_VAL + 21, batch * kv_h * kl * d);
let v = pseudo_random_f32(SEED_VAL + 22, batch * kv_h * kl * d);
let q_bf = f32_to_bf16(&q);
let k_bf = f32_to_bf16(&k);
let v_bf = f32_to_bf16(&v);
let (gpu_o, gpu_l) = run_train_fwd(
&device, &mut registry, &q_bf, &k_bf, &v_bf, None,
batch, h, kv_h, ql, kl, d, scale, false,
);
let (ref_o, ref_l) = oracle_for_bf16(
&q_bf, &k_bf, &v_bf, None,
batch, h, kv_h, ql, kl, d, scale, false,
);
assert_close(&gpu_o, &ref_o, 5e-3, "d256_no_mask_O");
assert_close(&gpu_l, &ref_l, 3e-3, "d256_no_mask_L");
}
#[test]
fn test_forward_causal_mask_parity() {
let (device, mut registry) = setup();
let batch = 1; let h = 2; let kv_h = 2; let ql = 64; let kl = 64; let d = 64;
let scale = 1.0 / (d as f32).sqrt();
let q = pseudo_random_f32(SEED_VAL + 30, batch * h * ql * d);
let k = pseudo_random_f32(SEED_VAL + 31, batch * kv_h * kl * d);
let v = pseudo_random_f32(SEED_VAL + 32, batch * kv_h * kl * d);
let q_bf = f32_to_bf16(&q);
let k_bf = f32_to_bf16(&k);
let v_bf = f32_to_bf16(&v);
let (gpu_o, gpu_l) = run_train_fwd(
&device, &mut registry, &q_bf, &k_bf, &v_bf, None,
batch, h, kv_h, ql, kl, d, scale, true,
);
let (ref_o, ref_l) = oracle_for_bf16(
&q_bf, &k_bf, &v_bf, None,
batch, h, kv_h, ql, kl, d, scale, true,
);
assert_close(&gpu_o, &ref_o, 5e-3, "causal_mask_O");
assert_close(&gpu_l, &ref_l, 3e-3, "causal_mask_L");
}
#[test]
fn test_forward_causal_mask_strict_lower_triangular() {
let (device, mut registry) = setup();
let batch = 1; let h = 1; let kv_h = 1; let ql = 32; let kl = 32; let d = 64;
let scale = 1.0 / (d as f32).sqrt();
let q_pos_target = 15usize;
let q_f32 = pseudo_random_f32(SEED_VAL + 40, batch * h * ql * d);
let k_clean_f32 = pseudo_random_f32(SEED_VAL + 41, batch * kv_h * kl * d);
let v_f32 = pseudo_random_f32(SEED_VAL + 42, batch * kv_h * kl * d);
let mut k_perturbed_f32 = k_clean_f32.clone();
for j in (q_pos_target + 1)..kl {
for dd in 0..d {
k_perturbed_f32[j * d + dd] = 128.0_f32;
}
}
let q_bf = f32_to_bf16(&q_f32);
let k_clean_bf = f32_to_bf16(&k_clean_f32);
let k_perturbed_bf = f32_to_bf16(&k_perturbed_f32);
let v_bf = f32_to_bf16(&v_f32);
let (gpu_o_perturbed, _) = run_train_fwd(
&device, &mut registry, &q_bf, &k_perturbed_bf, &v_bf, None,
batch, h, kv_h, ql, kl, d, scale, true,
);
let (ref_o_clean, _) = oracle_for_bf16(
&q_bf, &k_clean_bf, &v_bf, None,
batch, h, kv_h, ql, kl, d, scale, true,
);
let target_elems = (q_pos_target + 1) * d;
let gpu_prefix = &gpu_o_perturbed[..target_elems];
let ref_prefix = &ref_o_clean[..target_elems];
assert_close(gpu_prefix, ref_prefix, 5e-3,
"causal_strict_lower_triangular (rows 0..=q_pos_target)");
eprintln!(
"test_forward_causal_mask_strict_lower_triangular: PASS — \
O[0..={q_pos_target}] independent of K rows [{}..]",
q_pos_target + 1
);
}
#[test]
fn test_forward_sliding_window_parity() {
let (device, mut registry) = setup();
let batch = 1; let h = 2; let kv_h = 2; let ql = 64; let kl = 64; let d = 64;
let window = 16usize;
let scale = 1.0 / (d as f32).sqrt();
let q = pseudo_random_f32(SEED_VAL + 50, batch * h * ql * d);
let k = pseudo_random_f32(SEED_VAL + 51, batch * kv_h * kl * d);
let v = pseudo_random_f32(SEED_VAL + 52, batch * kv_h * kl * d);
let q_bf = f32_to_bf16(&q);
let k_bf = f32_to_bf16(&k);
let v_bf = f32_to_bf16(&v);
let mut mask_f32 = vec![f32::NEG_INFINITY; batch * h * ql * kl];
for b in 0..batch {
for hh in 0..h {
for i in 0..ql {
for j in 0..kl {
let in_window = (j <= i) && (j + window >= i);
let idx = b * h * ql * kl + hh * ql * kl + i * kl + j;
mask_f32[idx] = if in_window { 0.0_f32 } else { f32::NEG_INFINITY };
}
}
}
}
let mask_bf = f32_to_bf16(&mask_f32);
let (gpu_o, gpu_l) = run_train_fwd(
&device, &mut registry, &q_bf, &k_bf, &v_bf, Some(&mask_bf),
batch, h, kv_h, ql, kl, d, scale, false,
);
let (ref_o, ref_l) = oracle_for_bf16(
&q_bf, &k_bf, &v_bf, Some(&mask_bf),
batch, h, kv_h, ql, kl, d, scale, false,
);
assert_close(&gpu_o, &ref_o, 5e-3, "swa_O");
assert_close(&gpu_l, &ref_l, 3e-3, "swa_L");
}
#[test]
fn test_forward_gqa_parity() {
let (device, mut registry) = setup();
let batch = 1; let h = 8; let kv_h = 2; let ql = 64; let kl = 64; let d = 64;
let scale = 1.0 / (d as f32).sqrt();
let q = pseudo_random_f32(SEED_VAL + 60, batch * h * ql * d);
let k = pseudo_random_f32(SEED_VAL + 61, batch * kv_h * kl * d);
let v = pseudo_random_f32(SEED_VAL + 62, batch * kv_h * kl * d);
let q_bf = f32_to_bf16(&q);
let k_bf = f32_to_bf16(&k);
let v_bf = f32_to_bf16(&v);
let (gpu_o, gpu_l) = run_train_fwd(
&device, &mut registry, &q_bf, &k_bf, &v_bf, None,
batch, h, kv_h, ql, kl, d, scale, false,
);
let (ref_o, ref_l) = oracle_for_bf16(
&q_bf, &k_bf, &v_bf, None,
batch, h, kv_h, ql, kl, d, scale, false,
);
assert_close(&gpu_o, &ref_o, 5e-3, "gqa_4x_O");
assert_close(&gpu_l, &ref_l, 3e-3, "gqa_4x_L");
}
#[test]
fn test_d256_library_compiles_with_l_out() {
let (device, mut registry) = setup();
let result = registry.get_pipeline_with_bool_constants(
"flash_attn_train_fwd_bf16_d256",
device.metal_device(),
&[(200, true), (201, true), (300, false), (301, false)],
);
match result {
Ok(_) => eprintln!("test_d256_library_compiles_with_l_out: OK"),
Err(mlx_native::MlxError::ShaderCompilationError { name, message }) => {
panic!(
"D=256 bf16 train fwd failed to compile — if this mentions \
threadgroup memory, the L_out buffer likely pushed the tile over \
32 KB (it shouldn't: L_out is device memory, not threadgroup). \
name={name}, message={message}"
);
}
Err(e) => panic!("Unexpected error: {e:?}"),
}
}
#[allow(clippy::too_many_arguments)]
fn sdpa_backward_reference_f32(
q: &[f32],
k: &[f32],
v: &[f32],
l_nat: &[f32],
do_: &[f32],
mask: Option<&[f32]>,
batch: usize,
n_heads: usize,
n_kv_heads: usize,
ql: usize,
kl: usize,
head_dim: usize,
scale: f32,
do_causal: bool,
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let heads_per_kv = n_heads / n_kv_heads;
let mut dq = vec![0.0f32; batch * n_heads * ql * head_dim];
let mut dk = vec![0.0f32; batch * n_kv_heads * kl * head_dim];
let mut dv = vec![0.0f32; batch * n_kv_heads * kl * head_dim];
for b in 0..batch {
for h in 0..n_heads {
let kv_h = h / heads_per_kv;
for q_i in 0..ql {
let q_base = b * n_heads * ql * head_dim + h * ql * head_dim + q_i * head_dim;
let l_i = l_nat[b * n_heads * ql + h * ql + q_i];
let mut s = vec![0.0f32; kl];
for k_j in 0..kl {
let kv_base = b * n_kv_heads * kl * head_dim
+ kv_h * kl * head_dim
+ k_j * head_dim;
let mut dot = 0.0f32;
for d in 0..head_dim {
dot += q[q_base + d] * k[kv_base + d];
}
s[k_j] = scale * dot;
if let Some(m) = mask {
let m_idx = b * n_heads * ql * kl + h * ql * kl + q_i * kl + k_j;
if m[m_idx] == f32::NEG_INFINITY {
s[k_j] = f32::NEG_INFINITY;
} else {
s[k_j] += m[m_idx];
}
}
if do_causal && k_j > q_i {
s[k_j] = f32::NEG_INFINITY;
}
}
let mut p = vec![0.0f32; kl];
for k_j in 0..kl {
p[k_j] = if s[k_j] == f32::NEG_INFINITY {
0.0f32
} else {
(s[k_j] - l_i).exp()
};
}
let mut d_i = 0.0f32;
for k_j in 0..kl {
let kv_base = b * n_kv_heads * kl * head_dim
+ kv_h * kl * head_dim
+ k_j * head_dim;
let mut dp_j = 0.0f32;
for d in 0..head_dim {
dp_j += do_[q_base + d] * v[kv_base + d];
}
d_i += p[k_j] * dp_j;
}
for k_j in 0..kl {
let kv_base = b * n_kv_heads * kl * head_dim
+ kv_h * kl * head_dim
+ k_j * head_dim;
let dk_base = b * n_kv_heads * kl * head_dim
+ kv_h * kl * head_dim
+ k_j * head_dim;
let dv_base = dk_base;
let mut dp_j = 0.0f32;
for d in 0..head_dim {
dp_j += do_[q_base + d] * v[kv_base + d];
}
let ds_j = p[k_j] * (dp_j - d_i);
for d in 0..head_dim {
dv[dv_base + d] += p[k_j] * do_[q_base + d];
}
for d in 0..head_dim {
dq[q_base + d] += scale * ds_j * k[kv_base + d];
dk[dk_base + d] += scale * ds_j * q[q_base + d];
}
}
}
}
}
(dq, dk, dv)
}
#[allow(clippy::too_many_arguments)]
fn run_train_bwd(
device: &MlxDevice,
registry: &mut KernelRegistry,
q_bf: &[bf16],
k_bf: &[bf16],
v_bf: &[bf16],
mask_bf: Option<&[bf16]>,
do_bf: &[bf16],
batch: usize,
n_q_heads: usize,
n_kv_heads: usize,
ql: usize,
kl: usize,
head_dim: usize,
scale: f32,
causal: bool,
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let q_elems = batch * n_q_heads * ql * head_dim;
let kv_elems = batch * n_kv_heads * kl * head_dim;
let l_elems = batch * n_q_heads * ql;
let q_buf = alloc_bf16(device, q_elems, "Q");
let k_buf = alloc_bf16(device, kv_elems, "K");
let v_buf = alloc_bf16(device, kv_elems, "V");
let mut o_buf = alloc_bf16(device, q_elems, "O");
let mut l_buf = alloc_f32(device, l_elems, "L");
fill_bf16_buf(&q_buf, q_bf);
fill_bf16_buf(&k_buf, k_bf);
fill_bf16_buf(&v_buf, v_bf);
let mask_buf = mask_bf.map(|m| {
let mask_elems = batch * n_q_heads * ql * kl;
assert_eq!(m.len(), mask_elems);
let buf = alloc_bf16(device, mask_elems, "mask");
fill_bf16_buf(&buf, m);
buf
});
let do_buf = alloc_bf16(device, q_elems, "dO");
fill_bf16_buf(&do_buf, do_bf);
let mut dq_buf = alloc_bf16(device, q_elems, "dQ");
let mut dk_buf = alloc_bf16(device, kv_elems, "dK");
let mut dv_buf = alloc_bf16(device, kv_elems, "dV");
let params = FlashAttnTrainParams {
batch: batch as u32,
n_q_heads: n_q_heads as u32,
n_kv_heads: n_kv_heads as u32,
head_dim: head_dim as u32,
q_seq_len: ql as u32,
k_seq_len: kl as u32,
scale,
causal,
};
let mut encoder = device.command_encoder().expect("command_encoder");
if head_dim == 64 {
dispatch_flash_attn_train_fwd_bf16_d64(
&mut encoder, device, registry,
&q_buf, &k_buf, &v_buf, mask_buf.as_ref(), &mut o_buf, &mut l_buf,
¶ms,
).expect("fwd d64");
} else {
dispatch_flash_attn_train_fwd_bf16_d256(
&mut encoder, device, registry,
&q_buf, &k_buf, &v_buf, mask_buf.as_ref(), &mut o_buf, &mut l_buf,
¶ms,
).expect("fwd d256");
}
encoder.memory_barrier();
if head_dim == 64 {
dispatch_flash_attn_train_bwd_bf16_d64(
&mut encoder, device, registry,
&q_buf, &k_buf, &v_buf, &o_buf, &l_buf, &do_buf,
mask_buf.as_ref(), &mut dq_buf, &mut dk_buf, &mut dv_buf,
¶ms,
).expect("bwd d64");
} else {
dispatch_flash_attn_train_bwd_bf16_d256(
&mut encoder, device, registry,
&q_buf, &k_buf, &v_buf, &o_buf, &l_buf, &do_buf,
mask_buf.as_ref(), &mut dq_buf, &mut dk_buf, &mut dv_buf,
¶ms,
).expect("bwd d256");
}
encoder.commit_and_wait().expect("commit_and_wait");
let dq_f32 = bf16_to_f32(&read_bf16_buf(&dq_buf, q_elems));
let dk_f32 = bf16_to_f32(&read_bf16_buf(&dk_buf, kv_elems));
let dv_f32 = bf16_to_f32(&read_bf16_buf(&dv_buf, kv_elems));
(dq_f32, dk_f32, dv_f32)
}
#[test]
fn test_backward_kernel_names_and_library_compiles() {
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
flash_attn_train::register(&mut registry);
flash_attn_train::register_bwd(&mut registry);
for &name in flash_attn_train::all_bwd_kernel_names_for_test() {
let result = if name == "flash_attn_train_bwd_bf16_d64"
|| name == "flash_attn_train_bwd_bf16_d256"
{
registry.get_pipeline_with_bool_constants(
name,
device.metal_device(),
&[(200, true), (201, true), (300, false), (301, false)],
)
} else {
registry.get_pipeline(name, device.metal_device())
};
match result {
Ok(_) => eprintln!("test_backward_kernel_names_and_library_compiles: {name} OK"),
Err(e) => panic!("BWD pipeline compilation failed for {name}: {e:?}"),
}
}
}
#[test]
fn test_backward_no_mask_d64_parity() {
let (device, mut registry) = setup();
flash_attn_train::register_bwd(&mut registry);
let batch = 1; let h = 1; let kv_h = 1; let ql = 32; let kl = 32; let d = 64;
let scale = 1.0 / (d as f32).sqrt();
let q_f32 = pseudo_random_f32(SEED_VAL + 100, batch * h * ql * d);
let k_f32 = pseudo_random_f32(SEED_VAL + 101, batch * kv_h * kl * d);
let v_f32 = pseudo_random_f32(SEED_VAL + 102, batch * kv_h * kl * d);
let do_f32 = pseudo_random_f32(SEED_VAL + 103, batch * h * ql * d);
let q_bf = f32_to_bf16(&q_f32);
let k_bf = f32_to_bf16(&k_f32);
let v_bf = f32_to_bf16(&v_f32);
let do_bf = f32_to_bf16(&do_f32);
let (gpu_dq, gpu_dk, gpu_dv) = run_train_bwd(
&device, &mut registry, &q_bf, &k_bf, &v_bf, None, &do_bf,
batch, h, kv_h, ql, kl, d, scale, false,
);
let q_rt = bf16_to_f32(&q_bf);
let k_rt = bf16_to_f32(&k_bf);
let v_rt = bf16_to_f32(&v_bf);
let do_rt = bf16_to_f32(&do_bf);
let (_, ref_l) = oracle_for_bf16(&q_bf, &k_bf, &v_bf, None, batch, h, kv_h, ql, kl, d, scale, false);
let (ref_dq, ref_dk, ref_dv) = sdpa_backward_reference_f32(
&q_rt, &k_rt, &v_rt, &ref_l, &do_rt, None,
batch, h, kv_h, ql, kl, d, scale, false,
);
assert_close(&gpu_dq, &ref_dq, 5e-3, "bwd_d64_no_mask_dQ");
assert_close(&gpu_dk, &ref_dk, 5e-3, "bwd_d64_no_mask_dK");
assert_close(&gpu_dv, &ref_dv, 5e-3, "bwd_d64_no_mask_dV");
}
#[test]
fn test_backward_no_mask_d256_parity() {
let (device, mut registry) = setup();
flash_attn_train::register_bwd(&mut registry);
let batch = 1; let h = 4; let kv_h = 4; let ql = 128; let kl = 128; let d = 256;
let scale = 1.0 / (d as f32).sqrt();
let q_f32 = pseudo_random_f32(SEED_VAL + 110, batch * h * ql * d);
let k_f32 = pseudo_random_f32(SEED_VAL + 111, batch * kv_h * kl * d);
let v_f32 = pseudo_random_f32(SEED_VAL + 112, batch * kv_h * kl * d);
let do_f32 = pseudo_random_f32(SEED_VAL + 113, batch * h * ql * d);
let q_bf = f32_to_bf16(&q_f32);
let k_bf = f32_to_bf16(&k_f32);
let v_bf = f32_to_bf16(&v_f32);
let do_bf = f32_to_bf16(&do_f32);
let (gpu_dq, gpu_dk, gpu_dv) = run_train_bwd(
&device, &mut registry, &q_bf, &k_bf, &v_bf, None, &do_bf,
batch, h, kv_h, ql, kl, d, scale, false,
);
let q_rt = bf16_to_f32(&q_bf);
let k_rt = bf16_to_f32(&k_bf);
let v_rt = bf16_to_f32(&v_bf);
let do_rt = bf16_to_f32(&do_bf);
let (_, ref_l) = oracle_for_bf16(&q_bf, &k_bf, &v_bf, None, batch, h, kv_h, ql, kl, d, scale, false);
let (ref_dq, ref_dk, ref_dv) = sdpa_backward_reference_f32(
&q_rt, &k_rt, &v_rt, &ref_l, &do_rt, None,
batch, h, kv_h, ql, kl, d, scale, false,
);
assert_close(&gpu_dq, &ref_dq, 5e-3, "bwd_d256_no_mask_dQ");
assert_close(&gpu_dk, &ref_dk, 5e-3, "bwd_d256_no_mask_dK");
assert_close(&gpu_dv, &ref_dv, 5e-3, "bwd_d256_no_mask_dV");
}
#[test]
fn test_backward_finite_diff_falsifier() {
let (device, mut registry) = setup();
flash_attn_train::register_bwd(&mut registry);
let batch = 1; let h = 1; let kv_h = 1; let ql = 8; let kl = 8; let d = 64;
let scale = 1.0 / (d as f32).sqrt();
let h_step = 1e-2_f32;
let q_f32 = pseudo_random_f32(SEED_VAL + 200, batch * h * ql * d);
let k_f32 = pseudo_random_f32(SEED_VAL + 201, batch * kv_h * kl * d);
let v_f32 = pseudo_random_f32(SEED_VAL + 202, batch * kv_h * kl * d);
let do_f32 = pseudo_random_f32(SEED_VAL + 203, batch * h * ql * d);
fn scalar_loss(
q: &[f32], k: &[f32], v: &[f32], do_: &[f32],
batch: usize, h: usize, kv_h: usize, ql: usize, kl: usize, d: usize, scale: f32,
) -> f32 {
let (o, _) = sdpa_reference_with_logsumexp(
q, k, v, None, batch, h, kv_h, ql, kl, d, scale, false,
);
o.iter().zip(do_.iter()).map(|(o, g)| o * g).sum()
}
let q_bf = f32_to_bf16(&q_f32);
let k_bf = f32_to_bf16(&k_f32);
let v_bf = f32_to_bf16(&v_f32);
let do_bf = f32_to_bf16(&do_f32);
let (gpu_dq, gpu_dk, gpu_dv) = run_train_bwd(
&device, &mut registry, &q_bf, &k_bf, &v_bf, None, &do_bf,
batch, h, kv_h, ql, kl, d, scale, false,
);
{
let mut q_plus = q_f32.clone(); q_plus[0] += h_step;
let mut q_minus = q_f32.clone(); q_minus[0] -= h_step;
let fd = (scalar_loss(&q_plus, &k_f32, &v_f32, &do_f32, batch, h, kv_h, ql, kl, d, scale)
- scalar_loss(&q_minus, &k_f32, &v_f32, &do_f32, batch, h, kv_h, ql, kl, d, scale))
/ (2.0 * h_step);
let analytical = gpu_dq[0];
let diff = (fd - analytical).abs();
assert!(
diff <= 5e-2_f32,
"finite_diff_falsifier: dQ[0] fd={fd:.5} analytical={analytical:.5} diff={diff:.5e}"
);
eprintln!("finite_diff_falsifier: dQ[0] fd={fd:.5} analytical={analytical:.5} diff={diff:.5e} PASS");
}
{
let mut k_plus = k_f32.clone(); k_plus[0] += h_step;
let mut k_minus = k_f32.clone(); k_minus[0] -= h_step;
let fd = (scalar_loss(&q_f32, &k_plus, &v_f32, &do_f32, batch, h, kv_h, ql, kl, d, scale)
- scalar_loss(&q_f32, &k_minus, &v_f32, &do_f32, batch, h, kv_h, ql, kl, d, scale))
/ (2.0 * h_step);
let analytical = gpu_dk[0];
let diff = (fd - analytical).abs();
assert!(
diff <= 5e-2_f32,
"finite_diff_falsifier: dK[0] fd={fd:.5} analytical={analytical:.5} diff={diff:.5e}"
);
eprintln!("finite_diff_falsifier: dK[0] fd={fd:.5} analytical={analytical:.5} diff={diff:.5e} PASS");
}
{
let mut v_plus = v_f32.clone(); v_plus[0] += h_step;
let mut v_minus = v_f32.clone(); v_minus[0] -= h_step;
let fd = (scalar_loss(&q_f32, &k_f32, &v_plus, &do_f32, batch, h, kv_h, ql, kl, d, scale)
- scalar_loss(&q_f32, &k_f32, &v_minus, &do_f32, batch, h, kv_h, ql, kl, d, scale))
/ (2.0 * h_step);
let analytical = gpu_dv[0];
let diff = (fd - analytical).abs();
assert!(
diff <= 5e-2_f32,
"finite_diff_falsifier: dV[0] fd={fd:.5} analytical={analytical:.5} diff={diff:.5e}"
);
eprintln!("finite_diff_falsifier: dV[0] fd={fd:.5} analytical={analytical:.5} diff={diff:.5e} PASS");
}
}
#[test]
fn test_backward_causal_mask_parity() {
let (device, mut registry) = setup();
flash_attn_train::register_bwd(&mut registry);
let batch = 1; let h = 1; let kv_h = 1; let ql = 64; let kl = 64; let d = 64;
let scale = 1.0 / (d as f32).sqrt();
let q_f32 = pseudo_random_f32(SEED_VAL + 300, batch * h * ql * d);
let k_f32 = pseudo_random_f32(SEED_VAL + 301, batch * kv_h * kl * d);
let v_f32 = pseudo_random_f32(SEED_VAL + 302, batch * kv_h * kl * d);
let do_f32 = pseudo_random_f32(SEED_VAL + 303, batch * h * ql * d);
let q_bf = f32_to_bf16(&q_f32);
let k_bf = f32_to_bf16(&k_f32);
let v_bf = f32_to_bf16(&v_f32);
let do_bf = f32_to_bf16(&do_f32);
let (gpu_dq, gpu_dk, gpu_dv) = run_train_bwd(
&device, &mut registry, &q_bf, &k_bf, &v_bf, None, &do_bf,
batch, h, kv_h, ql, kl, d, scale, true, );
let q_rt = bf16_to_f32(&q_bf);
let k_rt = bf16_to_f32(&k_bf);
let v_rt = bf16_to_f32(&v_bf);
let do_rt = bf16_to_f32(&do_bf);
let (_, ref_l) = oracle_for_bf16(&q_bf, &k_bf, &v_bf, None, batch, h, kv_h, ql, kl, d, scale, true);
let (ref_dq, ref_dk, ref_dv) = sdpa_backward_reference_f32(
&q_rt, &k_rt, &v_rt, &ref_l, &do_rt, None,
batch, h, kv_h, ql, kl, d, scale, true,
);
assert_close(&gpu_dq, &ref_dq, 5e-3, "causal_bwd_dQ");
assert_close(&gpu_dk, &ref_dk, 5e-3, "causal_bwd_dK");
assert_close(&gpu_dv, &ref_dv, 5e-3, "causal_bwd_dV");
}
#[test]
fn test_backward_sliding_window_mask_parity() {
let (device, mut registry) = setup();
flash_attn_train::register_bwd(&mut registry);
let batch = 1; let h = 1; let kv_h = 1; let ql = 32; let kl = 32; let d = 64;
let window = 8usize;
let scale = 1.0 / (d as f32).sqrt();
let q_f32 = pseudo_random_f32(SEED_VAL + 400, batch * h * ql * d);
let k_f32 = pseudo_random_f32(SEED_VAL + 401, batch * kv_h * kl * d);
let v_f32 = pseudo_random_f32(SEED_VAL + 402, batch * kv_h * kl * d);
let do_f32 = pseudo_random_f32(SEED_VAL + 403, batch * h * ql * d);
let mut mask_f32 = vec![f32::NEG_INFINITY; batch * h * ql * kl];
for b in 0..batch { for hh in 0..h { for i in 0..ql { for j in 0..kl {
let in_window = (j <= i) && (j + window >= i);
let idx = b * h * ql * kl + hh * ql * kl + i * kl + j;
if in_window { mask_f32[idx] = 0.0_f32; }
} } } }
let mask_bf = f32_to_bf16(&mask_f32);
let q_bf = f32_to_bf16(&q_f32);
let k_bf = f32_to_bf16(&k_f32);
let v_bf = f32_to_bf16(&v_f32);
let do_bf = f32_to_bf16(&do_f32);
let (gpu_dq, gpu_dk, gpu_dv) = run_train_bwd(
&device, &mut registry, &q_bf, &k_bf, &v_bf, Some(&mask_bf), &do_bf,
batch, h, kv_h, ql, kl, d, scale, false,
);
let q_rt = bf16_to_f32(&q_bf);
let k_rt = bf16_to_f32(&k_bf);
let v_rt = bf16_to_f32(&v_bf);
let do_rt = bf16_to_f32(&do_bf);
let mask_rt = bf16_to_f32(&mask_bf);
let (_, ref_l) = oracle_for_bf16(&q_bf, &k_bf, &v_bf, Some(&mask_bf), batch, h, kv_h, ql, kl, d, scale, false);
let (ref_dq, ref_dk, ref_dv) = sdpa_backward_reference_f32(
&q_rt, &k_rt, &v_rt, &ref_l, &do_rt, Some(&mask_rt),
batch, h, kv_h, ql, kl, d, scale, false,
);
assert_close(&gpu_dq, &ref_dq, 5e-3, "swa_bwd_dQ");
assert_close(&gpu_dk, &ref_dk, 5e-3, "swa_bwd_dK");
assert_close(&gpu_dv, &ref_dv, 5e-3, "swa_bwd_dV");
}
#[test]
fn test_backward_gqa_accumulation() {
let (device, mut registry) = setup();
flash_attn_train::register_bwd(&mut registry);
let batch = 1; let h = 8; let kv_h = 2; let ql = 32; let kl = 32; let d = 64;
let scale = 1.0 / (d as f32).sqrt();
let q_f32 = pseudo_random_f32(SEED_VAL + 500, batch * h * ql * d);
let k_f32 = pseudo_random_f32(SEED_VAL + 501, batch * kv_h * kl * d);
let v_f32 = pseudo_random_f32(SEED_VAL + 502, batch * kv_h * kl * d);
let do_f32 = pseudo_random_f32(SEED_VAL + 503, batch * h * ql * d);
let q_bf = f32_to_bf16(&q_f32);
let k_bf = f32_to_bf16(&k_f32);
let v_bf = f32_to_bf16(&v_f32);
let do_bf = f32_to_bf16(&do_f32);
let (gpu_dq, gpu_dk, gpu_dv) = run_train_bwd(
&device, &mut registry, &q_bf, &k_bf, &v_bf, None, &do_bf,
batch, h, kv_h, ql, kl, d, scale, false,
);
let q_rt = bf16_to_f32(&q_bf);
let k_rt = bf16_to_f32(&k_bf);
let v_rt = bf16_to_f32(&v_bf);
let do_rt = bf16_to_f32(&do_bf);
let (_, ref_l) = oracle_for_bf16(&q_bf, &k_bf, &v_bf, None, batch, h, kv_h, ql, kl, d, scale, false);
let (ref_dq, ref_dk, ref_dv) = sdpa_backward_reference_f32(
&q_rt, &k_rt, &v_rt, &ref_l, &do_rt, None,
batch, h, kv_h, ql, kl, d, scale, false,
);
assert_close(&gpu_dq, &ref_dq, 5e-3, "gqa_bwd_dQ");
assert_close(&gpu_dk, &ref_dk, 5e-3, "gqa_bwd_dK");
assert_close(&gpu_dv, &ref_dv, 5e-3, "gqa_bwd_dV");
}
}