Skip to main content

dispatch_kv_cache_copy_seq_bf16

Function dispatch_kv_cache_copy_seq_bf16 

Source
pub fn dispatch_kv_cache_copy_seq_bf16(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    src: &MlxBuffer,
    cache: &MlxBuffer,
    n_heads: u32,
    head_dim: u32,
    capacity: u32,
    seq_pos_start: u32,
    n_tokens: u32,
    src_tok_offset: u32,
) -> Result<()>
Expand description

Multi-position, all-heads KV cache copy (BF16 source → F32 cache, batched prefill).

Same layout and semantics as dispatch_kv_cache_copy_seq_f32 — including src_tok_offset source slicing and dst_pos % capacity ring-wrap for sliding-window layers — but reads bfloat16 from the source and promotes to float32 on write.

Used in the Phase 2 bf16 activation path where pf_k_normed / pf_v_normed become bf16, but the KV cache (used by decode SDPA) stays f32.

Source layout: [n_src_tokens, n_heads, head_dim] bf16. Cache layout: [n_heads, capacity, head_dim] f32.