mlx-native 0.6.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

// Wave 5b.1 iter 4 — chunk-local cumsum for [B, T, H] log-decay tensor.
//
// Spec source: FLA `chunk_local_cumsum_scalar` at
//   /opt/vllm/vllm/model_executor/layers/fla/ops/cumsum.py:160-195.
//
// For each (b, h) and each chunk i_t in 0..NT:
//   g_cumsum[b, t_start..t_start+BT, h] = inclusive_prefix_sum_along_t(
//       g_in[b, t_start..t_start+BT, h]
//   )
//
// Operates in f32 throughout (g is f32).
//
// # Layout
//
//   g_in     : [B, T, H] f32 — H innermost (time-stride = H)
//   g_cumsum : [B, T, H] f32 — same layout
//
// # Threading model
//
//   Grid: (1, H, B*NT). One threadgroup per (i_t, h, b). Each thread in the
//   threadgroup owns one timestep i ∈ [0, BT) but only thread 0 writes the
//   sequential per-chunk scan accumulator. (Iter 4.5 audit fix a62b111 reduced
//   grid_x from BT to 1; the prior dispatch had 64 redundant threadgroups
//   producing duplicate writes that the GPU coalesced — semantics correct
//   but pure dispatch overhead.)
//
//   Sequential per-chunk scan: thread 0 walks i=0..BT-1, accumulating.
//   This is acceptable because BT=64 is small and we have B*NT*H
//   threadgroups running in parallel (8 in the test fixture).
//
//   Iter 5 perf can replace with a Hillis-Steele scan if needed; the
//   8-threadgroup serial scan is well below noise on M5 Max.

constant constexpr uint MAX_BT = 64;

kernel void chunk_local_cumsum_g_f32(
    device const float *g_in     [[buffer(0)]],
    device       float *g_cumsum [[buffer(1)]],
    device const uint  *params   [[buffer(2)]],   // [B, T, H, BT, NT]
    uint3 tg_pos                  [[threadgroup_position_in_grid]],
    uint3 t_pos                   [[thread_position_in_threadgroup]]
) {
    const uint B  = params[0];
    const uint T  = params[1];
    const uint H  = params[2];
    const uint BT = params[3];
    const uint NT = params[4];

    const uint i_h = tg_pos.y;
    const uint b_nt = tg_pos.z;
    const uint b   = b_nt / NT;
    const uint i_t = b_nt % NT;
    if (b >= B || i_h >= H || i_t >= NT) {
        return;
    }

    // Only thread 0 walks the chunk serially. BT <= 64 so this is bounded;
    // parallelism comes from the (B * NT * H) threadgroups.
    if (t_pos.x != 0) {
        return;
    }

    const uint t_start = i_t * BT;
    const uint base = (b * T + t_start) * H + i_h;
    float acc = 0.0f;
    for (uint i = 0; i < BT; ++i) {
        acc += g_in[base + i * H];
        g_cumsum[base + i * H] = acc;
    }
}