pub fn dispatch_kv_cache_copy_seq_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &DeviceRef,
src: &MlxBuffer,
cache: &MlxBuffer,
n_heads: u32,
head_dim: u32,
capacity: u32,
seq_pos_start: u32,
n_tokens: u32,
src_tok_offset: u32,
) -> Result<()>Expand description
Multi-position, all-heads KV cache copy (F32 → F32 cache, batched prefill).
Source layout: [n_src_tokens, n_heads, head_dim] (token-major). The
kernel reads [src_tok_offset, src_tok_offset + n_tokens) from it.
Cache layout: [n_heads, capacity, head_dim] (head-major).
Writes absolute positions [seq_pos_start, seq_pos_start + n_tokens) into
cache slots dst_pos % capacity.
Global-layer contract: caller sets seq_pos_start + n_tokens <= capacity
so dst_pos % capacity == dst_pos and writes are linear. Typical call:
src_tok_offset = 0, n_tokens = seq_len, seq_pos_start = 0.
Sliding-window contract: caller sets capacity = sliding_window,
n_tokens = min(seq_len, capacity), src_tok_offset = seq_len - n_tokens,
seq_pos_start = seq_len - n_tokens. This writes the last n_tokens
source tokens into modular slots exactly once — no intra-dispatch race.
Decode side reads via ring_start = write_pos % capacity.