Skip to main content

mlx_native/ops/
moe_dispatch.rs

1//! GPU-accelerated MoE expert dispatch (Stage 1: loop over selected experts).
2//!
3//! For each of the K selected experts, runs:
4//!   gate_out   = gate_proj_e(x)       [input_dim -> intermediate_dim]
5//!   up_out     = up_proj_e(x)         [input_dim -> intermediate_dim]
6//!   hidden     = GELU(gate_out) * up_out
7//!   expert_out = down_proj_e(hidden)  [intermediate_dim -> input_dim]
8//!   result    += routing_weight_e * expert_out
9//!
10//! Stage 1 uses individual kernel dispatches per expert and per projection.
11//! The projections use float matmul (caller dequantizes or provides float
12//! weights).  Stage 2 optimization (Epic 6) would fuse these.
13//!
14//! This module provides the high-level `moe_dispatch` function that
15//! orchestrates the per-expert loop, using the fused_gelu_mul and
16//! moe_accumulate shaders from moe_dispatch.metal.
17
18use metal::MTLSize;
19
20use crate::buffer::MlxBuffer;
21use crate::encoder::CommandEncoder;
22use crate::error::{MlxError, Result};
23use crate::kernel_registry::KernelRegistry;
24
25use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
26
27/// Parameters for MoE dispatch.
28pub struct MoeDispatchParams {
29    /// Input/output dimension (e.g. 2816 for Gemma 4 MoE layers).
30    pub input_dim: usize,
31    /// Intermediate FFN dimension per expert (e.g. 704 for Gemma 4).
32    pub intermediate_dim: usize,
33    /// Number of selected experts (top_k, e.g. 8).
34    pub n_selected: usize,
35}
36
37/// A single expert's weight matrices (float32, pre-dequantized or float).
38///
39/// Each expert has three projection matrices:
40/// * `gate_proj`: `[input_dim, intermediate_dim]` row-major
41/// * `up_proj`:   `[input_dim, intermediate_dim]` row-major
42/// * `down_proj`: `[intermediate_dim, input_dim]` row-major
43pub struct ExpertWeights<'a> {
44    pub gate_proj: &'a MlxBuffer,
45    pub up_proj: &'a MlxBuffer,
46    pub down_proj: &'a MlxBuffer,
47}
48
49/// MSL-compatible struct for fused_gelu_mul kernel.
50#[repr(C)]
51#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
52struct GpuFusedGeluMulParams {
53    n_elements: u32,
54}
55
56/// MSL-compatible struct for moe_accumulate kernel.
57#[repr(C)]
58#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
59struct GpuMoeAccumParams {
60    n_elements: u32,
61    routing_weight: f32,
62}
63
64/// MSL-compatible struct for zero_buffer kernel.
65#[repr(C)]
66#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
67struct GpuZeroParams {
68    n_elements: u32,
69}
70
71/// MSL-compatible struct for a simple matmul params.
72/// This is used with the naive_matmul shader for expert projections.
73#[repr(C)]
74#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
75struct GpuMatmulParams {
76    m: u32,  // rows of output (1 for single-token)
77    k: u32,  // inner dimension
78    n: u32,  // cols of output
79}
80
81/// Encode MoE dispatch: loop over selected experts, run FFN, accumulate.
82///
83/// This is the Stage 1 implementation that loops over each selected expert
84/// and dispatches individual compute passes for each projection.
85///
86/// # Buffer expectations
87///
88/// * `input` — f32, `[input_dim]` (single token hidden state)
89/// * `expert_weights` — slice of `n_selected` expert weight structs, each
90///   containing gate_proj, up_proj, down_proj as f32 buffers
91/// * `routing_weights` — f32, `[n_selected]` (softmax routing weights from moe_gate)
92/// * `output` — f32, `[input_dim]` (output, will be zero-initialized)
93/// * `scratch_gate` — f32, `[intermediate_dim]` (scratch buffer for gate_proj output)
94/// * `scratch_up` — f32, `[intermediate_dim]` (scratch buffer for up_proj output)
95/// * `scratch_hidden` — f32, `[intermediate_dim]` (scratch buffer for GELU*up output)
96/// * `scratch_expert` — f32, `[input_dim]` (scratch buffer for down_proj output)
97///
98/// # Design Notes
99///
100/// The caller provides scratch buffers to avoid allocating inside the
101/// encoding loop.  These can come from a `MlxBufferPool`.
102///
103/// For the matrix projections, we use a naive matmul kernel (single-token,
104/// M=1, so it's really a matvec).  The quantized_matmul from Story 1.2
105/// would be used when weights remain quantized.  Stage 1 assumes float
106/// weights for simplicity.
107///
108/// # Errors
109///
110/// Returns `MlxError::InvalidArgument` if parameters are invalid.
111#[allow(clippy::too_many_arguments)]
112pub fn moe_dispatch(
113    encoder: &mut CommandEncoder,
114    registry: &mut KernelRegistry,
115    device: &metal::DeviceRef,
116    input: &MlxBuffer,
117    expert_weights: &[ExpertWeights<'_>],
118    routing_weights: &[f32],
119    output: &MlxBuffer,
120    scratch_gate: &MlxBuffer,
121    scratch_up: &MlxBuffer,
122    scratch_hidden: &MlxBuffer,
123    scratch_expert: &MlxBuffer,
124    params: &MoeDispatchParams,
125) -> Result<()> {
126    // --- Validation ---
127    if params.input_dim == 0 {
128        return Err(MlxError::InvalidArgument(
129            "moe_dispatch: input_dim must be > 0".into(),
130        ));
131    }
132    if params.intermediate_dim == 0 {
133        return Err(MlxError::InvalidArgument(
134            "moe_dispatch: intermediate_dim must be > 0".into(),
135        ));
136    }
137    if params.n_selected == 0 {
138        return Err(MlxError::InvalidArgument(
139            "moe_dispatch: n_selected must be > 0".into(),
140        ));
141    }
142    if expert_weights.len() != params.n_selected {
143        return Err(MlxError::InvalidArgument(format!(
144            "moe_dispatch: expert_weights length ({}) must match n_selected ({})",
145            expert_weights.len(),
146            params.n_selected
147        )));
148    }
149    if routing_weights.len() != params.n_selected {
150        return Err(MlxError::InvalidArgument(format!(
151            "moe_dispatch: routing_weights length ({}) must match n_selected ({})",
152            routing_weights.len(),
153            params.n_selected
154        )));
155    }
156
157    // Validate buffer sizes
158    let input_bytes = params.input_dim * std::mem::size_of::<f32>();
159    if input.byte_len() < input_bytes {
160        return Err(MlxError::InvalidArgument(format!(
161            "moe_dispatch: input buffer too small: need {} bytes, have {}",
162            input_bytes,
163            input.byte_len()
164        )));
165    }
166    if output.byte_len() < input_bytes {
167        return Err(MlxError::InvalidArgument(format!(
168            "moe_dispatch: output buffer too small: need {} bytes, have {}",
169            input_bytes,
170            output.byte_len()
171        )));
172    }
173
174    let intermediate_bytes = params.intermediate_dim * std::mem::size_of::<f32>();
175    if scratch_gate.byte_len() < intermediate_bytes {
176        return Err(MlxError::InvalidArgument(
177            "moe_dispatch: scratch_gate buffer too small".into(),
178        ));
179    }
180    if scratch_up.byte_len() < intermediate_bytes {
181        return Err(MlxError::InvalidArgument(
182            "moe_dispatch: scratch_up buffer too small".into(),
183        ));
184    }
185    if scratch_hidden.byte_len() < intermediate_bytes {
186        return Err(MlxError::InvalidArgument(
187            "moe_dispatch: scratch_hidden buffer too small".into(),
188        ));
189    }
190    if scratch_expert.byte_len() < input_bytes {
191        return Err(MlxError::InvalidArgument(
192            "moe_dispatch: scratch_expert buffer too small".into(),
193        ));
194    }
195
196    // Validate expert weight sizes
197    let gate_up_bytes = params.input_dim * params.intermediate_dim * std::mem::size_of::<f32>();
198    let down_bytes = params.intermediate_dim * params.input_dim * std::mem::size_of::<f32>();
199
200    for (i, ew) in expert_weights.iter().enumerate() {
201        if ew.gate_proj.byte_len() < gate_up_bytes {
202            return Err(MlxError::InvalidArgument(format!(
203                "moe_dispatch: expert {} gate_proj too small: need {} bytes, have {}",
204                i, gate_up_bytes, ew.gate_proj.byte_len()
205            )));
206        }
207        if ew.up_proj.byte_len() < gate_up_bytes {
208            return Err(MlxError::InvalidArgument(format!(
209                "moe_dispatch: expert {} up_proj too small: need {} bytes, have {}",
210                i, gate_up_bytes, ew.up_proj.byte_len()
211            )));
212        }
213        if ew.down_proj.byte_len() < down_bytes {
214            return Err(MlxError::InvalidArgument(format!(
215                "moe_dispatch: expert {} down_proj too small: need {} bytes, have {}",
216                i, down_bytes, ew.down_proj.byte_len()
217            )));
218        }
219    }
220
221    // --- Pre-warm pipeline cache to avoid compilation during hot loop ---
222    // Each get_pipeline call borrows &mut registry; we must not hold multiple
223    // returned references simultaneously.  Pre-warming ensures subsequent
224    // get_pipeline calls are cache hits, and we retrieve each one just before
225    // use (via helper closures or repeated single calls).
226    {
227        registry.get_pipeline("naive_matvec_f32", device)?;
228        registry.get_pipeline("fused_gelu_mul", device)?;
229        registry.get_pipeline("moe_accumulate", device)?;
230        registry.get_pipeline("zero_buffer", device)?;
231    }
232    // SAFETY: After pre-warming, the pipelines are in the cache and the HashMap
233    // entries will not be moved or removed for the lifetime of `registry`.
234    // We convert to raw pointers to avoid the borrow checker's complaint about
235    // multiple &mut borrows while holding references into the cache.
236    let matvec_pipeline: *const metal::ComputePipelineState = {
237        let p = registry.get_pipeline("naive_matvec_f32", device)?;
238        p as *const _
239    };
240    let gelu_mul_pipeline: *const metal::ComputePipelineState = {
241        let p = registry.get_pipeline("fused_gelu_mul", device)?;
242        p as *const _
243    };
244    let accum_pipeline: *const metal::ComputePipelineState = {
245        let p = registry.get_pipeline("moe_accumulate", device)?;
246        p as *const _
247    };
248    let zero_pipeline: *const metal::ComputePipelineState = {
249        let p = registry.get_pipeline("zero_buffer", device)?;
250        p as *const _
251    };
252    // Re-borrow from the raw pointers.  This is safe because:
253    //   1. The HashMap entries are stable (no insertions/removals below).
254    //   2. The registry outlives these references.
255    //   3. We only read through these pointers.
256    let matvec_pipeline = unsafe { &*matvec_pipeline };
257    let gelu_mul_pipeline = unsafe { &*gelu_mul_pipeline };
258    let accum_pipeline = unsafe { &*accum_pipeline };
259    let zero_pipeline = unsafe { &*zero_pipeline };
260
261    // --- Zero-initialize output ---
262    let zero_params = GpuZeroParams {
263        n_elements: params.input_dim as u32,
264    };
265    encode_with_args(
266        encoder,
267        zero_pipeline,
268        &[
269            (0, KernelArg::Buffer(output)),
270            (1, KernelArg::Bytes(as_bytes(&zero_params))),
271        ],
272        MTLSize::new(params.input_dim as u64, 1, 1),
273        MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
274    );
275
276    // --- Loop over selected experts ---
277    for (i, ew) in expert_weights.iter().enumerate() {
278        let w = routing_weights[i];
279
280        // Skip experts with near-zero routing weight
281        if w.abs() < 1e-10 {
282            continue;
283        }
284
285        // Barrier after zero_buffer (first iteration) or after previous
286        // accumulate (subsequent iterations) — the output buffer was just written.
287        encoder.memory_barrier();
288
289        // gate_out = gate_proj @ input  (matvec: [intermediate_dim, input_dim] @ [input_dim] -> [intermediate_dim])
290        let gate_params = GpuMatmulParams {
291            m: 1,
292            k: params.input_dim as u32,
293            n: params.intermediate_dim as u32,
294        };
295        encode_with_args(
296            encoder,
297            matvec_pipeline,
298            &[
299                (0, KernelArg::Buffer(ew.gate_proj)),
300                (1, KernelArg::Buffer(input)),
301                (2, KernelArg::Buffer(scratch_gate)),
302                (3, KernelArg::Bytes(as_bytes(&gate_params))),
303            ],
304            MTLSize::new(params.intermediate_dim as u64, 1, 1),
305            MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
306        );
307
308        // up_out = up_proj @ input  (same shape as gate)
309        let up_params = GpuMatmulParams {
310            m: 1,
311            k: params.input_dim as u32,
312            n: params.intermediate_dim as u32,
313        };
314        encode_with_args(
315            encoder,
316            matvec_pipeline,
317            &[
318                (0, KernelArg::Buffer(ew.up_proj)),
319                (1, KernelArg::Buffer(input)),
320                (2, KernelArg::Buffer(scratch_up)),
321                (3, KernelArg::Bytes(as_bytes(&up_params))),
322            ],
323            MTLSize::new(params.intermediate_dim as u64, 1, 1),
324            MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
325        );
326
327        // Barrier: gate_out and up_out must complete before gelu_mul reads them.
328        encoder.memory_barrier();
329
330        // hidden = GELU(gate_out) * up_out
331        let gelu_params = GpuFusedGeluMulParams {
332            n_elements: params.intermediate_dim as u32,
333        };
334        encode_with_args(
335            encoder,
336            gelu_mul_pipeline,
337            &[
338                (0, KernelArg::Buffer(scratch_gate)),
339                (1, KernelArg::Buffer(scratch_up)),
340                (2, KernelArg::Buffer(scratch_hidden)),
341                (3, KernelArg::Bytes(as_bytes(&gelu_params))),
342            ],
343            MTLSize::new(params.intermediate_dim as u64, 1, 1),
344            MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
345        );
346
347        // Barrier: hidden must complete before down_proj reads it.
348        encoder.memory_barrier();
349
350        // expert_out = down_proj @ hidden  (matvec: [input_dim, intermediate_dim] @ [intermediate_dim] -> [input_dim])
351        let down_params = GpuMatmulParams {
352            m: 1,
353            k: params.intermediate_dim as u32,
354            n: params.input_dim as u32,
355        };
356        encode_with_args(
357            encoder,
358            matvec_pipeline,
359            &[
360                (0, KernelArg::Buffer(ew.down_proj)),
361                (1, KernelArg::Buffer(scratch_hidden)),
362                (2, KernelArg::Buffer(scratch_expert)),
363                (3, KernelArg::Bytes(as_bytes(&down_params))),
364            ],
365            MTLSize::new(params.input_dim as u64, 1, 1),
366            MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
367        );
368
369        // Barrier: expert_out must complete before accumulate reads it.
370        encoder.memory_barrier();
371
372        // result += w * expert_out
373        let accum_params = GpuMoeAccumParams {
374            n_elements: params.input_dim as u32,
375            routing_weight: w,
376        };
377        encode_with_args(
378            encoder,
379            accum_pipeline,
380            &[
381                (0, KernelArg::Buffer(output)),
382                (1, KernelArg::Buffer(scratch_expert)),
383                (2, KernelArg::Bytes(as_bytes(&accum_params))),
384            ],
385            MTLSize::new(params.input_dim as u64, 1, 1),
386            MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
387        );
388    }
389
390    Ok(())
391}
392
393/// Encode a fused SwiGLU on a `[2*N]` gate_up buffer, producing `[N]` output.
394///
395/// Computes `output[i] = GELU(gate_up[i]) * gate_up[N + i]` for `i in 0..N`.
396///
397/// Uses the `moe_swiglu_fused` kernel from moe_dispatch.metal.
398///
399/// # Arguments
400/// * `encoder`        -- Command encoder to record into.
401/// * `registry`       -- Kernel registry for pipeline lookup.
402/// * `device`         -- Metal device reference.
403/// * `gate_up`        -- f32 buffer `[2 * n_elements]` (gate || up concatenated).
404/// * `output`         -- f32 buffer `[n_elements]` (output).
405/// * `n_elements`     -- Number of output elements (intermediate_dim).
406pub fn moe_swiglu_fused_encode(
407    encoder: &mut CommandEncoder,
408    registry: &mut KernelRegistry,
409    device: &metal::DeviceRef,
410    gate_up: &MlxBuffer,
411    output: &MlxBuffer,
412    n_elements: usize,
413) -> Result<()> {
414    if n_elements == 0 {
415        return Err(MlxError::InvalidArgument(
416            "moe_swiglu_fused_encode: n_elements must be > 0".into(),
417        ));
418    }
419    let gu_required = 2 * n_elements * std::mem::size_of::<f32>();
420    if gate_up.byte_len() < gu_required {
421        return Err(MlxError::InvalidArgument(format!(
422            "moe_swiglu_fused_encode: gate_up buffer too small: need {} bytes, have {}",
423            gu_required, gate_up.byte_len()
424        )));
425    }
426    let out_required = n_elements * std::mem::size_of::<f32>();
427    if output.byte_len() < out_required {
428        return Err(MlxError::InvalidArgument(format!(
429            "moe_swiglu_fused_encode: output buffer too small: need {} bytes, have {}",
430            out_required, output.byte_len()
431        )));
432    }
433
434    let pipeline = registry.get_pipeline("moe_swiglu_fused", device)?;
435    let params = GpuFusedGeluMulParams {
436        n_elements: n_elements as u32,
437    };
438    encode_with_args(
439        encoder,
440        pipeline,
441        &[
442            (0, KernelArg::Buffer(gate_up)),
443            (1, KernelArg::Buffer(output)),
444            (2, KernelArg::Bytes(as_bytes(&params))),
445        ],
446        MTLSize::new(n_elements as u64, 1, 1),
447        MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
448    );
449    Ok(())
450}
451
452/// Zero-initialize an f32 GPU buffer using the `zero_buffer` kernel.
453///
454/// This is useful for preparing an accumulator buffer before dispatching
455/// weighted accumulation passes.
456///
457/// # Arguments
458/// * `encoder`    — Command encoder to record into.
459/// * `registry`   — Kernel registry for pipeline lookup.
460/// * `device`     — Metal device reference.
461/// * `output`     — f32 buffer to zero, must be at least `n_elements * 4` bytes.
462/// * `n_elements` — Number of f32 elements to zero.
463pub fn moe_zero_buffer_encode(
464    encoder: &mut CommandEncoder,
465    registry: &mut KernelRegistry,
466    device: &metal::DeviceRef,
467    output: &MlxBuffer,
468    n_elements: usize,
469) -> Result<()> {
470    if n_elements == 0 {
471        return Err(MlxError::InvalidArgument(
472            "moe_zero_buffer_encode: n_elements must be > 0".into(),
473        ));
474    }
475    let required = n_elements * std::mem::size_of::<f32>();
476    if output.byte_len() < required {
477        return Err(MlxError::InvalidArgument(format!(
478            "moe_zero_buffer_encode: buffer too small: need {} bytes, have {}",
479            required, output.byte_len()
480        )));
481    }
482
483    let pipeline = registry.get_pipeline("zero_buffer", device)?;
484    let params = GpuZeroParams { n_elements: n_elements as u32 };
485    encode_with_args(
486        encoder,
487        pipeline,
488        &[
489            (0, KernelArg::Buffer(output)),
490            (1, KernelArg::Bytes(as_bytes(&params))),
491        ],
492        MTLSize::new(n_elements as u64, 1, 1),
493        MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
494    );
495    Ok(())
496}
497
498/// Encode a weighted accumulation: `accumulator[i] += routing_weight * expert_output[i]`.
499///
500/// Uses the `moe_accumulate` kernel from moe_dispatch.metal.
501///
502/// # Arguments
503/// * `encoder`        — Command encoder to record into.
504/// * `registry`       — Kernel registry for pipeline lookup.
505/// * `device`         — Metal device reference.
506/// * `accumulator`    — f32 buffer `[n_elements]`, in/out.
507/// * `expert_output`  — f32 buffer `[n_elements]`, input.
508/// * `routing_weight` — Scalar weight for this expert.
509/// * `n_elements`     — Number of f32 elements.
510pub fn moe_accumulate_encode(
511    encoder: &mut CommandEncoder,
512    registry: &mut KernelRegistry,
513    device: &metal::DeviceRef,
514    accumulator: &MlxBuffer,
515    expert_output: &MlxBuffer,
516    routing_weight: f32,
517    n_elements: usize,
518) -> Result<()> {
519    if n_elements == 0 {
520        return Err(MlxError::InvalidArgument(
521            "moe_accumulate_encode: n_elements must be > 0".into(),
522        ));
523    }
524    let required = n_elements * std::mem::size_of::<f32>();
525    if accumulator.byte_len() < required {
526        return Err(MlxError::InvalidArgument(format!(
527            "moe_accumulate_encode: accumulator too small: need {} bytes, have {}",
528            required, accumulator.byte_len()
529        )));
530    }
531    if expert_output.byte_len() < required {
532        return Err(MlxError::InvalidArgument(format!(
533            "moe_accumulate_encode: expert_output too small: need {} bytes, have {}",
534            required, expert_output.byte_len()
535        )));
536    }
537
538    let pipeline = registry.get_pipeline("moe_accumulate", device)?;
539    let params = GpuMoeAccumParams {
540        n_elements: n_elements as u32,
541        routing_weight,
542    };
543    encode_with_args(
544        encoder,
545        pipeline,
546        &[
547            (0, KernelArg::Buffer(accumulator)),
548            (1, KernelArg::Buffer(expert_output)),
549            (2, KernelArg::Bytes(as_bytes(&params))),
550        ],
551        MTLSize::new(n_elements as u64, 1, 1),
552        MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
553    );
554    Ok(())
555}
556
557/// Encode a batched SwiGLU across all top_k expert slots in one dispatch.
558///
559/// Takes a `[top_k, 2*intermediate]` gate_up buffer and produces
560/// `[top_k, intermediate]` output: `GELU(gate[i]) * up[i]` per slot.
561///
562/// Replaces top_k separate `moe_swiglu_fused_encode_offset` dispatches with 1.
563///
564/// # Arguments
565/// * `encoder`       -- Command encoder to record into.
566/// * `registry`      -- Kernel registry for pipeline lookup.
567/// * `device`        -- Metal device reference.
568/// * `gate_up`       -- f32 buffer `[top_k * 2 * intermediate]`.
569/// * `output`        -- f32 buffer `[top_k * intermediate]`.
570/// * `intermediate`  -- Intermediate dimension per expert.
571/// * `top_k`         -- Number of selected expert slots.
572pub fn moe_swiglu_batch_encode(
573    encoder: &mut CommandEncoder,
574    registry: &mut KernelRegistry,
575    device: &metal::DeviceRef,
576    gate_up: &MlxBuffer,
577    output: &MlxBuffer,
578    intermediate: usize,
579    top_k: usize,
580) -> Result<()> {
581    if intermediate == 0 || top_k == 0 {
582        return Err(MlxError::InvalidArgument(
583            "moe_swiglu_batch_encode: intermediate and top_k must be > 0".into(),
584        ));
585    }
586    let gu_required = top_k * 2 * intermediate * std::mem::size_of::<f32>();
587    if gate_up.byte_len() < gu_required {
588        return Err(MlxError::InvalidArgument(format!(
589            "moe_swiglu_batch_encode: gate_up too small: need {} bytes, have {}",
590            gu_required, gate_up.byte_len()
591        )));
592    }
593    let out_required = top_k * intermediate * std::mem::size_of::<f32>();
594    if output.byte_len() < out_required {
595        return Err(MlxError::InvalidArgument(format!(
596            "moe_swiglu_batch_encode: output too small: need {} bytes, have {}",
597            out_required, output.byte_len()
598        )));
599    }
600
601    let pipeline = registry.get_pipeline("moe_swiglu_batch", device)?;
602    let intermediate_bytes = (intermediate as u32).to_ne_bytes();
603    let top_k_bytes = (top_k as u32).to_ne_bytes();
604
605    encode_with_args(
606        encoder,
607        pipeline,
608        &[
609            (0, KernelArg::Buffer(gate_up)),
610            (1, KernelArg::Buffer(output)),
611            (2, KernelArg::Bytes(&intermediate_bytes)),
612            (3, KernelArg::Bytes(&top_k_bytes)),
613        ],
614        MTLSize::new(intermediate as u64, top_k as u64, 1),
615        MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
616    );
617    Ok(())
618}
619
620/// MSL-compatible struct for moe_weighted_sum kernel.
621#[repr(C)]
622#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
623struct GpuMoeWeightedSumParams {
624    hidden_size: u32,
625    top_k: u32,
626}
627
628/// Encode a weighted sum of all top_k expert outputs in one dispatch.
629///
630/// Replaces the zero_buffer + top_k * moe_accumulate pattern with 1 dispatch.
631/// The weights buffer must contain pre-scaled routing weights for all top_k
632/// experts (i.e. `routing_weight * per_expert_scale`).
633///
634/// # Arguments
635/// * `encoder`        -- Command encoder to record into.
636/// * `registry`       -- Kernel registry for pipeline lookup.
637/// * `device`         -- Metal device reference.
638/// * `expert_outputs` -- f32 buffer `[top_k * hidden_size]`.
639/// * `weights`        -- f32 buffer `[top_k]` (pre-scaled routing weights).
640/// * `output`         -- f32 buffer `[hidden_size]` (output weighted sum).
641/// * `hidden_size`    -- Hidden dimension.
642/// * `top_k`          -- Number of selected expert slots.
643pub fn moe_weighted_sum_encode(
644    encoder: &mut CommandEncoder,
645    registry: &mut KernelRegistry,
646    device: &metal::DeviceRef,
647    expert_outputs: &MlxBuffer,
648    weights: &MlxBuffer,
649    output: &MlxBuffer,
650    hidden_size: usize,
651    top_k: usize,
652) -> Result<()> {
653    if hidden_size == 0 || top_k == 0 {
654        return Err(MlxError::InvalidArgument(
655            "moe_weighted_sum_encode: hidden_size and top_k must be > 0".into(),
656        ));
657    }
658    let expert_required = top_k * hidden_size * std::mem::size_of::<f32>();
659    if expert_outputs.byte_len() < expert_required {
660        return Err(MlxError::InvalidArgument(format!(
661            "moe_weighted_sum_encode: expert_outputs too small: need {} bytes, have {}",
662            expert_required, expert_outputs.byte_len()
663        )));
664    }
665    let weights_required = top_k * std::mem::size_of::<f32>();
666    if weights.byte_len() < weights_required {
667        return Err(MlxError::InvalidArgument(format!(
668            "moe_weighted_sum_encode: weights too small: need {} bytes, have {}",
669            weights_required, weights.byte_len()
670        )));
671    }
672    let out_required = hidden_size * std::mem::size_of::<f32>();
673    if output.byte_len() < out_required {
674        return Err(MlxError::InvalidArgument(format!(
675            "moe_weighted_sum_encode: output too small: need {} bytes, have {}",
676            out_required, output.byte_len()
677        )));
678    }
679
680    let pipeline = registry.get_pipeline("moe_weighted_sum", device)?;
681    let params = GpuMoeWeightedSumParams {
682        hidden_size: hidden_size as u32,
683        top_k: top_k as u32,
684    };
685    encode_with_args(
686        encoder,
687        pipeline,
688        &[
689            (0, KernelArg::Buffer(expert_outputs)),
690            (1, KernelArg::Buffer(weights)),
691            (2, KernelArg::Buffer(output)),
692            (3, KernelArg::Bytes(as_bytes(&params))),
693        ],
694        MTLSize::new(hidden_size as u64, 1, 1),
695        MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
696    );
697    Ok(())
698}
699
700/// MSL-compatible struct for moe_gather_topk_weights kernel.
701#[repr(C)]
702#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
703struct GpuMoeGatherTopkParams {
704    n_experts: u32,
705    top_k: u32,
706}
707
708/// Encode a GPU-side MoE top-K routing gather.
709///
710/// Reads softmax probs and sorted indices (both on GPU from prior dispatches),
711/// gathers the top-K expert IDs and their weights, applies per_expert_scale,
712/// and renormalizes.  This eliminates the CPU readback that previously forced
713/// a session break between S1 and S4.
714///
715/// # Arguments
716/// * `encoder`          -- Command encoder.
717/// * `registry`         -- Kernel registry.
718/// * `device`           -- Metal device.
719/// * `softmax_probs`    -- f32 `[n_experts]` (output of dispatch_softmax).
720/// * `sorted_indices`   -- u32 `[n_experts]` (output of dispatch_argsort_desc_f32).
721/// * `per_expert_scale` -- f32 `[n_experts]` (learned per-expert scale).
722/// * `out_expert_ids`   -- u32 `[top_k]` (output: selected expert indices).
723/// * `out_weights`      -- f32 `[top_k]` (output: pre-scaled routing weights).
724/// * `n_experts`        -- Total number of experts.
725/// * `top_k`            -- Number of experts to select.
726pub fn moe_gather_topk_weights_encode(
727    encoder: &mut CommandEncoder,
728    registry: &mut KernelRegistry,
729    device: &metal::DeviceRef,
730    softmax_probs: &MlxBuffer,
731    sorted_indices: &MlxBuffer,
732    per_expert_scale: &MlxBuffer,
733    out_expert_ids: &MlxBuffer,
734    out_weights: &MlxBuffer,
735    n_experts: usize,
736    top_k: usize,
737) -> Result<()> {
738    if n_experts == 0 || top_k == 0 {
739        return Err(MlxError::InvalidArgument(
740            "moe_gather_topk_weights: n_experts and top_k must be > 0".into(),
741        ));
742    }
743    if top_k > n_experts {
744        return Err(MlxError::InvalidArgument(format!(
745            "moe_gather_topk_weights: top_k ({}) > n_experts ({})",
746            top_k, n_experts,
747        )));
748    }
749    if top_k > 8 {
750        return Err(MlxError::InvalidArgument(format!(
751            "moe_gather_topk_weights: top_k ({}) > 8 (shader fixed-size array limit)",
752            top_k,
753        )));
754    }
755
756    let f32_size = std::mem::size_of::<f32>();
757    let u32_size = std::mem::size_of::<u32>();
758    if softmax_probs.byte_len() < n_experts * f32_size {
759        return Err(MlxError::InvalidArgument("softmax_probs too small".into()));
760    }
761    if sorted_indices.byte_len() < n_experts * u32_size {
762        return Err(MlxError::InvalidArgument("sorted_indices too small".into()));
763    }
764    if per_expert_scale.byte_len() < n_experts * f32_size {
765        return Err(MlxError::InvalidArgument("per_expert_scale too small".into()));
766    }
767    if out_expert_ids.byte_len() < top_k * u32_size {
768        return Err(MlxError::InvalidArgument("out_expert_ids too small".into()));
769    }
770    if out_weights.byte_len() < top_k * f32_size {
771        return Err(MlxError::InvalidArgument("out_weights too small".into()));
772    }
773
774    let pipeline = registry.get_pipeline("moe_gather_topk_weights", device)?;
775    let params = GpuMoeGatherTopkParams {
776        n_experts: n_experts as u32,
777        top_k: top_k as u32,
778    };
779    encode_with_args(
780        encoder,
781        pipeline,
782        &[
783            (0, KernelArg::Buffer(softmax_probs)),
784            (1, KernelArg::Buffer(sorted_indices)),
785            (2, KernelArg::Buffer(per_expert_scale)),
786            (3, KernelArg::Buffer(out_expert_ids)),
787            (4, KernelArg::Buffer(out_weights)),
788            (5, KernelArg::Bytes(as_bytes(&params))),
789        ],
790        MTLSize::new(1, 1, 1),  // single thread
791        MTLSize::new(1, 1, 1),
792    );
793    Ok(())
794}
795
796/// Like [`moe_swiglu_fused_encode`] but reads from `gate_up` at `gu_byte_offset`
797/// and writes to `output` at `out_byte_offset`.
798///
799/// This enables operating on slices within larger buffers (e.g. the _id kernel
800/// output which contains top_k rows of gate_up data).
801#[allow(clippy::too_many_arguments)]
802pub fn moe_swiglu_fused_encode_offset(
803    encoder: &mut CommandEncoder,
804    registry: &mut KernelRegistry,
805    device: &metal::DeviceRef,
806    gate_up: &MlxBuffer,
807    gu_byte_offset: usize,
808    output: &MlxBuffer,
809    out_byte_offset: usize,
810    n_elements: usize,
811) -> Result<()> {
812    if n_elements == 0 {
813        return Err(MlxError::InvalidArgument(
814            "moe_swiglu_fused_encode_offset: n_elements must be > 0".into(),
815        ));
816    }
817    let gu_required = gu_byte_offset + 2 * n_elements * std::mem::size_of::<f32>();
818    if gate_up.byte_len() < gu_required {
819        return Err(MlxError::InvalidArgument(format!(
820            "moe_swiglu_fused_encode_offset: gate_up buffer too small: need {} bytes (offset {}), have {}",
821            gu_required, gu_byte_offset, gate_up.byte_len()
822        )));
823    }
824    let out_required = out_byte_offset + n_elements * std::mem::size_of::<f32>();
825    if output.byte_len() < out_required {
826        return Err(MlxError::InvalidArgument(format!(
827            "moe_swiglu_fused_encode_offset: output buffer too small: need {} bytes (offset {}), have {}",
828            out_required, out_byte_offset, output.byte_len()
829        )));
830    }
831
832    let pipeline = registry.get_pipeline("moe_swiglu_fused", device)?;
833    let params = GpuFusedGeluMulParams {
834        n_elements: n_elements as u32,
835    };
836    encode_with_args(
837        encoder,
838        pipeline,
839        &[
840            (0, KernelArg::BufferWithOffset(gate_up, gu_byte_offset as u64)),
841            (1, KernelArg::BufferWithOffset(output, out_byte_offset as u64)),
842            (2, KernelArg::Bytes(as_bytes(&params))),
843        ],
844        MTLSize::new(n_elements as u64, 1, 1),
845        MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
846    );
847    Ok(())
848}
849
850/// Like [`moe_accumulate_encode`] but reads `expert_output` from `src_byte_offset`.
851///
852/// This enables reading from a slice within a larger buffer (e.g. the down _id
853/// kernel output which contains top_k rows of hidden data).
854#[allow(clippy::too_many_arguments)]
855pub fn moe_accumulate_encode_offset(
856    encoder: &mut CommandEncoder,
857    registry: &mut KernelRegistry,
858    device: &metal::DeviceRef,
859    accumulator: &MlxBuffer,
860    expert_output: &MlxBuffer,
861    src_byte_offset: usize,
862    routing_weight: f32,
863    n_elements: usize,
864) -> Result<()> {
865    if n_elements == 0 {
866        return Err(MlxError::InvalidArgument(
867            "moe_accumulate_encode_offset: n_elements must be > 0".into(),
868        ));
869    }
870    let required = n_elements * std::mem::size_of::<f32>();
871    if accumulator.byte_len() < required {
872        return Err(MlxError::InvalidArgument(format!(
873            "moe_accumulate_encode_offset: accumulator too small: need {} bytes, have {}",
874            required, accumulator.byte_len()
875        )));
876    }
877    let src_required = src_byte_offset + required;
878    if expert_output.byte_len() < src_required {
879        return Err(MlxError::InvalidArgument(format!(
880            "moe_accumulate_encode_offset: expert_output too small: need {} bytes (offset {}), have {}",
881            src_required, src_byte_offset, expert_output.byte_len()
882        )));
883    }
884
885    let pipeline = registry.get_pipeline("moe_accumulate", device)?;
886    let params = GpuMoeAccumParams {
887        n_elements: n_elements as u32,
888        routing_weight,
889    };
890    encode_with_args(
891        encoder,
892        pipeline,
893        &[
894            (0, KernelArg::Buffer(accumulator)),
895            (1, KernelArg::BufferWithOffset(expert_output, src_byte_offset as u64)),
896            (2, KernelArg::Bytes(as_bytes(&params))),
897        ],
898        MTLSize::new(n_elements as u64, 1, 1),
899        MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
900    );
901    Ok(())
902}
903
904/// Encode a fused GELU-multiply on bf16 buffers.
905///
906/// Computes `output[i] = GELU(gate_out[i]) * up_out[i]` with bf16 I/O and
907/// f32 accumulator.  Port of [`fused_gelu_mul`] for the bf16 activation path.
908///
909/// # Arguments
910/// * `encoder`    — Command encoder.
911/// * `registry`   — Kernel registry.
912/// * `device`     — Metal device.
913/// * `gate_out`   — bf16 buffer `[n_elements]`.
914/// * `up_out`     — bf16 buffer `[n_elements]`.
915/// * `output`     — bf16 buffer `[n_elements]`.
916/// * `n_elements` — Number of elements.
917pub fn fused_gelu_mul_bf16_encode(
918    encoder: &mut CommandEncoder,
919    registry: &mut KernelRegistry,
920    device: &metal::DeviceRef,
921    gate_out: &MlxBuffer,
922    up_out: &MlxBuffer,
923    output: &MlxBuffer,
924    n_elements: usize,
925) -> Result<()> {
926    if n_elements == 0 {
927        return Err(MlxError::InvalidArgument(
928            "fused_gelu_mul_bf16_encode: n_elements must be > 0".into(),
929        ));
930    }
931    // bf16 is 2 bytes per element
932    let required = n_elements * 2;
933    if gate_out.byte_len() < required {
934        return Err(MlxError::InvalidArgument(format!(
935            "fused_gelu_mul_bf16_encode: gate_out too small: need {} bytes, have {}",
936            required, gate_out.byte_len()
937        )));
938    }
939    if up_out.byte_len() < required {
940        return Err(MlxError::InvalidArgument(format!(
941            "fused_gelu_mul_bf16_encode: up_out too small: need {} bytes, have {}",
942            required, up_out.byte_len()
943        )));
944    }
945    if output.byte_len() < required {
946        return Err(MlxError::InvalidArgument(format!(
947            "fused_gelu_mul_bf16_encode: output too small: need {} bytes, have {}",
948            required, output.byte_len()
949        )));
950    }
951
952    let pipeline = registry.get_pipeline("fused_gelu_mul_bf16", device)?;
953    let params = GpuFusedGeluMulParams {
954        n_elements: n_elements as u32,
955    };
956    encode_with_args(
957        encoder,
958        pipeline,
959        &[
960            (0, KernelArg::Buffer(gate_out)),
961            (1, KernelArg::Buffer(up_out)),
962            (2, KernelArg::Buffer(output)),
963            (3, KernelArg::Bytes(as_bytes(&params))),
964        ],
965        MTLSize::new(n_elements as u64, 1, 1),
966        MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
967    );
968    Ok(())
969}
970
971/// GPU params for batched SwiGLU over multiple tokens.
972#[repr(C)]
973#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
974struct GpuMoeSwigluSeqParams {
975    intermediate: u32,
976    top_k: u32,
977    n_tokens: u32,
978}
979
980/// Multi-token SwiGLU for batched prefill.
981///
982/// Input:  `[n_tokens, top_k, 2*intermediate]`
983/// Output: `[n_tokens, top_k, intermediate]`
984#[allow(clippy::too_many_arguments)]
985pub fn moe_swiglu_seq_encode(
986    encoder: &mut CommandEncoder,
987    registry: &mut KernelRegistry,
988    device: &metal::DeviceRef,
989    gate_up: &MlxBuffer,
990    output: &MlxBuffer,
991    intermediate: usize,
992    top_k: usize,
993    n_tokens: usize,
994) -> Result<()> {
995    if intermediate == 0 || top_k == 0 || n_tokens == 0 {
996        return Err(MlxError::InvalidArgument(
997            "moe_swiglu_seq_encode: all dims must be > 0".into(),
998        ));
999    }
1000    let gu_required = n_tokens * top_k * 2 * intermediate * std::mem::size_of::<f32>();
1001    if gate_up.byte_len() < gu_required {
1002        return Err(MlxError::InvalidArgument(format!(
1003            "moe_swiglu_seq_encode: gate_up too small: need {} bytes, have {}",
1004            gu_required, gate_up.byte_len()
1005        )));
1006    }
1007    let out_required = n_tokens * top_k * intermediate * std::mem::size_of::<f32>();
1008    if output.byte_len() < out_required {
1009        return Err(MlxError::InvalidArgument(format!(
1010            "moe_swiglu_seq_encode: output too small: need {} bytes, have {}",
1011            out_required, output.byte_len()
1012        )));
1013    }
1014
1015    let pipeline = registry.get_pipeline("moe_swiglu_seq", device)?;
1016    let gpu_params = GpuMoeSwigluSeqParams {
1017        intermediate: intermediate as u32,
1018        top_k: top_k as u32,
1019        n_tokens: n_tokens as u32,
1020    };
1021
1022    encode_with_args(
1023        encoder,
1024        pipeline,
1025        &[
1026            (0, KernelArg::Buffer(gate_up)),
1027            (1, KernelArg::Buffer(output)),
1028            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
1029        ],
1030        MTLSize::new(intermediate as u64, top_k as u64, n_tokens as u64),
1031        MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
1032    );
1033    Ok(())
1034}
1035
1036/// GPU params for batched weighted-sum over multiple tokens.
1037#[repr(C)]
1038#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
1039struct GpuMoeWeightedSumSeqParams {
1040    hidden_size: u32,
1041    top_k: u32,
1042    n_tokens: u32,
1043}
1044
1045/// Multi-token weighted sum of expert outputs for batched prefill.
1046///
1047/// * `expert_outputs` — `[n_tokens, top_k, hidden_size]`
1048/// * `weights`        — `[n_tokens, top_k]`
1049/// * `output`         — `[n_tokens, hidden_size]`
1050#[allow(clippy::too_many_arguments)]
1051pub fn moe_weighted_sum_seq_encode(
1052    encoder: &mut CommandEncoder,
1053    registry: &mut KernelRegistry,
1054    device: &metal::DeviceRef,
1055    expert_outputs: &MlxBuffer,
1056    weights: &MlxBuffer,
1057    output: &MlxBuffer,
1058    hidden_size: usize,
1059    top_k: usize,
1060    n_tokens: usize,
1061) -> Result<()> {
1062    if hidden_size == 0 || top_k == 0 || n_tokens == 0 {
1063        return Err(MlxError::InvalidArgument(
1064            "moe_weighted_sum_seq_encode: all dims must be > 0".into(),
1065        ));
1066    }
1067    let expert_required = n_tokens * top_k * hidden_size * std::mem::size_of::<f32>();
1068    if expert_outputs.byte_len() < expert_required {
1069        return Err(MlxError::InvalidArgument(format!(
1070            "moe_weighted_sum_seq_encode: expert_outputs too small: need {} bytes, have {}",
1071            expert_required, expert_outputs.byte_len()
1072        )));
1073    }
1074    let weights_required = n_tokens * top_k * std::mem::size_of::<f32>();
1075    if weights.byte_len() < weights_required {
1076        return Err(MlxError::InvalidArgument(format!(
1077            "moe_weighted_sum_seq_encode: weights too small: need {} bytes, have {}",
1078            weights_required, weights.byte_len()
1079        )));
1080    }
1081    let out_required = n_tokens * hidden_size * std::mem::size_of::<f32>();
1082    if output.byte_len() < out_required {
1083        return Err(MlxError::InvalidArgument(format!(
1084            "moe_weighted_sum_seq_encode: output too small: need {} bytes, have {}",
1085            out_required, output.byte_len()
1086        )));
1087    }
1088
1089    let pipeline = registry.get_pipeline("moe_weighted_sum_seq", device)?;
1090    let gpu_params = GpuMoeWeightedSumSeqParams {
1091        hidden_size: hidden_size as u32,
1092        top_k: top_k as u32,
1093        n_tokens: n_tokens as u32,
1094    };
1095
1096    encode_with_args(
1097        encoder,
1098        pipeline,
1099        &[
1100            (0, KernelArg::Buffer(expert_outputs)),
1101            (1, KernelArg::Buffer(weights)),
1102            (2, KernelArg::Buffer(output)),
1103            (3, KernelArg::Bytes(as_bytes(&gpu_params))),
1104        ],
1105        MTLSize::new(hidden_size as u64, n_tokens as u64, 1),
1106        MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
1107    );
1108    Ok(())
1109}
1110
1111/// Multi-token SwiGLU for batched prefill (bf16 I/O, f32 accumulator).
1112///
1113/// Port of [`moe_swiglu_seq_encode`] with bf16 buffers.
1114/// Input:  `[n_tokens, top_k, 2*intermediate]` bf16
1115/// Output: `[n_tokens, top_k, intermediate]`   bf16
1116#[allow(clippy::too_many_arguments)]
1117pub fn moe_swiglu_seq_bf16_encode(
1118    encoder: &mut CommandEncoder,
1119    registry: &mut KernelRegistry,
1120    device: &metal::DeviceRef,
1121    gate_up: &MlxBuffer,
1122    output: &MlxBuffer,
1123    intermediate: usize,
1124    top_k: usize,
1125    n_tokens: usize,
1126) -> Result<()> {
1127    if intermediate == 0 || top_k == 0 || n_tokens == 0 {
1128        return Err(MlxError::InvalidArgument(
1129            "moe_swiglu_seq_bf16_encode: all dims must be > 0".into(),
1130        ));
1131    }
1132    // bf16 = 2 bytes per element
1133    let gu_required = n_tokens * top_k * 2 * intermediate * 2;
1134    if gate_up.byte_len() < gu_required {
1135        return Err(MlxError::InvalidArgument(format!(
1136            "moe_swiglu_seq_bf16_encode: gate_up too small: need {} bytes, have {}",
1137            gu_required, gate_up.byte_len()
1138        )));
1139    }
1140    let out_required = n_tokens * top_k * intermediate * 2;
1141    if output.byte_len() < out_required {
1142        return Err(MlxError::InvalidArgument(format!(
1143            "moe_swiglu_seq_bf16_encode: output too small: need {} bytes, have {}",
1144            out_required, output.byte_len()
1145        )));
1146    }
1147
1148    let pipeline = registry.get_pipeline("moe_swiglu_seq_bf16", device)?;
1149    let gpu_params = GpuMoeSwigluSeqParams {
1150        intermediate: intermediate as u32,
1151        top_k: top_k as u32,
1152        n_tokens: n_tokens as u32,
1153    };
1154
1155    encode_with_args(
1156        encoder,
1157        pipeline,
1158        &[
1159            (0, KernelArg::Buffer(gate_up)),
1160            (1, KernelArg::Buffer(output)),
1161            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
1162        ],
1163        MTLSize::new(intermediate as u64, top_k as u64, n_tokens as u64),
1164        MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
1165    );
1166    Ok(())
1167}
1168
1169/// Multi-token weighted sum of expert outputs for batched prefill (bf16 inputs).
1170///
1171/// Port of [`moe_weighted_sum_seq_encode`] that accepts bf16 expert_outputs
1172/// and produces f32 output — matching the convention where expert intermediates
1173/// are bf16 but the weighted accumulator (pf_moe_accum) stays f32 for residual
1174/// precision.
1175///
1176/// * `expert_outputs` — bf16 `[n_tokens, top_k, hidden_size]`
1177/// * `weights`        — f32  `[n_tokens, top_k]`
1178/// * `output`         — f32  `[n_tokens, hidden_size]`
1179#[allow(clippy::too_many_arguments)]
1180pub fn moe_weighted_sum_seq_bf16_input_encode(
1181    encoder: &mut CommandEncoder,
1182    registry: &mut KernelRegistry,
1183    device: &metal::DeviceRef,
1184    expert_outputs: &MlxBuffer,
1185    weights: &MlxBuffer,
1186    output: &MlxBuffer,
1187    hidden_size: usize,
1188    top_k: usize,
1189    n_tokens: usize,
1190) -> Result<()> {
1191    if hidden_size == 0 || top_k == 0 || n_tokens == 0 {
1192        return Err(MlxError::InvalidArgument(
1193            "moe_weighted_sum_seq_bf16_input_encode: all dims must be > 0".into(),
1194        ));
1195    }
1196    // expert_outputs is bf16 (2 bytes per element)
1197    let expert_required = n_tokens * top_k * hidden_size * 2;
1198    if expert_outputs.byte_len() < expert_required {
1199        return Err(MlxError::InvalidArgument(format!(
1200            "moe_weighted_sum_seq_bf16_input_encode: expert_outputs too small: need {} bytes, have {}",
1201            expert_required, expert_outputs.byte_len()
1202        )));
1203    }
1204    let weights_required = n_tokens * top_k * std::mem::size_of::<f32>();
1205    if weights.byte_len() < weights_required {
1206        return Err(MlxError::InvalidArgument(format!(
1207            "moe_weighted_sum_seq_bf16_input_encode: weights too small: need {} bytes, have {}",
1208            weights_required, weights.byte_len()
1209        )));
1210    }
1211    let out_required = n_tokens * hidden_size * std::mem::size_of::<f32>();
1212    if output.byte_len() < out_required {
1213        return Err(MlxError::InvalidArgument(format!(
1214            "moe_weighted_sum_seq_bf16_input_encode: output too small: need {} bytes, have {}",
1215            out_required, output.byte_len()
1216        )));
1217    }
1218
1219    let pipeline = registry.get_pipeline("moe_weighted_sum_seq_bf16_input", device)?;
1220    let gpu_params = GpuMoeWeightedSumSeqParams {
1221        hidden_size: hidden_size as u32,
1222        top_k: top_k as u32,
1223        n_tokens: n_tokens as u32,
1224    };
1225
1226    encode_with_args(
1227        encoder,
1228        pipeline,
1229        &[
1230            (0, KernelArg::Buffer(expert_outputs)),
1231            (1, KernelArg::Buffer(weights)),
1232            (2, KernelArg::Buffer(output)),
1233            (3, KernelArg::Bytes(as_bytes(&gpu_params))),
1234        ],
1235        MTLSize::new(hidden_size as u64, n_tokens as u64, 1),
1236        MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
1237    );
1238    Ok(())
1239}
1240
1241// ============================================================================
1242// ADR-020 iter-11h-e3a — backward kernels for moe_weighted_sum_seq.
1243//
1244// Forward (existing): output[t, d] = sum_k expert_outputs[t, k, d] * weights[t, k]
1245// Backward:
1246//   d_expert_outputs[t, k, d] = weights[t, k] * d_output[t, d]            (parallel)
1247//   d_weights[t, k]           = sum_d expert_outputs[t, k, d] * d_output  (reduction)
1248// ============================================================================
1249
1250/// Backward of `moe_weighted_sum_seq` w.r.t. `expert_outputs`.
1251///
1252/// Computes `d_expert_outputs[t, k, d] = weights[t, k] * d_output[t, d]`.
1253/// Embarrassingly parallel — one thread writes one output element.
1254#[allow(clippy::too_many_arguments)]
1255pub fn moe_weighted_sum_seq_backward_outputs_encode(
1256    encoder: &mut CommandEncoder,
1257    registry: &mut KernelRegistry,
1258    device: &metal::DeviceRef,
1259    weights: &MlxBuffer,
1260    d_output: &MlxBuffer,
1261    d_expert_outputs: &MlxBuffer,
1262    hidden_size: usize,
1263    top_k: usize,
1264    n_tokens: usize,
1265) -> Result<()> {
1266    if hidden_size == 0 || top_k == 0 || n_tokens == 0 {
1267        return Err(MlxError::InvalidArgument(
1268            "moe_weighted_sum_seq_backward_outputs_encode: all dims must be > 0".into(),
1269        ));
1270    }
1271    let f32_size = std::mem::size_of::<f32>();
1272    let weights_required = n_tokens * top_k * f32_size;
1273    if weights.byte_len() < weights_required {
1274        return Err(MlxError::InvalidArgument(format!(
1275            "moe_weighted_sum_seq_backward_outputs_encode: weights too small: need {} bytes, have {}",
1276            weights_required,
1277            weights.byte_len()
1278        )));
1279    }
1280    let dout_required = n_tokens * hidden_size * f32_size;
1281    if d_output.byte_len() < dout_required {
1282        return Err(MlxError::InvalidArgument(format!(
1283            "moe_weighted_sum_seq_backward_outputs_encode: d_output too small: need {} bytes, have {}",
1284            dout_required,
1285            d_output.byte_len()
1286        )));
1287    }
1288    let dexp_required = n_tokens * top_k * hidden_size * f32_size;
1289    if d_expert_outputs.byte_len() < dexp_required {
1290        return Err(MlxError::InvalidArgument(format!(
1291            "moe_weighted_sum_seq_backward_outputs_encode: d_expert_outputs too small: need {} bytes, have {}",
1292            dexp_required,
1293            d_expert_outputs.byte_len()
1294        )));
1295    }
1296
1297    let pipeline = registry.get_pipeline("moe_weighted_sum_seq_backward_outputs_f32", device)?;
1298    let gpu_params = GpuMoeWeightedSumSeqParams {
1299        hidden_size: hidden_size as u32,
1300        top_k: top_k as u32,
1301        n_tokens: n_tokens as u32,
1302    };
1303
1304    encode_with_args(
1305        encoder,
1306        pipeline,
1307        &[
1308            (0, KernelArg::Buffer(weights)),
1309            (1, KernelArg::Buffer(d_output)),
1310            (2, KernelArg::Buffer(d_expert_outputs)),
1311            (3, KernelArg::Bytes(as_bytes(&gpu_params))),
1312        ],
1313        MTLSize::new(hidden_size as u64, top_k as u64, n_tokens as u64),
1314        MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
1315    );
1316    Ok(())
1317}
1318
1319/// Backward of `moe_weighted_sum_seq` w.r.t. `weights`.
1320///
1321/// Computes `d_weights[t, k] = sum_d expert_outputs[t, k, d] * d_output[t, d]`.
1322/// Grid: 1D over `n_tokens * top_k`; each thread reduces serially across hidden.
1323#[allow(clippy::too_many_arguments)]
1324pub fn moe_weighted_sum_seq_backward_weights_encode(
1325    encoder: &mut CommandEncoder,
1326    registry: &mut KernelRegistry,
1327    device: &metal::DeviceRef,
1328    expert_outputs: &MlxBuffer,
1329    d_output: &MlxBuffer,
1330    d_weights: &MlxBuffer,
1331    hidden_size: usize,
1332    top_k: usize,
1333    n_tokens: usize,
1334) -> Result<()> {
1335    if hidden_size == 0 || top_k == 0 || n_tokens == 0 {
1336        return Err(MlxError::InvalidArgument(
1337            "moe_weighted_sum_seq_backward_weights_encode: all dims must be > 0".into(),
1338        ));
1339    }
1340    let f32_size = std::mem::size_of::<f32>();
1341    let exp_required = n_tokens * top_k * hidden_size * f32_size;
1342    if expert_outputs.byte_len() < exp_required {
1343        return Err(MlxError::InvalidArgument(format!(
1344            "moe_weighted_sum_seq_backward_weights_encode: expert_outputs too small: need {} bytes, have {}",
1345            exp_required,
1346            expert_outputs.byte_len()
1347        )));
1348    }
1349    let dout_required = n_tokens * hidden_size * f32_size;
1350    if d_output.byte_len() < dout_required {
1351        return Err(MlxError::InvalidArgument(format!(
1352            "moe_weighted_sum_seq_backward_weights_encode: d_output too small: need {} bytes, have {}",
1353            dout_required,
1354            d_output.byte_len()
1355        )));
1356    }
1357    let dw_required = n_tokens * top_k * f32_size;
1358    if d_weights.byte_len() < dw_required {
1359        return Err(MlxError::InvalidArgument(format!(
1360            "moe_weighted_sum_seq_backward_weights_encode: d_weights too small: need {} bytes, have {}",
1361            dw_required,
1362            d_weights.byte_len()
1363        )));
1364    }
1365
1366    let pipeline = registry.get_pipeline("moe_weighted_sum_seq_backward_weights_f32", device)?;
1367    let gpu_params = GpuMoeWeightedSumSeqParams {
1368        hidden_size: hidden_size as u32,
1369        top_k: top_k as u32,
1370        n_tokens: n_tokens as u32,
1371    };
1372
1373    let total = (n_tokens * top_k) as u64;
1374    encode_with_args(
1375        encoder,
1376        pipeline,
1377        &[
1378            (0, KernelArg::Buffer(expert_outputs)),
1379            (1, KernelArg::Buffer(d_output)),
1380            (2, KernelArg::Buffer(d_weights)),
1381            (3, KernelArg::Bytes(as_bytes(&gpu_params))),
1382        ],
1383        MTLSize::new(total, 1, 1),
1384        MTLSize::new(std::cmp::min(256, total), 1, 1),
1385    );
1386    Ok(())
1387}
1388
1389#[cfg(test)]
1390mod backward_weighted_sum_seq_tests {
1391    use super::*;
1392    use crate::device::MlxDevice;
1393    use crate::dtypes::DType;
1394
1395    fn alloc_f32(d: &MlxDevice, n: usize) -> MlxBuffer {
1396        let mut b = d.alloc_buffer(n * 4, DType::F32, vec![n]).unwrap();
1397        b.as_mut_slice::<f32>().unwrap().fill(0.0);
1398        b
1399    }
1400    fn fill_f32(buf: &mut MlxBuffer, vals: &[f32]) {
1401        buf.as_mut_slice::<f32>().unwrap()[..vals.len()].copy_from_slice(vals);
1402    }
1403
1404    fn cpu_forward(
1405        expert_outs: &[f32],
1406        weights: &[f32],
1407        n_tokens: usize,
1408        top_k: usize,
1409        hidden: usize,
1410    ) -> Vec<f32> {
1411        let mut out = vec![0.0f32; n_tokens * hidden];
1412        for t in 0..n_tokens {
1413            for d in 0..hidden {
1414                let mut sum = 0.0f64;
1415                for k in 0..top_k {
1416                    let exp_ix = (t * top_k + k) * hidden + d;
1417                    let w_ix = t * top_k + k;
1418                    sum += (expert_outs[exp_ix] as f64) * (weights[w_ix] as f64);
1419                }
1420                out[t * hidden + d] = sum as f32;
1421            }
1422        }
1423        out
1424    }
1425
1426    /// FD falsifier on d_expert_outputs.
1427    /// L = sum(output * d_output_seed); d_expert_outputs[t,k,d] should equal
1428    /// weights[t,k] * d_output_seed[t,d].
1429    #[test]
1430    fn backward_outputs_finite_difference_falsifier() {
1431        let device = MlxDevice::new().unwrap();
1432        let mut registry = KernelRegistry::new();
1433        let n_tokens = 3usize;
1434        let top_k = 4usize;
1435        let hidden = 7usize;
1436
1437        let expert_outs: Vec<f32> = (0..n_tokens * top_k * hidden)
1438            .map(|i| 0.1 + (i as f32) * 0.013 - (i as f32 * 0.004).sin())
1439            .collect();
1440        let weights: Vec<f32> = (0..n_tokens * top_k)
1441            .map(|i| 0.2 + (i as f32) * 0.07)
1442            .collect();
1443        let d_output_seed: Vec<f32> = (0..n_tokens * hidden)
1444            .map(|i| 0.3 + (i as f32) * 0.05 - (i as f32 * 0.011).cos())
1445            .collect();
1446
1447        let mut exp_buf = alloc_f32(&device, n_tokens * top_k * hidden);
1448        fill_f32(&mut exp_buf, &expert_outs);
1449        let mut w_buf = alloc_f32(&device, n_tokens * top_k);
1450        fill_f32(&mut w_buf, &weights);
1451        let mut dout_buf = alloc_f32(&device, n_tokens * hidden);
1452        fill_f32(&mut dout_buf, &d_output_seed);
1453        let dexp_buf = alloc_f32(&device, n_tokens * top_k * hidden);
1454
1455        let mut encoder = device.command_encoder().unwrap();
1456        moe_weighted_sum_seq_backward_outputs_encode(
1457            &mut encoder, &mut registry, device.metal_device(),
1458            &w_buf, &dout_buf, &dexp_buf, hidden, top_k, n_tokens,
1459        ).unwrap();
1460        encoder.commit_and_wait().unwrap();
1461        let analytic = dexp_buf.as_slice::<f32>().unwrap().to_vec();
1462
1463        let h: f32 = 1e-3;
1464        for idx in 0..(n_tokens * top_k * hidden) {
1465            let mut ep = expert_outs.clone(); ep[idx] += h;
1466            let mut em = expert_outs.clone(); em[idx] -= h;
1467            let yp = cpu_forward(&ep, &weights, n_tokens, top_k, hidden);
1468            let ym = cpu_forward(&em, &weights, n_tokens, top_k, hidden);
1469            let lp: f64 = yp.iter().zip(&d_output_seed).map(|(a, b)| (*a as f64) * (*b as f64)).sum();
1470            let lm: f64 = ym.iter().zip(&d_output_seed).map(|(a, b)| (*a as f64) * (*b as f64)).sum();
1471            let fd = (lp - lm) / (2.0 * h as f64);
1472            let tol = 5e-2 * fd.abs().max(1.0);
1473            assert!(
1474                (analytic[idx] as f64 - fd).abs() < tol,
1475                "d_expert_outputs[{}]: analytic={} fd={}",
1476                idx, analytic[idx], fd
1477            );
1478        }
1479    }
1480
1481    /// FD falsifier on d_weights.
1482    /// d_weights[t,k] should equal sum_d expert_outputs[t,k,d] * d_output[t,d].
1483    #[test]
1484    fn backward_weights_finite_difference_falsifier() {
1485        let device = MlxDevice::new().unwrap();
1486        let mut registry = KernelRegistry::new();
1487        let n_tokens = 4usize;
1488        let top_k = 3usize;
1489        let hidden = 11usize;
1490
1491        let expert_outs: Vec<f32> = (0..n_tokens * top_k * hidden)
1492            .map(|i| 0.05 + (i as f32) * 0.017 + (i as f32 * 0.009).sin())
1493            .collect();
1494        let weights: Vec<f32> = (0..n_tokens * top_k)
1495            .map(|i| 0.1 + (i as f32) * 0.05)
1496            .collect();
1497        let d_output_seed: Vec<f32> = (0..n_tokens * hidden)
1498            .map(|i| 0.2 + (i as f32) * 0.03 - (i as f32 * 0.013).cos())
1499            .collect();
1500
1501        let mut exp_buf = alloc_f32(&device, n_tokens * top_k * hidden);
1502        fill_f32(&mut exp_buf, &expert_outs);
1503        let mut dout_buf = alloc_f32(&device, n_tokens * hidden);
1504        fill_f32(&mut dout_buf, &d_output_seed);
1505        let dw_buf = alloc_f32(&device, n_tokens * top_k);
1506
1507        let mut encoder = device.command_encoder().unwrap();
1508        moe_weighted_sum_seq_backward_weights_encode(
1509            &mut encoder, &mut registry, device.metal_device(),
1510            &exp_buf, &dout_buf, &dw_buf, hidden, top_k, n_tokens,
1511        ).unwrap();
1512        encoder.commit_and_wait().unwrap();
1513        let analytic = dw_buf.as_slice::<f32>().unwrap().to_vec();
1514
1515        let h: f32 = 1e-3;
1516        for idx in 0..(n_tokens * top_k) {
1517            let mut wp = weights.clone(); wp[idx] += h;
1518            let mut wm = weights.clone(); wm[idx] -= h;
1519            let yp = cpu_forward(&expert_outs, &wp, n_tokens, top_k, hidden);
1520            let ym = cpu_forward(&expert_outs, &wm, n_tokens, top_k, hidden);
1521            let lp: f64 = yp.iter().zip(&d_output_seed).map(|(a, b)| (*a as f64) * (*b as f64)).sum();
1522            let lm: f64 = ym.iter().zip(&d_output_seed).map(|(a, b)| (*a as f64) * (*b as f64)).sum();
1523            let fd = (lp - lm) / (2.0 * h as f64);
1524            let tol = 5e-2 * fd.abs().max(1.0);
1525            assert!(
1526                (analytic[idx] as f64 - fd).abs() < tol,
1527                "d_weights[{}]: analytic={} fd={}",
1528                idx, analytic[idx], fd
1529            );
1530        }
1531    }
1532
1533    /// CPU oracle round-trip: confirm both kernels match closed-form on a
1534    /// small fixture.
1535    #[test]
1536    fn forward_then_backward_round_trip_matches_cpu_oracle() {
1537        let device = MlxDevice::new().unwrap();
1538        let mut registry = KernelRegistry::new();
1539        let n_tokens = 2usize;
1540        let top_k = 2usize;
1541        let hidden = 5usize;
1542
1543        let expert_outs: Vec<f32> = (0..n_tokens * top_k * hidden)
1544            .map(|i| (i as f32) * 0.1 - 0.2)
1545            .collect();
1546        let weights: Vec<f32> = (0..n_tokens * top_k)
1547            .map(|i| 0.3 + (i as f32) * 0.1)
1548            .collect();
1549        let d_output_seed: Vec<f32> = (0..n_tokens * hidden)
1550            .map(|i| 1.0 - (i as f32) * 0.05)
1551            .collect();
1552
1553        // CPU oracles
1554        let mut cpu_dexp = vec![0.0f32; n_tokens * top_k * hidden];
1555        let mut cpu_dw = vec![0.0f32; n_tokens * top_k];
1556        for t in 0..n_tokens {
1557            for k in 0..top_k {
1558                for d in 0..hidden {
1559                    let exp_ix = (t * top_k + k) * hidden + d;
1560                    let w_ix = t * top_k + k;
1561                    let dout_ix = t * hidden + d;
1562                    cpu_dexp[exp_ix] = weights[w_ix] * d_output_seed[dout_ix];
1563                    cpu_dw[w_ix] += expert_outs[exp_ix] * d_output_seed[dout_ix];
1564                }
1565            }
1566        }
1567
1568        let mut exp_buf = alloc_f32(&device, n_tokens * top_k * hidden);
1569        fill_f32(&mut exp_buf, &expert_outs);
1570        let mut w_buf = alloc_f32(&device, n_tokens * top_k);
1571        fill_f32(&mut w_buf, &weights);
1572        let mut dout_buf = alloc_f32(&device, n_tokens * hidden);
1573        fill_f32(&mut dout_buf, &d_output_seed);
1574        let dexp_buf = alloc_f32(&device, n_tokens * top_k * hidden);
1575        let dw_buf = alloc_f32(&device, n_tokens * top_k);
1576
1577        let mut encoder = device.command_encoder().unwrap();
1578        moe_weighted_sum_seq_backward_outputs_encode(
1579            &mut encoder, &mut registry, device.metal_device(),
1580            &w_buf, &dout_buf, &dexp_buf, hidden, top_k, n_tokens,
1581        ).unwrap();
1582        moe_weighted_sum_seq_backward_weights_encode(
1583            &mut encoder, &mut registry, device.metal_device(),
1584            &exp_buf, &dout_buf, &dw_buf, hidden, top_k, n_tokens,
1585        ).unwrap();
1586        encoder.commit_and_wait().unwrap();
1587
1588        let gpu_dexp = dexp_buf.as_slice::<f32>().unwrap().to_vec();
1589        let gpu_dw = dw_buf.as_slice::<f32>().unwrap().to_vec();
1590
1591        for i in 0..gpu_dexp.len() {
1592            assert!(
1593                (gpu_dexp[i] - cpu_dexp[i]).abs() < 1e-5,
1594                "d_expert_outputs[{}]: gpu={} cpu={}",
1595                i, gpu_dexp[i], cpu_dexp[i]
1596            );
1597        }
1598        for i in 0..gpu_dw.len() {
1599            assert!(
1600                (gpu_dw[i] - cpu_dw[i]).abs() < 1e-5,
1601                "d_weights[{}]: gpu={} cpu={}",
1602                i, gpu_dw[i], cpu_dw[i]
1603            );
1604        }
1605    }
1606
1607    #[test]
1608    fn rejects_size_mismatch() {
1609        let device = MlxDevice::new().unwrap();
1610        let mut registry = KernelRegistry::new();
1611        let too_small = alloc_f32(&device, 4);
1612        let any = alloc_f32(&device, 1024);
1613        let mut encoder = device.command_encoder().unwrap();
1614        let res = moe_weighted_sum_seq_backward_outputs_encode(
1615            &mut encoder, &mut registry, device.metal_device(),
1616            &too_small, &any, &any, 7, 3, 4,
1617        );
1618        assert!(res.is_err());
1619        let res2 = moe_weighted_sum_seq_backward_weights_encode(
1620            &mut encoder, &mut registry, device.metal_device(),
1621            &too_small, &any, &any, 11, 3, 4,
1622        );
1623        assert!(res2.is_err());
1624    }
1625}
1626
1627// ============================================================================
1628// ADR-020 iter-11h-e3b — fused backward kernel for moe_swiglu_seq.
1629//
1630// Forward (existing): output[t,k,i] = gelu(gate[t,k,i]) * up[t,k,i] where
1631// gate_up is concatenated [t, k, 2*intermediate] (gate at offset 0..I,
1632// up at offset I..2I).
1633//
1634// Backward (this kernel — fused so gelu' intermediates are reused):
1635//   ∂L/∂gate[t,k,i] = ∂output[t,k,i] · up · gelu'(gate)
1636//   ∂L/∂up[t,k,i]   = ∂output[t,k,i] · gelu(gate)
1637// Writes both into a single d_gate_up buffer with the SAME [t,k,2I]
1638// layout as the forward gate_up input (gate grad in lower half, up grad
1639// in upper half).
1640// ============================================================================
1641
1642/// Backward of `moe_swiglu_seq` — single fused kernel writes both gate
1643/// and up gradients into the supplied `d_gate_up` buffer (same layout
1644/// as forward `gate_up`).
1645#[allow(clippy::too_many_arguments)]
1646pub fn moe_swiglu_seq_backward_encode(
1647    encoder: &mut CommandEncoder,
1648    registry: &mut KernelRegistry,
1649    device: &metal::DeviceRef,
1650    gate_up: &MlxBuffer,
1651    d_output: &MlxBuffer,
1652    d_gate_up: &MlxBuffer,
1653    intermediate: usize,
1654    top_k: usize,
1655    n_tokens: usize,
1656) -> Result<()> {
1657    if intermediate == 0 || top_k == 0 || n_tokens == 0 {
1658        return Err(MlxError::InvalidArgument(
1659            "moe_swiglu_seq_backward_encode: all dims must be > 0".into(),
1660        ));
1661    }
1662    let f32_size = std::mem::size_of::<f32>();
1663    let gu_required = n_tokens * top_k * 2 * intermediate * f32_size;
1664    if gate_up.byte_len() < gu_required {
1665        return Err(MlxError::InvalidArgument(format!(
1666            "moe_swiglu_seq_backward_encode: gate_up too small: need {} bytes, have {}",
1667            gu_required, gate_up.byte_len()
1668        )));
1669    }
1670    if d_gate_up.byte_len() < gu_required {
1671        return Err(MlxError::InvalidArgument(format!(
1672            "moe_swiglu_seq_backward_encode: d_gate_up too small: need {} bytes, have {}",
1673            gu_required, d_gate_up.byte_len()
1674        )));
1675    }
1676    let dout_required = n_tokens * top_k * intermediate * f32_size;
1677    if d_output.byte_len() < dout_required {
1678        return Err(MlxError::InvalidArgument(format!(
1679            "moe_swiglu_seq_backward_encode: d_output too small: need {} bytes, have {}",
1680            dout_required, d_output.byte_len()
1681        )));
1682    }
1683
1684    let pipeline = registry.get_pipeline("moe_swiglu_seq_backward_f32", device)?;
1685    let gpu_params = GpuMoeSwigluSeqParams {
1686        intermediate: intermediate as u32,
1687        top_k: top_k as u32,
1688        n_tokens: n_tokens as u32,
1689    };
1690
1691    encode_with_args(
1692        encoder,
1693        pipeline,
1694        &[
1695            (0, KernelArg::Buffer(gate_up)),
1696            (1, KernelArg::Buffer(d_output)),
1697            (2, KernelArg::Buffer(d_gate_up)),
1698            (3, KernelArg::Bytes(as_bytes(&gpu_params))),
1699        ],
1700        MTLSize::new(intermediate as u64, top_k as u64, n_tokens as u64),
1701        MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
1702    );
1703    Ok(())
1704}
1705
1706#[cfg(test)]
1707mod backward_swiglu_seq_tests {
1708    use super::*;
1709    use crate::device::MlxDevice;
1710    use crate::dtypes::DType;
1711
1712    fn alloc_f32(d: &MlxDevice, n: usize) -> MlxBuffer {
1713        let mut b = d.alloc_buffer(n * 4, DType::F32, vec![n]).unwrap();
1714        b.as_mut_slice::<f32>().unwrap().fill(0.0);
1715        b
1716    }
1717    fn fill_f32(buf: &mut MlxBuffer, vals: &[f32]) {
1718        buf.as_mut_slice::<f32>().unwrap()[..vals.len()].copy_from_slice(vals);
1719    }
1720
1721    /// FP64-precision tanh-approx GELU oracle, matches the .metal kernel.
1722    fn cpu_gelu(g: f64) -> f64 {
1723        let s = 0.7978845608 * (g + 0.044715 * g * g * g);
1724        let t = s.tanh();
1725        0.5 * g * (1.0 + t)
1726    }
1727
1728    /// CPU forward replicates `moe_swiglu_seq` exactly.
1729    fn cpu_forward(
1730        gate_up: &[f32],
1731        n_tokens: usize,
1732        top_k: usize,
1733        intermediate: usize,
1734    ) -> Vec<f32> {
1735        let mut out = vec![0.0f32; n_tokens * top_k * intermediate];
1736        for t in 0..n_tokens {
1737            for k in 0..top_k {
1738                let slot_base = (t * top_k + k) * 2 * intermediate;
1739                for i in 0..intermediate {
1740                    let g = gate_up[slot_base + i] as f64;
1741                    let u = gate_up[slot_base + intermediate + i] as f64;
1742                    let y = cpu_gelu(g) * u;
1743                    out[(t * top_k + k) * intermediate + i] = y as f32;
1744                }
1745            }
1746        }
1747        out
1748    }
1749
1750    /// FD falsifier on ∂gate AND ∂up.  Loss L = sum(forward_output * d_output_seed).
1751    /// Probes every element of gate_up (both halves) at h=1e-3.
1752    #[test]
1753    fn backward_finite_difference_falsifier_both_gradients() {
1754        let device = MlxDevice::new().unwrap();
1755        let mut registry = KernelRegistry::new();
1756        let n_tokens = 2usize;
1757        let top_k = 3usize;
1758        let intermediate = 5usize;
1759        let gu_n = n_tokens * top_k * 2 * intermediate;
1760        let dout_n = n_tokens * top_k * intermediate;
1761
1762        // Deterministic non-trivial gate values (negative + small + large)
1763        // exercise GELU's full domain.
1764        let gate_up: Vec<f32> = (0..gu_n)
1765            .map(|i| 0.5 + (i as f32) * 0.07 - (i as f32 * 0.013).sin())
1766            .collect();
1767        let d_output_seed: Vec<f32> = (0..dout_n)
1768            .map(|i| 0.3 + (i as f32) * 0.05 - (i as f32 * 0.011).cos())
1769            .collect();
1770
1771        let mut gu_buf = alloc_f32(&device, gu_n);
1772        fill_f32(&mut gu_buf, &gate_up);
1773        let mut dout_buf = alloc_f32(&device, dout_n);
1774        fill_f32(&mut dout_buf, &d_output_seed);
1775        let dgu_buf = alloc_f32(&device, gu_n);
1776
1777        let mut encoder = device.command_encoder().unwrap();
1778        moe_swiglu_seq_backward_encode(
1779            &mut encoder, &mut registry, device.metal_device(),
1780            &gu_buf, &dout_buf, &dgu_buf,
1781            intermediate, top_k, n_tokens,
1782        ).unwrap();
1783        encoder.commit_and_wait().unwrap();
1784        let analytic = dgu_buf.as_slice::<f32>().unwrap().to_vec();
1785
1786        // Finite difference probe over EVERY element of gate_up.
1787        let h: f32 = 1e-3;
1788        for idx in 0..gu_n {
1789            let mut gp = gate_up.clone();
1790            gp[idx] += h;
1791            let mut gm = gate_up.clone();
1792            gm[idx] -= h;
1793            let yp = cpu_forward(&gp, n_tokens, top_k, intermediate);
1794            let ym = cpu_forward(&gm, n_tokens, top_k, intermediate);
1795            let lp: f64 = yp.iter().zip(&d_output_seed)
1796                .map(|(a, b)| (*a as f64) * (*b as f64)).sum();
1797            let lm: f64 = ym.iter().zip(&d_output_seed)
1798                .map(|(a, b)| (*a as f64) * (*b as f64)).sum();
1799            let fd = (lp - lm) / (2.0 * h as f64);
1800            let tol = 5e-2 * fd.abs().max(1.0);
1801            assert!(
1802                (analytic[idx] as f64 - fd).abs() < tol,
1803                "d_gate_up[{}]: analytic={} fd={} (gate_up_value={})",
1804                idx, analytic[idx], fd, gate_up[idx]
1805            );
1806        }
1807    }
1808
1809    /// Cross-check: at canonical points (gate=0, gate=large positive,
1810    /// gate=large negative) check known gradient asymptotics.
1811    /// - At gate=0:    gelu(0)=0, gelu'(0)=0.5  → ∂up=0,             ∂gate=0.5·dy·up
1812    /// - At gate=+10:  gelu≈10,    gelu'≈1     → ∂up≈10·dy,           ∂gate≈up·dy
1813    /// - At gate=-10:  gelu≈0,     gelu'≈0     → ∂up≈0,               ∂gate≈0
1814    #[test]
1815    fn backward_canonical_asymptotics_match_expected() {
1816        let device = MlxDevice::new().unwrap();
1817        let mut registry = KernelRegistry::new();
1818        let n_tokens = 1usize;
1819        let top_k = 1usize;
1820        let intermediate = 3usize;
1821        let gu_n = n_tokens * top_k * 2 * intermediate;
1822        let dout_n = n_tokens * top_k * intermediate;
1823
1824        // gate_up layout: [gate0, gate1, gate2, up0, up1, up2]
1825        // gate0=0   up0=2.0    dy0=0.5    →  ∂up0=0,        ∂gate0=0.5·0.5·2.0=0.5
1826        // gate1=10  up1=3.0    dy1=1.0    →  ∂up1≈10.0,     ∂gate1≈3.0
1827        // gate2=-10 up2=4.0    dy2=1.0    →  ∂up2≈0.0,      ∂gate2≈0.0
1828        let gate_up = vec![0.0f32, 10.0, -10.0, 2.0, 3.0, 4.0];
1829        let d_output_seed = vec![0.5f32, 1.0, 1.0];
1830
1831        let mut gu_buf = alloc_f32(&device, gu_n);
1832        fill_f32(&mut gu_buf, &gate_up);
1833        let mut dout_buf = alloc_f32(&device, dout_n);
1834        fill_f32(&mut dout_buf, &d_output_seed);
1835        let dgu_buf = alloc_f32(&device, gu_n);
1836
1837        let mut encoder = device.command_encoder().unwrap();
1838        moe_swiglu_seq_backward_encode(
1839            &mut encoder, &mut registry, device.metal_device(),
1840            &gu_buf, &dout_buf, &dgu_buf,
1841            intermediate, top_k, n_tokens,
1842        ).unwrap();
1843        encoder.commit_and_wait().unwrap();
1844        let g = dgu_buf.as_slice::<f32>().unwrap();
1845
1846        // ∂gate0 = 0.5·dy0·up0 = 0.5·0.5·2.0 = 0.5  (exact)
1847        assert!((g[0] - 0.5).abs() < 1e-5, "∂gate0={}", g[0]);
1848        // ∂gate1 ≈ up1·dy1 = 3.0  (asymptotic gelu'(10) ≈ 1)
1849        assert!((g[1] - 3.0).abs() < 0.05, "∂gate1={}", g[1]);
1850        // ∂gate2 ≈ 0  (asymptotic gelu'(-10) ≈ 0)
1851        assert!(g[2].abs() < 0.05, "∂gate2={}", g[2]);
1852        // ∂up0 = dy0·gelu(0) = 0.5·0 = 0
1853        assert!(g[3].abs() < 1e-5, "∂up0={}", g[3]);
1854        // ∂up1 ≈ dy1·gelu(10) ≈ 1.0·10 = 10.0
1855        assert!((g[4] - 10.0).abs() < 0.05, "∂up1={}", g[4]);
1856        // ∂up2 ≈ dy2·gelu(-10) ≈ 0
1857        assert!(g[5].abs() < 0.05, "∂up2={}", g[5]);
1858    }
1859
1860    #[test]
1861    fn rejects_size_mismatch() {
1862        let device = MlxDevice::new().unwrap();
1863        let mut registry = KernelRegistry::new();
1864        let too_small = alloc_f32(&device, 4);
1865        let any = alloc_f32(&device, 1024);
1866        let mut encoder = device.command_encoder().unwrap();
1867        let res = moe_swiglu_seq_backward_encode(
1868            &mut encoder, &mut registry, device.metal_device(),
1869            &too_small, &any, &any, 5, 3, 2,
1870        );
1871        assert!(res.is_err());
1872    }
1873}