mlx-native 0.3.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// Copy new K or V data from a source GPU buffer into the KV cache buffer
/// at the correct write position, with optional modulo wrapping for sliding
/// window (ring buffer) caches.
///
/// Grid: 1D, total_elements = n_new * row_size
/// Each thread copies one element.
kernel void kv_cache_copy(
    device const half* src       [[buffer(0)]],   // source K or V [n_new, row_size]
    device half*       cache     [[buffer(1)]],   // destination cache buffer
    constant uint&     write_pos [[buffer(2)]],   // starting write position in cache
    constant uint&     row_size  [[buffer(3)]],   // n_kv_heads * head_dim
    constant uint&     n_new     [[buffer(4)]],   // number of new tokens
    constant uint&     cache_cap [[buffer(5)]],   // window size (sliding) or max_seq_len (global)
    constant uint&     is_sliding [[buffer(6)]],  // 1 = use modulo wrapping, 0 = linear
    uint tid [[thread_position_in_grid]]
) {
    uint total_elements = n_new * row_size;
    if (tid >= total_elements) return;

    uint token_idx = tid / row_size;
    uint elem_idx  = tid % row_size;

    uint dst_pos = is_sliding
        ? ((write_pos + token_idx) % cache_cap)
        : (write_pos + token_idx);

    cache[dst_pos * row_size + elem_idx] = src[token_idx * row_size + elem_idx];
}

/// Float32 variant of kv_cache_copy for F32 KV caches.
///
/// Identical logic to the half variant but operates on float data.
/// Used when the activation pipeline is F32 throughout (no bf16 casting).
kernel void kv_cache_copy_f32(
    device const float* src       [[buffer(0)]],   // source K or V [n_new, row_size]
    device float*       cache     [[buffer(1)]],   // destination cache buffer
    constant uint&     write_pos [[buffer(2)]],   // starting write position in cache
    constant uint&     row_size  [[buffer(3)]],   // n_kv_heads * head_dim
    constant uint&     n_new     [[buffer(4)]],   // number of new tokens
    constant uint&     cache_cap [[buffer(5)]],   // window size (sliding) or max_seq_len (global)
    constant uint&     is_sliding [[buffer(6)]],  // 1 = use modulo wrapping, 0 = linear
    uint tid [[thread_position_in_grid]]
) {
    uint total_elements = n_new * row_size;
    if (tid >= total_elements) return;

    uint token_idx = tid / row_size;
    uint elem_idx  = tid % row_size;

    uint dst_pos = is_sliding
        ? ((write_pos + token_idx) % cache_cap)
        : (write_pos + token_idx);

    cache[dst_pos * row_size + elem_idx] = src[token_idx * row_size + elem_idx];
}

/// Batched KV cache copy — copies ALL heads in one dispatch.
///
/// Source layout: [n_heads * head_dim] flat (one token, all heads).
/// Cache layout: [n_heads, capacity, head_dim] head-major.
///
/// Grid: 2D — x=element within head (head_dim), y=head index (n_heads).
/// Replaces n_heads separate kv_cache_copy_f32 dispatches with 1.
kernel void kv_cache_copy_batch_f32(
    device const float* src       [[buffer(0)]],   // [n_heads * head_dim] flat
    device float*       cache     [[buffer(1)]],   // [n_heads, capacity, head_dim]
    constant uint&     n_heads   [[buffer(2)]],   // number of KV heads
    constant uint&     head_dim  [[buffer(3)]],   // elements per head
    constant uint&     capacity  [[buffer(4)]],   // cache capacity (ring buffer size)
    constant uint&     seq_pos   [[buffer(5)]],   // write position (already wrapped)
    uint2 tid [[thread_position_in_grid]]          // x=elem, y=head
) {
    uint elem = tid.x;
    uint head = tid.y;
    if (head >= n_heads || elem >= head_dim) return;

    uint src_idx = head * head_dim + elem;
    uint dst_idx = head * capacity * head_dim + seq_pos * head_dim + elem;
    cache[dst_idx] = src[src_idx];
}

/// Batched KV cache copy with F32→F16 cast — copies ALL heads in one dispatch.
///
/// Source layout: [n_heads * head_dim] flat F32 (one token, all heads).
/// Cache layout: [n_heads, capacity, head_dim] head-major F16.
///
/// Casts float → half on write, halving cache memory bandwidth for SDPA reads.
/// Reference: llama.cpp stores KV cache in F16 for bandwidth-bound decode SDPA.
///
/// Grid: 2D — x=element within head (head_dim), y=head index (n_heads).
kernel void kv_cache_copy_batch_f32_to_f16(
    device const float* src       [[buffer(0)]],   // [n_heads * head_dim] flat F32
    device half*        cache     [[buffer(1)]],   // [n_heads, capacity, head_dim] F16
    constant uint&     n_heads   [[buffer(2)]],   // number of KV heads
    constant uint&     head_dim  [[buffer(3)]],   // elements per head
    constant uint&     capacity  [[buffer(4)]],   // cache capacity (ring buffer size)
    constant uint&     seq_pos   [[buffer(5)]],   // write position (already wrapped)
    uint2 tid [[thread_position_in_grid]]          // x=elem, y=head
) {
    uint elem = tid.x;
    uint head = tid.y;
    if (head >= n_heads || elem >= head_dim) return;

    uint src_idx = head * head_dim + elem;
    uint dst_idx = head * capacity * head_dim + seq_pos * head_dim + elem;
    cache[dst_idx] = half(src[src_idx]);
}

