1use ferrolearn_core::error::FerroError;
26use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
27use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
28use ferrolearn_core::traits::{Fit, Predict};
29use ndarray::{Array1, Array2};
30use num_traits::{Float, FromPrimitive, ToPrimitive};
31use rand::SeedableRng;
32use rand::rngs::StdRng;
33use rand::seq::index::sample as rand_sample_indices;
34use serde::{Deserialize, Serialize};
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
42pub enum ClassificationCriterion {
43 Gini,
45 Entropy,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
51pub enum RegressionCriterion {
52 Mse,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
65pub enum Node<F> {
66 Split {
68 feature: usize,
70 threshold: F,
72 left: usize,
74 right: usize,
76 impurity_decrease: F,
78 n_samples: usize,
80 },
81 Leaf {
83 value: F,
85 class_distribution: Option<Vec<F>>,
87 n_samples: usize,
89 },
90}
91
92#[derive(Debug, Clone, Copy)]
98pub(crate) struct TreeParams {
99 pub(crate) max_depth: Option<usize>,
100 pub(crate) min_samples_split: usize,
101 pub(crate) min_samples_leaf: usize,
102}
103
104struct ClassificationData<'a, F> {
106 x: &'a Array2<F>,
107 y: &'a [usize],
108 n_classes: usize,
109 feature_indices: Option<&'a [usize]>,
113 max_features_per_split: Option<usize>,
117 criterion: ClassificationCriterion,
118}
119
120struct RegressionData<'a, F> {
122 x: &'a Array2<F>,
123 y: &'a Array1<F>,
124 feature_indices: Option<&'a [usize]>,
125 max_features_per_split: Option<usize>,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct DecisionTreeClassifier<F> {
143 pub max_depth: Option<usize>,
145 pub min_samples_split: usize,
147 pub min_samples_leaf: usize,
149 pub criterion: ClassificationCriterion,
151 _marker: std::marker::PhantomData<F>,
152}
153
154impl<F: Float> DecisionTreeClassifier<F> {
155 #[must_use]
160 pub fn new() -> Self {
161 Self {
162 max_depth: None,
163 min_samples_split: 2,
164 min_samples_leaf: 1,
165 criterion: ClassificationCriterion::Gini,
166 _marker: std::marker::PhantomData,
167 }
168 }
169
170 #[must_use]
172 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
173 self.max_depth = max_depth;
174 self
175 }
176
177 #[must_use]
179 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
180 self.min_samples_split = min_samples_split;
181 self
182 }
183
184 #[must_use]
186 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
187 self.min_samples_leaf = min_samples_leaf;
188 self
189 }
190
191 #[must_use]
193 pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
194 self.criterion = criterion;
195 self
196 }
197}
198
199impl<F: Float> Default for DecisionTreeClassifier<F> {
200 fn default() -> Self {
201 Self::new()
202 }
203}
204
205#[derive(Debug, Clone)]
215pub struct FittedDecisionTreeClassifier<F> {
216 nodes: Vec<Node<F>>,
218 classes: Vec<usize>,
220 n_features: usize,
222 feature_importances: Array1<F>,
224}
225
226impl<F: Float + Send + Sync + 'static> FittedDecisionTreeClassifier<F> {
227 #[must_use]
229 pub fn nodes(&self) -> &[Node<F>] {
230 &self.nodes
231 }
232
233 #[must_use]
235 pub fn n_features(&self) -> usize {
236 self.n_features
237 }
238
239 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
248 if x.ncols() != self.n_features {
249 return Err(FerroError::ShapeMismatch {
250 expected: vec![self.n_features],
251 actual: vec![x.ncols()],
252 context: "number of features must match fitted model".into(),
253 });
254 }
255 let n_samples = x.nrows();
256 let n_classes = self.classes.len();
257 let mut proba = Array2::zeros((n_samples, n_classes));
258 for i in 0..n_samples {
259 let row = x.row(i);
260 let leaf = traverse_tree(&self.nodes, &row);
261 if let Node::Leaf {
262 class_distribution: Some(ref dist),
263 ..
264 } = self.nodes[leaf]
265 {
266 for (j, &p) in dist.iter().enumerate() {
267 proba[[i, j]] = p;
268 }
269 }
270 }
271 Ok(proba)
272 }
273
274 pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
282 if x.nrows() != y.len() {
283 return Err(FerroError::ShapeMismatch {
284 expected: vec![x.nrows()],
285 actual: vec![y.len()],
286 context: "y length must match number of samples in X".into(),
287 });
288 }
289 let preds = self.predict(x)?;
290 Ok(crate::mean_accuracy(&preds, y))
291 }
292
293 pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
300 let proba = self.predict_proba(x)?;
301 Ok(crate::log_proba(&proba))
302 }
303}
304
305impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for DecisionTreeClassifier<F> {
306 type Fitted = FittedDecisionTreeClassifier<F>;
307 type Error = FerroError;
308
309 fn fit(
318 &self,
319 x: &Array2<F>,
320 y: &Array1<usize>,
321 ) -> Result<FittedDecisionTreeClassifier<F>, FerroError> {
322 let (n_samples, n_features) = x.dim();
323
324 if n_samples != y.len() {
325 return Err(FerroError::ShapeMismatch {
326 expected: vec![n_samples],
327 actual: vec![y.len()],
328 context: "y length must match number of samples in X".into(),
329 });
330 }
331 if n_samples == 0 {
332 return Err(FerroError::InsufficientSamples {
333 required: 1,
334 actual: 0,
335 context: "DecisionTreeClassifier requires at least one sample".into(),
336 });
337 }
338 if self.min_samples_split < 2 {
339 return Err(FerroError::InvalidParameter {
340 name: "min_samples_split".into(),
341 reason: "must be at least 2".into(),
342 });
343 }
344 if self.min_samples_leaf < 1 {
345 return Err(FerroError::InvalidParameter {
346 name: "min_samples_leaf".into(),
347 reason: "must be at least 1".into(),
348 });
349 }
350
351 let mut classes: Vec<usize> = y.iter().copied().collect();
353 classes.sort_unstable();
354 classes.dedup();
355 let n_classes = classes.len();
356
357 let y_mapped: Vec<usize> = y
359 .iter()
360 .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
361 .collect();
362
363 let indices: Vec<usize> = (0..n_samples).collect();
364
365 let data = ClassificationData {
366 x,
367 y: &y_mapped,
368 n_classes,
369 feature_indices: None,
370 max_features_per_split: None,
371 criterion: self.criterion,
372 };
373 let params = TreeParams {
374 max_depth: self.max_depth,
375 min_samples_split: self.min_samples_split,
376 min_samples_leaf: self.min_samples_leaf,
377 };
378
379 let mut nodes: Vec<Node<F>> = Vec::new();
380 build_classification_tree(&data, &indices, &mut nodes, 0, ¶ms, None);
381
382 let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
383
384 Ok(FittedDecisionTreeClassifier {
385 nodes,
386 classes,
387 n_features,
388 feature_importances,
389 })
390 }
391}
392
393impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedDecisionTreeClassifier<F> {
394 type Output = Array1<usize>;
395 type Error = FerroError;
396
397 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
404 if x.ncols() != self.n_features {
405 return Err(FerroError::ShapeMismatch {
406 expected: vec![self.n_features],
407 actual: vec![x.ncols()],
408 context: "number of features must match fitted model".into(),
409 });
410 }
411 let n_samples = x.nrows();
412 let mut predictions = Array1::zeros(n_samples);
413 for i in 0..n_samples {
414 let row = x.row(i);
415 let leaf = traverse_tree(&self.nodes, &row);
416 if let Node::Leaf { value, .. } = self.nodes[leaf] {
417 predictions[i] = float_to_usize(value);
418 }
419 }
420 Ok(predictions)
421 }
422}
423
424impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
425 for FittedDecisionTreeClassifier<F>
426{
427 fn feature_importances(&self) -> &Array1<F> {
428 &self.feature_importances
429 }
430}
431
432impl<F: Float + Send + Sync + 'static> HasClasses for FittedDecisionTreeClassifier<F> {
433 fn classes(&self) -> &[usize] {
434 &self.classes
435 }
436
437 fn n_classes(&self) -> usize {
438 self.classes.len()
439 }
440}
441
442impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
444 for DecisionTreeClassifier<F>
445{
446 fn fit_pipeline(
447 &self,
448 x: &Array2<F>,
449 y: &Array1<F>,
450 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
451 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
452 let fitted = self.fit(x, &y_usize)?;
453 Ok(Box::new(FittedClassifierPipelineAdapter(fitted)))
454 }
455}
456
457struct FittedClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
459 FittedDecisionTreeClassifier<F>,
460);
461
462impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
463 for FittedClassifierPipelineAdapter<F>
464{
465 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
466 let preds = self.0.predict(x)?;
467 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
468 }
469}
470
471#[derive(Debug, Clone, Serialize, Deserialize)]
484pub struct DecisionTreeRegressor<F> {
485 pub max_depth: Option<usize>,
487 pub min_samples_split: usize,
489 pub min_samples_leaf: usize,
491 pub criterion: RegressionCriterion,
493 _marker: std::marker::PhantomData<F>,
494}
495
496impl<F: Float> DecisionTreeRegressor<F> {
497 #[must_use]
502 pub fn new() -> Self {
503 Self {
504 max_depth: None,
505 min_samples_split: 2,
506 min_samples_leaf: 1,
507 criterion: RegressionCriterion::Mse,
508 _marker: std::marker::PhantomData,
509 }
510 }
511
512 #[must_use]
514 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
515 self.max_depth = max_depth;
516 self
517 }
518
519 #[must_use]
521 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
522 self.min_samples_split = min_samples_split;
523 self
524 }
525
526 #[must_use]
528 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
529 self.min_samples_leaf = min_samples_leaf;
530 self
531 }
532
533 #[must_use]
535 pub fn with_criterion(mut self, criterion: RegressionCriterion) -> Self {
536 self.criterion = criterion;
537 self
538 }
539}
540
541impl<F: Float> Default for DecisionTreeRegressor<F> {
542 fn default() -> Self {
543 Self::new()
544 }
545}
546
547#[derive(Debug, Clone)]
555pub struct FittedDecisionTreeRegressor<F> {
556 nodes: Vec<Node<F>>,
558 n_features: usize,
560 feature_importances: Array1<F>,
562}
563
564impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for DecisionTreeRegressor<F> {
565 type Fitted = FittedDecisionTreeRegressor<F>;
566 type Error = FerroError;
567
568 fn fit(
577 &self,
578 x: &Array2<F>,
579 y: &Array1<F>,
580 ) -> Result<FittedDecisionTreeRegressor<F>, FerroError> {
581 let (n_samples, n_features) = x.dim();
582
583 if n_samples != y.len() {
584 return Err(FerroError::ShapeMismatch {
585 expected: vec![n_samples],
586 actual: vec![y.len()],
587 context: "y length must match number of samples in X".into(),
588 });
589 }
590 if n_samples == 0 {
591 return Err(FerroError::InsufficientSamples {
592 required: 1,
593 actual: 0,
594 context: "DecisionTreeRegressor requires at least one sample".into(),
595 });
596 }
597 if self.min_samples_split < 2 {
598 return Err(FerroError::InvalidParameter {
599 name: "min_samples_split".into(),
600 reason: "must be at least 2".into(),
601 });
602 }
603 if self.min_samples_leaf < 1 {
604 return Err(FerroError::InvalidParameter {
605 name: "min_samples_leaf".into(),
606 reason: "must be at least 1".into(),
607 });
608 }
609
610 let indices: Vec<usize> = (0..n_samples).collect();
611
612 let data = RegressionData {
613 x,
614 y,
615 feature_indices: None,
616 max_features_per_split: None,
617 };
618 let params = TreeParams {
619 max_depth: self.max_depth,
620 min_samples_split: self.min_samples_split,
621 min_samples_leaf: self.min_samples_leaf,
622 };
623
624 let mut nodes: Vec<Node<F>> = Vec::new();
625 build_regression_tree(&data, &indices, &mut nodes, 0, ¶ms, None);
626
627 let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
628
629 Ok(FittedDecisionTreeRegressor {
630 nodes,
631 n_features,
632 feature_importances,
633 })
634 }
635}
636
637impl<F: Float + Send + Sync + 'static> FittedDecisionTreeRegressor<F> {
638 #[must_use]
640 pub fn nodes(&self) -> &[Node<F>] {
641 &self.nodes
642 }
643
644 #[must_use]
646 pub fn n_features(&self) -> usize {
647 self.n_features
648 }
649
650 pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<F, FerroError> {
658 if x.nrows() != y.len() {
659 return Err(FerroError::ShapeMismatch {
660 expected: vec![x.nrows()],
661 actual: vec![y.len()],
662 context: "y length must match number of samples in X".into(),
663 });
664 }
665 let preds = self.predict(x)?;
666 Ok(crate::r2_score(&preds, y))
667 }
668}
669
670impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedDecisionTreeRegressor<F> {
671 type Output = Array1<F>;
672 type Error = FerroError;
673
674 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
681 if x.ncols() != self.n_features {
682 return Err(FerroError::ShapeMismatch {
683 expected: vec![self.n_features],
684 actual: vec![x.ncols()],
685 context: "number of features must match fitted model".into(),
686 });
687 }
688 let n_samples = x.nrows();
689 let mut predictions = Array1::zeros(n_samples);
690 for i in 0..n_samples {
691 let row = x.row(i);
692 let leaf = traverse_tree(&self.nodes, &row);
693 if let Node::Leaf { value, .. } = self.nodes[leaf] {
694 predictions[i] = value;
695 }
696 }
697 Ok(predictions)
698 }
699}
700
701impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedDecisionTreeRegressor<F> {
702 fn feature_importances(&self) -> &Array1<F> {
703 &self.feature_importances
704 }
705}
706
707impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for DecisionTreeRegressor<F> {
709 fn fit_pipeline(
710 &self,
711 x: &Array2<F>,
712 y: &Array1<F>,
713 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
714 let fitted = self.fit(x, y)?;
715 Ok(Box::new(fitted))
716 }
717}
718
719impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F>
720 for FittedDecisionTreeRegressor<F>
721{
722 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
723 self.predict(x)
724 }
725}
726
727fn traverse_tree<F: Float>(nodes: &[Node<F>], sample: &ndarray::ArrayView1<F>) -> usize {
733 let mut idx = 0;
734 loop {
735 match &nodes[idx] {
736 Node::Split {
737 feature,
738 threshold,
739 left,
740 right,
741 ..
742 } => {
743 if sample[*feature] <= *threshold {
744 idx = *left;
745 } else {
746 idx = *right;
747 }
748 }
749 Node::Leaf { .. } => return idx,
750 }
751 }
752}
753
754pub(crate) fn traverse<F: Float>(nodes: &[Node<F>], sample: &ndarray::ArrayView1<F>) -> usize {
758 traverse_tree(nodes, sample)
759}
760
761fn float_to_usize<F: Float>(v: F) -> usize {
763 v.to_f64().map_or(0, |f| f.round() as usize)
764}
765
766fn gini_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
768 if total == 0 {
769 return F::zero();
770 }
771 let total_f = F::from(total).unwrap();
772 let mut impurity = F::one();
773 for &count in class_counts {
774 let p = F::from(count).unwrap() / total_f;
775 impurity = impurity - p * p;
776 }
777 impurity
778}
779
780fn entropy_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
782 if total == 0 {
783 return F::zero();
784 }
785 let total_f = F::from(total).unwrap();
786 let mut ent = F::zero();
787 for &count in class_counts {
788 if count > 0 {
789 let p = F::from(count).unwrap() / total_f;
790 ent = ent - p * p.ln();
791 }
792 }
793 ent
794}
795
796fn mean_value<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
798 if indices.is_empty() {
799 return F::zero();
800 }
801 let sum: F = indices.iter().map(|&i| y[i]).fold(F::zero(), |a, b| a + b);
802 sum / F::from(indices.len()).unwrap()
803}
804
805fn mse_for_indices<F: Float>(y: &Array1<F>, indices: &[usize], mean: F) -> F {
807 if indices.is_empty() {
808 return F::zero();
809 }
810 let sum_sq: F = indices
811 .iter()
812 .map(|&i| {
813 let diff = y[i] - mean;
814 diff * diff
815 })
816 .fold(F::zero(), |a, b| a + b);
817 sum_sq / F::from(indices.len()).unwrap()
818}
819
820fn compute_impurity<F: Float>(
822 class_counts: &[usize],
823 total: usize,
824 criterion: ClassificationCriterion,
825) -> F {
826 match criterion {
827 ClassificationCriterion::Gini => gini_impurity(class_counts, total),
828 ClassificationCriterion::Entropy => entropy_impurity(class_counts, total),
829 }
830}
831
832fn make_classification_leaf<F: Float>(
834 nodes: &mut Vec<Node<F>>,
835 class_counts: &[usize],
836 n_classes: usize,
837 n_samples: usize,
838) -> usize {
839 let majority_class = class_counts
840 .iter()
841 .enumerate()
842 .max_by_key(|&(_, &count)| count)
843 .map_or(0, |(i, _)| i);
844
845 let total_f = if n_samples > 0 {
846 F::from(n_samples).unwrap()
847 } else {
848 F::one()
849 };
850 let distribution: Vec<F> = (0..n_classes)
851 .map(|c| F::from(class_counts[c]).unwrap() / total_f)
852 .collect();
853
854 let idx = nodes.len();
855 nodes.push(Node::Leaf {
856 value: F::from(majority_class).unwrap(),
857 class_distribution: Some(distribution),
858 n_samples,
859 });
860 idx
861}
862
863fn build_classification_tree<F: Float>(
867 data: &ClassificationData<'_, F>,
868 indices: &[usize],
869 nodes: &mut Vec<Node<F>>,
870 depth: usize,
871 params: &TreeParams,
872 mut rng: Option<&mut StdRng>,
873) -> usize {
874 let n = indices.len();
875
876 let mut class_counts = vec![0usize; data.n_classes];
877 for &i in indices {
878 class_counts[data.y[i]] += 1;
879 }
880
881 let should_stop = n < params.min_samples_split
882 || params.max_depth.is_some_and(|d| depth >= d)
883 || class_counts.iter().filter(|&&c| c > 0).count() <= 1;
884
885 if should_stop {
886 return make_classification_leaf(nodes, &class_counts, data.n_classes, n);
887 }
888
889 let best = find_best_classification_split(
892 data,
893 indices,
894 params.min_samples_leaf,
895 rng.as_deref_mut(),
896 );
897
898 if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
899 let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
900 .iter()
901 .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
902
903 let node_idx = nodes.len();
904 nodes.push(Node::Leaf {
905 value: F::zero(),
906 class_distribution: None,
907 n_samples: 0,
908 }); let left_idx = build_classification_tree(
911 data,
912 &left_indices,
913 nodes,
914 depth + 1,
915 params,
916 rng.as_deref_mut(),
917 );
918 let right_idx = build_classification_tree(
919 data,
920 &right_indices,
921 nodes,
922 depth + 1,
923 params,
924 rng.as_deref_mut(),
925 );
926
927 nodes[node_idx] = Node::Split {
928 feature: best_feature,
929 threshold: best_threshold,
930 left: left_idx,
931 right: right_idx,
932 impurity_decrease: best_impurity_decrease,
933 n_samples: n,
934 };
935
936 node_idx
937 } else {
938 make_classification_leaf(nodes, &class_counts, data.n_classes, n)
939 }
940}
941
942fn find_best_classification_split<F: Float>(
952 data: &ClassificationData<'_, F>,
953 indices: &[usize],
954 min_samples_leaf: usize,
955 rng: Option<&mut StdRng>,
956) -> Option<(usize, F, F)> {
957 let n = indices.len();
958 let n_f = F::from(n).unwrap();
959 let n_features = data.x.ncols();
960
961 let mut parent_counts = vec![0usize; data.n_classes];
962 for &i in indices {
963 parent_counts[data.y[i]] += 1;
964 }
965 let parent_impurity = compute_impurity::<F>(&parent_counts, n, data.criterion);
966
967 let mut best_score = F::neg_infinity();
968 let mut best_feature = 0;
969 let mut best_threshold = F::zero();
970
971 let candidate_features: Vec<usize> = match (data.max_features_per_split, rng) {
978 (Some(k), Some(rng)) => {
979 let k = k.min(n_features).max(1);
980 rand_sample_indices(rng, n_features, k).into_vec()
981 }
982 _ => match data.feature_indices {
983 Some(feat) => feat.to_vec(),
984 None => (0..n_features).collect(),
985 },
986 };
987
988 for feat in candidate_features {
989 let mut sorted_indices: Vec<usize> = indices.to_vec();
990 sorted_indices.sort_by(|&a, &b| data.x[[a, feat]].partial_cmp(&data.x[[b, feat]]).unwrap());
991
992 let mut left_counts = vec![0usize; data.n_classes];
993 let mut right_counts = parent_counts.clone();
994 let mut left_n = 0usize;
995
996 for split_pos in 0..n - 1 {
997 let idx = sorted_indices[split_pos];
998 let cls = data.y[idx];
999 left_counts[cls] += 1;
1000 right_counts[cls] -= 1;
1001 left_n += 1;
1002 let right_n = n - left_n;
1003
1004 let next_idx = sorted_indices[split_pos + 1];
1005 if data.x[[idx, feat]] == data.x[[next_idx, feat]] {
1006 continue;
1007 }
1008
1009 if left_n < min_samples_leaf || right_n < min_samples_leaf {
1010 continue;
1011 }
1012
1013 let left_impurity = compute_impurity::<F>(&left_counts, left_n, data.criterion);
1014 let right_impurity = compute_impurity::<F>(&right_counts, right_n, data.criterion);
1015 let left_weight = F::from(left_n).unwrap() / n_f;
1016 let right_weight = F::from(right_n).unwrap() / n_f;
1017 let weighted_child_impurity =
1018 left_weight * left_impurity + right_weight * right_impurity;
1019 let impurity_decrease = parent_impurity - weighted_child_impurity;
1020
1021 if impurity_decrease > best_score {
1022 best_score = impurity_decrease;
1023 best_feature = feat;
1024 best_threshold =
1025 (data.x[[idx, feat]] + data.x[[next_idx, feat]]) / F::from(2.0).unwrap();
1026 }
1027 }
1028 }
1029
1030 if best_score > F::zero() {
1031 Some((best_feature, best_threshold, best_score * n_f))
1032 } else {
1033 None
1034 }
1035}
1036
1037fn build_regression_tree<F: Float>(
1039 data: &RegressionData<'_, F>,
1040 indices: &[usize],
1041 nodes: &mut Vec<Node<F>>,
1042 depth: usize,
1043 params: &TreeParams,
1044 mut rng: Option<&mut StdRng>,
1045) -> usize {
1046 let n = indices.len();
1047 let mean = mean_value(data.y, indices);
1048
1049 let should_stop = n < params.min_samples_split || params.max_depth.is_some_and(|d| depth >= d);
1050
1051 if should_stop {
1052 let idx = nodes.len();
1053 nodes.push(Node::Leaf {
1054 value: mean,
1055 class_distribution: None,
1056 n_samples: n,
1057 });
1058 return idx;
1059 }
1060
1061 let parent_mse = mse_for_indices(data.y, indices, mean);
1062 if parent_mse <= F::epsilon() {
1063 let idx = nodes.len();
1064 nodes.push(Node::Leaf {
1065 value: mean,
1066 class_distribution: None,
1067 n_samples: n,
1068 });
1069 return idx;
1070 }
1071
1072 let best = find_best_regression_split(
1073 data,
1074 indices,
1075 params.min_samples_leaf,
1076 rng.as_deref_mut(),
1077 );
1078
1079 if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
1080 let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
1081 .iter()
1082 .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
1083
1084 let node_idx = nodes.len();
1085 nodes.push(Node::Leaf {
1086 value: F::zero(),
1087 class_distribution: None,
1088 n_samples: 0,
1089 }); let left_idx = build_regression_tree(
1092 data,
1093 &left_indices,
1094 nodes,
1095 depth + 1,
1096 params,
1097 rng.as_deref_mut(),
1098 );
1099 let right_idx = build_regression_tree(
1100 data,
1101 &right_indices,
1102 nodes,
1103 depth + 1,
1104 params,
1105 rng.as_deref_mut(),
1106 );
1107
1108 nodes[node_idx] = Node::Split {
1109 feature: best_feature,
1110 threshold: best_threshold,
1111 left: left_idx,
1112 right: right_idx,
1113 impurity_decrease: best_impurity_decrease,
1114 n_samples: n,
1115 };
1116
1117 node_idx
1118 } else {
1119 let idx = nodes.len();
1120 nodes.push(Node::Leaf {
1121 value: mean,
1122 class_distribution: None,
1123 n_samples: n,
1124 });
1125 idx
1126 }
1127}
1128
1129fn find_best_regression_split<F: Float>(
1136 data: &RegressionData<'_, F>,
1137 indices: &[usize],
1138 min_samples_leaf: usize,
1139 rng: Option<&mut StdRng>,
1140) -> Option<(usize, F, F)> {
1141 let n = indices.len();
1142 let n_f = F::from(n).unwrap();
1143 let n_features = data.x.ncols();
1144
1145 let parent_sum: F = indices
1146 .iter()
1147 .map(|&i| data.y[i])
1148 .fold(F::zero(), |a, b| a + b);
1149 let parent_sum_sq: F = indices
1150 .iter()
1151 .map(|&i| data.y[i] * data.y[i])
1152 .fold(F::zero(), |a, b| a + b);
1153 let parent_mse = parent_sum_sq / n_f - (parent_sum / n_f) * (parent_sum / n_f);
1154
1155 let mut best_score = F::neg_infinity();
1156 let mut best_feature = 0;
1157 let mut best_threshold = F::zero();
1158
1159 let candidate_features: Vec<usize> = match (data.max_features_per_split, rng) {
1160 (Some(k), Some(rng)) => {
1161 let k = k.min(n_features).max(1);
1162 rand_sample_indices(rng, n_features, k).into_vec()
1163 }
1164 _ => match data.feature_indices {
1165 Some(feat) => feat.to_vec(),
1166 None => (0..n_features).collect(),
1167 },
1168 };
1169
1170 for feat in candidate_features {
1171 let mut sorted_indices: Vec<usize> = indices.to_vec();
1172 sorted_indices.sort_by(|&a, &b| data.x[[a, feat]].partial_cmp(&data.x[[b, feat]]).unwrap());
1173
1174 let mut left_sum = F::zero();
1175 let mut left_sum_sq = F::zero();
1176 let mut left_n: usize = 0;
1177
1178 for split_pos in 0..n - 1 {
1179 let idx = sorted_indices[split_pos];
1180 let val = data.y[idx];
1181 left_sum = left_sum + val;
1182 left_sum_sq = left_sum_sq + val * val;
1183 left_n += 1;
1184 let right_n = n - left_n;
1185
1186 let next_idx = sorted_indices[split_pos + 1];
1187 if data.x[[idx, feat]] == data.x[[next_idx, feat]] {
1188 continue;
1189 }
1190
1191 if left_n < min_samples_leaf || right_n < min_samples_leaf {
1192 continue;
1193 }
1194
1195 let left_n_f = F::from(left_n).unwrap();
1196 let right_n_f = F::from(right_n).unwrap();
1197
1198 let left_mean = left_sum / left_n_f;
1199 let left_mse = left_sum_sq / left_n_f - left_mean * left_mean;
1200
1201 let right_sum = parent_sum - left_sum;
1202 let right_sum_sq = parent_sum_sq - left_sum_sq;
1203 let right_mean = right_sum / right_n_f;
1204 let right_mse = right_sum_sq / right_n_f - right_mean * right_mean;
1205
1206 let weighted_child_mse = (left_n_f * left_mse + right_n_f * right_mse) / n_f;
1207 let mse_decrease = parent_mse - weighted_child_mse;
1208
1209 if mse_decrease > best_score {
1210 best_score = mse_decrease;
1211 best_feature = feat;
1212 best_threshold =
1213 (data.x[[idx, feat]] + data.x[[next_idx, feat]]) / F::from(2.0).unwrap();
1214 }
1215 }
1216 }
1217
1218 if best_score > F::zero() {
1219 Some((best_feature, best_threshold, best_score * n_f))
1220 } else {
1221 None
1222 }
1223}
1224
1225pub(crate) fn compute_feature_importances<F: Float>(
1227 nodes: &[Node<F>],
1228 n_features: usize,
1229 _total_samples: usize,
1230) -> Array1<F> {
1231 let mut importances = Array1::zeros(n_features);
1232 for node in nodes {
1233 if let Node::Split {
1234 feature,
1235 impurity_decrease,
1236 ..
1237 } = node
1238 {
1239 importances[*feature] = importances[*feature] + *impurity_decrease;
1240 }
1241 }
1242 let total: F = importances.iter().copied().fold(F::zero(), |a, b| a + b);
1243 if total > F::zero() {
1244 importances.mapv_inplace(|v| v / total);
1245 }
1246 importances
1247}
1248
1249pub(crate) fn aggregate_tree_importances<F: Float>(
1264 trees: &[Vec<Node<F>>],
1265 feature_indices: Option<&[Vec<usize>]>,
1266 weights: Option<&[F]>,
1267 n_features: usize,
1268) -> Array1<F> {
1269 let mut total_imp = Array1::<F>::zeros(n_features);
1270 for (t, nodes) in trees.iter().enumerate() {
1271 let w = weights.map_or(F::one(), |ws| ws[t]);
1272 for node in nodes {
1273 if let Node::Split {
1274 feature,
1275 impurity_decrease,
1276 ..
1277 } = node
1278 {
1279 let original_feature = match feature_indices {
1280 Some(map) => map[t][*feature],
1281 None => *feature,
1282 };
1283 total_imp[original_feature] =
1284 total_imp[original_feature] + w * *impurity_decrease;
1285 }
1286 }
1287 }
1288 let total: F = total_imp.iter().copied().fold(F::zero(), |a, b| a + b);
1289 if total > F::zero() {
1290 total_imp.mapv_inplace(|v| v / total);
1291 }
1292 total_imp
1293}
1294
1295#[allow(clippy::too_many_arguments)]
1303pub(crate) fn build_classification_tree_with_feature_subset<F: Float>(
1304 x: &Array2<F>,
1305 y: &[usize],
1306 n_classes: usize,
1307 indices: &[usize],
1308 feature_indices: &[usize],
1309 params: &TreeParams,
1310 criterion: ClassificationCriterion,
1311) -> Vec<Node<F>> {
1312 let data = ClassificationData {
1313 x,
1314 y,
1315 n_classes,
1316 feature_indices: Some(feature_indices),
1317 max_features_per_split: None,
1318 criterion,
1319 };
1320 let mut nodes = Vec::new();
1321 build_classification_tree(&data, indices, &mut nodes, 0, params, None);
1322 nodes
1323}
1324
1325#[allow(clippy::too_many_arguments)]
1333pub(crate) fn build_classification_tree_per_split_features<F: Float>(
1334 x: &Array2<F>,
1335 y: &[usize],
1336 n_classes: usize,
1337 indices: &[usize],
1338 max_features: usize,
1339 params: &TreeParams,
1340 criterion: ClassificationCriterion,
1341 seed: u64,
1342) -> Vec<Node<F>> {
1343 let data = ClassificationData {
1344 x,
1345 y,
1346 n_classes,
1347 feature_indices: None,
1348 max_features_per_split: Some(max_features),
1349 criterion,
1350 };
1351 let mut rng = StdRng::seed_from_u64(seed);
1352 let mut nodes = Vec::new();
1353 build_classification_tree(&data, indices, &mut nodes, 0, params, Some(&mut rng));
1354 nodes
1355}
1356
1357pub(crate) fn build_regression_tree_with_feature_subset<F: Float>(
1359 x: &Array2<F>,
1360 y: &Array1<F>,
1361 indices: &[usize],
1362 feature_indices: &[usize],
1363 params: &TreeParams,
1364) -> Vec<Node<F>> {
1365 let data = RegressionData {
1366 x,
1367 y,
1368 feature_indices: Some(feature_indices),
1369 max_features_per_split: None,
1370 };
1371 let mut nodes = Vec::new();
1372 build_regression_tree(&data, indices, &mut nodes, 0, params, None);
1373 nodes
1374}
1375
1376pub(crate) fn build_regression_tree_per_split_features<F: Float>(
1381 x: &Array2<F>,
1382 y: &Array1<F>,
1383 indices: &[usize],
1384 max_features: usize,
1385 params: &TreeParams,
1386 seed: u64,
1387) -> Vec<Node<F>> {
1388 let data = RegressionData {
1389 x,
1390 y,
1391 feature_indices: None,
1392 max_features_per_split: Some(max_features),
1393 };
1394 let mut rng = StdRng::seed_from_u64(seed);
1395 let mut nodes = Vec::new();
1396 build_regression_tree(&data, indices, &mut nodes, 0, params, Some(&mut rng));
1397 nodes
1398}
1399
1400#[cfg(test)]
1405mod tests {
1406 use super::*;
1407 use approx::assert_relative_eq;
1408 use ndarray::array;
1409
1410 #[test]
1413 fn test_classifier_simple_binary() {
1414 let x = Array2::from_shape_vec(
1415 (6, 2),
1416 vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
1417 )
1418 .unwrap();
1419 let y = array![0, 0, 0, 1, 1, 1];
1420
1421 let model = DecisionTreeClassifier::<f64>::new();
1422 let fitted = model.fit(&x, &y).unwrap();
1423 let preds = fitted.predict(&x).unwrap();
1424
1425 assert_eq!(preds.len(), 6);
1426 for i in 0..3 {
1427 assert_eq!(preds[i], 0);
1428 }
1429 for i in 3..6 {
1430 assert_eq!(preds[i], 1);
1431 }
1432 }
1433
1434 #[test]
1435 fn test_classifier_single_class() {
1436 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1437 let y = array![0, 0, 0];
1438
1439 let model = DecisionTreeClassifier::<f64>::new();
1440 let fitted = model.fit(&x, &y).unwrap();
1441 let preds = fitted.predict(&x).unwrap();
1442
1443 assert_eq!(preds, array![0, 0, 0]);
1444 }
1445
1446 #[test]
1447 fn test_classifier_max_depth_1() {
1448 let x =
1449 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1450 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1451
1452 let model = DecisionTreeClassifier::<f64>::new().with_max_depth(Some(1));
1453 let fitted = model.fit(&x, &y).unwrap();
1454 let preds = fitted.predict(&x).unwrap();
1455
1456 for i in 0..4 {
1457 assert_eq!(preds[i], 0);
1458 }
1459 for i in 4..8 {
1460 assert_eq!(preds[i], 1);
1461 }
1462 }
1463
1464 #[test]
1465 fn test_classifier_min_samples_split() {
1466 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1467 let y = array![0, 0, 0, 1, 1, 1];
1468
1469 let model = DecisionTreeClassifier::<f64>::new().with_min_samples_split(7);
1470 let fitted = model.fit(&x, &y).unwrap();
1471 let preds = fitted.predict(&x).unwrap();
1472
1473 let majority = preds[0];
1474 for &p in &preds {
1475 assert_eq!(p, majority);
1476 }
1477 }
1478
1479 #[test]
1480 fn test_classifier_min_samples_leaf() {
1481 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1482 let y = array![0, 0, 0, 1, 1, 1];
1483
1484 let model = DecisionTreeClassifier::<f64>::new().with_min_samples_leaf(4);
1485 let fitted = model.fit(&x, &y).unwrap();
1486 let preds = fitted.predict(&x).unwrap();
1487
1488 let majority = preds[0];
1489 for &p in &preds {
1490 assert_eq!(p, majority);
1491 }
1492 }
1493
1494 #[test]
1495 fn test_classifier_gini_vs_entropy() {
1496 let x = Array2::from_shape_vec(
1497 (8, 2),
1498 vec![
1499 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0, 5.0, 5.0, 5.0, 6.0, 6.0, 5.0, 6.0, 6.0,
1500 ],
1501 )
1502 .unwrap();
1503 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1504
1505 let gini_model =
1506 DecisionTreeClassifier::<f64>::new().with_criterion(ClassificationCriterion::Gini);
1507 let entropy_model =
1508 DecisionTreeClassifier::<f64>::new().with_criterion(ClassificationCriterion::Entropy);
1509
1510 let fitted_gini = gini_model.fit(&x, &y).unwrap();
1511 let fitted_entropy = entropy_model.fit(&x, &y).unwrap();
1512
1513 let preds_gini = fitted_gini.predict(&x).unwrap();
1514 let preds_entropy = fitted_entropy.predict(&x).unwrap();
1515
1516 assert_eq!(preds_gini, y);
1517 assert_eq!(preds_entropy, y);
1518 }
1519
1520 #[test]
1521 fn test_classifier_predict_proba() {
1522 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1523 let y = array![0, 0, 0, 1, 1, 1];
1524
1525 let model = DecisionTreeClassifier::<f64>::new();
1526 let fitted = model.fit(&x, &y).unwrap();
1527 let proba = fitted.predict_proba(&x).unwrap();
1528
1529 assert_eq!(proba.dim(), (6, 2));
1530 for i in 0..6 {
1531 let row_sum: f64 = proba.row(i).iter().sum();
1532 assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
1533 }
1534 }
1535
1536 #[test]
1537 fn test_classifier_shape_mismatch_fit() {
1538 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1539 let y = array![0, 1];
1540
1541 let model = DecisionTreeClassifier::<f64>::new();
1542 assert!(model.fit(&x, &y).is_err());
1543 }
1544
1545 #[test]
1546 fn test_classifier_shape_mismatch_predict() {
1547 let x =
1548 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1549 let y = array![0, 0, 1, 1];
1550
1551 let model = DecisionTreeClassifier::<f64>::new();
1552 let fitted = model.fit(&x, &y).unwrap();
1553
1554 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1555 assert!(fitted.predict(&x_bad).is_err());
1556 }
1557
1558 #[test]
1559 fn test_classifier_empty_data() {
1560 let x = Array2::<f64>::zeros((0, 2));
1561 let y = Array1::<usize>::zeros(0);
1562
1563 let model = DecisionTreeClassifier::<f64>::new();
1564 assert!(model.fit(&x, &y).is_err());
1565 }
1566
1567 #[test]
1568 fn test_classifier_feature_importances() {
1569 let x = Array2::from_shape_vec(
1570 (8, 2),
1571 vec![
1572 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1573 ],
1574 )
1575 .unwrap();
1576 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1577
1578 let model = DecisionTreeClassifier::<f64>::new();
1579 let fitted = model.fit(&x, &y).unwrap();
1580 let importances = fitted.feature_importances();
1581
1582 assert_eq!(importances.len(), 2);
1583 assert!(importances[0] > 0.0);
1584 let sum: f64 = importances.iter().sum();
1585 assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
1586 }
1587
1588 #[test]
1589 fn test_classifier_has_classes() {
1590 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1591 let y = array![0, 1, 2, 0, 1, 2];
1592
1593 let model = DecisionTreeClassifier::<f64>::new();
1594 let fitted = model.fit(&x, &y).unwrap();
1595
1596 assert_eq!(fitted.classes(), &[0, 1, 2]);
1597 assert_eq!(fitted.n_classes(), 3);
1598 }
1599
1600 #[test]
1601 fn test_classifier_invalid_min_samples_split() {
1602 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1603 let y = array![0, 0, 1, 1];
1604
1605 let model = DecisionTreeClassifier::<f64>::new().with_min_samples_split(1);
1606 assert!(model.fit(&x, &y).is_err());
1607 }
1608
1609 #[test]
1610 fn test_classifier_invalid_min_samples_leaf() {
1611 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1612 let y = array![0, 0, 1, 1];
1613
1614 let model = DecisionTreeClassifier::<f64>::new().with_min_samples_leaf(0);
1615 assert!(model.fit(&x, &y).is_err());
1616 }
1617
1618 #[test]
1619 fn test_classifier_multiclass() {
1620 let x = Array2::from_shape_vec((9, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1621 .unwrap();
1622 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
1623
1624 let model = DecisionTreeClassifier::<f64>::new();
1625 let fitted = model.fit(&x, &y).unwrap();
1626 let preds = fitted.predict(&x).unwrap();
1627
1628 assert_eq!(preds, y);
1629 }
1630
1631 #[test]
1632 fn test_classifier_pipeline_integration() {
1633 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1634 let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
1635
1636 let model = DecisionTreeClassifier::<f64>::new();
1637 let fitted = model.fit_pipeline(&x, &y).unwrap();
1638 let preds = fitted.predict_pipeline(&x).unwrap();
1639 assert_eq!(preds.len(), 6);
1640 }
1641
1642 #[test]
1645 fn test_regressor_simple() {
1646 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
1647 let y = array![1.0, 2.0, 3.0, 4.0, 5.0];
1648
1649 let model = DecisionTreeRegressor::<f64>::new();
1650 let fitted = model.fit(&x, &y).unwrap();
1651 let preds = fitted.predict(&x).unwrap();
1652
1653 for (p, &actual) in preds.iter().zip(y.iter()) {
1654 assert_relative_eq!(*p, actual, epsilon = 1e-10);
1655 }
1656 }
1657
1658 #[test]
1659 fn test_regressor_max_depth() {
1660 let x =
1661 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1662 let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1663
1664 let model = DecisionTreeRegressor::<f64>::new().with_max_depth(Some(1));
1665 let fitted = model.fit(&x, &y).unwrap();
1666 let preds = fitted.predict(&x).unwrap();
1667
1668 for i in 0..4 {
1669 assert_relative_eq!(preds[i], 1.0, epsilon = 1e-10);
1670 }
1671 for i in 4..8 {
1672 assert_relative_eq!(preds[i], 5.0, epsilon = 1e-10);
1673 }
1674 }
1675
1676 #[test]
1677 fn test_regressor_constant_target() {
1678 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1679 let y = array![3.0, 3.0, 3.0, 3.0];
1680
1681 let model = DecisionTreeRegressor::<f64>::new();
1682 let fitted = model.fit(&x, &y).unwrap();
1683 let preds = fitted.predict(&x).unwrap();
1684
1685 for &p in &preds {
1686 assert_relative_eq!(p, 3.0, epsilon = 1e-10);
1687 }
1688 }
1689
1690 #[test]
1691 fn test_regressor_shape_mismatch_fit() {
1692 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1693 let y = array![1.0, 2.0];
1694
1695 let model = DecisionTreeRegressor::<f64>::new();
1696 assert!(model.fit(&x, &y).is_err());
1697 }
1698
1699 #[test]
1700 fn test_regressor_shape_mismatch_predict() {
1701 let x =
1702 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1703 let y = array![1.0, 2.0, 3.0, 4.0];
1704
1705 let model = DecisionTreeRegressor::<f64>::new();
1706 let fitted = model.fit(&x, &y).unwrap();
1707
1708 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1709 assert!(fitted.predict(&x_bad).is_err());
1710 }
1711
1712 #[test]
1713 fn test_regressor_empty_data() {
1714 let x = Array2::<f64>::zeros((0, 2));
1715 let y = Array1::<f64>::zeros(0);
1716
1717 let model = DecisionTreeRegressor::<f64>::new();
1718 assert!(model.fit(&x, &y).is_err());
1719 }
1720
1721 #[test]
1722 fn test_regressor_feature_importances() {
1723 let x = Array2::from_shape_vec(
1724 (8, 2),
1725 vec![
1726 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1727 ],
1728 )
1729 .unwrap();
1730 let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1731
1732 let model = DecisionTreeRegressor::<f64>::new();
1733 let fitted = model.fit(&x, &y).unwrap();
1734 let importances = fitted.feature_importances();
1735
1736 assert_eq!(importances.len(), 2);
1737 assert!(importances[0] > 0.0);
1738 let sum: f64 = importances.iter().sum();
1739 assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
1740 }
1741
1742 #[test]
1743 fn test_regressor_min_samples_split() {
1744 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1745 let y = array![1.0, 2.0, 3.0, 4.0];
1746
1747 let model = DecisionTreeRegressor::<f64>::new().with_min_samples_split(5);
1748 let fitted = model.fit(&x, &y).unwrap();
1749 let preds = fitted.predict(&x).unwrap();
1750
1751 let mean = 2.5;
1752 for &p in &preds {
1753 assert_relative_eq!(p, mean, epsilon = 1e-10);
1754 }
1755 }
1756
1757 #[test]
1758 fn test_regressor_pipeline_integration() {
1759 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1760 let y = array![1.0, 2.0, 3.0, 4.0];
1761
1762 let model = DecisionTreeRegressor::<f64>::new();
1763 let fitted = model.fit_pipeline(&x, &y).unwrap();
1764 let preds = fitted.predict_pipeline(&x).unwrap();
1765 assert_eq!(preds.len(), 4);
1766 }
1767
1768 #[test]
1769 fn test_regressor_f32_support() {
1770 let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1771 let y = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
1772
1773 let model = DecisionTreeRegressor::<f32>::new();
1774 let fitted = model.fit(&x, &y).unwrap();
1775 let preds = fitted.predict(&x).unwrap();
1776 assert_eq!(preds.len(), 4);
1777 }
1778
1779 #[test]
1780 fn test_classifier_f32_support() {
1781 let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1782 let y = array![0, 0, 0, 1, 1, 1];
1783
1784 let model = DecisionTreeClassifier::<f32>::new();
1785 let fitted = model.fit(&x, &y).unwrap();
1786 let preds = fitted.predict(&x).unwrap();
1787 assert_eq!(preds.len(), 6);
1788 }
1789
1790 #[test]
1793 fn test_gini_impurity_pure() {
1794 let counts = vec![5, 0];
1795 let imp: f64 = gini_impurity(&counts, 5);
1796 assert_relative_eq!(imp, 0.0, epsilon = 1e-10);
1797 }
1798
1799 #[test]
1800 fn test_gini_impurity_balanced() {
1801 let counts = vec![5, 5];
1802 let imp: f64 = gini_impurity(&counts, 10);
1803 assert_relative_eq!(imp, 0.5, epsilon = 1e-10);
1804 }
1805
1806 #[test]
1807 fn test_entropy_pure() {
1808 let counts = vec![5, 0];
1809 let ent: f64 = entropy_impurity(&counts, 5);
1810 assert_relative_eq!(ent, 0.0, epsilon = 1e-10);
1811 }
1812
1813 #[test]
1814 fn test_entropy_balanced() {
1815 let counts = vec![5, 5];
1816 let ent: f64 = entropy_impurity(&counts, 10);
1817 assert_relative_eq!(ent, 2.0f64.ln(), epsilon = 1e-10);
1818 }
1819}