mlx-native 0.1.3

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

// --------------------------------------------------------------------------
// moe_gate — Parallel top-K expert routing with softmax weights.
//
// Operates in parallel: one threadgroup per token, 128 threads per group.
//
// Algorithm per token:
//   1. RMS-Norm of hidden state (parallel reduction in threadgroup).
//   2. Router matmul: each thread handles ceil(n_experts/128) experts.
//      Normed hidden is stored in threadgroup shared memory (tg_hidden).
//      Logits are accumulated in threadgroup shared memory (tg_logits).
//   3. Top-K selection (single thread, K=8 from 128 experts via insertion
//      sort — 8 × 128 = 1024 comparisons, trivial).
//   4. Apply per_expert_scale after softmax, then re-normalize.
//
// Buffers:
//   0: hidden_state      — bfloat [seq_len, hidden_dim]
//   1: router_weights    — float  [n_experts, hidden_dim] (row-major)
//   2: norm_weight       — float  [hidden_dim]
//   3: per_expert_scale  — float  [n_experts]
//   4: expert_ids        — uint   [seq_len, top_k]  (output)
//   5: expert_weights    — float  [seq_len, top_k]  (output)
//   6: hidden_dim        — constant uint
//   7: n_experts         — constant uint
//   8: top_k             — constant uint
//   9: rms_eps           — constant float
//
// Threadgroup shared memory layout (index 0):
//   [0 .. hidden_dim)           — float: normed hidden state
//   [hidden_dim .. hidden_dim+n_experts) — float: router logits
//
// Grid:     (seq_len, 1, 1)   — one threadgroup per token
// Threads:  (128, 1, 1)       — 128 threads per threadgroup
// --------------------------------------------------------------------------

kernel void moe_gate(
    device const bfloat* hidden_state       [[buffer(0)]],
    device const float*  router_weights     [[buffer(1)]],
    device const float*  norm_weight        [[buffer(2)]],
    device const float*  per_expert_scale   [[buffer(3)]],
    device uint*         expert_ids         [[buffer(4)]],
    device float*        expert_weights     [[buffer(5)]],
    constant uint&       hidden_dim         [[buffer(6)]],
    constant uint&       n_experts          [[buffer(7)]],
    constant uint&       top_k              [[buffer(8)]],
    constant float&      rms_eps            [[buffer(9)]],
    uint  tid    [[thread_index_in_threadgroup]],
    uint  token  [[threadgroup_position_in_grid]],
    uint  tg_size [[threads_per_threadgroup]],
    threadgroup float* shared               [[threadgroup(0)]]
) {
    // shared layout:
    //   shared[0 .. hidden_dim)                — normed hidden (f32)
    //   shared[hidden_dim .. hidden_dim+n_experts) — logits (f32)
    threadgroup float* tg_hidden = shared;
    threadgroup float* tg_logits = shared + hidden_dim;

    const uint token_base = token * hidden_dim;

    // -----------------------------------------------------------------------
    // Phase 1: RMS Norm
    //   Compute rms_inv in parallel, store normed*weight in tg_hidden.
    // -----------------------------------------------------------------------

    // Step 1a: partial sum of squares
    float partial_sq = 0.0f;
    for (uint i = tid; i < hidden_dim; i += tg_size) {
        float v = static_cast<float>(hidden_state[token_base + i]);
        partial_sq += v * v;
    }

    // Reuse the logit region as a reduction scratch (n_experts >= tg_size=128).
    tg_logits[tid] = partial_sq;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Tree reduction over tg_size threads
    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
        if (tid < stride) {
            tg_logits[tid] += tg_logits[tid + stride];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    const float rms_inv = rsqrt(tg_logits[0] / float(hidden_dim) + rms_eps);
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Step 1b: normalize and multiply by norm_weight -> store in tg_hidden
    for (uint i = tid; i < hidden_dim; i += tg_size) {
        float v = static_cast<float>(hidden_state[token_base + i]);
        tg_hidden[i] = v * rms_inv * norm_weight[i];
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // -----------------------------------------------------------------------
    // Phase 2: Router matmul
    //   Each thread computes dot(tg_hidden, router_weights[e]) for its experts.
    // -----------------------------------------------------------------------
    for (uint e = tid; e < n_experts; e += tg_size) {
        float dot = 0.0f;
        device const float* w_row = router_weights + e * hidden_dim;
        for (uint d = 0; d < hidden_dim; d++) {
            dot += tg_hidden[d] * w_row[d];
        }
        tg_logits[e] = dot;
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // -----------------------------------------------------------------------
    // Phase 3: Top-K + softmax + per_expert_scale (single thread, tid == 0)
    // -----------------------------------------------------------------------
    if (tid == 0) {
        // MSL does not allow variable-length arrays; n_experts <= 128.
        bool  selected[128];
        float sel_logits[8];   // top_k <= 8

        for (uint e = 0; e < n_experts; e++) {
            selected[e] = false;
        }

        const uint out_base = token * top_k;

        // Insertion sort for top-K
        for (uint k = 0; k < top_k; k++) {
            float best_val = -INFINITY;
            uint  best_idx = 0;
            for (uint e = 0; e < n_experts; e++) {
                if (!selected[e] && tg_logits[e] > best_val) {
                    best_val = tg_logits[e];
                    best_idx = e;
                }
            }
            selected[best_idx] = true;
            expert_ids[out_base + k]  = best_idx;
            sel_logits[k]             = best_val;
        }

        // Standard softmax (no scale)
        float max_logit = sel_logits[0];
        for (uint k = 1; k < top_k; k++) {
            max_logit = max(max_logit, sel_logits[k]);
        }

        float exp_vals[8];
        float sum_exp = 0.0f;
        for (uint k = 0; k < top_k; k++) {
            exp_vals[k] = exp(sel_logits[k] - max_logit);
            sum_exp += exp_vals[k];
        }
        float inv_sum = 1.0f / sum_exp;

        // Apply softmax, then scale, then re-normalize
        float scaled_weights[8];
        float scale_sum = 0.0f;
        for (uint k = 0; k < top_k; k++) {
            float softmax_val = exp_vals[k] * inv_sum;
            float scale = per_expert_scale[expert_ids[out_base + k]];
            scaled_weights[k] = softmax_val * scale;
            scale_sum += scaled_weights[k];
        }
        float inv_scale_sum = 1.0f / scale_sum;
        for (uint k = 0; k < top_k; k++) {
            expert_weights[out_base + k] = scaled_weights[k] * inv_scale_sum;
        }
    }
}