Skip to main content

ferrolearn_tree/
decision_tree.rs

1//! CART decision tree classifiers and regressors.
2//!
3//! This module provides [`DecisionTreeClassifier`] and [`DecisionTreeRegressor`],
4//! implementing the Classification and Regression Trees (CART) algorithm with
5//! configurable splitting criteria, depth limits, and minimum sample constraints.
6//!
7//! # Examples
8//!
9//! ```
10//! use ferrolearn_tree::DecisionTreeClassifier;
11//! use ferrolearn_core::{Fit, Predict};
12//! use ndarray::{array, Array1, Array2};
13//!
14//! let x = Array2::from_shape_vec((6, 2), vec![
15//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,
16//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,
17//! ]).unwrap();
18//! let y = array![0, 0, 0, 1, 1, 1];
19//!
20//! let model = DecisionTreeClassifier::<f64>::new();
21//! let fitted = model.fit(&x, &y).unwrap();
22//! let preds = fitted.predict(&x).unwrap();
23//! ```
24
25use ferrolearn_core::error::FerroError;
26use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
27use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
28use ferrolearn_core::traits::{Fit, Predict};
29use ndarray::{Array1, Array2};
30use num_traits::{Float, FromPrimitive, ToPrimitive};
31use rand::SeedableRng;
32use rand::rngs::StdRng;
33use rand::seq::index::sample as rand_sample_indices;
34use serde::{Deserialize, Serialize};
35
36// ---------------------------------------------------------------------------
37// Splitting criterion enums
38// ---------------------------------------------------------------------------
39
40/// Splitting criterion for classification trees.
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
42pub enum ClassificationCriterion {
43    /// Gini impurity.
44    Gini,
45    /// Shannon entropy.
46    Entropy,
47}
48
49/// Splitting criterion for regression trees.
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
51pub enum RegressionCriterion {
52    /// Mean squared error.
53    Mse,
54}
55
56// ---------------------------------------------------------------------------
57// Node representation (flat vec for cache efficiency)
58// ---------------------------------------------------------------------------
59
60/// A single node in the decision tree, stored in a flat `Vec` for cache efficiency.
61///
62/// Internal nodes hold a split (feature index + threshold), while leaf nodes
63/// store a prediction value and optional class distribution (for classifiers).
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub enum Node<F> {
66    /// An internal split node.
67    Split {
68        /// Feature index used for the split.
69        feature: usize,
70        /// Threshold value; samples with `x[feature] <= threshold` go left.
71        threshold: F,
72        /// Index of the left child node in the flat vec.
73        left: usize,
74        /// Index of the right child node in the flat vec.
75        right: usize,
76        /// Weighted impurity decrease from this split (for feature importance).
77        impurity_decrease: F,
78        /// Number of samples that reached this node during training.
79        n_samples: usize,
80    },
81    /// A leaf node that stores a prediction.
82    Leaf {
83        /// Predicted value: class label (as F) for classifiers, mean for regressors.
84        value: F,
85        /// Class distribution (proportion of each class). Only used by classifiers.
86        class_distribution: Option<Vec<F>>,
87        /// Number of samples that reached this node during training.
88        n_samples: usize,
89    },
90}
91
92// ---------------------------------------------------------------------------
93// Internal config structs (to reduce argument counts)
94// ---------------------------------------------------------------------------
95
96/// Configuration parameters for tree building, bundled to reduce argument counts.
97#[derive(Debug, Clone, Copy)]
98pub(crate) struct TreeParams {
99    pub(crate) max_depth: Option<usize>,
100    pub(crate) min_samples_split: usize,
101    pub(crate) min_samples_leaf: usize,
102}
103
104/// Data references for classification tree building.
105struct ClassificationData<'a, F> {
106    x: &'a Array2<F>,
107    y: &'a [usize],
108    n_classes: usize,
109    /// Fixed feature subset for the entire tree (used by Bagging-style
110    /// per-tree feature subsampling). Mutually exclusive with
111    /// [`max_features_per_split`].
112    feature_indices: Option<&'a [usize]>,
113    /// When set, every split samples a fresh random subset of this many
114    /// features (per-split feature sampling, the Breiman 2001 RandomForest
115    /// behaviour and what scikit-learn does).
116    max_features_per_split: Option<usize>,
117    criterion: ClassificationCriterion,
118}
119
120/// Data references for regression tree building.
121struct RegressionData<'a, F> {
122    x: &'a Array2<F>,
123    y: &'a Array1<F>,
124    feature_indices: Option<&'a [usize]>,
125    /// See [`ClassificationData::max_features_per_split`].
126    max_features_per_split: Option<usize>,
127}
128
129// ---------------------------------------------------------------------------
130// DecisionTreeClassifier
131// ---------------------------------------------------------------------------
132
133/// CART decision tree classifier.
134///
135/// Builds a binary tree by recursively finding the feature and threshold that
136/// maximises the reduction in the chosen impurity criterion (Gini or Entropy).
137///
138/// # Type Parameters
139///
140/// - `F`: The floating-point type (`f32` or `f64`).
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct DecisionTreeClassifier<F> {
143    /// Maximum depth of the tree. `None` means unlimited.
144    pub max_depth: Option<usize>,
145    /// Minimum number of samples required to split an internal node.
146    pub min_samples_split: usize,
147    /// Minimum number of samples required in a leaf node.
148    pub min_samples_leaf: usize,
149    /// Splitting criterion.
150    pub criterion: ClassificationCriterion,
151    _marker: std::marker::PhantomData<F>,
152}
153
154impl<F: Float> DecisionTreeClassifier<F> {
155    /// Create a new `DecisionTreeClassifier` with default settings.
156    ///
157    /// Defaults: `max_depth = None`, `min_samples_split = 2`,
158    /// `min_samples_leaf = 1`, `criterion = Gini`.
159    #[must_use]
160    pub fn new() -> Self {
161        Self {
162            max_depth: None,
163            min_samples_split: 2,
164            min_samples_leaf: 1,
165            criterion: ClassificationCriterion::Gini,
166            _marker: std::marker::PhantomData,
167        }
168    }
169
170    /// Set the maximum tree depth.
171    #[must_use]
172    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
173        self.max_depth = max_depth;
174        self
175    }
176
177    /// Set the minimum number of samples required to split a node.
178    #[must_use]
179    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
180        self.min_samples_split = min_samples_split;
181        self
182    }
183
184    /// Set the minimum number of samples required in a leaf node.
185    #[must_use]
186    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
187        self.min_samples_leaf = min_samples_leaf;
188        self
189    }
190
191    /// Set the splitting criterion.
192    #[must_use]
193    pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
194        self.criterion = criterion;
195        self
196    }
197}
198
199impl<F: Float> Default for DecisionTreeClassifier<F> {
200    fn default() -> Self {
201        Self::new()
202    }
203}
204
205// ---------------------------------------------------------------------------
206// FittedDecisionTreeClassifier
207// ---------------------------------------------------------------------------
208
209/// A fitted CART decision tree classifier.
210///
211/// Stores the learned tree as a flat `Vec<Node<F>>` for cache-friendly traversal.
212/// Implements [`Predict`] for generating class predictions and
213/// [`HasFeatureImportances`] for inspecting per-feature importance scores.
214#[derive(Debug, Clone)]
215pub struct FittedDecisionTreeClassifier<F> {
216    /// Flat node storage; index 0 is the root.
217    nodes: Vec<Node<F>>,
218    /// Sorted unique class labels observed during training.
219    classes: Vec<usize>,
220    /// Number of features the model was trained on.
221    n_features: usize,
222    /// Per-feature importance scores (normalised to sum to 1).
223    feature_importances: Array1<F>,
224}
225
226impl<F: Float + Send + Sync + 'static> FittedDecisionTreeClassifier<F> {
227    /// Returns a reference to the flat node storage of the tree.
228    #[must_use]
229    pub fn nodes(&self) -> &[Node<F>] {
230        &self.nodes
231    }
232
233    /// Returns the number of features the model was trained on.
234    #[must_use]
235    pub fn n_features(&self) -> usize {
236        self.n_features
237    }
238
239    /// Predict class probabilities for each sample.
240    ///
241    /// Returns a 2-D array of shape `(n_samples, n_classes)`.
242    ///
243    /// # Errors
244    ///
245    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
246    /// not match the training data.
247    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
248        if x.ncols() != self.n_features {
249            return Err(FerroError::ShapeMismatch {
250                expected: vec![self.n_features],
251                actual: vec![x.ncols()],
252                context: "number of features must match fitted model".into(),
253            });
254        }
255        let n_samples = x.nrows();
256        let n_classes = self.classes.len();
257        let mut proba = Array2::zeros((n_samples, n_classes));
258        for i in 0..n_samples {
259            let row = x.row(i);
260            let leaf = traverse_tree(&self.nodes, &row);
261            if let Node::Leaf {
262                class_distribution: Some(ref dist),
263                ..
264            } = self.nodes[leaf]
265            {
266                for (j, &p) in dist.iter().enumerate() {
267                    proba[[i, j]] = p;
268                }
269            }
270        }
271        Ok(proba)
272    }
273
274    /// Mean accuracy on the given test data and labels.
275    /// Equivalent to sklearn's `ClassifierMixin.score`.
276    ///
277    /// # Errors
278    ///
279    /// Returns [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()` or
280    /// the feature count does not match the training data.
281    pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
282        if x.nrows() != y.len() {
283            return Err(FerroError::ShapeMismatch {
284                expected: vec![x.nrows()],
285                actual: vec![y.len()],
286                context: "y length must match number of samples in X".into(),
287            });
288        }
289        let preds = self.predict(x)?;
290        Ok(crate::mean_accuracy(&preds, y))
291    }
292
293    /// Element-wise log of [`predict_proba`](Self::predict_proba). Mirrors
294    /// sklearn's `ClassifierMixin.predict_log_proba`.
295    ///
296    /// # Errors
297    ///
298    /// Forwards any error from [`predict_proba`](Self::predict_proba).
299    pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
300        let proba = self.predict_proba(x)?;
301        Ok(crate::log_proba(&proba))
302    }
303}
304
305impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for DecisionTreeClassifier<F> {
306    type Fitted = FittedDecisionTreeClassifier<F>;
307    type Error = FerroError;
308
309    /// Fit the decision tree classifier on the training data.
310    ///
311    /// # Errors
312    ///
313    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
314    /// numbers of samples.
315    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
316    /// Returns [`FerroError::InvalidParameter`] if hyperparameters are invalid.
317    fn fit(
318        &self,
319        x: &Array2<F>,
320        y: &Array1<usize>,
321    ) -> Result<FittedDecisionTreeClassifier<F>, FerroError> {
322        let (n_samples, n_features) = x.dim();
323
324        if n_samples != y.len() {
325            return Err(FerroError::ShapeMismatch {
326                expected: vec![n_samples],
327                actual: vec![y.len()],
328                context: "y length must match number of samples in X".into(),
329            });
330        }
331        if n_samples == 0 {
332            return Err(FerroError::InsufficientSamples {
333                required: 1,
334                actual: 0,
335                context: "DecisionTreeClassifier requires at least one sample".into(),
336            });
337        }
338        if self.min_samples_split < 2 {
339            return Err(FerroError::InvalidParameter {
340                name: "min_samples_split".into(),
341                reason: "must be at least 2".into(),
342            });
343        }
344        if self.min_samples_leaf < 1 {
345            return Err(FerroError::InvalidParameter {
346                name: "min_samples_leaf".into(),
347                reason: "must be at least 1".into(),
348            });
349        }
350
351        // Determine unique classes.
352        let mut classes: Vec<usize> = y.iter().copied().collect();
353        classes.sort_unstable();
354        classes.dedup();
355        let n_classes = classes.len();
356
357        // Map class labels to indices 0..n_classes.
358        let y_mapped: Vec<usize> = y
359            .iter()
360            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
361            .collect();
362
363        let indices: Vec<usize> = (0..n_samples).collect();
364
365        let data = ClassificationData {
366            x,
367            y: &y_mapped,
368            n_classes,
369            feature_indices: None,
370            max_features_per_split: None,
371            criterion: self.criterion,
372        };
373        let params = TreeParams {
374            max_depth: self.max_depth,
375            min_samples_split: self.min_samples_split,
376            min_samples_leaf: self.min_samples_leaf,
377        };
378
379        let mut nodes: Vec<Node<F>> = Vec::new();
380        build_classification_tree(&data, &indices, &mut nodes, 0, &params, None);
381
382        let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
383
384        Ok(FittedDecisionTreeClassifier {
385            nodes,
386            classes,
387            n_features,
388            feature_importances,
389        })
390    }
391}
392
393impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedDecisionTreeClassifier<F> {
394    type Output = Array1<usize>;
395    type Error = FerroError;
396
397    /// Predict class labels for the given feature matrix.
398    ///
399    /// # Errors
400    ///
401    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
402    /// not match the fitted model.
403    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
404        if x.ncols() != self.n_features {
405            return Err(FerroError::ShapeMismatch {
406                expected: vec![self.n_features],
407                actual: vec![x.ncols()],
408                context: "number of features must match fitted model".into(),
409            });
410        }
411        let n_samples = x.nrows();
412        let mut predictions = Array1::zeros(n_samples);
413        for i in 0..n_samples {
414            let row = x.row(i);
415            let leaf = traverse_tree(&self.nodes, &row);
416            if let Node::Leaf { value, .. } = self.nodes[leaf] {
417                predictions[i] = float_to_usize(value);
418            }
419        }
420        Ok(predictions)
421    }
422}
423
424impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
425    for FittedDecisionTreeClassifier<F>
426{
427    fn feature_importances(&self) -> &Array1<F> {
428        &self.feature_importances
429    }
430}
431
432impl<F: Float + Send + Sync + 'static> HasClasses for FittedDecisionTreeClassifier<F> {
433    fn classes(&self) -> &[usize] {
434        &self.classes
435    }
436
437    fn n_classes(&self) -> usize {
438        self.classes.len()
439    }
440}
441
442// Pipeline integration.
443impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
444    for DecisionTreeClassifier<F>
445{
446    fn fit_pipeline(
447        &self,
448        x: &Array2<F>,
449        y: &Array1<F>,
450    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
451        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
452        let fitted = self.fit(x, &y_usize)?;
453        Ok(Box::new(FittedClassifierPipelineAdapter(fitted)))
454    }
455}
456
457/// Adapter to make `FittedDecisionTreeClassifier<F>` work as a pipeline estimator.
458struct FittedClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
459    FittedDecisionTreeClassifier<F>,
460);
461
462impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
463    for FittedClassifierPipelineAdapter<F>
464{
465    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
466        let preds = self.0.predict(x)?;
467        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
468    }
469}
470
471// ---------------------------------------------------------------------------
472// DecisionTreeRegressor
473// ---------------------------------------------------------------------------
474
475/// CART decision tree regressor.
476///
477/// Builds a binary tree by recursively finding the feature and threshold that
478/// minimises the mean squared error of the split.
479///
480/// # Type Parameters
481///
482/// - `F`: The floating-point type (`f32` or `f64`).
483#[derive(Debug, Clone, Serialize, Deserialize)]
484pub struct DecisionTreeRegressor<F> {
485    /// Maximum depth of the tree. `None` means unlimited.
486    pub max_depth: Option<usize>,
487    /// Minimum number of samples required to split an internal node.
488    pub min_samples_split: usize,
489    /// Minimum number of samples required in a leaf node.
490    pub min_samples_leaf: usize,
491    /// Splitting criterion.
492    pub criterion: RegressionCriterion,
493    _marker: std::marker::PhantomData<F>,
494}
495
496impl<F: Float> DecisionTreeRegressor<F> {
497    /// Create a new `DecisionTreeRegressor` with default settings.
498    ///
499    /// Defaults: `max_depth = None`, `min_samples_split = 2`,
500    /// `min_samples_leaf = 1`, `criterion = MSE`.
501    #[must_use]
502    pub fn new() -> Self {
503        Self {
504            max_depth: None,
505            min_samples_split: 2,
506            min_samples_leaf: 1,
507            criterion: RegressionCriterion::Mse,
508            _marker: std::marker::PhantomData,
509        }
510    }
511
512    /// Set the maximum tree depth.
513    #[must_use]
514    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
515        self.max_depth = max_depth;
516        self
517    }
518
519    /// Set the minimum number of samples required to split a node.
520    #[must_use]
521    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
522        self.min_samples_split = min_samples_split;
523        self
524    }
525
526    /// Set the minimum number of samples required in a leaf node.
527    #[must_use]
528    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
529        self.min_samples_leaf = min_samples_leaf;
530        self
531    }
532
533    /// Set the splitting criterion.
534    #[must_use]
535    pub fn with_criterion(mut self, criterion: RegressionCriterion) -> Self {
536        self.criterion = criterion;
537        self
538    }
539}
540
541impl<F: Float> Default for DecisionTreeRegressor<F> {
542    fn default() -> Self {
543        Self::new()
544    }
545}
546
547// ---------------------------------------------------------------------------
548// FittedDecisionTreeRegressor
549// ---------------------------------------------------------------------------
550
551/// A fitted CART decision tree regressor.
552///
553/// Stores the learned tree as a flat `Vec<Node<F>>` for cache-friendly traversal.
554#[derive(Debug, Clone)]
555pub struct FittedDecisionTreeRegressor<F> {
556    /// Flat node storage; index 0 is the root.
557    nodes: Vec<Node<F>>,
558    /// Number of features the model was trained on.
559    n_features: usize,
560    /// Per-feature importance scores (normalised to sum to 1).
561    feature_importances: Array1<F>,
562}
563
564impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for DecisionTreeRegressor<F> {
565    type Fitted = FittedDecisionTreeRegressor<F>;
566    type Error = FerroError;
567
568    /// Fit the decision tree regressor on the training data.
569    ///
570    /// # Errors
571    ///
572    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
573    /// numbers of samples.
574    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
575    /// Returns [`FerroError::InvalidParameter`] if hyperparameters are invalid.
576    fn fit(
577        &self,
578        x: &Array2<F>,
579        y: &Array1<F>,
580    ) -> Result<FittedDecisionTreeRegressor<F>, FerroError> {
581        let (n_samples, n_features) = x.dim();
582
583        if n_samples != y.len() {
584            return Err(FerroError::ShapeMismatch {
585                expected: vec![n_samples],
586                actual: vec![y.len()],
587                context: "y length must match number of samples in X".into(),
588            });
589        }
590        if n_samples == 0 {
591            return Err(FerroError::InsufficientSamples {
592                required: 1,
593                actual: 0,
594                context: "DecisionTreeRegressor requires at least one sample".into(),
595            });
596        }
597        if self.min_samples_split < 2 {
598            return Err(FerroError::InvalidParameter {
599                name: "min_samples_split".into(),
600                reason: "must be at least 2".into(),
601            });
602        }
603        if self.min_samples_leaf < 1 {
604            return Err(FerroError::InvalidParameter {
605                name: "min_samples_leaf".into(),
606                reason: "must be at least 1".into(),
607            });
608        }
609
610        let indices: Vec<usize> = (0..n_samples).collect();
611
612        let data = RegressionData {
613            x,
614            y,
615            feature_indices: None,
616            max_features_per_split: None,
617        };
618        let params = TreeParams {
619            max_depth: self.max_depth,
620            min_samples_split: self.min_samples_split,
621            min_samples_leaf: self.min_samples_leaf,
622        };
623
624        let mut nodes: Vec<Node<F>> = Vec::new();
625        build_regression_tree(&data, &indices, &mut nodes, 0, &params, None);
626
627        let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
628
629        Ok(FittedDecisionTreeRegressor {
630            nodes,
631            n_features,
632            feature_importances,
633        })
634    }
635}
636
637impl<F: Float + Send + Sync + 'static> FittedDecisionTreeRegressor<F> {
638    /// Returns a reference to the flat node storage of the tree.
639    #[must_use]
640    pub fn nodes(&self) -> &[Node<F>] {
641        &self.nodes
642    }
643
644    /// Returns the number of features the model was trained on.
645    #[must_use]
646    pub fn n_features(&self) -> usize {
647        self.n_features
648    }
649
650    /// R² coefficient of determination on the given test data.
651    /// Equivalent to sklearn's `RegressorMixin.score`.
652    ///
653    /// # Errors
654    ///
655    /// Returns [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()` or
656    /// the feature count does not match the training data.
657    pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<F, FerroError> {
658        if x.nrows() != y.len() {
659            return Err(FerroError::ShapeMismatch {
660                expected: vec![x.nrows()],
661                actual: vec![y.len()],
662                context: "y length must match number of samples in X".into(),
663            });
664        }
665        let preds = self.predict(x)?;
666        Ok(crate::r2_score(&preds, y))
667    }
668}
669
670impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedDecisionTreeRegressor<F> {
671    type Output = Array1<F>;
672    type Error = FerroError;
673
674    /// Predict target values for the given feature matrix.
675    ///
676    /// # Errors
677    ///
678    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
679    /// not match the fitted model.
680    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
681        if x.ncols() != self.n_features {
682            return Err(FerroError::ShapeMismatch {
683                expected: vec![self.n_features],
684                actual: vec![x.ncols()],
685                context: "number of features must match fitted model".into(),
686            });
687        }
688        let n_samples = x.nrows();
689        let mut predictions = Array1::zeros(n_samples);
690        for i in 0..n_samples {
691            let row = x.row(i);
692            let leaf = traverse_tree(&self.nodes, &row);
693            if let Node::Leaf { value, .. } = self.nodes[leaf] {
694                predictions[i] = value;
695            }
696        }
697        Ok(predictions)
698    }
699}
700
701impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedDecisionTreeRegressor<F> {
702    fn feature_importances(&self) -> &Array1<F> {
703        &self.feature_importances
704    }
705}
706
707// Pipeline integration.
708impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for DecisionTreeRegressor<F> {
709    fn fit_pipeline(
710        &self,
711        x: &Array2<F>,
712        y: &Array1<F>,
713    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
714        let fitted = self.fit(x, y)?;
715        Ok(Box::new(fitted))
716    }
717}
718
719impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F>
720    for FittedDecisionTreeRegressor<F>
721{
722    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
723        self.predict(x)
724    }
725}
726
727// ---------------------------------------------------------------------------
728// Internal: tree building helpers
729// ---------------------------------------------------------------------------
730
731/// Traverse the tree from root to leaf for a single sample, returning the leaf node index.
732fn traverse_tree<F: Float>(nodes: &[Node<F>], sample: &ndarray::ArrayView1<F>) -> usize {
733    let mut idx = 0;
734    loop {
735        match &nodes[idx] {
736            Node::Split {
737                feature,
738                threshold,
739                left,
740                right,
741                ..
742            } => {
743                if sample[*feature] <= *threshold {
744                    idx = *left;
745                } else {
746                    idx = *right;
747                }
748            }
749            Node::Leaf { .. } => return idx,
750        }
751    }
752}
753
754/// Traverse a tree from root to leaf for a single sample (crate-public wrapper).
755///
756/// Returns the index of the leaf node in the flat node vector.
757pub(crate) fn traverse<F: Float>(nodes: &[Node<F>], sample: &ndarray::ArrayView1<F>) -> usize {
758    traverse_tree(nodes, sample)
759}
760
761/// Convert a `Float` value to `usize` (for class labels stored as floats).
762fn float_to_usize<F: Float>(v: F) -> usize {
763    v.to_f64().map_or(0, |f| f.round() as usize)
764}
765
766/// Compute the Gini impurity for a set of class counts.
767fn gini_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
768    if total == 0 {
769        return F::zero();
770    }
771    let total_f = F::from(total).unwrap();
772    let mut impurity = F::one();
773    for &count in class_counts {
774        let p = F::from(count).unwrap() / total_f;
775        impurity = impurity - p * p;
776    }
777    impurity
778}
779
780/// Compute the Shannon entropy for a set of class counts.
781fn entropy_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
782    if total == 0 {
783        return F::zero();
784    }
785    let total_f = F::from(total).unwrap();
786    let mut ent = F::zero();
787    for &count in class_counts {
788        if count > 0 {
789            let p = F::from(count).unwrap() / total_f;
790            ent = ent - p * p.ln();
791        }
792    }
793    ent
794}
795
796/// Compute the mean of target values for the given indices.
797fn mean_value<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
798    if indices.is_empty() {
799        return F::zero();
800    }
801    let sum: F = indices.iter().map(|&i| y[i]).fold(F::zero(), |a, b| a + b);
802    sum / F::from(indices.len()).unwrap()
803}
804
805/// Compute the MSE for the given indices relative to a given mean.
806fn mse_for_indices<F: Float>(y: &Array1<F>, indices: &[usize], mean: F) -> F {
807    if indices.is_empty() {
808        return F::zero();
809    }
810    let sum_sq: F = indices
811        .iter()
812        .map(|&i| {
813            let diff = y[i] - mean;
814            diff * diff
815        })
816        .fold(F::zero(), |a, b| a + b);
817    sum_sq / F::from(indices.len()).unwrap()
818}
819
820/// Compute impurity for a given classification criterion.
821fn compute_impurity<F: Float>(
822    class_counts: &[usize],
823    total: usize,
824    criterion: ClassificationCriterion,
825) -> F {
826    match criterion {
827        ClassificationCriterion::Gini => gini_impurity(class_counts, total),
828        ClassificationCriterion::Entropy => entropy_impurity(class_counts, total),
829    }
830}
831
832/// Create a classification leaf node and return its index.
833fn make_classification_leaf<F: Float>(
834    nodes: &mut Vec<Node<F>>,
835    class_counts: &[usize],
836    n_classes: usize,
837    n_samples: usize,
838) -> usize {
839    let majority_class = class_counts
840        .iter()
841        .enumerate()
842        .max_by_key(|&(_, &count)| count)
843        .map_or(0, |(i, _)| i);
844
845    let total_f = if n_samples > 0 {
846        F::from(n_samples).unwrap()
847    } else {
848        F::one()
849    };
850    let distribution: Vec<F> = (0..n_classes)
851        .map(|c| F::from(class_counts[c]).unwrap() / total_f)
852        .collect();
853
854    let idx = nodes.len();
855    nodes.push(Node::Leaf {
856        value: F::from(majority_class).unwrap(),
857        class_distribution: Some(distribution),
858        n_samples,
859    });
860    idx
861}
862
863/// Build a classification tree recursively.
864///
865/// Returns the index of the node that was created at the root of this subtree.
866fn build_classification_tree<F: Float>(
867    data: &ClassificationData<'_, F>,
868    indices: &[usize],
869    nodes: &mut Vec<Node<F>>,
870    depth: usize,
871    params: &TreeParams,
872    mut rng: Option<&mut StdRng>,
873) -> usize {
874    let n = indices.len();
875
876    let mut class_counts = vec![0usize; data.n_classes];
877    for &i in indices {
878        class_counts[data.y[i]] += 1;
879    }
880
881    let should_stop = n < params.min_samples_split
882        || params.max_depth.is_some_and(|d| depth >= d)
883        || class_counts.iter().filter(|&&c| c > 0).count() <= 1;
884
885    if should_stop {
886        return make_classification_leaf(nodes, &class_counts, data.n_classes, n);
887    }
888
889    // Reborrow the rng for the split-finder; recursive children get fresh
890    // reborrows via `rng.as_deref_mut()` below.
891    let best = find_best_classification_split(
892        data,
893        indices,
894        params.min_samples_leaf,
895        rng.as_deref_mut(),
896    );
897
898    if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
899        let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
900            .iter()
901            .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
902
903        let node_idx = nodes.len();
904        nodes.push(Node::Leaf {
905            value: F::zero(),
906            class_distribution: None,
907            n_samples: 0,
908        }); // placeholder
909
910        let left_idx = build_classification_tree(
911            data,
912            &left_indices,
913            nodes,
914            depth + 1,
915            params,
916            rng.as_deref_mut(),
917        );
918        let right_idx = build_classification_tree(
919            data,
920            &right_indices,
921            nodes,
922            depth + 1,
923            params,
924            rng.as_deref_mut(),
925        );
926
927        nodes[node_idx] = Node::Split {
928            feature: best_feature,
929            threshold: best_threshold,
930            left: left_idx,
931            right: right_idx,
932            impurity_decrease: best_impurity_decrease,
933            n_samples: n,
934        };
935
936        node_idx
937    } else {
938        make_classification_leaf(nodes, &class_counts, data.n_classes, n)
939    }
940}
941
942/// Find the best split for a classification node.
943///
944/// Returns `(feature_index, threshold, weighted_impurity_decrease)` or `None`.
945///
946/// When `data.max_features_per_split` is set, `rng` must be `Some` and a fresh
947/// random subset of that many features is drawn for this single split (the
948/// per-split feature sampling used by Breiman 2001 RandomForest and
949/// scikit-learn). When `data.feature_indices` is set, the fixed per-tree
950/// subset is used instead. Otherwise all features are considered.
951fn find_best_classification_split<F: Float>(
952    data: &ClassificationData<'_, F>,
953    indices: &[usize],
954    min_samples_leaf: usize,
955    rng: Option<&mut StdRng>,
956) -> Option<(usize, F, F)> {
957    let n = indices.len();
958    let n_f = F::from(n).unwrap();
959    let n_features = data.x.ncols();
960
961    let mut parent_counts = vec![0usize; data.n_classes];
962    for &i in indices {
963        parent_counts[data.y[i]] += 1;
964    }
965    let parent_impurity = compute_impurity::<F>(&parent_counts, n, data.criterion);
966
967    let mut best_score = F::neg_infinity();
968    let mut best_feature = 0;
969    let mut best_threshold = F::zero();
970
971    // Build the candidate feature list for this split.
972    //
973    // Priority:
974    //   1. `max_features_per_split` — sample fresh subset using rng (Breiman RF).
975    //   2. `feature_indices`        — fixed per-tree subset (Bagging).
976    //   3. otherwise                — all features (plain DT).
977    let candidate_features: Vec<usize> = match (data.max_features_per_split, rng) {
978        (Some(k), Some(rng)) => {
979            let k = k.min(n_features).max(1);
980            rand_sample_indices(rng, n_features, k).into_vec()
981        }
982        _ => match data.feature_indices {
983            Some(feat) => feat.to_vec(),
984            None => (0..n_features).collect(),
985        },
986    };
987
988    for feat in candidate_features {
989        let mut sorted_indices: Vec<usize> = indices.to_vec();
990        sorted_indices.sort_by(|&a, &b| data.x[[a, feat]].partial_cmp(&data.x[[b, feat]]).unwrap());
991
992        let mut left_counts = vec![0usize; data.n_classes];
993        let mut right_counts = parent_counts.clone();
994        let mut left_n = 0usize;
995
996        for split_pos in 0..n - 1 {
997            let idx = sorted_indices[split_pos];
998            let cls = data.y[idx];
999            left_counts[cls] += 1;
1000            right_counts[cls] -= 1;
1001            left_n += 1;
1002            let right_n = n - left_n;
1003
1004            let next_idx = sorted_indices[split_pos + 1];
1005            if data.x[[idx, feat]] == data.x[[next_idx, feat]] {
1006                continue;
1007            }
1008
1009            if left_n < min_samples_leaf || right_n < min_samples_leaf {
1010                continue;
1011            }
1012
1013            let left_impurity = compute_impurity::<F>(&left_counts, left_n, data.criterion);
1014            let right_impurity = compute_impurity::<F>(&right_counts, right_n, data.criterion);
1015            let left_weight = F::from(left_n).unwrap() / n_f;
1016            let right_weight = F::from(right_n).unwrap() / n_f;
1017            let weighted_child_impurity =
1018                left_weight * left_impurity + right_weight * right_impurity;
1019            let impurity_decrease = parent_impurity - weighted_child_impurity;
1020
1021            if impurity_decrease > best_score {
1022                best_score = impurity_decrease;
1023                best_feature = feat;
1024                best_threshold =
1025                    (data.x[[idx, feat]] + data.x[[next_idx, feat]]) / F::from(2.0).unwrap();
1026            }
1027        }
1028    }
1029
1030    if best_score > F::zero() {
1031        Some((best_feature, best_threshold, best_score * n_f))
1032    } else {
1033        None
1034    }
1035}
1036
1037/// Build a regression tree recursively.
1038fn build_regression_tree<F: Float>(
1039    data: &RegressionData<'_, F>,
1040    indices: &[usize],
1041    nodes: &mut Vec<Node<F>>,
1042    depth: usize,
1043    params: &TreeParams,
1044    mut rng: Option<&mut StdRng>,
1045) -> usize {
1046    let n = indices.len();
1047    let mean = mean_value(data.y, indices);
1048
1049    let should_stop = n < params.min_samples_split || params.max_depth.is_some_and(|d| depth >= d);
1050
1051    if should_stop {
1052        let idx = nodes.len();
1053        nodes.push(Node::Leaf {
1054            value: mean,
1055            class_distribution: None,
1056            n_samples: n,
1057        });
1058        return idx;
1059    }
1060
1061    let parent_mse = mse_for_indices(data.y, indices, mean);
1062    if parent_mse <= F::epsilon() {
1063        let idx = nodes.len();
1064        nodes.push(Node::Leaf {
1065            value: mean,
1066            class_distribution: None,
1067            n_samples: n,
1068        });
1069        return idx;
1070    }
1071
1072    let best = find_best_regression_split(
1073        data,
1074        indices,
1075        params.min_samples_leaf,
1076        rng.as_deref_mut(),
1077    );
1078
1079    if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
1080        let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
1081            .iter()
1082            .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
1083
1084        let node_idx = nodes.len();
1085        nodes.push(Node::Leaf {
1086            value: F::zero(),
1087            class_distribution: None,
1088            n_samples: 0,
1089        }); // placeholder
1090
1091        let left_idx = build_regression_tree(
1092            data,
1093            &left_indices,
1094            nodes,
1095            depth + 1,
1096            params,
1097            rng.as_deref_mut(),
1098        );
1099        let right_idx = build_regression_tree(
1100            data,
1101            &right_indices,
1102            nodes,
1103            depth + 1,
1104            params,
1105            rng.as_deref_mut(),
1106        );
1107
1108        nodes[node_idx] = Node::Split {
1109            feature: best_feature,
1110            threshold: best_threshold,
1111            left: left_idx,
1112            right: right_idx,
1113            impurity_decrease: best_impurity_decrease,
1114            n_samples: n,
1115        };
1116
1117        node_idx
1118    } else {
1119        let idx = nodes.len();
1120        nodes.push(Node::Leaf {
1121            value: mean,
1122            class_distribution: None,
1123            n_samples: n,
1124        });
1125        idx
1126    }
1127}
1128
1129/// Find the best split for a regression node using MSE reduction.
1130///
1131/// Returns `(feature_index, threshold, weighted_mse_decrease)` or `None`.
1132///
1133/// See [`find_best_classification_split`] for the candidate-feature selection
1134/// rules (per-split sampling vs fixed subset vs all features).
1135fn find_best_regression_split<F: Float>(
1136    data: &RegressionData<'_, F>,
1137    indices: &[usize],
1138    min_samples_leaf: usize,
1139    rng: Option<&mut StdRng>,
1140) -> Option<(usize, F, F)> {
1141    let n = indices.len();
1142    let n_f = F::from(n).unwrap();
1143    let n_features = data.x.ncols();
1144
1145    let parent_sum: F = indices
1146        .iter()
1147        .map(|&i| data.y[i])
1148        .fold(F::zero(), |a, b| a + b);
1149    let parent_sum_sq: F = indices
1150        .iter()
1151        .map(|&i| data.y[i] * data.y[i])
1152        .fold(F::zero(), |a, b| a + b);
1153    let parent_mse = parent_sum_sq / n_f - (parent_sum / n_f) * (parent_sum / n_f);
1154
1155    let mut best_score = F::neg_infinity();
1156    let mut best_feature = 0;
1157    let mut best_threshold = F::zero();
1158
1159    let candidate_features: Vec<usize> = match (data.max_features_per_split, rng) {
1160        (Some(k), Some(rng)) => {
1161            let k = k.min(n_features).max(1);
1162            rand_sample_indices(rng, n_features, k).into_vec()
1163        }
1164        _ => match data.feature_indices {
1165            Some(feat) => feat.to_vec(),
1166            None => (0..n_features).collect(),
1167        },
1168    };
1169
1170    for feat in candidate_features {
1171        let mut sorted_indices: Vec<usize> = indices.to_vec();
1172        sorted_indices.sort_by(|&a, &b| data.x[[a, feat]].partial_cmp(&data.x[[b, feat]]).unwrap());
1173
1174        let mut left_sum = F::zero();
1175        let mut left_sum_sq = F::zero();
1176        let mut left_n: usize = 0;
1177
1178        for split_pos in 0..n - 1 {
1179            let idx = sorted_indices[split_pos];
1180            let val = data.y[idx];
1181            left_sum = left_sum + val;
1182            left_sum_sq = left_sum_sq + val * val;
1183            left_n += 1;
1184            let right_n = n - left_n;
1185
1186            let next_idx = sorted_indices[split_pos + 1];
1187            if data.x[[idx, feat]] == data.x[[next_idx, feat]] {
1188                continue;
1189            }
1190
1191            if left_n < min_samples_leaf || right_n < min_samples_leaf {
1192                continue;
1193            }
1194
1195            let left_n_f = F::from(left_n).unwrap();
1196            let right_n_f = F::from(right_n).unwrap();
1197
1198            let left_mean = left_sum / left_n_f;
1199            let left_mse = left_sum_sq / left_n_f - left_mean * left_mean;
1200
1201            let right_sum = parent_sum - left_sum;
1202            let right_sum_sq = parent_sum_sq - left_sum_sq;
1203            let right_mean = right_sum / right_n_f;
1204            let right_mse = right_sum_sq / right_n_f - right_mean * right_mean;
1205
1206            let weighted_child_mse = (left_n_f * left_mse + right_n_f * right_mse) / n_f;
1207            let mse_decrease = parent_mse - weighted_child_mse;
1208
1209            if mse_decrease > best_score {
1210                best_score = mse_decrease;
1211                best_feature = feat;
1212                best_threshold =
1213                    (data.x[[idx, feat]] + data.x[[next_idx, feat]]) / F::from(2.0).unwrap();
1214            }
1215        }
1216    }
1217
1218    if best_score > F::zero() {
1219        Some((best_feature, best_threshold, best_score * n_f))
1220    } else {
1221        None
1222    }
1223}
1224
1225/// Compute normalised feature importances from impurity decreases in the tree.
1226pub(crate) fn compute_feature_importances<F: Float>(
1227    nodes: &[Node<F>],
1228    n_features: usize,
1229    _total_samples: usize,
1230) -> Array1<F> {
1231    let mut importances = Array1::zeros(n_features);
1232    for node in nodes {
1233        if let Node::Split {
1234            feature,
1235            impurity_decrease,
1236            ..
1237        } = node
1238        {
1239            importances[*feature] = importances[*feature] + *impurity_decrease;
1240        }
1241    }
1242    let total: F = importances.iter().copied().fold(F::zero(), |a, b| a + b);
1243    if total > F::zero() {
1244        importances.mapv_inplace(|v| v / total);
1245    }
1246    importances
1247}
1248
1249/// Aggregate per-tree feature importances across an ensemble.
1250///
1251/// - `trees`: the per-tree node lists.
1252/// - `feature_indices`: when `Some`, each tree was trained on a feature
1253///   subset; the tree-local feature indices are remapped through
1254///   `feature_indices[t]` back to the original feature space. When `None`,
1255///   every tree uses the full feature space directly.
1256/// - `weights`: when `Some`, each tree's importances are scaled by
1257///   `weights[t]` before aggregation (used by AdaBoost). When `None`,
1258///   uniform weights of 1.
1259/// - `n_features`: width of the original feature space.
1260///
1261/// Returns an `Array1<F>` of length `n_features`, normalized to sum to 1
1262/// (or all zeros if no splits had any impurity decrease).
1263pub(crate) fn aggregate_tree_importances<F: Float>(
1264    trees: &[Vec<Node<F>>],
1265    feature_indices: Option<&[Vec<usize>]>,
1266    weights: Option<&[F]>,
1267    n_features: usize,
1268) -> Array1<F> {
1269    let mut total_imp = Array1::<F>::zeros(n_features);
1270    for (t, nodes) in trees.iter().enumerate() {
1271        let w = weights.map_or(F::one(), |ws| ws[t]);
1272        for node in nodes {
1273            if let Node::Split {
1274                feature,
1275                impurity_decrease,
1276                ..
1277            } = node
1278            {
1279                let original_feature = match feature_indices {
1280                    Some(map) => map[t][*feature],
1281                    None => *feature,
1282                };
1283                total_imp[original_feature] =
1284                    total_imp[original_feature] + w * *impurity_decrease;
1285            }
1286        }
1287    }
1288    let total: F = total_imp.iter().copied().fold(F::zero(), |a, b| a + b);
1289    if total > F::zero() {
1290        total_imp.mapv_inplace(|v| v / total);
1291    }
1292    total_imp
1293}
1294
1295// ---------------------------------------------------------------------------
1296// Public builders for forest usage
1297// ---------------------------------------------------------------------------
1298
1299/// Build a classification tree with a subset of features considered per split.
1300///
1301/// Used internally by `RandomForestClassifier` to build individual trees.
1302#[allow(clippy::too_many_arguments)]
1303pub(crate) fn build_classification_tree_with_feature_subset<F: Float>(
1304    x: &Array2<F>,
1305    y: &[usize],
1306    n_classes: usize,
1307    indices: &[usize],
1308    feature_indices: &[usize],
1309    params: &TreeParams,
1310    criterion: ClassificationCriterion,
1311) -> Vec<Node<F>> {
1312    let data = ClassificationData {
1313        x,
1314        y,
1315        n_classes,
1316        feature_indices: Some(feature_indices),
1317        max_features_per_split: None,
1318        criterion,
1319    };
1320    let mut nodes = Vec::new();
1321    build_classification_tree(&data, indices, &mut nodes, 0, params, None);
1322    nodes
1323}
1324
1325/// Build a classification tree with **per-split** random feature sampling.
1326///
1327/// At every split node, a fresh random subset of `max_features` features is
1328/// drawn from the full `0..n_features` pool. This is the Breiman (2001)
1329/// RandomForest behaviour and matches scikit-learn.
1330///
1331/// Used by `RandomForestClassifier` and `ExtraTreesClassifier`.
1332#[allow(clippy::too_many_arguments)]
1333pub(crate) fn build_classification_tree_per_split_features<F: Float>(
1334    x: &Array2<F>,
1335    y: &[usize],
1336    n_classes: usize,
1337    indices: &[usize],
1338    max_features: usize,
1339    params: &TreeParams,
1340    criterion: ClassificationCriterion,
1341    seed: u64,
1342) -> Vec<Node<F>> {
1343    let data = ClassificationData {
1344        x,
1345        y,
1346        n_classes,
1347        feature_indices: None,
1348        max_features_per_split: Some(max_features),
1349        criterion,
1350    };
1351    let mut rng = StdRng::seed_from_u64(seed);
1352    let mut nodes = Vec::new();
1353    build_classification_tree(&data, indices, &mut nodes, 0, params, Some(&mut rng));
1354    nodes
1355}
1356
1357/// Build a regression tree with a subset of features considered per split.
1358pub(crate) fn build_regression_tree_with_feature_subset<F: Float>(
1359    x: &Array2<F>,
1360    y: &Array1<F>,
1361    indices: &[usize],
1362    feature_indices: &[usize],
1363    params: &TreeParams,
1364) -> Vec<Node<F>> {
1365    let data = RegressionData {
1366        x,
1367        y,
1368        feature_indices: Some(feature_indices),
1369        max_features_per_split: None,
1370    };
1371    let mut nodes = Vec::new();
1372    build_regression_tree(&data, indices, &mut nodes, 0, params, None);
1373    nodes
1374}
1375
1376/// Build a regression tree with **per-split** random feature sampling
1377/// (Breiman 2001 RandomForest, sklearn-equivalent).
1378///
1379/// Used by `RandomForestRegressor` and `ExtraTreesRegressor`.
1380pub(crate) fn build_regression_tree_per_split_features<F: Float>(
1381    x: &Array2<F>,
1382    y: &Array1<F>,
1383    indices: &[usize],
1384    max_features: usize,
1385    params: &TreeParams,
1386    seed: u64,
1387) -> Vec<Node<F>> {
1388    let data = RegressionData {
1389        x,
1390        y,
1391        feature_indices: None,
1392        max_features_per_split: Some(max_features),
1393    };
1394    let mut rng = StdRng::seed_from_u64(seed);
1395    let mut nodes = Vec::new();
1396    build_regression_tree(&data, indices, &mut nodes, 0, params, Some(&mut rng));
1397    nodes
1398}
1399
1400// ---------------------------------------------------------------------------
1401// Tests
1402// ---------------------------------------------------------------------------
1403
1404#[cfg(test)]
1405mod tests {
1406    use super::*;
1407    use approx::assert_relative_eq;
1408    use ndarray::array;
1409
1410    // -- Classifier tests --
1411
1412    #[test]
1413    fn test_classifier_simple_binary() {
1414        let x = Array2::from_shape_vec(
1415            (6, 2),
1416            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
1417        )
1418        .unwrap();
1419        let y = array![0, 0, 0, 1, 1, 1];
1420
1421        let model = DecisionTreeClassifier::<f64>::new();
1422        let fitted = model.fit(&x, &y).unwrap();
1423        let preds = fitted.predict(&x).unwrap();
1424
1425        assert_eq!(preds.len(), 6);
1426        for i in 0..3 {
1427            assert_eq!(preds[i], 0);
1428        }
1429        for i in 3..6 {
1430            assert_eq!(preds[i], 1);
1431        }
1432    }
1433
1434    #[test]
1435    fn test_classifier_single_class() {
1436        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1437        let y = array![0, 0, 0];
1438
1439        let model = DecisionTreeClassifier::<f64>::new();
1440        let fitted = model.fit(&x, &y).unwrap();
1441        let preds = fitted.predict(&x).unwrap();
1442
1443        assert_eq!(preds, array![0, 0, 0]);
1444    }
1445
1446    #[test]
1447    fn test_classifier_max_depth_1() {
1448        let x =
1449            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1450        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1451
1452        let model = DecisionTreeClassifier::<f64>::new().with_max_depth(Some(1));
1453        let fitted = model.fit(&x, &y).unwrap();
1454        let preds = fitted.predict(&x).unwrap();
1455
1456        for i in 0..4 {
1457            assert_eq!(preds[i], 0);
1458        }
1459        for i in 4..8 {
1460            assert_eq!(preds[i], 1);
1461        }
1462    }
1463
1464    #[test]
1465    fn test_classifier_min_samples_split() {
1466        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1467        let y = array![0, 0, 0, 1, 1, 1];
1468
1469        let model = DecisionTreeClassifier::<f64>::new().with_min_samples_split(7);
1470        let fitted = model.fit(&x, &y).unwrap();
1471        let preds = fitted.predict(&x).unwrap();
1472
1473        let majority = preds[0];
1474        for &p in &preds {
1475            assert_eq!(p, majority);
1476        }
1477    }
1478
1479    #[test]
1480    fn test_classifier_min_samples_leaf() {
1481        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1482        let y = array![0, 0, 0, 1, 1, 1];
1483
1484        let model = DecisionTreeClassifier::<f64>::new().with_min_samples_leaf(4);
1485        let fitted = model.fit(&x, &y).unwrap();
1486        let preds = fitted.predict(&x).unwrap();
1487
1488        let majority = preds[0];
1489        for &p in &preds {
1490            assert_eq!(p, majority);
1491        }
1492    }
1493
1494    #[test]
1495    fn test_classifier_gini_vs_entropy() {
1496        let x = Array2::from_shape_vec(
1497            (8, 2),
1498            vec![
1499                1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0, 5.0, 5.0, 5.0, 6.0, 6.0, 5.0, 6.0, 6.0,
1500            ],
1501        )
1502        .unwrap();
1503        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1504
1505        let gini_model =
1506            DecisionTreeClassifier::<f64>::new().with_criterion(ClassificationCriterion::Gini);
1507        let entropy_model =
1508            DecisionTreeClassifier::<f64>::new().with_criterion(ClassificationCriterion::Entropy);
1509
1510        let fitted_gini = gini_model.fit(&x, &y).unwrap();
1511        let fitted_entropy = entropy_model.fit(&x, &y).unwrap();
1512
1513        let preds_gini = fitted_gini.predict(&x).unwrap();
1514        let preds_entropy = fitted_entropy.predict(&x).unwrap();
1515
1516        assert_eq!(preds_gini, y);
1517        assert_eq!(preds_entropy, y);
1518    }
1519
1520    #[test]
1521    fn test_classifier_predict_proba() {
1522        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1523        let y = array![0, 0, 0, 1, 1, 1];
1524
1525        let model = DecisionTreeClassifier::<f64>::new();
1526        let fitted = model.fit(&x, &y).unwrap();
1527        let proba = fitted.predict_proba(&x).unwrap();
1528
1529        assert_eq!(proba.dim(), (6, 2));
1530        for i in 0..6 {
1531            let row_sum: f64 = proba.row(i).iter().sum();
1532            assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
1533        }
1534    }
1535
1536    #[test]
1537    fn test_classifier_shape_mismatch_fit() {
1538        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1539        let y = array![0, 1];
1540
1541        let model = DecisionTreeClassifier::<f64>::new();
1542        assert!(model.fit(&x, &y).is_err());
1543    }
1544
1545    #[test]
1546    fn test_classifier_shape_mismatch_predict() {
1547        let x =
1548            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1549        let y = array![0, 0, 1, 1];
1550
1551        let model = DecisionTreeClassifier::<f64>::new();
1552        let fitted = model.fit(&x, &y).unwrap();
1553
1554        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1555        assert!(fitted.predict(&x_bad).is_err());
1556    }
1557
1558    #[test]
1559    fn test_classifier_empty_data() {
1560        let x = Array2::<f64>::zeros((0, 2));
1561        let y = Array1::<usize>::zeros(0);
1562
1563        let model = DecisionTreeClassifier::<f64>::new();
1564        assert!(model.fit(&x, &y).is_err());
1565    }
1566
1567    #[test]
1568    fn test_classifier_feature_importances() {
1569        let x = Array2::from_shape_vec(
1570            (8, 2),
1571            vec![
1572                1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1573            ],
1574        )
1575        .unwrap();
1576        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1577
1578        let model = DecisionTreeClassifier::<f64>::new();
1579        let fitted = model.fit(&x, &y).unwrap();
1580        let importances = fitted.feature_importances();
1581
1582        assert_eq!(importances.len(), 2);
1583        assert!(importances[0] > 0.0);
1584        let sum: f64 = importances.iter().sum();
1585        assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
1586    }
1587
1588    #[test]
1589    fn test_classifier_has_classes() {
1590        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1591        let y = array![0, 1, 2, 0, 1, 2];
1592
1593        let model = DecisionTreeClassifier::<f64>::new();
1594        let fitted = model.fit(&x, &y).unwrap();
1595
1596        assert_eq!(fitted.classes(), &[0, 1, 2]);
1597        assert_eq!(fitted.n_classes(), 3);
1598    }
1599
1600    #[test]
1601    fn test_classifier_invalid_min_samples_split() {
1602        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1603        let y = array![0, 0, 1, 1];
1604
1605        let model = DecisionTreeClassifier::<f64>::new().with_min_samples_split(1);
1606        assert!(model.fit(&x, &y).is_err());
1607    }
1608
1609    #[test]
1610    fn test_classifier_invalid_min_samples_leaf() {
1611        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1612        let y = array![0, 0, 1, 1];
1613
1614        let model = DecisionTreeClassifier::<f64>::new().with_min_samples_leaf(0);
1615        assert!(model.fit(&x, &y).is_err());
1616    }
1617
1618    #[test]
1619    fn test_classifier_multiclass() {
1620        let x = Array2::from_shape_vec((9, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1621            .unwrap();
1622        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
1623
1624        let model = DecisionTreeClassifier::<f64>::new();
1625        let fitted = model.fit(&x, &y).unwrap();
1626        let preds = fitted.predict(&x).unwrap();
1627
1628        assert_eq!(preds, y);
1629    }
1630
1631    #[test]
1632    fn test_classifier_pipeline_integration() {
1633        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1634        let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
1635
1636        let model = DecisionTreeClassifier::<f64>::new();
1637        let fitted = model.fit_pipeline(&x, &y).unwrap();
1638        let preds = fitted.predict_pipeline(&x).unwrap();
1639        assert_eq!(preds.len(), 6);
1640    }
1641
1642    // -- Regressor tests --
1643
1644    #[test]
1645    fn test_regressor_simple() {
1646        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
1647        let y = array![1.0, 2.0, 3.0, 4.0, 5.0];
1648
1649        let model = DecisionTreeRegressor::<f64>::new();
1650        let fitted = model.fit(&x, &y).unwrap();
1651        let preds = fitted.predict(&x).unwrap();
1652
1653        for (p, &actual) in preds.iter().zip(y.iter()) {
1654            assert_relative_eq!(*p, actual, epsilon = 1e-10);
1655        }
1656    }
1657
1658    #[test]
1659    fn test_regressor_max_depth() {
1660        let x =
1661            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1662        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1663
1664        let model = DecisionTreeRegressor::<f64>::new().with_max_depth(Some(1));
1665        let fitted = model.fit(&x, &y).unwrap();
1666        let preds = fitted.predict(&x).unwrap();
1667
1668        for i in 0..4 {
1669            assert_relative_eq!(preds[i], 1.0, epsilon = 1e-10);
1670        }
1671        for i in 4..8 {
1672            assert_relative_eq!(preds[i], 5.0, epsilon = 1e-10);
1673        }
1674    }
1675
1676    #[test]
1677    fn test_regressor_constant_target() {
1678        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1679        let y = array![3.0, 3.0, 3.0, 3.0];
1680
1681        let model = DecisionTreeRegressor::<f64>::new();
1682        let fitted = model.fit(&x, &y).unwrap();
1683        let preds = fitted.predict(&x).unwrap();
1684
1685        for &p in &preds {
1686            assert_relative_eq!(p, 3.0, epsilon = 1e-10);
1687        }
1688    }
1689
1690    #[test]
1691    fn test_regressor_shape_mismatch_fit() {
1692        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1693        let y = array![1.0, 2.0];
1694
1695        let model = DecisionTreeRegressor::<f64>::new();
1696        assert!(model.fit(&x, &y).is_err());
1697    }
1698
1699    #[test]
1700    fn test_regressor_shape_mismatch_predict() {
1701        let x =
1702            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1703        let y = array![1.0, 2.0, 3.0, 4.0];
1704
1705        let model = DecisionTreeRegressor::<f64>::new();
1706        let fitted = model.fit(&x, &y).unwrap();
1707
1708        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1709        assert!(fitted.predict(&x_bad).is_err());
1710    }
1711
1712    #[test]
1713    fn test_regressor_empty_data() {
1714        let x = Array2::<f64>::zeros((0, 2));
1715        let y = Array1::<f64>::zeros(0);
1716
1717        let model = DecisionTreeRegressor::<f64>::new();
1718        assert!(model.fit(&x, &y).is_err());
1719    }
1720
1721    #[test]
1722    fn test_regressor_feature_importances() {
1723        let x = Array2::from_shape_vec(
1724            (8, 2),
1725            vec![
1726                1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1727            ],
1728        )
1729        .unwrap();
1730        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1731
1732        let model = DecisionTreeRegressor::<f64>::new();
1733        let fitted = model.fit(&x, &y).unwrap();
1734        let importances = fitted.feature_importances();
1735
1736        assert_eq!(importances.len(), 2);
1737        assert!(importances[0] > 0.0);
1738        let sum: f64 = importances.iter().sum();
1739        assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
1740    }
1741
1742    #[test]
1743    fn test_regressor_min_samples_split() {
1744        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1745        let y = array![1.0, 2.0, 3.0, 4.0];
1746
1747        let model = DecisionTreeRegressor::<f64>::new().with_min_samples_split(5);
1748        let fitted = model.fit(&x, &y).unwrap();
1749        let preds = fitted.predict(&x).unwrap();
1750
1751        let mean = 2.5;
1752        for &p in &preds {
1753            assert_relative_eq!(p, mean, epsilon = 1e-10);
1754        }
1755    }
1756
1757    #[test]
1758    fn test_regressor_pipeline_integration() {
1759        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1760        let y = array![1.0, 2.0, 3.0, 4.0];
1761
1762        let model = DecisionTreeRegressor::<f64>::new();
1763        let fitted = model.fit_pipeline(&x, &y).unwrap();
1764        let preds = fitted.predict_pipeline(&x).unwrap();
1765        assert_eq!(preds.len(), 4);
1766    }
1767
1768    #[test]
1769    fn test_regressor_f32_support() {
1770        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1771        let y = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
1772
1773        let model = DecisionTreeRegressor::<f32>::new();
1774        let fitted = model.fit(&x, &y).unwrap();
1775        let preds = fitted.predict(&x).unwrap();
1776        assert_eq!(preds.len(), 4);
1777    }
1778
1779    #[test]
1780    fn test_classifier_f32_support() {
1781        let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1782        let y = array![0, 0, 0, 1, 1, 1];
1783
1784        let model = DecisionTreeClassifier::<f32>::new();
1785        let fitted = model.fit(&x, &y).unwrap();
1786        let preds = fitted.predict(&x).unwrap();
1787        assert_eq!(preds.len(), 6);
1788    }
1789
1790    // -- Internal helper tests --
1791
1792    #[test]
1793    fn test_gini_impurity_pure() {
1794        let counts = vec![5, 0];
1795        let imp: f64 = gini_impurity(&counts, 5);
1796        assert_relative_eq!(imp, 0.0, epsilon = 1e-10);
1797    }
1798
1799    #[test]
1800    fn test_gini_impurity_balanced() {
1801        let counts = vec![5, 5];
1802        let imp: f64 = gini_impurity(&counts, 10);
1803        assert_relative_eq!(imp, 0.5, epsilon = 1e-10);
1804    }
1805
1806    #[test]
1807    fn test_entropy_pure() {
1808        let counts = vec![5, 0];
1809        let ent: f64 = entropy_impurity(&counts, 5);
1810        assert_relative_eq!(ent, 0.0, epsilon = 1e-10);
1811    }
1812
1813    #[test]
1814    fn test_entropy_balanced() {
1815        let counts = vec![5, 5];
1816        let ent: f64 = entropy_impurity(&counts, 10);
1817        assert_relative_eq!(ent, 2.0f64.ln(), epsilon = 1e-10);
1818    }
1819}