pub fn moe_dispatch(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &DeviceRef,
input: &MlxBuffer,
expert_weights: &[ExpertWeights<'_>],
routing_weights: &[f32],
output: &MlxBuffer,
scratch_gate: &MlxBuffer,
scratch_up: &MlxBuffer,
scratch_hidden: &MlxBuffer,
scratch_expert: &MlxBuffer,
params: &MoeDispatchParams,
) -> Result<()>Expand description
Encode MoE dispatch: loop over selected experts, run FFN, accumulate.
This is the Stage 1 implementation that loops over each selected expert and dispatches individual compute passes for each projection.
§Buffer expectations
input— f32,[input_dim](single token hidden state)expert_weights— slice ofn_selectedexpert weight structs, each containing gate_proj, up_proj, down_proj as f32 buffersrouting_weights— f32,[n_selected](softmax routing weights from moe_gate)output— f32,[input_dim](output, will be zero-initialized)scratch_gate— f32,[intermediate_dim](scratch buffer for gate_proj output)scratch_up— f32,[intermediate_dim](scratch buffer for up_proj output)scratch_hidden— f32,[intermediate_dim](scratch buffer for GELU*up output)scratch_expert— f32,[input_dim](scratch buffer for down_proj output)
§Design Notes
The caller provides scratch buffers to avoid allocating inside the
encoding loop. These can come from a MlxBufferPool.
For the matrix projections, we use a naive matmul kernel (single-token, M=1, so it’s really a matvec). The quantized_matmul from Story 1.2 would be used when weights remain quantized. Stage 1 assumes float weights for simplicity.
§Errors
Returns MlxError::InvalidArgument if parameters are invalid.