linfa_trees/decision_trees/
algorithm.rs

1//! Linear decision trees
2//!
3use std::cmp::Ordering;
4use std::collections::{HashMap, HashSet};
5use std::hash::{Hash, Hasher};
6
7use linfa::dataset::AsSingleTargets;
8use ndarray::{Array1, ArrayBase, Axis, Data, Ix1, Ix2};
9
10use super::NodeIter;
11use super::Tikz;
12use super::{DecisionTreeValidParams, SplitQuality};
13use linfa::{
14    dataset::{Labels, Records},
15    error::Error,
16    error::Result,
17    traits::*,
18    DatasetBase, Float, Label,
19};
20
21#[cfg(feature = "serde")]
22use serde_crate::{Deserialize, Serialize};
23
24/// RowMask tracks observations
25///
26/// The decision tree algorithm splits observations at a certain split value for a specific feature. The
27/// left and right children can then only use a certain number of observations. In order to track
28/// that, the observations are masked with a boolean vector, hiding all observations which are not
29/// applicable in a lower tree.
30struct RowMask {
31    mask: Vec<bool>,
32    nsamples: usize,
33}
34
35impl RowMask {
36    /// Generates a RowMask without hidden observations
37    ///
38    /// ### Parameters
39    ///
40    /// * `nsamples`: the total number of observations
41    ///
42    fn all(nsamples: usize) -> Self {
43        RowMask {
44            mask: vec![true; nsamples],
45            nsamples,
46        }
47    }
48
49    /// Generates a RowMask where all observations are hidden
50    ///
51    /// ### Parameters
52    ///
53    /// * `nsamples`: the total number of observations
54    fn none(nsamples: usize) -> Self {
55        RowMask {
56            mask: vec![false; nsamples],
57            nsamples: 0,
58        }
59    }
60
61    /// Sets the observation at the specified index as visible
62    ///
63    /// ### Parameters
64    ///
65    /// * `idx`: the index of the observation to turn visible
66    ///
67    /// ### Panics
68    ///
69    /// If `idx` is out of bounds
70    ///
71    fn mark(&mut self, idx: usize) {
72        self.mask[idx] = true;
73        self.nsamples += 1;
74    }
75}
76
77/// Sorted values of observations with indices (always for a particular feature)
78struct SortedIndex<'a, F: Float> {
79    feature_name: &'a str,
80    sorted_values: Vec<(usize, F)>,
81}
82
83impl<'a, F: Float> SortedIndex<'a, F> {
84    /// Sorts the values of a given feature in ascending order
85    ///
86    /// ### Parameters
87    ///
88    /// * `x`: the observations to sort
89    /// * `feature_idx`: the index of the feature on whch to sort the data
90    /// * `feature_name`: the human readable name of the feature
91    ///
92    /// ### Returns
93    ///
94    /// A sorted vector of (index, value) pairs obtained by sorting the observations by
95    /// the value of the specified feature.
96    fn of_array_column(
97        x: &ArrayBase<impl Data<Elem = F>, Ix2>,
98        feature_idx: usize,
99        feature_name: &'a str,
100    ) -> Self {
101        let sliced_column: Vec<F> = x.index_axis(Axis(1), feature_idx).to_vec();
102        let mut pairs: Vec<(usize, F)> = sliced_column.into_iter().enumerate().collect();
103        pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Greater));
104
105        SortedIndex {
106            sorted_values: pairs,
107            feature_name,
108        }
109    }
110}
111
112#[cfg_attr(
113    feature = "serde",
114    derive(Serialize, Deserialize),
115    serde(crate = "serde_crate")
116)]
117#[derive(Debug, Clone)]
118/// A node in the decision tree
119pub struct TreeNode<F, L> {
120    feature_idx: usize,
121    feature_name: String,
122    split_value: F,
123    impurity_decrease: F,
124    left_child: Option<Box<TreeNode<F, L>>>,
125    right_child: Option<Box<TreeNode<F, L>>>,
126    leaf_node: bool,
127    prediction: L,
128    depth: usize,
129}
130
131impl<F: Float, L: Label> Hash for TreeNode<F, L> {
132    fn hash<H: Hasher>(&self, state: &mut H) {
133        let data: Vec<u64> = vec![self.feature_idx as u64, self.leaf_node as u64];
134        data.hash(state);
135    }
136}
137
138impl<F, L> Eq for TreeNode<F, L> {}
139
140impl<F, L> PartialEq for TreeNode<F, L> {
141    fn eq(&self, other: &Self) -> bool {
142        self.feature_idx == other.feature_idx
143    }
144}
145
146impl<F: Float, L: Label + std::fmt::Debug> TreeNode<F, L> {
147    fn empty_leaf(prediction: L, depth: usize) -> Self {
148        TreeNode {
149            feature_idx: 0,
150            feature_name: "".to_string(),
151            split_value: F::zero(),
152            impurity_decrease: F::zero(),
153            left_child: None,
154            right_child: None,
155            leaf_node: true,
156            prediction,
157            depth,
158        }
159    }
160
161    /// Returns true if the node has no children
162    pub fn is_leaf(&self) -> bool {
163        self.leaf_node
164    }
165
166    /// Returns the depth of the node in the decision tree
167    pub fn depth(&self) -> usize {
168        self.depth
169    }
170
171    /// Returns `Some(prediction)` for leaf nodes and `None` for internal nodes.
172    pub fn prediction(&self) -> Option<L> {
173        if self.is_leaf() {
174            Some(self.prediction.clone())
175        } else {
176            None
177        }
178    }
179
180    /// Returns both children, first left then right
181    pub fn children(&self) -> Vec<&Option<Box<TreeNode<F, L>>>> {
182        vec![&self.left_child, &self.right_child]
183    }
184
185    /// Return the split (feature index, value) and its impurity decrease
186    pub fn split(&self) -> (usize, F, F) {
187        (self.feature_idx, self.split_value, self.impurity_decrease)
188    }
189
190    /// Returns the name of the feature used in the split if the node is internal,
191    /// `None` otherwise
192    pub fn feature_name(&self) -> Option<&String> {
193        if self.leaf_node {
194            None
195        } else {
196            Some(&self.feature_name)
197        }
198    }
199
200    /// Recursively fits the node
201    fn fit<D: Data<Elem = F>, T: AsSingleTargets<Elem = L> + Labels<Elem = L>>(
202        data: &DatasetBase<ArrayBase<D, Ix2>, T>,
203        mask: &RowMask,
204        hyperparameters: &DecisionTreeValidParams<F, L>,
205        sorted_indices: &[SortedIndex<F>],
206        depth: usize,
207    ) -> Result<Self> {
208        // compute weighted frequencies for target classes
209        let parent_class_freq = data.label_frequencies_with_mask(&mask.mask);
210        // set our prediction for this subset to the modal class
211        let prediction = find_modal_class(&parent_class_freq);
212        // get targets from dataset
213        let target = data.as_single_targets();
214
215        // return empty leaf when we don't have enough samples or the maximal depth is reached
216        if (mask.nsamples as f32) < hyperparameters.min_weight_split()
217            || hyperparameters
218                .max_depth()
219                .map(|max_depth| depth >= max_depth)
220                .unwrap_or(false)
221        {
222            return Ok(Self::empty_leaf(prediction, depth));
223        }
224
225        // Find best split for current level
226        let mut best = None;
227
228        // Iterate over all features
229        for (feature_idx, sorted_index) in sorted_indices.iter().enumerate() {
230            let mut right_class_freq = parent_class_freq.clone();
231            let mut left_class_freq = HashMap::new();
232
233            // We keep a running total of the aggregate weight in the right split
234            // to avoid having to sum over the hash map
235            let total_weight = parent_class_freq.values().sum::<f32>();
236            let mut weight_on_right_side = total_weight;
237            let mut weight_on_left_side = 0.0;
238
239            // We start by putting all available observations in the right subtree
240            // and then move the (sorted by `feature_idx`) observations one by one to
241            // the left subtree and evaluate the quality of the resulting split. At each
242            // iteration, the obtained split is compared with `best`, in order
243            // to find the best possible split.
244            // The resulting split will then have the observations with a value of their `feature_idx`
245            // feature smaller than the split value in the left subtree and the others still in the right
246            // subtree
247            for i in 0..mask.mask.len() - 1 {
248                // (index of the observation, value of its `feature_idx` feature)
249                let (presorted_index, mut split_value) = sorted_index.sorted_values[i];
250
251                // Skip if the observation is unavailable in this subtree
252                if !mask.mask[presorted_index] {
253                    continue;
254                }
255
256                // Target and weight of the current observation
257                let sample_class = &target[presorted_index];
258                let sample_weight = data.weight_for(presorted_index);
259
260                // Move the observation from the right subtree to the left subtree
261
262                // Decrement the weight on the class for this sample on the right
263                // side by the weight of this sample
264                *right_class_freq.get_mut(sample_class).unwrap() -= sample_weight;
265                weight_on_right_side -= sample_weight;
266
267                // Increment the weight on the class for this sample on the
268                // right side by the weight of this sample
269                *left_class_freq.entry(sample_class.clone()).or_insert(0.0) += sample_weight;
270                weight_on_left_side += sample_weight;
271
272                // Continue if the next value is equal, so that equal values end up in the same subtree
273                if (sorted_index.sorted_values[i].1 - sorted_index.sorted_values[i + 1].1).abs()
274                    < F::cast(1e-5)
275                {
276                    continue;
277                }
278
279                // If the split would result in too few samples in a leaf
280                // then skip computing the quality
281                if weight_on_right_side < hyperparameters.min_weight_leaf()
282                    || weight_on_left_side < hyperparameters.min_weight_leaf()
283                {
284                    continue;
285                }
286
287                // Calculate the quality of each resulting subset of the dataset
288                let (left_score, right_score) = match hyperparameters.split_quality() {
289                    SplitQuality::Gini => (
290                        gini_impurity(&right_class_freq),
291                        gini_impurity(&left_class_freq),
292                    ),
293                    SplitQuality::Entropy => {
294                        (entropy(&right_class_freq), entropy(&left_class_freq))
295                    }
296                };
297
298                // Weight the qualities based on the number of samples in each subset
299                let w = weight_on_right_side / total_weight;
300                let score = w * left_score + (1.0 - w) * right_score;
301
302                // Take the midpoint from this value and the next one as split_value
303                split_value = (split_value + sorted_index.sorted_values[i + 1].1) / F::cast(2.0);
304
305                // override best indices when score improved
306                best = match best.take() {
307                    None => Some((feature_idx, split_value, score)),
308                    Some((_, _, best_score)) if score < best_score => {
309                        Some((feature_idx, split_value, score))
310                    }
311                    x => x,
312                };
313            }
314        }
315
316        // At this point all possible splits for all possible features have been computed
317        // and the best one (if any) is stored in `best`. Now we can compute the
318        // impurity decrease as `impurity of the node before splitting - impurity of the split`.
319        // If the impurity decrease is above the treshold set in the parameters, then the split is
320        // applied and `fit` is recursively called in the two resulting subtrees. If there is no
321        // possible split, or if it doesn't bring enough impurity decrease, then the node is set as
322        // a leaf node that predicts the most common label in the available observations.
323
324        let impurity_decrease = if let Some((_, _, best_score)) = best {
325            let parent_score = match hyperparameters.split_quality() {
326                SplitQuality::Gini => gini_impurity(&parent_class_freq),
327                SplitQuality::Entropy => entropy(&parent_class_freq),
328            };
329            let parent_score = F::cast(parent_score);
330
331            // return empty leaf if impurity has not decreased enough
332            parent_score - F::cast(best_score)
333        } else {
334            // return zero impurity decrease if we have not found any solution
335            F::zero()
336        };
337
338        if impurity_decrease < hyperparameters.min_impurity_decrease() {
339            return Ok(Self::empty_leaf(prediction, depth));
340        }
341
342        let (best_feature_idx, best_split_value, _) = best.unwrap();
343
344        // determine new masks for the left and right subtrees
345        let mut left_mask = RowMask::none(data.nsamples());
346        let mut right_mask = RowMask::none(data.nsamples());
347
348        for i in 0..data.nsamples() {
349            if mask.mask[i] {
350                if data.records()[(i, best_feature_idx)] <= best_split_value {
351                    left_mask.mark(i);
352                } else {
353                    right_mask.mark(i);
354                }
355            }
356        }
357
358        // Recurse and refit on left and right subtrees
359        let left_child = if left_mask.nsamples > 0 {
360            Some(Box::new(TreeNode::fit(
361                data,
362                &left_mask,
363                hyperparameters,
364                sorted_indices,
365                depth + 1,
366            )?))
367        } else {
368            None
369        };
370
371        let right_child = if right_mask.nsamples > 0 {
372            Some(Box::new(TreeNode::fit(
373                data,
374                &right_mask,
375                hyperparameters,
376                sorted_indices,
377                depth + 1,
378            )?))
379        } else {
380            None
381        };
382
383        let leaf_node = left_child.is_none() || right_child.is_none();
384
385        Ok(TreeNode {
386            feature_idx: best_feature_idx,
387            feature_name: sorted_indices[best_feature_idx].feature_name.to_owned(),
388            split_value: best_split_value,
389            impurity_decrease,
390            left_child,
391            right_child,
392            leaf_node,
393            prediction,
394            depth,
395        })
396    }
397
398    /// Prune tree after fitting it
399    ///
400    /// This removes parts of the tree which results in the same prediction for
401    /// all sub-trees. This is called right after fit to ensure that the tree
402    /// is small.
403    fn prune(&mut self) -> Option<L> {
404        if self.is_leaf() {
405            return Some(self.prediction.clone());
406        }
407
408        let left = self.left_child.as_mut().and_then(|x| x.prune());
409        let right = self.right_child.as_mut().and_then(|x| x.prune());
410
411        match (left, right) {
412            (Some(x), Some(y)) => {
413                if x == y {
414                    self.prediction = x.clone();
415                    self.right_child = None;
416                    self.left_child = None;
417                    self.leaf_node = true;
418
419                    Some(x)
420                } else {
421                    None
422                }
423            }
424            _ => None,
425        }
426    }
427}
428
429/// A fitted decision tree model for classification.
430///
431/// ### Structure
432/// A decision tree structure is a binary tree where:
433/// * Each internal node specifies a decision, represented by a choice of a feature and a "split value" such that all observations for which
434///     `feature <= split_value` is true fall in the left subtree, while the others fall in the right subtree.
435///
436/// * leaf nodes make predictions, and their prediction is the most popular label in the node
437///
438/// ### Algorithm
439///
440/// Starting with a single root node, decision trees are trained recursively by applying the following rule to every
441/// node considered:
442///
443/// * Find the best split value for each feature of the observations belonging in the node;
444/// * Select the feature (and its best split value) that maximizes the quality of the split;
445/// * If the score of the split is sufficiently larger than the score of the unsplit node, then two child nodes are generated, the left one
446///   containing all observations with `feature <= split value` and the right one containing the rest.
447/// * If no suitable split is found, the node is marked as leaf and its prediction is set to be the most common label in the node;
448///
449/// The [quality score](SplitQuality) used can be specified in the [parameters](crate::DecisionTreeParams).
450///
451/// ### Predictions
452///
453/// To predict the label of a sample, the tree is traversed from the root to a leaf, choosing between left and right children according to
454/// the values of the features of the sample. The final prediction for the sample is the prediction of the reached leaf.
455///
456/// ### Additional constraints
457///
458/// In order to avoid overfitting the training data, some additional constraints on the quality/quantity of splits can be added to the tree.
459/// A description of these additional rules is provided in the [parameters](crate::DecisionTreeParams) page.
460///
461/// ### Example
462///
463/// Here is an example on how to train a decision tree from its parameters:
464///
465/// ```rust
466///
467/// use linfa_trees::DecisionTree;
468/// use linfa::prelude::*;
469/// use linfa_datasets;
470///
471/// // Load the dataset
472/// let dataset = linfa_datasets::iris();
473/// // Fit the tree
474/// let tree = DecisionTree::params().fit(&dataset).unwrap();
475/// // Get accuracy on training set
476/// let accuracy = tree.predict(&dataset).confusion_matrix(&dataset).unwrap().accuracy();
477///
478/// assert!(accuracy > 0.9);
479///
480/// ```
481///
482#[cfg_attr(
483    feature = "serde",
484    derive(Serialize, Deserialize),
485    serde(crate = "serde_crate")
486)]
487#[derive(Debug, Clone, PartialEq)]
488pub struct DecisionTree<F: Float, L: Label> {
489    root_node: TreeNode<F, L>,
490    num_features: usize,
491}
492
493impl<F: Float, L: Label + Default, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<L>>
494    for DecisionTree<F, L>
495{
496    /// Make predictions for each row of a matrix of features `x`.
497    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
498        assert_eq!(
499            x.nrows(),
500            y.len(),
501            "The number of data points must match the number of output targets."
502        );
503
504        for (row, target) in x.rows().into_iter().zip(y.iter_mut()) {
505            *target = make_prediction(&row, &self.root_node);
506        }
507    }
508
509    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
510        Array1::default(x.nrows())
511    }
512}
513
514impl<F: Float, L: Label + std::fmt::Debug, D, T> Fit<ArrayBase<D, Ix2>, T, Error>
515    for DecisionTreeValidParams<F, L>
516where
517    D: Data<Elem = F>,
518    T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
519{
520    type Object = DecisionTree<F, L>;
521
522    /// Fit a decision tree using `hyperparamters` on the dataset consisting of
523    /// a matrix of features `x` and an array of labels `y`.
524    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
525        let x = dataset.records();
526        let feature_names = dataset.feature_names();
527        let all_idxs = RowMask::all(x.nrows());
528        let sorted_indices: Vec<_> = (0..(x.ncols()))
529            .map(|feature_idx| {
530                SortedIndex::of_array_column(x, feature_idx, &feature_names[feature_idx])
531            })
532            .collect();
533
534        let mut root_node = TreeNode::fit(dataset, &all_idxs, self, &sorted_indices, 0)?;
535        root_node.prune();
536
537        Ok(DecisionTree {
538            root_node,
539            num_features: dataset.records().ncols(),
540        })
541    }
542}
543
544impl<F: Float, L: Label> DecisionTree<F, L> {
545    /// Create a node iterator in level-order (BFT)
546    pub fn iter_nodes(&self) -> NodeIter<F, L> {
547        // queue of nodes yet to explore
548        let queue = vec![&self.root_node];
549
550        NodeIter::new(queue)
551    }
552
553    /// Return features_idx of this tree (BFT)
554    pub fn features(&self) -> Vec<usize> {
555        // vector of feature indexes to return
556        let mut fitted_features = HashSet::new();
557
558        for node in self.iter_nodes().filter(|node| !node.is_leaf()) {
559            if !fitted_features.contains(&node.feature_idx) {
560                fitted_features.insert(node.feature_idx);
561            }
562        }
563
564        fitted_features.into_iter().collect::<Vec<_>>()
565    }
566
567    /// Return the mean impurity decrease for each feature
568    pub fn mean_impurity_decrease(&self) -> Vec<F> {
569        // total impurity decrease for each feature
570        let mut impurity_decrease = vec![F::zero(); self.num_features];
571        let mut num_nodes = vec![0; self.num_features];
572
573        for node in self.iter_nodes().filter(|node| !node.leaf_node) {
574            // add feature impurity decrease to list
575            impurity_decrease[node.feature_idx] += node.impurity_decrease;
576            num_nodes[node.feature_idx] += 1;
577        }
578
579        impurity_decrease
580            .into_iter()
581            .zip(num_nodes)
582            .map(|(val, n)| if n == 0 { F::zero() } else { val / F::cast(n) })
583            .collect()
584    }
585
586    /// Return the relative impurity decrease for each feature
587    pub fn relative_impurity_decrease(&self) -> Vec<F> {
588        let mean_impurity_decrease = self.mean_impurity_decrease();
589        let sum = mean_impurity_decrease.iter().cloned().sum();
590
591        mean_impurity_decrease
592            .into_iter()
593            .map(|x| x / sum)
594            .collect()
595    }
596
597    /// Return the feature importance, i.e. the relative impurity decrease, for each feature
598    pub fn feature_importance(&self) -> Vec<F> {
599        self.relative_impurity_decrease()
600    }
601
602    /// Return root node of the tree
603    pub fn root_node(&self) -> &TreeNode<F, L> {
604        &self.root_node
605    }
606
607    /// Return max depth of the tree
608    pub fn max_depth(&self) -> usize {
609        self.iter_nodes()
610            .fold(0, |max, node| usize::max(max, node.depth))
611    }
612
613    /// Return the number of leaves in this tree
614    pub fn num_leaves(&self) -> usize {
615        self.iter_nodes().filter(|node| node.is_leaf()).count()
616    }
617
618    /// Generates a [`Tikz`] structure to print the
619    /// fitted tree in Tex using tikz and forest, with the following default parameters:
620    ///
621    /// * `legend=false`
622    /// * `complete=true`
623    ///
624    pub fn export_to_tikz(&self) -> Tikz<F, L> {
625        Tikz::new(self)
626    }
627}
628
629/// Classify a sample &x recursively using the tree node `node`.
630fn make_prediction<F: Float, L: Label>(
631    x: &ArrayBase<impl Data<Elem = F>, Ix1>,
632    node: &TreeNode<F, L>,
633) -> L {
634    if node.leaf_node {
635        node.prediction.clone()
636    } else if x[node.feature_idx] < node.split_value {
637        make_prediction(x, node.left_child.as_ref().unwrap())
638    } else {
639        make_prediction(x, node.right_child.as_ref().unwrap())
640    }
641}
642
643/// Finds the most frequent class for a hash map of frequencies. If two
644/// classes have the same weight then the first class found with that
645/// frequency is returned.
646fn find_modal_class<L: Label>(class_freq: &HashMap<L, f32>) -> L {
647    // TODO: Refactor this with fold_first
648
649    let val = class_freq
650        .iter()
651        .fold(None, |acc, (idx, freq)| match acc {
652            None => Some((idx, freq)),
653            Some((_best_idx, best_freq)) => {
654                if best_freq > freq {
655                    acc
656                } else {
657                    Some((idx, freq))
658                }
659            }
660        })
661        .unwrap()
662        .0;
663
664    (*val).clone()
665}
666
667/// Given the class frequencies calculates the gini impurity of the subset.
668fn gini_impurity<L: Label>(class_freq: &HashMap<L, f32>) -> f32 {
669    let n_samples = class_freq.values().sum::<f32>();
670    assert!(n_samples > 0.0);
671
672    let purity = class_freq
673        .values()
674        .map(|x| x / n_samples)
675        .map(|x| x * x)
676        .sum::<f32>();
677
678    1.0 - purity
679}
680
681/// Given the class frequencies calculates the entropy of the subset.
682fn entropy<L: Label>(class_freq: &HashMap<L, f32>) -> f32 {
683    let n_samples = class_freq.values().sum::<f32>();
684    assert!(n_samples > 0.0);
685
686    class_freq
687        .values()
688        .map(|x| x / n_samples)
689        .map(|x| if x > 0.0 { -x * x.log2() } else { 0.0 })
690        .sum()
691}
692
693#[cfg(test)]
694mod tests {
695    use super::*;
696
697    use approx::assert_abs_diff_eq;
698    use linfa::{error::Result, metrics::ToConfusionMatrix, Dataset, ParamGuard};
699    use ndarray::{array, concatenate, s, Array, Array1, Array2, Axis};
700    use rand::rngs::SmallRng;
701
702    use crate::DecisionTreeParams;
703    use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
704
705    #[test]
706    fn autotraits() {
707        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
708        has_autotraits::<DecisionTree<f64, bool>>();
709        has_autotraits::<TreeNode<f64, bool>>();
710        has_autotraits::<DecisionTreeValidParams<f64, bool>>();
711        has_autotraits::<DecisionTreeParams<f64, bool>>();
712        has_autotraits::<NodeIter<f64, bool>>();
713        has_autotraits::<Tikz<f64, bool>>();
714    }
715
716    #[test]
717    fn prediction_for_rows_example() {
718        let labels = Array::from(vec![0, 0, 0, 0, 0, 0, 1, 1]);
719        let row_mask = RowMask::all(labels.len());
720
721        let dataset: DatasetBase<(), Array1<usize>> = DatasetBase::new((), labels);
722        let class_freq = dataset.label_frequencies_with_mask(&row_mask.mask);
723
724        assert_eq!(find_modal_class(&class_freq), 0);
725    }
726
727    #[test]
728    fn gini_impurity_example() {
729        let class_freq = vec![(0, 6.0), (1, 2.0), (2, 0.0)].into_iter().collect();
730
731        // Class 0 occurs 75% of the time
732        // Class 1 occurs 25% of the time
733        // Class 2 occurs 0% of the time
734        // Gini impurity is 1 - 0.75*0.75 - 0.25*0.25 - 0*0 = 0.375
735        assert_abs_diff_eq!(gini_impurity(&class_freq), 0.375, epsilon = 1e-5);
736    }
737
738    #[test]
739    fn entropy_example() {
740        let class_freq = vec![(0, 6.0), (1, 2.0), (2, 0.0)].into_iter().collect();
741
742        // Class 0 occurs 75% of the time
743        // Class 1 occurs 25% of the time
744        // Class 2 occurs 0% of the time
745        // Entropy is -0.75*log2(0.75) - 0.25*log2(0.25) - 0*log2(0) = 0.81127812
746        assert_abs_diff_eq!(entropy(&class_freq), 0.81127, epsilon = 1e-5);
747
748        // If split is perfect then entropy is zero
749        let perfect_class_freq = vec![(0, 8.0), (1, 0.0), (2, 0.0)].into_iter().collect();
750
751        assert_abs_diff_eq!(entropy(&perfect_class_freq), 0.0, epsilon = 1e-5);
752    }
753
754    #[test]
755    /// Single feature test
756    ///
757    /// Generate a dataset where a single feature perfectly correlates
758    /// with the target while the remaining features are random gaussian
759    /// noise and do not add any information.
760    fn single_feature_random_noise_binary() -> Result<()> {
761        // generate data with 9 white noise and a single correlated feature
762        let mut data = Array::random((50, 10), Uniform::new(-4., 4.));
763        data.slice_mut(s![.., 8]).assign(
764            &(0..50)
765                .map(|x| if x < 25 { 0.0 } else { 1.0 })
766                .collect::<Array1<_>>(),
767        );
768
769        let targets = (0..50).map(|x| x < 25).collect::<Array1<_>>();
770        let dataset = Dataset::new(data, targets);
771
772        let model = DecisionTree::params().max_depth(Some(2)).fit(&dataset)?;
773
774        // we should only use feature index 8 here
775        assert_eq!(&model.features(), &[8]);
776
777        let ground_truth = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
778
779        for (imp, truth) in model.feature_importance().iter().zip(&ground_truth) {
780            assert_abs_diff_eq!(imp, truth, epsilon = 1e-15);
781        }
782
783        // check for perfect accuracy
784        let cm = model
785            .predict(dataset.records())
786            .confusion_matrix(&dataset)?;
787        assert_abs_diff_eq!(cm.accuracy(), 1.0, epsilon = 1e-15);
788
789        Ok(())
790    }
791
792    #[test]
793    /// Check that for random data the max depth is used
794    fn check_max_depth() -> Result<()> {
795        let mut rng = SmallRng::seed_from_u64(42);
796
797        // create very sparse data
798        let data = Array::random_using((50, 50), Uniform::new(-1., 1.), &mut rng);
799        let targets = (0..50).collect::<Array1<usize>>();
800
801        let dataset = Dataset::new(data, targets);
802
803        // check that the provided depth is actually used
804        for max_depth in &[1, 5, 10, 20] {
805            let model = DecisionTree::params()
806                .max_depth(Some(*max_depth))
807                .min_impurity_decrease(1e-10f64)
808                .min_weight_split(1e-10)
809                .fit(&dataset)?;
810            assert_eq!(model.max_depth(), *max_depth);
811        }
812
813        Ok(())
814    }
815
816    #[test]
817    /// Small perfectly separable dataset test
818    ///
819    /// This dataset of three elements is perfectly using the second feature.
820    fn perfectly_separable_small() -> Result<()> {
821        let data = array![[1., 2., 3.], [1., 2., 4.], [1., 3., 3.5]];
822        let targets = array![0, 0, 1];
823
824        let dataset = Dataset::new(data.clone(), targets);
825        let model = DecisionTree::params().max_depth(Some(1)).fit(&dataset)?;
826
827        assert_eq!(model.predict(&data), array![0, 0, 1]);
828
829        Ok(())
830    }
831
832    #[test]
833    /// Small toy dataset from scikit-sklearn
834    fn toy_dataset() -> Result<()> {
835        let data = array![
836            [0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 1.0, -14.0, 0.0, -4.0, 0.0, 0.0, 0.0, 0.0,],
837            [0.0, 0.0, 5.0, 3.0, 0.0, -4.0, 0.0, 0.0, 1.0, -5.0, 0.2, 0.0, 4.0, 1.0,],
838            [-1.0, -1.0, 0.0, 0.0, -4.5, 0.0, 0.0, 2.1, 1.0, 0.0, 0.0, -4.5, 0.0, 1.0,],
839            [-1.0, -1.0, 0.0, -1.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.0, 0.0, 1.0,],
840            [-1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,],
841            [-1.0, -2.0, 0.0, 4.0, -3.0, 10.0, 4.0, 0.0, -3.2, 0.0, 4.0, 3.0, -4.0, 1.0,],
842            [2.11, 0.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -3.0, 1.0,],
843            [2.11, 0.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.0, 0.0, -2.0, 1.0,],
844            [2.11, 8.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.0, 0.0, -2.0, 1.0,],
845            [2.11, 8.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -1.0, 0.0,],
846            [2.0, 8.0, 5.0, 1.0, 0.5, -4.0, 10.0, 0.0, 1.0, -5.0, 3.0, 0.0, 2.0, 0.0,],
847            [2.0, 0.0, 1.0, 1.0, 1.0, -1.0, 1.0, 0.0, 0.0, -2.0, 3.0, 0.0, 1.0, 0.0,],
848            [2.0, 0.0, 1.0, 2.0, 3.0, -1.0, 10.0, 2.0, 0.0, -1.0, 1.0, 2.0, 2.0, 0.0,],
849            [1.0, 1.0, 0.0, 2.0, 2.0, -1.0, 1.0, 2.0, 0.0, -5.0, 1.0, 2.0, 3.0, 0.0,],
850            [3.0, 1.0, 0.0, 3.0, 0.0, -4.0, 10.0, 0.0, 1.0, -5.0, 3.0, 0.0, 3.0, 1.0,],
851            [2.11, 8.0, -6.0, -0.5, 0.0, 1.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -3.0, 1.0,],
852            [2.11, 8.0, -6.0, -0.5, 0.0, 1.0, 0.0, 0.0, -3.2, 6.0, 1.5, 1.0, -1.0, -1.0,],
853            [2.11, 8.0, -6.0, -0.5, 0.0, 10.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -1.0, -1.0,],
854            [2.0, 0.0, 5.0, 1.0, 0.5, -2.0, 10.0, 0.0, 1.0, -5.0, 3.0, 1.0, 0.0, -1.0,],
855            [2.0, 0.0, 1.0, 1.0, 1.0, -2.0, 1.0, 0.0, 0.0, -2.0, 0.0, 0.0, 0.0, 1.0,],
856            [2.0, 1.0, 1.0, 1.0, 2.0, -1.0, 10.0, 2.0, 0.0, -1.0, 0.0, 2.0, 1.0, 1.0,],
857            [1.0, 1.0, 0.0, 0.0, 1.0, -3.0, 1.0, 2.0, 0.0, -5.0, 1.0, 2.0, 1.0, 1.0,],
858            [3.0, 1.0, 0.0, 1.0, 0.0, -4.0, 1.0, 0.0, 1.0, -2.0, 0.0, 0.0, 1.0, 0.0,]
859        ];
860
861        let targets = array![1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0];
862
863        let dataset = Dataset::new(data, targets);
864        let model = DecisionTree::params().fit(&dataset)?;
865        let prediction = model.predict(&dataset);
866
867        let cm = prediction.confusion_matrix(&dataset)?;
868        assert!(cm.accuracy() > 0.95);
869
870        Ok(())
871    }
872
873    #[test]
874    /// Multilabel classification
875    fn multilabel_four_uniform() -> Result<()> {
876        let mut data = concatenate(
877            Axis(0),
878            &[Array2::random((40, 2), Uniform::new(-1., 1.)).view()],
879        )
880        .unwrap();
881
882        data.outer_iter_mut().enumerate().for_each(|(i, mut p)| {
883            if i < 10 {
884                p += &array![-2., -2.]
885            } else if i < 20 {
886                p += &array![-2., 2.];
887            } else if i < 30 {
888                p += &array![2., -2.];
889            } else {
890                p += &array![2., 2.];
891            }
892        });
893
894        let targets = (0..40)
895            .map(|x| match x {
896                x if x < 10 => 0,
897                x if x < 20 => 1,
898                x if x < 30 => 2,
899                _ => 3,
900            })
901            .collect::<Array1<_>>();
902
903        let dataset = Dataset::new(data.clone(), targets);
904
905        let model = DecisionTree::params().fit(&dataset)?;
906        let prediction = model.predict(data);
907
908        let cm = prediction.confusion_matrix(&dataset)?;
909        assert!(cm.accuracy() > 0.99);
910
911        Ok(())
912    }
913
914    #[test]
915    #[should_panic]
916    /// Check that a small or negative impurity decrease panics
917    fn panic_min_impurity_decrease() {
918        DecisionTree::<f64, bool>::params()
919            .min_impurity_decrease(0.0)
920            .check()
921            .unwrap();
922    }
923}