/// Multi-position, all-heads KV cache copy (F32 source → F32 cache).
///
/// Source layout: [n_src_tokens, n_heads, head_dim] — token-major, from
/// head_norm+RoPE. The caller selects the surviving slice of the source
/// via `src_tok_offset`; only tokens
/// `[src_tok_offset, src_tok_offset + n_tokens)` are read.
/// Cache layout:  [n_heads, capacity, head_dim] — head-major dense_kvs.
/// Writes absolute positions `[seq_pos_start, seq_pos_start + n_tokens)`
/// into cache slots `dst_pos % capacity`.
///
/// Global (non-sliding) layer contract: caller ensures
/// `seq_pos_start + n_tokens <= capacity` so `dst_pos % capacity == dst_pos`
/// and behaviour is linear/no-wrap.
///
/// Sliding-window contract: caller sets `capacity = sliding_window`,
/// `n_tokens = min(seq_len, capacity)`, `src_tok_offset = seq_len - n_tokens`,
/// `seq_pos_start = seq_len - n_tokens`. This writes the last `n_tokens`
/// source tokens into modular slots, each exactly once — no intra-dispatch
/// race. Decode side reads via `ring_start = write_pos % capacity`
/// (`hf2q:src/serve/forward_mlx.rs`), consistent with this layout.
///
/// Grid: 3D — x=elem within head, y=head, z=token.
kernel void kv_cache_copy_seq_f32(
    device const float* src       [[buffer(0)]],   // [n_src_tokens, n_heads, head_dim] F32
    device float*       cache     [[buffer(1)]],   // [n_heads, capacity, head_dim] F32
    constant uint&     n_heads   [[buffer(2)]],
    constant uint&     head_dim  [[buffer(3)]],
    constant uint&     capacity  [[buffer(4)]],
    constant uint&     seq_pos_start [[buffer(5)]],
    constant uint&     n_tokens  [[buffer(6)]],
    constant uint&     src_tok_offset [[buffer(7)]],
    uint3 tid [[thread_position_in_grid]]
) {
    uint elem = tid.x;
    uint head = tid.y;
    uint tok  = tid.z;
    if (head >= n_heads || elem >= head_dim || tok >= n_tokens) return;

    uint src_tok = src_tok_offset + tok;
    uint src_idx = src_tok * (n_heads * head_dim) + head * head_dim + elem;
    uint dst_pos = seq_pos_start + tok;
    uint slot    = dst_pos % capacity;
    uint dst_idx = head * capacity * head_dim + slot * head_dim + elem;
    cache[dst_idx] = src[src_idx];
}

/// Multi-position, all-heads KV cache copy (F32 source → F16 cache).
///
/// Same layout/semantics as kv_cache_copy_seq_f32 (including `src_tok_offset`
/// source slicing and `dst_pos % capacity` ring-wrap for sliding layers)
/// but casts to half on write.
kernel void kv_cache_copy_seq_f32_to_f16(
    device const float* src       [[buffer(0)]],   // [n_src_tokens, n_heads, head_dim] F32
    device half*        cache     [[buffer(1)]],   // [n_heads, capacity, head_dim] F16
    constant uint&     n_heads   [[buffer(2)]],
    constant uint&     head_dim  [[buffer(3)]],
    constant uint&     capacity  [[buffer(4)]],
    constant uint&     seq_pos_start [[buffer(5)]],
    constant uint&     n_tokens  [[buffer(6)]],
    constant uint&     src_tok_offset [[buffer(7)]],
    uint3 tid [[thread_position_in_grid]]
) {
    uint elem = tid.x;
    uint head = tid.y;
    uint tok  = tid.z;
    if (head >= n_heads || elem >= head_dim || tok >= n_tokens) return;

    uint src_tok = src_tok_offset + tok;
    uint src_idx = src_tok * (n_heads * head_dim) + head * head_dim + elem;
    uint dst_pos = seq_pos_start + tok;
    uint slot    = dst_pos % capacity;
    uint dst_idx = head * capacity * head_dim + slot * head_dim + elem;
    cache[dst_idx] = half(src[src_idx]);
}