ferrum_models/moe/router.rs
1//! MoE router (gating) — pick top-K experts per token.
2//!
3//! Given router logits of shape `[batch, num_experts]` (output of the small
4//! gating linear), produce per-token expert indices + combine weights:
5//!
6//! 1. Softmax over each row (so all probs are non-negative and sum to 1).
7//! 2. Take the K highest-probability experts.
8//! 3. Optionally renormalise those K probs so they sum back to 1
9//! (Qwen3-MoE / Mixtral default; some legacy variants don't).
10//!
11//! Output layout is **flat with stride `top_k`**: `expert_ids[b*K + k]`
12//! is the k-th selected expert for token b, and `expert_weights[b*K + k]`
13//! is its combine weight. That matches how the dispatch loop iterates.
14
15/// Result of routing one batch: parallel arrays indexed `[b * top_k + k]`.
16#[derive(Debug, Clone, PartialEq)]
17pub struct RouterOutput {
18 /// Selected expert indices. `expert_ids[b * top_k + k] ∈ [0, num_experts)`.
19 pub expert_ids: Vec<u32>,
20 /// Combine weights. Same shape as `expert_ids`. If
21 /// `norm_topk_prob` was true, the K weights for each token sum to 1;
22 /// otherwise they're the raw (post-softmax) probabilities of the
23 /// selected experts.
24 pub expert_weights: Vec<f32>,
25}
26
27impl RouterOutput {
28 /// Empty `RouterOutput` with no allocation. Use [`Self::reset`] before
29 /// reuse — this is the cheap constructor for putting it in a scratch
30 /// struct.
31 pub fn empty() -> Self {
32 Self {
33 expert_ids: Vec::new(),
34 expert_weights: Vec::new(),
35 }
36 }
37
38 /// Resize both vectors to `batch * top_k`. Existing capacity is reused
39 /// when sufficient; growth uses standard `Vec::resize`. Old contents
40 /// are not preserved (callers always overwrite).
41 pub fn reset(&mut self, batch: usize, top_k: usize) {
42 let n = batch * top_k;
43 self.expert_ids.clear();
44 self.expert_ids.resize(n, 0);
45 self.expert_weights.clear();
46 self.expert_weights.resize(n, 0.0);
47 }
48
49 /// Number of tokens routed.
50 pub fn batch(&self) -> usize {
51 // `top_k` is the second dimension; we don't store it explicitly,
52 // so derive it from the assumption that the caller passed
53 // consistent sizes. Length checks belong upstream.
54 self.expert_ids.len() / self.batch_top_k_pair_count()
55 }
56
57 fn batch_top_k_pair_count(&self) -> usize {
58 // Defensive: avoid divide-by-zero for the empty-router case.
59 self.expert_ids.len().max(1)
60 }
61}
62
63/// Route a batch of tokens to top-K experts.
64///
65/// `logits`: row-major `[batch, num_experts]`. Each row is the raw output
66/// of the gating linear for one token.
67///
68/// `norm_topk_prob`: if true, the K returned weights for each token are
69/// renormalised to sum to 1 (after the masked softmax) — Qwen3-MoE and
70/// Mixtral both do this. If false, they're the raw softmax probabilities,
71/// which leaves probability mass "on the floor" for unselected experts.
72///
73/// Panics if `top_k == 0` or `top_k > num_experts` or
74/// `logits.len() != batch * num_experts` — these are programming errors,
75/// not runtime conditions.
76pub fn route(
77 logits: &[f32],
78 batch: usize,
79 num_experts: usize,
80 top_k: usize,
81 norm_topk_prob: bool,
82) -> RouterOutput {
83 let mut out = RouterOutput::empty();
84 let mut scratch = Vec::new();
85 route_into(
86 logits,
87 batch,
88 num_experts,
89 top_k,
90 norm_topk_prob,
91 &mut out,
92 &mut scratch,
93 );
94 out
95}
96
97/// Allocation-free variant of [`route`].
98///
99/// The 4 per-row `Vec` allocations of [`route`] (softmax buffer, indexed
100/// pair buffer for sort, top-K index buffer, renormalised weights buffer)
101/// dominate per-token cost in MoE forward — at c=32 / num_experts=128 /
102/// top_k=8 / 48 layers that's 4 608 allocations per decode token, or
103/// ~10 ms of pure CPU per token (25% of MoE wallclock at c=32 on RTX 4090).
104///
105/// This variant takes a reusable `out: &mut RouterOutput` and a
106/// `scratch_probs: &mut Vec<f32>` softmax buffer, both of which are
107/// `clear() + resize()` reused across calls — zero allocations after warmup.
108/// Top-K is computed via argmax-mask (K passes of a linear scan) instead
109/// of a full O(N log N) sort, which is also faster for K=8 / N=128.
110///
111/// Tie-breaking: when two probs are equal, the smaller index wins (matches
112/// [`route`] / Metal `moe_router_topk_softmax_f32` for bit-exact output).
113pub fn route_into(
114 logits: &[f32],
115 batch: usize,
116 num_experts: usize,
117 top_k: usize,
118 norm_topk_prob: bool,
119 out: &mut RouterOutput,
120 scratch_probs: &mut Vec<f32>,
121) {
122 assert_eq!(
123 logits.len(),
124 batch * num_experts,
125 "router logits shape mismatch: expected {batch}×{num_experts}, got {}",
126 logits.len()
127 );
128 assert!(top_k > 0, "top_k must be > 0");
129 assert!(
130 top_k <= num_experts,
131 "top_k {top_k} exceeds num_experts {num_experts}"
132 );
133
134 out.reset(batch, top_k);
135 scratch_probs.clear();
136 scratch_probs.resize(num_experts, 0.0);
137
138 for b in 0..batch {
139 let row = &logits[b * num_experts..(b + 1) * num_experts];
140
141 // ── Softmax in-place into scratch_probs. ─────────────────────────
142 let mut max = f32::NEG_INFINITY;
143 for &v in row {
144 if v > max {
145 max = v;
146 }
147 }
148 let mut sum = 0.0f32;
149 for (i, &v) in row.iter().enumerate() {
150 let e = (v - max).exp();
151 scratch_probs[i] = e;
152 sum += e;
153 }
154 let inv_sum = 1.0 / sum;
155 for v in scratch_probs.iter_mut() {
156 *v *= inv_sum;
157 }
158
159 // ── Top-K via argmax-mask. K passes; each picks the largest
160 // remaining prob and overwrites it with -inf so the next pass
161 // sees the next-best. The strict `v > best` keeps the first
162 // (smallest-index) tied entry — matches the sort-based path.
163 let mut sel_sum = 0.0f32;
164 let dst_lo = b * top_k;
165 for k in 0..top_k {
166 let mut best = f32::NEG_INFINITY;
167 let mut best_idx = 0usize;
168 for (i, &v) in scratch_probs.iter().enumerate() {
169 if v > best {
170 best = v;
171 best_idx = i;
172 }
173 }
174 out.expert_ids[dst_lo + k] = best_idx as u32;
175 out.expert_weights[dst_lo + k] = best;
176 sel_sum += best;
177 scratch_probs[best_idx] = f32::NEG_INFINITY;
178 }
179
180 // ── Optional renorm of the K picked weights. ─────────────────
181 if norm_topk_prob {
182 if sel_sum > 0.0 {
183 let scale = 1.0 / sel_sum;
184 for w in &mut out.expert_weights[dst_lo..dst_lo + top_k] {
185 *w *= scale;
186 }
187 } else {
188 let uniform = 1.0 / top_k as f32;
189 for w in &mut out.expert_weights[dst_lo..dst_lo + top_k] {
190 *w = uniform;
191 }
192 }
193 }
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use rand::rngs::StdRng;
201 use rand::{Rng, SeedableRng};
202
203 fn run_parity(batch: usize, num_experts: usize, top_k: usize, norm: bool, seed: u64) {
204 let mut rng = StdRng::seed_from_u64(seed);
205 let logits: Vec<f32> = (0..batch * num_experts)
206 .map(|_| rng.gen_range(-3.0..3.0_f32))
207 .collect();
208
209 // Old / new bit-for-bit must match — both use stable max-subtract
210 // softmax + first-tie-wins top-K + (optional) sum-renorm.
211 let a = route(&logits, batch, num_experts, top_k, norm);
212 let mut b = RouterOutput::empty();
213 let mut probs = Vec::new();
214 route_into(&logits, batch, num_experts, top_k, norm, &mut b, &mut probs);
215
216 assert_eq!(a.expert_ids, b.expert_ids, "expert_ids mismatch");
217 for (i, (&aw, &bw)) in a.expert_weights.iter().zip(&b.expert_weights).enumerate() {
218 // Within ulps for the renorm-divide, exact otherwise.
219 let delta = (aw - bw).abs();
220 assert!(
221 delta < 1e-6,
222 "weight[{i}] mismatch: route={aw} route_into={bw} delta={delta}"
223 );
224 }
225 }
226
227 #[test]
228 fn parity_qwen3_moe_shape() {
229 // Qwen3-MoE 30B-A3B production shape (norm_topk_prob=true).
230 run_parity(32, 128, 8, true, 0xDEADBEEF);
231 run_parity(1, 128, 8, true, 0x1234);
232 run_parity(64, 128, 8, true, 0x5678);
233 }
234
235 #[test]
236 fn parity_no_renorm() {
237 run_parity(8, 64, 4, false, 0xC0FFEE);
238 }
239
240 #[test]
241 fn parity_topk_one() {
242 run_parity(4, 16, 1, true, 0x42);
243 run_parity(4, 16, 1, false, 0x42);
244 }
245
246 #[test]
247 fn allocation_free_after_warmup() {
248 // Sanity: scratch capacity stays put across calls — we don't
249 // grow / shrink the underlying Vec on each call.
250 let mut out = RouterOutput::empty();
251 let mut probs = Vec::new();
252 let logits = vec![0.5f32; 32 * 128];
253 route_into(&logits, 32, 128, 8, true, &mut out, &mut probs);
254 let cap_ids = out.expert_ids.capacity();
255 let cap_w = out.expert_weights.capacity();
256 let cap_p = probs.capacity();
257 // Repeat — capacity must not grow.
258 for _ in 0..16 {
259 route_into(&logits, 32, 128, 8, true, &mut out, &mut probs);
260 assert_eq!(out.expert_ids.capacity(), cap_ids);
261 assert_eq!(out.expert_weights.capacity(), cap_w);
262 assert_eq!(probs.capacity(), cap_p);
263 }
264 }
265}