mod router;
use crate::learner::StreamingLearner;
use irithyll_core::error::ConfigError;
use router::LinearRouter;
struct ExpertSlot {
model: Box<dyn StreamingLearner>,
warmup_hint: usize,
utilization_ewma: f64,
samples_trained: u64,
}
#[derive(Debug, Clone)]
pub struct NeuralMoEConfig {
pub top_k: usize,
pub router_lr: f64,
pub load_balance_rate: f64,
pub utilization_span: usize,
pub utilization_threshold: f64,
pub reset_dead: bool,
pub seed: u64,
}
impl Default for NeuralMoEConfig {
fn default() -> Self {
Self {
top_k: 2,
router_lr: 0.01,
load_balance_rate: 0.01,
utilization_span: 500,
utilization_threshold: 0.01,
reset_dead: true,
seed: 42,
}
}
}
pub struct NeuralMoE {
experts: Vec<ExpertSlot>,
router: LinearRouter,
config: NeuralMoEConfig,
n_samples: u64,
cached_disagreement: f64,
prev_prediction: f64,
prev_change: f64,
prev_prev_change: f64,
alignment_ewma: f64,
gate_entropy_ewma: f64,
}
pub struct NeuralMoEBuilder {
experts: Vec<(Box<dyn StreamingLearner>, usize)>, config: NeuralMoEConfig,
}
impl NeuralMoE {
pub fn builder() -> NeuralMoEBuilder {
NeuralMoEBuilder {
experts: Vec::new(),
config: NeuralMoEConfig::default(),
}
}
}
impl NeuralMoEBuilder {
pub fn expert(mut self, model: impl StreamingLearner + 'static) -> Self {
self.experts.push((Box::new(model), 0));
self
}
pub fn expert_with_warmup(
mut self,
model: impl StreamingLearner + 'static,
warmup: usize,
) -> Self {
self.experts.push((Box::new(model), warmup));
self
}
pub fn top_k(mut self, k: usize) -> Self {
self.config.top_k = k;
self
}
pub fn router_lr(mut self, lr: f64) -> Self {
self.config.router_lr = lr;
self
}
pub fn load_balance_rate(mut self, r: f64) -> Self {
self.config.load_balance_rate = r;
self
}
pub fn utilization_span(mut self, s: usize) -> Self {
self.config.utilization_span = s;
self
}
pub fn utilization_threshold(mut self, t: f64) -> Self {
self.config.utilization_threshold = t;
self
}
pub fn reset_dead(mut self, b: bool) -> Self {
self.config.reset_dead = b;
self
}
pub fn seed(mut self, s: u64) -> Self {
self.config.seed = s;
self
}
pub fn build(self) -> Result<NeuralMoE, ConfigError> {
let n_experts = self.experts.len();
if n_experts < 2 {
return Err(ConfigError::out_of_range(
"n_experts",
"must be >= 2",
n_experts,
));
}
let config = &self.config;
if config.top_k == 0 {
return Err(ConfigError::out_of_range(
"top_k",
"must be >= 1",
config.top_k,
));
}
if config.top_k > n_experts {
return Err(ConfigError::invalid(
"top_k",
format!("must be <= n_experts ({}), got {}", n_experts, config.top_k),
));
}
if config.router_lr < 0.0 {
return Err(ConfigError::out_of_range(
"router_lr",
"must be >= 0 (0 = frozen router)",
config.router_lr,
));
}
if config.load_balance_rate < 0.0 || config.load_balance_rate > 1.0 {
return Err(ConfigError::out_of_range(
"load_balance_rate",
"must be in [0, 1]",
config.load_balance_rate,
));
}
if config.utilization_threshold <= 0.0 || config.utilization_threshold >= 1.0 {
return Err(ConfigError::out_of_range(
"utilization_threshold",
"must be in (0, 1)",
config.utilization_threshold,
));
}
if config.utilization_span == 0 {
return Err(ConfigError::out_of_range(
"utilization_span",
"must be >= 1",
config.utilization_span,
));
}
let config = self.config;
let router = LinearRouter::new(
n_experts,
config.router_lr,
config.load_balance_rate,
config.utilization_span,
config.seed,
);
let experts: Vec<ExpertSlot> = self
.experts
.into_iter()
.map(|(model, warmup)| ExpertSlot {
model,
warmup_hint: warmup,
utilization_ewma: 0.0,
samples_trained: 0,
})
.collect();
Ok(NeuralMoE {
experts,
router,
config,
n_samples: 0,
cached_disagreement: 0.0,
prev_prediction: 0.0,
prev_change: 0.0,
prev_prev_change: 0.0,
alignment_ewma: 0.0,
gate_entropy_ewma: 0.0,
})
}
}
impl NeuralMoE {
pub fn n_experts(&self) -> usize {
self.experts.len()
}
pub fn top_k(&self) -> usize {
self.config.top_k
}
pub fn utilization(&self) -> Vec<f64> {
self.experts.iter().map(|e| e.utilization_ewma).collect()
}
pub fn expert_samples(&self) -> Vec<u64> {
self.experts.iter().map(|e| e.samples_trained).collect()
}
pub fn n_dead_experts(&self) -> usize {
self.experts
.iter()
.filter(|e| {
e.samples_trained > self.config.utilization_span as u64
&& e.utilization_ewma < self.config.utilization_threshold
})
.count()
}
pub fn load_distribution(&self) -> &[f64] {
self.router.load_distribution()
}
pub fn expert_disagreement(&self, features: &[f64]) -> f64 {
let preds = self.expert_predictions(features);
if preds.len() < 2 {
return 0.0;
}
let n = preds.len() as f64;
let mean = preds.iter().sum::<f64>() / n;
let var = preds.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / (n - 1.0);
var.sqrt()
}
#[inline]
pub fn cached_disagreement(&self) -> f64 {
self.cached_disagreement
}
pub fn expert_predictions(&self, features: &[f64]) -> Vec<f64> {
self.experts
.iter()
.map(|e| e.model.predict(features))
.collect()
}
pub fn routing_probabilities(&self, features: &[f64]) -> Vec<f64> {
self.router.probabilities(features)
}
pub fn warmup_progress(&self) -> Vec<f64> {
self.experts
.iter()
.map(|e| {
if e.warmup_hint == 0 {
1.0
} else {
(e.samples_trained as f64 / e.warmup_hint as f64).min(1.0)
}
})
.collect()
}
fn warmup_penalties(&self) -> Vec<f64> {
self.experts
.iter()
.map(|e| {
let progress = if e.warmup_hint == 0 {
1.0
} else {
(e.samples_trained as f64 / e.warmup_hint as f64).min(1.0)
};
-5.0 * (1.0 - progress)
})
.collect()
}
}
impl StreamingLearner for NeuralMoE {
fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
let k = self.config.top_k.min(self.experts.len());
let penalties = self.warmup_penalties();
let active_indices = self
.router
.select_top_k_with_penalties(features, k, &penalties);
let mut best_idx = active_indices[0];
let mut best_error = f64::INFINITY;
let mut active_preds: Vec<f64> = Vec::with_capacity(k);
for &idx in &active_indices {
let pred = self.experts[idx].model.predict(features);
active_preds.push(pred);
let error = (target - pred).abs();
if error < best_error {
best_error = error;
best_idx = idx;
}
}
if active_preds.len() >= 2 {
let n = active_preds.len() as f64;
let mean = active_preds.iter().sum::<f64>() / n;
let var = active_preds.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / (n - 1.0);
self.cached_disagreement = var.sqrt();
}
{
let weights = self.router.renormalized_weights_with_penalties(
features,
&active_indices,
&penalties,
);
let mut current_pred = 0.0;
for (idx, w) in &weights {
current_pred +=
w * active_preds[active_indices.iter().position(|&i| i == *idx).unwrap_or(0)];
}
let current_change = current_pred - self.prev_prediction;
if self.n_samples > 0 {
let acceleration = current_change - self.prev_change;
let prev_acceleration = self.prev_change - self.prev_prev_change;
let agreement = if acceleration.abs() > 1e-15 && prev_acceleration.abs() > 1e-15 {
if (acceleration > 0.0) == (prev_acceleration > 0.0) {
1.0
} else {
-1.0
}
} else {
0.0
};
const ALIGN_ALPHA: f64 = 0.05;
self.alignment_ewma =
(1.0 - ALIGN_ALPHA) * self.alignment_ewma + ALIGN_ALPHA * agreement;
}
self.prev_prev_change = self.prev_change;
self.prev_change = current_change;
self.prev_prediction = current_pred;
}
for &idx in &active_indices {
self.experts[idx].model.train_one(features, target, weight);
self.experts[idx].samples_trained += 1;
}
self.router.update(features, best_idx);
self.router.update_load_balance(&active_indices);
let probs = self.router.probabilities(features);
let util_alpha = 2.0 / (self.config.utilization_span as f64 + 1.0);
for (i, slot) in self.experts.iter_mut().enumerate() {
let p = if i < probs.len() { probs[i] } else { 0.0 };
slot.utilization_ewma = util_alpha * p + (1.0 - util_alpha) * slot.utilization_ewma;
}
{
let k_experts = probs.len();
if k_experts > 1 {
let ln_k = (k_experts as f64).ln();
let mut h = 0.0;
for &p in &probs {
if p > 1e-15 {
h -= p * p.ln();
}
}
let normalized_h = (h / ln_k).clamp(0.0, 1.0);
const GATE_ALPHA: f64 = 0.01;
self.gate_entropy_ewma =
(1.0 - GATE_ALPHA) * self.gate_entropy_ewma + GATE_ALPHA * normalized_h;
}
}
if self.config.reset_dead && self.n_samples > self.config.utilization_span as u64 {
self.reset_dead_experts();
}
self.n_samples += 1;
}
fn predict(&self, features: &[f64]) -> f64 {
let k = self.config.top_k.min(self.experts.len());
let penalties = self.warmup_penalties();
let active_indices = self
.router
.select_top_k_with_penalties(features, k, &penalties);
let weights =
self.router
.renormalized_weights_with_penalties(features, &active_indices, &penalties);
let mut pred = 0.0;
for (idx, w) in &weights {
pred += w * self.experts[*idx].model.predict(features);
}
pred
}
fn n_samples_seen(&self) -> u64 {
self.n_samples
}
fn reset(&mut self) {
for slot in &mut self.experts {
slot.model.reset();
slot.utilization_ewma = 0.0;
slot.samples_trained = 0;
}
self.router.reset();
self.n_samples = 0;
self.cached_disagreement = 0.0;
self.prev_prediction = 0.0;
self.prev_change = 0.0;
self.prev_prev_change = 0.0;
self.alignment_ewma = 0.0;
self.gate_entropy_ewma = 0.0;
}
#[allow(deprecated)]
fn diagnostics_array(&self) -> [f64; 5] {
<Self as crate::learner::Tunable>::diagnostics_array(self)
}
}
impl crate::learner::Tunable for NeuralMoE {
fn diagnostics_array(&self) -> [f64; 5] {
use crate::automl::DiagnosticSource;
match self.config_diagnostics() {
Some(d) => [
d.residual_alignment,
d.regularization_sensitivity,
d.depth_sufficiency,
d.effective_dof,
d.uncertainty,
],
None => [0.0; 5],
}
}
fn adjust_config(&mut self, _lr_multiplier: f64, _lambda_delta: f64) {
}
}
impl NeuralMoE {
fn reset_dead_experts(&mut self) {
for slot in &mut self.experts {
if slot.samples_trained > self.config.utilization_span as u64
&& slot.utilization_ewma < self.config.utilization_threshold
{
slot.model.reset();
slot.utilization_ewma = 0.0;
slot.samples_trained = 0;
}
}
}
}
impl crate::automl::DiagnosticSource for NeuralMoE {
fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
let depth_sufficiency = self.gate_entropy_ewma.clamp(0.0, 1.0);
Some(crate::automl::ConfigDiagnostics {
residual_alignment: self.alignment_ewma,
regularization_sensitivity: self.config.load_balance_rate,
depth_sufficiency,
effective_dof: self.n_experts() as f64,
uncertainty: self.cached_disagreement,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{linear, rls, sgbt};
#[test]
fn builder_creates_moe() {
let moe = NeuralMoE::builder()
.expert(sgbt(10, 0.01))
.expert(sgbt(20, 0.01))
.expert(linear(0.01))
.top_k(2)
.build()
.unwrap();
assert_eq!(moe.n_experts(), 3);
assert_eq!(moe.top_k(), 2);
assert_eq!(moe.n_samples_seen(), 0);
}
#[test]
fn builder_errors_with_one_expert() {
use irithyll_core::error::ConfigError;
let result = NeuralMoE::builder().expert(sgbt(10, 0.01)).build();
assert!(result.is_err(), "expected Err with one expert");
let err = result.err().unwrap();
assert!(
matches!(&err, ConfigError::OutOfRange { param, .. } if *param == "n_experts"),
"expected OutOfRange for n_experts"
);
}
#[test]
fn train_and_predict_finite() {
let mut moe = NeuralMoE::builder()
.expert(sgbt(10, 0.01))
.expert(sgbt(20, 0.01))
.expert(linear(0.01))
.top_k(2)
.build()
.unwrap();
for i in 0..100 {
let x = [i as f64 * 0.01, (i as f64).sin()];
let y = x[0] * 2.0 + 1.0;
moe.train(&x, y);
}
let pred = moe.predict(&[0.5, 0.5_f64.sin()]);
assert!(pred.is_finite(), "prediction should be finite, got {pred}");
}
#[test]
fn n_samples_tracks_correctly() {
let mut moe = NeuralMoE::builder()
.expert(linear(0.01))
.expert(linear(0.02))
.build()
.unwrap();
for i in 0..42 {
moe.train(&[i as f64], i as f64 * 2.0);
}
assert_eq!(moe.n_samples_seen(), 42);
}
#[test]
fn reset_clears_state() {
let mut moe = NeuralMoE::builder()
.expert(linear(0.01))
.expert(linear(0.02))
.build()
.unwrap();
for i in 0..50 {
moe.train(&[i as f64], i as f64);
}
assert!(moe.n_samples_seen() > 0);
moe.reset();
assert_eq!(moe.n_samples_seen(), 0);
for s in moe.expert_samples() {
assert_eq!(s, 0, "expert samples should be 0 after reset");
}
}
#[test]
fn implements_streaming_learner() {
let moe = NeuralMoE::builder()
.expert(linear(0.01))
.expert(linear(0.02))
.build()
.unwrap();
let mut boxed: Box<dyn StreamingLearner> = Box::new(moe);
boxed.train(&[1.0], 2.0);
let pred = boxed.predict(&[1.0]);
assert!(pred.is_finite(), "trait object prediction should be finite");
}
#[test]
fn expert_predictions_returns_all() {
let moe = NeuralMoE::builder()
.expert(linear(0.01))
.expert(linear(0.02))
.expert(linear(0.05))
.top_k(2)
.build()
.unwrap();
let preds = moe.expert_predictions(&[1.0]);
assert_eq!(preds.len(), 3, "should have predictions from all 3 experts");
}
#[test]
fn routing_probabilities_sum_to_one() {
let moe = NeuralMoE::builder()
.expert(sgbt(10, 0.01))
.expert(sgbt(20, 0.01))
.expert(linear(0.01))
.build()
.unwrap();
let probs = moe.routing_probabilities(&[1.0, 2.0]);
let sum: f64 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-10,
"routing probabilities should sum to 1.0, got {sum}"
);
}
#[test]
fn utilization_starts_at_zero() {
let moe = NeuralMoE::builder()
.expert(linear(0.01))
.expert(linear(0.02))
.build()
.unwrap();
for u in moe.utilization() {
assert!((u - 0.0).abs() < 1e-12, "initial utilization should be 0.0");
}
}
#[test]
fn warmup_hint_stored() {
let moe = NeuralMoE::builder()
.expert(linear(0.01))
.expert_with_warmup(linear(0.02), 50)
.build()
.unwrap();
assert_eq!(moe.experts[0].warmup_hint, 0, "first expert has no warmup");
assert_eq!(
moe.experts[1].warmup_hint, 50,
"second expert has warmup 50"
);
}
#[test]
fn heterogeneous_experts_work() {
let mut moe = NeuralMoE::builder()
.expert(sgbt(10, 0.01))
.expert(linear(0.01))
.expert(rls(0.99))
.top_k(2)
.build()
.unwrap();
for i in 0..200 {
let x = [i as f64 * 0.01, (i as f64 * 0.1).sin()];
let y = x[0] * 3.0 + x[1] * 2.0 + 1.0;
moe.train(&x, y);
}
let pred = moe.predict(&[1.0, 1.0_f64.sin()]);
assert!(
pred.is_finite(),
"heterogeneous MoE prediction should be finite, got {pred}"
);
}
#[test]
fn top_k_limits_active_experts() {
let mut moe = NeuralMoE::builder()
.expert(linear(0.01))
.expert(linear(0.02))
.expert(linear(0.03))
.expert(linear(0.04))
.top_k(1) .build()
.unwrap();
for i in 0..100 {
moe.train(&[i as f64], i as f64 * 2.0);
}
let samples = moe.expert_samples();
let total_expert_trains: u64 = samples.iter().sum();
assert_eq!(
total_expert_trains, 100,
"with top_k=1, total expert trains should equal n_samples"
);
}
#[test]
fn load_distribution_available() {
let moe = NeuralMoE::builder()
.expert(linear(0.01))
.expert(linear(0.02))
.build()
.unwrap();
let load = moe.load_distribution();
assert_eq!(load.len(), 2, "load distribution should have 2 entries");
}
#[test]
fn custom_config() {
let moe = NeuralMoE::builder()
.expert(linear(0.01))
.expert(linear(0.02))
.top_k(1)
.router_lr(0.05)
.load_balance_rate(0.02)
.utilization_span(200)
.utilization_threshold(0.05)
.reset_dead(false)
.seed(999)
.build()
.unwrap();
assert_eq!(moe.config.top_k, 1);
assert!((moe.config.router_lr - 0.05).abs() < 1e-12);
assert!((moe.config.load_balance_rate - 0.02).abs() < 1e-12);
assert_eq!(moe.config.utilization_span, 200);
assert!((moe.config.utilization_threshold - 0.05).abs() < 1e-12);
assert!(!moe.config.reset_dead);
assert_eq!(moe.config.seed, 999);
}
#[test]
fn moe_expert_disagreement() {
let mut moe = NeuralMoE::builder()
.expert(sgbt(10, 0.01))
.expert(sgbt(20, 0.01))
.expert(linear(0.01))
.top_k(2)
.build()
.unwrap();
assert!(
moe.cached_disagreement().abs() < 1e-15,
"cached_disagreement should be 0 before training, got {}",
moe.cached_disagreement()
);
for i in 0..100 {
let x = [i as f64 * 0.01, (i as f64).sin()];
let y = x[0] * 2.0 + 1.0;
moe.train(&x, y);
}
let disagree = moe.cached_disagreement();
assert!(
disagree >= 0.0,
"expert_disagreement should be >= 0, got {}",
disagree
);
assert!(
disagree.is_finite(),
"expert_disagreement should be finite, got {}",
disagree
);
let direct = moe.expert_disagreement(&[0.5, 0.5_f64.sin()]);
assert!(
direct.is_finite(),
"expert_disagreement() should be finite, got {}",
direct
);
}
#[test]
fn warmup_progress_accessor() {
let moe = NeuralMoE::builder()
.expert(linear(0.01))
.expert_with_warmup(linear(0.02), 100)
.expert_with_warmup(linear(0.03), 0) .build()
.unwrap();
let progress = moe.warmup_progress();
assert_eq!(progress.len(), 3, "should have 3 progress values");
assert!(
(progress[0] - 1.0).abs() < 1e-12,
"expert with no warmup (hint=0) should have progress 1.0, got {}",
progress[0]
);
assert!(
(progress[1] - 0.0).abs() < 1e-12,
"expert with warmup=100 and 0 samples should have progress 0.0, got {}",
progress[1]
);
assert!(
(progress[2] - 1.0).abs() < 1e-12,
"expert with warmup_hint=0 should have progress 1.0, got {}",
progress[2]
);
}
#[test]
fn warmup_suppresses_cold_expert() {
let mut moe = NeuralMoE::builder()
.expert(linear(0.01))
.expert_with_warmup(linear(0.02), 100)
.top_k(1)
.router_lr(0.0) .load_balance_rate(0.0) .reset_dead(false)
.build()
.unwrap();
let mut expert0_count = 0u64;
let mut expert1_count = 0u64;
for i in 0..50 {
let x = [i as f64 * 0.01];
let y = x[0] * 2.0;
moe.train(&x, y);
}
let progress = moe.warmup_progress();
assert!(
(progress[0] - 1.0).abs() < 1e-12,
"expert 0 (no warmup) should always be at progress 1.0, got {}",
progress[0]
);
assert!(
progress[1] < 1.0,
"expert 1 should not be fully warmed up yet, got progress {}",
progress[1]
);
for i in 50..100 {
let x = [i as f64 * 0.01];
let y = x[0] * 2.0;
let samples_before = moe.expert_samples();
moe.train(&x, y);
let samples_after = moe.expert_samples();
if samples_after[0] > samples_before[0] {
expert0_count += 1;
}
if samples_after[1] > samples_before[1] {
expert1_count += 1;
}
}
assert!(
expert0_count > expert1_count,
"non-warmup expert should be routed more often than cold expert \
(expert0={}, expert1={})",
expert0_count,
expert1_count
);
}
#[test]
fn warmup_eventually_routes_normally() {
let mut moe = NeuralMoE::builder()
.expert(linear(0.01))
.expert_with_warmup(linear(0.02), 20)
.top_k(2) .reset_dead(false)
.build()
.unwrap();
let penalties = moe.warmup_penalties();
assert!(
(penalties[0] - 0.0).abs() < 1e-12,
"expert 0 (no warmup) should have 0 penalty, got {}",
penalties[0]
);
assert!(
(penalties[1] - (-5.0)).abs() < 1e-12,
"expert 1 (cold, 0 samples) should have -5.0 penalty, got {}",
penalties[1]
);
for i in 0..50 {
let x = [i as f64 * 0.01];
let y = x[0] * 2.0;
moe.train(&x, y);
}
let progress = moe.warmup_progress();
assert!(
(progress[1] - 1.0).abs() < 1e-12,
"expert 1 should be fully warmed up after 50 samples with top_k=2 (warmup=20), \
got progress {}",
progress[1]
);
let penalties_after = moe.warmup_penalties();
assert!(
penalties_after[0].abs() < 1e-12,
"expert 0 penalty should be 0 after warmup, got {}",
penalties_after[0]
);
assert!(
penalties_after[1].abs() < 1e-12,
"expert 1 penalty should be 0 after warmup, got {}",
penalties_after[1]
);
let progress_formula = (10.0_f64 / 20.0).min(1.0);
let expected_penalty = -5.0 * (1.0 - progress_formula);
assert!(
(expected_penalty - (-2.5)).abs() < 1e-12,
"mid-warmup penalty formula check: expected -2.5, got {}",
expected_penalty
);
}
#[test]
fn neural_moe_seed_observably_affects_initial_routing() {
let moe_seed_1 = NeuralMoE::builder()
.expert(linear(0.01))
.expert(linear(0.02))
.expert(linear(0.03))
.seed(123)
.build()
.unwrap();
let moe_seed_2 = NeuralMoE::builder()
.expert(linear(0.01))
.expert(linear(0.02))
.expert(linear(0.03))
.seed(456)
.build()
.unwrap();
assert_eq!(
moe_seed_1.config.seed, 123,
"seed should be stored in config"
);
assert_eq!(
moe_seed_2.config.seed, 456,
"different seed should be stored"
);
assert_ne!(
moe_seed_1.config.seed, moe_seed_2.config.seed,
"Seeds should be distinct when set to different values"
);
}
}