1use std::cmp::Ordering;
4use std::collections::{HashMap, HashSet};
5use std::hash::{Hash, Hasher};
6
7use linfa::dataset::AsSingleTargets;
8use ndarray::{Array1, ArrayBase, Axis, Data, Ix1, Ix2};
9
10use super::NodeIter;
11use super::Tikz;
12use super::{DecisionTreeValidParams, SplitQuality};
13use linfa::{
14 dataset::{Labels, Records},
15 error::Error,
16 error::Result,
17 traits::*,
18 DatasetBase, Float, Label,
19};
20
21#[cfg(feature = "serde")]
22use serde_crate::{Deserialize, Serialize};
23
24struct RowMask {
31 mask: Vec<bool>,
32 nsamples: usize,
33}
34
35impl RowMask {
36 fn all(nsamples: usize) -> Self {
43 RowMask {
44 mask: vec![true; nsamples],
45 nsamples,
46 }
47 }
48
49 fn none(nsamples: usize) -> Self {
55 RowMask {
56 mask: vec![false; nsamples],
57 nsamples: 0,
58 }
59 }
60
61 fn mark(&mut self, idx: usize) {
72 self.mask[idx] = true;
73 self.nsamples += 1;
74 }
75}
76
77struct SortedIndex<'a, F: Float> {
79 feature_name: &'a str,
80 sorted_values: Vec<(usize, F)>,
81}
82
83impl<'a, F: Float> SortedIndex<'a, F> {
84 fn of_array_column(
97 x: &ArrayBase<impl Data<Elem = F>, Ix2>,
98 feature_idx: usize,
99 feature_name: &'a str,
100 ) -> Self {
101 let sliced_column: Vec<F> = x.index_axis(Axis(1), feature_idx).to_vec();
102 let mut pairs: Vec<(usize, F)> = sliced_column.into_iter().enumerate().collect();
103 pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Greater));
104
105 SortedIndex {
106 sorted_values: pairs,
107 feature_name,
108 }
109 }
110}
111
112#[cfg_attr(
113 feature = "serde",
114 derive(Serialize, Deserialize),
115 serde(crate = "serde_crate")
116)]
117#[derive(Debug, Clone)]
118pub struct TreeNode<F, L> {
120 feature_idx: usize,
121 feature_name: String,
122 split_value: F,
123 impurity_decrease: F,
124 left_child: Option<Box<TreeNode<F, L>>>,
125 right_child: Option<Box<TreeNode<F, L>>>,
126 leaf_node: bool,
127 prediction: L,
128 depth: usize,
129}
130
131impl<F: Float, L: Label> Hash for TreeNode<F, L> {
132 fn hash<H: Hasher>(&self, state: &mut H) {
133 let data: Vec<u64> = vec![self.feature_idx as u64, self.leaf_node as u64];
134 data.hash(state);
135 }
136}
137
138impl<F, L> Eq for TreeNode<F, L> {}
139
140impl<F, L> PartialEq for TreeNode<F, L> {
141 fn eq(&self, other: &Self) -> bool {
142 self.feature_idx == other.feature_idx
143 }
144}
145
146impl<F: Float, L: Label + std::fmt::Debug> TreeNode<F, L> {
147 fn empty_leaf(prediction: L, depth: usize) -> Self {
148 TreeNode {
149 feature_idx: 0,
150 feature_name: "".to_string(),
151 split_value: F::zero(),
152 impurity_decrease: F::zero(),
153 left_child: None,
154 right_child: None,
155 leaf_node: true,
156 prediction,
157 depth,
158 }
159 }
160
161 pub fn is_leaf(&self) -> bool {
163 self.leaf_node
164 }
165
166 pub fn depth(&self) -> usize {
168 self.depth
169 }
170
171 pub fn prediction(&self) -> Option<L> {
173 if self.is_leaf() {
174 Some(self.prediction.clone())
175 } else {
176 None
177 }
178 }
179
180 pub fn children(&self) -> Vec<&Option<Box<TreeNode<F, L>>>> {
182 vec![&self.left_child, &self.right_child]
183 }
184
185 pub fn split(&self) -> (usize, F, F) {
187 (self.feature_idx, self.split_value, self.impurity_decrease)
188 }
189
190 pub fn feature_name(&self) -> Option<&String> {
193 if self.leaf_node {
194 None
195 } else {
196 Some(&self.feature_name)
197 }
198 }
199
200 fn fit<D: Data<Elem = F>, T: AsSingleTargets<Elem = L> + Labels<Elem = L>>(
202 data: &DatasetBase<ArrayBase<D, Ix2>, T>,
203 mask: &RowMask,
204 hyperparameters: &DecisionTreeValidParams<F, L>,
205 sorted_indices: &[SortedIndex<F>],
206 depth: usize,
207 ) -> Result<Self> {
208 let parent_class_freq = data.label_frequencies_with_mask(&mask.mask);
210 let prediction = find_modal_class(&parent_class_freq);
212 let target = data.as_single_targets();
214
215 if (mask.nsamples as f32) < hyperparameters.min_weight_split()
217 || hyperparameters
218 .max_depth()
219 .map(|max_depth| depth >= max_depth)
220 .unwrap_or(false)
221 {
222 return Ok(Self::empty_leaf(prediction, depth));
223 }
224
225 let mut best = None;
227
228 for (feature_idx, sorted_index) in sorted_indices.iter().enumerate() {
230 let mut right_class_freq = parent_class_freq.clone();
231 let mut left_class_freq = HashMap::new();
232
233 let total_weight = parent_class_freq.values().sum::<f32>();
236 let mut weight_on_right_side = total_weight;
237 let mut weight_on_left_side = 0.0;
238
239 for i in 0..mask.mask.len() - 1 {
248 let (presorted_index, mut split_value) = sorted_index.sorted_values[i];
250
251 if !mask.mask[presorted_index] {
253 continue;
254 }
255
256 let sample_class = &target[presorted_index];
258 let sample_weight = data.weight_for(presorted_index);
259
260 *right_class_freq.get_mut(sample_class).unwrap() -= sample_weight;
265 weight_on_right_side -= sample_weight;
266
267 *left_class_freq.entry(sample_class.clone()).or_insert(0.0) += sample_weight;
270 weight_on_left_side += sample_weight;
271
272 if (sorted_index.sorted_values[i].1 - sorted_index.sorted_values[i + 1].1).abs()
274 < F::cast(1e-5)
275 {
276 continue;
277 }
278
279 if weight_on_right_side < hyperparameters.min_weight_leaf()
282 || weight_on_left_side < hyperparameters.min_weight_leaf()
283 {
284 continue;
285 }
286
287 let (left_score, right_score) = match hyperparameters.split_quality() {
289 SplitQuality::Gini => (
290 gini_impurity(&right_class_freq),
291 gini_impurity(&left_class_freq),
292 ),
293 SplitQuality::Entropy => {
294 (entropy(&right_class_freq), entropy(&left_class_freq))
295 }
296 };
297
298 let w = weight_on_right_side / total_weight;
300 let score = w * left_score + (1.0 - w) * right_score;
301
302 split_value = (split_value + sorted_index.sorted_values[i + 1].1) / F::cast(2.0);
304
305 best = match best.take() {
307 None => Some((feature_idx, split_value, score)),
308 Some((_, _, best_score)) if score < best_score => {
309 Some((feature_idx, split_value, score))
310 }
311 x => x,
312 };
313 }
314 }
315
316 let impurity_decrease = if let Some((_, _, best_score)) = best {
325 let parent_score = match hyperparameters.split_quality() {
326 SplitQuality::Gini => gini_impurity(&parent_class_freq),
327 SplitQuality::Entropy => entropy(&parent_class_freq),
328 };
329 let parent_score = F::cast(parent_score);
330
331 parent_score - F::cast(best_score)
333 } else {
334 F::zero()
336 };
337
338 if impurity_decrease < hyperparameters.min_impurity_decrease() {
339 return Ok(Self::empty_leaf(prediction, depth));
340 }
341
342 let (best_feature_idx, best_split_value, _) = best.unwrap();
343
344 let mut left_mask = RowMask::none(data.nsamples());
346 let mut right_mask = RowMask::none(data.nsamples());
347
348 for i in 0..data.nsamples() {
349 if mask.mask[i] {
350 if data.records()[(i, best_feature_idx)] <= best_split_value {
351 left_mask.mark(i);
352 } else {
353 right_mask.mark(i);
354 }
355 }
356 }
357
358 let left_child = if left_mask.nsamples > 0 {
360 Some(Box::new(TreeNode::fit(
361 data,
362 &left_mask,
363 hyperparameters,
364 sorted_indices,
365 depth + 1,
366 )?))
367 } else {
368 None
369 };
370
371 let right_child = if right_mask.nsamples > 0 {
372 Some(Box::new(TreeNode::fit(
373 data,
374 &right_mask,
375 hyperparameters,
376 sorted_indices,
377 depth + 1,
378 )?))
379 } else {
380 None
381 };
382
383 let leaf_node = left_child.is_none() || right_child.is_none();
384
385 Ok(TreeNode {
386 feature_idx: best_feature_idx,
387 feature_name: sorted_indices[best_feature_idx].feature_name.to_owned(),
388 split_value: best_split_value,
389 impurity_decrease,
390 left_child,
391 right_child,
392 leaf_node,
393 prediction,
394 depth,
395 })
396 }
397
398 fn prune(&mut self) -> Option<L> {
404 if self.is_leaf() {
405 return Some(self.prediction.clone());
406 }
407
408 let left = self.left_child.as_mut().and_then(|x| x.prune());
409 let right = self.right_child.as_mut().and_then(|x| x.prune());
410
411 match (left, right) {
412 (Some(x), Some(y)) => {
413 if x == y {
414 self.prediction = x.clone();
415 self.right_child = None;
416 self.left_child = None;
417 self.leaf_node = true;
418
419 Some(x)
420 } else {
421 None
422 }
423 }
424 _ => None,
425 }
426 }
427}
428
429#[cfg_attr(
483 feature = "serde",
484 derive(Serialize, Deserialize),
485 serde(crate = "serde_crate")
486)]
487#[derive(Debug, Clone, PartialEq)]
488pub struct DecisionTree<F: Float, L: Label> {
489 root_node: TreeNode<F, L>,
490 num_features: usize,
491}
492
493impl<F: Float, L: Label + Default, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<L>>
494 for DecisionTree<F, L>
495{
496 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
498 assert_eq!(
499 x.nrows(),
500 y.len(),
501 "The number of data points must match the number of output targets."
502 );
503
504 for (row, target) in x.rows().into_iter().zip(y.iter_mut()) {
505 *target = make_prediction(&row, &self.root_node);
506 }
507 }
508
509 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
510 Array1::default(x.nrows())
511 }
512}
513
514impl<F: Float, L: Label + std::fmt::Debug, D, T> Fit<ArrayBase<D, Ix2>, T, Error>
515 for DecisionTreeValidParams<F, L>
516where
517 D: Data<Elem = F>,
518 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
519{
520 type Object = DecisionTree<F, L>;
521
522 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
525 let x = dataset.records();
526 let feature_names = dataset.feature_names();
527 let all_idxs = RowMask::all(x.nrows());
528 let sorted_indices: Vec<_> = (0..(x.ncols()))
529 .map(|feature_idx| {
530 SortedIndex::of_array_column(x, feature_idx, &feature_names[feature_idx])
531 })
532 .collect();
533
534 let mut root_node = TreeNode::fit(dataset, &all_idxs, self, &sorted_indices, 0)?;
535 root_node.prune();
536
537 Ok(DecisionTree {
538 root_node,
539 num_features: dataset.records().ncols(),
540 })
541 }
542}
543
544impl<F: Float, L: Label> DecisionTree<F, L> {
545 pub fn iter_nodes(&self) -> NodeIter<F, L> {
547 let queue = vec![&self.root_node];
549
550 NodeIter::new(queue)
551 }
552
553 pub fn features(&self) -> Vec<usize> {
555 let mut fitted_features = HashSet::new();
557
558 for node in self.iter_nodes().filter(|node| !node.is_leaf()) {
559 if !fitted_features.contains(&node.feature_idx) {
560 fitted_features.insert(node.feature_idx);
561 }
562 }
563
564 fitted_features.into_iter().collect::<Vec<_>>()
565 }
566
567 pub fn mean_impurity_decrease(&self) -> Vec<F> {
569 let mut impurity_decrease = vec![F::zero(); self.num_features];
571 let mut num_nodes = vec![0; self.num_features];
572
573 for node in self.iter_nodes().filter(|node| !node.leaf_node) {
574 impurity_decrease[node.feature_idx] += node.impurity_decrease;
576 num_nodes[node.feature_idx] += 1;
577 }
578
579 impurity_decrease
580 .into_iter()
581 .zip(num_nodes)
582 .map(|(val, n)| if n == 0 { F::zero() } else { val / F::cast(n) })
583 .collect()
584 }
585
586 pub fn relative_impurity_decrease(&self) -> Vec<F> {
588 let mean_impurity_decrease = self.mean_impurity_decrease();
589 let sum = mean_impurity_decrease.iter().cloned().sum();
590
591 mean_impurity_decrease
592 .into_iter()
593 .map(|x| x / sum)
594 .collect()
595 }
596
597 pub fn feature_importance(&self) -> Vec<F> {
599 self.relative_impurity_decrease()
600 }
601
602 pub fn root_node(&self) -> &TreeNode<F, L> {
604 &self.root_node
605 }
606
607 pub fn max_depth(&self) -> usize {
609 self.iter_nodes()
610 .fold(0, |max, node| usize::max(max, node.depth))
611 }
612
613 pub fn num_leaves(&self) -> usize {
615 self.iter_nodes().filter(|node| node.is_leaf()).count()
616 }
617
618 pub fn export_to_tikz(&self) -> Tikz<F, L> {
625 Tikz::new(self)
626 }
627}
628
629fn make_prediction<F: Float, L: Label>(
631 x: &ArrayBase<impl Data<Elem = F>, Ix1>,
632 node: &TreeNode<F, L>,
633) -> L {
634 if node.leaf_node {
635 node.prediction.clone()
636 } else if x[node.feature_idx] < node.split_value {
637 make_prediction(x, node.left_child.as_ref().unwrap())
638 } else {
639 make_prediction(x, node.right_child.as_ref().unwrap())
640 }
641}
642
643fn find_modal_class<L: Label>(class_freq: &HashMap<L, f32>) -> L {
647 let val = class_freq
650 .iter()
651 .fold(None, |acc, (idx, freq)| match acc {
652 None => Some((idx, freq)),
653 Some((_best_idx, best_freq)) => {
654 if best_freq > freq {
655 acc
656 } else {
657 Some((idx, freq))
658 }
659 }
660 })
661 .unwrap()
662 .0;
663
664 (*val).clone()
665}
666
667fn gini_impurity<L: Label>(class_freq: &HashMap<L, f32>) -> f32 {
669 let n_samples = class_freq.values().sum::<f32>();
670 assert!(n_samples > 0.0);
671
672 let purity = class_freq
673 .values()
674 .map(|x| x / n_samples)
675 .map(|x| x * x)
676 .sum::<f32>();
677
678 1.0 - purity
679}
680
681fn entropy<L: Label>(class_freq: &HashMap<L, f32>) -> f32 {
683 let n_samples = class_freq.values().sum::<f32>();
684 assert!(n_samples > 0.0);
685
686 class_freq
687 .values()
688 .map(|x| x / n_samples)
689 .map(|x| if x > 0.0 { -x * x.log2() } else { 0.0 })
690 .sum()
691}
692
693#[cfg(test)]
694mod tests {
695 use super::*;
696
697 use approx::assert_abs_diff_eq;
698 use linfa::{error::Result, metrics::ToConfusionMatrix, Dataset, ParamGuard};
699 use ndarray::{array, concatenate, s, Array, Array1, Array2, Axis};
700 use rand::rngs::SmallRng;
701
702 use crate::DecisionTreeParams;
703 use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
704
705 #[test]
706 fn autotraits() {
707 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
708 has_autotraits::<DecisionTree<f64, bool>>();
709 has_autotraits::<TreeNode<f64, bool>>();
710 has_autotraits::<DecisionTreeValidParams<f64, bool>>();
711 has_autotraits::<DecisionTreeParams<f64, bool>>();
712 has_autotraits::<NodeIter<f64, bool>>();
713 has_autotraits::<Tikz<f64, bool>>();
714 }
715
716 #[test]
717 fn prediction_for_rows_example() {
718 let labels = Array::from(vec![0, 0, 0, 0, 0, 0, 1, 1]);
719 let row_mask = RowMask::all(labels.len());
720
721 let dataset: DatasetBase<(), Array1<usize>> = DatasetBase::new((), labels);
722 let class_freq = dataset.label_frequencies_with_mask(&row_mask.mask);
723
724 assert_eq!(find_modal_class(&class_freq), 0);
725 }
726
727 #[test]
728 fn gini_impurity_example() {
729 let class_freq = vec![(0, 6.0), (1, 2.0), (2, 0.0)].into_iter().collect();
730
731 assert_abs_diff_eq!(gini_impurity(&class_freq), 0.375, epsilon = 1e-5);
736 }
737
738 #[test]
739 fn entropy_example() {
740 let class_freq = vec![(0, 6.0), (1, 2.0), (2, 0.0)].into_iter().collect();
741
742 assert_abs_diff_eq!(entropy(&class_freq), 0.81127, epsilon = 1e-5);
747
748 let perfect_class_freq = vec![(0, 8.0), (1, 0.0), (2, 0.0)].into_iter().collect();
750
751 assert_abs_diff_eq!(entropy(&perfect_class_freq), 0.0, epsilon = 1e-5);
752 }
753
754 #[test]
755 fn single_feature_random_noise_binary() -> Result<()> {
761 let mut data = Array::random((50, 10), Uniform::new(-4., 4.));
763 data.slice_mut(s![.., 8]).assign(
764 &(0..50)
765 .map(|x| if x < 25 { 0.0 } else { 1.0 })
766 .collect::<Array1<_>>(),
767 );
768
769 let targets = (0..50).map(|x| x < 25).collect::<Array1<_>>();
770 let dataset = Dataset::new(data, targets);
771
772 let model = DecisionTree::params().max_depth(Some(2)).fit(&dataset)?;
773
774 assert_eq!(&model.features(), &[8]);
776
777 let ground_truth = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
778
779 for (imp, truth) in model.feature_importance().iter().zip(&ground_truth) {
780 assert_abs_diff_eq!(imp, truth, epsilon = 1e-15);
781 }
782
783 let cm = model
785 .predict(dataset.records())
786 .confusion_matrix(&dataset)?;
787 assert_abs_diff_eq!(cm.accuracy(), 1.0, epsilon = 1e-15);
788
789 Ok(())
790 }
791
792 #[test]
793 fn check_max_depth() -> Result<()> {
795 let mut rng = SmallRng::seed_from_u64(42);
796
797 let data = Array::random_using((50, 50), Uniform::new(-1., 1.), &mut rng);
799 let targets = (0..50).collect::<Array1<usize>>();
800
801 let dataset = Dataset::new(data, targets);
802
803 for max_depth in &[1, 5, 10, 20] {
805 let model = DecisionTree::params()
806 .max_depth(Some(*max_depth))
807 .min_impurity_decrease(1e-10f64)
808 .min_weight_split(1e-10)
809 .fit(&dataset)?;
810 assert_eq!(model.max_depth(), *max_depth);
811 }
812
813 Ok(())
814 }
815
816 #[test]
817 fn perfectly_separable_small() -> Result<()> {
821 let data = array![[1., 2., 3.], [1., 2., 4.], [1., 3., 3.5]];
822 let targets = array![0, 0, 1];
823
824 let dataset = Dataset::new(data.clone(), targets);
825 let model = DecisionTree::params().max_depth(Some(1)).fit(&dataset)?;
826
827 assert_eq!(model.predict(&data), array![0, 0, 1]);
828
829 Ok(())
830 }
831
832 #[test]
833 fn toy_dataset() -> Result<()> {
835 let data = array![
836 [0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 1.0, -14.0, 0.0, -4.0, 0.0, 0.0, 0.0, 0.0,],
837 [0.0, 0.0, 5.0, 3.0, 0.0, -4.0, 0.0, 0.0, 1.0, -5.0, 0.2, 0.0, 4.0, 1.0,],
838 [-1.0, -1.0, 0.0, 0.0, -4.5, 0.0, 0.0, 2.1, 1.0, 0.0, 0.0, -4.5, 0.0, 1.0,],
839 [-1.0, -1.0, 0.0, -1.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.0, 0.0, 1.0,],
840 [-1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,],
841 [-1.0, -2.0, 0.0, 4.0, -3.0, 10.0, 4.0, 0.0, -3.2, 0.0, 4.0, 3.0, -4.0, 1.0,],
842 [2.11, 0.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -3.0, 1.0,],
843 [2.11, 0.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.0, 0.0, -2.0, 1.0,],
844 [2.11, 8.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.0, 0.0, -2.0, 1.0,],
845 [2.11, 8.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -1.0, 0.0,],
846 [2.0, 8.0, 5.0, 1.0, 0.5, -4.0, 10.0, 0.0, 1.0, -5.0, 3.0, 0.0, 2.0, 0.0,],
847 [2.0, 0.0, 1.0, 1.0, 1.0, -1.0, 1.0, 0.0, 0.0, -2.0, 3.0, 0.0, 1.0, 0.0,],
848 [2.0, 0.0, 1.0, 2.0, 3.0, -1.0, 10.0, 2.0, 0.0, -1.0, 1.0, 2.0, 2.0, 0.0,],
849 [1.0, 1.0, 0.0, 2.0, 2.0, -1.0, 1.0, 2.0, 0.0, -5.0, 1.0, 2.0, 3.0, 0.0,],
850 [3.0, 1.0, 0.0, 3.0, 0.0, -4.0, 10.0, 0.0, 1.0, -5.0, 3.0, 0.0, 3.0, 1.0,],
851 [2.11, 8.0, -6.0, -0.5, 0.0, 1.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -3.0, 1.0,],
852 [2.11, 8.0, -6.0, -0.5, 0.0, 1.0, 0.0, 0.0, -3.2, 6.0, 1.5, 1.0, -1.0, -1.0,],
853 [2.11, 8.0, -6.0, -0.5, 0.0, 10.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -1.0, -1.0,],
854 [2.0, 0.0, 5.0, 1.0, 0.5, -2.0, 10.0, 0.0, 1.0, -5.0, 3.0, 1.0, 0.0, -1.0,],
855 [2.0, 0.0, 1.0, 1.0, 1.0, -2.0, 1.0, 0.0, 0.0, -2.0, 0.0, 0.0, 0.0, 1.0,],
856 [2.0, 1.0, 1.0, 1.0, 2.0, -1.0, 10.0, 2.0, 0.0, -1.0, 0.0, 2.0, 1.0, 1.0,],
857 [1.0, 1.0, 0.0, 0.0, 1.0, -3.0, 1.0, 2.0, 0.0, -5.0, 1.0, 2.0, 1.0, 1.0,],
858 [3.0, 1.0, 0.0, 1.0, 0.0, -4.0, 1.0, 0.0, 1.0, -2.0, 0.0, 0.0, 1.0, 0.0,]
859 ];
860
861 let targets = array![1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0];
862
863 let dataset = Dataset::new(data, targets);
864 let model = DecisionTree::params().fit(&dataset)?;
865 let prediction = model.predict(&dataset);
866
867 let cm = prediction.confusion_matrix(&dataset)?;
868 assert!(cm.accuracy() > 0.95);
869
870 Ok(())
871 }
872
873 #[test]
874 fn multilabel_four_uniform() -> Result<()> {
876 let mut data = concatenate(
877 Axis(0),
878 &[Array2::random((40, 2), Uniform::new(-1., 1.)).view()],
879 )
880 .unwrap();
881
882 data.outer_iter_mut().enumerate().for_each(|(i, mut p)| {
883 if i < 10 {
884 p += &array![-2., -2.]
885 } else if i < 20 {
886 p += &array![-2., 2.];
887 } else if i < 30 {
888 p += &array![2., -2.];
889 } else {
890 p += &array![2., 2.];
891 }
892 });
893
894 let targets = (0..40)
895 .map(|x| match x {
896 x if x < 10 => 0,
897 x if x < 20 => 1,
898 x if x < 30 => 2,
899 _ => 3,
900 })
901 .collect::<Array1<_>>();
902
903 let dataset = Dataset::new(data.clone(), targets);
904
905 let model = DecisionTree::params().fit(&dataset)?;
906 let prediction = model.predict(data);
907
908 let cm = prediction.confusion_matrix(&dataset)?;
909 assert!(cm.accuracy() > 0.99);
910
911 Ok(())
912 }
913
914 #[test]
915 #[should_panic]
916 fn panic_min_impurity_decrease() {
918 DecisionTree::<f64, bool>::params()
919 .min_impurity_decrease(0.0)
920 .check()
921 .unwrap();
922 }
923}