mlx-native 0.1.0

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// Temperature-scaled softmax + categorical sample, entirely on GPU.
///
/// Eliminates the need to transfer the full logits array (up to 1MB for
/// vocab_size=262144) to the CPU for stochastic sampling.
///
/// Algorithm (single threadgroup, 4 passes):
///
///   Pass 1 — find max (subtract for numerical stability after temp scaling):
///     Each thread computes local max over its chunk, tree-reduces to global max.
///
///   Pass 2 — compute exp(logit/T - max) and accumulate partition sum:
///     Each thread computes exp values, writes to a temporary scratch buffer
///     (reusing the input layout), accumulates a local sum, tree-reduces.
///
///   Pass 3 — normalize to probabilities (divide by partition sum):
///     Each thread divides its scratch values by the global sum.
///
///   Pass 4 — compute CDF prefix sum + binary-search sample:
///     We use a simple sequential scan by thread 0 over the scratch buffer.
///     For large vocab sizes (262144) this is ~1ms extra vs a full parallel
///     prefix sum, which is acceptable.  A parallel version can be added later.
///
/// Output:
///   out_token  — u32 sampled token index
///   out_logprob — f32 log-probability of the sampled token (ln(prob))
///
/// Buffer layout:
///   buffer(0): logits      — float [n_elements]  input logits (read-only)
///   buffer(1): scratch     — float [n_elements]  temp buffer (probabilities)
///   buffer(2): out_token   — uint  [1]           sampled token
///   buffer(3): out_logprob — float [1]           log-prob of sampled token
///   buffer(4): params      — float [3]           [n_elements_f, temperature, random_val]
///
/// params encoding:
///   params[0] — n_elements cast to float (recast to uint inside kernel)
///   params[1] — temperature  (> 0.0)
///   params[2] — uniform random value in [0, 1) for categorical sampling
///
/// Threadgroup: (min(1024, next_power_of_two(n_elements)), 1, 1)
/// Grid:        (1, 1, 1)
/// Shared mem index 0: tg_size * sizeof(float) for float reduction
/// Shared mem index 1: (unused — reserved for future index tracking)
///
/// NOTE: tg_size must be a power of 2.

kernel void softmax_sample_f32(
    device const float* logits      [[buffer(0)]],
    device float*       scratch     [[buffer(1)]],
    device uint*        out_token   [[buffer(2)]],
    device float*       out_logprob [[buffer(3)]],
    device const float* params      [[buffer(4)]],
    uint tid     [[thread_index_in_threadgroup]],
    uint tg_size [[threads_per_threadgroup]],
    threadgroup float* shared       [[threadgroup(0)]]
) {
    const uint  n_elements  = uint(params[0]);
    const float temperature = params[1];
    const float random_val  = params[2];

    const float inv_temp = 1.0f / temperature;

    // -----------------------------------------------------------------------
    // Pass 1: find max of (logit / temperature) for numerical stability.
    // -----------------------------------------------------------------------
    float local_max = -INFINITY;
    for (uint i = tid; i < n_elements; i += tg_size) {
        local_max = max(local_max, logits[i] * inv_temp);
    }

    shared[tid] = local_max;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
        if (tid < stride) {
            shared[tid] = max(shared[tid], shared[tid + stride]);
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }
    const float global_max = shared[0];
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // -----------------------------------------------------------------------
    // Pass 2: exp(logit/T - max), write to scratch, accumulate local sum.
    // -----------------------------------------------------------------------
    float local_sum = 0.0f;
    for (uint i = tid; i < n_elements; i += tg_size) {
        float e = exp(logits[i] * inv_temp - global_max);
        scratch[i] = e;
        local_sum += e;
    }

    shared[tid] = local_sum;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
        if (tid < stride) {
            shared[tid] += shared[tid + stride];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }
    const float partition = shared[0];
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // -----------------------------------------------------------------------
    // Pass 3: normalize scratch to probabilities.
    // -----------------------------------------------------------------------
    const float inv_partition = 1.0f / partition;
    for (uint i = tid; i < n_elements; i += tg_size) {
        scratch[i] *= inv_partition;
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // -----------------------------------------------------------------------
    // Pass 4: thread 0 scans the CDF and finds the sampled token.
    //
    // This is O(n_elements) sequential work done by a single thread.  For
    // vocab_size=262144 this is roughly 262K multiply-adds at ~1 cycle each
    // (~0.08ms on M-series at 3GHz), which is much cheaper than a GPU->CPU
    // transfer of the full 1MB logits array.
    //
    // A parallel prefix-sum + binary-search can halve this if needed later.
    // -----------------------------------------------------------------------
    if (tid == 0) {
        float cdf     = 0.0f;
        uint  tok     = n_elements - 1;  // fallback: last token
        for (uint i = 0; i < n_elements; i++) {
            cdf += scratch[i];
            if (cdf >= random_val) {
                tok = i;
                break;
            }
        }
        out_token[0]   = tok;
        out_logprob[0] = log(scratch[tok]);  // ln(prob)
    }
}