1use ndarray::{Array1, Array2};
7use rand::Rng;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone)]
12pub struct RoutingResult {
13 pub expert_indices: Vec<Vec<usize>>,
15 pub expert_weights: Vec<Vec<f32>>,
17 pub routing_probs: Array2<f32>,
19}
20
21#[derive(Debug, Clone)]
23pub struct TopKRouter {
24 pub gate_weight: Array2<f32>,
26 pub top_k: usize,
28 pub capacity_factor: f32,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct RouterConfig {
35 pub input_dim: usize,
36 pub num_experts: usize,
37 pub top_k: usize,
38 pub capacity_factor: f32,
39}
40
41impl TopKRouter {
42 pub fn new(config: &RouterConfig) -> Self {
44 let scale = (2.0 / (config.input_dim + config.num_experts) as f32).sqrt();
45 let gate_weight =
46 Array2::from_shape_fn((config.input_dim, config.num_experts), |(i, j)| {
47 ((i * config.num_experts + j) as f32 * 0.4567).sin() * scale
48 });
49
50 Self { gate_weight, top_k: config.top_k, capacity_factor: config.capacity_factor }
51 }
52
53 pub fn route(&self, input: &Array2<f32>) -> RoutingResult {
61 let batch_size = input.nrows();
62 let num_experts = self.gate_weight.ncols();
63
64 let logits = input.dot(&self.gate_weight);
66
67 let routing_probs = softmax_rows(&logits);
69
70 let capacity = capacity_limit(batch_size, self.top_k, num_experts, self.capacity_factor);
72
73 let (expert_indices, expert_weights) =
75 select_top_k_with_capacity(&routing_probs, self.top_k, capacity);
76
77 RoutingResult { expert_indices, expert_weights, routing_probs }
78 }
79}
80
81#[derive(Debug, Clone)]
86pub struct NoisyTopKRouter {
87 pub inner: TopKRouter,
89 pub noise_std: f32,
91}
92
93impl NoisyTopKRouter {
94 pub fn new(config: &RouterConfig, noise_std: f32) -> Self {
96 Self { inner: TopKRouter::new(config), noise_std }
97 }
98
99 pub fn route(&self, input: &Array2<f32>) -> RoutingResult {
101 let batch_size = input.nrows();
102 let num_experts = self.inner.gate_weight.ncols();
103
104 let mut logits = input.dot(&self.inner.gate_weight);
106
107 let mut rng = rand::rng();
109 for val in &mut logits {
110 let noise: f32 = rng.random::<f32>() * 2.0 - 1.0; *val += noise * self.noise_std;
112 }
113
114 let routing_probs = softmax_rows(&logits);
115 let capacity =
116 capacity_limit(batch_size, self.inner.top_k, num_experts, self.inner.capacity_factor);
117 let (expert_indices, expert_weights) =
118 select_top_k_with_capacity(&routing_probs, self.inner.top_k, capacity);
119
120 RoutingResult { expert_indices, expert_weights, routing_probs }
121 }
122}
123
124pub(crate) fn softmax_rows(logits: &Array2<f32>) -> Array2<f32> {
126 let mut result = logits.clone();
127 for mut row in result.rows_mut() {
128 let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
129 row.mapv_inplace(|v| (v - max_val).exp());
130 let sum: f32 = row.iter().sum();
131 if sum > 0.0 {
132 row.mapv_inplace(|v| v / sum);
133 }
134 }
135 result
136}
137
138pub(crate) fn capacity_limit(
142 batch_size: usize,
143 top_k: usize,
144 num_experts: usize,
145 capacity_factor: f32,
146) -> usize {
147 let raw = capacity_factor * (batch_size * top_k) as f32 / num_experts as f32;
148 raw.ceil().max(1.0) as usize
149}
150
151fn select_top_k_with_capacity(
157 probs: &Array2<f32>,
158 top_k: usize,
159 capacity: usize,
160) -> (Vec<Vec<usize>>, Vec<Vec<f32>>) {
161 let batch_size = probs.nrows();
162 let num_experts = probs.ncols();
163 let mut expert_counts = vec![0usize; num_experts];
164 let mut all_indices = Vec::with_capacity(batch_size);
165 let mut all_weights = Vec::with_capacity(batch_size);
166
167 for i in 0..batch_size {
168 let row: Vec<f32> = probs.row(i).to_vec();
169 let (indices, weights) = assign_token_experts(&row, top_k, capacity, &mut expert_counts);
170 all_indices.push(indices);
171 all_weights.push(weights);
172 }
173
174 (all_indices, all_weights)
175}
176
177fn assign_token_experts(
179 row: &[f32],
180 top_k: usize,
181 capacity: usize,
182 expert_counts: &mut [usize],
183) -> (Vec<usize>, Vec<f32>) {
184 let mut sorted: Vec<(usize, f32)> = row.iter().copied().enumerate().collect();
185 sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
186
187 let mut indices = Vec::with_capacity(top_k);
188 let mut weights = Vec::with_capacity(top_k);
189
190 for &(expert_idx, weight) in &sorted {
191 if indices.len() >= top_k {
192 break;
193 }
194 if expert_counts[expert_idx] < capacity {
195 indices.push(expert_idx);
196 weights.push(weight);
197 expert_counts[expert_idx] += 1;
198 }
199 }
200
201 pad_assignments(&mut indices, &mut weights, top_k);
202 renormalize_weights(&mut weights);
203 (indices, weights)
204}
205
206fn pad_assignments(indices: &mut Vec<usize>, weights: &mut Vec<f32>, top_k: usize) {
208 while indices.len() < top_k {
209 if let Some(&last_idx) = indices.last() {
210 indices.push(last_idx);
211 weights.push(0.0);
212 } else {
213 indices.push(0);
214 weights.push(1.0 / top_k as f32);
215 }
216 }
217}
218
219fn renormalize_weights(weights: &mut [f32]) {
221 let sum: f32 = weights.iter().sum();
222 if sum > 0.0 {
223 for w in weights.iter_mut() {
224 *w /= sum;
225 }
226 }
227}
228
229pub(crate) fn expert_load_fractions(routing_probs: &Array2<f32>) -> Array1<f32> {
234 let num_experts = routing_probs.ncols();
235 let batch_size = routing_probs.nrows();
236 if batch_size == 0 {
237 return Array1::zeros(num_experts);
238 }
239 let col_sums = routing_probs.sum_axis(ndarray::Axis(0));
240 col_sums / batch_size as f32
241}