Skip to main content

dispatch_kv_cache_copy_batch_f32

Function dispatch_kv_cache_copy_batch_f32 

Source
pub fn dispatch_kv_cache_copy_batch_f32(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    src: &MlxBuffer,
    cache: &MlxBuffer,
    n_heads: u32,
    head_dim: u32,
    capacity: u32,
    seq_pos: u32,
) -> Result<()>
Expand description

Dispatch a batched GPU copy from a source f32 buffer into a f32 KV cache.

Copies ALL heads in one dispatch instead of one dispatch per head.

Source layout: [n_heads * head_dim] flat (one token, all heads). Cache layout: [n_heads, capacity, head_dim] head-major.

ยงArguments

  • encoder - Command encoder to record the dispatch into.
  • registry - Kernel registry (must have kv_cache_copy_batch_f32 registered).
  • device - Metal device for pipeline compilation.
  • src - Source buffer of shape [n_heads * head_dim] (f32).
  • cache - Destination cache buffer (f32, pre-allocated).
  • n_heads - Number of KV heads.
  • head_dim - Elements per head.
  • capacity - Cache capacity (window size or max_seq_len).
  • seq_pos - Write position in cache (already wrapped for sliding).