mlx-native 0.6.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
#include <metal_stdlib>
using namespace metal;

// --------------------------------------------------------------------------
// moe_expert_ffn — Single-expert FFN: gate_proj + up_proj -> GELU -> down_proj
//
// For a single expert, computes:
//   gate_out = gate_proj(x)        [input_dim -> intermediate_dim]
//   up_out   = up_proj(x)          [input_dim -> intermediate_dim]
//   hidden   = GELU(gate_out) * up_out
//   out      = down_proj(hidden)   [intermediate_dim -> input_dim]
//
// This shader works on float (f32) dequantized weights.  The Rust host
// is responsible for providing pre-dequantized weight slices for the
// selected expert, OR this shader is called after dequantization.
//
// For Stage 1, the Rust host loops over selected experts and dispatches
// the quantized_matmul kernel for each projection, then calls this
// elementwise kernel for the GELU + multiply fusion.
//
// This shader does the fused GELU-multiply: hidden = GELU(gate_out) * up_out
// --------------------------------------------------------------------------

// GELU approximation matching PyTorch/MLX: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
// Clamp threshold for tanh argument to prevent NaN from exp() overflow.
// tanh saturates at +/-1 well before |x| = 10, so clamping at 15 is safe.
inline float gelu_approx(float x) {
    const float sqrt_2_over_pi = 0.7978845608028654f;
    const float tanh_clamp = 15.0f;
    float x3 = x * x * x;
    float inner = sqrt_2_over_pi * (x + 0.044715f * x3);
    inner = clamp(inner, -tanh_clamp, tanh_clamp);
    return 0.5f * x * (1.0f + tanh(inner));
}

struct FusedGeluMulParams {
    uint n_elements;
};

