Skip to main content

mlx_native/ops/
moe_gate.rs

1//! GPU-accelerated MoE gating: parallel top-K expert selection with softmax
2//! routing.
3//!
4//! One threadgroup per token (grid = seq_len × 1 × 1), 128 threads per group.
5//! Supports bf16 hidden state input, f32 router weights, and per-expert scale.
6//!
7//! Designed for Gemma 4: 128 experts, top-8 routing, hidden_dim=2816.
8
9use metal::MTLSize;
10
11use crate::buffer::MlxBuffer;
12use crate::encoder::CommandEncoder;
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15
16/// Parameters for MoE gate routing.
17pub struct MoeGateParams {
18    /// Hidden state dimension (e.g. 2816 for Gemma 4).
19    pub hidden_dim: usize,
20    /// Total number of experts (e.g. 128 for Gemma 4).
21    pub n_experts: usize,
22    /// Number of experts to select (e.g. 8 for Gemma 4).
23    pub top_k: usize,
24    /// Number of tokens in the sequence (seq_len >= 1).
25    pub seq_len: usize,
26    /// RMS norm epsilon (e.g. 1e-6).
27    pub rms_eps: f32,
28}
29
30/// Encode a parallel MoE gate operation.
31///
32/// Launches one threadgroup per token; each threadgroup runs 128 threads that
33/// cooperate on:
34///   1. RMS-Norm of the token's hidden state.
35///   2. Router matmul (each thread handles ⌈n_experts/128⌉ experts).
36///   3. Top-K insertion sort + softmax + per_expert_scale (single thread).
37///
38/// # Buffer expectations
39///
40/// * `hidden_state`      — bf16, `[seq_len, hidden_dim]`
41/// * `router_weights`    — f32,  `[n_experts, hidden_dim]` (row-major, pre-cached on GPU)
42/// * `norm_weight`       — f32,  `[hidden_dim]` (RMS norm learned weight)
43/// * `per_expert_scale`  — f32,  `[n_experts]`
44/// * `out_expert_ids`    — u32,  `[seq_len, top_k]`  (output)
45/// * `out_weights`       — f32,  `[seq_len, top_k]`  (output)
46///
47/// # Errors
48///
49/// Returns `MlxError::InvalidArgument` if any parameter or buffer is invalid.
50#[allow(clippy::too_many_arguments)]
51pub fn moe_gate(
52    encoder: &mut CommandEncoder,
53    registry: &mut KernelRegistry,
54    device: &metal::DeviceRef,
55    hidden_state: &MlxBuffer,
56    router_weights: &MlxBuffer,
57    norm_weight: &MlxBuffer,
58    per_expert_scale: &MlxBuffer,
59    out_expert_ids: &MlxBuffer,
60    out_weights: &MlxBuffer,
61    params: &MoeGateParams,
62) -> Result<()> {
63    // --- Validation ---
64    if params.hidden_dim == 0 {
65        return Err(MlxError::InvalidArgument(
66            "moe_gate: hidden_dim must be > 0".into(),
67        ));
68    }
69    if params.n_experts == 0 {
70        return Err(MlxError::InvalidArgument(
71            "moe_gate: n_experts must be > 0".into(),
72        ));
73    }
74    if params.top_k == 0 {
75        return Err(MlxError::InvalidArgument(
76            "moe_gate: top_k must be > 0".into(),
77        ));
78    }
79    if params.seq_len == 0 {
80        return Err(MlxError::InvalidArgument(
81            "moe_gate: seq_len must be > 0".into(),
82        ));
83    }
84    if params.top_k > params.n_experts {
85        return Err(MlxError::InvalidArgument(format!(
86            "moe_gate: top_k ({}) must be <= n_experts ({})",
87            params.top_k, params.n_experts
88        )));
89    }
90    if params.n_experts > 128 {
91        return Err(MlxError::InvalidArgument(format!(
92            "moe_gate: n_experts ({}) exceeds max 128 (shader fixed-size array limit)",
93            params.n_experts
94        )));
95    }
96
97    // bf16 elements are 2 bytes each
98    let bf16_size = 2usize;
99    let f32_size = std::mem::size_of::<f32>();
100    let u32_size = std::mem::size_of::<u32>();
101
102    let expected_hidden_bytes = params.seq_len * params.hidden_dim * bf16_size;
103    if hidden_state.byte_len() < expected_hidden_bytes {
104        return Err(MlxError::InvalidArgument(format!(
105            "moe_gate: hidden_state buffer too small: need {} bytes, have {}",
106            expected_hidden_bytes,
107            hidden_state.byte_len()
108        )));
109    }
110
111    let expected_router_bytes = params.n_experts * params.hidden_dim * f32_size;
112    if router_weights.byte_len() < expected_router_bytes {
113        return Err(MlxError::InvalidArgument(format!(
114            "moe_gate: router_weights buffer too small: need {} bytes, have {}",
115            expected_router_bytes,
116            router_weights.byte_len()
117        )));
118    }
119
120    let expected_norm_bytes = params.hidden_dim * f32_size;
121    if norm_weight.byte_len() < expected_norm_bytes {
122        return Err(MlxError::InvalidArgument(format!(
123            "moe_gate: norm_weight buffer too small: need {} bytes, have {}",
124            expected_norm_bytes,
125            norm_weight.byte_len()
126        )));
127    }
128
129    let expected_scale_bytes = params.n_experts * f32_size;
130    if per_expert_scale.byte_len() < expected_scale_bytes {
131        return Err(MlxError::InvalidArgument(format!(
132            "moe_gate: per_expert_scale buffer too small: need {} bytes, have {}",
133            expected_scale_bytes,
134            per_expert_scale.byte_len()
135        )));
136    }
137
138    let expected_ids_bytes = params.seq_len * params.top_k * u32_size;
139    if out_expert_ids.byte_len() < expected_ids_bytes {
140        return Err(MlxError::InvalidArgument(format!(
141            "moe_gate: out_expert_ids buffer too small: need {} bytes, have {}",
142            expected_ids_bytes,
143            out_expert_ids.byte_len()
144        )));
145    }
146
147    let expected_weights_bytes = params.seq_len * params.top_k * f32_size;
148    if out_weights.byte_len() < expected_weights_bytes {
149        return Err(MlxError::InvalidArgument(format!(
150            "moe_gate: out_weights buffer too small: need {} bytes, have {}",
151            expected_weights_bytes,
152            out_weights.byte_len()
153        )));
154    }
155
156    // --- Kernel dispatch ---
157    let pipeline = registry.get_pipeline("moe_gate", device)?;
158
159    // 128 threads per threadgroup — one per expert for the matmul phase.
160    // Must be a power of 2 for the tree-reduction in RMS norm.
161    let tg_threads: u64 = 128;
162
163    // One threadgroup per token.
164    let threadgroups = MTLSize::new(params.seq_len as u64, 1, 1);
165    let threadgroup_size = MTLSize::new(tg_threads, 1, 1);
166
167    // Threadgroup shared memory layout (see moe_gate.metal):
168    //   [0 .. hidden_dim)                   — f32 normed hidden state
169    //   [hidden_dim .. hidden_dim+n_experts) — f32 router logits / reduction scratch
170    //
171    // Both regions are large enough for the RMS reduction (uses logit region as
172    // scratch with tg_size=128 slots, which fits within n_experts=128 floats).
173    let shared_bytes =
174        ((params.hidden_dim + params.n_experts) * std::mem::size_of::<f32>()) as u64;
175
176    // Scalar constants passed as inline bytes via set_bytes.
177    let hidden_dim_u32 = params.hidden_dim as u32;
178    let n_experts_u32  = params.n_experts  as u32;
179    let top_k_u32      = params.top_k      as u32;
180    let rms_eps_f32    = params.rms_eps;
181
182    use crate::encoder::{KernelArg, as_bytes};
183
184    encoder.encode_threadgroups_with_args_and_shared(
185        pipeline,
186        &[
187            (0, KernelArg::Buffer(hidden_state)),
188            (1, KernelArg::Buffer(router_weights)),
189            (2, KernelArg::Buffer(norm_weight)),
190            (3, KernelArg::Buffer(per_expert_scale)),
191            (4, KernelArg::Buffer(out_expert_ids)),
192            (5, KernelArg::Buffer(out_weights)),
193            (6, KernelArg::Bytes(as_bytes(&hidden_dim_u32))),
194            (7, KernelArg::Bytes(as_bytes(&n_experts_u32))),
195            (8, KernelArg::Bytes(as_bytes(&top_k_u32))),
196            (9, KernelArg::Bytes(as_bytes(&rms_eps_f32))),
197        ],
198        &[(0, shared_bytes)],
199        threadgroups,
200        threadgroup_size,
201    );
202
203    Ok(())
204}