#[derive(Clone, Debug, Default)]
pub enum MaxFeatures {
#[default]
None,
Fraction(f64),
Value(usize),
Sqrt,
Callable(fn(usize) -> usize),
}
impl MaxFeatures {
pub fn from_n_features(&self, n_features: usize) -> usize {
let value = match self {
MaxFeatures::None => n_features,
MaxFeatures::Fraction(fraction) => (fraction * n_features as f64) as usize,
MaxFeatures::Value(number) => *number,
MaxFeatures::Sqrt => (n_features as f64).sqrt() as usize,
MaxFeatures::Callable(callable) => callable(n_features),
};
value.max(1).min(n_features)
}
}
#[derive(Clone, Debug)]
pub struct DecisionTreeParameters {
pub max_depth: Option<usize>,
pub max_features: MaxFeatures,
pub min_samples_split: usize,
pub min_samples_leaf: usize,
pub random_state: u64,
}
impl Default for DecisionTreeParameters {
fn default() -> Self {
DecisionTreeParameters {
max_depth: None,
max_features: MaxFeatures::default(),
min_samples_split: 2,
min_samples_leaf: 1,
random_state: 0,
}
}
}
impl DecisionTreeParameters {
pub fn new(
max_depth: Option<usize>,
max_features: MaxFeatures,
min_samples_split: usize,
min_samples_leaf: usize,
random_state: u64,
) -> Self {
DecisionTreeParameters {
max_depth,
max_features,
min_samples_split,
min_samples_leaf,
random_state,
}
}
pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
self.max_depth = max_depth;
self
}
pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
self.max_features = max_features;
self
}
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
self.min_samples_split = min_samples_split;
self
}
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
self.min_samples_leaf = min_samples_leaf;
self
}
pub fn with_random_state(mut self, random_state: u64) -> Self {
self.random_state = random_state;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::*;
#[rstest]
#[case(MaxFeatures::None, 10, 10)]
#[case(MaxFeatures::Fraction(1.), 10, 10)]
#[case(MaxFeatures::Fraction(0.8), 10, 8)]
#[case(MaxFeatures::Fraction(0.01), 10, 1)]
#[case(MaxFeatures::Fraction(2.), 10, 10)]
#[case(MaxFeatures::Value(5), 10, 5)]
#[case(MaxFeatures::Sqrt, 10, 3)]
#[case(MaxFeatures::Callable(|x| x % 4), 10, 2)]
fn test_MaxFeatures(
#[case] max_features: MaxFeatures,
#[case] n_features: usize,
#[case] expected: usize,
) {
assert_eq!(max_features.from_n_features(n_features), expected);
}
}