1use metal::MTLSize;
12
13use crate::buffer::MlxBuffer;
14use crate::encoder::CommandEncoder;
15use crate::error::{MlxError, Result};
16use crate::kernel_registry::KernelRegistry;
17
18use super::encode_helpers::{encode_with_args, KernelArg};
19
20pub static KV_CACHE_COPY_SHADER_SOURCE: &str = include_str!("../shaders/kv_cache_copy.metal");
22
23pub fn register(registry: &mut KernelRegistry) {
25 registry.register_source("kv_cache_copy", KV_CACHE_COPY_SHADER_SOURCE);
26}
27
28#[allow(clippy::too_many_arguments)]
49pub fn dispatch_kv_cache_copy(
50 encoder: &mut CommandEncoder,
51 registry: &mut KernelRegistry,
52 device: &metal::DeviceRef,
53 src: &MlxBuffer,
54 cache: &MlxBuffer,
55 write_pos: u32,
56 row_size: u32,
57 n_new: u32,
58 cache_cap: u32,
59 is_sliding: bool,
60) -> Result<()> {
61 if n_new == 0 || row_size == 0 {
62 return Ok(()); }
64
65 let total_elements = (n_new as u64) * (row_size as u64);
66 let src_elements = src.element_count() as u64;
67 if src_elements < total_elements {
68 return Err(MlxError::InvalidArgument(format!(
69 "kv_cache_copy: src has {} elements but need {} (n_new={} * row_size={})",
70 src_elements, total_elements, n_new, row_size
71 )));
72 }
73
74 if !is_sliding && (write_pos as u64 + n_new as u64) > cache_cap as u64 {
76 return Err(MlxError::InvalidArgument(format!(
77 "kv_cache_copy: global cache overflow: write_pos({}) + n_new({}) > cache_cap({})",
78 write_pos, n_new, cache_cap
79 )));
80 }
81
82 let pipeline = registry.get_pipeline("kv_cache_copy", device)?;
83
84 let is_sliding_val: u32 = if is_sliding { 1 } else { 0 };
85
86 let write_pos_bytes = write_pos.to_ne_bytes();
88 let row_size_bytes = row_size.to_ne_bytes();
89 let n_new_bytes = n_new.to_ne_bytes();
90 let cache_cap_bytes = cache_cap.to_ne_bytes();
91 let is_sliding_bytes = is_sliding_val.to_ne_bytes();
92
93 encode_with_args(
94 encoder,
95 pipeline,
96 &[
97 (0, KernelArg::Buffer(src)),
98 (1, KernelArg::Buffer(cache)),
99 (2, KernelArg::Bytes(&write_pos_bytes)),
100 (3, KernelArg::Bytes(&row_size_bytes)),
101 (4, KernelArg::Bytes(&n_new_bytes)),
102 (5, KernelArg::Bytes(&cache_cap_bytes)),
103 (6, KernelArg::Bytes(&is_sliding_bytes)),
104 ],
105 MTLSize::new(total_elements, 1, 1),
106 MTLSize::new(std::cmp::min(256, total_elements), 1, 1),
107 );
108
109 Ok(())
110}
111
112#[allow(clippy::too_many_arguments)]
131pub fn dispatch_kv_cache_copy_batch_f32(
132 encoder: &mut CommandEncoder,
133 registry: &mut KernelRegistry,
134 device: &metal::DeviceRef,
135 src: &MlxBuffer,
136 cache: &MlxBuffer,
137 n_heads: u32,
138 head_dim: u32,
139 capacity: u32,
140 seq_pos: u32,
141) -> Result<()> {
142 if n_heads == 0 || head_dim == 0 {
143 return Ok(());
144 }
145
146 let total_src = (n_heads as u64) * (head_dim as u64);
147 if (src.element_count() as u64) < total_src {
148 return Err(MlxError::InvalidArgument(format!(
149 "kv_cache_copy_batch_f32: src has {} elements but need {} (n_heads={} * head_dim={})",
150 src.element_count(), total_src, n_heads, head_dim
151 )));
152 }
153
154 let pipeline = registry.get_pipeline("kv_cache_copy_batch_f32", device)?;
155
156 let n_heads_bytes = n_heads.to_ne_bytes();
157 let head_dim_bytes = head_dim.to_ne_bytes();
158 let capacity_bytes = capacity.to_ne_bytes();
159 let seq_pos_bytes = seq_pos.to_ne_bytes();
160
161 use super::encode_helpers::{encode_with_args, KernelArg};
162
163 encode_with_args(
164 encoder,
165 pipeline,
166 &[
167 (0, KernelArg::Buffer(src)),
168 (1, KernelArg::Buffer(cache)),
169 (2, KernelArg::Bytes(&n_heads_bytes)),
170 (3, KernelArg::Bytes(&head_dim_bytes)),
171 (4, KernelArg::Bytes(&capacity_bytes)),
172 (5, KernelArg::Bytes(&seq_pos_bytes)),
173 ],
174 MTLSize::new(head_dim as u64, n_heads as u64, 1),
175 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
176 );
177
178 Ok(())
179}
180
181#[allow(clippy::too_many_arguments)]
205pub fn dispatch_kv_cache_copy_f32(
206 encoder: &mut CommandEncoder,
207 registry: &mut KernelRegistry,
208 device: &metal::DeviceRef,
209 src: &MlxBuffer,
210 cache: &MlxBuffer,
211 write_pos: u32,
212 row_size: u32,
213 n_new: u32,
214 cache_cap: u32,
215 is_sliding: bool,
216) -> Result<()> {
217 if n_new == 0 || row_size == 0 {
218 return Ok(()); }
220
221 let total_elements = (n_new as u64) * (row_size as u64);
222 let src_elements = src.element_count() as u64;
223 if src_elements < total_elements {
224 return Err(MlxError::InvalidArgument(format!(
225 "kv_cache_copy_f32: src has {} elements but need {} (n_new={} * row_size={})",
226 src_elements, total_elements, n_new, row_size
227 )));
228 }
229
230 if !is_sliding && (write_pos as u64 + n_new as u64) > cache_cap as u64 {
232 return Err(MlxError::InvalidArgument(format!(
233 "kv_cache_copy_f32: global cache overflow: write_pos({}) + n_new({}) > cache_cap({})",
234 write_pos, n_new, cache_cap
235 )));
236 }
237
238 let pipeline = registry.get_pipeline("kv_cache_copy_f32", device)?;
239
240 let is_sliding_val: u32 = if is_sliding { 1 } else { 0 };
241
242 let write_pos_bytes = write_pos.to_ne_bytes();
243 let row_size_bytes = row_size.to_ne_bytes();
244 let n_new_bytes = n_new.to_ne_bytes();
245 let cache_cap_bytes = cache_cap.to_ne_bytes();
246 let is_sliding_bytes = is_sliding_val.to_ne_bytes();
247
248 encode_with_args(
249 encoder,
250 pipeline,
251 &[
252 (0, KernelArg::Buffer(src)),
253 (1, KernelArg::Buffer(cache)),
254 (2, KernelArg::Bytes(&write_pos_bytes)),
255 (3, KernelArg::Bytes(&row_size_bytes)),
256 (4, KernelArg::Bytes(&n_new_bytes)),
257 (5, KernelArg::Bytes(&cache_cap_bytes)),
258 (6, KernelArg::Bytes(&is_sliding_bytes)),
259 ],
260 MTLSize::new(total_elements, 1, 1),
261 MTLSize::new(std::cmp::min(256, total_elements), 1, 1),
262 );
263
264 Ok(())
265}
266
267#[allow(clippy::too_many_arguments)]
276pub fn dispatch_kv_cache_copy_batch_f32_to_f16(
277 encoder: &mut CommandEncoder,
278 registry: &mut KernelRegistry,
279 device: &metal::DeviceRef,
280 src: &MlxBuffer,
281 cache: &MlxBuffer,
282 n_heads: u32,
283 head_dim: u32,
284 capacity: u32,
285 seq_pos: u32,
286) -> Result<()> {
287 if n_heads == 0 || head_dim == 0 {
288 return Ok(());
289 }
290
291 let total_src = (n_heads as u64) * (head_dim as u64);
292 if (src.element_count() as u64) < total_src {
293 return Err(MlxError::InvalidArgument(format!(
294 "kv_cache_copy_batch_f32_to_f16: src has {} elements but need {} (n_heads={} * head_dim={})",
295 src.element_count(), total_src, n_heads, head_dim
296 )));
297 }
298
299 let pipeline = registry.get_pipeline("kv_cache_copy_batch_f32_to_f16", device)?;
300
301 let n_heads_bytes = n_heads.to_ne_bytes();
302 let head_dim_bytes = head_dim.to_ne_bytes();
303 let capacity_bytes = capacity.to_ne_bytes();
304 let seq_pos_bytes = seq_pos.to_ne_bytes();
305
306 use super::encode_helpers::{encode_with_args, KernelArg};
307
308 encode_with_args(
309 encoder,
310 pipeline,
311 &[
312 (0, KernelArg::Buffer(src)),
313 (1, KernelArg::Buffer(cache)),
314 (2, KernelArg::Bytes(&n_heads_bytes)),
315 (3, KernelArg::Bytes(&head_dim_bytes)),
316 (4, KernelArg::Bytes(&capacity_bytes)),
317 (5, KernelArg::Bytes(&seq_pos_bytes)),
318 ],
319 MTLSize::new(head_dim as u64, n_heads as u64, 1),
320 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
321 );
322
323 Ok(())
324}
325
326#[allow(clippy::too_many_arguments)]
344pub fn dispatch_kv_cache_copy_seq_f32(
345 encoder: &mut CommandEncoder,
346 registry: &mut KernelRegistry,
347 device: &metal::DeviceRef,
348 src: &MlxBuffer,
349 cache: &MlxBuffer,
350 n_heads: u32,
351 head_dim: u32,
352 capacity: u32,
353 seq_pos_start: u32,
354 n_tokens: u32,
355 src_tok_offset: u32,
356) -> Result<()> {
357 if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
358 return Ok(());
359 }
360 let total_src = ((src_tok_offset as u64) + (n_tokens as u64))
361 * (n_heads as u64) * (head_dim as u64);
362 if (src.element_count() as u64) < total_src {
363 return Err(MlxError::InvalidArgument(format!(
364 "kv_cache_copy_seq_f32: src has {} elements, need {} ((src_tok_offset={} + n_tokens={}) * n_heads={} * head_dim={})",
365 src.element_count(), total_src, src_tok_offset, n_tokens, n_heads, head_dim
366 )));
367 }
368
369 let pipeline = registry.get_pipeline("kv_cache_copy_seq_f32", device)?;
370
371 let n_heads_bytes = n_heads.to_ne_bytes();
372 let head_dim_bytes = head_dim.to_ne_bytes();
373 let capacity_bytes = capacity.to_ne_bytes();
374 let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
375 let n_tokens_bytes = n_tokens.to_ne_bytes();
376 let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
377
378 use super::encode_helpers::{encode_with_args, KernelArg};
379
380 encode_with_args(
381 encoder,
382 pipeline,
383 &[
384 (0, KernelArg::Buffer(src)),
385 (1, KernelArg::Buffer(cache)),
386 (2, KernelArg::Bytes(&n_heads_bytes)),
387 (3, KernelArg::Bytes(&head_dim_bytes)),
388 (4, KernelArg::Bytes(&capacity_bytes)),
389 (5, KernelArg::Bytes(&seq_pos_start_bytes)),
390 (6, KernelArg::Bytes(&n_tokens_bytes)),
391 (7, KernelArg::Bytes(&src_tok_offset_bytes)),
392 ],
393 MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
394 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
395 );
396
397 Ok(())
398}
399
400#[allow(clippy::too_many_arguments)]
409pub fn dispatch_kv_cache_copy_seq_f32_dual(
410 encoder: &mut CommandEncoder,
411 registry: &mut KernelRegistry,
412 device: &metal::DeviceRef,
413 src_k: &MlxBuffer,
414 src_v: &MlxBuffer,
415 cache_k: &MlxBuffer,
416 cache_v: &MlxBuffer,
417 n_heads: u32,
418 head_dim: u32,
419 capacity: u32,
420 seq_pos_start: u32,
421 n_tokens: u32,
422 src_tok_offset: u32,
423) -> Result<()> {
424 if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
425 return Ok(());
426 }
427 let total_src = ((src_tok_offset as u64) + (n_tokens as u64))
428 * (n_heads as u64) * (head_dim as u64);
429 for (name, b) in [("src_k", src_k), ("src_v", src_v)] {
430 if (b.element_count() as u64) < total_src {
431 return Err(MlxError::InvalidArgument(format!(
432 "kv_cache_copy_seq_f32_dual: {} has {} elements, need {}",
433 name, b.element_count(), total_src
434 )));
435 }
436 }
437
438 let pipeline = registry.get_pipeline("kv_cache_copy_seq_f32_kv_dual", device)?;
439
440 let n_heads_bytes = n_heads.to_ne_bytes();
441 let head_dim_bytes = head_dim.to_ne_bytes();
442 let capacity_bytes = capacity.to_ne_bytes();
443 let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
444 let n_tokens_bytes = n_tokens.to_ne_bytes();
445 let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
446
447 use super::encode_helpers::{encode_with_args, KernelArg};
448
449 encode_with_args(
450 encoder,
451 pipeline,
452 &[
453 (0, KernelArg::Buffer(src_k)),
454 (1, KernelArg::Buffer(src_v)),
455 (2, KernelArg::Buffer(cache_k)),
456 (3, KernelArg::Buffer(cache_v)),
457 (4, KernelArg::Bytes(&n_heads_bytes)),
458 (5, KernelArg::Bytes(&head_dim_bytes)),
459 (6, KernelArg::Bytes(&capacity_bytes)),
460 (7, KernelArg::Bytes(&seq_pos_start_bytes)),
461 (8, KernelArg::Bytes(&n_tokens_bytes)),
462 (9, KernelArg::Bytes(&src_tok_offset_bytes)),
463 ],
464 MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
465 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
466 );
467
468 Ok(())
469}
470
471#[allow(clippy::too_many_arguments)]
474pub fn dispatch_kv_cache_copy_seq_f32_to_f16_dual(
475 encoder: &mut CommandEncoder,
476 registry: &mut KernelRegistry,
477 device: &metal::DeviceRef,
478 src_k: &MlxBuffer,
479 src_v: &MlxBuffer,
480 cache_k: &MlxBuffer,
481 cache_v: &MlxBuffer,
482 n_heads: u32,
483 head_dim: u32,
484 capacity: u32,
485 seq_pos_start: u32,
486 n_tokens: u32,
487 src_tok_offset: u32,
488) -> Result<()> {
489 if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
490 return Ok(());
491 }
492 let total_src = ((src_tok_offset as u64) + (n_tokens as u64))
493 * (n_heads as u64) * (head_dim as u64);
494 for (name, b) in [("src_k", src_k), ("src_v", src_v)] {
495 if (b.element_count() as u64) < total_src {
496 return Err(MlxError::InvalidArgument(format!(
497 "kv_cache_copy_seq_f32_to_f16_dual: {} has {} elements, need {}",
498 name, b.element_count(), total_src
499 )));
500 }
501 }
502
503 let pipeline = registry.get_pipeline("kv_cache_copy_seq_f32_to_f16_kv_dual", device)?;
504
505 let n_heads_bytes = n_heads.to_ne_bytes();
506 let head_dim_bytes = head_dim.to_ne_bytes();
507 let capacity_bytes = capacity.to_ne_bytes();
508 let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
509 let n_tokens_bytes = n_tokens.to_ne_bytes();
510 let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
511
512 use super::encode_helpers::{encode_with_args, KernelArg};
513
514 encode_with_args(
515 encoder,
516 pipeline,
517 &[
518 (0, KernelArg::Buffer(src_k)),
519 (1, KernelArg::Buffer(src_v)),
520 (2, KernelArg::Buffer(cache_k)),
521 (3, KernelArg::Buffer(cache_v)),
522 (4, KernelArg::Bytes(&n_heads_bytes)),
523 (5, KernelArg::Bytes(&head_dim_bytes)),
524 (6, KernelArg::Bytes(&capacity_bytes)),
525 (7, KernelArg::Bytes(&seq_pos_start_bytes)),
526 (8, KernelArg::Bytes(&n_tokens_bytes)),
527 (9, KernelArg::Bytes(&src_tok_offset_bytes)),
528 ],
529 MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
530 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
531 );
532
533 Ok(())
534}
535
536#[allow(clippy::too_many_arguments)]
549pub fn dispatch_kv_cache_copy_seq_bf16(
550 encoder: &mut CommandEncoder,
551 registry: &mut KernelRegistry,
552 device: &metal::DeviceRef,
553 src: &MlxBuffer,
554 cache: &MlxBuffer,
555 n_heads: u32,
556 head_dim: u32,
557 capacity: u32,
558 seq_pos_start: u32,
559 n_tokens: u32,
560 src_tok_offset: u32,
561) -> Result<()> {
562 if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
563 return Ok(());
564 }
565 let total_src = ((src_tok_offset as u64) + (n_tokens as u64))
567 * (n_heads as u64) * (head_dim as u64);
568 let src_bytes_needed = total_src * 2; if (src.byte_len() as u64) < src_bytes_needed {
570 return Err(MlxError::InvalidArgument(format!(
571 "kv_cache_copy_seq_bf16: src has {} bytes, need {} ((src_tok_offset={} + n_tokens={}) * n_heads={} * head_dim={} * 2)",
572 src.byte_len(), src_bytes_needed, src_tok_offset, n_tokens, n_heads, head_dim
573 )));
574 }
575
576 let pipeline = registry.get_pipeline("kv_cache_copy_seq_bf16", device)?;
577
578 let n_heads_bytes = n_heads.to_ne_bytes();
579 let head_dim_bytes = head_dim.to_ne_bytes();
580 let capacity_bytes = capacity.to_ne_bytes();
581 let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
582 let n_tokens_bytes = n_tokens.to_ne_bytes();
583 let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
584
585 use super::encode_helpers::{encode_with_args, KernelArg};
586
587 encode_with_args(
588 encoder,
589 pipeline,
590 &[
591 (0, KernelArg::Buffer(src)),
592 (1, KernelArg::Buffer(cache)),
593 (2, KernelArg::Bytes(&n_heads_bytes)),
594 (3, KernelArg::Bytes(&head_dim_bytes)),
595 (4, KernelArg::Bytes(&capacity_bytes)),
596 (5, KernelArg::Bytes(&seq_pos_start_bytes)),
597 (6, KernelArg::Bytes(&n_tokens_bytes)),
598 (7, KernelArg::Bytes(&src_tok_offset_bytes)),
599 ],
600 MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
601 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
602 );
603
604 Ok(())
605}
606
607#[allow(clippy::too_many_arguments)]
613pub fn dispatch_kv_cache_copy_seq_f32_to_f16(
614 encoder: &mut CommandEncoder,
615 registry: &mut KernelRegistry,
616 device: &metal::DeviceRef,
617 src: &MlxBuffer,
618 cache: &MlxBuffer,
619 n_heads: u32,
620 head_dim: u32,
621 capacity: u32,
622 seq_pos_start: u32,
623 n_tokens: u32,
624 src_tok_offset: u32,
625) -> Result<()> {
626 if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
627 return Ok(());
628 }
629 let total_src = ((src_tok_offset as u64) + (n_tokens as u64))
630 * (n_heads as u64) * (head_dim as u64);
631 if (src.element_count() as u64) < total_src {
632 return Err(MlxError::InvalidArgument(format!(
633 "kv_cache_copy_seq_f32_to_f16: src has {} elements, need {}",
634 src.element_count(), total_src
635 )));
636 }
637
638 let pipeline = registry.get_pipeline("kv_cache_copy_seq_f32_to_f16", device)?;
639
640 let n_heads_bytes = n_heads.to_ne_bytes();
641 let head_dim_bytes = head_dim.to_ne_bytes();
642 let capacity_bytes = capacity.to_ne_bytes();
643 let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
644 let n_tokens_bytes = n_tokens.to_ne_bytes();
645 let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
646
647 use super::encode_helpers::{encode_with_args, KernelArg};
648
649 encode_with_args(
650 encoder,
651 pipeline,
652 &[
653 (0, KernelArg::Buffer(src)),
654 (1, KernelArg::Buffer(cache)),
655 (2, KernelArg::Bytes(&n_heads_bytes)),
656 (3, KernelArg::Bytes(&head_dim_bytes)),
657 (4, KernelArg::Bytes(&capacity_bytes)),
658 (5, KernelArg::Bytes(&seq_pos_start_bytes)),
659 (6, KernelArg::Bytes(&n_tokens_bytes)),
660 (7, KernelArg::Bytes(&src_tok_offset_bytes)),
661 ],
662 MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
663 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
664 );
665
666 Ok(())
667}