use indexmap::IndexMap;
use log::debug;
use std::collections::{BTreeMap, HashMap};
use std::io::{self, BufRead, BufReader, Write};
use std::os::unix::ffi::OsStrExt;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::{ffi, fmt, fs::File, ptr, slice};
use crate::parameters::{BoosterParameters, TrainingParameters};
use crate::{XGBError, XGBResult};
use super::DMatrix;
pub type CustomObjective = fn(&[f32], &DMatrix) -> (Vec<f32>, Vec<f32>);
enum PredictOption {
OutputMargin,
PredictLeaf,
PredictContribitions,
PredictInteractions,
}
impl PredictOption {
fn options_as_mask(options: &[PredictOption]) -> i32 {
let mut option_mask = 0x00;
for option in options {
let value = match *option {
PredictOption::OutputMargin => 0x01,
PredictOption::PredictLeaf => 0x02,
PredictOption::PredictContribitions => 0x04,
PredictOption::PredictInteractions => 0x10,
};
option_mask |= value;
}
option_mask
}
}
#[derive(Clone)]
pub struct Booster {
handle: xgboost_rs_sys::BoosterHandle,
}
unsafe impl Send for Booster {}
unsafe impl Sync for Booster {}
impl Booster {
pub fn new(params: &BoosterParameters) -> XGBResult<Self> {
Self::new_with_cached_dmats(params, &[])
}
pub fn new_with_json_config(
dmats: &[&DMatrix],
config: HashMap<&str, &str>,
) -> XGBResult<Self> {
let mut handle = ptr::null_mut();
let s: Vec<xgboost_rs_sys::DMatrixHandle> = dmats.iter().map(|x| x.handle).collect();
xgb_call!(xgboost_rs_sys::XGBoosterCreate(
s.as_ptr(),
dmats.len() as u64,
&mut handle
))?;
let mut booster = Booster { handle };
booster.set_param_from_json(config);
Ok(booster)
}
pub fn new_with_cached_dmats(
params: &BoosterParameters,
dmats: &[&DMatrix],
) -> XGBResult<Self> {
let mut handle = ptr::null_mut();
let s: Vec<xgboost_rs_sys::DMatrixHandle> = dmats.iter().map(|x| x.handle).collect();
xgb_call!(xgboost_rs_sys::XGBoosterCreate(
s.as_ptr(),
dmats.len() as u64,
&mut handle
))?;
let mut booster = Booster { handle };
booster.set_params(params)?;
Ok(booster)
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> XGBResult<()> {
debug!("Writing Booster to: {}", path.as_ref().display());
let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
xgb_call!(xgboost_rs_sys::XGBoosterSaveModel(
self.handle,
fname.as_ptr()
))
}
pub fn load<P: AsRef<Path>>(path: P) -> XGBResult<Self> {
debug!("Loading Booster from: {}", path.as_ref().display());
if !path.as_ref().exists() {
return Err(XGBError::new(format!(
"File not found: {}",
path.as_ref().display()
)));
}
let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
let mut handle = ptr::null_mut();
xgb_call!(xgboost_rs_sys::XGBoosterCreate(ptr::null(), 0, &mut handle))?;
xgb_call!(xgboost_rs_sys::XGBoosterLoadModel(handle, fname.as_ptr()))?;
Ok(Booster { handle })
}
pub fn load_buffer(bytes: &[u8]) -> XGBResult<Self> {
debug!("Loading Booster from buffer (length = {})", bytes.len());
let mut handle = ptr::null_mut();
xgb_call!(xgboost_rs_sys::XGBoosterCreate(ptr::null(), 0, &mut handle))?;
xgb_call!(xgboost_rs_sys::XGBoosterLoadModelFromBuffer(
handle,
bytes.as_ptr().cast(),
bytes.len() as u64
))?;
Ok(Booster { handle })
}
pub fn train_increment(params: &TrainingParameters, model_name: &str) -> XGBResult<Self> {
let mut dmats = vec![params.dtrain];
if let Some(eval_sets) = params.evaluation_sets {
for (dmat, _) in eval_sets {
dmats.push(*dmat);
}
}
let path = Path::new(model_name);
let bytes = std::fs::read(path).expect("can't read saved booster file");
let mut bst = Booster::load_buffer(&bytes[..]).expect("can't load booster from buffer");
let version = bst.load_rabit_checkpoint()?;
debug!("Loaded Rabit checkpoint: version={}", version);
assert!(unsafe { xgboost_rs_sys::RabitGetWorldSize() != 1 || version == 0 });
unsafe { xgboost_rs_sys::RabitGetRank() };
let start_iteration = version / 2;
for i in start_iteration..params.boost_rounds as i32 {
if version % 2 == 0 {
if let Some(objective_fn) = params.custom_objective_fn {
debug!("Boosting in round: {}", i);
bst.update_custom(params.dtrain, objective_fn)?;
} else {
debug!("Updating in round: {}", i);
bst.update(params.dtrain, i)?;
}
bst.save_rabit_checkpoint()?;
}
assert!(unsafe {
xgboost_rs_sys::RabitGetWorldSize() == 1
|| version == xgboost_rs_sys::RabitVersionNumber()
});
if let Some(eval_sets) = params.evaluation_sets {
let mut dmat_eval_results = bst.eval_set(eval_sets, i)?;
if let Some(eval_fn) = params.custom_evaluation_fn {
let eval_name = "custom";
for (dmat, dmat_name) in eval_sets {
let margin = bst.predict_margin(dmat)?;
let eval_result = eval_fn(&margin, dmat);
let eval_results = dmat_eval_results
.entry(eval_name.to_string())
.or_insert_with(IndexMap::new);
eval_results.insert(String::from(*dmat_name), eval_result);
}
}
let mut eval_dmat_results = BTreeMap::new();
for (dmat_name, eval_results) in &dmat_eval_results {
for (eval_name, result) in eval_results {
let dmat_results = eval_dmat_results
.entry(eval_name)
.or_insert_with(BTreeMap::new);
dmat_results.insert(dmat_name, result);
}
}
}
}
Ok(bst)
}
pub fn train(
evaluation_sets: Option<&[(&DMatrix, &str)]>,
dtrain: &DMatrix,
config: HashMap<&str, &str>,
bst: Option<Booster>,
) -> XGBResult<Self> {
let cached_dmats = {
let mut dmats = vec![dtrain];
if let Some(eval_sets) = evaluation_sets {
for (dmat, _) in eval_sets {
dmats.push(*dmat);
}
}
dmats
};
let mut bst: Booster = {
if let Some(booster) = bst {
let mut length: u64 = 0;
let mut buffer_string = ptr::null();
xgb_call!(xgboost_rs_sys::XGBoosterSerializeToBuffer(
booster.handle,
&mut length,
&mut buffer_string
))
.expect("couldn't serialize to buffer!");
let mut bst_handle = ptr::null_mut();
let cached_dmat_handles: Vec<xgboost_rs_sys::DMatrixHandle> =
cached_dmats.iter().map(|x| x.handle).collect();
xgb_call!(xgboost_rs_sys::XGBoosterCreate(
cached_dmat_handles.as_ptr(),
cached_dmats.len() as u64,
&mut bst_handle
))?;
let mut bst_unserialize = Booster { handle: bst_handle };
xgb_call!(xgboost_rs_sys::XGBoosterUnserializeFromBuffer(
bst_unserialize.handle,
buffer_string as *mut ffi::c_void,
length,
))
.expect("couldn't unserialize from buffer!");
bst_unserialize.set_param_from_json(config);
bst_unserialize
} else {
Booster::new_with_json_config(&cached_dmats, config)?
}
};
for i in 0..16 {
bst.update(dtrain, i)?;
if let Some(eval_sets) = evaluation_sets {
let dmat_eval_results = bst.eval_set(eval_sets, i)?;
let mut eval_dmat_results = BTreeMap::new();
for (dmat_name, eval_results) in &dmat_eval_results {
for (eval_name, result) in eval_results {
let dmat_results = eval_dmat_results
.entry(eval_name)
.or_insert_with(BTreeMap::new);
dmat_results.insert(dmat_name, result);
}
}
}
}
Ok(bst)
}
pub fn save_config(&self) -> String {
let mut length: u64 = 1;
let mut json_string = ptr::null();
let _json = unsafe {
xgboost_rs_sys::XGBoosterSaveJsonConfig(self.handle, &mut length, &mut json_string)
};
let out = unsafe {
ffi::CStr::from_ptr(json_string)
.to_str()
.unwrap()
.to_owned()
};
out
}
pub fn set_params(&mut self, p: &BoosterParameters) -> XGBResult<()> {
for (key, value) in p.as_string_pairs() {
self.set_param(&key, &value)?;
}
Ok(())
}
pub fn update(&mut self, dtrain: &DMatrix, iteration: i32) -> XGBResult<()> {
xgb_call!(xgboost_rs_sys::XGBoosterUpdateOneIter(
self.handle,
iteration,
dtrain.handle
))
}
pub fn update_custom(
&mut self,
dtrain: &DMatrix,
objective_fn: CustomObjective,
) -> XGBResult<()> {
let pred = self.predict(dtrain)?;
let (gradient, hessian) = objective_fn(&pred, dtrain);
self.boost(dtrain, &gradient, &hessian)
}
fn boost(&mut self, dtrain: &DMatrix, gradient: &[f32], hessian: &[f32]) -> XGBResult<()> {
if gradient.len() != hessian.len() {
let msg = format!(
"Mismatch between length of gradient and hessian arrays ({} != {})",
gradient.len(),
hessian.len()
);
return Err(XGBError::new(msg));
}
assert_eq!(gradient.len(), hessian.len());
let mut grad_vec = gradient.to_vec();
let mut hess_vec = hessian.to_vec();
xgb_call!(xgboost_rs_sys::XGBoosterBoostOneIter(
self.handle,
dtrain.handle,
grad_vec.as_mut_ptr(),
hess_vec.as_mut_ptr(),
grad_vec.len() as u64
))
}
fn eval_set(
&self,
evals: &[(&DMatrix, &str)],
iteration: i32,
) -> XGBResult<IndexMap<String, IndexMap<String, f32>>> {
let (dmats, names) = {
let mut dmats = Vec::with_capacity(evals.len());
let mut names = Vec::with_capacity(evals.len());
for (dmat, name) in evals {
dmats.push(dmat);
names.push(*name);
}
(dmats, names)
};
assert_eq!(dmats.len(), names.len());
let mut s: Vec<xgboost_rs_sys::DMatrixHandle> = dmats.iter().map(|x| x.handle).collect();
let mut evnames: Vec<ffi::CString> = Vec::with_capacity(names.len());
let mut evptrs: Vec<*const libc::c_char> = Vec::with_capacity(names.len());
for name in &names {
let cstr = ffi::CString::new(*name).unwrap();
evptrs.push(cstr.as_ptr());
evnames.push(cstr);
}
evptrs.shrink_to_fit();
let mut out_result = ptr::null();
xgb_call!(xgboost_rs_sys::XGBoosterEvalOneIter(
self.handle,
iteration,
s.as_mut_ptr(),
evptrs.as_mut_ptr(),
dmats.len() as u64,
&mut out_result
))?;
let out = unsafe { ffi::CStr::from_ptr(out_result).to_str().unwrap().to_owned() };
Ok(Booster::parse_eval_string(&out, &names))
}
pub fn evaluate(&self, dmat: &DMatrix, name: &str) -> XGBResult<HashMap<String, f32>> {
let mut eval = self.eval_set(&[(dmat, name)], 0)?;
let mut result = HashMap::new();
eval.remove(name).unwrap().into_iter().for_each(|(k, v)| {
result.insert(k, v);
});
Ok(result)
}
pub fn get_attribute(&self, key: &str) -> XGBResult<Option<String>> {
let key = ffi::CString::new(key).unwrap();
let mut out_buf = ptr::null();
let mut success = 0;
xgb_call!(xgboost_rs_sys::XGBoosterGetAttr(
self.handle,
key.as_ptr(),
&mut out_buf,
&mut success
))?;
if success == 0 {
return Ok(None);
}
assert!(success == 1);
let c_str: &ffi::CStr = unsafe { ffi::CStr::from_ptr(out_buf) };
let out = c_str.to_str().unwrap();
Ok(Some(out.to_owned()))
}
pub fn set_attribute(&mut self, key: &str, value: &str) -> XGBResult<()> {
let key = ffi::CString::new(key).unwrap();
let value = ffi::CString::new(value).unwrap();
xgb_call!(xgboost_rs_sys::XGBoosterSetAttr(
self.handle,
key.as_ptr(),
value.as_ptr()
))
}
pub fn get_attribute_names(&self) -> XGBResult<Vec<String>> {
let mut out_len = 0;
let mut out = ptr::null_mut();
xgb_call!(xgboost_rs_sys::XGBoosterGetAttrNames(
self.handle,
&mut out_len,
&mut out
))?;
let out_ptr_slice = unsafe { slice::from_raw_parts(out, out_len as usize) };
let out_vec = out_ptr_slice
.iter()
.map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() })
.collect();
Ok(out_vec)
}
pub fn predict_from_dmat(
&self,
dmat: &DMatrix,
out_shape: &[u64; 2],
out_dim: &mut u64,
) -> XGBResult<Vec<f32>> {
let json_config = "{\"type\": 0,\"training\": false,\"iteration_begin\": 0,\"iteration_end\": 0,\"strict_shape\": true}".to_string();
let mut out_result = ptr::null();
let c_json_config = ffi::CString::new(json_config).unwrap();
xgb_call!(xgboost_rs_sys::XGBoosterPredictFromDMatrix(
self.handle,
dmat.handle,
c_json_config.as_ptr(),
&mut out_shape.as_ptr(),
out_dim,
&mut out_result
))?;
let out_len = out_shape[0];
assert!(!out_result.is_null());
let data = unsafe { slice::from_raw_parts(out_result, out_len as usize).to_vec() };
Ok(data)
}
pub fn predict(&self, dmat: &DMatrix) -> XGBResult<Vec<f32>> {
let option_mask = PredictOption::options_as_mask(&[]);
let ntree_limit = 0;
let mut out_len = 0;
let mut out_result = ptr::null();
xgb_call!(xgboost_rs_sys::XGBoosterPredict(
self.handle,
dmat.handle,
option_mask,
ntree_limit,
0,
&mut out_len,
&mut out_result
))?;
assert!(!out_result.is_null());
let data = unsafe { slice::from_raw_parts(out_result, out_len as usize).to_vec() };
Ok(data)
}
pub fn predict_margin(&self, dmat: &DMatrix) -> XGBResult<Vec<f32>> {
let option_mask = PredictOption::options_as_mask(&[PredictOption::OutputMargin]);
let ntree_limit = 0;
let mut out_len = 0;
let mut out_result = ptr::null();
xgb_call!(xgboost_rs_sys::XGBoosterPredict(
self.handle,
dmat.handle,
option_mask,
ntree_limit,
1,
&mut out_len,
&mut out_result
))?;
assert!(!out_result.is_null());
let data = unsafe { slice::from_raw_parts(out_result, out_len as usize).to_vec() };
Ok(data)
}
pub fn predict_leaf(&self, dmat: &DMatrix) -> XGBResult<(Vec<f32>, (usize, usize))> {
let option_mask = PredictOption::options_as_mask(&[PredictOption::PredictLeaf]);
let ntree_limit = 0;
let mut out_len = 0;
let mut out_result = ptr::null();
xgb_call!(xgboost_rs_sys::XGBoosterPredict(
self.handle,
dmat.handle,
option_mask,
ntree_limit,
0,
&mut out_len,
&mut out_result
))?;
assert!(!out_result.is_null());
let data = unsafe { slice::from_raw_parts(out_result, out_len as usize).to_vec() };
let num_rows = dmat.num_rows();
let num_cols = data.len() / num_rows;
Ok((data, (num_rows, num_cols)))
}
pub fn predict_contributions(&self, dmat: &DMatrix) -> XGBResult<(Vec<f32>, (usize, usize))> {
let option_mask = PredictOption::options_as_mask(&[PredictOption::PredictContribitions]);
let ntree_limit = 0;
let mut out_len = 0;
let mut out_result = ptr::null();
xgb_call!(xgboost_rs_sys::XGBoosterPredict(
self.handle,
dmat.handle,
option_mask,
ntree_limit,
0,
&mut out_len,
&mut out_result
))?;
assert!(!out_result.is_null());
let data = unsafe { slice::from_raw_parts(out_result, out_len as usize).to_vec() };
let num_rows = dmat.num_rows();
let num_cols = data.len() / num_rows;
Ok((data, (num_rows, num_cols)))
}
pub fn predict_interactions(
&self,
dmat: &DMatrix,
) -> XGBResult<(Vec<f32>, (usize, usize, usize))> {
let option_mask = PredictOption::options_as_mask(&[PredictOption::PredictInteractions]);
let ntree_limit = 0;
let mut out_len = 0;
let mut out_result = ptr::null();
xgb_call!(xgboost_rs_sys::XGBoosterPredict(
self.handle,
dmat.handle,
option_mask,
ntree_limit,
0,
&mut out_len,
&mut out_result
))?;
assert!(!out_result.is_null());
let data = unsafe { slice::from_raw_parts(out_result, out_len as usize).to_vec() };
let num_rows = dmat.num_rows();
let dim = ((data.len() / num_rows) as f64).sqrt() as usize;
Ok((data, (num_rows, dim, dim)))
}
pub fn dump_model(
&self,
with_statistics: bool,
feature_map: Option<&FeatureMap>,
) -> XGBResult<String> {
if let Some(fmap) = feature_map {
let tmp_dir = match tempfile::tempdir() {
Ok(dir) => dir,
Err(err) => return Err(XGBError::new(err.to_string())),
};
let file_path = tmp_dir.path().join("fmap.txt");
let mut file: File = match File::create(&file_path) {
Ok(f) => f,
Err(err) => return Err(XGBError::new(err.to_string())),
};
for (feature_num, (feature_name, feature_type)) in &fmap.0 {
writeln!(file, "{}\t{}\t{}", feature_num, feature_name, feature_type).unwrap();
}
self.dump_model_fmap(with_statistics, Some(&file_path))
} else {
self.dump_model_fmap(with_statistics, None)
}
}
fn dump_model_fmap(
&self,
with_statistics: bool,
feature_map_path: Option<&PathBuf>,
) -> XGBResult<String> {
let fmap = if let Some(path) = feature_map_path {
ffi::CString::new(path.as_os_str().as_bytes()).unwrap()
} else {
ffi::CString::new("").unwrap()
};
let format = ffi::CString::new("text").unwrap();
let mut out_len = 0;
let mut out_dump_array = ptr::null_mut();
xgb_call!(xgboost_rs_sys::XGBoosterDumpModelEx(
self.handle,
fmap.as_ptr(),
i32::from(with_statistics),
format.as_ptr(),
&mut out_len,
&mut out_dump_array
))?;
let out_ptr_slice = unsafe { slice::from_raw_parts(out_dump_array, out_len as usize) };
let out_vec: Vec<String> = out_ptr_slice
.iter()
.map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() })
.collect();
assert_eq!(out_len as usize, out_vec.len());
Ok(out_vec.join("\n"))
}
pub(crate) fn load_rabit_checkpoint(&self) -> XGBResult<i32> {
let mut version = 0;
xgb_call!(xgboost_rs_sys::XGBoosterLoadRabitCheckpoint(
self.handle,
&mut version
))?;
Ok(version)
}
pub(crate) fn save_rabit_checkpoint(&self) -> XGBResult<()> {
xgb_call!(xgboost_rs_sys::XGBoosterSaveRabitCheckpoint(self.handle))
}
fn set_param_from_json(&mut self, config: HashMap<&str, &str>) {
for (k, v) in config {
let name = ffi::CString::new(k).unwrap();
let value = ffi::CString::new(v).unwrap();
unsafe {
xgboost_rs_sys::XGBoosterSetParam(self.handle, name.as_ptr(), value.as_ptr())
};
}
}
fn set_param(&mut self, name: &str, value: &str) -> XGBResult<()> {
let name = ffi::CString::new(name).unwrap();
let value = ffi::CString::new(value).unwrap();
xgb_call!(xgboost_rs_sys::XGBoosterSetParam(
self.handle,
name.as_ptr(),
value.as_ptr()
))
}
fn parse_eval_string(eval: &str, evnames: &[&str]) -> IndexMap<String, IndexMap<String, f32>> {
let mut result: IndexMap<String, IndexMap<String, f32>> = IndexMap::new();
debug!("Parsing evaluation line: {}", &eval);
for part in eval.split('\t').skip(1) {
for evname in evnames {
if part.starts_with(evname) {
let metric_parts: Vec<&str> =
part[evname.len() + 1..].split(':').into_iter().collect();
assert_eq!(metric_parts.len(), 2);
let metric = metric_parts[0];
let score = metric_parts[1].parse::<f32>().unwrap_or_else(|_| {
panic!("Unable to parse XGBoost metrics output: {}", eval)
});
let metric_map = result
.entry(String::from(*evname))
.or_insert_with(IndexMap::new);
metric_map.insert(metric.to_owned(), score);
}
}
}
debug!("result: {:?}", &result);
result
}
}
impl Drop for Booster {
fn drop(&mut self) {
xgb_call!(xgboost_rs_sys::XGBoosterFree(self.handle)).unwrap();
}
}
pub struct FeatureMap(BTreeMap<u32, (String, FeatureType)>);
impl FeatureMap {
pub fn from_file<P: AsRef<Path>>(path: P) -> io::Result<FeatureMap> {
let file = File::open(path)?;
let mut features: FeatureMap = FeatureMap(BTreeMap::new());
for (i, line) in BufReader::new(&file).lines().enumerate() {
let line = line?;
let parts: Vec<&str> = line.split('\t').collect();
if parts.len() != 3 {
let msg = format!(
"Unable to parse features from line {}, expected 3 tab separated values",
i + 1
);
return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
}
assert_eq!(parts.len(), 3);
let feature_num: u32 = match parts[0].parse() {
Ok(num) => num,
Err(err) => {
let msg = format!(
"Unable to parse features from line {}, could not parse feature number: {}",
i + 1,
err
);
return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
}
};
let feature_name = parts[1];
let feature_type = match FeatureType::from_str(parts[2]) {
Ok(feature_type) => feature_type,
Err(msg) => {
let msg = format!("Unable to parse features from line {}: {}", i + 1, msg);
return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
}
};
features
.0
.insert(feature_num, (feature_name.to_string(), feature_type));
}
Ok(features)
}
}
pub enum FeatureType {
Binary,
Quantitative,
Integer,
}
impl FromStr for FeatureType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"i" => Ok(FeatureType::Binary),
"q" => Ok(FeatureType::Quantitative),
"int" => Ok(FeatureType::Integer),
_ => Err(format!(
"unrecognised feature type '{}', must be one of: 'i', 'q', 'int'",
s
)),
}
}
}
impl fmt::Display for FeatureType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let s = match self {
FeatureType::Binary => "i",
FeatureType::Quantitative => "q",
FeatureType::Integer => "int",
};
write!(f, "{}", s)
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use indexmap::IndexMap;
use ndarray::arr2;
use crate::{
parameters::{self, learning, tree, BoosterParameters},
Booster, DMatrix, XGBResult,
};
fn read_train_matrix() -> XGBResult<DMatrix> {
let data_path = concat!(env!("CARGO_MANIFEST_DIR"), "/src");
DMatrix::load(format!("{}/data.csv?format=csv", data_path))
}
fn load_test_booster() -> Booster {
let dmat = read_train_matrix().expect("Reading train matrix failed");
Booster::new_with_cached_dmats(&BoosterParameters::default(), &[&dmat])
.expect("Creating Booster failed")
}
#[test]
fn set_booster_parhm() {
let mut booster = load_test_booster();
let res = booster.set_param("key", "value");
assert!(res.is_ok());
}
#[test]
fn load_rabit_version() {
let version = load_test_booster().load_rabit_checkpoint().unwrap();
assert_eq!(version, 0);
}
#[test]
fn get_set_attr() {
let mut booster = load_test_booster();
let attr = booster
.get_attribute("foo")
.expect("Getting attribute failed");
assert_eq!(attr, None);
booster
.set_attribute("foo", "bar")
.expect("Setting attribute failed");
let attr = booster
.get_attribute("foo")
.expect("Getting attribute failed");
assert_eq!(attr, Some("bar".to_owned()));
}
#[test]
fn save_and_load_from_buffer() {
let data_path = concat!(env!("CARGO_MANIFEST_DIR"), "/src");
let dmat_train =
DMatrix::load(format!("{}/agaricus.txt.train?format=libsvm", data_path)).unwrap();
let mut booster =
Booster::new_with_cached_dmats(&BoosterParameters::default(), &[&dmat_train]).unwrap();
let attr = booster
.get_attribute("foo")
.expect("Getting attribute failed");
assert_eq!(attr, None);
booster
.set_attribute("foo", "bar")
.expect("Setting attribute failed");
let attr = booster
.get_attribute("foo")
.expect("Getting attribute failed");
assert_eq!(attr, Some("bar".to_owned()));
let dir = tempfile::tempdir().expect("create temp dir");
let path = dir.path().join("test-xgboost-model");
booster.save(&path).expect("saving booster failed");
drop(booster);
let bytes = std::fs::read(&path).expect("reading saved booster file failed");
let booster = Booster::load_buffer(&bytes[..]).expect("loading booster from buffer failed");
let attr = booster
.get_attribute("foo")
.expect("Getting attribute failed");
assert_eq!(attr, Some("bar".to_owned()));
}
#[test]
fn get_attribute_names() {
let mut booster = load_test_booster();
let attrs = booster
.get_attribute_names()
.expect("Getting attributes failed");
assert_eq!(attrs, Vec::<String>::new());
booster
.set_attribute("foo", "bar")
.expect("Setting attribute failed");
booster
.set_attribute("another", "another")
.expect("Setting attribute failed");
booster
.set_attribute("4", "4")
.expect("Setting attribute failed");
booster
.set_attribute("an even longer attribute name?", "")
.expect("Setting attribute failed");
let mut expected = vec!["foo", "another", "4", "an even longer attribute name?"];
expected.sort_unstable();
let mut attrs = booster
.get_attribute_names()
.expect("Getting attributes failed");
attrs.sort();
assert_eq!(attrs, expected);
}
#[test]
fn predict() {
let data_path = concat!(env!("CARGO_MANIFEST_DIR"), "/src");
let dmat_train =
DMatrix::load(format!("{}/agaricus.txt.train?format=libsvm", data_path)).unwrap();
let dmat_test =
DMatrix::load(format!("{}/agaricus.txt.test?format=libsvm", data_path)).unwrap();
let tree_params = tree::TreeBoosterParametersBuilder::default()
.max_depth(2)
.eta(1.0)
.build()
.unwrap();
let learning_params = learning::LearningTaskParametersBuilder::default()
.objective(learning::Objective::BinaryLogistic)
.eval_metrics(learning::Metrics::Custom(vec![
learning::EvaluationMetric::MapCutNegative(4),
learning::EvaluationMetric::LogLoss,
learning::EvaluationMetric::BinaryErrorRate(0.5),
]))
.build()
.unwrap();
let params = parameters::BoosterParametersBuilder::default()
.booster_type(parameters::BoosterType::Tree(tree_params))
.learning_params(learning_params)
.verbose(false)
.build()
.unwrap();
let mut booster =
Booster::new_with_cached_dmats(¶ms, &[&dmat_train, &dmat_test]).unwrap();
for i in 0..10 {
booster.update(&dmat_train, i).expect("update failed");
}
let eps = 1e-6;
let train_metrics = booster.evaluate(&dmat_train, "default").unwrap();
assert!(*train_metrics.get("logloss").unwrap() - 0.006_634 < eps);
assert!(*train_metrics.get("map@4-").unwrap() - 0.001_274 < eps);
let test_metrics = booster.evaluate(&dmat_test, "default").unwrap();
assert!(*test_metrics.get("logloss").unwrap() - 0.006_92 < eps);
assert!(*test_metrics.get("map@4-").unwrap() - 0.005_155 < eps);
let v = booster.predict(&dmat_test).unwrap();
assert_eq!(v.len(), dmat_test.num_rows());
let expected_start = [
0.005_015_169_3,
0.988_446_7,
0.005_015_169_3,
0.005_015_169_3,
0.026_636_455,
0.117_893_63,
0.988_446_7,
0.012_314_71,
0.988_446_7,
0.000_136_560_63,
];
let expected_end = [
0.002_520_344,
0.000_609_179_26,
0.998_810_05,
0.000_609_179_26,
0.000_609_179_26,
0.000_609_179_26,
0.000_609_179_26,
0.998_110_2,
0.002_855_195,
0.998_110_2,
];
for (pred, expected) in v.iter().zip(&expected_start) {
assert!(pred - expected < eps);
}
for (pred, expected) in v[v.len() - 10..].iter().zip(&expected_end) {
assert!(pred - expected < eps);
}
}
#[test]
fn predict_leaf() {
let data_path = concat!(env!("CARGO_MANIFEST_DIR"), "/src");
let dmat_train =
DMatrix::load(format!("{}/agaricus.txt.train?format=libsvm", data_path)).unwrap();
let dmat_test =
DMatrix::load(format!("{}/agaricus.txt.test?format=libsvm", data_path)).unwrap();
let tree_params = tree::TreeBoosterParametersBuilder::default()
.max_depth(2)
.eta(1.0)
.build()
.unwrap();
let learning_params = learning::LearningTaskParametersBuilder::default()
.objective(learning::Objective::BinaryLogistic)
.eval_metrics(learning::Metrics::Custom(vec![
learning::EvaluationMetric::LogLoss,
]))
.build()
.unwrap();
let params = parameters::BoosterParametersBuilder::default()
.booster_type(parameters::BoosterType::Tree(tree_params))
.learning_params(learning_params)
.verbose(false)
.build()
.unwrap();
let mut booster =
Booster::new_with_cached_dmats(¶ms, &[&dmat_train, &dmat_test]).unwrap();
let num_rounds = 15;
for i in 0..num_rounds {
booster.update(&dmat_train, i).expect("update failed");
}
let (_preds, shape) = booster.predict_leaf(&dmat_test).unwrap();
let num_samples = dmat_test.num_rows();
assert_eq!(shape, (num_samples, num_rounds as usize));
}
#[test]
fn predict_contributions() {
let data_path = concat!(env!("CARGO_MANIFEST_DIR"), "/src");
let dmat_train =
DMatrix::load(format!("{}/agaricus.txt.train?format=libsvm", data_path)).unwrap();
let dmat_test =
DMatrix::load(format!("{}/agaricus.txt.test?format=libsvm", data_path)).unwrap();
let tree_params = tree::TreeBoosterParametersBuilder::default()
.max_depth(2)
.eta(1.0)
.build()
.unwrap();
let learning_params = learning::LearningTaskParametersBuilder::default()
.objective(learning::Objective::BinaryLogistic)
.eval_metrics(learning::Metrics::Custom(vec![
learning::EvaluationMetric::LogLoss,
]))
.build()
.unwrap();
let params = parameters::BoosterParametersBuilder::default()
.booster_type(parameters::BoosterType::Tree(tree_params))
.learning_params(learning_params)
.verbose(false)
.build()
.unwrap();
let mut booster =
Booster::new_with_cached_dmats(¶ms, &[&dmat_train, &dmat_test]).unwrap();
let num_rounds = 5;
for i in 0..num_rounds {
booster.update(&dmat_train, i).expect("update failed");
}
let (_preds, shape) = booster.predict_contributions(&dmat_test).unwrap();
let num_samples = dmat_test.num_rows();
let num_features = dmat_train.num_cols();
assert_eq!(shape, (num_samples, num_features + 1));
}
#[test]
fn predict_interactions() {
let data_path = concat!(env!("CARGO_MANIFEST_DIR"), "/src");
let dmat_train =
DMatrix::load(format!("{}/agaricus.txt.train?format=libsvm", data_path)).unwrap();
let dmat_test =
DMatrix::load(format!("{}/agaricus.txt.test?format=libsvm", data_path)).unwrap();
let tree_params = tree::TreeBoosterParametersBuilder::default()
.max_depth(2)
.eta(1.0)
.build()
.unwrap();
let learning_params = learning::LearningTaskParametersBuilder::default()
.objective(learning::Objective::BinaryLogistic)
.eval_metrics(learning::Metrics::Custom(vec![
learning::EvaluationMetric::LogLoss,
]))
.build()
.unwrap();
let params = parameters::BoosterParametersBuilder::default()
.booster_type(parameters::BoosterType::Tree(tree_params))
.learning_params(learning_params)
.verbose(false)
.build()
.unwrap();
let mut booster =
Booster::new_with_cached_dmats(¶ms, &[&dmat_train, &dmat_test]).unwrap();
let num_rounds = 5;
for i in 0..num_rounds {
booster.update(&dmat_train, i).expect("update failed");
}
let (_preds, shape) = booster.predict_interactions(&dmat_test).unwrap();
let num_samples = dmat_test.num_rows();
let num_features = dmat_train.num_cols();
assert_eq!(shape, (num_samples, num_features + 1, num_features + 1));
}
#[test]
fn parse_eval_string() {
let s = "[0]\ttrain-map@4-:0.5\ttrain-logloss:1.0\ttest-map@4-:0.25\ttest-logloss:0.75";
let mut metrics = IndexMap::new();
let mut train_metrics = IndexMap::new();
train_metrics.insert("map@4-".to_owned(), 0.5);
train_metrics.insert("logloss".to_owned(), 1.0);
let mut test_metrics = IndexMap::new();
test_metrics.insert("map@4-".to_owned(), 0.25);
test_metrics.insert("logloss".to_owned(), 0.75);
metrics.insert("train".to_owned(), train_metrics);
metrics.insert("test".to_owned(), test_metrics);
assert_eq!(Booster::parse_eval_string(s, &["train", "test"]), metrics);
}
#[test]
#[allow(clippy::too_many_lines)]
fn pred_from_dmat() {
let data_arr_2d = arr2(&[
[
8.325_2,
4.1e+01,
6.984_127,
1.023_809_6,
3.22e+02,
2.555_555_6,
3.788e+01,
-1.222_3e+02,
],
[
8.301_4,
2.1e+01,
6.238_137_2,
9.718_805e-1,
2.401e+03,
2.109_841_8,
3.786e+01,
-1.222_2e+02,
],
[
7.257_4,
5.2e+01,
8.288_136,
1.073_446_3,
4.96e+02,
2.802_26,
3.785e+01,
-1.222_4e+02,
],
[
5.643_1,
5.2e+01,
5.817_352,
1.073_059_3,
5.58e+02,
2.547_945_3,
3.785e+01,
-1.222_5e+02,
],
[
3.846_2,
5.2e+01,
6.281_853,
1.081_081,
5.65e+02,
2.181_467_3,
3.785e+01,
-1.222_5e+02,
],
[
4.036_8,
5.2e+01,
4.761_658,
1.103_627,
4.13e+02,
2.139_896_4,
3.785e+01,
-1.222_5e+02,
],
[
3.659_1,
5.2e+01,
4.931_906_7,
9.513_619e-1,
1.094e+03,
2.128_404_6,
3.784e+01,
-1.222_5e+02,
],
[
3.12,
5.2e1,
4.797_527,
1.061_823_8,
1.157e3,
1.788_253_4,
3.784e1,
-1.222_5e2,
],
[
2.080_4,
4.2e1,
4.294_117_5,
1.117_647,
1.206e3,
2.026_890_8,
3.784e1,
-1.222_6e2,
],
[
3.691_2,
5.2e1,
4.970_588,
9.901_960_5e-1,
1.551e3,
2.172_268_9,
3.784e1,
-1.222_5e2,
],
[
3.203_1,
5.2e1,
5.477_612,
1.079_602,
9.1e2,
2.263_681_7,
3.785e1,
-1.222_6e2,
],
[
3.270_5,
5.2e1,
4.772_479_5,
1.024_523_1,
1.504e3,
2.049_046_3,
3.785e1,
-1.222_6e2,
],
[
3.075,
5.2e1,
5.322_649_5,
1.012_820_5,
1.098e+03,
2.346_153_7,
3.785e1,
-1.222_6e2,
],
[
2.673_6,
5.2e1,
4.0,
1.097_701_2,
3.45e+02,
1.982_758_6,
3.784e1,
-1.222_6e2,
],
[
1.916_7,
5.2e1,
4.262_903,
1.009_677_4,
1.212e+03,
1.954_838_8,
3.785e+01,
-1.222_6e2,
],
[
2.125,
5.0e+01,
4.242_424,
1.071_969_7,
6.97e+02,
2.640_151_5,
3.785e+01,
-1.222_6e2,
],
[
2.775,
5.2e1,
5.939_577,
1.048_338_4,
7.93e2,
2.395_770_3,
3.785e1,
-1.222_7e2,
],
[
2.120_2,
5.2e1,
4.052_805_4,
9.669_967e-1,
6.48e2,
2.138_614,
3.785e1,
-1.222_7e2,
],
[
1.991_1,
5.0e1,
5.343_675_6,
1.085_918_9,
9.9e2,
2.362_768_4,
3.784e1,
-1.222_6e2,
],
[
2.603_3,
5.2e1,
5.465_454_6,
1.083_636_4,
6.9e2,
2.509_091,
3.784e1,
-1.222_7e2,
],
[
1.357_8,
4.0e1,
4.524_096_5,
1.108_433_7,
4.09e2,
2.463_855_5,
3.785e1,
-1.222_7e2,
],
[
1.713_5,
4.2e1,
4.478_142_3,
1.002_732_3,
9.29e2,
2.538_251_4,
3.785e1,
-1.222_7e2,
],
[
1.725,
5.2e1,
5.096_234_3,
1.131_799_1,
1.015e3,
2.123_431,
3.784e1,
-1.222_7e2,
],
[
2.180_6,
5.2e1,
5.193_846,
1.036_923,
8.53e2,
2.624_615_4,
3.784e1,
-1.222_7e2,
],
[
2.6,
5.2e1,
5.270_142,
1.035_545,
1.006e3,
2.383_886_3,
3.784e1,
-1.222_7e2,
],
[
2.403_8,
4.1e1,
4.495_798,
1.033_613_4,
3.17e2,
2.663_865_6,
3.785e1,
-1.222_8e2,
],
[
2.459_7,
4.9e+01,
4.728_033_5,
1.020_920_5,
6.07e+02,
2.539_749,
3.785e1,
-1.222_8e2,
],
[
1.808,
5.2e1,
4.780_856_6,
1.060_453_4,
1.102e3,
2.775_818_6,
3.785e+01,
-1.222_8e2,
],
[
1.642_4,
5.0e1,
4.401_691_4,
1.040_169_1,
1.131e3,
2.391_120_4,
3.784e1,
-1.222_8e2,
],
[
1.687_5,
5.2e1,
4.703_225_6,
1.032_258,
3.95e2,
2.548_387,
3.784e1,
-1.222_8e2,
],
[
1.927_4,
4.9e1,
5.068_783_3,
1.182_539_7,
8.63e2,
2.283_069,
3.784e1,
-1.222_8e2,
],
[
1.961_5,
5.2e1,
4.882_086_3,
1.090_702_9,
1.168e3,
2.648_526_2,
3.784e1,
-1.222_8e2,
],
[
1.796_9,
4.8e1,
5.737_313_3,
1.220_895_5,
1.026e3,
3.062_686_7,
3.784e1,
-1.222_7e2,
],
[
1.375,
4.9e1,
5.030_395,
1.112_462,
7.54e2,
2.291_793_3,
3.783e1,
-1.222_7e2,
],
[
2.730_3,
5.1e1,
4.972_015,
1.070_895_6,
1.258e3,
2.347_015,
3.783e1,
-1.222_7e2,
],
[
1.486_1,
4.9e1,
4.602_272_5,
1.068_181_9,
5.7e2,
2.159_091,
3.783e1,
-1.222_7e2,
],
[
1.097_2,
4.8e1,
4.807_486_5,
1.155_080_2,
9.87e2,
2.639_037_4,
3.783e1,
-1.222_7e2,
],
[
1.410_3,
5.2e1,
3.749_379_6,
9.677_419e-1,
9.01e2,
2.235_732,
3.783e1,
-1.222_8e2,
],
[
3.48,
5.2e1,
4.757_282,
1.067_961_2,
6.89e2,
2.229_773_5,
3.783e1,
-1.222_6e2,
],
[
2.589_8,
5.2e1,
3.494_253,
1.027_298_8,
1.377e3,
1.978_448_3,
3.783e1,
-1.222_6e2,
],
[
2.097_8,
5.2e1,
4.215_19,
1.060_759_5,
9.46e2,
2.394_936_8,
3.783e1,
-1.222_6e2,
],
[
1.285_2,
5.1e1,
3.759_036,
1.248_996,
5.17e2,
2.076_305_2,
3.783e1,
-1.222_6e2,
],
[
1.025,
4.9e1,
3.772_486_7,
1.068_783,
4.62e2,
2.444_444_4,
3.784e1,
-1.222_6e2,
],
[
3.964_3,
5.2e1,
4.797_98,
1.020_202,
4.67e2,
2.358_585_8,
3.784e1,
-1.222_6e2,
],
[
3.012_5,
5.2e1,
4.941_781,
1.065_068_5,
6.6e2,
2.260_274,
3.783e1,
-1.222_6e2,
],
[
2.676_8,
5.2e1,
4.335_078_7,
1.099_476_5,
7.18e2,
1.879_581_1,
3.783e1,
-1.222_6e2,
],
[
2.026,
5.0e+01,
3.700_657_8,
1.059_210_5,
6.16e2,
2.026_315_7,
3.783e1,
-1.222_6e2,
],
[
1.734_8,
4.3e1,
3.980_237_2,
1.233_201_6,
5.58e2,
2.205_533_5,
3.782e1,
-1.222_7e2,
],
[
9.506e-1, 4.0e1, 3.9, 1.218_75, 4.23e2, 2.643_75, 3.782e1, -1.222_6e2,
],
[
1.775,
4.0e1,
2.687_5,
1.065_340_9,
7.0e2,
1.988_636_4,
3.782e1,
-1.222_7e2,
],
[
9.218e-1,
2.1e1,
2.045_662_2,
1.034_246_6,
7.35e2,
1.678_082_2,
3.782e1,
-1.222_7e2,
],
[
1.504_5,
4.3e1,
4.589_680_7,
1.120_393_2,
1.061e3,
2.606_879_7,
3.782e1,
-1.222_7e2,
],
[
1.110_8,
4.1e1,
4.473_611,
1.184_722_2,
1.959e3,
2.720_833_3,
3.782e1,
-1.222_7e2,
],
[
1.247_5,
5.2e+1,
4.075,
1.14,
1.162e+3,
2.905,
3.782e+1,
-1.222_7e+2,
],
[
1.609_8,
5.2e1,
5.021_459,
1.008_583_7,
7.01e2,
3.008_583_8,
3.782e1,
-1.222_8e2,
],
[
1.411_3,
5.2e1,
4.295_454_5,
1.104_545_5,
5.76e2,
2.618_181_7,
3.782e1,
-1.222_8e2,
],
[
1.505_7,
5.2e1,
4.779_923,
1.111_969_1,
6.22e2,
2.401_544_3,
3.782e1,
-1.222_8e2,
],
[
8.172e-1,
5.2e1,
6.102_459,
1.372_950_8,
7.28e+02,
2.983_606_6,
3.782e+01,
-1.222_8e+02,
],
[
1.217_1e+00,
5.2e+01,
4.562_5e+00,
1.121_710_5,
1.074e+03,
3.532_894_8,
3.782e+01,
-1.222_8e+02,
],
[
2.562_5e+00,
2.0,
2.771_929_7,
7.543_859_5e-1,
9.4e+01,
1.649_122_8,
3.782e+01,
-1.222_9e+02,
],
[
3.392_9e+00,
5.2e+01,
5.994_652_3,
1.128_342_3,
5.54e+02,
2.962_566_9,
3.783e+01,
-1.222_9e+02,
],
[
6.118_3e+00,
4.9e+01,
5.869_565,
1.260_869_6,
8.6e+01,
3.739_130_5,
3.782e+01,
-1.222_9e+02,
],
[
9.011e-01,
5.0e+01,
6.229_508_4,
1.557_377_1,
3.77e+02,
3.090_164,
3.781e+01,
-1.222_9e+02,
],
[
1.191e+00,
5.2e+01,
7.698_113_4,
1.490_566,
5.21e+02,
3.276_729_6,
3.781e+01,
-1.223e+02,
],
[
2.593_8,
4.8e+01,
6.225_564,
1.368_421_1,
3.92e+02,
2.947_368_4,
3.781e+01,
-1.223e+02,
],
[
1.166_7e+00,
5.2e+01,
5.401_069_6,
1.117_647,
6.04e+02,
3.229_946_6,
3.781e+01,
-1.223e+02,
],
[
8.056e-01,
4.8e+01,
4.382_53,
1.066_265_1,
7.88e+02,
2.373_494,
3.781e+01,
-1.223e+02,
],
[
2.609_4e+00,
5.2e+01,
6.986_394_4,
1.659_864,
4.92e+02,
3.346_938_8,
3.78e+01,
-1.222_9e+02,
],
[
1.851_6e+00,
5.2e+01,
6.975_61,
1.329_268_3,
2.74e+02,
3.341_463_3,
3.781e+01,
-1.223e+02,
],
[
9.802e-01,
4.6e+01,
4.584_288,
1.054_009_8,
1.823e+03,
2.983_633_3,
3.781e+01,
-1.222_9e+02,
],
[
1.771_9,
2.6e+01,
6.047_244,
1.196_850_4,
3.92e+02,
3.086_614_1,
3.781e+01,
-1.222_9e+02,
],
[
7.286e-1,
4.6e+1,
3.375_451_3,
1.072_202_2,
5.82e+02,
2.101_083,
3.781e+01,
-1.222_9e+02,
],
[
1.75e+00,
4.9e+01,
5.552_631_4,
1.342_105_3,
5.6e+02,
3.684_210_5,
3.781_e+01,
-1.222_9e+02,
],
[
4.999e-1,
4.6e+1,
1.714_285_7,
5.714_286e-1,
1.8e+01,
2.571_428_5,
3.781e+01,
-1.222_9e+02,
],
[
2.483e+00,
2.0e+01,
6.278_195_4,
1.210_526_3,
2.9e+02,
2.180_451_2,
3.781e+01,
-1.222_9e+02,
],
[
9.241e-01,
1.7e+01,
2.817_767_6,
1.052_391_8,
7.62e+02,
1.735_763_1,
3.781e+01,
-1.222_8e+02,
],
[
2.446_4e+00,
3.6e+01,
5.724_951,
1.104_125_7,
1.236e+03,
2.428_290_8,
3.781e+01,
-1.222_8e+02,
],
[
1.111_1e+00,
1.9e+01,
5.830_918,
1.173_913,
7.21e+02,
3.483_091_8,
3.781e+01,
-1.222_8e+02,
],
[
8.026e-01,
2.3e+01,
5.369_230_7,
1.150_769_2,
1.054e+03,
3.243_076_8,
3.781e+01,
-1.222_9e+02,
],
[
2.011_4e+00,
3.8e+01,
4.412_903_3,
1.135_483_9,
3.44e+02,
2.219_354_9,
3.78e+01,
-1.222_8e+02,
],
[
1.5e+00,
1.7e+01,
3.197_231_8,
1.0e+00,
6.09e+02,
2.107_266_4,
3.781e+01,
-1.222_8e+02,
],
[
1.166_7,
5.2e+01,
3.75e+00,
1.0e+00,
1.83e+02,
3.267_857,
3.781e+01,
-1.222_7e+02,
],
[
1.520_8,
5.2e+01,
3.908_046,
1.114_942_6,
2.0e+02,
2.298_850_5,
3.781e+01,
-1.222_8e+02,
],
[
8.075e-1,
5.2e+01,
2.490_322_6,
1.058_064_5,
3.46e+02,
2.232_258,
3.781e+01,
-1.222_8e+02,
],
[
1.808_8e+00,
3.5e+01,
5.609_467_5,
1.088_757_4,
4.67e+02,
2.763_313_5,
3.781e+01,
-1.222_8e+02,
],
[
2.408_3e+00,
5.2e+01,
6.721_739_3,
1.243_478_3,
3.77e+02,
3.278_261,
3.781e+01,
-1.222_8e+02,
],
[
9.77e-01,
4.0e+01,
2.315_789_5,
1.186_842_1,
5.82e+02,
1.531_578_9,
3.781e+01,
-1.222_7e+02,
],
[
7.6e-01,
1.0e+01,
2.651_515_2,
1.054_545_4,
5.46e+02,
1.654_545_4,
3.781e+01,
-1.222_7e+02,
],
[
9.722e-01,
1.0e+01,
2.692_307_7,
1.076_923_1,
1.25e+02,
3.205_128_2,
3.78e+01,
-1.222_7e+02,
],
[
1.243_4,
5.2e+01,
2.929_411_6,
9.176_470_6e-1,
3.96e+02,
4.658_823_5,
3.78e+01,
-1.222_7e+02,
],
[
2.093_8,
1.6e+01,
2.745_856_3,
1.082_873,
8.0e+02,
2.209_944_7,
3.78e+01,
-1.222_7e+02,
],
[
8.668e-1,
5.2e+01,
2.443_181_8,
9.886_364e-1,
9.04e+02,
1.027_272_7e1,
3.78e+01,
-1.222_8e2,
],
[
7.5e-01,
5.2e+01,
2.823_529_5,
9.117_647e-1,
1.91e+02,
5.617_647,
3.78e+01,
-1.222_8e+02,
],
[
2.635_4,
2.7e+01,
3.493_377_4,
1.149_006_6,
7.18e+02,
2.377_483_4,
3.779e+01,
-1.222_7e+2,
],
[
1.847_7,
3.9e+01,
3.672_376_9,
1.334_047_1,
1.327e+03,
2.841_541_8,
3.78e+01,
-1.222_7e+2,
],
[
2.009_6,
3.6e+01,
2.294_016_4,
1.066_293_6,
3.469e+03,
1.493_327_6,
3.78e+01,
-1.222_6e+02,
],
[
2.834_5,
3.1e+01,
3.894_915_3,
1.127_966,
2.048e+03,
1.735_593_2,
3.782e+01,
-1.222_6e+02,
],
[
2.006_2,
2.9e+01,
3.681_318_8,
1.175_824_2,
2.02e+02,
2.219_780_2,
3.781e+01,
-1.222_6e+2,
],
[
1.218_5,
2.2e+1,
2.945_6,
1.016,
2.024e+3,
1.619_2,
3.782e+1,
-1.222_6e+2,
],
[
2.610_4,
3.7e+1,
3.707_142_8,
1.107_142_8,
1.838e+3,
1.875_510_2,
3.782e+1,
-1.222_6e+2,
],
]);
let target_vec = [
4.526, 3.585, 3.521, 3.413, 3.422, 2.697, 2.992, 2.414, 2.267, 2.611, 2.815, 2.418,
2.135, 1.913, 1.592, 1.4, 1.525, 1.555, 1.587, 1.629, 1.475, 1.598, 1.139, 0.997,
1.326, 1.075, 0.938, 1.055, 1.089, 1.32, 1.223, 1.152, 1.104, 1.049, 1.097, 0.972,
1.045, 1.039, 1.914, 1.76, 1.554, 1.5, 1.188, 1.888, 1.844, 1.823, 1.425, 1.375, 1.875,
1.125, 1.719, 0.938, 0.975, 1.042, 0.875, 0.831, 0.875, 0.853, 0.803, 0.6, 0.757, 0.75,
0.861, 0.761, 0.735, 0.784, 0.844, 0.813, 0.85, 1.292, 0.825, 0.952, 0.75, 0.675,
1.375, 1.775, 1.021, 1.083, 1.125, 1.313, 1.625, 1.125, 1.125, 1.375, 1.188, 0.982,
1.188, 1.625, 1.375, 5.00001, 1.625, 1.375, 1.625, 1.875, 1.792, 1.3, 1.838, 1.25, 1.7,
1.931,
];
let strides_ax_0 = data_arr_2d.strides()[0] as usize;
let strides_ax_1 = data_arr_2d.strides()[1] as usize;
let byte_size_ax_0 = std::mem::size_of::<f32>() * strides_ax_0;
let byte_size_ax_1 = std::mem::size_of::<f32>() * strides_ax_1;
let mut xg_matrix = DMatrix::from_col_major_f32(
data_arr_2d.as_slice_memory_order().unwrap(),
byte_size_ax_0,
byte_size_ax_1,
100,
9,
-1,
f32::NAN,
)
.unwrap();
let lbls: Vec<f32> = target_vec.iter().map(|elem| *elem as f32).collect();
xg_matrix.set_labels(lbls.as_slice()).unwrap();
let mut initial_training_config: HashMap<&str, &str> = HashMap::new();
initial_training_config.insert("validate_parameters", "1");
initial_training_config.insert("process_type", "default");
initial_training_config.insert("tree_method", "hist");
initial_training_config.insert("eval_metric", "rmse");
initial_training_config.insert("max_depth", "3");
let evals = &[(&xg_matrix, "train")];
let bst = Booster::train(
Some(evals),
&xg_matrix,
initial_training_config,
None, )
.unwrap();
let test_data_arr_2d = arr2(&[
[
1.91,
4.6e+01,
5.0,
1.004_132_3,
5.23e+02,
2.161_157_1,
3.936e+01,
-1.217e+02,
6.39e-01,
],
[
2.047_4,
3.7e+01,
4.957_446_6,
1.053_191_5,
1.505e+03,
3.202_127_7,
3.936e+01,
-1.217e+2,
5.6e-01,
],
[
1.835_5,
3.4e+01,
5.103_03,
1.127_272_7,
6.35e+02,
3.848_484_8,
3.936e+01,
-1.216_9e+2,
6.3e-01,
],
[
2.324_3,
2.7e+01,
6.347_188_5,
1.063_569_7,
1.1e+03,
2.689_486_5,
3.938e+01,
-1.217_4e+2,
8.55e-01,
],
[
2.525_9,
3.0e+01,
5.508_108,
1.037_837_9,
5.01e+2,
2.708_108_2,
3.933e+1,
-1.218e+2,
8.13e-1,
],
[
2.281_3,
2.1e+01,
5.207_272_5,
1.032_727_2,
8.62e+02,
3.134_545_6,
3.942e+01,
-1.217_1e+2,
5.76e-01,
],
[
2.172_8,
2.2e+01,
5.616_099,
1.058_823_6,
9.41e+02,
2.913_312_7,
3.941e+01,
-1.217_1e+2,
5.94e-01,
],
[
2.494_3,
2.9e+01,
5.050_898,
9.790_419_3e-1,
8.64e+02,
2.586_826_3,
3.94e+01,
-1.217_5e+2,
8.19e-01,
],
[
3.392_9,
3.9e+01,
6.656_626_7,
1.084_337_4,
4.08e+02,
2.457_831_4,
3.948e+01,
-1.217_9e+2,
8.21e-01,
],
[
2.381_6,
1.6e+01,
6.055_954,
1.120_516_5,
1.516e+03,
2.175_036,
3.815e+01,
-1.204_6e+2,
1.16,
],
[
2.5,
1.0e+01,
5.381_443_3,
1.116_838_5,
7.85e+02,
2.697_594_5,
3.812e+01,
-1.205_5e+2,
1.161,
],
[
2.365_4,
3.4e+01,
5.590_631_5,
1.138_492_8,
1.15e+03,
2.342_158_8,
3.809e+01,
-1.205_6e+2,
9.49e-01,
],
[
2.906_3,
2.7e+01,
6.025_125_5,
1.125_628_1,
4.63e+02,
2.326_633_2,
3.807e+01,
-1.205_5e+2,
9.22e-1,
],
[
2.287_5,
3.7e+01,
5.257_143,
1.057_142_9,
3.39e+02,
2.421_428_7,
3.807e+01,
-1.205_4e+2,
7.99e-01,
],
[
2.652_8,
9.0,
8.010_753,
1.586_021_5,
2.233e+03,
2.401_075_4,
3.797e+01,
-1.206_7e+2,
1.33,
],
[
3.0,
1.6e+01,
6.110_569,
1.162_601_6,
1.777e+03,
2.889_431,
3.809e+01,
-1.204_6e+2,
1.226,
],
[
2.982_1,
1.9e+01,
5.278_947_4,
1.236_842_2,
5.38e+02,
2.831_579,
3.824e+01,
-1.207_9e+2,
9.04e-01,
],
[
2.047_2,
1.6e+01,
5.931_559,
1.218_631_1,
1.319e+3,
2.507_604,
3.82e+1,
-1.209e+2,
9.32e-1,
],
[
4.010_9,
8.0,
5.574_176,
1.063_186_8,
1.0e+03,
2.747_252_7,
3.816e+1,
-1.208_8e+2,
1.259,
],
[
3.636,
9.0,
5.994_983,
1.137_123_7,
1.8e+03,
3.010_033_4,
3.811e+1,
-1.209_1e+2,
1.331,
],
]);
let strides_ax_0 = test_data_arr_2d.strides()[0] as usize;
let strides_ax_1 = test_data_arr_2d.strides()[1] as usize;
let byte_size_ax_0 = std::mem::size_of::<f32>() * strides_ax_0;
let byte_size_ax_1 = std::mem::size_of::<f32>() * strides_ax_1;
let test_data = DMatrix::from_col_major_f32(
test_data_arr_2d.as_slice_memory_order().unwrap(),
byte_size_ax_0,
byte_size_ax_1,
20,
9,
-1,
f32::NAN,
)
.unwrap();
let mut out_dim: u64 = 10;
bst.predict_from_dmat(&test_data, &[20, 9], &mut out_dim)
.unwrap();
}
}