Skip to main content

dispatch_kv_cache_copy

Function dispatch_kv_cache_copy 

Source
pub fn dispatch_kv_cache_copy(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    src: &MlxBuffer,
    cache: &MlxBuffer,
    write_pos: u32,
    row_size: u32,
    n_new: u32,
    cache_cap: u32,
    is_sliding: bool,
) -> Result<()>
Expand description

Dispatch a GPU copy from a source bf16 buffer into a KV cache buffer.

Both src and cache must be bf16 Metal buffers in shared memory.

§Arguments

  • encoder - Command encoder to record the dispatch into.
  • registry - Kernel registry (must have kv_cache_copy registered).
  • device - Metal device for pipeline compilation.
  • src - Source buffer of shape [n_new, row_size] (bf16).
  • cache - Destination cache buffer (bf16, pre-allocated).
  • write_pos - Starting write position in the cache (token index).
  • row_size - Elements per token row (n_kv_heads * head_dim).
  • n_new - Number of new tokens to copy.
  • cache_cap - Cache capacity (window size for sliding, max_seq_len for global).
  • is_sliding- Whether to use modulo wrapping (true for sliding window).

§Errors

Returns MlxError::InvalidArgument if parameters are inconsistent.