Skip to main content

mlx_native/ops/
moe_dispatch.rs

1//! GPU-accelerated MoE expert dispatch (Stage 1: loop over selected experts).
2//!
3//! For each of the K selected experts, runs:
4//!   gate_out   = gate_proj_e(x)       [input_dim -> intermediate_dim]
5//!   up_out     = up_proj_e(x)         [input_dim -> intermediate_dim]
6//!   hidden     = GELU(gate_out) * up_out
7//!   expert_out = down_proj_e(hidden)  [intermediate_dim -> input_dim]
8//!   result    += routing_weight_e * expert_out
9//!
10//! Stage 1 uses individual kernel dispatches per expert and per projection.
11//! The projections use float matmul (caller dequantizes or provides float
12//! weights).  Stage 2 optimization (Epic 6) would fuse these.
13//!
14//! This module provides the high-level `moe_dispatch` function that
15//! orchestrates the per-expert loop, using the fused_gelu_mul and
16//! moe_accumulate shaders from moe_dispatch.metal.
17
18use metal::MTLSize;
19
20use crate::buffer::MlxBuffer;
21use crate::encoder::CommandEncoder;
22use crate::error::{MlxError, Result};
23use crate::kernel_registry::KernelRegistry;
24
25use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
26
27/// Parameters for MoE dispatch.
28pub struct MoeDispatchParams {
29    /// Input/output dimension (e.g. 2816 for Gemma 4 MoE layers).
30    pub input_dim: usize,
31    /// Intermediate FFN dimension per expert (e.g. 704 for Gemma 4).
32    pub intermediate_dim: usize,
33    /// Number of selected experts (top_k, e.g. 8).
34    pub n_selected: usize,
35}
36
37/// A single expert's weight matrices (float32, pre-dequantized or float).
38///
39/// Each expert has three projection matrices:
40/// * `gate_proj`: `[input_dim, intermediate_dim]` row-major
41/// * `up_proj`:   `[input_dim, intermediate_dim]` row-major
42/// * `down_proj`: `[intermediate_dim, input_dim]` row-major
43pub struct ExpertWeights<'a> {
44    pub gate_proj: &'a MlxBuffer,
45    pub up_proj: &'a MlxBuffer,
46    pub down_proj: &'a MlxBuffer,
47}
48
49/// MSL-compatible struct for fused_gelu_mul kernel.
50#[repr(C)]
51#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
52struct GpuFusedGeluMulParams {
53    n_elements: u32,
54}
55
56/// MSL-compatible struct for moe_accumulate kernel.
57#[repr(C)]
58#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
59struct GpuMoeAccumParams {
60    n_elements: u32,
61    routing_weight: f32,
62}
63
64/// MSL-compatible struct for zero_buffer kernel.
65#[repr(C)]
66#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
67struct GpuZeroParams {
68    n_elements: u32,
69}
70
71/// MSL-compatible struct for a simple matmul params.
72/// This is used with the naive_matmul shader for expert projections.
73#[repr(C)]
74#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
75struct GpuMatmulParams {
76    m: u32,  // rows of output (1 for single-token)
77    k: u32,  // inner dimension
78    n: u32,  // cols of output
79}
80
81/// Encode MoE dispatch: loop over selected experts, run FFN, accumulate.
82///
83/// This is the Stage 1 implementation that loops over each selected expert
84/// and dispatches individual compute passes for each projection.
85///
86/// # Buffer expectations
87///
88/// * `input` — f32, `[input_dim]` (single token hidden state)
89/// * `expert_weights` — slice of `n_selected` expert weight structs, each
90///   containing gate_proj, up_proj, down_proj as f32 buffers
91/// * `routing_weights` — f32, `[n_selected]` (softmax routing weights from moe_gate)
92/// * `output` — f32, `[input_dim]` (output, will be zero-initialized)
93/// * `scratch_gate` — f32, `[intermediate_dim]` (scratch buffer for gate_proj output)
94/// * `scratch_up` — f32, `[intermediate_dim]` (scratch buffer for up_proj output)
95/// * `scratch_hidden` — f32, `[intermediate_dim]` (scratch buffer for GELU*up output)
96/// * `scratch_expert` — f32, `[input_dim]` (scratch buffer for down_proj output)
97///
98/// # Design Notes
99///
100/// The caller provides scratch buffers to avoid allocating inside the
101/// encoding loop.  These can come from a `MlxBufferPool`.
102///
103/// For the matrix projections, we use a naive matmul kernel (single-token,
104/// M=1, so it's really a matvec).  The quantized_matmul from Story 1.2
105/// would be used when weights remain quantized.  Stage 1 assumes float
106/// weights for simplicity.
107///
108/// # Errors
109///
110/// Returns `MlxError::InvalidArgument` if parameters are invalid.
111#[allow(clippy::too_many_arguments)]
112pub fn moe_dispatch(
113    encoder: &mut CommandEncoder,
114    registry: &mut KernelRegistry,
115    device: &metal::DeviceRef,
116    input: &MlxBuffer,
117    expert_weights: &[ExpertWeights<'_>],
118    routing_weights: &[f32],
119    output: &MlxBuffer,
120    scratch_gate: &MlxBuffer,
121    scratch_up: &MlxBuffer,
122    scratch_hidden: &MlxBuffer,
123    scratch_expert: &MlxBuffer,
124    params: &MoeDispatchParams,
125) -> Result<()> {
126    // --- Validation ---
127    if params.input_dim == 0 {
128        return Err(MlxError::InvalidArgument(
129            "moe_dispatch: input_dim must be > 0".into(),
130        ));
131    }
132    if params.intermediate_dim == 0 {
133        return Err(MlxError::InvalidArgument(
134            "moe_dispatch: intermediate_dim must be > 0".into(),
135        ));
136    }
137    if params.n_selected == 0 {
138        return Err(MlxError::InvalidArgument(
139            "moe_dispatch: n_selected must be > 0".into(),
140        ));
141    }
142    if expert_weights.len() != params.n_selected {
143        return Err(MlxError::InvalidArgument(format!(
144            "moe_dispatch: expert_weights length ({}) must match n_selected ({})",
145            expert_weights.len(),
146            params.n_selected
147        )));
148    }
149    if routing_weights.len() != params.n_selected {
150        return Err(MlxError::InvalidArgument(format!(
151            "moe_dispatch: routing_weights length ({}) must match n_selected ({})",
152            routing_weights.len(),
153            params.n_selected
154        )));
155    }
156
157    // Validate buffer sizes
158    let input_bytes = params.input_dim * std::mem::size_of::<f32>();
159    if input.byte_len() < input_bytes {
160        return Err(MlxError::InvalidArgument(format!(
161            "moe_dispatch: input buffer too small: need {} bytes, have {}",
162            input_bytes,
163            input.byte_len()
164        )));
165    }
166    if output.byte_len() < input_bytes {
167        return Err(MlxError::InvalidArgument(format!(
168            "moe_dispatch: output buffer too small: need {} bytes, have {}",
169            input_bytes,
170            output.byte_len()
171        )));
172    }
173
174    let intermediate_bytes = params.intermediate_dim * std::mem::size_of::<f32>();
175    if scratch_gate.byte_len() < intermediate_bytes {
176        return Err(MlxError::InvalidArgument(
177            "moe_dispatch: scratch_gate buffer too small".into(),
178        ));
179    }
180    if scratch_up.byte_len() < intermediate_bytes {
181        return Err(MlxError::InvalidArgument(
182            "moe_dispatch: scratch_up buffer too small".into(),
183        ));
184    }
185    if scratch_hidden.byte_len() < intermediate_bytes {
186        return Err(MlxError::InvalidArgument(
187            "moe_dispatch: scratch_hidden buffer too small".into(),
188        ));
189    }
190    if scratch_expert.byte_len() < input_bytes {
191        return Err(MlxError::InvalidArgument(
192            "moe_dispatch: scratch_expert buffer too small".into(),
193        ));
194    }
195
196    // Validate expert weight sizes
197    let gate_up_bytes = params.input_dim * params.intermediate_dim * std::mem::size_of::<f32>();
198    let down_bytes = params.intermediate_dim * params.input_dim * std::mem::size_of::<f32>();
199
200    for (i, ew) in expert_weights.iter().enumerate() {
201        if ew.gate_proj.byte_len() < gate_up_bytes {
202            return Err(MlxError::InvalidArgument(format!(
203                "moe_dispatch: expert {} gate_proj too small: need {} bytes, have {}",
204                i, gate_up_bytes, ew.gate_proj.byte_len()
205            )));
206        }
207        if ew.up_proj.byte_len() < gate_up_bytes {
208            return Err(MlxError::InvalidArgument(format!(
209                "moe_dispatch: expert {} up_proj too small: need {} bytes, have {}",
210                i, gate_up_bytes, ew.up_proj.byte_len()
211            )));
212        }
213        if ew.down_proj.byte_len() < down_bytes {
214            return Err(MlxError::InvalidArgument(format!(
215                "moe_dispatch: expert {} down_proj too small: need {} bytes, have {}",
216                i, down_bytes, ew.down_proj.byte_len()
217            )));
218        }
219    }
220
221    // --- Pre-warm pipeline cache to avoid compilation during hot loop ---
222    // Each get_pipeline call borrows &mut registry; we must not hold multiple
223    // returned references simultaneously.  Pre-warming ensures subsequent
224    // get_pipeline calls are cache hits, and we retrieve each one just before
225    // use (via helper closures or repeated single calls).
226    {
227        registry.get_pipeline("naive_matvec_f32", device)?;
228        registry.get_pipeline("fused_gelu_mul", device)?;
229        registry.get_pipeline("moe_accumulate", device)?;
230        registry.get_pipeline("zero_buffer", device)?;
231    }
232    // SAFETY: After pre-warming, the pipelines are in the cache and the HashMap
233    // entries will not be moved or removed for the lifetime of `registry`.
234    // We convert to raw pointers to avoid the borrow checker's complaint about
235    // multiple &mut borrows while holding references into the cache.
236    let matvec_pipeline: *const metal::ComputePipelineState = {
237        let p = registry.get_pipeline("naive_matvec_f32", device)?;
238        p as *const _
239    };
240    let gelu_mul_pipeline: *const metal::ComputePipelineState = {
241        let p = registry.get_pipeline("fused_gelu_mul", device)?;
242        p as *const _
243    };
244    let accum_pipeline: *const metal::ComputePipelineState = {
245        let p = registry.get_pipeline("moe_accumulate", device)?;
246        p as *const _
247    };
248    let zero_pipeline: *const metal::ComputePipelineState = {
249        let p = registry.get_pipeline("zero_buffer", device)?;
250        p as *const _
251    };
252    // Re-borrow from the raw pointers.  This is safe because:
253    //   1. The HashMap entries are stable (no insertions/removals below).
254    //   2. The registry outlives these references.
255    //   3. We only read through these pointers.
256    let matvec_pipeline = unsafe { &*matvec_pipeline };
257    let gelu_mul_pipeline = unsafe { &*gelu_mul_pipeline };
258    let accum_pipeline = unsafe { &*accum_pipeline };
259    let zero_pipeline = unsafe { &*zero_pipeline };
260
261    // --- Zero-initialize output ---
262    let zero_params = GpuZeroParams {
263        n_elements: params.input_dim as u32,
264    };
265    encode_with_args(
266        encoder,
267        zero_pipeline,
268        &[
269            (0, KernelArg::Buffer(output)),
270            (1, KernelArg::Bytes(as_bytes(&zero_params))),
271        ],
272        MTLSize::new(params.input_dim as u64, 1, 1),
273        MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
274    );
275
276    // --- Loop over selected experts ---
277    for (i, ew) in expert_weights.iter().enumerate() {
278        let w = routing_weights[i];
279
280        // Skip experts with near-zero routing weight
281        if w.abs() < 1e-10 {
282            continue;
283        }
284
285        // Barrier after zero_buffer (first iteration) or after previous
286        // accumulate (subsequent iterations) — the output buffer was just written.
287        encoder.memory_barrier();
288
289        // gate_out = gate_proj @ input  (matvec: [intermediate_dim, input_dim] @ [input_dim] -> [intermediate_dim])
290        let gate_params = GpuMatmulParams {
291            m: 1,
292            k: params.input_dim as u32,
293            n: params.intermediate_dim as u32,
294        };
295        encode_with_args(
296            encoder,
297            matvec_pipeline,
298            &[
299                (0, KernelArg::Buffer(ew.gate_proj)),
300                (1, KernelArg::Buffer(input)),
301                (2, KernelArg::Buffer(scratch_gate)),
302                (3, KernelArg::Bytes(as_bytes(&gate_params))),
303            ],
304            MTLSize::new(params.intermediate_dim as u64, 1, 1),
305            MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
306        );
307
308        // up_out = up_proj @ input  (same shape as gate)
309        let up_params = GpuMatmulParams {
310            m: 1,
311            k: params.input_dim as u32,
312            n: params.intermediate_dim as u32,
313        };
314        encode_with_args(
315            encoder,
316            matvec_pipeline,
317            &[
318                (0, KernelArg::Buffer(ew.up_proj)),
319                (1, KernelArg::Buffer(input)),
320                (2, KernelArg::Buffer(scratch_up)),
321                (3, KernelArg::Bytes(as_bytes(&up_params))),
322            ],
323            MTLSize::new(params.intermediate_dim as u64, 1, 1),
324            MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
325        );
326
327        // Barrier: gate_out and up_out must complete before gelu_mul reads them.
328        encoder.memory_barrier();
329
330        // hidden = GELU(gate_out) * up_out
331        let gelu_params = GpuFusedGeluMulParams {
332            n_elements: params.intermediate_dim as u32,
333        };
334        encode_with_args(
335            encoder,
336            gelu_mul_pipeline,
337            &[
338                (0, KernelArg::Buffer(scratch_gate)),
339                (1, KernelArg::Buffer(scratch_up)),
340                (2, KernelArg::Buffer(scratch_hidden)),
341                (3, KernelArg::Bytes(as_bytes(&gelu_params))),
342            ],
343            MTLSize::new(params.intermediate_dim as u64, 1, 1),
344            MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
345        );
346
347        // Barrier: hidden must complete before down_proj reads it.
348        encoder.memory_barrier();
349
350        // expert_out = down_proj @ hidden  (matvec: [input_dim, intermediate_dim] @ [intermediate_dim] -> [input_dim])
351        let down_params = GpuMatmulParams {
352            m: 1,
353            k: params.intermediate_dim as u32,
354            n: params.input_dim as u32,
355        };
356        encode_with_args(
357            encoder,
358            matvec_pipeline,
359            &[
360                (0, KernelArg::Buffer(ew.down_proj)),
361                (1, KernelArg::Buffer(scratch_hidden)),
362                (2, KernelArg::Buffer(scratch_expert)),
363                (3, KernelArg::Bytes(as_bytes(&down_params))),
364            ],
365            MTLSize::new(params.input_dim as u64, 1, 1),
366            MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
367        );
368
369        // Barrier: expert_out must complete before accumulate reads it.
370        encoder.memory_barrier();
371
372        // result += w * expert_out
373        let accum_params = GpuMoeAccumParams {
374            n_elements: params.input_dim as u32,
375            routing_weight: w,
376        };
377        encode_with_args(
378            encoder,
379            accum_pipeline,
380            &[
381                (0, KernelArg::Buffer(output)),
382                (1, KernelArg::Buffer(scratch_expert)),
383                (2, KernelArg::Bytes(as_bytes(&accum_params))),
384            ],
385            MTLSize::new(params.input_dim as u64, 1, 1),
386            MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
387        );
388    }
389
390    Ok(())
391}
392
393/// Encode a fused SwiGLU on a `[2*N]` gate_up buffer, producing `[N]` output.
394///
395/// Computes `output[i] = GELU(gate_up[i]) * gate_up[N + i]` for `i in 0..N`.
396///
397/// Uses the `moe_swiglu_fused` kernel from moe_dispatch.metal.
398///
399/// # Arguments
400/// * `encoder`        -- Command encoder to record into.
401/// * `registry`       -- Kernel registry for pipeline lookup.
402/// * `device`         -- Metal device reference.
403/// * `gate_up`        -- f32 buffer `[2 * n_elements]` (gate || up concatenated).
404/// * `output`         -- f32 buffer `[n_elements]` (output).
405/// * `n_elements`     -- Number of output elements (intermediate_dim).
406pub fn moe_swiglu_fused_encode(
407    encoder: &mut CommandEncoder,
408    registry: &mut KernelRegistry,
409    device: &metal::DeviceRef,
410    gate_up: &MlxBuffer,
411    output: &MlxBuffer,
412    n_elements: usize,
413) -> Result<()> {
414    if n_elements == 0 {
415        return Err(MlxError::InvalidArgument(
416            "moe_swiglu_fused_encode: n_elements must be > 0".into(),
417        ));
418    }
419    let gu_required = 2 * n_elements * std::mem::size_of::<f32>();
420    if gate_up.byte_len() < gu_required {
421        return Err(MlxError::InvalidArgument(format!(
422            "moe_swiglu_fused_encode: gate_up buffer too small: need {} bytes, have {}",
423            gu_required, gate_up.byte_len()
424        )));
425    }
426    let out_required = n_elements * std::mem::size_of::<f32>();
427    if output.byte_len() < out_required {
428        return Err(MlxError::InvalidArgument(format!(
429            "moe_swiglu_fused_encode: output buffer too small: need {} bytes, have {}",
430            out_required, output.byte_len()
431        )));
432    }
433
434    let pipeline = registry.get_pipeline("moe_swiglu_fused", device)?;
435    let params = GpuFusedGeluMulParams {
436        n_elements: n_elements as u32,
437    };
438    encode_with_args(
439        encoder,
440        pipeline,
441        &[
442            (0, KernelArg::Buffer(gate_up)),
443            (1, KernelArg::Buffer(output)),
444            (2, KernelArg::Bytes(as_bytes(&params))),
445        ],
446        MTLSize::new(n_elements as u64, 1, 1),
447        MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
448    );
449    Ok(())
450}
451
452/// Zero-initialize an f32 GPU buffer using the `zero_buffer` kernel.
453///
454/// This is useful for preparing an accumulator buffer before dispatching
455/// weighted accumulation passes.
456///
457/// # Arguments
458/// * `encoder`    — Command encoder to record into.
459/// * `registry`   — Kernel registry for pipeline lookup.
460/// * `device`     — Metal device reference.
461/// * `output`     — f32 buffer to zero, must be at least `n_elements * 4` bytes.
462/// * `n_elements` — Number of f32 elements to zero.
463pub fn moe_zero_buffer_encode(
464    encoder: &mut CommandEncoder,
465    registry: &mut KernelRegistry,
466    device: &metal::DeviceRef,
467    output: &MlxBuffer,
468    n_elements: usize,
469) -> Result<()> {
470    if n_elements == 0 {
471        return Err(MlxError::InvalidArgument(
472            "moe_zero_buffer_encode: n_elements must be > 0".into(),
473        ));
474    }
475    let required = n_elements * std::mem::size_of::<f32>();
476    if output.byte_len() < required {
477        return Err(MlxError::InvalidArgument(format!(
478            "moe_zero_buffer_encode: buffer too small: need {} bytes, have {}",
479            required, output.byte_len()
480        )));
481    }
482
483    let pipeline = registry.get_pipeline("zero_buffer", device)?;
484    let params = GpuZeroParams { n_elements: n_elements as u32 };
485    encode_with_args(
486        encoder,
487        pipeline,
488        &[
489            (0, KernelArg::Buffer(output)),
490            (1, KernelArg::Bytes(as_bytes(&params))),
491        ],
492        MTLSize::new(n_elements as u64, 1, 1),
493        MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
494    );
495    Ok(())
496}
497
498/// Encode a weighted accumulation: `accumulator[i] += routing_weight * expert_output[i]`.
499///
500/// Uses the `moe_accumulate` kernel from moe_dispatch.metal.
501///
502/// # Arguments
503/// * `encoder`        — Command encoder to record into.
504/// * `registry`       — Kernel registry for pipeline lookup.
505/// * `device`         — Metal device reference.
506/// * `accumulator`    — f32 buffer `[n_elements]`, in/out.
507/// * `expert_output`  — f32 buffer `[n_elements]`, input.
508/// * `routing_weight` — Scalar weight for this expert.
509/// * `n_elements`     — Number of f32 elements.
510pub fn moe_accumulate_encode(
511    encoder: &mut CommandEncoder,
512    registry: &mut KernelRegistry,
513    device: &metal::DeviceRef,
514    accumulator: &MlxBuffer,
515    expert_output: &MlxBuffer,
516    routing_weight: f32,
517    n_elements: usize,
518) -> Result<()> {
519    if n_elements == 0 {
520        return Err(MlxError::InvalidArgument(
521            "moe_accumulate_encode: n_elements must be > 0".into(),
522        ));
523    }
524    let required = n_elements * std::mem::size_of::<f32>();
525    if accumulator.byte_len() < required {
526        return Err(MlxError::InvalidArgument(format!(
527            "moe_accumulate_encode: accumulator too small: need {} bytes, have {}",
528            required, accumulator.byte_len()
529        )));
530    }
531    if expert_output.byte_len() < required {
532        return Err(MlxError::InvalidArgument(format!(
533            "moe_accumulate_encode: expert_output too small: need {} bytes, have {}",
534            required, expert_output.byte_len()
535        )));
536    }
537
538    let pipeline = registry.get_pipeline("moe_accumulate", device)?;
539    let params = GpuMoeAccumParams {
540        n_elements: n_elements as u32,
541        routing_weight,
542    };
543    encode_with_args(
544        encoder,
545        pipeline,
546        &[
547            (0, KernelArg::Buffer(accumulator)),
548            (1, KernelArg::Buffer(expert_output)),
549            (2, KernelArg::Bytes(as_bytes(&params))),
550        ],
551        MTLSize::new(n_elements as u64, 1, 1),
552        MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
553    );
554    Ok(())
555}
556
557/// Encode a batched SwiGLU across all top_k expert slots in one dispatch.
558///
559/// Takes a `[top_k, 2*intermediate]` gate_up buffer and produces
560/// `[top_k, intermediate]` output: `GELU(gate[i]) * up[i]` per slot.
561///
562/// Replaces top_k separate `moe_swiglu_fused_encode_offset` dispatches with 1.
563///
564/// # Arguments
565/// * `encoder`       -- Command encoder to record into.
566/// * `registry`      -- Kernel registry for pipeline lookup.
567/// * `device`        -- Metal device reference.
568/// * `gate_up`       -- f32 buffer `[top_k * 2 * intermediate]`.
569/// * `output`        -- f32 buffer `[top_k * intermediate]`.
570/// * `intermediate`  -- Intermediate dimension per expert.
571/// * `top_k`         -- Number of selected expert slots.
572pub fn moe_swiglu_batch_encode(
573    encoder: &mut CommandEncoder,
574    registry: &mut KernelRegistry,
575    device: &metal::DeviceRef,
576    gate_up: &MlxBuffer,
577    output: &MlxBuffer,
578    intermediate: usize,
579    top_k: usize,
580) -> Result<()> {
581    if intermediate == 0 || top_k == 0 {
582        return Err(MlxError::InvalidArgument(
583            "moe_swiglu_batch_encode: intermediate and top_k must be > 0".into(),
584        ));
585    }
586    let gu_required = top_k * 2 * intermediate * std::mem::size_of::<f32>();
587    if gate_up.byte_len() < gu_required {
588        return Err(MlxError::InvalidArgument(format!(
589            "moe_swiglu_batch_encode: gate_up too small: need {} bytes, have {}",
590            gu_required, gate_up.byte_len()
591        )));
592    }
593    let out_required = top_k * intermediate * std::mem::size_of::<f32>();
594    if output.byte_len() < out_required {
595        return Err(MlxError::InvalidArgument(format!(
596            "moe_swiglu_batch_encode: output too small: need {} bytes, have {}",
597            out_required, output.byte_len()
598        )));
599    }
600
601    let pipeline = registry.get_pipeline("moe_swiglu_batch", device)?;
602    let intermediate_bytes = (intermediate as u32).to_ne_bytes();
603    let top_k_bytes = (top_k as u32).to_ne_bytes();
604
605    encode_with_args(
606        encoder,
607        pipeline,
608        &[
609            (0, KernelArg::Buffer(gate_up)),
610            (1, KernelArg::Buffer(output)),
611            (2, KernelArg::Bytes(&intermediate_bytes)),
612            (3, KernelArg::Bytes(&top_k_bytes)),
613        ],
614        MTLSize::new(intermediate as u64, top_k as u64, 1),
615        MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
616    );
617    Ok(())
618}
619
620/// MSL-compatible struct for moe_weighted_sum kernel.
621#[repr(C)]
622#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
623struct GpuMoeWeightedSumParams {
624    hidden_size: u32,
625    top_k: u32,
626}
627
628/// Encode a weighted sum of all top_k expert outputs in one dispatch.
629///
630/// Replaces the zero_buffer + top_k * moe_accumulate pattern with 1 dispatch.
631/// The weights buffer must contain pre-scaled routing weights for all top_k
632/// experts (i.e. `routing_weight * per_expert_scale`).
633///
634/// # Arguments
635/// * `encoder`        -- Command encoder to record into.
636/// * `registry`       -- Kernel registry for pipeline lookup.
637/// * `device`         -- Metal device reference.
638/// * `expert_outputs` -- f32 buffer `[top_k * hidden_size]`.
639/// * `weights`        -- f32 buffer `[top_k]` (pre-scaled routing weights).
640/// * `output`         -- f32 buffer `[hidden_size]` (output weighted sum).
641/// * `hidden_size`    -- Hidden dimension.
642/// * `top_k`          -- Number of selected expert slots.
643pub fn moe_weighted_sum_encode(
644    encoder: &mut CommandEncoder,
645    registry: &mut KernelRegistry,
646    device: &metal::DeviceRef,
647    expert_outputs: &MlxBuffer,
648    weights: &MlxBuffer,
649    output: &MlxBuffer,
650    hidden_size: usize,
651    top_k: usize,
652) -> Result<()> {
653    if hidden_size == 0 || top_k == 0 {
654        return Err(MlxError::InvalidArgument(
655            "moe_weighted_sum_encode: hidden_size and top_k must be > 0".into(),
656        ));
657    }
658    let expert_required = top_k * hidden_size * std::mem::size_of::<f32>();
659    if expert_outputs.byte_len() < expert_required {
660        return Err(MlxError::InvalidArgument(format!(
661            "moe_weighted_sum_encode: expert_outputs too small: need {} bytes, have {}",
662            expert_required, expert_outputs.byte_len()
663        )));
664    }
665    let weights_required = top_k * std::mem::size_of::<f32>();
666    if weights.byte_len() < weights_required {
667        return Err(MlxError::InvalidArgument(format!(
668            "moe_weighted_sum_encode: weights too small: need {} bytes, have {}",
669            weights_required, weights.byte_len()
670        )));
671    }
672    let out_required = hidden_size * std::mem::size_of::<f32>();
673    if output.byte_len() < out_required {
674        return Err(MlxError::InvalidArgument(format!(
675            "moe_weighted_sum_encode: output too small: need {} bytes, have {}",
676            out_required, output.byte_len()
677        )));
678    }
679
680    let pipeline = registry.get_pipeline("moe_weighted_sum", device)?;
681    let params = GpuMoeWeightedSumParams {
682        hidden_size: hidden_size as u32,
683        top_k: top_k as u32,
684    };
685    encode_with_args(
686        encoder,
687        pipeline,
688        &[
689            (0, KernelArg::Buffer(expert_outputs)),
690            (1, KernelArg::Buffer(weights)),
691            (2, KernelArg::Buffer(output)),
692            (3, KernelArg::Bytes(as_bytes(&params))),
693        ],
694        MTLSize::new(hidden_size as u64, 1, 1),
695        MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
696    );
697    Ok(())
698}
699
700/// MSL-compatible struct for moe_gather_topk_weights kernel.
701#[repr(C)]
702#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
703struct GpuMoeGatherTopkParams {
704    n_experts: u32,
705    top_k: u32,
706}
707
708/// Encode a GPU-side MoE top-K routing gather.
709///
710/// Reads softmax probs and sorted indices (both on GPU from prior dispatches),
711/// gathers the top-K expert IDs and their weights, applies per_expert_scale,
712/// and renormalizes.  This eliminates the CPU readback that previously forced
713/// a session break between S1 and S4.
714///
715/// # Arguments
716/// * `encoder`          -- Command encoder.
717/// * `registry`         -- Kernel registry.
718/// * `device`           -- Metal device.
719/// * `softmax_probs`    -- f32 `[n_experts]` (output of dispatch_softmax).
720/// * `sorted_indices`   -- u32 `[n_experts]` (output of dispatch_argsort_desc_f32).
721/// * `per_expert_scale` -- f32 `[n_experts]` (learned per-expert scale).
722/// * `out_expert_ids`   -- u32 `[top_k]` (output: selected expert indices).
723/// * `out_weights`      -- f32 `[top_k]` (output: pre-scaled routing weights).
724/// * `n_experts`        -- Total number of experts.
725/// * `top_k`            -- Number of experts to select.
726pub fn moe_gather_topk_weights_encode(
727    encoder: &mut CommandEncoder,
728    registry: &mut KernelRegistry,
729    device: &metal::DeviceRef,
730    softmax_probs: &MlxBuffer,
731    sorted_indices: &MlxBuffer,
732    per_expert_scale: &MlxBuffer,
733    out_expert_ids: &MlxBuffer,
734    out_weights: &MlxBuffer,
735    n_experts: usize,
736    top_k: usize,
737) -> Result<()> {
738    if n_experts == 0 || top_k == 0 {
739        return Err(MlxError::InvalidArgument(
740            "moe_gather_topk_weights: n_experts and top_k must be > 0".into(),
741        ));
742    }
743    if top_k > n_experts {
744        return Err(MlxError::InvalidArgument(format!(
745            "moe_gather_topk_weights: top_k ({}) > n_experts ({})",
746            top_k, n_experts,
747        )));
748    }
749    if top_k > 8 {
750        return Err(MlxError::InvalidArgument(format!(
751            "moe_gather_topk_weights: top_k ({}) > 8 (shader fixed-size array limit)",
752            top_k,
753        )));
754    }
755
756    let f32_size = std::mem::size_of::<f32>();
757    let u32_size = std::mem::size_of::<u32>();
758    if softmax_probs.byte_len() < n_experts * f32_size {
759        return Err(MlxError::InvalidArgument("softmax_probs too small".into()));
760    }
761    if sorted_indices.byte_len() < n_experts * u32_size {
762        return Err(MlxError::InvalidArgument("sorted_indices too small".into()));
763    }
764    if per_expert_scale.byte_len() < n_experts * f32_size {
765        return Err(MlxError::InvalidArgument("per_expert_scale too small".into()));
766    }
767    if out_expert_ids.byte_len() < top_k * u32_size {
768        return Err(MlxError::InvalidArgument("out_expert_ids too small".into()));
769    }
770    if out_weights.byte_len() < top_k * f32_size {
771        return Err(MlxError::InvalidArgument("out_weights too small".into()));
772    }
773
774    let pipeline = registry.get_pipeline("moe_gather_topk_weights", device)?;
775    let params = GpuMoeGatherTopkParams {
776        n_experts: n_experts as u32,
777        top_k: top_k as u32,
778    };
779    encode_with_args(
780        encoder,
781        pipeline,
782        &[
783            (0, KernelArg::Buffer(softmax_probs)),
784            (1, KernelArg::Buffer(sorted_indices)),
785            (2, KernelArg::Buffer(per_expert_scale)),
786            (3, KernelArg::Buffer(out_expert_ids)),
787            (4, KernelArg::Buffer(out_weights)),
788            (5, KernelArg::Bytes(as_bytes(&params))),
789        ],
790        MTLSize::new(1, 1, 1),  // single thread
791        MTLSize::new(1, 1, 1),
792    );
793    Ok(())
794}
795
796/// Like [`moe_swiglu_fused_encode`] but reads from `gate_up` at `gu_byte_offset`
797/// and writes to `output` at `out_byte_offset`.
798///
799/// This enables operating on slices within larger buffers (e.g. the _id kernel
800/// output which contains top_k rows of gate_up data).
801#[allow(clippy::too_many_arguments)]
802pub fn moe_swiglu_fused_encode_offset(
803    encoder: &mut CommandEncoder,
804    registry: &mut KernelRegistry,
805    device: &metal::DeviceRef,
806    gate_up: &MlxBuffer,
807    gu_byte_offset: usize,
808    output: &MlxBuffer,
809    out_byte_offset: usize,
810    n_elements: usize,
811) -> Result<()> {
812    if n_elements == 0 {
813        return Err(MlxError::InvalidArgument(
814            "moe_swiglu_fused_encode_offset: n_elements must be > 0".into(),
815        ));
816    }
817    let gu_required = gu_byte_offset + 2 * n_elements * std::mem::size_of::<f32>();
818    if gate_up.byte_len() < gu_required {
819        return Err(MlxError::InvalidArgument(format!(
820            "moe_swiglu_fused_encode_offset: gate_up buffer too small: need {} bytes (offset {}), have {}",
821            gu_required, gu_byte_offset, gate_up.byte_len()
822        )));
823    }
824    let out_required = out_byte_offset + n_elements * std::mem::size_of::<f32>();
825    if output.byte_len() < out_required {
826        return Err(MlxError::InvalidArgument(format!(
827            "moe_swiglu_fused_encode_offset: output buffer too small: need {} bytes (offset {}), have {}",
828            out_required, out_byte_offset, output.byte_len()
829        )));
830    }
831
832    let pipeline = registry.get_pipeline("moe_swiglu_fused", device)?;
833    let params = GpuFusedGeluMulParams {
834        n_elements: n_elements as u32,
835    };
836    encode_with_args(
837        encoder,
838        pipeline,
839        &[
840            (0, KernelArg::BufferWithOffset(gate_up, gu_byte_offset as u64)),
841            (1, KernelArg::BufferWithOffset(output, out_byte_offset as u64)),
842            (2, KernelArg::Bytes(as_bytes(&params))),
843        ],
844        MTLSize::new(n_elements as u64, 1, 1),
845        MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
846    );
847    Ok(())
848}
849
850/// Like [`moe_accumulate_encode`] but reads `expert_output` from `src_byte_offset`.
851///
852/// This enables reading from a slice within a larger buffer (e.g. the down _id
853/// kernel output which contains top_k rows of hidden data).
854#[allow(clippy::too_many_arguments)]
855pub fn moe_accumulate_encode_offset(
856    encoder: &mut CommandEncoder,
857    registry: &mut KernelRegistry,
858    device: &metal::DeviceRef,
859    accumulator: &MlxBuffer,
860    expert_output: &MlxBuffer,
861    src_byte_offset: usize,
862    routing_weight: f32,
863    n_elements: usize,
864) -> Result<()> {
865    if n_elements == 0 {
866        return Err(MlxError::InvalidArgument(
867            "moe_accumulate_encode_offset: n_elements must be > 0".into(),
868        ));
869    }
870    let required = n_elements * std::mem::size_of::<f32>();
871    if accumulator.byte_len() < required {
872        return Err(MlxError::InvalidArgument(format!(
873            "moe_accumulate_encode_offset: accumulator too small: need {} bytes, have {}",
874            required, accumulator.byte_len()
875        )));
876    }
877    let src_required = src_byte_offset + required;
878    if expert_output.byte_len() < src_required {
879        return Err(MlxError::InvalidArgument(format!(
880            "moe_accumulate_encode_offset: expert_output too small: need {} bytes (offset {}), have {}",
881            src_required, src_byte_offset, expert_output.byte_len()
882        )));
883    }
884
885    let pipeline = registry.get_pipeline("moe_accumulate", device)?;
886    let params = GpuMoeAccumParams {
887        n_elements: n_elements as u32,
888        routing_weight,
889    };
890    encode_with_args(
891        encoder,
892        pipeline,
893        &[
894            (0, KernelArg::Buffer(accumulator)),
895            (1, KernelArg::BufferWithOffset(expert_output, src_byte_offset as u64)),
896            (2, KernelArg::Bytes(as_bytes(&params))),
897        ],
898        MTLSize::new(n_elements as u64, 1, 1),
899        MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
900    );
901    Ok(())
902}
903
904/// GPU params for batched SwiGLU over multiple tokens.
905#[repr(C)]
906#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
907struct GpuMoeSwigluSeqParams {
908    intermediate: u32,
909    top_k: u32,
910    n_tokens: u32,
911}
912
913/// Multi-token SwiGLU for batched prefill.
914///
915/// Input:  `[n_tokens, top_k, 2*intermediate]`
916/// Output: `[n_tokens, top_k, intermediate]`
917#[allow(clippy::too_many_arguments)]
918pub fn moe_swiglu_seq_encode(
919    encoder: &mut CommandEncoder,
920    registry: &mut KernelRegistry,
921    device: &metal::DeviceRef,
922    gate_up: &MlxBuffer,
923    output: &MlxBuffer,
924    intermediate: usize,
925    top_k: usize,
926    n_tokens: usize,
927) -> Result<()> {
928    if intermediate == 0 || top_k == 0 || n_tokens == 0 {
929        return Err(MlxError::InvalidArgument(
930            "moe_swiglu_seq_encode: all dims must be > 0".into(),
931        ));
932    }
933    let gu_required = n_tokens * top_k * 2 * intermediate * std::mem::size_of::<f32>();
934    if gate_up.byte_len() < gu_required {
935        return Err(MlxError::InvalidArgument(format!(
936            "moe_swiglu_seq_encode: gate_up too small: need {} bytes, have {}",
937            gu_required, gate_up.byte_len()
938        )));
939    }
940    let out_required = n_tokens * top_k * intermediate * std::mem::size_of::<f32>();
941    if output.byte_len() < out_required {
942        return Err(MlxError::InvalidArgument(format!(
943            "moe_swiglu_seq_encode: output too small: need {} bytes, have {}",
944            out_required, output.byte_len()
945        )));
946    }
947
948    let pipeline = registry.get_pipeline("moe_swiglu_seq", device)?;
949    let gpu_params = GpuMoeSwigluSeqParams {
950        intermediate: intermediate as u32,
951        top_k: top_k as u32,
952        n_tokens: n_tokens as u32,
953    };
954
955    encode_with_args(
956        encoder,
957        pipeline,
958        &[
959            (0, KernelArg::Buffer(gate_up)),
960            (1, KernelArg::Buffer(output)),
961            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
962        ],
963        MTLSize::new(intermediate as u64, top_k as u64, n_tokens as u64),
964        MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
965    );
966    Ok(())
967}
968
969/// GPU params for batched weighted-sum over multiple tokens.
970#[repr(C)]
971#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
972struct GpuMoeWeightedSumSeqParams {
973    hidden_size: u32,
974    top_k: u32,
975    n_tokens: u32,
976}
977
978/// Multi-token weighted sum of expert outputs for batched prefill.
979///
980/// * `expert_outputs` — `[n_tokens, top_k, hidden_size]`
981/// * `weights`        — `[n_tokens, top_k]`
982/// * `output`         — `[n_tokens, hidden_size]`
983#[allow(clippy::too_many_arguments)]
984pub fn moe_weighted_sum_seq_encode(
985    encoder: &mut CommandEncoder,
986    registry: &mut KernelRegistry,
987    device: &metal::DeviceRef,
988    expert_outputs: &MlxBuffer,
989    weights: &MlxBuffer,
990    output: &MlxBuffer,
991    hidden_size: usize,
992    top_k: usize,
993    n_tokens: usize,
994) -> Result<()> {
995    if hidden_size == 0 || top_k == 0 || n_tokens == 0 {
996        return Err(MlxError::InvalidArgument(
997            "moe_weighted_sum_seq_encode: all dims must be > 0".into(),
998        ));
999    }
1000    let expert_required = n_tokens * top_k * hidden_size * std::mem::size_of::<f32>();
1001    if expert_outputs.byte_len() < expert_required {
1002        return Err(MlxError::InvalidArgument(format!(
1003            "moe_weighted_sum_seq_encode: expert_outputs too small: need {} bytes, have {}",
1004            expert_required, expert_outputs.byte_len()
1005        )));
1006    }
1007    let weights_required = n_tokens * top_k * std::mem::size_of::<f32>();
1008    if weights.byte_len() < weights_required {
1009        return Err(MlxError::InvalidArgument(format!(
1010            "moe_weighted_sum_seq_encode: weights too small: need {} bytes, have {}",
1011            weights_required, weights.byte_len()
1012        )));
1013    }
1014    let out_required = n_tokens * hidden_size * std::mem::size_of::<f32>();
1015    if output.byte_len() < out_required {
1016        return Err(MlxError::InvalidArgument(format!(
1017            "moe_weighted_sum_seq_encode: output too small: need {} bytes, have {}",
1018            out_required, output.byte_len()
1019        )));
1020    }
1021
1022    let pipeline = registry.get_pipeline("moe_weighted_sum_seq", device)?;
1023    let gpu_params = GpuMoeWeightedSumSeqParams {
1024        hidden_size: hidden_size as u32,
1025        top_k: top_k as u32,
1026        n_tokens: n_tokens as u32,
1027    };
1028
1029    encode_with_args(
1030        encoder,
1031        pipeline,
1032        &[
1033            (0, KernelArg::Buffer(expert_outputs)),
1034            (1, KernelArg::Buffer(weights)),
1035            (2, KernelArg::Buffer(output)),
1036            (3, KernelArg::Bytes(as_bytes(&gpu_params))),
1037        ],
1038        MTLSize::new(hidden_size as u64, n_tokens as u64, 1),
1039        MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
1040    );
1041    Ok(())
1042}