Skip to main content

irithyll_core/ensemble/
moe.rs

1//! Streaming Mixture of Experts over SGBT ensembles.
2//!
3//! Implements a gated mixture of K independent [`SGBT`] experts with a learned
4//! linear softmax gate. Each expert is a full streaming gradient boosted tree
5//! ensemble; the gate routes incoming samples to the most relevant expert(s)
6//! based on the feature vector, enabling capacity specialization across
7//! different regions of the input space.
8//!
9//! # Algorithm
10//!
11//! The gating network computes K logits `z_k = W_k · x + b_k` and applies
12//! softmax to obtain routing probabilities `p_k = softmax(z)_k`. Prediction
13//! is the probability-weighted sum of expert predictions:
14//!
15//! ```text
16//! ŷ = Σ_k  p_k(x) · f_k(x)
17//! ```
18//!
19//! During training, the gate is updated via online SGD on the cross-entropy
20//! loss between the softmax distribution and the one-hot indicator of the
21//! best expert (lowest loss on the current sample). This encourages the gate
22//! to learn which expert is most competent for each region.
23//!
24//! Two gating modes are supported:
25//!
26//! - **Soft** (default): All experts receive every sample, weighted by their
27//!   gating probability. This maximizes information flow but has O(K) training
28//!   cost per sample.
29//! - **Hard (top-k)**: Only the top-k experts (by gating probability) receive
30//!   the sample. This reduces computation when K is large, at the cost of
31//!   slower expert specialization.
32//!
33//! # References
34//!
35//! - Jacobs, R. A., Jordan, M. I., Nowlan, S. J., & Hinton, G. E. (1991).
36//!   Adaptive Mixtures of Local Experts. *Neural Computation*, 3(1), 79–87.
37//! - Shazeer, N., Mirhoseini, A., Maziarz, K., Davis, A., Le, Q., Hinton, G.,
38//!   & Dean, J. (2017). Outrageously Large Neural Networks: The Sparsely-Gated
39//!   Mixture-of-Experts Layer. *ICLR 2017*.
40//!
41//! # Example
42//!
43//! ```text
44//! use irithyll::ensemble::moe::{MoESGBT, GatingMode};
45//! use irithyll::SGBTConfig;
46//!
47//! let config = SGBTConfig::builder()
48//!     .n_steps(10)
49//!     .learning_rate(0.1)
50//!     .grace_period(10)
51//!     .build()
52//!     .unwrap();
53//!
54//! let mut moe = MoESGBT::new(config, 3);
55//! moe.train_one(&irithyll::Sample::new(vec![1.0, 2.0], 3.0));
56//! let pred = moe.predict(&[1.0, 2.0]);
57//! ```
58
59use alloc::vec;
60use alloc::vec::Vec;
61
62use core::fmt;
63
64use crate::ensemble::config::SGBTConfig;
65use crate::ensemble::SGBT;
66use crate::loss::squared::SquaredLoss;
67use crate::loss::Loss;
68use crate::sample::{Observation, SampleRef};
69
70// ---------------------------------------------------------------------------
71// GatingMode
72// ---------------------------------------------------------------------------
73
74/// Controls how the gate routes samples to experts.
75///
76/// - [`Soft`](GatingMode::Soft): every expert sees every sample, weighted by
77///   gating probability. Maximizes information flow.
78/// - [`Hard`](GatingMode::Hard): only the `top_k` experts with highest gating
79///   probability receive the sample. Reduces cost when K is large.
80#[derive(Debug, Clone)]
81#[non_exhaustive]
82pub enum GatingMode {
83    /// All experts receive every sample, weighted by gating probability.
84    Soft,
85    /// Only the top-k experts receive the sample (sparse routing).
86    Hard {
87        /// Number of experts to route each sample to.
88        top_k: usize,
89    },
90}
91
92// ---------------------------------------------------------------------------
93// Softmax (numerically stable)
94// ---------------------------------------------------------------------------
95
96/// Numerically stable softmax: subtract max logit before exponentiating to
97/// prevent overflow, then normalize.
98pub(crate) fn softmax(logits: &[f64]) -> Vec<f64> {
99    let max = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
100    let exps: Vec<f64> = logits.iter().map(|&z| crate::math::exp(z - max)).collect();
101    let sum: f64 = exps.iter().sum();
102    exps.iter().map(|&e| e / sum).collect()
103}
104
105// ---------------------------------------------------------------------------
106// MoESGBT
107// ---------------------------------------------------------------------------
108
109/// Streaming Mixture of Experts over SGBT ensembles.
110///
111/// Combines K independent [`SGBT<L>`] experts with a learned linear softmax
112/// gating network. The gate is trained online via SGD to route samples to the
113/// expert with the lowest loss, while all experts (or the top-k in hard gating
114/// mode) are trained on each incoming sample.
115///
116/// Generic over `L: Loss` so the expert loss function is monomorphized. The
117/// default is [`SquaredLoss`] for regression tasks.
118///
119/// # Gate Architecture
120///
121/// The gate is a single linear layer: `z_k = W_k · x + b_k` followed by
122/// softmax. Weights are lazily initialized to zeros on the first sample
123/// (since the feature dimensionality is not known at construction time).
124/// The gate learns via cross-entropy gradient descent against the one-hot
125/// indicator of the best expert per sample.
126pub struct MoESGBT<L: Loss = SquaredLoss> {
127    /// The K expert SGBT ensembles.
128    experts: Vec<SGBT<L>>,
129    /// Gate weight matrix [K x d], lazily initialized on first sample.
130    gate_weights: Vec<Vec<f64>>,
131    /// Gate bias vector [K].
132    gate_bias: Vec<f64>,
133    /// Learning rate for the gating network SGD updates.
134    gate_lr: f64,
135    /// Number of features (set on first sample, `None` until then).
136    n_features: Option<usize>,
137    /// Gating mode (soft or hard top-k).
138    gating_mode: GatingMode,
139    /// Configuration used to construct each expert.
140    config: SGBTConfig,
141    /// Loss function (shared type with experts, used for best-expert selection).
142    loss: L,
143    /// Total training samples seen.
144    samples_seen: u64,
145}
146
147// ---------------------------------------------------------------------------
148// Clone
149// ---------------------------------------------------------------------------
150
151impl<L: Loss + Clone> Clone for MoESGBT<L> {
152    fn clone(&self) -> Self {
153        Self {
154            experts: self.experts.clone(),
155            gate_weights: self.gate_weights.clone(),
156            gate_bias: self.gate_bias.clone(),
157            gate_lr: self.gate_lr,
158            n_features: self.n_features,
159            gating_mode: self.gating_mode.clone(),
160            config: self.config.clone(),
161            loss: self.loss.clone(),
162            samples_seen: self.samples_seen,
163        }
164    }
165}
166
167// ---------------------------------------------------------------------------
168// Debug
169// ---------------------------------------------------------------------------
170
171impl<L: Loss> fmt::Debug for MoESGBT<L> {
172    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173        f.debug_struct("MoESGBT")
174            .field("n_experts", &self.experts.len())
175            .field("gating_mode", &self.gating_mode)
176            .field("samples_seen", &self.samples_seen)
177            .finish()
178    }
179}
180
181// ---------------------------------------------------------------------------
182// Default loss constructor (SquaredLoss)
183// ---------------------------------------------------------------------------
184
185impl MoESGBT<SquaredLoss> {
186    /// Create a new MoE ensemble with squared loss (regression) and soft gating.
187    ///
188    /// Each expert is seeded uniquely via `config.seed ^ (0x0000_0E00_0000_0000 | i)`.
189    /// The gating learning rate defaults to 0.01.
190    ///
191    /// # Panics
192    ///
193    /// Panics if `n_experts < 1`.
194    pub fn new(config: SGBTConfig, n_experts: usize) -> Self {
195        Self::with_loss(config, SquaredLoss, n_experts)
196    }
197}
198
199// ---------------------------------------------------------------------------
200// General impl
201// ---------------------------------------------------------------------------
202
203impl<L: Loss + Clone> MoESGBT<L> {
204    /// Create a new MoE ensemble with a custom loss and soft gating.
205    ///
206    /// # Panics
207    ///
208    /// Panics if `n_experts < 1`.
209    pub fn with_loss(config: SGBTConfig, loss: L, n_experts: usize) -> Self {
210        Self::with_gating(config, loss, n_experts, GatingMode::Soft, 0.01)
211    }
212
213    /// Create a new MoE ensemble with full control over gating mode and gate
214    /// learning rate.
215    ///
216    /// # Panics
217    ///
218    /// Panics if `n_experts < 1`.
219    pub fn with_gating(
220        config: SGBTConfig,
221        loss: L,
222        n_experts: usize,
223        gating_mode: GatingMode,
224        gate_lr: f64,
225    ) -> Self {
226        assert!(n_experts >= 1, "MoESGBT requires at least 1 expert");
227
228        let experts = (0..n_experts)
229            .map(|i| {
230                let mut cfg = config.clone();
231                cfg.seed = config.seed ^ (0x0000_0E00_0000_0000 | i as u64);
232                SGBT::with_loss(cfg, loss.clone())
233            })
234            .collect();
235
236        let gate_bias = vec![0.0; n_experts];
237
238        Self {
239            experts,
240            gate_weights: Vec::new(), // lazy init
241            gate_bias,
242            gate_lr,
243            n_features: None,
244            gating_mode,
245            config,
246            loss,
247            samples_seen: 0,
248        }
249    }
250}
251
252impl<L: Loss> MoESGBT<L> {
253    // -------------------------------------------------------------------
254    // Internal helpers
255    // -------------------------------------------------------------------
256
257    /// Ensure the gate weight matrix is initialized to the correct dimensions.
258    /// Called lazily on the first sample when `n_features` is discovered.
259    fn ensure_gate_init(&mut self, d: usize) {
260        if self.n_features.is_none() {
261            let k = self.experts.len();
262            self.gate_weights = vec![vec![0.0; d]; k];
263            self.n_features = Some(d);
264        }
265    }
266
267    /// Compute raw gate logits: z_k = W_k · x + b_k.
268    fn gate_logits(&self, features: &[f64]) -> Vec<f64> {
269        let k = self.experts.len();
270        let mut logits = Vec::with_capacity(k);
271        for i in 0..k {
272            let dot: f64 = self.gate_weights[i]
273                .iter()
274                .zip(features.iter())
275                .map(|(&w, &x)| w * x)
276                .sum();
277            logits.push(dot + self.gate_bias[i]);
278        }
279        logits
280    }
281
282    // -------------------------------------------------------------------
283    // Public API -- gating
284    // -------------------------------------------------------------------
285
286    /// Compute gating probabilities for a feature vector.
287    ///
288    /// Returns a vector of K probabilities that sum to 1.0, one per expert.
289    /// The gate must be initialized (at least one training sample seen),
290    /// otherwise returns uniform probabilities.
291    pub fn gating_probabilities(&self, features: &[f64]) -> Vec<f64> {
292        let k = self.experts.len();
293        if self.n_features.is_none() {
294            // Gate not initialized yet -- return uniform
295            return vec![1.0 / k as f64; k];
296        }
297        let logits = self.gate_logits(features);
298        softmax(&logits)
299    }
300
301    // -------------------------------------------------------------------
302    // Public API -- training
303    // -------------------------------------------------------------------
304
305    /// Train on a single observation.
306    ///
307    /// 1. Lazily initializes the gate weights if this is the first sample.
308    /// 2. Computes gating probabilities via softmax over the linear gate.
309    /// 3. Routes the sample to experts according to the gating mode:
310    ///    - **Soft**: all experts receive the sample, each weighted by its
311    ///      gating probability (via `SampleRef::weighted`).
312    ///    - **Hard(top_k)**: only the top-k experts by probability receive
313    ///      the sample (with unit weight).
314    /// 4. Updates gate weights via SGD on the cross-entropy gradient:
315    ///    find the best expert (lowest loss), compute `dz_k = p_k - 1{k==best}`,
316    ///    and apply `W_k -= gate_lr * dz_k * x`, `b_k -= gate_lr * dz_k`.
317    pub fn train_one(&mut self, sample: &impl Observation) {
318        let features = sample.features();
319        let target = sample.target();
320        let d = features.len();
321
322        // Step 1: lazy gate initialization
323        self.ensure_gate_init(d);
324
325        // Step 2: compute gating probabilities
326        let logits = self.gate_logits(features);
327        let probs = softmax(&logits);
328        let k = self.experts.len();
329
330        // Step 3: train experts based on gating mode
331        match &self.gating_mode {
332            GatingMode::Soft => {
333                // Every expert gets the sample, weighted by gating probability
334                for (expert, &prob) in self.experts.iter_mut().zip(probs.iter()) {
335                    let weighted = SampleRef::weighted(features, target, prob);
336                    expert.train_one(&weighted);
337                }
338            }
339            GatingMode::Hard { top_k } => {
340                // Only the top-k experts get the sample
341                let top_k = (*top_k).min(k);
342                let mut indices: Vec<usize> = (0..k).collect();
343                indices.sort_unstable_by(|&a, &b| {
344                    probs[b]
345                        .partial_cmp(&probs[a])
346                        .unwrap_or(core::cmp::Ordering::Equal)
347                });
348                for &i in indices.iter().take(top_k) {
349                    let obs = SampleRef::new(features, target);
350                    self.experts[i].train_one(&obs);
351                }
352            }
353        }
354
355        // Step 4: update gate weights via SGD on cross-entropy gradient
356        // Find best expert (lowest loss on this sample)
357        let mut best_idx = 0;
358        let mut best_loss = f64::INFINITY;
359        for (i, expert) in self.experts.iter().enumerate() {
360            let pred = expert.predict(features);
361            let l = self.loss.loss(target, pred);
362            if l < best_loss {
363                best_loss = l;
364                best_idx = i;
365            }
366        }
367
368        // Cross-entropy gradient: dz_k = p_k - 1{k == best}
369        // SGD update: W_k -= lr * dz_k * x,  b_k -= lr * dz_k
370        for (i, (weights_row, bias)) in self
371            .gate_weights
372            .iter_mut()
373            .zip(self.gate_bias.iter_mut())
374            .enumerate()
375        {
376            let indicator = if i == best_idx { 1.0 } else { 0.0 };
377            let grad = probs[i] - indicator;
378            let lr = self.gate_lr;
379
380            for (j, &xj) in features.iter().enumerate() {
381                weights_row[j] -= lr * grad * xj;
382            }
383            *bias -= lr * grad;
384        }
385
386        self.samples_seen += 1;
387    }
388
389    /// Train on a batch of observations.
390    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
391        for sample in samples {
392            self.train_one(sample);
393        }
394    }
395
396    // -------------------------------------------------------------------
397    // Public API -- prediction
398    // -------------------------------------------------------------------
399
400    /// Predict the output for a feature vector.
401    ///
402    /// Computes the probability-weighted sum of expert predictions:
403    /// `ŷ = Σ_k p_k(x) · f_k(x)`.
404    pub fn predict(&self, features: &[f64]) -> f64 {
405        let probs = self.gating_probabilities(features);
406        let mut pred = 0.0;
407        for (i, &p) in probs.iter().enumerate() {
408            pred += p * self.experts[i].predict(features);
409        }
410        pred
411    }
412
413    /// Predict with gating probabilities returned alongside the prediction.
414    ///
415    /// Returns `(prediction, probabilities)` where probabilities is a K-length
416    /// vector summing to 1.0.
417    pub fn predict_with_gating(&self, features: &[f64]) -> (f64, Vec<f64>) {
418        let probs = self.gating_probabilities(features);
419        let mut pred = 0.0;
420        for (i, &p) in probs.iter().enumerate() {
421            pred += p * self.experts[i].predict(features);
422        }
423        (pred, probs)
424    }
425
426    /// Get each expert's individual prediction for a feature vector.
427    ///
428    /// Returns a K-length vector of raw predictions, one per expert.
429    pub fn expert_predictions(&self, features: &[f64]) -> Vec<f64> {
430        self.experts.iter().map(|e| e.predict(features)).collect()
431    }
432
433    // -------------------------------------------------------------------
434    // Public API -- inspection
435    // -------------------------------------------------------------------
436
437    /// Number of experts in the mixture.
438    #[inline]
439    pub fn n_experts(&self) -> usize {
440        self.experts.len()
441    }
442
443    /// Total training samples seen.
444    #[inline]
445    pub fn n_samples_seen(&self) -> u64 {
446        self.samples_seen
447    }
448
449    /// Immutable access to all experts.
450    pub fn experts(&self) -> &[SGBT<L>] {
451        &self.experts
452    }
453
454    /// Immutable access to a specific expert.
455    ///
456    /// # Panics
457    ///
458    /// Panics if `idx >= n_experts`.
459    pub fn expert(&self, idx: usize) -> &SGBT<L> {
460        &self.experts[idx]
461    }
462
463    /// Reset the entire MoE to its initial state.
464    ///
465    /// Resets all experts, clears gate weights and biases back to zeros,
466    /// and resets the sample counter.
467    pub fn reset(&mut self) {
468        for expert in &mut self.experts {
469            expert.reset();
470        }
471        let k = self.experts.len();
472        self.gate_weights.clear();
473        self.gate_bias = vec![0.0; k];
474        self.n_features = None;
475        self.samples_seen = 0;
476    }
477}
478
479// ---------------------------------------------------------------------------
480// StreamingLearner impl
481// ---------------------------------------------------------------------------
482
483use crate::learner::StreamingLearner;
484
485impl<L: Loss> StreamingLearner for MoESGBT<L> {
486    fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
487        let sample = SampleRef::weighted(features, target, weight);
488        // UFCS: call the inherent train_one(&impl Observation), not this trait method.
489        MoESGBT::train_one(self, &sample);
490    }
491
492    fn predict(&self, features: &[f64]) -> f64 {
493        MoESGBT::predict(self, features)
494    }
495
496    fn n_samples_seen(&self) -> u64 {
497        self.samples_seen
498    }
499
500    fn reset(&mut self) {
501        MoESGBT::reset(self);
502    }
503}
504
505// ===========================================================================
506// Tests
507// ===========================================================================
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use crate::loss::huber::HuberLoss;
513    use crate::sample::Sample;
514    use alloc::boxed::Box;
515    use alloc::vec;
516    use alloc::vec::Vec;
517
518    /// Helper: build a minimal config for tests.
519    fn test_config() -> SGBTConfig {
520        SGBTConfig::builder()
521            .n_steps(5)
522            .learning_rate(0.1)
523            .grace_period(5)
524            .build()
525            .unwrap()
526    }
527
528    #[test]
529    fn test_creation() {
530        let moe = MoESGBT::new(test_config(), 3);
531        assert_eq!(moe.n_experts(), 3);
532        assert_eq!(moe.n_samples_seen(), 0);
533    }
534
535    #[test]
536    fn test_with_loss() {
537        let moe = MoESGBT::with_loss(test_config(), HuberLoss { delta: 1.0 }, 4);
538        assert_eq!(moe.n_experts(), 4);
539        assert_eq!(moe.n_samples_seen(), 0);
540    }
541
542    #[test]
543    fn test_soft_gating_trains_all() {
544        let mut moe = MoESGBT::new(test_config(), 3);
545        let sample = Sample::new(vec![1.0, 2.0], 5.0);
546
547        moe.train_one(&sample);
548
549        // In soft mode, every expert should have seen the sample
550        for i in 0..3 {
551            assert_eq!(moe.expert(i).n_samples_seen(), 1);
552        }
553    }
554
555    #[test]
556    fn test_hard_gating_top_k() {
557        let mut moe = MoESGBT::with_gating(
558            test_config(),
559            SquaredLoss,
560            4,
561            GatingMode::Hard { top_k: 2 },
562            0.01,
563        );
564        let sample = Sample::new(vec![1.0, 2.0], 5.0);
565
566        moe.train_one(&sample);
567
568        // Exactly top_k=2 experts should have received the sample
569        let trained_count = (0..4)
570            .filter(|&i| moe.expert(i).n_samples_seen() > 0)
571            .count();
572        assert_eq!(trained_count, 2);
573    }
574
575    #[test]
576    fn test_gating_probabilities_sum_to_one() {
577        let mut moe = MoESGBT::new(test_config(), 5);
578
579        // Before training: uniform probabilities
580        let probs = moe.gating_probabilities(&[1.0, 2.0]);
581        let sum: f64 = probs.iter().sum();
582        assert!((sum - 1.0).abs() < 1e-10, "pre-training sum = {}", sum);
583
584        // After training: probabilities should still sum to 1
585        for i in 0..20 {
586            let sample = Sample::new(vec![i as f64, (i * 2) as f64], i as f64);
587            moe.train_one(&sample);
588        }
589        let probs = moe.gating_probabilities(&[5.0, 10.0]);
590        let sum: f64 = probs.iter().sum();
591        assert!((sum - 1.0).abs() < 1e-10, "post-training sum = {}", sum);
592    }
593
594    #[test]
595    fn test_prediction_changes_after_training() {
596        let mut moe = MoESGBT::new(test_config(), 3);
597        let features = vec![1.0, 2.0, 3.0];
598
599        let pred_before = moe.predict(&features);
600
601        for i in 0..50 {
602            let sample = Sample::new(features.clone(), 10.0 + i as f64 * 0.1);
603            moe.train_one(&sample);
604        }
605
606        let pred_after = moe.predict(&features);
607        assert!(
608            (pred_after - pred_before).abs() > 1e-6,
609            "prediction should change after training: before={}, after={}",
610            pred_before,
611            pred_after
612        );
613    }
614
615    #[test]
616    fn test_expert_specialization() {
617        // Two regions: x < 0 targets ~-10, x >= 0 targets ~+10
618        let mut moe = MoESGBT::with_gating(test_config(), SquaredLoss, 2, GatingMode::Soft, 0.05);
619
620        // Train with separable data
621        for i in 0..200 {
622            let x = if i % 2 == 0 {
623                -(i as f64 + 1.0)
624            } else {
625                i as f64 + 1.0
626            };
627            let target = if x < 0.0 { -10.0 } else { 10.0 };
628            let sample = Sample::new(vec![x], target);
629            moe.train_one(&sample);
630        }
631
632        // After training, the gating probabilities should differ for
633        // negative vs positive inputs
634        let probs_neg = moe.gating_probabilities(&[-5.0]);
635        let probs_pos = moe.gating_probabilities(&[5.0]);
636
637        // The dominant expert should be different (or at least the distributions
638        // should be noticeably different)
639        let diff: f64 = probs_neg
640            .iter()
641            .zip(probs_pos.iter())
642            .map(|(a, b)| (a - b).abs())
643            .sum();
644        assert!(
645            diff > 0.01,
646            "gate should route differently: neg={:?}, pos={:?}",
647            probs_neg,
648            probs_pos
649        );
650    }
651
652    #[test]
653    fn test_predict_with_gating() {
654        let mut moe = MoESGBT::new(test_config(), 3);
655        let sample = Sample::new(vec![1.0, 2.0], 5.0);
656        moe.train_one(&sample);
657
658        let (pred, probs) = moe.predict_with_gating(&[1.0, 2.0]);
659        assert_eq!(probs.len(), 3);
660        let sum: f64 = probs.iter().sum();
661        assert!((sum - 1.0).abs() < 1e-10);
662
663        // Prediction should equal weighted sum of expert predictions
664        let expert_preds = moe.expert_predictions(&[1.0, 2.0]);
665        let expected: f64 = probs
666            .iter()
667            .zip(expert_preds.iter())
668            .map(|(p, e)| p * e)
669            .sum();
670        assert!(
671            (pred - expected).abs() < 1e-10,
672            "pred={} expected={}",
673            pred,
674            expected
675        );
676    }
677
678    #[test]
679    fn test_expert_predictions() {
680        let mut moe = MoESGBT::new(test_config(), 3);
681        for i in 0..10 {
682            let sample = Sample::new(vec![i as f64], i as f64);
683            moe.train_one(&sample);
684        }
685
686        let preds = moe.expert_predictions(&[5.0]);
687        assert_eq!(preds.len(), 3);
688        // Each expert should produce a finite prediction
689        for &p in &preds {
690            assert!(p.is_finite(), "expert prediction should be finite: {}", p);
691        }
692    }
693
694    #[test]
695    fn test_n_experts() {
696        let moe = MoESGBT::new(test_config(), 7);
697        assert_eq!(moe.n_experts(), 7);
698        assert_eq!(moe.experts().len(), 7);
699    }
700
701    #[test]
702    fn test_n_samples_seen() {
703        let mut moe = MoESGBT::new(test_config(), 2);
704        assert_eq!(moe.n_samples_seen(), 0);
705
706        for i in 0..25 {
707            moe.train_one(&Sample::new(vec![i as f64], i as f64));
708        }
709        assert_eq!(moe.n_samples_seen(), 25);
710    }
711
712    #[test]
713    fn test_reset() {
714        let mut moe = MoESGBT::new(test_config(), 3);
715
716        for i in 0..50 {
717            moe.train_one(&Sample::new(vec![i as f64, (i * 2) as f64], i as f64));
718        }
719        assert_eq!(moe.n_samples_seen(), 50);
720
721        moe.reset();
722
723        assert_eq!(moe.n_samples_seen(), 0);
724        assert_eq!(moe.n_experts(), 3);
725        // Gate should be re-lazily-initialized
726        let probs = moe.gating_probabilities(&[1.0, 2.0]);
727        assert_eq!(probs.len(), 3);
728        // After reset, probabilities are uniform again
729        for &p in &probs {
730            assert!(
731                (p - 1.0 / 3.0).abs() < 1e-10,
732                "expected uniform after reset, got {}",
733                p
734            );
735        }
736    }
737
738    #[test]
739    fn test_single_expert() {
740        // With a single expert, MoE should behave like a plain SGBT
741        let config = test_config();
742        let mut moe = MoESGBT::new(config.clone(), 1);
743
744        let mut plain = SGBT::new({
745            let mut cfg = config.clone();
746            cfg.seed = config.seed ^ 0x0000_0E00_0000_0000;
747            cfg
748        });
749
750        // The single expert gets weight=1.0 always, so predictions should
751        // be very close (both see same data, same seed)
752        for i in 0..30 {
753            let sample = Sample::new(vec![i as f64], i as f64 * 2.0);
754            moe.train_one(&sample);
755            // For the plain SGBT, we need to replicate the soft-gating weight.
756            // With one expert, p=1.0, so SampleRef::weighted(features, target, 1.0)
757            // is equivalent to a normal sample (weight=1.0).
758            let weighted = SampleRef::weighted(&sample.features, sample.target, 1.0);
759            plain.train_one(&weighted);
760        }
761
762        let moe_pred = moe.predict(&[15.0]);
763        let plain_pred = plain.predict(&[15.0]);
764        assert!(
765            (moe_pred - plain_pred).abs() < 1e-6,
766            "single expert MoE should match plain SGBT: moe={}, plain={}",
767            moe_pred,
768            plain_pred
769        );
770    }
771
772    #[test]
773    fn test_gate_lr_effect() {
774        // A higher gate learning rate should cause the gate to diverge from
775        // uniform faster than a lower one.
776        let config = test_config();
777
778        let mut moe_low =
779            MoESGBT::with_gating(config.clone(), SquaredLoss, 3, GatingMode::Soft, 0.001);
780        let mut moe_high = MoESGBT::with_gating(config, SquaredLoss, 3, GatingMode::Soft, 0.1);
781
782        // Train both on the same data
783        for i in 0..50 {
784            let sample = Sample::new(vec![i as f64], i as f64);
785            moe_low.train_one(&sample);
786            moe_high.train_one(&sample);
787        }
788
789        // Measure deviation from uniform for both
790        let uniform = 1.0 / 3.0;
791        let probs_low = moe_low.gating_probabilities(&[25.0]);
792        let probs_high = moe_high.gating_probabilities(&[25.0]);
793
794        let dev_low: f64 = probs_low.iter().map(|p| (p - uniform).abs()).sum();
795        let dev_high: f64 = probs_high.iter().map(|p| (p - uniform).abs()).sum();
796
797        assert!(
798            dev_high > dev_low,
799            "higher gate_lr should cause more deviation from uniform: low={}, high={}",
800            dev_low,
801            dev_high
802        );
803    }
804
805    #[test]
806    fn test_batch_training() {
807        let mut moe = MoESGBT::new(test_config(), 3);
808
809        let samples: Vec<Sample> = (0..20)
810            .map(|i| Sample::new(vec![i as f64, (i * 3) as f64], i as f64))
811            .collect();
812
813        moe.train_batch(&samples);
814
815        assert_eq!(moe.n_samples_seen(), 20);
816
817        // Should produce non-zero predictions after batch training
818        let pred = moe.predict(&[10.0, 30.0]);
819        assert!(pred.is_finite());
820    }
821
822    #[test]
823    fn streaming_learner_trait_object() {
824        let config = test_config();
825        let model = MoESGBT::new(config, 3);
826        let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
827        for i in 0..100 {
828            let x = i as f64 * 0.1;
829            boxed.train(&[x], x * 2.0);
830        }
831        assert_eq!(boxed.n_samples_seen(), 100);
832        let pred = boxed.predict(&[5.0]);
833        assert!(pred.is_finite());
834        boxed.reset();
835        assert_eq!(boxed.n_samples_seen(), 0);
836    }
837}