Skip to main content

irithyll/moe/
mod.rs

1//! Streaming Neural Mixture of Experts.
2//!
3//! Polymorphic MoE where each expert is any [`StreamingLearner`] — mix ESN,
4//! Mamba, SpikeNet, SGBT, and attention models in one ensemble. A linear
5//! softmax router learns online which experts work best for which inputs.
6//!
7//! # Architecture
8//!
9//! - **Experts**: `Box<dyn StreamingLearner>` — any model type
10//! - **Router**: Linear softmax gate, trained via SGD on cross-entropy
11//! - **Top-k routing**: Only k experts are activated per sample (sparse)
12//! - **Load balancing**: Per-expert bias prevents routing collapse (DeepSeek-v3)
13//! - **Warmup protection**: Neural experts with cold-start phases are given time
14//! - **Dead expert reset**: Experts with near-zero utilization are automatically reset
15//!
16//! # References
17//!
18//! - Jacobs et al. (1991) "Adaptive Mixtures of Local Experts" — original MoE
19//! - Shazeer et al. (2017) "Outrageously Large Neural Networks" — sparse top-k gating
20//! - Wang et al. (2024) "Auxiliary-Loss-Free Load Balancing" — bias-based load balance
21//! - Aspis et al. (2025) "DriftMoE" — streaming MoE with neural router
22
23mod router;
24
25use crate::learner::StreamingLearner;
26use router::LinearRouter;
27
28// ---------------------------------------------------------------------------
29// ExpertSlot (private)
30// ---------------------------------------------------------------------------
31
32struct ExpertSlot {
33    model: Box<dyn StreamingLearner>,
34    /// Reserved for future warmup-aware routing (Phase 2).
35    #[allow(dead_code)]
36    warmup_hint: usize,
37    utilization_ewma: f64,
38    samples_trained: u64,
39}
40
41// ---------------------------------------------------------------------------
42// NeuralMoEConfig
43// ---------------------------------------------------------------------------
44
45/// Configuration for [`NeuralMoE`].
46#[derive(Debug, Clone)]
47pub struct NeuralMoEConfig {
48    /// Number of experts activated per sample (default: 2).
49    pub top_k: usize,
50    /// Router learning rate (default: 0.01).
51    pub router_lr: f64,
52    /// Load balance bias adjustment rate (default: 0.01).
53    pub load_balance_rate: f64,
54    /// EWMA span for utilization tracking (default: 500).
55    pub utilization_span: usize,
56    /// Utilization threshold — experts below this are "dead" (default: 0.01).
57    pub utilization_threshold: f64,
58    /// Whether to auto-reset dead experts (default: true).
59    pub reset_dead: bool,
60    /// RNG seed (default: 42).
61    pub seed: u64,
62}
63
64impl Default for NeuralMoEConfig {
65    fn default() -> Self {
66        Self {
67            top_k: 2,
68            router_lr: 0.01,
69            load_balance_rate: 0.01,
70            utilization_span: 500,
71            utilization_threshold: 0.01,
72            reset_dead: true,
73            seed: 42,
74        }
75    }
76}
77
78// ---------------------------------------------------------------------------
79// NeuralMoE
80// ---------------------------------------------------------------------------
81
82/// Streaming Neural Mixture of Experts.
83///
84/// Polymorphic MoE where each expert can be any `StreamingLearner`.
85/// Implements `StreamingLearner` itself for composability.
86///
87/// # Example
88///
89/// ```no_run
90/// use irithyll::moe::NeuralMoE;
91/// use irithyll::{sgbt, esn, StreamingLearner};
92///
93/// let mut moe = NeuralMoE::builder()
94///     .expert(sgbt(50, 0.01))
95///     .expert(sgbt(100, 0.005))
96///     .expert_with_warmup(esn(50, 0.9), 50)
97///     .top_k(2)
98///     .build();
99///
100/// moe.train(&[1.0, 2.0, 3.0], 4.0);
101/// let pred = moe.predict(&[1.0, 2.0, 3.0]);
102/// ```
103pub struct NeuralMoE {
104    experts: Vec<ExpertSlot>,
105    router: LinearRouter,
106    config: NeuralMoEConfig,
107    n_samples: u64,
108    /// Cached expert disagreement (std dev of active expert predictions).
109    cached_disagreement: f64,
110    /// Previous prediction for residual alignment tracking.
111    prev_prediction: f64,
112    /// Previous prediction change for residual alignment tracking.
113    prev_change: f64,
114    /// Change from two steps ago, for acceleration-based alignment.
115    prev_prev_change: f64,
116    /// EWMA of residual alignment signal.
117    alignment_ewma: f64,
118    /// EWMA of normalized gating entropy for expert utilization.
119    gate_entropy_ewma: f64,
120}
121
122// ---------------------------------------------------------------------------
123// NeuralMoEBuilder
124// ---------------------------------------------------------------------------
125
126/// Builder for [`NeuralMoE`].
127pub struct NeuralMoEBuilder {
128    experts: Vec<(Box<dyn StreamingLearner>, usize)>, // (model, warmup_hint)
129    config: NeuralMoEConfig,
130}
131
132impl NeuralMoE {
133    /// Start building a `NeuralMoE` with the builder pattern.
134    pub fn builder() -> NeuralMoEBuilder {
135        NeuralMoEBuilder {
136            experts: Vec::new(),
137            config: NeuralMoEConfig::default(),
138        }
139    }
140}
141
142impl NeuralMoEBuilder {
143    /// Add an expert with no warmup protection.
144    pub fn expert(mut self, model: impl StreamingLearner + 'static) -> Self {
145        self.experts.push((Box::new(model), 0));
146        self
147    }
148
149    /// Add an expert with a warmup hint (protected during cold-start).
150    pub fn expert_with_warmup(
151        mut self,
152        model: impl StreamingLearner + 'static,
153        warmup: usize,
154    ) -> Self {
155        self.experts.push((Box::new(model), warmup));
156        self
157    }
158
159    /// Set the number of experts activated per sample.
160    pub fn top_k(mut self, k: usize) -> Self {
161        self.config.top_k = k;
162        self
163    }
164
165    /// Set the router learning rate.
166    pub fn router_lr(mut self, lr: f64) -> Self {
167        self.config.router_lr = lr;
168        self
169    }
170
171    /// Set the load balance bias adjustment rate.
172    pub fn load_balance_rate(mut self, r: f64) -> Self {
173        self.config.load_balance_rate = r;
174        self
175    }
176
177    /// Set the EWMA span for utilization tracking.
178    pub fn utilization_span(mut self, s: usize) -> Self {
179        self.config.utilization_span = s;
180        self
181    }
182
183    /// Set the utilization threshold below which experts are "dead".
184    pub fn utilization_threshold(mut self, t: f64) -> Self {
185        self.config.utilization_threshold = t;
186        self
187    }
188
189    /// Set whether to auto-reset dead experts.
190    pub fn reset_dead(mut self, b: bool) -> Self {
191        self.config.reset_dead = b;
192        self
193    }
194
195    /// Set the RNG seed.
196    pub fn seed(mut self, s: u64) -> Self {
197        self.config.seed = s;
198        self
199    }
200
201    /// Build the NeuralMoE.
202    ///
203    /// # Panics
204    /// Panics if fewer than 2 experts were added.
205    pub fn build(self) -> NeuralMoE {
206        assert!(
207            self.experts.len() >= 2,
208            "NeuralMoE requires at least 2 experts, got {}",
209            self.experts.len()
210        );
211
212        let k = self.experts.len();
213        let config = self.config;
214
215        let router = LinearRouter::new(
216            k,
217            config.router_lr,
218            config.load_balance_rate,
219            config.utilization_span,
220        );
221
222        let experts: Vec<ExpertSlot> = self
223            .experts
224            .into_iter()
225            .map(|(model, warmup)| ExpertSlot {
226                model,
227                warmup_hint: warmup,
228                utilization_ewma: 0.0,
229                samples_trained: 0,
230            })
231            .collect();
232
233        NeuralMoE {
234            experts,
235            router,
236            config,
237            n_samples: 0,
238            cached_disagreement: 0.0,
239            prev_prediction: 0.0,
240            prev_change: 0.0,
241            prev_prev_change: 0.0,
242            alignment_ewma: 0.0,
243            gate_entropy_ewma: 0.0,
244        }
245    }
246}
247
248// ---------------------------------------------------------------------------
249// Public methods
250// ---------------------------------------------------------------------------
251
252impl NeuralMoE {
253    /// Number of experts.
254    pub fn n_experts(&self) -> usize {
255        self.experts.len()
256    }
257
258    /// Current top-k setting.
259    pub fn top_k(&self) -> usize {
260        self.config.top_k
261    }
262
263    /// Per-expert utilization (EWMA of routing probability).
264    pub fn utilization(&self) -> Vec<f64> {
265        self.experts.iter().map(|e| e.utilization_ewma).collect()
266    }
267
268    /// Per-expert samples trained.
269    pub fn expert_samples(&self) -> Vec<u64> {
270        self.experts.iter().map(|e| e.samples_trained).collect()
271    }
272
273    /// Number of dead experts (utilization below threshold).
274    pub fn n_dead_experts(&self) -> usize {
275        self.experts
276            .iter()
277            .filter(|e| {
278                e.samples_trained > self.config.utilization_span as u64
279                    && e.utilization_ewma < self.config.utilization_threshold
280            })
281            .count()
282    }
283
284    /// Load distribution from the router.
285    pub fn load_distribution(&self) -> &[f64] {
286        self.router.load_distribution()
287    }
288
289    /// Expert disagreement: std dev of all expert predictions.
290    ///
291    /// High disagreement indicates the experts have divergent views on
292    /// this input — a real uncertainty signal. Returns 0.0 when fewer
293    /// than 2 experts exist or predictions are uniform.
294    pub fn expert_disagreement(&self, features: &[f64]) -> f64 {
295        let preds = self.expert_predictions(features);
296        if preds.len() < 2 {
297            return 0.0;
298        }
299        let n = preds.len() as f64;
300        let mean = preds.iter().sum::<f64>() / n;
301        let var = preds.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / (n - 1.0);
302        var.sqrt()
303    }
304
305    /// Cached expert disagreement from the most recent `train_one` call.
306    ///
307    /// This avoids recomputing disagreement in `config_diagnostics()`, which
308    /// only has `&self` (no features available). Updated every `train_one`.
309    #[inline]
310    pub fn cached_disagreement(&self) -> f64 {
311        self.cached_disagreement
312    }
313
314    /// Get predictions from all experts (for inspection).
315    pub fn expert_predictions(&self, features: &[f64]) -> Vec<f64> {
316        self.experts
317            .iter()
318            .map(|e| e.model.predict(features))
319            .collect()
320    }
321
322    /// Get current routing probabilities (for inspection).
323    pub fn routing_probabilities(&self, features: &[f64]) -> Vec<f64> {
324        self.router.probabilities(features)
325    }
326}
327
328// ---------------------------------------------------------------------------
329// StreamingLearner implementation
330// ---------------------------------------------------------------------------
331
332impl StreamingLearner for NeuralMoE {
333    fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
334        let k = self.config.top_k.min(self.experts.len());
335
336        // 1. Select top-k experts via router
337        let active_indices = self.router.select_top_k(features, k);
338
339        // 2. Collect predictions from active experts + find best + cache disagreement
340        let mut best_idx = active_indices[0];
341        let mut best_error = f64::INFINITY;
342        let mut active_preds: Vec<f64> = Vec::with_capacity(k);
343
344        for &idx in &active_indices {
345            let pred = self.experts[idx].model.predict(features);
346            active_preds.push(pred);
347            let error = (target - pred).abs();
348            if error < best_error {
349                best_error = error;
350                best_idx = idx;
351            }
352        }
353
354        // Cache expert disagreement (std dev of active expert predictions)
355        if active_preds.len() >= 2 {
356            let n = active_preds.len() as f64;
357            let mean = active_preds.iter().sum::<f64>() / n;
358            let var = active_preds.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / (n - 1.0);
359            self.cached_disagreement = var.sqrt();
360        }
361
362        // Update residual alignment tracking using the weighted prediction.
363        {
364            let weights = self.router.renormalized_weights(features, &active_indices);
365            let mut current_pred = 0.0;
366            for (idx, w) in &weights {
367                current_pred +=
368                    w * active_preds[active_indices.iter().position(|&i| i == *idx).unwrap_or(0)];
369            }
370            let current_change = current_pred - self.prev_prediction;
371            if self.n_samples > 0 {
372                let acceleration = current_change - self.prev_change;
373                let prev_acceleration = self.prev_change - self.prev_prev_change;
374                let agreement = if acceleration.abs() > 1e-15 && prev_acceleration.abs() > 1e-15 {
375                    if (acceleration > 0.0) == (prev_acceleration > 0.0) {
376                        1.0
377                    } else {
378                        -1.0
379                    }
380                } else {
381                    0.0
382                };
383                const ALIGN_ALPHA: f64 = 0.05;
384                self.alignment_ewma =
385                    (1.0 - ALIGN_ALPHA) * self.alignment_ewma + ALIGN_ALPHA * agreement;
386            }
387            self.prev_prev_change = self.prev_change;
388            self.prev_change = current_change;
389            self.prev_prediction = current_pred;
390        }
391
392        // 3. Train active experts
393        for &idx in &active_indices {
394            self.experts[idx].model.train_one(features, target, weight);
395            self.experts[idx].samples_trained += 1;
396        }
397
398        // 4. Update router (cross-entropy on best expert)
399        self.router.update(features, best_idx);
400
401        // 5. Update load balancing
402        self.router.update_load_balance(&active_indices);
403
404        // 6. Update utilization EWMA for all experts
405        let probs = self.router.probabilities(features);
406        let util_alpha = 2.0 / (self.config.utilization_span as f64 + 1.0);
407        for (i, slot) in self.experts.iter_mut().enumerate() {
408            let p = if i < probs.len() { probs[i] } else { 0.0 };
409            slot.utilization_ewma = util_alpha * p + (1.0 - util_alpha) * slot.utilization_ewma;
410        }
411
412        // Track gating entropy: H = -sum(p_k * ln(p_k)) / ln(K).
413        {
414            let k_experts = probs.len();
415            if k_experts > 1 {
416                let ln_k = (k_experts as f64).ln();
417                let mut h = 0.0;
418                for &p in &probs {
419                    if p > 1e-15 {
420                        h -= p * p.ln();
421                    }
422                }
423                let normalized_h = (h / ln_k).clamp(0.0, 1.0);
424                const GATE_ALPHA: f64 = 0.01;
425                self.gate_entropy_ewma =
426                    (1.0 - GATE_ALPHA) * self.gate_entropy_ewma + GATE_ALPHA * normalized_h;
427            }
428        }
429
430        // 7. Check for dead experts (only after enough samples)
431        if self.config.reset_dead && self.n_samples > self.config.utilization_span as u64 {
432            self.reset_dead_experts();
433        }
434
435        self.n_samples += 1;
436    }
437
438    fn predict(&self, features: &[f64]) -> f64 {
439        let k = self.config.top_k.min(self.experts.len());
440        let active_indices = self.router.select_top_k(features, k);
441        let weights = self.router.renormalized_weights(features, &active_indices);
442
443        // Weighted prediction: sum(w_k * f_k(x)) for active experts
444        let mut pred = 0.0;
445        for (idx, w) in &weights {
446            pred += w * self.experts[*idx].model.predict(features);
447        }
448        pred
449    }
450
451    fn n_samples_seen(&self) -> u64 {
452        self.n_samples
453    }
454
455    fn reset(&mut self) {
456        for slot in &mut self.experts {
457            slot.model.reset();
458            slot.utilization_ewma = 0.0;
459            slot.samples_trained = 0;
460        }
461        self.router.reset();
462        self.n_samples = 0;
463        self.cached_disagreement = 0.0;
464        self.prev_prediction = 0.0;
465        self.prev_change = 0.0;
466        self.prev_prev_change = 0.0;
467        self.alignment_ewma = 0.0;
468        self.gate_entropy_ewma = 0.0;
469    }
470
471    fn diagnostics_array(&self) -> [f64; 5] {
472        use crate::automl::DiagnosticSource;
473        match self.config_diagnostics() {
474            Some(d) => [
475                d.residual_alignment,
476                d.regularization_sensitivity,
477                d.depth_sufficiency,
478                d.effective_dof,
479                d.uncertainty,
480            ],
481            None => [0.0; 5],
482        }
483    }
484}
485
486// ---------------------------------------------------------------------------
487// Private methods
488// ---------------------------------------------------------------------------
489
490impl NeuralMoE {
491    /// Reset experts with near-zero utilization.
492    fn reset_dead_experts(&mut self) {
493        for slot in &mut self.experts {
494            if slot.samples_trained > self.config.utilization_span as u64
495                && slot.utilization_ewma < self.config.utilization_threshold
496            {
497                slot.model.reset();
498                slot.utilization_ewma = 0.0;
499                slot.samples_trained = 0;
500            }
501        }
502    }
503}
504
505// ---------------------------------------------------------------------------
506// DiagnosticSource impl
507// ---------------------------------------------------------------------------
508
509impl crate::automl::DiagnosticSource for NeuralMoE {
510    fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
511        // Gating entropy EWMA as depth_sufficiency: measures how evenly the
512        // router distributes load across experts. 1.0 = uniform utilization,
513        // 0.0 = single expert dominates.
514        let depth_sufficiency = self.gate_entropy_ewma.clamp(0.0, 1.0);
515
516        Some(crate::automl::ConfigDiagnostics {
517            residual_alignment: self.alignment_ewma,
518            regularization_sensitivity: self.config.load_balance_rate,
519            depth_sufficiency,
520            effective_dof: self.n_experts() as f64,
521            uncertainty: self.cached_disagreement,
522        })
523    }
524}
525
526// ---------------------------------------------------------------------------
527// Tests
528// ---------------------------------------------------------------------------
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533    // Import factory functions for creating test experts
534    use crate::{linear, rls, sgbt};
535
536    #[test]
537    fn builder_creates_moe() {
538        let moe = NeuralMoE::builder()
539            .expert(sgbt(10, 0.01))
540            .expert(sgbt(20, 0.01))
541            .expert(linear(0.01))
542            .top_k(2)
543            .build();
544
545        assert_eq!(moe.n_experts(), 3);
546        assert_eq!(moe.top_k(), 2);
547        assert_eq!(moe.n_samples_seen(), 0);
548    }
549
550    #[test]
551    #[should_panic(expected = "at least 2 experts")]
552    fn builder_panics_with_one_expert() {
553        NeuralMoE::builder().expert(sgbt(10, 0.01)).build();
554    }
555
556    #[test]
557    fn train_and_predict_finite() {
558        let mut moe = NeuralMoE::builder()
559            .expert(sgbt(10, 0.01))
560            .expert(sgbt(20, 0.01))
561            .expert(linear(0.01))
562            .top_k(2)
563            .build();
564
565        for i in 0..100 {
566            let x = [i as f64 * 0.01, (i as f64).sin()];
567            let y = x[0] * 2.0 + 1.0;
568            moe.train(&x, y);
569        }
570
571        let pred = moe.predict(&[0.5, 0.5_f64.sin()]);
572        assert!(pred.is_finite(), "prediction should be finite, got {pred}");
573    }
574
575    #[test]
576    fn n_samples_tracks_correctly() {
577        let mut moe = NeuralMoE::builder()
578            .expert(linear(0.01))
579            .expert(linear(0.02))
580            .build();
581
582        for i in 0..42 {
583            moe.train(&[i as f64], i as f64 * 2.0);
584        }
585        assert_eq!(moe.n_samples_seen(), 42);
586    }
587
588    #[test]
589    fn reset_clears_state() {
590        let mut moe = NeuralMoE::builder()
591            .expert(linear(0.01))
592            .expert(linear(0.02))
593            .build();
594
595        for i in 0..50 {
596            moe.train(&[i as f64], i as f64);
597        }
598        assert!(moe.n_samples_seen() > 0);
599
600        moe.reset();
601        assert_eq!(moe.n_samples_seen(), 0);
602        for s in moe.expert_samples() {
603            assert_eq!(s, 0, "expert samples should be 0 after reset");
604        }
605    }
606
607    #[test]
608    fn implements_streaming_learner() {
609        let moe = NeuralMoE::builder()
610            .expert(linear(0.01))
611            .expert(linear(0.02))
612            .build();
613
614        let mut boxed: Box<dyn StreamingLearner> = Box::new(moe);
615        boxed.train(&[1.0], 2.0);
616        let pred = boxed.predict(&[1.0]);
617        assert!(pred.is_finite(), "trait object prediction should be finite");
618    }
619
620    #[test]
621    fn expert_predictions_returns_all() {
622        let moe = NeuralMoE::builder()
623            .expert(linear(0.01))
624            .expert(linear(0.02))
625            .expert(linear(0.05))
626            .top_k(2)
627            .build();
628
629        let preds = moe.expert_predictions(&[1.0]);
630        assert_eq!(preds.len(), 3, "should have predictions from all 3 experts");
631    }
632
633    #[test]
634    fn routing_probabilities_sum_to_one() {
635        let moe = NeuralMoE::builder()
636            .expert(sgbt(10, 0.01))
637            .expert(sgbt(20, 0.01))
638            .expert(linear(0.01))
639            .build();
640
641        let probs = moe.routing_probabilities(&[1.0, 2.0]);
642        let sum: f64 = probs.iter().sum();
643        assert!(
644            (sum - 1.0).abs() < 1e-10,
645            "routing probabilities should sum to 1.0, got {sum}"
646        );
647    }
648
649    #[test]
650    fn utilization_starts_at_zero() {
651        let moe = NeuralMoE::builder()
652            .expert(linear(0.01))
653            .expert(linear(0.02))
654            .build();
655
656        for u in moe.utilization() {
657            assert!((u - 0.0).abs() < 1e-12, "initial utilization should be 0.0");
658        }
659    }
660
661    #[test]
662    fn warmup_hint_stored() {
663        let moe = NeuralMoE::builder()
664            .expert(linear(0.01))
665            .expert_with_warmup(linear(0.02), 50)
666            .build();
667
668        assert_eq!(moe.experts[0].warmup_hint, 0, "first expert has no warmup");
669        assert_eq!(
670            moe.experts[1].warmup_hint, 50,
671            "second expert has warmup 50"
672        );
673    }
674
675    #[test]
676    fn heterogeneous_experts_work() {
677        let mut moe = NeuralMoE::builder()
678            .expert(sgbt(10, 0.01))
679            .expert(linear(0.01))
680            .expert(rls(0.99))
681            .top_k(2)
682            .build();
683
684        for i in 0..200 {
685            let x = [i as f64 * 0.01, (i as f64 * 0.1).sin()];
686            let y = x[0] * 3.0 + x[1] * 2.0 + 1.0;
687            moe.train(&x, y);
688        }
689
690        let pred = moe.predict(&[1.0, 1.0_f64.sin()]);
691        assert!(
692            pred.is_finite(),
693            "heterogeneous MoE prediction should be finite, got {pred}"
694        );
695    }
696
697    #[test]
698    fn top_k_limits_active_experts() {
699        let mut moe = NeuralMoE::builder()
700            .expert(linear(0.01))
701            .expert(linear(0.02))
702            .expert(linear(0.03))
703            .expert(linear(0.04))
704            .top_k(1) // only 1 expert active per sample
705            .build();
706
707        // Train some data
708        for i in 0..100 {
709            moe.train(&[i as f64], i as f64 * 2.0);
710        }
711
712        // With top_k=1, exactly 1 expert trains per sample = 100 total expert trains
713        let samples = moe.expert_samples();
714        let total_expert_trains: u64 = samples.iter().sum();
715        assert_eq!(
716            total_expert_trains, 100,
717            "with top_k=1, total expert trains should equal n_samples"
718        );
719    }
720
721    #[test]
722    fn load_distribution_available() {
723        let moe = NeuralMoE::builder()
724            .expert(linear(0.01))
725            .expert(linear(0.02))
726            .build();
727
728        let load = moe.load_distribution();
729        assert_eq!(load.len(), 2, "load distribution should have 2 entries");
730    }
731
732    #[test]
733    fn custom_config() {
734        let moe = NeuralMoE::builder()
735            .expert(linear(0.01))
736            .expert(linear(0.02))
737            .top_k(1)
738            .router_lr(0.05)
739            .load_balance_rate(0.02)
740            .utilization_span(200)
741            .utilization_threshold(0.05)
742            .reset_dead(false)
743            .seed(999)
744            .build();
745
746        assert_eq!(moe.config.top_k, 1);
747        assert!((moe.config.router_lr - 0.05).abs() < 1e-12);
748        assert!((moe.config.load_balance_rate - 0.02).abs() < 1e-12);
749        assert_eq!(moe.config.utilization_span, 200);
750        assert!((moe.config.utilization_threshold - 0.05).abs() < 1e-12);
751        assert!(!moe.config.reset_dead);
752        assert_eq!(moe.config.seed, 999);
753    }
754
755    #[test]
756    fn moe_expert_disagreement() {
757        let mut moe = NeuralMoE::builder()
758            .expert(sgbt(10, 0.01))
759            .expert(sgbt(20, 0.01))
760            .expert(linear(0.01))
761            .top_k(2)
762            .build();
763
764        // Before training, cached_disagreement is 0
765        assert!(
766            moe.cached_disagreement().abs() < 1e-15,
767            "cached_disagreement should be 0 before training, got {}",
768            moe.cached_disagreement()
769        );
770
771        // Train on 100 samples
772        for i in 0..100 {
773            let x = [i as f64 * 0.01, (i as f64).sin()];
774            let y = x[0] * 2.0 + 1.0;
775            moe.train(&x, y);
776        }
777
778        // After training, experts should have diverged enough for disagreement > 0
779        let disagree = moe.cached_disagreement();
780        assert!(
781            disagree >= 0.0,
782            "expert_disagreement should be >= 0, got {}",
783            disagree
784        );
785        assert!(
786            disagree.is_finite(),
787            "expert_disagreement should be finite, got {}",
788            disagree
789        );
790
791        // Also test the direct method
792        let direct = moe.expert_disagreement(&[0.5, 0.5_f64.sin()]);
793        assert!(
794            direct.is_finite(),
795            "expert_disagreement() should be finite, got {}",
796            direct
797        );
798    }
799}