Skip to main content

moe_weighted_sum_seq_encode

Function moe_weighted_sum_seq_encode 

Source
pub fn moe_weighted_sum_seq_encode(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    expert_outputs: &MlxBuffer,
    weights: &MlxBuffer,
    output: &MlxBuffer,
    hidden_size: usize,
    top_k: usize,
    n_tokens: usize,
) -> Result<()>
Expand description

Multi-token weighted sum of expert outputs for batched prefill.

  • expert_outputs[n_tokens, top_k, hidden_size]
  • weights[n_tokens, top_k]
  • output[n_tokens, hidden_size]