use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
pub static QKV_SPLIT_SHADER_SOURCE: &str = include_str!("../shaders/qkv_split.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("qkv_split_f32", QKV_SPLIT_SHADER_SOURCE);
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuQkvSplitParams {
seq: u32,
q_sp: u32,
k_sp: u32,
v_sp: u32,
qkv_ch: u32,
}
#[derive(Clone, Copy, Debug)]
pub struct QkvSplitParams {
pub seq: u32,
pub q_sp: u32,
pub k_sp: u32,
pub v_sp: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_qkv_split_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
qkv: &MlxBuffer,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
params: &QkvSplitParams,
) -> Result<()> {
if params.seq == 0 || params.q_sp == 0 || params.k_sp == 0 || params.v_sp == 0 {
return Err(MlxError::InvalidArgument(
"qkv_split_f32: seq, q_sp, k_sp, v_sp must all be > 0".into(),
));
}
let qkv_ch = params
.q_sp
.checked_add(params.k_sp)
.and_then(|qk| qk.checked_add(params.v_sp))
.ok_or_else(|| {
MlxError::InvalidArgument(
"qkv_split_f32: q_sp + k_sp + v_sp overflows u32".into(),
)
})?;
let in_bytes = (params.seq as usize) * (qkv_ch as usize) * 4;
if qkv.byte_len() < in_bytes {
return Err(MlxError::InvalidArgument(format!(
"qkv_split_f32: qkv buffer too small: need {} bytes, have {}",
in_bytes,
qkv.byte_len()
)));
}
let q_bytes = (params.seq as usize) * (params.q_sp as usize) * 4;
if q.byte_len() < q_bytes {
return Err(MlxError::InvalidArgument(format!(
"qkv_split_f32: q buffer too small: need {} bytes, have {}",
q_bytes,
q.byte_len()
)));
}
let k_bytes = (params.seq as usize) * (params.k_sp as usize) * 4;
if k.byte_len() < k_bytes {
return Err(MlxError::InvalidArgument(format!(
"qkv_split_f32: k buffer too small: need {} bytes, have {}",
k_bytes,
k.byte_len()
)));
}
let v_bytes = (params.seq as usize) * (params.v_sp as usize) * 4;
if v.byte_len() < v_bytes {
return Err(MlxError::InvalidArgument(format!(
"qkv_split_f32: v buffer too small: need {} bytes, have {}",
v_bytes,
v.byte_len()
)));
}
let pipeline = registry.get_pipeline("qkv_split_f32", device)?;
let gpu_params = GpuQkvSplitParams {
seq: params.seq,
q_sp: params.q_sp,
k_sp: params.k_sp,
v_sp: params.v_sp,
qkv_ch,
};
let grid = MTLSize::new(qkv_ch as u64, params.seq as u64, 1);
let tg_x = std::cmp::min(256u64, qkv_ch as u64);
let tg = MTLSize::new(tg_x, 1, 1);
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(qkv)),
(1, KernelArg::Buffer(q)),
(2, KernelArg::Buffer(k)),
(3, KernelArg::Buffer(v)),
(4, KernelArg::Bytes(as_bytes(&gpu_params))),
],
grid,
tg,
);
Ok(())
}