// Fused GELU(gate_out) * up_out
// Buffers:
//   0: gate_out  — float [n_elements]  (input, will be overwritten with result)
//   1: up_out    — float [n_elements]  (input)
//   2: output    — float [n_elements]  (output: GELU(gate_out) * up_out)
//   3: params    — { n_elements }
kernel void fused_gelu_mul(
    device const float* gate_out [[buffer(0)]],
    device const float* up_out   [[buffer(1)]],
    device float*       output   [[buffer(2)]],
    constant FusedGeluMulParams& params [[buffer(3)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.n_elements) return;
    output[gid] = gelu_approx(gate_out[gid]) * up_out[gid];
}

// --------------------------------------------------------------------------
// moe_accumulate — Weighted accumulation: result += weight * expert_output
//
// Used by the Rust host to accumulate expert outputs with routing weights.
//
// Buffers:
//   0: accumulator   — float [n_elements]  (in/out)
//   1: expert_output — float [n_elements]  (input)
//   2: params        — { n_elements, routing_weight }
// --------------------------------------------------------------------------

struct MoeAccumParams {
    uint n_elements;
    float routing_weight;
};

kernel void moe_accumulate(
    device float*       accumulator   [[buffer(0)]],
    device const float* expert_output [[buffer(1)]],
    constant MoeAccumParams& params   [[buffer(2)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.n_elements) return;
    accumulator[gid] += params.routing_weight * expert_output[gid];
}

// --------------------------------------------------------------------------
// zero_buffer — Zero-initialize a float buffer
//
// Buffers:
//   0: buffer      — float [n_elements]
//   1: params      — { n_elements }
// --------------------------------------------------------------------------

struct ZeroParams {
    uint n_elements;
};

kernel void zero_buffer(
    device float* buffer [[buffer(0)]],
    constant ZeroParams& params [[buffer(1)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.n_elements) return;
    buffer[gid] = 0.0f;
}

// --------------------------------------------------------------------------
// moe_swiglu_fused — SwiGLU on a fused [gate, up] buffer.
//
// Takes a buffer of 2*N elements where:
//   - First N elements are the gate projection output
//   - Last N elements are the up projection output
// Produces N elements: GELU(gate[i]) * up[i]
//
// This is the fused variant for models that use a single gate_up projection
// (e.g., Gemma 4 MoE experts with fused gate/up weights).
//
// Buffers:
//   0: gate_up — float [2 * N] (input: gate || up, concatenated)
//   1: output  — float [N]     (output: GELU(gate) * up)
//   2: params  — { n_elements = N }
// --------------------------------------------------------------------------

kernel void moe_swiglu_fused(
    device const float* gate_up [[buffer(0)]],
    device float*       output  [[buffer(1)]],
    constant FusedGeluMulParams& params [[buffer(2)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.n_elements) return;
    float gate_val = gate_up[gid];
    float up_val   = gate_up[params.n_elements + gid];
    output[gid] = gelu_approx(gate_val) * up_val;
}

// --------------------------------------------------------------------------
// moe_swiglu_batch — Batched SwiGLU across all top_k expert slots.
//
// Takes a [top_k, 2*intermediate] buffer where for each slot k:
//   - gate values are at [k, 0..intermediate)
//   - up values are at [k, intermediate..2*intermediate)
// Produces [top_k, intermediate] output: GELU(gate[i]) * up[i] per slot.
//
// Grid: 2D — x=element within intermediate, y=expert slot.
// Replaces top_k separate moe_swiglu_fused dispatches with 1.
// --------------------------------------------------------------------------

kernel void moe_swiglu_batch(
    device const float* gate_up_buf  [[buffer(0)]],  // [top_k, 2*intermediate]
    device float*       output_buf   [[buffer(1)]],  // [top_k, intermediate]
    constant uint&      intermediate [[buffer(2)]],
    constant uint&      top_k        [[buffer(3)]],
    uint2 tid [[thread_position_in_grid]]             // x=element, y=slot
) {
    uint i = tid.x;
    uint slot = tid.y;
    if (slot >= top_k || i >= intermediate) return;

    uint base = slot * 2 * intermediate;
    float gate = gate_up_buf[base + i];
    float up = gate_up_buf[base + intermediate + i];
    // SwiGLU = GELU(gate) * up
    float gelu = gate * 0.5f * (1.0f + precise::tanh(
        0.7978845608f * (gate + 0.044715f * gate * gate * gate)));
    output_buf[slot * intermediate + i] = gelu * up;
}

// --------------------------------------------------------------------------
// moe_weighted_sum — Weighted sum of all top_k expert outputs in one dispatch.
//
// Replaces the zero_buffer + top_k * moe_accumulate pattern (9 dispatches)
// with a single dispatch that reads all expert outputs and routing weights.
//
// Grid: 1D — each thread computes one element of the output.
// --------------------------------------------------------------------------

struct MoeWeightedSumParams {
    uint hidden_size;
    uint top_k;
};

kernel void moe_weighted_sum(
    device const float*  expert_outputs [[buffer(0)]],  // [top_k, hidden_size]
    device const float*  weights        [[buffer(1)]],  // [top_k]
    device float*        output         [[buffer(2)]],  // [hidden_size]
    constant MoeWeightedSumParams& params [[buffer(3)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= params.hidden_size) return;
    float sum = 0.0f;
    for (uint k = 0; k < params.top_k; k++) {
        sum += expert_outputs[k * params.hidden_size + tid] * weights[k];
    }
    output[tid] = sum;
}

// --------------------------------------------------------------------------
// fused_gelu_mul_bf16 — bf16 variant of fused_gelu_mul.
//
// Inputs gate_out and up_out are bfloat16; output is bfloat16.
// All arithmetic is promoted to f32 (accumulate in f32, store in bf16).
//
// Buffers:
//   0: gate_out  — bfloat [n_elements]
//   1: up_out    — bfloat [n_elements]
//   2: output    — bfloat [n_elements]  (GELU(gate_out) * up_out)
//   3: params    — { n_elements }
// --------------------------------------------------------------------------

kernel void fused_gelu_mul_bf16(
    device const bfloat* gate_out [[buffer(0)]],
    device const bfloat* up_out   [[buffer(1)]],
    device bfloat*       output   [[buffer(2)]],
    constant FusedGeluMulParams& params [[buffer(3)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.n_elements) return;
    // Promote to f32 for accurate GELU; store result as bf16.
    const float g = static_cast<float>(gate_out[gid]);
    const float u = static_cast<float>(up_out[gid]);
    output[gid] = bfloat(gelu_approx(g) * u);
}

/// Multi-token SwiGLU for batched prefill.
/// Input:  [n_tokens, top_k, 2*intermediate]
/// Output: [n_tokens, top_k, intermediate]
/// Grid:   3D (intermediate, top_k, n_tokens)
struct MoeSwigluSeqParams {
    uint intermediate;
    uint top_k;
    uint n_tokens;
};
kernel void moe_swiglu_seq(
    device const float* gate_up_buf  [[buffer(0)]],
    device float*       output_buf   [[buffer(1)]],
    constant MoeSwigluSeqParams& params [[buffer(2)]],
    uint3 tid [[thread_position_in_grid]]
) {
    uint i = tid.x;
    uint slot = tid.y;
    uint tok = tid.z;
    if (tok >= params.n_tokens || slot >= params.top_k || i >= params.intermediate) return;

    uint slot_base = (tok * params.top_k + slot) * 2 * params.intermediate;
    float gate = gate_up_buf[slot_base + i];
    float up   = gate_up_buf[slot_base + params.intermediate + i];
    float gelu = gate * 0.5f * (1.0f + precise::tanh(
        0.7978845608f * (gate + 0.044715f * gate * gate * gate)));
    output_buf[(tok * params.top_k + slot) * params.intermediate + i] = gelu * up;
}

/// Multi-token weighted sum of expert outputs.
/// Input:  expert_outputs [n_tokens, top_k, hidden_size]
///         weights        [n_tokens, top_k]
/// Output: [n_tokens, hidden_size]
struct MoeWeightedSumSeqParams {
    uint hidden_size;
    uint top_k;
    uint n_tokens;
};
kernel void moe_weighted_sum_seq(
    device const float*  expert_outputs [[buffer(0)]],
    device const float*  weights        [[buffer(1)]],
    device float*        output         [[buffer(2)]],
    constant MoeWeightedSumSeqParams& params [[buffer(3)]],
    uint2 tid [[thread_position_in_grid]]
) {
    uint d = tid.x;
    uint tok = tid.y;
    if (tok >= params.n_tokens || d >= params.hidden_size) return;

    float sum = 0.0f;
    for (uint k = 0; k < params.top_k; k++) {
        const uint in_idx = (tok * params.top_k + k) * params.hidden_size + d;
        const uint w_idx = tok * params.top_k + k;
        sum += expert_outputs[in_idx] * weights[w_idx];
    }
    output[tok * params.hidden_size + d] = sum;
}

// --------------------------------------------------------------------------
// moe_swiglu_seq_bf16 — bf16 variant of moe_swiglu_seq.
//
// Input:  [n_tokens, top_k, 2*intermediate] bfloat16
// Output: [n_tokens, top_k, intermediate]   bfloat16
// Grid:   3D (intermediate, top_k, n_tokens)
//
// All arithmetic in f32; inputs read as bf16, output stored as bf16.
// --------------------------------------------------------------------------

kernel void moe_swiglu_seq_bf16(
    device const bfloat* gate_up_buf  [[buffer(0)]],
    device bfloat*       output_buf   [[buffer(1)]],
    constant MoeSwigluSeqParams& params [[buffer(2)]],
    uint3 tid [[thread_position_in_grid]]
) {
    uint i    = tid.x;
    uint slot = tid.y;
    uint tok  = tid.z;
    if (tok >= params.n_tokens || slot >= params.top_k || i >= params.intermediate) return;

    uint slot_base = (tok * params.top_k + slot) * 2u * params.intermediate;
    const float gate = static_cast<float>(gate_up_buf[slot_base + i]);
    const float up   = static_cast<float>(gate_up_buf[slot_base + params.intermediate + i]);
    // SwiGLU = GELU(gate) * up; compute in f32, store as bf16
    const float gelu = gate * 0.5f * (1.0f + precise::tanh(
        0.7978845608f * (gate + 0.044715f * gate * gate * gate)));
    output_buf[(tok * params.top_k + slot) * params.intermediate + i] = bfloat(gelu * up);
}

// --------------------------------------------------------------------------
// moe_weighted_sum_seq_bf16_input — bf16 expert_outputs, f32 weights/output.
//
// Matches the MoE convention: expert intermediates are bf16, but the
// weighted accumulator (pf_moe_accum) stays f32 for residual precision.
//
// Input:  expert_outputs [n_tokens, top_k, hidden_size] bfloat16
//         weights        [n_tokens, top_k]              float32
// Output: [n_tokens, hidden_size]                       float32
// --------------------------------------------------------------------------

kernel void moe_weighted_sum_seq_bf16_input(
    device const bfloat* expert_outputs [[buffer(0)]],
    device const float*  weights        [[buffer(1)]],
    device float*        output         [[buffer(2)]],
    constant MoeWeightedSumSeqParams& params [[buffer(3)]],
    uint2 tid [[thread_position_in_grid]]
) {
    uint d   = tid.x;
    uint tok = tid.y;
    if (tok >= params.n_tokens || d >= params.hidden_size) return;

    float sum = 0.0f;
    for (uint k = 0u; k < params.top_k; k++) {
        const uint in_idx = (tok * params.top_k + k) * params.hidden_size + d;
        const uint w_idx  = tok * params.top_k + k;
        // Promote bf16 expert output to f32 before multiply-accumulate
        sum += static_cast<float>(expert_outputs[in_idx]) * weights[w_idx];
    }
    output[tok * params.hidden_size + d] = sum;
}

// --------------------------------------------------------------------------
// naive_matvec_f32 — Simple matrix-vector multiply for expert projections.
//
// Computes: output[row] = dot(weight[row, :], input[:])
//
// weight is [N, K] row-major, input is [K], output is [N].
// Each thread computes one output element.
//
// This is a naive implementation for Stage 1.  For large matrices,
// the quantized_matmul kernel from Story 1.2 would be used instead.
//
// Buffers:
//   0: weight — float [N, K] row-major
//   1: input  — float [K]
//   2: output — float [N]
//   3: params — { m (unused, always 1), k, n }
// --------------------------------------------------------------------------

struct MatvecParams {
    uint m;  // unused for matvec, included for struct compatibility
    uint k;  // inner dimension
    uint n;  // output dimension (number of rows in weight)
};

kernel void naive_matvec_f32(
    device const float* weight [[buffer(0)]],
    device const float* input  [[buffer(1)]],
    device float*       output [[buffer(2)]],
    constant MatvecParams& params [[buffer(3)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.n) return;

    float sum = 0.0f;
    device const float* row = weight + gid * params.k;
    for (uint i = 0; i < params.k; i++) {
        sum += row[i] * input[i];
    }
    output[gid] = sum;
}

// --------------------------------------------------------------------------
// moe_gather_topk_weights — Gather softmax probs at top-K sorted indices,
// multiply by per_expert_scale, and renormalize.
//
// Replaces the CPU softmax+argsort+gather+scale+renorm sequence for MoE
// routing.  Runs on GPU so the session never needs to break.
//
// Inputs:
//   softmax_probs   — f32 [n_experts]  (softmax output, from dispatch_softmax)
//   sorted_indices  — u32 [n_experts]  (descending sort, from dispatch_argsort)
//   per_expert_scale— f32 [n_experts]  (per-expert learned scale)
//
// Outputs:
//   out_expert_ids  — u32 [top_k]  (selected expert indices)
//   out_weights     — f32 [top_k]  (pre-scaled, renormalized routing weights)
//
// Grid: single thread (top_k <= 8, trivial work).
// --------------------------------------------------------------------------

struct MoeGatherTopkParams {
    uint n_experts;
    uint top_k;
};

kernel void moe_gather_topk_weights(
    device const float*  softmax_probs    [[buffer(0)]],
    device const uint*   sorted_indices   [[buffer(1)]],
    device const float*  per_expert_scale [[buffer(2)]],
    device uint*         out_expert_ids   [[buffer(3)]],
    device float*        out_weights      [[buffer(4)]],
    constant MoeGatherTopkParams& params  [[buffer(5)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid != 0) return;

    // 1. Gather top-K expert ids and their softmax probabilities
    float top_probs[8];   // max top_k = 8
    float prob_sum = 0.0f;
    for (uint k = 0; k < params.top_k; k++) {
        uint eid = sorted_indices[k];
        out_expert_ids[k] = eid;
        top_probs[k] = softmax_probs[eid];
        prob_sum += top_probs[k];
    }

    // 2. Renormalize and apply per_expert_scale
    float inv_sum = (prob_sum > 0.0f) ? (1.0f / prob_sum) : 0.0f;
    for (uint k = 0; k < params.top_k; k++) {
        uint eid = out_expert_ids[k];
        out_weights[k] = (top_probs[k] * inv_sum) * per_expert_scale[eid];
    }
}