mlx-native 0.6.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

// Inclusive prefix sum (cumulative sum) along the last axis.
//
// Computes: out[r, i] = sum(x[r, 0..=i]) for every row r independently.
//
// Spec source: ADR-013 Decision 4. Formula derived from the definition of
// an inclusive prefix scan; no code copied from llama.cpp.
//
// Algorithm: per-row Hillis-Steele scan using threadgroup shared memory.
// One threadgroup per row; each thread owns CHUNK contiguous elements, loaded
// into private memory, reduced locally, then a Hillis-Steele scan runs across
// thread-local totals. Finally each thread adds the exclusive prefix of
// preceding threads' totals to its chunk and writes outputs.
//
// Buffer layout:
//   buffer(0): input   - shape [rows, dim]
//   buffer(1): output  - shape [rows, dim]
//   buffer(2): params  - uint2: (dim, tg_size)
//
// Threadgroup shape: (tg_size, 1, 1) - caller picks tg_size so that
//   tg_size * CHUNK >= dim. CHUNK is computed by the caller as
//   ceil_div(dim, tg_size). Kernel reads CHUNK from params[2] if present,
//   otherwise derives it from dim / tg_size at runtime.
//
// Accumulation is performed in f32 regardless of input dtype for numerical
// stability (critical for Gated DeltaNet's decay-mask which multiplies these
// sums later).

// Maximum per-thread chunk size. A threadgroup of 256 threads × 32 elements
// per thread handles dim up to 8192 in a single pass. Larger dims can be
// supported by the caller increasing tg_size up to 1024 (hardware max) or
// tiling across multiple kernel launches.
#define CUMSUM_MAX_CHUNK 32

kernel void cumsum_f32(
    device const float *input   [[buffer(0)]],
    device float       *output  [[buffer(1)]],
    device const uint  *params  [[buffer(2)]],
    uint row_idx [[threadgroup_position_in_grid]],
    uint tid     [[thread_index_in_threadgroup]],
    uint tg_size [[threads_per_threadgroup]],
    threadgroup float *shared   [[threadgroup(0)]]
) {
    const uint dim = params[0];
    const uint base = row_idx * dim;

    // Per-thread chunk bounds. Each thread owns a contiguous range
    // [lo, hi) of the row.
    const uint chunk = (dim + tg_size - 1u) / tg_size;
    const uint lo = min(tid * chunk, dim);
    const uint hi = min(lo + chunk, dim);
    const uint len = hi - lo;

    // Load chunk into private memory and compute local prefix sum.
    thread float local_buf[CUMSUM_MAX_CHUNK];
    float local_sum = 0.0f;
    for (uint i = 0; i < len; ++i) {
        local_sum += float(input[base + lo + i]);
        local_buf[i] = local_sum;
    }
    // Thread's contribution to the row-wide running total.
    shared[tid] = local_sum;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Hillis-Steele INCLUSIVE scan across thread totals in shared memory.
    // Uses a temporary second buffer slot pattern: write to (tid + tg_size).
    // shared[0..tg_size)      = current values
    // shared[tg_size..2*tg_size) = previous-iteration values
    for (uint offset = 1u; offset < tg_size; offset <<= 1u) {
        float v = shared[tid];
        if (tid >= offset) {
            v += shared[tid - offset];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
        shared[tid] = v;
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    // shared[tid] now holds the inclusive prefix over threads 0..=tid of
    // their local sums. Exclusive prefix (needed to offset this thread's
    // local buf) is shared[tid] - local_sum = shared[tid-1] for tid>0.
    const float exclusive = (tid == 0u) ? 0.0f : shared[tid - 1u];

    for (uint i = 0; i < len; ++i) {
        output[base + lo + i] = local_buf[i] + exclusive;
    }
}

kernel void cumsum_bf16(
    device const bfloat *input   [[buffer(0)]],
    device bfloat       *output  [[buffer(1)]],
    device const uint   *params  [[buffer(2)]],
    uint row_idx [[threadgroup_position_in_grid]],
    uint tid     [[thread_index_in_threadgroup]],
    uint tg_size [[threads_per_threadgroup]],
    threadgroup float *shared    [[threadgroup(0)]]
) {
    const uint dim = params[0];
    const uint base = row_idx * dim;

    const uint chunk = (dim + tg_size - 1u) / tg_size;
    const uint lo = min(tid * chunk, dim);
    const uint hi = min(lo + chunk, dim);
    const uint len = hi - lo;

    thread float local_buf[CUMSUM_MAX_CHUNK];
    float local_sum = 0.0f;
    for (uint i = 0; i < len; ++i) {
        local_sum += float(input[base + lo + i]);
        local_buf[i] = local_sum;
    }
    shared[tid] = local_sum;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (uint offset = 1u; offset < tg_size; offset <<= 1u) {
        float v = shared[tid];
        if (tid >= offset) {
            v += shared[tid - offset];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
        shared[tid] = v;
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    const float exclusive = (tid == 0u) ? 0.0f : shared[tid - 1u];

    for (uint i = 0; i < len; ++i) {
        output[base + lo + i] = bfloat(local_buf[i] + exclusive);
    }
}