Skip to main content

mlx_native/ops/
kv_cache_copy.rs

1//! KV cache GPU copy dispatch.
2//!
3//! Copies new K or V data directly from a source GPU buffer into a
4//! pre-allocated KV cache buffer at the correct write position, with
5//! optional modulo wrapping for sliding window (ring buffer) caches.
6//!
7//! This eliminates the CPU round-trip that `append_bf16` requires:
8//! instead of GPU -> CPU (as_slice) -> CPU (copy loop) -> shared buffer,
9//! the GPU copies directly between two shared Metal buffers.
10
11use metal::MTLSize;
12
13use crate::buffer::MlxBuffer;
14use crate::encoder::CommandEncoder;
15use crate::error::{MlxError, Result};
16use crate::kernel_registry::KernelRegistry;
17
18use super::encode_helpers::{encode_with_args, KernelArg};
19
20/// MSL source for the KV cache copy kernel (embedded at compile time).
21pub static KV_CACHE_COPY_SHADER_SOURCE: &str = include_str!("../shaders/kv_cache_copy.metal");
22
23/// Register KV cache copy shader source with the given kernel registry.
24pub fn register(registry: &mut KernelRegistry) {
25    registry.register_source("kv_cache_copy", KV_CACHE_COPY_SHADER_SOURCE);
26}
27
28/// Dispatch a GPU copy from a source bf16 buffer into a KV cache buffer.
29///
30/// Both `src` and `cache` must be bf16 Metal buffers in shared memory.
31///
32/// # Arguments
33///
34/// * `encoder`   - Command encoder to record the dispatch into.
35/// * `registry`  - Kernel registry (must have kv_cache_copy registered).
36/// * `device`    - Metal device for pipeline compilation.
37/// * `src`       - Source buffer of shape `[n_new, row_size]` (bf16).
38/// * `cache`     - Destination cache buffer (bf16, pre-allocated).
39/// * `write_pos` - Starting write position in the cache (token index).
40/// * `row_size`  - Elements per token row (`n_kv_heads * head_dim`).
41/// * `n_new`     - Number of new tokens to copy.
42/// * `cache_cap` - Cache capacity (window size for sliding, max_seq_len for global).
43/// * `is_sliding`- Whether to use modulo wrapping (`true` for sliding window).
44///
45/// # Errors
46///
47/// Returns `MlxError::InvalidArgument` if parameters are inconsistent.
48#[allow(clippy::too_many_arguments)]
49pub fn dispatch_kv_cache_copy(
50    encoder: &mut CommandEncoder,
51    registry: &mut KernelRegistry,
52    device: &metal::DeviceRef,
53    src: &MlxBuffer,
54    cache: &MlxBuffer,
55    write_pos: u32,
56    row_size: u32,
57    n_new: u32,
58    cache_cap: u32,
59    is_sliding: bool,
60) -> Result<()> {
61    if n_new == 0 || row_size == 0 {
62        return Ok(()); // Nothing to copy
63    }
64
65    let total_elements = (n_new as u64) * (row_size as u64);
66    let src_elements = src.element_count() as u64;
67    if src_elements < total_elements {
68        return Err(MlxError::InvalidArgument(format!(
69            "kv_cache_copy: src has {} elements but need {} (n_new={} * row_size={})",
70            src_elements, total_elements, n_new, row_size
71        )));
72    }
73
74    // For global (non-sliding) caches, check we won't write past capacity
75    if !is_sliding && (write_pos as u64 + n_new as u64) > cache_cap as u64 {
76        return Err(MlxError::InvalidArgument(format!(
77            "kv_cache_copy: global cache overflow: write_pos({}) + n_new({}) > cache_cap({})",
78            write_pos, n_new, cache_cap
79        )));
80    }
81
82    let pipeline = registry.get_pipeline("kv_cache_copy", device)?;
83
84    let is_sliding_val: u32 = if is_sliding { 1 } else { 0 };
85
86    // Pass each scalar as individual set_bytes calls matching buffer indices 2-6
87    let write_pos_bytes = write_pos.to_ne_bytes();
88    let row_size_bytes = row_size.to_ne_bytes();
89    let n_new_bytes = n_new.to_ne_bytes();
90    let cache_cap_bytes = cache_cap.to_ne_bytes();
91    let is_sliding_bytes = is_sliding_val.to_ne_bytes();
92
93    encode_with_args(
94        encoder,
95        pipeline,
96        &[
97            (0, KernelArg::Buffer(src)),
98            (1, KernelArg::Buffer(cache)),
99            (2, KernelArg::Bytes(&write_pos_bytes)),
100            (3, KernelArg::Bytes(&row_size_bytes)),
101            (4, KernelArg::Bytes(&n_new_bytes)),
102            (5, KernelArg::Bytes(&cache_cap_bytes)),
103            (6, KernelArg::Bytes(&is_sliding_bytes)),
104        ],
105        MTLSize::new(total_elements, 1, 1),
106        MTLSize::new(std::cmp::min(256, total_elements), 1, 1),
107    );
108
109    Ok(())
110}
111
112/// Dispatch a batched GPU copy from a source f32 buffer into a f32 KV cache.
113///
114/// Copies ALL heads in one dispatch instead of one dispatch per head.
115///
116/// Source layout: `[n_heads * head_dim]` flat (one token, all heads).
117/// Cache layout: `[n_heads, capacity, head_dim]` head-major.
118///
119/// # Arguments
120///
121/// * `encoder`   - Command encoder to record the dispatch into.
122/// * `registry`  - Kernel registry (must have kv_cache_copy_batch_f32 registered).
123/// * `device`    - Metal device for pipeline compilation.
124/// * `src`       - Source buffer of shape `[n_heads * head_dim]` (f32).
125/// * `cache`     - Destination cache buffer (f32, pre-allocated).
126/// * `n_heads`   - Number of KV heads.
127/// * `head_dim`  - Elements per head.
128/// * `capacity`  - Cache capacity (window size or max_seq_len).
129/// * `seq_pos`   - Write position in cache (already wrapped for sliding).
130#[allow(clippy::too_many_arguments)]
131pub fn dispatch_kv_cache_copy_batch_f32(
132    encoder: &mut CommandEncoder,
133    registry: &mut KernelRegistry,
134    device: &metal::DeviceRef,
135    src: &MlxBuffer,
136    cache: &MlxBuffer,
137    n_heads: u32,
138    head_dim: u32,
139    capacity: u32,
140    seq_pos: u32,
141) -> Result<()> {
142    if n_heads == 0 || head_dim == 0 {
143        return Ok(());
144    }
145
146    let total_src = (n_heads as u64) * (head_dim as u64);
147    if (src.element_count() as u64) < total_src {
148        return Err(MlxError::InvalidArgument(format!(
149            "kv_cache_copy_batch_f32: src has {} elements but need {} (n_heads={} * head_dim={})",
150            src.element_count(), total_src, n_heads, head_dim
151        )));
152    }
153
154    let pipeline = registry.get_pipeline("kv_cache_copy_batch_f32", device)?;
155
156    let n_heads_bytes = n_heads.to_ne_bytes();
157    let head_dim_bytes = head_dim.to_ne_bytes();
158    let capacity_bytes = capacity.to_ne_bytes();
159    let seq_pos_bytes = seq_pos.to_ne_bytes();
160
161    use super::encode_helpers::{encode_with_args, KernelArg};
162
163    encode_with_args(
164        encoder,
165        pipeline,
166        &[
167            (0, KernelArg::Buffer(src)),
168            (1, KernelArg::Buffer(cache)),
169            (2, KernelArg::Bytes(&n_heads_bytes)),
170            (3, KernelArg::Bytes(&head_dim_bytes)),
171            (4, KernelArg::Bytes(&capacity_bytes)),
172            (5, KernelArg::Bytes(&seq_pos_bytes)),
173        ],
174        MTLSize::new(head_dim as u64, n_heads as u64, 1),
175        MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
176    );
177
178    Ok(())
179}
180
181/// Dispatch a GPU copy from a source f32 buffer into a f32 KV cache buffer.
182///
183/// Identical to `dispatch_kv_cache_copy` but for F32 data (used when the
184/// activation pipeline operates in F32 throughout).
185///
186/// Both `src` and `cache` must be f32 Metal buffers in shared memory.
187///
188/// # Arguments
189///
190/// * `encoder`   - Command encoder to record the dispatch into.
191/// * `registry`  - Kernel registry (must have kv_cache_copy_f32 registered).
192/// * `device`    - Metal device for pipeline compilation.
193/// * `src`       - Source buffer of shape `[n_new, row_size]` (f32).
194/// * `cache`     - Destination cache buffer (f32, pre-allocated).
195/// * `write_pos` - Starting write position in the cache (token index).
196/// * `row_size`  - Elements per token row (`n_kv_heads * head_dim`).
197/// * `n_new`     - Number of new tokens to copy.
198/// * `cache_cap` - Cache capacity (window size for sliding, max_seq_len for global).
199/// * `is_sliding`- Whether to use modulo wrapping (`true` for sliding window).
200///
201/// # Errors
202///
203/// Returns `MlxError::InvalidArgument` if parameters are inconsistent.
204#[allow(clippy::too_many_arguments)]
205pub fn dispatch_kv_cache_copy_f32(
206    encoder: &mut CommandEncoder,
207    registry: &mut KernelRegistry,
208    device: &metal::DeviceRef,
209    src: &MlxBuffer,
210    cache: &MlxBuffer,
211    write_pos: u32,
212    row_size: u32,
213    n_new: u32,
214    cache_cap: u32,
215    is_sliding: bool,
216) -> Result<()> {
217    if n_new == 0 || row_size == 0 {
218        return Ok(()); // Nothing to copy
219    }
220
221    let total_elements = (n_new as u64) * (row_size as u64);
222    let src_elements = src.element_count() as u64;
223    if src_elements < total_elements {
224        return Err(MlxError::InvalidArgument(format!(
225            "kv_cache_copy_f32: src has {} elements but need {} (n_new={} * row_size={})",
226            src_elements, total_elements, n_new, row_size
227        )));
228    }
229
230    // For global (non-sliding) caches, check we won't write past capacity
231    if !is_sliding && (write_pos as u64 + n_new as u64) > cache_cap as u64 {
232        return Err(MlxError::InvalidArgument(format!(
233            "kv_cache_copy_f32: global cache overflow: write_pos({}) + n_new({}) > cache_cap({})",
234            write_pos, n_new, cache_cap
235        )));
236    }
237
238    let pipeline = registry.get_pipeline("kv_cache_copy_f32", device)?;
239
240    let is_sliding_val: u32 = if is_sliding { 1 } else { 0 };
241
242    let write_pos_bytes = write_pos.to_ne_bytes();
243    let row_size_bytes = row_size.to_ne_bytes();
244    let n_new_bytes = n_new.to_ne_bytes();
245    let cache_cap_bytes = cache_cap.to_ne_bytes();
246    let is_sliding_bytes = is_sliding_val.to_ne_bytes();
247
248    encode_with_args(
249        encoder,
250        pipeline,
251        &[
252            (0, KernelArg::Buffer(src)),
253            (1, KernelArg::Buffer(cache)),
254            (2, KernelArg::Bytes(&write_pos_bytes)),
255            (3, KernelArg::Bytes(&row_size_bytes)),
256            (4, KernelArg::Bytes(&n_new_bytes)),
257            (5, KernelArg::Bytes(&cache_cap_bytes)),
258            (6, KernelArg::Bytes(&is_sliding_bytes)),
259        ],
260        MTLSize::new(total_elements, 1, 1),
261        MTLSize::new(std::cmp::min(256, total_elements), 1, 1),
262    );
263
264    Ok(())
265}
266
267/// Dispatch a batched F32→F16 copy from a source f32 buffer into an f16 KV cache.
268///
269/// Copies ALL heads in one dispatch, casting float→half on write.
270/// This halves KV cache memory bandwidth for SDPA reads (bandwidth-bound
271/// at batch=1 decode). Reference: llama.cpp stores KV cache in F16.
272///
273/// Source layout: `[n_heads * head_dim]` flat F32 (one token, all heads).
274/// Cache layout: `[n_heads, capacity, head_dim]` head-major F16.
275#[allow(clippy::too_many_arguments)]
276pub fn dispatch_kv_cache_copy_batch_f32_to_f16(
277    encoder: &mut CommandEncoder,
278    registry: &mut KernelRegistry,
279    device: &metal::DeviceRef,
280    src: &MlxBuffer,
281    cache: &MlxBuffer,
282    n_heads: u32,
283    head_dim: u32,
284    capacity: u32,
285    seq_pos: u32,
286) -> Result<()> {
287    if n_heads == 0 || head_dim == 0 {
288        return Ok(());
289    }
290
291    let total_src = (n_heads as u64) * (head_dim as u64);
292    if (src.element_count() as u64) < total_src {
293        return Err(MlxError::InvalidArgument(format!(
294            "kv_cache_copy_batch_f32_to_f16: src has {} elements but need {} (n_heads={} * head_dim={})",
295            src.element_count(), total_src, n_heads, head_dim
296        )));
297    }
298
299    let pipeline = registry.get_pipeline("kv_cache_copy_batch_f32_to_f16", device)?;
300
301    let n_heads_bytes = n_heads.to_ne_bytes();
302    let head_dim_bytes = head_dim.to_ne_bytes();
303    let capacity_bytes = capacity.to_ne_bytes();
304    let seq_pos_bytes = seq_pos.to_ne_bytes();
305
306    use super::encode_helpers::{encode_with_args, KernelArg};
307
308    encode_with_args(
309        encoder,
310        pipeline,
311        &[
312            (0, KernelArg::Buffer(src)),
313            (1, KernelArg::Buffer(cache)),
314            (2, KernelArg::Bytes(&n_heads_bytes)),
315            (3, KernelArg::Bytes(&head_dim_bytes)),
316            (4, KernelArg::Bytes(&capacity_bytes)),
317            (5, KernelArg::Bytes(&seq_pos_bytes)),
318        ],
319        MTLSize::new(head_dim as u64, n_heads as u64, 1),
320        MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
321    );
322
323    Ok(())
324}