mlx-native 0.9.0

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
// bilinear_resize_2d_f32 — Metal shader for ADR-021 K2.
//
// Antialiased bilinear resize (triangle-filter with support =
// max(1, 1/sf)) of a [H_src, W_src, C] row-major f32 tensor into
// a [H_dst, W_dst, C] row-major f32 tensor. Mirrors the CPU oracle
// `qwen3vl_resize_position_embeddings_bilinear` exactly:
//   - sample-coord mapping: `(dst + 0.5) / sf - 0.5` (PyTorch
//     align_corners=False / pixel_offset=0.5).
//   - antialias: triangle filter with support = max(1, 1/sf), sums
//     ALL contributions in the support window, then renormalizes by
//     total weight.
//   - For sf >= 1 (upsampling), support degenerates to 1 → 4-tap
//     bilinear; for sf < 1 (downsampling), support > 1 → wider
//     low-pass filter — load-bearing semantic match to llama.cpp's
//     `BILINEAR | ANTIALIAS` mode.
//
// Reference: /opt/llama.cpp/ggml/src/ggml-cpu/ops.cpp:7578-7637
// (the C++ source the CPU oracle ports). For the Qwen3-VL fixture
// trained_n == target_n on both axes so the fast path of the
// general formula collapses to pass-through (sy=floor(y), weight=1)
// — verified.
//
// Each thread emits one output element at (y_dst, x_dst, c). One
// thread per output element keeps the per-thread loop short for the
// canonical Qwen3-VL window sizes (sf ∈ [0.5, 4]) and avoids
// threadgroup-shared reduction.

#include <metal_stdlib>
using namespace metal;

// Must match `GpuBilinearResize2dParams` in src/ops/bilinear_resize_2d.rs.
struct BilinearResize2dParams {
    uint trained_n;       // edge length of the (square) source grid
    uint target_n_x;
    uint target_n_y;
    uint n_embd;          // channel count C
    float sf_x;           // target_n_x / trained_n
    float sf_y;           // target_n_y / trained_n
    float support_x;      // max(1.0, 1.0 / sf_x)
    float support_y;      // max(1.0, 1.0 / sf_y)
    float invscale_x;     // 1.0 / support_x
    float invscale_y;     // 1.0 / support_y
};

// Buffers:
//   0: params       — BilinearResize2dParams
//   1: src_table    — float [trained_n * trained_n * n_embd] row-major
//   2: dst_table    — float [target_n_y * target_n_x * n_embd] row-major
kernel void bilinear_resize_2d_f32(
    constant BilinearResize2dParams& params [[buffer(0)]],
    device const float*              src    [[buffer(1)]],
    device float*                    dst    [[buffer(2)]],
    uint                             gid    [[thread_position_in_grid]]
) {
    const uint total = params.target_n_y * params.target_n_x * params.n_embd;
    if (gid >= total) {
        return;
    }
    const uint trained = params.trained_n;
    const uint target_x = params.target_n_x;
    const uint n_embd = params.n_embd;

    const uint nx_c = target_x * n_embd;
    const uint y_dst = gid / nx_c;
    const uint within_y = gid - y_dst * nx_c;
    const uint x_dst = within_y / n_embd;
    const uint c = within_y - x_dst * n_embd;

    // Source-coord mapping (align_corners=False / pixel_offset=0.5).
    const float pixel_offset = 0.5f;
    const float y = ((float)y_dst + pixel_offset) / params.sf_y;
    const float x = ((float)x_dst + pixel_offset) / params.sf_x;

    int y_min = (int)max(y - params.support_y + pixel_offset, 0.0f);
    int y_max = (int)min(y + params.support_y + pixel_offset, (float)trained);
    int x_min = (int)max(x - params.support_x + pixel_offset, 0.0f);
    int x_max = (int)min(x + params.support_x + pixel_offset, (float)trained);

    float acc = 0.0f;
    float total_weight = 0.0f;
    for (int sy = y_min; sy < y_max; ++sy) {
        const float wy = max(1.0f - fabs(((float)sy - y + pixel_offset) * params.invscale_y), 0.0f);
        for (int sx = x_min; sx < x_max; ++sx) {
            const float wx = max(1.0f - fabs(((float)sx - x + pixel_offset) * params.invscale_x), 0.0f);
            const float w = wx * wy;
            if (w > 0.0f) {
                const uint src_idx = (uint)sy * trained * n_embd + (uint)sx * n_embd + c;
                acc += src[src_idx] * w;
                total_weight += w;
            }
        }
    }
    if (total_weight > 0.0f) {
        dst[gid] = acc / total_weight;
    } else {
        dst[gid] = 0.0f;
    }
}