mlx-native 0.7.1

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// ADR-020 iter-11h-e1 — take_along_axis (gather) + scatter-backward
/// kernels for the GpuTape autograd pipeline.
///
/// Used by MoE router (mlx-lm/qwen3_next.py:Qwen3NextSparseMoeBlock):
///   gates  = softmax(linear(x))      // [n_tokens, n_experts]
///   inds   = argpartition(gates, k)  // [n_tokens, top_k] (u32)
///   scores = take_along_axis(gates, inds, axis=-1)   // [n_tokens, top_k]
///
/// Forward (per (r, j)):
///   y[r, j] = x[r, indices[r, j]]
///
/// Backward (zero-init dx, per (r, j) scatter):
///   dx[r, indices[r, j]] += dy[r, j]
///
/// Top-K gives distinct indices per row → no atomic collisions. We
/// rely on the caller to pre-zero dx (alloc_buffer zero-fills per
/// MlxDevice contract).
///
/// Layout:
///   x       : `[rows, cols]` row-major f32
///   indices : `[rows, k]` row-major u32
///   y       : `[rows, k]` row-major f32

kernel void take_along_axis_f32(
    device const float *x       [[buffer(0)]],
    device const uint  *indices [[buffer(1)]],
    device float       *y       [[buffer(2)]],
    device const uint  *params  [[buffer(3)]],   // [rows, cols, k]
    uint2 tid [[thread_position_in_grid]]
) {
    const uint rows = params[0];
    const uint cols = params[1];
    const uint k    = params[2];

    const uint r = tid.x;
    const uint j = tid.y;
    if (r >= rows || j >= k) return;

    const uint idx = indices[r * k + j];
    if (idx >= cols) {
        // Out-of-bounds index — write 0 to surface the bug rather
        // than read OOB.  Caller validates indices but this is a
        // defensive guard inside the kernel.
        y[r * k + j] = 0.0f;
        return;
    }
    y[r * k + j] = x[r * cols + idx];
}

/// Scatter `dy` into `dx` at the positions specified by `indices`.
/// `dx` must be PRE-ZEROED by caller (alloc_buffer satisfies this).
///
/// One thread per (r, j); writes `dx[r, indices[r, j]] = dy[r, j]`.
/// Top-K indices are distinct within a row → no race even without
/// atomics.
kernel void take_along_axis_backward_f32(
    device const float *dy      [[buffer(0)]],
    device const uint  *indices [[buffer(1)]],
    device float       *dx      [[buffer(2)]],
    device const uint  *params  [[buffer(3)]],   // [rows, cols, k]
    uint2 tid [[thread_position_in_grid]]
) {
    const uint rows = params[0];
    const uint cols = params[1];
    const uint k    = params[2];

    const uint r = tid.x;
    const uint j = tid.y;
    if (r >= rows || j >= k) return;

    const uint idx = indices[r * k + j];
    if (idx >= cols) return;
    dx[r * cols + idx] = dy[r * k + j];
}