mlx-native 0.6.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

// =============================================================================
// qkv_split_f32 — Strided split of a fused QKV tensor into Q/K/V outputs.
//
// Input layout (per token, contiguous f32):
//
//   qkv[t, :] = [ Q (q_sp) | K (k_sp) | V (v_sp) ]   (length = qkv_ch)
//
// where q_sp = n_k_heads * d_k, k_sp = n_k_heads * d_k, v_sp = n_v_heads * d_v,
// and qkv_ch = q_sp + k_sp + v_sp.
//
// Outputs:
//   q[t * q_sp + i] = qkv[t * qkv_ch + i]                       for i in 0..q_sp
//   k[t * k_sp + i] = qkv[t * qkv_ch + q_sp + i]                for i in 0..k_sp
//   v[t * v_sp + i] = qkv[t * qkv_ch + q_sp + k_sp + i]         for i in 0..v_sp
//
// Buffer bindings:
//   buffer(0): qkv    — float (input,  fused)
//   buffer(1): q      — float (output)
//   buffer(2): k      — float (output)
//   buffer(3): v      — float (output)
//   buffer(4): params — uint [5] — {seq, q_sp, k_sp, v_sp, qkv_ch}
//
// Grid:        (qkv_ch, seq, 1)   — one thread per element of the input
// Threadgroup: (min(256, qkv_ch), 1, 1)
//
// Each thread reads one input element and writes it to exactly one of {q,k,v}
// based on the column index. Perfect coalescing on the input read; the output
// write is strided (3 disjoint output regions) but each region remains
// contiguous within itself, so writes are still coalesced within a head-span.
//
// W-5b.18 worker (ADR-005, 2026-04-27): replaces the hf2q-side
// download_f32 + CPU triple-loop + 3× upload_f32 round-trip in
// `gpu_delta_net.rs::layer_qkv_deinterleave` (838 ms / 17.5 ms per layer).
// =============================================================================

struct QkvSplitParams {
    uint seq;
    uint q_sp;
    uint k_sp;
    uint v_sp;
    uint qkv_ch;
};

kernel void qkv_split_f32(
    device const float*       qkv    [[buffer(0)]],
    device float*             q      [[buffer(1)]],
    device float*             k      [[buffer(2)]],
    device float*             v      [[buffer(3)]],
    constant QkvSplitParams&  params [[buffer(4)]],
    uint2 pos [[thread_position_in_grid]]
) {
    const uint col = pos.x;       // index within qkv_ch
    const uint row = pos.y;       // token index (0..seq)

    if (col >= params.qkv_ch || row >= params.seq) return;

    const uint src_idx = row * params.qkv_ch + col;
    const float val    = qkv[src_idx];

    const uint qk_boundary = params.q_sp + params.k_sp;

    if (col < params.q_sp) {
        // Q region.
        const uint dst_idx = row * params.q_sp + col;
        q[dst_idx] = val;
    } else if (col < qk_boundary) {
        // K region.
        const uint k_off   = col - params.q_sp;
        const uint dst_idx = row * params.k_sp + k_off;
        k[dst_idx] = val;
    } else {
        // V region.
        const uint v_off   = col - qk_boundary;
        const uint dst_idx = row * params.v_sp + v_off;
        v[dst_idx] = val;
    }
}