Skip to main content

ferrolearn_tree/
decision_tree.rs

1//! CART decision tree classifiers and regressors.
2//!
3//! This module provides [`DecisionTreeClassifier`] and [`DecisionTreeRegressor`],
4//! implementing the Classification and Regression Trees (CART) algorithm with
5//! configurable splitting criteria, depth limits, and minimum sample constraints.
6//!
7//! # Examples
8//!
9//! ```
10//! use ferrolearn_tree::DecisionTreeClassifier;
11//! use ferrolearn_core::{Fit, Predict};
12//! use ndarray::{array, Array1, Array2};
13//!
14//! let x = Array2::from_shape_vec((6, 2), vec![
15//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,
16//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,
17//! ]).unwrap();
18//! let y = array![0, 0, 0, 1, 1, 1];
19//!
20//! let model = DecisionTreeClassifier::<f64>::new();
21//! let fitted = model.fit(&x, &y).unwrap();
22//! let preds = fitted.predict(&x).unwrap();
23//! ```
24
25use ferrolearn_core::error::FerroError;
26use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
27use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
28use ferrolearn_core::traits::{Fit, Predict};
29use ndarray::{Array1, Array2};
30use num_traits::{Float, FromPrimitive, ToPrimitive};
31use serde::{Deserialize, Serialize};
32
33// ---------------------------------------------------------------------------
34// Splitting criterion enums
35// ---------------------------------------------------------------------------
36
37/// Splitting criterion for classification trees.
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum ClassificationCriterion {
40    /// Gini impurity.
41    Gini,
42    /// Shannon entropy.
43    Entropy,
44}
45
46/// Splitting criterion for regression trees.
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
48pub enum RegressionCriterion {
49    /// Mean squared error.
50    Mse,
51}
52
53// ---------------------------------------------------------------------------
54// Node representation (flat vec for cache efficiency)
55// ---------------------------------------------------------------------------
56
57/// A single node in the decision tree, stored in a flat `Vec` for cache efficiency.
58///
59/// Internal nodes hold a split (feature index + threshold), while leaf nodes
60/// store a prediction value and optional class distribution (for classifiers).
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub enum Node<F> {
63    /// An internal split node.
64    Split {
65        /// Feature index used for the split.
66        feature: usize,
67        /// Threshold value; samples with `x[feature] <= threshold` go left.
68        threshold: F,
69        /// Index of the left child node in the flat vec.
70        left: usize,
71        /// Index of the right child node in the flat vec.
72        right: usize,
73        /// Weighted impurity decrease from this split (for feature importance).
74        impurity_decrease: F,
75        /// Number of samples that reached this node during training.
76        n_samples: usize,
77    },
78    /// A leaf node that stores a prediction.
79    Leaf {
80        /// Predicted value: class label (as F) for classifiers, mean for regressors.
81        value: F,
82        /// Class distribution (proportion of each class). Only used by classifiers.
83        class_distribution: Option<Vec<F>>,
84        /// Number of samples that reached this node during training.
85        n_samples: usize,
86    },
87}
88
89// ---------------------------------------------------------------------------
90// Internal config structs (to reduce argument counts)
91// ---------------------------------------------------------------------------
92
93/// Configuration parameters for tree building, bundled to reduce argument counts.
94#[derive(Debug, Clone, Copy)]
95pub(crate) struct TreeParams {
96    pub(crate) max_depth: Option<usize>,
97    pub(crate) min_samples_split: usize,
98    pub(crate) min_samples_leaf: usize,
99}
100
101/// Data references for classification tree building.
102struct ClassificationData<'a, F> {
103    x: &'a Array2<F>,
104    y: &'a [usize],
105    n_classes: usize,
106    feature_indices: Option<&'a [usize]>,
107    criterion: ClassificationCriterion,
108}
109
110/// Data references for regression tree building.
111struct RegressionData<'a, F> {
112    x: &'a Array2<F>,
113    y: &'a Array1<F>,
114    feature_indices: Option<&'a [usize]>,
115}
116
117// ---------------------------------------------------------------------------
118// DecisionTreeClassifier
119// ---------------------------------------------------------------------------
120
121/// CART decision tree classifier.
122///
123/// Builds a binary tree by recursively finding the feature and threshold that
124/// maximises the reduction in the chosen impurity criterion (Gini or Entropy).
125///
126/// # Type Parameters
127///
128/// - `F`: The floating-point type (`f32` or `f64`).
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct DecisionTreeClassifier<F> {
131    /// Maximum depth of the tree. `None` means unlimited.
132    pub max_depth: Option<usize>,
133    /// Minimum number of samples required to split an internal node.
134    pub min_samples_split: usize,
135    /// Minimum number of samples required in a leaf node.
136    pub min_samples_leaf: usize,
137    /// Splitting criterion.
138    pub criterion: ClassificationCriterion,
139    _marker: std::marker::PhantomData<F>,
140}
141
142impl<F: Float> DecisionTreeClassifier<F> {
143    /// Create a new `DecisionTreeClassifier` with default settings.
144    ///
145    /// Defaults: `max_depth = None`, `min_samples_split = 2`,
146    /// `min_samples_leaf = 1`, `criterion = Gini`.
147    #[must_use]
148    pub fn new() -> Self {
149        Self {
150            max_depth: None,
151            min_samples_split: 2,
152            min_samples_leaf: 1,
153            criterion: ClassificationCriterion::Gini,
154            _marker: std::marker::PhantomData,
155        }
156    }
157
158    /// Set the maximum tree depth.
159    #[must_use]
160    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
161        self.max_depth = max_depth;
162        self
163    }
164
165    /// Set the minimum number of samples required to split a node.
166    #[must_use]
167    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
168        self.min_samples_split = min_samples_split;
169        self
170    }
171
172    /// Set the minimum number of samples required in a leaf node.
173    #[must_use]
174    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
175        self.min_samples_leaf = min_samples_leaf;
176        self
177    }
178
179    /// Set the splitting criterion.
180    #[must_use]
181    pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
182        self.criterion = criterion;
183        self
184    }
185}
186
187impl<F: Float> Default for DecisionTreeClassifier<F> {
188    fn default() -> Self {
189        Self::new()
190    }
191}
192
193// ---------------------------------------------------------------------------
194// FittedDecisionTreeClassifier
195// ---------------------------------------------------------------------------
196
197/// A fitted CART decision tree classifier.
198///
199/// Stores the learned tree as a flat `Vec<Node<F>>` for cache-friendly traversal.
200/// Implements [`Predict`] for generating class predictions and
201/// [`HasFeatureImportances`] for inspecting per-feature importance scores.
202#[derive(Debug, Clone)]
203pub struct FittedDecisionTreeClassifier<F> {
204    /// Flat node storage; index 0 is the root.
205    nodes: Vec<Node<F>>,
206    /// Sorted unique class labels observed during training.
207    classes: Vec<usize>,
208    /// Number of features the model was trained on.
209    n_features: usize,
210    /// Per-feature importance scores (normalised to sum to 1).
211    feature_importances: Array1<F>,
212}
213
214impl<F: Float + Send + Sync + 'static> FittedDecisionTreeClassifier<F> {
215    /// Returns a reference to the flat node storage of the tree.
216    #[must_use]
217    pub fn nodes(&self) -> &[Node<F>] {
218        &self.nodes
219    }
220
221    /// Returns the number of features the model was trained on.
222    #[must_use]
223    pub fn n_features(&self) -> usize {
224        self.n_features
225    }
226
227    /// Predict class probabilities for each sample.
228    ///
229    /// Returns a 2-D array of shape `(n_samples, n_classes)`.
230    ///
231    /// # Errors
232    ///
233    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
234    /// not match the training data.
235    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
236        if x.ncols() != self.n_features {
237            return Err(FerroError::ShapeMismatch {
238                expected: vec![self.n_features],
239                actual: vec![x.ncols()],
240                context: "number of features must match fitted model".into(),
241            });
242        }
243        let n_samples = x.nrows();
244        let n_classes = self.classes.len();
245        let mut proba = Array2::zeros((n_samples, n_classes));
246        for i in 0..n_samples {
247            let row = x.row(i);
248            let leaf = traverse_tree(&self.nodes, &row);
249            if let Node::Leaf {
250                class_distribution: Some(ref dist),
251                ..
252            } = self.nodes[leaf]
253            {
254                for (j, &p) in dist.iter().enumerate() {
255                    proba[[i, j]] = p;
256                }
257            }
258        }
259        Ok(proba)
260    }
261}
262
263impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for DecisionTreeClassifier<F> {
264    type Fitted = FittedDecisionTreeClassifier<F>;
265    type Error = FerroError;
266
267    /// Fit the decision tree classifier on the training data.
268    ///
269    /// # Errors
270    ///
271    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
272    /// numbers of samples.
273    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
274    /// Returns [`FerroError::InvalidParameter`] if hyperparameters are invalid.
275    fn fit(
276        &self,
277        x: &Array2<F>,
278        y: &Array1<usize>,
279    ) -> Result<FittedDecisionTreeClassifier<F>, FerroError> {
280        let (n_samples, n_features) = x.dim();
281
282        if n_samples != y.len() {
283            return Err(FerroError::ShapeMismatch {
284                expected: vec![n_samples],
285                actual: vec![y.len()],
286                context: "y length must match number of samples in X".into(),
287            });
288        }
289        if n_samples == 0 {
290            return Err(FerroError::InsufficientSamples {
291                required: 1,
292                actual: 0,
293                context: "DecisionTreeClassifier requires at least one sample".into(),
294            });
295        }
296        if self.min_samples_split < 2 {
297            return Err(FerroError::InvalidParameter {
298                name: "min_samples_split".into(),
299                reason: "must be at least 2".into(),
300            });
301        }
302        if self.min_samples_leaf < 1 {
303            return Err(FerroError::InvalidParameter {
304                name: "min_samples_leaf".into(),
305                reason: "must be at least 1".into(),
306            });
307        }
308
309        // Determine unique classes.
310        let mut classes: Vec<usize> = y.iter().copied().collect();
311        classes.sort_unstable();
312        classes.dedup();
313        let n_classes = classes.len();
314
315        // Map class labels to indices 0..n_classes.
316        let y_mapped: Vec<usize> = y
317            .iter()
318            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
319            .collect();
320
321        let indices: Vec<usize> = (0..n_samples).collect();
322
323        let data = ClassificationData {
324            x,
325            y: &y_mapped,
326            n_classes,
327            feature_indices: None,
328            criterion: self.criterion,
329        };
330        let params = TreeParams {
331            max_depth: self.max_depth,
332            min_samples_split: self.min_samples_split,
333            min_samples_leaf: self.min_samples_leaf,
334        };
335
336        let mut nodes: Vec<Node<F>> = Vec::new();
337        build_classification_tree(&data, &indices, &mut nodes, 0, &params);
338
339        let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
340
341        Ok(FittedDecisionTreeClassifier {
342            nodes,
343            classes,
344            n_features,
345            feature_importances,
346        })
347    }
348}
349
350impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedDecisionTreeClassifier<F> {
351    type Output = Array1<usize>;
352    type Error = FerroError;
353
354    /// Predict class labels for the given feature matrix.
355    ///
356    /// # Errors
357    ///
358    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
359    /// not match the fitted model.
360    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
361        if x.ncols() != self.n_features {
362            return Err(FerroError::ShapeMismatch {
363                expected: vec![self.n_features],
364                actual: vec![x.ncols()],
365                context: "number of features must match fitted model".into(),
366            });
367        }
368        let n_samples = x.nrows();
369        let mut predictions = Array1::zeros(n_samples);
370        for i in 0..n_samples {
371            let row = x.row(i);
372            let leaf = traverse_tree(&self.nodes, &row);
373            if let Node::Leaf { value, .. } = self.nodes[leaf] {
374                predictions[i] = float_to_usize(value);
375            }
376        }
377        Ok(predictions)
378    }
379}
380
381impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
382    for FittedDecisionTreeClassifier<F>
383{
384    fn feature_importances(&self) -> &Array1<F> {
385        &self.feature_importances
386    }
387}
388
389impl<F: Float + Send + Sync + 'static> HasClasses for FittedDecisionTreeClassifier<F> {
390    fn classes(&self) -> &[usize] {
391        &self.classes
392    }
393
394    fn n_classes(&self) -> usize {
395        self.classes.len()
396    }
397}
398
399// Pipeline integration.
400impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
401    for DecisionTreeClassifier<F>
402{
403    fn fit_pipeline(
404        &self,
405        x: &Array2<F>,
406        y: &Array1<F>,
407    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
408        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
409        let fitted = self.fit(x, &y_usize)?;
410        Ok(Box::new(FittedClassifierPipelineAdapter(fitted)))
411    }
412}
413
414/// Adapter to make `FittedDecisionTreeClassifier<F>` work as a pipeline estimator.
415struct FittedClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
416    FittedDecisionTreeClassifier<F>,
417);
418
419impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
420    for FittedClassifierPipelineAdapter<F>
421{
422    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
423        let preds = self.0.predict(x)?;
424        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
425    }
426}
427
428// ---------------------------------------------------------------------------
429// DecisionTreeRegressor
430// ---------------------------------------------------------------------------
431
432/// CART decision tree regressor.
433///
434/// Builds a binary tree by recursively finding the feature and threshold that
435/// minimises the mean squared error of the split.
436///
437/// # Type Parameters
438///
439/// - `F`: The floating-point type (`f32` or `f64`).
440#[derive(Debug, Clone, Serialize, Deserialize)]
441pub struct DecisionTreeRegressor<F> {
442    /// Maximum depth of the tree. `None` means unlimited.
443    pub max_depth: Option<usize>,
444    /// Minimum number of samples required to split an internal node.
445    pub min_samples_split: usize,
446    /// Minimum number of samples required in a leaf node.
447    pub min_samples_leaf: usize,
448    /// Splitting criterion.
449    pub criterion: RegressionCriterion,
450    _marker: std::marker::PhantomData<F>,
451}
452
453impl<F: Float> DecisionTreeRegressor<F> {
454    /// Create a new `DecisionTreeRegressor` with default settings.
455    ///
456    /// Defaults: `max_depth = None`, `min_samples_split = 2`,
457    /// `min_samples_leaf = 1`, `criterion = MSE`.
458    #[must_use]
459    pub fn new() -> Self {
460        Self {
461            max_depth: None,
462            min_samples_split: 2,
463            min_samples_leaf: 1,
464            criterion: RegressionCriterion::Mse,
465            _marker: std::marker::PhantomData,
466        }
467    }
468
469    /// Set the maximum tree depth.
470    #[must_use]
471    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
472        self.max_depth = max_depth;
473        self
474    }
475
476    /// Set the minimum number of samples required to split a node.
477    #[must_use]
478    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
479        self.min_samples_split = min_samples_split;
480        self
481    }
482
483    /// Set the minimum number of samples required in a leaf node.
484    #[must_use]
485    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
486        self.min_samples_leaf = min_samples_leaf;
487        self
488    }
489
490    /// Set the splitting criterion.
491    #[must_use]
492    pub fn with_criterion(mut self, criterion: RegressionCriterion) -> Self {
493        self.criterion = criterion;
494        self
495    }
496}
497
498impl<F: Float> Default for DecisionTreeRegressor<F> {
499    fn default() -> Self {
500        Self::new()
501    }
502}
503
504// ---------------------------------------------------------------------------
505// FittedDecisionTreeRegressor
506// ---------------------------------------------------------------------------
507
508/// A fitted CART decision tree regressor.
509///
510/// Stores the learned tree as a flat `Vec<Node<F>>` for cache-friendly traversal.
511#[derive(Debug, Clone)]
512pub struct FittedDecisionTreeRegressor<F> {
513    /// Flat node storage; index 0 is the root.
514    nodes: Vec<Node<F>>,
515    /// Number of features the model was trained on.
516    n_features: usize,
517    /// Per-feature importance scores (normalised to sum to 1).
518    feature_importances: Array1<F>,
519}
520
521impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for DecisionTreeRegressor<F> {
522    type Fitted = FittedDecisionTreeRegressor<F>;
523    type Error = FerroError;
524
525    /// Fit the decision tree regressor on the training data.
526    ///
527    /// # Errors
528    ///
529    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
530    /// numbers of samples.
531    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
532    /// Returns [`FerroError::InvalidParameter`] if hyperparameters are invalid.
533    fn fit(
534        &self,
535        x: &Array2<F>,
536        y: &Array1<F>,
537    ) -> Result<FittedDecisionTreeRegressor<F>, FerroError> {
538        let (n_samples, n_features) = x.dim();
539
540        if n_samples != y.len() {
541            return Err(FerroError::ShapeMismatch {
542                expected: vec![n_samples],
543                actual: vec![y.len()],
544                context: "y length must match number of samples in X".into(),
545            });
546        }
547        if n_samples == 0 {
548            return Err(FerroError::InsufficientSamples {
549                required: 1,
550                actual: 0,
551                context: "DecisionTreeRegressor requires at least one sample".into(),
552            });
553        }
554        if self.min_samples_split < 2 {
555            return Err(FerroError::InvalidParameter {
556                name: "min_samples_split".into(),
557                reason: "must be at least 2".into(),
558            });
559        }
560        if self.min_samples_leaf < 1 {
561            return Err(FerroError::InvalidParameter {
562                name: "min_samples_leaf".into(),
563                reason: "must be at least 1".into(),
564            });
565        }
566
567        let indices: Vec<usize> = (0..n_samples).collect();
568
569        let data = RegressionData {
570            x,
571            y,
572            feature_indices: None,
573        };
574        let params = TreeParams {
575            max_depth: self.max_depth,
576            min_samples_split: self.min_samples_split,
577            min_samples_leaf: self.min_samples_leaf,
578        };
579
580        let mut nodes: Vec<Node<F>> = Vec::new();
581        build_regression_tree(&data, &indices, &mut nodes, 0, &params);
582
583        let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
584
585        Ok(FittedDecisionTreeRegressor {
586            nodes,
587            n_features,
588            feature_importances,
589        })
590    }
591}
592
593impl<F: Float + Send + Sync + 'static> FittedDecisionTreeRegressor<F> {
594    /// Returns a reference to the flat node storage of the tree.
595    #[must_use]
596    pub fn nodes(&self) -> &[Node<F>] {
597        &self.nodes
598    }
599
600    /// Returns the number of features the model was trained on.
601    #[must_use]
602    pub fn n_features(&self) -> usize {
603        self.n_features
604    }
605}
606
607impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedDecisionTreeRegressor<F> {
608    type Output = Array1<F>;
609    type Error = FerroError;
610
611    /// Predict target values for the given feature matrix.
612    ///
613    /// # Errors
614    ///
615    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
616    /// not match the fitted model.
617    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
618        if x.ncols() != self.n_features {
619            return Err(FerroError::ShapeMismatch {
620                expected: vec![self.n_features],
621                actual: vec![x.ncols()],
622                context: "number of features must match fitted model".into(),
623            });
624        }
625        let n_samples = x.nrows();
626        let mut predictions = Array1::zeros(n_samples);
627        for i in 0..n_samples {
628            let row = x.row(i);
629            let leaf = traverse_tree(&self.nodes, &row);
630            if let Node::Leaf { value, .. } = self.nodes[leaf] {
631                predictions[i] = value;
632            }
633        }
634        Ok(predictions)
635    }
636}
637
638impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedDecisionTreeRegressor<F> {
639    fn feature_importances(&self) -> &Array1<F> {
640        &self.feature_importances
641    }
642}
643
644// Pipeline integration.
645impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for DecisionTreeRegressor<F> {
646    fn fit_pipeline(
647        &self,
648        x: &Array2<F>,
649        y: &Array1<F>,
650    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
651        let fitted = self.fit(x, y)?;
652        Ok(Box::new(fitted))
653    }
654}
655
656impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F>
657    for FittedDecisionTreeRegressor<F>
658{
659    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
660        self.predict(x)
661    }
662}
663
664// ---------------------------------------------------------------------------
665// Internal: tree building helpers
666// ---------------------------------------------------------------------------
667
668/// Traverse the tree from root to leaf for a single sample, returning the leaf node index.
669fn traverse_tree<F: Float>(nodes: &[Node<F>], sample: &ndarray::ArrayView1<F>) -> usize {
670    let mut idx = 0;
671    loop {
672        match &nodes[idx] {
673            Node::Split {
674                feature,
675                threshold,
676                left,
677                right,
678                ..
679            } => {
680                if sample[*feature] <= *threshold {
681                    idx = *left;
682                } else {
683                    idx = *right;
684                }
685            }
686            Node::Leaf { .. } => return idx,
687        }
688    }
689}
690
691/// Traverse a tree from root to leaf for a single sample (crate-public wrapper).
692///
693/// Returns the index of the leaf node in the flat node vector.
694pub(crate) fn traverse<F: Float>(nodes: &[Node<F>], sample: &ndarray::ArrayView1<F>) -> usize {
695    traverse_tree(nodes, sample)
696}
697
698/// Convert a `Float` value to `usize` (for class labels stored as floats).
699fn float_to_usize<F: Float>(v: F) -> usize {
700    v.to_f64().map_or(0, |f| f.round() as usize)
701}
702
703/// Compute the Gini impurity for a set of class counts.
704fn gini_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
705    if total == 0 {
706        return F::zero();
707    }
708    let total_f = F::from(total).unwrap();
709    let mut impurity = F::one();
710    for &count in class_counts {
711        let p = F::from(count).unwrap() / total_f;
712        impurity = impurity - p * p;
713    }
714    impurity
715}
716
717/// Compute the Shannon entropy for a set of class counts.
718fn entropy_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
719    if total == 0 {
720        return F::zero();
721    }
722    let total_f = F::from(total).unwrap();
723    let mut ent = F::zero();
724    for &count in class_counts {
725        if count > 0 {
726            let p = F::from(count).unwrap() / total_f;
727            ent = ent - p * p.ln();
728        }
729    }
730    ent
731}
732
733/// Compute the mean of target values for the given indices.
734fn mean_value<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
735    if indices.is_empty() {
736        return F::zero();
737    }
738    let sum: F = indices.iter().map(|&i| y[i]).fold(F::zero(), |a, b| a + b);
739    sum / F::from(indices.len()).unwrap()
740}
741
742/// Compute the MSE for the given indices relative to a given mean.
743fn mse_for_indices<F: Float>(y: &Array1<F>, indices: &[usize], mean: F) -> F {
744    if indices.is_empty() {
745        return F::zero();
746    }
747    let sum_sq: F = indices
748        .iter()
749        .map(|&i| {
750            let diff = y[i] - mean;
751            diff * diff
752        })
753        .fold(F::zero(), |a, b| a + b);
754    sum_sq / F::from(indices.len()).unwrap()
755}
756
757/// Compute impurity for a given classification criterion.
758fn compute_impurity<F: Float>(
759    class_counts: &[usize],
760    total: usize,
761    criterion: ClassificationCriterion,
762) -> F {
763    match criterion {
764        ClassificationCriterion::Gini => gini_impurity(class_counts, total),
765        ClassificationCriterion::Entropy => entropy_impurity(class_counts, total),
766    }
767}
768
769/// Create a classification leaf node and return its index.
770fn make_classification_leaf<F: Float>(
771    nodes: &mut Vec<Node<F>>,
772    class_counts: &[usize],
773    n_classes: usize,
774    n_samples: usize,
775) -> usize {
776    let majority_class = class_counts
777        .iter()
778        .enumerate()
779        .max_by_key(|&(_, &count)| count)
780        .map_or(0, |(i, _)| i);
781
782    let total_f = if n_samples > 0 {
783        F::from(n_samples).unwrap()
784    } else {
785        F::one()
786    };
787    let distribution: Vec<F> = (0..n_classes)
788        .map(|c| F::from(class_counts[c]).unwrap() / total_f)
789        .collect();
790
791    let idx = nodes.len();
792    nodes.push(Node::Leaf {
793        value: F::from(majority_class).unwrap(),
794        class_distribution: Some(distribution),
795        n_samples,
796    });
797    idx
798}
799
800/// Build a classification tree recursively.
801///
802/// Returns the index of the node that was created at the root of this subtree.
803fn build_classification_tree<F: Float>(
804    data: &ClassificationData<'_, F>,
805    indices: &[usize],
806    nodes: &mut Vec<Node<F>>,
807    depth: usize,
808    params: &TreeParams,
809) -> usize {
810    let n = indices.len();
811
812    let mut class_counts = vec![0usize; data.n_classes];
813    for &i in indices {
814        class_counts[data.y[i]] += 1;
815    }
816
817    let should_stop = n < params.min_samples_split
818        || params.max_depth.is_some_and(|d| depth >= d)
819        || class_counts.iter().filter(|&&c| c > 0).count() <= 1;
820
821    if should_stop {
822        return make_classification_leaf(nodes, &class_counts, data.n_classes, n);
823    }
824
825    let best = find_best_classification_split(data, indices, params.min_samples_leaf);
826
827    if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
828        let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
829            .iter()
830            .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
831
832        let node_idx = nodes.len();
833        nodes.push(Node::Leaf {
834            value: F::zero(),
835            class_distribution: None,
836            n_samples: 0,
837        }); // placeholder
838
839        let left_idx = build_classification_tree(data, &left_indices, nodes, depth + 1, params);
840        let right_idx = build_classification_tree(data, &right_indices, nodes, depth + 1, params);
841
842        nodes[node_idx] = Node::Split {
843            feature: best_feature,
844            threshold: best_threshold,
845            left: left_idx,
846            right: right_idx,
847            impurity_decrease: best_impurity_decrease,
848            n_samples: n,
849        };
850
851        node_idx
852    } else {
853        make_classification_leaf(nodes, &class_counts, data.n_classes, n)
854    }
855}
856
857/// Find the best split for a classification node.
858///
859/// Returns `(feature_index, threshold, weighted_impurity_decrease)` or `None`.
860fn find_best_classification_split<F: Float>(
861    data: &ClassificationData<'_, F>,
862    indices: &[usize],
863    min_samples_leaf: usize,
864) -> Option<(usize, F, F)> {
865    let n = indices.len();
866    let n_f = F::from(n).unwrap();
867    let n_features = data.x.ncols();
868
869    let mut parent_counts = vec![0usize; data.n_classes];
870    for &i in indices {
871        parent_counts[data.y[i]] += 1;
872    }
873    let parent_impurity = compute_impurity::<F>(&parent_counts, n, data.criterion);
874
875    let mut best_score = F::neg_infinity();
876    let mut best_feature = 0;
877    let mut best_threshold = F::zero();
878
879    // Iterate over either specified feature subset or all features.
880    let feature_iter: Box<dyn Iterator<Item = usize>> =
881        if let Some(feat_indices) = data.feature_indices {
882            Box::new(feat_indices.iter().copied())
883        } else {
884            Box::new(0..n_features)
885        };
886
887    for feat in feature_iter {
888        let mut sorted_indices: Vec<usize> = indices.to_vec();
889        sorted_indices.sort_by(|&a, &b| data.x[[a, feat]].partial_cmp(&data.x[[b, feat]]).unwrap());
890
891        let mut left_counts = vec![0usize; data.n_classes];
892        let mut right_counts = parent_counts.clone();
893        let mut left_n = 0usize;
894
895        for split_pos in 0..n - 1 {
896            let idx = sorted_indices[split_pos];
897            let cls = data.y[idx];
898            left_counts[cls] += 1;
899            right_counts[cls] -= 1;
900            left_n += 1;
901            let right_n = n - left_n;
902
903            let next_idx = sorted_indices[split_pos + 1];
904            if data.x[[idx, feat]] == data.x[[next_idx, feat]] {
905                continue;
906            }
907
908            if left_n < min_samples_leaf || right_n < min_samples_leaf {
909                continue;
910            }
911
912            let left_impurity = compute_impurity::<F>(&left_counts, left_n, data.criterion);
913            let right_impurity = compute_impurity::<F>(&right_counts, right_n, data.criterion);
914            let left_weight = F::from(left_n).unwrap() / n_f;
915            let right_weight = F::from(right_n).unwrap() / n_f;
916            let weighted_child_impurity =
917                left_weight * left_impurity + right_weight * right_impurity;
918            let impurity_decrease = parent_impurity - weighted_child_impurity;
919
920            if impurity_decrease > best_score {
921                best_score = impurity_decrease;
922                best_feature = feat;
923                best_threshold =
924                    (data.x[[idx, feat]] + data.x[[next_idx, feat]]) / F::from(2.0).unwrap();
925            }
926        }
927    }
928
929    if best_score > F::zero() {
930        Some((best_feature, best_threshold, best_score * n_f))
931    } else {
932        None
933    }
934}
935
936/// Build a regression tree recursively.
937fn build_regression_tree<F: Float>(
938    data: &RegressionData<'_, F>,
939    indices: &[usize],
940    nodes: &mut Vec<Node<F>>,
941    depth: usize,
942    params: &TreeParams,
943) -> usize {
944    let n = indices.len();
945    let mean = mean_value(data.y, indices);
946
947    let should_stop = n < params.min_samples_split || params.max_depth.is_some_and(|d| depth >= d);
948
949    if should_stop {
950        let idx = nodes.len();
951        nodes.push(Node::Leaf {
952            value: mean,
953            class_distribution: None,
954            n_samples: n,
955        });
956        return idx;
957    }
958
959    let parent_mse = mse_for_indices(data.y, indices, mean);
960    if parent_mse <= F::epsilon() {
961        let idx = nodes.len();
962        nodes.push(Node::Leaf {
963            value: mean,
964            class_distribution: None,
965            n_samples: n,
966        });
967        return idx;
968    }
969
970    let best = find_best_regression_split(data, indices, params.min_samples_leaf);
971
972    if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
973        let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
974            .iter()
975            .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
976
977        let node_idx = nodes.len();
978        nodes.push(Node::Leaf {
979            value: F::zero(),
980            class_distribution: None,
981            n_samples: 0,
982        }); // placeholder
983
984        let left_idx = build_regression_tree(data, &left_indices, nodes, depth + 1, params);
985        let right_idx = build_regression_tree(data, &right_indices, nodes, depth + 1, params);
986
987        nodes[node_idx] = Node::Split {
988            feature: best_feature,
989            threshold: best_threshold,
990            left: left_idx,
991            right: right_idx,
992            impurity_decrease: best_impurity_decrease,
993            n_samples: n,
994        };
995
996        node_idx
997    } else {
998        let idx = nodes.len();
999        nodes.push(Node::Leaf {
1000            value: mean,
1001            class_distribution: None,
1002            n_samples: n,
1003        });
1004        idx
1005    }
1006}
1007
1008/// Find the best split for a regression node using MSE reduction.
1009///
1010/// Returns `(feature_index, threshold, weighted_mse_decrease)` or `None`.
1011fn find_best_regression_split<F: Float>(
1012    data: &RegressionData<'_, F>,
1013    indices: &[usize],
1014    min_samples_leaf: usize,
1015) -> Option<(usize, F, F)> {
1016    let n = indices.len();
1017    let n_f = F::from(n).unwrap();
1018    let n_features = data.x.ncols();
1019
1020    let parent_sum: F = indices
1021        .iter()
1022        .map(|&i| data.y[i])
1023        .fold(F::zero(), |a, b| a + b);
1024    let parent_sum_sq: F = indices
1025        .iter()
1026        .map(|&i| data.y[i] * data.y[i])
1027        .fold(F::zero(), |a, b| a + b);
1028    let parent_mse = parent_sum_sq / n_f - (parent_sum / n_f) * (parent_sum / n_f);
1029
1030    let mut best_score = F::neg_infinity();
1031    let mut best_feature = 0;
1032    let mut best_threshold = F::zero();
1033
1034    let feature_iter: Box<dyn Iterator<Item = usize>> =
1035        if let Some(feat_indices) = data.feature_indices {
1036            Box::new(feat_indices.iter().copied())
1037        } else {
1038            Box::new(0..n_features)
1039        };
1040
1041    for feat in feature_iter {
1042        let mut sorted_indices: Vec<usize> = indices.to_vec();
1043        sorted_indices.sort_by(|&a, &b| data.x[[a, feat]].partial_cmp(&data.x[[b, feat]]).unwrap());
1044
1045        let mut left_sum = F::zero();
1046        let mut left_sum_sq = F::zero();
1047        let mut left_n: usize = 0;
1048
1049        for split_pos in 0..n - 1 {
1050            let idx = sorted_indices[split_pos];
1051            let val = data.y[idx];
1052            left_sum = left_sum + val;
1053            left_sum_sq = left_sum_sq + val * val;
1054            left_n += 1;
1055            let right_n = n - left_n;
1056
1057            let next_idx = sorted_indices[split_pos + 1];
1058            if data.x[[idx, feat]] == data.x[[next_idx, feat]] {
1059                continue;
1060            }
1061
1062            if left_n < min_samples_leaf || right_n < min_samples_leaf {
1063                continue;
1064            }
1065
1066            let left_n_f = F::from(left_n).unwrap();
1067            let right_n_f = F::from(right_n).unwrap();
1068
1069            let left_mean = left_sum / left_n_f;
1070            let left_mse = left_sum_sq / left_n_f - left_mean * left_mean;
1071
1072            let right_sum = parent_sum - left_sum;
1073            let right_sum_sq = parent_sum_sq - left_sum_sq;
1074            let right_mean = right_sum / right_n_f;
1075            let right_mse = right_sum_sq / right_n_f - right_mean * right_mean;
1076
1077            let weighted_child_mse = (left_n_f * left_mse + right_n_f * right_mse) / n_f;
1078            let mse_decrease = parent_mse - weighted_child_mse;
1079
1080            if mse_decrease > best_score {
1081                best_score = mse_decrease;
1082                best_feature = feat;
1083                best_threshold =
1084                    (data.x[[idx, feat]] + data.x[[next_idx, feat]]) / F::from(2.0).unwrap();
1085            }
1086        }
1087    }
1088
1089    if best_score > F::zero() {
1090        Some((best_feature, best_threshold, best_score * n_f))
1091    } else {
1092        None
1093    }
1094}
1095
1096/// Compute normalised feature importances from impurity decreases in the tree.
1097pub(crate) fn compute_feature_importances<F: Float>(
1098    nodes: &[Node<F>],
1099    n_features: usize,
1100    _total_samples: usize,
1101) -> Array1<F> {
1102    let mut importances = Array1::zeros(n_features);
1103    for node in nodes {
1104        if let Node::Split {
1105            feature,
1106            impurity_decrease,
1107            ..
1108        } = node
1109        {
1110            importances[*feature] = importances[*feature] + *impurity_decrease;
1111        }
1112    }
1113    let total: F = importances.iter().copied().fold(F::zero(), |a, b| a + b);
1114    if total > F::zero() {
1115        importances.mapv_inplace(|v| v / total);
1116    }
1117    importances
1118}
1119
1120// ---------------------------------------------------------------------------
1121// Public builders for forest usage
1122// ---------------------------------------------------------------------------
1123
1124/// Build a classification tree with a subset of features considered per split.
1125///
1126/// Used internally by `RandomForestClassifier` to build individual trees.
1127#[allow(clippy::too_many_arguments)]
1128pub(crate) fn build_classification_tree_with_feature_subset<F: Float>(
1129    x: &Array2<F>,
1130    y: &[usize],
1131    n_classes: usize,
1132    indices: &[usize],
1133    feature_indices: &[usize],
1134    params: &TreeParams,
1135    criterion: ClassificationCriterion,
1136) -> Vec<Node<F>> {
1137    let data = ClassificationData {
1138        x,
1139        y,
1140        n_classes,
1141        feature_indices: Some(feature_indices),
1142        criterion,
1143    };
1144    let mut nodes = Vec::new();
1145    build_classification_tree(&data, indices, &mut nodes, 0, params);
1146    nodes
1147}
1148
1149/// Build a regression tree with a subset of features considered per split.
1150pub(crate) fn build_regression_tree_with_feature_subset<F: Float>(
1151    x: &Array2<F>,
1152    y: &Array1<F>,
1153    indices: &[usize],
1154    feature_indices: &[usize],
1155    params: &TreeParams,
1156) -> Vec<Node<F>> {
1157    let data = RegressionData {
1158        x,
1159        y,
1160        feature_indices: Some(feature_indices),
1161    };
1162    let mut nodes = Vec::new();
1163    build_regression_tree(&data, indices, &mut nodes, 0, params);
1164    nodes
1165}
1166
1167// ---------------------------------------------------------------------------
1168// Tests
1169// ---------------------------------------------------------------------------
1170
1171#[cfg(test)]
1172mod tests {
1173    use super::*;
1174    use approx::assert_relative_eq;
1175    use ndarray::array;
1176
1177    // -- Classifier tests --
1178
1179    #[test]
1180    fn test_classifier_simple_binary() {
1181        let x = Array2::from_shape_vec(
1182            (6, 2),
1183            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
1184        )
1185        .unwrap();
1186        let y = array![0, 0, 0, 1, 1, 1];
1187
1188        let model = DecisionTreeClassifier::<f64>::new();
1189        let fitted = model.fit(&x, &y).unwrap();
1190        let preds = fitted.predict(&x).unwrap();
1191
1192        assert_eq!(preds.len(), 6);
1193        for i in 0..3 {
1194            assert_eq!(preds[i], 0);
1195        }
1196        for i in 3..6 {
1197            assert_eq!(preds[i], 1);
1198        }
1199    }
1200
1201    #[test]
1202    fn test_classifier_single_class() {
1203        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1204        let y = array![0, 0, 0];
1205
1206        let model = DecisionTreeClassifier::<f64>::new();
1207        let fitted = model.fit(&x, &y).unwrap();
1208        let preds = fitted.predict(&x).unwrap();
1209
1210        assert_eq!(preds, array![0, 0, 0]);
1211    }
1212
1213    #[test]
1214    fn test_classifier_max_depth_1() {
1215        let x =
1216            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1217        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1218
1219        let model = DecisionTreeClassifier::<f64>::new().with_max_depth(Some(1));
1220        let fitted = model.fit(&x, &y).unwrap();
1221        let preds = fitted.predict(&x).unwrap();
1222
1223        for i in 0..4 {
1224            assert_eq!(preds[i], 0);
1225        }
1226        for i in 4..8 {
1227            assert_eq!(preds[i], 1);
1228        }
1229    }
1230
1231    #[test]
1232    fn test_classifier_min_samples_split() {
1233        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1234        let y = array![0, 0, 0, 1, 1, 1];
1235
1236        let model = DecisionTreeClassifier::<f64>::new().with_min_samples_split(7);
1237        let fitted = model.fit(&x, &y).unwrap();
1238        let preds = fitted.predict(&x).unwrap();
1239
1240        let majority = preds[0];
1241        for &p in &preds {
1242            assert_eq!(p, majority);
1243        }
1244    }
1245
1246    #[test]
1247    fn test_classifier_min_samples_leaf() {
1248        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1249        let y = array![0, 0, 0, 1, 1, 1];
1250
1251        let model = DecisionTreeClassifier::<f64>::new().with_min_samples_leaf(4);
1252        let fitted = model.fit(&x, &y).unwrap();
1253        let preds = fitted.predict(&x).unwrap();
1254
1255        let majority = preds[0];
1256        for &p in &preds {
1257            assert_eq!(p, majority);
1258        }
1259    }
1260
1261    #[test]
1262    fn test_classifier_gini_vs_entropy() {
1263        let x = Array2::from_shape_vec(
1264            (8, 2),
1265            vec![
1266                1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0, 5.0, 5.0, 5.0, 6.0, 6.0, 5.0, 6.0, 6.0,
1267            ],
1268        )
1269        .unwrap();
1270        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1271
1272        let gini_model =
1273            DecisionTreeClassifier::<f64>::new().with_criterion(ClassificationCriterion::Gini);
1274        let entropy_model =
1275            DecisionTreeClassifier::<f64>::new().with_criterion(ClassificationCriterion::Entropy);
1276
1277        let fitted_gini = gini_model.fit(&x, &y).unwrap();
1278        let fitted_entropy = entropy_model.fit(&x, &y).unwrap();
1279
1280        let preds_gini = fitted_gini.predict(&x).unwrap();
1281        let preds_entropy = fitted_entropy.predict(&x).unwrap();
1282
1283        assert_eq!(preds_gini, y);
1284        assert_eq!(preds_entropy, y);
1285    }
1286
1287    #[test]
1288    fn test_classifier_predict_proba() {
1289        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1290        let y = array![0, 0, 0, 1, 1, 1];
1291
1292        let model = DecisionTreeClassifier::<f64>::new();
1293        let fitted = model.fit(&x, &y).unwrap();
1294        let proba = fitted.predict_proba(&x).unwrap();
1295
1296        assert_eq!(proba.dim(), (6, 2));
1297        for i in 0..6 {
1298            let row_sum: f64 = proba.row(i).iter().sum();
1299            assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
1300        }
1301    }
1302
1303    #[test]
1304    fn test_classifier_shape_mismatch_fit() {
1305        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1306        let y = array![0, 1];
1307
1308        let model = DecisionTreeClassifier::<f64>::new();
1309        assert!(model.fit(&x, &y).is_err());
1310    }
1311
1312    #[test]
1313    fn test_classifier_shape_mismatch_predict() {
1314        let x =
1315            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1316        let y = array![0, 0, 1, 1];
1317
1318        let model = DecisionTreeClassifier::<f64>::new();
1319        let fitted = model.fit(&x, &y).unwrap();
1320
1321        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1322        assert!(fitted.predict(&x_bad).is_err());
1323    }
1324
1325    #[test]
1326    fn test_classifier_empty_data() {
1327        let x = Array2::<f64>::zeros((0, 2));
1328        let y = Array1::<usize>::zeros(0);
1329
1330        let model = DecisionTreeClassifier::<f64>::new();
1331        assert!(model.fit(&x, &y).is_err());
1332    }
1333
1334    #[test]
1335    fn test_classifier_feature_importances() {
1336        let x = Array2::from_shape_vec(
1337            (8, 2),
1338            vec![
1339                1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1340            ],
1341        )
1342        .unwrap();
1343        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1344
1345        let model = DecisionTreeClassifier::<f64>::new();
1346        let fitted = model.fit(&x, &y).unwrap();
1347        let importances = fitted.feature_importances();
1348
1349        assert_eq!(importances.len(), 2);
1350        assert!(importances[0] > 0.0);
1351        let sum: f64 = importances.iter().sum();
1352        assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
1353    }
1354
1355    #[test]
1356    fn test_classifier_has_classes() {
1357        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1358        let y = array![0, 1, 2, 0, 1, 2];
1359
1360        let model = DecisionTreeClassifier::<f64>::new();
1361        let fitted = model.fit(&x, &y).unwrap();
1362
1363        assert_eq!(fitted.classes(), &[0, 1, 2]);
1364        assert_eq!(fitted.n_classes(), 3);
1365    }
1366
1367    #[test]
1368    fn test_classifier_invalid_min_samples_split() {
1369        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1370        let y = array![0, 0, 1, 1];
1371
1372        let model = DecisionTreeClassifier::<f64>::new().with_min_samples_split(1);
1373        assert!(model.fit(&x, &y).is_err());
1374    }
1375
1376    #[test]
1377    fn test_classifier_invalid_min_samples_leaf() {
1378        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1379        let y = array![0, 0, 1, 1];
1380
1381        let model = DecisionTreeClassifier::<f64>::new().with_min_samples_leaf(0);
1382        assert!(model.fit(&x, &y).is_err());
1383    }
1384
1385    #[test]
1386    fn test_classifier_multiclass() {
1387        let x = Array2::from_shape_vec((9, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1388            .unwrap();
1389        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
1390
1391        let model = DecisionTreeClassifier::<f64>::new();
1392        let fitted = model.fit(&x, &y).unwrap();
1393        let preds = fitted.predict(&x).unwrap();
1394
1395        assert_eq!(preds, y);
1396    }
1397
1398    #[test]
1399    fn test_classifier_pipeline_integration() {
1400        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1401        let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
1402
1403        let model = DecisionTreeClassifier::<f64>::new();
1404        let fitted = model.fit_pipeline(&x, &y).unwrap();
1405        let preds = fitted.predict_pipeline(&x).unwrap();
1406        assert_eq!(preds.len(), 6);
1407    }
1408
1409    // -- Regressor tests --
1410
1411    #[test]
1412    fn test_regressor_simple() {
1413        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
1414        let y = array![1.0, 2.0, 3.0, 4.0, 5.0];
1415
1416        let model = DecisionTreeRegressor::<f64>::new();
1417        let fitted = model.fit(&x, &y).unwrap();
1418        let preds = fitted.predict(&x).unwrap();
1419
1420        for (p, &actual) in preds.iter().zip(y.iter()) {
1421            assert_relative_eq!(*p, actual, epsilon = 1e-10);
1422        }
1423    }
1424
1425    #[test]
1426    fn test_regressor_max_depth() {
1427        let x =
1428            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1429        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1430
1431        let model = DecisionTreeRegressor::<f64>::new().with_max_depth(Some(1));
1432        let fitted = model.fit(&x, &y).unwrap();
1433        let preds = fitted.predict(&x).unwrap();
1434
1435        for i in 0..4 {
1436            assert_relative_eq!(preds[i], 1.0, epsilon = 1e-10);
1437        }
1438        for i in 4..8 {
1439            assert_relative_eq!(preds[i], 5.0, epsilon = 1e-10);
1440        }
1441    }
1442
1443    #[test]
1444    fn test_regressor_constant_target() {
1445        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1446        let y = array![3.0, 3.0, 3.0, 3.0];
1447
1448        let model = DecisionTreeRegressor::<f64>::new();
1449        let fitted = model.fit(&x, &y).unwrap();
1450        let preds = fitted.predict(&x).unwrap();
1451
1452        for &p in &preds {
1453            assert_relative_eq!(p, 3.0, epsilon = 1e-10);
1454        }
1455    }
1456
1457    #[test]
1458    fn test_regressor_shape_mismatch_fit() {
1459        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1460        let y = array![1.0, 2.0];
1461
1462        let model = DecisionTreeRegressor::<f64>::new();
1463        assert!(model.fit(&x, &y).is_err());
1464    }
1465
1466    #[test]
1467    fn test_regressor_shape_mismatch_predict() {
1468        let x =
1469            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1470        let y = array![1.0, 2.0, 3.0, 4.0];
1471
1472        let model = DecisionTreeRegressor::<f64>::new();
1473        let fitted = model.fit(&x, &y).unwrap();
1474
1475        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1476        assert!(fitted.predict(&x_bad).is_err());
1477    }
1478
1479    #[test]
1480    fn test_regressor_empty_data() {
1481        let x = Array2::<f64>::zeros((0, 2));
1482        let y = Array1::<f64>::zeros(0);
1483
1484        let model = DecisionTreeRegressor::<f64>::new();
1485        assert!(model.fit(&x, &y).is_err());
1486    }
1487
1488    #[test]
1489    fn test_regressor_feature_importances() {
1490        let x = Array2::from_shape_vec(
1491            (8, 2),
1492            vec![
1493                1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1494            ],
1495        )
1496        .unwrap();
1497        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1498
1499        let model = DecisionTreeRegressor::<f64>::new();
1500        let fitted = model.fit(&x, &y).unwrap();
1501        let importances = fitted.feature_importances();
1502
1503        assert_eq!(importances.len(), 2);
1504        assert!(importances[0] > 0.0);
1505        let sum: f64 = importances.iter().sum();
1506        assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
1507    }
1508
1509    #[test]
1510    fn test_regressor_min_samples_split() {
1511        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1512        let y = array![1.0, 2.0, 3.0, 4.0];
1513
1514        let model = DecisionTreeRegressor::<f64>::new().with_min_samples_split(5);
1515        let fitted = model.fit(&x, &y).unwrap();
1516        let preds = fitted.predict(&x).unwrap();
1517
1518        let mean = 2.5;
1519        for &p in &preds {
1520            assert_relative_eq!(p, mean, epsilon = 1e-10);
1521        }
1522    }
1523
1524    #[test]
1525    fn test_regressor_pipeline_integration() {
1526        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1527        let y = array![1.0, 2.0, 3.0, 4.0];
1528
1529        let model = DecisionTreeRegressor::<f64>::new();
1530        let fitted = model.fit_pipeline(&x, &y).unwrap();
1531        let preds = fitted.predict_pipeline(&x).unwrap();
1532        assert_eq!(preds.len(), 4);
1533    }
1534
1535    #[test]
1536    fn test_regressor_f32_support() {
1537        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1538        let y = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
1539
1540        let model = DecisionTreeRegressor::<f32>::new();
1541        let fitted = model.fit(&x, &y).unwrap();
1542        let preds = fitted.predict(&x).unwrap();
1543        assert_eq!(preds.len(), 4);
1544    }
1545
1546    #[test]
1547    fn test_classifier_f32_support() {
1548        let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1549        let y = array![0, 0, 0, 1, 1, 1];
1550
1551        let model = DecisionTreeClassifier::<f32>::new();
1552        let fitted = model.fit(&x, &y).unwrap();
1553        let preds = fitted.predict(&x).unwrap();
1554        assert_eq!(preds.len(), 6);
1555    }
1556
1557    // -- Internal helper tests --
1558
1559    #[test]
1560    fn test_gini_impurity_pure() {
1561        let counts = vec![5, 0];
1562        let imp: f64 = gini_impurity(&counts, 5);
1563        assert_relative_eq!(imp, 0.0, epsilon = 1e-10);
1564    }
1565
1566    #[test]
1567    fn test_gini_impurity_balanced() {
1568        let counts = vec![5, 5];
1569        let imp: f64 = gini_impurity(&counts, 10);
1570        assert_relative_eq!(imp, 0.5, epsilon = 1e-10);
1571    }
1572
1573    #[test]
1574    fn test_entropy_pure() {
1575        let counts = vec![5, 0];
1576        let ent: f64 = entropy_impurity(&counts, 5);
1577        assert_relative_eq!(ent, 0.0, epsilon = 1e-10);
1578    }
1579
1580    #[test]
1581    fn test_entropy_balanced() {
1582        let counts = vec![5, 5];
1583        let ent: f64 = entropy_impurity(&counts, 10);
1584        assert_relative_eq!(ent, 2.0f64.ln(), epsilon = 1e-10);
1585    }
1586}