pub struct MoESGBT<L: Loss = SquaredLoss> { /* private fields */ }alloc only.Expand description
Streaming Mixture of Experts over SGBT ensembles.
Combines K independent SGBT<L> experts with a learned linear softmax
gating network. The gate is trained online via SGD to route samples to the
expert with the lowest loss, while all experts (or the top-k in hard gating
mode) are trained on each incoming sample.
Generic over L: Loss so the expert loss function is monomorphized. The
default is SquaredLoss for regression tasks.
§Gate Architecture
The gate is a single linear layer: z_k = W_k · x + b_k followed by
softmax. Weights are lazily initialized to zeros on the first sample
(since the feature dimensionality is not known at construction time).
The gate learns via cross-entropy gradient descent against the one-hot
indicator of the best expert per sample.
Implementations§
Source§impl MoESGBT<SquaredLoss>
impl MoESGBT<SquaredLoss>
Sourcepub fn new(config: SGBTConfig, n_experts: usize) -> Self
pub fn new(config: SGBTConfig, n_experts: usize) -> Self
Create a new MoE ensemble with squared loss (regression) and soft gating.
Each expert is seeded uniquely via config.seed ^ (0x0000_0E00_0000_0000 | i).
The gating learning rate defaults to 0.01.
§Panics
Panics if n_experts < 1.
Source§impl<L: Loss + Clone> MoESGBT<L>
impl<L: Loss + Clone> MoESGBT<L>
Sourcepub fn with_loss(config: SGBTConfig, loss: L, n_experts: usize) -> Self
pub fn with_loss(config: SGBTConfig, loss: L, n_experts: usize) -> Self
Sourcepub fn with_gating(
config: SGBTConfig,
loss: L,
n_experts: usize,
gating_mode: GatingMode,
gate_lr: f64,
) -> Self
pub fn with_gating( config: SGBTConfig, loss: L, n_experts: usize, gating_mode: GatingMode, gate_lr: f64, ) -> Self
Create a new MoE ensemble with full control over gating mode and gate learning rate.
§Panics
Panics if n_experts < 1.
Source§impl<L: Loss> MoESGBT<L>
impl<L: Loss> MoESGBT<L>
Sourcepub fn gating_probabilities(&self, features: &[f64]) -> Vec<f64>
pub fn gating_probabilities(&self, features: &[f64]) -> Vec<f64>
Compute gating probabilities for a feature vector.
Returns a vector of K probabilities that sum to 1.0, one per expert. The gate must be initialized (at least one training sample seen), otherwise returns uniform probabilities.
Sourcepub fn train_one(&mut self, sample: &impl Observation)
pub fn train_one(&mut self, sample: &impl Observation)
Train on a single observation.
- Lazily initializes the gate weights if this is the first sample.
- Computes gating probabilities via softmax over the linear gate.
- Routes the sample to experts according to the gating mode:
- Soft: all experts receive the sample, each weighted by its
gating probability (via
SampleRef::weighted). - Hard(top_k): only the top-k experts by probability receive the sample (with unit weight).
- Soft: all experts receive the sample, each weighted by its
gating probability (via
- Updates gate weights via SGD on the cross-entropy gradient:
find the best expert (lowest loss), compute
dz_k = p_k - 1{k==best}, and applyW_k -= gate_lr * dz_k * x,b_k -= gate_lr * dz_k.
Sourcepub fn train_batch<O: Observation>(&mut self, samples: &[O])
pub fn train_batch<O: Observation>(&mut self, samples: &[O])
Train on a batch of observations.
Sourcepub fn predict(&self, features: &[f64]) -> f64
pub fn predict(&self, features: &[f64]) -> f64
Predict the output for a feature vector.
Computes the probability-weighted sum of expert predictions:
ŷ = Σ_k p_k(x) · f_k(x).
Sourcepub fn predict_with_gating(&self, features: &[f64]) -> (f64, Vec<f64>)
pub fn predict_with_gating(&self, features: &[f64]) -> (f64, Vec<f64>)
Predict with gating probabilities returned alongside the prediction.
Returns (prediction, probabilities) where probabilities is a K-length
vector summing to 1.0.
Sourcepub fn expert_predictions(&self, features: &[f64]) -> Vec<f64>
pub fn expert_predictions(&self, features: &[f64]) -> Vec<f64>
Get each expert’s individual prediction for a feature vector.
Returns a K-length vector of raw predictions, one per expert.
Sourcepub fn n_samples_seen(&self) -> u64
pub fn n_samples_seen(&self) -> u64
Total training samples seen.