Skip to main content

entrenar/moe/
router.rs

1//! Gating/routing mechanisms for Mixture of Experts
2//!
3//! Provides `TopKRouter` (deterministic) and `NoisyTopKRouter` (with exploration noise)
4//! for selecting which experts process each input token.
5
6use ndarray::{Array1, Array2};
7use rand::Rng;
8use serde::{Deserialize, Serialize};
9
10/// Result of routing a batch of tokens to experts.
11#[derive(Debug, Clone)]
12pub struct RoutingResult {
13    /// Expert indices selected per token: shape [batch_size, top_k]
14    pub expert_indices: Vec<Vec<usize>>,
15    /// Gating weights per token for selected experts: shape [batch_size, top_k]
16    pub expert_weights: Vec<Vec<f32>>,
17    /// Full probability distribution over experts per token: shape [batch_size, num_experts]
18    pub routing_probs: Array2<f32>,
19}
20
21/// Deterministic top-k router: linear projection followed by softmax, then top-k selection.
22#[derive(Debug, Clone)]
23pub struct TopKRouter {
24    /// Gating weight matrix: [input_dim, num_experts]
25    pub gate_weight: Array2<f32>,
26    /// Number of experts to route each token to
27    pub top_k: usize,
28    /// Maximum fraction of tokens each expert can process
29    pub capacity_factor: f32,
30}
31
32/// Router configuration parameters.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct RouterConfig {
35    pub input_dim: usize,
36    pub num_experts: usize,
37    pub top_k: usize,
38    pub capacity_factor: f32,
39}
40
41impl TopKRouter {
42    /// Create a new top-k router with Xavier-initialized gate weights.
43    pub fn new(config: &RouterConfig) -> Self {
44        let scale = (2.0 / (config.input_dim + config.num_experts) as f32).sqrt();
45        let gate_weight =
46            Array2::from_shape_fn((config.input_dim, config.num_experts), |(i, j)| {
47                ((i * config.num_experts + j) as f32 * 0.4567).sin() * scale
48            });
49
50        Self { gate_weight, top_k: config.top_k, capacity_factor: config.capacity_factor }
51    }
52
53    /// Route a batch of input tokens to the top-k experts.
54    ///
55    /// # Arguments
56    /// * `input` - Input tensor of shape [batch_size, input_dim]
57    ///
58    /// # Returns
59    /// `RoutingResult` containing expert assignments and weights.
60    pub fn route(&self, input: &Array2<f32>) -> RoutingResult {
61        let batch_size = input.nrows();
62        let num_experts = self.gate_weight.ncols();
63
64        // Compute logits: [batch_size, num_experts] = input @ gate_weight
65        let logits = input.dot(&self.gate_weight);
66
67        // Softmax per row to get routing probabilities
68        let routing_probs = softmax_rows(&logits);
69
70        // Apply capacity factor to determine max tokens per expert
71        let capacity = capacity_limit(batch_size, self.top_k, num_experts, self.capacity_factor);
72
73        // Select top-k experts per token, respecting capacity
74        let (expert_indices, expert_weights) =
75            select_top_k_with_capacity(&routing_probs, self.top_k, capacity);
76
77        RoutingResult { expert_indices, expert_weights, routing_probs }
78    }
79}
80
81/// Noisy top-k router: adds Gaussian noise to logits before routing for exploration.
82///
83/// Based on the Switch Transformer / GShard approach where noise encourages
84/// balanced expert utilization during training.
85#[derive(Debug, Clone)]
86pub struct NoisyTopKRouter {
87    /// Underlying deterministic router
88    pub inner: TopKRouter,
89    /// Standard deviation of Gaussian noise added to logits
90    pub noise_std: f32,
91}
92
93impl NoisyTopKRouter {
94    /// Create a new noisy top-k router.
95    pub fn new(config: &RouterConfig, noise_std: f32) -> Self {
96        Self { inner: TopKRouter::new(config), noise_std }
97    }
98
99    /// Route with added Gaussian noise for exploration.
100    pub fn route(&self, input: &Array2<f32>) -> RoutingResult {
101        let batch_size = input.nrows();
102        let num_experts = self.inner.gate_weight.ncols();
103
104        // Compute logits
105        let mut logits = input.dot(&self.inner.gate_weight);
106
107        // Add Gaussian noise
108        let mut rng = rand::rng();
109        for val in &mut logits {
110            let noise: f32 = rng.random::<f32>() * 2.0 - 1.0; // Uniform approximation
111            *val += noise * self.noise_std;
112        }
113
114        let routing_probs = softmax_rows(&logits);
115        let capacity =
116            capacity_limit(batch_size, self.inner.top_k, num_experts, self.inner.capacity_factor);
117        let (expert_indices, expert_weights) =
118            select_top_k_with_capacity(&routing_probs, self.inner.top_k, capacity);
119
120        RoutingResult { expert_indices, expert_weights, routing_probs }
121    }
122}
123
124/// Compute row-wise softmax of a 2D array.
125pub(crate) fn softmax_rows(logits: &Array2<f32>) -> Array2<f32> {
126    let mut result = logits.clone();
127    for mut row in result.rows_mut() {
128        let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
129        row.mapv_inplace(|v| (v - max_val).exp());
130        let sum: f32 = row.iter().sum();
131        if sum > 0.0 {
132            row.mapv_inplace(|v| v / sum);
133        }
134    }
135    result
136}
137
138/// Compute the capacity limit per expert.
139///
140/// capacity = ceil(capacity_factor * batch_size * top_k / num_experts)
141pub(crate) fn capacity_limit(
142    batch_size: usize,
143    top_k: usize,
144    num_experts: usize,
145    capacity_factor: f32,
146) -> usize {
147    let raw = capacity_factor * (batch_size * top_k) as f32 / num_experts as f32;
148    raw.ceil().max(1.0) as usize
149}
150
151/// Select top-k experts per token, enforcing a per-expert capacity limit.
152///
153/// Returns (expert_indices, expert_weights) where each inner Vec has length top_k.
154/// When an expert is at capacity, the token's assignment falls through to the next
155/// highest-scoring expert.
156fn select_top_k_with_capacity(
157    probs: &Array2<f32>,
158    top_k: usize,
159    capacity: usize,
160) -> (Vec<Vec<usize>>, Vec<Vec<f32>>) {
161    let batch_size = probs.nrows();
162    let num_experts = probs.ncols();
163    let mut expert_counts = vec![0usize; num_experts];
164    let mut all_indices = Vec::with_capacity(batch_size);
165    let mut all_weights = Vec::with_capacity(batch_size);
166
167    for i in 0..batch_size {
168        let row: Vec<f32> = probs.row(i).to_vec();
169        let (indices, weights) = assign_token_experts(&row, top_k, capacity, &mut expert_counts);
170        all_indices.push(indices);
171        all_weights.push(weights);
172    }
173
174    (all_indices, all_weights)
175}
176
177/// Assign top-k experts for a single token, respecting capacity limits.
178fn assign_token_experts(
179    row: &[f32],
180    top_k: usize,
181    capacity: usize,
182    expert_counts: &mut [usize],
183) -> (Vec<usize>, Vec<f32>) {
184    let mut sorted: Vec<(usize, f32)> = row.iter().copied().enumerate().collect();
185    sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
186
187    let mut indices = Vec::with_capacity(top_k);
188    let mut weights = Vec::with_capacity(top_k);
189
190    for &(expert_idx, weight) in &sorted {
191        if indices.len() >= top_k {
192            break;
193        }
194        if expert_counts[expert_idx] < capacity {
195            indices.push(expert_idx);
196            weights.push(weight);
197            expert_counts[expert_idx] += 1;
198        }
199    }
200
201    pad_assignments(&mut indices, &mut weights, top_k);
202    renormalize_weights(&mut weights);
203    (indices, weights)
204}
205
206/// Pad assignments to top_k if capacity prevented full assignment.
207fn pad_assignments(indices: &mut Vec<usize>, weights: &mut Vec<f32>, top_k: usize) {
208    while indices.len() < top_k {
209        if let Some(&last_idx) = indices.last() {
210            indices.push(last_idx);
211            weights.push(0.0);
212        } else {
213            indices.push(0);
214            weights.push(1.0 / top_k as f32);
215        }
216    }
217}
218
219/// Renormalize weights to sum to 1.0.
220fn renormalize_weights(weights: &mut [f32]) {
221    let sum: f32 = weights.iter().sum();
222    if sum > 0.0 {
223        for w in weights.iter_mut() {
224            *w /= sum;
225        }
226    }
227}
228
229/// Compute the fraction of tokens routed to each expert.
230///
231/// Returns an Array1 of length num_experts with the fraction of total routing
232/// probability assigned to each expert.
233pub(crate) fn expert_load_fractions(routing_probs: &Array2<f32>) -> Array1<f32> {
234    let num_experts = routing_probs.ncols();
235    let batch_size = routing_probs.nrows();
236    if batch_size == 0 {
237        return Array1::zeros(num_experts);
238    }
239    let col_sums = routing_probs.sum_axis(ndarray::Axis(0));
240    col_sums / batch_size as f32
241}