use serde_json::Value;
use std::os::raw::{c_char, c_longlong, c_void};
use std::{convert::TryInto, ffi::CString};
use crate::{dataset::DType, Dataset, Error, Result};
use lightgbm3_sys::BoosterHandle;
pub struct Booster {
handle: BoosterHandle,
n_features: i32,
n_iterations: i32, max_iterations: i32, n_classes: i32,
}
enum PredictType {
Normal,
RawScore,
Contrib,
}
pub enum ImportanceType {
Split,
Gain,
}
impl Booster {
fn new(handle: BoosterHandle) -> Result<Self> {
let mut booster = Booster {
handle,
n_features: 0,
n_iterations: 0,
max_iterations: 0,
n_classes: 0,
};
booster.n_features = booster.inner_num_features()?;
booster.n_iterations = booster.inner_num_iterations()?;
booster.max_iterations = booster.n_iterations;
booster.n_classes = booster.inner_num_classes()?;
Ok(booster)
}
pub fn from_file(filename: &str) -> Result<Self> {
let filename_str = CString::new(filename).unwrap();
let mut out_num_iterations = 0;
let mut handle = std::ptr::null_mut();
lgbm_call!(lightgbm3_sys::LGBM_BoosterCreateFromModelfile(
filename_str.as_ptr(),
&mut out_num_iterations,
&mut handle
))?;
Booster::new(handle)
}
pub fn from_string(model_description: &str) -> Result<Self> {
let cstring = CString::new(model_description).unwrap();
let mut out_num_iterations = 0;
let mut handle = std::ptr::null_mut();
lgbm_call!(lightgbm3_sys::LGBM_BoosterLoadModelFromString(
cstring.as_ptr(),
&mut out_num_iterations,
&mut handle
))?;
Booster::new(handle)
}
pub fn save_file(&self, filename: &str) -> Result<()> {
let filename_str = CString::new(filename).unwrap();
lgbm_call!(lightgbm3_sys::LGBM_BoosterSaveModel(
self.handle,
0_i32,
-1_i32,
0_i32,
filename_str.as_ptr(),
))?;
Ok(())
}
pub fn save_string(&self) -> Result<String> {
let mut out_size = 0_i64;
lgbm_call!(lightgbm3_sys::LGBM_BoosterSaveModelToString(
self.handle,
0_i32,
-1_i32,
0_i32,
0,
&mut out_size,
std::ptr::null_mut(),
))?;
let mut buffer = vec![
0u8;
out_size
.try_into()
.map_err(|_| Error::new("size negative"))?
];
lgbm_call!(lightgbm3_sys::LGBM_BoosterSaveModelToString(
self.handle,
0_i32,
-1_i32,
0_i32,
buffer.len() as c_longlong,
&mut out_size,
buffer.as_mut_ptr() as *mut c_char
))?;
if buffer.pop() != Some(0) {
panic!("write out of bounds happened in lightgbm call");
}
let cstring = CString::new(buffer).map_err(|e| Error::new(e.to_string()))?;
cstring
.into_string()
.map_err(|_| Error::new("can't convert model string to unicode"))
}
pub fn num_classes(&self) -> i32 {
self.n_classes
}
pub fn num_features(&self) -> i32 {
self.n_features
}
pub fn num_iterations(&self) -> i32 {
self.n_iterations
}
pub fn max_iterations(&self) -> i32 {
self.max_iterations
}
pub fn set_max_iterations(&mut self, max_iterations: i32) -> Result<()> {
if max_iterations > self.n_iterations {
return Err(Error::new(format!(
"max_iterations for prediction ({max_iterations})\
should not exceed the number of trees in the booster ({})",
self.n_iterations
)));
}
self.max_iterations = max_iterations;
Ok(())
}
pub fn train(dataset: Dataset, parameters: &Value) -> Result<Self> {
Self::train_with_valid(dataset, None, parameters)
}
pub fn train_with_valid(
dataset: Dataset,
valid_dataset: Option<Dataset>,
parameters: &Value,
) -> Result<Self> {
let num_iterations: i64 = parameters["num_iterations"].as_i64().unwrap_or(100);
let early_stopping_rounds: Option<i64> = parameters["early_stopping_rounds"].as_i64();
let params_string = parameters
.as_object()
.unwrap()
.iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join(" ");
let params_cstring = CString::new(params_string).unwrap();
let mut handle = std::ptr::null_mut();
lgbm_call!(lightgbm3_sys::LGBM_BoosterCreate(
dataset.handle,
params_cstring.as_ptr(),
&mut handle
))?;
if let Some(ref valid) = valid_dataset {
lgbm_call!(lightgbm3_sys::LGBM_BoosterAddValidData(
handle,
valid.handle
))?;
}
let mut is_finished: i32 = 0;
let mut best_score: Option<f64> = None;
let mut rounds_without_improvement: i64 = 0;
let has_valid = valid_dataset.is_some();
let do_early_stopping = has_valid && early_stopping_rounds.is_some();
for iter in 1..=num_iterations {
lgbm_call!(lightgbm3_sys::LGBM_BoosterUpdateOneIter(
handle,
&mut is_finished
))?;
if is_finished != 0 {
break;
}
if do_early_stopping {
let early_stop_rounds = early_stopping_rounds.unwrap();
let eval_results = Self::get_eval_at(handle, 1)?;
if let Some(¤t_score) = eval_results.first() {
let improved = match best_score {
None => true,
Some(best) => Self::is_score_better(current_score, best, parameters),
};
if improved {
best_score = Some(current_score);
rounds_without_improvement = 0;
} else {
rounds_without_improvement += 1;
if rounds_without_improvement >= early_stop_rounds {
let rollback_iters = early_stop_rounds.min(iter);
for _ in 0..rollback_iters {
lgbm_call!(lightgbm3_sys::LGBM_BoosterRollbackOneIter(handle))?;
}
break;
}
}
}
}
}
Booster::new(handle)
}
fn is_score_better(current: f64, best: f64, parameters: &Value) -> bool {
let metric = parameters["metric"].as_str().unwrap_or("binary_logloss");
let higher_is_better = matches!(
metric,
"auc" | "average_precision" | "map" | "ndcg" | "accuracy"
);
if higher_is_better {
current > best
} else {
current < best
}
}
fn get_eval_at(handle: BoosterHandle, data_idx: i32) -> Result<Vec<f64>> {
let mut num_evals: i32 = 0;
lgbm_call!(lightgbm3_sys::LGBM_BoosterGetEvalCounts(
handle,
&mut num_evals
))?;
if num_evals == 0 {
return Ok(vec![]);
}
let mut out_len: i32 = 0;
let mut results: Vec<f64> = vec![0.0; num_evals as usize];
lgbm_call!(lightgbm3_sys::LGBM_BoosterGetEval(
handle,
data_idx,
&mut out_len,
results.as_mut_ptr()
))?;
results.truncate(out_len as usize);
Ok(results)
}
fn real_predict<T: DType>(
&self,
flat_x: &[T],
n_features: i32,
is_row_major: bool,
predict_type: PredictType,
parameters: Option<&str>,
) -> Result<Vec<f64>> {
if self.n_features <= 0 {
return Err(Error::new("n_features should be greater than 0"));
}
if self.n_iterations <= 0 {
return Err(Error::new("n_iterations should be greater than 0"));
}
if n_features != self.n_features {
return Err(Error::new(
format!("Number of features in data ({}) doesn't match the number of features in booster ({})",
n_features,
self.n_features)
));
}
if flat_x.len() % n_features as usize != 0 {
return Err(Error::new(format!(
"Invalid length of data: data.len()={}, n_features={}",
flat_x.len(),
n_features
)));
}
let n_rows = flat_x.len() / n_features as usize;
let params_cstring = parameters
.map(CString::new)
.unwrap_or(CString::new(""))
.unwrap();
let mut out_length: c_longlong = 0;
let output_size = match predict_type {
PredictType::Contrib => n_rows * (self.n_features + 1) as usize,
_ => n_rows * self.n_classes as usize,
};
let mut out_result: Vec<f64> = vec![Default::default(); output_size];
lgbm_call!(lightgbm3_sys::LGBM_BoosterPredictForMat(
self.handle,
flat_x.as_ptr() as *const c_void,
T::get_c_api_dtype(),
n_rows as i32,
n_features,
if is_row_major { 1_i32 } else { 0_i32 }, predict_type.into(), 0_i32, self.max_iterations, params_cstring.as_ptr(),
&mut out_length,
out_result.as_mut_ptr()
))?;
Ok(out_result)
}
pub fn predict<T: DType>(
&self,
flat_x: &[T],
n_features: i32,
is_row_major: bool,
) -> Result<Vec<f64>> {
self.real_predict(flat_x, n_features, is_row_major, PredictType::Normal, None)
}
pub fn predict_with_params<T: DType>(
&self,
flat_x: &[T],
n_features: i32,
is_row_major: bool,
params: &str,
) -> Result<Vec<f64>> {
self.real_predict(
flat_x,
n_features,
is_row_major,
PredictType::Normal,
Some(params),
)
}
pub fn predict_contrib<T: DType>(
&self,
flat_x: &[T],
n_features: i32,
is_row_major: bool,
) -> Result<Vec<f64>> {
self.real_predict(flat_x, n_features, is_row_major, PredictType::Contrib, None)
}
pub fn raw_scores<T: DType>(
&self,
flat_x: &[T],
n_features: i32,
is_row_major: bool,
) -> Result<Vec<f64>> {
self.real_predict(
flat_x,
n_features,
is_row_major,
PredictType::RawScore,
None,
)
}
pub fn raw_scores_with_params<T: DType>(
&self,
flat_x: &[T],
n_features: i32,
is_row_major: bool,
parameters: &str,
) -> Result<Vec<f64>> {
self.real_predict(
flat_x,
n_features,
is_row_major,
PredictType::RawScore,
Some(parameters),
)
}
pub fn predict_from_vec_of_vec<T: DType>(
&self,
x: Vec<Vec<T>>,
is_row_major: bool,
) -> Result<Vec<Vec<f64>>> {
if x.is_empty() || x[0].is_empty() {
return Err(Error::new("x is empty"));
}
let n_features = match is_row_major {
true => x[0].len() as i32,
false => x.len() as i32,
};
let flat_x = x.into_iter().flatten().collect::<Vec<T>>();
let pred_y = self.predict(&flat_x, n_features, is_row_major)?;
Ok(pred_y
.chunks(self.n_classes as usize)
.map(|x| x.to_vec())
.collect())
}
fn inner_num_classes(&self) -> Result<i32> {
let mut num_classes = 0;
lgbm_call!(lightgbm3_sys::LGBM_BoosterGetNumClasses(
self.handle,
&mut num_classes
))?;
Ok(num_classes)
}
fn inner_num_features(&self) -> Result<i32> {
let mut num_features = 0;
lgbm_call!(lightgbm3_sys::LGBM_BoosterGetNumFeature(
self.handle,
&mut num_features
))?;
Ok(num_features)
}
fn inner_num_iterations(&self) -> Result<i32> {
let mut cur_iteration: i32 = 0;
lgbm_call!(lightgbm3_sys::LGBM_BoosterGetCurrentIteration(
self.handle,
&mut cur_iteration
))?;
Ok(cur_iteration)
}
pub fn feature_name(&self) -> Result<Vec<String>> {
let num_feature = self.inner_num_features()?;
let feature_name_length = 64;
let mut num_feature_names = 0;
let mut out_buffer_len = 0;
let out_strs = (0..num_feature)
.map(|_| {
CString::new(" ".repeat(feature_name_length))
.unwrap()
.into_raw()
})
.collect::<Vec<_>>();
lgbm_call!(lightgbm3_sys::LGBM_BoosterGetFeatureNames(
self.handle,
num_feature,
&mut num_feature_names,
feature_name_length,
&mut out_buffer_len,
out_strs.as_ptr() as *mut *mut c_char
))?;
let output: Vec<String> = out_strs
.into_iter()
.map(|s| unsafe { CString::from_raw(s).into_string().unwrap() })
.collect();
Ok(output)
}
pub fn feature_importance(&self, importance_type: ImportanceType) -> Result<Vec<f64>> {
let num_feature = self.inner_num_features()?;
let mut out_result: Vec<f64> = vec![Default::default(); num_feature as usize];
lgbm_call!(lightgbm3_sys::LGBM_BoosterFeatureImportance(
self.handle,
0_i32,
importance_type.into(),
out_result.as_mut_ptr()
))?;
Ok(out_result)
}
}
impl Drop for Booster {
fn drop(&mut self) {
lgbm_call!(lightgbm3_sys::LGBM_BoosterFree(self.handle)).unwrap();
}
}
impl From<ImportanceType> for i32 {
fn from(value: ImportanceType) -> Self {
match value {
ImportanceType::Split => lightgbm3_sys::C_API_FEATURE_IMPORTANCE_SPLIT as i32,
ImportanceType::Gain => lightgbm3_sys::C_API_FEATURE_IMPORTANCE_GAIN as i32,
}
}
}
impl From<PredictType> for i32 {
fn from(value: PredictType) -> Self {
match value {
PredictType::Normal => lightgbm3_sys::C_API_PREDICT_NORMAL as i32,
PredictType::RawScore => lightgbm3_sys::C_API_PREDICT_RAW_SCORE as i32,
PredictType::Contrib => lightgbm3_sys::C_API_PREDICT_CONTRIB as i32,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::{fs, path::Path};
const TMP_FOLDER: &str = "./target/tmp";
fn _read_train_file() -> Result<Dataset> {
Dataset::from_file("lightgbm3-sys/lightgbm/examples/binary_classification/binary.train")
}
fn _train_booster(params: &Value) -> Booster {
let dataset = _read_train_file().unwrap();
Booster::train(dataset, params).unwrap()
}
fn _default_params() -> Value {
let params = json! {
{
"num_iterations": 1,
"objective": "binary",
"metric": "auc",
"data_random_seed": 0
}
};
params
}
#[test]
fn predict_from_vec_of_vec() {
let params = json! {
{
"num_iterations": 10,
"objective": "binary",
"metric": "auc",
"data_random_seed": 0
}
};
let bst = _train_booster(¶ms);
let feature = vec![vec![0.5; 28], vec![0.0; 28], vec![0.9; 28]];
let result = bst.predict_from_vec_of_vec(feature, true).unwrap();
let mut normalized_result = Vec::new();
for r in &result {
normalized_result.push(if r[0] > 0.5 { 1 } else { 0 });
}
assert_eq!(normalized_result, vec![0, 0, 1]);
}
#[test]
fn predict_with_params() {
let params = json! {
{
"num_iterations": 10,
"objective": "binary",
"metric": "auc",
"data_random_seed": 0
}
};
let bst = _train_booster(¶ms);
let mut feature = [0.0; 28 * 3];
for i in 0..28 {
feature[i] = 0.5;
}
for i in 56..feature.len() {
feature[i] = 0.9;
}
let result = bst
.predict_with_params(&feature, 28, true, "num_threads=1")
.unwrap();
let mut normalized_result = Vec::new();
for r in &result {
normalized_result.push(if *r > 0.5 { 1 } else { 0 });
}
assert_eq!(normalized_result, vec![0, 0, 1]);
}
#[test]
fn num_feature() {
let params = _default_params();
let bst = _train_booster(¶ms);
let num_feature = bst.inner_num_features().unwrap();
assert_eq!(num_feature, 28);
}
#[test]
fn feature_importance() {
let params = _default_params();
let bst = _train_booster(¶ms);
let feature_importance = bst.feature_importance(ImportanceType::Gain).unwrap();
assert_eq!(feature_importance.len(), bst.n_features as usize);
assert!(feature_importance.iter().sum::<f64>() > 0.0);
}
#[test]
fn feature_name() {
let params = _default_params();
let bst = _train_booster(¶ms);
let feature_name = bst.feature_name().unwrap();
let target = (0..28).map(|i| format!("Column_{}", i)).collect::<Vec<_>>();
assert_eq!(feature_name, target);
}
#[test]
fn save_file() {
let params = _default_params();
let bst = _train_booster(¶ms);
let _ = fs::create_dir(TMP_FOLDER);
let filename = format!("{TMP_FOLDER}/model1.lgb");
assert!(bst.save_file(&filename).is_ok());
assert!(Path::new(&filename).exists());
assert!(Booster::from_file(&filename).is_ok());
assert!(fs::remove_file(&filename).is_ok());
}
#[test]
fn save_string() {
let params = _default_params();
let bst = _train_booster(¶ms);
let _ = fs::create_dir(TMP_FOLDER);
let filename = format!("{TMP_FOLDER}/model2.lgb");
assert_eq!(bst.save_file(&filename), Ok(()));
assert!(Path::new(&filename).exists());
let booster_file_content = fs::read_to_string(&filename).unwrap();
assert!(fs::remove_file(&filename).is_ok());
assert!(!booster_file_content.is_empty());
assert_eq!(Ok(booster_file_content.clone()), bst.save_string());
assert!(Booster::from_string(&booster_file_content).is_ok());
}
fn _read_test_file_with_ref(train: &Dataset) -> Result<Dataset> {
Dataset::from_file_with_reference(
"lightgbm3-sys/lightgbm/examples/binary_classification/binary.test",
Some(train),
)
}
#[test]
fn train_with_valid_no_early_stopping() {
let train_dataset = _read_train_file().unwrap();
let valid_dataset = _read_test_file_with_ref(&train_dataset).unwrap();
let params = json! {
{
"num_iterations": 10,
"objective": "binary",
"metric": "auc",
"data_random_seed": 0
}
};
let bst = Booster::train_with_valid(train_dataset, Some(valid_dataset), ¶ms).unwrap();
assert_eq!(bst.num_iterations(), 10);
}
#[test]
fn train_with_early_stopping() {
let train_dataset = _read_train_file().unwrap();
let valid_dataset = _read_test_file_with_ref(&train_dataset).unwrap();
let params = json! {
{
"num_iterations": 100,
"objective": "binary",
"metric": "auc",
"data_random_seed": 0,
"early_stopping_rounds": 5
}
};
let bst = Booster::train_with_valid(train_dataset, Some(valid_dataset), ¶ms).unwrap();
assert!(
bst.num_iterations() < 100,
"Expected early stopping to trigger before 100 iterations, got {}",
bst.num_iterations()
);
}
#[test]
fn train_with_early_stopping_logloss() {
let train_dataset = _read_train_file().unwrap();
let valid_dataset = _read_test_file_with_ref(&train_dataset).unwrap();
let params = json! {
{
"num_iterations": 100,
"objective": "binary",
"metric": "binary_logloss",
"data_random_seed": 0,
"early_stopping_rounds": 5
}
};
let bst = Booster::train_with_valid(train_dataset, Some(valid_dataset), ¶ms).unwrap();
assert!(bst.num_iterations() > 0);
}
}