mlx-native 0.7.1

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// RMS Normalization — reverse-mode autograd kernels.
///
/// Forward:
///   ms[b]    = (1/D) * Σ_j x[b, j]²
///   r[b]     = rsqrt(ms[b] + eps)              (a per-row scalar)
///   y[b, i]  = x[b, i] * r[b] * w[i]
///
/// Backward — given dy[rows, dim], x[rows, dim], w[dim] →
/// produce dx[rows, dim] and dw[dim] via the analytical identities:
///
///   ∂y[b, i] / ∂w[i]      = x[b, i] * r[b]
///   ∂y[b, i] / ∂x[b, k]   = δ_{i,k} * r[b] * w[i]
///                          - x[b, i] * x[b, k] * r[b]³ * w[i] / D
///
///   dw[i] = Σ_b dy[b, i] * x[b, i] * r[b]
///   dx[b, k] = r[b] * (dy[b, k] * w[k] - x[b, k] * (s[b] * r[b]² / D))
///   where s[b] = Σ_i dy[b, i] * x[b, i] * w[i]
///
/// We split the computation across THREE kernels:
///   1. `rms_norm_compute_rms_inv_f32` — produces r[rows]
///   2. `rms_norm_backward_dx_f32`     — produces dx[rows, dim]
///   3. `rms_norm_backward_dw_f32`     — produces dw[dim]
///
/// Why three? r[b] is reused by both dx and dw; computing it once in
/// a helper avoids redundant TG-wide reductions in the larger
/// kernels, especially for dw which is dim-major (per-feature
/// threadgroups would otherwise re-reduce r[b] for every feature).

/// Helper: compute r[b] = rsqrt(mean(x[b, :]²) + eps) for each row.
///
/// Buffer layout:
///   buffer(0): x       — float[rows, dim]
///   buffer(1): r_out   — float[rows]
///   buffer(2): params  — float2: (eps, dim_f)
///
/// Threadgroup: (tg_size, 1, 1); one TG per row.
/// Grid threadgroups: (rows, 1, 1).
/// Threadgroup shared memory: tg_size * sizeof(float).
kernel void rms_norm_compute_rms_inv_f32(
    device const float *x       [[buffer(0)]],
    device float       *r_out   [[buffer(1)]],
    device const float *params  [[buffer(2)]],
    uint  row_idx [[threadgroup_position_in_grid]],
    uint  tid     [[thread_index_in_threadgroup]],
    uint  tg_size [[threads_per_threadgroup]],
    threadgroup float *shared   [[threadgroup(0)]]
) {
    const float eps = params[0];
    const uint  dim = uint(params[1]);
    const uint  base = row_idx * dim;

    // Phase 1: per-thread partial sum of squares.
    float partial = 0.0f;
    for (uint i = tid; i < dim; i += tg_size) {
        const float v = x[base + i];
        partial += v * v;
    }
    shared[tid] = partial;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Tree reduction.
    for (uint stride = tg_size / 2u; stride > 0u; stride >>= 1u) {
        if (tid < stride) {
            shared[tid] += shared[tid + stride];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    if (tid == 0u) {
        r_out[row_idx] = rsqrt(shared[0] / float(dim) + eps);
    }
}

/// dx[b, k] = r[b] * (dy[b, k] * w[k] - x[b, k] * (s[b] * r[b]² / D))
///   where s[b] = Σ_i dy[b, i] * x[b, i] * w[i]
///
/// Buffer layout:
///   buffer(0): x       — float[rows, dim]
///   buffer(1): w       — float[dim]
///   buffer(2): dy      — float[rows, dim]
///   buffer(3): r       — float[rows]   (precomputed per `rms_norm_compute_rms_inv_f32`)
///   buffer(4): dx      — float[rows, dim]   (output)
///   buffer(5): params  — float2: (dim_f, _padding)
///
/// Threadgroup: (tg_size, 1, 1); one TG per row.
/// Grid threadgroups: (rows, 1, 1).
/// Threadgroup shared memory: tg_size * sizeof(float) (used for the s[b] reduction).
kernel void rms_norm_backward_dx_f32(
    device const float *x       [[buffer(0)]],
    device const float *w       [[buffer(1)]],
    device const float *dy      [[buffer(2)]],
    device const float *r       [[buffer(3)]],
    device float       *dx      [[buffer(4)]],
    device const float *params  [[buffer(5)]],
    uint  row_idx [[threadgroup_position_in_grid]],
    uint  tid     [[thread_index_in_threadgroup]],
    uint  tg_size [[threads_per_threadgroup]],
    threadgroup float *shared   [[threadgroup(0)]]
) {
    const uint  dim = uint(params[0]);
    const uint  base = row_idx * dim;
    const float r_b = r[row_idx];

    // Phase 1: compute s[b] = Σ_i dy[b, i] * x[b, i] * w[i]
    float partial = 0.0f;
    for (uint i = tid; i < dim; i += tg_size) {
        partial += dy[base + i] * x[base + i] * w[i];
    }
    shared[tid] = partial;
    threadgroup_barrier(mem_flags::mem_threadgroup);
    for (uint stride = tg_size / 2u; stride > 0u; stride >>= 1u) {
        if (tid < stride) {
            shared[tid] += shared[tid + stride];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }
    const float s_b = shared[0];
    // Coefficient on the `x[b, k]` term: s[b] * r[b]² / D.
    const float coeff = s_b * r_b * r_b / float(dim);

    // Phase 2: dx[b, k] = r[b] * (dy[b, k] * w[k] - x[b, k] * coeff)
    for (uint k = tid; k < dim; k += tg_size) {
        const float val = r_b * (dy[base + k] * w[k] - x[base + k] * coeff);
        dx[base + k] = val;
    }
}

/// dw[i] = Σ_b dy[b, i] * x[b, i] * r[b]
///
/// Buffer layout:
///   buffer(0): x       — float[rows, dim]
///   buffer(1): dy      — float[rows, dim]
///   buffer(2): r       — float[rows]
///   buffer(3): dw      — float[dim]    (output)
///   buffer(4): params  — float2: (dim_f, rows_f)
///
/// Threadgroup: (tg_size, 1, 1); one TG per FEATURE i.
/// Grid threadgroups: (dim, 1, 1).
/// Threadgroup shared memory: tg_size * sizeof(float).
///
/// Each TG sums over `rows` for its feature; threads stride over `b`,
/// then tree-reduce.
kernel void rms_norm_backward_dw_f32(
    device const float *x       [[buffer(0)]],
    device const float *dy      [[buffer(1)]],
    device const float *r       [[buffer(2)]],
    device float       *dw      [[buffer(3)]],
    device const float *params  [[buffer(4)]],
    uint  feat_idx [[threadgroup_position_in_grid]],
    uint  tid      [[thread_index_in_threadgroup]],
    uint  tg_size  [[threads_per_threadgroup]],
    threadgroup float *shared   [[threadgroup(0)]]
) {
    const uint dim  = uint(params[0]);
    const uint rows = uint(params[1]);

    // Phase 1: per-thread partial sum over rows.
    //   contribution[b] = dy[b, feat_idx] * x[b, feat_idx] * r[b]
    float partial = 0.0f;
    for (uint b = tid; b < rows; b += tg_size) {
        const uint base = b * dim + feat_idx;
        partial += dy[base] * x[base] * r[b];
    }
    shared[tid] = partial;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Tree reduction.
    for (uint stride = tg_size / 2u; stride > 0u; stride >>= 1u) {
        if (tid < stride) {
            shared[tid] += shared[tid + stride];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    if (tid == 0u) {
        dw[feat_idx] = shared[0];
    }
}