1mod router;
24
25use crate::learner::StreamingLearner;
26use router::LinearRouter;
27
28struct ExpertSlot {
33 model: Box<dyn StreamingLearner>,
34 #[allow(dead_code)]
36 warmup_hint: usize,
37 utilization_ewma: f64,
38 samples_trained: u64,
39}
40
41#[derive(Debug, Clone)]
47pub struct NeuralMoEConfig {
48 pub top_k: usize,
50 pub router_lr: f64,
52 pub load_balance_rate: f64,
54 pub utilization_span: usize,
56 pub utilization_threshold: f64,
58 pub reset_dead: bool,
60 pub seed: u64,
62}
63
64impl Default for NeuralMoEConfig {
65 fn default() -> Self {
66 Self {
67 top_k: 2,
68 router_lr: 0.01,
69 load_balance_rate: 0.01,
70 utilization_span: 500,
71 utilization_threshold: 0.01,
72 reset_dead: true,
73 seed: 42,
74 }
75 }
76}
77
78pub struct NeuralMoE {
104 experts: Vec<ExpertSlot>,
105 router: LinearRouter,
106 config: NeuralMoEConfig,
107 n_samples: u64,
108 cached_disagreement: f64,
110 prev_prediction: f64,
112 prev_change: f64,
114 prev_prev_change: f64,
116 alignment_ewma: f64,
118 gate_entropy_ewma: f64,
120}
121
122pub struct NeuralMoEBuilder {
128 experts: Vec<(Box<dyn StreamingLearner>, usize)>, config: NeuralMoEConfig,
130}
131
132impl NeuralMoE {
133 pub fn builder() -> NeuralMoEBuilder {
135 NeuralMoEBuilder {
136 experts: Vec::new(),
137 config: NeuralMoEConfig::default(),
138 }
139 }
140}
141
142impl NeuralMoEBuilder {
143 pub fn expert(mut self, model: impl StreamingLearner + 'static) -> Self {
145 self.experts.push((Box::new(model), 0));
146 self
147 }
148
149 pub fn expert_with_warmup(
151 mut self,
152 model: impl StreamingLearner + 'static,
153 warmup: usize,
154 ) -> Self {
155 self.experts.push((Box::new(model), warmup));
156 self
157 }
158
159 pub fn top_k(mut self, k: usize) -> Self {
161 self.config.top_k = k;
162 self
163 }
164
165 pub fn router_lr(mut self, lr: f64) -> Self {
167 self.config.router_lr = lr;
168 self
169 }
170
171 pub fn load_balance_rate(mut self, r: f64) -> Self {
173 self.config.load_balance_rate = r;
174 self
175 }
176
177 pub fn utilization_span(mut self, s: usize) -> Self {
179 self.config.utilization_span = s;
180 self
181 }
182
183 pub fn utilization_threshold(mut self, t: f64) -> Self {
185 self.config.utilization_threshold = t;
186 self
187 }
188
189 pub fn reset_dead(mut self, b: bool) -> Self {
191 self.config.reset_dead = b;
192 self
193 }
194
195 pub fn seed(mut self, s: u64) -> Self {
197 self.config.seed = s;
198 self
199 }
200
201 pub fn build(self) -> NeuralMoE {
206 assert!(
207 self.experts.len() >= 2,
208 "NeuralMoE requires at least 2 experts, got {}",
209 self.experts.len()
210 );
211
212 let k = self.experts.len();
213 let config = self.config;
214
215 let router = LinearRouter::new(
216 k,
217 config.router_lr,
218 config.load_balance_rate,
219 config.utilization_span,
220 );
221
222 let experts: Vec<ExpertSlot> = self
223 .experts
224 .into_iter()
225 .map(|(model, warmup)| ExpertSlot {
226 model,
227 warmup_hint: warmup,
228 utilization_ewma: 0.0,
229 samples_trained: 0,
230 })
231 .collect();
232
233 NeuralMoE {
234 experts,
235 router,
236 config,
237 n_samples: 0,
238 cached_disagreement: 0.0,
239 prev_prediction: 0.0,
240 prev_change: 0.0,
241 prev_prev_change: 0.0,
242 alignment_ewma: 0.0,
243 gate_entropy_ewma: 0.0,
244 }
245 }
246}
247
248impl NeuralMoE {
253 pub fn n_experts(&self) -> usize {
255 self.experts.len()
256 }
257
258 pub fn top_k(&self) -> usize {
260 self.config.top_k
261 }
262
263 pub fn utilization(&self) -> Vec<f64> {
265 self.experts.iter().map(|e| e.utilization_ewma).collect()
266 }
267
268 pub fn expert_samples(&self) -> Vec<u64> {
270 self.experts.iter().map(|e| e.samples_trained).collect()
271 }
272
273 pub fn n_dead_experts(&self) -> usize {
275 self.experts
276 .iter()
277 .filter(|e| {
278 e.samples_trained > self.config.utilization_span as u64
279 && e.utilization_ewma < self.config.utilization_threshold
280 })
281 .count()
282 }
283
284 pub fn load_distribution(&self) -> &[f64] {
286 self.router.load_distribution()
287 }
288
289 pub fn expert_disagreement(&self, features: &[f64]) -> f64 {
295 let preds = self.expert_predictions(features);
296 if preds.len() < 2 {
297 return 0.0;
298 }
299 let n = preds.len() as f64;
300 let mean = preds.iter().sum::<f64>() / n;
301 let var = preds.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / (n - 1.0);
302 var.sqrt()
303 }
304
305 #[inline]
310 pub fn cached_disagreement(&self) -> f64 {
311 self.cached_disagreement
312 }
313
314 pub fn expert_predictions(&self, features: &[f64]) -> Vec<f64> {
316 self.experts
317 .iter()
318 .map(|e| e.model.predict(features))
319 .collect()
320 }
321
322 pub fn routing_probabilities(&self, features: &[f64]) -> Vec<f64> {
324 self.router.probabilities(features)
325 }
326}
327
328impl StreamingLearner for NeuralMoE {
333 fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
334 let k = self.config.top_k.min(self.experts.len());
335
336 let active_indices = self.router.select_top_k(features, k);
338
339 let mut best_idx = active_indices[0];
341 let mut best_error = f64::INFINITY;
342 let mut active_preds: Vec<f64> = Vec::with_capacity(k);
343
344 for &idx in &active_indices {
345 let pred = self.experts[idx].model.predict(features);
346 active_preds.push(pred);
347 let error = (target - pred).abs();
348 if error < best_error {
349 best_error = error;
350 best_idx = idx;
351 }
352 }
353
354 if active_preds.len() >= 2 {
356 let n = active_preds.len() as f64;
357 let mean = active_preds.iter().sum::<f64>() / n;
358 let var = active_preds.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / (n - 1.0);
359 self.cached_disagreement = var.sqrt();
360 }
361
362 {
364 let weights = self.router.renormalized_weights(features, &active_indices);
365 let mut current_pred = 0.0;
366 for (idx, w) in &weights {
367 current_pred +=
368 w * active_preds[active_indices.iter().position(|&i| i == *idx).unwrap_or(0)];
369 }
370 let current_change = current_pred - self.prev_prediction;
371 if self.n_samples > 0 {
372 let acceleration = current_change - self.prev_change;
373 let prev_acceleration = self.prev_change - self.prev_prev_change;
374 let agreement = if acceleration.abs() > 1e-15 && prev_acceleration.abs() > 1e-15 {
375 if (acceleration > 0.0) == (prev_acceleration > 0.0) {
376 1.0
377 } else {
378 -1.0
379 }
380 } else {
381 0.0
382 };
383 const ALIGN_ALPHA: f64 = 0.05;
384 self.alignment_ewma =
385 (1.0 - ALIGN_ALPHA) * self.alignment_ewma + ALIGN_ALPHA * agreement;
386 }
387 self.prev_prev_change = self.prev_change;
388 self.prev_change = current_change;
389 self.prev_prediction = current_pred;
390 }
391
392 for &idx in &active_indices {
394 self.experts[idx].model.train_one(features, target, weight);
395 self.experts[idx].samples_trained += 1;
396 }
397
398 self.router.update(features, best_idx);
400
401 self.router.update_load_balance(&active_indices);
403
404 let probs = self.router.probabilities(features);
406 let util_alpha = 2.0 / (self.config.utilization_span as f64 + 1.0);
407 for (i, slot) in self.experts.iter_mut().enumerate() {
408 let p = if i < probs.len() { probs[i] } else { 0.0 };
409 slot.utilization_ewma = util_alpha * p + (1.0 - util_alpha) * slot.utilization_ewma;
410 }
411
412 {
414 let k_experts = probs.len();
415 if k_experts > 1 {
416 let ln_k = (k_experts as f64).ln();
417 let mut h = 0.0;
418 for &p in &probs {
419 if p > 1e-15 {
420 h -= p * p.ln();
421 }
422 }
423 let normalized_h = (h / ln_k).clamp(0.0, 1.0);
424 const GATE_ALPHA: f64 = 0.01;
425 self.gate_entropy_ewma =
426 (1.0 - GATE_ALPHA) * self.gate_entropy_ewma + GATE_ALPHA * normalized_h;
427 }
428 }
429
430 if self.config.reset_dead && self.n_samples > self.config.utilization_span as u64 {
432 self.reset_dead_experts();
433 }
434
435 self.n_samples += 1;
436 }
437
438 fn predict(&self, features: &[f64]) -> f64 {
439 let k = self.config.top_k.min(self.experts.len());
440 let active_indices = self.router.select_top_k(features, k);
441 let weights = self.router.renormalized_weights(features, &active_indices);
442
443 let mut pred = 0.0;
445 for (idx, w) in &weights {
446 pred += w * self.experts[*idx].model.predict(features);
447 }
448 pred
449 }
450
451 fn n_samples_seen(&self) -> u64 {
452 self.n_samples
453 }
454
455 fn reset(&mut self) {
456 for slot in &mut self.experts {
457 slot.model.reset();
458 slot.utilization_ewma = 0.0;
459 slot.samples_trained = 0;
460 }
461 self.router.reset();
462 self.n_samples = 0;
463 self.cached_disagreement = 0.0;
464 self.prev_prediction = 0.0;
465 self.prev_change = 0.0;
466 self.prev_prev_change = 0.0;
467 self.alignment_ewma = 0.0;
468 self.gate_entropy_ewma = 0.0;
469 }
470
471 fn diagnostics_array(&self) -> [f64; 5] {
472 use crate::automl::DiagnosticSource;
473 match self.config_diagnostics() {
474 Some(d) => [
475 d.residual_alignment,
476 d.regularization_sensitivity,
477 d.depth_sufficiency,
478 d.effective_dof,
479 d.uncertainty,
480 ],
481 None => [0.0; 5],
482 }
483 }
484}
485
486impl NeuralMoE {
491 fn reset_dead_experts(&mut self) {
493 for slot in &mut self.experts {
494 if slot.samples_trained > self.config.utilization_span as u64
495 && slot.utilization_ewma < self.config.utilization_threshold
496 {
497 slot.model.reset();
498 slot.utilization_ewma = 0.0;
499 slot.samples_trained = 0;
500 }
501 }
502 }
503}
504
505impl crate::automl::DiagnosticSource for NeuralMoE {
510 fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
511 let depth_sufficiency = self.gate_entropy_ewma.clamp(0.0, 1.0);
515
516 Some(crate::automl::ConfigDiagnostics {
517 residual_alignment: self.alignment_ewma,
518 regularization_sensitivity: self.config.load_balance_rate,
519 depth_sufficiency,
520 effective_dof: self.n_experts() as f64,
521 uncertainty: self.cached_disagreement,
522 })
523 }
524}
525
526#[cfg(test)]
531mod tests {
532 use super::*;
533 use crate::{linear, rls, sgbt};
535
536 #[test]
537 fn builder_creates_moe() {
538 let moe = NeuralMoE::builder()
539 .expert(sgbt(10, 0.01))
540 .expert(sgbt(20, 0.01))
541 .expert(linear(0.01))
542 .top_k(2)
543 .build();
544
545 assert_eq!(moe.n_experts(), 3);
546 assert_eq!(moe.top_k(), 2);
547 assert_eq!(moe.n_samples_seen(), 0);
548 }
549
550 #[test]
551 #[should_panic(expected = "at least 2 experts")]
552 fn builder_panics_with_one_expert() {
553 NeuralMoE::builder().expert(sgbt(10, 0.01)).build();
554 }
555
556 #[test]
557 fn train_and_predict_finite() {
558 let mut moe = NeuralMoE::builder()
559 .expert(sgbt(10, 0.01))
560 .expert(sgbt(20, 0.01))
561 .expert(linear(0.01))
562 .top_k(2)
563 .build();
564
565 for i in 0..100 {
566 let x = [i as f64 * 0.01, (i as f64).sin()];
567 let y = x[0] * 2.0 + 1.0;
568 moe.train(&x, y);
569 }
570
571 let pred = moe.predict(&[0.5, 0.5_f64.sin()]);
572 assert!(pred.is_finite(), "prediction should be finite, got {pred}");
573 }
574
575 #[test]
576 fn n_samples_tracks_correctly() {
577 let mut moe = NeuralMoE::builder()
578 .expert(linear(0.01))
579 .expert(linear(0.02))
580 .build();
581
582 for i in 0..42 {
583 moe.train(&[i as f64], i as f64 * 2.0);
584 }
585 assert_eq!(moe.n_samples_seen(), 42);
586 }
587
588 #[test]
589 fn reset_clears_state() {
590 let mut moe = NeuralMoE::builder()
591 .expert(linear(0.01))
592 .expert(linear(0.02))
593 .build();
594
595 for i in 0..50 {
596 moe.train(&[i as f64], i as f64);
597 }
598 assert!(moe.n_samples_seen() > 0);
599
600 moe.reset();
601 assert_eq!(moe.n_samples_seen(), 0);
602 for s in moe.expert_samples() {
603 assert_eq!(s, 0, "expert samples should be 0 after reset");
604 }
605 }
606
607 #[test]
608 fn implements_streaming_learner() {
609 let moe = NeuralMoE::builder()
610 .expert(linear(0.01))
611 .expert(linear(0.02))
612 .build();
613
614 let mut boxed: Box<dyn StreamingLearner> = Box::new(moe);
615 boxed.train(&[1.0], 2.0);
616 let pred = boxed.predict(&[1.0]);
617 assert!(pred.is_finite(), "trait object prediction should be finite");
618 }
619
620 #[test]
621 fn expert_predictions_returns_all() {
622 let moe = NeuralMoE::builder()
623 .expert(linear(0.01))
624 .expert(linear(0.02))
625 .expert(linear(0.05))
626 .top_k(2)
627 .build();
628
629 let preds = moe.expert_predictions(&[1.0]);
630 assert_eq!(preds.len(), 3, "should have predictions from all 3 experts");
631 }
632
633 #[test]
634 fn routing_probabilities_sum_to_one() {
635 let moe = NeuralMoE::builder()
636 .expert(sgbt(10, 0.01))
637 .expert(sgbt(20, 0.01))
638 .expert(linear(0.01))
639 .build();
640
641 let probs = moe.routing_probabilities(&[1.0, 2.0]);
642 let sum: f64 = probs.iter().sum();
643 assert!(
644 (sum - 1.0).abs() < 1e-10,
645 "routing probabilities should sum to 1.0, got {sum}"
646 );
647 }
648
649 #[test]
650 fn utilization_starts_at_zero() {
651 let moe = NeuralMoE::builder()
652 .expert(linear(0.01))
653 .expert(linear(0.02))
654 .build();
655
656 for u in moe.utilization() {
657 assert!((u - 0.0).abs() < 1e-12, "initial utilization should be 0.0");
658 }
659 }
660
661 #[test]
662 fn warmup_hint_stored() {
663 let moe = NeuralMoE::builder()
664 .expert(linear(0.01))
665 .expert_with_warmup(linear(0.02), 50)
666 .build();
667
668 assert_eq!(moe.experts[0].warmup_hint, 0, "first expert has no warmup");
669 assert_eq!(
670 moe.experts[1].warmup_hint, 50,
671 "second expert has warmup 50"
672 );
673 }
674
675 #[test]
676 fn heterogeneous_experts_work() {
677 let mut moe = NeuralMoE::builder()
678 .expert(sgbt(10, 0.01))
679 .expert(linear(0.01))
680 .expert(rls(0.99))
681 .top_k(2)
682 .build();
683
684 for i in 0..200 {
685 let x = [i as f64 * 0.01, (i as f64 * 0.1).sin()];
686 let y = x[0] * 3.0 + x[1] * 2.0 + 1.0;
687 moe.train(&x, y);
688 }
689
690 let pred = moe.predict(&[1.0, 1.0_f64.sin()]);
691 assert!(
692 pred.is_finite(),
693 "heterogeneous MoE prediction should be finite, got {pred}"
694 );
695 }
696
697 #[test]
698 fn top_k_limits_active_experts() {
699 let mut moe = NeuralMoE::builder()
700 .expert(linear(0.01))
701 .expert(linear(0.02))
702 .expert(linear(0.03))
703 .expert(linear(0.04))
704 .top_k(1) .build();
706
707 for i in 0..100 {
709 moe.train(&[i as f64], i as f64 * 2.0);
710 }
711
712 let samples = moe.expert_samples();
714 let total_expert_trains: u64 = samples.iter().sum();
715 assert_eq!(
716 total_expert_trains, 100,
717 "with top_k=1, total expert trains should equal n_samples"
718 );
719 }
720
721 #[test]
722 fn load_distribution_available() {
723 let moe = NeuralMoE::builder()
724 .expert(linear(0.01))
725 .expert(linear(0.02))
726 .build();
727
728 let load = moe.load_distribution();
729 assert_eq!(load.len(), 2, "load distribution should have 2 entries");
730 }
731
732 #[test]
733 fn custom_config() {
734 let moe = NeuralMoE::builder()
735 .expert(linear(0.01))
736 .expert(linear(0.02))
737 .top_k(1)
738 .router_lr(0.05)
739 .load_balance_rate(0.02)
740 .utilization_span(200)
741 .utilization_threshold(0.05)
742 .reset_dead(false)
743 .seed(999)
744 .build();
745
746 assert_eq!(moe.config.top_k, 1);
747 assert!((moe.config.router_lr - 0.05).abs() < 1e-12);
748 assert!((moe.config.load_balance_rate - 0.02).abs() < 1e-12);
749 assert_eq!(moe.config.utilization_span, 200);
750 assert!((moe.config.utilization_threshold - 0.05).abs() < 1e-12);
751 assert!(!moe.config.reset_dead);
752 assert_eq!(moe.config.seed, 999);
753 }
754
755 #[test]
756 fn moe_expert_disagreement() {
757 let mut moe = NeuralMoE::builder()
758 .expert(sgbt(10, 0.01))
759 .expert(sgbt(20, 0.01))
760 .expert(linear(0.01))
761 .top_k(2)
762 .build();
763
764 assert!(
766 moe.cached_disagreement().abs() < 1e-15,
767 "cached_disagreement should be 0 before training, got {}",
768 moe.cached_disagreement()
769 );
770
771 for i in 0..100 {
773 let x = [i as f64 * 0.01, (i as f64).sin()];
774 let y = x[0] * 2.0 + 1.0;
775 moe.train(&x, y);
776 }
777
778 let disagree = moe.cached_disagreement();
780 assert!(
781 disagree >= 0.0,
782 "expert_disagreement should be >= 0, got {}",
783 disagree
784 );
785 assert!(
786 disagree.is_finite(),
787 "expert_disagreement should be finite, got {}",
788 disagree
789 );
790
791 let direct = moe.expert_disagreement(&[0.5, 0.5_f64.sin()]);
793 assert!(
794 direct.is_finite(),
795 "expert_disagreement() should be finite, got {}",
796 direct
797 );
798 }
799}