Skip to main content

moe_swiglu_batch_encode

Function moe_swiglu_batch_encode 

Source
pub fn moe_swiglu_batch_encode(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    gate_up: &MlxBuffer,
    output: &MlxBuffer,
    intermediate: usize,
    top_k: usize,
) -> Result<()>
Expand description

Encode a batched SwiGLU across all top_k expert slots in one dispatch.

Takes a [top_k, 2*intermediate] gate_up buffer and produces [top_k, intermediate] output: GELU(gate[i]) * up[i] per slot.

Replaces top_k separate moe_swiglu_fused_encode_offset dispatches with 1.

§Arguments

  • encoder – Command encoder to record into.
  • registry – Kernel registry for pipeline lookup.
  • device – Metal device reference.
  • gate_up – f32 buffer [top_k * 2 * intermediate].
  • output – f32 buffer [top_k * intermediate].
  • intermediate – Intermediate dimension per expert.
  • top_k – Number of selected expert slots.