pub struct SGBT<L: Loss = SquaredLoss> { /* private fields */ }Expand description
Streaming Gradient Boosted Trees ensemble.
The primary entry point for training and prediction. Generic over L: Loss
so the loss function’s gradient/hessian calls are monomorphized (inlined)
into the boosting hot loop — no virtual dispatch overhead.
The default type parameter L = SquaredLoss means SGBT::new(config)
creates a regression model without specifying the loss type explicitly.
§Examples
use irithyll::{SGBTConfig, SGBT};
// Regression with squared loss (default):
let config = SGBTConfig::builder().n_steps(10).build().unwrap();
let model = SGBT::new(config);use irithyll::{SGBTConfig, SGBT};
use irithyll::loss::logistic::LogisticLoss;
// Classification with logistic loss — no Box::new()!
let config = SGBTConfig::builder().n_steps(10).build().unwrap();
let model = SGBT::with_loss(config, LogisticLoss);Implementations§
Source§impl SGBT<SquaredLoss>
impl SGBT<SquaredLoss>
Sourcepub fn new(config: SGBTConfig) -> Self
pub fn new(config: SGBTConfig) -> Self
Create a new SGBT ensemble with squared loss (regression).
This is the most common constructor. For classification or custom
losses, use with_loss.
Source§impl<L: Loss> SGBT<L>
impl<L: Loss> SGBT<L>
Sourcepub fn with_loss(config: SGBTConfig, loss: L) -> Self
pub fn with_loss(config: SGBTConfig, loss: L) -> Self
Create a new SGBT ensemble with a specific loss function.
The loss is stored by value (monomorphized), giving zero-cost gradient/hessian dispatch.
use irithyll::{SGBTConfig, SGBT};
use irithyll::loss::logistic::LogisticLoss;
let config = SGBTConfig::builder().n_steps(10).build().unwrap();
let model = SGBT::with_loss(config, LogisticLoss);Sourcepub fn train_one(&mut self, sample: &impl Observation)
pub fn train_one(&mut self, sample: &impl Observation)
Train on a single observation.
Accepts any type implementing Observation, including Sample,
SampleRef, or tuples like (&[f64], f64) for
zero-copy training.
Sourcepub fn train_batch<O: Observation>(&mut self, samples: &[O])
pub fn train_batch<O: Observation>(&mut self, samples: &[O])
Train on a batch of observations.
Sourcepub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>(
&mut self,
samples: &[O],
interval: usize,
callback: F,
)
pub fn train_batch_with_callback<O: Observation, F: FnMut(usize)>( &mut self, samples: &[O], interval: usize, callback: F, )
Train on a batch with periodic callback for cooperative yielding.
The callback is invoked every interval samples with the number of
samples processed so far. This allows long-running training to yield
to other tasks in an async runtime, update progress bars, or perform
periodic checkpointing.
§Example
use irithyll::{SGBTConfig, SGBT};
let config = SGBTConfig::builder().n_steps(10).build().unwrap();
let mut model = SGBT::new(config);
let data: Vec<(Vec<f64>, f64)> = Vec::new(); // your data
model.train_batch_with_callback(&data, 1000, |processed| {
println!("Trained {} samples", processed);
});Sourcepub fn train_batch_subsampled<O: Observation>(
&mut self,
samples: &[O],
max_samples: usize,
)
pub fn train_batch_subsampled<O: Observation>( &mut self, samples: &[O], max_samples: usize, )
Train on a random subsample of a batch using reservoir sampling.
When max_samples < samples.len(), selects a representative subset
using Algorithm R (Vitter, 1985) — a uniform random sample without
replacement. The selected samples are then trained in their original
order to preserve sequential dependencies.
This is ideal for large replay buffers where training on the full dataset is prohibitively slow but a representative subset gives equivalent model quality (e.g., 1M of 4.3M samples with R²=0.997).
When max_samples >= samples.len(), all samples are trained.
Sourcepub fn train_batch_subsampled_with_callback<O: Observation, F: FnMut(usize)>(
&mut self,
samples: &[O],
max_samples: usize,
interval: usize,
callback: F,
)
pub fn train_batch_subsampled_with_callback<O: Observation, F: FnMut(usize)>( &mut self, samples: &[O], max_samples: usize, interval: usize, callback: F, )
Train on a batch with both subsampling and periodic callbacks.
Combines reservoir subsampling with cooperative yield points. Ideal for long-running daemon training where you need both efficiency (subsampling) and cooperation (yielding).
Sourcepub fn predict_transformed(&self, features: &[f64]) -> f64
pub fn predict_transformed(&self, features: &[f64]) -> f64
Predict with loss transform applied (e.g., sigmoid for logistic loss).
Sourcepub fn predict_proba(&self, features: &[f64]) -> f64
pub fn predict_proba(&self, features: &[f64]) -> f64
Predict probability (alias for predict_transformed).
Sourcepub fn predict_with_confidence(&self, features: &[f64]) -> (f64, f64)
pub fn predict_with_confidence(&self, features: &[f64]) -> (f64, f64)
Predict with confidence estimation.
Returns (prediction, confidence) where confidence = 1 / sqrt(sum_variance).
Higher confidence indicates more certain predictions (leaves have seen
more hessian mass). Confidence of 0.0 means the model has no information.
This enables execution engines to modulate aggressiveness:
- High confidence + favorable prediction → act immediately
- Low confidence → fall back to simpler models or wait for more data
The variance per tree is estimated as 1 / (H_sum + lambda) at the
leaf where the sample lands. The ensemble variance is the sum of
per-tree variances (scaled by learning_rate²), and confidence is
the reciprocal of the standard deviation.
Sourcepub fn total_leaves(&self) -> usize
pub fn total_leaves(&self) -> usize
Total leaves across all active trees.
Sourcepub fn n_samples_seen(&self) -> u64
pub fn n_samples_seen(&self) -> u64
Total samples trained.
Sourcepub fn base_prediction(&self) -> f64
pub fn base_prediction(&self) -> f64
The current base prediction.
Sourcepub fn is_initialized(&self) -> bool
pub fn is_initialized(&self) -> bool
Whether the base prediction has been initialized.
Sourcepub fn config(&self) -> &SGBTConfig
pub fn config(&self) -> &SGBTConfig
Access the configuration.
Sourcepub fn steps(&self) -> &[BoostingStep]
pub fn steps(&self) -> &[BoostingStep]
Immutable access to the boosting steps.
Useful for model inspection and export (e.g., ONNX serialization).
Sourcepub fn feature_importances(&self) -> Vec<f64>
pub fn feature_importances(&self) -> Vec<f64>
Feature importances based on accumulated split gains across all trees.
Returns normalized importances (sum to 1.0) indexed by feature. Returns an empty Vec if no splits have occurred yet.
Sourcepub fn feature_names(&self) -> Option<&[String]>
pub fn feature_names(&self) -> Option<&[String]>
Feature names, if configured.
Sourcepub fn named_feature_importances(&self) -> Option<Vec<(String, f64)>>
pub fn named_feature_importances(&self) -> Option<Vec<(String, f64)>>
Feature importances paired with their names.
Returns None if feature names are not configured. Otherwise returns
(name, importance) pairs sorted by importance descending.
Sourcepub fn train_one_named(&mut self, features: &HashMap<String, f64>, target: f64)
pub fn train_one_named(&mut self, features: &HashMap<String, f64>, target: f64)
Train on a single sample with named features.
Converts a HashMap<String, f64> of named features into a positional
vector using the configured feature names. Missing features default to 0.0.
§Panics
Panics if feature_names is not configured.
Sourcepub fn predict_named(&self, features: &HashMap<String, f64>) -> f64
pub fn predict_named(&self, features: &HashMap<String, f64>) -> f64
Predict with named features.
Converts named features into a positional vector, same as train_one_named.
§Panics
Panics if feature_names is not configured.
Sourcepub fn explain(&self, features: &[f64]) -> ShapValues
pub fn explain(&self, features: &[f64]) -> ShapValues
Compute per-feature SHAP explanations for a prediction.
Returns ShapValues containing
per-feature contributions and a base value. The invariant holds:
base_value + sum(values) ≈ self.predict(features).
Sourcepub fn explain_named(&self, features: &[f64]) -> Option<NamedShapValues>
pub fn explain_named(&self, features: &[f64]) -> Option<NamedShapValues>
Compute named SHAP explanations (requires feature_names configured).
Returns None if feature names are not set. Otherwise returns
NamedShapValues with
(name, contribution) pairs sorted by absolute contribution descending.
Sourcepub fn to_model_state(&self) -> Result<ModelState>
pub fn to_model_state(&self) -> Result<ModelState>
Serialize the model into a ModelState.
Auto-detects the LossType from the loss
function’s Loss::loss_type() implementation.
§Errors
Returns IrithyllError::Serialization
if the loss does not implement loss_type() (returns None). For custom
losses, use to_model_state_with instead.
Sourcepub fn to_model_state_with(&self, loss_type: LossType) -> ModelState
pub fn to_model_state_with(&self, loss_type: LossType) -> ModelState
Serialize the model with an explicit LossType tag.
Use this for custom loss functions that don’t implement loss_type().
Source§impl SGBT<Box<dyn Loss>>
impl SGBT<Box<dyn Loss>>
Sourcepub fn from_model_state(state: ModelState) -> Self
pub fn from_model_state(state: ModelState) -> Self
Reconstruct an SGBT model from a ModelState.
Returns a DynSGBT (SGBT<Box<dyn Loss>>) because the concrete
loss type is determined at runtime from the serialized tag.
Rebuilds the full ensemble including tree topology and leaf values. Histogram accumulators are left empty and will rebuild from continued training. If drift detector state was serialized, it is restored; otherwise a fresh detector is created from the config.