Skip to main content

SGBT

Struct SGBT 

Source
pub struct SGBT<L: Loss = SquaredLoss> { /* private fields */ }
Expand description

Streaming Gradient Boosted Trees ensemble.

The primary entry point for training and prediction. Generic over L: Loss so the loss function’s gradient/hessian calls are monomorphized (inlined) into the boosting hot loop — no virtual dispatch overhead.

The default type parameter L = SquaredLoss means SGBT::new(config) creates a regression model without specifying the loss type explicitly.

§Examples

use irithyll::{SGBTConfig, SGBT};

// Regression with squared loss (default):
let config = SGBTConfig::builder().n_steps(10).build().unwrap();
let model = SGBT::new(config);
use irithyll::{SGBTConfig, SGBT};
use irithyll::loss::logistic::LogisticLoss;

// Classification with logistic loss — no Box::new()!
let config = SGBTConfig::builder().n_steps(10).build().unwrap();
let model = SGBT::with_loss(config, LogisticLoss);

Implementations§

Source§

impl SGBT<SquaredLoss>

Source

pub fn new(config: SGBTConfig) -> Self

Create a new SGBT ensemble with squared loss (regression).

This is the most common constructor. For classification or custom losses, use with_loss.

Source§

impl<L: Loss> SGBT<L>

Source

pub fn with_loss(config: SGBTConfig, loss: L) -> Self

Create a new SGBT ensemble with a specific loss function.

The loss is stored by value (monomorphized), giving zero-cost gradient/hessian dispatch.

use irithyll::{SGBTConfig, SGBT};
use irithyll::loss::logistic::LogisticLoss;

let config = SGBTConfig::builder().n_steps(10).build().unwrap();
let model = SGBT::with_loss(config, LogisticLoss);
Source

pub fn train_one(&mut self, sample: &impl Observation)

Train on a single observation.

Accepts any type implementing Observation, including Sample, SampleRef, or tuples like (&[f64], f64) for zero-copy training.

Source

pub fn train_batch<O: Observation>(&mut self, samples: &[O])

Train on a batch of observations.

Source

pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>( &mut self, samples: &[O], interval: usize, callback: F, )

Train on a batch with periodic callback for cooperative yielding.

The callback is invoked every interval samples with the number of samples processed so far. This allows long-running training to yield to other tasks in an async runtime, update progress bars, or perform periodic checkpointing.

§Example
use irithyll::{SGBTConfig, SGBT};

let config = SGBTConfig::builder().n_steps(10).build().unwrap();
let mut model = SGBT::new(config);
let data: Vec<(Vec<f64>, f64)> = Vec::new(); // your data

model.train_batch_with_callback(&data, 1000, |processed| {
    println!("Trained {} samples", processed);
});
Source

pub fn train_batch_subsampled<O: Observation>( &mut self, samples: &[O], max_samples: usize, )

Train on a random subsample of a batch using reservoir sampling.

When max_samples < samples.len(), selects a representative subset using Algorithm R (Vitter, 1985) — a uniform random sample without replacement. The selected samples are then trained in their original order to preserve sequential dependencies.

This is ideal for large replay buffers where training on the full dataset is prohibitively slow but a representative subset gives equivalent model quality (e.g., 1M of 4.3M samples with R²=0.997).

When max_samples >= samples.len(), all samples are trained.

Source

pub fn train_batch_subsampled_with_callback<O: Observation, F: FnMut(usize)>( &mut self, samples: &[O], max_samples: usize, interval: usize, callback: F, )

Train on a batch with both subsampling and periodic callbacks.

Combines reservoir subsampling with cooperative yield points. Ideal for long-running daemon training where you need both efficiency (subsampling) and cooperation (yielding).

Source

pub fn predict(&self, features: &[f64]) -> f64

Predict the raw output for a feature vector.

Source

pub fn predict_transformed(&self, features: &[f64]) -> f64

Predict with loss transform applied (e.g., sigmoid for logistic loss).

Source

pub fn predict_proba(&self, features: &[f64]) -> f64

Predict probability (alias for predict_transformed).

Source

pub fn predict_with_confidence(&self, features: &[f64]) -> (f64, f64)

Predict with confidence estimation.

Returns (prediction, confidence) where confidence = 1 / sqrt(sum_variance). Higher confidence indicates more certain predictions (leaves have seen more hessian mass). Confidence of 0.0 means the model has no information.

This enables execution engines to modulate aggressiveness:

  • High confidence + favorable prediction → act immediately
  • Low confidence → fall back to simpler models or wait for more data

