MesaTEE GBDT-RS : a fast and secure GBDT library, supporting TEEs such as Intel SGX and ARM TrustZone
MesaTEE GBDT-RS is a gradient boost decision tree library written in Safe Rust. There is no unsafe rust code in the library.
MesaTEE GBDT-RS provides the training and inference capabilities. And it can use the models trained by xgboost to do inference tasks.
New! The MesaTEE GBDT-RS paper has been accepted by IEEE S&P'19!
Supported Task
Supppoted task for both training and inference
- Linear regression: use SquaredError and LAD loss types
- Binary classification (labeled with 1 and -1): use LogLikelyhood loss type
Compatibility with xgboost
At this time, MesaTEE GBDT-RS support to use model trained by xgboost to do inference. The model should be trained by xgboost with following configruation:
- booster: gbtree
- objective: "reg:linear", "reg:logistic", "binary:logistic", "binary:logitraw", "multi:softprob", "multi:softmax" or "rank:pairwise".
We have tested that MesaTEE GBDT-RS is compatible with xgboost 0.81 and 0.82
Quick Start
Training Steps
- Set configuration
- Load training data
- Train the model
- (optional) Save the model
Inference Steps
- Load the model
- Load the test data
- Inference the test data
Example
use Config;
use ;
use GBDT;
use ;
let mut cfg = new;
cfg.set_feature_size;
cfg.set_max_depth;
cfg.set_iterations;
cfg.set_shrinkage;
cfg.set_loss;
cfg.set_debug;
cfg.set_data_sample_ratio;
cfg.set_feature_sample_ratio;
cfg.set_training_optimization_level;
// load data
let train_file = "dataset/agaricus-lepiota/train.txt";
let test_file = "dataset/agaricus-lepiota/test.txt";
let mut input_format = csv_format;
input_format.set_feature_size;
input_format.set_label_index;
let mut train_dv: DataVec = load.expect;
let test_dv: DataVec = load.expect;
// train and save model
let mut gbdt = GBDT new;
gbdt.fit;
gbdt.save_model.expect;
// load model and do inference
let model = GBDT load_model.expect;
let predicted: PredVec = model.predict;
Example code
- Linear regression: examples/iris.rs
- Binary classification: examples/agaricus-lepiota.rs
Use models trained by xgboost
Steps
- Use xgboost to train a model
- Use examples/convert_xgboost.py to convert the model
- Usage: python convert_xgboost.py xgboost_model_path objective output_path
- Note convert_xgboost.py depends on xgboost python libraries. The converted model can be used on machines without xgboost
- In rust code, call GBDT::load_from_xgboost(model_path, objective) to load the model
- Do inference
- (optional) Call GBDT::save_model to save the model to MesaTEE GBDT-RS native format.
Example code
- "reg:linear": examples/test-xgb-reg-linear.rs
- "reg:logistic": examples/test-xgb-reg-logistic.rs
- "binary:logistic": examples/test-xgb-binary-logistic.rs
- "binary:logitraw": examples/test-xgb-binary-logistic.rs
- "multi:softprob": examples/test-xgb-multi-softprob.rs
- "multi:softmax": examples/test-xgb-multi-softmax.rs
- "rank:pairwise": examples/test-xgb-rank-pairwise.rs
Multi-threading
Training:
At this time, training in MesaTEE GBDT-RS is single-threaded.
Inference:
The related inference functions are single-threaded. But they are thread-safe. We provide an inference example using multi threads in example/test-multithreads.rs
SGX usage
Because MesaTEE GBDT-RS is written in pure rust, with the help of rust-sgx-sdk, it can be used in sgx enclave easily as:
gbdt_sgx = { git = "https://github.com/mesalock-linux/gbdt-rs" }
This would import a crate named gbdt_sgx
. If you prefer gbdt
as normal:
gbdt = { package = "gbdt_sgx", git = "https://github.com/mesalock-linux/gbdt-rs" }
For more information and concret examples, please look at directory sgx/gbdt-sgx-test
.
License
Apache 2.0
Authors
Tianyi Li @n0b0dyCN n0b0dypku@gmail.com
Tongxin Li @litongxin1991 litongxin1991@gmail.com
Yu Ding @dingelish dingelish@gmail.com
Steering Committee
Tao Wei, Yulong Zhang
Acknowledgment
Thanks to @qiyiping for his/her great previous work gbdt. We read his/her code before starting this project.