use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::encoder::{CapturedOpKind, CommandEncoder, KernelArg, as_bytes};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use crate::DType;
pub static FLASH_ATTN_PREFILL_SHADER_SOURCE: &str =
include_str!("../shaders/flash_attn_prefill.metal");
const K_BF16_D256: &str = "flash_attn_prefill_bf16_d256";
const K_BF16_D256_BOOLMASK: &str = "flash_attn_prefill_bf16_d256_boolmask";
const K_F16_D256: &str = "flash_attn_prefill_f16_d256";
const K_F16_D256_BOOLMASK: &str = "flash_attn_prefill_f16_d256_boolmask";
const K_BF16_D512: &str = "flash_attn_prefill_bf16_d512";
const K_BF16_D512_BOOLMASK: &str = "flash_attn_prefill_bf16_d512_boolmask";
const K_F16_D512: &str = "flash_attn_prefill_f16_d512";
const K_F16_D512_BOOLMASK: &str = "flash_attn_prefill_f16_d512_boolmask";
const K_BF16_D64: &str = "flash_attn_prefill_bf16_d64";
const K_BF16_D64_BOOLMASK: &str = "flash_attn_prefill_bf16_d64_boolmask";
const K_F16_D64: &str = "flash_attn_prefill_f16_d64";
const K_F16_D64_BOOLMASK: &str = "flash_attn_prefill_f16_d64_boolmask";
const ALL_KERNEL_NAMES: &[&str] = &[
K_BF16_D256,
K_BF16_D256_BOOLMASK,
K_F16_D256,
K_F16_D256_BOOLMASK,
K_BF16_D512,
K_BF16_D512_BOOLMASK,
K_F16_D512,
K_F16_D512_BOOLMASK,
K_BF16_D64,
K_BF16_D64_BOOLMASK,
K_F16_D64,
K_F16_D64_BOOLMASK,
];
pub fn register(registry: &mut KernelRegistry) {
for &name in ALL_KERNEL_NAMES {
registry.register_source(name, FLASH_ATTN_PREFILL_SHADER_SOURCE);
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
pub struct AttnParamsGpu {
pub b: i32,
pub h: i32,
pub d: i32,
pub ql: i32,
pub kl: i32,
pub gqa_factor: i32,
pub scale: f32,
pub softcapping: f32,
pub nq: i32,
pub nk: i32,
pub nq_aligned: i32,
pub nk_aligned: i32,
pub ql_rem: i32,
pub kl_rem: i32,
pub ql_off: i32,
pub _pad: i32,
pub q_strides: [i64; 3],
pub k_strides: [i64; 3],
pub v_strides: [i64; 3],
pub o_strides: [i64; 3],
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
pub struct AttnMaskParamsGpu {
pub m_strides: [i64; 3],
}
#[derive(Debug, Clone, Copy)]
pub struct FlashAttnPrefillParams {
pub n_heads: u32,
pub n_kv_heads: u32,
pub head_dim: u32,
pub seq_len_q: u32,
pub seq_len_k: u32,
pub batch: u32,
pub scale: f32,
pub do_causal: bool,
}
const BQ_D256: u32 = 32;
const BK_D256: u32 = 16;
const WM_D256: u32 = 4;
const WN_D256: u32 = 1;
const BQ_D64: u32 = 32;
const BK_D64: u32 = 16;
const WM_D64: u32 = 4;
const WN_D64: u32 = 1;
fn validate_params(params: &FlashAttnPrefillParams) -> Result<()> {
if params.n_heads == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_prefill: n_heads must be > 0".into(),
));
}
if params.n_kv_heads == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_prefill: n_kv_heads must be > 0".into(),
));
}
if params.n_heads % params.n_kv_heads != 0 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_prefill: n_heads ({}) must be divisible by n_kv_heads ({})",
params.n_heads, params.n_kv_heads
)));
}
if params.seq_len_q == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_prefill: seq_len_q must be > 0".into(),
));
}
if params.seq_len_k == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_prefill: seq_len_k must be > 0".into(),
));
}
if params.batch == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_prefill: batch must be > 0".into(),
));
}
Ok(())
}
fn validate_buffer_size(buf: &MlxBuffer, name: &str, expected_elements: usize) -> Result<()> {
let expected_bytes = expected_elements * buf.dtype().size_of();
if buf.byte_len() < expected_bytes {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_prefill: {name} buffer too small: expected at least \
{expected_bytes} bytes, got {}",
buf.byte_len()
)));
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_flash_attn_prefill_bf16_d256(
encoder: &mut CommandEncoder,
device: &MlxDevice,
registry: &mut KernelRegistry,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
mask: Option<&MlxBuffer>,
out: &MlxBuffer,
params: &FlashAttnPrefillParams,
) -> Result<()> {
dispatch_flash_attn_prefill_bf16_d256_with_blk(
encoder, device, registry, q, k, v, mask, None, out, params,
)
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_flash_attn_prefill_bf16_d256_with_blk(
encoder: &mut CommandEncoder,
device: &MlxDevice,
registry: &mut KernelRegistry,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
mask: Option<&MlxBuffer>,
blk: Option<&MlxBuffer>,
out: &MlxBuffer,
params: &FlashAttnPrefillParams,
) -> Result<()> {
if params.head_dim != 256 {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_bf16_d256: head_dim must be 256, got {}",
params.head_dim
)));
}
if blk.is_some() && mask.is_none() {
return Err(MlxError::InvalidArgument(
"dispatch_flash_attn_prefill_bf16_d256_with_blk: \
blk requires mask (a blk without a mask is meaningless)"
.into(),
));
}
validate_params(params)?;
for (buf, name) in &[(q, "Q"), (k, "K"), (v, "V"), (out as &MlxBuffer, "out")] {
if buf.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_bf16_d256: {name} buffer must be BF16, \
got {:?}",
buf.dtype()
)));
}
}
if let Some(m) = mask {
if m.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_bf16_d256: mask buffer must be BF16, \
got {:?}",
m.dtype()
)));
}
}
let batch = params.batch as usize;
let h = params.n_heads as usize;
let h_kv = params.n_kv_heads as usize;
let ql = params.seq_len_q as usize;
let kl = params.seq_len_k as usize;
let d = params.head_dim as usize;
validate_buffer_size(q, "Q", batch * h * ql * d)?;
validate_buffer_size(k, "K", batch * h_kv * kl * d)?;
validate_buffer_size(v, "V", batch * h_kv * kl * d)?;
validate_buffer_size(out, "out", batch * h * ql * d)?;
let mask_is_rank2_broadcast = mask.is_some_and(|m| m.shape().len() == 2);
if let Some(m) = mask {
if mask_is_rank2_broadcast {
validate_buffer_size(m, "mask", ql * kl)?;
} else {
validate_buffer_size(m, "mask", batch * h * ql * kl)?;
}
}
let bq = BQ_D256;
let bk = BK_D256;
let wm = WM_D256;
let wn = WN_D256;
let nq = params.seq_len_q.div_ceil(bq);
let nk = params.seq_len_k.div_ceil(bk);
let nq_aligned = params.seq_len_q / bq;
let nk_aligned = params.seq_len_k / bk;
let ql_rem = params.seq_len_q % bq;
let kl_rem = params.seq_len_k % bk;
let align_q = ql_rem == 0;
let align_k = kl_rem == 0;
let has_mask = mask.is_some();
let has_blk = blk.is_some();
let do_causal = params.do_causal;
if let Some(b) = blk {
let nq_tiles = ql.div_ceil(BQ_D256 as usize);
let nk_tiles = kl.div_ceil(BK_D256 as usize);
let expected = nq_tiles * nk_tiles;
if b.byte_len() < expected {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_bf16_d256_with_blk: blk buffer \
too small: expected at least {expected} bytes (NQ={nq_tiles}, \
NK={nk_tiles}), got {}",
b.byte_len()
)));
}
}
let kernel_name = K_BF16_D256;
let pipeline = registry.get_pipeline_with_bool_constants(
kernel_name,
device.metal_device(),
&[
(200, align_q),
(201, align_k),
(300, has_mask),
(301, do_causal),
(303, has_blk),
],
)?;
let q_seq_stride = d as i64;
let q_head_stride = (ql * d) as i64;
let q_batch_stride = (h * ql * d) as i64;
let kv_seq_stride = d as i64;
let kv_head_stride = (kl * d) as i64;
let kv_batch_stride = (h_kv * kl * d) as i64;
let gqa_factor = (params.n_heads / params.n_kv_heads) as i32;
let attn_params = AttnParamsGpu {
b: params.batch as i32,
h: params.n_heads as i32,
d: params.head_dim as i32,
ql: params.seq_len_q as i32,
kl: params.seq_len_k as i32,
gqa_factor,
scale: params.scale,
softcapping: 1.0_f32, nq: nq as i32,
nk: nk as i32,
nq_aligned: nq_aligned as i32,
nk_aligned: nk_aligned as i32,
ql_rem: ql_rem as i32,
kl_rem: kl_rem as i32,
ql_off: 0, _pad: 0,
q_strides: [q_batch_stride, q_head_stride, q_seq_stride],
k_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
v_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
o_strides: [q_batch_stride, q_head_stride, q_seq_stride],
};
let grid = MTLSize::new(nq as u64, params.n_heads as u64, params.batch as u64);
let tg_size = MTLSize::new(32, wm as u64, wn as u64);
encoder.set_op_kind(CapturedOpKind::Sdpa);
if has_mask {
let mask_buf = mask.ok_or_else(|| {
MlxError::InvalidArgument(
"flash_attn_prefill: internal error — has_mask=true but mask is None".into(),
)
})?;
let (m_batch_stride, m_head_stride, m_ql_stride) = if mask_is_rank2_broadcast {
(0_i64, 0_i64, kl as i64)
} else {
((h * ql * kl) as i64, (ql * kl) as i64, kl as i64)
};
let mask_params = AttnMaskParamsGpu {
m_strides: [m_batch_stride, m_head_stride, m_ql_stride],
};
if has_blk {
let blk_buf = blk.ok_or_else(|| {
MlxError::InvalidArgument(
"flash_attn_prefill: internal error — has_blk=true but blk is None".into(),
)
})?;
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(q)),
(1, KernelArg::Buffer(k)),
(2, KernelArg::Buffer(v)),
(3, KernelArg::Buffer(out)),
(4, KernelArg::Bytes(as_bytes(&attn_params))),
(5, KernelArg::Bytes(as_bytes(&mask_params))),
(6, KernelArg::Buffer(mask_buf)),
(7, KernelArg::Buffer(blk_buf)),
],
grid,
tg_size,
);
} else {
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(q)),
(1, KernelArg::Buffer(k)),
(2, KernelArg::Buffer(v)),
(3, KernelArg::Buffer(out)),
(4, KernelArg::Bytes(as_bytes(&attn_params))),
(5, KernelArg::Bytes(as_bytes(&mask_params))),
(6, KernelArg::Buffer(mask_buf)),
],
grid,
tg_size,
);
}
} else {
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(q)),
(1, KernelArg::Buffer(k)),
(2, KernelArg::Buffer(v)),
(3, KernelArg::Buffer(out)),
(4, KernelArg::Bytes(as_bytes(&attn_params))),
],
grid,
tg_size,
);
}
Ok(())
}
#[derive(Debug, Clone, Copy)]
pub struct FlashAttnPrefillResumeParams {
pub n_heads: u32,
pub n_kv_heads: u32,
pub head_dim: u32,
pub seq_len_q: u32,
pub seq_len_k: u32,
pub batch: u32,
pub scale: f32,
pub do_causal: bool,
pub q_offset_in_k: u32,
pub kv_capacity: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_flash_attn_prefill_bf16_d256_resume(
encoder: &mut CommandEncoder,
device: &MlxDevice,
registry: &mut KernelRegistry,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
out: &MlxBuffer,
params: &FlashAttnPrefillResumeParams,
) -> Result<()> {
if params.head_dim != 256 {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_bf16_d256_resume: head_dim must be 256, got {}",
params.head_dim
)));
}
if params.n_heads == 0
|| params.n_kv_heads == 0
|| params.seq_len_q == 0
|| params.seq_len_k == 0
|| params.batch == 0
{
return Err(MlxError::InvalidArgument(
"dispatch_flash_attn_prefill_bf16_d256_resume: \
n_heads/n_kv_heads/seq_len_q/seq_len_k/batch must all be > 0"
.into(),
));
}
if params.n_heads % params.n_kv_heads != 0 {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_bf16_d256_resume: n_heads ({}) must \
be divisible by n_kv_heads ({})",
params.n_heads, params.n_kv_heads
)));
}
if params.q_offset_in_k + params.seq_len_q > params.seq_len_k {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_bf16_d256_resume: q_offset_in_k ({}) \
+ seq_len_q ({}) > seq_len_k ({}) — Q overshoots K",
params.q_offset_in_k, params.seq_len_q, params.seq_len_k
)));
}
if params.seq_len_k > params.kv_capacity {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_bf16_d256_resume: seq_len_k ({}) > \
kv_capacity ({}) — K/V overshoots slot capacity",
params.seq_len_k, params.kv_capacity
)));
}
for (buf, name) in &[(q, "Q"), (k, "K"), (v, "V"), (out as &MlxBuffer, "out")] {
if buf.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_bf16_d256_resume: {name} buffer \
must be BF16, got {:?}",
buf.dtype()
)));
}
}
let batch = params.batch as usize;
let h = params.n_heads as usize;
let h_kv = params.n_kv_heads as usize;
let ql = params.seq_len_q as usize;
let cap = params.kv_capacity as usize;
let d = params.head_dim as usize;
validate_buffer_size(q, "Q", batch * h * ql * d)?;
validate_buffer_size(k, "K", batch * h_kv * cap * d)?;
validate_buffer_size(v, "V", batch * h_kv * cap * d)?;
validate_buffer_size(out, "out", batch * h * ql * d)?;
let bq = BQ_D256;
let bk = BK_D256;
let wm = WM_D256;
let wn = WN_D256;
let nq = params.seq_len_q.div_ceil(bq);
let nk = params.seq_len_k.div_ceil(bk);
let nq_aligned = params.seq_len_q / bq;
let nk_aligned = params.seq_len_k / bk;
let ql_rem = params.seq_len_q % bq;
let kl_rem = params.seq_len_k % bk;
let align_q = ql_rem == 0;
let align_k = kl_rem == 0;
let has_mask = false; let has_blk = false;
let do_causal = params.do_causal;
let kernel_name = K_BF16_D256;
let pipeline = registry.get_pipeline_with_bool_constants(
kernel_name,
device.metal_device(),
&[
(200, align_q),
(201, align_k),
(300, has_mask),
(301, do_causal),
(303, has_blk),
],
)?;
let q_seq_stride = d as i64;
let q_head_stride = (ql * d) as i64;
let q_batch_stride = (h * ql * d) as i64;
let kv_seq_stride = d as i64;
let kv_head_stride = (cap * d) as i64;
let kv_batch_stride = (h_kv * cap * d) as i64;
let gqa_factor = (params.n_heads / params.n_kv_heads) as i32;
let attn_params = AttnParamsGpu {
b: params.batch as i32,
h: params.n_heads as i32,
d: params.head_dim as i32,
ql: params.seq_len_q as i32,
kl: params.seq_len_k as i32,
gqa_factor,
scale: params.scale,
softcapping: 1.0_f32,
nq: nq as i32,
nk: nk as i32,
nq_aligned: nq_aligned as i32,
nk_aligned: nk_aligned as i32,
ql_rem: ql_rem as i32,
kl_rem: kl_rem as i32,
ql_off: params.q_offset_in_k as i32, _pad: 0,
q_strides: [q_batch_stride, q_head_stride, q_seq_stride],
k_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
v_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
o_strides: [q_batch_stride, q_head_stride, q_seq_stride],
};
let grid = MTLSize::new(nq as u64, params.n_heads as u64, params.batch as u64);
let tg_size = MTLSize::new(32, wm as u64, wn as u64);
encoder.set_op_kind(CapturedOpKind::Sdpa);
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(q)),
(1, KernelArg::Buffer(k)),
(2, KernelArg::Buffer(v)),
(3, KernelArg::Buffer(out)),
(4, KernelArg::Bytes(as_bytes(&attn_params))),
],
grid,
tg_size,
);
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FlashAttnPrefillLayout {
HeadMajor,
SeqMajor,
}
impl FlashAttnPrefillLayout {
fn strides(self, n_heads: u32, seq_len: u32, head_dim: u32) -> [i64; 3] {
let h = n_heads as i64;
let l = seq_len as i64;
let d = head_dim as i64;
match self {
FlashAttnPrefillLayout::HeadMajor => [h * l * d, l * d, d],
FlashAttnPrefillLayout::SeqMajor => [l * h * d, d, h * d],
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_flash_attn_prefill_bf16_d64(
encoder: &mut CommandEncoder,
device: &MlxDevice,
registry: &mut KernelRegistry,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
mask: Option<&MlxBuffer>,
out: &MlxBuffer,
params: &FlashAttnPrefillParams,
layout: FlashAttnPrefillLayout,
) -> Result<()> {
if params.head_dim != 64 {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_bf16_d64: head_dim must be 64, got {}",
params.head_dim
)));
}
validate_params(params)?;
for (buf, name) in &[(q, "Q"), (k, "K"), (v, "V"), (out as &MlxBuffer, "out")] {
if buf.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_bf16_d64: {name} buffer must be BF16, got {:?}",
buf.dtype()
)));
}
}
if let Some(m) = mask {
if m.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_bf16_d64: mask buffer must be BF16, got {:?}",
m.dtype()
)));
}
}
let batch = params.batch as usize;
let h = params.n_heads as usize;
let h_kv = params.n_kv_heads as usize;
let ql = params.seq_len_q as usize;
let kl = params.seq_len_k as usize;
let d = params.head_dim as usize;
validate_buffer_size(q, "Q", batch * h * ql * d)?;
validate_buffer_size(k, "K", batch * h_kv * kl * d)?;
validate_buffer_size(v, "V", batch * h_kv * kl * d)?;
validate_buffer_size(out, "out", batch * h * ql * d)?;
let mask_is_rank2_broadcast = mask.is_some_and(|m| m.shape().len() == 2);
if let Some(m) = mask {
if mask_is_rank2_broadcast {
validate_buffer_size(m, "mask", ql * kl)?;
} else {
validate_buffer_size(m, "mask", batch * h * ql * kl)?;
}
}
let bq = BQ_D64;
let bk = BK_D64;
let wm = WM_D64;
let wn = WN_D64;
let nq = params.seq_len_q.div_ceil(bq);
let nk = params.seq_len_k.div_ceil(bk);
let nq_aligned = params.seq_len_q / bq;
let nk_aligned = params.seq_len_k / bk;
let ql_rem = params.seq_len_q % bq;
let kl_rem = params.seq_len_k % bk;
let align_q = ql_rem == 0;
let align_k = kl_rem == 0;
let has_mask = mask.is_some();
let has_blk = false;
let do_causal = params.do_causal;
let kernel_name = K_BF16_D64;
let pipeline = registry.get_pipeline_with_bool_constants(
kernel_name,
device.metal_device(),
&[
(200, align_q),
(201, align_k),
(300, has_mask),
(301, do_causal),
(303, has_blk),
],
)?;
let q_strides = layout.strides(params.n_heads, params.seq_len_q, params.head_dim);
let kv_strides = layout.strides(params.n_kv_heads, params.seq_len_k, params.head_dim);
let o_strides = layout.strides(params.n_heads, params.seq_len_q, params.head_dim);
let gqa_factor = (params.n_heads / params.n_kv_heads) as i32;
let attn_params = AttnParamsGpu {
b: params.batch as i32,
h: params.n_heads as i32,
d: params.head_dim as i32,
ql: params.seq_len_q as i32,
kl: params.seq_len_k as i32,
gqa_factor,
scale: params.scale,
softcapping: 1.0_f32,
nq: nq as i32,
nk: nk as i32,
nq_aligned: nq_aligned as i32,
nk_aligned: nk_aligned as i32,
ql_rem: ql_rem as i32,
kl_rem: kl_rem as i32,
ql_off: 0,
_pad: 0,
q_strides,
k_strides: kv_strides,
v_strides: kv_strides,
o_strides,
};
let grid = MTLSize::new(nq as u64, params.n_heads as u64, params.batch as u64);
let tg_size = MTLSize::new(32, wm as u64, wn as u64);
encoder.set_op_kind(CapturedOpKind::Sdpa);
if has_mask {
let mask_buf = mask.ok_or_else(|| {
MlxError::InvalidArgument(
"flash_attn_prefill_d64: internal error — has_mask=true but mask is None".into(),
)
})?;
let (m_batch_stride, m_head_stride, m_ql_stride) = if mask_is_rank2_broadcast {
(0_i64, 0_i64, kl as i64)
} else {
((h * ql * kl) as i64, (ql * kl) as i64, kl as i64)
};
let mask_params = AttnMaskParamsGpu {
m_strides: [m_batch_stride, m_head_stride, m_ql_stride],
};
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(q)),
(1, KernelArg::Buffer(k)),
(2, KernelArg::Buffer(v)),
(3, KernelArg::Buffer(out)),
(4, KernelArg::Bytes(as_bytes(&attn_params))),
(5, KernelArg::Bytes(as_bytes(&mask_params))),
(6, KernelArg::Buffer(mask_buf)),
],
grid,
tg_size,
);
} else {
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(q)),
(1, KernelArg::Buffer(k)),
(2, KernelArg::Buffer(v)),
(3, KernelArg::Buffer(out)),
(4, KernelArg::Bytes(as_bytes(&attn_params))),
],
grid,
tg_size,
);
}
Ok(())
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_attn_params_gpu_size() {
assert_eq!(std::mem::size_of::<AttnParamsGpu>(), 160);
}
#[test]
fn test_attn_mask_params_gpu_size() {
assert_eq!(std::mem::size_of::<AttnMaskParamsGpu>(), 24);
}
#[test]
fn test_validate_params_ok() {
let p = FlashAttnPrefillParams {
n_heads: 16,
n_kv_heads: 8,
head_dim: 256,
seq_len_q: 2048,
seq_len_k: 2048,
batch: 1,
scale: 1.0 / 256.0_f32.sqrt(),
do_causal: true,
};
assert!(validate_params(&p).is_ok());
}
#[test]
fn test_validate_params_zero_heads() {
let p = FlashAttnPrefillParams {
n_heads: 0,
n_kv_heads: 8,
head_dim: 256,
seq_len_q: 128,
seq_len_k: 128,
batch: 1,
scale: 1.0,
do_causal: false,
};
assert!(matches!(
validate_params(&p),
Err(MlxError::InvalidArgument(_))
));
}
#[test]
fn test_validate_params_bad_gqa_ratio() {
let p = FlashAttnPrefillParams {
n_heads: 16,
n_kv_heads: 7,
head_dim: 256,
seq_len_q: 128,
seq_len_k: 128,
batch: 1,
scale: 1.0,
do_causal: false,
};
assert!(matches!(
validate_params(&p),
Err(MlxError::InvalidArgument(_))
));
}
#[test]
fn test_wrong_head_dim_rejected() {
let p = FlashAttnPrefillParams {
n_heads: 16,
n_kv_heads: 8,
head_dim: 128, seq_len_q: 64,
seq_len_k: 64,
batch: 1,
scale: 1.0,
do_causal: false,
};
assert!(p.head_dim != 256, "test pre-condition: head_dim must not be 256");
}
#[test]
fn test_all_expected_kernel_names_registered() {
const EXPECTED: &[&str] = &[
"flash_attn_prefill_bf16_d256",
"flash_attn_prefill_bf16_d256_boolmask",
"flash_attn_prefill_f16_d256",
"flash_attn_prefill_f16_d256_boolmask",
"flash_attn_prefill_bf16_d512",
"flash_attn_prefill_bf16_d512_boolmask",
"flash_attn_prefill_f16_d512",
"flash_attn_prefill_f16_d512_boolmask",
"flash_attn_prefill_bf16_d64",
"flash_attn_prefill_bf16_d64_boolmask",
"flash_attn_prefill_f16_d64",
"flash_attn_prefill_f16_d64_boolmask",
];
let registered: std::collections::HashSet<&str> =
ALL_KERNEL_NAMES.iter().copied().collect();
let expected: std::collections::HashSet<&str> = EXPECTED.iter().copied().collect();
assert_eq!(
registered.len(),
ALL_KERNEL_NAMES.len(),
"ALL_KERNEL_NAMES contains duplicate entries"
);
for &name in ALL_KERNEL_NAMES {
assert!(!name.is_empty(), "kernel name must not be empty");
}
let missing: Vec<&str> = expected.difference(®istered).copied().collect();
assert!(
missing.is_empty(),
"expected kernel names missing from ALL_KERNEL_NAMES: {missing:?}"
);
let extra: Vec<&str> = registered.difference(&expected).copied().collect();
assert!(
extra.is_empty(),
"unexpected kernel names registered (update EXPECTED in this test): {extra:?}"
);
for &name in ALL_KERNEL_NAMES {
assert!(
!name.contains("float32"),
"f32 kernel {name} must not be registered — exceeds 32 KB TG mem limit"
);
assert!(
!name.contains("_f32_"),
"f32 kernel {name} must not be registered — exceeds 32 KB TG mem limit"
);
}
}
#[test]
fn test_tile_geometry_d256() {
assert_eq!(BQ_D256, 32, "BQ=32 for D=256");
assert_eq!(BK_D256, 16, "BK=16 for D=256");
assert_eq!(WM_D256, 4, "WM=4 for D=256");
assert_eq!(WN_D256, 1, "WN=1 for D=256");
assert_eq!(32 * WM_D256 * WN_D256, 128);
}
}