use std::fmt;
use crate::ensemble::config::SGBTConfig;
use crate::ensemble::SGBT;
use crate::loss::squared::SquaredLoss;
use crate::loss::Loss;
use crate::sample::{Observation, SampleRef};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum GatingMode {
Soft,
Hard {
top_k: usize,
},
}
pub(crate) fn softmax(logits: &[f64]) -> Vec<f64> {
let max = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = logits.iter().map(|&z| (z - max).exp()).collect();
let sum: f64 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
pub struct MoESGBT<L: Loss = SquaredLoss> {
experts: Vec<SGBT<L>>,
gate_weights: Vec<Vec<f64>>,
gate_bias: Vec<f64>,
gate_lr: f64,
n_features: Option<usize>,
gating_mode: GatingMode,
config: SGBTConfig,
loss: L,
samples_seen: u64,
}
impl<L: Loss + Clone> Clone for MoESGBT<L> {
fn clone(&self) -> Self {
Self {
experts: self.experts.clone(),
gate_weights: self.gate_weights.clone(),
gate_bias: self.gate_bias.clone(),
gate_lr: self.gate_lr,
n_features: self.n_features,
gating_mode: self.gating_mode.clone(),
config: self.config.clone(),
loss: self.loss.clone(),
samples_seen: self.samples_seen,
}
}
}
impl<L: Loss> fmt::Debug for MoESGBT<L> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MoESGBT")
.field("n_experts", &self.experts.len())
.field("gating_mode", &self.gating_mode)
.field("samples_seen", &self.samples_seen)
.finish()
}
}
impl MoESGBT<SquaredLoss> {
pub fn new(config: SGBTConfig, n_experts: usize) -> Self {
Self::with_loss(config, SquaredLoss, n_experts)
}
}
impl<L: Loss + Clone> MoESGBT<L> {
pub fn with_loss(config: SGBTConfig, loss: L, n_experts: usize) -> Self {
Self::with_gating(config, loss, n_experts, GatingMode::Soft, 0.01)
}
pub fn with_gating(
config: SGBTConfig,
loss: L,
n_experts: usize,
gating_mode: GatingMode,
gate_lr: f64,
) -> Self {
assert!(n_experts >= 1, "MoESGBT requires at least 1 expert");
let experts = (0..n_experts)
.map(|i| {
let mut cfg = config.clone();
cfg.seed = config.seed ^ (0x0000_0E00_0000_0000 | i as u64);
SGBT::with_loss(cfg, loss.clone())
})
.collect();
let gate_bias = vec![0.0; n_experts];
Self {
experts,
gate_weights: Vec::new(), gate_bias,
gate_lr,
n_features: None,
gating_mode,
config,
loss,
samples_seen: 0,
}
}
}
impl<L: Loss> MoESGBT<L> {
fn ensure_gate_init(&mut self, d: usize) {
if self.n_features.is_none() {
let k = self.experts.len();
self.gate_weights = vec![vec![0.0; d]; k];
self.n_features = Some(d);
}
}
fn gate_logits(&self, features: &[f64]) -> Vec<f64> {
let k = self.experts.len();
let mut logits = Vec::with_capacity(k);
for i in 0..k {
let dot: f64 = self.gate_weights[i]
.iter()
.zip(features.iter())
.map(|(&w, &x)| w * x)
.sum();
logits.push(dot + self.gate_bias[i]);
}
logits
}
pub fn gating_probabilities(&self, features: &[f64]) -> Vec<f64> {
let k = self.experts.len();
if self.n_features.is_none() {
return vec![1.0 / k as f64; k];
}
let logits = self.gate_logits(features);
softmax(&logits)
}
pub fn train_one(&mut self, sample: &impl Observation) {
let features = sample.features();
let target = sample.target();
let d = features.len();
self.ensure_gate_init(d);
let logits = self.gate_logits(features);
let probs = softmax(&logits);
let k = self.experts.len();
match &self.gating_mode {
GatingMode::Soft => {
for (expert, &prob) in self.experts.iter_mut().zip(probs.iter()) {
let weighted = SampleRef::weighted(features, target, prob);
expert.train_one(&weighted);
}
}
GatingMode::Hard { top_k } => {
let top_k = (*top_k).min(k);
let mut indices: Vec<usize> = (0..k).collect();
indices.sort_unstable_by(|&a, &b| {
probs[b]
.partial_cmp(&probs[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
for &i in indices.iter().take(top_k) {
let obs = SampleRef::new(features, target);
self.experts[i].train_one(&obs);
}
}
}
let mut best_idx = 0;
let mut best_loss = f64::INFINITY;
for (i, expert) in self.experts.iter().enumerate() {
let pred = expert.predict(features);
let l = self.loss.loss(target, pred);
if l < best_loss {
best_loss = l;
best_idx = i;
}
}
for (i, (weights_row, bias)) in self
.gate_weights
.iter_mut()
.zip(self.gate_bias.iter_mut())
.enumerate()
{
let indicator = if i == best_idx { 1.0 } else { 0.0 };
let grad = probs[i] - indicator;
let lr = self.gate_lr;
for (j, &xj) in features.iter().enumerate() {
weights_row[j] -= lr * grad * xj;
}
*bias -= lr * grad;
}
self.samples_seen += 1;
}
pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
for sample in samples {
self.train_one(sample);
}
}
pub fn predict(&self, features: &[f64]) -> f64 {
let probs = self.gating_probabilities(features);
let mut pred = 0.0;
for (i, &p) in probs.iter().enumerate() {
pred += p * self.experts[i].predict(features);
}
pred
}
pub fn predict_with_gating(&self, features: &[f64]) -> (f64, Vec<f64>) {
let probs = self.gating_probabilities(features);
let mut pred = 0.0;
for (i, &p) in probs.iter().enumerate() {
pred += p * self.experts[i].predict(features);
}
(pred, probs)
}
pub fn expert_predictions(&self, features: &[f64]) -> Vec<f64> {
self.experts.iter().map(|e| e.predict(features)).collect()
}
#[inline]
pub fn n_experts(&self) -> usize {
self.experts.len()
}
#[inline]
pub fn n_samples_seen(&self) -> u64 {
self.samples_seen
}
pub fn experts(&self) -> &[SGBT<L>] {
&self.experts
}
pub fn expert(&self, idx: usize) -> &SGBT<L> {
&self.experts[idx]
}
pub fn reset(&mut self) {
for expert in &mut self.experts {
expert.reset();
}
let k = self.experts.len();
self.gate_weights.clear();
self.gate_bias = vec![0.0; k];
self.n_features = None;
self.samples_seen = 0;
}
}
use crate::learner::StreamingLearner;
impl<L: Loss + Clone> StreamingLearner for MoESGBT<L> {
fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
let sample = SampleRef::weighted(features, target, weight);
MoESGBT::train_one(self, &sample);
}
fn predict(&self, features: &[f64]) -> f64 {
MoESGBT::predict(self, features)
}
fn n_samples_seen(&self) -> u64 {
self.samples_seen
}
fn reset(&mut self) {
MoESGBT::reset(self);
}
#[allow(deprecated)]
fn diagnostics_array(&self) -> [f64; 5] {
<Self as crate::learner::Tunable>::diagnostics_array(self)
}
#[allow(deprecated)]
fn replacement_count(&self) -> u64 {
<Self as crate::learner::Structural>::replacement_count(self)
}
#[allow(deprecated)]
fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
<Self as crate::learner::Tunable>::adjust_config(self, lr_multiplier, lambda_delta);
}
}
impl<L: Loss + Clone> crate::learner::Tunable for MoESGBT<L> {
fn diagnostics_array(&self) -> [f64; 5] {
use crate::automl::DiagnosticSource;
if let Some(first) = self.experts.first() {
use crate::learner::SGBTLearner;
let learner = SGBTLearner::new(first.clone());
match learner.config_diagnostics() {
Some(d) => [
d.residual_alignment,
d.regularization_sensitivity,
d.depth_sufficiency,
d.effective_dof,
d.uncertainty,
],
None => [0.0; 5],
}
} else {
[0.0; 5]
}
}
fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
for expert in &mut self.experts {
let new_lr = expert.config().learning_rate * lr_multiplier;
expert.set_learning_rate(new_lr);
let new_lambda = expert.config().lambda + lambda_delta;
expert.set_lambda(new_lambda);
}
}
}
impl<L: Loss + Clone> crate::learner::Structural for MoESGBT<L> {
fn apply_structural_change(&mut self, _depth_delta: i32, _steps_delta: i32) {
}
fn replacement_count(&self) -> u64 {
self.experts.iter().map(|e| e.total_replacements()).sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::loss::huber::HuberLoss;
use crate::sample::Sample;
fn test_config() -> SGBTConfig {
SGBTConfig::builder()
.n_steps(5)
.learning_rate(0.1)
.grace_period(5)
.build()
.unwrap()
}
#[test]
fn test_creation() {
let moe = MoESGBT::new(test_config(), 3);
assert_eq!(moe.n_experts(), 3);
assert_eq!(moe.n_samples_seen(), 0);
}
#[test]
fn test_with_loss() {
let moe = MoESGBT::with_loss(test_config(), HuberLoss { delta: 1.0 }, 4);
assert_eq!(moe.n_experts(), 4);
assert_eq!(moe.n_samples_seen(), 0);
}
#[test]
fn test_soft_gating_trains_all() {
let mut moe = MoESGBT::new(test_config(), 3);
let sample = Sample::new(vec![1.0, 2.0], 5.0);
moe.train_one(&sample);
for i in 0..3 {
assert_eq!(moe.expert(i).n_samples_seen(), 1);
}
}
#[test]
fn test_hard_gating_top_k() {
let mut moe = MoESGBT::with_gating(
test_config(),
SquaredLoss,
4,
GatingMode::Hard { top_k: 2 },
0.01,
);
let sample = Sample::new(vec![1.0, 2.0], 5.0);
moe.train_one(&sample);
let trained_count = (0..4)
.filter(|&i| moe.expert(i).n_samples_seen() > 0)
.count();
assert_eq!(trained_count, 2);
}
#[test]
fn test_gating_probabilities_sum_to_one() {
let mut moe = MoESGBT::new(test_config(), 5);
let probs = moe.gating_probabilities(&[1.0, 2.0]);
let sum: f64 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-10, "pre-training sum = {}", sum);
for i in 0..20 {
let sample = Sample::new(vec![i as f64, (i * 2) as f64], i as f64);
moe.train_one(&sample);
}
let probs = moe.gating_probabilities(&[5.0, 10.0]);
let sum: f64 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-10, "post-training sum = {}", sum);
}
#[test]
fn test_prediction_changes_after_training() {
let mut moe = MoESGBT::new(test_config(), 3);
let features = vec![1.0, 2.0, 3.0];
let pred_before = moe.predict(&features);
for i in 0..50 {
let sample = Sample::new(features.clone(), 10.0 + i as f64 * 0.1);
moe.train_one(&sample);
}
let pred_after = moe.predict(&features);
assert!(
(pred_after - pred_before).abs() > 1e-6,
"prediction should change after training: before={}, after={}",
pred_before,
pred_after
);
}
#[test]
fn test_expert_specialization() {
let mut moe = MoESGBT::with_gating(test_config(), SquaredLoss, 2, GatingMode::Soft, 0.05);
for i in 0..200 {
let x = if i % 2 == 0 {
-(i as f64 + 1.0)
} else {
i as f64 + 1.0
};
let target = if x < 0.0 { -10.0 } else { 10.0 };
let sample = Sample::new(vec![x], target);
moe.train_one(&sample);
}
let probs_neg = moe.gating_probabilities(&[-5.0]);
let probs_pos = moe.gating_probabilities(&[5.0]);
let diff: f64 = probs_neg
.iter()
.zip(probs_pos.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(
diff > 0.01,
"gate should route differently: neg={:?}, pos={:?}",
probs_neg,
probs_pos
);
}
#[test]
fn test_predict_with_gating() {
let mut moe = MoESGBT::new(test_config(), 3);
let sample = Sample::new(vec![1.0, 2.0], 5.0);
moe.train_one(&sample);
let (pred, probs) = moe.predict_with_gating(&[1.0, 2.0]);
assert_eq!(probs.len(), 3);
let sum: f64 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-10);
let expert_preds = moe.expert_predictions(&[1.0, 2.0]);
let expected: f64 = probs
.iter()
.zip(expert_preds.iter())
.map(|(p, e)| p * e)
.sum();
assert!(
(pred - expected).abs() < 1e-10,
"pred={} expected={}",
pred,
expected
);
}
#[test]
fn test_expert_predictions() {
let mut moe = MoESGBT::new(test_config(), 3);
for i in 0..10 {
let sample = Sample::new(vec![i as f64], i as f64);
moe.train_one(&sample);
}
let preds = moe.expert_predictions(&[5.0]);
assert_eq!(preds.len(), 3);
for &p in &preds {
assert!(p.is_finite(), "expert prediction should be finite: {}", p);
}
}
#[test]
fn test_n_experts() {
let moe = MoESGBT::new(test_config(), 7);
assert_eq!(moe.n_experts(), 7);
assert_eq!(moe.experts().len(), 7);
}
#[test]
fn test_n_samples_seen() {
let mut moe = MoESGBT::new(test_config(), 2);
assert_eq!(moe.n_samples_seen(), 0);
for i in 0..25 {
moe.train_one(&Sample::new(vec![i as f64], i as f64));
}
assert_eq!(moe.n_samples_seen(), 25);
}
#[test]
fn test_reset() {
let mut moe = MoESGBT::new(test_config(), 3);
for i in 0..50 {
moe.train_one(&Sample::new(vec![i as f64, (i * 2) as f64], i as f64));
}
assert_eq!(moe.n_samples_seen(), 50);
moe.reset();
assert_eq!(moe.n_samples_seen(), 0);
assert_eq!(moe.n_experts(), 3);
let probs = moe.gating_probabilities(&[1.0, 2.0]);
assert_eq!(probs.len(), 3);
for &p in &probs {
assert!(
(p - 1.0 / 3.0).abs() < 1e-10,
"expected uniform after reset, got {}",
p
);
}
}
#[test]
fn test_single_expert() {
let config = test_config();
let mut moe = MoESGBT::new(config.clone(), 1);
let mut plain = SGBT::new({
let mut cfg = config.clone();
cfg.seed = config.seed ^ 0x0000_0E00_0000_0000;
cfg
});
for i in 0..30 {
let sample = Sample::new(vec![i as f64], i as f64 * 2.0);
moe.train_one(&sample);
let weighted = SampleRef::weighted(&sample.features, sample.target, 1.0);
plain.train_one(&weighted);
}
let moe_pred = moe.predict(&[15.0]);
let plain_pred = plain.predict(&[15.0]);
assert!(
(moe_pred - plain_pred).abs() < 1e-6,
"single expert MoE should match plain SGBT: moe={}, plain={}",
moe_pred,
plain_pred
);
}
#[test]
fn test_gate_lr_effect() {
let config = test_config();
let mut moe_low =
MoESGBT::with_gating(config.clone(), SquaredLoss, 3, GatingMode::Soft, 0.001);
let mut moe_high = MoESGBT::with_gating(config, SquaredLoss, 3, GatingMode::Soft, 0.1);
for i in 0..50 {
let sample = Sample::new(vec![i as f64], i as f64);
moe_low.train_one(&sample);
moe_high.train_one(&sample);
}
let uniform = 1.0 / 3.0;
let probs_low = moe_low.gating_probabilities(&[25.0]);
let probs_high = moe_high.gating_probabilities(&[25.0]);
let dev_low: f64 = probs_low.iter().map(|p| (p - uniform).abs()).sum();
let dev_high: f64 = probs_high.iter().map(|p| (p - uniform).abs()).sum();
assert!(
dev_high > dev_low,
"higher gate_lr should cause more deviation from uniform: low={}, high={}",
dev_low,
dev_high
);
}
#[test]
fn test_batch_training() {
let mut moe = MoESGBT::new(test_config(), 3);
let samples: Vec<Sample> = (0..20)
.map(|i| Sample::new(vec![i as f64, (i * 3) as f64], i as f64))
.collect();
moe.train_batch(&samples);
assert_eq!(moe.n_samples_seen(), 20);
let pred = moe.predict(&[10.0, 30.0]);
assert!(pred.is_finite());
}
#[test]
fn streaming_learner_trait_object() {
let config = test_config();
let model = MoESGBT::new(config, 3);
let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
for i in 0..100 {
let x = i as f64 * 0.1;
boxed.train(&[x], x * 2.0);
}
assert_eq!(boxed.n_samples_seen(), 100);
let pred = boxed.predict(&[5.0]);
assert!(pred.is_finite());
boxed.reset();
assert_eq!(boxed.n_samples_seen(), 0);
}
}