irithyll_core/ensemble/moe.rs
1//! Streaming Mixture of Experts over SGBT ensembles.
2//!
3//! Implements a gated mixture of K independent [`SGBT`] experts with a learned
4//! linear softmax gate. Each expert is a full streaming gradient boosted tree
5//! ensemble; the gate routes incoming samples to the most relevant expert(s)
6//! based on the feature vector, enabling capacity specialization across
7//! different regions of the input space.
8//!
9//! # Algorithm
10//!
11//! The gating network computes K logits `z_k = W_k · x + b_k` and applies
12//! softmax to obtain routing probabilities `p_k = softmax(z)_k`. Prediction
13//! is the probability-weighted sum of expert predictions:
14//!
15//! ```text
16//! ŷ = Σ_k p_k(x) · f_k(x)
17//! ```
18//!
19//! During training, the gate is updated via online SGD on the cross-entropy
20//! loss between the softmax distribution and the one-hot indicator of the
21//! best expert (lowest loss on the current sample). This encourages the gate
22//! to learn which expert is most competent for each region.
23//!
24//! Two gating modes are supported:
25//!
26//! - **Soft** (default): All experts receive every sample, weighted by their
27//! gating probability. This maximizes information flow but has O(K) training
28//! cost per sample.
29//! - **Hard (top-k)**: Only the top-k experts (by gating probability) receive
30//! the sample. This reduces computation when K is large, at the cost of
31//! slower expert specialization.
32//!
33//! # References
34//!
35//! - Jacobs, R. A., Jordan, M. I., Nowlan, S. J., & Hinton, G. E. (1991).
36//! Adaptive Mixtures of Local Experts. *Neural Computation*, 3(1), 79–87.
37//! - Shazeer, N., Mirhoseini, A., Maziarz, K., Davis, A., Le, Q., Hinton, G.,
38//! & Dean, J. (2017). Outrageously Large Neural Networks: The Sparsely-Gated
39//! Mixture-of-Experts Layer. *ICLR 2017*.
40//!
41//! # Example
42//!
43//! ```text
44//! use irithyll::ensemble::moe::{MoESGBT, GatingMode};
45//! use irithyll::SGBTConfig;
46//!
47//! let config = SGBTConfig::builder()
48//! .n_steps(10)
49//! .learning_rate(0.1)
50//! .grace_period(10)
51//! .build()
52//! .unwrap();
53//!
54//! let mut moe = MoESGBT::new(config, 3);
55//! moe.train_one(&irithyll::Sample::new(vec![1.0, 2.0], 3.0));
56//! let pred = moe.predict(&[1.0, 2.0]);
57//! ```
58
59use alloc::vec;
60use alloc::vec::Vec;
61
62use core::fmt;
63
64use crate::ensemble::config::SGBTConfig;
65use crate::ensemble::SGBT;
66use crate::loss::squared::SquaredLoss;
67use crate::loss::Loss;
68use crate::sample::{Observation, SampleRef};
69
70// ---------------------------------------------------------------------------
71// GatingMode
72// ---------------------------------------------------------------------------
73
74/// Controls how the gate routes samples to experts.
75///
76/// - [`Soft`](GatingMode::Soft): every expert sees every sample, weighted by
77/// gating probability. Maximizes information flow.
78/// - [`Hard`](GatingMode::Hard): only the `top_k` experts with highest gating
79/// probability receive the sample. Reduces cost when K is large.
80#[derive(Debug, Clone)]
81pub enum GatingMode {
82 /// All experts receive every sample, weighted by gating probability.
83 Soft,
84 /// Only the top-k experts receive the sample (sparse routing).
85 Hard {
86 /// Number of experts to route each sample to.
87 top_k: usize,
88 },
89}
90
91// ---------------------------------------------------------------------------
92// Softmax (numerically stable)
93// ---------------------------------------------------------------------------
94
95/// Numerically stable softmax: subtract max logit before exponentiating to
96/// prevent overflow, then normalize.
97pub(crate) fn softmax(logits: &[f64]) -> Vec<f64> {
98 let max = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
99 let exps: Vec<f64> = logits.iter().map(|&z| crate::math::exp(z - max)).collect();
100 let sum: f64 = exps.iter().sum();
101 exps.iter().map(|&e| e / sum).collect()
102}
103
104// ---------------------------------------------------------------------------
105// MoESGBT
106// ---------------------------------------------------------------------------
107
108/// Streaming Mixture of Experts over SGBT ensembles.
109///
110/// Combines K independent [`SGBT<L>`] experts with a learned linear softmax
111/// gating network. The gate is trained online via SGD to route samples to the
112/// expert with the lowest loss, while all experts (or the top-k in hard gating
113/// mode) are trained on each incoming sample.
114///
115/// Generic over `L: Loss` so the expert loss function is monomorphized. The
116/// default is [`SquaredLoss`] for regression tasks.
117///
118/// # Gate Architecture
119///
120/// The gate is a single linear layer: `z_k = W_k · x + b_k` followed by
121/// softmax. Weights are lazily initialized to zeros on the first sample
122/// (since the feature dimensionality is not known at construction time).
123/// The gate learns via cross-entropy gradient descent against the one-hot
124/// indicator of the best expert per sample.
125pub struct MoESGBT<L: Loss = SquaredLoss> {
126 /// The K expert SGBT ensembles.
127 experts: Vec<SGBT<L>>,
128 /// Gate weight matrix [K x d], lazily initialized on first sample.
129 gate_weights: Vec<Vec<f64>>,
130 /// Gate bias vector [K].
131 gate_bias: Vec<f64>,
132 /// Learning rate for the gating network SGD updates.
133 gate_lr: f64,
134 /// Number of features (set on first sample, `None` until then).
135 n_features: Option<usize>,
136 /// Gating mode (soft or hard top-k).
137 gating_mode: GatingMode,
138 /// Configuration used to construct each expert.
139 config: SGBTConfig,
140 /// Loss function (shared type with experts, used for best-expert selection).
141 loss: L,
142 /// Total training samples seen.
143 samples_seen: u64,
144}
145
146// ---------------------------------------------------------------------------
147// Clone
148// ---------------------------------------------------------------------------
149
150impl<L: Loss + Clone> Clone for MoESGBT<L> {
151 fn clone(&self) -> Self {
152 Self {
153 experts: self.experts.clone(),
154 gate_weights: self.gate_weights.clone(),
155 gate_bias: self.gate_bias.clone(),
156 gate_lr: self.gate_lr,
157 n_features: self.n_features,
158 gating_mode: self.gating_mode.clone(),
159 config: self.config.clone(),
160 loss: self.loss.clone(),
161 samples_seen: self.samples_seen,
162 }
163 }
164}
165
166// ---------------------------------------------------------------------------
167// Debug
168// ---------------------------------------------------------------------------
169
170impl<L: Loss> fmt::Debug for MoESGBT<L> {
171 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172 f.debug_struct("MoESGBT")
173 .field("n_experts", &self.experts.len())
174 .field("gating_mode", &self.gating_mode)
175 .field("samples_seen", &self.samples_seen)
176 .finish()
177 }
178}
179
180// ---------------------------------------------------------------------------
181// Default loss constructor (SquaredLoss)
182// ---------------------------------------------------------------------------
183
184impl MoESGBT<SquaredLoss> {
185 /// Create a new MoE ensemble with squared loss (regression) and soft gating.
186 ///
187 /// Each expert is seeded uniquely via `config.seed ^ (0x0000_0E00_0000_0000 | i)`.
188 /// The gating learning rate defaults to 0.01.
189 ///
190 /// # Panics
191 ///
192 /// Panics if `n_experts < 1`.
193 pub fn new(config: SGBTConfig, n_experts: usize) -> Self {
194 Self::with_loss(config, SquaredLoss, n_experts)
195 }
196}
197
198// ---------------------------------------------------------------------------
199// General impl
200// ---------------------------------------------------------------------------
201
202impl<L: Loss + Clone> MoESGBT<L> {
203 /// Create a new MoE ensemble with a custom loss and soft gating.
204 ///
205 /// # Panics
206 ///
207 /// Panics if `n_experts < 1`.
208 pub fn with_loss(config: SGBTConfig, loss: L, n_experts: usize) -> Self {
209 Self::with_gating(config, loss, n_experts, GatingMode::Soft, 0.01)
210 }
211
212 /// Create a new MoE ensemble with full control over gating mode and gate
213 /// learning rate.
214 ///
215 /// # Panics
216 ///
217 /// Panics if `n_experts < 1`.
218 pub fn with_gating(
219 config: SGBTConfig,
220 loss: L,
221 n_experts: usize,
222 gating_mode: GatingMode,
223 gate_lr: f64,
224 ) -> Self {
225 assert!(n_experts >= 1, "MoESGBT requires at least 1 expert");
226
227 let experts = (0..n_experts)
228 .map(|i| {
229 let mut cfg = config.clone();
230 cfg.seed = config.seed ^ (0x0000_0E00_0000_0000 | i as u64);
231 SGBT::with_loss(cfg, loss.clone())
232 })
233 .collect();
234
235 let gate_bias = vec![0.0; n_experts];
236
237 Self {
238 experts,
239 gate_weights: Vec::new(), // lazy init
240 gate_bias,
241 gate_lr,
242 n_features: None,
243 gating_mode,
244 config,
245 loss,
246 samples_seen: 0,
247 }
248 }
249}
250
251impl<L: Loss> MoESGBT<L> {
252 // -------------------------------------------------------------------
253 // Internal helpers
254 // -------------------------------------------------------------------
255
256 /// Ensure the gate weight matrix is initialized to the correct dimensions.
257 /// Called lazily on the first sample when `n_features` is discovered.
258 fn ensure_gate_init(&mut self, d: usize) {
259 if self.n_features.is_none() {
260 let k = self.experts.len();
261 self.gate_weights = vec![vec![0.0; d]; k];
262 self.n_features = Some(d);
263 }
264 }
265
266 /// Compute raw gate logits: z_k = W_k · x + b_k.
267 fn gate_logits(&self, features: &[f64]) -> Vec<f64> {
268 let k = self.experts.len();
269 let mut logits = Vec::with_capacity(k);
270 for i in 0..k {
271 let dot: f64 = self.gate_weights[i]
272 .iter()
273 .zip(features.iter())
274 .map(|(&w, &x)| w * x)
275 .sum();
276 logits.push(dot + self.gate_bias[i]);
277 }
278 logits
279 }
280
281 // -------------------------------------------------------------------
282 // Public API -- gating
283 // -------------------------------------------------------------------
284
285 /// Compute gating probabilities for a feature vector.
286 ///
287 /// Returns a vector of K probabilities that sum to 1.0, one per expert.
288 /// The gate must be initialized (at least one training sample seen),
289 /// otherwise returns uniform probabilities.
290 pub fn gating_probabilities(&self, features: &[f64]) -> Vec<f64> {
291 let k = self.experts.len();
292 if self.n_features.is_none() {
293 // Gate not initialized yet -- return uniform
294 return vec![1.0 / k as f64; k];
295 }
296 let logits = self.gate_logits(features);
297 softmax(&logits)
298 }
299
300 // -------------------------------------------------------------------
301 // Public API -- training
302 // -------------------------------------------------------------------
303
304 /// Train on a single observation.
305 ///
306 /// 1. Lazily initializes the gate weights if this is the first sample.
307 /// 2. Computes gating probabilities via softmax over the linear gate.
308 /// 3. Routes the sample to experts according to the gating mode:
309 /// - **Soft**: all experts receive the sample, each weighted by its
310 /// gating probability (via `SampleRef::weighted`).
311 /// - **Hard(top_k)**: only the top-k experts by probability receive
312 /// the sample (with unit weight).
313 /// 4. Updates gate weights via SGD on the cross-entropy gradient:
314 /// find the best expert (lowest loss), compute `dz_k = p_k - 1{k==best}`,
315 /// and apply `W_k -= gate_lr * dz_k * x`, `b_k -= gate_lr * dz_k`.
316 pub fn train_one(&mut self, sample: &impl Observation) {
317 let features = sample.features();
318 let target = sample.target();
319 let d = features.len();
320
321 // Step 1: lazy gate initialization
322 self.ensure_gate_init(d);
323
324 // Step 2: compute gating probabilities
325 let logits = self.gate_logits(features);
326 let probs = softmax(&logits);
327 let k = self.experts.len();
328
329 // Step 3: train experts based on gating mode
330 match &self.gating_mode {
331 GatingMode::Soft => {
332 // Every expert gets the sample, weighted by gating probability
333 for (expert, &prob) in self.experts.iter_mut().zip(probs.iter()) {
334 let weighted = SampleRef::weighted(features, target, prob);
335 expert.train_one(&weighted);
336 }
337 }
338 GatingMode::Hard { top_k } => {
339 // Only the top-k experts get the sample
340 let top_k = (*top_k).min(k);
341 let mut indices: Vec<usize> = (0..k).collect();
342 indices.sort_unstable_by(|&a, &b| {
343 probs[b]
344 .partial_cmp(&probs[a])
345 .unwrap_or(core::cmp::Ordering::Equal)
346 });
347 for &i in indices.iter().take(top_k) {
348 let obs = SampleRef::new(features, target);
349 self.experts[i].train_one(&obs);
350 }
351 }
352 }
353
354 // Step 4: update gate weights via SGD on cross-entropy gradient
355 // Find best expert (lowest loss on this sample)
356 let mut best_idx = 0;
357 let mut best_loss = f64::INFINITY;
358 for (i, expert) in self.experts.iter().enumerate() {
359 let pred = expert.predict(features);
360 let l = self.loss.loss(target, pred);
361 if l < best_loss {
362 best_loss = l;
363 best_idx = i;
364 }
365 }
366
367 // Cross-entropy gradient: dz_k = p_k - 1{k == best}
368 // SGD update: W_k -= lr * dz_k * x, b_k -= lr * dz_k
369 for (i, (weights_row, bias)) in self
370 .gate_weights
371 .iter_mut()
372 .zip(self.gate_bias.iter_mut())
373 .enumerate()
374 {
375 let indicator = if i == best_idx { 1.0 } else { 0.0 };
376 let grad = probs[i] - indicator;
377 let lr = self.gate_lr;
378
379 for (j, &xj) in features.iter().enumerate() {
380 weights_row[j] -= lr * grad * xj;
381 }
382 *bias -= lr * grad;
383 }
384
385 self.samples_seen += 1;
386 }
387
388 /// Train on a batch of observations.
389 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
390 for sample in samples {
391 self.train_one(sample);
392 }
393 }
394
395 // -------------------------------------------------------------------
396 // Public API -- prediction
397 // -------------------------------------------------------------------
398
399 /// Predict the output for a feature vector.
400 ///
401 /// Computes the probability-weighted sum of expert predictions:
402 /// `ŷ = Σ_k p_k(x) · f_k(x)`.
403 pub fn predict(&self, features: &[f64]) -> f64 {
404 let probs = self.gating_probabilities(features);
405 let mut pred = 0.0;
406 for (i, &p) in probs.iter().enumerate() {
407 pred += p * self.experts[i].predict(features);
408 }
409 pred
410 }
411
412 /// Predict with gating probabilities returned alongside the prediction.
413 ///
414 /// Returns `(prediction, probabilities)` where probabilities is a K-length
415 /// vector summing to 1.0.
416 pub fn predict_with_gating(&self, features: &[f64]) -> (f64, Vec<f64>) {
417 let probs = self.gating_probabilities(features);
418 let mut pred = 0.0;
419 for (i, &p) in probs.iter().enumerate() {
420 pred += p * self.experts[i].predict(features);
421 }
422 (pred, probs)
423 }
424
425 /// Get each expert's individual prediction for a feature vector.
426 ///
427 /// Returns a K-length vector of raw predictions, one per expert.
428 pub fn expert_predictions(&self, features: &[f64]) -> Vec<f64> {
429 self.experts.iter().map(|e| e.predict(features)).collect()
430 }
431
432 // -------------------------------------------------------------------
433 // Public API -- inspection
434 // -------------------------------------------------------------------
435
436 /// Number of experts in the mixture.
437 #[inline]
438 pub fn n_experts(&self) -> usize {
439 self.experts.len()
440 }
441
442 /// Total training samples seen.
443 #[inline]
444 pub fn n_samples_seen(&self) -> u64 {
445 self.samples_seen
446 }
447
448 /// Immutable access to all experts.
449 pub fn experts(&self) -> &[SGBT<L>] {
450 &self.experts
451 }
452
453 /// Immutable access to a specific expert.
454 ///
455 /// # Panics
456 ///
457 /// Panics if `idx >= n_experts`.
458 pub fn expert(&self, idx: usize) -> &SGBT<L> {
459 &self.experts[idx]
460 }
461
462 /// Reset the entire MoE to its initial state.
463 ///
464 /// Resets all experts, clears gate weights and biases back to zeros,
465 /// and resets the sample counter.
466 pub fn reset(&mut self) {
467 for expert in &mut self.experts {
468 expert.reset();
469 }
470 let k = self.experts.len();
471 self.gate_weights.clear();
472 self.gate_bias = vec![0.0; k];
473 self.n_features = None;
474 self.samples_seen = 0;
475 }
476}
477
478// ---------------------------------------------------------------------------
479// StreamingLearner impl
480// ---------------------------------------------------------------------------
481
482use crate::learner::StreamingLearner;
483
484impl<L: Loss> StreamingLearner for MoESGBT<L> {
485 fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
486 let sample = SampleRef::weighted(features, target, weight);
487 // UFCS: call the inherent train_one(&impl Observation), not this trait method.
488 MoESGBT::train_one(self, &sample);
489 }
490
491 fn predict(&self, features: &[f64]) -> f64 {
492 MoESGBT::predict(self, features)
493 }
494
495 fn n_samples_seen(&self) -> u64 {
496 self.samples_seen
497 }
498
499 fn reset(&mut self) {
500 MoESGBT::reset(self);
501 }
502}
503
504// ===========================================================================
505// Tests
506// ===========================================================================
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511 use crate::loss::huber::HuberLoss;
512 use crate::sample::Sample;
513 use alloc::boxed::Box;
514 use alloc::vec;
515 use alloc::vec::Vec;
516
517 /// Helper: build a minimal config for tests.
518 fn test_config() -> SGBTConfig {
519 SGBTConfig::builder()
520 .n_steps(5)
521 .learning_rate(0.1)
522 .grace_period(5)
523 .build()
524 .unwrap()
525 }
526
527 #[test]
528 fn test_creation() {
529 let moe = MoESGBT::new(test_config(), 3);
530 assert_eq!(moe.n_experts(), 3);
531 assert_eq!(moe.n_samples_seen(), 0);
532 }
533
534 #[test]
535 fn test_with_loss() {
536 let moe = MoESGBT::with_loss(test_config(), HuberLoss { delta: 1.0 }, 4);
537 assert_eq!(moe.n_experts(), 4);
538 assert_eq!(moe.n_samples_seen(), 0);
539 }
540
541 #[test]
542 fn test_soft_gating_trains_all() {
543 let mut moe = MoESGBT::new(test_config(), 3);
544 let sample = Sample::new(vec![1.0, 2.0], 5.0);
545
546 moe.train_one(&sample);
547
548 // In soft mode, every expert should have seen the sample
549 for i in 0..3 {
550 assert_eq!(moe.expert(i).n_samples_seen(), 1);
551 }
552 }
553
554 #[test]
555 fn test_hard_gating_top_k() {
556 let mut moe = MoESGBT::with_gating(
557 test_config(),
558 SquaredLoss,
559 4,
560 GatingMode::Hard { top_k: 2 },
561 0.01,
562 );
563 let sample = Sample::new(vec![1.0, 2.0], 5.0);
564
565 moe.train_one(&sample);
566
567 // Exactly top_k=2 experts should have received the sample
568 let trained_count = (0..4)
569 .filter(|&i| moe.expert(i).n_samples_seen() > 0)
570 .count();
571 assert_eq!(trained_count, 2);
572 }
573
574 #[test]
575 fn test_gating_probabilities_sum_to_one() {
576 let mut moe = MoESGBT::new(test_config(), 5);
577
578 // Before training: uniform probabilities
579 let probs = moe.gating_probabilities(&[1.0, 2.0]);
580 let sum: f64 = probs.iter().sum();
581 assert!((sum - 1.0).abs() < 1e-10, "pre-training sum = {}", sum);
582
583 // After training: probabilities should still sum to 1
584 for i in 0..20 {
585 let sample = Sample::new(vec![i as f64, (i * 2) as f64], i as f64);
586 moe.train_one(&sample);
587 }
588 let probs = moe.gating_probabilities(&[5.0, 10.0]);
589 let sum: f64 = probs.iter().sum();
590 assert!((sum - 1.0).abs() < 1e-10, "post-training sum = {}", sum);
591 }
592
593 #[test]
594 fn test_prediction_changes_after_training() {
595 let mut moe = MoESGBT::new(test_config(), 3);
596 let features = vec![1.0, 2.0, 3.0];
597
598 let pred_before = moe.predict(&features);
599
600 for i in 0..50 {
601 let sample = Sample::new(features.clone(), 10.0 + i as f64 * 0.1);
602 moe.train_one(&sample);
603 }
604
605 let pred_after = moe.predict(&features);
606 assert!(
607 (pred_after - pred_before).abs() > 1e-6,
608 "prediction should change after training: before={}, after={}",
609 pred_before,
610 pred_after
611 );
612 }
613
614 #[test]
615 fn test_expert_specialization() {
616 // Two regions: x < 0 targets ~-10, x >= 0 targets ~+10
617 let mut moe = MoESGBT::with_gating(test_config(), SquaredLoss, 2, GatingMode::Soft, 0.05);
618
619 // Train with separable data
620 for i in 0..200 {
621 let x = if i % 2 == 0 {
622 -(i as f64 + 1.0)
623 } else {
624 i as f64 + 1.0
625 };
626 let target = if x < 0.0 { -10.0 } else { 10.0 };
627 let sample = Sample::new(vec![x], target);
628 moe.train_one(&sample);
629 }
630
631 // After training, the gating probabilities should differ for
632 // negative vs positive inputs
633 let probs_neg = moe.gating_probabilities(&[-5.0]);
634 let probs_pos = moe.gating_probabilities(&[5.0]);
635
636 // The dominant expert should be different (or at least the distributions
637 // should be noticeably different)
638 let diff: f64 = probs_neg
639 .iter()
640 .zip(probs_pos.iter())
641 .map(|(a, b)| (a - b).abs())
642 .sum();
643 assert!(
644 diff > 0.01,
645 "gate should route differently: neg={:?}, pos={:?}",
646 probs_neg,
647 probs_pos
648 );
649 }
650
651 #[test]
652 fn test_predict_with_gating() {
653 let mut moe = MoESGBT::new(test_config(), 3);
654 let sample = Sample::new(vec![1.0, 2.0], 5.0);
655 moe.train_one(&sample);
656
657 let (pred, probs) = moe.predict_with_gating(&[1.0, 2.0]);
658 assert_eq!(probs.len(), 3);
659 let sum: f64 = probs.iter().sum();
660 assert!((sum - 1.0).abs() < 1e-10);
661
662 // Prediction should equal weighted sum of expert predictions
663 let expert_preds = moe.expert_predictions(&[1.0, 2.0]);
664 let expected: f64 = probs
665 .iter()
666 .zip(expert_preds.iter())
667 .map(|(p, e)| p * e)
668 .sum();
669 assert!(
670 (pred - expected).abs() < 1e-10,
671 "pred={} expected={}",
672 pred,
673 expected
674 );
675 }
676
677 #[test]
678 fn test_expert_predictions() {
679 let mut moe = MoESGBT::new(test_config(), 3);
680 for i in 0..10 {
681 let sample = Sample::new(vec![i as f64], i as f64);
682 moe.train_one(&sample);
683 }
684
685 let preds = moe.expert_predictions(&[5.0]);
686 assert_eq!(preds.len(), 3);
687 // Each expert should produce a finite prediction
688 for &p in &preds {
689 assert!(p.is_finite(), "expert prediction should be finite: {}", p);
690 }
691 }
692
693 #[test]
694 fn test_n_experts() {
695 let moe = MoESGBT::new(test_config(), 7);
696 assert_eq!(moe.n_experts(), 7);
697 assert_eq!(moe.experts().len(), 7);
698 }
699
700 #[test]
701 fn test_n_samples_seen() {
702 let mut moe = MoESGBT::new(test_config(), 2);
703 assert_eq!(moe.n_samples_seen(), 0);
704
705 for i in 0..25 {
706 moe.train_one(&Sample::new(vec![i as f64], i as f64));
707 }
708 assert_eq!(moe.n_samples_seen(), 25);
709 }
710
711 #[test]
712 fn test_reset() {
713 let mut moe = MoESGBT::new(test_config(), 3);
714
715 for i in 0..50 {
716 moe.train_one(&Sample::new(vec![i as f64, (i * 2) as f64], i as f64));
717 }
718 assert_eq!(moe.n_samples_seen(), 50);
719
720 moe.reset();
721
722 assert_eq!(moe.n_samples_seen(), 0);
723 assert_eq!(moe.n_experts(), 3);
724 // Gate should be re-lazily-initialized
725 let probs = moe.gating_probabilities(&[1.0, 2.0]);
726 assert_eq!(probs.len(), 3);
727 // After reset, probabilities are uniform again
728 for &p in &probs {
729 assert!(
730 (p - 1.0 / 3.0).abs() < 1e-10,
731 "expected uniform after reset, got {}",
732 p
733 );
734 }
735 }
736
737 #[test]
738 fn test_single_expert() {
739 // With a single expert, MoE should behave like a plain SGBT
740 let config = test_config();
741 let mut moe = MoESGBT::new(config.clone(), 1);
742
743 let mut plain = SGBT::new({
744 let mut cfg = config.clone();
745 cfg.seed = config.seed ^ 0x0000_0E00_0000_0000;
746 cfg
747 });
748
749 // The single expert gets weight=1.0 always, so predictions should
750 // be very close (both see same data, same seed)
751 for i in 0..30 {
752 let sample = Sample::new(vec![i as f64], i as f64 * 2.0);
753 moe.train_one(&sample);
754 // For the plain SGBT, we need to replicate the soft-gating weight.
755 // With one expert, p=1.0, so SampleRef::weighted(features, target, 1.0)
756 // is equivalent to a normal sample (weight=1.0).
757 let weighted = SampleRef::weighted(&sample.features, sample.target, 1.0);
758 plain.train_one(&weighted);
759 }
760
761 let moe_pred = moe.predict(&[15.0]);
762 let plain_pred = plain.predict(&[15.0]);
763 assert!(
764 (moe_pred - plain_pred).abs() < 1e-6,
765 "single expert MoE should match plain SGBT: moe={}, plain={}",
766 moe_pred,
767 plain_pred
768 );
769 }
770
771 #[test]
772 fn test_gate_lr_effect() {
773 // A higher gate learning rate should cause the gate to diverge from
774 // uniform faster than a lower one.
775 let config = test_config();
776
777 let mut moe_low =
778 MoESGBT::with_gating(config.clone(), SquaredLoss, 3, GatingMode::Soft, 0.001);
779 let mut moe_high = MoESGBT::with_gating(config, SquaredLoss, 3, GatingMode::Soft, 0.1);
780
781 // Train both on the same data
782 for i in 0..50 {
783 let sample = Sample::new(vec![i as f64], i as f64);
784 moe_low.train_one(&sample);
785 moe_high.train_one(&sample);
786 }
787
788 // Measure deviation from uniform for both
789 let uniform = 1.0 / 3.0;
790 let probs_low = moe_low.gating_probabilities(&[25.0]);
791 let probs_high = moe_high.gating_probabilities(&[25.0]);
792
793 let dev_low: f64 = probs_low.iter().map(|p| (p - uniform).abs()).sum();
794 let dev_high: f64 = probs_high.iter().map(|p| (p - uniform).abs()).sum();
795
796 assert!(
797 dev_high > dev_low,
798 "higher gate_lr should cause more deviation from uniform: low={}, high={}",
799 dev_low,
800 dev_high
801 );
802 }
803
804 #[test]
805 fn test_batch_training() {
806 let mut moe = MoESGBT::new(test_config(), 3);
807
808 let samples: Vec<Sample> = (0..20)
809 .map(|i| Sample::new(vec![i as f64, (i * 3) as f64], i as f64))
810 .collect();
811
812 moe.train_batch(&samples);
813
814 assert_eq!(moe.n_samples_seen(), 20);
815
816 // Should produce non-zero predictions after batch training
817 let pred = moe.predict(&[10.0, 30.0]);
818 assert!(pred.is_finite());
819 }
820
821 #[test]
822 fn streaming_learner_trait_object() {
823 let config = test_config();
824 let model = MoESGBT::new(config, 3);
825 let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
826 for i in 0..100 {
827 let x = i as f64 * 0.1;
828 boxed.train(&[x], x * 2.0);
829 }
830 assert_eq!(boxed.n_samples_seen(), 100);
831 let pred = boxed.predict(&[5.0]);
832 assert!(pred.is_finite());
833 boxed.reset();
834 assert_eq!(boxed.n_samples_seen(), 0);
835 }
836}