use linfa::{
error::{Error, Result},
Float, Label, ParamGuard,
};
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use crate::DecisionTree;
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum SplitQuality {
Gini,
Entropy,
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct DecisionTreeValidParams<F, L> {
split_quality: SplitQuality,
max_depth: Option<usize>,
min_weight_split: f32,
min_weight_leaf: f32,
min_impurity_decrease: F,
label_marker: PhantomData<L>,
}
impl<F: Float, L> DecisionTreeValidParams<F, L> {
pub fn split_quality(&self) -> SplitQuality {
self.split_quality
}
pub fn max_depth(&self) -> Option<usize> {
self.max_depth
}
pub fn min_weight_split(&self) -> f32 {
self.min_weight_split
}
pub fn min_weight_leaf(&self) -> f32 {
self.min_weight_leaf
}
pub fn min_impurity_decrease(&self) -> F {
self.min_impurity_decrease
}
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct DecisionTreeParams<F, L>(DecisionTreeValidParams<F, L>);
impl<F: Float, L: Label> DecisionTreeParams<F, L> {
pub fn new() -> Self {
Self(DecisionTreeValidParams {
split_quality: SplitQuality::Gini,
max_depth: None,
min_weight_split: 2.0,
min_weight_leaf: 1.0,
min_impurity_decrease: F::cast(0.00001),
label_marker: PhantomData,
})
}
pub fn split_quality(mut self, split_quality: SplitQuality) -> Self {
self.0.split_quality = split_quality;
self
}
pub fn max_depth(mut self, max_depth: Option<usize>) -> Self {
self.0.max_depth = max_depth;
self
}
pub fn min_weight_split(mut self, min_weight_split: f32) -> Self {
self.0.min_weight_split = min_weight_split;
self
}
pub fn min_weight_leaf(mut self, min_weight_leaf: f32) -> Self {
self.0.min_weight_leaf = min_weight_leaf;
self
}
pub fn min_impurity_decrease(mut self, min_impurity_decrease: F) -> Self {
self.0.min_impurity_decrease = min_impurity_decrease;
self
}
}
impl<F: Float, L: Label> Default for DecisionTreeParams<F, L> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float, L: Label> DecisionTree<F, L> {
#[allow(clippy::new_ret_no_self)]
pub fn params() -> DecisionTreeParams<F, L> {
DecisionTreeParams::new()
}
}
impl<F: Float, L> ParamGuard for DecisionTreeParams<F, L> {
type Checked = DecisionTreeValidParams<F, L>;
type Error = Error;
fn check_ref(&self) -> Result<&Self::Checked> {
if self.0.min_impurity_decrease < F::epsilon() {
Err(Error::Parameters(format!(
"Minimum impurity decrease should be greater than zero, but was {}",
self.0.min_impurity_decrease
)))
} else {
Ok(&self.0)
}
}
fn check(self) -> Result<Self::Checked> {
self.check_ref()?;
Ok(self.0)
}
}