mlx-native 0.7.0

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
#include <metal_stdlib>
using namespace metal;

// --------------------------------------------------------------------------
// moe_expert_ffn — Single-expert FFN: gate_proj + up_proj -> GELU -> down_proj
//
// For a single expert, computes:
//   gate_out = gate_proj(x)        [input_dim -> intermediate_dim]
//   up_out   = up_proj(x)          [input_dim -> intermediate_dim]
//   hidden   = GELU(gate_out) * up_out
//   out      = down_proj(hidden)   [intermediate_dim -> input_dim]
//
// This shader works on float (f32) dequantized weights.  The Rust host
// is responsible for providing pre-dequantized weight slices for the
// selected expert, OR this shader is called after dequantization.
//
// For Stage 1, the Rust host loops over selected experts and dispatches
// the quantized_matmul kernel for each projection, then calls this
// elementwise kernel for the GELU + multiply fusion.
//
// This shader does the fused GELU-multiply: hidden = GELU(gate_out) * up_out
// --------------------------------------------------------------------------

// GELU approximation matching PyTorch/MLX: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
// Clamp threshold for tanh argument to prevent NaN from exp() overflow.
// tanh saturates at +/-1 well before |x| = 10, so clamping at 15 is safe.
inline float gelu_approx(float x) {
    const float sqrt_2_over_pi = 0.7978845608028654f;
    const float tanh_clamp = 15.0f;
    float x3 = x * x * x;
    float inner = sqrt_2_over_pi * (x + 0.044715f * x3);
    inner = clamp(inner, -tanh_clamp, tanh_clamp);
    return 0.5f * x * (1.0f + tanh(inner));
}

struct FusedGeluMulParams {
    uint n_elements;
};

