mlx-native 0.1.3

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// Generic strided copy for making tensors contiguous.
///
/// Converts a strided 2D tensor to a contiguous layout:
///   dst[row * cols + col] = src[row * stride_row + col * stride_col]
///
/// Buffer layout:
///   buffer(0): src    — float (strided layout)
///   buffer(1): dst    — float (contiguous output)
///   buffer(2): params — uint [4] — {rows, cols, stride_row, stride_col}
///
/// Grid:        (cols, rows, 1)
/// Threadgroup: (min(256, cols), 1, 1)

struct StridedCopyParams {
    uint rows;
    uint cols;
    uint stride_row;
    uint stride_col;
};

kernel void strided_copy_f32(
    device const float*         src    [[buffer(0)]],
    device float*               dst    [[buffer(1)]],
    constant StridedCopyParams& params [[buffer(2)]],
    uint2 pos [[thread_position_in_grid]]
) {
    const uint col = pos.x;
    const uint row = pos.y;

    if (col >= params.cols || row >= params.rows) return;

    uint src_idx = row * params.stride_row + col * params.stride_col;
    uint dst_idx = row * params.cols + col;
    dst[dst_idx] = src[src_idx];
}

// --------------------------------------------------------------------------
// offset_copy_f32 — Copy `count` f32 elements with src/dst offsets.
//
// dst[dst_offset + i] = src[src_offset + i]  for i in 0..count
//
// Buffer layout:
//   buffer(0): src    — float (source buffer)
//   buffer(1): dst    — float (destination buffer)
//   buffer(2): params — uint [3] — {src_offset, dst_offset, count}
//
// Grid:        (count, 1, 1)
// Threadgroup: (min(256, count), 1, 1)
// --------------------------------------------------------------------------

struct OffsetCopyParams {
    uint src_offset;
    uint dst_offset;
    uint count;
};

kernel void offset_copy_f32(
    device const float*       src    [[buffer(0)]],
    device float*             dst    [[buffer(1)]],
    constant OffsetCopyParams& params [[buffer(2)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= params.count) return;
    dst[params.dst_offset + tid] = src[params.src_offset + tid];
}