1use alloc::vec;
64use alloc::vec::Vec;
65
66use core::fmt;
67
68use crate::ensemble::config::SGBTConfig;
69use crate::ensemble::distributional::{DistributionalSGBT, GaussianPrediction};
70use crate::ensemble::moe::{softmax, GatingMode};
71use crate::sample::{Observation, SampleRef};
72
73pub struct MoEDistributionalSGBT {
85 experts: Vec<DistributionalSGBT>,
87 shadows: Vec<DistributionalSGBT>,
89 gate_weights: Vec<Vec<f64>>,
91 gate_bias: Vec<f64>,
93 gate_lr: f64,
95 n_features: Option<usize>,
97 gating_mode: GatingMode,
99 config: SGBTConfig,
101 expert_configs: Option<Vec<SGBTConfig>>,
104 samples_seen: u64,
106 entropy_weight: f64,
114
115 cumulative_advantage: Vec<f64>,
118 shadow_n: Vec<u64>,
120 max_nll_diff: Vec<f64>,
122 delta: f64,
124 shadow_min_samples: u64,
126 shadow_replacements: Vec<u64>,
128}
129
130impl Clone for MoEDistributionalSGBT {
135 fn clone(&self) -> Self {
136 Self {
137 experts: self.experts.clone(),
138 shadows: self.shadows.clone(),
139 gate_weights: self.gate_weights.clone(),
140 gate_bias: self.gate_bias.clone(),
141 gate_lr: self.gate_lr,
142 n_features: self.n_features,
143 gating_mode: self.gating_mode.clone(),
144 config: self.config.clone(),
145 expert_configs: self.expert_configs.clone(),
146 samples_seen: self.samples_seen,
147 entropy_weight: self.entropy_weight,
148 cumulative_advantage: self.cumulative_advantage.clone(),
149 shadow_n: self.shadow_n.clone(),
150 max_nll_diff: self.max_nll_diff.clone(),
151 delta: self.delta,
152 shadow_min_samples: self.shadow_min_samples,
153 shadow_replacements: self.shadow_replacements.clone(),
154 }
155 }
156}
157
158impl fmt::Debug for MoEDistributionalSGBT {
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 f.debug_struct("MoEDistributionalSGBT")
165 .field("n_experts", &self.experts.len())
166 .field("samples_seen", &self.samples_seen)
167 .field("shadow_replacements", &self.shadow_replacements)
168 .finish()
169 }
170}
171
172impl MoEDistributionalSGBT {
177 pub fn new(config: SGBTConfig, n_experts: usize) -> Self {
186 Self::with_shadow_config(config, n_experts, GatingMode::Soft, 0.01, 1e-3, 500)
187 }
188
189 pub fn with_gating(
196 config: SGBTConfig,
197 n_experts: usize,
198 gating_mode: GatingMode,
199 gate_lr: f64,
200 ) -> Self {
201 Self::with_shadow_config(config, n_experts, gating_mode, gate_lr, 1e-3, 500)
202 }
203
204 pub fn with_shadow_config(
220 config: SGBTConfig,
221 n_experts: usize,
222 gating_mode: GatingMode,
223 gate_lr: f64,
224 delta: f64,
225 shadow_min_samples: u64,
226 ) -> Self {
227 assert!(
228 n_experts >= 1,
229 "MoEDistributionalSGBT requires at least 1 expert"
230 );
231
232 let experts: Vec<DistributionalSGBT> = (0..n_experts)
233 .map(|i| {
234 let mut cfg = config.clone();
235 cfg.seed = config.seed ^ (0x0E00_0000 | i as u64);
236 DistributionalSGBT::new(cfg)
237 })
238 .collect();
239
240 let shadows: Vec<DistributionalSGBT> = (0..n_experts)
241 .map(|i| {
242 let mut cfg = config.clone();
243 cfg.seed = config.seed ^ (0x5A00_0000 | i as u64);
244 DistributionalSGBT::new(cfg)
245 })
246 .collect();
247
248 let gate_bias = vec![0.0; n_experts];
249
250 Self {
251 experts,
252 shadows,
253 gate_weights: Vec::new(), gate_bias,
255 gate_lr,
256 n_features: None,
257 gating_mode,
258 config,
259 expert_configs: None,
260 samples_seen: 0,
261 entropy_weight: 0.0,
262 cumulative_advantage: vec![0.0; n_experts],
263 shadow_n: vec![0; n_experts],
264 max_nll_diff: vec![0.0; n_experts],
265 delta,
266 shadow_min_samples,
267 shadow_replacements: vec![0; n_experts],
268 }
269 }
270
271 pub fn with_expert_configs(
291 configs: Vec<SGBTConfig>,
292 gating_mode: GatingMode,
293 gate_lr: f64,
294 entropy_weight: f64,
295 delta: f64,
296 shadow_min_samples: u64,
297 ) -> Self {
298 assert!(
299 !configs.is_empty(),
300 "MoEDistributionalSGBT requires at least 1 expert config"
301 );
302
303 let n_experts = configs.len();
304
305 let experts: Vec<DistributionalSGBT> = configs
306 .iter()
307 .enumerate()
308 .map(|(i, cfg)| {
309 let mut c = cfg.clone();
310 c.seed = cfg.seed ^ (0x0E00_0000 | i as u64);
311 DistributionalSGBT::new(c)
312 })
313 .collect();
314
315 let shadows: Vec<DistributionalSGBT> = configs
316 .iter()
317 .enumerate()
318 .map(|(i, cfg)| {
319 let mut c = cfg.clone();
320 c.seed = cfg.seed ^ (0x5A00_0000 | i as u64);
321 DistributionalSGBT::new(c)
322 })
323 .collect();
324
325 let gate_bias = vec![0.0; n_experts];
326 let shared_config = configs[0].clone();
327
328 Self {
329 experts,
330 shadows,
331 gate_weights: Vec::new(),
332 gate_bias,
333 gate_lr,
334 n_features: None,
335 gating_mode,
336 config: shared_config,
337 expert_configs: Some(configs),
338 samples_seen: 0,
339 entropy_weight,
340 cumulative_advantage: vec![0.0; n_experts],
341 shadow_n: vec![0; n_experts],
342 max_nll_diff: vec![0.0; n_experts],
343 delta,
344 shadow_min_samples,
345 shadow_replacements: vec![0; n_experts],
346 }
347 }
348}
349
350impl MoEDistributionalSGBT {
355 fn ensure_gate_init(&mut self, d: usize) {
361 if self.n_features.is_none() {
362 let k = self.experts.len();
363 self.gate_weights = vec![vec![0.0; d]; k];
364 self.n_features = Some(d);
365 }
366 }
367
368 fn gate_logits(&self, features: &[f64]) -> Vec<f64> {
370 let k = self.experts.len();
371 let mut logits = Vec::with_capacity(k);
372 for i in 0..k {
373 let dot: f64 = self.gate_weights[i]
374 .iter()
375 .zip(features.iter())
376 .map(|(&w, &x)| w * x)
377 .sum();
378 logits.push(dot + self.gate_bias[i]);
379 }
380 logits
381 }
382
383 #[inline]
386 fn gaussian_nll(pred: &GaussianPrediction, target: f64) -> f64 {
387 let z = (target - pred.mu) / pred.sigma.max(1e-16);
388 pred.log_sigma + 0.5 * z * z
389 }
390
391 pub fn gating_probabilities(&self, features: &[f64]) -> Vec<f64> {
400 let k = self.experts.len();
401 if self.n_features.is_none() {
402 return vec![1.0 / k as f64; k];
403 }
404 let logits = self.gate_logits(features);
405 softmax(&logits)
406 }
407
408 pub fn train_one(&mut self, sample: &impl Observation) {
421 let features = sample.features();
422 let target = sample.target();
423 let d = features.len();
424
425 self.ensure_gate_init(d);
427
428 let logits = self.gate_logits(features);
430 let probs = softmax(&logits);
431 let k = self.experts.len();
432
433 match &self.gating_mode {
435 GatingMode::Soft => {
436 for (i, &prob) in probs.iter().enumerate() {
437 let weighted = SampleRef::weighted(features, target, prob);
438 self.experts[i].train_one(&weighted);
439 self.shadows[i].train_one(&weighted);
440 }
441 }
442 GatingMode::Hard { top_k } => {
443 let top_k = (*top_k).min(k);
444 let mut indices: Vec<usize> = (0..k).collect();
445 indices.sort_unstable_by(|&a, &b| {
446 probs[b]
447 .partial_cmp(&probs[a])
448 .unwrap_or(core::cmp::Ordering::Equal)
449 });
450 for &i in indices.iter().take(top_k) {
451 let obs = SampleRef::new(features, target);
452 self.experts[i].train_one(&obs);
453 self.shadows[i].train_one(&obs);
454 }
455 }
456 }
457
458 for i in 0..k {
460 if !self.experts[i].is_initialized() || !self.shadows[i].is_initialized() {
462 continue;
463 }
464 if self.shadows[i].n_samples_seen() < self.shadow_min_samples {
465 continue;
466 }
467
468 let pred_active = self.experts[i].predict(features);
469 let pred_shadow = self.shadows[i].predict(features);
470
471 let nll_active = Self::gaussian_nll(&pred_active, target);
472 let nll_shadow = Self::gaussian_nll(&pred_shadow, target);
473
474 let diff = nll_active - nll_shadow;
476 self.cumulative_advantage[i] += diff;
477 self.shadow_n[i] += 1;
478
479 let abs_diff = crate::math::abs(diff);
480 if abs_diff > self.max_nll_diff[i] {
481 self.max_nll_diff[i] = abs_diff;
482 }
483
484 if self.shadow_n[i] >= 10 && self.max_nll_diff[i] > 0.0 {
486 let mean_advantage = self.cumulative_advantage[i] / self.shadow_n[i] as f64;
487 if mean_advantage > 0.0 {
488 let r_squared = self.max_nll_diff[i] * self.max_nll_diff[i];
489 let ln_inv_delta = crate::math::ln(1.0 / self.delta);
490 let epsilon = crate::math::sqrt(
491 r_squared * ln_inv_delta / (2.0 * self.shadow_n[i] as f64),
492 );
493
494 if mean_advantage > epsilon {
495 self.experts[i] = self.shadows[i].clone();
497 let base_cfg = self
498 .expert_configs
499 .as_ref()
500 .map(|c| &c[i])
501 .unwrap_or(&self.config);
502 let mut fresh_cfg = base_cfg.clone();
503 fresh_cfg.seed = base_cfg.seed
504 ^ (0x5A00_0000 | i as u64)
505 ^ (self.shadow_replacements[i].wrapping_add(1) * 0x9E37_79B9);
506 self.shadows[i] = DistributionalSGBT::new(fresh_cfg);
507
508 self.cumulative_advantage[i] = 0.0;
510 self.shadow_n[i] = 0;
511 self.max_nll_diff[i] = 0.0;
512 self.shadow_replacements[i] += 1;
513 }
514 }
515 }
516 }
517
518 let mut best_idx = 0;
521 let mut best_nll = f64::INFINITY;
522 for (i, expert) in self.experts.iter().enumerate() {
523 let pred = expert.predict(features);
524 let nll = Self::gaussian_nll(&pred, target);
525 if nll < best_nll {
526 best_nll = nll;
527 best_idx = i;
528 }
529 }
530
531 let entropy_mean_log_term: f64 = if self.entropy_weight != 0.0 {
535 probs
536 .iter()
537 .map(|&p| {
538 let lp = if p > 1e-10 { crate::math::ln(p) } else { -23.0 };
539 p * (lp + 1.0)
540 })
541 .sum()
542 } else {
543 0.0
544 };
545
546 for (i, (weights_row, bias)) in self
547 .gate_weights
548 .iter_mut()
549 .zip(self.gate_bias.iter_mut())
550 .enumerate()
551 {
552 let indicator = if i == best_idx { 1.0 } else { 0.0 };
553 let ce_grad = probs[i] - indicator;
554
555 let total_grad = if self.entropy_weight != 0.0 {
556 let log_p = if probs[i] > 1e-10 {
557 crate::math::ln(probs[i])
558 } else {
559 -23.0
560 };
561 let entropy_grad = probs[i] * (log_p + 1.0) - entropy_mean_log_term;
562 ce_grad + self.entropy_weight * entropy_grad
563 } else {
564 ce_grad
565 };
566
567 let lr = self.gate_lr;
568 for (j, &xj) in features.iter().enumerate() {
569 weights_row[j] -= lr * total_grad * xj;
570 }
571 *bias -= lr * total_grad;
572 }
573
574 self.samples_seen += 1;
575 }
576
577 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
579 for sample in samples {
580 self.train_one(sample);
581 }
582 }
583
584 pub fn predict(&self, features: &[f64]) -> GaussianPrediction {
598 let probs = self.gating_probabilities(features);
599 let preds: Vec<GaussianPrediction> =
600 self.experts.iter().map(|e| e.predict(features)).collect();
601
602 let mu_mix: f64 = probs
604 .iter()
605 .zip(preds.iter())
606 .map(|(&p, pred)| p * pred.mu)
607 .sum();
608
609 let second_moment: f64 = probs
612 .iter()
613 .zip(preds.iter())
614 .map(|(&p, pred)| p * (pred.sigma * pred.sigma + pred.mu * pred.mu))
615 .sum();
616 let var_mix = (second_moment - mu_mix * mu_mix).max(1e-16);
617 let sigma_mix = crate::math::sqrt(var_mix);
618
619 GaussianPrediction {
620 mu: mu_mix,
621 sigma: sigma_mix,
622 log_sigma: crate::math::ln(sigma_mix),
623 }
624 }
625
626 pub fn predict_with_gating(&self, features: &[f64]) -> (GaussianPrediction, Vec<f64>) {
631 let probs = self.gating_probabilities(features);
632 let preds: Vec<GaussianPrediction> =
633 self.experts.iter().map(|e| e.predict(features)).collect();
634
635 let mu_mix: f64 = probs
636 .iter()
637 .zip(preds.iter())
638 .map(|(&p, pred)| p * pred.mu)
639 .sum();
640
641 let second_moment: f64 = probs
642 .iter()
643 .zip(preds.iter())
644 .map(|(&p, pred)| p * (pred.sigma * pred.sigma + pred.mu * pred.mu))
645 .sum();
646 let var_mix = (second_moment - mu_mix * mu_mix).max(1e-16);
647 let sigma_mix = crate::math::sqrt(var_mix);
648
649 let pred = GaussianPrediction {
650 mu: mu_mix,
651 sigma: sigma_mix,
652 log_sigma: crate::math::ln(sigma_mix),
653 };
654 (pred, probs)
655 }
656
657 pub fn expert_predictions(&self, features: &[f64]) -> Vec<GaussianPrediction> {
661 self.experts.iter().map(|e| e.predict(features)).collect()
662 }
663
664 #[inline]
666 pub fn predict_mu(&self, features: &[f64]) -> f64 {
667 self.predict(features).mu
668 }
669
670 #[inline]
676 pub fn n_experts(&self) -> usize {
677 self.experts.len()
678 }
679
680 #[inline]
682 pub fn n_samples_seen(&self) -> u64 {
683 self.samples_seen
684 }
685
686 pub fn experts(&self) -> &[DistributionalSGBT] {
688 &self.experts
689 }
690
691 pub fn expert(&self, idx: usize) -> &DistributionalSGBT {
697 &self.experts[idx]
698 }
699
700 pub fn shadow_replacements(&self) -> &[u64] {
702 &self.shadow_replacements
703 }
704
705 #[inline]
707 pub fn entropy_weight(&self) -> f64 {
708 self.entropy_weight
709 }
710
711 pub fn expert_configs(&self) -> Option<&[SGBTConfig]> {
713 self.expert_configs.as_deref()
714 }
715
716 pub fn reset(&mut self) {
721 let k = self.experts.len();
722 for expert in &mut self.experts {
723 expert.reset();
724 }
725 for shadow in &mut self.shadows {
726 shadow.reset();
727 }
728 self.gate_weights.clear();
729 self.gate_bias = vec![0.0; k];
730 self.n_features = None;
731 self.samples_seen = 0;
732 self.cumulative_advantage = vec![0.0; k];
733 self.shadow_n = vec![0; k];
734 self.max_nll_diff = vec![0.0; k];
735 self.shadow_replacements = vec![0; k];
736 }
737}
738
739use crate::learner::StreamingLearner;
744
745impl StreamingLearner for MoEDistributionalSGBT {
746 fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
747 let sample = SampleRef::weighted(features, target, weight);
748 MoEDistributionalSGBT::train_one(self, &sample);
750 }
751
752 fn predict(&self, features: &[f64]) -> f64 {
754 MoEDistributionalSGBT::predict(self, features).mu
755 }
756
757 fn n_samples_seen(&self) -> u64 {
758 self.samples_seen
759 }
760
761 fn reset(&mut self) {
762 MoEDistributionalSGBT::reset(self);
763 }
764}
765
766#[cfg(test)]
771mod tests {
772 use super::*;
773 use crate::sample::Sample;
774 use alloc::boxed::Box;
775 use alloc::vec;
776 use alloc::vec::Vec;
777
778 fn test_config() -> SGBTConfig {
780 SGBTConfig::builder()
781 .n_steps(5)
782 .learning_rate(0.1)
783 .grace_period(5)
784 .build()
785 .unwrap()
786 }
787
788 #[test]
789 fn test_creation() {
790 let moe = MoEDistributionalSGBT::new(test_config(), 3);
791 assert_eq!(moe.n_experts(), 3);
792 assert_eq!(moe.n_samples_seen(), 0);
793 assert_eq!(moe.shadow_replacements().len(), 3);
794 for &r in moe.shadow_replacements() {
795 assert_eq!(r, 0);
796 }
797 }
798
799 #[test]
800 fn test_gating_probabilities_sum_to_one() {
801 let mut moe = MoEDistributionalSGBT::new(test_config(), 5);
802
803 let probs = moe.gating_probabilities(&[1.0, 2.0]);
805 let sum: f64 = probs.iter().sum();
806 assert!((sum - 1.0).abs() < 1e-10, "pre-training sum = {}", sum);
807
808 for i in 0..20 {
810 let sample = Sample::new(vec![i as f64, (i * 2) as f64], i as f64);
811 moe.train_one(&sample);
812 }
813 let probs = moe.gating_probabilities(&[5.0, 10.0]);
814 let sum: f64 = probs.iter().sum();
815 assert!((sum - 1.0).abs() < 1e-10, "post-training sum = {}", sum);
816 }
817
818 #[test]
819 fn test_prediction_is_valid_gaussian() {
820 let mut moe = MoEDistributionalSGBT::new(test_config(), 3);
821
822 for i in 0..50 {
824 let sample = Sample::new(vec![i as f64, (i as f64) * 0.5], i as f64 * 2.0);
825 moe.train_one(&sample);
826 }
827
828 let pred = moe.predict(&[10.0, 5.0]);
829 assert!(pred.mu.is_finite(), "mu should be finite: {}", pred.mu);
830 assert!(pred.sigma > 0.0, "sigma should be > 0: {}", pred.sigma);
831 assert!(
832 pred.log_sigma.is_finite(),
833 "log_sigma should be finite: {}",
834 pred.log_sigma
835 );
836 }
837
838 #[test]
839 fn test_prediction_changes_after_training() {
840 let mut moe = MoEDistributionalSGBT::new(test_config(), 3);
841 let features = vec![1.0, 2.0, 3.0];
842
843 let pred_before = moe.predict(&features);
844
845 for i in 0..100 {
846 let sample = Sample::new(features.clone(), 10.0 + i as f64 * 0.1);
847 moe.train_one(&sample);
848 }
849
850 let pred_after = moe.predict(&features);
851 assert!(
852 (pred_after.mu - pred_before.mu).abs() > 1e-6,
853 "mu should change after training: before={}, after={}",
854 pred_before.mu,
855 pred_after.mu
856 );
857 }
858
859 #[test]
860 fn test_mixture_variance() {
861 let mut moe = MoEDistributionalSGBT::new(test_config(), 2);
864
865 for i in 0..80 {
867 let sample = Sample::new(vec![i as f64], i as f64 * 3.0);
868 moe.train_one(&sample);
869 }
870
871 let features = &[40.0];
872 let probs = moe.gating_probabilities(features);
873 let expert_preds = moe.expert_predictions(features);
874
875 let mu_mix: f64 = probs
877 .iter()
878 .zip(expert_preds.iter())
879 .map(|(&p, pred)| p * pred.mu)
880 .sum();
881 let second_moment: f64 = probs
882 .iter()
883 .zip(expert_preds.iter())
884 .map(|(&p, pred)| p * (pred.sigma * pred.sigma + pred.mu * pred.mu))
885 .sum();
886 let var_mix = (second_moment - mu_mix * mu_mix).max(1e-16);
887 let sigma_mix = var_mix.sqrt();
888
889 let pred = moe.predict(features);
890 assert!(
891 (pred.mu - mu_mix).abs() < 1e-10,
892 "mu mismatch: pred={}, manual={}",
893 pred.mu,
894 mu_mix
895 );
896 assert!(
897 (pred.sigma - sigma_mix).abs() < 1e-10,
898 "sigma mismatch: pred={}, manual={}",
899 pred.sigma,
900 sigma_mix
901 );
902 }
903
904 #[test]
905 fn test_expert_predictions_count() {
906 let moe = MoEDistributionalSGBT::new(test_config(), 4);
907 let preds = moe.expert_predictions(&[1.0, 2.0]);
908 assert_eq!(preds.len(), 4, "should return one prediction per expert");
909 }
910
911 #[test]
912 fn test_predict_with_gating_consistency() {
913 let mut moe = MoEDistributionalSGBT::new(test_config(), 3);
914
915 for i in 0..50 {
916 let sample = Sample::new(vec![i as f64, (i as f64) * 0.5], i as f64);
917 moe.train_one(&sample);
918 }
919
920 let features = &[10.0, 5.0];
921 let (pred, probs) = moe.predict_with_gating(features);
922 let expert_preds = moe.expert_predictions(features);
923
924 assert_eq!(probs.len(), 3);
925 let sum: f64 = probs.iter().sum();
926 assert!((sum - 1.0).abs() < 1e-10);
927
928 let expected_mu: f64 = probs
930 .iter()
931 .zip(expert_preds.iter())
932 .map(|(&p, ep)| p * ep.mu)
933 .sum();
934 assert!(
935 (pred.mu - expected_mu).abs() < 1e-10,
936 "mu mismatch: pred={}, expected={}",
937 pred.mu,
938 expected_mu
939 );
940 }
941
942 #[test]
943 fn test_n_samples_seen_increments() {
944 let mut moe = MoEDistributionalSGBT::new(test_config(), 2);
945 assert_eq!(moe.n_samples_seen(), 0);
946
947 for i in 0..25 {
948 moe.train_one(&Sample::new(vec![i as f64], i as f64));
949 }
950 assert_eq!(moe.n_samples_seen(), 25);
951 }
952
953 #[test]
954 fn test_reset_clears_state() {
955 let mut moe = MoEDistributionalSGBT::new(test_config(), 3);
956
957 for i in 0..50 {
958 moe.train_one(&Sample::new(vec![i as f64, (i * 2) as f64], i as f64));
959 }
960 assert_eq!(moe.n_samples_seen(), 50);
961
962 moe.reset();
963
964 assert_eq!(moe.n_samples_seen(), 0);
965 assert_eq!(moe.n_experts(), 3);
966 let probs = moe.gating_probabilities(&[1.0, 2.0]);
968 assert_eq!(probs.len(), 3);
969 for &p in &probs {
971 assert!(
972 (p - 1.0 / 3.0).abs() < 1e-10,
973 "expected uniform after reset, got {}",
974 p
975 );
976 }
977 for &r in moe.shadow_replacements() {
979 assert_eq!(r, 0);
980 }
981 }
982
983 #[test]
984 fn test_streaming_learner_trait() {
985 let config = test_config();
986 let model = MoEDistributionalSGBT::new(config, 3);
987 let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
988 for i in 0..100 {
989 let x = i as f64 * 0.1;
990 boxed.train(&[x], x * 2.0);
991 }
992 assert_eq!(boxed.n_samples_seen(), 100);
993 let pred = boxed.predict(&[5.0]);
994 assert!(pred.is_finite());
995 boxed.reset();
996 assert_eq!(boxed.n_samples_seen(), 0);
997 }
998
999 #[test]
1000 fn test_hard_gating_mode() {
1001 let mut moe = MoEDistributionalSGBT::with_gating(
1002 test_config(),
1003 4,
1004 GatingMode::Hard { top_k: 2 },
1005 0.01,
1006 );
1007
1008 for i in 0..30 {
1009 let sample = Sample::new(vec![i as f64], i as f64);
1010 moe.train_one(&sample);
1011 }
1012
1013 assert_eq!(moe.n_samples_seen(), 30);
1014 let pred = moe.predict(&[15.0]);
1015 assert!(pred.mu.is_finite());
1016 assert!(pred.sigma > 0.0);
1017 }
1018
1019 #[test]
1020 fn test_predict_mu_matches_predict() {
1021 let mut moe = MoEDistributionalSGBT::new(test_config(), 3);
1022
1023 for i in 0..50 {
1024 moe.train_one(&Sample::new(vec![i as f64], i as f64 * 2.0));
1025 }
1026
1027 let features = &[25.0];
1028 let mu_direct = moe.predict_mu(features);
1029 let mu_from_predict = moe.predict(features).mu;
1030 assert!(
1031 (mu_direct - mu_from_predict).abs() < 1e-12,
1032 "predict_mu={} vs predict().mu={}",
1033 mu_direct,
1034 mu_from_predict
1035 );
1036 }
1037
1038 #[test]
1039 fn test_batch_training() {
1040 let mut moe = MoEDistributionalSGBT::new(test_config(), 3);
1041
1042 let samples: Vec<Sample> = (0..20)
1043 .map(|i| Sample::new(vec![i as f64, (i * 3) as f64], i as f64))
1044 .collect();
1045
1046 moe.train_batch(&samples);
1047
1048 assert_eq!(moe.n_samples_seen(), 20);
1049 let pred = moe.predict(&[10.0, 30.0]);
1050 assert!(pred.mu.is_finite());
1051 assert!(pred.sigma > 0.0);
1052 }
1053
1054 #[test]
1055 fn moe_with_expert_configs_different_depths() {
1056 let configs: Vec<SGBTConfig> = (0..3)
1058 .map(|i| {
1059 SGBTConfig::builder()
1060 .n_steps(5)
1061 .learning_rate(0.1)
1062 .grace_period(5)
1063 .max_depth(2 + i) .build()
1065 .unwrap()
1066 })
1067 .collect();
1068
1069 let mut moe = MoEDistributionalSGBT::with_expert_configs(
1070 configs.clone(),
1071 GatingMode::Soft,
1072 0.01,
1073 0.0, 1e-3,
1075 500,
1076 );
1077
1078 assert_eq!(moe.n_experts(), 3);
1079 assert!(moe.expert_configs().is_some());
1080 assert_eq!(moe.expert_configs().unwrap().len(), 3);
1081
1082 for (i, cfg) in configs.iter().enumerate() {
1084 assert_eq!(moe.expert(i).config().max_depth, cfg.max_depth);
1085 }
1086
1087 for i in 0..50 {
1089 let sample = Sample::new(vec![i as f64, (i * 2) as f64], i as f64 * 3.0);
1090 moe.train_one(&sample);
1091 }
1092 let pred = moe.predict(&[10.0, 20.0]);
1093 assert!(pred.mu.is_finite());
1094 assert!(pred.sigma > 0.0);
1095 }
1096
1097 #[test]
1098 fn entropy_regularization_prevents_collapse() {
1099 let config = test_config();
1102 let mut moe = MoEDistributionalSGBT::with_expert_configs(
1103 vec![config.clone(), config.clone(), config],
1104 GatingMode::Soft,
1105 0.01,
1106 0.1, 1e-3,
1108 500,
1109 );
1110
1111 for i in 0..200 {
1113 let x = (i % 10) as f64;
1114 let sample = Sample::new(vec![x, x * 2.0], x * 3.0);
1115 moe.train_one(&sample);
1116 }
1117
1118 let probs = moe.gating_probabilities(&[5.0, 10.0]);
1120 for (i, &p) in probs.iter().enumerate() {
1121 assert!(
1122 p > 0.02,
1123 "Expert {} has probability {} -- gate collapsed despite entropy regularization",
1124 i,
1125 p
1126 );
1127 }
1128 }
1129
1130 #[test]
1131 fn moe_expert_configs_shadow_respawn_correct() {
1132 let configs: Vec<SGBTConfig> = (0..2)
1135 .map(|i| {
1136 SGBTConfig::builder()
1137 .n_steps(3)
1138 .learning_rate(0.1)
1139 .grace_period(5)
1140 .max_depth(3 + i) .build()
1142 .unwrap()
1143 })
1144 .collect();
1145
1146 let moe = MoEDistributionalSGBT::with_expert_configs(
1147 configs.clone(),
1148 GatingMode::Soft,
1149 0.01,
1150 0.0,
1151 1e-3,
1152 500,
1153 );
1154
1155 let ec = moe.expert_configs().unwrap();
1157 assert_eq!(ec[0].max_depth, 3);
1158 assert_eq!(ec[1].max_depth, 4);
1159
1160 assert_eq!(moe.expert(0).config().max_depth, 3);
1164 assert_eq!(moe.expert(1).config().max_depth, 4);
1165 }
1166}