Skip to main content

irithyll_core/ensemble/
moe_distributional.rs

1//! Streaming Mixture of Experts over Distributional SGBT ensembles with
2//! shadow expert competition.
3//!
4//! Combines K independent [`DistributionalSGBT`] experts with a learned
5//! linear softmax gating network. Each expert outputs a full Gaussian
6//! predictive distribution N(mu, sigma^2). The mixture prediction applies
7//! the law of total variance to produce a single Gaussian from the K
8//! expert predictions weighted by gating probabilities.
9//!
10//! Additionally, each expert slot has a *shadow* expert that trains in
11//! parallel. When the shadow achieves statistically better Gaussian NLL
12//! than the active expert (verified via the Hoeffding bound), it replaces
13//! the active expert. This enables continuous capacity adaptation.
14//!
15//! # Algorithm
16//!
17//! The gating network computes K logits `z_k = W_k . x + b_k` and applies
18//! softmax to obtain routing probabilities `p_k = softmax(z)_k`. Prediction
19//! is a mixture of Gaussians via the law of total variance:
20//!
21//! ```text
22//! mu_mix = sum(p_k * mu_k)
23//! var_mix = sum(p_k * (sigma_k^2 + mu_k^2)) - mu_mix^2
24//! sigma_mix = sqrt(var_mix)
25//! ```
26//!
27//! The gate is updated via online SGD on the cross-entropy loss between
28//! the softmax distribution and the one-hot indicator of the best expert
29//! (lowest Gaussian NLL on the current sample).
30//!
31//! # Shadow Competition
32//!
33//! Each expert slot maintains a shadow model that trains on the same data.
34//! After `shadow_min_samples` observations, the Hoeffding bound is used to
35//! test whether the shadow's cumulative NLL advantage is statistically
36//! significant:
37//!
38//! ```text
39//! epsilon = sqrt(R^2 * ln(1/delta) / (2*n))
40//! swap if mean_advantage > epsilon
41//! ```
42//!
43//! # Example
44//!
45//! ```text
46//! use irithyll::ensemble::moe_distributional::MoEDistributionalSGBT;
47//! use irithyll::ensemble::moe::GatingMode;
48//! use irithyll::SGBTConfig;
49//!
50//! let config = SGBTConfig::builder()
51//!     .n_steps(10)
52//!     .learning_rate(0.1)
53//!     .grace_period(10)
54//!     .build()
55//!     .unwrap();
56//!
57//! let mut moe = MoEDistributionalSGBT::new(config, 3);
58//! moe.train_one(&irithyll::Sample::new(vec![1.0, 2.0], 3.0));
59//! let pred = moe.predict(&[1.0, 2.0]);
60//! assert!(pred.sigma > 0.0);
61//! ```
62
63use alloc::vec;
64use alloc::vec::Vec;
65
66use core::fmt;
67
68use crate::ensemble::config::SGBTConfig;
69use crate::ensemble::distributional::{DistributionalSGBT, GaussianPrediction};
70use crate::ensemble::moe::{softmax, GatingMode};
71use crate::sample::{Observation, SampleRef};
72
73// ---------------------------------------------------------------------------
74// MoEDistributionalSGBT
75// ---------------------------------------------------------------------------
76
77/// Streaming Mixture of Experts over [`DistributionalSGBT`] ensembles with
78/// shadow expert competition.
79///
80/// Combines K independent distributional experts with a learned linear
81/// softmax gating network. Prediction is a mixture of Gaussians via the
82/// law of total variance. Shadow experts compete to replace active experts
83/// using the Hoeffding bound on Gaussian NLL differences.
84pub struct MoEDistributionalSGBT {
85    /// The K active expert DistributionalSGBT ensembles.
86    experts: Vec<DistributionalSGBT>,
87    /// The K shadow experts (one per slot, trained in parallel).
88    shadows: Vec<DistributionalSGBT>,
89    /// Gate weight matrix [K x d], lazily initialized on first sample.
90    gate_weights: Vec<Vec<f64>>,
91    /// Gate bias vector [K].
92    gate_bias: Vec<f64>,
93    /// Learning rate for the gating network SGD updates.
94    gate_lr: f64,
95    /// Number of features (set on first sample, `None` until then).
96    n_features: Option<usize>,
97    /// Gating mode (soft or hard top-k).
98    gating_mode: GatingMode,
99    /// Configuration used to construct each expert.
100    config: SGBTConfig,
101    /// Optional per-expert configurations. When `Some`, each expert uses its
102    /// own `SGBTConfig` instead of the shared `config`.
103    expert_configs: Option<Vec<SGBTConfig>>,
104    /// Total training samples seen.
105    samples_seen: u64,
106    /// Entropy regularization weight for gate load balancing.
107    ///
108    /// Adds `entropy_weight * entropy_gradient` to the gate SGD update,
109    /// encouraging the gate to spread probability mass across all experts
110    /// rather than collapsing to a single expert.
111    ///
112    /// Default: 0.0 (disabled for backward compat in existing constructors).
113    entropy_weight: f64,
114
115    // -- Shadow competition state per slot --
116    /// Cumulative NLL advantage of shadow over active (positive = shadow better).
117    cumulative_advantage: Vec<f64>,
118    /// Number of comparison samples per slot.
119    shadow_n: Vec<u64>,
120    /// Maximum absolute NLL difference seen per slot (for Hoeffding range R).
121    max_nll_diff: Vec<f64>,
122    /// Hoeffding confidence parameter (default 1e-3).
123    delta: f64,
124    /// Minimum samples before shadow comparison begins.
125    shadow_min_samples: u64,
126    /// Count of shadow replacements per slot.
127    shadow_replacements: Vec<u64>,
128}
129
130// ---------------------------------------------------------------------------
131// Clone
132// ---------------------------------------------------------------------------
133
134impl Clone for MoEDistributionalSGBT {
135    fn clone(&self) -> Self {
136        Self {
137            experts: self.experts.clone(),
138            shadows: self.shadows.clone(),
139            gate_weights: self.gate_weights.clone(),
140            gate_bias: self.gate_bias.clone(),
141            gate_lr: self.gate_lr,
142            n_features: self.n_features,
143            gating_mode: self.gating_mode.clone(),
144            config: self.config.clone(),
145            expert_configs: self.expert_configs.clone(),
146            samples_seen: self.samples_seen,
147            entropy_weight: self.entropy_weight,
148            cumulative_advantage: self.cumulative_advantage.clone(),
149            shadow_n: self.shadow_n.clone(),
150            max_nll_diff: self.max_nll_diff.clone(),
151            delta: self.delta,
152            shadow_min_samples: self.shadow_min_samples,
153            shadow_replacements: self.shadow_replacements.clone(),
154        }
155    }
156}
157
158// ---------------------------------------------------------------------------
159// Debug
160// ---------------------------------------------------------------------------
161
162impl fmt::Debug for MoEDistributionalSGBT {
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        f.debug_struct("MoEDistributionalSGBT")
165            .field("n_experts", &self.experts.len())
166            .field("samples_seen", &self.samples_seen)
167            .field("shadow_replacements", &self.shadow_replacements)
168            .finish()
169    }
170}
171
172// ---------------------------------------------------------------------------
173// Constructors
174// ---------------------------------------------------------------------------
175
176impl MoEDistributionalSGBT {
177    /// Create a new MoE distributional ensemble with soft gating and default
178    /// shadow competition parameters.
179    ///
180    /// Defaults: gate_lr = 0.01, delta = 1e-3, shadow_min_samples = 500.
181    ///
182    /// # Panics
183    ///
184    /// Panics if `n_experts < 1`.
185    pub fn new(config: SGBTConfig, n_experts: usize) -> Self {
186        Self::with_shadow_config(config, n_experts, GatingMode::Soft, 0.01, 1e-3, 500)
187    }
188
189    /// Create a new MoE distributional ensemble with custom gating mode
190    /// and gate learning rate. Uses default shadow parameters.
191    ///
192    /// # Panics
193    ///
194    /// Panics if `n_experts < 1`.
195    pub fn with_gating(
196        config: SGBTConfig,
197        n_experts: usize,
198        gating_mode: GatingMode,
199        gate_lr: f64,
200    ) -> Self {
201        Self::with_shadow_config(config, n_experts, gating_mode, gate_lr, 1e-3, 500)
202    }
203
204    /// Create a new MoE distributional ensemble with full control over
205    /// gating and shadow competition parameters.
206    ///
207    /// # Arguments
208    ///
209    /// * `config` -- SGBT configuration for each expert
210    /// * `n_experts` -- number of expert slots
211    /// * `gating_mode` -- soft or hard(top_k) gating
212    /// * `gate_lr` -- learning rate for the gating network
213    /// * `delta` -- Hoeffding confidence parameter (lower = more conservative)
214    /// * `shadow_min_samples` -- warmup before shadow comparison begins
215    ///
216    /// # Panics
217    ///
218    /// Panics if `n_experts < 1`.
219    pub fn with_shadow_config(
220        config: SGBTConfig,
221        n_experts: usize,
222        gating_mode: GatingMode,
223        gate_lr: f64,
224        delta: f64,
225        shadow_min_samples: u64,
226    ) -> Self {
227        assert!(
228            n_experts >= 1,
229            "MoEDistributionalSGBT requires at least 1 expert"
230        );
231
232        let experts: Vec<DistributionalSGBT> = (0..n_experts)
233            .map(|i| {
234                let mut cfg = config.clone();
235                cfg.seed = config.seed ^ (0x0E00_0000 | i as u64);
236                DistributionalSGBT::new(cfg)
237            })
238            .collect();
239
240        let shadows: Vec<DistributionalSGBT> = (0..n_experts)
241            .map(|i| {
242                let mut cfg = config.clone();
243                cfg.seed = config.seed ^ (0x5A00_0000 | i as u64);
244                DistributionalSGBT::new(cfg)
245            })
246            .collect();
247
248        let gate_bias = vec![0.0; n_experts];
249
250        Self {
251            experts,
252            shadows,
253            gate_weights: Vec::new(), // lazy init
254            gate_bias,
255            gate_lr,
256            n_features: None,
257            gating_mode,
258            config,
259            expert_configs: None,
260            samples_seen: 0,
261            entropy_weight: 0.0,
262            cumulative_advantage: vec![0.0; n_experts],
263            shadow_n: vec![0; n_experts],
264            max_nll_diff: vec![0.0; n_experts],
265            delta,
266            shadow_min_samples,
267            shadow_replacements: vec![0; n_experts],
268        }
269    }
270
271    /// Create a new MoE distributional ensemble where each expert uses its own
272    /// `SGBTConfig`, enabling different depths, lambda, learning rates, etc.
273    ///
274    /// The first config in `configs` is also used as the shared fallback (stored
275    /// in `self.config`) for any field-level queries.
276    ///
277    /// # Arguments
278    ///
279    /// * `configs` -- one `SGBTConfig` per expert (length determines n_experts)
280    /// * `gating_mode` -- soft or hard(top_k) gating
281    /// * `gate_lr` -- learning rate for the gating network
282    /// * `entropy_weight` -- entropy regularization weight (0.0 = disabled,
283    ///   0.1 = typical for preventing gate collapse)
284    /// * `delta` -- Hoeffding confidence parameter for shadow competition
285    /// * `shadow_min_samples` -- warmup before shadow comparison begins
286    ///
287    /// # Panics
288    ///
289    /// Panics if `configs` is empty.
290    pub fn with_expert_configs(
291        configs: Vec<SGBTConfig>,
292        gating_mode: GatingMode,
293        gate_lr: f64,
294        entropy_weight: f64,
295        delta: f64,
296        shadow_min_samples: u64,
297    ) -> Self {
298        assert!(
299            !configs.is_empty(),
300            "MoEDistributionalSGBT requires at least 1 expert config"
301        );
302
303        let n_experts = configs.len();
304
305        let experts: Vec<DistributionalSGBT> = configs
306            .iter()
307            .enumerate()
308            .map(|(i, cfg)| {
309                let mut c = cfg.clone();
310                c.seed = cfg.seed ^ (0x0E00_0000 | i as u64);
311                DistributionalSGBT::new(c)
312            })
313            .collect();
314
315        let shadows: Vec<DistributionalSGBT> = configs
316            .iter()
317            .enumerate()
318            .map(|(i, cfg)| {
319                let mut c = cfg.clone();
320                c.seed = cfg.seed ^ (0x5A00_0000 | i as u64);
321                DistributionalSGBT::new(c)
322            })
323            .collect();
324
325        let gate_bias = vec![0.0; n_experts];
326        let shared_config = configs[0].clone();
327
328        Self {
329            experts,
330            shadows,
331            gate_weights: Vec::new(),
332            gate_bias,
333            gate_lr,
334            n_features: None,
335            gating_mode,
336            config: shared_config,
337            expert_configs: Some(configs),
338            samples_seen: 0,
339            entropy_weight,
340            cumulative_advantage: vec![0.0; n_experts],
341            shadow_n: vec![0; n_experts],
342            max_nll_diff: vec![0.0; n_experts],
343            delta,
344            shadow_min_samples,
345            shadow_replacements: vec![0; n_experts],
346        }
347    }
348}
349
350// ---------------------------------------------------------------------------
351// Core impl
352// ---------------------------------------------------------------------------
353
354impl MoEDistributionalSGBT {
355    // -------------------------------------------------------------------
356    // Internal helpers
357    // -------------------------------------------------------------------
358
359    /// Ensure the gate weight matrix is initialized to the correct dimensions.
360    fn ensure_gate_init(&mut self, d: usize) {
361        if self.n_features.is_none() {
362            let k = self.experts.len();
363            self.gate_weights = vec![vec![0.0; d]; k];
364            self.n_features = Some(d);
365        }
366    }
367
368    /// Compute raw gate logits: z_k = W_k . x + b_k.
369    fn gate_logits(&self, features: &[f64]) -> Vec<f64> {
370        let k = self.experts.len();
371        let mut logits = Vec::with_capacity(k);
372        for i in 0..k {
373            let dot: f64 = self.gate_weights[i]
374                .iter()
375                .zip(features.iter())
376                .map(|(&w, &x)| w * x)
377                .sum();
378            logits.push(dot + self.gate_bias[i]);
379        }
380        logits
381    }
382
383    /// Compute Gaussian NLL for a single prediction.
384    /// nll = log_sigma + 0.5 * ((target - mu) / sigma)^2
385    #[inline]
386    fn gaussian_nll(pred: &GaussianPrediction, target: f64) -> f64 {
387        let z = (target - pred.mu) / pred.sigma.max(1e-16);
388        pred.log_sigma + 0.5 * z * z
389    }
390
391    // -------------------------------------------------------------------
392    // Public API -- gating
393    // -------------------------------------------------------------------
394
395    /// Compute gating probabilities for a feature vector.
396    ///
397    /// Returns a vector of K probabilities that sum to 1.0, one per expert.
398    /// If the gate is not yet initialized, returns uniform probabilities.
399    pub fn gating_probabilities(&self, features: &[f64]) -> Vec<f64> {
400        let k = self.experts.len();
401        if self.n_features.is_none() {
402            return vec![1.0 / k as f64; k];
403        }
404        let logits = self.gate_logits(features);
405        softmax(&logits)
406    }
407
408    // -------------------------------------------------------------------
409    // Public API -- training
410    // -------------------------------------------------------------------
411
412    /// Train on a single observation.
413    ///
414    /// 1. Lazily initializes the gate weights on the first sample.
415    /// 2. Computes gating probabilities via softmax over the linear gate.
416    /// 3. Routes the sample to experts (and their shadows) based on gating mode.
417    /// 4. Performs shadow competition via Hoeffding-bound NLL comparison.
418    /// 5. Updates gate weights via SGD on the cross-entropy gradient (best
419    ///    expert by lowest Gaussian NLL).
420    pub fn train_one(&mut self, sample: &impl Observation) {
421        let features = sample.features();
422        let target = sample.target();
423        let d = features.len();
424
425        // Step 1: lazy gate initialization
426        self.ensure_gate_init(d);
427
428        // Step 2: compute gating probabilities
429        let logits = self.gate_logits(features);
430        let probs = softmax(&logits);
431        let k = self.experts.len();
432
433        // Step 3: train experts and shadows based on gating mode
434        match &self.gating_mode {
435            GatingMode::Soft => {
436                for (i, &prob) in probs.iter().enumerate() {
437                    let weighted = SampleRef::weighted(features, target, prob);
438                    self.experts[i].train_one(&weighted);
439                    self.shadows[i].train_one(&weighted);
440                }
441            }
442            GatingMode::Hard { top_k } => {
443                let top_k = (*top_k).min(k);
444                let mut indices: Vec<usize> = (0..k).collect();
445                indices.sort_unstable_by(|&a, &b| {
446                    probs[b]
447                        .partial_cmp(&probs[a])
448                        .unwrap_or(core::cmp::Ordering::Equal)
449                });
450                for &i in indices.iter().take(top_k) {
451                    let obs = SampleRef::new(features, target);
452                    self.experts[i].train_one(&obs);
453                    self.shadows[i].train_one(&obs);
454                }
455            }
456        }
457
458        // Step 4: shadow competition per slot
459        for i in 0..k {
460            // Skip if expert or shadow not yet initialized, or below warmup
461            if !self.experts[i].is_initialized() || !self.shadows[i].is_initialized() {
462                continue;
463            }
464            if self.shadows[i].n_samples_seen() < self.shadow_min_samples {
465                continue;
466            }
467
468            let pred_active = self.experts[i].predict(features);
469            let pred_shadow = self.shadows[i].predict(features);
470
471            let nll_active = Self::gaussian_nll(&pred_active, target);
472            let nll_shadow = Self::gaussian_nll(&pred_shadow, target);
473
474            // Positive diff = shadow is better (lower NLL)
475            let diff = nll_active - nll_shadow;
476            self.cumulative_advantage[i] += diff;
477            self.shadow_n[i] += 1;
478
479            let abs_diff = crate::math::abs(diff);
480            if abs_diff > self.max_nll_diff[i] {
481                self.max_nll_diff[i] = abs_diff;
482            }
483
484            // Hoeffding bound test
485            if self.shadow_n[i] >= 10 && self.max_nll_diff[i] > 0.0 {
486                let mean_advantage = self.cumulative_advantage[i] / self.shadow_n[i] as f64;
487                if mean_advantage > 0.0 {
488                    let r_squared = self.max_nll_diff[i] * self.max_nll_diff[i];
489                    let ln_inv_delta = crate::math::ln(1.0 / self.delta);
490                    let epsilon = crate::math::sqrt(
491                        r_squared * ln_inv_delta / (2.0 * self.shadow_n[i] as f64),
492                    );
493
494                    if mean_advantage > epsilon {
495                        // Swap: shadow becomes active, create fresh shadow
496                        self.experts[i] = self.shadows[i].clone();
497                        let base_cfg = self
498                            .expert_configs
499                            .as_ref()
500                            .map(|c| &c[i])
501                            .unwrap_or(&self.config);
502                        let mut fresh_cfg = base_cfg.clone();
503                        fresh_cfg.seed = base_cfg.seed
504                            ^ (0x5A00_0000 | i as u64)
505                            ^ (self.shadow_replacements[i].wrapping_add(1) * 0x9E37_79B9);
506                        self.shadows[i] = DistributionalSGBT::new(fresh_cfg);
507
508                        // Reset comparison state
509                        self.cumulative_advantage[i] = 0.0;
510                        self.shadow_n[i] = 0;
511                        self.max_nll_diff[i] = 0.0;
512                        self.shadow_replacements[i] += 1;
513                    }
514                }
515            }
516        }
517
518        // Step 5: update gate weights via SGD on cross-entropy gradient
519        // Find best expert by lowest Gaussian NLL
520        let mut best_idx = 0;
521        let mut best_nll = f64::INFINITY;
522        for (i, expert) in self.experts.iter().enumerate() {
523            let pred = expert.predict(features);
524            let nll = Self::gaussian_nll(&pred, target);
525            if nll < best_nll {
526                best_nll = nll;
527                best_idx = i;
528            }
529        }
530
531        // Cross-entropy gradient + entropy regularization: dz_k = ce_grad + entropy_weight * entropy_grad
532        // Entropy gradient: d(-H)/dz_k = p_k * (log(p_k) + 1) - mean_term
533        // This pushes the gate toward uniform distribution, preventing collapse.
534        let entropy_mean_log_term: f64 = if self.entropy_weight != 0.0 {
535            probs
536                .iter()
537                .map(|&p| {
538                    let lp = if p > 1e-10 { crate::math::ln(p) } else { -23.0 };
539                    p * (lp + 1.0)
540                })
541                .sum()
542        } else {
543            0.0
544        };
545
546        for (i, (weights_row, bias)) in self
547            .gate_weights
548            .iter_mut()
549            .zip(self.gate_bias.iter_mut())
550            .enumerate()
551        {
552            let indicator = if i == best_idx { 1.0 } else { 0.0 };
553            let ce_grad = probs[i] - indicator;
554
555            let total_grad = if self.entropy_weight != 0.0 {
556                let log_p = if probs[i] > 1e-10 {
557                    crate::math::ln(probs[i])
558                } else {
559                    -23.0
560                };
561                let entropy_grad = probs[i] * (log_p + 1.0) - entropy_mean_log_term;
562                ce_grad + self.entropy_weight * entropy_grad
563            } else {
564                ce_grad
565            };
566
567            let lr = self.gate_lr;
568            for (j, &xj) in features.iter().enumerate() {
569                weights_row[j] -= lr * total_grad * xj;
570            }
571            *bias -= lr * total_grad;
572        }
573
574        self.samples_seen += 1;
575    }
576
577    /// Train on a batch of observations.
578    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
579        for sample in samples {
580            self.train_one(sample);
581        }
582    }
583
584    // -------------------------------------------------------------------
585    // Public API -- prediction
586    // -------------------------------------------------------------------
587
588    /// Predict the full mixture Gaussian distribution for a feature vector.
589    ///
590    /// Applies the law of total variance to combine K expert Gaussians:
591    ///
592    /// ```text
593    /// mu_mix = sum(p_k * mu_k)
594    /// var_mix = sum(p_k * (sigma_k^2 + mu_k^2)) - mu_mix^2
595    /// sigma_mix = sqrt(var_mix)
596    /// ```
597    pub fn predict(&self, features: &[f64]) -> GaussianPrediction {
598        let probs = self.gating_probabilities(features);
599        let preds: Vec<GaussianPrediction> =
600            self.experts.iter().map(|e| e.predict(features)).collect();
601
602        // Mixture mean
603        let mu_mix: f64 = probs
604            .iter()
605            .zip(preds.iter())
606            .map(|(&p, pred)| p * pred.mu)
607            .sum();
608
609        // Law of total variance: Var = E[Var(X|K)] + Var(E[X|K])
610        // = sum(p_k * sigma_k^2) + sum(p_k * mu_k^2) - mu_mix^2
611        let second_moment: f64 = probs
612            .iter()
613            .zip(preds.iter())
614            .map(|(&p, pred)| p * (pred.sigma * pred.sigma + pred.mu * pred.mu))
615            .sum();
616        let var_mix = (second_moment - mu_mix * mu_mix).max(1e-16);
617        let sigma_mix = crate::math::sqrt(var_mix);
618
619        GaussianPrediction {
620            mu: mu_mix,
621            sigma: sigma_mix,
622            log_sigma: crate::math::ln(sigma_mix),
623        }
624    }
625
626    /// Predict with gating probabilities returned alongside the prediction.
627    ///
628    /// Returns `(GaussianPrediction, probabilities)` where probabilities is
629    /// a K-length vector summing to 1.0.
630    pub fn predict_with_gating(&self, features: &[f64]) -> (GaussianPrediction, Vec<f64>) {
631        let probs = self.gating_probabilities(features);
632        let preds: Vec<GaussianPrediction> =
633            self.experts.iter().map(|e| e.predict(features)).collect();
634
635        let mu_mix: f64 = probs
636            .iter()
637            .zip(preds.iter())
638            .map(|(&p, pred)| p * pred.mu)
639            .sum();
640
641        let second_moment: f64 = probs
642            .iter()
643            .zip(preds.iter())
644            .map(|(&p, pred)| p * (pred.sigma * pred.sigma + pred.mu * pred.mu))
645            .sum();
646        let var_mix = (second_moment - mu_mix * mu_mix).max(1e-16);
647        let sigma_mix = crate::math::sqrt(var_mix);
648
649        let pred = GaussianPrediction {
650            mu: mu_mix,
651            sigma: sigma_mix,
652            log_sigma: crate::math::ln(sigma_mix),
653        };
654        (pred, probs)
655    }
656
657    /// Get each expert's individual prediction for a feature vector.
658    ///
659    /// Returns a K-length vector of Gaussian predictions, one per expert.
660    pub fn expert_predictions(&self, features: &[f64]) -> Vec<GaussianPrediction> {
661        self.experts.iter().map(|e| e.predict(features)).collect()
662    }
663
664    /// Predict the mean (location parameter) of the mixture.
665    #[inline]
666    pub fn predict_mu(&self, features: &[f64]) -> f64 {
667        self.predict(features).mu
668    }
669
670    // -------------------------------------------------------------------
671    // Public API -- inspection
672    // -------------------------------------------------------------------
673
674    /// Number of experts in the mixture.
675    #[inline]
676    pub fn n_experts(&self) -> usize {
677        self.experts.len()
678    }
679
680    /// Total training samples seen.
681    #[inline]
682    pub fn n_samples_seen(&self) -> u64 {
683        self.samples_seen
684    }
685
686    /// Immutable access to all experts.
687    pub fn experts(&self) -> &[DistributionalSGBT] {
688        &self.experts
689    }
690
691    /// Immutable access to a specific expert.
692    ///
693    /// # Panics
694    ///
695    /// Panics if `idx >= n_experts`.
696    pub fn expert(&self, idx: usize) -> &DistributionalSGBT {
697        &self.experts[idx]
698    }
699
700    /// Shadow replacement counts per expert slot.
701    pub fn shadow_replacements(&self) -> &[u64] {
702        &self.shadow_replacements
703    }
704
705    /// Entropy regularization weight for gate load balancing.
706    #[inline]
707    pub fn entropy_weight(&self) -> f64 {
708        self.entropy_weight
709    }
710
711    /// Per-expert configurations, if set via [`with_expert_configs`](Self::with_expert_configs).
712    pub fn expert_configs(&self) -> Option<&[SGBTConfig]> {
713        self.expert_configs.as_deref()
714    }
715
716    /// Reset the entire MoE to its initial state.
717    ///
718    /// Resets all experts and shadows, clears gate weights and biases back to
719    /// zeros, resets shadow competition state and the sample counter.
720    pub fn reset(&mut self) {
721        let k = self.experts.len();
722        for expert in &mut self.experts {
723            expert.reset();
724        }
725        for shadow in &mut self.shadows {
726            shadow.reset();
727        }
728        self.gate_weights.clear();
729        self.gate_bias = vec![0.0; k];
730        self.n_features = None;
731        self.samples_seen = 0;
732        self.cumulative_advantage = vec![0.0; k];
733        self.shadow_n = vec![0; k];
734        self.max_nll_diff = vec![0.0; k];
735        self.shadow_replacements = vec![0; k];
736    }
737}
738
739// ---------------------------------------------------------------------------
740// StreamingLearner impl
741// ---------------------------------------------------------------------------
742
743use crate::learner::StreamingLearner;
744
745impl StreamingLearner for MoEDistributionalSGBT {
746    fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
747        let sample = SampleRef::weighted(features, target, weight);
748        // UFCS: call the inherent train_one(&impl Observation), not this trait method.
749        MoEDistributionalSGBT::train_one(self, &sample);
750    }
751
752    /// Returns the mean (mu) of the predicted mixture Gaussian.
753    fn predict(&self, features: &[f64]) -> f64 {
754        MoEDistributionalSGBT::predict(self, features).mu
755    }
756
757    fn n_samples_seen(&self) -> u64 {
758        self.samples_seen
759    }
760
761    fn reset(&mut self) {
762        MoEDistributionalSGBT::reset(self);
763    }
764}
765
766// ===========================================================================
767// Tests
768// ===========================================================================
769
770#[cfg(test)]
771mod tests {
772    use super::*;
773    use crate::sample::Sample;
774    use alloc::boxed::Box;
775    use alloc::vec;
776    use alloc::vec::Vec;
777
778    /// Helper: build a minimal config for tests.
779    fn test_config() -> SGBTConfig {
780        SGBTConfig::builder()
781            .n_steps(5)
782            .learning_rate(0.1)
783            .grace_period(5)
784            .build()
785            .unwrap()
786    }
787
788    #[test]
789    fn test_creation() {
790        let moe = MoEDistributionalSGBT::new(test_config(), 3);
791        assert_eq!(moe.n_experts(), 3);
792        assert_eq!(moe.n_samples_seen(), 0);
793        assert_eq!(moe.shadow_replacements().len(), 3);
794        for &r in moe.shadow_replacements() {
795            assert_eq!(r, 0);
796        }
797    }
798
799    #[test]
800    fn test_gating_probabilities_sum_to_one() {
801        let mut moe = MoEDistributionalSGBT::new(test_config(), 5);
802
803        // Before training: uniform probabilities
804        let probs = moe.gating_probabilities(&[1.0, 2.0]);
805        let sum: f64 = probs.iter().sum();
806        assert!((sum - 1.0).abs() < 1e-10, "pre-training sum = {}", sum);
807
808        // After training: probabilities should still sum to 1
809        for i in 0..20 {
810            let sample = Sample::new(vec![i as f64, (i * 2) as f64], i as f64);
811            moe.train_one(&sample);
812        }
813        let probs = moe.gating_probabilities(&[5.0, 10.0]);
814        let sum: f64 = probs.iter().sum();
815        assert!((sum - 1.0).abs() < 1e-10, "post-training sum = {}", sum);
816    }
817
818    #[test]
819    fn test_prediction_is_valid_gaussian() {
820        let mut moe = MoEDistributionalSGBT::new(test_config(), 3);
821
822        // Train enough for base initialization
823        for i in 0..50 {
824            let sample = Sample::new(vec![i as f64, (i as f64) * 0.5], i as f64 * 2.0);
825            moe.train_one(&sample);
826        }
827
828        let pred = moe.predict(&[10.0, 5.0]);
829        assert!(pred.mu.is_finite(), "mu should be finite: {}", pred.mu);
830        assert!(pred.sigma > 0.0, "sigma should be > 0: {}", pred.sigma);
831        assert!(
832            pred.log_sigma.is_finite(),
833            "log_sigma should be finite: {}",
834            pred.log_sigma
835        );
836    }
837
838    #[test]
839    fn test_prediction_changes_after_training() {
840        let mut moe = MoEDistributionalSGBT::new(test_config(), 3);
841        let features = vec![1.0, 2.0, 3.0];
842
843        let pred_before = moe.predict(&features);
844
845        for i in 0..100 {
846            let sample = Sample::new(features.clone(), 10.0 + i as f64 * 0.1);
847            moe.train_one(&sample);
848        }
849
850        let pred_after = moe.predict(&features);
851        assert!(
852            (pred_after.mu - pred_before.mu).abs() > 1e-6,
853            "mu should change after training: before={}, after={}",
854            pred_before.mu,
855            pred_after.mu
856        );
857    }
858
859    #[test]
860    fn test_mixture_variance() {
861        // Manual check: with uniform gating and known expert outputs,
862        // verify the law of total variance formula.
863        let mut moe = MoEDistributionalSGBT::new(test_config(), 2);
864
865        // Train so experts produce non-trivial predictions
866        for i in 0..80 {
867            let sample = Sample::new(vec![i as f64], i as f64 * 3.0);
868            moe.train_one(&sample);
869        }
870
871        let features = &[40.0];
872        let probs = moe.gating_probabilities(features);
873        let expert_preds = moe.expert_predictions(features);
874
875        // Manual mixture calculation
876        let mu_mix: f64 = probs
877            .iter()
878            .zip(expert_preds.iter())
879            .map(|(&p, pred)| p * pred.mu)
880            .sum();
881        let second_moment: f64 = probs
882            .iter()
883            .zip(expert_preds.iter())
884            .map(|(&p, pred)| p * (pred.sigma * pred.sigma + pred.mu * pred.mu))
885            .sum();
886        let var_mix = (second_moment - mu_mix * mu_mix).max(1e-16);
887        let sigma_mix = var_mix.sqrt();
888
889        let pred = moe.predict(features);
890        assert!(
891            (pred.mu - mu_mix).abs() < 1e-10,
892            "mu mismatch: pred={}, manual={}",
893            pred.mu,
894            mu_mix
895        );
896        assert!(
897            (pred.sigma - sigma_mix).abs() < 1e-10,
898            "sigma mismatch: pred={}, manual={}",
899            pred.sigma,
900            sigma_mix
901        );
902    }
903
904    #[test]
905    fn test_expert_predictions_count() {
906        let moe = MoEDistributionalSGBT::new(test_config(), 4);
907        let preds = moe.expert_predictions(&[1.0, 2.0]);
908        assert_eq!(preds.len(), 4, "should return one prediction per expert");
909    }
910
911    #[test]
912    fn test_predict_with_gating_consistency() {
913        let mut moe = MoEDistributionalSGBT::new(test_config(), 3);
914
915        for i in 0..50 {
916            let sample = Sample::new(vec![i as f64, (i as f64) * 0.5], i as f64);
917            moe.train_one(&sample);
918        }
919
920        let features = &[10.0, 5.0];
921        let (pred, probs) = moe.predict_with_gating(features);
922        let expert_preds = moe.expert_predictions(features);
923
924        assert_eq!(probs.len(), 3);
925        let sum: f64 = probs.iter().sum();
926        assert!((sum - 1.0).abs() < 1e-10);
927
928        // mu should equal weighted sum of expert means
929        let expected_mu: f64 = probs
930            .iter()
931            .zip(expert_preds.iter())
932            .map(|(&p, ep)| p * ep.mu)
933            .sum();
934        assert!(
935            (pred.mu - expected_mu).abs() < 1e-10,
936            "mu mismatch: pred={}, expected={}",
937            pred.mu,
938            expected_mu
939        );
940    }
941
942    #[test]
943    fn test_n_samples_seen_increments() {
944        let mut moe = MoEDistributionalSGBT::new(test_config(), 2);
945        assert_eq!(moe.n_samples_seen(), 0);
946
947        for i in 0..25 {
948            moe.train_one(&Sample::new(vec![i as f64], i as f64));
949        }
950        assert_eq!(moe.n_samples_seen(), 25);
951    }
952
953    #[test]
954    fn test_reset_clears_state() {
955        let mut moe = MoEDistributionalSGBT::new(test_config(), 3);
956
957        for i in 0..50 {
958            moe.train_one(&Sample::new(vec![i as f64, (i * 2) as f64], i as f64));
959        }
960        assert_eq!(moe.n_samples_seen(), 50);
961
962        moe.reset();
963
964        assert_eq!(moe.n_samples_seen(), 0);
965        assert_eq!(moe.n_experts(), 3);
966        // Gate should be re-lazily-initialized
967        let probs = moe.gating_probabilities(&[1.0, 2.0]);
968        assert_eq!(probs.len(), 3);
969        // After reset, probabilities are uniform again
970        for &p in &probs {
971            assert!(
972                (p - 1.0 / 3.0).abs() < 1e-10,
973                "expected uniform after reset, got {}",
974                p
975            );
976        }
977        // Shadow replacement counters are also reset
978        for &r in moe.shadow_replacements() {
979            assert_eq!(r, 0);
980        }
981    }
982
983    #[test]
984    fn test_streaming_learner_trait() {
985        let config = test_config();
986        let model = MoEDistributionalSGBT::new(config, 3);
987        let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
988        for i in 0..100 {
989            let x = i as f64 * 0.1;
990            boxed.train(&[x], x * 2.0);
991        }
992        assert_eq!(boxed.n_samples_seen(), 100);
993        let pred = boxed.predict(&[5.0]);
994        assert!(pred.is_finite());
995        boxed.reset();
996        assert_eq!(boxed.n_samples_seen(), 0);
997    }
998
999    #[test]
1000    fn test_hard_gating_mode() {
1001        let mut moe = MoEDistributionalSGBT::with_gating(
1002            test_config(),
1003            4,
1004            GatingMode::Hard { top_k: 2 },
1005            0.01,
1006        );
1007
1008        for i in 0..30 {
1009            let sample = Sample::new(vec![i as f64], i as f64);
1010            moe.train_one(&sample);
1011        }
1012
1013        assert_eq!(moe.n_samples_seen(), 30);
1014        let pred = moe.predict(&[15.0]);
1015        assert!(pred.mu.is_finite());
1016        assert!(pred.sigma > 0.0);
1017    }
1018
1019    #[test]
1020    fn test_predict_mu_matches_predict() {
1021        let mut moe = MoEDistributionalSGBT::new(test_config(), 3);
1022
1023        for i in 0..50 {
1024            moe.train_one(&Sample::new(vec![i as f64], i as f64 * 2.0));
1025        }
1026
1027        let features = &[25.0];
1028        let mu_direct = moe.predict_mu(features);
1029        let mu_from_predict = moe.predict(features).mu;
1030        assert!(
1031            (mu_direct - mu_from_predict).abs() < 1e-12,
1032            "predict_mu={} vs predict().mu={}",
1033            mu_direct,
1034            mu_from_predict
1035        );
1036    }
1037
1038    #[test]
1039    fn test_batch_training() {
1040        let mut moe = MoEDistributionalSGBT::new(test_config(), 3);
1041
1042        let samples: Vec<Sample> = (0..20)
1043            .map(|i| Sample::new(vec![i as f64, (i * 3) as f64], i as f64))
1044            .collect();
1045
1046        moe.train_batch(&samples);
1047
1048        assert_eq!(moe.n_samples_seen(), 20);
1049        let pred = moe.predict(&[10.0, 30.0]);
1050        assert!(pred.mu.is_finite());
1051        assert!(pred.sigma > 0.0);
1052    }
1053
1054    #[test]
1055    fn moe_with_expert_configs_different_depths() {
1056        // Each expert gets its own config with different tree depth.
1057        let configs: Vec<SGBTConfig> = (0..3)
1058            .map(|i| {
1059                SGBTConfig::builder()
1060                    .n_steps(5)
1061                    .learning_rate(0.1)
1062                    .grace_period(5)
1063                    .max_depth(2 + i) // depth 2, 3, 4
1064                    .build()
1065                    .unwrap()
1066            })
1067            .collect();
1068
1069        let mut moe = MoEDistributionalSGBT::with_expert_configs(
1070            configs.clone(),
1071            GatingMode::Soft,
1072            0.01,
1073            0.0, // no entropy
1074            1e-3,
1075            500,
1076        );
1077
1078        assert_eq!(moe.n_experts(), 3);
1079        assert!(moe.expert_configs().is_some());
1080        assert_eq!(moe.expert_configs().unwrap().len(), 3);
1081
1082        // Verify each expert got its config (via max_depth)
1083        for (i, cfg) in configs.iter().enumerate() {
1084            assert_eq!(moe.expert(i).config().max_depth, cfg.max_depth);
1085        }
1086
1087        // Train and verify it works
1088        for i in 0..50 {
1089            let sample = Sample::new(vec![i as f64, (i * 2) as f64], i as f64 * 3.0);
1090            moe.train_one(&sample);
1091        }
1092        let pred = moe.predict(&[10.0, 20.0]);
1093        assert!(pred.mu.is_finite());
1094        assert!(pred.sigma > 0.0);
1095    }
1096
1097    #[test]
1098    fn entropy_regularization_prevents_collapse() {
1099        // With entropy weight, gate probs should stay above a minimum for all experts
1100        // when data is uniform across patterns.
1101        let config = test_config();
1102        let mut moe = MoEDistributionalSGBT::with_expert_configs(
1103            vec![config.clone(), config.clone(), config],
1104            GatingMode::Soft,
1105            0.01,
1106            0.1, // entropy weight
1107            1e-3,
1108            500,
1109        );
1110
1111        // Train with uniform-ish data
1112        for i in 0..200 {
1113            let x = (i % 10) as f64;
1114            let sample = Sample::new(vec![x, x * 2.0], x * 3.0);
1115            moe.train_one(&sample);
1116        }
1117
1118        // Check that no expert is completely starved
1119        let probs = moe.gating_probabilities(&[5.0, 10.0]);
1120        for (i, &p) in probs.iter().enumerate() {
1121            assert!(
1122                p > 0.02,
1123                "Expert {} has probability {} -- gate collapsed despite entropy regularization",
1124                i,
1125                p
1126            );
1127        }
1128    }
1129
1130    #[test]
1131    fn moe_expert_configs_shadow_respawn_correct() {
1132        // After shadow swap, the fresh shadow should use the per-expert config.
1133        // We can verify the config path is correct by construction.
1134        let configs: Vec<SGBTConfig> = (0..2)
1135            .map(|i| {
1136                SGBTConfig::builder()
1137                    .n_steps(3)
1138                    .learning_rate(0.1)
1139                    .grace_period(5)
1140                    .max_depth(3 + i) // depth 3, 4
1141                    .build()
1142                    .unwrap()
1143            })
1144            .collect();
1145
1146        let moe = MoEDistributionalSGBT::with_expert_configs(
1147            configs.clone(),
1148            GatingMode::Soft,
1149            0.01,
1150            0.0,
1151            1e-3,
1152            500,
1153        );
1154
1155        // Verify expert_configs are stored
1156        let ec = moe.expert_configs().unwrap();
1157        assert_eq!(ec[0].max_depth, 3);
1158        assert_eq!(ec[1].max_depth, 4);
1159
1160        // The shadow swap path references expert_configs[i] -- verified by
1161        // the code structure. This test confirms the configs are accessible
1162        // and correctly indexed.
1163        assert_eq!(moe.expert(0).config().max_depth, 3);
1164        assert_eq!(moe.expert(1).config().max_depth, 4);
1165    }
1166}