mlx-native 0.1.0

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

// --------------------------------------------------------------------------
// moe_expert_ffn — Single-expert FFN: gate_proj + up_proj -> GELU -> down_proj
//
// For a single expert, computes:
//   gate_out = gate_proj(x)        [input_dim -> intermediate_dim]
//   up_out   = up_proj(x)          [input_dim -> intermediate_dim]
//   hidden   = GELU(gate_out) * up_out
//   out      = down_proj(hidden)   [intermediate_dim -> input_dim]
//
// This shader works on float (f32) dequantized weights.  The Rust host
// is responsible for providing pre-dequantized weight slices for the
// selected expert, OR this shader is called after dequantization.
//
// For Stage 1, the Rust host loops over selected experts and dispatches
// the quantized_matmul kernel for each projection, then calls this
// elementwise kernel for the GELU + multiply fusion.
//
// This shader does the fused GELU-multiply: hidden = GELU(gate_out) * up_out
// --------------------------------------------------------------------------

// GELU approximation matching PyTorch/MLX: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
// Clamp threshold for tanh argument to prevent NaN from exp() overflow.
// tanh saturates at +/-1 well before |x| = 10, so clamping at 15 is safe.
inline float gelu_approx(float x) {
    const float sqrt_2_over_pi = 0.7978845608028654f;
    const float tanh_clamp = 15.0f;
    float x3 = x * x * x;
    float inner = sqrt_2_over_pi * (x + 0.044715f * x3);
    inner = clamp(inner, -tanh_clamp, tanh_clamp);
    return 0.5f * x * (1.0f + tanh(inner));
}

struct FusedGeluMulParams {
    uint n_elements;
};

