Skip to main content

moe_weighted_sum_encode

Function moe_weighted_sum_encode 

Source
pub fn moe_weighted_sum_encode(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    expert_outputs: &MlxBuffer,
    weights: &MlxBuffer,
    output: &MlxBuffer,
    hidden_size: usize,
    top_k: usize,
) -> Result<()>
Expand description

Encode a weighted sum of all top_k expert outputs in one dispatch.

Replaces the zero_buffer + top_k * moe_accumulate pattern with 1 dispatch. The weights buffer must contain pre-scaled routing weights for all top_k experts (i.e. routing_weight * per_expert_scale).

§Arguments

  • encoder – Command encoder to record into.
  • registry – Kernel registry for pipeline lookup.
  • device – Metal device reference.
  • expert_outputs – f32 buffer [top_k * hidden_size].
  • weights – f32 buffer [top_k] (pre-scaled routing weights).
  • output – f32 buffer [hidden_size] (output weighted sum).
  • hidden_size – Hidden dimension.
  • top_k – Number of selected expert slots.