pub mod router;
#[cfg(test)]
mod tests;
use ndarray::{Array1, Array2};
use serde::{Deserialize, Serialize};
pub use router::{NoisyTopKRouter, RoutingResult, TopKRouter};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MoeConfig {
pub num_experts: usize,
pub top_k: usize,
pub capacity_factor: f32,
pub noise_std: f32,
pub input_dim: usize,
pub hidden_dim: usize,
}
impl Default for MoeConfig {
fn default() -> Self {
Self {
num_experts: 8,
top_k: 2,
capacity_factor: 1.25,
noise_std: 0.0,
input_dim: 64,
hidden_dim: 128,
}
}
}
#[derive(Debug, Clone)]
pub struct Expert {
pub w1: Array2<f32>,
pub b1: Array1<f32>,
pub w2: Array2<f32>,
pub b2: Array1<f32>,
}
impl Expert {
pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
let scale1 = (2.0 / (input_dim + hidden_dim) as f32).sqrt();
let scale2 = (2.0 / (hidden_dim + input_dim) as f32).sqrt();
Self {
w1: Array2::from_shape_fn((input_dim, hidden_dim), |(i, j)| {
((i * hidden_dim + j) as f32 * 0.3141).sin() * scale1
}),
b1: Array1::zeros(hidden_dim),
w2: Array2::from_shape_fn((hidden_dim, input_dim), |(i, j)| {
((i * input_dim + j) as f32 * 0.2718).sin() * scale2
}),
b2: Array1::zeros(input_dim),
}
}
pub fn forward(&self, input: &Array1<f32>) -> Array1<f32> {
let hidden = input.dot(&self.w1) + &self.b1;
let hidden = hidden.mapv(|v| v.max(0.0));
hidden.dot(&self.w2) + &self.b2
}
pub fn forward_batch(&self, input: &Array2<f32>) -> Array2<f32> {
let hidden = input.dot(&self.w1) + &self.b1;
let hidden = hidden.mapv(|v| v.max(0.0));
hidden.dot(&self.w2) + &self.b2
}
}
#[derive(Debug, Clone)]
pub enum Router {
Deterministic(TopKRouter),
Noisy(NoisyTopKRouter),
}
impl Router {
pub fn route(&self, input: &Array2<f32>) -> RoutingResult {
match self {
Router::Deterministic(r) => r.route(input),
Router::Noisy(r) => r.route(input),
}
}
}
#[derive(Debug, Clone)]
pub struct MoeLayer {
pub config: MoeConfig,
pub router: Router,
pub experts: Vec<Expert>,
}
impl MoeLayer {
pub fn new(config: MoeConfig) -> Self {
let router_config = router::RouterConfig {
input_dim: config.input_dim,
num_experts: config.num_experts,
top_k: config.top_k,
capacity_factor: config.capacity_factor,
};
let router = if config.noise_std > 0.0 {
Router::Noisy(NoisyTopKRouter::new(&router_config, config.noise_std))
} else {
Router::Deterministic(TopKRouter::new(&router_config))
};
let experts = (0..config.num_experts)
.map(|_| Expert::new(config.input_dim, config.hidden_dim))
.collect();
Self { config, router, experts }
}
pub fn forward(&self, input: &Array2<f32>) -> (Array2<f32>, RoutingResult) {
let batch_size = input.nrows();
let input_dim = input.ncols();
let routing = self.router.route(input);
let mut output = Array2::zeros((batch_size, input_dim));
for i in 0..batch_size {
let token = input.row(i).to_owned();
let mut combined = Array1::zeros(input_dim);
for (k, &expert_idx) in routing.expert_indices[i].iter().enumerate() {
let weight = routing.expert_weights[i][k];
if weight > 0.0 {
let expert_output = self.experts[expert_idx].forward(&token);
combined += &(expert_output * weight);
}
}
output.row_mut(i).assign(&combined);
}
(output, routing)
}
pub fn balance_loss(&self, routing: &RoutingResult) -> f32 {
let num_experts = self.config.num_experts;
let batch_size = routing.routing_probs.nrows();
if batch_size == 0 {
return 0.0;
}
let mut dispatch_counts = vec![0usize; num_experts];
for token_experts in &routing.expert_indices {
for &expert_idx in token_experts {
dispatch_counts[expert_idx] += 1;
}
}
let total_dispatches: usize = dispatch_counts.iter().sum();
let f: Vec<f32> = dispatch_counts
.iter()
.map(|&c| if total_dispatches > 0 { c as f32 / total_dispatches as f32 } else { 0.0 })
.collect();
let p = router::expert_load_fractions(&routing.routing_probs);
let dot: f32 = f.iter().zip(p.iter()).map(|(fi, pi)| fi * pi).sum();
num_experts as f32 * dot
}
}