irithyll_core/ensemble/moe.rs
1//! Streaming Mixture of Experts over SGBT ensembles.
2//!
3//! Implements a gated mixture of K independent [`SGBT`] experts with a learned
4//! linear softmax gate. Each expert is a full streaming gradient boosted tree
5//! ensemble; the gate routes incoming samples to the most relevant expert(s)
6//! based on the feature vector, enabling capacity specialization across
7//! different regions of the input space.
8//!
9//! # Algorithm
10//!
11//! The gating network computes K logits `z_k = W_k · x + b_k` and applies
12//! softmax to obtain routing probabilities `p_k = softmax(z)_k`. Prediction
13//! is the probability-weighted sum of expert predictions:
14//!
15//! ```text
16//! ŷ = Σ_k p_k(x) · f_k(x)
17//! ```
18//!
19//! During training, the gate is updated via online SGD on the cross-entropy
20//! loss between the softmax distribution and the one-hot indicator of the
21//! best expert (lowest loss on the current sample). This encourages the gate
22//! to learn which expert is most competent for each region.
23//!
24//! Two gating modes are supported:
25//!
26//! - **Soft** (default): All experts receive every sample, weighted by their
27//! gating probability. This maximizes information flow but has O(K) training
28//! cost per sample.
29//! - **Hard (top-k)**: Only the top-k experts (by gating probability) receive
30//! the sample. This reduces computation when K is large, at the cost of
31//! slower expert specialization.
32//!
33//! # References
34//!
35//! - Jacobs, R. A., Jordan, M. I., Nowlan, S. J., & Hinton, G. E. (1991).
36//! Adaptive Mixtures of Local Experts. *Neural Computation*, 3(1), 79–87.
37//! - Shazeer, N., Mirhoseini, A., Maziarz, K., Davis, A., Le, Q., Hinton, G.,
38//! & Dean, J. (2017). Outrageously Large Neural Networks: The Sparsely-Gated
39//! Mixture-of-Experts Layer. *ICLR 2017*.
40//!
41//! # Example
42//!
43//! ```text
44//! use irithyll::ensemble::moe::{MoESGBT, GatingMode};
45//! use irithyll::SGBTConfig;
46//!
47//! let config = SGBTConfig::builder()
48//! .n_steps(10)
49//! .learning_rate(0.1)
50//! .grace_period(10)
51//! .build()
52//! .unwrap();
53//!
54//! let mut moe = MoESGBT::new(config, 3);
55//! moe.train_one(&irithyll::Sample::new(vec![1.0, 2.0], 3.0));
56//! let pred = moe.predict(&[1.0, 2.0]);
57//! ```
58
59use alloc::vec;
60use alloc::vec::Vec;
61
62use core::fmt;
63
64use crate::ensemble::config::SGBTConfig;
65use crate::ensemble::SGBT;
66use crate::loss::squared::SquaredLoss;
67use crate::loss::Loss;
68use crate::sample::{Observation, SampleRef};
69
70// ---------------------------------------------------------------------------
71// GatingMode
72// ---------------------------------------------------------------------------
73
74/// Controls how the gate routes samples to experts.
75///
76/// - [`Soft`](GatingMode::Soft): every expert sees every sample, weighted by
77/// gating probability. Maximizes information flow.
78/// - [`Hard`](GatingMode::Hard): only the `top_k` experts with highest gating
79/// probability receive the sample. Reduces cost when K is large.
80#[derive(Debug, Clone)]
81#[non_exhaustive]
82pub enum GatingMode {
83 /// All experts receive every sample, weighted by gating probability.
84 Soft,
85 /// Only the top-k experts receive the sample (sparse routing).
86 Hard {
87 /// Number of experts to route each sample to.
88 top_k: usize,
89 },
90}
91
92// ---------------------------------------------------------------------------
93// Softmax (numerically stable)
94// ---------------------------------------------------------------------------
95
96/// Numerically stable softmax: subtract max logit before exponentiating to
97/// prevent overflow, then normalize.
98pub(crate) fn softmax(logits: &[f64]) -> Vec<f64> {
99 let max = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
100 let exps: Vec<f64> = logits.iter().map(|&z| crate::math::exp(z - max)).collect();
101 let sum: f64 = exps.iter().sum();
102 exps.iter().map(|&e| e / sum).collect()
103}
104
105// ---------------------------------------------------------------------------
106// MoESGBT
107// ---------------------------------------------------------------------------
108
109/// Streaming Mixture of Experts over SGBT ensembles.
110///
111/// Combines K independent [`SGBT<L>`] experts with a learned linear softmax
112/// gating network. The gate is trained online via SGD to route samples to the
113/// expert with the lowest loss, while all experts (or the top-k in hard gating
114/// mode) are trained on each incoming sample.
115///
116/// Generic over `L: Loss` so the expert loss function is monomorphized. The
117/// default is [`SquaredLoss`] for regression tasks.
118///
119/// # Gate Architecture
120///
121/// The gate is a single linear layer: `z_k = W_k · x + b_k` followed by
122/// softmax. Weights are lazily initialized to zeros on the first sample
123/// (since the feature dimensionality is not known at construction time).
124/// The gate learns via cross-entropy gradient descent against the one-hot
125/// indicator of the best expert per sample.
126pub struct MoESGBT<L: Loss = SquaredLoss> {
127 /// The K expert SGBT ensembles.
128 experts: Vec<SGBT<L>>,
129 /// Gate weight matrix [K x d], lazily initialized on first sample.
130 gate_weights: Vec<Vec<f64>>,
131 /// Gate bias vector [K].
132 gate_bias: Vec<f64>,
133 /// Learning rate for the gating network SGD updates.
134 gate_lr: f64,
135 /// Number of features (set on first sample, `None` until then).
136 n_features: Option<usize>,
137 /// Gating mode (soft or hard top-k).
138 gating_mode: GatingMode,
139 /// Configuration used to construct each expert.
140 config: SGBTConfig,
141 /// Loss function (shared type with experts, used for best-expert selection).
142 loss: L,
143 /// Total training samples seen.
144 samples_seen: u64,
145}
146
147// ---------------------------------------------------------------------------
148// Clone
149// ---------------------------------------------------------------------------
150
151impl<L: Loss + Clone> Clone for MoESGBT<L> {
152 fn clone(&self) -> Self {
153 Self {
154 experts: self.experts.clone(),
155 gate_weights: self.gate_weights.clone(),
156 gate_bias: self.gate_bias.clone(),
157 gate_lr: self.gate_lr,
158 n_features: self.n_features,
159 gating_mode: self.gating_mode.clone(),
160 config: self.config.clone(),
161 loss: self.loss.clone(),
162 samples_seen: self.samples_seen,
163 }
164 }
165}
166
167// ---------------------------------------------------------------------------
168// Debug
169// ---------------------------------------------------------------------------
170
171impl<L: Loss> fmt::Debug for MoESGBT<L> {
172 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173 f.debug_struct("MoESGBT")
174 .field("n_experts", &self.experts.len())
175 .field("gating_mode", &self.gating_mode)
176 .field("samples_seen", &self.samples_seen)
177 .finish()
178 }
179}
180
181// ---------------------------------------------------------------------------
182// Default loss constructor (SquaredLoss)
183// ---------------------------------------------------------------------------
184
185impl MoESGBT<SquaredLoss> {
186 /// Create a new MoE ensemble with squared loss (regression) and soft gating.
187 ///
188 /// Each expert is seeded uniquely via `config.seed ^ (0x0000_0E00_0000_0000 | i)`.
189 /// The gating learning rate defaults to 0.01.
190 ///
191 /// # Panics
192 ///
193 /// Panics if `n_experts < 1`.
194 pub fn new(config: SGBTConfig, n_experts: usize) -> Self {
195 Self::with_loss(config, SquaredLoss, n_experts)
196 }
197}
198
199// ---------------------------------------------------------------------------
200// General impl
201// ---------------------------------------------------------------------------
202
203impl<L: Loss + Clone> MoESGBT<L> {
204 /// Create a new MoE ensemble with a custom loss and soft gating.
205 ///
206 /// # Panics
207 ///
208 /// Panics if `n_experts < 1`.
209 pub fn with_loss(config: SGBTConfig, loss: L, n_experts: usize) -> Self {
210 Self::with_gating(config, loss, n_experts, GatingMode::Soft, 0.01)
211 }
212
213 /// Create a new MoE ensemble with full control over gating mode and gate
214 /// learning rate.
215 ///
216 /// # Panics
217 ///
218 /// Panics if `n_experts < 1`.
219 pub fn with_gating(
220 config: SGBTConfig,
221 loss: L,
222 n_experts: usize,
223 gating_mode: GatingMode,
224 gate_lr: f64,
225 ) -> Self {
226 assert!(n_experts >= 1, "MoESGBT requires at least 1 expert");
227
228 let experts = (0..n_experts)
229 .map(|i| {
230 let mut cfg = config.clone();
231 cfg.seed = config.seed ^ (0x0000_0E00_0000_0000 | i as u64);
232 SGBT::with_loss(cfg, loss.clone())
233 })
234 .collect();
235
236 let gate_bias = vec![0.0; n_experts];
237
238 Self {
239 experts,
240 gate_weights: Vec::new(), // lazy init
241 gate_bias,
242 gate_lr,
243 n_features: None,
244 gating_mode,
245 config,
246 loss,
247 samples_seen: 0,
248 }
249 }
250}
251
252impl<L: Loss> MoESGBT<L> {
253 // -------------------------------------------------------------------
254 // Internal helpers
255 // -------------------------------------------------------------------
256
257 /// Ensure the gate weight matrix is initialized to the correct dimensions.
258 /// Called lazily on the first sample when `n_features` is discovered.
259 fn ensure_gate_init(&mut self, d: usize) {
260 if self.n_features.is_none() {
261 let k = self.experts.len();
262 self.gate_weights = vec![vec![0.0; d]; k];
263 self.n_features = Some(d);
264 }
265 }
266
267 /// Compute raw gate logits: z_k = W_k · x + b_k.
268 fn gate_logits(&self, features: &[f64]) -> Vec<f64> {
269 let k = self.experts.len();
270 let mut logits = Vec::with_capacity(k);
271 for i in 0..k {
272 let dot: f64 = self.gate_weights[i]
273 .iter()
274 .zip(features.iter())
275 .map(|(&w, &x)| w * x)
276 .sum();
277 logits.push(dot + self.gate_bias[i]);
278 }
279 logits
280 }
281
282 // -------------------------------------------------------------------
283 // Public API -- gating
284 // -------------------------------------------------------------------
285
286 /// Compute gating probabilities for a feature vector.
287 ///
288 /// Returns a vector of K probabilities that sum to 1.0, one per expert.
289 /// The gate must be initialized (at least one training sample seen),
290 /// otherwise returns uniform probabilities.
291 pub fn gating_probabilities(&self, features: &[f64]) -> Vec<f64> {
292 let k = self.experts.len();
293 if self.n_features.is_none() {
294 // Gate not initialized yet -- return uniform
295 return vec![1.0 / k as f64; k];
296 }
297 let logits = self.gate_logits(features);
298 softmax(&logits)
299 }
300
301 // -------------------------------------------------------------------
302 // Public API -- training
303 // -------------------------------------------------------------------
304
305 /// Train on a single observation.
306 ///
307 /// 1. Lazily initializes the gate weights if this is the first sample.
308 /// 2. Computes gating probabilities via softmax over the linear gate.
309 /// 3. Routes the sample to experts according to the gating mode:
310 /// - **Soft**: all experts receive the sample, each weighted by its
311 /// gating probability (via `SampleRef::weighted`).
312 /// - **Hard(top_k)**: only the top-k experts by probability receive
313 /// the sample (with unit weight).
314 /// 4. Updates gate weights via SGD on the cross-entropy gradient:
315 /// find the best expert (lowest loss), compute `dz_k = p_k - 1{k==best}`,
316 /// and apply `W_k -= gate_lr * dz_k * x`, `b_k -= gate_lr * dz_k`.
317 pub fn train_one(&mut self, sample: &impl Observation) {
318 let features = sample.features();
319 let target = sample.target();
320 let d = features.len();
321
322 // Step 1: lazy gate initialization
323 self.ensure_gate_init(d);
324
325 // Step 2: compute gating probabilities
326 let logits = self.gate_logits(features);
327 let probs = softmax(&logits);
328 let k = self.experts.len();
329
330 // Step 3: train experts based on gating mode
331 match &self.gating_mode {
332 GatingMode::Soft => {
333 // Every expert gets the sample, weighted by gating probability
334 for (expert, &prob) in self.experts.iter_mut().zip(probs.iter()) {
335 let weighted = SampleRef::weighted(features, target, prob);
336 expert.train_one(&weighted);
337 }
338 }
339 GatingMode::Hard { top_k } => {
340 // Only the top-k experts get the sample
341 let top_k = (*top_k).min(k);
342 let mut indices: Vec<usize> = (0..k).collect();
343 indices.sort_unstable_by(|&a, &b| {
344 probs[b]
345 .partial_cmp(&probs[a])
346 .unwrap_or(core::cmp::Ordering::Equal)
347 });
348 for &i in indices.iter().take(top_k) {
349 let obs = SampleRef::new(features, target);
350 self.experts[i].train_one(&obs);
351 }
352 }
353 }
354
355 // Step 4: update gate weights via SGD on cross-entropy gradient
356 // Find best expert (lowest loss on this sample)
357 let mut best_idx = 0;
358 let mut best_loss = f64::INFINITY;
359 for (i, expert) in self.experts.iter().enumerate() {
360 let pred = expert.predict(features);
361 let l = self.loss.loss(target, pred);
362 if l < best_loss {
363 best_loss = l;
364 best_idx = i;
365 }
366 }
367
368 // Cross-entropy gradient: dz_k = p_k - 1{k == best}
369 // SGD update: W_k -= lr * dz_k * x, b_k -= lr * dz_k
370 for (i, (weights_row, bias)) in self
371 .gate_weights
372 .iter_mut()
373 .zip(self.gate_bias.iter_mut())
374 .enumerate()
375 {
376 let indicator = if i == best_idx { 1.0 } else { 0.0 };
377 let grad = probs[i] - indicator;
378 let lr = self.gate_lr;
379
380 for (j, &xj) in features.iter().enumerate() {
381 weights_row[j] -= lr * grad * xj;
382 }
383 *bias -= lr * grad;
384 }
385
386 self.samples_seen += 1;
387 }
388
389 /// Train on a batch of observations.
390 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
391 for sample in samples {
392 self.train_one(sample);
393 }
394 }
395
396 // -------------------------------------------------------------------
397 // Public API -- prediction
398 // -------------------------------------------------------------------
399
400 /// Predict the output for a feature vector.
401 ///
402 /// Computes the probability-weighted sum of expert predictions:
403 /// `ŷ = Σ_k p_k(x) · f_k(x)`.
404 pub fn predict(&self, features: &[f64]) -> f64 {
405 let probs = self.gating_probabilities(features);
406 let mut pred = 0.0;
407 for (i, &p) in probs.iter().enumerate() {
408 pred += p * self.experts[i].predict(features);
409 }
410 pred
411 }
412
413 /// Predict with gating probabilities returned alongside the prediction.
414 ///
415 /// Returns `(prediction, probabilities)` where probabilities is a K-length
416 /// vector summing to 1.0.
417 pub fn predict_with_gating(&self, features: &[f64]) -> (f64, Vec<f64>) {
418 let probs = self.gating_probabilities(features);
419 let mut pred = 0.0;
420 for (i, &p) in probs.iter().enumerate() {
421 pred += p * self.experts[i].predict(features);
422 }
423 (pred, probs)
424 }
425
426 /// Get each expert's individual prediction for a feature vector.
427 ///
428 /// Returns a K-length vector of raw predictions, one per expert.
429 pub fn expert_predictions(&self, features: &[f64]) -> Vec<f64> {
430 self.experts.iter().map(|e| e.predict(features)).collect()
431 }
432
433 // -------------------------------------------------------------------
434 // Public API -- inspection
435 // -------------------------------------------------------------------
436
437 /// Number of experts in the mixture.
438 #[inline]
439 pub fn n_experts(&self) -> usize {
440 self.experts.len()
441 }
442
443 /// Total training samples seen.
444 #[inline]
445 pub fn n_samples_seen(&self) -> u64 {
446 self.samples_seen
447 }
448
449 /// Immutable access to all experts.
450 pub fn experts(&self) -> &[SGBT<L>] {
451 &self.experts
452 }
453
454 /// Immutable access to a specific expert.
455 ///
456 /// # Panics
457 ///
458 /// Panics if `idx >= n_experts`.
459 pub fn expert(&self, idx: usize) -> &SGBT<L> {
460 &self.experts[idx]
461 }
462
463 /// Reset the entire MoE to its initial state.
464 ///
465 /// Resets all experts, clears gate weights and biases back to zeros,
466 /// and resets the sample counter.
467 pub fn reset(&mut self) {
468 for expert in &mut self.experts {
469 expert.reset();
470 }
471 let k = self.experts.len();
472 self.gate_weights.clear();
473 self.gate_bias = vec![0.0; k];
474 self.n_features = None;
475 self.samples_seen = 0;
476 }
477}
478
479// ---------------------------------------------------------------------------
480// StreamingLearner impl
481// ---------------------------------------------------------------------------
482
483use crate::learner::StreamingLearner;
484
485impl<L: Loss> StreamingLearner for MoESGBT<L> {
486 fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
487 let sample = SampleRef::weighted(features, target, weight);
488 // UFCS: call the inherent train_one(&impl Observation), not this trait method.
489 MoESGBT::train_one(self, &sample);
490 }
491
492 fn predict(&self, features: &[f64]) -> f64 {
493 MoESGBT::predict(self, features)
494 }
495
496 fn n_samples_seen(&self) -> u64 {
497 self.samples_seen
498 }
499
500 fn reset(&mut self) {
501 MoESGBT::reset(self);
502 }
503}
504
505// ===========================================================================
506// Tests
507// ===========================================================================
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512 use crate::loss::huber::HuberLoss;
513 use crate::sample::Sample;
514 use alloc::boxed::Box;
515 use alloc::vec;
516 use alloc::vec::Vec;
517
518 /// Helper: build a minimal config for tests.
519 fn test_config() -> SGBTConfig {
520 SGBTConfig::builder()
521 .n_steps(5)
522 .learning_rate(0.1)
523 .grace_period(5)
524 .build()
525 .unwrap()
526 }
527
528 #[test]
529 fn test_creation() {
530 let moe = MoESGBT::new(test_config(), 3);
531 assert_eq!(moe.n_experts(), 3);
532 assert_eq!(moe.n_samples_seen(), 0);
533 }
534
535 #[test]
536 fn test_with_loss() {
537 let moe = MoESGBT::with_loss(test_config(), HuberLoss { delta: 1.0 }, 4);
538 assert_eq!(moe.n_experts(), 4);
539 assert_eq!(moe.n_samples_seen(), 0);
540 }
541
542 #[test]
543 fn test_soft_gating_trains_all() {
544 let mut moe = MoESGBT::new(test_config(), 3);
545 let sample = Sample::new(vec![1.0, 2.0], 5.0);
546
547 moe.train_one(&sample);
548
549 // In soft mode, every expert should have seen the sample
550 for i in 0..3 {
551 assert_eq!(moe.expert(i).n_samples_seen(), 1);
552 }
553 }
554
555 #[test]
556 fn test_hard_gating_top_k() {
557 let mut moe = MoESGBT::with_gating(
558 test_config(),
559 SquaredLoss,
560 4,
561 GatingMode::Hard { top_k: 2 },
562 0.01,
563 );
564 let sample = Sample::new(vec![1.0, 2.0], 5.0);
565
566 moe.train_one(&sample);
567
568 // Exactly top_k=2 experts should have received the sample
569 let trained_count = (0..4)
570 .filter(|&i| moe.expert(i).n_samples_seen() > 0)
571 .count();
572 assert_eq!(trained_count, 2);
573 }
574
575 #[test]
576 fn test_gating_probabilities_sum_to_one() {
577 let mut moe = MoESGBT::new(test_config(), 5);
578
579 // Before training: uniform probabilities
580 let probs = moe.gating_probabilities(&[1.0, 2.0]);
581 let sum: f64 = probs.iter().sum();
582 assert!((sum - 1.0).abs() < 1e-10, "pre-training sum = {}", sum);
583
584 // After training: probabilities should still sum to 1
585 for i in 0..20 {
586 let sample = Sample::new(vec![i as f64, (i * 2) as f64], i as f64);
587 moe.train_one(&sample);
588 }
589 let probs = moe.gating_probabilities(&[5.0, 10.0]);
590 let sum: f64 = probs.iter().sum();
591 assert!((sum - 1.0).abs() < 1e-10, "post-training sum = {}", sum);
592 }
593
594 #[test]
595 fn test_prediction_changes_after_training() {
596 let mut moe = MoESGBT::new(test_config(), 3);
597 let features = vec![1.0, 2.0, 3.0];
598
599 let pred_before = moe.predict(&features);
600
601 for i in 0..50 {
602 let sample = Sample::new(features.clone(), 10.0 + i as f64 * 0.1);
603 moe.train_one(&sample);
604 }
605
606 let pred_after = moe.predict(&features);
607 assert!(
608 (pred_after - pred_before).abs() > 1e-6,
609 "prediction should change after training: before={}, after={}",
610 pred_before,
611 pred_after
612 );
613 }
614
615 #[test]
616 fn test_expert_specialization() {
617 // Two regions: x < 0 targets ~-10, x >= 0 targets ~+10
618 let mut moe = MoESGBT::with_gating(test_config(), SquaredLoss, 2, GatingMode::Soft, 0.05);
619
620 // Train with separable data
621 for i in 0..200 {
622 let x = if i % 2 == 0 {
623 -(i as f64 + 1.0)
624 } else {
625 i as f64 + 1.0
626 };
627 let target = if x < 0.0 { -10.0 } else { 10.0 };
628 let sample = Sample::new(vec![x], target);
629 moe.train_one(&sample);
630 }
631
632 // After training, the gating probabilities should differ for
633 // negative vs positive inputs
634 let probs_neg = moe.gating_probabilities(&[-5.0]);
635 let probs_pos = moe.gating_probabilities(&[5.0]);
636
637 // The dominant expert should be different (or at least the distributions
638 // should be noticeably different)
639 let diff: f64 = probs_neg
640 .iter()
641 .zip(probs_pos.iter())
642 .map(|(a, b)| (a - b).abs())
643 .sum();
644 assert!(
645 diff > 0.01,
646 "gate should route differently: neg={:?}, pos={:?}",
647 probs_neg,
648 probs_pos
649 );
650 }
651
652 #[test]
653 fn test_predict_with_gating() {
654 let mut moe = MoESGBT::new(test_config(), 3);
655 let sample = Sample::new(vec![1.0, 2.0], 5.0);
656 moe.train_one(&sample);
657
658 let (pred, probs) = moe.predict_with_gating(&[1.0, 2.0]);
659 assert_eq!(probs.len(), 3);
660 let sum: f64 = probs.iter().sum();
661 assert!((sum - 1.0).abs() < 1e-10);
662
663 // Prediction should equal weighted sum of expert predictions
664 let expert_preds = moe.expert_predictions(&[1.0, 2.0]);
665 let expected: f64 = probs
666 .iter()
667 .zip(expert_preds.iter())
668 .map(|(p, e)| p * e)
669 .sum();
670 assert!(
671 (pred - expected).abs() < 1e-10,
672 "pred={} expected={}",
673 pred,
674 expected
675 );
676 }
677
678 #[test]
679 fn test_expert_predictions() {
680 let mut moe = MoESGBT::new(test_config(), 3);
681 for i in 0..10 {
682 let sample = Sample::new(vec![i as f64], i as f64);
683 moe.train_one(&sample);
684 }
685
686 let preds = moe.expert_predictions(&[5.0]);
687 assert_eq!(preds.len(), 3);
688 // Each expert should produce a finite prediction
689 for &p in &preds {
690 assert!(p.is_finite(), "expert prediction should be finite: {}", p);
691 }
692 }
693
694 #[test]
695 fn test_n_experts() {
696 let moe = MoESGBT::new(test_config(), 7);
697 assert_eq!(moe.n_experts(), 7);
698 assert_eq!(moe.experts().len(), 7);
699 }
700
701 #[test]
702 fn test_n_samples_seen() {
703 let mut moe = MoESGBT::new(test_config(), 2);
704 assert_eq!(moe.n_samples_seen(), 0);
705
706 for i in 0..25 {
707 moe.train_one(&Sample::new(vec![i as f64], i as f64));
708 }
709 assert_eq!(moe.n_samples_seen(), 25);
710 }
711
712 #[test]
713 fn test_reset() {
714 let mut moe = MoESGBT::new(test_config(), 3);
715
716 for i in 0..50 {
717 moe.train_one(&Sample::new(vec![i as f64, (i * 2) as f64], i as f64));
718 }
719 assert_eq!(moe.n_samples_seen(), 50);
720
721 moe.reset();
722
723 assert_eq!(moe.n_samples_seen(), 0);
724 assert_eq!(moe.n_experts(), 3);
725 // Gate should be re-lazily-initialized
726 let probs = moe.gating_probabilities(&[1.0, 2.0]);
727 assert_eq!(probs.len(), 3);
728 // After reset, probabilities are uniform again
729 for &p in &probs {
730 assert!(
731 (p - 1.0 / 3.0).abs() < 1e-10,
732 "expected uniform after reset, got {}",
733 p
734 );
735 }
736 }
737
738 #[test]
739 fn test_single_expert() {
740 // With a single expert, MoE should behave like a plain SGBT
741 let config = test_config();
742 let mut moe = MoESGBT::new(config.clone(), 1);
743
744 let mut plain = SGBT::new({
745 let mut cfg = config.clone();
746 cfg.seed = config.seed ^ 0x0000_0E00_0000_0000;
747 cfg
748 });
749
750 // The single expert gets weight=1.0 always, so predictions should
751 // be very close (both see same data, same seed)
752 for i in 0..30 {
753 let sample = Sample::new(vec![i as f64], i as f64 * 2.0);
754 moe.train_one(&sample);
755 // For the plain SGBT, we need to replicate the soft-gating weight.
756 // With one expert, p=1.0, so SampleRef::weighted(features, target, 1.0)
757 // is equivalent to a normal sample (weight=1.0).
758 let weighted = SampleRef::weighted(&sample.features, sample.target, 1.0);
759 plain.train_one(&weighted);
760 }
761
762 let moe_pred = moe.predict(&[15.0]);
763 let plain_pred = plain.predict(&[15.0]);
764 assert!(
765 (moe_pred - plain_pred).abs() < 1e-6,
766 "single expert MoE should match plain SGBT: moe={}, plain={}",
767 moe_pred,
768 plain_pred
769 );
770 }
771
772 #[test]
773 fn test_gate_lr_effect() {
774 // A higher gate learning rate should cause the gate to diverge from
775 // uniform faster than a lower one.
776 let config = test_config();
777
778 let mut moe_low =
779 MoESGBT::with_gating(config.clone(), SquaredLoss, 3, GatingMode::Soft, 0.001);
780 let mut moe_high = MoESGBT::with_gating(config, SquaredLoss, 3, GatingMode::Soft, 0.1);
781
782 // Train both on the same data
783 for i in 0..50 {
784 let sample = Sample::new(vec![i as f64], i as f64);
785 moe_low.train_one(&sample);
786 moe_high.train_one(&sample);
787 }
788
789 // Measure deviation from uniform for both
790 let uniform = 1.0 / 3.0;
791 let probs_low = moe_low.gating_probabilities(&[25.0]);
792 let probs_high = moe_high.gating_probabilities(&[25.0]);
793
794 let dev_low: f64 = probs_low.iter().map(|p| (p - uniform).abs()).sum();
795 let dev_high: f64 = probs_high.iter().map(|p| (p - uniform).abs()).sum();
796
797 assert!(
798 dev_high > dev_low,
799 "higher gate_lr should cause more deviation from uniform: low={}, high={}",
800 dev_low,
801 dev_high
802 );
803 }
804
805 #[test]
806 fn test_batch_training() {
807 let mut moe = MoESGBT::new(test_config(), 3);
808
809 let samples: Vec<Sample> = (0..20)
810 .map(|i| Sample::new(vec![i as f64, (i * 3) as f64], i as f64))
811 .collect();
812
813 moe.train_batch(&samples);
814
815 assert_eq!(moe.n_samples_seen(), 20);
816
817 // Should produce non-zero predictions after batch training
818 let pred = moe.predict(&[10.0, 30.0]);
819 assert!(pred.is_finite());
820 }
821
822 #[test]
823 fn streaming_learner_trait_object() {
824 let config = test_config();
825 let model = MoESGBT::new(config, 3);
826 let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
827 for i in 0..100 {
828 let x = i as f64 * 0.1;
829 boxed.train(&[x], x * 2.0);
830 }
831 assert_eq!(boxed.n_samples_seen(), 100);
832 let pred = boxed.predict(&[5.0]);
833 assert!(pred.is_finite());
834 boxed.reset();
835 assert_eq!(boxed.n_samples_seen(), 0);
836 }
837}