Skip to main content

SGBT

Struct SGBT 

Source
pub struct SGBT<L: Loss = SquaredLoss> { /* private fields */ }
Available on crate feature alloc only.
Expand description

Streaming Gradient Boosted Trees ensemble.

The primary entry point for training and prediction. Generic over L: Loss so the loss function’s gradient/hessian calls are monomorphized (inlined) into the boosting hot loop – no virtual dispatch overhead.

The default type parameter L = SquaredLoss means SGBT::new(config) creates a regression model without specifying the loss type explicitly.

§Examples

use irithyll::{SGBTConfig, SGBT};

// Regression with squared loss (default):
let config = SGBTConfig::builder().n_steps(10).build().unwrap();
let model = SGBT::new(config);
```ignore

```text
use irithyll::{SGBTConfig, SGBT};
use irithyll::loss::logistic::LogisticLoss;

// Classification with logistic loss -- no Box::new()!
let config = SGBTConfig::builder().n_steps(10).build().unwrap();
let model = SGBT::with_loss(config, LogisticLoss);

Implementations§

Source§

impl SGBT<SquaredLoss>

Source

pub fn new(config: SGBTConfig) -> Self

Create a new SGBT ensemble with squared loss (regression).

This is the most common constructor. For classification or custom losses, use with_loss.

Source§

impl<L: Loss> SGBT<L>

Source

pub fn with_loss(config: SGBTConfig, loss: L) -> Self

Create a new SGBT ensemble with a specific loss function.

The loss is stored by value (monomorphized), giving zero-cost gradient/hessian dispatch.

use irithyll::{SGBTConfig, SGBT};
use irithyll::loss::logistic::LogisticLoss;

let config = SGBTConfig::builder().n_steps(10).build().unwrap();
let model = SGBT::with_loss(config, LogisticLoss);
Source

pub fn train_one(&mut self, sample: &impl Observation)

Train on a single observation.

Accepts any type implementing Observation, including Sample, SampleRef, or tuples like (&[f64], f64) for zero-copy training.

Source

pub fn train_batch<O: Observation>(&mut self, samples: &[O])

Train on a batch of observations.

Source

pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>( &mut self, samples: &[O], interval: usize, callback: F, )

Train on a batch with periodic callback for cooperative yielding.

The callback is invoked every interval samples with the number of samples processed so far. This allows long-running training to yield to other tasks in an async runtime, update progress bars, or perform periodic checkpointing.

§Example
use irithyll::{SGBTConfig, SGBT};

let config = SGBTConfig::builder().n_steps(10).build().unwrap();
let mut model = SGBT::new(config);
let data: Vec<(Vec<f64>, f64)> = Vec::new(); // your data

model.train_batch_with_callback(&data, 1000, |processed| {
    println!("Trained {} samples", processed);
});
Source

pub fn train_batch_subsampled<O: Observation>( &mut self, samples: &[O], max_samples: usize, )

Train on a random subsample of a batch using reservoir sampling.

When max_samples < samples.len(), selects a representative subset using Algorithm R (Vitter, 1985) – a uniform random sample without replacement. The selected samples are then trained in their original order to preserve sequential dependencies.

This is ideal for large replay buffers where training on the full dataset is prohibitively slow but a representative subset gives equivalent model quality (e.g., 1M of 4.3M samples with R²=0.997).

When max_samples >= samples.len(), all samples are trained.

Source

pub fn train_batch_subsampled_with_callback<O: Observation, F: FnMut(usize)>( &mut self, samples: &[O], max_samples: usize, interval: usize, callback: F, )

Train on a batch with both subsampling and periodic callbacks.

Combines reservoir subsampling with cooperative yield points. Ideal for long-running daemon training where you need both efficiency (subsampling) and cooperation (yielding).

Source

pub fn predict(&self, features: &[f64]) -> f64

Predict the raw output for a feature vector.

Always uses sigmoid-blended soft routing with auto-calibrated per-feature bandwidths derived from median split threshold gaps. Features that have never been split on use hard routing (bandwidth = infinity).

Source

pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64

Predict using sigmoid-blended soft routing with an explicit bandwidth.

Uses a single bandwidth for all features. For auto-calibrated per-feature bandwidths, use predict() which always uses smooth routing.

Source

pub fn auto_bandwidths(&self) -> &[f64]

Per-feature auto-calibrated bandwidths used by predict().

Empty before the first training sample. Each entry corresponds to a feature index; f64::INFINITY means that feature has no splits and uses hard routing.

Source

pub fn predict_interpolated(&self, features: &[f64]) -> f64

Predict with parent-leaf linear interpolation.

Blends each leaf prediction with its parent’s preserved prediction based on sample count, preventing stale predictions from fresh leaves.

Source

pub fn predict_sibling_interpolated(&self, features: &[f64]) -> f64

