Skip to main content

moe_dispatch

Function moe_dispatch 

Source
pub fn moe_dispatch(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    input: &MlxBuffer,
    expert_weights: &[ExpertWeights<'_>],
    routing_weights: &[f32],
    output: &MlxBuffer,
    scratch_gate: &MlxBuffer,
    scratch_up: &MlxBuffer,
    scratch_hidden: &MlxBuffer,
    scratch_expert: &MlxBuffer,
    params: &MoeDispatchParams,
) -> Result<()>
Expand description

Encode MoE dispatch: loop over selected experts, run FFN, accumulate.

This is the Stage 1 implementation that loops over each selected expert and dispatches individual compute passes for each projection.

§Buffer expectations

  • input — f32, [input_dim] (single token hidden state)
  • expert_weights — slice of n_selected expert weight structs, each containing gate_proj, up_proj, down_proj as f32 buffers
  • routing_weights — f32, [n_selected] (softmax routing weights from moe_gate)
  • output — f32, [input_dim] (output, will be zero-initialized)
  • scratch_gate — f32, [intermediate_dim] (scratch buffer for gate_proj output)
  • scratch_up — f32, [intermediate_dim] (scratch buffer for up_proj output)
  • scratch_hidden — f32, [intermediate_dim] (scratch buffer for GELU*up output)
  • scratch_expert — f32, [input_dim] (scratch buffer for down_proj output)

§Design Notes

The caller provides scratch buffers to avoid allocating inside the encoding loop. These can come from a MlxBufferPool.

For the matrix projections, we use a naive matmul kernel (single-token, M=1, so it’s really a matvec). The quantized_matmul from Story 1.2 would be used when weights remain quantized. Stage 1 assumes float weights for simplicity.

§Errors

Returns MlxError::InvalidArgument if parameters are invalid.