// Fused GELU(gate_out) * up_out
// Buffers:
//   0: gate_out  — float [n_elements]  (input, will be overwritten with result)
//   1: up_out    — float [n_elements]  (input)
//   2: output    — float [n_elements]  (output: GELU(gate_out) * up_out)
//   3: params    — { n_elements }
kernel void fused_gelu_mul(
    device const float* gate_out [[buffer(0)]],
    device const float* up_out   [[buffer(1)]],
    device float*       output   [[buffer(2)]],
    constant FusedGeluMulParams& params [[buffer(3)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.n_elements) return;
    output[gid] = gelu_approx(gate_out[gid]) * up_out[gid];
}

// --------------------------------------------------------------------------
// moe_accumulate — Weighted accumulation: result += weight * expert_output
//
// Used by the Rust host to accumulate expert outputs with routing weights.
//
// Buffers:
//   0: accumulator   — float [n_elements]  (in/out)
//   1: expert_output — float [n_elements]  (input)
//   2: params        — { n_elements, routing_weight }
// --------------------------------------------------------------------------

struct MoeAccumParams {
    uint n_elements;
    float routing_weight;
};

kernel void moe_accumulate(
    device float*       accumulator   [[buffer(0)]],
    device const float* expert_output [[buffer(1)]],
    constant MoeAccumParams& params   [[buffer(2)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.n_elements) return;
    accumulator[gid] += params.routing_weight * expert_output[gid];
}

// --------------------------------------------------------------------------
// zero_buffer — Zero-initialize a float buffer
//
// Buffers:
//   0: buffer      — float [n_elements]
//   1: params      — { n_elements }
// --------------------------------------------------------------------------

struct ZeroParams {
    uint n_elements;
};

kernel void zero_buffer(
    device float* buffer [[buffer(0)]],
    constant ZeroParams& params [[buffer(1)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.n_elements) return;
    buffer[gid] = 0.0f;
}

// --------------------------------------------------------------------------
// moe_swiglu_fused — SwiGLU on a fused [gate, up] buffer.
//
// Takes a buffer of 2*N elements where:
//   - First N elements are the gate projection output
//   - Last N elements are the up projection output
// Produces N elements: GELU(gate[i]) * up[i]
//
// This is the fused variant for models that use a single gate_up projection
// (e.g., Gemma 4 MoE experts with fused gate/up weights).
//
// Buffers:
//   0: gate_up — float [2 * N] (input: gate || up, concatenated)
//   1: output  — float [N]     (output: GELU(gate) * up)
//   2: params  — { n_elements = N }
// --------------------------------------------------------------------------

kernel void moe_swiglu_fused(
    device const float* gate_up [[buffer(0)]],
    device float*       output  [[buffer(1)]],
    constant FusedGeluMulParams& params [[buffer(2)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.n_elements) return;
    float gate_val = gate_up[gid];
    float up_val   = gate_up[params.n_elements + gid];
    output[gid] = gelu_approx(gate_val) * up_val;
}

// --------------------------------------------------------------------------
// moe_swiglu_batch — Batched SwiGLU across all top_k expert slots.
//
// Takes a [top_k, 2*intermediate] buffer where for each slot k:
//   - gate values are at [k, 0..intermediate)
//   - up values are at [k, intermediate..2*intermediate)
// Produces [top_k, intermediate] output: GELU(gate[i]) * up[i] per slot.
//
// Grid: 2D — x=element within intermediate, y=expert slot.
// Replaces top_k separate moe_swiglu_fused dispatches with 1.
// --------------------------------------------------------------------------

kernel void moe_swiglu_batch(
    device const float* gate_up_buf  [[buffer(0)]],  // [top_k, 2*intermediate]
    device float*       output_buf   [[buffer(1)]],  // [top_k, intermediate]
    constant uint&      intermediate [[buffer(2)]],
    constant uint&      top_k        [[buffer(3)]],
    uint2 tid [[thread_position_in_grid]]             // x=element, y=slot
) {
    uint i = tid.x;
    uint slot = tid.y;
    if (slot >= top_k || i >= intermediate) return;

    uint base = slot * 2 * intermediate;
    float gate = gate_up_buf[base + i];
    float up = gate_up_buf[base + intermediate + i];
    // SwiGLU = GELU(gate) * up
    float gelu = gate * 0.5f * (1.0f + precise::tanh(
        0.7978845608f * (gate + 0.044715f * gate * gate * gate)));
    output_buf[slot * intermediate + i] = gelu * up;
}

// --------------------------------------------------------------------------
// moe_weighted_sum — Weighted sum of all top_k expert outputs in one dispatch.
//
// Replaces the zero_buffer + top_k * moe_accumulate pattern (9 dispatches)
// with a single dispatch that reads all expert outputs and routing weights.
//
// Grid: 1D — each thread computes one element of the output.
// --------------------------------------------------------------------------

struct MoeWeightedSumParams {
    uint hidden_size;
    uint top_k;
};

kernel void moe_weighted_sum(
    device const float*  expert_outputs [[buffer(0)]],  // [top_k, hidden_size]
    device const float*  weights        [[buffer(1)]],  // [top_k]
    device float*        output         [[buffer(2)]],  // [hidden_size]
    constant MoeWeightedSumParams& params [[buffer(3)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= params.hidden_size) return;
    float sum = 0.0f;
    for (uint k = 0; k < params.top_k; k++) {
        sum += expert_outputs[k * params.hidden_size + tid] * weights[k];
    }
    output[tid] = sum;
}

// --------------------------------------------------------------------------
// fused_gelu_mul_bf16 — bf16 variant of fused_gelu_mul.
//
// Inputs gate_out and up_out are bfloat16; output is bfloat16.
// All arithmetic is promoted to f32 (accumulate in f32, store in bf16).
//
// Buffers:
//   0: gate_out  — bfloat [n_elements]
//   1: up_out    — bfloat [n_elements]
//   2: output    — bfloat [n_elements]  (GELU(gate_out) * up_out)
//   3: params    — { n_elements }
// --------------------------------------------------------------------------

kernel void fused_gelu_mul_bf16(
    device const bfloat* gate_out [[buffer(0)]],
    device const bfloat* up_out   [[buffer(1)]],
    device bfloat*       output   [[buffer(2)]],
    constant FusedGeluMulParams& params [[buffer(3)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.n_elements) return;
    // Promote to f32 for accurate GELU; store result as bf16.
    const float g = static_cast<float>(gate_out[gid]);
    const float u = static_cast<float>(up_out[gid]);
    output[gid] = bfloat(gelu_approx(g) * u);
}

/// Multi-token SwiGLU for batched prefill.
/// Input:  [n_tokens, top_k, 2*intermediate]
/// Output: [n_tokens, top_k, intermediate]
/// Grid:   3D (intermediate, top_k, n_tokens)
struct MoeSwigluSeqParams {
    uint intermediate;
    uint top_k;
    uint n_tokens;
};
kernel void moe_swiglu_seq(
    device const float* gate_up_buf  [[buffer(0)]],
    device float*       output_buf   [[buffer(1)]],
    constant MoeSwigluSeqParams& params [[buffer(2)]],
    uint3 tid [[thread_position_in_grid]]
) {
    uint i = tid.x;
    uint slot = tid.y;
    uint tok = tid.z;
    if (tok >= params.n_tokens || slot >= params.top_k || i >= params.intermediate) return;

    uint slot_base = (tok * params.top_k + slot) * 2 * params.intermediate;
    float gate = gate_up_buf[slot_base + i];
    float up   = gate_up_buf[slot_base + params.intermediate + i];
    float gelu = gate * 0.5f * (1.0f + precise::tanh(
        0.7978845608f * (gate + 0.044715f * gate * gate * gate)));
    output_buf[(tok * params.top_k + slot) * params.intermediate + i] = gelu * up;
}

/// Multi-token weighted sum of expert outputs.
/// Input:  expert_outputs [n_tokens, top_k, hidden_size]
///         weights        [n_tokens, top_k]
/// Output: [n_tokens, hidden_size]
struct MoeWeightedSumSeqParams {
    uint hidden_size;
    uint top_k;
    uint n_tokens;
};
kernel void moe_weighted_sum_seq(
    device const float*  expert_outputs [[buffer(0)]],
    device const float*  weights        [[buffer(1)]],
    device float*        output         [[buffer(2)]],
    constant MoeWeightedSumSeqParams& params [[buffer(3)]],
    uint2 tid [[thread_position_in_grid]]
) {
    uint d = tid.x;
    uint tok = tid.y;
    if (tok >= params.n_tokens || d >= params.hidden_size) return;

    float sum = 0.0f;
    for (uint k = 0; k < params.top_k; k++) {
        const uint in_idx = (tok * params.top_k + k) * params.hidden_size + d;
        const uint w_idx = tok * params.top_k + k;
        sum += expert_outputs[in_idx] * weights[w_idx];
    }
    output[tok * params.hidden_size + d] = sum;
}

// --------------------------------------------------------------------------
// moe_swiglu_seq_bf16 — bf16 variant of moe_swiglu_seq.
//
// Input:  [n_tokens, top_k, 2*intermediate] bfloat16
// Output: [n_tokens, top_k, intermediate]   bfloat16
// Grid:   3D (intermediate, top_k, n_tokens)
//
// All arithmetic in f32; inputs read as bf16, output stored as bf16.
// --------------------------------------------------------------------------

kernel void moe_swiglu_seq_bf16(
    device const bfloat* gate_up_buf  [[buffer(0)]],
    device bfloat*       output_buf   [[buffer(1)]],
    constant MoeSwigluSeqParams& params [[buffer(2)]],
    uint3 tid [[thread_position_in_grid]]
) {
    uint i    = tid.x;
    uint slot = tid.y;
    uint tok  = tid.z;
    if (tok >= params.n_tokens || slot >= params.top_k || i >= params.intermediate) return;

    uint slot_base = (tok * params.top_k + slot) * 2u * params.intermediate;
    const float gate = static_cast<float>(gate_up_buf[slot_base + i]);
    const float up   = static_cast<float>(gate_up_buf[slot_base + params.intermediate + i]);
    // SwiGLU = GELU(gate) * up; compute in f32, store as bf16
    const float gelu = gate * 0.5f * (1.0f + precise::tanh(
        0.7978845608f * (gate + 0.044715f * gate * gate * gate)));
    output_buf[(tok * params.top_k + slot) * params.intermediate + i] = bfloat(gelu * up);
}

// --------------------------------------------------------------------------
// moe_weighted_sum_seq_bf16_input — bf16 expert_outputs, f32 weights/output.
//
// Matches the MoE convention: expert intermediates are bf16, but the
// weighted accumulator (pf_moe_accum) stays f32 for residual precision.
//
// Input:  expert_outputs [n_tokens, top_k, hidden_size] bfloat16
//         weights        [n_tokens, top_k]              float32
// Output: [n_tokens, hidden_size]                       float32
// --------------------------------------------------------------------------

kernel void moe_weighted_sum_seq_bf16_input(
    device const bfloat* expert_outputs [[buffer(0)]],
    device const float*  weights        [[buffer(1)]],
    device float*        output         [[buffer(2)]],
    constant MoeWeightedSumSeqParams& params [[buffer(3)]],
    uint2 tid [[thread_position_in_grid]]
) {
    uint d   = tid.x;
    uint tok = tid.y;
    if (tok >= params.n_tokens || d >= params.hidden_size) return;

    float sum = 0.0f;
    for (uint k = 0u; k < params.top_k; k++) {
        const uint in_idx = (tok * params.top_k + k) * params.hidden_size + d;
        const uint w_idx  = tok * params.top_k + k;
        // Promote bf16 expert output to f32 before multiply-accumulate
        sum += static_cast<float>(expert_outputs[in_idx]) * weights[w_idx];
    }
    output[tok * params.hidden_size + d] = sum;
}

// --------------------------------------------------------------------------
// naive_matvec_f32 — Simple matrix-vector multiply for expert projections.
//
// Computes: output[row] = dot(weight[row, :], input[:])
//
// weight is [N, K] row-major, input is [K], output is [N].
// Each thread computes one output element.
//
// This is a naive implementation for Stage 1.  For large matrices,
// the quantized_matmul kernel from Story 1.2 would be used instead.
//
// Buffers:
//   0: weight — float [N, K] row-major
//   1: input  — float [K]
//   2: output — float [N]
//   3: params — { m (unused, always 1), k, n }
// --------------------------------------------------------------------------

struct MatvecParams {
    uint m;  // unused for matvec, included for struct compatibility
    uint k;  // inner dimension
    uint n;  // output dimension (number of rows in weight)
};

kernel void naive_matvec_f32(
    device const float* weight [[buffer(0)]],
    device const float* input  [[buffer(1)]],
    device float*       output [[buffer(2)]],
    constant MatvecParams& params [[buffer(3)]],
    uint gid [[thread_position_in_grid]]
) {
    if (gid >= params.n) return;

    float sum = 0.0f;
    device const float* row = weight + gid * params.k;
    for (uint i = 0; i < params.k; i++) {
        sum += row[i] * input[i];
    }
    output[gid] = sum;
}

// --------------------------------------------------------------------------
// moe_gather_topk_weights — Gather softmax probs at top-K sorted indices,
// multiply by per_expert_scale, and renormalize.
//
// Replaces the CPU softmax+argsort+gather+scale+renorm sequence for MoE
// routing.  Runs on GPU so the session never needs to break.
//
// Inputs:
//   softmax_probs   — f32 [n_experts]  (softmax output, from dispatch_softmax)
//   sorted_indices  — u32 [n_experts]  (descending sort, from dispatch_argsort)
//   per_expert_scale— f32 [n_experts]  (per-expert learned scale)
//
// Outputs:
//   out_expert_ids  — u32 [top_k]  (selected expert indices)
//   out_weights     — f32 [top_k]  (pre-scaled, renormalized routing weights)
//
// Grid: single thread (top_k <= 8, trivial work).
// --------------------------------------------------------------------------

struct MoeGatherTopkParams {
    uint n_experts;
    uint top_k;
};

kernel void moe_gather_topk_weights(
    device const float*  softmax_probs    [[buffer(0)]],
    device const uint*   sorted_indices   [[buffer(1)]],
    device const float*  per_expert_scale [[buffer(2)]],
    device uint*         out_expert_ids   [[buffer(3)]],
    device float*        out_weights      [[buffer(4)]],
    constant MoeGatherTopkParams& params  [[buffer(5)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid != 0) return;

    // 1. Gather top-K expert ids and their softmax probabilities
    float top_probs[8];   // max top_k = 8
    float prob_sum = 0.0f;
    for (uint k = 0; k < params.top_k; k++) {
        uint eid = sorted_indices[k];
        out_expert_ids[k] = eid;
        top_probs[k] = softmax_probs[eid];
        prob_sum += top_probs[k];
    }

    // 2. Renormalize and apply per_expert_scale
    float inv_sum = (prob_sum > 0.0f) ? (1.0f / prob_sum) : 0.0f;
    for (uint k = 0; k < params.top_k; k++) {
        uint eid = out_expert_ids[k];
        out_weights[k] = (top_probs[k] * inv_sum) * per_expert_scale[eid];
    }
}

// ============================================================================
// ADR-020 iter-11h-e3a — backward kernels for moe_weighted_sum_seq.
//
// Forward (existing kernel above):
//   output[t, d] = sum_k expert_outputs[t, k, d] * weights[t, k]
//
// Backward (this file):
//   d_expert_outputs[t, k, d] = weights[t, k] * d_output[t, d]      (parallel)
//   d_weights[t, k]           = sum_d expert_outputs[t, k, d]
//                                       * d_output[t, d]            (reduction)
//
// For DWQ training of MoE with frozen FP16 router, only d_expert_outputs is
// strictly required (it carries gradient back to per-expert SwiGLU and on into
// the quantized expert Linears).  d_weights is provided for completeness and
// for future use cases that train the router.
// ============================================================================

/// d_expert_outputs[t, k, d] = weights[t, k] * d_output[t, d]
/// Grid: 3D (hidden_size, top_k, n_tokens) — fully parallel.
kernel void moe_weighted_sum_seq_backward_outputs_f32(
    device const float*  weights         [[buffer(0)]],  // [n_tokens, top_k]
    device const float*  d_output        [[buffer(1)]],  // [n_tokens, hidden_size]
    device       float*  d_expert_outs   [[buffer(2)]],  // [n_tokens, top_k, hidden_size]
    constant MoeWeightedSumSeqParams& params [[buffer(3)]],
    uint3 tid [[thread_position_in_grid]]
) {
    const uint d   = tid.x;
    const uint k   = tid.y;
    const uint tok = tid.z;
    if (tok >= params.n_tokens || k >= params.top_k || d >= params.hidden_size) return;

    const uint w_idx   = tok * params.top_k + k;
    const uint dout_ix = tok * params.hidden_size + d;
    const uint dexp_ix = (tok * params.top_k + k) * params.hidden_size + d;
    d_expert_outs[dexp_ix] = weights[w_idx] * d_output[dout_ix];
}

/// d_weights[t, k] = sum_d expert_outputs[t, k, d] * d_output[t, d]
/// Grid: 1D (n_tokens * top_k) — each thread owns one (t, k) pair and
/// reduces serially across hidden_size.  For typical DWQ shapes
/// (n_tokens=1..16, top_k=4..8, hidden=2K..6K) this is < 1ms; the
/// simpler-correctness form is preferred over a threadgroup reduction
/// at this iter.  Profile and revisit if profiling shows it on the
/// critical path.
kernel void moe_weighted_sum_seq_backward_weights_f32(
    device const float*  expert_outs   [[buffer(0)]],  // [n_tokens, top_k, hidden_size]
    device const float*  d_output      [[buffer(1)]],  // [n_tokens, hidden_size]
    device       float*  d_weights     [[buffer(2)]],  // [n_tokens, top_k]
    constant MoeWeightedSumSeqParams& params [[buffer(3)]],
    uint tid [[thread_position_in_grid]]
) {
    const uint total = params.n_tokens * params.top_k;
    if (tid >= total) return;

    const uint tok = tid / params.top_k;
    const uint k   = tid % params.top_k;

    float acc = 0.0f;
    const uint exp_base  = (tok * params.top_k + k) * params.hidden_size;
    const uint dout_base = tok * params.hidden_size;
    for (uint d = 0; d < params.hidden_size; d++) {
        acc += expert_outs[exp_base + d] * d_output[dout_base + d];
    }
    d_weights[tid] = acc;
}

// ============================================================================
// ADR-020 iter-11h-e3b — fused backward kernel for moe_swiglu_seq.
//
// Forward (existing kernel above):
//   gate = gate_up[t, k, 0..I)
//   up   = gate_up[t, k, I..2I)
//   T    = tanh(0.7978845608 * (gate + 0.044715 * gate^3))
//   gelu = 0.5 * gate * (1 + T)
//   output[t, k, i] = gelu * up
//
// Backward (this kernel — fused so gate_prime intermediates are reused):
//   ∂L/∂up[t,k,i]   = ∂output[t,k,i] · gelu(gate)
//   ∂L/∂gate[t,k,i] = ∂output[t,k,i] · up · gelu_prime(gate)
//
// gelu_prime via tanh-approx chain rule:
//   s          = 0.7978845608 · (gate + 0.044715·gate^3)
//   ds/dgate   = 0.7978845608 · (1 + 0.134145·gate^2)        // 0.134145 = 3·0.044715
//   T          = tanh(s)
//   dT/dgate   = (1 - T^2) · ds/dgate
//   dgelu/dgate = 0.5·(1 + T) + 0.5·gate·dT/dgate
//
// Output: writes BOTH ∂gate (lower half of d_gate_up) and ∂up (upper half)
// in a single dispatch.  Caller must pre-zero d_gate_up if accumulating
// across calls; this kernel writes (not adds).
// ============================================================================

kernel void moe_swiglu_seq_backward_f32(
    device const float* gate_up_buf  [[buffer(0)]],   // [n_tokens, top_k, 2*intermediate] forward input
    device const float* d_output     [[buffer(1)]],   // [n_tokens, top_k, intermediate]
    device       float* d_gate_up    [[buffer(2)]],   // [n_tokens, top_k, 2*intermediate] output gradient
    constant MoeSwigluSeqParams& params [[buffer(3)]],
    uint3 tid [[thread_position_in_grid]]
) {
    const uint i    = tid.x;
    const uint slot = tid.y;
    const uint tok  = tid.z;
    if (tok >= params.n_tokens || slot >= params.top_k || i >= params.intermediate) return;

    const uint slot_base   = (tok * params.top_k + slot) * 2u * params.intermediate;
    const uint dout_idx    = (tok * params.top_k + slot) * params.intermediate + i;
    const uint gate_idx    = slot_base + i;
    const uint up_idx      = slot_base + params.intermediate + i;

    const float gate = gate_up_buf[gate_idx];
    const float up   = gate_up_buf[up_idx];
    const float dy   = d_output[dout_idx];

    // Recompute T = tanh(s) and gelu(gate) from scratch.  Avoids an
    // extra forward-output buffer at the cost of ~10 flops/thread.
    const float g2   = gate * gate;
    const float s    = 0.7978845608f * (gate + 0.044715f * gate * g2);
    const float T    = precise::tanh(s);
    const float gelu = 0.5f * gate * (1.0f + T);

    // gelu_prime via chain rule.
    const float dsdg     = 0.7978845608f * (1.0f + 0.134145f * g2);
    const float dTdg     = (1.0f - T * T) * dsdg;
    const float gelu_pri = 0.5f * (1.0f + T) + 0.5f * gate * dTdg;

    d_gate_up[gate_idx] = dy * up * gelu_pri;
    d_gate_up[up_idx]   = dy * gelu;
}