use ndarray::{Array1, Array2};
use rand::Rng;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct RoutingResult {
pub expert_indices: Vec<Vec<usize>>,
pub expert_weights: Vec<Vec<f32>>,
pub routing_probs: Array2<f32>,
}
#[derive(Debug, Clone)]
pub struct TopKRouter {
pub gate_weight: Array2<f32>,
pub top_k: usize,
pub capacity_factor: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterConfig {
pub input_dim: usize,
pub num_experts: usize,
pub top_k: usize,
pub capacity_factor: f32,
}
impl TopKRouter {
pub fn new(config: &RouterConfig) -> Self {
let scale = (2.0 / (config.input_dim + config.num_experts) as f32).sqrt();
let gate_weight =
Array2::from_shape_fn((config.input_dim, config.num_experts), |(i, j)| {
((i * config.num_experts + j) as f32 * 0.4567).sin() * scale
});
Self { gate_weight, top_k: config.top_k, capacity_factor: config.capacity_factor }
}
pub fn route(&self, input: &Array2<f32>) -> RoutingResult {
let batch_size = input.nrows();
let num_experts = self.gate_weight.ncols();
let logits = input.dot(&self.gate_weight);
let routing_probs = softmax_rows(&logits);
let capacity = capacity_limit(batch_size, self.top_k, num_experts, self.capacity_factor);
let (expert_indices, expert_weights) =
select_top_k_with_capacity(&routing_probs, self.top_k, capacity);
RoutingResult { expert_indices, expert_weights, routing_probs }
}
}
#[derive(Debug, Clone)]
pub struct NoisyTopKRouter {
pub inner: TopKRouter,
pub noise_std: f32,
}
impl NoisyTopKRouter {
pub fn new(config: &RouterConfig, noise_std: f32) -> Self {
Self { inner: TopKRouter::new(config), noise_std }
}
pub fn route(&self, input: &Array2<f32>) -> RoutingResult {
let batch_size = input.nrows();
let num_experts = self.inner.gate_weight.ncols();
let mut logits = input.dot(&self.inner.gate_weight);
let mut rng = rand::rng();
for val in &mut logits {
let noise: f32 = rng.random::<f32>() * 2.0 - 1.0; *val += noise * self.noise_std;
}
let routing_probs = softmax_rows(&logits);
let capacity =
capacity_limit(batch_size, self.inner.top_k, num_experts, self.inner.capacity_factor);
let (expert_indices, expert_weights) =
select_top_k_with_capacity(&routing_probs, self.inner.top_k, capacity);
RoutingResult { expert_indices, expert_weights, routing_probs }
}
}
pub(crate) fn softmax_rows(logits: &Array2<f32>) -> Array2<f32> {
let mut result = logits.clone();
for mut row in result.rows_mut() {
let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
row.mapv_inplace(|v| (v - max_val).exp());
let sum: f32 = row.iter().sum();
if sum > 0.0 {
row.mapv_inplace(|v| v / sum);
}
}
result
}
pub(crate) fn capacity_limit(
batch_size: usize,
top_k: usize,
num_experts: usize,
capacity_factor: f32,
) -> usize {
let raw = capacity_factor * (batch_size * top_k) as f32 / num_experts as f32;
raw.ceil().max(1.0) as usize
}
fn select_top_k_with_capacity(
probs: &Array2<f32>,
top_k: usize,
capacity: usize,
) -> (Vec<Vec<usize>>, Vec<Vec<f32>>) {
let batch_size = probs.nrows();
let num_experts = probs.ncols();
let mut expert_counts = vec![0usize; num_experts];
let mut all_indices = Vec::with_capacity(batch_size);
let mut all_weights = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let row: Vec<f32> = probs.row(i).to_vec();
let (indices, weights) = assign_token_experts(&row, top_k, capacity, &mut expert_counts);
all_indices.push(indices);
all_weights.push(weights);
}
(all_indices, all_weights)
}
fn assign_token_experts(
row: &[f32],
top_k: usize,
capacity: usize,
expert_counts: &mut [usize],
) -> (Vec<usize>, Vec<f32>) {
let mut sorted: Vec<(usize, f32)> = row.iter().copied().enumerate().collect();
sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut indices = Vec::with_capacity(top_k);
let mut weights = Vec::with_capacity(top_k);
for &(expert_idx, weight) in &sorted {
if indices.len() >= top_k {
break;
}
if expert_counts[expert_idx] < capacity {
indices.push(expert_idx);
weights.push(weight);
expert_counts[expert_idx] += 1;
}
}
pad_assignments(&mut indices, &mut weights, top_k);
renormalize_weights(&mut weights);
(indices, weights)
}
fn pad_assignments(indices: &mut Vec<usize>, weights: &mut Vec<f32>, top_k: usize) {
while indices.len() < top_k {
if let Some(&last_idx) = indices.last() {
indices.push(last_idx);
weights.push(0.0);
} else {
indices.push(0);
weights.push(1.0 / top_k as f32);
}
}
}
fn renormalize_weights(weights: &mut [f32]) {
let sum: f32 = weights.iter().sum();
if sum > 0.0 {
for w in weights.iter_mut() {
*w /= sum;
}
}
}
pub(crate) fn expert_load_fractions(routing_probs: &Array2<f32>) -> Array1<f32> {
let num_experts = routing_probs.ncols();
let batch_size = routing_probs.nrows();
if batch_size == 0 {
return Array1::zeros(num_experts);
}
let col_sums = routing_probs.sum_axis(ndarray::Axis(0));
col_sums / batch_size as f32
}