use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::encoder::{CapturedOpKind, CommandEncoder, KernelArg, as_bytes};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use crate::DType;
pub use super::flash_attn_prefill::{
AttnMaskParamsGpu, AttnParamsGpu, FlashAttnPrefillParams,
};
pub static FLASH_ATTN_PREFILL_D512_SHADER_SOURCE: &str =
include_str!("../shaders/flash_attn_prefill_d512.metal");
pub const K_LLAMACPP_BF16_D512: &str = "flash_attn_prefill_llamacpp_bf16_d512";
pub const K_LLAMACPP_BF16_D512_BOOLMASK: &str =
"flash_attn_prefill_llamacpp_bf16_d512_boolmask";
pub const K_LLAMACPP_F16_D512: &str = "flash_attn_prefill_llamacpp_f16_d512";
pub const K_LLAMACPP_F16_D512_BOOLMASK: &str =
"flash_attn_prefill_llamacpp_f16_d512_boolmask";
pub const ALL_KERNEL_NAMES: &[&str] = &[
K_LLAMACPP_BF16_D512,
K_LLAMACPP_BF16_D512_BOOLMASK,
K_LLAMACPP_F16_D512,
K_LLAMACPP_F16_D512_BOOLMASK,
];
pub fn register(registry: &mut KernelRegistry) {
for &name in ALL_KERNEL_NAMES {
registry.register_source(name, FLASH_ATTN_PREFILL_D512_SHADER_SOURCE);
}
}
pub const NQPSG_D512: u32 = 8;
pub const NCPSG_D512: u32 = 64;
pub const NSG_D512: u32 = 8;
pub const FC_IDX_NSG: usize = 322;
pub const TGMEM_BYTES_D512: u32 = 28_672;
fn validate_params_d512(params: &FlashAttnPrefillParams) -> Result<()> {
if params.n_heads == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_prefill_d512: n_heads must be > 0".into(),
));
}
if params.n_kv_heads == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_prefill_d512: n_kv_heads must be > 0".into(),
));
}
if params.n_heads % params.n_kv_heads != 0 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_prefill_d512: n_heads ({}) must be divisible by n_kv_heads ({})",
params.n_heads, params.n_kv_heads
)));
}
if params.seq_len_q == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_prefill_d512: seq_len_q must be > 0".into(),
));
}
if params.seq_len_k == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_prefill_d512: seq_len_k must be > 0".into(),
));
}
if params.batch == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_prefill_d512: batch must be > 0".into(),
));
}
Ok(())
}
fn validate_buffer_size(buf: &MlxBuffer, name: &str, expected_elements: usize) -> Result<()> {
let expected_bytes = expected_elements * buf.dtype().size_of();
if buf.byte_len() < expected_bytes {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_prefill_d512: {name} buffer too small: expected at least \
{expected_bytes} bytes, got {}",
buf.byte_len()
)));
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_flash_attn_prefill_bf16_d512(
encoder: &mut CommandEncoder,
device: &MlxDevice,
registry: &mut KernelRegistry,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
mask: Option<&MlxBuffer>,
out: &MlxBuffer,
params: &FlashAttnPrefillParams,
) -> Result<()> {
dispatch_flash_attn_prefill_bf16_d512_with_blk(
encoder, device, registry, q, k, v, mask, None, out, params,
)
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_flash_attn_prefill_bf16_d512_with_blk(
encoder: &mut CommandEncoder,
device: &MlxDevice,
registry: &mut KernelRegistry,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
mask: Option<&MlxBuffer>,
blk: Option<&MlxBuffer>,
out: &MlxBuffer,
params: &FlashAttnPrefillParams,
) -> Result<()> {
dispatch_flash_attn_prefill_bf16_d512_with_nsg_and_blk(
encoder, device, registry, q, k, v, mask, blk, out, params, NSG_D512,
)
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_flash_attn_prefill_bf16_d512_with_nsg(
encoder: &mut CommandEncoder,
device: &MlxDevice,
registry: &mut KernelRegistry,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
mask: Option<&MlxBuffer>,
out: &MlxBuffer,
params: &FlashAttnPrefillParams,
nsg: u32,
) -> Result<()> {
dispatch_flash_attn_prefill_bf16_d512_with_nsg_and_blk(
encoder, device, registry, q, k, v, mask, None, out, params, nsg,
)
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_flash_attn_prefill_bf16_d512_with_nsg_and_blk(
encoder: &mut CommandEncoder,
device: &MlxDevice,
registry: &mut KernelRegistry,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
mask: Option<&MlxBuffer>,
blk: Option<&MlxBuffer>,
out: &MlxBuffer,
params: &FlashAttnPrefillParams,
nsg: u32,
) -> Result<()>
{
if params.head_dim != 512 {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_bf16_d512: head_dim must be 512, got {}",
params.head_dim
)));
}
if nsg != 4 && nsg != 8 {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_bf16_d512: nsg must be 4 or 8, got {nsg}"
)));
}
if blk.is_some() && mask.is_none() {
return Err(MlxError::InvalidArgument(
"dispatch_flash_attn_prefill_bf16_d512: \
blk requires mask (a blk without a mask is meaningless)"
.into(),
));
}
validate_params_d512(params)?;
for (buf, name) in &[(q, "Q"), (k, "K"), (v, "V"), (out as &MlxBuffer, "out")] {
if buf.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_bf16_d512: {name} buffer must be BF16, \
got {:?}",
buf.dtype()
)));
}
}
if let Some(m) = mask {
if m.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_bf16_d512: mask buffer must be BF16, \
got {:?}",
m.dtype()
)));
}
}
let batch = params.batch as usize;
let h = params.n_heads as usize;
let h_kv = params.n_kv_heads as usize;
let ql = params.seq_len_q as usize;
let kl = params.seq_len_k as usize;
let d = params.head_dim as usize;
validate_buffer_size(q, "Q", batch * h * ql * d)?;
validate_buffer_size(k, "K", batch * h_kv * kl * d)?;
validate_buffer_size(v, "V", batch * h_kv * kl * d)?;
validate_buffer_size(out, "out", batch * h * ql * d)?;
let mask_is_rank2_broadcast = mask.is_some_and(|m| m.shape().len() == 2);
if let Some(m) = mask {
if mask_is_rank2_broadcast {
validate_buffer_size(m, "mask", ql * kl)?;
} else {
validate_buffer_size(m, "mask", batch * h * ql * kl)?;
}
}
let nqpsg = NQPSG_D512;
let ncpsg = NCPSG_D512;
let nq = params.seq_len_q.div_ceil(nqpsg);
let nk = params.seq_len_k.div_ceil(ncpsg);
let nq_aligned = params.seq_len_q / nqpsg;
let nk_aligned = params.seq_len_k / ncpsg;
let ql_rem = params.seq_len_q % nqpsg;
let kl_rem = params.seq_len_k % ncpsg;
let align_q = ql_rem == 0;
let align_k = kl_rem == 0;
let has_mask = mask.is_some();
let has_blk = blk.is_some();
let do_causal = params.do_causal;
let bq_main = 8_u32;
let bk_main = 64_u32;
if let Some(b) = blk {
let nq_tiles = ql.div_ceil(bq_main as usize);
let nk_tiles = kl.div_ceil(bk_main as usize);
let expected = nq_tiles * nk_tiles;
if b.byte_len() < expected {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_bf16_d512: blk buffer too small: \
expected at least {expected} bytes (NQ={nq_tiles}, \
NK={nk_tiles}), got {}",
b.byte_len()
)));
}
}
let kernel_name = K_LLAMACPP_BF16_D512;
let pipeline = registry.get_pipeline_with_constants(
kernel_name,
device.metal_device(),
&[
(200, align_q),
(201, align_k),
(300, has_mask),
(301, do_causal),
(303, has_blk),
],
&[
(FC_IDX_NSG, nsg as i32),
],
)?;
let q_seq_stride = d as i64;
let q_head_stride = (ql * d) as i64;
let q_batch_stride = (h * ql * d) as i64;
let kv_seq_stride = d as i64;
let kv_head_stride = (kl * d) as i64;
let kv_batch_stride = (h_kv * kl * d) as i64;
let gqa_factor = (params.n_heads / params.n_kv_heads) as i32;
let attn_params = AttnParamsGpu {
b: params.batch as i32,
h: params.n_heads as i32,
d: params.head_dim as i32,
ql: params.seq_len_q as i32,
kl: params.seq_len_k as i32,
gqa_factor,
scale: params.scale,
softcapping: 1.0_f32, nq: nq as i32,
nk: nk as i32,
nq_aligned: nq_aligned as i32,
nk_aligned: nk_aligned as i32,
ql_rem: ql_rem as i32,
kl_rem: kl_rem as i32,
ql_off: 0, _pad: 0,
q_strides: [q_batch_stride, q_head_stride, q_seq_stride],
k_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
v_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
o_strides: [q_batch_stride, q_head_stride, q_seq_stride],
};
let grid = MTLSize::new(
nq as u64,
params.n_heads as u64,
params.batch as u64,
);
let tg_size = MTLSize::new(32, nsg as u64, 1);
encoder.set_op_kind(CapturedOpKind::Sdpa);
let tgmem = TGMEM_BYTES_D512 as u64;
if has_mask {
let mask_buf = mask.ok_or_else(|| {
MlxError::InvalidArgument(
"flash_attn_prefill_d512: internal error — has_mask=true but mask is None".into(),
)
})?;
let (m_batch_stride, m_head_stride, m_ql_stride) = if mask_is_rank2_broadcast {
(0_i64, 0_i64, kl as i64)
} else {
((h * ql * kl) as i64, (ql * kl) as i64, kl as i64)
};
let mask_params = AttnMaskParamsGpu {
m_strides: [m_batch_stride, m_head_stride, m_ql_stride],
};
if has_blk {
let blk_buf = blk.ok_or_else(|| {
MlxError::InvalidArgument(
"flash_attn_prefill_d512: internal error — has_blk=true but blk is None".into(),
)
})?;
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Buffer(q)),
(1, KernelArg::Buffer(k)),
(2, KernelArg::Buffer(v)),
(3, KernelArg::Buffer(out)),
(4, KernelArg::Bytes(as_bytes(&attn_params))),
(5, KernelArg::Bytes(as_bytes(&mask_params))),
(6, KernelArg::Buffer(mask_buf)),
(7, KernelArg::Buffer(blk_buf)),
],
&[(0, tgmem)],
grid,
tg_size,
);
} else {
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Buffer(q)),
(1, KernelArg::Buffer(k)),
(2, KernelArg::Buffer(v)),
(3, KernelArg::Buffer(out)),
(4, KernelArg::Bytes(as_bytes(&attn_params))),
(5, KernelArg::Bytes(as_bytes(&mask_params))),
(6, KernelArg::Buffer(mask_buf)),
],
&[(0, tgmem)],
grid,
tg_size,
);
}
} else {
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Buffer(q)),
(1, KernelArg::Buffer(k)),
(2, KernelArg::Buffer(v)),
(3, KernelArg::Buffer(out)),
(4, KernelArg::Bytes(as_bytes(&attn_params))),
],
&[(0, tgmem)],
grid,
tg_size,
);
}
Ok(())
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_tile_geometry_d512() {
assert_eq!(NQPSG_D512, 8, "NQPSG=8 for D=512 (llama.cpp-impl.h:93)");
assert_eq!(NCPSG_D512, 64, "NCPSG=64 for D=512 (llama.cpp-impl.h:94)");
assert_eq!(NSG_D512, 8, "NSG=8 for D=512 (llama.cpp-ops.cpp:2807)");
assert_eq!(32 * NSG_D512, 256);
}
#[test]
fn test_threadgroup_memory_matches_llamacpp() {
assert_eq!(TGMEM_BYTES_D512, 28_672);
}
#[test]
fn test_fc_idx_nsg_matches_llamacpp() {
assert_eq!(FC_IDX_NSG, 322);
}
#[test]
fn test_four_kernel_names_registered() {
assert_eq!(ALL_KERNEL_NAMES.len(), 4);
let mut seen = std::collections::HashSet::new();
for &name in ALL_KERNEL_NAMES {
assert!(!name.is_empty());
assert!(seen.insert(name), "duplicate name: {name}");
assert!(
name.starts_with("flash_attn_prefill_llamacpp_"),
"name must be prefixed with llamacpp marker: {name}"
);
assert!(name.contains("d512"), "all D=512 names must contain d512: {name}");
}
}
#[test]
fn test_validate_params_d512_wrong_head_dim() {
let p = FlashAttnPrefillParams {
n_heads: 2,
n_kv_heads: 2,
head_dim: 256, seq_len_q: 8,
seq_len_k: 8,
batch: 1,
scale: 1.0,
do_causal: false,
};
assert!(validate_params_d512(&p).is_ok());
}
#[test]
fn test_validate_params_d512_ok() {
let p = FlashAttnPrefillParams {
n_heads: 4,
n_kv_heads: 2,
head_dim: 512,
seq_len_q: 128,
seq_len_k: 128,
batch: 1,
scale: 1.0 / 512.0_f32.sqrt(),
do_causal: true,
};
assert!(validate_params_d512(&p).is_ok());
}
#[test]
fn test_validate_params_d512_zero_heads() {
let p = FlashAttnPrefillParams {
n_heads: 0,
n_kv_heads: 2,
head_dim: 512,
seq_len_q: 8,
seq_len_k: 8,
batch: 1,
scale: 1.0,
do_causal: false,
};
assert!(matches!(
validate_params_d512(&p),
Err(MlxError::InvalidArgument(_))
));
}
}