mlx-native 0.3.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// Fused residual addition + RMS normalization (bfloat16).
///
/// Replaces two separate kernel dispatches (elementwise_add_bf16 + rms_norm_bf16)
/// with a single pass, eliminating the intermediate summed buffer and the second
/// kernel launch.
///
/// Computes:
///   sum[i]    = float(residual[i]) + float(input[i])
///   normed[i] = sum[i] * rsqrt(mean(sum^2) + eps) * float(weight[i])
///
/// Optionally writes `sum` to a separate output buffer for use as the next
/// layer's residual (avoids an extra elementwise kernel).
///
/// Buffer layout:
///   buffer(0): residual      — bfloat [rows * dim]
///   buffer(1): input         — bfloat [rows * dim]
///   buffer(2): weight        — bfloat [dim]
///   buffer(3): normed_output — bfloat [rows * dim]   normalized result
///   buffer(4): sum_output    — bfloat [rows * dim]   residual+input (may be unused)
///   buffer(5): params        — { dim, rows, eps, write_sum }
///
/// Threadgroup: (min(256, next_pow2(dim)), 1, 1)
/// Grid       : (rows, 1, 1)
///
/// Shared memory: tg_size * sizeof(float) at threadgroup(0) for the reduction.

struct FusedResidualNormParams {
    uint  dim;
    uint  rows;
    float eps;
    uint  write_sum;   // 0 = skip writing sum_output, nonzero = write it
};

kernel void fused_residual_norm_bf16(
    device const bfloat*               residual      [[buffer(0)]],
    device const bfloat*               input         [[buffer(1)]],
    device const bfloat*               weight        [[buffer(2)]],
    device bfloat*                     normed_output [[buffer(3)]],
    device bfloat*                     sum_output    [[buffer(4)]],
    constant FusedResidualNormParams&  params        [[buffer(5)]],
    uint row_id   [[threadgroup_position_in_grid]],
    uint tid      [[thread_index_in_threadgroup]],
    uint tg_size  [[threads_per_threadgroup]],
    threadgroup float* shared          [[threadgroup(0)]]
) {
    const uint  dim       = params.dim;
    const float eps       = params.eps;
    const bool  write_sum = (params.write_sum != 0u);

    const uint base = row_id * dim;

    // -------------------------------------------------------------------------
    // Phase 1: compute residual + input element-wise, accumulate sum-of-squares
    //
    // We store the f32 sum in shared memory (reused after reduction).
    // Each thread owns ceil(dim / tg_size) elements.
    // -------------------------------------------------------------------------

    // First pass: each thread computes its partial sum-of-squares while
    // storing the partial sums into shared scratch for the reduction.
    float partial_sq = 0.0f;
    for (uint i = tid; i < dim; i += tg_size) {
        const float s = static_cast<float>(residual[base + i])
                      + static_cast<float>(input[base + i]);
        // Temporarily borrow shared[i] to cache the element sum for phase 2.
        // Safe because the reduction only writes shared[0..tg_size-1] by thread
        // index, and here we write by element index — these ranges only alias if
        // tg_size >= dim, which is always true (tg_size = next_pow2(dim), >= dim).
        shared[i] = s;
        partial_sq += s * s;
    }

    // Reduction: each thread contributes its partial sum-of-squares in shared[tid].
    // We must be careful: above we wrote to shared[i] for i in [tid, dim).
    // The reduction writes to shared[tid] (by thread index, 0..tg_size-1).
    // If dim <= tg_size, element slots [0..dim) and reduction slots [0..tg_size-1]
    // overlap.  We handle this by using a separate reduction barrier:
    //
    // Approach: accumulate partial_sq from threads, store in a separate range.
    // Since tg_size is always >= dim but elements are only in [0, dim), we can
    // use the range [dim, dim + tg_size) if tg_size allows — but that requires
    // 2*tg_size shared memory.
    //
    // Simpler correct approach: perform the reduction AFTER all element sums are
    // written, using a read-first-then-reduce pattern: save the per-thread
    // partial_sq (already computed) into shared[tid], then reduce.
    // The element sums in shared[0..dim) are overwritten for indices tid < dim,
    // which we restore in phase 2 by recomputing from the original buffers.
    //
    // For correctness and simplicity: store partial_sq in a dedicated slot.
    // We need tg_size extra slots.  To avoid doubling shared memory, we instead
    // take the pragmatic approach: finish caching element sums first, then
    // overwrite shared[0..tg_size) for the reduction, then recompute sums.

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Optionally write the un-normed sum to sum_output now (we still have each
    // element's sum cached in shared[i]).  This must happen before we overwrite
    // shared for the reduction.
    if (write_sum) {
        for (uint i = tid; i < dim; i += tg_size) {
            sum_output[base + i] = bfloat(shared[i]);
        }
    }

    // Now overwrite shared[0..tg_size) with per-thread partial_sq for reduction.
    shared[tid] = partial_sq;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
        if (tid < stride) {
            shared[tid] += shared[tid + stride];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    const float rms_inv = rsqrt(shared[0] / float(dim) + eps);

    // -------------------------------------------------------------------------
    // Phase 2: normalize and write normed_output
    //
    // We no longer have element sums in shared (overwritten by reduction).
    // Recompute each element sum directly from the source buffers.
    // -------------------------------------------------------------------------
    for (uint i = tid; i < dim; i += tg_size) {
        const float s = static_cast<float>(residual[base + i])
                      + static_cast<float>(input[base + i]);
        const float w = static_cast<float>(weight[i]);
        normed_output[base + i] = bfloat(s * rms_inv * w);
    }
}