Skip to main content

MoESGBT

Struct MoESGBT 

Source
pub struct MoESGBT<L: Loss = SquaredLoss> { /* private fields */ }
Available on crate feature alloc only.
Expand description

Streaming Mixture of Experts over SGBT ensembles.

Combines K independent SGBT<L> experts with a learned linear softmax gating network. The gate is trained online via SGD to route samples to the expert with the lowest loss, while all experts (or the top-k in hard gating mode) are trained on each incoming sample.

Generic over L: Loss so the expert loss function is monomorphized. The default is SquaredLoss for regression tasks.

§Gate Architecture

The gate is a single linear layer: z_k = W_k · x + b_k followed by softmax. Weights are lazily initialized to zeros on the first sample (since the feature dimensionality is not known at construction time). The gate learns via cross-entropy gradient descent against the one-hot indicator of the best expert per sample.

Implementations§

Source§

impl MoESGBT<SquaredLoss>

Source

pub fn new(config: SGBTConfig, n_experts: usize) -> Self

Create a new MoE ensemble with squared loss (regression) and soft gating.

Each expert is seeded uniquely via config.seed ^ (0x0000_0E00_0000_0000 | i). The gating learning rate defaults to 0.01.

§Panics

Panics if n_experts < 1.

Source§

impl<L: Loss + Clone> MoESGBT<L>

Source

pub fn with_loss(config: SGBTConfig, loss: L, n_experts: usize) -> Self

Create a new MoE ensemble with a custom loss and soft gating.

§Panics

Panics if n_experts < 1.

Source

pub fn with_gating( config: SGBTConfig, loss: L, n_experts: usize, gating_mode: GatingMode, gate_lr: f64, ) -> Self

Create a new MoE ensemble with full control over gating mode and gate learning rate.

§Panics

Panics if n_experts < 1.

Source§

impl<L: Loss> MoESGBT<L>

Source

pub fn gating_probabilities(&self, features: &[f64]) -> Vec<f64>

Compute gating probabilities for a feature vector.

Returns a vector of K probabilities that sum to 1.0, one per expert. The gate must be initialized (at least one training sample seen), otherwise returns uniform probabilities.

Source

pub fn train_one(&mut self, sample: &impl Observation)

Train on a single observation.

  1. Lazily initializes the gate weights if this is the first sample.
  2. Computes gating probabilities via softmax over the linear gate.
  3. Routes the sample to experts according to the gating mode:
    • Soft: all experts receive the sample, each weighted by its gating probability (via SampleRef::weighted).
    • Hard(top_k): only the top-k experts by probability receive the sample (with unit weight).
  4. Updates gate weights via SGD on the cross-entropy gradient: find the best expert (lowest loss), compute dz_k = p_k - 1{k==best}, and apply W_k -= gate_lr * dz_k * x, b_k -= gate_lr * dz_k.
Source

pub fn train_batch<O: Observation>(&mut self, samples: &[O])

Train on a batch of observations.

Source

pub fn predict(&self, features: &[f64]) -> f64

Predict the output for a feature vector.

Computes the probability-weighted sum of expert predictions: ŷ = Σ_k p_k(x) · f_k(x).

Source

pub fn predict_with_gating(&self, features: &[f64]) -> (f64, Vec<f64>)

Predict with gating probabilities returned alongside the prediction.

Returns (prediction, probabilities) where probabilities is a K-length vector summing to 1.0.

Source

pub fn expert_predictions(&self, features: &[f64]) -> Vec<f64>

Get each expert’s individual prediction for a feature vector.

Returns a K-length vector of raw predictions, one per expert.

Source

pub fn n_experts(&self) -> usize

Number of experts in the mixture.

Source

pub fn n_samples_seen(&self) -> u64

Total training samples seen.

Source

pub fn experts(&self) -> &[SGBT<L>]

Immutable access to all experts.

Source

pub fn expert(&self, idx: usize) -> &SGBT<L>

Immutable access to a specific expert.

§Panics

Panics if idx >= n_experts.

Source

pub fn reset(&mut self)

Reset the entire MoE to its initial state.

Resets all experts, clears gate weights and biases back to zeros, and resets the sample counter.

Trait Implementations§

Source§

impl<L: Loss + Clone> Clone for MoESGBT<L>

Source§

fn clone(&self) -> Self

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl<L: Loss> Debug for MoESGBT<L>

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl<L: Loss> StreamingLearner for MoESGBT<L>

Source§

fn train_one(&mut self, features: &[f64], target: f64, weight: f64)

Train on a single observation with explicit sample weight. Read more
Source§

fn predict(&self, features: &[f64]) -> f64

Predict the target for the given feature vector. Read more
Source§

fn n_samples_seen(&self) -> u64

Total number of observations trained on since creation or last reset.
Source§

fn reset(&mut self)

Reset the model to its initial (untrained) state. Read more
Source§

fn train(&mut self, features: &[f64], target: f64)

Train on a single observation with unit weight. Read more
Source§

fn predict_batch(&self, feature_matrix: &[&[f64]]) -> Vec<f64>

Predict for each row in a feature matrix. Read more

Auto Trait Implementations§

§

impl<L> Freeze for MoESGBT<L>
where L: Freeze,

§

impl<L = SquaredLoss> !RefUnwindSafe for MoESGBT<L>

§

impl<L> Send for MoESGBT<L>

§

impl<L> Sync for MoESGBT<L>

§

impl<L> Unpin for MoESGBT<L>
where L: Unpin,

§

impl<L> UnsafeUnpin for MoESGBT<L>
where L: UnsafeUnpin,

§

impl<L = SquaredLoss> !UnwindSafe for MoESGBT<L>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.