use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::encoder::{as_bytes, CapturedOpKind, CommandEncoder, KernelArg};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use crate::DType;
pub static FLASH_ATTN_VEC_SHADER_SOURCE: &str =
include_str!("../shaders/flash_attn_vec.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("flash_attn_vec_dk256", FLASH_ATTN_VEC_SHADER_SOURCE);
registry.register_source("flash_attn_vec_dk512", FLASH_ATTN_VEC_SHADER_SOURCE);
registry.register_source("flash_attn_vec_reduce_dk256", FLASH_ATTN_VEC_SHADER_SOURCE);
registry.register_source("flash_attn_vec_reduce_dk512", FLASH_ATTN_VEC_SHADER_SOURCE);
registry.register_source("flash_attn_vec_f16kv_dk256", FLASH_ATTN_VEC_SHADER_SOURCE);
registry.register_source("flash_attn_vec_f16kv_dk512", FLASH_ATTN_VEC_SHADER_SOURCE);
}
#[derive(Debug, Clone, Copy)]
pub struct FlashAttnVecParams {
pub num_heads: u32,
pub num_kv_heads: u32,
pub head_dim: u32,
pub kv_seq_len: u32,
pub kv_capacity: u32,
pub scale: f32,
pub mask_type: u32,
pub sliding_window: u32,
pub softcap: f32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct FlashAttnVecParamsGpu {
n_heads: u32,
n_kv_heads: u32,
head_dim: u32,
kv_seq_len: u32,
kv_capacity: u32,
scale: f32,
mask_type: u32,
sliding_window: u32,
softcap: f32,
nwg: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct FlashAttnVecReduceParamsGpu {
nrows: u32,
}
const NWG: u32 = 32;
fn validate_params(params: &FlashAttnVecParams) -> Result<()> {
if params.head_dim != 256 && params.head_dim != 512 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec: head_dim must be 256 or 512, got {}",
params.head_dim
)));
}
if params.num_heads == 0 || params.num_kv_heads == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_vec: num_heads and num_kv_heads must be > 0".into(),
));
}
if params.num_heads % params.num_kv_heads != 0 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec: num_heads ({}) must be divisible by num_kv_heads ({})",
params.num_heads, params.num_kv_heads
)));
}
if params.kv_seq_len == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_vec: kv_seq_len must be > 0".into(),
));
}
if params.kv_capacity < params.kv_seq_len {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec: kv_capacity ({}) must be >= kv_seq_len ({})",
params.kv_capacity, params.kv_seq_len
)));
}
Ok(())
}
pub fn flash_attn_vec(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
output: &MlxBuffer,
tmp: &MlxBuffer,
params: &FlashAttnVecParams,
) -> Result<()> {
validate_params(params)?;
let head_dim = params.head_dim;
let nwg = NWG;
let gpu_params = FlashAttnVecParamsGpu {
n_heads: params.num_heads,
n_kv_heads: params.num_kv_heads,
head_dim: params.head_dim,
kv_seq_len: params.kv_seq_len,
kv_capacity: params.kv_capacity,
scale: params.scale,
mask_type: params.mask_type,
sliding_window: params.sliding_window,
softcap: params.softcap,
nwg,
};
let kv_is_f16 = k.dtype() == DType::F16;
let kernel_name = match (head_dim, kv_is_f16) {
(256, false) => "flash_attn_vec_dk256",
(512, false) => "flash_attn_vec_dk512",
(256, true) => "flash_attn_vec_f16kv_dk256",
(512, true) => "flash_attn_vec_f16kv_dk512",
_ => unreachable!(), };
let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
let pk = pad2(head_dim as usize, 128);
let pv = pad2(head_dim as usize, 128);
let sh = 4 * 32; let shmem_halfs = pk + sh + 2 * pv;
let shmem_bytes = shmem_halfs * 2;
encoder.set_op_kind(CapturedOpKind::Sdpa);
let threadgroups = MTLSize::new(1, params.num_heads as u64, nwg as u64);
let threadgroup_size = MTLSize::new(32, 1, 1);
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(q)),
(2, KernelArg::Buffer(k)),
(3, KernelArg::Buffer(v)),
(4, KernelArg::Buffer(tmp)),
],
&[(0, shmem_bytes as u64)],
threadgroups,
threadgroup_size,
);
if nwg > 1 {
encoder.memory_barrier();
let reduce_params = FlashAttnVecReduceParamsGpu {
nrows: params.num_heads,
};
let reduce_kernel = match head_dim {
256 => "flash_attn_vec_reduce_dk256",
512 => "flash_attn_vec_reduce_dk512",
_ => unreachable!(),
};
let reduce_pipeline =
registry.get_pipeline(reduce_kernel, device.metal_device())?;
let reduce_tg = MTLSize::new(params.num_heads as u64, 1, 1);
let reduce_tg_size = MTLSize::new(32 * nwg as u64, 1, 1);
{
let read_ranges = vec![
{
let s = tmp.contents_ptr() as usize;
(s, s + tmp.byte_len())
},
];
let write_ranges = vec![
{
let s = output.contents_ptr() as usize;
(s, s + output.byte_len())
},
];
encoder.set_pending_buffer_ranges(read_ranges, write_ranges);
}
encoder.encode_threadgroups_with_args(
reduce_pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&reduce_params))),
(1, KernelArg::Buffer(tmp)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(&nwg))),
],
reduce_tg,
reduce_tg_size,
);
}
Ok(())
}
pub fn tmp_buffer_bytes(num_heads: u32, head_dim: u32) -> usize {
let nrows = num_heads as usize;
let nwg = NWG as usize;
let dv = head_dim as usize;
(nrows * nwg * (dv + 2)) * std::mem::size_of::<f32>()
}
fn pad2(x: usize, n: usize) -> usize {
(x + n - 1) & !(n - 1)
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_validate_params_ok() {
let p = FlashAttnVecParams {
num_heads: 16,
num_kv_heads: 8,
head_dim: 256,
kv_seq_len: 100,
kv_capacity: 1024,
scale: 1.0,
mask_type: 1,
sliding_window: 0,
softcap: 0.0,
};
assert!(validate_params(&p).is_ok());
}
#[test]
fn test_validate_params_bad_head_dim() {
let p = FlashAttnVecParams {
num_heads: 16,
num_kv_heads: 8,
head_dim: 128,
kv_seq_len: 100,
kv_capacity: 1024,
scale: 1.0,
mask_type: 0,
sliding_window: 0,
softcap: 0.0,
};
assert!(validate_params(&p).is_err());
}
#[test]
fn test_gpu_params_layout() {
assert_eq!(
std::mem::size_of::<FlashAttnVecParamsGpu>(),
40, );
}
#[test]
fn test_tmp_buffer_size() {
let bytes = tmp_buffer_bytes(16, 256);
assert_eq!(bytes, 16 * 32 * 258 * 4);
}
}