use crate::{Sample, DecisionTree};
use crate::weak_learner::common::type_and_struct::*;
use super::bin::*;
use super::criterion::*;
use std::collections::HashMap;
pub const DEFAULT_NBIN: usize = 255;
pub const DEFAULT_MAX_DEPTH: usize = 2;
#[derive(Clone)]
pub struct DecisionTreeBuilder<'a> {
sample: &'a Sample,
n_bins: HashMap<&'a str, usize>,
max_depth: Depth,
criterion: Criterion,
}
impl<'a> DecisionTreeBuilder<'a> {
pub fn new(sample: &'a Sample) -> Self {
let n_bins = sample.features()
.iter()
.map(|feat| {
let n_bin = feat.distinct_value_count()
.min(DEFAULT_NBIN);
(feat.name(), n_bin)
})
.collect();
let max_depth = Depth::from(DEFAULT_MAX_DEPTH);
let criterion = Criterion::Entropy;
Self { sample, n_bins, max_depth, criterion, }
}
pub fn max_depth(mut self, depth: usize) -> Self {
assert!(depth > 0, "Tree must have positive depth");
self.max_depth = Depth::from(depth);
self
}
#[inline]
pub fn criterion(mut self, criterion: Criterion) -> Self {
self.criterion = criterion;
self
}
pub fn set_nbins<T>(&mut self, name: T, n_bins: usize)
where T: AsRef<str>
{
let name = name.as_ref();
match self.n_bins.get_mut(name) {
Some(val) => { *val = n_bins; },
None => {
panic!("The feature named `{name}` does not exist");
},
}
}
pub fn build(self) -> DecisionTree<'a> {
let bins = self.sample.features()
.iter()
.map(|feature| {
let name = feature.name();
let n_bins = *self.n_bins.get(name).unwrap();
(name, Bins::cut(feature, n_bins))
})
.collect::<HashMap<_, _>>();
let dtree = DecisionTree::from_components(
bins, self.criterion, self.max_depth
);
dtree
}
}