use std::default::Default;
use super::Interval;
#[derive(Clone)]
pub enum SampleType {
Uniform,
Weighted,
}
impl ToString for SampleType {
fn to_string(&self) -> String {
match *self {
SampleType::Uniform => "uniform".to_owned(),
SampleType::Weighted => "weighted".to_owned(),
}
}
}
impl Default for SampleType {
fn default() -> Self { SampleType::Uniform }
}
#[derive(Clone)]
pub enum NormalizeType {
Tree,
Forest,
}
impl ToString for NormalizeType {
fn to_string(&self) -> String {
match *self {
NormalizeType::Tree => "tree".to_owned(),
NormalizeType::Forest => "forest".to_owned(),
}
}
}
impl Default for NormalizeType {
fn default() -> Self { NormalizeType::Tree }
}
#[derive(Builder, Clone)]
#[builder(build_fn(validate = "Self::validate"))]
#[builder(default)]
pub struct DartBoosterParameters {
sample_type: SampleType,
normalize_type: NormalizeType,
rate_drop: f32,
one_drop: bool,
skip_drop: f32,
}
impl Default for DartBoosterParameters {
fn default() -> Self {
DartBoosterParameters {
sample_type: SampleType::default(),
normalize_type: NormalizeType::default(),
rate_drop: 0.0,
one_drop: false,
skip_drop: 0.0,
}
}
}
impl DartBoosterParameters {
pub(crate) fn as_string_pairs(&self) -> Vec<(String, String)> {
let mut v = Vec::new();
v.push(("booster".to_owned(), "dart".to_owned()));
v.push(("sample_type".to_owned(), self.sample_type.to_string()));
v.push(("normalize_type".to_owned(), self.normalize_type.to_string()));
v.push(("rate_drop".to_owned(), self.rate_drop.to_string()));
v.push(("one_drop".to_owned(), (self.one_drop as u8).to_string()));
v.push(("skip_drop".to_owned(), self.skip_drop.to_string()));
v
}
}
impl DartBoosterParametersBuilder {
fn validate(&self) -> Result<(), String> {
Interval::new_closed_closed(0.0, 1.0).validate(&self.rate_drop, "rate_drop")?;
Interval::new_closed_closed(0.0, 1.0).validate(&self.skip_drop, "skip_drop")?;
Ok(())
}
}