GradientBooster

Struct GradientBooster 

Source
pub struct GradientBooster { /* private fields */ }
Expand description

Main gradient boosting implementation.

GradientBooster is the core ensemble model that trains a sequence of decision trees to minimize a given loss function. It supports various training strategies and provides comprehensive monitoring capabilities.

§Key Features

  • Loss Functions: MSE, MAE, Huber (robust regression), LogLoss (classification)
  • Regularization: Learning rate shrinkage, L2 regularization (lambda)
  • Stochastic Boosting: Row subsampling for variance reduction
  • Early Stopping: Configurable patience and improvement tolerance
  • Feature Importance: Gain-based importance scores

§Serialization

This struct derives Serialize and Deserialize, allowing trained models to be saved and loaded using serde-compatible formats.

Implementations§

Source§

impl GradientBooster

Source

pub fn new(config: GBRTConfig) -> BoostingResult<Self>

Creates a new gradient booster with the given configuration.

Validates the configuration before returning. The booster is untrained and ready for [fit()] to be called.

§Parameters
  • config: GBRTConfig with hyperparameters and training options.
§Errors

Returns BoostingError::ConfigError if configuration validation fails.

Source

pub fn fit( &mut self, train_data: &Dataset, validation_data: Option<&Dataset>, ) -> BoostingResult<()>

Trains the model on a dataset with optional validation data.

This is the main training entry point. It performs the complete gradient boosting algorithm including:

  • Data validation and initialization
  • Iterative tree building with gradient optimization
  • Optional stochastic subsampling
  • Validation loss monitoring and early stopping
  • Feature importance computation
§Parameters
  • train_data: Training dataset with features and targets.
  • validation_data: Optional validation dataset for early stopping and monitoring.
§Errors

Returns errors for:

  • Invalid training data (empty, mismatched dimensions)
  • Configuration issues detected during training
  • Tree building failures
  • Loss computation problems (e.g., NaN values)
§Early Stopping

If config.early_stopping_rounds is set and validation_data is provided, training stops when validation loss doesn’t improve significantly (by more than config.early_stopping_tolerance) for the specified number of rounds.

Source

pub fn predict(&self, features: &FeatureMatrix) -> BoostingResult<Vec<f64>>

Makes predictions for a batch of samples.

Calculates predictions by summing contributions from all trees with shrinkage applied. Applies loss-specific transformation (e.g., sigmoid for LogLoss).

§Parameters
  • features: Feature matrix with shape (n_samples, n_features).
§Returns

Vector of predictions of length n_samples.

§Errors
  • PredictionError::ModelNotTrained if called before [fit()]
  • PredictionError::FeatureMismatch if features.n_features() doesn’t match training data
Source

pub fn predict_single(&self, features: &[f64]) -> BoostingResult<f64>

Makes a prediction for a single sample.

More efficient than predict() for single-sample inference.

§Parameters
  • features: Slice of feature values of length n_features.
§Returns

Single prediction value.

§Errors
  • PredictionError::ModelNotTrained if called before [fit()]
  • PredictionError::FeatureMismatch if features.len() doesn’t match training data
Source

pub fn feature_importance(&self) -> &[f64]

Returns feature importance scores from the trained model.

Importance is computed as the normalized total gain contributed by splits on each feature across all trees. Scores sum to 1.0.

§Returns

Slice of importance scores of length n_features. Returns zeros if config.compute_feature_importance is false.

§Panics

Will panic if called before any model has been trained (no features are known).

Source

pub fn training_state(&self) -> Option<&TrainingState>

Returns the training history and validation state.

Contains per-iteration metrics and the best iteration if early stopping was used. This is useful for plotting learning curves or analyzing training dynamics.

§Returns

Some(TrainingState) after training, None before training.

Source

pub fn n_trees(&self) -> usize

Returns the number of trees in the ensemble.

Note: This may be fewer than config.n_estimators if early stopping triggered.

§Returns

Number of trees after training.

Source

pub fn config(&self) -> &GBRTConfig

Returns the configuration used to create this booster.

Source

pub fn is_trained(&self) -> bool

Checks whether the model has been trained.

Trait Implementations§

Source§

impl Clone for GradientBooster

Source§

fn clone(&self) -> GradientBooster

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for GradientBooster

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl<'de> Deserialize<'de> for GradientBooster

Source§

fn deserialize<__D>(__deserializer: __D) -> Result<Self, __D::Error>
where __D: Deserializer<'de>,

Deserialize this value from the given Serde deserializer. Read more
Source§

impl Display for GradientBooster

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl Serialize for GradientBooster

Source§

fn serialize<__S>(&self, __serializer: __S) -> Result<__S::Ok, __S::Error>
where __S: Serializer,

Serialize this value into the given Serde serializer. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T> ToString for T
where T: Display + ?Sized,

Source§

fn to_string(&self) -> String

Converts the given value to a String. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> Allocation for T
where T: RefUnwindSafe + Send + Sync,

Source§

impl<T> DeserializeOwned for T
where T: for<'de> Deserialize<'de>,