Predict with sibling-based interpolation for feature-continuous predictions.

At each split node near the threshold boundary, blends left and right subtree predictions linearly based on distance from the threshold. Uses auto-calibrated bandwidths as the interpolation margin. Predictions vary continuously as features change, eliminating step-function artifacts.

Source

pub fn predict_graduated(&self, features: &[f64]) -> f64

Predict with graduated active-shadow blending.

Smoothly transitions between active and shadow trees during replacement, eliminating prediction dips. Requires shadow_warmup to be configured. When disabled, equivalent to predict().

Source

pub fn predict_graduated_sibling_interpolated(&self, features: &[f64]) -> f64

Predict with graduated blending + sibling interpolation (premium path).

Combines graduated active-shadow handoff (no prediction dips during tree replacement) with feature-continuous sibling interpolation (no step-function artifacts near split boundaries).

Source

pub fn predict_transformed(&self, features: &[f64]) -> f64

Predict with loss transform applied (e.g., sigmoid for logistic loss).

Source

pub fn predict_proba(&self, features: &[f64]) -> f64

Predict probability (alias for predict_transformed).

Source

pub fn predict_with_confidence(&self, features: &[f64]) -> (f64, f64)

Predict with confidence estimation.

Returns (prediction, confidence) where confidence = 1 / sqrt(sum_variance). Higher confidence indicates more certain predictions (leaves have seen more hessian mass). Confidence of 0.0 means the model has no information.

This enables execution engines to modulate aggressiveness:

  • High confidence + favorable prediction → act immediately
  • Low confidence → fall back to simpler models or wait for more data

The variance per tree is estimated as 1 / (H_sum + lambda) at the leaf where the sample lands. The ensemble variance is the sum of per-tree variances (scaled by learning_rate²), and confidence is the reciprocal of the standard deviation.

Source

pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64>

Batch prediction.

Source

pub fn n_steps(&self) -> usize

Number of boosting steps.

Source

pub fn n_trees(&self) -> usize

Total trees (active + alternates).

Source

pub fn total_leaves(&self) -> usize

Total leaves across all active trees.

Source

pub fn n_samples_seen(&self) -> u64

Total samples trained.

Source

pub fn base_prediction(&self) -> f64

The current base prediction.

Source

pub fn is_initialized(&self) -> bool

Whether the base prediction has been initialized.

Source

pub fn config(&self) -> &SGBTConfig

Access the configuration.

Source

pub fn set_learning_rate(&mut self, lr: f64)

Set the learning rate for future boosting rounds.

This allows external schedulers (e.g., lr_schedule::LRScheduler) to adapt the rate over time without rebuilding the model.

§Arguments
  • lr – New learning rate (should be positive and finite)
Source

pub fn steps(&self) -> &[BoostingStep]

Immutable access to the boosting steps.

Useful for model inspection and export (e.g., ONNX serialization).

Source

pub fn loss(&self) -> &L

Immutable access to the loss function.

Source

pub fn feature_importances(&self) -> Vec<f64>

Feature importances based on accumulated split gains across all trees.

Returns normalized importances (sum to 1.0) indexed by feature. Returns an empty Vec if no splits have occurred yet.

Source

pub fn feature_names(&self) -> Option<&[String]>

Feature names, if configured.

Source

pub fn named_feature_importances(&self) -> Option<Vec<(String, f64)>>

Feature importances paired with their names.

Returns None if feature names are not configured. Otherwise returns (name, importance) pairs sorted by importance descending.

Source

pub fn train_one_named(&mut self, features: &HashMap<String, f64>, target: f64)

Available on crate feature std only.

Train on a single sample with named features.

Converts a HashMap<String, f64> of named features into a positional vector using the configured feature names. Missing features default to 0.0.

§Panics

Panics if feature_names is not configured.

Source

pub fn predict_named(&self, features: &HashMap<String, f64>) -> f64

Available on crate feature std only.

Predict with named features.

Converts named features into a positional vector, same as train_one_named.

§Panics

Panics if feature_names is not configured.

Source

pub fn reset(&mut self)

Reset the ensemble to initial state.

Trait Implementations§

Source§

impl<L: Loss + Clone> Clone for SGBT<L>

Source§

fn clone(&self) -> Self

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl<L: Loss> Debug for SGBT<L>

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more

Auto Trait Implementations§

§

impl<L> Freeze for SGBT<L>
where L: Freeze,

§

impl<L = SquaredLoss> !RefUnwindSafe for SGBT<L>

§

impl<L> Send for SGBT<L>

§

impl<L> Sync for SGBT<L>

§

impl<L> Unpin for SGBT<L>
where L: Unpin,

§

impl<L> UnsafeUnpin for SGBT<L>
where L: UnsafeUnpin,

§

impl<L = SquaredLoss> !UnwindSafe for SGBT<L>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.