Skip to main content

dispatch_kv_cache_copy_seq_f32

Function dispatch_kv_cache_copy_seq_f32 

Source
pub fn dispatch_kv_cache_copy_seq_f32(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    src: &MlxBuffer,
    cache: &MlxBuffer,
    n_heads: u32,
    head_dim: u32,
    capacity: u32,
    seq_pos_start: u32,
    n_tokens: u32,
    src_tok_offset: u32,
) -> Result<()>
Expand description

Multi-position, all-heads KV cache copy (F32 → F32 cache, batched prefill).

Source layout: [n_src_tokens, n_heads, head_dim] (token-major). The kernel reads [src_tok_offset, src_tok_offset + n_tokens) from it. Cache layout: [n_heads, capacity, head_dim] (head-major). Writes absolute positions [seq_pos_start, seq_pos_start + n_tokens) into cache slots dst_pos % capacity.

Global-layer contract: caller sets seq_pos_start + n_tokens <= capacity so dst_pos % capacity == dst_pos and writes are linear. Typical call: src_tok_offset = 0, n_tokens = seq_len, seq_pos_start = 0.

Sliding-window contract: caller sets capacity = sliding_window, n_tokens = min(seq_len, capacity), src_tok_offset = seq_len - n_tokens, seq_pos_start = seq_len - n_tokens. This writes the last n_tokens source tokens into modular slots exactly once — no intra-dispatch race. Decode side reads via ring_start = write_pos % capacity.