Skip to main content

irithyll_core/ensemble/
moe.rs

1//! Streaming Mixture of Experts over SGBT ensembles.
2//!
3//! Implements a gated mixture of K independent [`SGBT`] experts with a learned
4//! linear softmax gate. Each expert is a full streaming gradient boosted tree
5//! ensemble; the gate routes incoming samples to the most relevant expert(s)
6//! based on the feature vector, enabling capacity specialization across
7//! different regions of the input space.
8//!
9//! # Algorithm
10//!
11//! The gating network computes K logits `z_k = W_k · x + b_k` and applies
12//! softmax to obtain routing probabilities `p_k = softmax(z)_k`. Prediction
13//! is the probability-weighted sum of expert predictions:
14//!
15//! ```text
16//! ŷ = Σ_k  p_k(x) · f_k(x)
17//! ```
18//!
19//! During training, the gate is updated via online SGD on the cross-entropy
20//! loss between the softmax distribution and the one-hot indicator of the
21//! best expert (lowest loss on the current sample). This encourages the gate
22//! to learn which expert is most competent for each region.
23//!
24//! Two gating modes are supported:
25//!
26//! - **Soft** (default): All experts receive every sample, weighted by their
27//!   gating probability. This maximizes information flow but has O(K) training
28//!   cost per sample.
29//! - **Hard (top-k)**: Only the top-k experts (by gating probability) receive
30//!   the sample. This reduces computation when K is large, at the cost of
31//!   slower expert specialization.
32//!
33//! # References
34//!
35//! - Jacobs, R. A., Jordan, M. I., Nowlan, S. J., & Hinton, G. E. (1991).
36//!   Adaptive Mixtures of Local Experts. *Neural Computation*, 3(1), 79–87.
37//! - Shazeer, N., Mirhoseini, A., Maziarz, K., Davis, A., Le, Q., Hinton, G.,
38//!   & Dean, J. (2017). Outrageously Large Neural Networks: The Sparsely-Gated
39//!   Mixture-of-Experts Layer. *ICLR 2017*.
40//!
41//! # Example
42//!
43//! ```text
44//! use irithyll::ensemble::moe::{MoESGBT, GatingMode};
45//! use irithyll::SGBTConfig;
46//!
47//! let config = SGBTConfig::builder()
48//!     .n_steps(10)
49//!     .learning_rate(0.1)
50//!     .grace_period(10)
51//!     .build()
52//!     .unwrap();
53//!
54//! let mut moe = MoESGBT::new(config, 3);
55//! moe.train_one(&irithyll::Sample::new(vec![1.0, 2.0], 3.0));
56//! let pred = moe.predict(&[1.0, 2.0]);
57//! ```
58
59use alloc::vec;
60use alloc::vec::Vec;
61
62use core::fmt;
63
64use crate::ensemble::config::SGBTConfig;
65use crate::ensemble::SGBT;
66use crate::loss::squared::SquaredLoss;
67use crate::loss::Loss;
68use crate::sample::{Observation, SampleRef};
69
70// ---------------------------------------------------------------------------
71// GatingMode
72// ---------------------------------------------------------------------------
73
74/// Controls how the gate routes samples to experts.
75///
76/// - [`Soft`](GatingMode::Soft): every expert sees every sample, weighted by
77///   gating probability. Maximizes information flow.
78/// - [`Hard`](GatingMode::Hard): only the `top_k` experts with highest gating
79///   probability receive the sample. Reduces cost when K is large.
80#[derive(Debug, Clone)]
81pub enum GatingMode {
82    /// All experts receive every sample, weighted by gating probability.
83    Soft,
84    /// Only the top-k experts receive the sample (sparse routing).
85    Hard {
86        /// Number of experts to route each sample to.
87        top_k: usize,
88    },
89}
90
91// ---------------------------------------------------------------------------
92// Softmax (numerically stable)
93// ---------------------------------------------------------------------------
94
95/// Numerically stable softmax: subtract max logit before exponentiating to
96/// prevent overflow, then normalize.
97pub(crate) fn softmax(logits: &[f64]) -> Vec<f64> {
98    let max = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
99    let exps: Vec<f64> = logits.iter().map(|&z| crate::math::exp(z - max)).collect();
100    let sum: f64 = exps.iter().sum();
101    exps.iter().map(|&e| e / sum).collect()
102}
103
104// ---------------------------------------------------------------------------
105// MoESGBT
106// ---------------------------------------------------------------------------
107
108/// Streaming Mixture of Experts over SGBT ensembles.
109///
110/// Combines K independent [`SGBT<L>`] experts with a learned linear softmax
111/// gating network. The gate is trained online via SGD to route samples to the
112/// expert with the lowest loss, while all experts (or the top-k in hard gating
113/// mode) are trained on each incoming sample.
114///
115/// Generic over `L: Loss` so the expert loss function is monomorphized. The
116/// default is [`SquaredLoss`] for regression tasks.
117///
118/// # Gate Architecture
119///
120/// The gate is a single linear layer: `z_k = W_k · x + b_k` followed by
121/// softmax. Weights are lazily initialized to zeros on the first sample
122/// (since the feature dimensionality is not known at construction time).
123/// The gate learns via cross-entropy gradient descent against the one-hot
124/// indicator of the best expert per sample.
125pub struct MoESGBT<L: Loss = SquaredLoss> {
126    /// The K expert SGBT ensembles.
127    experts: Vec<SGBT<L>>,
128    /// Gate weight matrix [K x d], lazily initialized on first sample.
129    gate_weights: Vec<Vec<f64>>,
130    /// Gate bias vector [K].
131    gate_bias: Vec<f64>,
132    /// Learning rate for the gating network SGD updates.
133    gate_lr: f64,
134    /// Number of features (set on first sample, `None` until then).
135    n_features: Option<usize>,
136    /// Gating mode (soft or hard top-k).
137    gating_mode: GatingMode,
138    /// Configuration used to construct each expert.
139    config: SGBTConfig,
140    /// Loss function (shared type with experts, used for best-expert selection).
141    loss: L,
142    /// Total training samples seen.
143    samples_seen: u64,
144}
145
146// ---------------------------------------------------------------------------
147// Clone
148// ---------------------------------------------------------------------------
149
150impl<L: Loss + Clone> Clone for MoESGBT<L> {
151    fn clone(&self) -> Self {
152        Self {
153            experts: self.experts.clone(),
154            gate_weights: self.gate_weights.clone(),
155            gate_bias: self.gate_bias.clone(),
156            gate_lr: self.gate_lr,
157            n_features: self.n_features,
158            gating_mode: self.gating_mode.clone(),
159            config: self.config.clone(),
160            loss: self.loss.clone(),
161            samples_seen: self.samples_seen,
162        }
163    }
164}
165
166// ---------------------------------------------------------------------------
167// Debug
168// ---------------------------------------------------------------------------
169
170impl<L: Loss> fmt::Debug for MoESGBT<L> {
171    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172        f.debug_struct("MoESGBT")
173            .field("n_experts", &self.experts.len())
174            .field("gating_mode", &self.gating_mode)
175            .field("samples_seen", &self.samples_seen)
176            .finish()
177    }
178}
179
180// ---------------------------------------------------------------------------
181// Default loss constructor (SquaredLoss)
182// ---------------------------------------------------------------------------
183
184impl MoESGBT<SquaredLoss> {
185    /// Create a new MoE ensemble with squared loss (regression) and soft gating.
186    ///
187    /// Each expert is seeded uniquely via `config.seed ^ (0x0000_0E00_0000_0000 | i)`.
188    /// The gating learning rate defaults to 0.01.
189    ///
190    /// # Panics
191    ///
192    /// Panics if `n_experts < 1`.
193    pub fn new(config: SGBTConfig, n_experts: usize) -> Self {
194        Self::with_loss(config, SquaredLoss, n_experts)
195    }
196}
197
198// ---------------------------------------------------------------------------
199// General impl
200// ---------------------------------------------------------------------------
201
202impl<L: Loss + Clone> MoESGBT<L> {
203    /// Create a new MoE ensemble with a custom loss and soft gating.
204    ///
205    /// # Panics
206    ///
207    /// Panics if `n_experts < 1`.
208    pub fn with_loss(config: SGBTConfig, loss: L, n_experts: usize) -> Self {
209        Self::with_gating(config, loss, n_experts, GatingMode::Soft, 0.01)
210    }
211
212    /// Create a new MoE ensemble with full control over gating mode and gate
213    /// learning rate.
214    ///
215    /// # Panics
216    ///
217    /// Panics if `n_experts < 1`.
218    pub fn with_gating(
219        config: SGBTConfig,
220        loss: L,
221        n_experts: usize,
222        gating_mode: GatingMode,
223        gate_lr: f64,
224    ) -> Self {
225        assert!(n_experts >= 1, "MoESGBT requires at least 1 expert");
226
227        let experts = (0..n_experts)
228            .map(|i| {
229                let mut cfg = config.clone();
230                cfg.seed = config.seed ^ (0x0000_0E00_0000_0000 | i as u64);
231                SGBT::with_loss(cfg, loss.clone())
232            })
233            .collect();
234
235        let gate_bias = vec![0.0; n_experts];
236
237        Self {
238            experts,
239            gate_weights: Vec::new(), // lazy init
240            gate_bias,
241            gate_lr,
242            n_features: None,
243            gating_mode,
244            config,
245            loss,
246            samples_seen: 0,
247        }
248    }
249}
250
251impl<L: Loss> MoESGBT<L> {
252    // -------------------------------------------------------------------
253    // Internal helpers
254    // -------------------------------------------------------------------
255
256    /// Ensure the gate weight matrix is initialized to the correct dimensions.
257    /// Called lazily on the first sample when `n_features` is discovered.
258    fn ensure_gate_init(&mut self, d: usize) {
259        if self.n_features.is_none() {
260            let k = self.experts.len();
261            self.gate_weights = vec![vec![0.0; d]; k];
262            self.n_features = Some(d);
263        }
264    }
265
266    /// Compute raw gate logits: z_k = W_k · x + b_k.
267    fn gate_logits(&self, features: &[f64]) -> Vec<f64> {
268        let k = self.experts.len();
269        let mut logits = Vec::with_capacity(k);
270        for i in 0..k {
271            let dot: f64 = self.gate_weights[i]
272                .iter()
273                .zip(features.iter())
274                .map(|(&w, &x)| w * x)
275                .sum();
276            logits.push(dot + self.gate_bias[i]);
277        }
278        logits
279    }
280
281    // -------------------------------------------------------------------
282    // Public API -- gating
283    // -------------------------------------------------------------------
284
285    /// Compute gating probabilities for a feature vector.
286    ///
287    /// Returns a vector of K probabilities that sum to 1.0, one per expert.
288    /// The gate must be initialized (at least one training sample seen),
289    /// otherwise returns uniform probabilities.
290    pub fn gating_probabilities(&self, features: &[f64]) -> Vec<f64> {
291        let k = self.experts.len();
292        if self.n_features.is_none() {
293            // Gate not initialized yet -- return uniform
294            return vec![1.0 / k as f64; k];
295        }
296        let logits = self.gate_logits(features);
297        softmax(&logits)
298    }
299
300    // -------------------------------------------------------------------
301    // Public API -- training
302    // -------------------------------------------------------------------
303
304    /// Train on a single observation.
305    ///
306    /// 1. Lazily initializes the gate weights if this is the first sample.
307    /// 2. Computes gating probabilities via softmax over the linear gate.
308    /// 3. Routes the sample to experts according to the gating mode:
309    ///    - **Soft**: all experts receive the sample, each weighted by its
310    ///      gating probability (via `SampleRef::weighted`).
311    ///    - **Hard(top_k)**: only the top-k experts by probability receive
312    ///      the sample (with unit weight).
313    /// 4. Updates gate weights via SGD on the cross-entropy gradient:
314    ///    find the best expert (lowest loss), compute `dz_k = p_k - 1{k==best}`,
315    ///    and apply `W_k -= gate_lr * dz_k * x`, `b_k -= gate_lr * dz_k`.
316    pub fn train_one(&mut self, sample: &impl Observation) {
317        let features = sample.features();
318        let target = sample.target();
319        let d = features.len();
320
321        // Step 1: lazy gate initialization
322        self.ensure_gate_init(d);
323
324        // Step 2: compute gating probabilities
325        let logits = self.gate_logits(features);
326        let probs = softmax(&logits);
327        let k = self.experts.len();
328
329        // Step 3: train experts based on gating mode
330        match &self.gating_mode {
331            GatingMode::Soft => {
332                // Every expert gets the sample, weighted by gating probability
333                for (expert, &prob) in self.experts.iter_mut().zip(probs.iter()) {
334                    let weighted = SampleRef::weighted(features, target, prob);
335                    expert.train_one(&weighted);
336                }
337            }
338            GatingMode::Hard { top_k } => {
339                // Only the top-k experts get the sample
340                let top_k = (*top_k).min(k);
341                let mut indices: Vec<usize> = (0..k).collect();
342                indices.sort_unstable_by(|&a, &b| {
343                    probs[b]
344                        .partial_cmp(&probs[a])
345                        .unwrap_or(core::cmp::Ordering::Equal)
346                });
347                for &i in indices.iter().take(top_k) {
348                    let obs = SampleRef::new(features, target);
349                    self.experts[i].train_one(&obs);
350                }
351            }
352        }
353
354        // Step 4: update gate weights via SGD on cross-entropy gradient
355        // Find best expert (lowest loss on this sample)
356        let mut best_idx = 0;
357        let mut best_loss = f64::INFINITY;
358        for (i, expert) in self.experts.iter().enumerate() {
359            let pred = expert.predict(features);
360            let l = self.loss.loss(target, pred);
361            if l < best_loss {
362                best_loss = l;
363                best_idx = i;
364            }
365        }
366
367        // Cross-entropy gradient: dz_k = p_k - 1{k == best}
368        // SGD update: W_k -= lr * dz_k * x,  b_k -= lr * dz_k
369        for (i, (weights_row, bias)) in self
370            .gate_weights
371            .iter_mut()
372            .zip(self.gate_bias.iter_mut())
373            .enumerate()
374        {
375            let indicator = if i == best_idx { 1.0 } else { 0.0 };
376            let grad = probs[i] - indicator;
377            let lr = self.gate_lr;
378
379            for (j, &xj) in features.iter().enumerate() {
380                weights_row[j] -= lr * grad * xj;
381            }
382            *bias -= lr * grad;
383        }
384
385        self.samples_seen += 1;
386    }
387
388    /// Train on a batch of observations.
389    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
390        for sample in samples {
391            self.train_one(sample);
392        }
393    }
394
395    // -------------------------------------------------------------------
396    // Public API -- prediction
397    // -------------------------------------------------------------------
398
399    /// Predict the output for a feature vector.
400    ///
401    /// Computes the probability-weighted sum of expert predictions:
402    /// `ŷ = Σ_k p_k(x) · f_k(x)`.
403    pub fn predict(&self, features: &[f64]) -> f64 {
404        let probs = self.gating_probabilities(features);
405        let mut pred = 0.0;
406        for (i, &p) in probs.iter().enumerate() {
407            pred += p * self.experts[i].predict(features);
408        }
409        pred
410    }
411
412    /// Predict with gating probabilities returned alongside the prediction.
413    ///
414    /// Returns `(prediction, probabilities)` where probabilities is a K-length
415    /// vector summing to 1.0.
416    pub fn predict_with_gating(&self, features: &[f64]) -> (f64, Vec<f64>) {
417        let probs = self.gating_probabilities(features);
418        let mut pred = 0.0;
419        for (i, &p) in probs.iter().enumerate() {
420            pred += p * self.experts[i].predict(features);
421        }
422        (pred, probs)
423    }
424
425    /// Get each expert's individual prediction for a feature vector.
426    ///
427    /// Returns a K-length vector of raw predictions, one per expert.
428    pub fn expert_predictions(&self, features: &[f64]) -> Vec<f64> {
429        self.experts.iter().map(|e| e.predict(features)).collect()
430    }
431
432    // -------------------------------------------------------------------
433    // Public API -- inspection
434    // -------------------------------------------------------------------
435
436    /// Number of experts in the mixture.
437    #[inline]
438    pub fn n_experts(&self) -> usize {
439        self.experts.len()
440    }
441
442    /// Total training samples seen.
443    #[inline]
444    pub fn n_samples_seen(&self) -> u64 {
445        self.samples_seen
446    }
447
448    /// Immutable access to all experts.
449    pub fn experts(&self) -> &[SGBT<L>] {
450        &self.experts
451    }
452
453    /// Immutable access to a specific expert.
454    ///
455    /// # Panics
456    ///
457    /// Panics if `idx >= n_experts`.
458    pub fn expert(&self, idx: usize) -> &SGBT<L> {
459        &self.experts[idx]
460    }
461
462    /// Reset the entire MoE to its initial state.
463    ///
464    /// Resets all experts, clears gate weights and biases back to zeros,
465    /// and resets the sample counter.
466    pub fn reset(&mut self) {
467        for expert in &mut self.experts {
468            expert.reset();
469        }
470        let k = self.experts.len();
471        self.gate_weights.clear();
472        self.gate_bias = vec![0.0; k];
473        self.n_features = None;
474        self.samples_seen = 0;
475    }
476}
477
478// ---------------------------------------------------------------------------
479// StreamingLearner impl
480// ---------------------------------------------------------------------------
481
482use crate::learner::StreamingLearner;
483
484impl<L: Loss> StreamingLearner for MoESGBT<L> {
485    fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
486        let sample = SampleRef::weighted(features, target, weight);
487        // UFCS: call the inherent train_one(&impl Observation), not this trait method.
488        MoESGBT::train_one(self, &sample);
489    }
490
491    fn predict(&self, features: &[f64]) -> f64 {
492        MoESGBT::predict(self, features)
493    }
494
495    fn n_samples_seen(&self) -> u64 {
496        self.samples_seen
497    }
498
499    fn reset(&mut self) {
500        MoESGBT::reset(self);
501    }
502}
503
504// ===========================================================================
505// Tests
506// ===========================================================================
507
508#[cfg(test)]
509mod tests {
510    use super::*;
511    use crate::loss::huber::HuberLoss;
512    use crate::sample::Sample;
513    use alloc::boxed::Box;
514    use alloc::vec;
515    use alloc::vec::Vec;
516
517    /// Helper: build a minimal config for tests.
518    fn test_config() -> SGBTConfig {
519        SGBTConfig::builder()
520            .n_steps(5)
521            .learning_rate(0.1)
522            .grace_period(5)
523            .build()
524            .unwrap()
525    }
526
527    #[test]
528    fn test_creation() {
529        let moe = MoESGBT::new(test_config(), 3);
530        assert_eq!(moe.n_experts(), 3);
531        assert_eq!(moe.n_samples_seen(), 0);
532    }
533
534    #[test]
535    fn test_with_loss() {
536        let moe = MoESGBT::with_loss(test_config(), HuberLoss { delta: 1.0 }, 4);
537        assert_eq!(moe.n_experts(), 4);
538        assert_eq!(moe.n_samples_seen(), 0);
539    }
540
541    #[test]
542    fn test_soft_gating_trains_all() {
543        let mut moe = MoESGBT::new(test_config(), 3);
544        let sample = Sample::new(vec![1.0, 2.0], 5.0);
545
546        moe.train_one(&sample);
547
548        // In soft mode, every expert should have seen the sample
549        for i in 0..3 {
550            assert_eq!(moe.expert(i).n_samples_seen(), 1);
551        }
552    }
553
554    #[test]
555    fn test_hard_gating_top_k() {
556        let mut moe = MoESGBT::with_gating(
557            test_config(),
558            SquaredLoss,
559            4,
560            GatingMode::Hard { top_k: 2 },
561            0.01,
562        );
563        let sample = Sample::new(vec![1.0, 2.0], 5.0);
564
565        moe.train_one(&sample);
566
567        // Exactly top_k=2 experts should have received the sample
568        let trained_count = (0..4)
569            .filter(|&i| moe.expert(i).n_samples_seen() > 0)
570            .count();
571        assert_eq!(trained_count, 2);
572    }
573
574    #[test]
575    fn test_gating_probabilities_sum_to_one() {
576        let mut moe = MoESGBT::new(test_config(), 5);
577
578        // Before training: uniform probabilities
579        let probs = moe.gating_probabilities(&[1.0, 2.0]);
580        let sum: f64 = probs.iter().sum();
581        assert!((sum - 1.0).abs() < 1e-10, "pre-training sum = {}", sum);
582
583        // After training: probabilities should still sum to 1
584        for i in 0..20 {
585            let sample = Sample::new(vec![i as f64, (i * 2) as f64], i as f64);
586            moe.train_one(&sample);
587        }
588        let probs = moe.gating_probabilities(&[5.0, 10.0]);
589        let sum: f64 = probs.iter().sum();
590        assert!((sum - 1.0).abs() < 1e-10, "post-training sum = {}", sum);
591    }
592
593    #[test]
594    fn test_prediction_changes_after_training() {
595        let mut moe = MoESGBT::new(test_config(), 3);
596        let features = vec![1.0, 2.0, 3.0];
597
598        let pred_before = moe.predict(&features);
599
600        for i in 0..50 {
601            let sample = Sample::new(features.clone(), 10.0 + i as f64 * 0.1);
602            moe.train_one(&sample);
603        }
604
605        let pred_after = moe.predict(&features);
606        assert!(
607            (pred_after - pred_before).abs() > 1e-6,
608            "prediction should change after training: before={}, after={}",
609            pred_before,
610            pred_after
611        );
612    }
613
614    #[test]
615    fn test_expert_specialization() {
616        // Two regions: x < 0 targets ~-10, x >= 0 targets ~+10
617        let mut moe = MoESGBT::with_gating(test_config(), SquaredLoss, 2, GatingMode::Soft, 0.05);
618
619        // Train with separable data
620        for i in 0..200 {
621            let x = if i % 2 == 0 {
622                -(i as f64 + 1.0)
623            } else {
624                i as f64 + 1.0
625            };
626            let target = if x < 0.0 { -10.0 } else { 10.0 };
627            let sample = Sample::new(vec![x], target);
628            moe.train_one(&sample);
629        }
630
631        // After training, the gating probabilities should differ for
632        // negative vs positive inputs
633        let probs_neg = moe.gating_probabilities(&[-5.0]);
634        let probs_pos = moe.gating_probabilities(&[5.0]);
635
636        // The dominant expert should be different (or at least the distributions
637        // should be noticeably different)
638        let diff: f64 = probs_neg
639            .iter()
640            .zip(probs_pos.iter())
641            .map(|(a, b)| (a - b).abs())
642            .sum();
643        assert!(
644            diff > 0.01,
645            "gate should route differently: neg={:?}, pos={:?}",
646            probs_neg,
647            probs_pos
648        );
649    }
650
651    #[test]
652    fn test_predict_with_gating() {
653        let mut moe = MoESGBT::new(test_config(), 3);
654        let sample = Sample::new(vec![1.0, 2.0], 5.0);
655        moe.train_one(&sample);
656
657        let (pred, probs) = moe.predict_with_gating(&[1.0, 2.0]);
658        assert_eq!(probs.len(), 3);
659        let sum: f64 = probs.iter().sum();
660        assert!((sum - 1.0).abs() < 1e-10);
661
662        // Prediction should equal weighted sum of expert predictions
663        let expert_preds = moe.expert_predictions(&[1.0, 2.0]);
664        let expected: f64 = probs
665            .iter()
666            .zip(expert_preds.iter())
667            .map(|(p, e)| p * e)
668            .sum();
669        assert!(
670            (pred - expected).abs() < 1e-10,
671            "pred={} expected={}",
672            pred,
673            expected
674        );
675    }
676
677    #[test]
678    fn test_expert_predictions() {
679        let mut moe = MoESGBT::new(test_config(), 3);
680        for i in 0..10 {
681            let sample = Sample::new(vec![i as f64], i as f64);
682            moe.train_one(&sample);
683        }
684
685        let preds = moe.expert_predictions(&[5.0]);
686        assert_eq!(preds.len(), 3);
687        // Each expert should produce a finite prediction
688        for &p in &preds {
689            assert!(p.is_finite(), "expert prediction should be finite: {}", p);
690        }
691    }
692
693    #[test]
694    fn test_n_experts() {
695        let moe = MoESGBT::new(test_config(), 7);
696        assert_eq!(moe.n_experts(), 7);
697        assert_eq!(moe.experts().len(), 7);
698    }
699
700    #[test]
701    fn test_n_samples_seen() {
702        let mut moe = MoESGBT::new(test_config(), 2);
703        assert_eq!(moe.n_samples_seen(), 0);
704
705        for i in 0..25 {
706            moe.train_one(&Sample::new(vec![i as f64], i as f64));
707        }
708        assert_eq!(moe.n_samples_seen(), 25);
709    }
710
711    #[test]
712    fn test_reset() {
713        let mut moe = MoESGBT::new(test_config(), 3);
714
715        for i in 0..50 {
716            moe.train_one(&Sample::new(vec![i as f64, (i * 2) as f64], i as f64));
717        }
718        assert_eq!(moe.n_samples_seen(), 50);
719
720        moe.reset();
721
722        assert_eq!(moe.n_samples_seen(), 0);
723        assert_eq!(moe.n_experts(), 3);
724        // Gate should be re-lazily-initialized
725        let probs = moe.gating_probabilities(&[1.0, 2.0]);
726        assert_eq!(probs.len(), 3);
727        // After reset, probabilities are uniform again
728        for &p in &probs {
729            assert!(
730                (p - 1.0 / 3.0).abs() < 1e-10,
731                "expected uniform after reset, got {}",
732                p
733            );
734        }
735    }
736
737    #[test]
738    fn test_single_expert() {
739        // With a single expert, MoE should behave like a plain SGBT
740        let config = test_config();
741        let mut moe = MoESGBT::new(config.clone(), 1);
742
743        let mut plain = SGBT::new({
744            let mut cfg = config.clone();
745            cfg.seed = config.seed ^ 0x0000_0E00_0000_0000;
746            cfg
747        });
748
749        // The single expert gets weight=1.0 always, so predictions should
750        // be very close (both see same data, same seed)
751        for i in 0..30 {
752            let sample = Sample::new(vec![i as f64], i as f64 * 2.0);
753            moe.train_one(&sample);
754            // For the plain SGBT, we need to replicate the soft-gating weight.
755            // With one expert, p=1.0, so SampleRef::weighted(features, target, 1.0)
756            // is equivalent to a normal sample (weight=1.0).
757            let weighted = SampleRef::weighted(&sample.features, sample.target, 1.0);
758            plain.train_one(&weighted);
759        }
760
761        let moe_pred = moe.predict(&[15.0]);
762        let plain_pred = plain.predict(&[15.0]);
763        assert!(
764            (moe_pred - plain_pred).abs() < 1e-6,
765            "single expert MoE should match plain SGBT: moe={}, plain={}",
766            moe_pred,
767            plain_pred
768        );
769    }
770
771    #[test]
772    fn test_gate_lr_effect() {
773        // A higher gate learning rate should cause the gate to diverge from
774        // uniform faster than a lower one.
775        let config = test_config();
776
777        let mut moe_low =
778            MoESGBT::with_gating(config.clone(), SquaredLoss, 3, GatingMode::Soft, 0.001);
779        let mut moe_high = MoESGBT::with_gating(config, SquaredLoss, 3, GatingMode::Soft, 0.1);
780
781        // Train both on the same data
782        for i in 0..50 {
783            let sample = Sample::new(vec![i as f64], i as f64);
784            moe_low.train_one(&sample);
785            moe_high.train_one(&sample);
786        }
787
788        // Measure deviation from uniform for both
789        let uniform = 1.0 / 3.0;
790        let probs_low = moe_low.gating_probabilities(&[25.0]);
791        let probs_high = moe_high.gating_probabilities(&[25.0]);
792
793        let dev_low: f64 = probs_low.iter().map(|p| (p - uniform).abs()).sum();
794        let dev_high: f64 = probs_high.iter().map(|p| (p - uniform).abs()).sum();
795
796        assert!(
797            dev_high > dev_low,
798            "higher gate_lr should cause more deviation from uniform: low={}, high={}",
799            dev_low,
800            dev_high
801        );
802    }
803
804    #[test]
805    fn test_batch_training() {
806        let mut moe = MoESGBT::new(test_config(), 3);
807
808        let samples: Vec<Sample> = (0..20)
809            .map(|i| Sample::new(vec![i as f64, (i * 3) as f64], i as f64))
810            .collect();
811
812        moe.train_batch(&samples);
813
814        assert_eq!(moe.n_samples_seen(), 20);
815
816        // Should produce non-zero predictions after batch training
817        let pred = moe.predict(&[10.0, 30.0]);
818        assert!(pred.is_finite());
819    }
820
821    #[test]
822    fn streaming_learner_trait_object() {
823        let config = test_config();
824        let model = MoESGBT::new(config, 3);
825        let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
826        for i in 0..100 {
827            let x = i as f64 * 0.1;
828            boxed.train(&[x], x * 2.0);
829        }
830        assert_eq!(boxed.n_samples_seen(), 100);
831        let pred = boxed.predict(&[5.0]);
832        assert!(pred.is_finite());
833        boxed.reset();
834        assert_eq!(boxed.n_samples_seen(), 0);
835    }
836}