use std::default::Default;
use super::Interval;
#[derive(Clone)]
pub enum TreeMethod {
Auto,
Exact,
Approx,
Hist,
GpuExact,
GpuHist,
}
impl ToString for TreeMethod {
fn to_string(&self) -> String {
match *self {
TreeMethod::Auto => "auto".to_owned(),
TreeMethod::Exact => "exact".to_owned(),
TreeMethod::Approx => "approx".to_owned(),
TreeMethod::Hist => "hist".to_owned(),
TreeMethod::GpuExact => "gpu_exact".to_owned(),
TreeMethod::GpuHist => "gpu_hist".to_owned(),
}
}
}
impl Default for TreeMethod {
fn default() -> Self { TreeMethod::Auto }
}
#[derive(Clone)]
pub enum TreeUpdater {
GrowColMaker,
DistCol,
GrowHistMaker,
GrowLocalHistMaker,
GrowSkMaker,
Sync,
Refresh,
Prune,
}
impl ToString for TreeUpdater {
fn to_string(&self) -> String {
match *self {
TreeUpdater::GrowColMaker => "grow_colmaker".to_owned(),
TreeUpdater::DistCol => "distcol".to_owned(),
TreeUpdater::GrowHistMaker => "grow_histmaker".to_owned(),
TreeUpdater::GrowLocalHistMaker => "grow_local_histmaker".to_owned(),
TreeUpdater::GrowSkMaker => "grow_skmaker".to_owned(),
TreeUpdater::Sync => "sync".to_owned(),
TreeUpdater::Refresh => "refresh".to_owned(),
TreeUpdater::Prune => "prune".to_owned(),
}
}
}
#[derive(Clone)]
pub enum ProcessType {
Default,
Update,
}
impl ToString for ProcessType {
fn to_string(&self) -> String {
match *self {
ProcessType::Default => "default".to_owned(),
ProcessType::Update => "update".to_owned(),
}
}
}
impl Default for ProcessType {
fn default() -> Self { ProcessType::Default }
}
#[derive(Clone)]
pub enum GrowPolicy {
Depthwise,
LossGuide,
}
impl ToString for GrowPolicy {
fn to_string(&self) -> String {
match *self {
GrowPolicy::Depthwise => "depthwise".to_owned(),
GrowPolicy::LossGuide => "lossguide".to_owned(),
}
}
}
impl Default for GrowPolicy {
fn default() -> Self { GrowPolicy::Depthwise }
}
#[derive(Clone)]
pub enum Predictor {
Cpu,
Gpu,
}
impl ToString for Predictor {
fn to_string(&self) -> String {
match *self {
Predictor::Cpu => "cpu_predictor".to_owned(),
Predictor::Gpu => "gpu_predictor".to_owned(),
}
}
}
impl Default for Predictor {
fn default() -> Self { Predictor::Cpu }
}
#[derive(Builder, Clone)]
#[builder(build_fn(validate = "Self::validate"))]
#[builder(default)]
pub struct TreeBoosterParameters {
eta: f32,
gamma: u32,
max_depth: u32,
min_child_weight: u32,
max_delta_step: u32,
subsample: f32,
colsample_bytree: f32,
colsample_bylevel: f32,
lambda: u32,
alpha: u32,
#[builder(default = "TreeMethod::default()")]
tree_method: TreeMethod,
sketch_eps: f32,
scale_pos_weight: f32,
updater: Vec<TreeUpdater>,
refresh_leaf: bool,
process_type: ProcessType,
grow_policy: GrowPolicy,
max_leaves: u32,
max_bin: u32,
predictor: Predictor,
}
impl Default for TreeBoosterParameters {
fn default() -> Self {
TreeBoosterParameters {
eta: 0.3,
gamma: 0,
max_depth: 6,
min_child_weight: 1,
max_delta_step: 0,
subsample: 1.0,
colsample_bytree: 1.0,
colsample_bylevel: 1.0,
lambda: 1,
alpha: 0,
tree_method: TreeMethod::default(),
sketch_eps: 0.03,
scale_pos_weight: 1.0,
updater: vec![TreeUpdater::GrowColMaker, TreeUpdater::Prune],
refresh_leaf: true,
process_type: ProcessType::default(),
grow_policy: GrowPolicy::default(),
max_leaves: 0,
max_bin: 256,
predictor: Predictor::default(),
}
}
}
impl TreeBoosterParameters {
pub(crate) fn as_string_pairs(&self) -> Vec<(String, String)> {
let mut v = Vec::new();
v.push(("booster".to_owned(), "gbtree".to_owned()));
v.push(("eta".to_owned(), self.eta.to_string()));
v.push(("gamma".to_owned(), self.gamma.to_string()));
v.push(("max_depth".to_owned(), self.max_depth.to_string()));
v.push(("min_child_weight".to_owned(), self.min_child_weight.to_string()));
v.push(("max_delta_step".to_owned(), self.max_delta_step.to_string()));
v.push(("subsample".to_owned(), self.subsample.to_string()));
v.push(("colsample_bytree".to_owned(), self.colsample_bytree.to_string()));
v.push(("colsample_bylevel".to_owned(), self.colsample_bylevel.to_string()));
v.push(("lambda".to_owned(), self.lambda.to_string()));
v.push(("alpha".to_owned(), self.alpha.to_string()));
v.push(("tree_method".to_owned(), self.tree_method.to_string()));
v.push(("sketch_eps".to_owned(), self.sketch_eps.to_string()));
v.push(("scale_pos_weight".to_owned(), self.scale_pos_weight.to_string()));
v.push(("updater".to_owned(), self.updater.iter().map(|u| u.to_string()).collect::<Vec<String>>().join(",")));
v.push(("refresh_leaf".to_owned(), (self.refresh_leaf as u8).to_string()));
v.push(("process_type".to_owned(), self.process_type.to_string()));
v.push(("grow_policy".to_owned(), self.grow_policy.to_string()));
v.push(("max_leaves".to_owned(), self.max_leaves.to_string()));
v.push(("max_bin".to_owned(), self.max_bin.to_string()));
v.push(("predictor".to_owned(), self.predictor.to_string()));
v
}
}
impl TreeBoosterParametersBuilder {
fn validate(&self) -> Result<(), String> {
Interval::new_closed_closed(0.0, 1.0).validate(&self.eta, "eta")?;
Interval::new_open_closed(0.0, 1.0).validate(&self.subsample, "subsample")?;
Interval::new_open_closed(0.0, 1.0).validate(&self.colsample_bytree, "colsample_bytree")?;
Interval::new_open_closed(0.0, 1.0).validate(&self.colsample_bylevel, "colsample_bylevel")?;
Interval::new_open_open(0.0, 1.0).validate(&self.sketch_eps, "sketch_eps")?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tree_params() {
let p = TreeBoosterParameters::default();
assert_eq!(p.eta, 0.3);
let p = TreeBoosterParametersBuilder::default().build().unwrap();
assert_eq!(p.eta, 0.3);
}
}