1use metal::MTLSize;
12
13use crate::buffer::MlxBuffer;
14use crate::encoder::CommandEncoder;
15use crate::error::{MlxError, Result};
16use crate::kernel_registry::KernelRegistry;
17
18use super::encode_helpers::{encode_with_args, KernelArg};
19
20pub static KV_CACHE_COPY_SHADER_SOURCE: &str = include_str!("../shaders/kv_cache_copy.metal");
22
23pub fn register(registry: &mut KernelRegistry) {
25 registry.register_source("kv_cache_copy", KV_CACHE_COPY_SHADER_SOURCE);
26}
27
28#[allow(clippy::too_many_arguments)]
49pub fn dispatch_kv_cache_copy(
50 encoder: &mut CommandEncoder,
51 registry: &mut KernelRegistry,
52 device: &metal::DeviceRef,
53 src: &MlxBuffer,
54 cache: &MlxBuffer,
55 write_pos: u32,
56 row_size: u32,
57 n_new: u32,
58 cache_cap: u32,
59 is_sliding: bool,
60) -> Result<()> {
61 if n_new == 0 || row_size == 0 {
62 return Ok(()); }
64
65 let total_elements = (n_new as u64) * (row_size as u64);
66 let src_elements = src.element_count() as u64;
67 if src_elements < total_elements {
68 return Err(MlxError::InvalidArgument(format!(
69 "kv_cache_copy: src has {} elements but need {} (n_new={} * row_size={})",
70 src_elements, total_elements, n_new, row_size
71 )));
72 }
73
74 if !is_sliding && (write_pos as u64 + n_new as u64) > cache_cap as u64 {
76 return Err(MlxError::InvalidArgument(format!(
77 "kv_cache_copy: global cache overflow: write_pos({}) + n_new({}) > cache_cap({})",
78 write_pos, n_new, cache_cap
79 )));
80 }
81
82 let pipeline = registry.get_pipeline("kv_cache_copy", device)?;
83
84 let is_sliding_val: u32 = if is_sliding { 1 } else { 0 };
85
86 let write_pos_bytes = write_pos.to_ne_bytes();
88 let row_size_bytes = row_size.to_ne_bytes();
89 let n_new_bytes = n_new.to_ne_bytes();
90 let cache_cap_bytes = cache_cap.to_ne_bytes();
91 let is_sliding_bytes = is_sliding_val.to_ne_bytes();
92
93 encode_with_args(
94 encoder,
95 pipeline,
96 &[
97 (0, KernelArg::Buffer(src)),
98 (1, KernelArg::Buffer(cache)),
99 (2, KernelArg::Bytes(&write_pos_bytes)),
100 (3, KernelArg::Bytes(&row_size_bytes)),
101 (4, KernelArg::Bytes(&n_new_bytes)),
102 (5, KernelArg::Bytes(&cache_cap_bytes)),
103 (6, KernelArg::Bytes(&is_sliding_bytes)),
104 ],
105 MTLSize::new(total_elements, 1, 1),
106 MTLSize::new(std::cmp::min(256, total_elements), 1, 1),
107 );
108
109 Ok(())
110}
111
112#[allow(clippy::too_many_arguments)]
131pub fn dispatch_kv_cache_copy_batch_f32(
132 encoder: &mut CommandEncoder,
133 registry: &mut KernelRegistry,
134 device: &metal::DeviceRef,
135 src: &MlxBuffer,
136 cache: &MlxBuffer,
137 n_heads: u32,
138 head_dim: u32,
139 capacity: u32,
140 seq_pos: u32,
141) -> Result<()> {
142 if n_heads == 0 || head_dim == 0 {
143 return Ok(());
144 }
145
146 let total_src = (n_heads as u64) * (head_dim as u64);
147 if (src.element_count() as u64) < total_src {
148 return Err(MlxError::InvalidArgument(format!(
149 "kv_cache_copy_batch_f32: src has {} elements but need {} (n_heads={} * head_dim={})",
150 src.element_count(), total_src, n_heads, head_dim
151 )));
152 }
153
154 let pipeline = registry.get_pipeline("kv_cache_copy_batch_f32", device)?;
155
156 let n_heads_bytes = n_heads.to_ne_bytes();
157 let head_dim_bytes = head_dim.to_ne_bytes();
158 let capacity_bytes = capacity.to_ne_bytes();
159 let seq_pos_bytes = seq_pos.to_ne_bytes();
160
161 use super::encode_helpers::{encode_with_args, KernelArg};
162
163 encode_with_args(
164 encoder,
165 pipeline,
166 &[
167 (0, KernelArg::Buffer(src)),
168 (1, KernelArg::Buffer(cache)),
169 (2, KernelArg::Bytes(&n_heads_bytes)),
170 (3, KernelArg::Bytes(&head_dim_bytes)),
171 (4, KernelArg::Bytes(&capacity_bytes)),
172 (5, KernelArg::Bytes(&seq_pos_bytes)),
173 ],
174 MTLSize::new(head_dim as u64, n_heads as u64, 1),
175 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
176 );
177
178 Ok(())
179}
180
181#[allow(clippy::too_many_arguments)]
205pub fn dispatch_kv_cache_copy_f32(
206 encoder: &mut CommandEncoder,
207 registry: &mut KernelRegistry,
208 device: &metal::DeviceRef,
209 src: &MlxBuffer,
210 cache: &MlxBuffer,
211 write_pos: u32,
212 row_size: u32,
213 n_new: u32,
214 cache_cap: u32,
215 is_sliding: bool,
216) -> Result<()> {
217 if n_new == 0 || row_size == 0 {
218 return Ok(()); }
220
221 let total_elements = (n_new as u64) * (row_size as u64);
222 let src_elements = src.element_count() as u64;
223 if src_elements < total_elements {
224 return Err(MlxError::InvalidArgument(format!(
225 "kv_cache_copy_f32: src has {} elements but need {} (n_new={} * row_size={})",
226 src_elements, total_elements, n_new, row_size
227 )));
228 }
229
230 if !is_sliding && (write_pos as u64 + n_new as u64) > cache_cap as u64 {
232 return Err(MlxError::InvalidArgument(format!(
233 "kv_cache_copy_f32: global cache overflow: write_pos({}) + n_new({}) > cache_cap({})",
234 write_pos, n_new, cache_cap
235 )));
236 }
237
238 let pipeline = registry.get_pipeline("kv_cache_copy_f32", device)?;
239
240 let is_sliding_val: u32 = if is_sliding { 1 } else { 0 };
241
242 let write_pos_bytes = write_pos.to_ne_bytes();
243 let row_size_bytes = row_size.to_ne_bytes();
244 let n_new_bytes = n_new.to_ne_bytes();
245 let cache_cap_bytes = cache_cap.to_ne_bytes();
246 let is_sliding_bytes = is_sliding_val.to_ne_bytes();
247
248 encode_with_args(
249 encoder,
250 pipeline,
251 &[
252 (0, KernelArg::Buffer(src)),
253 (1, KernelArg::Buffer(cache)),
254 (2, KernelArg::Bytes(&write_pos_bytes)),
255 (3, KernelArg::Bytes(&row_size_bytes)),
256 (4, KernelArg::Bytes(&n_new_bytes)),
257 (5, KernelArg::Bytes(&cache_cap_bytes)),
258 (6, KernelArg::Bytes(&is_sliding_bytes)),
259 ],
260 MTLSize::new(total_elements, 1, 1),
261 MTLSize::new(std::cmp::min(256, total_elements), 1, 1),
262 );
263
264 Ok(())
265}
266
267#[allow(clippy::too_many_arguments)]
276pub fn dispatch_kv_cache_copy_batch_f32_to_f16(
277 encoder: &mut CommandEncoder,
278 registry: &mut KernelRegistry,
279 device: &metal::DeviceRef,
280 src: &MlxBuffer,
281 cache: &MlxBuffer,
282 n_heads: u32,
283 head_dim: u32,
284 capacity: u32,
285 seq_pos: u32,
286) -> Result<()> {
287 if n_heads == 0 || head_dim == 0 {
288 return Ok(());
289 }
290
291 let total_src = (n_heads as u64) * (head_dim as u64);
292 if (src.element_count() as u64) < total_src {
293 return Err(MlxError::InvalidArgument(format!(
294 "kv_cache_copy_batch_f32_to_f16: src has {} elements but need {} (n_heads={} * head_dim={})",
295 src.element_count(), total_src, n_heads, head_dim
296 )));
297 }
298
299 let pipeline = registry.get_pipeline("kv_cache_copy_batch_f32_to_f16", device)?;
300
301 let n_heads_bytes = n_heads.to_ne_bytes();
302 let head_dim_bytes = head_dim.to_ne_bytes();
303 let capacity_bytes = capacity.to_ne_bytes();
304 let seq_pos_bytes = seq_pos.to_ne_bytes();
305
306 use super::encode_helpers::{encode_with_args, KernelArg};
307
308 encode_with_args(
309 encoder,
310 pipeline,
311 &[
312 (0, KernelArg::Buffer(src)),
313 (1, KernelArg::Buffer(cache)),
314 (2, KernelArg::Bytes(&n_heads_bytes)),
315 (3, KernelArg::Bytes(&head_dim_bytes)),
316 (4, KernelArg::Bytes(&capacity_bytes)),
317 (5, KernelArg::Bytes(&seq_pos_bytes)),
318 ],
319 MTLSize::new(head_dim as u64, n_heads as u64, 1),
320 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
321 );
322
323 Ok(())
324}
325
326#[allow(clippy::too_many_arguments)]
338pub fn dispatch_kv_cache_copy_batch_f32_kv_dual(
339 encoder: &mut CommandEncoder,
340 registry: &mut KernelRegistry,
341 device: &metal::DeviceRef,
342 src_k: &MlxBuffer,
343 src_v: &MlxBuffer,
344 cache_k: &MlxBuffer,
345 cache_v: &MlxBuffer,
346 n_heads: u32,
347 head_dim: u32,
348 capacity: u32,
349 seq_pos: u32,
350) -> Result<()> {
351 if n_heads == 0 || head_dim == 0 {
352 return Ok(());
353 }
354
355 let total_src = (n_heads as u64) * (head_dim as u64);
356 if (src_k.element_count() as u64) < total_src {
357 return Err(MlxError::InvalidArgument(format!(
358 "kv_cache_copy_batch_f32_kv_dual: src_k has {} elements but need {} (n_heads={} * head_dim={})",
359 src_k.element_count(), total_src, n_heads, head_dim
360 )));
361 }
362 if (src_v.element_count() as u64) < total_src {
363 return Err(MlxError::InvalidArgument(format!(
364 "kv_cache_copy_batch_f32_kv_dual: src_v has {} elements but need {} (n_heads={} * head_dim={})",
365 src_v.element_count(), total_src, n_heads, head_dim
366 )));
367 }
368
369 let pipeline = registry.get_pipeline("kv_cache_copy_batch_f32_kv_dual", device)?;
370
371 let n_heads_bytes = n_heads.to_ne_bytes();
372 let head_dim_bytes = head_dim.to_ne_bytes();
373 let capacity_bytes = capacity.to_ne_bytes();
374 let seq_pos_bytes = seq_pos.to_ne_bytes();
375
376 encode_with_args(
377 encoder,
378 pipeline,
379 &[
380 (0, KernelArg::Buffer(src_k)),
381 (1, KernelArg::Buffer(src_v)),
382 (2, KernelArg::Buffer(cache_k)),
383 (3, KernelArg::Buffer(cache_v)),
384 (4, KernelArg::Bytes(&n_heads_bytes)),
385 (5, KernelArg::Bytes(&head_dim_bytes)),
386 (6, KernelArg::Bytes(&capacity_bytes)),
387 (7, KernelArg::Bytes(&seq_pos_bytes)),
388 ],
389 MTLSize::new(head_dim as u64, n_heads as u64, 1),
390 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
391 );
392
393 Ok(())
394}
395
396#[allow(clippy::too_many_arguments)]
401pub fn dispatch_kv_cache_copy_batch_f32_to_f16_kv_dual(
402 encoder: &mut CommandEncoder,
403 registry: &mut KernelRegistry,
404 device: &metal::DeviceRef,
405 src_k: &MlxBuffer,
406 src_v: &MlxBuffer,
407 cache_k: &MlxBuffer,
408 cache_v: &MlxBuffer,
409 n_heads: u32,
410 head_dim: u32,
411 capacity: u32,
412 seq_pos: u32,
413) -> Result<()> {
414 if n_heads == 0 || head_dim == 0 {
415 return Ok(());
416 }
417
418 let total_src = (n_heads as u64) * (head_dim as u64);
419 if (src_k.element_count() as u64) < total_src {
420 return Err(MlxError::InvalidArgument(format!(
421 "kv_cache_copy_batch_f32_to_f16_kv_dual: src_k has {} elements but need {} (n_heads={} * head_dim={})",
422 src_k.element_count(), total_src, n_heads, head_dim
423 )));
424 }
425 if (src_v.element_count() as u64) < total_src {
426 return Err(MlxError::InvalidArgument(format!(
427 "kv_cache_copy_batch_f32_to_f16_kv_dual: src_v has {} elements but need {} (n_heads={} * head_dim={})",
428 src_v.element_count(), total_src, n_heads, head_dim
429 )));
430 }
431
432 let pipeline = registry.get_pipeline("kv_cache_copy_batch_f32_to_f16_kv_dual", device)?;
433
434 let n_heads_bytes = n_heads.to_ne_bytes();
435 let head_dim_bytes = head_dim.to_ne_bytes();
436 let capacity_bytes = capacity.to_ne_bytes();
437 let seq_pos_bytes = seq_pos.to_ne_bytes();
438
439 encode_with_args(
440 encoder,
441 pipeline,
442 &[
443 (0, KernelArg::Buffer(src_k)),
444 (1, KernelArg::Buffer(src_v)),
445 (2, KernelArg::Buffer(cache_k)),
446 (3, KernelArg::Buffer(cache_v)),
447 (4, KernelArg::Bytes(&n_heads_bytes)),
448 (5, KernelArg::Bytes(&head_dim_bytes)),
449 (6, KernelArg::Bytes(&capacity_bytes)),
450 (7, KernelArg::Bytes(&seq_pos_bytes)),
451 ],
452 MTLSize::new(head_dim as u64, n_heads as u64, 1),
453 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
454 );
455
456 Ok(())
457}
458
459#[allow(clippy::too_many_arguments)]
477pub fn dispatch_kv_cache_copy_seq_f32(
478 encoder: &mut CommandEncoder,
479 registry: &mut KernelRegistry,
480 device: &metal::DeviceRef,
481 src: &MlxBuffer,
482 cache: &MlxBuffer,
483 n_heads: u32,
484 head_dim: u32,
485 capacity: u32,
486 seq_pos_start: u32,
487 n_tokens: u32,
488 src_tok_offset: u32,
489) -> Result<()> {
490 if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
491 return Ok(());
492 }
493 let total_src = ((src_tok_offset as u64) + (n_tokens as u64))
494 * (n_heads as u64) * (head_dim as u64);
495 if (src.element_count() as u64) < total_src {
496 return Err(MlxError::InvalidArgument(format!(
497 "kv_cache_copy_seq_f32: src has {} elements, need {} ((src_tok_offset={} + n_tokens={}) * n_heads={} * head_dim={})",
498 src.element_count(), total_src, src_tok_offset, n_tokens, n_heads, head_dim
499 )));
500 }
501
502 let pipeline = registry.get_pipeline("kv_cache_copy_seq_f32", device)?;
503
504 let n_heads_bytes = n_heads.to_ne_bytes();
505 let head_dim_bytes = head_dim.to_ne_bytes();
506 let capacity_bytes = capacity.to_ne_bytes();
507 let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
508 let n_tokens_bytes = n_tokens.to_ne_bytes();
509 let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
510
511 use super::encode_helpers::{encode_with_args, KernelArg};
512
513 encode_with_args(
514 encoder,
515 pipeline,
516 &[
517 (0, KernelArg::Buffer(src)),
518 (1, KernelArg::Buffer(cache)),
519 (2, KernelArg::Bytes(&n_heads_bytes)),
520 (3, KernelArg::Bytes(&head_dim_bytes)),
521 (4, KernelArg::Bytes(&capacity_bytes)),
522 (5, KernelArg::Bytes(&seq_pos_start_bytes)),
523 (6, KernelArg::Bytes(&n_tokens_bytes)),
524 (7, KernelArg::Bytes(&src_tok_offset_bytes)),
525 ],
526 MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
527 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
528 );
529
530 Ok(())
531}
532
533#[allow(clippy::too_many_arguments)]
542pub fn dispatch_kv_cache_copy_seq_f32_dual(
543 encoder: &mut CommandEncoder,
544 registry: &mut KernelRegistry,
545 device: &metal::DeviceRef,
546 src_k: &MlxBuffer,
547 src_v: &MlxBuffer,
548 cache_k: &MlxBuffer,
549 cache_v: &MlxBuffer,
550 n_heads: u32,
551 head_dim: u32,
552 capacity: u32,
553 seq_pos_start: u32,
554 n_tokens: u32,
555 src_tok_offset: u32,
556) -> Result<()> {
557 if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
558 return Ok(());
559 }
560 let total_src = ((src_tok_offset as u64) + (n_tokens as u64))
561 * (n_heads as u64) * (head_dim as u64);
562 for (name, b) in [("src_k", src_k), ("src_v", src_v)] {
563 if (b.element_count() as u64) < total_src {
564 return Err(MlxError::InvalidArgument(format!(
565 "kv_cache_copy_seq_f32_dual: {} has {} elements, need {}",
566 name, b.element_count(), total_src
567 )));
568 }
569 }
570
571 let pipeline = registry.get_pipeline("kv_cache_copy_seq_f32_kv_dual", device)?;
572
573 let n_heads_bytes = n_heads.to_ne_bytes();
574 let head_dim_bytes = head_dim.to_ne_bytes();
575 let capacity_bytes = capacity.to_ne_bytes();
576 let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
577 let n_tokens_bytes = n_tokens.to_ne_bytes();
578 let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
579
580 use super::encode_helpers::{encode_with_args, KernelArg};
581
582 encode_with_args(
583 encoder,
584 pipeline,
585 &[
586 (0, KernelArg::Buffer(src_k)),
587 (1, KernelArg::Buffer(src_v)),
588 (2, KernelArg::Buffer(cache_k)),
589 (3, KernelArg::Buffer(cache_v)),
590 (4, KernelArg::Bytes(&n_heads_bytes)),
591 (5, KernelArg::Bytes(&head_dim_bytes)),
592 (6, KernelArg::Bytes(&capacity_bytes)),
593 (7, KernelArg::Bytes(&seq_pos_start_bytes)),
594 (8, KernelArg::Bytes(&n_tokens_bytes)),
595 (9, KernelArg::Bytes(&src_tok_offset_bytes)),
596 ],
597 MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
598 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
599 );
600
601 Ok(())
602}
603
604#[allow(clippy::too_many_arguments)]
607pub fn dispatch_kv_cache_copy_seq_f32_to_f16_dual(
608 encoder: &mut CommandEncoder,
609 registry: &mut KernelRegistry,
610 device: &metal::DeviceRef,
611 src_k: &MlxBuffer,
612 src_v: &MlxBuffer,
613 cache_k: &MlxBuffer,
614 cache_v: &MlxBuffer,
615 n_heads: u32,
616 head_dim: u32,
617 capacity: u32,
618 seq_pos_start: u32,
619 n_tokens: u32,
620 src_tok_offset: u32,
621) -> Result<()> {
622 if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
623 return Ok(());
624 }
625 let total_src = ((src_tok_offset as u64) + (n_tokens as u64))
626 * (n_heads as u64) * (head_dim as u64);
627 for (name, b) in [("src_k", src_k), ("src_v", src_v)] {
628 if (b.element_count() as u64) < total_src {
629 return Err(MlxError::InvalidArgument(format!(
630 "kv_cache_copy_seq_f32_to_f16_dual: {} has {} elements, need {}",
631 name, b.element_count(), total_src
632 )));
633 }
634 }
635
636 let pipeline = registry.get_pipeline("kv_cache_copy_seq_f32_to_f16_kv_dual", device)?;
637
638 let n_heads_bytes = n_heads.to_ne_bytes();
639 let head_dim_bytes = head_dim.to_ne_bytes();
640 let capacity_bytes = capacity.to_ne_bytes();
641 let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
642 let n_tokens_bytes = n_tokens.to_ne_bytes();
643 let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
644
645 use super::encode_helpers::{encode_with_args, KernelArg};
646
647 encode_with_args(
648 encoder,
649 pipeline,
650 &[
651 (0, KernelArg::Buffer(src_k)),
652 (1, KernelArg::Buffer(src_v)),
653 (2, KernelArg::Buffer(cache_k)),
654 (3, KernelArg::Buffer(cache_v)),
655 (4, KernelArg::Bytes(&n_heads_bytes)),
656 (5, KernelArg::Bytes(&head_dim_bytes)),
657 (6, KernelArg::Bytes(&capacity_bytes)),
658 (7, KernelArg::Bytes(&seq_pos_start_bytes)),
659 (8, KernelArg::Bytes(&n_tokens_bytes)),
660 (9, KernelArg::Bytes(&src_tok_offset_bytes)),
661 ],
662 MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
663 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
664 );
665
666 Ok(())
667}
668
669#[allow(clippy::too_many_arguments)]
682pub fn dispatch_kv_cache_copy_seq_bf16(
683 encoder: &mut CommandEncoder,
684 registry: &mut KernelRegistry,
685 device: &metal::DeviceRef,
686 src: &MlxBuffer,
687 cache: &MlxBuffer,
688 n_heads: u32,
689 head_dim: u32,
690 capacity: u32,
691 seq_pos_start: u32,
692 n_tokens: u32,
693 src_tok_offset: u32,
694) -> Result<()> {
695 if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
696 return Ok(());
697 }
698 let total_src = ((src_tok_offset as u64) + (n_tokens as u64))
700 * (n_heads as u64) * (head_dim as u64);
701 let src_bytes_needed = total_src * 2; if (src.byte_len() as u64) < src_bytes_needed {
703 return Err(MlxError::InvalidArgument(format!(
704 "kv_cache_copy_seq_bf16: src has {} bytes, need {} ((src_tok_offset={} + n_tokens={}) * n_heads={} * head_dim={} * 2)",
705 src.byte_len(), src_bytes_needed, src_tok_offset, n_tokens, n_heads, head_dim
706 )));
707 }
708
709 let pipeline = registry.get_pipeline("kv_cache_copy_seq_bf16", device)?;
710
711 let n_heads_bytes = n_heads.to_ne_bytes();
712 let head_dim_bytes = head_dim.to_ne_bytes();
713 let capacity_bytes = capacity.to_ne_bytes();
714 let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
715 let n_tokens_bytes = n_tokens.to_ne_bytes();
716 let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
717
718 use super::encode_helpers::{encode_with_args, KernelArg};
719
720 encode_with_args(
721 encoder,
722 pipeline,
723 &[
724 (0, KernelArg::Buffer(src)),
725 (1, KernelArg::Buffer(cache)),
726 (2, KernelArg::Bytes(&n_heads_bytes)),
727 (3, KernelArg::Bytes(&head_dim_bytes)),
728 (4, KernelArg::Bytes(&capacity_bytes)),
729 (5, KernelArg::Bytes(&seq_pos_start_bytes)),
730 (6, KernelArg::Bytes(&n_tokens_bytes)),
731 (7, KernelArg::Bytes(&src_tok_offset_bytes)),
732 ],
733 MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
734 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
735 );
736
737 Ok(())
738}
739
740#[allow(clippy::too_many_arguments)]
746pub fn dispatch_kv_cache_copy_seq_f32_to_f16(
747 encoder: &mut CommandEncoder,
748 registry: &mut KernelRegistry,
749 device: &metal::DeviceRef,
750 src: &MlxBuffer,
751 cache: &MlxBuffer,
752 n_heads: u32,
753 head_dim: u32,
754 capacity: u32,
755 seq_pos_start: u32,
756 n_tokens: u32,
757 src_tok_offset: u32,
758) -> Result<()> {
759 if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
760 return Ok(());
761 }
762 let total_src = ((src_tok_offset as u64) + (n_tokens as u64))
763 * (n_heads as u64) * (head_dim as u64);
764 if (src.element_count() as u64) < total_src {
765 return Err(MlxError::InvalidArgument(format!(
766 "kv_cache_copy_seq_f32_to_f16: src has {} elements, need {}",
767 src.element_count(), total_src
768 )));
769 }
770
771 let pipeline = registry.get_pipeline("kv_cache_copy_seq_f32_to_f16", device)?;
772
773 let n_heads_bytes = n_heads.to_ne_bytes();
774 let head_dim_bytes = head_dim.to_ne_bytes();
775 let capacity_bytes = capacity.to_ne_bytes();
776 let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
777 let n_tokens_bytes = n_tokens.to_ne_bytes();
778 let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
779
780 use super::encode_helpers::{encode_with_args, KernelArg};
781
782 encode_with_args(
783 encoder,
784 pipeline,
785 &[
786 (0, KernelArg::Buffer(src)),
787 (1, KernelArg::Buffer(cache)),
788 (2, KernelArg::Bytes(&n_heads_bytes)),
789 (3, KernelArg::Bytes(&head_dim_bytes)),
790 (4, KernelArg::Bytes(&capacity_bytes)),
791 (5, KernelArg::Bytes(&seq_pos_start_bytes)),
792 (6, KernelArg::Bytes(&n_tokens_bytes)),
793 (7, KernelArg::Bytes(&src_tok_offset_bytes)),
794 ],
795 MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
796 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
797 );
798
799 Ok(())
800}
801
802#[allow(clippy::too_many_arguments)]
816pub fn dispatch_kv_cache_copy_seq_bf16_to_bf16_head_major(
817 encoder: &mut CommandEncoder,
818 registry: &mut KernelRegistry,
819 device: &metal::DeviceRef,
820 src: &MlxBuffer,
821 cache: &MlxBuffer,
822 n_heads: u32,
823 head_dim: u32,
824 capacity: u32,
825 seq_pos_start: u32,
826 n_tokens: u32,
827 src_tok_offset: u32,
828 src_seq_len: u32,
829) -> Result<()> {
830 if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
831 return Ok(());
832 }
833 let total_src = (n_heads as u64) * (src_seq_len as u64) * (head_dim as u64);
834 if (src.element_count() as u64) < total_src {
835 return Err(MlxError::InvalidArgument(format!(
836 "kv_cache_copy_seq_bf16_to_bf16_head_major: src has {} elements, need {} \
837 ({} heads × {} src_seq_len × {} head_dim)",
838 src.element_count(), total_src, n_heads, src_seq_len, head_dim
839 )));
840 }
841 if src.dtype() != crate::DType::BF16 {
842 return Err(MlxError::InvalidArgument(format!(
843 "kv_cache_copy_seq_bf16_to_bf16_head_major: src must be BF16, got {:?}",
844 src.dtype()
845 )));
846 }
847 if cache.dtype() != crate::DType::BF16 {
848 return Err(MlxError::InvalidArgument(format!(
849 "kv_cache_copy_seq_bf16_to_bf16_head_major: cache must be BF16, got {:?}",
850 cache.dtype()
851 )));
852 }
853
854 let pipeline = registry.get_pipeline("kv_cache_copy_seq_bf16_to_bf16_head_major", device)?;
855
856 let n_heads_bytes = n_heads.to_ne_bytes();
857 let head_dim_bytes = head_dim.to_ne_bytes();
858 let capacity_bytes = capacity.to_ne_bytes();
859 let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
860 let n_tokens_bytes = n_tokens.to_ne_bytes();
861 let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
862 let src_seq_len_bytes = src_seq_len.to_ne_bytes();
863
864 use super::encode_helpers::{encode_with_args, KernelArg};
865
866 encode_with_args(
867 encoder,
868 pipeline,
869 &[
870 (0, KernelArg::Buffer(src)),
871 (1, KernelArg::Buffer(cache)),
872 (2, KernelArg::Bytes(&n_heads_bytes)),
873 (3, KernelArg::Bytes(&head_dim_bytes)),
874 (4, KernelArg::Bytes(&capacity_bytes)),
875 (5, KernelArg::Bytes(&seq_pos_start_bytes)),
876 (6, KernelArg::Bytes(&n_tokens_bytes)),
877 (7, KernelArg::Bytes(&src_tok_offset_bytes)),
878 (8, KernelArg::Bytes(&src_seq_len_bytes)),
879 ],
880 MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
881 MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
882 );
883
884 Ok(())
885}