GBRT-RS: Gradient Boosted Regression Trees in Rust
A high-performance, production-ready implementation of Gradient Boosted Regression Trees (GBRT) in Rust, engineered for speed, memory safety, and reproducible machine learning workflows. This library provides both a comprehensive CLI tool and a programmatic API, supporting regression and binary classification tasks with advanced features including intelligent categorical encoding with automatic metadata capture, early stopping with patience-based termination, k-fold cross-validation with proper metric calculation, feature importance analysis, and robust model serialization.
Features
- High Performance: Written in Rust for speed and memory safety
- Dual Interface: Both programmatic API and comprehensive CLI
- Multiple Objectives: Regression (MSE, MAE, Huber) and Binary Classification (LogLoss)
- Categorical Features: First-class support with encoding and auto-detection
- Early Stopping: Validation-based early stopping with configurable patience
- Cross-Validation: Built-in k-fold CV with proper metrics
- Feature Importance: Analyze which features drive predictions
- Serialization: Save/load trained models to/from JSON
- Data Validation: Automatic dataset validation and error handling
- Hyperparameter Tuning: Full control over learning rate, depth, subsampling
Installation
Install from Crates.io
Build from Source
The optimized binary will be available at target/release/gbrt-rs.
Quick Start
Command Line Interface (Recommended for Data Scientists)
Train a model on housing data:
# Train with validation and categorical features
# Cross-validate to assess generalization
# Evaluate on test set
# Make predictions
# Analyze feature importance
Programmatic API (For Rust Developers)
use ;
CLI Reference
Core Commands
train - Train a new model
)
|
cross-validate - K-fold cross-validation
Outputs per-fold and mean R² (or accuracy for classification).
Analysis Commands
evaluate - Compute metrics on test data
Outputs RMSE, MAE, R², and sample count.
predict - Generate predictions
Writes predictions to CSV/JSON or stdout.
feature-importance - Analyze model
Shows top features by importance score.
info - Model metadata
Displays model type, features, trees, parameters, and training state.
Library API Reference
Core Types
-
GBRTModel: Main model structnew(): Regression modelnew_classifier(): Classification modelwith_config(config): Custom configurationfit(&dataset): Train on datafit_with_validation(&train, &val): Train with early stoppingpredict(&features) -> Vec<f64>: Predictionsfeature_importance() -> Vec<f64>: Importance scorestraining_history() -> Option<&TrainingState>: Training metrics
-
Dataset: Data containerfeatures() -> &FeatureMatrixtargets() -> &Arrayn_samples(),n_features()validate(): Check data integrity
-
DataLoader: CSV loading with encodingload_csv(path): Load basic CSVload_csv_with_categorical(path, target, cat_indices, threshold): With encodingcross_validation_splits(dataset, k): Generate CV folds
Configuration
use ;
let config = GBRTConfig ;
let model = with_config?;
Data Format Requirements
CSV Specifications
- Header row: Required with column names
- Target column: Must be specified via
--target - Numeric columns: Must contain valid numbers only
- Categorical columns: String values, auto-detected or explicitly specified
- Missing values: Not supported (will cause errors)
- Encoding: UTF-8 required
Categorical Feature Best Practices
Explicit Specification (Recommended):
Auto-Detection (Use with caution):
- Threshold = unique_values / total_samples (default 0.1)
- Columns with < 10% unique values become categorical
- Risk: May misclassify discrete numeric features
Why Explicit is Better:
- Guarantees consistent encoding across train/test
- Reproducible and self-documenting
- Avoids threshold boundary issues
Architecture Overview
src/
├── lib.rs # Public API & convenience functions
├── main.rs # CLI entry point & command handlers
├── core/ # Core types and configurations
│ ├── GBRTConfig # Main booster configuration
│ ├── TreeConfig # Tree-specific parameters
│ └── LossFunction # Objective enum
├── boosting/ # Gradient boosting engine
│ ├── GradientBooster # Core boosting algorithm
│ ├── BoosterFactory # Factory methods
│ └── TrainingState # Metrics tracking
├── tree/ # Decision tree components
│ ├── TreeBuilder # Builder pattern implementation
│ ├── BestSplitter # Split finding with binning
│ └── MSECriterion # Split quality computation
├── data/ # Data structures and preprocessing
│ ├── Dataset # Container for X and y
│ ├── FeatureMatrix # Feature storage with encoding
│ └── preprocessing # Categorical encoding logic
├── objective/ # Loss functions
│ ├── MSEObjective # Regression
│ └── LogLossObjective # Classification
└── io/ # I/O operations
├── DataLoader # CSV reading with encoding
├── ModelIO # JSON serialization
└── SaveOptions # Metadata storage
Key Design Decisions
- Builder Pattern:
TreeBuilderallows custom splitters/criteria - Trait Abstractions:
Splitter,Criterionenable experimentation - Error Propagation:
Result<T, GBRTError>for robust error handling - Zero-Copy:
FeatureMatrixavoids unnecessary allocations - Type Safety: Strong typing prevents invalid configurations
Performance Guidelines
Training Speed Optimization
# Faster, less accurate
# Balanced (recommended)
# Slower, more accurate
Memory Usage
- Each tree depth increases memory usage exponentially
- Large categorical encodings (many categories) increase memory
- Use
--subsample 0.8to reduce memory footprint
Prediction Latency
- Single-threaded prediction (currently)
- Linear in number of trees and depth
- For real-time: consider model pruning or quantization
Configuration Examples
Housing Price Regression
Stock Price Direction Classification
Large Dataset with Sampling
Troubleshooting & FAQ
Q: Why am I getting low R² scores?
A: Common causes:
- Categorical columns not specified → model sees raw strings as missing
- Too few trees (
--n-estimators) or too shallow (--max-depth) - Learning rate too high (try 0.05-0.1 for regression)
- Data leakage: features that include target information
Debug steps:
# Check model info
# Ensure encoded features count matches expectations
# Should see: Features: 8 (original) + N (encoded categories)
# Run cross-validation to detect overfitting/underfitting
Q: Why does evaluation show catastrophic R² (negative millions)?
A: This is the categorical encoding mismatch bug. The evaluation data wasn't encoded the same way as training data. Update to latest version where this is fixed.
Q: What's the difference between --validation and --test-split?
A:
--validation: Uses a separate file you provide (recommended for reproducibility)--test-split: Splits your training data internally (convenient but less control)
Q: How do I handle large datasets that don't fit in memory?
A: Current version loads all data into memory. For large datasets:
- Use
--subsample 0.5to train on subset - Reduce
--max-depthto limit tree size - Future version will add streaming support
Q: Do I need to specify --categorical-columns for every command?
A: No! Only for:
- ✅ train (stores in model metadata)
- ✅ cross-validate (trains new models each time)
Not needed for:
- ❌ evaluate (reads from model metadata)
- ❌ predict (reads from model metadata)
- ❌ feature-importance (uses stored feature names)
- ❌ info (displays stored metadata)
Q: Can I use this for time series forecasting?
A: GBRT is not ideal for time series because:
- No temporal awareness
- Assumes i.i.d. data
- Use specialized libraries like
prophetorstatsforecastinstead
Q: How do I interpret feature importance?
A: Importance = total reduction in loss contributed by splits on that feature. Important notes:
- Only shows predictive power, not causation
- Can be biased toward high-cardinality features
- Use SHAP values for more nuanced interpretation (planned feature)
Contributing
We welcome contributions! Please see CONTRIBUTING.md.
Development Setup
# Fork and clone
# Install Rust tools
# Setup pre-commit hooks
# Create feature branch
# Make changes ensuring:
# Submit PR
Code Standards
- Follow Rust naming conventions (snake_case, UpperCamelCase)
- Add tests for new functionality
- Update documentation
- Keep CI green
Roadmap
v0.2.0 (Current)
- ✅ Regression and binary classification
- ✅ Categorical feature encoding
- ✅ Early stopping
- ✅ Cross-validation
- ✅ Feature importance
v0.3.0 (In Progress)
- Multi-threaded training
- Multi-class classification
- Missing value imputation
- Hyperparameter search utilities
v0.4.0 (Future)
- GPU acceleration via CUDA
- SHAP value computation
- ONNX model export
- Python bindings
- Streaming for large datasets
Benchmarks
Performance on a 100,000 sample, 50 feature dataset:
| Library | Train Time | Predict Time | Memory | R² |
|---|---|---|---|---|
| gbrt-rs | 2.3s | 12ms | 85 MB | 0.847 |
| XGBoost | 2.8s | 18ms | 120 MB | 0.852 |
| LightGBM | 1.9s | 15ms | 95 MB | 0.849 |
| sklearn | 8.7s | 45ms | 280 MB | 0.831 |
Measured on: AMD Ryzen 9 5900X, 32GB RAM, single-threaded
License
This project is licensed under the MIT License - see LICENSE for details.
Citation
If you use GBRT-RS in research, please cite:
Acknowledgments
- Algorithm: Based on "Greedy Function Approximation: A Gradient Boosting Machine" (Friedman, 2001)
- Inspiration: XGBoost, LightGBM, CatBoost
- Rust Ecosystem: Built with
ndarray,serde,clap,csv,thiserror
Contact & Support
- Issues: GitHub Issues
- Discussions: GitHub Discussions
- Email: yahayajelil@yahoo.com
Happy Boosting! 🚀
Made with ❤️ in Rust