mlx-native 0.7.1

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
// feature_concat_f32 — Metal shader for ADR-021 K5.
//
// Strided copy of one chunk into its slice of the concatenated
// destination tensor. For a chunk of shape [T, D] f32 row-major,
// computes:
//
//   for t in 0..T:
//     for d in 0..D:
//       dst[t * dst_stride + dst_offset + d] = src[t * D + d]
//
// where `dst_stride = D_total = sum of all chunk D's`. Launching
// once per chunk (varying `dst_offset`) builds the full concatenated
// `[T, D_total]` row-major tensor with each row carrying
// `[chunk0[t]; chunk1[t]; ...; chunkN[t]]` — matching qwen3vl.cpp:186
// `ggml_concat(ctx0, embeddings, deepstack_features, 0)`.

#include <metal_stdlib>
using namespace metal;

// Must match `GpuFeatureConcatParams` in src/ops/feature_concat.rs.
struct FeatureConcatParams {
    uint n_tokens;       // T
    uint src_dim;        // D_i (chunk feature width)
    uint dst_offset;     // start column for this chunk in the [T, D_total] dst
    uint dst_stride;     // D_total (row stride of dst)
};

// Buffers:
//   0: params — FeatureConcatParams
//   1: src    — float [T * src_dim] row-major (one chunk)
//   2: dst    — float [T * dst_stride] row-major (concatenated tensor;
//               this kernel only writes to the [t, dst_offset .. dst_offset+src_dim] slice)
kernel void feature_concat_f32(
    constant FeatureConcatParams& params [[buffer(0)]],
    device const float*           src    [[buffer(1)]],
    device float*                 dst    [[buffer(2)]],
    uint                          gid    [[thread_position_in_grid]]
) {
    const uint total = params.n_tokens * params.src_dim;
    if (gid >= total) {
        return;
    }
    const uint d_dim = params.src_dim;
    const uint t = gid / d_dim;
    const uint d = gid - t * d_dim;
    const uint dst_idx = t * params.dst_stride + params.dst_offset + d;
    dst[dst_idx] = src[gid];
}