mlx-native 0.1.3

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// Numerically stable softmax along the last dimension.
///
/// Algorithm:
///   1. Find max(x) for each row (subtract for stability)
///   2. Compute exp(x - max) for each element
///   3. Sum the exponentials
///   4. Divide each exp by the sum
///
/// All accumulations use f32 for numerical stability even with f16 inputs.
///
/// Buffer layout:
///   buffer(0): input  — float array of shape [rows, cols]
///   buffer(1): output — float array of shape [rows, cols]
///   buffer(2): params — float2: (cols_f, 0)
///
/// Threadgroup: (threadgroup_size, 1, 1) — one threadgroup per row
/// Grid threadgroups: (rows, 1, 1)

kernel void softmax_f32(
    device const float *input     [[buffer(0)]],
    device float       *output    [[buffer(1)]],
    device const float *params    [[buffer(2)]],
    uint row_idx   [[threadgroup_position_in_grid]],
    uint tid       [[thread_index_in_threadgroup]],
    uint tg_size   [[threads_per_threadgroup]],
    threadgroup float *shared     [[threadgroup(0)]]
) {
    const uint cols = uint(params[0]);
    const uint base = row_idx * cols;

    // Phase 1: find row max
    float local_max = -INFINITY;
    for (uint i = tid; i < cols; i += tg_size) {
        local_max = max(local_max, input[base + i]);
    }

    shared[tid] = local_max;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Reduction for max
    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
        if (tid < stride) {
            shared[tid] = max(shared[tid], shared[tid + stride]);
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }
    const float row_max = shared[0];
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Phase 2: compute exp(x - max) and accumulate sum
    float local_sum = 0.0f;
    for (uint i = tid; i < cols; i += tg_size) {
        const float e = exp(input[base + i] - row_max);
        output[base + i] = e;  // store intermediate exp values
        local_sum += e;
    }

    shared[tid] = local_sum;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Reduction for sum
    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
        if (tid < stride) {
            shared[tid] += shared[tid + stride];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }
    const float row_sum = shared[0];
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Phase 3: normalize
    const float inv_sum = 1.0f / row_sum;
    for (uint i = tid; i < cols; i += tg_size) {
        output[base + i] *= inv_sum;
    }
}

kernel void softmax_f16(
    device const half  *input     [[buffer(0)]],
    device half        *output    [[buffer(1)]],
    device const float *params    [[buffer(2)]],
    uint row_idx   [[threadgroup_position_in_grid]],
    uint tid       [[thread_index_in_threadgroup]],
    uint tg_size   [[threads_per_threadgroup]],
    threadgroup float *shared     [[threadgroup(0)]]
) {
    const uint cols = uint(params[0]);
    const uint base = row_idx * cols;

    // Phase 1: find row max (accumulate in f32)
    float local_max = -INFINITY;
    for (uint i = tid; i < cols; i += tg_size) {
        local_max = max(local_max, float(input[base + i]));
    }

    shared[tid] = local_max;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
        if (tid < stride) {
            shared[tid] = max(shared[tid], shared[tid + stride]);
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }
    const float row_max = shared[0];
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Phase 2: exp(x - max) accumulated in f32, stored to a temporary
    // We reuse the output buffer to store the f32 exp values packed as f16,
    // but compute the sum in f32.
    float local_sum = 0.0f;
    for (uint i = tid; i < cols; i += tg_size) {
        const float e = exp(float(input[base + i]) - row_max);
        output[base + i] = half(e);  // store intermediate
        local_sum += e;
    }

    shared[tid] = local_sum;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
        if (tid < stride) {
            shared[tid] += shared[tid + stride];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }
    const float row_sum = shared[0];
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Phase 3: normalize
    const float inv_sum = 1.0f / row_sum;
    for (uint i = tid; i < cols; i += tg_size) {
        // Re-read the f16 intermediate, promote to f32, normalize, store as f16
        output[base + i] = half(float(output[base + i]) * inv_sum);
    }
}

kernel void softmax_bf16(
    device const bfloat *input     [[buffer(0)]],
    device bfloat       *output    [[buffer(1)]],
    device const float  *params    [[buffer(2)]],
    uint row_idx   [[threadgroup_position_in_grid]],
    uint tid       [[thread_index_in_threadgroup]],
    uint tg_size   [[threads_per_threadgroup]],
    threadgroup float *shared     [[threadgroup(0)]]
) {
    const uint cols = uint(params[0]);
    const uint base = row_idx * cols;

    // Phase 1: find row max (accumulate in f32 for numerical stability)
    float local_max = -INFINITY;
    for (uint i = tid; i < cols; i += tg_size) {
        local_max = max(local_max, static_cast<float>(input[base + i]));
    }

    shared[tid] = local_max;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
        if (tid < stride) {
            shared[tid] = max(shared[tid], shared[tid + stride]);
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }
    const float row_max = shared[0];
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Phase 2: compute exp(x - max) in f32, store intermediate as bf16
    float local_sum = 0.0f;
    for (uint i = tid; i < cols; i += tg_size) {
        const float e = exp(static_cast<float>(input[base + i]) - row_max);
        output[base + i] = bfloat(e);  // store intermediate
        local_sum += e;
    }

    shared[tid] = local_sum;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
        if (tid < stride) {
            shared[tid] += shared[tid + stride];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }
    const float row_sum = shared[0];
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Phase 3: normalize — re-read bf16 intermediate, normalize in f32, store as bf16
    const float inv_sum = 1.0f / row_sum;
    for (uint i = tid; i < cols; i += tg_size) {
        output[base + i] = bfloat(static_cast<float>(output[base + i]) * inv_sum);
    }
}