mlx-native 0.1.0

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// Copy new K or V data from a source GPU buffer into the KV cache buffer
/// at the correct write position, with optional modulo wrapping for sliding
/// window (ring buffer) caches.
///
/// Grid: 1D, total_elements = n_new * row_size
/// Each thread copies one element.
kernel void kv_cache_copy(
    device const half* src       [[buffer(0)]],   // source K or V [n_new, row_size]
    device half*       cache     [[buffer(1)]],   // destination cache buffer
    constant uint&     write_pos [[buffer(2)]],   // starting write position in cache
    constant uint&     row_size  [[buffer(3)]],   // n_kv_heads * head_dim
    constant uint&     n_new     [[buffer(4)]],   // number of new tokens
    constant uint&     cache_cap [[buffer(5)]],   // window size (sliding) or max_seq_len (global)
    constant uint&     is_sliding [[buffer(6)]],  // 1 = use modulo wrapping, 0 = linear
    uint tid [[thread_position_in_grid]]
) {
    uint total_elements = n_new * row_size;
    if (tid >= total_elements) return;

    uint token_idx = tid / row_size;
    uint elem_idx  = tid % row_size;

    uint dst_pos = is_sliding
        ? ((write_pos + token_idx) % cache_cap)
        : (write_pos + token_idx);

    cache[dst_pos * row_size + elem_idx] = src[token_idx * row_size + elem_idx];
}

/// Float32 variant of kv_cache_copy for F32 KV caches.
///
/// Identical logic to the half variant but operates on float data.
/// Used when the activation pipeline is F32 throughout (no bf16 casting).
kernel void kv_cache_copy_f32(
    device const float* src       [[buffer(0)]],   // source K or V [n_new, row_size]
    device float*       cache     [[buffer(1)]],   // destination cache buffer
    constant uint&     write_pos [[buffer(2)]],   // starting write position in cache
    constant uint&     row_size  [[buffer(3)]],   // n_kv_heads * head_dim
    constant uint&     n_new     [[buffer(4)]],   // number of new tokens
    constant uint&     cache_cap [[buffer(5)]],   // window size (sliding) or max_seq_len (global)
    constant uint&     is_sliding [[buffer(6)]],  // 1 = use modulo wrapping, 0 = linear
    uint tid [[thread_position_in_grid]]
) {
    uint total_elements = n_new * row_size;
    if (tid >= total_elements) return;

    uint token_idx = tid / row_size;
    uint elem_idx  = tid % row_size;

    uint dst_pos = is_sliding
        ? ((write_pos + token_idx) % cache_cap)
        : (write_pos + token_idx);

    cache[dst_pos * row_size + elem_idx] = src[token_idx * row_size + elem_idx];
}

/// Batched KV cache copy — copies ALL heads in one dispatch.
///
/// Source layout: [n_heads * head_dim] flat (one token, all heads).
/// Cache layout: [n_heads, capacity, head_dim] head-major.
///
/// Grid: 2D — x=element within head (head_dim), y=head index (n_heads).
/// Replaces n_heads separate kv_cache_copy_f32 dispatches with 1.
kernel void kv_cache_copy_batch_f32(
    device const float* src       [[buffer(0)]],   // [n_heads * head_dim] flat
    device float*       cache     [[buffer(1)]],   // [n_heads, capacity, head_dim]
    constant uint&     n_heads   [[buffer(2)]],   // number of KV heads
    constant uint&     head_dim  [[buffer(3)]],   // elements per head
    constant uint&     capacity  [[buffer(4)]],   // cache capacity (ring buffer size)
    constant uint&     seq_pos   [[buffer(5)]],   // write position (already wrapped)
    uint2 tid [[thread_position_in_grid]]          // x=elem, y=head
) {
    uint elem = tid.x;
    uint head = tid.y;
    if (head >= n_heads || elem >= head_dim) return;

    uint src_idx = head * head_dim + elem;
    uint dst_idx = head * capacity * head_dim + seq_pos * head_dim + elem;
    cache[dst_idx] = src[src_idx];
}

/// Batched KV cache copy with F32→F16 cast — copies ALL heads in one dispatch.
///
/// Source layout: [n_heads * head_dim] flat F32 (one token, all heads).
/// Cache layout: [n_heads, capacity, head_dim] head-major F16.
///
/// Casts float → half on write, halving cache memory bandwidth for SDPA reads.
/// Reference: llama.cpp stores KV cache in F16 for bandwidth-bound decode SDPA.
///
/// Grid: 2D — x=element within head (head_dim), y=head index (n_heads).
kernel void kv_cache_copy_batch_f32_to_f16(
    device const float* src       [[buffer(0)]],   // [n_heads * head_dim] flat F32
    device half*        cache     [[buffer(1)]],   // [n_heads, capacity, head_dim] F16
    constant uint&     n_heads   [[buffer(2)]],   // number of KV heads
    constant uint&     head_dim  [[buffer(3)]],   // elements per head
    constant uint&     capacity  [[buffer(4)]],   // cache capacity (ring buffer size)
    constant uint&     seq_pos   [[buffer(5)]],   // write position (already wrapped)
    uint2 tid [[thread_position_in_grid]]          // x=elem, y=head
) {
    uint elem = tid.x;
    uint head = tid.y;
    if (head >= n_heads || elem >= head_dim) return;

    uint src_idx = head * head_dim + elem;
    uint dst_idx = head * capacity * head_dim + seq_pos * head_dim + elem;
    cache[dst_idx] = half(src[src_idx]);
}