use ndarray::Array2;
use super::error::{MoeError, MoeResult};
#[derive(Debug, Clone, PartialEq)]
pub struct BatchGatingStats {
pub gate_scores_per_token: Array2<f64>,
pub routed_expert_per_token: Vec<usize>,
}
impl BatchGatingStats {
pub fn empty(batch_size: usize, num_experts: usize) -> Self {
Self {
gate_scores_per_token: Array2::<f64>::zeros((batch_size, num_experts)),
routed_expert_per_token: Vec::with_capacity(batch_size),
}
}
pub fn batch_size(&self) -> usize {
self.gate_scores_per_token.nrows()
}
pub fn num_experts(&self) -> usize {
self.gate_scores_per_token.ncols()
}
pub fn expert_counts(&self) -> Vec<f64> {
let mut counts = vec![0.0_f64; self.num_experts()];
for &idx in &self.routed_expert_per_token {
if idx < counts.len() {
counts[idx] += 1.0;
}
}
counts
}
pub fn expert_importance(&self) -> Vec<f64> {
let mut sums = vec![0.0_f64; self.num_experts()];
for row in self.gate_scores_per_token.rows() {
for (i, v) in row.iter().enumerate() {
sums[i] += *v;
}
}
sums
}
}
fn cv_squared(values: &[f64]) -> f64 {
if values.is_empty() {
return 0.0;
}
let n = values.len() as f64;
let mean: f64 = values.iter().sum::<f64>() / n;
if mean.abs() <= f64::EPSILON {
return 0.0;
}
let variance: f64 = values
.iter()
.map(|v| {
let d = *v - mean;
d * d
})
.sum::<f64>()
/ n;
variance / (mean * mean)
}
pub fn importance_loss(stats: &BatchGatingStats) -> MoeResult<f64> {
if stats.num_experts() == 0 {
return Err(MoeError::EmptyExpertPool);
}
let importances = stats.expert_importance();
Ok(cv_squared(&importances))
}
pub fn load_loss(stats: &BatchGatingStats) -> MoeResult<f64> {
if stats.num_experts() == 0 {
return Err(MoeError::EmptyExpertPool);
}
let counts = stats.expert_counts();
Ok(cv_squared(&counts))
}
pub fn combined_aux_loss(stats: &BatchGatingStats, alpha: f64) -> MoeResult<f64> {
let l_imp = importance_loss(stats)?;
let l_load = load_loss(stats)?;
Ok(alpha * (l_imp + l_load))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cv_squared_zero_for_uniform() {
let uniform = vec![1.0, 1.0, 1.0, 1.0];
assert!(cv_squared(&uniform).abs() < 1e-12);
}
#[test]
fn cv_squared_positive_for_skew() {
let skewed = vec![4.0, 0.0, 0.0, 0.0];
assert!(cv_squared(&skewed) > 0.0);
}
#[test]
fn cv_squared_zero_on_zero_mean() {
let zeros = vec![0.0, 0.0, 0.0];
assert!(cv_squared(&zeros).abs() < 1e-12);
}
}