Skip to main content

moe_accumulate_encode

Function moe_accumulate_encode 

Source
pub fn moe_accumulate_encode(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    accumulator: &MlxBuffer,
    expert_output: &MlxBuffer,
    routing_weight: f32,
    n_elements: usize,
) -> Result<()>
Expand description

Encode a weighted accumulation: accumulator[i] += routing_weight * expert_output[i].

Uses the moe_accumulate kernel from moe_dispatch.metal.

§Arguments

  • encoder — Command encoder to record into.
  • registry — Kernel registry for pipeline lookup.
  • device — Metal device reference.
  • accumulator — f32 buffer [n_elements], in/out.
  • expert_output — f32 buffer [n_elements], input.
  • routing_weight — Scalar weight for this expert.
  • n_elements — Number of f32 elements.