use std::default::Default;
use derive_builder::Builder;
use super::Interval;
#[derive(Clone)]
pub enum TreeMethod {
Auto,
Exact,
Approx,
Hist,
GpuExact,
GpuHist,
}
impl ToString for TreeMethod {
fn to_string(&self) -> String {
match *self {
TreeMethod::Auto => "auto".to_owned(),
TreeMethod::Exact => "exact".to_owned(),
TreeMethod::Approx => "approx".to_owned(),
TreeMethod::Hist => "hist".to_owned(),
TreeMethod::GpuExact => "gpu_exact".to_owned(),
TreeMethod::GpuHist => "gpu_hist".to_owned(),
}
}
}
impl Default for TreeMethod {
fn default() -> Self {
TreeMethod::Auto
}
}
impl From<String> for TreeMethod {
fn from(s: String) -> Self {
use std::borrow::Borrow;
Self::from(s.borrow())
}
}
impl<'a> From<&'a str> for TreeMethod {
fn from(s: &'a str) -> Self {
match s {
"auto" => TreeMethod::Auto,
"exact" => TreeMethod::Exact,
"approx" => TreeMethod::Approx,
"hist" => TreeMethod::Hist,
"gpu_exact" => TreeMethod::GpuExact,
"gpu_hist" => TreeMethod::GpuHist,
_ => panic!("no known tree_method for {}", s),
}
}
}
#[derive(Clone)]
pub enum TreeUpdater {
GrowColMaker,
DistCol,
GrowHistMaker,
GrowLocalHistMaker,
GrowSkMaker,
Sync,
Refresh,
Prune,
}
impl ToString for TreeUpdater {
fn to_string(&self) -> String {
match *self {
TreeUpdater::GrowColMaker => "grow_colmaker".to_owned(),
TreeUpdater::DistCol => "distcol".to_owned(),
TreeUpdater::GrowHistMaker => "grow_histmaker".to_owned(),
TreeUpdater::GrowLocalHistMaker => "grow_local_histmaker".to_owned(),
TreeUpdater::GrowSkMaker => "grow_skmaker".to_owned(),
TreeUpdater::Sync => "sync".to_owned(),
TreeUpdater::Refresh => "refresh".to_owned(),
TreeUpdater::Prune => "prune".to_owned(),
}
}
}
#[derive(Clone)]
pub enum ProcessType {
Default,
Update,
}
impl ToString for ProcessType {
fn to_string(&self) -> String {
match *self {
ProcessType::Default => "default".to_owned(),
ProcessType::Update => "update".to_owned(),
}
}
}
impl Default for ProcessType {
fn default() -> Self {
ProcessType::Default
}
}
#[derive(Clone)]
pub enum GrowPolicy {
Depthwise,
LossGuide,
}
impl ToString for GrowPolicy {
fn to_string(&self) -> String {
match *self {
GrowPolicy::Depthwise => "depthwise".to_owned(),
GrowPolicy::LossGuide => "lossguide".to_owned(),
}
}
}
impl Default for GrowPolicy {
fn default() -> Self {
GrowPolicy::Depthwise
}
}
#[derive(Clone)]
pub enum Predictor {
Cpu,
Gpu,
}
impl ToString for Predictor {
fn to_string(&self) -> String {
match *self {
Predictor::Cpu => "cpu_predictor".to_owned(),
Predictor::Gpu => "gpu_predictor".to_owned(),
}
}
}
impl Default for Predictor {
fn default() -> Self {
Predictor::Cpu
}
}
#[derive(Builder, Clone)]
#[builder(build_fn(validate = "Self::validate"))]
#[builder(default)]
pub struct TreeBoosterParameters {
eta: f32,
gamma: f32,
max_depth: u32,
min_child_weight: f32,
max_delta_step: f32,
subsample: f32,
colsample_bytree: f32,
colsample_bylevel: f32,
colsample_bynode: f32,
lambda: f32,
alpha: f32,
#[builder(default = "TreeMethod::default()")]
tree_method: TreeMethod,
sketch_eps: f32,
scale_pos_weight: f32,
updater: Vec<TreeUpdater>,
refresh_leaf: bool,
process_type: ProcessType,
grow_policy: GrowPolicy,
max_leaves: u32,
max_bin: u32,
num_parallel_tree: u32,
predictor: Predictor,
}
impl Default for TreeBoosterParameters {
fn default() -> Self {
TreeBoosterParameters {
eta: 0.3,
gamma: 0.0,
max_depth: 6,
min_child_weight: 1.0,
max_delta_step: 0.0,
subsample: 1.0,
colsample_bytree: 1.0,
colsample_bylevel: 1.0,
colsample_bynode: 1.0,
lambda: 1.0,
alpha: 0.0,
tree_method: TreeMethod::default(),
sketch_eps: 0.03,
scale_pos_weight: 1.0,
updater: Vec::new(),
refresh_leaf: true,
process_type: ProcessType::default(),
grow_policy: GrowPolicy::default(),
max_leaves: 0,
max_bin: 256,
num_parallel_tree: 1,
predictor: Predictor::default(),
}
}
}
impl TreeBoosterParameters {
pub(crate) fn as_string_pairs(&self) -> Vec<(String, String)> {
let mut v = vec![
("booster".to_owned(), "gbtree".to_owned()),
("eta".to_owned(), self.eta.to_string()),
("gamma".to_owned(), self.gamma.to_string()),
("max_depth".to_owned(), self.max_depth.to_string()),
(
"min_child_weight".to_owned(),
self.min_child_weight.to_string(),
),
("max_delta_step".to_owned(), self.max_delta_step.to_string()),
("subsample".to_owned(), self.subsample.to_string()),
(
"colsample_bytree".to_owned(),
self.colsample_bytree.to_string(),
),
(
"colsample_bylevel".to_owned(),
self.colsample_bylevel.to_string(),
),
(
"colsample_bynode".to_owned(),
self.colsample_bynode.to_string(),
),
("lambda".to_owned(), self.lambda.to_string()),
("alpha".to_owned(), self.alpha.to_string()),
("tree_method".to_owned(), self.tree_method.to_string()),
("sketch_eps".to_owned(), self.sketch_eps.to_string()),
(
"scale_pos_weight".to_owned(),
self.scale_pos_weight.to_string(),
),
(
"refresh_leaf".to_owned(),
(u8::from(self.refresh_leaf)).to_string(),
),
("process_type".to_owned(), self.process_type.to_string()),
("grow_policy".to_owned(), self.grow_policy.to_string()),
("max_leaves".to_owned(), self.max_leaves.to_string()),
("max_bin".to_owned(), self.max_bin.to_string()),
(
"num_parallel_tree".to_owned(),
self.num_parallel_tree.to_string(),
),
("predictor".to_owned(), self.predictor.to_string()),
];
if !self.updater.is_empty() {
v.push((
"updater".to_owned(),
self.updater
.iter()
.map(std::string::ToString::to_string)
.collect::<Vec<String>>()
.join(","),
));
}
v
}
}
impl TreeBoosterParametersBuilder {
fn validate(&self) -> Result<(), String> {
Interval::new_closed_closed(0.0, 1.0).validate(&self.eta, "eta")?;
Interval::new_open_closed(0.0, 1.0).validate(&self.subsample, "subsample")?;
Interval::new_open_closed(0.0, 1.0).validate(&self.colsample_bytree, "colsample_bytree")?;
Interval::new_open_closed(0.0, 1.0)
.validate(&self.colsample_bylevel, "colsample_bylevel")?;
Interval::new_open_closed(0.0, 1.0).validate(&self.colsample_bynode, "colsample_bynode")?;
Interval::new_open_open(0.0, 1.0).validate(&self.sketch_eps, "sketch_eps")?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tree_params() {
let error_margin = f32::EPSILON;
let p = TreeBoosterParameters::default();
assert!((p.eta - 0.3).abs() < error_margin);
let p = TreeBoosterParametersBuilder::default().build().unwrap();
assert!((p.eta - 0.3).abs() < error_margin);
}
}