use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::encoder::{as_bytes, CapturedOpKind, CommandEncoder, KernelArg};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static FLASH_ATTN_VEC_PEER_PORT_SHADER_SOURCE: &str =
include_str!("../shaders/flash_attn_vec_peer_port_f16.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source(
"flash_attn_vec_peer_port_f16_dk256_dv256",
FLASH_ATTN_VEC_PEER_PORT_SHADER_SOURCE,
);
}
#[derive(Debug, Clone)]
pub struct FlashAttnVecPeerPortParams {
pub num_heads: u32,
pub num_kv_heads: u32,
pub head_dim: u32,
pub kv_seq_len: u32,
pub kv_capacity: u32,
pub scale: f32,
pub mask_type: u32,
pub sliding_window: u32,
pub ring_start: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct FlashAttnVecPeerPortParamsGpu {
num_heads: u32,
num_kv_heads: u32,
head_dim: u32,
kv_seq_len: u32,
kv_capacity: u32,
scale: f32,
mask_type: u32,
sliding_window: u32,
ring_start: u32,
}
fn pad2(x: usize, n: usize) -> usize {
(x + n - 1) & !(n - 1)
}
#[allow(clippy::too_many_arguments)]
pub fn flash_attn_vec_peer_port_f16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
q: &MlxBuffer,
k_f16: &MlxBuffer,
v_f16: &MlxBuffer,
output: &MlxBuffer,
params: &FlashAttnVecPeerPortParams,
) -> Result<()> {
if params.head_dim != 256 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_peer_port_f16: head_dim must be 256, got {}",
params.head_dim
)));
}
if k_f16.dtype() != crate::DType::F16 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_peer_port_f16: k_f16 must be DType::F16, got {:?}",
k_f16.dtype()
)));
}
if v_f16.dtype() != crate::DType::F16 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_peer_port_f16: v_f16 must be DType::F16, got {:?}",
v_f16.dtype()
)));
}
if params.num_heads == 0 || params.num_kv_heads == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_vec_peer_port_f16: num_heads and num_kv_heads must be > 0".into(),
));
}
if params.num_heads % params.num_kv_heads != 0 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_peer_port_f16: num_heads ({}) % num_kv_heads ({}) != 0",
params.num_heads, params.num_kv_heads
)));
}
if params.kv_seq_len == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_vec_peer_port_f16: kv_seq_len must be > 0".into(),
));
}
if params.kv_capacity < params.kv_seq_len {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_peer_port_f16: kv_capacity ({}) < kv_seq_len ({})",
params.kv_capacity, params.kv_seq_len
)));
}
let gpu_params = FlashAttnVecPeerPortParamsGpu {
num_heads: params.num_heads,
num_kv_heads: params.num_kv_heads,
head_dim: params.head_dim,
kv_seq_len: params.kv_seq_len,
kv_capacity: params.kv_capacity,
scale: params.scale,
mask_type: params.mask_type,
sliding_window: params.sliding_window,
ring_start: params.ring_start,
};
let dk = params.head_dim as usize;
let dv = params.head_dim as usize;
let c = 32_usize;
let pk = pad2(dk, 128);
let pv = pad2(dv, 128);
let sh = 4 * c;
let shmem_halfs = pk + 1 * (sh + 2 * pv);
let shmem_bytes = shmem_halfs * 2;
let pipeline = registry.get_pipeline(
"flash_attn_vec_peer_port_f16_dk256_dv256",
device.metal_device(),
)?;
encoder.set_op_kind(CapturedOpKind::Sdpa);
let threadgroups = MTLSize::new(1, params.num_heads as u64, 1);
let threadgroup_size = MTLSize::new(32, 1, 1);
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(q)),
(2, KernelArg::Buffer(k_f16)),
(3, KernelArg::Buffer(v_f16)),
(4, KernelArg::Buffer(output)),
],
&[(0, shmem_bytes as u64)],
threadgroups,
threadgroup_size,
);
Ok(())
}
#[derive(Debug, Clone, Copy)]
pub struct FlashAttnVecPeerPortReduceParams {
pub nrows: i32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct FlashAttnVecPeerPortReduceParamsGpu {
nrows: i32,
}
pub fn flash_attn_vec_peer_port_f16_nwg32_tmp_bytes(num_heads: u32, head_dim: u32) -> usize {
const NWG: u64 = 32;
let nrows = num_heads as u64;
let dv = head_dim as u64;
((nrows * NWG * (dv + 2)) * 4) as usize
}
#[allow(clippy::too_many_arguments)]
pub fn flash_attn_vec_peer_port_f16_nwg32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
q: &MlxBuffer,
k_f16: &MlxBuffer,
v_f16: &MlxBuffer,
tmp: &MlxBuffer,
output: &MlxBuffer,
params: &FlashAttnVecPeerPortParams,
) -> Result<()> {
if params.head_dim != 256 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_peer_port_f16_nwg32: head_dim must be 256, got {}",
params.head_dim
)));
}
if k_f16.dtype() != crate::DType::F16 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_peer_port_f16_nwg32: k_f16 must be DType::F16, got {:?}",
k_f16.dtype()
)));
}
if v_f16.dtype() != crate::DType::F16 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_peer_port_f16_nwg32: v_f16 must be DType::F16, got {:?}",
v_f16.dtype()
)));
}
if params.num_heads == 0 || params.num_kv_heads == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_vec_peer_port_f16_nwg32: num_heads/num_kv_heads must be > 0".into(),
));
}
if params.num_heads % params.num_kv_heads != 0 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_peer_port_f16_nwg32: num_heads ({}) % num_kv_heads ({}) != 0",
params.num_heads, params.num_kv_heads
)));
}
if params.kv_seq_len == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_vec_peer_port_f16_nwg32: kv_seq_len must be > 0".into(),
));
}
if params.kv_capacity < params.kv_seq_len {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_peer_port_f16_nwg32: kv_capacity ({}) < kv_seq_len ({})",
params.kv_capacity, params.kv_seq_len
)));
}
let required_tmp = flash_attn_vec_peer_port_f16_nwg32_tmp_bytes(
params.num_heads,
params.head_dim,
);
if tmp.byte_len() < required_tmp {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_peer_port_f16_nwg32: tmp buffer too small ({} < {} bytes)",
tmp.byte_len(),
required_tmp,
)));
}
let gpu_params = FlashAttnVecPeerPortParamsGpu {
num_heads: params.num_heads,
num_kv_heads: params.num_kv_heads,
head_dim: params.head_dim,
kv_seq_len: params.kv_seq_len,
kv_capacity: params.kv_capacity,
scale: params.scale,
mask_type: params.mask_type,
sliding_window: params.sliding_window,
ring_start: params.ring_start,
};
let dk = params.head_dim as usize;
let dv = params.head_dim as usize;
let c = 32_usize;
let pk = pad2(dk, 128);
let pv = pad2(dv, 128);
let sh = 4 * c;
let shmem_halfs = pk + 1 * (sh + 2 * pv);
let shmem_bytes = shmem_halfs * 2;
let vec_pipeline = registry.get_pipeline(
"flash_attn_vec_peer_port_f16_nwg32_dk256_dv256",
device.metal_device(),
)?;
encoder.set_op_kind(CapturedOpKind::Sdpa);
let vec_threadgroups = MTLSize::new(1, params.num_heads as u64, 32);
let vec_threadgroup_size = MTLSize::new(32, 1, 1);
encoder.encode_threadgroups_with_args_and_shared(
vec_pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(q)),
(2, KernelArg::Buffer(k_f16)),
(3, KernelArg::Buffer(v_f16)),
(4, KernelArg::Buffer(tmp)),
],
&[(0, shmem_bytes as u64)],
vec_threadgroups,
vec_threadgroup_size,
);
encoder.memory_barrier();
let reduce_pipeline = registry.get_pipeline(
"flash_attn_vec_peer_port_f16_reduce_dv256_nwg32",
device.metal_device(),
)?;
let reduce_params = FlashAttnVecPeerPortReduceParamsGpu {
nrows: params.num_heads as i32,
};
let reduce_threadgroups = MTLSize::new(params.num_heads as u64, 1, 1);
let reduce_threadgroup_size = MTLSize::new(32 * 32, 1, 1);
encoder.encode_threadgroups_with_args(
reduce_pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&reduce_params))),
(1, KernelArg::Buffer(tmp)),
(2, KernelArg::Buffer(output)),
],
reduce_threadgroups,
reduce_threadgroup_size,
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_params_size() {
assert_eq!(std::mem::size_of::<FlashAttnVecPeerPortParamsGpu>(), 36);
}
#[test]
fn test_shmem_formula() {
let dk = 256_usize;
let dv = 256_usize;
let c = 32_usize;
let pk = pad2(dk, 128);
let pv = pad2(dv, 128);
let sh = 4 * c;
let shmem_halfs = pk + 1 * (sh + 2 * pv);
let shmem_bytes = shmem_halfs * 2;
assert_eq!(shmem_bytes, 1792);
}
#[test]
fn pipeline_registers_and_compiles() {
let device = match crate::device::MlxDevice::new() {
Ok(d) => d,
Err(_) => return, };
let mut registry = KernelRegistry::new();
register(&mut registry);
registry
.get_pipeline("flash_attn_vec_peer_port_f16_dk256_dv256", device.metal_device())
.expect("Metal compiler rejected flash_attn_vec_peer_port_f16_dk256_dv256 — check MSL source");
}
#[test]
fn reduce_pipeline_registers_and_compiles() {
let device = match crate::device::MlxDevice::new() {
Ok(d) => d,
Err(_) => return,
};
let mut registry = KernelRegistry::new();
register(&mut registry);
registry
.get_pipeline(
"flash_attn_vec_peer_port_f16_reduce_dv256_nwg32",
device.metal_device(),
)
.expect(
"Metal compiler rejected flash_attn_vec_peer_port_f16_reduce_dv256_nwg32 \
— check MSL source",
);
}
#[test]
fn nwg32_pipeline_registers_and_compiles() {
let device = match crate::device::MlxDevice::new() {
Ok(d) => d,
Err(_) => return,
};
let mut registry = KernelRegistry::new();
register(&mut registry);
registry
.get_pipeline(
"flash_attn_vec_peer_port_f16_nwg32_dk256_dv256",
device.metal_device(),
)
.expect(
"Metal compiler rejected flash_attn_vec_peer_port_f16_nwg32_dk256_dv256 \
— check MSL source",
);
}
}