use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::marker::PhantomData;
use ndarray::{Array1, ArrayBase, Axis, Data, Ix1, Ix2};
use super::hyperparameters::{DecisionTreeParams, SplitQuality};
use super::NodeIter;
use super::Tikz;
use linfa::{
dataset::{AsTargets, Labels, Records},
error::Result,
traits::*,
DatasetBase, Float, Label,
};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
struct RowMask {
mask: Vec<bool>,
nsamples: usize,
}
impl RowMask {
fn all(nsamples: usize) -> Self {
RowMask {
mask: vec![true; nsamples as usize],
nsamples,
}
}
fn none(nsamples: usize) -> Self {
RowMask {
mask: vec![false; nsamples as usize],
nsamples: 0,
}
}
fn mark(&mut self, idx: usize) {
self.mask[idx] = true;
self.nsamples += 1;
}
}
struct SortedIndex<'a, F: Float> {
feature_name: &'a str,
sorted_values: Vec<(usize, F)>,
}
impl<'a, F: Float> SortedIndex<'a, F> {
fn of_array_column(
x: &ArrayBase<impl Data<Elem = F>, Ix2>,
feature_idx: usize,
feature_name: &'a str,
) -> Self {
let sliced_column: Vec<F> = x.index_axis(Axis(1), feature_idx).to_vec();
let mut pairs: Vec<(usize, F)> = sliced_column.into_iter().enumerate().collect();
pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Greater));
SortedIndex {
sorted_values: pairs,
feature_name,
}
}
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug, Clone)]
pub struct TreeNode<F, L> {
feature_idx: usize,
feature_name: String,
split_value: F,
impurity_decrease: F,
left_child: Option<Box<TreeNode<F, L>>>,
right_child: Option<Box<TreeNode<F, L>>>,
leaf_node: bool,
prediction: L,
depth: usize,
}
impl<F: Float, L: Label> Hash for TreeNode<F, L> {
fn hash<H: Hasher>(&self, state: &mut H) {
let mut data: Vec<u64> = vec![];
data.push(self.feature_idx as u64);
data.push(self.leaf_node as u64);
data.hash(state);
}
}
impl<F, L> Eq for TreeNode<F, L> {}
impl<F, L> PartialEq for TreeNode<F, L> {
fn eq(&self, other: &Self) -> bool {
self.feature_idx == other.feature_idx
}
}
impl<F: Float, L: Label + std::fmt::Debug> TreeNode<F, L> {
fn empty_leaf(prediction: L, depth: usize) -> Self {
TreeNode {
feature_idx: 0,
feature_name: "".to_string(),
split_value: F::zero(),
impurity_decrease: F::zero(),
left_child: None,
right_child: None,
leaf_node: true,
prediction,
depth,
}
}
pub fn is_leaf(&self) -> bool {
self.leaf_node
}
pub fn depth(&self) -> usize {
self.depth
}
pub fn prediction(&self) -> Option<L> {
if self.is_leaf() {
Some(self.prediction.clone())
} else {
None
}
}
pub fn children(&self) -> Vec<&Option<Box<TreeNode<F, L>>>> {
vec![&self.left_child, &self.right_child]
}
pub fn split(&self) -> (usize, F, F) {
(self.feature_idx, self.split_value, self.impurity_decrease)
}
pub fn feature_name(&self) -> Option<&String> {
if self.leaf_node {
None
} else {
Some(&self.feature_name)
}
}
fn fit<D: Data<Elem = F>, T: AsTargets<Elem = L> + Labels<Elem = L>>(
data: &DatasetBase<ArrayBase<D, Ix2>, T>,
mask: &RowMask,
hyperparameters: &DecisionTreeParams<F, L>,
sorted_indices: &[SortedIndex<F>],
depth: usize,
) -> Result<Self> {
let parent_class_freq = data.label_frequencies_with_mask(&mask.mask);
let prediction = find_modal_class(&parent_class_freq);
let target = data.try_single_target()?;
if (mask.nsamples as f32) < hyperparameters.min_weight_split
|| hyperparameters
.max_depth
.map(|max_depth| depth >= max_depth)
.unwrap_or(false)
{
return Ok(Self::empty_leaf(prediction, depth));
}
let mut best = None;
for (feature_idx, sorted_index) in sorted_indices.iter().enumerate() {
let mut right_class_freq = parent_class_freq.clone();
let mut left_class_freq = HashMap::new();
let total_weight = parent_class_freq.values().sum::<f32>();
let mut weight_on_right_side = total_weight;
let mut weight_on_left_side = 0.0;
for i in 0..mask.mask.len() - 1 {
let (presorted_index, mut split_value) = sorted_index.sorted_values[i];
if !mask.mask[presorted_index] {
continue;
}
let sample_class = &target[presorted_index];
let sample_weight = data.weight_for(presorted_index);
*right_class_freq.get_mut(sample_class).unwrap() -= sample_weight;
weight_on_right_side -= sample_weight;
*left_class_freq.entry(sample_class.clone()).or_insert(0.0) += sample_weight;
weight_on_left_side += sample_weight;
if (sorted_index.sorted_values[i].1 - sorted_index.sorted_values[i + 1].1).abs()
< F::from(1e-5).unwrap()
{
continue;
}
if weight_on_right_side < hyperparameters.min_weight_leaf
|| weight_on_left_side < hyperparameters.min_weight_leaf
{
continue;
}
let (left_score, right_score) = match hyperparameters.split_quality {
SplitQuality::Gini => (
gini_impurity(&right_class_freq),
gini_impurity(&left_class_freq),
),
SplitQuality::Entropy => {
(entropy(&right_class_freq), entropy(&left_class_freq))
}
};
let w = weight_on_right_side / total_weight;
let score = w * left_score + (1.0 - w) * right_score;
split_value =
(split_value + sorted_index.sorted_values[i + 1].1) / F::from(2.0).unwrap();
best = match best.take() {
None => Some((feature_idx, split_value, score)),
Some((_, _, best_score)) if score < best_score => {
Some((feature_idx, split_value, score))
}
x => x,
};
}
}
let impurity_decrease = if let Some((_, _, best_score)) = best {
let parent_score = match hyperparameters.split_quality {
SplitQuality::Gini => gini_impurity(&parent_class_freq),
SplitQuality::Entropy => entropy(&parent_class_freq),
};
let parent_score = F::from(parent_score).unwrap();
parent_score - F::from(best_score).unwrap()
} else {
F::zero()
};
if impurity_decrease < hyperparameters.min_impurity_decrease {
return Ok(Self::empty_leaf(prediction, depth));
}
let (best_feature_idx, best_split_value, _) = best.unwrap();
let mut left_mask = RowMask::none(data.nsamples());
let mut right_mask = RowMask::none(data.nsamples());
for i in 0..data.nsamples() {
if mask.mask[i] {
if data.records()[(i, best_feature_idx)] <= best_split_value {
left_mask.mark(i);
} else {
right_mask.mark(i);
}
}
}
let left_child = if left_mask.nsamples > 0 {
Some(Box::new(TreeNode::fit(
data,
&left_mask,
&hyperparameters,
&sorted_indices,
depth + 1,
)?))
} else {
None
};
let right_child = if right_mask.nsamples > 0 {
Some(Box::new(TreeNode::fit(
data,
&right_mask,
&hyperparameters,
&sorted_indices,
depth + 1,
)?))
} else {
None
};
let leaf_node = left_child.is_none() || right_child.is_none();
Ok(TreeNode {
feature_idx: best_feature_idx,
feature_name: sorted_indices[best_feature_idx].feature_name.to_owned(),
split_value: best_split_value,
impurity_decrease,
left_child,
right_child,
leaf_node,
prediction,
depth,
})
}
fn prune(&mut self) -> Option<L> {
if self.is_leaf() {
return Some(self.prediction.clone());
}
let left = self.left_child.as_mut().and_then(|x| x.prune());
let right = self.right_child.as_mut().and_then(|x| x.prune());
match (left, right) {
(Some(x), Some(y)) => {
if x == y {
self.prediction = x.clone();
self.right_child = None;
self.left_child = None;
self.leaf_node = true;
Some(x)
} else {
None
}
}
_ => None,
}
}
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug)]
pub struct DecisionTree<F: Float, L: Label> {
root_node: TreeNode<F, L>,
num_features: usize,
}
impl<F: Float, L: Label, D: Data<Elem = F>> PredictRef<ArrayBase<D, Ix2>, Array1<L>>
for DecisionTree<F, L>
{
fn predict_ref<'a>(&'a self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
x.genrows()
.into_iter()
.map(|row| make_prediction(&row, &self.root_node))
.collect()
}
}
impl<'a, F: Float, L: Label + 'a + std::fmt::Debug, D, T> Fit<'a, ArrayBase<D, Ix2>, T>
for DecisionTreeParams<F, L>
where
D: Data<Elem = F>,
T: AsTargets<Elem = L> + Labels<Elem = L>,
{
type Object = Result<DecisionTree<F, L>>;
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Self::Object {
self.validate().unwrap();
let x = dataset.records();
let feature_names = dataset.feature_names();
let all_idxs = RowMask::all(x.nrows());
let sorted_indices: Vec<_> = (0..(x.ncols()))
.map(|feature_idx| {
SortedIndex::of_array_column(&x, feature_idx, &feature_names[feature_idx])
})
.collect();
let mut root_node = TreeNode::fit(&dataset, &all_idxs, &self, &sorted_indices, 0)?;
root_node.prune();
Ok(DecisionTree {
root_node,
num_features: dataset.records().ncols(),
})
}
}
impl<F: Float, L: Label + std::fmt::Debug> DecisionTree<F, L> {
#[allow(clippy::new_ret_no_self)]
pub fn params() -> DecisionTreeParams<F, L> {
DecisionTreeParams {
split_quality: SplitQuality::Gini,
max_depth: None,
min_weight_split: 2.0,
min_weight_leaf: 1.0,
min_impurity_decrease: F::from(0.00001).unwrap(),
phantom: PhantomData,
}
}
pub fn iter_nodes(&self) -> NodeIter<F, L> {
let queue = vec![&self.root_node];
NodeIter::new(queue)
}
pub fn features(&self) -> Vec<usize> {
let mut fitted_features = HashSet::new();
for node in self.iter_nodes().filter(|node| !node.is_leaf()) {
if !fitted_features.contains(&node.feature_idx) {
fitted_features.insert(node.feature_idx);
}
}
fitted_features.into_iter().collect::<Vec<_>>()
}
pub fn mean_impurity_decrease(&self) -> Vec<F> {
let mut impurity_decrease = vec![F::zero(); self.num_features];
let mut num_nodes = vec![0; self.num_features];
for node in self.iter_nodes().filter(|node| !node.leaf_node) {
impurity_decrease[node.feature_idx] += node.impurity_decrease;
num_nodes[node.feature_idx] += 1;
}
impurity_decrease
.into_iter()
.zip(num_nodes.into_iter())
.map(|(val, n)| {
if n == 0 {
F::zero()
} else {
val / F::from(n).unwrap()
}
})
.collect()
}
pub fn relative_impurity_decrease(&self) -> Vec<F> {
let mean_impurity_decrease = self.mean_impurity_decrease();
let sum = mean_impurity_decrease.iter().cloned().sum();
mean_impurity_decrease
.into_iter()
.map(|x| x / sum)
.collect()
}
pub fn feature_importance(&self) -> Vec<F> {
self.relative_impurity_decrease()
}
pub fn root_node(&self) -> &TreeNode<F, L> {
&self.root_node
}
pub fn max_depth(&self) -> usize {
self.iter_nodes()
.fold(0, |max, node| usize::max(max, node.depth))
}
pub fn num_leaves(&self) -> usize {
self.iter_nodes().filter(|node| node.is_leaf()).count()
}
pub fn export_to_tikz(&self) -> Tikz<F, L> {
Tikz::new(&self)
}
}
fn make_prediction<F: Float, L: Label>(
x: &ArrayBase<impl Data<Elem = F>, Ix1>,
node: &TreeNode<F, L>,
) -> L {
if node.leaf_node {
node.prediction.clone()
} else if x[node.feature_idx] < node.split_value {
make_prediction(x, node.left_child.as_ref().unwrap())
} else {
make_prediction(x, node.right_child.as_ref().unwrap())
}
}
fn find_modal_class<L: Label>(class_freq: &HashMap<L, f32>) -> L {
let val = class_freq
.iter()
.fold(None, |acc, (idx, freq)| match acc {
None => Some((idx, freq)),
Some((_best_idx, best_freq)) => {
if best_freq > freq {
acc
} else {
Some((idx, freq))
}
}
})
.unwrap()
.0;
(*val).clone()
}
fn gini_impurity<L: Label>(class_freq: &HashMap<L, f32>) -> f32 {
let n_samples = class_freq.values().sum::<f32>();
assert!(n_samples > 0.0);
let purity = class_freq
.values()
.map(|x| x / n_samples)
.map(|x| x * x)
.sum::<f32>();
1.0 - purity
}
fn entropy<L: Label>(class_freq: &HashMap<L, f32>) -> f32 {
let n_samples = class_freq.values().sum::<f32>();
assert!(n_samples > 0.0);
class_freq
.values()
.map(|x| x / n_samples)
.map(|x| if x > 0.0 { -x * x.log2() } else { 0.0 })
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use linfa::{error::Result, metrics::ToConfusionMatrix, Dataset};
use ndarray::{array, s, stack, Array, Array1, Array2, Axis};
use rand::rngs::SmallRng;
use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
#[test]
fn prediction_for_rows_example() {
let labels = Array::from(vec![0, 0, 0, 0, 0, 0, 1, 1]);
let row_mask = RowMask::all(labels.len());
let dataset: DatasetBase<(), Array1<usize>> = DatasetBase::new((), labels);
let class_freq = dataset.label_frequencies_with_mask(&row_mask.mask);
assert_eq!(find_modal_class(&class_freq), 0);
}
#[test]
fn gini_impurity_example() {
let class_freq = vec![(0, 6.0), (1, 2.0), (2, 0.0)].into_iter().collect();
assert_abs_diff_eq!(gini_impurity(&class_freq), 0.375, epsilon = 1e-5);
}
#[test]
fn entropy_example() {
let class_freq = vec![(0, 6.0), (1, 2.0), (2, 0.0)].into_iter().collect();
assert_abs_diff_eq!(entropy(&class_freq), 0.81127, epsilon = 1e-5);
let perfect_class_freq = vec![(0, 8.0), (1, 0.0), (2, 0.0)].into_iter().collect();
assert_abs_diff_eq!(entropy(&perfect_class_freq), 0.0, epsilon = 1e-5);
}
#[test]
fn single_feature_random_noise_binary() -> Result<()> {
let mut data = Array::random((50, 10), Uniform::new(-4., 4.));
data.slice_mut(s![.., 8]).assign(
&(0..50)
.map(|x| if x < 25 { 0.0 } else { 1.0 })
.collect::<Array1<_>>(),
);
let targets = (0..50).map(|x| x < 25).collect::<Array1<_>>();
let dataset = Dataset::new(data, targets);
let model = DecisionTree::params().max_depth(Some(2)).fit(&dataset)?;
assert_eq!(&model.features(), &[8]);
let ground_truth = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
for (imp, truth) in model.feature_importance().iter().zip(&ground_truth) {
assert_abs_diff_eq!(imp, truth, epsilon = 1e-15);
}
let cm = model
.predict(dataset.records())
.confusion_matrix(&dataset)?;
assert_abs_diff_eq!(cm.accuracy(), 1.0, epsilon = 1e-15);
Ok(())
}
#[test]
fn check_max_depth() -> Result<()> {
let mut rng = SmallRng::seed_from_u64(42);
let data = Array::random_using((50, 50), Uniform::new(-1., 1.), &mut rng);
let targets = (0..50).collect::<Array1<usize>>();
let dataset = Dataset::new(data, targets);
for max_depth in &[1, 5, 10, 20] {
let model = DecisionTree::params()
.max_depth(Some(*max_depth))
.min_impurity_decrease(1e-10f64)
.min_weight_split(1e-10)
.fit(&dataset)?;
assert_eq!(model.max_depth(), *max_depth);
}
Ok(())
}
#[test]
fn perfectly_separable_small() -> Result<()> {
let data = array![[1., 2., 3.], [1., 2., 4.], [1., 3., 3.5]];
let targets = array![0, 0, 1];
let dataset = Dataset::new(data.clone(), targets);
let model = DecisionTree::params().max_depth(Some(1)).fit(&dataset)?;
assert_eq!(model.predict(&data), array![0, 0, 1]);
Ok(())
}
#[test]
fn toy_dataset() -> Result<()> {
let data = array![
[0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 1.0, -14.0, 0.0, -4.0, 0.0, 0.0, 0.0, 0.0,],
[0.0, 0.0, 5.0, 3.0, 0.0, -4.0, 0.0, 0.0, 1.0, -5.0, 0.2, 0.0, 4.0, 1.0,],
[-1.0, -1.0, 0.0, 0.0, -4.5, 0.0, 0.0, 2.1, 1.0, 0.0, 0.0, -4.5, 0.0, 1.0,],
[-1.0, -1.0, 0.0, -1.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.0, 0.0, 1.0,],
[-1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,],
[-1.0, -2.0, 0.0, 4.0, -3.0, 10.0, 4.0, 0.0, -3.2, 0.0, 4.0, 3.0, -4.0, 1.0,],
[2.11, 0.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -3.0, 1.0,],
[2.11, 0.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.0, 0.0, -2.0, 1.0,],
[2.11, 8.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.0, 0.0, -2.0, 1.0,],
[2.11, 8.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -1.0, 0.0,],
[2.0, 8.0, 5.0, 1.0, 0.5, -4.0, 10.0, 0.0, 1.0, -5.0, 3.0, 0.0, 2.0, 0.0,],
[2.0, 0.0, 1.0, 1.0, 1.0, -1.0, 1.0, 0.0, 0.0, -2.0, 3.0, 0.0, 1.0, 0.0,],
[2.0, 0.0, 1.0, 2.0, 3.0, -1.0, 10.0, 2.0, 0.0, -1.0, 1.0, 2.0, 2.0, 0.0,],
[1.0, 1.0, 0.0, 2.0, 2.0, -1.0, 1.0, 2.0, 0.0, -5.0, 1.0, 2.0, 3.0, 0.0,],
[3.0, 1.0, 0.0, 3.0, 0.0, -4.0, 10.0, 0.0, 1.0, -5.0, 3.0, 0.0, 3.0, 1.0,],
[2.11, 8.0, -6.0, -0.5, 0.0, 1.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -3.0, 1.0,],
[2.11, 8.0, -6.0, -0.5, 0.0, 1.0, 0.0, 0.0, -3.2, 6.0, 1.5, 1.0, -1.0, -1.0,],
[2.11, 8.0, -6.0, -0.5, 0.0, 10.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -1.0, -1.0,],
[2.0, 0.0, 5.0, 1.0, 0.5, -2.0, 10.0, 0.0, 1.0, -5.0, 3.0, 1.0, 0.0, -1.0,],
[2.0, 0.0, 1.0, 1.0, 1.0, -2.0, 1.0, 0.0, 0.0, -2.0, 0.0, 0.0, 0.0, 1.0,],
[2.0, 1.0, 1.0, 1.0, 2.0, -1.0, 10.0, 2.0, 0.0, -1.0, 0.0, 2.0, 1.0, 1.0,],
[1.0, 1.0, 0.0, 0.0, 1.0, -3.0, 1.0, 2.0, 0.0, -5.0, 1.0, 2.0, 1.0, 1.0,],
[3.0, 1.0, 0.0, 1.0, 0.0, -4.0, 1.0, 0.0, 1.0, -2.0, 0.0, 0.0, 1.0, 0.0,]
];
let targets = array![1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0];
let dataset = Dataset::new(data, targets);
let model = DecisionTree::params().fit(&dataset)?;
let prediction = model.predict(&dataset);
let cm = prediction.confusion_matrix(&dataset)?;
assert!(cm.accuracy() > 0.95);
Ok(())
}
#[test]
fn multilabel_four_uniform() -> Result<()> {
let mut data = stack(
Axis(0),
&[Array2::random((40, 2), Uniform::new(-1., 1.)).view()],
)
.unwrap();
data.outer_iter_mut().enumerate().for_each(|(i, mut p)| {
if i < 10 {
p += &array![-2., -2.]
} else if i < 20 {
p += &array![-2., 2.];
} else if i < 30 {
p += &array![2., -2.];
} else {
p += &array![2., 2.];
}
});
let targets = (0..40)
.map(|x| match x {
x if x < 10 => 0,
x if x < 20 => 1,
x if x < 30 => 2,
_ => 3,
})
.collect::<Array1<_>>();
let dataset = Dataset::new(data.clone(), targets);
let model = DecisionTree::params().fit(&dataset)?;
let prediction = model.predict(data);
let cm = prediction.confusion_matrix(&dataset)?;
assert!(cm.accuracy() > 0.99);
Ok(())
}
#[test]
#[should_panic]
fn panic_min_impurity_decrease() {
DecisionTree::<f64, bool>::params()
.min_impurity_decrease(0.0)
.validate()
.unwrap();
}
}