// Fused GELU(gate_out) * up_out
// Buffers:
//   0: gate_out  — float [n_elements]  (input, will be overwritten with result)
//   1: up_out    — float [n_elements]  (input)
//   2: output    — float [n_elements]  (output: GELU(gate_out) * up_out)
//   3: params    — { n_elements }
kernel void fused_gelu_mul(
    device const float* gate_out [[buffer(0)]],
    device const float* up_out   [[buffer(1)]],
    device float*       output   [[buffer(2)]],
    constant FusedGeluMulParams& params [[buffer(3)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.n_elements) return;
    output[gid] = gelu_approx(gate_out[gid]) * up_out[gid];
}

// --------------------------------------------------------------------------
// moe_accumulate — Weighted accumulation: result += weight * expert_output
//
// Used by the Rust host to accumulate expert outputs with routing weights.
//
// Buffers:
//   0: accumulator   — float [n_elements]  (in/out)
//   1: expert_output — float [n_elements]  (input)
//   2: params        — { n_elements, routing_weight }
// --------------------------------------------------------------------------

struct MoeAccumParams {
    uint n_elements;
    float routing_weight;
};

kernel void moe_accumulate(
    device float*       accumulator   [[buffer(0)]],
    device const float* expert_output [[buffer(1)]],
    constant MoeAccumParams& params   [[buffer(2)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.n_elements) return;
    accumulator[gid] += params.routing_weight * expert_output[gid];
}

// --------------------------------------------------------------------------
// zero_buffer — Zero-initialize a float buffer
//
// Buffers:
//   0: buffer      — float [n_elements]
//   1: params      — { n_elements }
// --------------------------------------------------------------------------

struct ZeroParams {
    uint n_elements;
};

kernel void zero_buffer(
    device float* buffer [[buffer(0)]],
    constant ZeroParams& params [[buffer(1)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.n_elements) return;
    buffer[gid] = 0.0f;
}

// --------------------------------------------------------------------------
// moe_swiglu_fused — SwiGLU on a fused [gate, up] buffer.
//
// Takes a buffer of 2*N elements where:
//   - First N elements are the gate projection output
//   - Last N elements are the up projection output
// Produces N elements: GELU(gate[i]) * up[i]
//
// This is the fused variant for models that use a single gate_up projection
// (e.g., Gemma 4 MoE experts with fused gate/up weights).
//
// Buffers:
//   0: gate_up — float [2 * N] (input: gate || up, concatenated)
//   1: output  — float [N]     (output: GELU(gate) * up)
//   2: params  — { n_elements = N }
// --------------------------------------------------------------------------

kernel void moe_swiglu_fused(
    device const float* gate_up [[buffer(0)]],
    device float*       output  [[buffer(1)]],
    constant FusedGeluMulParams& params [[buffer(2)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.n_elements) return;
    float gate_val = gate_up[gid];
    float up_val   = gate_up[params.n_elements + gid];
    output[gid] = gelu_approx(gate_val) * up_val;
}

// --------------------------------------------------------------------------
// moe_swiglu_batch — Batched SwiGLU across all top_k expert slots.
//
// Takes a [top_k, 2*intermediate] buffer where for each slot k:
//   - gate values are at [k, 0..intermediate)
//   - up values are at [k, intermediate..2*intermediate)
// Produces [top_k, intermediate] output: GELU(gate[i]) * up[i] per slot.
//
// Grid: 2D — x=element within intermediate, y=expert slot.
// Replaces top_k separate moe_swiglu_fused dispatches with 1.
// --------------------------------------------------------------------------

kernel void moe_swiglu_batch(
    device const float* gate_up_buf  [[buffer(0)]],  // [top_k, 2*intermediate]
    device float*       output_buf   [[buffer(1)]],  // [top_k, intermediate]
    constant uint&      intermediate [[buffer(2)]],
    constant uint&      top_k        [[buffer(3)]],
    uint2 tid [[thread_position_in_grid]]             // x=element, y=slot
) {
    uint i = tid.x;
    uint slot = tid.y;
    if (slot >= top_k || i >= intermediate) return;

    uint base = slot * 2 * intermediate;
    float gate = gate_up_buf[base + i];
    float up = gate_up_buf[base + intermediate + i];
    // SwiGLU = GELU(gate) * up
    float gelu = gate * 0.5f * (1.0f + precise::tanh(
        0.7978845608f * (gate + 0.044715f * gate * gate * gate)));
    output_buf[slot * intermediate + i] = gelu * up;
}

// --------------------------------------------------------------------------
// moe_weighted_sum — Weighted sum of all top_k expert outputs in one dispatch.
//
// Replaces the zero_buffer + top_k * moe_accumulate pattern (9 dispatches)
// with a single dispatch that reads all expert outputs and routing weights.
//
// Grid: 1D — each thread computes one element of the output.
// --------------------------------------------------------------------------

struct MoeWeightedSumParams {
    uint hidden_size;
    uint top_k;
};

kernel void moe_weighted_sum(
    device const float*  expert_outputs [[buffer(0)]],  // [top_k, hidden_size]
    device const float*  weights        [[buffer(1)]],  // [top_k]
    device float*        output         [[buffer(2)]],  // [hidden_size]
    constant MoeWeightedSumParams& params [[buffer(3)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= params.hidden_size) return;
    float sum = 0.0f;
    for (uint k = 0; k < params.top_k; k++) {
        sum += expert_outputs[k * params.hidden_size + tid] * weights[k];
    }
    output[tid] = sum;
}

// --------------------------------------------------------------------------
// naive_matvec_f32 — Simple matrix-vector multiply for expert projections.
//
// Computes: output[row] = dot(weight[row, :], input[:])
//
// weight is [N, K] row-major, input is [K], output is [N].
// Each thread computes one output element.
//
// This is a naive implementation for Stage 1.  For large matrices,
// the quantized_matmul kernel from Story 1.2 would be used instead.
//
// Buffers:
//   0: weight — float [N, K] row-major
//   1: input  — float [K]
//   2: output — float [N]
//   3: params — { m (unused, always 1), k, n }
// --------------------------------------------------------------------------

struct MatvecParams {
    uint m;  // unused for matvec, included for struct compatibility
    uint k;  // inner dimension
    uint n;  // output dimension (number of rows in weight)
};

kernel void naive_matvec_f32(
    device const float* weight [[buffer(0)]],
    device const float* input  [[buffer(1)]],
    device float*       output [[buffer(2)]],
    constant MatvecParams& params [[buffer(3)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.n) return;

    float sum = 0.0f;
    device const float* row = weight + gid * params.k;
    for (uint i = 0; i < params.k; i++) {
        sum += row[i] * input[i];
    }
    output[gid] = sum;
}

// --------------------------------------------------------------------------
// moe_gather_topk_weights — Gather softmax probs at top-K sorted indices,
// multiply by per_expert_scale, and renormalize.
//
// Replaces the CPU softmax+argsort+gather+scale+renorm sequence for MoE
// routing.  Runs on GPU so the session never needs to break.
//
// Inputs:
//   softmax_probs   — f32 [n_experts]  (softmax output, from dispatch_softmax)
//   sorted_indices  — u32 [n_experts]  (descending sort, from dispatch_argsort)
//   per_expert_scale— f32 [n_experts]  (per-expert learned scale)
//
// Outputs:
//   out_expert_ids  — u32 [top_k]  (selected expert indices)
//   out_weights     — f32 [top_k]  (pre-scaled, renormalized routing weights)
//
// Grid: single thread (top_k <= 8, trivial work).
// --------------------------------------------------------------------------

struct MoeGatherTopkParams {
    uint n_experts;
    uint top_k;
};

kernel void moe_gather_topk_weights(
    device const float*  softmax_probs    [[buffer(0)]],
    device const uint*   sorted_indices   [[buffer(1)]],
    device const float*  per_expert_scale [[buffer(2)]],
    device uint*         out_expert_ids   [[buffer(3)]],
    device float*        out_weights      [[buffer(4)]],
    constant MoeGatherTopkParams& params  [[buffer(5)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid != 0) return;

    // 1. Gather top-K expert ids and their softmax probabilities
    float top_probs[8];   // max top_k = 8
    float prob_sum = 0.0f;
    for (uint k = 0; k < params.top_k; k++) {
        uint eid = sorted_indices[k];
        out_expert_ids[k] = eid;
        top_probs[k] = softmax_probs[eid];
        prob_sum += top_probs[k];
    }

    // 2. Renormalize and apply per_expert_scale
    float inv_sum = (prob_sum > 0.0f) ? (1.0f / prob_sum) : 0.0f;
    for (uint k = 0; k < params.top_k; k++) {
        uint eid = out_expert_ids[k];
        out_weights[k] = (top_probs[k] * inv_sum) * per_expert_scale[eid];
    }
}