1use metal::MTLSize;
12
13use crate::buffer::MlxBuffer;
14use crate::encoder::CommandEncoder;
15use crate::error::{MlxError, Result};
16use crate::kernel_registry::KernelRegistry;
17
18use super::encode_helpers::{encode_with_args, KernelArg};
19
20pub static KV_CACHE_COPY_SHADER_SOURCE: &str = include_str!("../shaders/kv_cache_copy.metal");
22
23pub fn register(registry: &mut KernelRegistry) {
25 registry.register_source("kv_cache_copy", KV_CACHE_COPY_SHADER_SOURCE);
26}
27
28#[allow(clippy::too_many_arguments)]
49pub fn dispatch_kv_cache_copy(
50 encoder: &mut CommandEncoder,
51 registry: &mut KernelRegistry,
52 device: &metal::DeviceRef,
53 src: &MlxBuffer,
54 cache: &MlxBuffer,
55 write_pos: u32,
56 row_size: u32,
57 n_new: u32,
58 cache_cap: u32,
59 is_sliding: bool,
60) -> Result<()> {
61 if n_new == 0 || row_size == 0 {
62 return Ok(()); }
64
65 let total_elements = (n_new as u64) * (row_size as u64);
66 let src_elements = src.element_count() as u64;
67 if src_elements < total_elements {
68 return Err(MlxError::InvalidArgument(format!(
69 "kv_cache_copy: src has {} elements but need {} (n_new={} * row_size={})",
70 src_elements, total_elements, n_new, row_size
71 )));
72 }
73
74 if !is_sliding && (write_pos as u64 + n_new as u64) > cache_cap as u64 {
76 return Err(MlxError::InvalidArgument(format!(
77 "kv_cache_copy: global cache overflow: write_pos({}) + n_new({}) > cache_cap({})",
78 write_pos, n_new, cache_cap
79 )));
80 }
81
82 let pipeline = registry.get_pipeline("kv_cache_copy", device)?;
83
84 let is_sliding_val: u32 = if is_sliding { 1 } else { 0 };
85
86 let write_pos_bytes = write_pos.to_ne_bytes();
88 let row_size_bytes = row_size.to_ne_bytes();
89 let n_new_bytes = n_new.to_ne_bytes();
90 let cache_cap_bytes = cache_cap.to_ne_bytes();
91 let is_sliding_bytes = is_sliding_val.to_ne_bytes();
92
93 encode_with_args(
94 encoder,
95 pipeline,
96 &[
97 (0, KernelArg::Buffer(src)),
98 (1, KernelArg::Buffer(cache)),
99 (2, KernelArg::Bytes(&write_pos_bytes)),
100 (3, KernelArg::Bytes(&row_size_bytes)),
101 (4, KernelArg::Bytes(&n_new_bytes)),
102 (5, KernelArg::Bytes(&cache_cap_bytes)),
103 (6, KernelArg::Bytes(&is_sliding_bytes)),
104 ],
105 MTLSize::new(total_elements, 1, 1),
106 MTLSize::new(std::cmp::min(256, total_elements), 1, 1),
107 );
108
109 Ok(())
110}
111
112#[allow(clippy::too_many_arguments)]
131pub fn dispatch_kv_cache_copy_batch_f32(
132 encoder: &mut CommandEncoder,
133 registry: &mut KernelRegistry,
134 device: &metal::DeviceRef,
135 src: &MlxBuffer,
136 cache: &MlxBuffer,
137 n_heads: u32,
138 head_dim: u32,
139 capacity: u32,
140 seq_pos: u32,
141) -> Result<()> {
142 if n_heads == 0 || head_dim == 0 {
143 return Ok(());
144 }
145
146 let total_src = (n_heads as u64) * (head_dim as u64);
147 if (src.element_count() as u64) < total_src {
148 return Err(MlxError::InvalidArgument(format!(
149 "kv_cache_copy_batch_f32: src has {} elements but need {} (n_heads={} * head_dim={})",
150 src.element_count(), total_src, n_heads, head_dim
151 )));
152 }
153
154 let pipeline = registry.get_pipeline("kv_cache_copy_batch_f32", device)?;
155
156 let n_heads_bytes = n_heads.to_ne_bytes();
157 let head_dim_bytes = head_dim.to_ne_bytes();
158 let capacity_bytes = capacity.to_ne_bytes();
159 let seq_pos_bytes = seq_pos.to_ne_bytes();
160
161 use super::encode_helpers::{encode_with_args, KernelArg};
162
163 encode_with_args(
164 encoder,
165 pipeline,
166 &[
167 (0, KernelArg::Buffer(src)),
168 (1, KernelArg::Buffer(cache)),
169 (2, KernelArg::Bytes(&n_heads_bytes)),
170 (3, KernelArg::Bytes(&head_dim_bytes)),
171 (4, KernelArg::Bytes(&capacity_bytes)),
172 (5, KernelArg::Bytes(&seq_pos_bytes)),
173 ],
174 MTLSize::new(head_dim as u64, n_heads as u64, 1),
175 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
176 );
177
178 Ok(())
179}
180
181#[allow(clippy::too_many_arguments)]
205pub fn dispatch_kv_cache_copy_f32(
206 encoder: &mut CommandEncoder,
207 registry: &mut KernelRegistry,
208 device: &metal::DeviceRef,
209 src: &MlxBuffer,
210 cache: &MlxBuffer,
211 write_pos: u32,
212 row_size: u32,
213 n_new: u32,
214 cache_cap: u32,
215 is_sliding: bool,
216) -> Result<()> {
217 if n_new == 0 || row_size == 0 {
218 return Ok(()); }
220
221 let total_elements = (n_new as u64) * (row_size as u64);
222 let src_elements = src.element_count() as u64;
223 if src_elements < total_elements {
224 return Err(MlxError::InvalidArgument(format!(
225 "kv_cache_copy_f32: src has {} elements but need {} (n_new={} * row_size={})",
226 src_elements, total_elements, n_new, row_size
227 )));
228 }
229
230 if !is_sliding && (write_pos as u64 + n_new as u64) > cache_cap as u64 {
232 return Err(MlxError::InvalidArgument(format!(
233 "kv_cache_copy_f32: global cache overflow: write_pos({}) + n_new({}) > cache_cap({})",
234 write_pos, n_new, cache_cap
235 )));
236 }
237
238 let pipeline = registry.get_pipeline("kv_cache_copy_f32", device)?;
239
240 let is_sliding_val: u32 = if is_sliding { 1 } else { 0 };
241
242 let write_pos_bytes = write_pos.to_ne_bytes();
243 let row_size_bytes = row_size.to_ne_bytes();
244 let n_new_bytes = n_new.to_ne_bytes();
245 let cache_cap_bytes = cache_cap.to_ne_bytes();
246 let is_sliding_bytes = is_sliding_val.to_ne_bytes();
247
248 encode_with_args(
249 encoder,
250 pipeline,
251 &[
252 (0, KernelArg::Buffer(src)),
253 (1, KernelArg::Buffer(cache)),
254 (2, KernelArg::Bytes(&write_pos_bytes)),
255 (3, KernelArg::Bytes(&row_size_bytes)),
256 (4, KernelArg::Bytes(&n_new_bytes)),
257 (5, KernelArg::Bytes(&cache_cap_bytes)),
258 (6, KernelArg::Bytes(&is_sliding_bytes)),
259 ],
260 MTLSize::new(total_elements, 1, 1),
261 MTLSize::new(std::cmp::min(256, total_elements), 1, 1),
262 );
263
264 Ok(())
265}
266
267#[allow(clippy::too_many_arguments)]
276pub fn dispatch_kv_cache_copy_batch_f32_to_f16(
277 encoder: &mut CommandEncoder,
278 registry: &mut KernelRegistry,
279 device: &metal::DeviceRef,
280 src: &MlxBuffer,
281 cache: &MlxBuffer,
282 n_heads: u32,
283 head_dim: u32,
284 capacity: u32,
285 seq_pos: u32,
286) -> Result<()> {
287 if n_heads == 0 || head_dim == 0 {
288 return Ok(());
289 }
290
291 let total_src = (n_heads as u64) * (head_dim as u64);
292 if (src.element_count() as u64) < total_src {
293 return Err(MlxError::InvalidArgument(format!(
294 "kv_cache_copy_batch_f32_to_f16: src has {} elements but need {} (n_heads={} * head_dim={})",
295 src.element_count(), total_src, n_heads, head_dim
296 )));
297 }
298
299 let pipeline = registry.get_pipeline("kv_cache_copy_batch_f32_to_f16", device)?;
300
301 let n_heads_bytes = n_heads.to_ne_bytes();
302 let head_dim_bytes = head_dim.to_ne_bytes();
303 let capacity_bytes = capacity.to_ne_bytes();
304 let seq_pos_bytes = seq_pos.to_ne_bytes();
305
306 use super::encode_helpers::{encode_with_args, KernelArg};
307
308 encode_with_args(
309 encoder,
310 pipeline,
311 &[
312 (0, KernelArg::Buffer(src)),
313 (1, KernelArg::Buffer(cache)),
314 (2, KernelArg::Bytes(&n_heads_bytes)),
315 (3, KernelArg::Bytes(&head_dim_bytes)),
316 (4, KernelArg::Bytes(&capacity_bytes)),
317 (5, KernelArg::Bytes(&seq_pos_bytes)),
318 ],
319 MTLSize::new(head_dim as u64, n_heads as u64, 1),
320 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
321 );
322
323 Ok(())
324}