use rayon::prelude::*;
use crate::{Sample, WeakLearner};
use crate::weak_learner::common::{
type_and_struct::*,
split_rule::*,
};
use super::{
node::*,
criterion::*,
train_node::*,
dtree_classifier::DTreeClassifier,
};
use std::rc::Rc;
use std::cell::RefCell;
use std::collections::HashMap;
pub struct DTree {
criterion: Criterion,
max_depth: Depth,
}
impl DTree {
#[inline]
pub fn init(sample: &Sample) -> Self {
let criterion = Criterion::Entropy;
let n_sample = sample.shape().0;
let depth = ((n_sample as f64).log2() + 1.0).ceil() as usize;
Self {
criterion,
max_depth: Depth::from(depth),
}
}
pub fn max_depth(mut self, depth: usize) -> Self {
assert!(depth > 0);
self.max_depth = Depth::from(depth);
self
}
#[inline]
pub fn criterion(mut self, criterion: Criterion) -> Self {
self.criterion = criterion;
self
}
}
impl WeakLearner for DTree {
type Hypothesis = DTreeClassifier;
#[inline]
fn produce(&self, sample: &Sample, dist: &[f64])
-> Self::Hypothesis
{
let n_sample = sample.shape().0;
let indices = (0..n_sample).into_iter()
.filter(|&i| dist[i] > 0.0)
.collect::<Vec<usize>>();
let criterion = self.criterion;
let tree = full_tree(
sample, dist, indices, criterion, self.max_depth
);
tree.borrow_mut().remove_redundant_nodes();
let root = Node::from(
Rc::try_unwrap(tree)
.expect("Root node has reference counter >= 1")
.into_inner()
);
DTreeClassifier::from(root)
}
}
#[inline]
fn full_tree(
sample: &Sample,
dist: &[f64],
indices: Vec<usize>,
criterion: Criterion,
depth: Depth,
) -> Rc<RefCell<TrainNode>>
{
let total_weight = indices.par_iter()
.copied()
.map(|i| dist[i])
.sum::<f64>();
let (conf, loss) = confidence_and_loss(sample, dist, &indices[..]);
if loss == 0.0 {
return TrainNode::leaf(conf, total_weight, loss);
}
let (feature, threshold) = criterion.best_split(
sample, dist, &indices[..]
);
let rule = Splitter::new(feature, Threshold::from(threshold));
let mut lindices = Vec::new();
let mut rindices = Vec::new();
for i in indices.into_iter() {
match rule.split(sample, i) {
LR::Left => { lindices.push(i); },
LR::Right => { rindices.push(i); },
}
}
if lindices.is_empty() || rindices.is_empty() {
return TrainNode::leaf(conf, total_weight, loss);
}
let ltree; let rtree;
if depth <= 1 {
ltree = construct_leaf(sample, dist, lindices);
rtree = construct_leaf(sample, dist, rindices);
} else {
let depth = depth - 1;
ltree = full_tree(sample, dist, lindices, criterion, depth);
rtree = full_tree(sample, dist, rindices, criterion, depth);
}
TrainNode::branch(rule, ltree, rtree, conf, total_weight, loss)
}
#[inline]
fn construct_leaf(
sample: &Sample,
dist: &[f64],
indices: Vec<usize>
) -> Rc<RefCell<TrainNode>>
{
let (conf, loss) = confidence_and_loss(sample, dist, &indices[..]);
let total_weight = indices.iter()
.copied()
.map(|i| dist[i])
.sum::<f64>();
TrainNode::leaf(conf, total_weight, loss)
}
#[inline]
fn confidence_and_loss(sample: &Sample, dist: &[f64], indices: &[usize])
-> (Confidence<f64>, LossValue)
{
let target = sample.target();
let mut counter: HashMap<i64, f64> = HashMap::new();
for &i in indices {
let l = target[i] as i64;
let cnt = counter.entry(l).or_insert(0.0);
*cnt += dist[i];
}
let total = counter.values().sum::<f64>();
let (label, p) = counter.into_par_iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap();
let loss = if total > 0.0 {
total * (1.0 - (p / total))
} else {
0.0
};
let confidence = if total > 0.0 {
label as f64 * (2.0 * (p / total) - 1.0)
} else {
label as f64
};
let confidence = Confidence::from(confidence);
let loss = LossValue::from(loss);
(confidence, loss)
}