mlx-native 0.8.0

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
// im2col_2d_3ch_f32 — Metal shader for ADR-021 K1.
//
// Unfolds a [3, H, W] f32 row-major pixel buffer into a
// [num_patches, 3*p²] f32 row-major im2col matrix matching
// `patch_embed_forward_hw`'s inner-kernel iteration order
// (channel-major, dy-major, dx-major). The output is the
// `src1` operand of the two `dense_matmul_f32_f32_tensor`
// dispatches that replace the dual-conv patch embed.
//
// Reference: src/inference/vision/vit.rs::patch_embed_forward_hw
//   for ic in 0..3:
//     for dy in 0..p:
//       for dx in 0..p:
//         k = ic*p² + dy*p + dx
// where output row index `m = patch_y * nps_x + patch_x` and the
// column index `k` iterates the unfolded patch in (ic, dy, dx)
// order — matching how `dense_matmul_f32_f32_tensor` consumes the
// `src1` slice as `[M=num_patches, K=3*p²]` row-major.

#include <metal_stdlib>
using namespace metal;

// Must match `GpuIm2col2d3chParams` in src/ops/im2col_2d_3ch.rs.
struct Im2col2d3chParams {
    uint pixel_h;
    uint pixel_w;
    uint patch_size;
    uint nps_x;
    uint nps_y;
    uint k_total;        // = 3 * p²
    uint num_patches;    // = nps_x * nps_y
    uint _pad;
};

// Buffers:
//   0: params      — Im2col2d3chParams
//   1: pixels      — float [3 * pixel_h * pixel_w] row-major (channel-major)
//   2: output      — float [num_patches * k_total] row-major
kernel void im2col_2d_3ch_f32(
    constant Im2col2d3chParams& params [[buffer(0)]],
    device const float*         pixels [[buffer(1)]],
    device float*               output [[buffer(2)]],
    uint                        gid    [[thread_position_in_grid]]
) {
    const uint total = params.num_patches * params.k_total;
    if (gid >= total) {
        return;
    }
    const uint p = params.patch_size;
    const uint w = params.pixel_w;
    const uint h = params.pixel_h;
    const uint nps_x = params.nps_x;
    const uint k_total = params.k_total;
    const uint p2 = p * p;
    const uint hw = h * w;

    const uint patch_idx = gid / k_total;
    const uint k = gid - patch_idx * k_total;

    const uint ic = k / p2;
    const uint within = k - ic * p2;
    const uint dy = within / p;
    const uint dx = within - dy * p;

    const uint patch_y = patch_idx / nps_x;
    const uint patch_x = patch_idx - patch_y * nps_x;

    const uint src_y = patch_y * p + dy;
    const uint src_x = patch_x * p + dx;
    const uint src_idx = ic * hw + src_y * w + src_x;

    output[gid] = pixels[src_idx];
}