use std::collections::HashMap;
use trustformers_core::{errors::Result, layers::Linear, tensor::Tensor, traits::Layer};
#[derive(Debug, Clone)]
pub struct MoEConfig {
pub hidden_size: usize,
pub num_experts: usize,
pub num_experts_per_token: usize,
pub expert_capacity: Option<usize>, pub load_balancing_loss_coeff: f32,
pub router_z_loss_coeff: f32,
pub use_auxiliary_loss: bool,
pub jitter_noise: f32, }
impl Default for MoEConfig {
fn default() -> Self {
Self {
hidden_size: 4096,
num_experts: 8,
num_experts_per_token: 2,
expert_capacity: None,
load_balancing_loss_coeff: 0.01,
router_z_loss_coeff: 0.001,
use_auxiliary_loss: true,
jitter_noise: 1e-2,
}
}
}
#[derive(Debug, Clone)]
pub struct RoutingStats {
pub expert_counts: Vec<f32>,
pub expert_weights: Vec<f32>,
pub load_balancing_loss: f32,
pub router_z_loss: f32,
}
pub trait Expert: Layer<Input = Tensor, Output = Tensor> + Send + Sync {
fn expert_id(&self) -> usize;
fn capacity(&self) -> Option<usize> {
None
}
}
pub struct MLPExpert {
id: usize,
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
activation: String,
}
impl MLPExpert {
pub fn new(
id: usize,
hidden_size: usize,
intermediate_size: usize,
activation: String,
) -> Result<Self> {
let gate_proj = Linear::new(hidden_size, intermediate_size, false);
let up_proj = Linear::new(hidden_size, intermediate_size, false);
let down_proj = Linear::new(intermediate_size, hidden_size, false);
Ok(Self {
id,
gate_proj,
up_proj,
down_proj,
activation,
})
}
fn apply_activation(&self, x: &Tensor) -> Result<Tensor> {
match self.activation.as_str() {
"silu" | "swish" => x.silu(),
"gelu" => x.gelu(),
"relu" => x.relu(),
_ => Ok(x.clone()),
}
}
}
impl Layer for MLPExpert {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let gate = self.gate_proj.forward(input.clone())?;
let gate_activated = self.apply_activation(&gate)?;
let up = self.up_proj.forward(input)?;
let gated = gate_activated.mul(&up)?;
self.down_proj.forward(gated)
}
}
impl Expert for MLPExpert {
fn expert_id(&self) -> usize {
self.id
}
}
pub struct TopKRouter {
gate: Linear,
config: MoEConfig,
}
impl TopKRouter {
pub fn new(config: MoEConfig) -> Result<Self> {
let gate = Linear::new(config.hidden_size, config.num_experts, false);
Ok(Self { gate, config })
}
pub fn route(&self, hidden_states: &Tensor) -> Result<RouterOutput> {
let batch_size = hidden_states.shape()[0];
let seq_len = hidden_states.shape()[1];
let hidden_size = hidden_states.shape()[2];
let flattened = hidden_states.reshape(&[batch_size * seq_len, hidden_size])?;
let router_logits = self.gate.forward(flattened)?;
let router_logits = if self.config.jitter_noise > 0.0 {
let noise = Tensor::randn_like(&router_logits)?.mul_scalar(self.config.jitter_noise)?;
router_logits.add(&noise)?
} else {
router_logits
};
let router_probs = router_logits.softmax(-1)?;
let (top_k_weights, top_k_indices) = self.select_top_k(&router_probs)?;
let stats = self.compute_routing_stats(&router_probs, &top_k_weights, &top_k_indices)?;
Ok(RouterOutput {
top_k_weights,
top_k_indices,
router_probs,
stats,
})
}
fn select_top_k(&self, router_probs: &Tensor) -> Result<(Tensor, Tensor)> {
let num_tokens = router_probs.shape()[0];
let num_experts = router_probs.shape()[1];
let mut all_weights = Vec::new();
let mut all_indices = Vec::new();
for token_idx in 0..num_tokens {
let mut expert_probs: Vec<(f32, usize)> = Vec::new();
for expert_idx in 0..num_experts {
let prob = router_probs.get_scalar(&[token_idx, expert_idx])?;
expert_probs.push((prob, expert_idx));
}
expert_probs.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("operation failed"));
expert_probs.truncate(self.config.num_experts_per_token);
let sum: f32 = expert_probs.iter().map(|(p, _)| p).sum();
let norm_factor = if sum > 0.0 { 1.0 / sum } else { 1.0 };
for (prob, expert_idx) in expert_probs {
all_weights.push(prob * norm_factor);
all_indices.push(expert_idx as f32);
}
}
let weights_tensor = Tensor::from_vec(
all_weights,
&[num_tokens, self.config.num_experts_per_token],
)?;
let indices_tensor = Tensor::from_vec(
all_indices,
&[num_tokens, self.config.num_experts_per_token],
)?;
Ok((weights_tensor, indices_tensor))
}
fn compute_routing_stats(
&self,
router_probs: &Tensor,
top_k_weights: &Tensor,
top_k_indices: &Tensor,
) -> Result<RoutingStats> {
let num_tokens = router_probs.shape()[0];
let num_experts = self.config.num_experts;
let mut expert_counts = vec![0.0; num_experts];
let mut expert_weights = vec![0.0; num_experts];
for token_idx in 0..num_tokens {
for k in 0..self.config.num_experts_per_token {
let expert_idx = top_k_indices.get_scalar(&[token_idx, k])? as usize;
let weight = top_k_weights.get_scalar(&[token_idx, k])?;
expert_counts[expert_idx] += 1.0;
expert_weights[expert_idx] += weight;
}
}
let total_tokens = num_tokens as f32;
expert_counts.iter_mut().for_each(|c| *c /= total_tokens);
expert_weights.iter_mut().for_each(|w| *w /= total_tokens);
let _mean_count = 1.0 / num_experts as f32;
let load_balancing_loss: f32 = expert_counts
.iter()
.zip(expert_weights.iter())
.map(|(count, weight)| count * weight)
.sum::<f32>()
* num_experts as f32
- 1.0;
let router_z_loss = router_probs
.pow(2.0)?
.sum(Some(vec![router_probs.shape().len() - 1]), false)?
.mean()?
.get_scalar(&[])?;
Ok(RoutingStats {
expert_counts,
expert_weights,
load_balancing_loss,
router_z_loss,
})
}
}
pub struct RouterOutput {
pub top_k_weights: Tensor,
pub top_k_indices: Tensor,
pub router_probs: Tensor,
pub stats: RoutingStats,
}
pub struct SparseMoE<E: Expert> {
experts: Vec<E>,
router: TopKRouter,
config: MoEConfig,
}
impl<E: Expert> SparseMoE<E> {
pub fn new(experts: Vec<E>, config: MoEConfig) -> Result<Self> {
let router = TopKRouter::new(config.clone())?;
Ok(Self {
experts,
router,
config,
})
}
pub fn num_experts(&self) -> usize {
self.experts.len()
}
pub fn last_routing_stats(&self) -> Option<&RoutingStats> {
None
}
}
impl<E: Expert> Layer for SparseMoE<E> {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let batch_size = input.shape()[0];
let seq_len = input.shape()[1];
let hidden_size = input.shape()[2];
let router_output = self.router.route(&input)?;
let flattened_input = input.reshape(&[batch_size * seq_len, hidden_size])?;
let num_tokens = flattened_input.shape()[0];
let mut output = Tensor::zeros(&[num_tokens, hidden_size])?;
for token_idx in 0..num_tokens {
let token_input =
flattened_input.slice_multi(&[(token_idx, token_idx + 1), (0, hidden_size)])?;
let mut token_output = Tensor::zeros(&[1, hidden_size])?;
for k in 0..self.config.num_experts_per_token {
let expert_idx = router_output.top_k_indices.get_scalar(&[token_idx, k])? as usize;
let weight = router_output.top_k_weights.get_scalar(&[token_idx, k])?;
let expert_output = self.experts[expert_idx].forward(token_input.clone())?;
let weighted_output = expert_output.mul_scalar(weight)?;
token_output = token_output.add(&weighted_output)?;
}
let token_output_slice =
output.slice_multi(&[(token_idx, token_idx + 1), (0, hidden_size)])?;
let updated_slice = token_output_slice.add(&token_output)?;
if token_idx == 0 {
output = updated_slice.clone();
} else {
let current_tokens = output.slice_multi(&[(0, token_idx), (0, hidden_size)])?;
let remaining_shape = if token_idx + 1 < num_tokens {
Some(output.slice_multi(&[(token_idx + 1, num_tokens), (0, hidden_size)])?)
} else {
None
};
output = if let Some(remaining) = remaining_shape {
Tensor::concat(&[current_tokens, updated_slice, remaining], 0)?
} else {
Tensor::concat(&[current_tokens, updated_slice], 0)?
};
}
}
output.reshape(&[batch_size, seq_len, hidden_size])
}
}
pub type SwitchMoE<E> = SparseMoE<E>;
pub fn switch_config(hidden_size: usize, num_experts: usize) -> MoEConfig {
MoEConfig {
hidden_size,
num_experts,
num_experts_per_token: 1, ..Default::default()
}
}
pub fn glam_config(hidden_size: usize, num_experts: usize) -> MoEConfig {
MoEConfig {
hidden_size,
num_experts,
num_experts_per_token: 2, ..Default::default()
}
}
pub struct ExpertParallel<E: Expert> {
local_experts: Vec<E>,
expert_mapping: HashMap<usize, usize>, #[allow(dead_code)]
rank: usize,
#[allow(dead_code)]
world_size: usize,
}
impl<E: Expert> ExpertParallel<E> {
pub fn new(experts: Vec<E>, rank: usize, world_size: usize) -> Self {
let mut expert_mapping = HashMap::new();
for (local_id, expert) in experts.iter().enumerate() {
expert_mapping.insert(expert.expert_id(), local_id);
}
Self {
local_experts: experts,
expert_mapping,
rank,
world_size,
}
}
pub fn has_expert(&self, expert_id: usize) -> bool {
self.expert_mapping.contains_key(&expert_id)
}
pub fn forward_local(&self, expert_id: usize, input: &Tensor) -> Result<Option<Tensor>> {
if let Some(&local_id) = self.expert_mapping.get(&expert_id) {
let output = self.local_experts[local_id].forward(input.clone())?;
Ok(Some(output))
} else {
Ok(None)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_moe_config_default_values() {
let cfg = MoEConfig::default();
assert_eq!(cfg.hidden_size, 4096);
assert_eq!(cfg.num_experts, 8);
assert_eq!(cfg.num_experts_per_token, 2);
assert!(cfg.expert_capacity.is_none());
assert!((cfg.load_balancing_loss_coeff - 0.01).abs() < 1e-7);
assert!((cfg.router_z_loss_coeff - 0.001).abs() < 1e-7);
assert!(cfg.use_auxiliary_loss);
assert!((cfg.jitter_noise - 1e-2).abs() < 1e-7);
}
#[test]
fn test_moe_config_clone() {
let cfg = MoEConfig::default();
let c = cfg.clone();
assert_eq!(c.num_experts, cfg.num_experts);
assert_eq!(c.hidden_size, cfg.hidden_size);
}
#[test]
fn test_moe_config_custom() {
let cfg = MoEConfig {
hidden_size: 512,
num_experts: 16,
num_experts_per_token: 4,
expert_capacity: Some(32),
load_balancing_loss_coeff: 0.05,
router_z_loss_coeff: 0.005,
use_auxiliary_loss: false,
jitter_noise: 0.0,
};
assert_eq!(cfg.hidden_size, 512);
assert_eq!(cfg.num_experts, 16);
assert_eq!(cfg.expert_capacity, Some(32));
assert!(!cfg.use_auxiliary_loss);
}
#[test]
fn test_switch_config_top1() {
let cfg = switch_config(256, 4);
assert_eq!(cfg.hidden_size, 256);
assert_eq!(cfg.num_experts, 4);
assert_eq!(cfg.num_experts_per_token, 1, "Switch uses top-1");
}
#[test]
fn test_glam_config_top2() {
let cfg = glam_config(512, 8);
assert_eq!(cfg.hidden_size, 512);
assert_eq!(cfg.num_experts, 8);
assert_eq!(cfg.num_experts_per_token, 2, "GLaM uses top-2");
}
#[test]
fn test_routing_stats_construction() {
let stats = RoutingStats {
expert_counts: vec![10.0, 15.0, 8.0, 12.0],
expert_weights: vec![0.3, 0.35, 0.15, 0.2],
load_balancing_loss: 0.002,
router_z_loss: 0.0005,
};
assert_eq!(stats.expert_counts.len(), 4);
assert!((stats.load_balancing_loss - 0.002).abs() < 1e-7);
}
#[test]
fn test_routing_stats_clone() {
let stats = RoutingStats {
expert_counts: vec![1.0, 2.0],
expert_weights: vec![0.5, 0.5],
load_balancing_loss: 0.01,
router_z_loss: 0.001,
};
let c = stats.clone();
assert_eq!(c.expert_counts, stats.expert_counts);
}
#[test]
fn test_mlp_expert_construction_id() {
let expert = MLPExpert::new(3, 64, 128, "silu".to_string())
.expect("MLPExpert creation should succeed");
assert_eq!(expert.expert_id(), 3);
}
#[test]
fn test_mlp_expert_id_zero() {
let expert = MLPExpert::new(0, 32, 64, "gelu".to_string())
.expect("MLPExpert creation should succeed");
assert_eq!(expert.expert_id(), 0);
}
#[test]
fn test_mlp_expert_capacity_default_none() {
let expert = MLPExpert::new(1, 64, 128, "relu".to_string())
.expect("MLPExpert creation should succeed");
assert!(expert.capacity().is_none());
}
#[test]
fn test_mlp_expert_forward_output_shape() {
let expert = MLPExpert::new(0, 8, 16, "silu".to_string())
.expect("MLPExpert creation should succeed");
let input = Tensor::zeros(&[4, 8]).expect("tensor creation should succeed");
let out = expert.forward(input).expect("expert forward should succeed");
assert_eq!(
out.shape(),
&[4, 8],
"output shape should match hidden_size"
);
}
#[test]
fn test_mlp_expert_forward_gelu() {
let expert = MLPExpert::new(2, 8, 16, "gelu".to_string())
.expect("MLPExpert creation should succeed");
let input = Tensor::zeros(&[2, 8]).expect("tensor creation should succeed");
let out = expert.forward(input).expect("gelu expert forward should succeed");
assert_eq!(out.shape(), &[2, 8]);
}
#[test]
fn test_mlp_expert_forward_relu() {
let expert =
MLPExpert::new(0, 4, 8, "relu".to_string()).expect("MLPExpert creation should succeed");
let input = Tensor::zeros(&[1, 4]).expect("tensor creation should succeed");
let out = expert.forward(input).expect("relu expert forward should succeed");
assert_eq!(out.shape(), &[1, 4]);
}
#[test]
fn test_mlp_expert_forward_unknown_activation() {
let expert = MLPExpert::new(0, 4, 8, "unknown_act".to_string())
.expect("MLPExpert creation should succeed");
let input = Tensor::zeros(&[1, 4]).expect("tensor creation should succeed");
let out = expert.forward(input).expect("expert with unknown act should succeed");
assert_eq!(out.shape(), &[1, 4]);
}
#[test]
fn test_top_k_router_construction() {
let cfg = MoEConfig {
hidden_size: 16,
num_experts: 4,
num_experts_per_token: 2,
expert_capacity: None,
load_balancing_loss_coeff: 0.01,
router_z_loss_coeff: 0.001,
use_auxiliary_loss: true,
jitter_noise: 0.0,
};
let _router = TopKRouter::new(cfg).expect("TopKRouter construction should succeed");
}
#[test]
fn test_expert_parallel_has_expert() {
let experts: Vec<MLPExpert> = (0..4)
.map(|i| {
MLPExpert::new(i, 8, 16, "silu".to_string())
.expect("expert creation should succeed")
})
.collect();
let ep = ExpertParallel::new(experts, 0, 2);
assert!(ep.has_expert(0));
assert!(ep.has_expert(1));
assert!(ep.has_expert(2));
assert!(ep.has_expert(3));
assert!(!ep.has_expert(4));
}
#[test]
fn test_expert_parallel_forward_local_present() {
let experts: Vec<MLPExpert> = (0..2)
.map(|i| {
MLPExpert::new(i, 8, 16, "silu".to_string())
.expect("expert creation should succeed")
})
.collect();
let ep = ExpertParallel::new(experts, 0, 1);
let input = Tensor::zeros(&[1, 8]).expect("tensor creation should succeed");
let out = ep.forward_local(0, &input).expect("forward_local should succeed");
assert!(out.is_some(), "expert 0 should be local");
}
#[test]
fn test_expert_parallel_forward_local_absent() {
let experts: Vec<MLPExpert> = (0..2)
.map(|i| {
MLPExpert::new(i, 8, 16, "silu".to_string())
.expect("expert creation should succeed")
})
.collect();
let ep = ExpertParallel::new(experts, 0, 2);
let input = Tensor::zeros(&[1, 8]).expect("tensor creation should succeed");
let out = ep.forward_local(99, &input).expect("forward_local for absent should succeed");
assert!(out.is_none(), "expert 99 should not be local");
}
#[test]
fn test_expert_parallel_world_size_mapping() {
let all_experts: Vec<MLPExpert> = (0..4)
.map(|i| {
MLPExpert::new(i, 8, 16, "silu".to_string())
.expect("expert creation should succeed")
})
.collect();
let ep = ExpertParallel::new(all_experts, 1, 2);
for i in 0..4usize {
assert!(ep.has_expert(i), "expert {} should exist", i);
}
}
}