mlx-native 0.6.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

// =============================================================================
// repeat_tiled_f32 — Tiled GQA broadcast of a [T, Hg, K] F32 tensor into a
//                    [T, H, K] F32 output, where dst[t, h, k] = src[t, h % Hg, k].
//
// This is the GPU analog of llama.cpp's `ggml_repeat_4d` op as used in
// `delta-net-base.cpp:357 k = ggml_repeat(ctx0, k, s)` for the GQA broadcast
// of K (and the matching Q expansion in hf2q's `apply_gated_delta_net_chunk`
// wrapper, which uses TILED replication to match Qwen3.6's GGUF tensor
// layout — see `gpu_delta_net.rs:834-866` for the convention rationale).
//
// Layout (row-major contiguous F32, no strides):
//   src:  [T, Hg, K]   — element count = T * Hg * K
//   dst:  [T, H,  K]   — element count = T * H  * K   (H % Hg == 0)
//
// Mapping (TILED, per W-5b.4 / `project_qwen36_gqa_tiled_vs_block`):
//   dst[t, h, k] = src[t, h % Hg, k]
//
// Buffer bindings:
//   buffer(0): src    — float (input,  [T, Hg, K])
//   buffer(1): dst    — float (output, [T, H, K])
//   buffer(2): params — uint [3] — {seq=T, hg=Hg, h=H, k=K}
//
// Grid:        (K, H, T)            — one thread per OUTPUT element
// Threadgroup: (min(256, K), 1, 1)
//
// Each thread reads from src[t * Hg * K + (h % Hg) * K + k] and writes to
// dst[t * H * K + h * K + k]. The src reads collapse to Hg distinct
// addresses per (t,k) pair — well-cached on Apple unified memory.
//
// W-5b.19 worker (ADR-005, 2026-04-27): replaces the hf2q-side CPU triple-
// loop tiled-replicate at `gpu_delta_net.rs:893-940` (`q_expanded` /
// `k_expanded` fill, ~497 ms / 10.4 ms-per-layer at PP4106). Production
// caller: `hf2q::inference::models::qwen35::gpu_delta_net::
// apply_gated_delta_net_chunk` (chunk-prefill GQA pre-expansion).
// =============================================================================

struct RepeatTiledParams {
    uint seq;  // T
    uint hg;   // Hg (n_k_heads)
    uint h;    // H  (n_v_heads)
    uint k;    // K  (head dim)
};

kernel void repeat_tiled_f32(
    device const float*           src    [[buffer(0)]],
    device float*                 dst    [[buffer(1)]],
    constant RepeatTiledParams&   params [[buffer(2)]],
    uint3 pos [[thread_position_in_grid]]
) {
    const uint k = pos.x;        // index within head dim
    const uint h = pos.y;        // dst head index (0..H)
    const uint t = pos.z;        // token index    (0..T)

    if (k >= params.k || h >= params.h || t >= params.seq) return;

    // TILED mapping — `kh = h % Hg` (matches Qwen3.6 GGUF layout, llama's
    // ggml_repeat_4d, and the wrapper-side fix at gpu_delta_net.rs:927).
    const uint kh = h % params.hg;

    const uint src_idx = t * params.hg * params.k + kh * params.k + k;
    const uint dst_idx = t * params.h  * params.k + h  * params.k + k;

    dst[dst_idx] = src[src_idx];
}