use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::encoder::{as_bytes, CapturedOpKind, CommandEncoder, KernelArg};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static FLASH_ATTN_VEC_TQ_SHADER_SOURCE: &str =
include_str!("../shaders/flash_attn_vec_tq.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("flash_attn_vec_tq_dk256", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
registry.register_source("flash_attn_vec_tq_dk512", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
}
#[derive(Debug, Clone, Copy)]
pub struct FlashAttnVecTqParams {
pub num_heads: u32,
pub num_kv_heads: u32,
pub head_dim: u32,
pub kv_seq_len: u32,
pub kv_capacity: u32,
pub scale: f32,
pub mask_type: u32,
pub sliding_window: u32,
pub softcap: f32,
pub ring_start: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct FlashAttnVecReduceParamsGpu {
nrows: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct FlashAttnVecTqParamsGpu {
n_heads: u32,
n_kv_heads: u32,
head_dim: u32,
kv_seq_len: u32,
kv_capacity: u32,
scale: f32,
mask_type: u32,
sliding_window: u32,
softcap: f32,
nwg: u32,
ring_start: u32,
}
fn validate_params(params: &FlashAttnVecTqParams) -> Result<()> {
if params.head_dim != 256 && params.head_dim != 512 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_tq: head_dim must be 256 or 512, got {}",
params.head_dim
)));
}
if params.num_heads == 0 || params.num_kv_heads == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_vec_tq: num_heads and num_kv_heads must be > 0".into(),
));
}
if params.num_heads % params.num_kv_heads != 0 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_tq: num_heads ({}) must be divisible by num_kv_heads ({})",
params.num_heads, params.num_kv_heads
)));
}
if params.kv_seq_len == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_vec_tq: kv_seq_len must be > 0".into(),
));
}
if params.kv_capacity < params.kv_seq_len {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_tq: kv_capacity ({}) must be >= kv_seq_len ({})",
params.kv_capacity, params.kv_seq_len
)));
}
Ok(())
}
fn compute_nwg(_kv_seq_len: u32) -> u32 {
if let Ok(v) = std::env::var("HF2Q_TQ_NWG") {
if let Ok(n) = v.parse::<u32>() {
if n >= 1 && n <= 32 {
return n;
}
}
}
16
}
#[allow(clippy::too_many_arguments)]
pub fn flash_attn_vec_tq(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
q: &MlxBuffer,
k_packed: &MlxBuffer,
k_norms: &MlxBuffer,
v_packed: &MlxBuffer,
v_norms: &MlxBuffer,
output: &MlxBuffer,
tmp: &MlxBuffer,
params: &FlashAttnVecTqParams,
) -> Result<()> {
validate_params(params)?;
let head_dim = params.head_dim;
let nwg = compute_nwg(params.kv_seq_len);
let gpu_params = FlashAttnVecTqParamsGpu {
n_heads: params.num_heads,
n_kv_heads: params.num_kv_heads,
head_dim: params.head_dim,
kv_seq_len: params.kv_seq_len,
kv_capacity: params.kv_capacity,
scale: params.scale,
mask_type: params.mask_type,
sliding_window: params.sliding_window,
softcap: params.softcap,
nwg,
ring_start: params.ring_start,
};
let kernel_name = match head_dim {
256 => "flash_attn_vec_tq_dk256",
512 => "flash_attn_vec_tq_dk512",
_ => return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_tq: unsupported head_dim {head_dim}"
))),
};
let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
let pk = pad2(head_dim as usize, 128);
let pv = pad2(head_dim as usize, 128);
let sh = 4 * 32; let shmem_halfs = pk + sh + 2 * pv;
let shmem_bytes = shmem_halfs * 2;
encoder.set_op_kind(CapturedOpKind::Sdpa);
let threadgroups = MTLSize::new(1, params.num_heads as u64, nwg as u64);
let threadgroup_size = MTLSize::new(32, 1, 1);
let dst_buf = if nwg == 1 { output } else { tmp };
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(q)),
(2, KernelArg::Buffer(k_packed)),
(3, KernelArg::Buffer(k_norms)),
(4, KernelArg::Buffer(v_packed)),
(5, KernelArg::Buffer(v_norms)),
(6, KernelArg::Buffer(dst_buf)),
],
&[(0, shmem_bytes as u64)],
threadgroups,
threadgroup_size,
);
if nwg > 1 {
encoder.memory_barrier();
let reduce_params = FlashAttnVecReduceParamsGpu { nrows: params.num_heads };
let reduce_kernel = match head_dim {
256 => "flash_attn_vec_reduce_dk256",
512 => "flash_attn_vec_reduce_dk512",
_ => unreachable!(),
};
let reduce_pipeline = registry.get_pipeline(reduce_kernel, device.metal_device())?;
let reduce_tg = MTLSize::new(params.num_heads as u64, 1, 1);
let reduce_tg_size = MTLSize::new(32 * nwg as u64, 1, 1);
encoder.encode_threadgroups_with_args(
reduce_pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&reduce_params))),
(1, KernelArg::Buffer(tmp)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(&nwg))),
],
reduce_tg,
reduce_tg_size,
);
}
Ok(())
}
pub fn tmp_buffer_bytes(num_heads: u32, head_dim: u32) -> usize {
let nrows = num_heads as usize;
let max_nwg = 32usize;
let dv = head_dim as usize;
(nrows * max_nwg * (dv + 2)) * std::mem::size_of::<f32>()
}
fn pad2(x: usize, n: usize) -> usize {
(x + n - 1) & !(n - 1)
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_validate_params_ok() {
let p = FlashAttnVecTqParams {
num_heads: 8,
num_kv_heads: 4,
head_dim: 256,
kv_seq_len: 64,
kv_capacity: 1024,
scale: 1.0,
mask_type: 1,
sliding_window: 0,
softcap: 0.0,
ring_start: 0,
};
assert!(validate_params(&p).is_ok());
}
#[test]
fn test_validate_params_bad_head_dim() {
let p = FlashAttnVecTqParams {
num_heads: 8,
num_kv_heads: 4,
head_dim: 128,
kv_seq_len: 64,
kv_capacity: 1024,
scale: 1.0,
mask_type: 0,
sliding_window: 0,
softcap: 0.0,
ring_start: 0,
};
assert!(validate_params(&p).is_err());
}
#[test]
fn test_gpu_params_layout() {
assert_eq!(
std::mem::size_of::<FlashAttnVecTqParamsGpu>(),
44, );
}
}