mlx-native 0.1.1

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// Kernel A: gather_bench_nibble
///
/// Simulates TurboQuant SDPA reads: unpacks 4-bit nibble indices from a
/// nibble-packed buffer, then gathers the corresponding centroid row.
///
/// Layout:
///   packed    : [capacity, head_dim/2]  uint8 — low nibble = even coord, high nibble = odd coord
///   centroids : [16, head_dim]          float — pre-rotated centroid table (4-bit = 16 entries)
///   out       : [capacity, head_dim]    float — gathered output
///
/// Grid: 2D — x = coord index (head_dim), y = position index (capacity)
/// Threadgroup: [256, 1, 1]
kernel void gather_bench_nibble(
    device const uint8_t* packed        [[buffer(0)]],
    constant float*        centroids    [[buffer(1)]],
    constant uint&         capacity     [[buffer(2)]],
    constant uint&         head_dim     [[buffer(3)]],
    device float*          out          [[buffer(4)]],
    uint2 tid [[thread_position_in_grid]]
) {
    uint c = tid.x; // coordinate index within head
    uint p = tid.y; // position (token) index

    if (p >= capacity || c >= head_dim) return;

    // Extract 4-bit index from nibble-packed buffer.
    // Low nibble  → even coordinate (c % 2 == 0)
    // High nibble → odd  coordinate (c % 2 == 1)
    uint byte_idx = p * (head_dim / 2u) + c / 2u;
    uint8_t byte  = packed[byte_idx];
    uint    idx   = (c % 2u == 0u) ? (byte & 0xFu) : ((byte >> 4u) & 0xFu);

    // Gather from centroid table and write output.
    out[p * head_dim + c] = centroids[idx * head_dim + c];
}

/// Kernel B: gather_bench_f16_seq
///
/// Baseline sequential F16 reads: reads every element of the F16 KV cache
/// and widens to F32 — this is the workload a standard F16 SDPA performs.
///
/// Layout:
///   cache : [capacity, head_dim]  half  — F16 KV cache
///   out   : [capacity, head_dim]  float — widened output
///
/// Grid: 2D — x = coord index (head_dim), y = position index (capacity)
/// Threadgroup: [256, 1, 1]
kernel void gather_bench_f16_seq(
    device const half* cache        [[buffer(0)]],
    constant uint&     capacity     [[buffer(1)]],
    constant uint&     head_dim     [[buffer(2)]],
    device float*      out          [[buffer(3)]],
    uint2 tid [[thread_position_in_grid]]
) {
    uint c = tid.x;
    uint p = tid.y;

    if (p >= capacity || c >= head_dim) return;

    out[p * head_dim + c] = float(cache[p * head_dim + c]);
}