1use ferrolearn_core::error::FerroError;
26use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
27use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
28use ferrolearn_core::traits::{Fit, Predict};
29use ndarray::{Array1, Array2};
30use num_traits::{Float, FromPrimitive, ToPrimitive};
31use serde::{Deserialize, Serialize};
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum ClassificationCriterion {
40 Gini,
42 Entropy,
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
48pub enum RegressionCriterion {
49 Mse,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
62pub enum Node<F> {
63 Split {
65 feature: usize,
67 threshold: F,
69 left: usize,
71 right: usize,
73 impurity_decrease: F,
75 n_samples: usize,
77 },
78 Leaf {
80 value: F,
82 class_distribution: Option<Vec<F>>,
84 n_samples: usize,
86 },
87}
88
89#[derive(Debug, Clone, Copy)]
95pub(crate) struct TreeParams {
96 pub(crate) max_depth: Option<usize>,
97 pub(crate) min_samples_split: usize,
98 pub(crate) min_samples_leaf: usize,
99}
100
101struct ClassificationData<'a, F> {
103 x: &'a Array2<F>,
104 y: &'a [usize],
105 n_classes: usize,
106 feature_indices: Option<&'a [usize]>,
107 criterion: ClassificationCriterion,
108}
109
110struct RegressionData<'a, F> {
112 x: &'a Array2<F>,
113 y: &'a Array1<F>,
114 feature_indices: Option<&'a [usize]>,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct DecisionTreeClassifier<F> {
131 pub max_depth: Option<usize>,
133 pub min_samples_split: usize,
135 pub min_samples_leaf: usize,
137 pub criterion: ClassificationCriterion,
139 _marker: std::marker::PhantomData<F>,
140}
141
142impl<F: Float> DecisionTreeClassifier<F> {
143 #[must_use]
148 pub fn new() -> Self {
149 Self {
150 max_depth: None,
151 min_samples_split: 2,
152 min_samples_leaf: 1,
153 criterion: ClassificationCriterion::Gini,
154 _marker: std::marker::PhantomData,
155 }
156 }
157
158 #[must_use]
160 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
161 self.max_depth = max_depth;
162 self
163 }
164
165 #[must_use]
167 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
168 self.min_samples_split = min_samples_split;
169 self
170 }
171
172 #[must_use]
174 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
175 self.min_samples_leaf = min_samples_leaf;
176 self
177 }
178
179 #[must_use]
181 pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
182 self.criterion = criterion;
183 self
184 }
185}
186
187impl<F: Float> Default for DecisionTreeClassifier<F> {
188 fn default() -> Self {
189 Self::new()
190 }
191}
192
193#[derive(Debug, Clone)]
203pub struct FittedDecisionTreeClassifier<F> {
204 nodes: Vec<Node<F>>,
206 classes: Vec<usize>,
208 n_features: usize,
210 feature_importances: Array1<F>,
212}
213
214impl<F: Float + Send + Sync + 'static> FittedDecisionTreeClassifier<F> {
215 #[must_use]
217 pub fn nodes(&self) -> &[Node<F>] {
218 &self.nodes
219 }
220
221 #[must_use]
223 pub fn n_features(&self) -> usize {
224 self.n_features
225 }
226
227 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
236 if x.ncols() != self.n_features {
237 return Err(FerroError::ShapeMismatch {
238 expected: vec![self.n_features],
239 actual: vec![x.ncols()],
240 context: "number of features must match fitted model".into(),
241 });
242 }
243 let n_samples = x.nrows();
244 let n_classes = self.classes.len();
245 let mut proba = Array2::zeros((n_samples, n_classes));
246 for i in 0..n_samples {
247 let row = x.row(i);
248 let leaf = traverse_tree(&self.nodes, &row);
249 if let Node::Leaf {
250 class_distribution: Some(ref dist),
251 ..
252 } = self.nodes[leaf]
253 {
254 for (j, &p) in dist.iter().enumerate() {
255 proba[[i, j]] = p;
256 }
257 }
258 }
259 Ok(proba)
260 }
261}
262
263impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for DecisionTreeClassifier<F> {
264 type Fitted = FittedDecisionTreeClassifier<F>;
265 type Error = FerroError;
266
267 fn fit(
276 &self,
277 x: &Array2<F>,
278 y: &Array1<usize>,
279 ) -> Result<FittedDecisionTreeClassifier<F>, FerroError> {
280 let (n_samples, n_features) = x.dim();
281
282 if n_samples != y.len() {
283 return Err(FerroError::ShapeMismatch {
284 expected: vec![n_samples],
285 actual: vec![y.len()],
286 context: "y length must match number of samples in X".into(),
287 });
288 }
289 if n_samples == 0 {
290 return Err(FerroError::InsufficientSamples {
291 required: 1,
292 actual: 0,
293 context: "DecisionTreeClassifier requires at least one sample".into(),
294 });
295 }
296 if self.min_samples_split < 2 {
297 return Err(FerroError::InvalidParameter {
298 name: "min_samples_split".into(),
299 reason: "must be at least 2".into(),
300 });
301 }
302 if self.min_samples_leaf < 1 {
303 return Err(FerroError::InvalidParameter {
304 name: "min_samples_leaf".into(),
305 reason: "must be at least 1".into(),
306 });
307 }
308
309 let mut classes: Vec<usize> = y.iter().copied().collect();
311 classes.sort_unstable();
312 classes.dedup();
313 let n_classes = classes.len();
314
315 let y_mapped: Vec<usize> = y
317 .iter()
318 .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
319 .collect();
320
321 let indices: Vec<usize> = (0..n_samples).collect();
322
323 let data = ClassificationData {
324 x,
325 y: &y_mapped,
326 n_classes,
327 feature_indices: None,
328 criterion: self.criterion,
329 };
330 let params = TreeParams {
331 max_depth: self.max_depth,
332 min_samples_split: self.min_samples_split,
333 min_samples_leaf: self.min_samples_leaf,
334 };
335
336 let mut nodes: Vec<Node<F>> = Vec::new();
337 build_classification_tree(&data, &indices, &mut nodes, 0, ¶ms);
338
339 let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
340
341 Ok(FittedDecisionTreeClassifier {
342 nodes,
343 classes,
344 n_features,
345 feature_importances,
346 })
347 }
348}
349
350impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedDecisionTreeClassifier<F> {
351 type Output = Array1<usize>;
352 type Error = FerroError;
353
354 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
361 if x.ncols() != self.n_features {
362 return Err(FerroError::ShapeMismatch {
363 expected: vec![self.n_features],
364 actual: vec![x.ncols()],
365 context: "number of features must match fitted model".into(),
366 });
367 }
368 let n_samples = x.nrows();
369 let mut predictions = Array1::zeros(n_samples);
370 for i in 0..n_samples {
371 let row = x.row(i);
372 let leaf = traverse_tree(&self.nodes, &row);
373 if let Node::Leaf { value, .. } = self.nodes[leaf] {
374 predictions[i] = float_to_usize(value);
375 }
376 }
377 Ok(predictions)
378 }
379}
380
381impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
382 for FittedDecisionTreeClassifier<F>
383{
384 fn feature_importances(&self) -> &Array1<F> {
385 &self.feature_importances
386 }
387}
388
389impl<F: Float + Send + Sync + 'static> HasClasses for FittedDecisionTreeClassifier<F> {
390 fn classes(&self) -> &[usize] {
391 &self.classes
392 }
393
394 fn n_classes(&self) -> usize {
395 self.classes.len()
396 }
397}
398
399impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
401 for DecisionTreeClassifier<F>
402{
403 fn fit_pipeline(
404 &self,
405 x: &Array2<F>,
406 y: &Array1<F>,
407 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
408 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
409 let fitted = self.fit(x, &y_usize)?;
410 Ok(Box::new(FittedClassifierPipelineAdapter(fitted)))
411 }
412}
413
414struct FittedClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
416 FittedDecisionTreeClassifier<F>,
417);
418
419impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
420 for FittedClassifierPipelineAdapter<F>
421{
422 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
423 let preds = self.0.predict(x)?;
424 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
425 }
426}
427
428#[derive(Debug, Clone, Serialize, Deserialize)]
441pub struct DecisionTreeRegressor<F> {
442 pub max_depth: Option<usize>,
444 pub min_samples_split: usize,
446 pub min_samples_leaf: usize,
448 pub criterion: RegressionCriterion,
450 _marker: std::marker::PhantomData<F>,
451}
452
453impl<F: Float> DecisionTreeRegressor<F> {
454 #[must_use]
459 pub fn new() -> Self {
460 Self {
461 max_depth: None,
462 min_samples_split: 2,
463 min_samples_leaf: 1,
464 criterion: RegressionCriterion::Mse,
465 _marker: std::marker::PhantomData,
466 }
467 }
468
469 #[must_use]
471 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
472 self.max_depth = max_depth;
473 self
474 }
475
476 #[must_use]
478 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
479 self.min_samples_split = min_samples_split;
480 self
481 }
482
483 #[must_use]
485 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
486 self.min_samples_leaf = min_samples_leaf;
487 self
488 }
489
490 #[must_use]
492 pub fn with_criterion(mut self, criterion: RegressionCriterion) -> Self {
493 self.criterion = criterion;
494 self
495 }
496}
497
498impl<F: Float> Default for DecisionTreeRegressor<F> {
499 fn default() -> Self {
500 Self::new()
501 }
502}
503
504#[derive(Debug, Clone)]
512pub struct FittedDecisionTreeRegressor<F> {
513 nodes: Vec<Node<F>>,
515 n_features: usize,
517 feature_importances: Array1<F>,
519}
520
521impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for DecisionTreeRegressor<F> {
522 type Fitted = FittedDecisionTreeRegressor<F>;
523 type Error = FerroError;
524
525 fn fit(
534 &self,
535 x: &Array2<F>,
536 y: &Array1<F>,
537 ) -> Result<FittedDecisionTreeRegressor<F>, FerroError> {
538 let (n_samples, n_features) = x.dim();
539
540 if n_samples != y.len() {
541 return Err(FerroError::ShapeMismatch {
542 expected: vec![n_samples],
543 actual: vec![y.len()],
544 context: "y length must match number of samples in X".into(),
545 });
546 }
547 if n_samples == 0 {
548 return Err(FerroError::InsufficientSamples {
549 required: 1,
550 actual: 0,
551 context: "DecisionTreeRegressor requires at least one sample".into(),
552 });
553 }
554 if self.min_samples_split < 2 {
555 return Err(FerroError::InvalidParameter {
556 name: "min_samples_split".into(),
557 reason: "must be at least 2".into(),
558 });
559 }
560 if self.min_samples_leaf < 1 {
561 return Err(FerroError::InvalidParameter {
562 name: "min_samples_leaf".into(),
563 reason: "must be at least 1".into(),
564 });
565 }
566
567 let indices: Vec<usize> = (0..n_samples).collect();
568
569 let data = RegressionData {
570 x,
571 y,
572 feature_indices: None,
573 };
574 let params = TreeParams {
575 max_depth: self.max_depth,
576 min_samples_split: self.min_samples_split,
577 min_samples_leaf: self.min_samples_leaf,
578 };
579
580 let mut nodes: Vec<Node<F>> = Vec::new();
581 build_regression_tree(&data, &indices, &mut nodes, 0, ¶ms);
582
583 let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
584
585 Ok(FittedDecisionTreeRegressor {
586 nodes,
587 n_features,
588 feature_importances,
589 })
590 }
591}
592
593impl<F: Float + Send + Sync + 'static> FittedDecisionTreeRegressor<F> {
594 #[must_use]
596 pub fn nodes(&self) -> &[Node<F>] {
597 &self.nodes
598 }
599
600 #[must_use]
602 pub fn n_features(&self) -> usize {
603 self.n_features
604 }
605}
606
607impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedDecisionTreeRegressor<F> {
608 type Output = Array1<F>;
609 type Error = FerroError;
610
611 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
618 if x.ncols() != self.n_features {
619 return Err(FerroError::ShapeMismatch {
620 expected: vec![self.n_features],
621 actual: vec![x.ncols()],
622 context: "number of features must match fitted model".into(),
623 });
624 }
625 let n_samples = x.nrows();
626 let mut predictions = Array1::zeros(n_samples);
627 for i in 0..n_samples {
628 let row = x.row(i);
629 let leaf = traverse_tree(&self.nodes, &row);
630 if let Node::Leaf { value, .. } = self.nodes[leaf] {
631 predictions[i] = value;
632 }
633 }
634 Ok(predictions)
635 }
636}
637
638impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedDecisionTreeRegressor<F> {
639 fn feature_importances(&self) -> &Array1<F> {
640 &self.feature_importances
641 }
642}
643
644impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for DecisionTreeRegressor<F> {
646 fn fit_pipeline(
647 &self,
648 x: &Array2<F>,
649 y: &Array1<F>,
650 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
651 let fitted = self.fit(x, y)?;
652 Ok(Box::new(fitted))
653 }
654}
655
656impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F>
657 for FittedDecisionTreeRegressor<F>
658{
659 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
660 self.predict(x)
661 }
662}
663
664fn traverse_tree<F: Float>(nodes: &[Node<F>], sample: &ndarray::ArrayView1<F>) -> usize {
670 let mut idx = 0;
671 loop {
672 match &nodes[idx] {
673 Node::Split {
674 feature,
675 threshold,
676 left,
677 right,
678 ..
679 } => {
680 if sample[*feature] <= *threshold {
681 idx = *left;
682 } else {
683 idx = *right;
684 }
685 }
686 Node::Leaf { .. } => return idx,
687 }
688 }
689}
690
691pub(crate) fn traverse<F: Float>(nodes: &[Node<F>], sample: &ndarray::ArrayView1<F>) -> usize {
695 traverse_tree(nodes, sample)
696}
697
698fn float_to_usize<F: Float>(v: F) -> usize {
700 v.to_f64().map_or(0, |f| f.round() as usize)
701}
702
703fn gini_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
705 if total == 0 {
706 return F::zero();
707 }
708 let total_f = F::from(total).unwrap();
709 let mut impurity = F::one();
710 for &count in class_counts {
711 let p = F::from(count).unwrap() / total_f;
712 impurity = impurity - p * p;
713 }
714 impurity
715}
716
717fn entropy_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
719 if total == 0 {
720 return F::zero();
721 }
722 let total_f = F::from(total).unwrap();
723 let mut ent = F::zero();
724 for &count in class_counts {
725 if count > 0 {
726 let p = F::from(count).unwrap() / total_f;
727 ent = ent - p * p.ln();
728 }
729 }
730 ent
731}
732
733fn mean_value<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
735 if indices.is_empty() {
736 return F::zero();
737 }
738 let sum: F = indices.iter().map(|&i| y[i]).fold(F::zero(), |a, b| a + b);
739 sum / F::from(indices.len()).unwrap()
740}
741
742fn mse_for_indices<F: Float>(y: &Array1<F>, indices: &[usize], mean: F) -> F {
744 if indices.is_empty() {
745 return F::zero();
746 }
747 let sum_sq: F = indices
748 .iter()
749 .map(|&i| {
750 let diff = y[i] - mean;
751 diff * diff
752 })
753 .fold(F::zero(), |a, b| a + b);
754 sum_sq / F::from(indices.len()).unwrap()
755}
756
757fn compute_impurity<F: Float>(
759 class_counts: &[usize],
760 total: usize,
761 criterion: ClassificationCriterion,
762) -> F {
763 match criterion {
764 ClassificationCriterion::Gini => gini_impurity(class_counts, total),
765 ClassificationCriterion::Entropy => entropy_impurity(class_counts, total),
766 }
767}
768
769fn make_classification_leaf<F: Float>(
771 nodes: &mut Vec<Node<F>>,
772 class_counts: &[usize],
773 n_classes: usize,
774 n_samples: usize,
775) -> usize {
776 let majority_class = class_counts
777 .iter()
778 .enumerate()
779 .max_by_key(|&(_, &count)| count)
780 .map_or(0, |(i, _)| i);
781
782 let total_f = if n_samples > 0 {
783 F::from(n_samples).unwrap()
784 } else {
785 F::one()
786 };
787 let distribution: Vec<F> = (0..n_classes)
788 .map(|c| F::from(class_counts[c]).unwrap() / total_f)
789 .collect();
790
791 let idx = nodes.len();
792 nodes.push(Node::Leaf {
793 value: F::from(majority_class).unwrap(),
794 class_distribution: Some(distribution),
795 n_samples,
796 });
797 idx
798}
799
800fn build_classification_tree<F: Float>(
804 data: &ClassificationData<'_, F>,
805 indices: &[usize],
806 nodes: &mut Vec<Node<F>>,
807 depth: usize,
808 params: &TreeParams,
809) -> usize {
810 let n = indices.len();
811
812 let mut class_counts = vec![0usize; data.n_classes];
813 for &i in indices {
814 class_counts[data.y[i]] += 1;
815 }
816
817 let should_stop = n < params.min_samples_split
818 || params.max_depth.is_some_and(|d| depth >= d)
819 || class_counts.iter().filter(|&&c| c > 0).count() <= 1;
820
821 if should_stop {
822 return make_classification_leaf(nodes, &class_counts, data.n_classes, n);
823 }
824
825 let best = find_best_classification_split(data, indices, params.min_samples_leaf);
826
827 if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
828 let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
829 .iter()
830 .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
831
832 let node_idx = nodes.len();
833 nodes.push(Node::Leaf {
834 value: F::zero(),
835 class_distribution: None,
836 n_samples: 0,
837 }); let left_idx = build_classification_tree(data, &left_indices, nodes, depth + 1, params);
840 let right_idx = build_classification_tree(data, &right_indices, nodes, depth + 1, params);
841
842 nodes[node_idx] = Node::Split {
843 feature: best_feature,
844 threshold: best_threshold,
845 left: left_idx,
846 right: right_idx,
847 impurity_decrease: best_impurity_decrease,
848 n_samples: n,
849 };
850
851 node_idx
852 } else {
853 make_classification_leaf(nodes, &class_counts, data.n_classes, n)
854 }
855}
856
857fn find_best_classification_split<F: Float>(
861 data: &ClassificationData<'_, F>,
862 indices: &[usize],
863 min_samples_leaf: usize,
864) -> Option<(usize, F, F)> {
865 let n = indices.len();
866 let n_f = F::from(n).unwrap();
867 let n_features = data.x.ncols();
868
869 let mut parent_counts = vec![0usize; data.n_classes];
870 for &i in indices {
871 parent_counts[data.y[i]] += 1;
872 }
873 let parent_impurity = compute_impurity::<F>(&parent_counts, n, data.criterion);
874
875 let mut best_score = F::neg_infinity();
876 let mut best_feature = 0;
877 let mut best_threshold = F::zero();
878
879 let feature_iter: Box<dyn Iterator<Item = usize>> =
881 if let Some(feat_indices) = data.feature_indices {
882 Box::new(feat_indices.iter().copied())
883 } else {
884 Box::new(0..n_features)
885 };
886
887 for feat in feature_iter {
888 let mut sorted_indices: Vec<usize> = indices.to_vec();
889 sorted_indices.sort_by(|&a, &b| data.x[[a, feat]].partial_cmp(&data.x[[b, feat]]).unwrap());
890
891 let mut left_counts = vec![0usize; data.n_classes];
892 let mut right_counts = parent_counts.clone();
893 let mut left_n = 0usize;
894
895 for split_pos in 0..n - 1 {
896 let idx = sorted_indices[split_pos];
897 let cls = data.y[idx];
898 left_counts[cls] += 1;
899 right_counts[cls] -= 1;
900 left_n += 1;
901 let right_n = n - left_n;
902
903 let next_idx = sorted_indices[split_pos + 1];
904 if data.x[[idx, feat]] == data.x[[next_idx, feat]] {
905 continue;
906 }
907
908 if left_n < min_samples_leaf || right_n < min_samples_leaf {
909 continue;
910 }
911
912 let left_impurity = compute_impurity::<F>(&left_counts, left_n, data.criterion);
913 let right_impurity = compute_impurity::<F>(&right_counts, right_n, data.criterion);
914 let left_weight = F::from(left_n).unwrap() / n_f;
915 let right_weight = F::from(right_n).unwrap() / n_f;
916 let weighted_child_impurity =
917 left_weight * left_impurity + right_weight * right_impurity;
918 let impurity_decrease = parent_impurity - weighted_child_impurity;
919
920 if impurity_decrease > best_score {
921 best_score = impurity_decrease;
922 best_feature = feat;
923 best_threshold =
924 (data.x[[idx, feat]] + data.x[[next_idx, feat]]) / F::from(2.0).unwrap();
925 }
926 }
927 }
928
929 if best_score > F::zero() {
930 Some((best_feature, best_threshold, best_score * n_f))
931 } else {
932 None
933 }
934}
935
936fn build_regression_tree<F: Float>(
938 data: &RegressionData<'_, F>,
939 indices: &[usize],
940 nodes: &mut Vec<Node<F>>,
941 depth: usize,
942 params: &TreeParams,
943) -> usize {
944 let n = indices.len();
945 let mean = mean_value(data.y, indices);
946
947 let should_stop = n < params.min_samples_split || params.max_depth.is_some_and(|d| depth >= d);
948
949 if should_stop {
950 let idx = nodes.len();
951 nodes.push(Node::Leaf {
952 value: mean,
953 class_distribution: None,
954 n_samples: n,
955 });
956 return idx;
957 }
958
959 let parent_mse = mse_for_indices(data.y, indices, mean);
960 if parent_mse <= F::epsilon() {
961 let idx = nodes.len();
962 nodes.push(Node::Leaf {
963 value: mean,
964 class_distribution: None,
965 n_samples: n,
966 });
967 return idx;
968 }
969
970 let best = find_best_regression_split(data, indices, params.min_samples_leaf);
971
972 if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
973 let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
974 .iter()
975 .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
976
977 let node_idx = nodes.len();
978 nodes.push(Node::Leaf {
979 value: F::zero(),
980 class_distribution: None,
981 n_samples: 0,
982 }); let left_idx = build_regression_tree(data, &left_indices, nodes, depth + 1, params);
985 let right_idx = build_regression_tree(data, &right_indices, nodes, depth + 1, params);
986
987 nodes[node_idx] = Node::Split {
988 feature: best_feature,
989 threshold: best_threshold,
990 left: left_idx,
991 right: right_idx,
992 impurity_decrease: best_impurity_decrease,
993 n_samples: n,
994 };
995
996 node_idx
997 } else {
998 let idx = nodes.len();
999 nodes.push(Node::Leaf {
1000 value: mean,
1001 class_distribution: None,
1002 n_samples: n,
1003 });
1004 idx
1005 }
1006}
1007
1008fn find_best_regression_split<F: Float>(
1012 data: &RegressionData<'_, F>,
1013 indices: &[usize],
1014 min_samples_leaf: usize,
1015) -> Option<(usize, F, F)> {
1016 let n = indices.len();
1017 let n_f = F::from(n).unwrap();
1018 let n_features = data.x.ncols();
1019
1020 let parent_sum: F = indices
1021 .iter()
1022 .map(|&i| data.y[i])
1023 .fold(F::zero(), |a, b| a + b);
1024 let parent_sum_sq: F = indices
1025 .iter()
1026 .map(|&i| data.y[i] * data.y[i])
1027 .fold(F::zero(), |a, b| a + b);
1028 let parent_mse = parent_sum_sq / n_f - (parent_sum / n_f) * (parent_sum / n_f);
1029
1030 let mut best_score = F::neg_infinity();
1031 let mut best_feature = 0;
1032 let mut best_threshold = F::zero();
1033
1034 let feature_iter: Box<dyn Iterator<Item = usize>> =
1035 if let Some(feat_indices) = data.feature_indices {
1036 Box::new(feat_indices.iter().copied())
1037 } else {
1038 Box::new(0..n_features)
1039 };
1040
1041 for feat in feature_iter {
1042 let mut sorted_indices: Vec<usize> = indices.to_vec();
1043 sorted_indices.sort_by(|&a, &b| data.x[[a, feat]].partial_cmp(&data.x[[b, feat]]).unwrap());
1044
1045 let mut left_sum = F::zero();
1046 let mut left_sum_sq = F::zero();
1047 let mut left_n: usize = 0;
1048
1049 for split_pos in 0..n - 1 {
1050 let idx = sorted_indices[split_pos];
1051 let val = data.y[idx];
1052 left_sum = left_sum + val;
1053 left_sum_sq = left_sum_sq + val * val;
1054 left_n += 1;
1055 let right_n = n - left_n;
1056
1057 let next_idx = sorted_indices[split_pos + 1];
1058 if data.x[[idx, feat]] == data.x[[next_idx, feat]] {
1059 continue;
1060 }
1061
1062 if left_n < min_samples_leaf || right_n < min_samples_leaf {
1063 continue;
1064 }
1065
1066 let left_n_f = F::from(left_n).unwrap();
1067 let right_n_f = F::from(right_n).unwrap();
1068
1069 let left_mean = left_sum / left_n_f;
1070 let left_mse = left_sum_sq / left_n_f - left_mean * left_mean;
1071
1072 let right_sum = parent_sum - left_sum;
1073 let right_sum_sq = parent_sum_sq - left_sum_sq;
1074 let right_mean = right_sum / right_n_f;
1075 let right_mse = right_sum_sq / right_n_f - right_mean * right_mean;
1076
1077 let weighted_child_mse = (left_n_f * left_mse + right_n_f * right_mse) / n_f;
1078 let mse_decrease = parent_mse - weighted_child_mse;
1079
1080 if mse_decrease > best_score {
1081 best_score = mse_decrease;
1082 best_feature = feat;
1083 best_threshold =
1084 (data.x[[idx, feat]] + data.x[[next_idx, feat]]) / F::from(2.0).unwrap();
1085 }
1086 }
1087 }
1088
1089 if best_score > F::zero() {
1090 Some((best_feature, best_threshold, best_score * n_f))
1091 } else {
1092 None
1093 }
1094}
1095
1096pub(crate) fn compute_feature_importances<F: Float>(
1098 nodes: &[Node<F>],
1099 n_features: usize,
1100 _total_samples: usize,
1101) -> Array1<F> {
1102 let mut importances = Array1::zeros(n_features);
1103 for node in nodes {
1104 if let Node::Split {
1105 feature,
1106 impurity_decrease,
1107 ..
1108 } = node
1109 {
1110 importances[*feature] = importances[*feature] + *impurity_decrease;
1111 }
1112 }
1113 let total: F = importances.iter().copied().fold(F::zero(), |a, b| a + b);
1114 if total > F::zero() {
1115 importances.mapv_inplace(|v| v / total);
1116 }
1117 importances
1118}
1119
1120#[allow(clippy::too_many_arguments)]
1128pub(crate) fn build_classification_tree_with_feature_subset<F: Float>(
1129 x: &Array2<F>,
1130 y: &[usize],
1131 n_classes: usize,
1132 indices: &[usize],
1133 feature_indices: &[usize],
1134 params: &TreeParams,
1135 criterion: ClassificationCriterion,
1136) -> Vec<Node<F>> {
1137 let data = ClassificationData {
1138 x,
1139 y,
1140 n_classes,
1141 feature_indices: Some(feature_indices),
1142 criterion,
1143 };
1144 let mut nodes = Vec::new();
1145 build_classification_tree(&data, indices, &mut nodes, 0, params);
1146 nodes
1147}
1148
1149pub(crate) fn build_regression_tree_with_feature_subset<F: Float>(
1151 x: &Array2<F>,
1152 y: &Array1<F>,
1153 indices: &[usize],
1154 feature_indices: &[usize],
1155 params: &TreeParams,
1156) -> Vec<Node<F>> {
1157 let data = RegressionData {
1158 x,
1159 y,
1160 feature_indices: Some(feature_indices),
1161 };
1162 let mut nodes = Vec::new();
1163 build_regression_tree(&data, indices, &mut nodes, 0, params);
1164 nodes
1165}
1166
1167#[cfg(test)]
1172mod tests {
1173 use super::*;
1174 use approx::assert_relative_eq;
1175 use ndarray::array;
1176
1177 #[test]
1180 fn test_classifier_simple_binary() {
1181 let x = Array2::from_shape_vec(
1182 (6, 2),
1183 vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
1184 )
1185 .unwrap();
1186 let y = array![0, 0, 0, 1, 1, 1];
1187
1188 let model = DecisionTreeClassifier::<f64>::new();
1189 let fitted = model.fit(&x, &y).unwrap();
1190 let preds = fitted.predict(&x).unwrap();
1191
1192 assert_eq!(preds.len(), 6);
1193 for i in 0..3 {
1194 assert_eq!(preds[i], 0);
1195 }
1196 for i in 3..6 {
1197 assert_eq!(preds[i], 1);
1198 }
1199 }
1200
1201 #[test]
1202 fn test_classifier_single_class() {
1203 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1204 let y = array![0, 0, 0];
1205
1206 let model = DecisionTreeClassifier::<f64>::new();
1207 let fitted = model.fit(&x, &y).unwrap();
1208 let preds = fitted.predict(&x).unwrap();
1209
1210 assert_eq!(preds, array![0, 0, 0]);
1211 }
1212
1213 #[test]
1214 fn test_classifier_max_depth_1() {
1215 let x =
1216 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1217 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1218
1219 let model = DecisionTreeClassifier::<f64>::new().with_max_depth(Some(1));
1220 let fitted = model.fit(&x, &y).unwrap();
1221 let preds = fitted.predict(&x).unwrap();
1222
1223 for i in 0..4 {
1224 assert_eq!(preds[i], 0);
1225 }
1226 for i in 4..8 {
1227 assert_eq!(preds[i], 1);
1228 }
1229 }
1230
1231 #[test]
1232 fn test_classifier_min_samples_split() {
1233 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1234 let y = array![0, 0, 0, 1, 1, 1];
1235
1236 let model = DecisionTreeClassifier::<f64>::new().with_min_samples_split(7);
1237 let fitted = model.fit(&x, &y).unwrap();
1238 let preds = fitted.predict(&x).unwrap();
1239
1240 let majority = preds[0];
1241 for &p in &preds {
1242 assert_eq!(p, majority);
1243 }
1244 }
1245
1246 #[test]
1247 fn test_classifier_min_samples_leaf() {
1248 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1249 let y = array![0, 0, 0, 1, 1, 1];
1250
1251 let model = DecisionTreeClassifier::<f64>::new().with_min_samples_leaf(4);
1252 let fitted = model.fit(&x, &y).unwrap();
1253 let preds = fitted.predict(&x).unwrap();
1254
1255 let majority = preds[0];
1256 for &p in &preds {
1257 assert_eq!(p, majority);
1258 }
1259 }
1260
1261 #[test]
1262 fn test_classifier_gini_vs_entropy() {
1263 let x = Array2::from_shape_vec(
1264 (8, 2),
1265 vec![
1266 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0, 5.0, 5.0, 5.0, 6.0, 6.0, 5.0, 6.0, 6.0,
1267 ],
1268 )
1269 .unwrap();
1270 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1271
1272 let gini_model =
1273 DecisionTreeClassifier::<f64>::new().with_criterion(ClassificationCriterion::Gini);
1274 let entropy_model =
1275 DecisionTreeClassifier::<f64>::new().with_criterion(ClassificationCriterion::Entropy);
1276
1277 let fitted_gini = gini_model.fit(&x, &y).unwrap();
1278 let fitted_entropy = entropy_model.fit(&x, &y).unwrap();
1279
1280 let preds_gini = fitted_gini.predict(&x).unwrap();
1281 let preds_entropy = fitted_entropy.predict(&x).unwrap();
1282
1283 assert_eq!(preds_gini, y);
1284 assert_eq!(preds_entropy, y);
1285 }
1286
1287 #[test]
1288 fn test_classifier_predict_proba() {
1289 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1290 let y = array![0, 0, 0, 1, 1, 1];
1291
1292 let model = DecisionTreeClassifier::<f64>::new();
1293 let fitted = model.fit(&x, &y).unwrap();
1294 let proba = fitted.predict_proba(&x).unwrap();
1295
1296 assert_eq!(proba.dim(), (6, 2));
1297 for i in 0..6 {
1298 let row_sum: f64 = proba.row(i).iter().sum();
1299 assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
1300 }
1301 }
1302
1303 #[test]
1304 fn test_classifier_shape_mismatch_fit() {
1305 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1306 let y = array![0, 1];
1307
1308 let model = DecisionTreeClassifier::<f64>::new();
1309 assert!(model.fit(&x, &y).is_err());
1310 }
1311
1312 #[test]
1313 fn test_classifier_shape_mismatch_predict() {
1314 let x =
1315 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1316 let y = array![0, 0, 1, 1];
1317
1318 let model = DecisionTreeClassifier::<f64>::new();
1319 let fitted = model.fit(&x, &y).unwrap();
1320
1321 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1322 assert!(fitted.predict(&x_bad).is_err());
1323 }
1324
1325 #[test]
1326 fn test_classifier_empty_data() {
1327 let x = Array2::<f64>::zeros((0, 2));
1328 let y = Array1::<usize>::zeros(0);
1329
1330 let model = DecisionTreeClassifier::<f64>::new();
1331 assert!(model.fit(&x, &y).is_err());
1332 }
1333
1334 #[test]
1335 fn test_classifier_feature_importances() {
1336 let x = Array2::from_shape_vec(
1337 (8, 2),
1338 vec![
1339 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1340 ],
1341 )
1342 .unwrap();
1343 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1344
1345 let model = DecisionTreeClassifier::<f64>::new();
1346 let fitted = model.fit(&x, &y).unwrap();
1347 let importances = fitted.feature_importances();
1348
1349 assert_eq!(importances.len(), 2);
1350 assert!(importances[0] > 0.0);
1351 let sum: f64 = importances.iter().sum();
1352 assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
1353 }
1354
1355 #[test]
1356 fn test_classifier_has_classes() {
1357 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1358 let y = array![0, 1, 2, 0, 1, 2];
1359
1360 let model = DecisionTreeClassifier::<f64>::new();
1361 let fitted = model.fit(&x, &y).unwrap();
1362
1363 assert_eq!(fitted.classes(), &[0, 1, 2]);
1364 assert_eq!(fitted.n_classes(), 3);
1365 }
1366
1367 #[test]
1368 fn test_classifier_invalid_min_samples_split() {
1369 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1370 let y = array![0, 0, 1, 1];
1371
1372 let model = DecisionTreeClassifier::<f64>::new().with_min_samples_split(1);
1373 assert!(model.fit(&x, &y).is_err());
1374 }
1375
1376 #[test]
1377 fn test_classifier_invalid_min_samples_leaf() {
1378 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1379 let y = array![0, 0, 1, 1];
1380
1381 let model = DecisionTreeClassifier::<f64>::new().with_min_samples_leaf(0);
1382 assert!(model.fit(&x, &y).is_err());
1383 }
1384
1385 #[test]
1386 fn test_classifier_multiclass() {
1387 let x = Array2::from_shape_vec((9, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1388 .unwrap();
1389 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
1390
1391 let model = DecisionTreeClassifier::<f64>::new();
1392 let fitted = model.fit(&x, &y).unwrap();
1393 let preds = fitted.predict(&x).unwrap();
1394
1395 assert_eq!(preds, y);
1396 }
1397
1398 #[test]
1399 fn test_classifier_pipeline_integration() {
1400 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1401 let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
1402
1403 let model = DecisionTreeClassifier::<f64>::new();
1404 let fitted = model.fit_pipeline(&x, &y).unwrap();
1405 let preds = fitted.predict_pipeline(&x).unwrap();
1406 assert_eq!(preds.len(), 6);
1407 }
1408
1409 #[test]
1412 fn test_regressor_simple() {
1413 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
1414 let y = array![1.0, 2.0, 3.0, 4.0, 5.0];
1415
1416 let model = DecisionTreeRegressor::<f64>::new();
1417 let fitted = model.fit(&x, &y).unwrap();
1418 let preds = fitted.predict(&x).unwrap();
1419
1420 for (p, &actual) in preds.iter().zip(y.iter()) {
1421 assert_relative_eq!(*p, actual, epsilon = 1e-10);
1422 }
1423 }
1424
1425 #[test]
1426 fn test_regressor_max_depth() {
1427 let x =
1428 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1429 let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1430
1431 let model = DecisionTreeRegressor::<f64>::new().with_max_depth(Some(1));
1432 let fitted = model.fit(&x, &y).unwrap();
1433 let preds = fitted.predict(&x).unwrap();
1434
1435 for i in 0..4 {
1436 assert_relative_eq!(preds[i], 1.0, epsilon = 1e-10);
1437 }
1438 for i in 4..8 {
1439 assert_relative_eq!(preds[i], 5.0, epsilon = 1e-10);
1440 }
1441 }
1442
1443 #[test]
1444 fn test_regressor_constant_target() {
1445 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1446 let y = array![3.0, 3.0, 3.0, 3.0];
1447
1448 let model = DecisionTreeRegressor::<f64>::new();
1449 let fitted = model.fit(&x, &y).unwrap();
1450 let preds = fitted.predict(&x).unwrap();
1451
1452 for &p in &preds {
1453 assert_relative_eq!(p, 3.0, epsilon = 1e-10);
1454 }
1455 }
1456
1457 #[test]
1458 fn test_regressor_shape_mismatch_fit() {
1459 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1460 let y = array![1.0, 2.0];
1461
1462 let model = DecisionTreeRegressor::<f64>::new();
1463 assert!(model.fit(&x, &y).is_err());
1464 }
1465
1466 #[test]
1467 fn test_regressor_shape_mismatch_predict() {
1468 let x =
1469 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1470 let y = array![1.0, 2.0, 3.0, 4.0];
1471
1472 let model = DecisionTreeRegressor::<f64>::new();
1473 let fitted = model.fit(&x, &y).unwrap();
1474
1475 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1476 assert!(fitted.predict(&x_bad).is_err());
1477 }
1478
1479 #[test]
1480 fn test_regressor_empty_data() {
1481 let x = Array2::<f64>::zeros((0, 2));
1482 let y = Array1::<f64>::zeros(0);
1483
1484 let model = DecisionTreeRegressor::<f64>::new();
1485 assert!(model.fit(&x, &y).is_err());
1486 }
1487
1488 #[test]
1489 fn test_regressor_feature_importances() {
1490 let x = Array2::from_shape_vec(
1491 (8, 2),
1492 vec![
1493 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1494 ],
1495 )
1496 .unwrap();
1497 let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1498
1499 let model = DecisionTreeRegressor::<f64>::new();
1500 let fitted = model.fit(&x, &y).unwrap();
1501 let importances = fitted.feature_importances();
1502
1503 assert_eq!(importances.len(), 2);
1504 assert!(importances[0] > 0.0);
1505 let sum: f64 = importances.iter().sum();
1506 assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
1507 }
1508
1509 #[test]
1510 fn test_regressor_min_samples_split() {
1511 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1512 let y = array![1.0, 2.0, 3.0, 4.0];
1513
1514 let model = DecisionTreeRegressor::<f64>::new().with_min_samples_split(5);
1515 let fitted = model.fit(&x, &y).unwrap();
1516 let preds = fitted.predict(&x).unwrap();
1517
1518 let mean = 2.5;
1519 for &p in &preds {
1520 assert_relative_eq!(p, mean, epsilon = 1e-10);
1521 }
1522 }
1523
1524 #[test]
1525 fn test_regressor_pipeline_integration() {
1526 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1527 let y = array![1.0, 2.0, 3.0, 4.0];
1528
1529 let model = DecisionTreeRegressor::<f64>::new();
1530 let fitted = model.fit_pipeline(&x, &y).unwrap();
1531 let preds = fitted.predict_pipeline(&x).unwrap();
1532 assert_eq!(preds.len(), 4);
1533 }
1534
1535 #[test]
1536 fn test_regressor_f32_support() {
1537 let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1538 let y = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
1539
1540 let model = DecisionTreeRegressor::<f32>::new();
1541 let fitted = model.fit(&x, &y).unwrap();
1542 let preds = fitted.predict(&x).unwrap();
1543 assert_eq!(preds.len(), 4);
1544 }
1545
1546 #[test]
1547 fn test_classifier_f32_support() {
1548 let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1549 let y = array![0, 0, 0, 1, 1, 1];
1550
1551 let model = DecisionTreeClassifier::<f32>::new();
1552 let fitted = model.fit(&x, &y).unwrap();
1553 let preds = fitted.predict(&x).unwrap();
1554 assert_eq!(preds.len(), 6);
1555 }
1556
1557 #[test]
1560 fn test_gini_impurity_pure() {
1561 let counts = vec![5, 0];
1562 let imp: f64 = gini_impurity(&counts, 5);
1563 assert_relative_eq!(imp, 0.0, epsilon = 1e-10);
1564 }
1565
1566 #[test]
1567 fn test_gini_impurity_balanced() {
1568 let counts = vec![5, 5];
1569 let imp: f64 = gini_impurity(&counts, 10);
1570 assert_relative_eq!(imp, 0.5, epsilon = 1e-10);
1571 }
1572
1573 #[test]
1574 fn test_entropy_pure() {
1575 let counts = vec![5, 0];
1576 let ent: f64 = entropy_impurity(&counts, 5);
1577 assert_relative_eq!(ent, 0.0, epsilon = 1e-10);
1578 }
1579
1580 #[test]
1581 fn test_entropy_balanced() {
1582 let counts = vec![5, 5];
1583 let ent: f64 = entropy_impurity(&counts, 10);
1584 assert_relative_eq!(ent, 2.0f64.ln(), epsilon = 1e-10);
1585 }
1586}