use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub enum MoeRoutingError {
EmptyInput,
InvalidNumExperts(String),
InvalidCapacityFactor(String),
DimensionMismatch { expected: usize, got: usize },
}
impl fmt::Display for MoeRoutingError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MoeRoutingError::EmptyInput => {
write!(f, "empty input: no tokens or hidden states provided")
}
MoeRoutingError::InvalidNumExperts(msg) => {
write!(f, "invalid num_experts: {msg}")
}
MoeRoutingError::InvalidCapacityFactor(msg) => {
write!(f, "invalid capacity_factor: {msg}")
}
MoeRoutingError::DimensionMismatch { expected, got } => {
write!(
f,
"dimension mismatch: expected {expected}, got {got}"
)
}
}
}
}
impl std::error::Error for MoeRoutingError {}
#[derive(Debug, Clone, PartialEq)]
pub enum RouterType {
TopK { k: usize },
ExpertChoice { capacity_factor: f32 },
SwitchTransformer,
Hash { num_experts: usize },
RandomRouter { seed: u64 },
}
fn column_softmax(matrix: &[f32], rows: usize, cols: usize) -> Vec<f32> {
let mut out = matrix.to_vec();
for col in 0..cols {
let mut max_val = f32::NEG_INFINITY;
for row in 0..rows {
let v = matrix[row * cols + col];
if v > max_val {
max_val = v;
}
}
let mut sum = 0.0_f32;
for row in 0..rows {
let e = (matrix[row * cols + col] - max_val).exp();
out[row * cols + col] = e;
sum += e;
}
if sum > 1e-10 {
for row in 0..rows {
out[row * cols + col] /= sum;
}
} else {
let uniform = 1.0 / rows as f32;
for row in 0..rows {
out[row * cols + col] = uniform;
}
}
}
out
}
fn row_softmax(matrix: &[f32], rows: usize, cols: usize) -> Vec<f32> {
let mut out = vec![0.0_f32; rows * cols];
for row in 0..rows {
let base = row * cols;
let slice = &matrix[base..base + cols];
let max_val = slice.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_vals: Vec<f32> = slice.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f32 = exp_vals.iter().sum();
let denom = if sum > 1e-10 { sum } else { 1.0 };
for col in 0..cols {
out[base + col] = exp_vals[col] / denom;
}
}
out
}
fn matmul(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
let mut c = vec![0.0_f32; m * n];
for i in 0..m {
for p in 0..k {
let a_ip = a[i * k + p];
for j in 0..n {
c[i * n + j] += a_ip * b[p * n + j];
}
}
}
c
}
fn top_k_desc(vals: &[f32], k: usize) -> Vec<usize> {
let n = vals.len();
let k = k.min(n);
let mut indices: Vec<usize> = (0..n).collect();
for i in 0..k {
let mut best = i;
for j in (i + 1)..n {
if vals[indices[j]] > vals[indices[best]] {
best = j;
}
}
indices.swap(i, best);
}
indices[..k].to_vec()
}
#[derive(Debug, Clone)]
pub struct ExpertChoiceAssignment {
pub expert_token_idx: Vec<Vec<usize>>,
pub expert_weights: Vec<Vec<f32>>,
pub capacity: usize,
}
impl ExpertChoiceAssignment {
pub fn tokens_per_expert(&self) -> Vec<usize> {
self.expert_token_idx.iter().map(|v| v.len()).collect()
}
}
pub struct ExpertChoiceRouter {
pub num_experts: usize,
pub capacity_factor: f32,
pub router_weights: Vec<f32>,
pub hidden_size: usize,
}
impl ExpertChoiceRouter {
pub fn new(num_experts: usize, hidden_size: usize, capacity_factor: f32) -> Self {
let total = hidden_size * num_experts;
let router_weights: Vec<f32> = (0..total)
.map(|i| ((i as f32 + 1.0) / total as f32).sin() * 0.1)
.collect();
Self {
num_experts,
capacity_factor,
router_weights,
hidden_size,
}
}
pub fn compute_affinity(
hidden: &[f32],
weights: &[f32],
seq_len: usize,
hidden_size: usize,
num_experts: usize,
) -> Vec<f32> {
matmul(hidden, weights, seq_len, hidden_size, num_experts)
}
pub fn route(
&self,
hidden_states: &[f32],
seq_len: usize,
hidden_size: usize,
) -> Result<ExpertChoiceAssignment, MoeRoutingError> {
if hidden_states.is_empty() {
return Err(MoeRoutingError::EmptyInput);
}
if self.num_experts == 0 {
return Err(MoeRoutingError::InvalidNumExperts(
"num_experts must be ≥ 1".to_string(),
));
}
if self.capacity_factor <= 0.0 {
return Err(MoeRoutingError::InvalidCapacityFactor(
"capacity_factor must be > 0".to_string(),
));
}
let expected_len = seq_len * hidden_size;
if hidden_states.len() != expected_len {
return Err(MoeRoutingError::DimensionMismatch {
expected: expected_len,
got: hidden_states.len(),
});
}
let capacity =
((self.capacity_factor * seq_len as f32 / self.num_experts as f32).floor() as usize)
.max(1);
let raw_affinity = Self::compute_affinity(
hidden_states,
&self.router_weights,
seq_len,
hidden_size,
self.num_experts,
);
let soft_affinity =
column_softmax(&raw_affinity, seq_len, self.num_experts);
let mut expert_token_idx = Vec::with_capacity(self.num_experts);
let mut expert_weights_out = Vec::with_capacity(self.num_experts);
for e in 0..self.num_experts {
let col: Vec<f32> = (0..seq_len)
.map(|t| soft_affinity[t * self.num_experts + e])
.collect();
let top_tokens = top_k_desc(&col, capacity);
let weights: Vec<f32> = top_tokens.iter().map(|&t| col[t]).collect();
expert_token_idx.push(top_tokens);
expert_weights_out.push(weights);
}
Ok(ExpertChoiceAssignment {
expert_token_idx,
expert_weights: expert_weights_out,
capacity,
})
}
}
pub struct HashRouter {
pub num_experts: usize,
}
impl HashRouter {
pub fn new(num_experts: usize) -> Self {
Self { num_experts }
}
pub fn route_token(token_id: u32, num_experts: usize) -> usize {
const FNV_OFFSET: u32 = 2_166_136_261;
const FNV_PRIME: u32 = 16_777_619;
let bytes = token_id.to_le_bytes();
let mut hash = FNV_OFFSET;
for byte in bytes {
hash ^= byte as u32;
hash = hash.wrapping_mul(FNV_PRIME);
}
(hash as usize) % num_experts.max(1)
}
pub fn route_batch(&self, token_ids: &[u32]) -> Vec<usize> {
token_ids
.iter()
.map(|&id| Self::route_token(id, self.num_experts))
.collect()
}
}
#[derive(Debug, Clone)]
pub struct SwitchAssignment {
pub expert_assignments: Vec<Option<usize>>,
pub num_dropped: usize,
pub expert_load: Vec<usize>,
}
impl SwitchAssignment {
pub fn drop_rate(&self) -> f32 {
let total = self.expert_assignments.len();
if total == 0 {
return 0.0;
}
self.num_dropped as f32 / total as f32
}
}
pub struct SwitchTransformerRouter {
pub num_experts: usize,
pub capacity_factor: f32,
pub router_weights: Vec<f32>,
pub hidden_size: usize,
}
impl SwitchTransformerRouter {
pub fn new(num_experts: usize, hidden_size: usize, capacity_factor: f32) -> Self {
let total = hidden_size * num_experts;
let router_weights: Vec<f32> = (0..total)
.map(|i| ((i as f32 * 0.3 + 0.7) / total as f32).cos() * 0.1)
.collect();
Self {
num_experts,
capacity_factor,
router_weights,
hidden_size,
}
}
pub fn route(
&self,
hidden: &[f32],
seq_len: usize,
hidden_size: usize,
) -> Result<SwitchAssignment, MoeRoutingError> {
if hidden.is_empty() {
return Err(MoeRoutingError::EmptyInput);
}
if self.num_experts == 0 {
return Err(MoeRoutingError::InvalidNumExperts(
"num_experts must be ≥ 1".to_string(),
));
}
if self.capacity_factor <= 0.0 {
return Err(MoeRoutingError::InvalidCapacityFactor(
"capacity_factor must be > 0".to_string(),
));
}
let expected = seq_len * hidden_size;
if hidden.len() != expected {
return Err(MoeRoutingError::DimensionMismatch {
expected,
got: hidden.len(),
});
}
let capacity = ((self.capacity_factor * seq_len as f32 / self.num_experts as f32).ceil()
as usize)
.max(1);
let logits = matmul(hidden, &self.router_weights, seq_len, hidden_size, self.num_experts);
let probs = row_softmax(&logits, seq_len, self.num_experts);
let mut expert_load = vec![0usize; self.num_experts];
let mut expert_assignments: Vec<Option<usize>> = Vec::with_capacity(seq_len);
let mut num_dropped = 0usize;
for t in 0..seq_len {
let row = &probs[t * self.num_experts..(t + 1) * self.num_experts];
let best_expert = row
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
if expert_load[best_expert] < capacity {
expert_load[best_expert] += 1;
expert_assignments.push(Some(best_expert));
} else {
num_dropped += 1;
expert_assignments.push(None);
}
}
Ok(SwitchAssignment {
expert_assignments,
num_dropped,
expert_load,
})
}
pub fn switch_load_balance_loss(
router_probs: &[f32],
expert_assignments: &[usize],
seq_len: usize,
num_experts: usize,
) -> f32 {
if seq_len == 0 || num_experts == 0 {
return 0.0;
}
let mut counts = vec![0usize; num_experts];
for &e in expert_assignments {
if e < num_experts {
counts[e] += 1;
}
}
let fractions: Vec<f32> = counts.iter().map(|&c| c as f32 / seq_len as f32).collect();
let mut prob_sums = vec![0.0_f32; num_experts];
for t in 0..seq_len {
for e in 0..num_experts {
prob_sums[e] += router_probs[t * num_experts + e];
}
}
let mean_probs: Vec<f32> = prob_sums
.iter()
.map(|&s| s / seq_len as f32)
.collect();
let dot: f32 = fractions
.iter()
.zip(mean_probs.iter())
.map(|(&f, &p)| f * p)
.sum();
num_experts as f32 * dot
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expert_choice_capacity_utilization() {
let num_experts = 4;
let hidden_size = 8;
let seq_len = 16;
let capacity_factor = 1.0_f32;
let router = ExpertChoiceRouter::new(num_experts, hidden_size, capacity_factor);
let hidden: Vec<f32> = (0..seq_len * hidden_size)
.map(|i| (i as f32) * 0.01)
.collect();
let assignment = router
.route(&hidden, seq_len, hidden_size)
.expect("route should succeed");
let expected_capacity =
((capacity_factor * seq_len as f32 / num_experts as f32).floor() as usize).max(1);
assert_eq!(assignment.capacity, expected_capacity);
let tpe = assignment.tokens_per_expert();
for (e, &cnt) in tpe.iter().enumerate() {
assert_eq!(
cnt, expected_capacity,
"expert {e} should process exactly {expected_capacity} tokens, got {cnt}"
);
}
}
#[test]
fn test_expert_choice_no_dropped_tokens() {
let num_experts = 2;
let hidden_size = 4;
let seq_len = 8;
let router = ExpertChoiceRouter::new(num_experts, hidden_size, 1.0);
let hidden: Vec<f32> = (0..seq_len * hidden_size).map(|i| i as f32 * 0.1).collect();
let assignment = router
.route(&hidden, seq_len, hidden_size)
.expect("route ok");
for tokens in &assignment.expert_token_idx {
assert!(!tokens.is_empty(), "each expert must have at least one token");
}
}
#[test]
fn test_expert_choice_token_indices_in_range() {
let num_experts = 3;
let hidden_size = 6;
let seq_len = 12;
let router = ExpertChoiceRouter::new(num_experts, hidden_size, 1.0);
let hidden: Vec<f32> = (0..seq_len * hidden_size).map(|i| i as f32 * 0.05).collect();
let assignment = router
.route(&hidden, seq_len, hidden_size)
.expect("route ok");
for tokens in &assignment.expert_token_idx {
for &t in tokens {
assert!(t < seq_len, "token index {t} out of range [0, {seq_len})");
}
}
}
#[test]
fn test_expert_choice_affinity_matrix_dimensions() {
let num_experts = 4;
let hidden_size = 8;
let seq_len = 10;
let router = ExpertChoiceRouter::new(num_experts, hidden_size, 1.0);
let hidden: Vec<f32> = vec![0.1; seq_len * hidden_size];
let affinity = ExpertChoiceRouter::compute_affinity(
&hidden,
&router.router_weights,
seq_len,
hidden_size,
num_experts,
);
assert_eq!(
affinity.len(),
seq_len * num_experts,
"affinity matrix should have {} elements, got {}",
seq_len * num_experts,
affinity.len()
);
}
#[test]
fn test_expert_choice_error_empty_input() {
let router = ExpertChoiceRouter::new(4, 8, 1.0);
let err = router.route(&[], 0, 8).unwrap_err();
assert_eq!(err, MoeRoutingError::EmptyInput);
}
#[test]
fn test_expert_choice_error_dimension_mismatch() {
let router = ExpertChoiceRouter::new(4, 8, 1.0);
let hidden = vec![0.0_f32; 10];
let err = router.route(&hidden, 2, 8).unwrap_err();
matches!(err, MoeRoutingError::DimensionMismatch { .. });
}
#[test]
fn test_hash_router_consistency() {
let router = HashRouter::new(8);
let token_id: u32 = 42;
let first = HashRouter::route_token(token_id, 8);
for _ in 0..100 {
assert_eq!(
HashRouter::route_token(token_id, 8),
first,
"hash routing must be deterministic"
);
}
}
#[test]
fn test_hash_router_batch_coverage() {
let num_experts = 8;
let router = HashRouter::new(num_experts);
let token_ids: Vec<u32> = (0..256).collect();
let assignments = router.route_batch(&token_ids);
assert_eq!(assignments.len(), token_ids.len());
let mut seen = vec![false; num_experts];
for &e in &assignments {
assert!(e < num_experts, "expert index out of range");
seen[e] = true;
}
let covered = seen.iter().filter(|&&b| b).count();
assert!(
covered >= num_experts,
"only {covered}/{num_experts} experts were assigned at least one token"
);
}
#[test]
fn test_hash_router_in_range() {
for num_experts in 1..=16 {
for token_id in 0..64_u32 {
let e = HashRouter::route_token(token_id, num_experts);
assert!(
e < num_experts,
"expert {e} out of range for num_experts={num_experts}"
);
}
}
}
#[test]
fn test_switch_drop_rate_within_capacity() {
let num_experts = 4;
let hidden_size = 8;
let seq_len = 8;
let router = SwitchTransformerRouter::new(num_experts, hidden_size, num_experts as f32);
let hidden: Vec<f32> = (0..seq_len * hidden_size).map(|i| i as f32 * 0.01).collect();
let assignment = router
.route(&hidden, seq_len, hidden_size)
.expect("route ok");
assert_eq!(
assignment.drop_rate(),
0.0,
"no tokens should be dropped when capacity equals seq_len"
);
}
#[test]
fn test_switch_drop_rate_with_tight_capacity() {
let num_experts = 8;
let hidden_size = 4;
let seq_len = 16;
let router = SwitchTransformerRouter::new(num_experts, hidden_size, 0.5);
let hidden: Vec<f32> = (0..seq_len * hidden_size).map(|i| i as f32 * 0.05).collect();
let assignment = router
.route(&hidden, seq_len, hidden_size)
.expect("route ok");
let total_capacity: usize = assignment.expert_load.iter().sum();
assert!(
total_capacity <= seq_len,
"total accepted must not exceed seq_len"
);
}
#[test]
fn test_switch_load_balance_loss_positive() {
let num_experts = 2;
let seq_len = 4;
let router_probs = vec![0.5_f32, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]; let expert_assignments = vec![0, 1, 0, 1];
let loss = SwitchTransformerRouter::switch_load_balance_loss(
&router_probs,
&expert_assignments,
seq_len,
num_experts,
);
assert!(
(loss - 1.0).abs() < 1e-5,
"expected loss ≈ 1.0, got {loss}"
);
}
#[test]
fn test_switch_assignment_drop_rate_calculation() {
let assignment = SwitchAssignment {
expert_assignments: vec![Some(0), None, Some(1), None, Some(0)],
num_dropped: 2,
expert_load: vec![2, 1],
};
let dr = assignment.drop_rate();
assert!(
(dr - 0.4).abs() < 1e-5,
"expected drop rate 0.4, got {dr}"
);
}
#[test]
fn test_routing_error_display() {
let e1 = MoeRoutingError::EmptyInput;
assert!(e1.to_string().contains("empty"));
let e2 = MoeRoutingError::InvalidNumExperts("must be ≥ 1".to_string());
assert!(e2.to_string().contains("num_experts"));
let e3 = MoeRoutingError::InvalidCapacityFactor("must be > 0".to_string());
assert!(e3.to_string().contains("capacity_factor"));
let e4 = MoeRoutingError::DimensionMismatch {
expected: 64,
got: 32,
};
let s4 = e4.to_string();
assert!(s4.contains("64") && s4.contains("32"));
}
#[test]
fn test_switch_error_empty_input() {
let router = SwitchTransformerRouter::new(4, 8, 1.25);
let err = router.route(&[], 0, 8).unwrap_err();
assert_eq!(err, MoeRoutingError::EmptyInput);
}
}