Skip to main content

ferrolearn_tree/
extra_tree.rs

1//! Extremely randomized tree classifiers and regressors.
2//!
3//! This module provides [`ExtraTreeClassifier`] and [`ExtraTreeRegressor`],
4//! which are variants of decision trees where split thresholds are chosen
5//! randomly rather than via exhaustive search. For each candidate feature,
6//! a random threshold is drawn uniformly between the feature's minimum and
7//! maximum values in the current node.
8//!
9//! # Examples
10//!
11//! ```
12//! use ferrolearn_tree::ExtraTreeClassifier;
13//! use ferrolearn_core::{Fit, Predict};
14//! use ndarray::{array, Array1, Array2};
15//!
16//! let x = Array2::from_shape_vec((6, 2), vec![
17//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,
18//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,
19//! ]).unwrap();
20//! let y = array![0, 0, 0, 1, 1, 1];
21//!
22//! let model = ExtraTreeClassifier::<f64>::new()
23//!     .with_random_state(42);
24//! let fitted = model.fit(&x, &y).unwrap();
25//! let preds = fitted.predict(&x).unwrap();
26//! ```
27
28use ferrolearn_core::error::FerroError;
29use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
30use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
31use ferrolearn_core::traits::{Fit, Predict};
32use ndarray::{Array1, Array2};
33use num_traits::{Float, FromPrimitive, ToPrimitive};
34use rand::SeedableRng;
35use rand::rngs::StdRng;
36use rand::seq::index::sample as rand_sample_indices;
37use serde::{Deserialize, Serialize};
38
39use crate::decision_tree::{
40    ClassificationCriterion, Node, RegressionCriterion, TreeParams, compute_feature_importances,
41    traverse,
42};
43use crate::random_forest::MaxFeatures;
44
45// ---------------------------------------------------------------------------
46// Internal data structs for extra-tree building
47// ---------------------------------------------------------------------------
48
49/// Data references for classification extra-tree building.
50struct ClassificationData<'a, F> {
51    x: &'a Array2<F>,
52    y: &'a [usize],
53    n_classes: usize,
54    feature_indices: Option<&'a [usize]>,
55    criterion: ClassificationCriterion,
56}
57
58/// Data references for regression extra-tree building.
59struct RegressionData<'a, F> {
60    x: &'a Array2<F>,
61    y: &'a Array1<F>,
62    feature_indices: Option<&'a [usize]>,
63}
64
65// ---------------------------------------------------------------------------
66// ExtraTreeClassifier
67// ---------------------------------------------------------------------------
68
69/// Extremely randomized tree classifier.
70///
71/// Like a [`DecisionTreeClassifier`](crate::DecisionTreeClassifier), but split
72/// thresholds are chosen randomly rather than via exhaustive search. For each
73/// candidate feature, a random threshold is drawn uniformly between the
74/// feature's minimum and maximum values in the current node.
75///
76/// # Type Parameters
77///
78/// - `F`: The floating-point type (`f32` or `f64`).
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ExtraTreeClassifier<F> {
81    /// Maximum depth of the tree. `None` means unlimited.
82    pub max_depth: Option<usize>,
83    /// Minimum number of samples required to split an internal node.
84    pub min_samples_split: usize,
85    /// Minimum number of samples required in a leaf node.
86    pub min_samples_leaf: usize,
87    /// Strategy for the number of features considered at each split.
88    pub max_features: MaxFeatures,
89    /// Splitting criterion.
90    pub criterion: ClassificationCriterion,
91    /// Random seed for reproducibility. `None` means non-deterministic.
92    pub random_state: Option<u64>,
93    _marker: std::marker::PhantomData<F>,
94}
95
96impl<F: Float> ExtraTreeClassifier<F> {
97    /// Create a new `ExtraTreeClassifier` with default settings.
98    ///
99    /// Defaults: `max_depth = None`, `min_samples_split = 2`,
100    /// `min_samples_leaf = 1`, `max_features = Sqrt`,
101    /// `criterion = Gini`, `random_state = None`.
102    #[must_use]
103    pub fn new() -> Self {
104        Self {
105            max_depth: None,
106            min_samples_split: 2,
107            min_samples_leaf: 1,
108            max_features: MaxFeatures::Sqrt,
109            criterion: ClassificationCriterion::Gini,
110            random_state: None,
111            _marker: std::marker::PhantomData,
112        }
113    }
114
115    /// Set the maximum tree depth.
116    #[must_use]
117    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
118        self.max_depth = max_depth;
119        self
120    }
121
122    /// Set the minimum number of samples required to split a node.
123    #[must_use]
124    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
125        self.min_samples_split = min_samples_split;
126        self
127    }
128
129    /// Set the minimum number of samples required in a leaf node.
130    #[must_use]
131    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
132        self.min_samples_leaf = min_samples_leaf;
133        self
134    }
135
136    /// Set the maximum features strategy.
137    #[must_use]
138    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
139        self.max_features = max_features;
140        self
141    }
142
143    /// Set the splitting criterion.
144    #[must_use]
145    pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
146        self.criterion = criterion;
147        self
148    }
149
150    /// Set the random seed for reproducibility.
151    #[must_use]
152    pub fn with_random_state(mut self, seed: u64) -> Self {
153        self.random_state = Some(seed);
154        self
155    }
156}
157
158impl<F: Float> Default for ExtraTreeClassifier<F> {
159    fn default() -> Self {
160        Self::new()
161    }
162}
163
164// ---------------------------------------------------------------------------
165// FittedExtraTreeClassifier
166// ---------------------------------------------------------------------------
167
168/// A fitted extremely randomized tree classifier.
169///
170/// Stores the learned tree as a flat `Vec<Node<F>>` for cache-friendly traversal.
171/// Implements [`Predict`] for generating class predictions and
172/// [`HasFeatureImportances`] for inspecting per-feature importance scores.
173#[derive(Debug, Clone)]
174pub struct FittedExtraTreeClassifier<F> {
175    /// Flat node storage; index 0 is the root.
176    nodes: Vec<Node<F>>,
177    /// Sorted unique class labels observed during training.
178    classes: Vec<usize>,
179    /// Number of features the model was trained on.
180    n_features: usize,
181    /// Per-feature importance scores (normalised to sum to 1).
182    feature_importances: Array1<F>,
183}
184
185impl<F: Float + Send + Sync + 'static> FittedExtraTreeClassifier<F> {
186    /// Returns a reference to the flat node storage of the tree.
187    #[must_use]
188    pub fn nodes(&self) -> &[Node<F>] {
189        &self.nodes
190    }
191
192    /// Returns the number of features the model was trained on.
193    #[must_use]
194    pub fn n_features(&self) -> usize {
195        self.n_features
196    }
197
198    /// Predict class probabilities for each sample.
199    ///
200    /// Returns a 2-D array of shape `(n_samples, n_classes)`.
201    ///
202    /// # Errors
203    ///
204    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
205    /// not match the training data.
206    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
207        if x.ncols() != self.n_features {
208            return Err(FerroError::ShapeMismatch {
209                expected: vec![self.n_features],
210                actual: vec![x.ncols()],
211                context: "number of features must match fitted model".into(),
212            });
213        }
214        let n_samples = x.nrows();
215        let n_classes = self.classes.len();
216        let mut proba = Array2::zeros((n_samples, n_classes));
217        for i in 0..n_samples {
218            let row = x.row(i);
219            let leaf = traverse(&self.nodes, &row);
220            if let Node::Leaf {
221                class_distribution: Some(ref dist),
222                ..
223            } = self.nodes[leaf]
224            {
225                for (j, &p) in dist.iter().enumerate() {
226                    proba[[i, j]] = p;
227                }
228            }
229        }
230        Ok(proba)
231    }
232
233    /// Mean accuracy on the given test data and labels.
234    /// Equivalent to sklearn's `ClassifierMixin.score`.
235    ///
236    /// # Errors
237    ///
238    /// Returns [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()` or
239    /// the feature count does not match the training data.
240    pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
241        if x.nrows() != y.len() {
242            return Err(FerroError::ShapeMismatch {
243                expected: vec![x.nrows()],
244                actual: vec![y.len()],
245                context: "y length must match number of samples in X".into(),
246            });
247        }
248        let preds = self.predict(x)?;
249        Ok(crate::mean_accuracy(&preds, y))
250    }
251
252    /// Element-wise log of [`predict_proba`](Self::predict_proba). Mirrors
253    /// sklearn's `ClassifierMixin.predict_log_proba`.
254    ///
255    /// # Errors
256    ///
257    /// Forwards any error from [`predict_proba`](Self::predict_proba).
258    pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
259        let proba = self.predict_proba(x)?;
260        Ok(crate::log_proba(&proba))
261    }
262}
263
264impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ExtraTreeClassifier<F> {
265    type Fitted = FittedExtraTreeClassifier<F>;
266    type Error = FerroError;
267
268    /// Fit the extra-tree classifier on the training data.
269    ///
270    /// # Errors
271    ///
272    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
273    /// numbers of samples.
274    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
275    /// Returns [`FerroError::InvalidParameter`] if hyperparameters are invalid.
276    fn fit(
277        &self,
278        x: &Array2<F>,
279        y: &Array1<usize>,
280    ) -> Result<FittedExtraTreeClassifier<F>, FerroError> {
281        let (n_samples, n_features) = x.dim();
282
283        if n_samples != y.len() {
284            return Err(FerroError::ShapeMismatch {
285                expected: vec![n_samples],
286                actual: vec![y.len()],
287                context: "y length must match number of samples in X".into(),
288            });
289        }
290        if n_samples == 0 {
291            return Err(FerroError::InsufficientSamples {
292                required: 1,
293                actual: 0,
294                context: "ExtraTreeClassifier requires at least one sample".into(),
295            });
296        }
297        if self.min_samples_split < 2 {
298            return Err(FerroError::InvalidParameter {
299                name: "min_samples_split".into(),
300                reason: "must be at least 2".into(),
301            });
302        }
303        if self.min_samples_leaf < 1 {
304            return Err(FerroError::InvalidParameter {
305                name: "min_samples_leaf".into(),
306                reason: "must be at least 1".into(),
307            });
308        }
309
310        // Determine unique classes.
311        let mut classes: Vec<usize> = y.iter().copied().collect();
312        classes.sort_unstable();
313        classes.dedup();
314        let n_classes = classes.len();
315
316        // Map class labels to indices 0..n_classes.
317        let y_mapped: Vec<usize> = y
318            .iter()
319            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
320            .collect();
321
322        let indices: Vec<usize> = (0..n_samples).collect();
323
324        let max_features_n = resolve_max_features(self.max_features, n_features);
325
326        let mut rng = if let Some(seed) = self.random_state {
327            StdRng::seed_from_u64(seed)
328        } else {
329            StdRng::from_os_rng()
330        };
331
332        let data = ClassificationData {
333            x,
334            y: &y_mapped,
335            n_classes,
336            feature_indices: None,
337            criterion: self.criterion,
338        };
339        let params = TreeParams {
340            max_depth: self.max_depth,
341            min_samples_split: self.min_samples_split,
342            min_samples_leaf: self.min_samples_leaf,
343        };
344
345        let mut nodes: Vec<Node<F>> = Vec::new();
346        build_extra_classification_tree(
347            &data,
348            &indices,
349            &mut nodes,
350            0,
351            &params,
352            n_features,
353            max_features_n,
354            &mut rng,
355        );
356
357        let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
358
359        Ok(FittedExtraTreeClassifier {
360            nodes,
361            classes,
362            n_features,
363            feature_importances,
364        })
365    }
366}
367
368impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreeClassifier<F> {
369    type Output = Array1<usize>;
370    type Error = FerroError;
371
372    /// Predict class labels for the given feature matrix.
373    ///
374    /// # Errors
375    ///
376    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
377    /// not match the fitted model.
378    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
379        if x.ncols() != self.n_features {
380            return Err(FerroError::ShapeMismatch {
381                expected: vec![self.n_features],
382                actual: vec![x.ncols()],
383                context: "number of features must match fitted model".into(),
384            });
385        }
386        let n_samples = x.nrows();
387        let mut predictions = Array1::zeros(n_samples);
388        for i in 0..n_samples {
389            let row = x.row(i);
390            let leaf = traverse(&self.nodes, &row);
391            if let Node::Leaf { value, .. } = self.nodes[leaf] {
392                predictions[i] = float_to_usize(value);
393            }
394        }
395        Ok(predictions)
396    }
397}
398
399impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreeClassifier<F> {
400    fn feature_importances(&self) -> &Array1<F> {
401        &self.feature_importances
402    }
403}
404
405impl<F: Float + Send + Sync + 'static> HasClasses for FittedExtraTreeClassifier<F> {
406    fn classes(&self) -> &[usize] {
407        &self.classes
408    }
409
410    fn n_classes(&self) -> usize {
411        self.classes.len()
412    }
413}
414
415// Pipeline integration.
416impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
417    for ExtraTreeClassifier<F>
418{
419    fn fit_pipeline(
420        &self,
421        x: &Array2<F>,
422        y: &Array1<F>,
423    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
424        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
425        let fitted = self.fit(x, &y_usize)?;
426        Ok(Box::new(FittedExtraTreeClassifierPipelineAdapter(fitted)))
427    }
428}
429
430/// Pipeline adapter for `FittedExtraTreeClassifier<F>`.
431struct FittedExtraTreeClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
432    FittedExtraTreeClassifier<F>,
433);
434
435impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
436    for FittedExtraTreeClassifierPipelineAdapter<F>
437{
438    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
439        let preds = self.0.predict(x)?;
440        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
441    }
442}
443
444// ---------------------------------------------------------------------------
445// ExtraTreeRegressor
446// ---------------------------------------------------------------------------
447
448/// Extremely randomized tree regressor.
449///
450/// Like a [`DecisionTreeRegressor`](crate::DecisionTreeRegressor), but split
451/// thresholds are chosen randomly rather than via exhaustive search. For each
452/// candidate feature, a random threshold is drawn uniformly between the
453/// feature's minimum and maximum values in the current node.
454///
455/// # Type Parameters
456///
457/// - `F`: The floating-point type (`f32` or `f64`).
458#[derive(Debug, Clone, Serialize, Deserialize)]
459pub struct ExtraTreeRegressor<F> {
460    /// Maximum depth of the tree. `None` means unlimited.
461    pub max_depth: Option<usize>,
462    /// Minimum number of samples required to split an internal node.
463    pub min_samples_split: usize,
464    /// Minimum number of samples required in a leaf node.
465    pub min_samples_leaf: usize,
466    /// Strategy for the number of features considered at each split.
467    pub max_features: MaxFeatures,
468    /// Splitting criterion.
469    pub criterion: RegressionCriterion,
470    /// Random seed for reproducibility. `None` means non-deterministic.
471    pub random_state: Option<u64>,
472    _marker: std::marker::PhantomData<F>,
473}
474
475impl<F: Float> ExtraTreeRegressor<F> {
476    /// Create a new `ExtraTreeRegressor` with default settings.
477    ///
478    /// Defaults: `max_depth = None`, `min_samples_split = 2`,
479    /// `min_samples_leaf = 1`, `max_features = All`,
480    /// `criterion = MSE`, `random_state = None`.
481    #[must_use]
482    pub fn new() -> Self {
483        Self {
484            max_depth: None,
485            min_samples_split: 2,
486            min_samples_leaf: 1,
487            max_features: MaxFeatures::All,
488            criterion: RegressionCriterion::Mse,
489            random_state: None,
490            _marker: std::marker::PhantomData,
491        }
492    }
493
494    /// Set the maximum tree depth.
495    #[must_use]
496    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
497        self.max_depth = max_depth;
498        self
499    }
500
501    /// Set the minimum number of samples required to split a node.
502    #[must_use]
503    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
504        self.min_samples_split = min_samples_split;
505        self
506    }
507
508    /// Set the minimum number of samples required in a leaf node.
509    #[must_use]
510    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
511        self.min_samples_leaf = min_samples_leaf;
512        self
513    }
514
515    /// Set the maximum features strategy.
516    #[must_use]
517    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
518        self.max_features = max_features;
519        self
520    }
521
522    /// Set the splitting criterion.
523    #[must_use]
524    pub fn with_criterion(mut self, criterion: RegressionCriterion) -> Self {
525        self.criterion = criterion;
526        self
527    }
528
529    /// Set the random seed for reproducibility.
530    #[must_use]
531    pub fn with_random_state(mut self, seed: u64) -> Self {
532        self.random_state = Some(seed);
533        self
534    }
535}
536
537impl<F: Float> Default for ExtraTreeRegressor<F> {
538    fn default() -> Self {
539        Self::new()
540    }
541}
542
543// ---------------------------------------------------------------------------
544// FittedExtraTreeRegressor
545// ---------------------------------------------------------------------------
546
547/// A fitted extremely randomized tree regressor.
548///
549/// Stores the learned tree as a flat `Vec<Node<F>>` for cache-friendly traversal.
550#[derive(Debug, Clone)]
551pub struct FittedExtraTreeRegressor<F> {
552    /// Flat node storage; index 0 is the root.
553    nodes: Vec<Node<F>>,
554    /// Number of features the model was trained on.
555    n_features: usize,
556    /// Per-feature importance scores (normalised to sum to 1).
557    feature_importances: Array1<F>,
558}
559
560impl<F: Float + Send + Sync + 'static> FittedExtraTreeRegressor<F> {
561    /// Returns a reference to the flat node storage of the tree.
562    #[must_use]
563    pub fn nodes(&self) -> &[Node<F>] {
564        &self.nodes
565    }
566
567    /// Returns the number of features the model was trained on.
568    #[must_use]
569    pub fn n_features(&self) -> usize {
570        self.n_features
571    }
572
573    /// R² coefficient of determination on the given test data.
574    /// Equivalent to sklearn's `RegressorMixin.score`.
575    ///
576    /// # Errors
577    ///
578    /// Returns [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()` or
579    /// the feature count does not match the training data.
580    pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<F, FerroError> {
581        if x.nrows() != y.len() {
582            return Err(FerroError::ShapeMismatch {
583                expected: vec![x.nrows()],
584                actual: vec![y.len()],
585                context: "y length must match number of samples in X".into(),
586            });
587        }
588        let preds = self.predict(x)?;
589        Ok(crate::r2_score(&preds, y))
590    }
591}
592
593impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for ExtraTreeRegressor<F> {
594    type Fitted = FittedExtraTreeRegressor<F>;
595    type Error = FerroError;
596
597    /// Fit the extra-tree regressor on the training data.
598    ///
599    /// # Errors
600    ///
601    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
602    /// numbers of samples.
603    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
604    /// Returns [`FerroError::InvalidParameter`] if hyperparameters are invalid.
605    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedExtraTreeRegressor<F>, FerroError> {
606        let (n_samples, n_features) = x.dim();
607
608        if n_samples != y.len() {
609            return Err(FerroError::ShapeMismatch {
610                expected: vec![n_samples],
611                actual: vec![y.len()],
612                context: "y length must match number of samples in X".into(),
613            });
614        }
615        if n_samples == 0 {
616            return Err(FerroError::InsufficientSamples {
617                required: 1,
618                actual: 0,
619                context: "ExtraTreeRegressor requires at least one sample".into(),
620            });
621        }
622        if self.min_samples_split < 2 {
623            return Err(FerroError::InvalidParameter {
624                name: "min_samples_split".into(),
625                reason: "must be at least 2".into(),
626            });
627        }
628        if self.min_samples_leaf < 1 {
629            return Err(FerroError::InvalidParameter {
630                name: "min_samples_leaf".into(),
631                reason: "must be at least 1".into(),
632            });
633        }
634
635        let indices: Vec<usize> = (0..n_samples).collect();
636        let max_features_n = resolve_max_features(self.max_features, n_features);
637
638        let mut rng = if let Some(seed) = self.random_state {
639            StdRng::seed_from_u64(seed)
640        } else {
641            StdRng::from_os_rng()
642        };
643
644        let data = RegressionData {
645            x,
646            y,
647            feature_indices: None,
648        };
649        let params = TreeParams {
650            max_depth: self.max_depth,
651            min_samples_split: self.min_samples_split,
652            min_samples_leaf: self.min_samples_leaf,
653        };
654
655        let mut nodes: Vec<Node<F>> = Vec::new();
656        build_extra_regression_tree(
657            &data,
658            &indices,
659            &mut nodes,
660            0,
661            &params,
662            n_features,
663            max_features_n,
664            &mut rng,
665        );
666
667        let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
668
669        Ok(FittedExtraTreeRegressor {
670            nodes,
671            n_features,
672            feature_importances,
673        })
674    }
675}
676
677impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreeRegressor<F> {
678    type Output = Array1<F>;
679    type Error = FerroError;
680
681    /// Predict target values for the given feature matrix.
682    ///
683    /// # Errors
684    ///
685    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
686    /// not match the fitted model.
687    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
688        if x.ncols() != self.n_features {
689            return Err(FerroError::ShapeMismatch {
690                expected: vec![self.n_features],
691                actual: vec![x.ncols()],
692                context: "number of features must match fitted model".into(),
693            });
694        }
695        let n_samples = x.nrows();
696        let mut predictions = Array1::zeros(n_samples);
697        for i in 0..n_samples {
698            let row = x.row(i);
699            let leaf = traverse(&self.nodes, &row);
700            if let Node::Leaf { value, .. } = self.nodes[leaf] {
701                predictions[i] = value;
702            }
703        }
704        Ok(predictions)
705    }
706}
707
708impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreeRegressor<F> {
709    fn feature_importances(&self) -> &Array1<F> {
710        &self.feature_importances
711    }
712}
713
714// Pipeline integration.
715impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for ExtraTreeRegressor<F> {
716    fn fit_pipeline(
717        &self,
718        x: &Array2<F>,
719        y: &Array1<F>,
720    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
721        let fitted = self.fit(x, y)?;
722        Ok(Box::new(fitted))
723    }
724}
725
726impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedExtraTreeRegressor<F> {
727    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
728        self.predict(x)
729    }
730}
731
732// ---------------------------------------------------------------------------
733// Internal: helpers
734// ---------------------------------------------------------------------------
735
736/// Resolve the `MaxFeatures` strategy to a concrete number.
737fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
738    let result = match strategy {
739        MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
740        MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
741        MaxFeatures::All => n_features,
742        MaxFeatures::Fixed(n) => n.min(n_features),
743        MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
744    };
745    result.max(1).min(n_features)
746}
747
748/// Convert a `Float` value to `usize` (for class labels stored as floats).
749fn float_to_usize<F: Float>(v: F) -> usize {
750    v.to_f64().map_or(0, |f| f.round() as usize)
751}
752
753/// Generate a uniform random float in `[min_val, max_val]`.
754fn random_threshold<F: Float>(rng: &mut StdRng, min_val: F, max_val: F) -> F {
755    use rand::RngCore;
756    // Generate a random f64 in [0, 1) and scale to [min_val, max_val].
757    let u = (rng.next_u64() as f64) / (u64::MAX as f64);
758    let range = max_val - min_val;
759    min_val + F::from(u).unwrap() * range
760}
761
762/// Compute the Gini impurity for a set of class counts.
763fn gini_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
764    if total == 0 {
765        return F::zero();
766    }
767    let total_f = F::from(total).unwrap();
768    let mut impurity = F::one();
769    for &count in class_counts {
770        let p = F::from(count).unwrap() / total_f;
771        impurity = impurity - p * p;
772    }
773    impurity
774}
775
776/// Compute the Shannon entropy for a set of class counts.
777fn entropy_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
778    if total == 0 {
779        return F::zero();
780    }
781    let total_f = F::from(total).unwrap();
782    let mut ent = F::zero();
783    for &count in class_counts {
784        if count > 0 {
785            let p = F::from(count).unwrap() / total_f;
786            ent = ent - p * p.ln();
787        }
788    }
789    ent
790}
791
792/// Compute impurity for a given classification criterion.
793fn compute_impurity<F: Float>(
794    class_counts: &[usize],
795    total: usize,
796    criterion: ClassificationCriterion,
797) -> F {
798    match criterion {
799        ClassificationCriterion::Gini => gini_impurity(class_counts, total),
800        ClassificationCriterion::Entropy => entropy_impurity(class_counts, total),
801    }
802}
803
804/// Create a classification leaf node and return its index.
805fn make_classification_leaf<F: Float>(
806    nodes: &mut Vec<Node<F>>,
807    class_counts: &[usize],
808    n_classes: usize,
809    n_samples: usize,
810) -> usize {
811    let majority_class = class_counts
812        .iter()
813        .enumerate()
814        .max_by_key(|&(_, &count)| count)
815        .map_or(0, |(i, _)| i);
816
817    let total_f = if n_samples > 0 {
818        F::from(n_samples).unwrap()
819    } else {
820        F::one()
821    };
822    let distribution: Vec<F> = (0..n_classes)
823        .map(|c| F::from(class_counts[c]).unwrap() / total_f)
824        .collect();
825
826    let idx = nodes.len();
827    nodes.push(Node::Leaf {
828        value: F::from(majority_class).unwrap(),
829        class_distribution: Some(distribution),
830        n_samples,
831    });
832    idx
833}
834
835/// Compute the mean of target values for the given indices.
836fn mean_value<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
837    if indices.is_empty() {
838        return F::zero();
839    }
840    let sum: F = indices.iter().map(|&i| y[i]).fold(F::zero(), |a, b| a + b);
841    sum / F::from(indices.len()).unwrap()
842}
843
844// ---------------------------------------------------------------------------
845// Extra-tree classification building
846// ---------------------------------------------------------------------------
847
848/// Build an extra-tree classification tree recursively with random thresholds.
849///
850/// At each node, a random subset of features is considered, and for each feature
851/// a random threshold is drawn uniformly between the feature's min and max in the
852/// current node.
853#[allow(clippy::too_many_arguments)]
854fn build_extra_classification_tree<F: Float>(
855    data: &ClassificationData<'_, F>,
856    indices: &[usize],
857    nodes: &mut Vec<Node<F>>,
858    depth: usize,
859    params: &TreeParams,
860    n_features: usize,
861    max_features_n: usize,
862    rng: &mut StdRng,
863) -> usize {
864    let n = indices.len();
865
866    let mut class_counts = vec![0usize; data.n_classes];
867    for &i in indices {
868        class_counts[data.y[i]] += 1;
869    }
870
871    let should_stop = n < params.min_samples_split
872        || params.max_depth.is_some_and(|d| depth >= d)
873        || class_counts.iter().filter(|&&c| c > 0).count() <= 1;
874
875    if should_stop {
876        return make_classification_leaf(nodes, &class_counts, data.n_classes, n);
877    }
878
879    let best = find_random_classification_split(
880        data,
881        indices,
882        params.min_samples_leaf,
883        n_features,
884        max_features_n,
885        rng,
886    );
887
888    if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
889        let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
890            .iter()
891            .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
892
893        // Ensure both children have at least min_samples_leaf.
894        if left_indices.len() < params.min_samples_leaf
895            || right_indices.len() < params.min_samples_leaf
896        {
897            return make_classification_leaf(nodes, &class_counts, data.n_classes, n);
898        }
899
900        let node_idx = nodes.len();
901        nodes.push(Node::Leaf {
902            value: F::zero(),
903            class_distribution: None,
904            n_samples: 0,
905        }); // placeholder
906
907        let left_idx = build_extra_classification_tree(
908            data,
909            &left_indices,
910            nodes,
911            depth + 1,
912            params,
913            n_features,
914            max_features_n,
915            rng,
916        );
917        let right_idx = build_extra_classification_tree(
918            data,
919            &right_indices,
920            nodes,
921            depth + 1,
922            params,
923            n_features,
924            max_features_n,
925            rng,
926        );
927
928        nodes[node_idx] = Node::Split {
929            feature: best_feature,
930            threshold: best_threshold,
931            left: left_idx,
932            right: right_idx,
933            impurity_decrease: best_impurity_decrease,
934            n_samples: n,
935        };
936
937        node_idx
938    } else {
939        make_classification_leaf(nodes, &class_counts, data.n_classes, n)
940    }
941}
942
943/// Find the best random split for a classification node.
944///
945/// For each candidate feature (from a random subset), pick a random threshold
946/// between min and max of that feature in the current node. Return the split
947/// with the largest impurity decrease, or `None` if no valid split exists.
948#[allow(clippy::too_many_arguments)]
949fn find_random_classification_split<F: Float>(
950    data: &ClassificationData<'_, F>,
951    indices: &[usize],
952    min_samples_leaf: usize,
953    n_features: usize,
954    max_features_n: usize,
955    rng: &mut StdRng,
956) -> Option<(usize, F, F)> {
957    let n = indices.len();
958    let n_f = F::from(n).unwrap();
959
960    let mut parent_counts = vec![0usize; data.n_classes];
961    for &i in indices {
962        parent_counts[data.y[i]] += 1;
963    }
964    let parent_impurity = compute_impurity::<F>(&parent_counts, n, data.criterion);
965
966    let mut best_score = F::neg_infinity();
967    let mut best_feature = 0;
968    let mut best_threshold = F::zero();
969
970    // Select random feature subset.
971    let feature_subset: Vec<usize> = if let Some(feat_indices) = data.feature_indices {
972        // If a feature subset is provided externally, sample from it.
973        let k = max_features_n.min(feat_indices.len());
974        rand_sample_indices(rng, feat_indices.len(), k)
975            .into_vec()
976            .into_iter()
977            .map(|i| feat_indices[i])
978            .collect()
979    } else {
980        let k = max_features_n.min(n_features);
981        rand_sample_indices(rng, n_features, k).into_vec()
982    };
983
984    for feat in feature_subset {
985        // Find min and max of this feature in the current node.
986        let mut feat_min = F::infinity();
987        let mut feat_max = F::neg_infinity();
988        for &i in indices {
989            let val = data.x[[i, feat]];
990            if val < feat_min {
991                feat_min = val;
992            }
993            if val > feat_max {
994                feat_max = val;
995            }
996        }
997
998        // Skip constant features.
999        if feat_min >= feat_max {
1000            continue;
1001        }
1002
1003        // Draw a random threshold uniformly in (min, max).
1004        let threshold = random_threshold(rng, feat_min, feat_max);
1005
1006        // Count left and right.
1007        let mut left_counts = vec![0usize; data.n_classes];
1008        let mut right_counts = vec![0usize; data.n_classes];
1009        let mut left_n = 0usize;
1010
1011        for &i in indices {
1012            let cls = data.y[i];
1013            if data.x[[i, feat]] <= threshold {
1014                left_counts[cls] += 1;
1015                left_n += 1;
1016            } else {
1017                right_counts[cls] += 1;
1018            }
1019        }
1020
1021        let right_n = n - left_n;
1022        if left_n < min_samples_leaf || right_n < min_samples_leaf {
1023            continue;
1024        }
1025
1026        let left_impurity = compute_impurity::<F>(&left_counts, left_n, data.criterion);
1027        let right_impurity = compute_impurity::<F>(&right_counts, right_n, data.criterion);
1028        let left_weight = F::from(left_n).unwrap() / n_f;
1029        let right_weight = F::from(right_n).unwrap() / n_f;
1030        let weighted_child_impurity = left_weight * left_impurity + right_weight * right_impurity;
1031        let impurity_decrease = parent_impurity - weighted_child_impurity;
1032
1033        if impurity_decrease > best_score {
1034            best_score = impurity_decrease;
1035            best_feature = feat;
1036            best_threshold = threshold;
1037        }
1038    }
1039
1040    if best_score > F::zero() {
1041        Some((best_feature, best_threshold, best_score * n_f))
1042    } else {
1043        None
1044    }
1045}
1046
1047// ---------------------------------------------------------------------------
1048// Extra-tree regression building
1049// ---------------------------------------------------------------------------
1050
1051/// Build an extra-tree regression tree recursively with random thresholds.
1052#[allow(clippy::too_many_arguments)]
1053fn build_extra_regression_tree<F: Float>(
1054    data: &RegressionData<'_, F>,
1055    indices: &[usize],
1056    nodes: &mut Vec<Node<F>>,
1057    depth: usize,
1058    params: &TreeParams,
1059    n_features: usize,
1060    max_features_n: usize,
1061    rng: &mut StdRng,
1062) -> usize {
1063    let n = indices.len();
1064    let mean = mean_value(data.y, indices);
1065
1066    let should_stop = n < params.min_samples_split || params.max_depth.is_some_and(|d| depth >= d);
1067
1068    if should_stop {
1069        let idx = nodes.len();
1070        nodes.push(Node::Leaf {
1071            value: mean,
1072            class_distribution: None,
1073            n_samples: n,
1074        });
1075        return idx;
1076    }
1077
1078    // Check if variance is essentially zero.
1079    let parent_sum_sq: F = indices
1080        .iter()
1081        .map(|&i| {
1082            let diff = data.y[i] - mean;
1083            diff * diff
1084        })
1085        .fold(F::zero(), |a, b| a + b);
1086    let parent_mse = parent_sum_sq / F::from(n).unwrap();
1087
1088    if parent_mse <= F::epsilon() {
1089        let idx = nodes.len();
1090        nodes.push(Node::Leaf {
1091            value: mean,
1092            class_distribution: None,
1093            n_samples: n,
1094        });
1095        return idx;
1096    }
1097
1098    let best = find_random_regression_split(
1099        data,
1100        indices,
1101        params.min_samples_leaf,
1102        n_features,
1103        max_features_n,
1104        rng,
1105    );
1106
1107    if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
1108        let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
1109            .iter()
1110            .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
1111
1112        // Ensure both children have at least min_samples_leaf.
1113        if left_indices.len() < params.min_samples_leaf
1114            || right_indices.len() < params.min_samples_leaf
1115        {
1116            let idx = nodes.len();
1117            nodes.push(Node::Leaf {
1118                value: mean,
1119                class_distribution: None,
1120                n_samples: n,
1121            });
1122            return idx;
1123        }
1124
1125        let node_idx = nodes.len();
1126        nodes.push(Node::Leaf {
1127            value: F::zero(),
1128            class_distribution: None,
1129            n_samples: 0,
1130        }); // placeholder
1131
1132        let left_idx = build_extra_regression_tree(
1133            data,
1134            &left_indices,
1135            nodes,
1136            depth + 1,
1137            params,
1138            n_features,
1139            max_features_n,
1140            rng,
1141        );
1142        let right_idx = build_extra_regression_tree(
1143            data,
1144            &right_indices,
1145            nodes,
1146            depth + 1,
1147            params,
1148            n_features,
1149            max_features_n,
1150            rng,
1151        );
1152
1153        nodes[node_idx] = Node::Split {
1154            feature: best_feature,
1155            threshold: best_threshold,
1156            left: left_idx,
1157            right: right_idx,
1158            impurity_decrease: best_impurity_decrease,
1159            n_samples: n,
1160        };
1161
1162        node_idx
1163    } else {
1164        let idx = nodes.len();
1165        nodes.push(Node::Leaf {
1166            value: mean,
1167            class_distribution: None,
1168            n_samples: n,
1169        });
1170        idx
1171    }
1172}
1173
1174/// Find the best random split for a regression node.
1175///
1176/// For each candidate feature (from a random subset), pick a random threshold
1177/// between min and max of that feature in the current node. Return the split
1178/// with the largest MSE decrease, or `None` if no valid split exists.
1179#[allow(clippy::too_many_arguments)]
1180fn find_random_regression_split<F: Float>(
1181    data: &RegressionData<'_, F>,
1182    indices: &[usize],
1183    min_samples_leaf: usize,
1184    n_features: usize,
1185    max_features_n: usize,
1186    rng: &mut StdRng,
1187) -> Option<(usize, F, F)> {
1188    let n = indices.len();
1189    let n_f = F::from(n).unwrap();
1190
1191    let parent_sum: F = indices
1192        .iter()
1193        .map(|&i| data.y[i])
1194        .fold(F::zero(), |a, b| a + b);
1195    let parent_sum_sq: F = indices
1196        .iter()
1197        .map(|&i| data.y[i] * data.y[i])
1198        .fold(F::zero(), |a, b| a + b);
1199    let parent_mse = parent_sum_sq / n_f - (parent_sum / n_f) * (parent_sum / n_f);
1200
1201    let mut best_score = F::neg_infinity();
1202    let mut best_feature = 0;
1203    let mut best_threshold = F::zero();
1204
1205    // Select random feature subset.
1206    let feature_subset: Vec<usize> = if let Some(feat_indices) = data.feature_indices {
1207        let k = max_features_n.min(feat_indices.len());
1208        rand_sample_indices(rng, feat_indices.len(), k)
1209            .into_vec()
1210            .into_iter()
1211            .map(|i| feat_indices[i])
1212            .collect()
1213    } else {
1214        let k = max_features_n.min(n_features);
1215        rand_sample_indices(rng, n_features, k).into_vec()
1216    };
1217
1218    for feat in feature_subset {
1219        // Find min and max of this feature in the current node.
1220        let mut feat_min = F::infinity();
1221        let mut feat_max = F::neg_infinity();
1222        for &i in indices {
1223            let val = data.x[[i, feat]];
1224            if val < feat_min {
1225                feat_min = val;
1226            }
1227            if val > feat_max {
1228                feat_max = val;
1229            }
1230        }
1231
1232        // Skip constant features.
1233        if feat_min >= feat_max {
1234            continue;
1235        }
1236
1237        // Draw a random threshold uniformly in (min, max).
1238        let threshold = random_threshold(rng, feat_min, feat_max);
1239
1240        // Compute left/right statistics.
1241        let mut left_sum = F::zero();
1242        let mut left_sum_sq = F::zero();
1243        let mut left_n: usize = 0;
1244
1245        for &i in indices {
1246            if data.x[[i, feat]] <= threshold {
1247                let val = data.y[i];
1248                left_sum = left_sum + val;
1249                left_sum_sq = left_sum_sq + val * val;
1250                left_n += 1;
1251            }
1252        }
1253
1254        let right_n = n - left_n;
1255        if left_n < min_samples_leaf || right_n < min_samples_leaf {
1256            continue;
1257        }
1258
1259        let left_n_f = F::from(left_n).unwrap();
1260        let right_n_f = F::from(right_n).unwrap();
1261
1262        let left_mean = left_sum / left_n_f;
1263        let left_mse = left_sum_sq / left_n_f - left_mean * left_mean;
1264
1265        let right_sum = parent_sum - left_sum;
1266        let right_sum_sq = parent_sum_sq - left_sum_sq;
1267        let right_mean = right_sum / right_n_f;
1268        let right_mse = right_sum_sq / right_n_f - right_mean * right_mean;
1269
1270        let weighted_child_mse = (left_n_f * left_mse + right_n_f * right_mse) / n_f;
1271        let mse_decrease = parent_mse - weighted_child_mse;
1272
1273        if mse_decrease > best_score {
1274            best_score = mse_decrease;
1275            best_feature = feat;
1276            best_threshold = threshold;
1277        }
1278    }
1279
1280    if best_score > F::zero() {
1281        Some((best_feature, best_threshold, best_score * n_f))
1282    } else {
1283        None
1284    }
1285}
1286
1287// ---------------------------------------------------------------------------
1288// Crate-internal functions for ensemble usage
1289// ---------------------------------------------------------------------------
1290
1291/// Build a classification extra-tree with a subset of features for ensemble use.
1292///
1293/// Used internally by `ExtraTreesClassifier` to build individual trees.
1294#[allow(clippy::too_many_arguments)]
1295pub(crate) fn build_extra_classification_tree_for_ensemble<F: Float>(
1296    x: &Array2<F>,
1297    y: &[usize],
1298    n_classes: usize,
1299    indices: &[usize],
1300    feature_indices: Option<&[usize]>,
1301    params: &TreeParams,
1302    criterion: ClassificationCriterion,
1303    n_features: usize,
1304    max_features_n: usize,
1305    rng: &mut StdRng,
1306) -> Vec<Node<F>> {
1307    let data = ClassificationData {
1308        x,
1309        y,
1310        n_classes,
1311        feature_indices,
1312        criterion,
1313    };
1314    let mut nodes = Vec::new();
1315    build_extra_classification_tree(
1316        &data,
1317        indices,
1318        &mut nodes,
1319        0,
1320        params,
1321        n_features,
1322        max_features_n,
1323        rng,
1324    );
1325    nodes
1326}
1327
1328/// Build a regression extra-tree with a subset of features for ensemble use.
1329///
1330/// Used internally by `ExtraTreesRegressor` to build individual trees.
1331#[allow(clippy::too_many_arguments)]
1332pub(crate) fn build_extra_regression_tree_for_ensemble<F: Float>(
1333    x: &Array2<F>,
1334    y: &Array1<F>,
1335    indices: &[usize],
1336    feature_indices: Option<&[usize]>,
1337    params: &TreeParams,
1338    n_features: usize,
1339    max_features_n: usize,
1340    rng: &mut StdRng,
1341) -> Vec<Node<F>> {
1342    let data = RegressionData {
1343        x,
1344        y,
1345        feature_indices,
1346    };
1347    let mut nodes = Vec::new();
1348    build_extra_regression_tree(
1349        &data,
1350        indices,
1351        &mut nodes,
1352        0,
1353        params,
1354        n_features,
1355        max_features_n,
1356        rng,
1357    );
1358    nodes
1359}
1360
1361// ---------------------------------------------------------------------------
1362// Tests
1363// ---------------------------------------------------------------------------
1364
1365#[cfg(test)]
1366mod tests {
1367    use super::*;
1368    use approx::assert_relative_eq;
1369    use ndarray::array;
1370
1371    // -- Classifier tests --
1372
1373    #[test]
1374    fn test_extra_classifier_simple_binary() {
1375        let x = Array2::from_shape_vec(
1376            (6, 2),
1377            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
1378        )
1379        .unwrap();
1380        let y = array![0, 0, 0, 1, 1, 1];
1381
1382        let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1383        let fitted = model.fit(&x, &y).unwrap();
1384        let preds = fitted.predict(&x).unwrap();
1385
1386        assert_eq!(preds.len(), 6);
1387        // ExtraTrees should separate linearly separable data.
1388        for i in 0..3 {
1389            assert_eq!(preds[i], 0, "sample {i} should be class 0");
1390        }
1391        for i in 3..6 {
1392            assert_eq!(preds[i], 1, "sample {i} should be class 1");
1393        }
1394    }
1395
1396    #[test]
1397    fn test_extra_classifier_single_class() {
1398        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1399        let y = array![0, 0, 0];
1400
1401        let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1402        let fitted = model.fit(&x, &y).unwrap();
1403        let preds = fitted.predict(&x).unwrap();
1404
1405        assert_eq!(preds, array![0, 0, 0]);
1406    }
1407
1408    #[test]
1409    fn test_extra_classifier_max_depth_1() {
1410        let x =
1411            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1412        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1413
1414        let model = ExtraTreeClassifier::<f64>::new()
1415            .with_max_depth(Some(1))
1416            .with_max_features(MaxFeatures::All)
1417            .with_random_state(42);
1418        let fitted = model.fit(&x, &y).unwrap();
1419        let _preds = fitted.predict(&x).unwrap();
1420
1421        // With depth 1 and a single feature, it should still separate the classes.
1422        // The tree has exactly one split node and two leaves.
1423        assert_eq!(fitted.nodes().len(), 3);
1424    }
1425
1426    #[test]
1427    fn test_extra_classifier_predict_proba() {
1428        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1429        let y = array![0, 0, 0, 1, 1, 1];
1430
1431        let model = ExtraTreeClassifier::<f64>::new()
1432            .with_max_features(MaxFeatures::All)
1433            .with_random_state(42);
1434        let fitted = model.fit(&x, &y).unwrap();
1435        let proba = fitted.predict_proba(&x).unwrap();
1436
1437        assert_eq!(proba.dim(), (6, 2));
1438        // Each row sums to 1.
1439        for i in 0..6 {
1440            let row_sum = proba.row(i).sum();
1441            assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
1442        }
1443    }
1444
1445    #[test]
1446    fn test_extra_classifier_feature_importances() {
1447        let x = Array2::from_shape_vec(
1448            (8, 2),
1449            vec![
1450                1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 5.0, 1.0, 6.0, 1.0, 7.0, 1.0, 8.0, 1.0,
1451            ],
1452        )
1453        .unwrap();
1454        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1455
1456        let model = ExtraTreeClassifier::<f64>::new()
1457            .with_max_features(MaxFeatures::All)
1458            .with_random_state(42);
1459        let fitted = model.fit(&x, &y).unwrap();
1460        let importances = fitted.feature_importances();
1461
1462        assert_eq!(importances.len(), 2);
1463        // The sum of importances should be 1 (normalised).
1464        let total: f64 = importances.sum();
1465        assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1466        // Feature 0 should have higher importance (feature 1 is constant).
1467        assert!(importances[0] > importances[1]);
1468    }
1469
1470    #[test]
1471    fn test_extra_classifier_shape_mismatch() {
1472        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1473        let y = array![0, 0]; // wrong length
1474
1475        let model = ExtraTreeClassifier::<f64>::new();
1476        assert!(model.fit(&x, &y).is_err());
1477    }
1478
1479    #[test]
1480    fn test_extra_classifier_empty_data() {
1481        let x = Array2::<f64>::zeros((0, 2));
1482        let y = Array1::<usize>::zeros(0);
1483
1484        let model = ExtraTreeClassifier::<f64>::new();
1485        assert!(model.fit(&x, &y).is_err());
1486    }
1487
1488    #[test]
1489    fn test_extra_classifier_invalid_min_samples_split() {
1490        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1491        let y = array![0, 0, 1];
1492
1493        let model = ExtraTreeClassifier::<f64>::new().with_min_samples_split(1);
1494        assert!(model.fit(&x, &y).is_err());
1495    }
1496
1497    #[test]
1498    fn test_extra_classifier_classes() {
1499        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1500        let y = array![0, 0, 0, 2, 2, 2]; // non-contiguous classes
1501
1502        let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1503        let fitted = model.fit(&x, &y).unwrap();
1504
1505        assert_eq!(fitted.classes(), &[0, 2]);
1506        assert_eq!(fitted.n_classes(), 2);
1507    }
1508
1509    #[test]
1510    fn test_extra_classifier_predict_shape_mismatch() {
1511        let x =
1512            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1513        let y = array![0, 0, 1, 1];
1514
1515        let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1516        let fitted = model.fit(&x, &y).unwrap();
1517
1518        let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1519        assert!(fitted.predict(&x_wrong).is_err());
1520    }
1521
1522    #[test]
1523    fn test_extra_classifier_f32() {
1524        let x = Array2::from_shape_vec(
1525            (6, 2),
1526            vec![
1527                1.0f32, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0,
1528            ],
1529        )
1530        .unwrap();
1531        let y = array![0, 0, 0, 1, 1, 1];
1532
1533        let model = ExtraTreeClassifier::<f32>::new().with_random_state(42);
1534        let fitted = model.fit(&x, &y).unwrap();
1535        let preds = fitted.predict(&x).unwrap();
1536        assert_eq!(preds.len(), 6);
1537    }
1538
1539    #[test]
1540    fn test_extra_classifier_deterministic() {
1541        let x = Array2::from_shape_vec(
1542            (8, 2),
1543            vec![
1544                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1545            ],
1546        )
1547        .unwrap();
1548        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1549
1550        let model1 = ExtraTreeClassifier::<f64>::new().with_random_state(123);
1551        let model2 = ExtraTreeClassifier::<f64>::new().with_random_state(123);
1552
1553        let fitted1 = model1.fit(&x, &y).unwrap();
1554        let fitted2 = model2.fit(&x, &y).unwrap();
1555
1556        let preds1 = fitted1.predict(&x).unwrap();
1557        let preds2 = fitted2.predict(&x).unwrap();
1558
1559        assert_eq!(preds1, preds2);
1560    }
1561
1562    // -- Regressor tests --
1563
1564    #[test]
1565    fn test_extra_regressor_simple() {
1566        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1567        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1568
1569        let model = ExtraTreeRegressor::<f64>::new()
1570            .with_max_features(MaxFeatures::All)
1571            .with_random_state(42);
1572        let fitted = model.fit(&x, &y).unwrap();
1573        let preds = fitted.predict(&x).unwrap();
1574
1575        // A deep extra-tree should roughly memorize the training data.
1576        assert_eq!(preds.len(), 6);
1577        for i in 0..6 {
1578            assert_relative_eq!(preds[i], y[i], epsilon = 1.0);
1579        }
1580    }
1581
1582    #[test]
1583    fn test_extra_regressor_constant_target() {
1584        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1585        let y = array![5.0, 5.0, 5.0, 5.0];
1586
1587        let model = ExtraTreeRegressor::<f64>::new().with_random_state(42);
1588        let fitted = model.fit(&x, &y).unwrap();
1589        let preds = fitted.predict(&x).unwrap();
1590
1591        for &p in &preds {
1592            assert_relative_eq!(p, 5.0, epsilon = 1e-10);
1593        }
1594    }
1595
1596    #[test]
1597    fn test_extra_regressor_feature_importances() {
1598        let x = Array2::from_shape_vec(
1599            (8, 2),
1600            vec![
1601                1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1602            ],
1603        )
1604        .unwrap();
1605        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1606
1607        let model = ExtraTreeRegressor::<f64>::new()
1608            .with_max_features(MaxFeatures::All)
1609            .with_random_state(42);
1610        let fitted = model.fit(&x, &y).unwrap();
1611        let importances = fitted.feature_importances();
1612
1613        assert_eq!(importances.len(), 2);
1614        let total: f64 = importances.sum();
1615        assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1616        // Feature 0 drives the target; feature 1 is constant.
1617        assert!(importances[0] > importances[1]);
1618    }
1619
1620    #[test]
1621    fn test_extra_regressor_shape_mismatch() {
1622        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1623        let y = array![1.0, 2.0]; // wrong length
1624
1625        let model = ExtraTreeRegressor::<f64>::new();
1626        assert!(model.fit(&x, &y).is_err());
1627    }
1628
1629    #[test]
1630    fn test_extra_regressor_empty_data() {
1631        let x = Array2::<f64>::zeros((0, 2));
1632        let y = Array1::<f64>::zeros(0);
1633
1634        let model = ExtraTreeRegressor::<f64>::new();
1635        assert!(model.fit(&x, &y).is_err());
1636    }
1637
1638    #[test]
1639    fn test_extra_regressor_predict_shape_mismatch() {
1640        let x =
1641            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1642        let y = array![1.0, 2.0, 3.0, 4.0];
1643
1644        let model = ExtraTreeRegressor::<f64>::new().with_random_state(42);
1645        let fitted = model.fit(&x, &y).unwrap();
1646
1647        let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1648        assert!(fitted.predict(&x_wrong).is_err());
1649    }
1650
1651    #[test]
1652    fn test_extra_regressor_max_depth() {
1653        let x =
1654            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1655        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1656
1657        let model = ExtraTreeRegressor::<f64>::new()
1658            .with_max_depth(Some(1))
1659            .with_max_features(MaxFeatures::All)
1660            .with_random_state(42);
1661        let fitted = model.fit(&x, &y).unwrap();
1662
1663        // With depth 1, the tree should have exactly 3 nodes: one split + two leaves.
1664        assert_eq!(fitted.nodes().len(), 3);
1665    }
1666
1667    #[test]
1668    fn test_extra_regressor_deterministic() {
1669        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1670        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1671
1672        let model1 = ExtraTreeRegressor::<f64>::new().with_random_state(99);
1673        let model2 = ExtraTreeRegressor::<f64>::new().with_random_state(99);
1674
1675        let fitted1 = model1.fit(&x, &y).unwrap();
1676        let fitted2 = model2.fit(&x, &y).unwrap();
1677
1678        let preds1 = fitted1.predict(&x).unwrap();
1679        let preds2 = fitted2.predict(&x).unwrap();
1680
1681        for i in 0..6 {
1682            assert_relative_eq!(preds1[i], preds2[i], epsilon = 1e-12);
1683        }
1684    }
1685
1686    #[test]
1687    fn test_extra_regressor_f32() {
1688        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1689        let y = array![1.0f32, 2.0, 3.0, 4.0];
1690
1691        let model = ExtraTreeRegressor::<f32>::new().with_random_state(42);
1692        let fitted = model.fit(&x, &y).unwrap();
1693        let preds = fitted.predict(&x).unwrap();
1694        assert_eq!(preds.len(), 4);
1695    }
1696
1697    // -- Builder tests --
1698
1699    #[test]
1700    fn test_classifier_builder_methods() {
1701        let model = ExtraTreeClassifier::<f64>::new()
1702            .with_max_depth(Some(5))
1703            .with_min_samples_split(10)
1704            .with_min_samples_leaf(3)
1705            .with_max_features(MaxFeatures::Log2)
1706            .with_criterion(ClassificationCriterion::Entropy)
1707            .with_random_state(42);
1708
1709        assert_eq!(model.max_depth, Some(5));
1710        assert_eq!(model.min_samples_split, 10);
1711        assert_eq!(model.min_samples_leaf, 3);
1712        assert_eq!(model.max_features, MaxFeatures::Log2);
1713        assert_eq!(model.criterion, ClassificationCriterion::Entropy);
1714        assert_eq!(model.random_state, Some(42));
1715    }
1716
1717    #[test]
1718    fn test_regressor_builder_methods() {
1719        let model = ExtraTreeRegressor::<f64>::new()
1720            .with_max_depth(Some(10))
1721            .with_min_samples_split(5)
1722            .with_min_samples_leaf(2)
1723            .with_max_features(MaxFeatures::Fixed(3))
1724            .with_criterion(RegressionCriterion::Mse)
1725            .with_random_state(99);
1726
1727        assert_eq!(model.max_depth, Some(10));
1728        assert_eq!(model.min_samples_split, 5);
1729        assert_eq!(model.min_samples_leaf, 2);
1730        assert_eq!(model.max_features, MaxFeatures::Fixed(3));
1731        assert_eq!(model.criterion, RegressionCriterion::Mse);
1732        assert_eq!(model.random_state, Some(99));
1733    }
1734
1735    #[test]
1736    fn test_classifier_default() {
1737        let model = ExtraTreeClassifier::<f64>::default();
1738        assert_eq!(model.max_depth, None);
1739        assert_eq!(model.min_samples_split, 2);
1740        assert_eq!(model.min_samples_leaf, 1);
1741        assert_eq!(model.max_features, MaxFeatures::Sqrt);
1742        assert_eq!(model.criterion, ClassificationCriterion::Gini);
1743        assert_eq!(model.random_state, None);
1744    }
1745
1746    #[test]
1747    fn test_regressor_default() {
1748        let model = ExtraTreeRegressor::<f64>::default();
1749        assert_eq!(model.max_depth, None);
1750        assert_eq!(model.min_samples_split, 2);
1751        assert_eq!(model.min_samples_leaf, 1);
1752        assert_eq!(model.max_features, MaxFeatures::All);
1753        assert_eq!(model.criterion, RegressionCriterion::Mse);
1754        assert_eq!(model.random_state, None);
1755    }
1756}