Crate gbrt_rs

Crate gbrt_rs 

Source
Expand description

Gradient Boosted Regression Trees (GBRT) in Rust

A high-performance, feature-rich implementation of gradient boosting for machine learning tasks including regression and binary classification.

This library provides a complete gradient boosting framework with:

  • Multiple objective functions (MSE, MAE, Huber, Log Loss)
  • Categorical feature encoding with auto-detection
  • Early stopping with validation support
  • Cross-validation with proper categorical handling
  • Feature importance analysis
  • Model serialization with metadata
  • Comprehensive error handling with unified error types

§Architecture

The library is organized into modular components:

  • core: Core configurations and loss functions
  • boosting: Gradient boosting algorithms and training loop
  • tree: Decision tree construction with pluggable split criteria
  • data: Dataset management and feature preprocessing
  • objective: Objective functions for optimization
  • utils: Mathematical utilities, validation, and serialization
  • io: Data loading and model persistence

§Examples

Basic regression example:

use gbrt_rs::{
    GradientBooster, BoosterFactory, Dataset, FeatureMatrix,
    DataLoader, ObjectiveType
};
use ndarray::Array2;

// Load data from CSV
let data_loader = DataLoader::new()?;
let dataset = data_loader.load_csv("data/training.csv")?;

// Train regression model
let mut booster = BoosterFactory::create_regression_booster()?;
booster.fit(&dataset, None)?;

// Make predictions
let predictions = booster.predict(dataset.features())?;

Training with validation and early stopping:

use gbrt_rs::{
    GBRTModel, Dataset, DataLoader, GBRTConfig,
    LossFunction, ObjectiveType
};

// Configure model
let config = GBRTConfig {
    n_estimators: 100,
    learning_rate: 0.1,
    max_depth: 6,
    early_stopping_rounds: Some(10),
    ..Default::default()
};

let mut model = GBRTModel::with_config(config)?;

// Split data for validation
let data_loader = DataLoader::new()?;
let dataset = data_loader.load_csv("data.csv")?;
let (train_data, val_data) = data_loader.split_data(&dataset, 0.2, true, None)?;

// Train with early stopping
model.fit_with_validation(&train_data, &val_data)?;

// Evaluate
let predictions = model.predict(val_data.features())?;

Re-exports§

pub use boosting::GradientBooster;
pub use boosting::BoosterFactory;
pub use boosting::create_booster;
pub use boosting::BoostingError;
pub use boosting::BoostingResult;
pub use boosting::BoostingType;
pub use boosting::TrainingState;
pub use boosting::IterationState;
pub use data::Dataset;
pub use data::FeatureMatrix;
pub use data::DataError;
pub use data::DataResult;
pub use data::Scaler;
pub use data::StandardScaler;
pub use data::MinMaxScaler;
pub use objective::Objective;
pub use objective::RegressionObjective;
pub use objective::BinaryClassificationObjective;
pub use objective::ObjectiveFactory;
pub use objective::create_objective;
pub use objective::ObjectiveType;
pub use objective::ObjectiveError;
pub use objective::ObjectiveResult;
pub use objective::ObjectiveConfig;
pub use objective::MSEObjective;
pub use objective::MAEObjective;
pub use objective::HuberObjective;
pub use objective::LogLossObjective;
pub use tree::DecisionTree;
pub use tree::TreeBuilder;
pub use tree::TreeError;
pub use tree::TreeResult;
pub use tree::Splitter;
pub use tree::BestSplitter;
pub use tree::SplitCriterion;
pub use tree::MSECriterion;
pub use utils::MathError;
pub use utils::MathResult;
pub use utils::Statistics;
pub use utils::VectorMath;
pub use utils::ValidationError;
pub use utils::ValidationResult;
pub use utils::DataValidator;
pub use utils::SerializationError;
pub use utils::SerializationResult;
pub use utils::ModelSerializer;
pub use io::DataLoader;
pub use io::DataLoaderError;
pub use io::DataLoaderResult;
pub use io::DataFormat;
pub use io::ModelIO;
pub use io::ModelIOError;
pub use io::ModelIOResult;
pub use io::ModelFormat;
pub use io::LoadOptions as DataLoadOptions;
pub use io::SaveOptions as ModelSaveOptions;

Modules§

boosting
Gradient boosting algorithms and ensemble training Gradient boosting algorithms and model training.
core
Core configurations and loss functions for gradient boosting.
data
Dataset handling, feature matrices, and preprocessing Data handling and preprocessing for gradient boosting.
io
Data loading and model serialization Input/Output operations for data loading and model persistence.
objective
Loss functions and objective functions for training Objective functions for gradient boosting.
preprocessing
Data preprocessing utilities
tree
Decision tree construction and split criteria Decision tree implementations for gradient boosting.
utils
Mathematical utilities and validation helpers Utility modules for common operations in gradient boosting.

Structs§

FeatureAnalysis
Feature importance analysis with named features.
FeatureSummary
Summary statistics for feature importance analysis.
GBRTModel
Main GBRT model struct that provides a high-level API for training and prediction.
ModelMetrics
Container for common model evaluation metrics.

Enums§

GBRTError
Unified error type for the entire crate.

Functions§

cross_validate
Performs k-fold cross-validation for model evaluation.
train_from_csv
Convenience functions for common tasks Trains a GBRT model on data from a CSV file.
train_with_validation_split
Trains a GBRT model with train/validation split.
version
Returns the crate version string.

Type Aliases§

Result
Result type alias for the crate using GBRTError