Skip to main content

ferrum_models/moe/
router.rs

1//! MoE router (gating) — pick top-K experts per token.
2//!
3//! Given router logits of shape `[batch, num_experts]` (output of the small
4//! gating linear), produce per-token expert indices + combine weights:
5//!
6//!   1. Softmax over each row (so all probs are non-negative and sum to 1).
7//!   2. Take the K highest-probability experts.
8//!   3. Optionally renormalise those K probs so they sum back to 1
9//!      (Qwen3-MoE / Mixtral default; some legacy variants don't).
10//!
11//! Output layout is **flat with stride `top_k`**: `expert_ids[b*K + k]`
12//! is the k-th selected expert for token b, and `expert_weights[b*K + k]`
13//! is its combine weight. That matches how the dispatch loop iterates.
14
15/// Result of routing one batch: parallel arrays indexed `[b * top_k + k]`.
16#[derive(Debug, Clone, PartialEq)]
17pub struct RouterOutput {
18    /// Selected expert indices. `expert_ids[b * top_k + k] ∈ [0, num_experts)`.
19    pub expert_ids: Vec<u32>,
20    /// Combine weights. Same shape as `expert_ids`. If
21    /// `norm_topk_prob` was true, the K weights for each token sum to 1;
22    /// otherwise they're the raw (post-softmax) probabilities of the
23    /// selected experts.
24    pub expert_weights: Vec<f32>,
25}
26
27impl RouterOutput {
28    /// Number of tokens routed.
29    pub fn batch(&self) -> usize {
30        // `top_k` is the second dimension; we don't store it explicitly,
31        // so derive it from the assumption that the caller passed
32        // consistent sizes. Length checks belong upstream.
33        self.expert_ids.len() / self.batch_top_k_pair_count()
34    }
35
36    fn batch_top_k_pair_count(&self) -> usize {
37        // Defensive: avoid divide-by-zero for the empty-router case.
38        self.expert_ids.len().max(1)
39    }
40}
41
42/// Route a batch of tokens to top-K experts.
43///
44/// `logits`: row-major `[batch, num_experts]`. Each row is the raw output
45/// of the gating linear for one token.
46///
47/// `norm_topk_prob`: if true, the K returned weights for each token are
48/// renormalised to sum to 1 (after the masked softmax) — Qwen3-MoE and
49/// Mixtral both do this. If false, they're the raw softmax probabilities,
50/// which leaves probability mass "on the floor" for unselected experts.
51///
52/// Panics if `top_k == 0` or `top_k > num_experts` or
53/// `logits.len() != batch * num_experts` — these are programming errors,
54/// not runtime conditions.
55pub fn route(
56    logits: &[f32],
57    batch: usize,
58    num_experts: usize,
59    top_k: usize,
60    norm_topk_prob: bool,
61) -> RouterOutput {
62    assert_eq!(
63        logits.len(),
64        batch * num_experts,
65        "router logits shape mismatch: expected {batch}×{num_experts}, got {}",
66        logits.len()
67    );
68    assert!(top_k > 0, "top_k must be > 0");
69    assert!(
70        top_k <= num_experts,
71        "top_k {top_k} exceeds num_experts {num_experts}"
72    );
73
74    let mut expert_ids = Vec::with_capacity(batch * top_k);
75    let mut expert_weights = Vec::with_capacity(batch * top_k);
76
77    for b in 0..batch {
78        let row = &logits[b * num_experts..(b + 1) * num_experts];
79        let probs = softmax(row);
80        let topk = top_k_indices(&probs, top_k);
81
82        // Optionally renorm the K selected weights. If norm is off we
83        // emit the raw post-softmax probs (unselected mass discarded).
84        let combine_weights = if norm_topk_prob {
85            renormalise(&topk, &probs)
86        } else {
87            topk.iter().map(|&i| probs[i]).collect::<Vec<_>>()
88        };
89
90        for (i, &exp_id) in topk.iter().enumerate() {
91            expert_ids.push(exp_id as u32);
92            expert_weights.push(combine_weights[i]);
93        }
94    }
95
96    RouterOutput {
97        expert_ids,
98        expert_weights,
99    }
100}
101
102/// Numerically-stable softmax over a single row.
103fn softmax(row: &[f32]) -> Vec<f32> {
104    let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
105    let mut exp: Vec<f32> = row.iter().map(|&x| (x - max).exp()).collect();
106    let sum: f32 = exp.iter().sum();
107    // sum is guaranteed > 0 because at least one term is exp(max-max) = 1.
108    for v in &mut exp {
109        *v /= sum;
110    }
111    exp
112}
113
114/// Return the indices of the K largest entries, sorted by value descending,
115/// breaking ties by smaller index first (stable / reproducible).
116fn top_k_indices(probs: &[f32], top_k: usize) -> Vec<usize> {
117    // Pair each prob with its index, sort by (-prob, index) then truncate.
118    let mut indexed: Vec<(usize, f32)> = probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
119    indexed.sort_by(|a, b| {
120        b.1.partial_cmp(&a.1)
121            .unwrap_or(std::cmp::Ordering::Equal)
122            .then_with(|| a.0.cmp(&b.0))
123    });
124    indexed.truncate(top_k);
125    indexed.into_iter().map(|(i, _)| i).collect()
126}
127
128/// Renormalise the K selected probabilities so they sum to 1.
129fn renormalise(selected: &[usize], probs: &[f32]) -> Vec<f32> {
130    let sum: f32 = selected.iter().map(|&i| probs[i]).sum();
131    // Guard against degenerate sum=0 (shouldn't happen with finite logits).
132    if sum > 0.0 {
133        selected.iter().map(|&i| probs[i] / sum).collect()
134    } else {
135        // Fallback: uniform 1/K.
136        let k = selected.len() as f32;
137        vec![1.0 / k; selected.len()]
138    }
139}