Expand description
Gradient Boosted Regression Trees (GBRT) in Rust
A high-performance, feature-rich implementation of gradient boosting for machine learning tasks including regression and binary classification.
This library provides a complete gradient boosting framework with:
- Multiple objective functions (MSE, MAE, Huber, Log Loss)
- Categorical feature encoding with auto-detection
- Early stopping with validation support
- Cross-validation with proper categorical handling
- Feature importance analysis
- Model serialization with metadata
- Comprehensive error handling with unified error types
§Architecture
The library is organized into modular components:
core: Core configurations and loss functionsboosting: Gradient boosting algorithms and training looptree: Decision tree construction with pluggable split criteriadata: Dataset management and feature preprocessingobjective: Objective functions for optimizationutils: Mathematical utilities, validation, and serializationio: Data loading and model persistence
§Examples
Basic regression example:
use gbrt_rs::{
GradientBooster, BoosterFactory, Dataset, FeatureMatrix,
DataLoader, ObjectiveType
};
use ndarray::Array2;
// Load data from CSV
let data_loader = DataLoader::new()?;
let dataset = data_loader.load_csv("data/training.csv")?;
// Train regression model
let mut booster = BoosterFactory::create_regression_booster()?;
booster.fit(&dataset, None)?;
// Make predictions
let predictions = booster.predict(dataset.features())?;Training with validation and early stopping:
use gbrt_rs::{
GBRTModel, Dataset, DataLoader, GBRTConfig,
LossFunction, ObjectiveType
};
// Configure model
let config = GBRTConfig {
n_estimators: 100,
learning_rate: 0.1,
max_depth: 6,
early_stopping_rounds: Some(10),
..Default::default()
};
let mut model = GBRTModel::with_config(config)?;
// Split data for validation
let data_loader = DataLoader::new()?;
let dataset = data_loader.load_csv("data.csv")?;
let (train_data, val_data) = data_loader.split_data(&dataset, 0.2, true, None)?;
// Train with early stopping
model.fit_with_validation(&train_data, &val_data)?;
// Evaluate
let predictions = model.predict(val_data.features())?;Re-exports§
pub use boosting::GradientBooster;pub use boosting::BoosterFactory;pub use boosting::create_booster;pub use boosting::BoostingError;pub use boosting::BoostingResult;pub use boosting::BoostingType;pub use boosting::TrainingState;pub use boosting::IterationState;pub use data::Dataset;pub use data::FeatureMatrix;pub use data::DataError;pub use data::DataResult;pub use data::Scaler;pub use data::StandardScaler;pub use data::MinMaxScaler;pub use objective::Objective;pub use objective::RegressionObjective;pub use objective::BinaryClassificationObjective;pub use objective::ObjectiveFactory;pub use objective::create_objective;pub use objective::ObjectiveType;pub use objective::ObjectiveError;pub use objective::ObjectiveResult;pub use objective::ObjectiveConfig;pub use objective::MSEObjective;pub use objective::MAEObjective;pub use objective::HuberObjective;pub use objective::LogLossObjective;pub use tree::DecisionTree;pub use tree::TreeBuilder;pub use tree::TreeError;pub use tree::TreeResult;pub use tree::Splitter;pub use tree::BestSplitter;pub use tree::SplitCriterion;pub use tree::MSECriterion;pub use utils::MathError;pub use utils::MathResult;pub use utils::Statistics;pub use utils::VectorMath;pub use utils::ValidationError;pub use utils::ValidationResult;pub use utils::DataValidator;pub use utils::SerializationError;pub use utils::SerializationResult;pub use utils::ModelSerializer;pub use io::DataLoader;pub use io::DataLoaderError;pub use io::DataLoaderResult;pub use io::DataFormat;pub use io::ModelIO;pub use io::ModelIOError;pub use io::ModelIOResult;pub use io::ModelFormat;pub use io::LoadOptions as DataLoadOptions;pub use io::SaveOptions as ModelSaveOptions;
Modules§
- boosting
- Gradient boosting algorithms and ensemble training Gradient boosting algorithms and model training.
- core
- Core configurations and loss functions for gradient boosting.
- data
- Dataset handling, feature matrices, and preprocessing Data handling and preprocessing for gradient boosting.
- io
- Data loading and model serialization Input/Output operations for data loading and model persistence.
- objective
- Loss functions and objective functions for training Objective functions for gradient boosting.
- preprocessing
- Data preprocessing utilities
- tree
- Decision tree construction and split criteria Decision tree implementations for gradient boosting.
- utils
- Mathematical utilities and validation helpers Utility modules for common operations in gradient boosting.
Structs§
- Feature
Analysis - Feature importance analysis with named features.
- Feature
Summary - Summary statistics for feature importance analysis.
- GBRT
Model - Main GBRT model struct that provides a high-level API for training and prediction.
- Model
Metrics - Container for common model evaluation metrics.
Enums§
- GBRT
Error - Unified error type for the entire crate.
Functions§
- cross_
validate - Performs k-fold cross-validation for model evaluation.
- train_
from_ csv - Convenience functions for common tasks Trains a GBRT model on data from a CSV file.
- train_
with_ validation_ split - Trains a GBRT model with train/validation split.
- version
- Returns the crate version string.
Type Aliases§
- Result
- Result type alias for the crate using
GBRTError