Expand description
GPU-accelerated MoE expert dispatch (Stage 1: loop over selected experts).
For each of the K selected experts, runs: gate_out = gate_proj_e(x) [input_dim -> intermediate_dim] up_out = up_proj_e(x) [input_dim -> intermediate_dim] hidden = GELU(gate_out) * up_out expert_out = down_proj_e(hidden) [intermediate_dim -> input_dim] result += routing_weight_e * expert_out
Stage 1 uses individual kernel dispatches per expert and per projection. The projections use float matmul (caller dequantizes or provides float weights). Stage 2 optimization (Epic 6) would fuse these.
This module provides the high-level moe_dispatch function that
orchestrates the per-expert loop, using the fused_gelu_mul and
moe_accumulate shaders from moe_dispatch.metal.
Structs§
- Expert
Weights - A single expert’s weight matrices (float32, pre-dequantized or float).
- MoeDispatch
Params - Parameters for MoE dispatch.
Functions§
- fused_
gelu_ mul_ bf16_ encode - Encode a fused GELU-multiply on bf16 buffers.
- moe_
accumulate_ encode - Encode a weighted accumulation:
accumulator[i] += routing_weight * expert_output[i]. - moe_
accumulate_ encode_ offset - Like
moe_accumulate_encodebut readsexpert_outputfromsrc_byte_offset. - moe_
dispatch - Encode MoE dispatch: loop over selected experts, run FFN, accumulate.
- moe_
gather_ topk_ weights_ encode - Encode a GPU-side MoE top-K routing gather.
- moe_
swiglu_ batch_ encode - Encode a batched SwiGLU across all top_k expert slots in one dispatch.
- moe_
swiglu_ fused_ encode - Encode a fused SwiGLU on a
[2*N]gate_up buffer, producing[N]output. - moe_
swiglu_ fused_ encode_ offset - Like
moe_swiglu_fused_encodebut reads fromgate_upatgu_byte_offsetand writes tooutputatout_byte_offset. - moe_
swiglu_ seq_ backward_ encode - Backward of
moe_swiglu_seq— single fused kernel writes both gate and up gradients into the suppliedd_gate_upbuffer (same layout as forwardgate_up). - moe_
swiglu_ seq_ bf16_ encode - Multi-token SwiGLU for batched prefill (bf16 I/O, f32 accumulator).
- moe_
swiglu_ seq_ encode - Multi-token SwiGLU for batched prefill.
- moe_
weighted_ sum_ encode - Encode a weighted sum of all top_k expert outputs in one dispatch.
- moe_
weighted_ sum_ seq_ backward_ outputs_ encode - Backward of
moe_weighted_sum_seqw.r.t.expert_outputs. - moe_
weighted_ sum_ seq_ backward_ weights_ encode - Backward of
moe_weighted_sum_seqw.r.t.weights. - moe_
weighted_ sum_ seq_ bf16_ input_ encode - Multi-token weighted sum of expert outputs for batched prefill (bf16 inputs).
- moe_
weighted_ sum_ seq_ encode - Multi-token weighted sum of expert outputs for batched prefill.
- moe_
zero_ buffer_ encode - Zero-initialize an f32 GPU buffer using the
zero_bufferkernel.