ferrum_models/moe/router.rs
1//! MoE router (gating) — pick top-K experts per token.
2//!
3//! Given router logits of shape `[batch, num_experts]` (output of the small
4//! gating linear), produce per-token expert indices + combine weights:
5//!
6//! 1. Softmax over each row (so all probs are non-negative and sum to 1).
7//! 2. Take the K highest-probability experts.
8//! 3. Optionally renormalise those K probs so they sum back to 1
9//! (Qwen3-MoE / Mixtral default; some legacy variants don't).
10//!
11//! Output layout is **flat with stride `top_k`**: `expert_ids[b*K + k]`
12//! is the k-th selected expert for token b, and `expert_weights[b*K + k]`
13//! is its combine weight. That matches how the dispatch loop iterates.
14
15/// Result of routing one batch: parallel arrays indexed `[b * top_k + k]`.
16#[derive(Debug, Clone, PartialEq)]
17pub struct RouterOutput {
18 /// Selected expert indices. `expert_ids[b * top_k + k] ∈ [0, num_experts)`.
19 pub expert_ids: Vec<u32>,
20 /// Combine weights. Same shape as `expert_ids`. If
21 /// `norm_topk_prob` was true, the K weights for each token sum to 1;
22 /// otherwise they're the raw (post-softmax) probabilities of the
23 /// selected experts.
24 pub expert_weights: Vec<f32>,
25}
26
27impl RouterOutput {
28 /// Number of tokens routed.
29 pub fn batch(&self) -> usize {
30 // `top_k` is the second dimension; we don't store it explicitly,
31 // so derive it from the assumption that the caller passed
32 // consistent sizes. Length checks belong upstream.
33 self.expert_ids.len() / self.batch_top_k_pair_count()
34 }
35
36 fn batch_top_k_pair_count(&self) -> usize {
37 // Defensive: avoid divide-by-zero for the empty-router case.
38 self.expert_ids.len().max(1)
39 }
40}
41
42/// Route a batch of tokens to top-K experts.
43///
44/// `logits`: row-major `[batch, num_experts]`. Each row is the raw output
45/// of the gating linear for one token.
46///
47/// `norm_topk_prob`: if true, the K returned weights for each token are
48/// renormalised to sum to 1 (after the masked softmax) — Qwen3-MoE and
49/// Mixtral both do this. If false, they're the raw softmax probabilities,
50/// which leaves probability mass "on the floor" for unselected experts.
51///
52/// Panics if `top_k == 0` or `top_k > num_experts` or
53/// `logits.len() != batch * num_experts` — these are programming errors,
54/// not runtime conditions.
55pub fn route(
56 logits: &[f32],
57 batch: usize,
58 num_experts: usize,
59 top_k: usize,
60 norm_topk_prob: bool,
61) -> RouterOutput {
62 assert_eq!(
63 logits.len(),
64 batch * num_experts,
65 "router logits shape mismatch: expected {batch}×{num_experts}, got {}",
66 logits.len()
67 );
68 assert!(top_k > 0, "top_k must be > 0");
69 assert!(
70 top_k <= num_experts,
71 "top_k {top_k} exceeds num_experts {num_experts}"
72 );
73
74 let mut expert_ids = Vec::with_capacity(batch * top_k);
75 let mut expert_weights = Vec::with_capacity(batch * top_k);
76
77 for b in 0..batch {
78 let row = &logits[b * num_experts..(b + 1) * num_experts];
79 let probs = softmax(row);
80 let topk = top_k_indices(&probs, top_k);
81
82 // Optionally renorm the K selected weights. If norm is off we
83 // emit the raw post-softmax probs (unselected mass discarded).
84 let combine_weights = if norm_topk_prob {
85 renormalise(&topk, &probs)
86 } else {
87 topk.iter().map(|&i| probs[i]).collect::<Vec<_>>()
88 };
89
90 for (i, &exp_id) in topk.iter().enumerate() {
91 expert_ids.push(exp_id as u32);
92 expert_weights.push(combine_weights[i]);
93 }
94 }
95
96 RouterOutput {
97 expert_ids,
98 expert_weights,
99 }
100}
101
102/// Numerically-stable softmax over a single row.
103fn softmax(row: &[f32]) -> Vec<f32> {
104 let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
105 let mut exp: Vec<f32> = row.iter().map(|&x| (x - max).exp()).collect();
106 let sum: f32 = exp.iter().sum();
107 // sum is guaranteed > 0 because at least one term is exp(max-max) = 1.
108 for v in &mut exp {
109 *v /= sum;
110 }
111 exp
112}
113
114/// Return the indices of the K largest entries, sorted by value descending,
115/// breaking ties by smaller index first (stable / reproducible).
116fn top_k_indices(probs: &[f32], top_k: usize) -> Vec<usize> {
117 // Pair each prob with its index, sort by (-prob, index) then truncate.
118 let mut indexed: Vec<(usize, f32)> = probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
119 indexed.sort_by(|a, b| {
120 b.1.partial_cmp(&a.1)
121 .unwrap_or(std::cmp::Ordering::Equal)
122 .then_with(|| a.0.cmp(&b.0))
123 });
124 indexed.truncate(top_k);
125 indexed.into_iter().map(|(i, _)| i).collect()
126}
127
128/// Renormalise the K selected probabilities so they sum to 1.
129fn renormalise(selected: &[usize], probs: &[f32]) -> Vec<f32> {
130 let sum: f32 = selected.iter().map(|&i| probs[i]).sum();
131 // Guard against degenerate sum=0 (shouldn't happen with finite logits).
132 if sum > 0.0 {
133 selected.iter().map(|&i| probs[i] / sum).collect()
134 } else {
135 // Fallback: uniform 1/K.
136 let k = selected.len() as f32;
137 vec![1.0 / k; selected.len()]
138 }
139}