use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::encoder::{as_bytes, CapturedOpKind, CommandEncoder, KernelArg};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub use super::flash_attn_vec_tq_hb::{
compute_nsg, tmp_buffer_bytes, FlashAttnVecTqHbParams,
};
pub static FLASH_ATTN_VEC_HYBRID_SHADER_SOURCE: &str =
include_str!("../shaders/flash_attn_vec_hybrid.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("flash_attn_vec_hybrid_dk256", FLASH_ATTN_VEC_HYBRID_SHADER_SOURCE);
registry.register_source("flash_attn_vec_hybrid_dk512", FLASH_ATTN_VEC_HYBRID_SHADER_SOURCE);
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct FlashAttnVecHybridParamsGpu {
n_heads: u32,
n_kv_heads: u32,
head_dim: u32,
kv_seq_len: u32,
kv_capacity: u32,
scale: f32,
mask_type: u32,
sliding_window: u32,
softcap: f32,
nwg: u32,
ring_start: u32,
scale_factor_d512: f32,
codebook_bits: u32,
fuse_fwht_pre: u32,
nsg: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct FlashAttnVecReduceParamsGpu {
nrows: u32,
}
fn validate_params(params: &FlashAttnVecTqHbParams) -> Result<()> {
if params.head_dim != 256 && params.head_dim != 512 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_hybrid: head_dim must be 256 or 512, got {}",
params.head_dim
)));
}
if params.num_heads == 0 || params.num_kv_heads == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_vec_hybrid: num_heads and num_kv_heads must be > 0".into(),
));
}
if params.num_heads % params.num_kv_heads != 0 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_hybrid: num_heads ({}) % num_kv_heads ({}) != 0",
params.num_heads, params.num_kv_heads
)));
}
if params.kv_seq_len == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_vec_hybrid: kv_seq_len must be > 0".into(),
));
}
if params.kv_capacity < params.kv_seq_len {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_hybrid: kv_capacity ({}) < kv_seq_len ({})",
params.kv_capacity, params.kv_seq_len
)));
}
if !matches!(params.codebook_bits, 5 | 6 | 8) {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_hybrid: V codebook_bits must be 5, 6, or 8, got {}",
params.codebook_bits
)));
}
if params.nsg == 0 || (params.nsg & (params.nsg - 1)) != 0 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_hybrid: nsg must be a power of 2 (1, 2, 4, ...), got {}",
params.nsg
)));
}
if params.nsg > 4 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_hybrid: nsg must be ≤ 4 (kernel reduce cap), got {}",
params.nsg
)));
}
Ok(())
}
fn compute_nwg(kv_seq_len: u32) -> u32 {
if let Ok(v) = std::env::var("HF2Q_HYBRID_NWG") {
if let Ok(n) = v.parse::<u32>() {
if (1..=32).contains(&n) {
return n;
}
}
}
if kv_seq_len > 512 { 32 } else { 16 }
}
#[allow(clippy::too_many_arguments)]
pub fn flash_attn_vec_hybrid(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
q: &MlxBuffer,
k_f16: &MlxBuffer,
v_packed: &MlxBuffer,
v_norms: &MlxBuffer,
output: &MlxBuffer,
tmp: &MlxBuffer,
params: &FlashAttnVecTqHbParams,
) -> Result<()> {
validate_params(params)?;
if k_f16.dtype() != crate::DType::F16 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_hybrid: k_f16 must be DType::F16, got {:?}",
k_f16.dtype()
)));
}
let head_dim = params.head_dim;
let nwg = compute_nwg(params.kv_seq_len);
let gpu_params = FlashAttnVecHybridParamsGpu {
n_heads: params.num_heads,
n_kv_heads: params.num_kv_heads,
head_dim: params.head_dim,
kv_seq_len: params.kv_seq_len,
kv_capacity: params.kv_capacity,
scale: params.scale,
mask_type: params.mask_type,
sliding_window: params.sliding_window,
softcap: params.softcap,
nwg,
ring_start: params.ring_start,
scale_factor_d512: params.scale_factor_d512,
codebook_bits: params.codebook_bits,
fuse_fwht_pre: params.fuse_fwht_pre,
nsg: params.nsg,
};
let kernel_name = match head_dim {
256 => "flash_attn_vec_hybrid_dk256",
512 => "flash_attn_vec_hybrid_dk512",
_ => return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_hybrid: unsupported head_dim {head_dim}"
))),
};
let cbits_const = (params.codebook_bits as i32, 50usize);
let v_is_f16: i32 = match v_packed.dtype() {
crate::DType::F16 => 1,
_ => 0,
};
let pipeline = registry
.get_pipeline_with_constants(
kernel_name,
device.metal_device(),
&[],
&[(cbits_const.1, cbits_const.0), (51usize, v_is_f16)],
)?;
let pk = pad2(head_dim as usize, 128);
let pv = pad2(head_dim as usize, 128);
let sh = 4 * 32;
let nsg = params.nsg as usize;
let shmem_halfs = pk + nsg * (sh + 2 * pv);
let shmem_bytes = shmem_halfs * 2;
encoder.set_op_kind(CapturedOpKind::Sdpa);
let threadgroups = MTLSize::new(1, params.num_heads as u64, nwg as u64);
let threadgroup_size = MTLSize::new(32, params.nsg as u64, 1);
let dst_buf = if nwg == 1 { output } else { tmp };
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(q)),
(2, KernelArg::Buffer(k_f16)), (3, KernelArg::Buffer(v_packed)),
(4, KernelArg::Buffer(v_norms)),
(5, KernelArg::Buffer(dst_buf)),
],
&[(0, shmem_bytes as u64)],
threadgroups,
threadgroup_size,
);
if nwg > 1 {
encoder.memory_barrier();
let reduce_params = FlashAttnVecReduceParamsGpu { nrows: params.num_heads };
let reduce_kernel = match head_dim {
256 => "flash_attn_vec_reduce_dk256",
512 => "flash_attn_vec_reduce_dk512",
_ => unreachable!(),
};
let reduce_pipeline = registry.get_pipeline(reduce_kernel, device.metal_device())?;
let reduce_tg = MTLSize::new(params.num_heads as u64, 1, 1);
let reduce_tg_size = MTLSize::new(32 * nwg as u64, 1, 1);
encoder.encode_threadgroups_with_args(
reduce_pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&reduce_params))),
(1, KernelArg::Buffer(tmp)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(&nwg))),
],
reduce_tg,
reduce_tg_size,
);
}
Ok(())
}
fn pad2(x: usize, n: usize) -> usize {
(x + n - 1) & !(n - 1)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_params_size() {
assert_eq!(std::mem::size_of::<FlashAttnVecHybridParamsGpu>(), 60);
}
#[test]
fn test_validate_bad_bits() {
let p = FlashAttnVecTqHbParams {
num_heads: 8,
num_kv_heads: 4,
head_dim: 256,
kv_seq_len: 64,
kv_capacity: 1024,
scale: 1.0,
mask_type: 0,
sliding_window: 0,
softcap: 0.0,
ring_start: 0,
scale_factor_d512: 1.0,
codebook_bits: 4, fuse_fwht_pre: 0,
nsg: 1,
};
assert!(validate_params(&p).is_err());
}
#[test]
fn test_validate_ok_8bit() {
let p = FlashAttnVecTqHbParams {
num_heads: 8,
num_kv_heads: 4,
head_dim: 256,
kv_seq_len: 64,
kv_capacity: 1024,
scale: 1.0,
mask_type: 0,
sliding_window: 0,
softcap: 0.0,
ring_start: 0,
scale_factor_d512: 1.0,
codebook_bits: 8,
fuse_fwht_pre: 0,
nsg: 1,
};
assert!(validate_params(&p).is_ok());
}
#[test]
fn test_validate_bad_head_dim() {
let p = FlashAttnVecTqHbParams {
num_heads: 8,
num_kv_heads: 4,
head_dim: 128, kv_seq_len: 64,
kv_capacity: 1024,
scale: 1.0,
mask_type: 0,
sliding_window: 0,
softcap: 0.0,
ring_start: 0,
scale_factor_d512: 1.0,
codebook_bits: 8,
fuse_fwht_pre: 0,
nsg: 1,
};
assert!(validate_params(&p).is_err());
}
}