mlx-native 0.6.7

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// Slice a column range out of a 2D row-major tensor.
///
///   `output[r, c] = input[r, start_col + c]`     for 0 ≤ c < out_cols
///
/// Buffer layout:
///   buffer(0): input  — float[rows, in_cols]
///   buffer(1): output — float[rows, out_cols]
///   buffer(2): params — uint[3]: (in_cols, out_cols, start_col)
///
/// Grid: 2D (out_cols, rows); one thread per output element.
kernel void slice_2d_cols_f32(
    device const float *input  [[buffer(0)]],
    device float       *output [[buffer(1)]],
    device const uint  *params [[buffer(2)]],
    uint2 tid [[thread_position_in_grid]],
    uint2 grid_size [[threads_per_grid]]
) {
    const uint in_cols  = params[0];
    const uint out_cols = params[1];
    const uint start    = params[2];
    const uint col = tid.x;
    const uint row = tid.y;
    if (col >= out_cols || row >= grid_size.y) {
        return;
    }
    output[row * out_cols + col] = input[row * in_cols + start + col];
}

/// Copy a 2D source tensor into a column slab of a 2D destination
/// tensor.  The destination must be pre-zeroed (or pre-populated)
/// by the caller — this kernel writes ONLY the slab
/// `dst[:, start_col_in_dst .. start_col_in_dst + src_cols]`.
///
/// Used to implement `concat_along_cols` by calling this kernel
/// once per source slab into a single pre-zeroed output.
///
/// Buffer layout:
///   buffer(0): src    — float[rows, src_cols]
///   buffer(1): dst    — float[rows, dst_cols]
///   buffer(2): params — uint[3]: (src_cols, dst_cols, start_col_in_dst)
///
/// Grid: 2D (src_cols, rows); one thread per source element.
kernel void copy_2d_cols_into_f32(
    device const float *src    [[buffer(0)]],
    device float       *dst    [[buffer(1)]],
    device const uint  *params [[buffer(2)]],
    uint2 tid [[thread_position_in_grid]],
    uint2 grid_size [[threads_per_grid]]
) {
    const uint src_cols = params[0];
    const uint dst_cols = params[1];
    const uint start    = params[2];
    const uint col = tid.x;
    const uint row = tid.y;
    if (col >= src_cols || row >= grid_size.y) {
        return;
    }
    dst[row * dst_cols + start + col] = src[row * src_cols + col];
}