pub fn moe_accumulate_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &DeviceRef,
accumulator: &MlxBuffer,
expert_output: &MlxBuffer,
routing_weight: f32,
n_elements: usize,
) -> Result<()>Expand description
Encode a weighted accumulation: accumulator[i] += routing_weight * expert_output[i].
Uses the moe_accumulate kernel from moe_dispatch.metal.
§Arguments
encoder— Command encoder to record into.registry— Kernel registry for pipeline lookup.device— Metal device reference.accumulator— f32 buffer[n_elements], in/out.expert_output— f32 buffer[n_elements], input.routing_weight— Scalar weight for this expert.n_elements— Number of f32 elements.