mlx-native 0.6.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
// scale_mask_softmax.metal — fused scale-then-mask-then-softmax over
// attention scores for the non-flash-attention prefill path.
//
// Replaces three sequential dispatches (scale, mask-add, softmax) with
// one row-local pass.  Each threadgroup processes one (head, query-
// position) row of the scores tensor and reduces over the key axis.
//
// Contract:
//   input  — f32 [nh, seq_q, seq_k] = [rows = nh*seq_q, cols = seq_k]
//   output — f32 same shape (row-normalized softmax probs; may alias
//            input for in-place operation)
//   mask   — bf16 [seq_q, seq_k] row-major.  Shared across heads: for
//            row_idx = h*seq_q + q, mask[q, :] is applied.  Masked
//            positions hold -INF (matches flash_attn_prefill_mask.metal's
//            sentinel: attended = 0.0f, masked = -INF).
//
// Math (per row):
//   tmp[k] = input[row, k] * scale + float(mask[q, k])
//   row_max = max_k tmp[k]
//   exp_k   = exp(tmp[k] - row_max)
//   row_sum = sum_k exp_k
//   output[row, k] = exp_k / row_sum
//
// Layout assumptions:
//   * Dispatcher sends threadgroups=(rows, 1, 1), tgsize=(N, 1, 1) with
//     N >= 32.  Each threadgroup gets `shared[N]` floats of scratch.
//   * rows is nh*seq_q; cols = seq_k.  The caller passes seq_q in
//     params so we can derive q = row_idx % seq_q for the mask index.
//
// Used for the non-FA prefill attention path (HF2Q_NO_FA=1).  Modelled
// on llama.cpp's kernel_soft_max_f32 (ggml-metal.metal:1855-1960),
// simplified for our specific case (no ALiBi, bf16 mask, fixed scale).

#include <metal_stdlib>
using namespace metal;

// Apple GPU simdgroup width — constant across all current generations.
#define N_SIMDWIDTH 32

struct ScaleMaskSoftmaxParams {
    uint  cols;        // seq_k (size of reduction axis)
    uint  seq_q;       // number of rows per head (to compute q = row % seq_q)
    float scale;       // multiplicative scale applied to input pre-mask (= 1/sqrt(hd))
    uint  _pad;
};

// D.3 — llama.cpp-style simdgroup-reduction softmax.  Uses hardware
// simd_max / simd_sum (1-cycle intra-simdgroup reductions) instead of
// the tree-reduce + threadgroup barriers we had before.  When the
// threadgroup has more than one simdgroup (tg_size > 32), a secondary
// cross-simdgroup reduction runs through shared memory — but within
// each phase that is only 1 extra barrier + 1 simd_reduce, versus the
// log2(tg_size) barriers the tree path took.  On Apple M5 at
// tg_size=256 (8 simdgroups, cols=2455 attention row), this cuts the
// softmax kernel time by ~3x per row.
//
// Structure matches llama.cpp's kernel_soft_max (ggml-metal.metal:1855).

kernel void scale_mask_softmax_f32(
    device const float  *input   [[buffer(0)]],
    device       float  *output  [[buffer(1)]],
    device const bfloat *mask    [[buffer(2)]],
    constant ScaleMaskSoftmaxParams & params [[buffer(3)]],
    uint  row_idx [[threadgroup_position_in_grid]],
    uint  tid     [[thread_index_in_threadgroup]],
    uint  tg_size [[threads_per_threadgroup]],
    uint  sgitg   [[simdgroup_index_in_threadgroup]],
    uint  tiisg   [[thread_index_in_simdgroup]],
    threadgroup float *shared [[threadgroup(0)]]
) {
    const uint  cols       = params.cols;
    const uint  seq_q      = params.seq_q;
    const float scale      = params.scale;
    const uint  q          = row_idx % seq_q;
    const uint  scores_base = row_idx * cols;
    const uint  mask_base   = q * cols;

    // ---- Phase 1: row-max ----
    float local_max = -INFINITY;
    for (uint i = tid; i < cols; i += tg_size) {
        float v = input[scores_base + i] * scale + float(mask[mask_base + i]);
        local_max = max(local_max, v);
    }

    // Intra-simdgroup reduction (hardware, 1 cycle).
    float max_val = simd_max(local_max);
    if (tg_size > N_SIMDWIDTH) {
        // Cross-simdgroup reduction via shared memory.
        if (sgitg == 0) {
            shared[tiisg] = -INFINITY;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
        if (tiisg == 0) {
            shared[sgitg] = max_val;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
        max_val = shared[tiisg];
        max_val = simd_max(max_val);
    }

    // ---- Phase 2: exp(v - max), store to output, accumulate sum ----
    float local_sum = 0.0f;
    for (uint i = tid; i < cols; i += tg_size) {
        float v = input[scores_base + i] * scale + float(mask[mask_base + i]);
        float e = exp(v - max_val);
        output[scores_base + i] = e;
        local_sum += e;
    }

    // Barrier fixes a sporadic reduction ordering bug on Apple GPUs —
    // matches llama.cpp's comment at ggml-metal.metal:1925.
    threadgroup_barrier(mem_flags::mem_none);

    float sum = simd_sum(local_sum);
    if (tg_size > N_SIMDWIDTH) {
        if (sgitg == 0) {
            shared[tiisg] = 0.0f;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
        if (tiisg == 0) {
            shared[sgitg] = sum;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
        sum = shared[tiisg];
        sum = simd_sum(sum);
    }

    // ---- Phase 3: normalise.  Guard against sum=0 (fully-masked row). ----
    const float inv_sum = (sum > 0.0f) ? (1.0f / sum) : 0.0f;
    for (uint i = tid; i < cols; i += tg_size) {
        output[scores_base + i] *= inv_sum;
    }
}