use rayon::prelude::*;
use crate::{Sample, WeakLearner};
use super::bin::*;
use crate::weak_learner::common::{
type_and_struct::*,
split_rule::*,
};
use super::{
node::*,
criterion::*,
train_node::*,
decision_tree_classifier::DecisionTreeClassifier,
};
use std::fmt;
use std::rc::Rc;
use std::collections::HashMap;
pub struct DecisionTree<'a> {
bins: HashMap<&'a str, Bins>,
criterion: Criterion,
max_depth: Depth,
}
impl<'a> DecisionTree<'a> {
#[inline]
pub(super) fn from_components(
bins: HashMap<&'a str, Bins>,
criterion: Criterion,
max_depth: Depth,
) -> Self
{
Self { bins, criterion, max_depth, }
}
#[inline]
fn full_tree(
&self,
sample: &'a Sample,
dist: &[f64],
indices: Vec<usize>,
criterion: Criterion,
depth: Depth,
) -> TrainNodePtr
{
let total_weight = indices.par_iter()
.copied()
.map(|i| dist[i])
.sum::<f64>();
let (conf, loss) = confidence_and_loss(sample, dist, &indices[..]);
if loss == 0.0 || depth < 1 {
return TrainNode::leaf(conf, total_weight, loss);
}
let (feature, threshold) = criterion.best_split(
&self.bins, sample, dist, &indices[..]
);
let rule = Splitter::new(feature, Threshold::from(threshold));
let mut lindices = Vec::new();
let mut rindices = Vec::new();
for i in indices {
match rule.split(sample, i) {
LR::Left => { lindices.push(i); },
LR::Right => { rindices.push(i); },
}
}
if lindices.is_empty() || rindices.is_empty() {
return TrainNode::leaf(conf, total_weight, loss);
}
let depth = depth - 1;
let ltree = self.full_tree(sample, dist, lindices, criterion, depth);
let rtree = self.full_tree(sample, dist, rindices, criterion, depth);
TrainNode::branch(rule, ltree, rtree, conf, total_weight, loss)
}
}
impl<'a> WeakLearner for DecisionTree<'a> {
type Hypothesis = DecisionTreeClassifier;
fn name(&self) -> &str {
"Decision Tree"
}
fn info(&self) -> Option<Vec<(&str, String)>> {
let n_bins = self.bins.values()
.map(|bin| bin.len())
.reduce(usize::max)
.unwrap_or(0);
let info = Vec::from([
("# of bins (max)", format!("{n_bins}")),
("Max depth", format!("{}", self.max_depth)),
("Split criterion", format!("{}", self.criterion)),
]);
Some(info)
}
#[inline]
fn produce(&self, sample: &Sample, dist: &[f64])
-> Self::Hypothesis
{
let n_sample = sample.shape().0;
let indices = (0..n_sample).filter(|&i| dist[i] > 0.0)
.collect::<Vec<usize>>();
assert_ne!(indices.len(), 0);
let criterion = self.criterion;
let tree = self.full_tree(
sample, dist, indices, criterion, self.max_depth
);
tree.borrow_mut().remove_redundant_nodes();
let root = Node::from(
Rc::try_unwrap(tree)
.expect("Root node has reference counter >= 1")
.into_inner()
);
DecisionTreeClassifier::from(root)
}
}
#[inline]
fn confidence_and_loss(sample: &Sample, dist: &[f64], indices: &[usize])
-> (Confidence<f64>, LossValue)
{
assert_ne!(indices.len(), 0);
let target = sample.target();
let mut counter: HashMap<i64, f64> = HashMap::new();
for &i in indices {
let l = target[i] as i64;
let cnt = counter.entry(l).or_insert(0.0);
*cnt += dist[i];
}
let total = counter.values().sum::<f64>();
let (label, p) = counter.into_par_iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap();
let loss = if total > 0.0 { total * (1.0 - (p / total)) } else { 0.0 };
let confidence = if total > 0.0 {
(label as f64 * (2.0 * (p / total) - 1.0)).clamp(-1.0, 1.0)
} else {
(label as f64).clamp(-1.0, 1.0)
};
let confidence = Confidence::from(confidence);
let loss = LossValue::from(loss);
(confidence, loss)
}
impl fmt::Display for DecisionTree<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"\
----------\n\
# Decision Tree Weak Learner\n\n\
- Max depth: {}\n\
- Splitting criterion: {}\n\
- Bins:\
",
self.max_depth,
self.criterion,
)?;
let width = self.bins.keys()
.map(|key| key.len())
.max()
.expect("Tried to print bins, but no features are found");
let max_bin_width = self.bins.values()
.map(|bin| bin.len().ilog10() as usize)
.max()
.expect("Tried to print bins, but no features are found")
+ 1;
for (feat_name, feat_bins) in self.bins.iter() {
let n_bins = feat_bins.len();
writeln!(
f,
"\
\t* [{feat_name: <width$} | \
{n_bins: >max_bin_width$} bins] \
{feat_bins}\
"
)?;
}
write!(f, "----------")
}
}