Skip to main content

ferrum_models/moe/
router.rs

1//! MoE router (gating) — pick top-K experts per token.
2//!
3//! Given router logits of shape `[batch, num_experts]` (output of the small
4//! gating linear), produce per-token expert indices + combine weights:
5//!
6//!   1. Softmax over each row (so all probs are non-negative and sum to 1).
7//!   2. Take the K highest-probability experts.
8//!   3. Optionally renormalise those K probs so they sum back to 1
9//!      (Qwen3-MoE / Mixtral default; some legacy variants don't).
10//!
11//! Output layout is **flat with stride `top_k`**: `expert_ids[b*K + k]`
12//! is the k-th selected expert for token b, and `expert_weights[b*K + k]`
13//! is its combine weight. That matches how the dispatch loop iterates.
14
15/// Result of routing one batch: parallel arrays indexed `[b * top_k + k]`.
16#[derive(Debug, Clone, PartialEq)]
17pub struct RouterOutput {
18    /// Selected expert indices. `expert_ids[b * top_k + k] ∈ [0, num_experts)`.
19    pub expert_ids: Vec<u32>,
20    /// Combine weights. Same shape as `expert_ids`. If
21    /// `norm_topk_prob` was true, the K weights for each token sum to 1;
22    /// otherwise they're the raw (post-softmax) probabilities of the
23    /// selected experts.
24    pub expert_weights: Vec<f32>,
25}
26
27impl RouterOutput {
28    /// Empty `RouterOutput` with no allocation. Use [`Self::reset`] before
29    /// reuse — this is the cheap constructor for putting it in a scratch
30    /// struct.
31    pub fn empty() -> Self {
32        Self {
33            expert_ids: Vec::new(),
34            expert_weights: Vec::new(),
35        }
36    }
37
38    /// Resize both vectors to `batch * top_k`. Existing capacity is reused
39    /// when sufficient; growth uses standard `Vec::resize`. Old contents
40    /// are not preserved (callers always overwrite).
41    pub fn reset(&mut self, batch: usize, top_k: usize) {
42        let n = batch * top_k;
43        self.expert_ids.clear();
44        self.expert_ids.resize(n, 0);
45        self.expert_weights.clear();
46        self.expert_weights.resize(n, 0.0);
47    }
48
49    /// Number of tokens routed.
50    pub fn batch(&self) -> usize {
51        // `top_k` is the second dimension; we don't store it explicitly,
52        // so derive it from the assumption that the caller passed
53        // consistent sizes. Length checks belong upstream.
54        self.expert_ids.len() / self.batch_top_k_pair_count()
55    }
56
57    fn batch_top_k_pair_count(&self) -> usize {
58        // Defensive: avoid divide-by-zero for the empty-router case.
59        self.expert_ids.len().max(1)
60    }
61}
62
63/// Route a batch of tokens to top-K experts.
64///
65/// `logits`: row-major `[batch, num_experts]`. Each row is the raw output
66/// of the gating linear for one token.
67///
68/// `norm_topk_prob`: if true, the K returned weights for each token are
69/// renormalised to sum to 1 (after the masked softmax) — Qwen3-MoE and
70/// Mixtral both do this. If false, they're the raw softmax probabilities,
71/// which leaves probability mass "on the floor" for unselected experts.
72///
73/// Panics if `top_k == 0` or `top_k > num_experts` or
74/// `logits.len() != batch * num_experts` — these are programming errors,
75/// not runtime conditions.
76pub fn route(
77    logits: &[f32],
78    batch: usize,
79    num_experts: usize,
80    top_k: usize,
81    norm_topk_prob: bool,
82) -> RouterOutput {
83    let mut out = RouterOutput::empty();
84    let mut scratch = Vec::new();
85    route_into(
86        logits,
87        batch,
88        num_experts,
89        top_k,
90        norm_topk_prob,
91        &mut out,
92        &mut scratch,
93    );
94    out
95}
96
97/// Allocation-free variant of [`route`].
98///
99/// The 4 per-row `Vec` allocations of [`route`] (softmax buffer, indexed
100/// pair buffer for sort, top-K index buffer, renormalised weights buffer)
101/// dominate per-token cost in MoE forward — at c=32 / num_experts=128 /
102/// top_k=8 / 48 layers that's 4 608 allocations per decode token, or
103/// ~10 ms of pure CPU per token (25% of MoE wallclock at c=32 on RTX 4090).
104///
105/// This variant takes a reusable `out: &mut RouterOutput` and a
106/// `scratch_probs: &mut Vec<f32>` softmax buffer, both of which are
107/// `clear() + resize()` reused across calls — zero allocations after warmup.
108/// Top-K is computed via argmax-mask (K passes of a linear scan) instead
109/// of a full O(N log N) sort, which is also faster for K=8 / N=128.
110///
111/// Tie-breaking: when two probs are equal, the smaller index wins (matches
112/// [`route`] / Metal `moe_router_topk_softmax_f32` for bit-exact output).
113pub fn route_into(
114    logits: &[f32],
115    batch: usize,
116    num_experts: usize,
117    top_k: usize,
118    norm_topk_prob: bool,
119    out: &mut RouterOutput,
120    scratch_probs: &mut Vec<f32>,
121) {
122    assert_eq!(
123        logits.len(),
124        batch * num_experts,
125        "router logits shape mismatch: expected {batch}×{num_experts}, got {}",
126        logits.len()
127    );
128    assert!(top_k > 0, "top_k must be > 0");
129    assert!(
130        top_k <= num_experts,
131        "top_k {top_k} exceeds num_experts {num_experts}"
132    );
133
134    out.reset(batch, top_k);
135    scratch_probs.clear();
136    scratch_probs.resize(num_experts, 0.0);
137
138    for b in 0..batch {
139        let row = &logits[b * num_experts..(b + 1) * num_experts];
140
141        // ── Softmax in-place into scratch_probs. ─────────────────────────
142        let mut max = f32::NEG_INFINITY;
143        for &v in row {
144            if v > max {
145                max = v;
146            }
147        }
148        let mut sum = 0.0f32;
149        for (i, &v) in row.iter().enumerate() {
150            let e = (v - max).exp();
151            scratch_probs[i] = e;
152            sum += e;
153        }
154        let inv_sum = 1.0 / sum;
155        for v in scratch_probs.iter_mut() {
156            *v *= inv_sum;
157        }
158
159        // ── Top-K via argmax-mask. K passes; each picks the largest
160        // remaining prob and overwrites it with -inf so the next pass
161        // sees the next-best. The strict `v > best` keeps the first
162        // (smallest-index) tied entry — matches the sort-based path.
163        let mut sel_sum = 0.0f32;
164        let dst_lo = b * top_k;
165        for k in 0..top_k {
166            let mut best = f32::NEG_INFINITY;
167            let mut best_idx = 0usize;
168            for (i, &v) in scratch_probs.iter().enumerate() {
169                if v > best {
170                    best = v;
171                    best_idx = i;
172                }
173            }
174            out.expert_ids[dst_lo + k] = best_idx as u32;
175            out.expert_weights[dst_lo + k] = best;
176            sel_sum += best;
177            scratch_probs[best_idx] = f32::NEG_INFINITY;
178        }
179
180        // ── Optional renorm of the K picked weights. ─────────────────
181        if norm_topk_prob {
182            if sel_sum > 0.0 {
183                let scale = 1.0 / sel_sum;
184                for w in &mut out.expert_weights[dst_lo..dst_lo + top_k] {
185                    *w *= scale;
186                }
187            } else {
188                let uniform = 1.0 / top_k as f32;
189                for w in &mut out.expert_weights[dst_lo..dst_lo + top_k] {
190                    *w = uniform;
191                }
192            }
193        }
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use rand::rngs::StdRng;
201    use rand::{Rng, SeedableRng};
202
203    fn run_parity(batch: usize, num_experts: usize, top_k: usize, norm: bool, seed: u64) {
204        let mut rng = StdRng::seed_from_u64(seed);
205        let logits: Vec<f32> = (0..batch * num_experts)
206            .map(|_| rng.gen_range(-3.0..3.0_f32))
207            .collect();
208
209        // Old / new bit-for-bit must match — both use stable max-subtract
210        // softmax + first-tie-wins top-K + (optional) sum-renorm.
211        let a = route(&logits, batch, num_experts, top_k, norm);
212        let mut b = RouterOutput::empty();
213        let mut probs = Vec::new();
214        route_into(&logits, batch, num_experts, top_k, norm, &mut b, &mut probs);
215
216        assert_eq!(a.expert_ids, b.expert_ids, "expert_ids mismatch");
217        for (i, (&aw, &bw)) in a.expert_weights.iter().zip(&b.expert_weights).enumerate() {
218            // Within ulps for the renorm-divide, exact otherwise.
219            let delta = (aw - bw).abs();
220            assert!(
221                delta < 1e-6,
222                "weight[{i}] mismatch: route={aw} route_into={bw} delta={delta}"
223            );
224        }
225    }
226
227    #[test]
228    fn parity_qwen3_moe_shape() {
229        // Qwen3-MoE 30B-A3B production shape (norm_topk_prob=true).
230        run_parity(32, 128, 8, true, 0xDEADBEEF);
231        run_parity(1, 128, 8, true, 0x1234);
232        run_parity(64, 128, 8, true, 0x5678);
233    }
234
235    #[test]
236    fn parity_no_renorm() {
237        run_parity(8, 64, 4, false, 0xC0FFEE);
238    }
239
240    #[test]
241    fn parity_topk_one() {
242        run_parity(4, 16, 1, true, 0x42);
243        run_parity(4, 16, 1, false, 0x42);
244    }
245
246    #[test]
247    fn allocation_free_after_warmup() {
248        // Sanity: scratch capacity stays put across calls — we don't
249        // grow / shrink the underlying Vec on each call.
250        let mut out = RouterOutput::empty();
251        let mut probs = Vec::new();
252        let logits = vec![0.5f32; 32 * 128];
253        route_into(&logits, 32, 128, 8, true, &mut out, &mut probs);
254        let cap_ids = out.expert_ids.capacity();
255        let cap_w = out.expert_weights.capacity();
256        let cap_p = probs.capacity();
257        // Repeat — capacity must not grow.
258        for _ in 0..16 {
259            route_into(&logits, 32, 128, 8, true, &mut out, &mut probs);
260            assert_eq!(out.expert_ids.capacity(), cap_ids);
261            assert_eq!(out.expert_weights.capacity(), cap_w);
262            assert_eq!(probs.capacity(), cap_p);
263        }
264    }
265}