pub fn moe_weighted_sum_seq_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &DeviceRef,
expert_outputs: &MlxBuffer,
weights: &MlxBuffer,
output: &MlxBuffer,
hidden_size: usize,
top_k: usize,
n_tokens: usize,
) -> Result<()>Expand description
Multi-token weighted sum of expert outputs for batched prefill.
expert_outputs—[n_tokens, top_k, hidden_size]weights—[n_tokens, top_k]output—[n_tokens, hidden_size]