The variance per tree is estimated as 1 / (H_sum + lambda) at the leaf where the sample lands. The ensemble variance is the sum of per-tree variances (scaled by learning_rate²), and confidence is the reciprocal of the standard deviation.

Source

pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64>

Batch prediction.

Source

pub fn n_steps(&self) -> usize

Number of boosting steps.

Source

pub fn n_trees(&self) -> usize

Total trees (active + alternates).

Source

pub fn total_leaves(&self) -> usize

Total leaves across all active trees.

Source

pub fn n_samples_seen(&self) -> u64

Total samples trained.

Source

pub fn base_prediction(&self) -> f64

The current base prediction.

Source

pub fn is_initialized(&self) -> bool

Whether the base prediction has been initialized.

Source

pub fn config(&self) -> &SGBTConfig

Access the configuration.

Source

pub fn steps(&self) -> &[BoostingStep]

Immutable access to the boosting steps.

Useful for model inspection and export (e.g., ONNX serialization).

Source

pub fn loss(&self) -> &L

Immutable access to the loss function.

Source

pub fn feature_importances(&self) -> Vec<f64>

Feature importances based on accumulated split gains across all trees.

Returns normalized importances (sum to 1.0) indexed by feature. Returns an empty Vec if no splits have occurred yet.

Source

pub fn feature_names(&self) -> Option<&[String]>

Feature names, if configured.

Source

pub fn named_feature_importances(&self) -> Option<Vec<(String, f64)>>

Feature importances paired with their names.

Returns None if feature names are not configured. Otherwise returns (name, importance) pairs sorted by importance descending.

Source

pub fn train_one_named(&mut self, features: &HashMap<String, f64>, target: f64)

Train on a single sample with named features.

Converts a HashMap<String, f64> of named features into a positional vector using the configured feature names. Missing features default to 0.0.

§Panics

Panics if feature_names is not configured.

Source

pub fn predict_named(&self, features: &HashMap<String, f64>) -> f64

Predict with named features.

Converts named features into a positional vector, same as train_one_named.

§Panics

Panics if feature_names is not configured.

Source

pub fn explain(&self, features: &[f64]) -> ShapValues

Compute per-feature SHAP explanations for a prediction.

Returns ShapValues containing per-feature contributions and a base value. The invariant holds: base_value + sum(values) ≈ self.predict(features).

Source

pub fn explain_named(&self, features: &[f64]) -> Option<NamedShapValues>

Compute named SHAP explanations (requires feature_names configured).

Returns None if feature names are not set. Otherwise returns NamedShapValues with (name, contribution) pairs sorted by absolute contribution descending.

Source

pub fn reset(&mut self)

Reset the ensemble to initial state.

Source

pub fn to_model_state(&self) -> Result<ModelState>

Serialize the model into a ModelState.

Auto-detects the LossType from the loss function’s Loss::loss_type() implementation.

§Errors

Returns IrithyllError::Serialization if the loss does not implement loss_type() (returns None). For custom losses, use to_model_state_with instead.

Source

pub fn to_model_state_with(&self, loss_type: LossType) -> ModelState

Serialize the model with an explicit LossType tag.

Use this for custom loss functions that don’t implement loss_type().

Source§

impl SGBT<Box<dyn Loss>>

Source

pub fn from_model_state(state: ModelState) -> Self

Reconstruct an SGBT model from a ModelState.

Returns a DynSGBT (SGBT<Box<dyn Loss>>) because the concrete loss type is determined at runtime from the serialized tag.

Rebuilds the full ensemble including tree topology and leaf values. Histogram accumulators are left empty and will rebuild from continued training. If drift detector state was serialized, it is restored; otherwise a fresh detector is created from the config.

Trait Implementations§

Source§

impl<L: Loss + Clone> Clone for SGBT<L>

Source§

fn clone(&self) -> Self

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl<L: Loss> Debug for SGBT<L>

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more

Auto Trait Implementations§

§

impl<L> Freeze for SGBT<L>
where L: Freeze,

§

impl<L = SquaredLoss> !RefUnwindSafe for SGBT<L>

§

impl<L> Send for SGBT<L>

§

impl<L> Sync for SGBT<L>

§

impl<L> Unpin for SGBT<L>
where L: Unpin,

§

impl<L> UnsafeUnpin for SGBT<L>
where L: UnsafeUnpin,

§

impl<L = SquaredLoss> !UnwindSafe for SGBT<L>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more