Skip to main content

mlx_native/ops/
kv_cache_copy.rs

1//! KV cache GPU copy dispatch.
2//!
3//! Copies new K or V data directly from a source GPU buffer into a
4//! pre-allocated KV cache buffer at the correct write position, with
5//! optional modulo wrapping for sliding window (ring buffer) caches.
6//!
7//! This eliminates the CPU round-trip that `append_bf16` requires:
8//! instead of GPU -> CPU (as_slice) -> CPU (copy loop) -> shared buffer,
9//! the GPU copies directly between two shared Metal buffers.
10
11use metal::MTLSize;
12
13use crate::buffer::MlxBuffer;
14use crate::encoder::CommandEncoder;
15use crate::error::{MlxError, Result};
16use crate::kernel_registry::KernelRegistry;
17
18use super::encode_helpers::{encode_with_args, KernelArg};
19
20/// MSL source for the KV cache copy kernel (embedded at compile time).
21pub static KV_CACHE_COPY_SHADER_SOURCE: &str = include_str!("../shaders/kv_cache_copy.metal");
22
23/// Register KV cache copy shader source with the given kernel registry.
24pub fn register(registry: &mut KernelRegistry) {
25    registry.register_source("kv_cache_copy", KV_CACHE_COPY_SHADER_SOURCE);
26}
27
28/// Dispatch a GPU copy from a source bf16 buffer into a KV cache buffer.
29///
30/// Both `src` and `cache` must be bf16 Metal buffers in shared memory.
31///
32/// # Arguments
33///
34/// * `encoder`   - Command encoder to record the dispatch into.
35/// * `registry`  - Kernel registry (must have kv_cache_copy registered).
36/// * `device`    - Metal device for pipeline compilation.
37/// * `src`       - Source buffer of shape `[n_new, row_size]` (bf16).
38/// * `cache`     - Destination cache buffer (bf16, pre-allocated).
39/// * `write_pos` - Starting write position in the cache (token index).
40/// * `row_size`  - Elements per token row (`n_kv_heads * head_dim`).
41/// * `n_new`     - Number of new tokens to copy.
42/// * `cache_cap` - Cache capacity (window size for sliding, max_seq_len for global).
43/// * `is_sliding`- Whether to use modulo wrapping (`true` for sliding window).
44///
45/// # Errors
46///
47/// Returns `MlxError::InvalidArgument` if parameters are inconsistent.
48#[allow(clippy::too_many_arguments)]
49pub fn dispatch_kv_cache_copy(
50    encoder: &mut CommandEncoder,
51    registry: &mut KernelRegistry,
52    device: &metal::DeviceRef,
53    src: &MlxBuffer,
54    cache: &MlxBuffer,
55    write_pos: u32,
56    row_size: u32,
57    n_new: u32,
58    cache_cap: u32,
59    is_sliding: bool,
60) -> Result<()> {
61    if n_new == 0 || row_size == 0 {
62        return Ok(()); // Nothing to copy
63    }
64
65    let total_elements = (n_new as u64) * (row_size as u64);
66    let src_elements = src.element_count() as u64;
67    if src_elements < total_elements {
68        return Err(MlxError::InvalidArgument(format!(
69            "kv_cache_copy: src has {} elements but need {} (n_new={} * row_size={})",
70            src_elements, total_elements, n_new, row_size
71        )));
72    }
73
74    // For global (non-sliding) caches, check we won't write past capacity
75    if !is_sliding && (write_pos as u64 + n_new as u64) > cache_cap as u64 {
76        return Err(MlxError::InvalidArgument(format!(
77            "kv_cache_copy: global cache overflow: write_pos({}) + n_new({}) > cache_cap({})",
78            write_pos, n_new, cache_cap
79        )));
80    }
81
82    let pipeline = registry.get_pipeline("kv_cache_copy", device)?;
83
84    let is_sliding_val: u32 = if is_sliding { 1 } else { 0 };
85
86    // Pass each scalar as individual set_bytes calls matching buffer indices 2-6
87    let write_pos_bytes = write_pos.to_ne_bytes();
88    let row_size_bytes = row_size.to_ne_bytes();
89    let n_new_bytes = n_new.to_ne_bytes();
90    let cache_cap_bytes = cache_cap.to_ne_bytes();
91    let is_sliding_bytes = is_sliding_val.to_ne_bytes();
92
93    encode_with_args(
94        encoder,
95        pipeline,
96        &[
97            (0, KernelArg::Buffer(src)),
98            (1, KernelArg::Buffer(cache)),
99            (2, KernelArg::Bytes(&write_pos_bytes)),
100            (3, KernelArg::Bytes(&row_size_bytes)),
101            (4, KernelArg::Bytes(&n_new_bytes)),
102            (5, KernelArg::Bytes(&cache_cap_bytes)),
103            (6, KernelArg::Bytes(&is_sliding_bytes)),
104        ],
105        MTLSize::new(total_elements, 1, 1),
106        MTLSize::new(std::cmp::min(256, total_elements), 1, 1),
107    );
108
109    Ok(())
110}
111
112/// Dispatch a batched GPU copy from a source f32 buffer into a f32 KV cache.
113///
114/// Copies ALL heads in one dispatch instead of one dispatch per head.
115///
116/// Source layout: `[n_heads * head_dim]` flat (one token, all heads).
117/// Cache layout: `[n_heads, capacity, head_dim]` head-major.
118///
119/// # Arguments
120///
121/// * `encoder`   - Command encoder to record the dispatch into.
122/// * `registry`  - Kernel registry (must have kv_cache_copy_batch_f32 registered).
123/// * `device`    - Metal device for pipeline compilation.
124/// * `src`       - Source buffer of shape `[n_heads * head_dim]` (f32).
125/// * `cache`     - Destination cache buffer (f32, pre-allocated).
126/// * `n_heads`   - Number of KV heads.
127/// * `head_dim`  - Elements per head.
128/// * `capacity`  - Cache capacity (window size or max_seq_len).
129/// * `seq_pos`   - Write position in cache (already wrapped for sliding).
130#[allow(clippy::too_many_arguments)]
131pub fn dispatch_kv_cache_copy_batch_f32(
132    encoder: &mut CommandEncoder,
133    registry: &mut KernelRegistry,
134    device: &metal::DeviceRef,
135    src: &MlxBuffer,
136    cache: &MlxBuffer,
137    n_heads: u32,
138    head_dim: u32,
139    capacity: u32,
140    seq_pos: u32,
141) -> Result<()> {
142    if n_heads == 0 || head_dim == 0 {
143        return Ok(());
144    }
145
146    let total_src = (n_heads as u64) * (head_dim as u64);
147    if (src.element_count() as u64) < total_src {
148        return Err(MlxError::InvalidArgument(format!(
149            "kv_cache_copy_batch_f32: src has {} elements but need {} (n_heads={} * head_dim={})",
150            src.element_count(), total_src, n_heads, head_dim
151        )));
152    }
153
154    let pipeline = registry.get_pipeline("kv_cache_copy_batch_f32", device)?;
155
156    let n_heads_bytes = n_heads.to_ne_bytes();
157    let head_dim_bytes = head_dim.to_ne_bytes();
158    let capacity_bytes = capacity.to_ne_bytes();
159    let seq_pos_bytes = seq_pos.to_ne_bytes();
160
161    use super::encode_helpers::{encode_with_args, KernelArg};
162
163    encode_with_args(
164        encoder,
165        pipeline,
166        &[
167            (0, KernelArg::Buffer(src)),
168            (1, KernelArg::Buffer(cache)),
169            (2, KernelArg::Bytes(&n_heads_bytes)),
170            (3, KernelArg::Bytes(&head_dim_bytes)),
171            (4, KernelArg::Bytes(&capacity_bytes)),
172            (5, KernelArg::Bytes(&seq_pos_bytes)),
173        ],
174        MTLSize::new(head_dim as u64, n_heads as u64, 1),
175        MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
176    );
177
178    Ok(())
179}
180
181/// Dispatch a GPU copy from a source f32 buffer into a f32 KV cache buffer.
182///
183/// Identical to `dispatch_kv_cache_copy` but for F32 data (used when the
184/// activation pipeline operates in F32 throughout).
185///
186/// Both `src` and `cache` must be f32 Metal buffers in shared memory.
187///
188/// # Arguments
189///
190/// * `encoder`   - Command encoder to record the dispatch into.
191/// * `registry`  - Kernel registry (must have kv_cache_copy_f32 registered).
192/// * `device`    - Metal device for pipeline compilation.
193/// * `src`       - Source buffer of shape `[n_new, row_size]` (f32).
194/// * `cache`     - Destination cache buffer (f32, pre-allocated).
195/// * `write_pos` - Starting write position in the cache (token index).
196/// * `row_size`  - Elements per token row (`n_kv_heads * head_dim`).
197/// * `n_new`     - Number of new tokens to copy.
198/// * `cache_cap` - Cache capacity (window size for sliding, max_seq_len for global).
199/// * `is_sliding`- Whether to use modulo wrapping (`true` for sliding window).
200///
201/// # Errors
202///
203/// Returns `MlxError::InvalidArgument` if parameters are inconsistent.
204#[allow(clippy::too_many_arguments)]
205pub fn dispatch_kv_cache_copy_f32(
206    encoder: &mut CommandEncoder,
207    registry: &mut KernelRegistry,
208    device: &metal::DeviceRef,
209    src: &MlxBuffer,
210    cache: &MlxBuffer,
211    write_pos: u32,
212    row_size: u32,
213    n_new: u32,
214    cache_cap: u32,
215    is_sliding: bool,
216) -> Result<()> {
217    if n_new == 0 || row_size == 0 {
218        return Ok(()); // Nothing to copy
219    }
220
221    let total_elements = (n_new as u64) * (row_size as u64);
222    let src_elements = src.element_count() as u64;
223    if src_elements < total_elements {
224        return Err(MlxError::InvalidArgument(format!(
225            "kv_cache_copy_f32: src has {} elements but need {} (n_new={} * row_size={})",
226            src_elements, total_elements, n_new, row_size
227        )));
228    }
229
230    // For global (non-sliding) caches, check we won't write past capacity
231    if !is_sliding && (write_pos as u64 + n_new as u64) > cache_cap as u64 {
232        return Err(MlxError::InvalidArgument(format!(
233            "kv_cache_copy_f32: global cache overflow: write_pos({}) + n_new({}) > cache_cap({})",
234            write_pos, n_new, cache_cap
235        )));
236    }
237
238    let pipeline = registry.get_pipeline("kv_cache_copy_f32", device)?;
239
240    let is_sliding_val: u32 = if is_sliding { 1 } else { 0 };
241
242    let write_pos_bytes = write_pos.to_ne_bytes();
243    let row_size_bytes = row_size.to_ne_bytes();
244    let n_new_bytes = n_new.to_ne_bytes();
245    let cache_cap_bytes = cache_cap.to_ne_bytes();
246    let is_sliding_bytes = is_sliding_val.to_ne_bytes();
247
248    encode_with_args(
249        encoder,
250        pipeline,
251        &[
252            (0, KernelArg::Buffer(src)),
253            (1, KernelArg::Buffer(cache)),
254            (2, KernelArg::Bytes(&write_pos_bytes)),
255            (3, KernelArg::Bytes(&row_size_bytes)),
256            (4, KernelArg::Bytes(&n_new_bytes)),
257            (5, KernelArg::Bytes(&cache_cap_bytes)),
258            (6, KernelArg::Bytes(&is_sliding_bytes)),
259        ],
260        MTLSize::new(total_elements, 1, 1),
261        MTLSize::new(std::cmp::min(256, total_elements), 1, 1),
262    );
263
264    Ok(())
265}
266
267/// Dispatch a batched F32→F16 copy from a source f32 buffer into an f16 KV cache.
268///
269/// Copies ALL heads in one dispatch, casting float→half on write.
270/// This halves KV cache memory bandwidth for SDPA reads (bandwidth-bound
271/// at batch=1 decode). Reference: llama.cpp stores KV cache in F16.
272///
273/// Source layout: `[n_heads * head_dim]` flat F32 (one token, all heads).
274/// Cache layout: `[n_heads, capacity, head_dim]` head-major F16.
275#[allow(clippy::too_many_arguments)]
276pub fn dispatch_kv_cache_copy_batch_f32_to_f16(
277    encoder: &mut CommandEncoder,
278    registry: &mut KernelRegistry,
279    device: &metal::DeviceRef,
280    src: &MlxBuffer,
281    cache: &MlxBuffer,
282    n_heads: u32,
283    head_dim: u32,
284    capacity: u32,
285    seq_pos: u32,
286) -> Result<()> {
287    if n_heads == 0 || head_dim == 0 {
288        return Ok(());
289    }
290
291    let total_src = (n_heads as u64) * (head_dim as u64);
292    if (src.element_count() as u64) < total_src {
293        return Err(MlxError::InvalidArgument(format!(
294            "kv_cache_copy_batch_f32_to_f16: src has {} elements but need {} (n_heads={} * head_dim={})",
295            src.element_count(), total_src, n_heads, head_dim
296        )));
297    }
298
299    let pipeline = registry.get_pipeline("kv_cache_copy_batch_f32_to_f16", device)?;
300
301    let n_heads_bytes = n_heads.to_ne_bytes();
302    let head_dim_bytes = head_dim.to_ne_bytes();
303    let capacity_bytes = capacity.to_ne_bytes();
304    let seq_pos_bytes = seq_pos.to_ne_bytes();
305
306    use super::encode_helpers::{encode_with_args, KernelArg};
307
308    encode_with_args(
309        encoder,
310        pipeline,
311        &[
312            (0, KernelArg::Buffer(src)),
313            (1, KernelArg::Buffer(cache)),
314            (2, KernelArg::Bytes(&n_heads_bytes)),
315            (3, KernelArg::Bytes(&head_dim_bytes)),
316            (4, KernelArg::Bytes(&capacity_bytes)),
317            (5, KernelArg::Bytes(&seq_pos_bytes)),
318        ],
319        MTLSize::new(head_dim as u64, n_heads as u64, 1),
320        MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
321    );
322
323    Ok(())
324}
325
326/// Fused single-position K + V cache copy (F32 source → F32 cache) — DECODE shape.
327///
328/// ADR-028 iter-145: collapses the 2-dispatch pattern (1× K, 1× V) into a single
329/// dispatch. Saves one kernel launch floor (~14 µs/Apple GPU) per layer per token.
330/// At gemma4 30 layers, drops 60→30 KV-copy dispatches/decode-token.
331///
332/// Source layouts: `[n_heads * head_dim]` flat F32 each (one token, all heads).
333/// Cache layouts:  `[n_heads, capacity, head_dim]` head-major F32 each.
334///
335/// Each thread copies one (K, V) element pair at the same coords; results are
336/// byte-identical to two `dispatch_kv_cache_copy_batch_f32` calls.
337#[allow(clippy::too_many_arguments)]
338pub fn dispatch_kv_cache_copy_batch_f32_kv_dual(
339    encoder: &mut CommandEncoder,
340    registry: &mut KernelRegistry,
341    device: &metal::DeviceRef,
342    src_k: &MlxBuffer,
343    src_v: &MlxBuffer,
344    cache_k: &MlxBuffer,
345    cache_v: &MlxBuffer,
346    n_heads: u32,
347    head_dim: u32,
348    capacity: u32,
349    seq_pos: u32,
350) -> Result<()> {
351    if n_heads == 0 || head_dim == 0 {
352        return Ok(());
353    }
354
355    let total_src = (n_heads as u64) * (head_dim as u64);
356    if (src_k.element_count() as u64) < total_src {
357        return Err(MlxError::InvalidArgument(format!(
358            "kv_cache_copy_batch_f32_kv_dual: src_k has {} elements but need {} (n_heads={} * head_dim={})",
359            src_k.element_count(), total_src, n_heads, head_dim
360        )));
361    }
362    if (src_v.element_count() as u64) < total_src {
363        return Err(MlxError::InvalidArgument(format!(
364            "kv_cache_copy_batch_f32_kv_dual: src_v has {} elements but need {} (n_heads={} * head_dim={})",
365            src_v.element_count(), total_src, n_heads, head_dim
366        )));
367    }
368
369    let pipeline = registry.get_pipeline("kv_cache_copy_batch_f32_kv_dual", device)?;
370
371    let n_heads_bytes = n_heads.to_ne_bytes();
372    let head_dim_bytes = head_dim.to_ne_bytes();
373    let capacity_bytes = capacity.to_ne_bytes();
374    let seq_pos_bytes = seq_pos.to_ne_bytes();
375
376    encode_with_args(
377        encoder,
378        pipeline,
379        &[
380            (0, KernelArg::Buffer(src_k)),
381            (1, KernelArg::Buffer(src_v)),
382            (2, KernelArg::Buffer(cache_k)),
383            (3, KernelArg::Buffer(cache_v)),
384            (4, KernelArg::Bytes(&n_heads_bytes)),
385            (5, KernelArg::Bytes(&head_dim_bytes)),
386            (6, KernelArg::Bytes(&capacity_bytes)),
387            (7, KernelArg::Bytes(&seq_pos_bytes)),
388        ],
389        MTLSize::new(head_dim as u64, n_heads as u64, 1),
390        MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
391    );
392
393    Ok(())
394}
395
396/// Fused single-position K + V cache copy (F32 source → F16 cache) — DECODE shape.
397///
398/// Same as `dispatch_kv_cache_copy_batch_f32_kv_dual` but casts F32→F16 on write
399/// for the use_f16_kv branch. Halves SDPA-read bandwidth post-write.
400#[allow(clippy::too_many_arguments)]
401pub fn dispatch_kv_cache_copy_batch_f32_to_f16_kv_dual(
402    encoder: &mut CommandEncoder,
403    registry: &mut KernelRegistry,
404    device: &metal::DeviceRef,
405    src_k: &MlxBuffer,
406    src_v: &MlxBuffer,
407    cache_k: &MlxBuffer,
408    cache_v: &MlxBuffer,
409    n_heads: u32,
410    head_dim: u32,
411    capacity: u32,
412    seq_pos: u32,
413) -> Result<()> {
414    if n_heads == 0 || head_dim == 0 {
415        return Ok(());
416    }
417
418    let total_src = (n_heads as u64) * (head_dim as u64);
419    if (src_k.element_count() as u64) < total_src {
420        return Err(MlxError::InvalidArgument(format!(
421            "kv_cache_copy_batch_f32_to_f16_kv_dual: src_k has {} elements but need {} (n_heads={} * head_dim={})",
422            src_k.element_count(), total_src, n_heads, head_dim
423        )));
424    }
425    if (src_v.element_count() as u64) < total_src {
426        return Err(MlxError::InvalidArgument(format!(
427            "kv_cache_copy_batch_f32_to_f16_kv_dual: src_v has {} elements but need {} (n_heads={} * head_dim={})",
428            src_v.element_count(), total_src, n_heads, head_dim
429        )));
430    }
431
432    let pipeline = registry.get_pipeline("kv_cache_copy_batch_f32_to_f16_kv_dual", device)?;
433
434    let n_heads_bytes = n_heads.to_ne_bytes();
435    let head_dim_bytes = head_dim.to_ne_bytes();
436    let capacity_bytes = capacity.to_ne_bytes();
437    let seq_pos_bytes = seq_pos.to_ne_bytes();
438
439    encode_with_args(
440        encoder,
441        pipeline,
442        &[
443            (0, KernelArg::Buffer(src_k)),
444            (1, KernelArg::Buffer(src_v)),
445            (2, KernelArg::Buffer(cache_k)),
446            (3, KernelArg::Buffer(cache_v)),
447            (4, KernelArg::Bytes(&n_heads_bytes)),
448            (5, KernelArg::Bytes(&head_dim_bytes)),
449            (6, KernelArg::Bytes(&capacity_bytes)),
450            (7, KernelArg::Bytes(&seq_pos_bytes)),
451        ],
452        MTLSize::new(head_dim as u64, n_heads as u64, 1),
453        MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
454    );
455
456    Ok(())
457}
458
459/// Multi-position, all-heads KV cache copy (F32 → F32 cache, batched prefill).
460///
461/// Source layout: `[n_src_tokens, n_heads, head_dim]` (token-major). The
462/// kernel reads `[src_tok_offset, src_tok_offset + n_tokens)` from it.
463/// Cache layout:  `[n_heads, capacity, head_dim]` (head-major).
464/// Writes absolute positions `[seq_pos_start, seq_pos_start + n_tokens)` into
465/// cache slots `dst_pos % capacity`.
466///
467/// Global-layer contract: caller sets `seq_pos_start + n_tokens <= capacity`
468/// so `dst_pos % capacity == dst_pos` and writes are linear. Typical call:
469/// `src_tok_offset = 0`, `n_tokens = seq_len`, `seq_pos_start = 0`.
470///
471/// Sliding-window contract: caller sets `capacity = sliding_window`,
472/// `n_tokens = min(seq_len, capacity)`, `src_tok_offset = seq_len - n_tokens`,
473/// `seq_pos_start = seq_len - n_tokens`. This writes the last `n_tokens`
474/// source tokens into modular slots exactly once — no intra-dispatch race.
475/// Decode side reads via `ring_start = write_pos % capacity`.
476#[allow(clippy::too_many_arguments)]
477pub fn dispatch_kv_cache_copy_seq_f32(
478    encoder: &mut CommandEncoder,
479    registry: &mut KernelRegistry,
480    device: &metal::DeviceRef,
481    src: &MlxBuffer,
482    cache: &MlxBuffer,
483    n_heads: u32,
484    head_dim: u32,
485    capacity: u32,
486    seq_pos_start: u32,
487    n_tokens: u32,
488    src_tok_offset: u32,
489) -> Result<()> {
490    if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
491        return Ok(());
492    }
493    let total_src = ((src_tok_offset as u64) + (n_tokens as u64))
494        * (n_heads as u64) * (head_dim as u64);
495    if (src.element_count() as u64) < total_src {
496        return Err(MlxError::InvalidArgument(format!(
497            "kv_cache_copy_seq_f32: src has {} elements, need {} ((src_tok_offset={} + n_tokens={}) * n_heads={} * head_dim={})",
498            src.element_count(), total_src, src_tok_offset, n_tokens, n_heads, head_dim
499        )));
500    }
501
502    let pipeline = registry.get_pipeline("kv_cache_copy_seq_f32", device)?;
503
504    let n_heads_bytes = n_heads.to_ne_bytes();
505    let head_dim_bytes = head_dim.to_ne_bytes();
506    let capacity_bytes = capacity.to_ne_bytes();
507    let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
508    let n_tokens_bytes = n_tokens.to_ne_bytes();
509    let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
510
511    use super::encode_helpers::{encode_with_args, KernelArg};
512
513    encode_with_args(
514        encoder,
515        pipeline,
516        &[
517            (0, KernelArg::Buffer(src)),
518            (1, KernelArg::Buffer(cache)),
519            (2, KernelArg::Bytes(&n_heads_bytes)),
520            (3, KernelArg::Bytes(&head_dim_bytes)),
521            (4, KernelArg::Bytes(&capacity_bytes)),
522            (5, KernelArg::Bytes(&seq_pos_start_bytes)),
523            (6, KernelArg::Bytes(&n_tokens_bytes)),
524            (7, KernelArg::Bytes(&src_tok_offset_bytes)),
525        ],
526        MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
527        MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
528    );
529
530    Ok(())
531}
532
533/// Fused K + V cache copy (F32 source → F32 cache).  Wave P4.11.
534///
535/// Combines two `dispatch_kv_cache_copy_seq_f32` calls (one for K, one
536/// for V) into one dispatch.  Both streams share identical metadata
537/// (n_heads, head_dim, capacity, seq_pos_start, n_tokens,
538/// src_tok_offset) and are independently addressed in src/cache, so a
539/// single thread can copy one (K, V) element pair at the same
540/// coordinates.  Saves 1 dispatch per layer (30/prefill on Gemma 4).
541#[allow(clippy::too_many_arguments)]
542pub fn dispatch_kv_cache_copy_seq_f32_dual(
543    encoder: &mut CommandEncoder,
544    registry: &mut KernelRegistry,
545    device: &metal::DeviceRef,
546    src_k: &MlxBuffer,
547    src_v: &MlxBuffer,
548    cache_k: &MlxBuffer,
549    cache_v: &MlxBuffer,
550    n_heads: u32,
551    head_dim: u32,
552    capacity: u32,
553    seq_pos_start: u32,
554    n_tokens: u32,
555    src_tok_offset: u32,
556) -> Result<()> {
557    if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
558        return Ok(());
559    }
560    let total_src = ((src_tok_offset as u64) + (n_tokens as u64))
561        * (n_heads as u64) * (head_dim as u64);
562    for (name, b) in [("src_k", src_k), ("src_v", src_v)] {
563        if (b.element_count() as u64) < total_src {
564            return Err(MlxError::InvalidArgument(format!(
565                "kv_cache_copy_seq_f32_dual: {} has {} elements, need {}",
566                name, b.element_count(), total_src
567            )));
568        }
569    }
570
571    let pipeline = registry.get_pipeline("kv_cache_copy_seq_f32_kv_dual", device)?;
572
573    let n_heads_bytes = n_heads.to_ne_bytes();
574    let head_dim_bytes = head_dim.to_ne_bytes();
575    let capacity_bytes = capacity.to_ne_bytes();
576    let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
577    let n_tokens_bytes = n_tokens.to_ne_bytes();
578    let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
579
580    use super::encode_helpers::{encode_with_args, KernelArg};
581
582    encode_with_args(
583        encoder,
584        pipeline,
585        &[
586            (0, KernelArg::Buffer(src_k)),
587            (1, KernelArg::Buffer(src_v)),
588            (2, KernelArg::Buffer(cache_k)),
589            (3, KernelArg::Buffer(cache_v)),
590            (4, KernelArg::Bytes(&n_heads_bytes)),
591            (5, KernelArg::Bytes(&head_dim_bytes)),
592            (6, KernelArg::Bytes(&capacity_bytes)),
593            (7, KernelArg::Bytes(&seq_pos_start_bytes)),
594            (8, KernelArg::Bytes(&n_tokens_bytes)),
595            (9, KernelArg::Bytes(&src_tok_offset_bytes)),
596        ],
597        MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
598        MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
599    );
600
601    Ok(())
602}
603
604/// Fused K + V cache copy (F32 source → F16 cache).  Wave P4.11
605/// f16-cache variant of `dispatch_kv_cache_copy_seq_f32_dual`.
606#[allow(clippy::too_many_arguments)]
607pub fn dispatch_kv_cache_copy_seq_f32_to_f16_dual(
608    encoder: &mut CommandEncoder,
609    registry: &mut KernelRegistry,
610    device: &metal::DeviceRef,
611    src_k: &MlxBuffer,
612    src_v: &MlxBuffer,
613    cache_k: &MlxBuffer,
614    cache_v: &MlxBuffer,
615    n_heads: u32,
616    head_dim: u32,
617    capacity: u32,
618    seq_pos_start: u32,
619    n_tokens: u32,
620    src_tok_offset: u32,
621) -> Result<()> {
622    if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
623        return Ok(());
624    }
625    let total_src = ((src_tok_offset as u64) + (n_tokens as u64))
626        * (n_heads as u64) * (head_dim as u64);
627    for (name, b) in [("src_k", src_k), ("src_v", src_v)] {
628        if (b.element_count() as u64) < total_src {
629            return Err(MlxError::InvalidArgument(format!(
630                "kv_cache_copy_seq_f32_to_f16_dual: {} has {} elements, need {}",
631                name, b.element_count(), total_src
632            )));
633        }
634    }
635
636    let pipeline = registry.get_pipeline("kv_cache_copy_seq_f32_to_f16_kv_dual", device)?;
637
638    let n_heads_bytes = n_heads.to_ne_bytes();
639    let head_dim_bytes = head_dim.to_ne_bytes();
640    let capacity_bytes = capacity.to_ne_bytes();
641    let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
642    let n_tokens_bytes = n_tokens.to_ne_bytes();
643    let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
644
645    use super::encode_helpers::{encode_with_args, KernelArg};
646
647    encode_with_args(
648        encoder,
649        pipeline,
650        &[
651            (0, KernelArg::Buffer(src_k)),
652            (1, KernelArg::Buffer(src_v)),
653            (2, KernelArg::Buffer(cache_k)),
654            (3, KernelArg::Buffer(cache_v)),
655            (4, KernelArg::Bytes(&n_heads_bytes)),
656            (5, KernelArg::Bytes(&head_dim_bytes)),
657            (6, KernelArg::Bytes(&capacity_bytes)),
658            (7, KernelArg::Bytes(&seq_pos_start_bytes)),
659            (8, KernelArg::Bytes(&n_tokens_bytes)),
660            (9, KernelArg::Bytes(&src_tok_offset_bytes)),
661        ],
662        MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
663        MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
664    );
665
666    Ok(())
667}
668
669/// Multi-position, all-heads KV cache copy (BF16 source → F32 cache, batched prefill).
670///
671/// Same layout and semantics as [`dispatch_kv_cache_copy_seq_f32`] — including
672/// `src_tok_offset` source slicing and `dst_pos % capacity` ring-wrap for
673/// sliding-window layers — but reads bfloat16 from the source and promotes to
674/// float32 on write.
675///
676/// Used in the Phase 2 bf16 activation path where `pf_k_normed` / `pf_v_normed`
677/// become bf16, but the KV cache (used by decode SDPA) stays f32.
678///
679/// Source layout: `[n_src_tokens, n_heads, head_dim]` bf16.
680/// Cache layout:  `[n_heads, capacity, head_dim]`     f32.
681#[allow(clippy::too_many_arguments)]
682pub fn dispatch_kv_cache_copy_seq_bf16(
683    encoder: &mut CommandEncoder,
684    registry: &mut KernelRegistry,
685    device: &metal::DeviceRef,
686    src: &MlxBuffer,
687    cache: &MlxBuffer,
688    n_heads: u32,
689    head_dim: u32,
690    capacity: u32,
691    seq_pos_start: u32,
692    n_tokens: u32,
693    src_tok_offset: u32,
694) -> Result<()> {
695    if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
696        return Ok(());
697    }
698    // src is bf16 (2 bytes per element)
699    let total_src = ((src_tok_offset as u64) + (n_tokens as u64))
700        * (n_heads as u64) * (head_dim as u64);
701    let src_bytes_needed = total_src * 2; // bf16 = 2 bytes
702    if (src.byte_len() as u64) < src_bytes_needed {
703        return Err(MlxError::InvalidArgument(format!(
704            "kv_cache_copy_seq_bf16: src has {} bytes, need {} ((src_tok_offset={} + n_tokens={}) * n_heads={} * head_dim={} * 2)",
705            src.byte_len(), src_bytes_needed, src_tok_offset, n_tokens, n_heads, head_dim
706        )));
707    }
708
709    let pipeline = registry.get_pipeline("kv_cache_copy_seq_bf16", device)?;
710
711    let n_heads_bytes = n_heads.to_ne_bytes();
712    let head_dim_bytes = head_dim.to_ne_bytes();
713    let capacity_bytes = capacity.to_ne_bytes();
714    let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
715    let n_tokens_bytes = n_tokens.to_ne_bytes();
716    let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
717
718    use super::encode_helpers::{encode_with_args, KernelArg};
719
720    encode_with_args(
721        encoder,
722        pipeline,
723        &[
724            (0, KernelArg::Buffer(src)),
725            (1, KernelArg::Buffer(cache)),
726            (2, KernelArg::Bytes(&n_heads_bytes)),
727            (3, KernelArg::Bytes(&head_dim_bytes)),
728            (4, KernelArg::Bytes(&capacity_bytes)),
729            (5, KernelArg::Bytes(&seq_pos_start_bytes)),
730            (6, KernelArg::Bytes(&n_tokens_bytes)),
731            (7, KernelArg::Bytes(&src_tok_offset_bytes)),
732        ],
733        MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
734        MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
735    );
736
737    Ok(())
738}
739
740/// Multi-position, all-heads KV cache copy (F32 source → F16 cache, batched prefill).
741///
742/// Same semantics as [`dispatch_kv_cache_copy_seq_f32`] (including
743/// `src_tok_offset` source slicing and `dst_pos % capacity` ring-wrap for
744/// sliding-window layers) but writes half-precision values in the cache.
745#[allow(clippy::too_many_arguments)]
746pub fn dispatch_kv_cache_copy_seq_f32_to_f16(
747    encoder: &mut CommandEncoder,
748    registry: &mut KernelRegistry,
749    device: &metal::DeviceRef,
750    src: &MlxBuffer,
751    cache: &MlxBuffer,
752    n_heads: u32,
753    head_dim: u32,
754    capacity: u32,
755    seq_pos_start: u32,
756    n_tokens: u32,
757    src_tok_offset: u32,
758) -> Result<()> {
759    if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
760        return Ok(());
761    }
762    let total_src = ((src_tok_offset as u64) + (n_tokens as u64))
763        * (n_heads as u64) * (head_dim as u64);
764    if (src.element_count() as u64) < total_src {
765        return Err(MlxError::InvalidArgument(format!(
766            "kv_cache_copy_seq_f32_to_f16: src has {} elements, need {}",
767            src.element_count(), total_src
768        )));
769    }
770
771    let pipeline = registry.get_pipeline("kv_cache_copy_seq_f32_to_f16", device)?;
772
773    let n_heads_bytes = n_heads.to_ne_bytes();
774    let head_dim_bytes = head_dim.to_ne_bytes();
775    let capacity_bytes = capacity.to_ne_bytes();
776    let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
777    let n_tokens_bytes = n_tokens.to_ne_bytes();
778    let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
779
780    use super::encode_helpers::{encode_with_args, KernelArg};
781
782    encode_with_args(
783        encoder,
784        pipeline,
785        &[
786            (0, KernelArg::Buffer(src)),
787            (1, KernelArg::Buffer(cache)),
788            (2, KernelArg::Bytes(&n_heads_bytes)),
789            (3, KernelArg::Bytes(&head_dim_bytes)),
790            (4, KernelArg::Bytes(&capacity_bytes)),
791            (5, KernelArg::Bytes(&seq_pos_start_bytes)),
792            (6, KernelArg::Bytes(&n_tokens_bytes)),
793            (7, KernelArg::Bytes(&src_tok_offset_bytes)),
794        ],
795        MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
796        MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
797    );
798
799    Ok(())
800}
801
802/// ADR-030 iter-95: bit-exact BF16→BF16 strided cache copy from pf_k_perm
803/// (head-major BF16) to bf16_xlen_cache (head-major BF16).
804///
805/// Used by the DFlash spec-decode xlen verify path to persist BF16 K/V
806/// across rounds without the F16-intermediate precision drift that
807/// iter-92/93 root-caused for Option A's non-toy coherence failures.
808/// Bit-identical to what Option C reads from pf_k_perm in the same call
809/// (single rounding at head_norm_rope's F32→BF16 output).
810///
811/// `src` layout: `[n_heads, src_seq_len, head_dim]` BF16 head-major
812/// (matches fused_head_norm_rope's bf16 permuted output `pf_k_perm` /
813/// `pf_v_perm`).
814/// `cache` layout: `[n_heads, capacity, head_dim]` BF16 head-major.
815#[allow(clippy::too_many_arguments)]
816pub fn dispatch_kv_cache_copy_seq_bf16_to_bf16_head_major(
817    encoder: &mut CommandEncoder,
818    registry: &mut KernelRegistry,
819    device: &metal::DeviceRef,
820    src: &MlxBuffer,
821    cache: &MlxBuffer,
822    n_heads: u32,
823    head_dim: u32,
824    capacity: u32,
825    seq_pos_start: u32,
826    n_tokens: u32,
827    src_tok_offset: u32,
828    src_seq_len: u32,
829) -> Result<()> {
830    if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
831        return Ok(());
832    }
833    let total_src = (n_heads as u64) * (src_seq_len as u64) * (head_dim as u64);
834    if (src.element_count() as u64) < total_src {
835        return Err(MlxError::InvalidArgument(format!(
836            "kv_cache_copy_seq_bf16_to_bf16_head_major: src has {} elements, need {} \
837             ({} heads × {} src_seq_len × {} head_dim)",
838            src.element_count(), total_src, n_heads, src_seq_len, head_dim
839        )));
840    }
841    if src.dtype() != crate::DType::BF16 {
842        return Err(MlxError::InvalidArgument(format!(
843            "kv_cache_copy_seq_bf16_to_bf16_head_major: src must be BF16, got {:?}",
844            src.dtype()
845        )));
846    }
847    if cache.dtype() != crate::DType::BF16 {
848        return Err(MlxError::InvalidArgument(format!(
849            "kv_cache_copy_seq_bf16_to_bf16_head_major: cache must be BF16, got {:?}",
850            cache.dtype()
851        )));
852    }
853
854    let pipeline = registry.get_pipeline("kv_cache_copy_seq_bf16_to_bf16_head_major", device)?;
855
856    let n_heads_bytes = n_heads.to_ne_bytes();
857    let head_dim_bytes = head_dim.to_ne_bytes();
858    let capacity_bytes = capacity.to_ne_bytes();
859    let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
860    let n_tokens_bytes = n_tokens.to_ne_bytes();
861    let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
862    let src_seq_len_bytes = src_seq_len.to_ne_bytes();
863
864    use super::encode_helpers::{encode_with_args, KernelArg};
865
866    encode_with_args(
867        encoder,
868        pipeline,
869        &[
870            (0, KernelArg::Buffer(src)),
871            (1, KernelArg::Buffer(cache)),
872            (2, KernelArg::Bytes(&n_heads_bytes)),
873            (3, KernelArg::Bytes(&head_dim_bytes)),
874            (4, KernelArg::Bytes(&capacity_bytes)),
875            (5, KernelArg::Bytes(&seq_pos_start_bytes)),
876            (6, KernelArg::Bytes(&n_tokens_bytes)),
877            (7, KernelArg::Bytes(&src_tok_offset_bytes)),
878            (8, KernelArg::Bytes(&src_seq_len_bytes)),
879        ],
880        MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
881        MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
882    );
883
884    Ok(())
885}