mlx-native 0.3.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
#include <metal_stdlib>
using namespace metal;

/// Fused RMS normalization + residual addition (float32).
///
/// Computes Gemma4's post-attention / post-FFN ordering:
///   normed[i] = rms_norm(input[i], weight[i], eps)
///   output[i] = residual[i] + normed[i]
///
/// Fuses two separate dispatches (rms_norm_f32 + elementwise_add_f32) into
/// one kernel launch per transformer sub-layer.
///
/// Buffer layout:
///   buffer(0): residual — float [rows * dim]   residual stream (unmodified)
///   buffer(1): input    — float [rows * dim]   sublayer output (to normalize)
///   buffer(2): weight   — float [dim]          RMS norm learned scale
///   buffer(3): output   — float [rows * dim]   residual + normed result
///   buffer(4): dim      — uint
///   buffer(5): rows     — uint
///   buffer(6): eps      — float
///
/// Threadgroup: (min(256, next_pow2(dim)), 1, 1) — one threadgroup per row
/// Grid       : (rows, 1, 1)
/// Shared mem : tg_size * sizeof(float) for the sum-of-squares reduction

kernel void fused_norm_add_f32(
    device const float* residual [[buffer(0)]],
    device const float* input    [[buffer(1)]],
    device const float* weight   [[buffer(2)]],
    device float*       output   [[buffer(3)]],
    constant uint&      dim      [[buffer(4)]],
    constant uint&      rows     [[buffer(5)]],
    constant float&     eps      [[buffer(6)]],
    uint row_id  [[threadgroup_position_in_grid]],
    uint tid     [[thread_index_in_threadgroup]],
    uint tg_size [[threads_per_threadgroup]],
    threadgroup float* shared [[threadgroup(0)]]
) {
    if (row_id >= rows) { return; }

    const uint base = row_id * dim;

    // Phase 1: accumulate partial sum-of-squares over input.
    float partial_sq = 0.0f;
    for (uint i = tid; i < dim; i += tg_size) {
        const float v = input[base + i];
        partial_sq += v * v;
    }

    shared[tid] = partial_sq;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Tree reduction to obtain total sum-of-squares.
    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
        if (tid < stride) {
            shared[tid] += shared[tid + stride];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    // rms_inv = rsqrt(mean(input^2) + eps)
    const float rms_inv = rsqrt(shared[0] / float(dim) + eps);

    // Phase 2: normalize input, apply weight, add residual, store output.
    for (uint i = tid; i < dim; i += tg_size) {
        const float normed = input[base + i] * rms_inv * weight[i];
        output[base + i] = residual[base + i] + normed;
    }
}

/// Fused residual addition + RMS normalization (float32).
///
/// Computes:
///   sum[i]    = residual[i] + input[i]
///   normed[i] = sum[i] * rsqrt(mean(sum^2) + eps) * weight[i]
///
/// Optionally writes `sum` to a separate output buffer for use as the next
/// layer's residual (avoids an extra elementwise kernel).
///
/// Buffer layout:
///   buffer(0): residual      — float [rows * dim]
///   buffer(1): input         — float [rows * dim]
///   buffer(2): weight        — float [dim]
///   buffer(3): normed_output — float [rows * dim]   normalized result
///   buffer(4): sum_output    — float [rows * dim]   residual+input (may be unused)
///   buffer(5): params        — { dim, rows, eps, write_sum }
///
/// Threadgroup: (min(256, next_pow2(dim)), 1, 1)
/// Grid       : (rows, 1, 1)
///
/// Shared memory: tg_size * sizeof(float) at threadgroup(0) for the reduction.

struct FusedResidualNormF32Params {
    uint  dim;
    uint  rows;
    float eps;
    uint  write_sum;   // 0 = skip writing sum_output, nonzero = write it
};

kernel void fused_residual_norm_f32(
    device const float*               residual      [[buffer(0)]],
    device const float*               input         [[buffer(1)]],
    device const float*               weight        [[buffer(2)]],
    device float*                     normed_output [[buffer(3)]],
    device float*                     sum_output    [[buffer(4)]],
    constant FusedResidualNormF32Params& params     [[buffer(5)]],
    uint row_id   [[threadgroup_position_in_grid]],
    uint tid      [[thread_index_in_threadgroup]],
    uint tg_size  [[threads_per_threadgroup]],
    threadgroup float* shared          [[threadgroup(0)]]
) {
    const uint  dim       = params.dim;
    const float eps       = params.eps;
    const bool  write_sum = (params.write_sum != 0u);

    const uint base = row_id * dim;

    // Phase 1: compute residual + input element-wise, accumulate sum-of-squares
    float partial_sq = 0.0f;
    for (uint i = tid; i < dim; i += tg_size) {
        const float s = residual[base + i] + input[base + i];
        shared[i] = s;
        partial_sq += s * s;
    }

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Optionally write the un-normed sum before the reduction overwrites shared.
    if (write_sum) {
        for (uint i = tid; i < dim; i += tg_size) {
            sum_output[base + i] = shared[i];
        }
    }

    // Reduction
    shared[tid] = partial_sq;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
        if (tid < stride) {
            shared[tid] += shared[tid + stride];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    const float rms_inv = rsqrt(shared[0] / float(dim) + eps);

    // Phase 2: recompute element sums, normalize, write
    for (uint i = tid; i < dim; i += tg_size) {
        const float s = residual[base + i] + input[base + i];
        const float w = weight[i];
        normed_output[base + i] = s * rms_inv * w;
    }
}

/// Fused post-layer: residual add + RMS norm + scalar multiply (float32).
///
/// Computes the end-of-layer sequence in one pass:
///   sum[i]    = residual[i] + input[i]
///   normed[i] = rms_norm(sum, weight, eps)[i]
///   output[i] = normed[i] * scalar
///
/// When scalar_is_vector != 0, scalar is a per-element array of shape [dim].
/// Otherwise, scalar[0] is broadcast to all elements.
///
/// Buffer layout:
///   buffer(0): residual — float [rows * dim]
///   buffer(1): input    — float [rows * dim]
///   buffer(2): weight   — float [dim]
///   buffer(3): output   — float [rows * dim]
///   buffer(4): scalar   — float [1] or [dim]
///   buffer(5): params   — FusedResidualNormScalarF32Params
///
/// Threadgroup: (min(256, next_pow2(dim)), 1, 1)
/// Grid       : (rows, 1, 1)

struct FusedResidualNormScalarF32Params {
    uint  dim;
    uint  rows;
    float eps;
    uint  scalar_is_vector; // 0 = broadcast scalar[0], nonzero = per-element
};

kernel void fused_residual_norm_scalar_f32(
    device const float*                       residual [[buffer(0)]],
    device const float*                       input    [[buffer(1)]],
    device const float*                       weight   [[buffer(2)]],
    device float*                             output   [[buffer(3)]],
    device const float*                       scalar   [[buffer(4)]],
    constant FusedResidualNormScalarF32Params& params  [[buffer(5)]],
    uint row_id   [[threadgroup_position_in_grid]],
    uint tid      [[thread_index_in_threadgroup]],
    uint tg_size  [[threads_per_threadgroup]],
    threadgroup float* shared          [[threadgroup(0)]]
) {
    const uint  dim              = params.dim;
    const float eps              = params.eps;
    const bool  scalar_is_vector = (params.scalar_is_vector != 0u);

    const uint base = row_id * dim;

    // Phase 1: accumulate sum-of-squares of (residual + input)
    float partial_sq = 0.0f;
    for (uint i = tid; i < dim; i += tg_size) {
        const float s = residual[base + i] + input[base + i];
        shared[i] = s;
        partial_sq += s * s;
    }

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Reduction
    shared[tid] = partial_sq;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
        if (tid < stride) {
            shared[tid] += shared[tid + stride];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    const float rms_inv = rsqrt(shared[0] / float(dim) + eps);

    // Load broadcast scalar if needed
    const float broadcast_scalar = scalar_is_vector ? 0.0f : scalar[0];

    // Phase 2: recompute sums, normalize, scale, write
    for (uint i = tid; i < dim; i += tg_size) {
        const float s = residual[base + i] + input[base + i];
        const float w = weight[i];
        const float sc = scalar_is_vector ? scalar[i] : broadcast_scalar;
        output[base + i] = s * rms_inv * w * sc;
    }
}

/// Fused MoE routing: softmax + argsort descending + gather top-K weights (float32).
///
/// Replaces 3 separate dispatches with one kernel. Operates on [1, num_experts]
/// logits (single token). Top-K is small (typically 2).
///
/// Buffer layout:
///   buffer(0): logits         — float [num_experts]  (input)
///   buffer(1): expert_ids     — uint  [top_k]        (output: sorted expert indices)
///   buffer(2): routing_weights— float [top_k]        (output: top-K softmax weights
///                                                    renormalized over the selected
///                                                    experts, then scaled by
///                                                    per_expert_scale)
///   buffer(3): per_expert_scale — float [num_experts] (input: per-expert scale factors)
///   buffer(4): params         — { num_experts, top_k }
///
/// Single threadgroup, tg_size threads.

struct FusedMoeRoutingParams {
    uint num_experts;
    uint top_k;
};

kernel void fused_moe_routing_f32(
    device const float*               logits          [[buffer(0)]],
    device uint*                      expert_ids      [[buffer(1)]],
    device float*                     routing_weights [[buffer(2)]],
    device const float*               per_expert_scale [[buffer(3)]],
    constant FusedMoeRoutingParams&   params          [[buffer(4)]],
    uint tid      [[thread_index_in_threadgroup]],
    uint tg_size  [[threads_per_threadgroup]],
    threadgroup float* shared [[threadgroup(0)]]
) {
    const uint num_experts = params.num_experts;
    const uint top_k       = params.top_k;

    // Step 1: find max for numerical stability (softmax)
    float local_max = -INFINITY;
    for (uint i = tid; i < num_experts; i += tg_size) {
        local_max = max(local_max, logits[i]);
    }
    shared[tid] = local_max;
    threadgroup_barrier(mem_flags::mem_threadgroup);
    for (uint s = tg_size / 2; s > 0; s >>= 1) {
        if (tid < s) shared[tid] = max(shared[tid], shared[tid + s]);
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }
    const float max_val = shared[0];

    // Step 2: compute exp(x - max) and sum
    float local_sum = 0.0f;
    for (uint i = tid; i < num_experts; i += tg_size) {
        const float e = exp(logits[i] - max_val);
        shared[num_experts + i] = e;  // store exp values after first num_experts slots
        local_sum += e;
    }
    shared[tid] = local_sum;
    threadgroup_barrier(mem_flags::mem_threadgroup);
    for (uint s = tg_size / 2; s > 0; s >>= 1) {
        if (tid < s) shared[tid] += shared[tid + s];
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }
    const float sum_exp = shared[0];
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Step 3: compute softmax probabilities in shared[0..num_experts)
    for (uint i = tid; i < num_experts; i += tg_size) {
        shared[i] = shared[num_experts + i] / sum_exp;
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Step 4: find top-K (serial, only thread 0 — K is tiny, typically 2)
    if (tid == 0) {
        // Simple selection sort for K elements
        for (uint k = 0; k < top_k; k++) {
            float best_val = -1.0f;
            uint best_idx = 0;
            for (uint i = 0; i < num_experts; i++) {
                if (shared[i] > best_val) {
                    best_val = shared[i];
                    best_idx = i;
                }
            }
            expert_ids[k] = best_idx;
            // Match the old passing candle path:
            //   1. softmax over all experts
            //   2. renormalize the selected top-K slice
            //   3. apply per_expert_scale
            // Critically, do NOT renormalize after applying per_expert_scale.
            routing_weights[k] = best_val;
            shared[best_idx] = -1.0f;  // mark as used
        }

        // Renormalize the top-K weights before applying per_expert_scale.
        float topk_sum = 0.0f;
        for (uint k = 0; k < top_k; k++) {
            topk_sum += routing_weights[k];
        }
        if (topk_sum > 0.0f) {
            for (uint k = 0; k < top_k; k++) {
                const uint eid = expert_ids[k];
                routing_weights[k] = (routing_weights[k] / topk_sum) * per_expert_scale[eid];
            }
        } else {
            for (uint k = 0; k < top_k; k++) {
                routing_weights[k] = 0.0f;
            }
        }
    }
}

/// Batched fused MoE routing for prefill (float32).
///
/// Same semantics as fused_moe_routing_f32, but processes n_tokens at once.
/// Grid: (n_tokens, 1, 1). Each threadgroup handles one token's routing.
///
/// Buffer layout:
///   buffer(0): logits          — float [n_tokens, num_experts]
///   buffer(1): expert_ids      — uint  [n_tokens, top_k]
///   buffer(2): routing_weights — float [n_tokens, top_k]
///   buffer(3): per_expert_scale — float [num_experts]
///   buffer(4): params          — { num_experts, top_k }
///
/// Shared memory: (2 * num_experts + tg_size) floats.
kernel void fused_moe_routing_batch_f32(
    device const float*               logits_all      [[buffer(0)]],
    device uint*                      expert_ids_all  [[buffer(1)]],
    device float*                     routing_weights_all [[buffer(2)]],
    device const float*               per_expert_scale [[buffer(3)]],
    constant FusedMoeRoutingParams&   params          [[buffer(4)]],
    uint tok_id   [[threadgroup_position_in_grid]],
    uint tid      [[thread_index_in_threadgroup]],
    uint tg_size  [[threads_per_threadgroup]],
    threadgroup float* shared [[threadgroup(0)]]
) {
    const uint num_experts = params.num_experts;
    const uint top_k       = params.top_k;

    device const float* logits          = logits_all       + tok_id * num_experts;
    device uint*        expert_ids      = expert_ids_all   + tok_id * top_k;
    device float*       routing_weights = routing_weights_all + tok_id * top_k;

    // Step 1: find max for numerical stability (softmax)
    float local_max = -INFINITY;
    for (uint i = tid; i < num_experts; i += tg_size) {
        local_max = max(local_max, logits[i]);
    }
    shared[tid] = local_max;
    threadgroup_barrier(mem_flags::mem_threadgroup);
    for (uint s = tg_size / 2; s > 0; s >>= 1) {
        if (tid < s) shared[tid] = max(shared[tid], shared[tid + s]);
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }
    const float max_val = shared[0];

    // All threads must complete the broadcast-read of shared[0] for max_val
    // BEFORE any thread (notably tid==0) overwrites shared[0] with its
    // local_sum. Same race class as the fused_head_norm_rope Phase-1→Phase-2
    // boundary (see b31505d / hf2q docs/spike-batched-prefill-race-rootcause.md).
    // Without this barrier, simdgroups that race ahead of simdgroup 0 read a
    // clobbered shared[0] and compute a corrupt max_val — produces
    // nondeterministic routing decisions at scale.
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Step 2: compute exp(x - max) and sum
    float local_sum = 0.0f;
    for (uint i = tid; i < num_experts; i += tg_size) {
        const float e = exp(logits[i] - max_val);
        shared[num_experts + i] = e;
        local_sum += e;
    }
    shared[tid] = local_sum;
    threadgroup_barrier(mem_flags::mem_threadgroup);
    for (uint s = tg_size / 2; s > 0; s >>= 1) {
        if (tid < s) shared[tid] += shared[tid + s];
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }
    const float sum_exp = shared[0];
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Step 3: compute softmax probabilities in shared[0..num_experts)
    for (uint i = tid; i < num_experts; i += tg_size) {
        shared[i] = shared[num_experts + i] / sum_exp;
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Step 4: find top-K (serial, only thread 0 — K is tiny, typically 2)
    if (tid == 0) {
        for (uint k = 0; k < top_k; k++) {
            float best_val = -1.0f;
            uint best_idx = 0;
            for (uint i = 0; i < num_experts; i++) {
                if (shared[i] > best_val) {
                    best_val = shared[i];
                    best_idx = i;
                }
            }
            expert_ids[k] = best_idx;
            routing_weights[k] = best_val;
            shared[best_idx] = -1.0f;
        }

        float topk_sum = 0.0f;
        for (uint k = 0; k < top_k; k++) {
            topk_sum += routing_weights[k];
        }
        if (topk_sum > 0.0f) {
            for (uint k = 0; k < top_k; k++) {
                const uint eid = expert_ids[k];
                routing_weights[k] = (routing_weights[k] / topk_sum) * per_expert_scale[eid];
            }
        } else {
            for (uint k = 0; k < top_k; k++) {
                routing_weights[k] = 0.0f;
            }
        }
    }
}

/// Fused RMS normalization + residual addition + scalar multiply (float32).
///
/// Computes:
///   normed[i] = rms_norm(input, weight, eps)[i]
///   output[i] = (residual[i] + normed[i]) * scalar[i or 0]
///
/// This is the correct end-of-layer operation for Gemma 4:
///   output = (residual + rms_norm(mlp_output)) * layer_scalar
///
/// Note: the norm is applied to `input` ALONE, not to the sum.
///
/// Buffer layout:
///   buffer(0): residual — float [rows * dim]
///   buffer(1): input    — float [rows * dim]   (sublayer output to normalize)
///   buffer(2): weight   — float [dim]          (RMS norm learned scale)
///   buffer(3): output   — float [rows * dim]
///   buffer(4): scalar   — float [1] or [dim]
///   buffer(5): params   — FusedNormAddScalarF32Params

struct FusedNormAddScalarF32Params {
    uint  dim;
    uint  rows;
    float eps;
    uint  scalar_is_vector; // 0 = broadcast scalar[0], nonzero = per-element
};

kernel void fused_norm_add_scalar_f32(
    device const float*                    residual [[buffer(0)]],
    device const float*                    input    [[buffer(1)]],
    device const float*                    weight   [[buffer(2)]],
    device float*                          output   [[buffer(3)]],
    device const float*                    scalar   [[buffer(4)]],
    constant FusedNormAddScalarF32Params&  params   [[buffer(5)]],
    uint row_id   [[threadgroup_position_in_grid]],
    uint tid      [[thread_index_in_threadgroup]],
    uint tg_size  [[threads_per_threadgroup]],
    threadgroup float* shared          [[threadgroup(0)]]
) {
    const uint  dim              = params.dim;
    const float eps              = params.eps;
    const bool  scalar_is_vector = (params.scalar_is_vector != 0u);

    if (row_id >= params.rows) { return; }

    const uint base = row_id * dim;

    // Phase 1: accumulate sum-of-squares of input (NOT the sum with residual)
    float partial_sq = 0.0f;
    for (uint i = tid; i < dim; i += tg_size) {
        const float v = input[base + i];
        partial_sq += v * v;
    }

    shared[tid] = partial_sq;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
        if (tid < stride) {
            shared[tid] += shared[tid + stride];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    const float rms_inv = rsqrt(shared[0] / float(dim) + eps);

    const float broadcast_scalar = scalar_is_vector ? 0.0f : scalar[0];

    // Phase 2: normalize input, add residual, scale, write
    for (uint i = tid; i < dim; i += tg_size) {
        const float normed = input[base + i] * rms_inv * weight[i];
        const float sum = residual[base + i] + normed;
        const float sc = scalar_is_vector ? scalar[i] : broadcast_scalar;
        output[base + i] = sum * sc;
    }
}