pub fn moe_weighted_sum_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &DeviceRef,
expert_outputs: &MlxBuffer,
weights: &MlxBuffer,
output: &MlxBuffer,
hidden_size: usize,
top_k: usize,
) -> Result<()>Expand description
Encode a weighted sum of all top_k expert outputs in one dispatch.
Replaces the zero_buffer + top_k * moe_accumulate pattern with 1 dispatch.
The weights buffer must contain pre-scaled routing weights for all top_k
experts (i.e. routing_weight * per_expert_scale).
§Arguments
encoder– Command encoder to record into.registry– Kernel registry for pipeline lookup.device– Metal device reference.expert_outputs– f32 buffer[top_k * hidden_size].weights– f32 buffer[top_k](pre-scaled routing weights).output– f32 buffer[hidden_size](output weighted sum).hidden_size– Hidden dimension.top_k– Number of selected expert slots.