Skip to main content

ferrolearn_tree/
decision_tree.rs

1//! CART decision tree classifiers and regressors.
2//!
3//! This module provides [`DecisionTreeClassifier`] and [`DecisionTreeRegressor`],
4//! implementing the Classification and Regression Trees (CART) algorithm with
5//! configurable splitting criteria, depth limits, and minimum sample constraints.
6//!
7//! # Examples
8//!
9//! ```
10//! use ferrolearn_tree::DecisionTreeClassifier;
11//! use ferrolearn_core::{Fit, Predict};
12//! use ndarray::{array, Array1, Array2};
13//!
14//! let x = Array2::from_shape_vec((6, 2), vec![
15//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,
16//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,
17//! ]).unwrap();
18//! let y = array![0, 0, 0, 1, 1, 1];
19//!
20//! let model = DecisionTreeClassifier::<f64>::new();
21//! let fitted = model.fit(&x, &y).unwrap();
22//! let preds = fitted.predict(&x).unwrap();
23//! ```
24
25use ferrolearn_core::error::FerroError;
26use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
27use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
28use ferrolearn_core::traits::{Fit, Predict};
29use ndarray::{Array1, Array2};
30use num_traits::{Float, FromPrimitive, ToPrimitive};
31use serde::{Deserialize, Serialize};
32
33// ---------------------------------------------------------------------------
34// Splitting criterion enums
35// ---------------------------------------------------------------------------
36
37/// Splitting criterion for classification trees.
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum ClassificationCriterion {
40    /// Gini impurity.
41    Gini,
42    /// Shannon entropy.
43    Entropy,
44}
45
46/// Splitting criterion for regression trees.
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
48pub enum RegressionCriterion {
49    /// Mean squared error.
50    Mse,
51}
52
53// ---------------------------------------------------------------------------
54// Node representation (flat vec for cache efficiency)
55// ---------------------------------------------------------------------------
56
57/// A single node in the decision tree, stored in a flat `Vec` for cache efficiency.
58///
59/// Internal nodes hold a split (feature index + threshold), while leaf nodes
60/// store a prediction value and optional class distribution (for classifiers).
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub enum Node<F> {
63    /// An internal split node.
64    Split {
65        /// Feature index used for the split.
66        feature: usize,
67        /// Threshold value; samples with `x[feature] <= threshold` go left.
68        threshold: F,
69        /// Index of the left child node in the flat vec.
70        left: usize,
71        /// Index of the right child node in the flat vec.
72        right: usize,
73        /// Weighted impurity decrease from this split (for feature importance).
74        impurity_decrease: F,
75        /// Number of samples that reached this node during training.
76        n_samples: usize,
77    },
78    /// A leaf node that stores a prediction.
79    Leaf {
80        /// Predicted value: class label (as F) for classifiers, mean for regressors.
81        value: F,
82        /// Class distribution (proportion of each class). Only used by classifiers.
83        class_distribution: Option<Vec<F>>,
84        /// Number of samples that reached this node during training.
85        n_samples: usize,
86    },
87}
88
89// ---------------------------------------------------------------------------
90// Internal config structs (to reduce argument counts)
91// ---------------------------------------------------------------------------
92
93/// Configuration parameters for tree building, bundled to reduce argument counts.
94#[derive(Debug, Clone, Copy)]
95pub(crate) struct TreeParams {
96    pub(crate) max_depth: Option<usize>,
97    pub(crate) min_samples_split: usize,
98    pub(crate) min_samples_leaf: usize,
99}
100
101/// Data references for classification tree building.
102struct ClassificationData<'a, F> {
103    x: &'a Array2<F>,
104    y: &'a [usize],
105    n_classes: usize,
106    feature_indices: Option<&'a [usize]>,
107    criterion: ClassificationCriterion,
108}
109
110/// Data references for regression tree building.
111struct RegressionData<'a, F> {
112    x: &'a Array2<F>,
113    y: &'a Array1<F>,
114    feature_indices: Option<&'a [usize]>,
115}
116
117// ---------------------------------------------------------------------------
118// DecisionTreeClassifier
119// ---------------------------------------------------------------------------
120
121/// CART decision tree classifier.
122///
123/// Builds a binary tree by recursively finding the feature and threshold that
124/// maximises the reduction in the chosen impurity criterion (Gini or Entropy).
125///
126/// # Type Parameters
127///
128/// - `F`: The floating-point type (`f32` or `f64`).
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct DecisionTreeClassifier<F> {
131    /// Maximum depth of the tree. `None` means unlimited.
132    pub max_depth: Option<usize>,
133    /// Minimum number of samples required to split an internal node.
134    pub min_samples_split: usize,
135    /// Minimum number of samples required in a leaf node.
136    pub min_samples_leaf: usize,
137    /// Splitting criterion.
138    pub criterion: ClassificationCriterion,
139    _marker: std::marker::PhantomData<F>,
140}
141
142impl<F: Float> DecisionTreeClassifier<F> {
143    /// Create a new `DecisionTreeClassifier` with default settings.
144    ///
145    /// Defaults: `max_depth = None`, `min_samples_split = 2`,
146    /// `min_samples_leaf = 1`, `criterion = Gini`.
147    #[must_use]
148    pub fn new() -> Self {
149        Self {
150            max_depth: None,
151            min_samples_split: 2,
152            min_samples_leaf: 1,
153            criterion: ClassificationCriterion::Gini,
154            _marker: std::marker::PhantomData,
155        }
156    }
157
158    /// Set the maximum tree depth.
159    #[must_use]
160    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
161        self.max_depth = max_depth;
162        self
163    }
164
165    /// Set the minimum number of samples required to split a node.
166    #[must_use]
167    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
168        self.min_samples_split = min_samples_split;
169        self
170    }
171
172    /// Set the minimum number of samples required in a leaf node.
173    #[must_use]
174    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
175        self.min_samples_leaf = min_samples_leaf;
176        self
177    }
178
179    /// Set the splitting criterion.
180    #[must_use]
181    pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
182        self.criterion = criterion;
183        self
184    }
185}
186
187impl<F: Float> Default for DecisionTreeClassifier<F> {
188    fn default() -> Self {
189        Self::new()
190    }
191}
192
193// ---------------------------------------------------------------------------
194// FittedDecisionTreeClassifier
195// ---------------------------------------------------------------------------
196
197/// A fitted CART decision tree classifier.
198///
199/// Stores the learned tree as a flat `Vec<Node<F>>` for cache-friendly traversal.
200/// Implements [`Predict`] for generating class predictions and
201/// [`HasFeatureImportances`] for inspecting per-feature importance scores.
202#[derive(Debug, Clone)]
203pub struct FittedDecisionTreeClassifier<F> {
204    /// Flat node storage; index 0 is the root.
205    nodes: Vec<Node<F>>,
206    /// Sorted unique class labels observed during training.
207    classes: Vec<usize>,
208    /// Number of features the model was trained on.
209    n_features: usize,
210    /// Per-feature importance scores (normalised to sum to 1).
211    feature_importances: Array1<F>,
212}
213
214impl<F: Float + Send + Sync + 'static> FittedDecisionTreeClassifier<F> {
215    /// Returns a reference to the flat node storage of the tree.
216    #[must_use]
217    pub fn nodes(&self) -> &[Node<F>] {
218        &self.nodes
219    }
220
221    /// Returns the number of features the model was trained on.
222    #[must_use]
223    pub fn n_features(&self) -> usize {
224        self.n_features
225    }
226
227    /// Predict class probabilities for each sample.
228    ///
229    /// Returns a 2-D array of shape `(n_samples, n_classes)`.
230    ///
231    /// # Errors
232    ///
233    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
234    /// not match the training data.
235    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
236        if x.ncols() != self.n_features {
237            return Err(FerroError::ShapeMismatch {
238                expected: vec![self.n_features],
239                actual: vec![x.ncols()],
240                context: "number of features must match fitted model".into(),
241            });
242        }
243        let n_samples = x.nrows();
244        let n_classes = self.classes.len();
245        let mut proba = Array2::zeros((n_samples, n_classes));
246        for i in 0..n_samples {
247            let row = x.row(i);
248            let leaf = traverse_tree(&self.nodes, &row);
249            if let Node::Leaf {
250                class_distribution: Some(ref dist),
251                ..
252            } = self.nodes[leaf]
253            {
254                for (j, &p) in dist.iter().enumerate() {
255                    proba[[i, j]] = p;
256                }
257            }
258        }
259        Ok(proba)
260    }
261}
262
263impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for DecisionTreeClassifier<F> {
264    type Fitted = FittedDecisionTreeClassifier<F>;
265    type Error = FerroError;
266
267    /// Fit the decision tree classifier on the training data.
268    ///
269    /// # Errors
270    ///
271    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
272    /// numbers of samples.
273    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
274    /// Returns [`FerroError::InvalidParameter`] if hyperparameters are invalid.
275    fn fit(
276        &self,
277        x: &Array2<F>,
278        y: &Array1<usize>,
279    ) -> Result<FittedDecisionTreeClassifier<F>, FerroError> {
280        let (n_samples, n_features) = x.dim();
281
282        if n_samples != y.len() {
283            return Err(FerroError::ShapeMismatch {
284                expected: vec![n_samples],
285                actual: vec![y.len()],
286                context: "y length must match number of samples in X".into(),
287            });
288        }
289        if n_samples == 0 {
290            return Err(FerroError::InsufficientSamples {
291                required: 1,
292                actual: 0,
293                context: "DecisionTreeClassifier requires at least one sample".into(),
294            });
295        }
296        if self.min_samples_split < 2 {
297            return Err(FerroError::InvalidParameter {
298                name: "min_samples_split".into(),
299                reason: "must be at least 2".into(),
300            });
301        }
302        if self.min_samples_leaf < 1 {
303            return Err(FerroError::InvalidParameter {
304                name: "min_samples_leaf".into(),
305                reason: "must be at least 1".into(),
306            });
307        }
308
309        // Determine unique classes.
310        let mut classes: Vec<usize> = y.iter().copied().collect();
311        classes.sort_unstable();
312        classes.dedup();
313        let n_classes = classes.len();
314
315        // Map class labels to indices 0..n_classes.
316        let y_mapped: Vec<usize> = y
317            .iter()
318            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
319            .collect();
320
321        let indices: Vec<usize> = (0..n_samples).collect();
322
323        let data = ClassificationData {
324            x,
325            y: &y_mapped,
326            n_classes,
327            feature_indices: None,
328            criterion: self.criterion,
329        };
330        let params = TreeParams {
331            max_depth: self.max_depth,
332            min_samples_split: self.min_samples_split,
333            min_samples_leaf: self.min_samples_leaf,
334        };
335
336        let mut nodes: Vec<Node<F>> = Vec::new();
337        build_classification_tree(&data, &indices, &mut nodes, 0, &params);
338
339        let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
340
341        Ok(FittedDecisionTreeClassifier {
342            nodes,
343            classes,
344            n_features,
345            feature_importances,
346        })
347    }
348}
349
350impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedDecisionTreeClassifier<F> {
351    type Output = Array1<usize>;
352    type Error = FerroError;
353
354    /// Predict class labels for the given feature matrix.
355    ///
356    /// # Errors
357    ///
358    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
359    /// not match the fitted model.
360    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
361        if x.ncols() != self.n_features {
362            return Err(FerroError::ShapeMismatch {
363                expected: vec![self.n_features],
364                actual: vec![x.ncols()],
365                context: "number of features must match fitted model".into(),
366            });
367        }
368        let n_samples = x.nrows();
369        let mut predictions = Array1::zeros(n_samples);
370        for i in 0..n_samples {
371            let row = x.row(i);
372            let leaf = traverse_tree(&self.nodes, &row);
373            if let Node::Leaf { value, .. } = self.nodes[leaf] {
374                predictions[i] = float_to_usize(value);
375            }
376        }
377        Ok(predictions)
378    }
379}
380
381impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
382    for FittedDecisionTreeClassifier<F>
383{
384    fn feature_importances(&self) -> &Array1<F> {
385        &self.feature_importances
386    }
387}
388
389impl<F: Float + Send + Sync + 'static> HasClasses for FittedDecisionTreeClassifier<F> {
390    fn classes(&self) -> &[usize] {
391        &self.classes
392    }
393
394    fn n_classes(&self) -> usize {
395        self.classes.len()
396    }
397}
398
399// Pipeline integration.
400impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
401    for DecisionTreeClassifier<F>
402{
403    fn fit_pipeline(
404        &self,
405        x: &Array2<F>,
406        y: &Array1<F>,
407    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
408        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
409        let fitted = self.fit(x, &y_usize)?;
410        Ok(Box::new(FittedClassifierPipelineAdapter(fitted)))
411    }
412}
413
414/// Adapter to make `FittedDecisionTreeClassifier<F>` work as a pipeline estimator.
415struct FittedClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
416    FittedDecisionTreeClassifier<F>,
417);
418
419impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
420    for FittedClassifierPipelineAdapter<F>
421{
422    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
423        let preds = self.0.predict(x)?;
424        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
425    }
426}
427
428// ---------------------------------------------------------------------------
429// DecisionTreeRegressor
430// ---------------------------------------------------------------------------
431
432/// CART decision tree regressor.
433///
434/// Builds a binary tree by recursively finding the feature and threshold that
435/// minimises the mean squared error of the split.
436///
437/// # Type Parameters
438///
439/// - `F`: The floating-point type (`f32` or `f64`).
440#[derive(Debug, Clone, Serialize, Deserialize)]
441pub struct DecisionTreeRegressor<F> {
442    /// Maximum depth of the tree. `None` means unlimited.
443    pub max_depth: Option<usize>,
444    /// Minimum number of samples required to split an internal node.
445    pub min_samples_split: usize,
446    /// Minimum number of samples required in a leaf node.
447    pub min_samples_leaf: usize,
448    /// Splitting criterion.
449    pub criterion: RegressionCriterion,
450    _marker: std::marker::PhantomData<F>,
451}
452
453impl<F: Float> DecisionTreeRegressor<F> {
454    /// Create a new `DecisionTreeRegressor` with default settings.
455    ///
456    /// Defaults: `max_depth = None`, `min_samples_split = 2`,
457    /// `min_samples_leaf = 1`, `criterion = MSE`.
458    #[must_use]
459    pub fn new() -> Self {
460        Self {
461            max_depth: None,
462            min_samples_split: 2,
463            min_samples_leaf: 1,
464            criterion: RegressionCriterion::Mse,
465            _marker: std::marker::PhantomData,
466        }
467    }
468
469    /// Set the maximum tree depth.
470    #[must_use]
471    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
472        self.max_depth = max_depth;
473        self
474    }
475
476    /// Set the minimum number of samples required to split a node.
477    #[must_use]
478    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
479        self.min_samples_split = min_samples_split;
480        self
481    }
482
483    /// Set the minimum number of samples required in a leaf node.
484    #[must_use]
485    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
486        self.min_samples_leaf = min_samples_leaf;
487        self
488    }
489
490    /// Set the splitting criterion.
491    #[must_use]
492    pub fn with_criterion(mut self, criterion: RegressionCriterion) -> Self {
493        self.criterion = criterion;
494        self
495    }
496}
497
498impl<F: Float> Default for DecisionTreeRegressor<F> {
499    fn default() -> Self {
500        Self::new()
501    }
502}
503
504// ---------------------------------------------------------------------------
505// FittedDecisionTreeRegressor
506// ---------------------------------------------------------------------------
507
508/// A fitted CART decision tree regressor.
509///
510/// Stores the learned tree as a flat `Vec<Node<F>>` for cache-friendly traversal.
511#[derive(Debug, Clone)]
512pub struct FittedDecisionTreeRegressor<F> {
513    /// Flat node storage; index 0 is the root.
514    nodes: Vec<Node<F>>,
515    /// Number of features the model was trained on.
516    n_features: usize,
517    /// Per-feature importance scores (normalised to sum to 1).
518    feature_importances: Array1<F>,
519}
520
521impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for DecisionTreeRegressor<F> {
522    type Fitted = FittedDecisionTreeRegressor<F>;
523    type Error = FerroError;
524
525    /// Fit the decision tree regressor on the training data.
526    ///
527    /// # Errors
528    ///
529    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
530    /// numbers of samples.
531    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
532    /// Returns [`FerroError::InvalidParameter`] if hyperparameters are invalid.
533    fn fit(
534        &self,
535        x: &Array2<F>,
536        y: &Array1<F>,
537    ) -> Result<FittedDecisionTreeRegressor<F>, FerroError> {
538        let (n_samples, n_features) = x.dim();
539
540        if n_samples != y.len() {
541            return Err(FerroError::ShapeMismatch {
542                expected: vec![n_samples],
543                actual: vec![y.len()],
544                context: "y length must match number of samples in X".into(),
545            });
546        }
547        if n_samples == 0 {
548            return Err(FerroError::InsufficientSamples {
549                required: 1,
550                actual: 0,
551                context: "DecisionTreeRegressor requires at least one sample".into(),
552            });
553        }
554        if self.min_samples_split < 2 {
555            return Err(FerroError::InvalidParameter {
556                name: "min_samples_split".into(),
557                reason: "must be at least 2".into(),
558            });
559        }
560        if self.min_samples_leaf < 1 {
561            return Err(FerroError::InvalidParameter {
562                name: "min_samples_leaf".into(),
563                reason: "must be at least 1".into(),
564            });
565        }
566
567        let indices: Vec<usize> = (0..n_samples).collect();
568
569        let data = RegressionData {
570            x,
571            y,
572            feature_indices: None,
573        };
574        let params = TreeParams {
575            max_depth: self.max_depth,
576            min_samples_split: self.min_samples_split,
577            min_samples_leaf: self.min_samples_leaf,
578        };
579
580        let mut nodes: Vec<Node<F>> = Vec::new();
581        build_regression_tree(&data, &indices, &mut nodes, 0, &params);
582
583        let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
584
585        Ok(FittedDecisionTreeRegressor {
586            nodes,
587            n_features,
588            feature_importances,
589        })
590    }
591}
592
593impl<F: Float + Send + Sync + 'static> FittedDecisionTreeRegressor<F> {
594    /// Returns a reference to the flat node storage of the tree.
595    #[must_use]
596    pub fn nodes(&self) -> &[Node<F>] {
597        &self.nodes
598    }
599
600    /// Returns the number of features the model was trained on.
601    #[must_use]
602    pub fn n_features(&self) -> usize {
603        self.n_features
604    }
605}
606
607impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedDecisionTreeRegressor<F> {
608    type Output = Array1<F>;
609    type Error = FerroError;
610
611    /// Predict target values for the given feature matrix.
612    ///
613    /// # Errors
614    ///
615    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
616    /// not match the fitted model.
617    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
618        if x.ncols() != self.n_features {
619            return Err(FerroError::ShapeMismatch {
620                expected: vec![self.n_features],
621                actual: vec![x.ncols()],
622                context: "number of features must match fitted model".into(),
623            });
624        }
625        let n_samples = x.nrows();
626        let mut predictions = Array1::zeros(n_samples);
627        for i in 0..n_samples {
628            let row = x.row(i);
629            let leaf = traverse_tree(&self.nodes, &row);
630            if let Node::Leaf { value, .. } = self.nodes[leaf] {
631                predictions[i] = value;
632            }
633        }
634        Ok(predictions)
635    }
636}
637
638impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedDecisionTreeRegressor<F> {
639    fn feature_importances(&self) -> &Array1<F> {
640        &self.feature_importances
641    }
642}
643
644// Pipeline integration.
645impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for DecisionTreeRegressor<F> {
646    fn fit_pipeline(
647        &self,
648        x: &Array2<F>,
649        y: &Array1<F>,
650    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
651        let fitted = self.fit(x, y)?;
652        Ok(Box::new(fitted))
653    }
654}
655
656impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F>
657    for FittedDecisionTreeRegressor<F>
658{
659    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
660        self.predict(x)
661    }
662}
663
664// ---------------------------------------------------------------------------
665// Internal: tree building helpers
666// ---------------------------------------------------------------------------
667
668/// Traverse the tree from root to leaf for a single sample, returning the leaf node index.
669fn traverse_tree<F: Float>(nodes: &[Node<F>], sample: &ndarray::ArrayView1<F>) -> usize {
670    let mut idx = 0;
671    loop {
672        match &nodes[idx] {
673            Node::Split {
674                feature,
675                threshold,
676                left,
677                right,
678                ..
679            } => {
680                if sample[*feature] <= *threshold {
681                    idx = *left;
682                } else {
683                    idx = *right;
684                }
685            }
686            Node::Leaf { .. } => return idx,
687        }
688    }
689}
690
691/// Traverse a tree from root to leaf for a single sample (crate-public wrapper).
692///
693/// Returns the index of the leaf node in the flat node vector.
694pub(crate) fn traverse<F: Float>(nodes: &[Node<F>], sample: &ndarray::ArrayView1<F>) -> usize {
695    traverse_tree(nodes, sample)
696}
697
698/// Convert a `Float` value to `usize` (for class labels stored as floats).
699fn float_to_usize<F: Float>(v: F) -> usize {
700    v.to_f64().map(|f| f.round() as usize).unwrap_or(0)
701}
702
703/// Compute the Gini impurity for a set of class counts.
704fn gini_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
705    if total == 0 {
706        return F::zero();
707    }
708    let total_f = F::from(total).unwrap();
709    let mut impurity = F::one();
710    for &count in class_counts {
711        let p = F::from(count).unwrap() / total_f;
712        impurity = impurity - p * p;
713    }
714    impurity
715}
716
717/// Compute the Shannon entropy for a set of class counts.
718fn entropy_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
719    if total == 0 {
720        return F::zero();
721    }
722    let total_f = F::from(total).unwrap();
723    let mut ent = F::zero();
724    for &count in class_counts {
725        if count > 0 {
726            let p = F::from(count).unwrap() / total_f;
727            ent = ent - p * p.ln();
728        }
729    }
730    ent
731}
732
733/// Compute the mean of target values for the given indices.
734fn mean_value<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
735    if indices.is_empty() {
736        return F::zero();
737    }
738    let sum: F = indices.iter().map(|&i| y[i]).fold(F::zero(), |a, b| a + b);
739    sum / F::from(indices.len()).unwrap()
740}
741
742/// Compute the MSE for the given indices relative to a given mean.
743fn mse_for_indices<F: Float>(y: &Array1<F>, indices: &[usize], mean: F) -> F {
744    if indices.is_empty() {
745        return F::zero();
746    }
747    let sum_sq: F = indices
748        .iter()
749        .map(|&i| {
750            let diff = y[i] - mean;
751            diff * diff
752        })
753        .fold(F::zero(), |a, b| a + b);
754    sum_sq / F::from(indices.len()).unwrap()
755}
756
757/// Compute impurity for a given classification criterion.
758fn compute_impurity<F: Float>(
759    class_counts: &[usize],
760    total: usize,
761    criterion: ClassificationCriterion,
762) -> F {
763    match criterion {
764        ClassificationCriterion::Gini => gini_impurity(class_counts, total),
765        ClassificationCriterion::Entropy => entropy_impurity(class_counts, total),
766    }
767}
768
769/// Create a classification leaf node and return its index.
770fn make_classification_leaf<F: Float>(
771    nodes: &mut Vec<Node<F>>,
772    class_counts: &[usize],
773    n_classes: usize,
774    n_samples: usize,
775) -> usize {
776    let majority_class = class_counts
777        .iter()
778        .enumerate()
779        .max_by_key(|&(_, &count)| count)
780        .map(|(i, _)| i)
781        .unwrap_or(0);
782
783    let total_f = if n_samples > 0 {
784        F::from(n_samples).unwrap()
785    } else {
786        F::one()
787    };
788    let distribution: Vec<F> = (0..n_classes)
789        .map(|c| F::from(class_counts[c]).unwrap() / total_f)
790        .collect();
791
792    let idx = nodes.len();
793    nodes.push(Node::Leaf {
794        value: F::from(majority_class).unwrap(),
795        class_distribution: Some(distribution),
796        n_samples,
797    });
798    idx
799}
800
801/// Build a classification tree recursively.
802///
803/// Returns the index of the node that was created at the root of this subtree.
804fn build_classification_tree<F: Float>(
805    data: &ClassificationData<'_, F>,
806    indices: &[usize],
807    nodes: &mut Vec<Node<F>>,
808    depth: usize,
809    params: &TreeParams,
810) -> usize {
811    let n = indices.len();
812
813    let mut class_counts = vec![0usize; data.n_classes];
814    for &i in indices {
815        class_counts[data.y[i]] += 1;
816    }
817
818    let should_stop = n < params.min_samples_split
819        || params.max_depth.is_some_and(|d| depth >= d)
820        || class_counts.iter().filter(|&&c| c > 0).count() <= 1;
821
822    if should_stop {
823        return make_classification_leaf(nodes, &class_counts, data.n_classes, n);
824    }
825
826    let best = find_best_classification_split(data, indices, params.min_samples_leaf);
827
828    if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
829        let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
830            .iter()
831            .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
832
833        let node_idx = nodes.len();
834        nodes.push(Node::Leaf {
835            value: F::zero(),
836            class_distribution: None,
837            n_samples: 0,
838        }); // placeholder
839
840        let left_idx = build_classification_tree(data, &left_indices, nodes, depth + 1, params);
841        let right_idx = build_classification_tree(data, &right_indices, nodes, depth + 1, params);
842
843        nodes[node_idx] = Node::Split {
844            feature: best_feature,
845            threshold: best_threshold,
846            left: left_idx,
847            right: right_idx,
848            impurity_decrease: best_impurity_decrease,
849            n_samples: n,
850        };
851
852        node_idx
853    } else {
854        make_classification_leaf(nodes, &class_counts, data.n_classes, n)
855    }
856}
857
858/// Find the best split for a classification node.
859///
860/// Returns `(feature_index, threshold, weighted_impurity_decrease)` or `None`.
861fn find_best_classification_split<F: Float>(
862    data: &ClassificationData<'_, F>,
863    indices: &[usize],
864    min_samples_leaf: usize,
865) -> Option<(usize, F, F)> {
866    let n = indices.len();
867    let n_f = F::from(n).unwrap();
868    let n_features = data.x.ncols();
869
870    let mut parent_counts = vec![0usize; data.n_classes];
871    for &i in indices {
872        parent_counts[data.y[i]] += 1;
873    }
874    let parent_impurity = compute_impurity::<F>(&parent_counts, n, data.criterion);
875
876    let mut best_score = F::neg_infinity();
877    let mut best_feature = 0;
878    let mut best_threshold = F::zero();
879
880    // Iterate over either specified feature subset or all features.
881    let feature_iter: Box<dyn Iterator<Item = usize>> =
882        if let Some(feat_indices) = data.feature_indices {
883            Box::new(feat_indices.iter().copied())
884        } else {
885            Box::new(0..n_features)
886        };
887
888    for feat in feature_iter {
889        let mut sorted_indices: Vec<usize> = indices.to_vec();
890        sorted_indices.sort_by(|&a, &b| data.x[[a, feat]].partial_cmp(&data.x[[b, feat]]).unwrap());
891
892        let mut left_counts = vec![0usize; data.n_classes];
893        let mut right_counts = parent_counts.clone();
894        let mut left_n = 0usize;
895
896        for split_pos in 0..n - 1 {
897            let idx = sorted_indices[split_pos];
898            let cls = data.y[idx];
899            left_counts[cls] += 1;
900            right_counts[cls] -= 1;
901            left_n += 1;
902            let right_n = n - left_n;
903
904            let next_idx = sorted_indices[split_pos + 1];
905            if data.x[[idx, feat]] == data.x[[next_idx, feat]] {
906                continue;
907            }
908
909            if left_n < min_samples_leaf || right_n < min_samples_leaf {
910                continue;
911            }
912
913            let left_impurity = compute_impurity::<F>(&left_counts, left_n, data.criterion);
914            let right_impurity = compute_impurity::<F>(&right_counts, right_n, data.criterion);
915            let left_weight = F::from(left_n).unwrap() / n_f;
916            let right_weight = F::from(right_n).unwrap() / n_f;
917            let weighted_child_impurity =
918                left_weight * left_impurity + right_weight * right_impurity;
919            let impurity_decrease = parent_impurity - weighted_child_impurity;
920
921            if impurity_decrease > best_score {
922                best_score = impurity_decrease;
923                best_feature = feat;
924                best_threshold =
925                    (data.x[[idx, feat]] + data.x[[next_idx, feat]]) / F::from(2.0).unwrap();
926            }
927        }
928    }
929
930    if best_score > F::zero() {
931        Some((best_feature, best_threshold, best_score * n_f))
932    } else {
933        None
934    }
935}
936
937/// Build a regression tree recursively.
938fn build_regression_tree<F: Float>(
939    data: &RegressionData<'_, F>,
940    indices: &[usize],
941    nodes: &mut Vec<Node<F>>,
942    depth: usize,
943    params: &TreeParams,
944) -> usize {
945    let n = indices.len();
946    let mean = mean_value(data.y, indices);
947
948    let should_stop = n < params.min_samples_split || params.max_depth.is_some_and(|d| depth >= d);
949
950    if should_stop {
951        let idx = nodes.len();
952        nodes.push(Node::Leaf {
953            value: mean,
954            class_distribution: None,
955            n_samples: n,
956        });
957        return idx;
958    }
959
960    let parent_mse = mse_for_indices(data.y, indices, mean);
961    if parent_mse <= F::epsilon() {
962        let idx = nodes.len();
963        nodes.push(Node::Leaf {
964            value: mean,
965            class_distribution: None,
966            n_samples: n,
967        });
968        return idx;
969    }
970
971    let best = find_best_regression_split(data, indices, params.min_samples_leaf);
972
973    if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
974        let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
975            .iter()
976            .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
977
978        let node_idx = nodes.len();
979        nodes.push(Node::Leaf {
980            value: F::zero(),
981            class_distribution: None,
982            n_samples: 0,
983        }); // placeholder
984
985        let left_idx = build_regression_tree(data, &left_indices, nodes, depth + 1, params);
986        let right_idx = build_regression_tree(data, &right_indices, nodes, depth + 1, params);
987
988        nodes[node_idx] = Node::Split {
989            feature: best_feature,
990            threshold: best_threshold,
991            left: left_idx,
992            right: right_idx,
993            impurity_decrease: best_impurity_decrease,
994            n_samples: n,
995        };
996
997        node_idx
998    } else {
999        let idx = nodes.len();
1000        nodes.push(Node::Leaf {
1001            value: mean,
1002            class_distribution: None,
1003            n_samples: n,
1004        });
1005        idx
1006    }
1007}
1008
1009/// Find the best split for a regression node using MSE reduction.
1010///
1011/// Returns `(feature_index, threshold, weighted_mse_decrease)` or `None`.
1012fn find_best_regression_split<F: Float>(
1013    data: &RegressionData<'_, F>,
1014    indices: &[usize],
1015    min_samples_leaf: usize,
1016) -> Option<(usize, F, F)> {
1017    let n = indices.len();
1018    let n_f = F::from(n).unwrap();
1019    let n_features = data.x.ncols();
1020
1021    let parent_sum: F = indices
1022        .iter()
1023        .map(|&i| data.y[i])
1024        .fold(F::zero(), |a, b| a + b);
1025    let parent_sum_sq: F = indices
1026        .iter()
1027        .map(|&i| data.y[i] * data.y[i])
1028        .fold(F::zero(), |a, b| a + b);
1029    let parent_mse = parent_sum_sq / n_f - (parent_sum / n_f) * (parent_sum / n_f);
1030
1031    let mut best_score = F::neg_infinity();
1032    let mut best_feature = 0;
1033    let mut best_threshold = F::zero();
1034
1035    let feature_iter: Box<dyn Iterator<Item = usize>> =
1036        if let Some(feat_indices) = data.feature_indices {
1037            Box::new(feat_indices.iter().copied())
1038        } else {
1039            Box::new(0..n_features)
1040        };
1041
1042    for feat in feature_iter {
1043        let mut sorted_indices: Vec<usize> = indices.to_vec();
1044        sorted_indices.sort_by(|&a, &b| data.x[[a, feat]].partial_cmp(&data.x[[b, feat]]).unwrap());
1045
1046        let mut left_sum = F::zero();
1047        let mut left_sum_sq = F::zero();
1048        let mut left_n: usize = 0;
1049
1050        for split_pos in 0..n - 1 {
1051            let idx = sorted_indices[split_pos];
1052            let val = data.y[idx];
1053            left_sum = left_sum + val;
1054            left_sum_sq = left_sum_sq + val * val;
1055            left_n += 1;
1056            let right_n = n - left_n;
1057
1058            let next_idx = sorted_indices[split_pos + 1];
1059            if data.x[[idx, feat]] == data.x[[next_idx, feat]] {
1060                continue;
1061            }
1062
1063            if left_n < min_samples_leaf || right_n < min_samples_leaf {
1064                continue;
1065            }
1066
1067            let left_n_f = F::from(left_n).unwrap();
1068            let right_n_f = F::from(right_n).unwrap();
1069
1070            let left_mean = left_sum / left_n_f;
1071            let left_mse = left_sum_sq / left_n_f - left_mean * left_mean;
1072
1073            let right_sum = parent_sum - left_sum;
1074            let right_sum_sq = parent_sum_sq - left_sum_sq;
1075            let right_mean = right_sum / right_n_f;
1076            let right_mse = right_sum_sq / right_n_f - right_mean * right_mean;
1077
1078            let weighted_child_mse = (left_n_f * left_mse + right_n_f * right_mse) / n_f;
1079            let mse_decrease = parent_mse - weighted_child_mse;
1080
1081            if mse_decrease > best_score {
1082                best_score = mse_decrease;
1083                best_feature = feat;
1084                best_threshold =
1085                    (data.x[[idx, feat]] + data.x[[next_idx, feat]]) / F::from(2.0).unwrap();
1086            }
1087        }
1088    }
1089
1090    if best_score > F::zero() {
1091        Some((best_feature, best_threshold, best_score * n_f))
1092    } else {
1093        None
1094    }
1095}
1096
1097/// Compute normalised feature importances from impurity decreases in the tree.
1098pub(crate) fn compute_feature_importances<F: Float>(
1099    nodes: &[Node<F>],
1100    n_features: usize,
1101    _total_samples: usize,
1102) -> Array1<F> {
1103    let mut importances = Array1::zeros(n_features);
1104    for node in nodes {
1105        if let Node::Split {
1106            feature,
1107            impurity_decrease,
1108            ..
1109        } = node
1110        {
1111            importances[*feature] = importances[*feature] + *impurity_decrease;
1112        }
1113    }
1114    let total: F = importances.iter().copied().fold(F::zero(), |a, b| a + b);
1115    if total > F::zero() {
1116        importances.mapv_inplace(|v| v / total);
1117    }
1118    importances
1119}
1120
1121// ---------------------------------------------------------------------------
1122// Public builders for forest usage
1123// ---------------------------------------------------------------------------
1124
1125/// Build a classification tree with a subset of features considered per split.
1126///
1127/// Used internally by `RandomForestClassifier` to build individual trees.
1128#[allow(clippy::too_many_arguments)]
1129pub(crate) fn build_classification_tree_with_feature_subset<F: Float>(
1130    x: &Array2<F>,
1131    y: &[usize],
1132    n_classes: usize,
1133    indices: &[usize],
1134    feature_indices: &[usize],
1135    params: &TreeParams,
1136    criterion: ClassificationCriterion,
1137) -> Vec<Node<F>> {
1138    let data = ClassificationData {
1139        x,
1140        y,
1141        n_classes,
1142        feature_indices: Some(feature_indices),
1143        criterion,
1144    };
1145    let mut nodes = Vec::new();
1146    build_classification_tree(&data, indices, &mut nodes, 0, params);
1147    nodes
1148}
1149
1150/// Build a regression tree with a subset of features considered per split.
1151pub(crate) fn build_regression_tree_with_feature_subset<F: Float>(
1152    x: &Array2<F>,
1153    y: &Array1<F>,
1154    indices: &[usize],
1155    feature_indices: &[usize],
1156    params: &TreeParams,
1157) -> Vec<Node<F>> {
1158    let data = RegressionData {
1159        x,
1160        y,
1161        feature_indices: Some(feature_indices),
1162    };
1163    let mut nodes = Vec::new();
1164    build_regression_tree(&data, indices, &mut nodes, 0, params);
1165    nodes
1166}
1167
1168// ---------------------------------------------------------------------------
1169// Tests
1170// ---------------------------------------------------------------------------
1171
1172#[cfg(test)]
1173mod tests {
1174    use super::*;
1175    use approx::assert_relative_eq;
1176    use ndarray::array;
1177
1178    // -- Classifier tests --
1179
1180    #[test]
1181    fn test_classifier_simple_binary() {
1182        let x = Array2::from_shape_vec(
1183            (6, 2),
1184            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
1185        )
1186        .unwrap();
1187        let y = array![0, 0, 0, 1, 1, 1];
1188
1189        let model = DecisionTreeClassifier::<f64>::new();
1190        let fitted = model.fit(&x, &y).unwrap();
1191        let preds = fitted.predict(&x).unwrap();
1192
1193        assert_eq!(preds.len(), 6);
1194        for i in 0..3 {
1195            assert_eq!(preds[i], 0);
1196        }
1197        for i in 3..6 {
1198            assert_eq!(preds[i], 1);
1199        }
1200    }
1201
1202    #[test]
1203    fn test_classifier_single_class() {
1204        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1205        let y = array![0, 0, 0];
1206
1207        let model = DecisionTreeClassifier::<f64>::new();
1208        let fitted = model.fit(&x, &y).unwrap();
1209        let preds = fitted.predict(&x).unwrap();
1210
1211        assert_eq!(preds, array![0, 0, 0]);
1212    }
1213
1214    #[test]
1215    fn test_classifier_max_depth_1() {
1216        let x =
1217            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1218        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1219
1220        let model = DecisionTreeClassifier::<f64>::new().with_max_depth(Some(1));
1221        let fitted = model.fit(&x, &y).unwrap();
1222        let preds = fitted.predict(&x).unwrap();
1223
1224        for i in 0..4 {
1225            assert_eq!(preds[i], 0);
1226        }
1227        for i in 4..8 {
1228            assert_eq!(preds[i], 1);
1229        }
1230    }
1231
1232    #[test]
1233    fn test_classifier_min_samples_split() {
1234        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1235        let y = array![0, 0, 0, 1, 1, 1];
1236
1237        let model = DecisionTreeClassifier::<f64>::new().with_min_samples_split(7);
1238        let fitted = model.fit(&x, &y).unwrap();
1239        let preds = fitted.predict(&x).unwrap();
1240
1241        let majority = preds[0];
1242        for &p in preds.iter() {
1243            assert_eq!(p, majority);
1244        }
1245    }
1246
1247    #[test]
1248    fn test_classifier_min_samples_leaf() {
1249        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1250        let y = array![0, 0, 0, 1, 1, 1];
1251
1252        let model = DecisionTreeClassifier::<f64>::new().with_min_samples_leaf(4);
1253        let fitted = model.fit(&x, &y).unwrap();
1254        let preds = fitted.predict(&x).unwrap();
1255
1256        let majority = preds[0];
1257        for &p in preds.iter() {
1258            assert_eq!(p, majority);
1259        }
1260    }
1261
1262    #[test]
1263    fn test_classifier_gini_vs_entropy() {
1264        let x = Array2::from_shape_vec(
1265            (8, 2),
1266            vec![
1267                1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0, 5.0, 5.0, 5.0, 6.0, 6.0, 5.0, 6.0, 6.0,
1268            ],
1269        )
1270        .unwrap();
1271        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1272
1273        let gini_model =
1274            DecisionTreeClassifier::<f64>::new().with_criterion(ClassificationCriterion::Gini);
1275        let entropy_model =
1276            DecisionTreeClassifier::<f64>::new().with_criterion(ClassificationCriterion::Entropy);
1277
1278        let fitted_gini = gini_model.fit(&x, &y).unwrap();
1279        let fitted_entropy = entropy_model.fit(&x, &y).unwrap();
1280
1281        let preds_gini = fitted_gini.predict(&x).unwrap();
1282        let preds_entropy = fitted_entropy.predict(&x).unwrap();
1283
1284        assert_eq!(preds_gini, y);
1285        assert_eq!(preds_entropy, y);
1286    }
1287
1288    #[test]
1289    fn test_classifier_predict_proba() {
1290        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1291        let y = array![0, 0, 0, 1, 1, 1];
1292
1293        let model = DecisionTreeClassifier::<f64>::new();
1294        let fitted = model.fit(&x, &y).unwrap();
1295        let proba = fitted.predict_proba(&x).unwrap();
1296
1297        assert_eq!(proba.dim(), (6, 2));
1298        for i in 0..6 {
1299            let row_sum: f64 = proba.row(i).iter().sum();
1300            assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
1301        }
1302    }
1303
1304    #[test]
1305    fn test_classifier_shape_mismatch_fit() {
1306        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1307        let y = array![0, 1];
1308
1309        let model = DecisionTreeClassifier::<f64>::new();
1310        assert!(model.fit(&x, &y).is_err());
1311    }
1312
1313    #[test]
1314    fn test_classifier_shape_mismatch_predict() {
1315        let x =
1316            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1317        let y = array![0, 0, 1, 1];
1318
1319        let model = DecisionTreeClassifier::<f64>::new();
1320        let fitted = model.fit(&x, &y).unwrap();
1321
1322        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1323        assert!(fitted.predict(&x_bad).is_err());
1324    }
1325
1326    #[test]
1327    fn test_classifier_empty_data() {
1328        let x = Array2::<f64>::zeros((0, 2));
1329        let y = Array1::<usize>::zeros(0);
1330
1331        let model = DecisionTreeClassifier::<f64>::new();
1332        assert!(model.fit(&x, &y).is_err());
1333    }
1334
1335    #[test]
1336    fn test_classifier_feature_importances() {
1337        let x = Array2::from_shape_vec(
1338            (8, 2),
1339            vec![
1340                1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1341            ],
1342        )
1343        .unwrap();
1344        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1345
1346        let model = DecisionTreeClassifier::<f64>::new();
1347        let fitted = model.fit(&x, &y).unwrap();
1348        let importances = fitted.feature_importances();
1349
1350        assert_eq!(importances.len(), 2);
1351        assert!(importances[0] > 0.0);
1352        let sum: f64 = importances.iter().sum();
1353        assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
1354    }
1355
1356    #[test]
1357    fn test_classifier_has_classes() {
1358        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1359        let y = array![0, 1, 2, 0, 1, 2];
1360
1361        let model = DecisionTreeClassifier::<f64>::new();
1362        let fitted = model.fit(&x, &y).unwrap();
1363
1364        assert_eq!(fitted.classes(), &[0, 1, 2]);
1365        assert_eq!(fitted.n_classes(), 3);
1366    }
1367
1368    #[test]
1369    fn test_classifier_invalid_min_samples_split() {
1370        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1371        let y = array![0, 0, 1, 1];
1372
1373        let model = DecisionTreeClassifier::<f64>::new().with_min_samples_split(1);
1374        assert!(model.fit(&x, &y).is_err());
1375    }
1376
1377    #[test]
1378    fn test_classifier_invalid_min_samples_leaf() {
1379        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1380        let y = array![0, 0, 1, 1];
1381
1382        let model = DecisionTreeClassifier::<f64>::new().with_min_samples_leaf(0);
1383        assert!(model.fit(&x, &y).is_err());
1384    }
1385
1386    #[test]
1387    fn test_classifier_multiclass() {
1388        let x = Array2::from_shape_vec((9, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1389            .unwrap();
1390        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
1391
1392        let model = DecisionTreeClassifier::<f64>::new();
1393        let fitted = model.fit(&x, &y).unwrap();
1394        let preds = fitted.predict(&x).unwrap();
1395
1396        assert_eq!(preds, y);
1397    }
1398
1399    #[test]
1400    fn test_classifier_pipeline_integration() {
1401        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1402        let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
1403
1404        let model = DecisionTreeClassifier::<f64>::new();
1405        let fitted = model.fit_pipeline(&x, &y).unwrap();
1406        let preds = fitted.predict_pipeline(&x).unwrap();
1407        assert_eq!(preds.len(), 6);
1408    }
1409
1410    // -- Regressor tests --
1411
1412    #[test]
1413    fn test_regressor_simple() {
1414        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
1415        let y = array![1.0, 2.0, 3.0, 4.0, 5.0];
1416
1417        let model = DecisionTreeRegressor::<f64>::new();
1418        let fitted = model.fit(&x, &y).unwrap();
1419        let preds = fitted.predict(&x).unwrap();
1420
1421        for (p, &actual) in preds.iter().zip(y.iter()) {
1422            assert_relative_eq!(*p, actual, epsilon = 1e-10);
1423        }
1424    }
1425
1426    #[test]
1427    fn test_regressor_max_depth() {
1428        let x =
1429            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1430        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1431
1432        let model = DecisionTreeRegressor::<f64>::new().with_max_depth(Some(1));
1433        let fitted = model.fit(&x, &y).unwrap();
1434        let preds = fitted.predict(&x).unwrap();
1435
1436        for i in 0..4 {
1437            assert_relative_eq!(preds[i], 1.0, epsilon = 1e-10);
1438        }
1439        for i in 4..8 {
1440            assert_relative_eq!(preds[i], 5.0, epsilon = 1e-10);
1441        }
1442    }
1443
1444    #[test]
1445    fn test_regressor_constant_target() {
1446        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1447        let y = array![3.0, 3.0, 3.0, 3.0];
1448
1449        let model = DecisionTreeRegressor::<f64>::new();
1450        let fitted = model.fit(&x, &y).unwrap();
1451        let preds = fitted.predict(&x).unwrap();
1452
1453        for &p in preds.iter() {
1454            assert_relative_eq!(p, 3.0, epsilon = 1e-10);
1455        }
1456    }
1457
1458    #[test]
1459    fn test_regressor_shape_mismatch_fit() {
1460        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1461        let y = array![1.0, 2.0];
1462
1463        let model = DecisionTreeRegressor::<f64>::new();
1464        assert!(model.fit(&x, &y).is_err());
1465    }
1466
1467    #[test]
1468    fn test_regressor_shape_mismatch_predict() {
1469        let x =
1470            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1471        let y = array![1.0, 2.0, 3.0, 4.0];
1472
1473        let model = DecisionTreeRegressor::<f64>::new();
1474        let fitted = model.fit(&x, &y).unwrap();
1475
1476        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1477        assert!(fitted.predict(&x_bad).is_err());
1478    }
1479
1480    #[test]
1481    fn test_regressor_empty_data() {
1482        let x = Array2::<f64>::zeros((0, 2));
1483        let y = Array1::<f64>::zeros(0);
1484
1485        let model = DecisionTreeRegressor::<f64>::new();
1486        assert!(model.fit(&x, &y).is_err());
1487    }
1488
1489    #[test]
1490    fn test_regressor_feature_importances() {
1491        let x = Array2::from_shape_vec(
1492            (8, 2),
1493            vec![
1494                1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1495            ],
1496        )
1497        .unwrap();
1498        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1499
1500        let model = DecisionTreeRegressor::<f64>::new();
1501        let fitted = model.fit(&x, &y).unwrap();
1502        let importances = fitted.feature_importances();
1503
1504        assert_eq!(importances.len(), 2);
1505        assert!(importances[0] > 0.0);
1506        let sum: f64 = importances.iter().sum();
1507        assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
1508    }
1509
1510    #[test]
1511    fn test_regressor_min_samples_split() {
1512        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1513        let y = array![1.0, 2.0, 3.0, 4.0];
1514
1515        let model = DecisionTreeRegressor::<f64>::new().with_min_samples_split(5);
1516        let fitted = model.fit(&x, &y).unwrap();
1517        let preds = fitted.predict(&x).unwrap();
1518
1519        let mean = 2.5;
1520        for &p in preds.iter() {
1521            assert_relative_eq!(p, mean, epsilon = 1e-10);
1522        }
1523    }
1524
1525    #[test]
1526    fn test_regressor_pipeline_integration() {
1527        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1528        let y = array![1.0, 2.0, 3.0, 4.0];
1529
1530        let model = DecisionTreeRegressor::<f64>::new();
1531        let fitted = model.fit_pipeline(&x, &y).unwrap();
1532        let preds = fitted.predict_pipeline(&x).unwrap();
1533        assert_eq!(preds.len(), 4);
1534    }
1535
1536    #[test]
1537    fn test_regressor_f32_support() {
1538        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1539        let y = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
1540
1541        let model = DecisionTreeRegressor::<f32>::new();
1542        let fitted = model.fit(&x, &y).unwrap();
1543        let preds = fitted.predict(&x).unwrap();
1544        assert_eq!(preds.len(), 4);
1545    }
1546
1547    #[test]
1548    fn test_classifier_f32_support() {
1549        let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1550        let y = array![0, 0, 0, 1, 1, 1];
1551
1552        let model = DecisionTreeClassifier::<f32>::new();
1553        let fitted = model.fit(&x, &y).unwrap();
1554        let preds = fitted.predict(&x).unwrap();
1555        assert_eq!(preds.len(), 6);
1556    }
1557
1558    // -- Internal helper tests --
1559
1560    #[test]
1561    fn test_gini_impurity_pure() {
1562        let counts = vec![5, 0];
1563        let imp: f64 = gini_impurity(&counts, 5);
1564        assert_relative_eq!(imp, 0.0, epsilon = 1e-10);
1565    }
1566
1567    #[test]
1568    fn test_gini_impurity_balanced() {
1569        let counts = vec![5, 5];
1570        let imp: f64 = gini_impurity(&counts, 10);
1571        assert_relative_eq!(imp, 0.5, epsilon = 1e-10);
1572    }
1573
1574    #[test]
1575    fn test_entropy_pure() {
1576        let counts = vec![5, 0];
1577        let ent: f64 = entropy_impurity(&counts, 5);
1578        assert_relative_eq!(ent, 0.0, epsilon = 1e-10);
1579    }
1580
1581    #[test]
1582    fn test_entropy_balanced() {
1583        let counts = vec![5, 5];
1584        let ent: f64 = entropy_impurity(&counts, 10);
1585        assert_relative_eq!(ent, 2.0f64.ln(), epsilon = 1e-10);
1586    }
1587}