mlx_native/ops/
qkv_split.rs1use metal::MTLSize;
22
23use crate::buffer::MlxBuffer;
24use crate::encoder::CommandEncoder;
25use crate::error::{MlxError, Result};
26use crate::kernel_registry::KernelRegistry;
27
28use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
29
30pub static QKV_SPLIT_SHADER_SOURCE: &str = include_str!("../shaders/qkv_split.metal");
32
33pub fn register(registry: &mut KernelRegistry) {
39 registry.register_source("qkv_split_f32", QKV_SPLIT_SHADER_SOURCE);
40}
41
42#[repr(C)]
46#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
47struct GpuQkvSplitParams {
48 seq: u32,
49 q_sp: u32,
50 k_sp: u32,
51 v_sp: u32,
52 qkv_ch: u32,
53}
54
55#[derive(Clone, Copy, Debug)]
57pub struct QkvSplitParams {
58 pub seq: u32,
60 pub q_sp: u32,
62 pub k_sp: u32,
64 pub v_sp: u32,
66}
67
68#[allow(clippy::too_many_arguments)]
90pub fn dispatch_qkv_split_f32(
91 encoder: &mut CommandEncoder,
92 registry: &mut KernelRegistry,
93 device: &metal::DeviceRef,
94 qkv: &MlxBuffer,
95 q: &MlxBuffer,
96 k: &MlxBuffer,
97 v: &MlxBuffer,
98 params: &QkvSplitParams,
99) -> Result<()> {
100 if params.seq == 0 || params.q_sp == 0 || params.k_sp == 0 || params.v_sp == 0 {
101 return Err(MlxError::InvalidArgument(
102 "qkv_split_f32: seq, q_sp, k_sp, v_sp must all be > 0".into(),
103 ));
104 }
105
106 let qkv_ch = params
107 .q_sp
108 .checked_add(params.k_sp)
109 .and_then(|qk| qk.checked_add(params.v_sp))
110 .ok_or_else(|| {
111 MlxError::InvalidArgument(
112 "qkv_split_f32: q_sp + k_sp + v_sp overflows u32".into(),
113 )
114 })?;
115
116 let in_bytes = (params.seq as usize) * (qkv_ch as usize) * 4;
118 if qkv.byte_len() < in_bytes {
119 return Err(MlxError::InvalidArgument(format!(
120 "qkv_split_f32: qkv buffer too small: need {} bytes, have {}",
121 in_bytes,
122 qkv.byte_len()
123 )));
124 }
125 let q_bytes = (params.seq as usize) * (params.q_sp as usize) * 4;
126 if q.byte_len() < q_bytes {
127 return Err(MlxError::InvalidArgument(format!(
128 "qkv_split_f32: q buffer too small: need {} bytes, have {}",
129 q_bytes,
130 q.byte_len()
131 )));
132 }
133 let k_bytes = (params.seq as usize) * (params.k_sp as usize) * 4;
134 if k.byte_len() < k_bytes {
135 return Err(MlxError::InvalidArgument(format!(
136 "qkv_split_f32: k buffer too small: need {} bytes, have {}",
137 k_bytes,
138 k.byte_len()
139 )));
140 }
141 let v_bytes = (params.seq as usize) * (params.v_sp as usize) * 4;
142 if v.byte_len() < v_bytes {
143 return Err(MlxError::InvalidArgument(format!(
144 "qkv_split_f32: v buffer too small: need {} bytes, have {}",
145 v_bytes,
146 v.byte_len()
147 )));
148 }
149
150 let pipeline = registry.get_pipeline("qkv_split_f32", device)?;
151
152 let gpu_params = GpuQkvSplitParams {
153 seq: params.seq,
154 q_sp: params.q_sp,
155 k_sp: params.k_sp,
156 v_sp: params.v_sp,
157 qkv_ch,
158 };
159
160 let grid = MTLSize::new(qkv_ch as u64, params.seq as u64, 1);
161 let tg_x = std::cmp::min(256u64, qkv_ch as u64);
162 let tg = MTLSize::new(tg_x, 1, 1);
163
164 encode_with_args(
165 encoder,
166 pipeline,
167 &[
168 (0, KernelArg::Buffer(qkv)),
169 (1, KernelArg::Buffer(q)),
170 (2, KernelArg::Buffer(k)),
171 (3, KernelArg::Buffer(v)),
172 (4, KernelArg::Bytes(as_bytes(&gpu_params))),
173 ],
174 grid,
175 tg,
176 );
177
178 Ok(())
179}