Skip to main content

ferrolearn_tree/
extra_tree.rs

1//! Extremely randomized tree classifiers and regressors.
2//!
3//! This module provides [`ExtraTreeClassifier`] and [`ExtraTreeRegressor`],
4//! which are variants of decision trees where split thresholds are chosen
5//! randomly rather than via exhaustive search. For each candidate feature,
6//! a random threshold is drawn uniformly between the feature's minimum and
7//! maximum values in the current node.
8//!
9//! # Examples
10//!
11//! ```
12//! use ferrolearn_tree::ExtraTreeClassifier;
13//! use ferrolearn_core::{Fit, Predict};
14//! use ndarray::{array, Array1, Array2};
15//!
16//! let x = Array2::from_shape_vec((6, 2), vec![
17//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,
18//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,
19//! ]).unwrap();
20//! let y = array![0, 0, 0, 1, 1, 1];
21//!
22//! let model = ExtraTreeClassifier::<f64>::new()
23//!     .with_random_state(42);
24//! let fitted = model.fit(&x, &y).unwrap();
25//! let preds = fitted.predict(&x).unwrap();
26//! ```
27
28use ferrolearn_core::error::FerroError;
29use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
30use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
31use ferrolearn_core::traits::{Fit, Predict};
32use ndarray::{Array1, Array2};
33use num_traits::{Float, FromPrimitive, ToPrimitive};
34use rand::SeedableRng;
35use rand::rngs::StdRng;
36use rand::seq::index::sample as rand_sample_indices;
37use serde::{Deserialize, Serialize};
38
39use crate::decision_tree::{
40    ClassificationCriterion, Node, RegressionCriterion, TreeParams, compute_feature_importances,
41    traverse,
42};
43use crate::random_forest::MaxFeatures;
44
45// ---------------------------------------------------------------------------
46// Internal data structs for extra-tree building
47// ---------------------------------------------------------------------------
48
49/// Data references for classification extra-tree building.
50struct ClassificationData<'a, F> {
51    x: &'a Array2<F>,
52    y: &'a [usize],
53    n_classes: usize,
54    feature_indices: Option<&'a [usize]>,
55    criterion: ClassificationCriterion,
56}
57
58/// Data references for regression extra-tree building.
59struct RegressionData<'a, F> {
60    x: &'a Array2<F>,
61    y: &'a Array1<F>,
62    feature_indices: Option<&'a [usize]>,
63}
64
65// ---------------------------------------------------------------------------
66// ExtraTreeClassifier
67// ---------------------------------------------------------------------------
68
69/// Extremely randomized tree classifier.
70///
71/// Like a [`DecisionTreeClassifier`](crate::DecisionTreeClassifier), but split
72/// thresholds are chosen randomly rather than via exhaustive search. For each
73/// candidate feature, a random threshold is drawn uniformly between the
74/// feature's minimum and maximum values in the current node.
75///
76/// # Type Parameters
77///
78/// - `F`: The floating-point type (`f32` or `f64`).
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ExtraTreeClassifier<F> {
81    /// Maximum depth of the tree. `None` means unlimited.
82    pub max_depth: Option<usize>,
83    /// Minimum number of samples required to split an internal node.
84    pub min_samples_split: usize,
85    /// Minimum number of samples required in a leaf node.
86    pub min_samples_leaf: usize,
87    /// Strategy for the number of features considered at each split.
88    pub max_features: MaxFeatures,
89    /// Splitting criterion.
90    pub criterion: ClassificationCriterion,
91    /// Random seed for reproducibility. `None` means non-deterministic.
92    pub random_state: Option<u64>,
93    _marker: std::marker::PhantomData<F>,
94}
95
96impl<F: Float> ExtraTreeClassifier<F> {
97    /// Create a new `ExtraTreeClassifier` with default settings.
98    ///
99    /// Defaults: `max_depth = None`, `min_samples_split = 2`,
100    /// `min_samples_leaf = 1`, `max_features = Sqrt`,
101    /// `criterion = Gini`, `random_state = None`.
102    #[must_use]
103    pub fn new() -> Self {
104        Self {
105            max_depth: None,
106            min_samples_split: 2,
107            min_samples_leaf: 1,
108            max_features: MaxFeatures::Sqrt,
109            criterion: ClassificationCriterion::Gini,
110            random_state: None,
111            _marker: std::marker::PhantomData,
112        }
113    }
114
115    /// Set the maximum tree depth.
116    #[must_use]
117    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
118        self.max_depth = max_depth;
119        self
120    }
121
122    /// Set the minimum number of samples required to split a node.
123    #[must_use]
124    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
125        self.min_samples_split = min_samples_split;
126        self
127    }
128
129    /// Set the minimum number of samples required in a leaf node.
130    #[must_use]
131    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
132        self.min_samples_leaf = min_samples_leaf;
133        self
134    }
135
136    /// Set the maximum features strategy.
137    #[must_use]
138    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
139        self.max_features = max_features;
140        self
141    }
142
143    /// Set the splitting criterion.
144    #[must_use]
145    pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
146        self.criterion = criterion;
147        self
148    }
149
150    /// Set the random seed for reproducibility.
151    #[must_use]
152    pub fn with_random_state(mut self, seed: u64) -> Self {
153        self.random_state = Some(seed);
154        self
155    }
156}
157
158impl<F: Float> Default for ExtraTreeClassifier<F> {
159    fn default() -> Self {
160        Self::new()
161    }
162}
163
164// ---------------------------------------------------------------------------
165// FittedExtraTreeClassifier
166// ---------------------------------------------------------------------------
167
168/// A fitted extremely randomized tree classifier.
169///
170/// Stores the learned tree as a flat `Vec<Node<F>>` for cache-friendly traversal.
171/// Implements [`Predict`] for generating class predictions and
172/// [`HasFeatureImportances`] for inspecting per-feature importance scores.
173#[derive(Debug, Clone)]
174pub struct FittedExtraTreeClassifier<F> {
175    /// Flat node storage; index 0 is the root.
176    nodes: Vec<Node<F>>,
177    /// Sorted unique class labels observed during training.
178    classes: Vec<usize>,
179    /// Number of features the model was trained on.
180    n_features: usize,
181    /// Per-feature importance scores (normalised to sum to 1).
182    feature_importances: Array1<F>,
183}
184
185impl<F: Float + Send + Sync + 'static> FittedExtraTreeClassifier<F> {
186    /// Returns a reference to the flat node storage of the tree.
187    #[must_use]
188    pub fn nodes(&self) -> &[Node<F>] {
189        &self.nodes
190    }
191
192    /// Returns the number of features the model was trained on.
193    #[must_use]
194    pub fn n_features(&self) -> usize {
195        self.n_features
196    }
197
198    /// Predict class probabilities for each sample.
199    ///
200    /// Returns a 2-D array of shape `(n_samples, n_classes)`.
201    ///
202    /// # Errors
203    ///
204    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
205    /// not match the training data.
206    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
207        if x.ncols() != self.n_features {
208            return Err(FerroError::ShapeMismatch {
209                expected: vec![self.n_features],
210                actual: vec![x.ncols()],
211                context: "number of features must match fitted model".into(),
212            });
213        }
214        let n_samples = x.nrows();
215        let n_classes = self.classes.len();
216        let mut proba = Array2::zeros((n_samples, n_classes));
217        for i in 0..n_samples {
218            let row = x.row(i);
219            let leaf = traverse(&self.nodes, &row);
220            if let Node::Leaf {
221                class_distribution: Some(ref dist),
222                ..
223            } = self.nodes[leaf]
224            {
225                for (j, &p) in dist.iter().enumerate() {
226                    proba[[i, j]] = p;
227                }
228            }
229        }
230        Ok(proba)
231    }
232}
233
234impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ExtraTreeClassifier<F> {
235    type Fitted = FittedExtraTreeClassifier<F>;
236    type Error = FerroError;
237
238    /// Fit the extra-tree classifier on the training data.
239    ///
240    /// # Errors
241    ///
242    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
243    /// numbers of samples.
244    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
245    /// Returns [`FerroError::InvalidParameter`] if hyperparameters are invalid.
246    fn fit(
247        &self,
248        x: &Array2<F>,
249        y: &Array1<usize>,
250    ) -> Result<FittedExtraTreeClassifier<F>, FerroError> {
251        let (n_samples, n_features) = x.dim();
252
253        if n_samples != y.len() {
254            return Err(FerroError::ShapeMismatch {
255                expected: vec![n_samples],
256                actual: vec![y.len()],
257                context: "y length must match number of samples in X".into(),
258            });
259        }
260        if n_samples == 0 {
261            return Err(FerroError::InsufficientSamples {
262                required: 1,
263                actual: 0,
264                context: "ExtraTreeClassifier requires at least one sample".into(),
265            });
266        }
267        if self.min_samples_split < 2 {
268            return Err(FerroError::InvalidParameter {
269                name: "min_samples_split".into(),
270                reason: "must be at least 2".into(),
271            });
272        }
273        if self.min_samples_leaf < 1 {
274            return Err(FerroError::InvalidParameter {
275                name: "min_samples_leaf".into(),
276                reason: "must be at least 1".into(),
277            });
278        }
279
280        // Determine unique classes.
281        let mut classes: Vec<usize> = y.iter().copied().collect();
282        classes.sort_unstable();
283        classes.dedup();
284        let n_classes = classes.len();
285
286        // Map class labels to indices 0..n_classes.
287        let y_mapped: Vec<usize> = y
288            .iter()
289            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
290            .collect();
291
292        let indices: Vec<usize> = (0..n_samples).collect();
293
294        let max_features_n = resolve_max_features(self.max_features, n_features);
295
296        let mut rng = if let Some(seed) = self.random_state {
297            StdRng::seed_from_u64(seed)
298        } else {
299            StdRng::from_os_rng()
300        };
301
302        let data = ClassificationData {
303            x,
304            y: &y_mapped,
305            n_classes,
306            feature_indices: None,
307            criterion: self.criterion,
308        };
309        let params = TreeParams {
310            max_depth: self.max_depth,
311            min_samples_split: self.min_samples_split,
312            min_samples_leaf: self.min_samples_leaf,
313        };
314
315        let mut nodes: Vec<Node<F>> = Vec::new();
316        build_extra_classification_tree(
317            &data,
318            &indices,
319            &mut nodes,
320            0,
321            &params,
322            n_features,
323            max_features_n,
324            &mut rng,
325        );
326
327        let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
328
329        Ok(FittedExtraTreeClassifier {
330            nodes,
331            classes,
332            n_features,
333            feature_importances,
334        })
335    }
336}
337
338impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreeClassifier<F> {
339    type Output = Array1<usize>;
340    type Error = FerroError;
341
342    /// Predict class labels for the given feature matrix.
343    ///
344    /// # Errors
345    ///
346    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
347    /// not match the fitted model.
348    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
349        if x.ncols() != self.n_features {
350            return Err(FerroError::ShapeMismatch {
351                expected: vec![self.n_features],
352                actual: vec![x.ncols()],
353                context: "number of features must match fitted model".into(),
354            });
355        }
356        let n_samples = x.nrows();
357        let mut predictions = Array1::zeros(n_samples);
358        for i in 0..n_samples {
359            let row = x.row(i);
360            let leaf = traverse(&self.nodes, &row);
361            if let Node::Leaf { value, .. } = self.nodes[leaf] {
362                predictions[i] = float_to_usize(value);
363            }
364        }
365        Ok(predictions)
366    }
367}
368
369impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreeClassifier<F> {
370    fn feature_importances(&self) -> &Array1<F> {
371        &self.feature_importances
372    }
373}
374
375impl<F: Float + Send + Sync + 'static> HasClasses for FittedExtraTreeClassifier<F> {
376    fn classes(&self) -> &[usize] {
377        &self.classes
378    }
379
380    fn n_classes(&self) -> usize {
381        self.classes.len()
382    }
383}
384
385// Pipeline integration.
386impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
387    for ExtraTreeClassifier<F>
388{
389    fn fit_pipeline(
390        &self,
391        x: &Array2<F>,
392        y: &Array1<F>,
393    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
394        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
395        let fitted = self.fit(x, &y_usize)?;
396        Ok(Box::new(FittedExtraTreeClassifierPipelineAdapter(fitted)))
397    }
398}
399
400/// Pipeline adapter for `FittedExtraTreeClassifier<F>`.
401struct FittedExtraTreeClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
402    FittedExtraTreeClassifier<F>,
403);
404
405impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
406    for FittedExtraTreeClassifierPipelineAdapter<F>
407{
408    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
409        let preds = self.0.predict(x)?;
410        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
411    }
412}
413
414// ---------------------------------------------------------------------------
415// ExtraTreeRegressor
416// ---------------------------------------------------------------------------
417
418/// Extremely randomized tree regressor.
419///
420/// Like a [`DecisionTreeRegressor`](crate::DecisionTreeRegressor), but split
421/// thresholds are chosen randomly rather than via exhaustive search. For each
422/// candidate feature, a random threshold is drawn uniformly between the
423/// feature's minimum and maximum values in the current node.
424///
425/// # Type Parameters
426///
427/// - `F`: The floating-point type (`f32` or `f64`).
428#[derive(Debug, Clone, Serialize, Deserialize)]
429pub struct ExtraTreeRegressor<F> {
430    /// Maximum depth of the tree. `None` means unlimited.
431    pub max_depth: Option<usize>,
432    /// Minimum number of samples required to split an internal node.
433    pub min_samples_split: usize,
434    /// Minimum number of samples required in a leaf node.
435    pub min_samples_leaf: usize,
436    /// Strategy for the number of features considered at each split.
437    pub max_features: MaxFeatures,
438    /// Splitting criterion.
439    pub criterion: RegressionCriterion,
440    /// Random seed for reproducibility. `None` means non-deterministic.
441    pub random_state: Option<u64>,
442    _marker: std::marker::PhantomData<F>,
443}
444
445impl<F: Float> ExtraTreeRegressor<F> {
446    /// Create a new `ExtraTreeRegressor` with default settings.
447    ///
448    /// Defaults: `max_depth = None`, `min_samples_split = 2`,
449    /// `min_samples_leaf = 1`, `max_features = All`,
450    /// `criterion = MSE`, `random_state = None`.
451    #[must_use]
452    pub fn new() -> Self {
453        Self {
454            max_depth: None,
455            min_samples_split: 2,
456            min_samples_leaf: 1,
457            max_features: MaxFeatures::All,
458            criterion: RegressionCriterion::Mse,
459            random_state: None,
460            _marker: std::marker::PhantomData,
461        }
462    }
463
464    /// Set the maximum tree depth.
465    #[must_use]
466    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
467        self.max_depth = max_depth;
468        self
469    }
470
471    /// Set the minimum number of samples required to split a node.
472    #[must_use]
473    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
474        self.min_samples_split = min_samples_split;
475        self
476    }
477
478    /// Set the minimum number of samples required in a leaf node.
479    #[must_use]
480    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
481        self.min_samples_leaf = min_samples_leaf;
482        self
483    }
484
485    /// Set the maximum features strategy.
486    #[must_use]
487    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
488        self.max_features = max_features;
489        self
490    }
491
492    /// Set the splitting criterion.
493    #[must_use]
494    pub fn with_criterion(mut self, criterion: RegressionCriterion) -> Self {
495        self.criterion = criterion;
496        self
497    }
498
499    /// Set the random seed for reproducibility.
500    #[must_use]
501    pub fn with_random_state(mut self, seed: u64) -> Self {
502        self.random_state = Some(seed);
503        self
504    }
505}
506
507impl<F: Float> Default for ExtraTreeRegressor<F> {
508    fn default() -> Self {
509        Self::new()
510    }
511}
512
513// ---------------------------------------------------------------------------
514// FittedExtraTreeRegressor
515// ---------------------------------------------------------------------------
516
517/// A fitted extremely randomized tree regressor.
518///
519/// Stores the learned tree as a flat `Vec<Node<F>>` for cache-friendly traversal.
520#[derive(Debug, Clone)]
521pub struct FittedExtraTreeRegressor<F> {
522    /// Flat node storage; index 0 is the root.
523    nodes: Vec<Node<F>>,
524    /// Number of features the model was trained on.
525    n_features: usize,
526    /// Per-feature importance scores (normalised to sum to 1).
527    feature_importances: Array1<F>,
528}
529
530impl<F: Float + Send + Sync + 'static> FittedExtraTreeRegressor<F> {
531    /// Returns a reference to the flat node storage of the tree.
532    #[must_use]
533    pub fn nodes(&self) -> &[Node<F>] {
534        &self.nodes
535    }
536
537    /// Returns the number of features the model was trained on.
538    #[must_use]
539    pub fn n_features(&self) -> usize {
540        self.n_features
541    }
542}
543
544impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for ExtraTreeRegressor<F> {
545    type Fitted = FittedExtraTreeRegressor<F>;
546    type Error = FerroError;
547
548    /// Fit the extra-tree regressor on the training data.
549    ///
550    /// # Errors
551    ///
552    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
553    /// numbers of samples.
554    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
555    /// Returns [`FerroError::InvalidParameter`] if hyperparameters are invalid.
556    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedExtraTreeRegressor<F>, FerroError> {
557        let (n_samples, n_features) = x.dim();
558
559        if n_samples != y.len() {
560            return Err(FerroError::ShapeMismatch {
561                expected: vec![n_samples],
562                actual: vec![y.len()],
563                context: "y length must match number of samples in X".into(),
564            });
565        }
566        if n_samples == 0 {
567            return Err(FerroError::InsufficientSamples {
568                required: 1,
569                actual: 0,
570                context: "ExtraTreeRegressor requires at least one sample".into(),
571            });
572        }
573        if self.min_samples_split < 2 {
574            return Err(FerroError::InvalidParameter {
575                name: "min_samples_split".into(),
576                reason: "must be at least 2".into(),
577            });
578        }
579        if self.min_samples_leaf < 1 {
580            return Err(FerroError::InvalidParameter {
581                name: "min_samples_leaf".into(),
582                reason: "must be at least 1".into(),
583            });
584        }
585
586        let indices: Vec<usize> = (0..n_samples).collect();
587        let max_features_n = resolve_max_features(self.max_features, n_features);
588
589        let mut rng = if let Some(seed) = self.random_state {
590            StdRng::seed_from_u64(seed)
591        } else {
592            StdRng::from_os_rng()
593        };
594
595        let data = RegressionData {
596            x,
597            y,
598            feature_indices: None,
599        };
600        let params = TreeParams {
601            max_depth: self.max_depth,
602            min_samples_split: self.min_samples_split,
603            min_samples_leaf: self.min_samples_leaf,
604        };
605
606        let mut nodes: Vec<Node<F>> = Vec::new();
607        build_extra_regression_tree(
608            &data,
609            &indices,
610            &mut nodes,
611            0,
612            &params,
613            n_features,
614            max_features_n,
615            &mut rng,
616        );
617
618        let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
619
620        Ok(FittedExtraTreeRegressor {
621            nodes,
622            n_features,
623            feature_importances,
624        })
625    }
626}
627
628impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreeRegressor<F> {
629    type Output = Array1<F>;
630    type Error = FerroError;
631
632    /// Predict target values for the given feature matrix.
633    ///
634    /// # Errors
635    ///
636    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
637    /// not match the fitted model.
638    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
639        if x.ncols() != self.n_features {
640            return Err(FerroError::ShapeMismatch {
641                expected: vec![self.n_features],
642                actual: vec![x.ncols()],
643                context: "number of features must match fitted model".into(),
644            });
645        }
646        let n_samples = x.nrows();
647        let mut predictions = Array1::zeros(n_samples);
648        for i in 0..n_samples {
649            let row = x.row(i);
650            let leaf = traverse(&self.nodes, &row);
651            if let Node::Leaf { value, .. } = self.nodes[leaf] {
652                predictions[i] = value;
653            }
654        }
655        Ok(predictions)
656    }
657}
658
659impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreeRegressor<F> {
660    fn feature_importances(&self) -> &Array1<F> {
661        &self.feature_importances
662    }
663}
664
665// Pipeline integration.
666impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for ExtraTreeRegressor<F> {
667    fn fit_pipeline(
668        &self,
669        x: &Array2<F>,
670        y: &Array1<F>,
671    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
672        let fitted = self.fit(x, y)?;
673        Ok(Box::new(fitted))
674    }
675}
676
677impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedExtraTreeRegressor<F> {
678    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
679        self.predict(x)
680    }
681}
682
683// ---------------------------------------------------------------------------
684// Internal: helpers
685// ---------------------------------------------------------------------------
686
687/// Resolve the `MaxFeatures` strategy to a concrete number.
688fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
689    let result = match strategy {
690        MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
691        MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
692        MaxFeatures::All => n_features,
693        MaxFeatures::Fixed(n) => n.min(n_features),
694        MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
695    };
696    result.max(1).min(n_features)
697}
698
699/// Convert a `Float` value to `usize` (for class labels stored as floats).
700fn float_to_usize<F: Float>(v: F) -> usize {
701    v.to_f64().map_or(0, |f| f.round() as usize)
702}
703
704/// Generate a uniform random float in `[min_val, max_val]`.
705fn random_threshold<F: Float>(rng: &mut StdRng, min_val: F, max_val: F) -> F {
706    use rand::RngCore;
707    // Generate a random f64 in [0, 1) and scale to [min_val, max_val].
708    let u = (rng.next_u64() as f64) / (u64::MAX as f64);
709    let range = max_val - min_val;
710    min_val + F::from(u).unwrap() * range
711}
712
713/// Compute the Gini impurity for a set of class counts.
714fn gini_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
715    if total == 0 {
716        return F::zero();
717    }
718    let total_f = F::from(total).unwrap();
719    let mut impurity = F::one();
720    for &count in class_counts {
721        let p = F::from(count).unwrap() / total_f;
722        impurity = impurity - p * p;
723    }
724    impurity
725}
726
727/// Compute the Shannon entropy for a set of class counts.
728fn entropy_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
729    if total == 0 {
730        return F::zero();
731    }
732    let total_f = F::from(total).unwrap();
733    let mut ent = F::zero();
734    for &count in class_counts {
735        if count > 0 {
736            let p = F::from(count).unwrap() / total_f;
737            ent = ent - p * p.ln();
738        }
739    }
740    ent
741}
742
743/// Compute impurity for a given classification criterion.
744fn compute_impurity<F: Float>(
745    class_counts: &[usize],
746    total: usize,
747    criterion: ClassificationCriterion,
748) -> F {
749    match criterion {
750        ClassificationCriterion::Gini => gini_impurity(class_counts, total),
751        ClassificationCriterion::Entropy => entropy_impurity(class_counts, total),
752    }
753}
754
755/// Create a classification leaf node and return its index.
756fn make_classification_leaf<F: Float>(
757    nodes: &mut Vec<Node<F>>,
758    class_counts: &[usize],
759    n_classes: usize,
760    n_samples: usize,
761) -> usize {
762    let majority_class = class_counts
763        .iter()
764        .enumerate()
765        .max_by_key(|&(_, &count)| count)
766        .map_or(0, |(i, _)| i);
767
768    let total_f = if n_samples > 0 {
769        F::from(n_samples).unwrap()
770    } else {
771        F::one()
772    };
773    let distribution: Vec<F> = (0..n_classes)
774        .map(|c| F::from(class_counts[c]).unwrap() / total_f)
775        .collect();
776
777    let idx = nodes.len();
778    nodes.push(Node::Leaf {
779        value: F::from(majority_class).unwrap(),
780        class_distribution: Some(distribution),
781        n_samples,
782    });
783    idx
784}
785
786/// Compute the mean of target values for the given indices.
787fn mean_value<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
788    if indices.is_empty() {
789        return F::zero();
790    }
791    let sum: F = indices.iter().map(|&i| y[i]).fold(F::zero(), |a, b| a + b);
792    sum / F::from(indices.len()).unwrap()
793}
794
795// ---------------------------------------------------------------------------
796// Extra-tree classification building
797// ---------------------------------------------------------------------------
798
799/// Build an extra-tree classification tree recursively with random thresholds.
800///
801/// At each node, a random subset of features is considered, and for each feature
802/// a random threshold is drawn uniformly between the feature's min and max in the
803/// current node.
804#[allow(clippy::too_many_arguments)]
805fn build_extra_classification_tree<F: Float>(
806    data: &ClassificationData<'_, F>,
807    indices: &[usize],
808    nodes: &mut Vec<Node<F>>,
809    depth: usize,
810    params: &TreeParams,
811    n_features: usize,
812    max_features_n: usize,
813    rng: &mut StdRng,
814) -> usize {
815    let n = indices.len();
816
817    let mut class_counts = vec![0usize; data.n_classes];
818    for &i in indices {
819        class_counts[data.y[i]] += 1;
820    }
821
822    let should_stop = n < params.min_samples_split
823        || params.max_depth.is_some_and(|d| depth >= d)
824        || class_counts.iter().filter(|&&c| c > 0).count() <= 1;
825
826    if should_stop {
827        return make_classification_leaf(nodes, &class_counts, data.n_classes, n);
828    }
829
830    let best = find_random_classification_split(
831        data,
832        indices,
833        params.min_samples_leaf,
834        n_features,
835        max_features_n,
836        rng,
837    );
838
839    if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
840        let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
841            .iter()
842            .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
843
844        // Ensure both children have at least min_samples_leaf.
845        if left_indices.len() < params.min_samples_leaf
846            || right_indices.len() < params.min_samples_leaf
847        {
848            return make_classification_leaf(nodes, &class_counts, data.n_classes, n);
849        }
850
851        let node_idx = nodes.len();
852        nodes.push(Node::Leaf {
853            value: F::zero(),
854            class_distribution: None,
855            n_samples: 0,
856        }); // placeholder
857
858        let left_idx = build_extra_classification_tree(
859            data,
860            &left_indices,
861            nodes,
862            depth + 1,
863            params,
864            n_features,
865            max_features_n,
866            rng,
867        );
868        let right_idx = build_extra_classification_tree(
869            data,
870            &right_indices,
871            nodes,
872            depth + 1,
873            params,
874            n_features,
875            max_features_n,
876            rng,
877        );
878
879        nodes[node_idx] = Node::Split {
880            feature: best_feature,
881            threshold: best_threshold,
882            left: left_idx,
883            right: right_idx,
884            impurity_decrease: best_impurity_decrease,
885            n_samples: n,
886        };
887
888        node_idx
889    } else {
890        make_classification_leaf(nodes, &class_counts, data.n_classes, n)
891    }
892}
893
894/// Find the best random split for a classification node.
895///
896/// For each candidate feature (from a random subset), pick a random threshold
897/// between min and max of that feature in the current node. Return the split
898/// with the largest impurity decrease, or `None` if no valid split exists.
899#[allow(clippy::too_many_arguments)]
900fn find_random_classification_split<F: Float>(
901    data: &ClassificationData<'_, F>,
902    indices: &[usize],
903    min_samples_leaf: usize,
904    n_features: usize,
905    max_features_n: usize,
906    rng: &mut StdRng,
907) -> Option<(usize, F, F)> {
908    let n = indices.len();
909    let n_f = F::from(n).unwrap();
910
911    let mut parent_counts = vec![0usize; data.n_classes];
912    for &i in indices {
913        parent_counts[data.y[i]] += 1;
914    }
915    let parent_impurity = compute_impurity::<F>(&parent_counts, n, data.criterion);
916
917    let mut best_score = F::neg_infinity();
918    let mut best_feature = 0;
919    let mut best_threshold = F::zero();
920
921    // Select random feature subset.
922    let feature_subset: Vec<usize> = if let Some(feat_indices) = data.feature_indices {
923        // If a feature subset is provided externally, sample from it.
924        let k = max_features_n.min(feat_indices.len());
925        rand_sample_indices(rng, feat_indices.len(), k)
926            .into_vec()
927            .into_iter()
928            .map(|i| feat_indices[i])
929            .collect()
930    } else {
931        let k = max_features_n.min(n_features);
932        rand_sample_indices(rng, n_features, k).into_vec()
933    };
934
935    for feat in feature_subset {
936        // Find min and max of this feature in the current node.
937        let mut feat_min = F::infinity();
938        let mut feat_max = F::neg_infinity();
939        for &i in indices {
940            let val = data.x[[i, feat]];
941            if val < feat_min {
942                feat_min = val;
943            }
944            if val > feat_max {
945                feat_max = val;
946            }
947        }
948
949        // Skip constant features.
950        if feat_min >= feat_max {
951            continue;
952        }
953
954        // Draw a random threshold uniformly in (min, max).
955        let threshold = random_threshold(rng, feat_min, feat_max);
956
957        // Count left and right.
958        let mut left_counts = vec![0usize; data.n_classes];
959        let mut right_counts = vec![0usize; data.n_classes];
960        let mut left_n = 0usize;
961
962        for &i in indices {
963            let cls = data.y[i];
964            if data.x[[i, feat]] <= threshold {
965                left_counts[cls] += 1;
966                left_n += 1;
967            } else {
968                right_counts[cls] += 1;
969            }
970        }
971
972        let right_n = n - left_n;
973        if left_n < min_samples_leaf || right_n < min_samples_leaf {
974            continue;
975        }
976
977        let left_impurity = compute_impurity::<F>(&left_counts, left_n, data.criterion);
978        let right_impurity = compute_impurity::<F>(&right_counts, right_n, data.criterion);
979        let left_weight = F::from(left_n).unwrap() / n_f;
980        let right_weight = F::from(right_n).unwrap() / n_f;
981        let weighted_child_impurity = left_weight * left_impurity + right_weight * right_impurity;
982        let impurity_decrease = parent_impurity - weighted_child_impurity;
983
984        if impurity_decrease > best_score {
985            best_score = impurity_decrease;
986            best_feature = feat;
987            best_threshold = threshold;
988        }
989    }
990
991    if best_score > F::zero() {
992        Some((best_feature, best_threshold, best_score * n_f))
993    } else {
994        None
995    }
996}
997
998// ---------------------------------------------------------------------------
999// Extra-tree regression building
1000// ---------------------------------------------------------------------------
1001
1002/// Build an extra-tree regression tree recursively with random thresholds.
1003#[allow(clippy::too_many_arguments)]
1004fn build_extra_regression_tree<F: Float>(
1005    data: &RegressionData<'_, F>,
1006    indices: &[usize],
1007    nodes: &mut Vec<Node<F>>,
1008    depth: usize,
1009    params: &TreeParams,
1010    n_features: usize,
1011    max_features_n: usize,
1012    rng: &mut StdRng,
1013) -> usize {
1014    let n = indices.len();
1015    let mean = mean_value(data.y, indices);
1016
1017    let should_stop = n < params.min_samples_split || params.max_depth.is_some_and(|d| depth >= d);
1018
1019    if should_stop {
1020        let idx = nodes.len();
1021        nodes.push(Node::Leaf {
1022            value: mean,
1023            class_distribution: None,
1024            n_samples: n,
1025        });
1026        return idx;
1027    }
1028
1029    // Check if variance is essentially zero.
1030    let parent_sum_sq: F = indices
1031        .iter()
1032        .map(|&i| {
1033            let diff = data.y[i] - mean;
1034            diff * diff
1035        })
1036        .fold(F::zero(), |a, b| a + b);
1037    let parent_mse = parent_sum_sq / F::from(n).unwrap();
1038
1039    if parent_mse <= F::epsilon() {
1040        let idx = nodes.len();
1041        nodes.push(Node::Leaf {
1042            value: mean,
1043            class_distribution: None,
1044            n_samples: n,
1045        });
1046        return idx;
1047    }
1048
1049    let best = find_random_regression_split(
1050        data,
1051        indices,
1052        params.min_samples_leaf,
1053        n_features,
1054        max_features_n,
1055        rng,
1056    );
1057
1058    if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
1059        let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
1060            .iter()
1061            .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
1062
1063        // Ensure both children have at least min_samples_leaf.
1064        if left_indices.len() < params.min_samples_leaf
1065            || right_indices.len() < params.min_samples_leaf
1066        {
1067            let idx = nodes.len();
1068            nodes.push(Node::Leaf {
1069                value: mean,
1070                class_distribution: None,
1071                n_samples: n,
1072            });
1073            return idx;
1074        }
1075
1076        let node_idx = nodes.len();
1077        nodes.push(Node::Leaf {
1078            value: F::zero(),
1079            class_distribution: None,
1080            n_samples: 0,
1081        }); // placeholder
1082
1083        let left_idx = build_extra_regression_tree(
1084            data,
1085            &left_indices,
1086            nodes,
1087            depth + 1,
1088            params,
1089            n_features,
1090            max_features_n,
1091            rng,
1092        );
1093        let right_idx = build_extra_regression_tree(
1094            data,
1095            &right_indices,
1096            nodes,
1097            depth + 1,
1098            params,
1099            n_features,
1100            max_features_n,
1101            rng,
1102        );
1103
1104        nodes[node_idx] = Node::Split {
1105            feature: best_feature,
1106            threshold: best_threshold,
1107            left: left_idx,
1108            right: right_idx,
1109            impurity_decrease: best_impurity_decrease,
1110            n_samples: n,
1111        };
1112
1113        node_idx
1114    } else {
1115        let idx = nodes.len();
1116        nodes.push(Node::Leaf {
1117            value: mean,
1118            class_distribution: None,
1119            n_samples: n,
1120        });
1121        idx
1122    }
1123}
1124
1125/// Find the best random split for a regression node.
1126///
1127/// For each candidate feature (from a random subset), pick a random threshold
1128/// between min and max of that feature in the current node. Return the split
1129/// with the largest MSE decrease, or `None` if no valid split exists.
1130#[allow(clippy::too_many_arguments)]
1131fn find_random_regression_split<F: Float>(
1132    data: &RegressionData<'_, F>,
1133    indices: &[usize],
1134    min_samples_leaf: usize,
1135    n_features: usize,
1136    max_features_n: usize,
1137    rng: &mut StdRng,
1138) -> Option<(usize, F, F)> {
1139    let n = indices.len();
1140    let n_f = F::from(n).unwrap();
1141
1142    let parent_sum: F = indices
1143        .iter()
1144        .map(|&i| data.y[i])
1145        .fold(F::zero(), |a, b| a + b);
1146    let parent_sum_sq: F = indices
1147        .iter()
1148        .map(|&i| data.y[i] * data.y[i])
1149        .fold(F::zero(), |a, b| a + b);
1150    let parent_mse = parent_sum_sq / n_f - (parent_sum / n_f) * (parent_sum / n_f);
1151
1152    let mut best_score = F::neg_infinity();
1153    let mut best_feature = 0;
1154    let mut best_threshold = F::zero();
1155
1156    // Select random feature subset.
1157    let feature_subset: Vec<usize> = if let Some(feat_indices) = data.feature_indices {
1158        let k = max_features_n.min(feat_indices.len());
1159        rand_sample_indices(rng, feat_indices.len(), k)
1160            .into_vec()
1161            .into_iter()
1162            .map(|i| feat_indices[i])
1163            .collect()
1164    } else {
1165        let k = max_features_n.min(n_features);
1166        rand_sample_indices(rng, n_features, k).into_vec()
1167    };
1168
1169    for feat in feature_subset {
1170        // Find min and max of this feature in the current node.
1171        let mut feat_min = F::infinity();
1172        let mut feat_max = F::neg_infinity();
1173        for &i in indices {
1174            let val = data.x[[i, feat]];
1175            if val < feat_min {
1176                feat_min = val;
1177            }
1178            if val > feat_max {
1179                feat_max = val;
1180            }
1181        }
1182
1183        // Skip constant features.
1184        if feat_min >= feat_max {
1185            continue;
1186        }
1187
1188        // Draw a random threshold uniformly in (min, max).
1189        let threshold = random_threshold(rng, feat_min, feat_max);
1190
1191        // Compute left/right statistics.
1192        let mut left_sum = F::zero();
1193        let mut left_sum_sq = F::zero();
1194        let mut left_n: usize = 0;
1195
1196        for &i in indices {
1197            if data.x[[i, feat]] <= threshold {
1198                let val = data.y[i];
1199                left_sum = left_sum + val;
1200                left_sum_sq = left_sum_sq + val * val;
1201                left_n += 1;
1202            }
1203        }
1204
1205        let right_n = n - left_n;
1206        if left_n < min_samples_leaf || right_n < min_samples_leaf {
1207            continue;
1208        }
1209
1210        let left_n_f = F::from(left_n).unwrap();
1211        let right_n_f = F::from(right_n).unwrap();
1212
1213        let left_mean = left_sum / left_n_f;
1214        let left_mse = left_sum_sq / left_n_f - left_mean * left_mean;
1215
1216        let right_sum = parent_sum - left_sum;
1217        let right_sum_sq = parent_sum_sq - left_sum_sq;
1218        let right_mean = right_sum / right_n_f;
1219        let right_mse = right_sum_sq / right_n_f - right_mean * right_mean;
1220
1221        let weighted_child_mse = (left_n_f * left_mse + right_n_f * right_mse) / n_f;
1222        let mse_decrease = parent_mse - weighted_child_mse;
1223
1224        if mse_decrease > best_score {
1225            best_score = mse_decrease;
1226            best_feature = feat;
1227            best_threshold = threshold;
1228        }
1229    }
1230
1231    if best_score > F::zero() {
1232        Some((best_feature, best_threshold, best_score * n_f))
1233    } else {
1234        None
1235    }
1236}
1237
1238// ---------------------------------------------------------------------------
1239// Crate-internal functions for ensemble usage
1240// ---------------------------------------------------------------------------
1241
1242/// Build a classification extra-tree with a subset of features for ensemble use.
1243///
1244/// Used internally by `ExtraTreesClassifier` to build individual trees.
1245#[allow(clippy::too_many_arguments)]
1246pub(crate) fn build_extra_classification_tree_for_ensemble<F: Float>(
1247    x: &Array2<F>,
1248    y: &[usize],
1249    n_classes: usize,
1250    indices: &[usize],
1251    feature_indices: Option<&[usize]>,
1252    params: &TreeParams,
1253    criterion: ClassificationCriterion,
1254    n_features: usize,
1255    max_features_n: usize,
1256    rng: &mut StdRng,
1257) -> Vec<Node<F>> {
1258    let data = ClassificationData {
1259        x,
1260        y,
1261        n_classes,
1262        feature_indices,
1263        criterion,
1264    };
1265    let mut nodes = Vec::new();
1266    build_extra_classification_tree(
1267        &data,
1268        indices,
1269        &mut nodes,
1270        0,
1271        params,
1272        n_features,
1273        max_features_n,
1274        rng,
1275    );
1276    nodes
1277}
1278
1279/// Build a regression extra-tree with a subset of features for ensemble use.
1280///
1281/// Used internally by `ExtraTreesRegressor` to build individual trees.
1282#[allow(clippy::too_many_arguments)]
1283pub(crate) fn build_extra_regression_tree_for_ensemble<F: Float>(
1284    x: &Array2<F>,
1285    y: &Array1<F>,
1286    indices: &[usize],
1287    feature_indices: Option<&[usize]>,
1288    params: &TreeParams,
1289    n_features: usize,
1290    max_features_n: usize,
1291    rng: &mut StdRng,
1292) -> Vec<Node<F>> {
1293    let data = RegressionData {
1294        x,
1295        y,
1296        feature_indices,
1297    };
1298    let mut nodes = Vec::new();
1299    build_extra_regression_tree(
1300        &data,
1301        indices,
1302        &mut nodes,
1303        0,
1304        params,
1305        n_features,
1306        max_features_n,
1307        rng,
1308    );
1309    nodes
1310}
1311
1312// ---------------------------------------------------------------------------
1313// Tests
1314// ---------------------------------------------------------------------------
1315
1316#[cfg(test)]
1317mod tests {
1318    use super::*;
1319    use approx::assert_relative_eq;
1320    use ndarray::array;
1321
1322    // -- Classifier tests --
1323
1324    #[test]
1325    fn test_extra_classifier_simple_binary() {
1326        let x = Array2::from_shape_vec(
1327            (6, 2),
1328            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
1329        )
1330        .unwrap();
1331        let y = array![0, 0, 0, 1, 1, 1];
1332
1333        let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1334        let fitted = model.fit(&x, &y).unwrap();
1335        let preds = fitted.predict(&x).unwrap();
1336
1337        assert_eq!(preds.len(), 6);
1338        // ExtraTrees should separate linearly separable data.
1339        for i in 0..3 {
1340            assert_eq!(preds[i], 0, "sample {i} should be class 0");
1341        }
1342        for i in 3..6 {
1343            assert_eq!(preds[i], 1, "sample {i} should be class 1");
1344        }
1345    }
1346
1347    #[test]
1348    fn test_extra_classifier_single_class() {
1349        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1350        let y = array![0, 0, 0];
1351
1352        let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1353        let fitted = model.fit(&x, &y).unwrap();
1354        let preds = fitted.predict(&x).unwrap();
1355
1356        assert_eq!(preds, array![0, 0, 0]);
1357    }
1358
1359    #[test]
1360    fn test_extra_classifier_max_depth_1() {
1361        let x =
1362            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1363        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1364
1365        let model = ExtraTreeClassifier::<f64>::new()
1366            .with_max_depth(Some(1))
1367            .with_max_features(MaxFeatures::All)
1368            .with_random_state(42);
1369        let fitted = model.fit(&x, &y).unwrap();
1370        let _preds = fitted.predict(&x).unwrap();
1371
1372        // With depth 1 and a single feature, it should still separate the classes.
1373        // The tree has exactly one split node and two leaves.
1374        assert_eq!(fitted.nodes().len(), 3);
1375    }
1376
1377    #[test]
1378    fn test_extra_classifier_predict_proba() {
1379        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1380        let y = array![0, 0, 0, 1, 1, 1];
1381
1382        let model = ExtraTreeClassifier::<f64>::new()
1383            .with_max_features(MaxFeatures::All)
1384            .with_random_state(42);
1385        let fitted = model.fit(&x, &y).unwrap();
1386        let proba = fitted.predict_proba(&x).unwrap();
1387
1388        assert_eq!(proba.dim(), (6, 2));
1389        // Each row sums to 1.
1390        for i in 0..6 {
1391            let row_sum = proba.row(i).sum();
1392            assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
1393        }
1394    }
1395
1396    #[test]
1397    fn test_extra_classifier_feature_importances() {
1398        let x = Array2::from_shape_vec(
1399            (8, 2),
1400            vec![
1401                1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 5.0, 1.0, 6.0, 1.0, 7.0, 1.0, 8.0, 1.0,
1402            ],
1403        )
1404        .unwrap();
1405        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1406
1407        let model = ExtraTreeClassifier::<f64>::new()
1408            .with_max_features(MaxFeatures::All)
1409            .with_random_state(42);
1410        let fitted = model.fit(&x, &y).unwrap();
1411        let importances = fitted.feature_importances();
1412
1413        assert_eq!(importances.len(), 2);
1414        // The sum of importances should be 1 (normalised).
1415        let total: f64 = importances.sum();
1416        assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1417        // Feature 0 should have higher importance (feature 1 is constant).
1418        assert!(importances[0] > importances[1]);
1419    }
1420
1421    #[test]
1422    fn test_extra_classifier_shape_mismatch() {
1423        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1424        let y = array![0, 0]; // wrong length
1425
1426        let model = ExtraTreeClassifier::<f64>::new();
1427        assert!(model.fit(&x, &y).is_err());
1428    }
1429
1430    #[test]
1431    fn test_extra_classifier_empty_data() {
1432        let x = Array2::<f64>::zeros((0, 2));
1433        let y = Array1::<usize>::zeros(0);
1434
1435        let model = ExtraTreeClassifier::<f64>::new();
1436        assert!(model.fit(&x, &y).is_err());
1437    }
1438
1439    #[test]
1440    fn test_extra_classifier_invalid_min_samples_split() {
1441        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1442        let y = array![0, 0, 1];
1443
1444        let model = ExtraTreeClassifier::<f64>::new().with_min_samples_split(1);
1445        assert!(model.fit(&x, &y).is_err());
1446    }
1447
1448    #[test]
1449    fn test_extra_classifier_classes() {
1450        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1451        let y = array![0, 0, 0, 2, 2, 2]; // non-contiguous classes
1452
1453        let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1454        let fitted = model.fit(&x, &y).unwrap();
1455
1456        assert_eq!(fitted.classes(), &[0, 2]);
1457        assert_eq!(fitted.n_classes(), 2);
1458    }
1459
1460    #[test]
1461    fn test_extra_classifier_predict_shape_mismatch() {
1462        let x =
1463            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1464        let y = array![0, 0, 1, 1];
1465
1466        let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1467        let fitted = model.fit(&x, &y).unwrap();
1468
1469        let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1470        assert!(fitted.predict(&x_wrong).is_err());
1471    }
1472
1473    #[test]
1474    fn test_extra_classifier_f32() {
1475        let x = Array2::from_shape_vec(
1476            (6, 2),
1477            vec![
1478                1.0f32, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0,
1479            ],
1480        )
1481        .unwrap();
1482        let y = array![0, 0, 0, 1, 1, 1];
1483
1484        let model = ExtraTreeClassifier::<f32>::new().with_random_state(42);
1485        let fitted = model.fit(&x, &y).unwrap();
1486        let preds = fitted.predict(&x).unwrap();
1487        assert_eq!(preds.len(), 6);
1488    }
1489
1490    #[test]
1491    fn test_extra_classifier_deterministic() {
1492        let x = Array2::from_shape_vec(
1493            (8, 2),
1494            vec![
1495                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1496            ],
1497        )
1498        .unwrap();
1499        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1500
1501        let model1 = ExtraTreeClassifier::<f64>::new().with_random_state(123);
1502        let model2 = ExtraTreeClassifier::<f64>::new().with_random_state(123);
1503
1504        let fitted1 = model1.fit(&x, &y).unwrap();
1505        let fitted2 = model2.fit(&x, &y).unwrap();
1506
1507        let preds1 = fitted1.predict(&x).unwrap();
1508        let preds2 = fitted2.predict(&x).unwrap();
1509
1510        assert_eq!(preds1, preds2);
1511    }
1512
1513    // -- Regressor tests --
1514
1515    #[test]
1516    fn test_extra_regressor_simple() {
1517        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1518        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1519
1520        let model = ExtraTreeRegressor::<f64>::new()
1521            .with_max_features(MaxFeatures::All)
1522            .with_random_state(42);
1523        let fitted = model.fit(&x, &y).unwrap();
1524        let preds = fitted.predict(&x).unwrap();
1525
1526        // A deep extra-tree should roughly memorize the training data.
1527        assert_eq!(preds.len(), 6);
1528        for i in 0..6 {
1529            assert_relative_eq!(preds[i], y[i], epsilon = 1.0);
1530        }
1531    }
1532
1533    #[test]
1534    fn test_extra_regressor_constant_target() {
1535        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1536        let y = array![5.0, 5.0, 5.0, 5.0];
1537
1538        let model = ExtraTreeRegressor::<f64>::new().with_random_state(42);
1539        let fitted = model.fit(&x, &y).unwrap();
1540        let preds = fitted.predict(&x).unwrap();
1541
1542        for &p in &preds {
1543            assert_relative_eq!(p, 5.0, epsilon = 1e-10);
1544        }
1545    }
1546
1547    #[test]
1548    fn test_extra_regressor_feature_importances() {
1549        let x = Array2::from_shape_vec(
1550            (8, 2),
1551            vec![
1552                1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1553            ],
1554        )
1555        .unwrap();
1556        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1557
1558        let model = ExtraTreeRegressor::<f64>::new()
1559            .with_max_features(MaxFeatures::All)
1560            .with_random_state(42);
1561        let fitted = model.fit(&x, &y).unwrap();
1562        let importances = fitted.feature_importances();
1563
1564        assert_eq!(importances.len(), 2);
1565        let total: f64 = importances.sum();
1566        assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1567        // Feature 0 drives the target; feature 1 is constant.
1568        assert!(importances[0] > importances[1]);
1569    }
1570
1571    #[test]
1572    fn test_extra_regressor_shape_mismatch() {
1573        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1574        let y = array![1.0, 2.0]; // wrong length
1575
1576        let model = ExtraTreeRegressor::<f64>::new();
1577        assert!(model.fit(&x, &y).is_err());
1578    }
1579
1580    #[test]
1581    fn test_extra_regressor_empty_data() {
1582        let x = Array2::<f64>::zeros((0, 2));
1583        let y = Array1::<f64>::zeros(0);
1584
1585        let model = ExtraTreeRegressor::<f64>::new();
1586        assert!(model.fit(&x, &y).is_err());
1587    }
1588
1589    #[test]
1590    fn test_extra_regressor_predict_shape_mismatch() {
1591        let x =
1592            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1593        let y = array![1.0, 2.0, 3.0, 4.0];
1594
1595        let model = ExtraTreeRegressor::<f64>::new().with_random_state(42);
1596        let fitted = model.fit(&x, &y).unwrap();
1597
1598        let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1599        assert!(fitted.predict(&x_wrong).is_err());
1600    }
1601
1602    #[test]
1603    fn test_extra_regressor_max_depth() {
1604        let x =
1605            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1606        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1607
1608        let model = ExtraTreeRegressor::<f64>::new()
1609            .with_max_depth(Some(1))
1610            .with_max_features(MaxFeatures::All)
1611            .with_random_state(42);
1612        let fitted = model.fit(&x, &y).unwrap();
1613
1614        // With depth 1, the tree should have exactly 3 nodes: one split + two leaves.
1615        assert_eq!(fitted.nodes().len(), 3);
1616    }
1617
1618    #[test]
1619    fn test_extra_regressor_deterministic() {
1620        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1621        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1622
1623        let model1 = ExtraTreeRegressor::<f64>::new().with_random_state(99);
1624        let model2 = ExtraTreeRegressor::<f64>::new().with_random_state(99);
1625
1626        let fitted1 = model1.fit(&x, &y).unwrap();
1627        let fitted2 = model2.fit(&x, &y).unwrap();
1628
1629        let preds1 = fitted1.predict(&x).unwrap();
1630        let preds2 = fitted2.predict(&x).unwrap();
1631
1632        for i in 0..6 {
1633            assert_relative_eq!(preds1[i], preds2[i], epsilon = 1e-12);
1634        }
1635    }
1636
1637    #[test]
1638    fn test_extra_regressor_f32() {
1639        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1640        let y = array![1.0f32, 2.0, 3.0, 4.0];
1641
1642        let model = ExtraTreeRegressor::<f32>::new().with_random_state(42);
1643        let fitted = model.fit(&x, &y).unwrap();
1644        let preds = fitted.predict(&x).unwrap();
1645        assert_eq!(preds.len(), 4);
1646    }
1647
1648    // -- Builder tests --
1649
1650    #[test]
1651    fn test_classifier_builder_methods() {
1652        let model = ExtraTreeClassifier::<f64>::new()
1653            .with_max_depth(Some(5))
1654            .with_min_samples_split(10)
1655            .with_min_samples_leaf(3)
1656            .with_max_features(MaxFeatures::Log2)
1657            .with_criterion(ClassificationCriterion::Entropy)
1658            .with_random_state(42);
1659
1660        assert_eq!(model.max_depth, Some(5));
1661        assert_eq!(model.min_samples_split, 10);
1662        assert_eq!(model.min_samples_leaf, 3);
1663        assert_eq!(model.max_features, MaxFeatures::Log2);
1664        assert_eq!(model.criterion, ClassificationCriterion::Entropy);
1665        assert_eq!(model.random_state, Some(42));
1666    }
1667
1668    #[test]
1669    fn test_regressor_builder_methods() {
1670        let model = ExtraTreeRegressor::<f64>::new()
1671            .with_max_depth(Some(10))
1672            .with_min_samples_split(5)
1673            .with_min_samples_leaf(2)
1674            .with_max_features(MaxFeatures::Fixed(3))
1675            .with_criterion(RegressionCriterion::Mse)
1676            .with_random_state(99);
1677
1678        assert_eq!(model.max_depth, Some(10));
1679        assert_eq!(model.min_samples_split, 5);
1680        assert_eq!(model.min_samples_leaf, 2);
1681        assert_eq!(model.max_features, MaxFeatures::Fixed(3));
1682        assert_eq!(model.criterion, RegressionCriterion::Mse);
1683        assert_eq!(model.random_state, Some(99));
1684    }
1685
1686    #[test]
1687    fn test_classifier_default() {
1688        let model = ExtraTreeClassifier::<f64>::default();
1689        assert_eq!(model.max_depth, None);
1690        assert_eq!(model.min_samples_split, 2);
1691        assert_eq!(model.min_samples_leaf, 1);
1692        assert_eq!(model.max_features, MaxFeatures::Sqrt);
1693        assert_eq!(model.criterion, ClassificationCriterion::Gini);
1694        assert_eq!(model.random_state, None);
1695    }
1696
1697    #[test]
1698    fn test_regressor_default() {
1699        let model = ExtraTreeRegressor::<f64>::default();
1700        assert_eq!(model.max_depth, None);
1701        assert_eq!(model.min_samples_split, 2);
1702        assert_eq!(model.min_samples_leaf, 1);
1703        assert_eq!(model.max_features, MaxFeatures::All);
1704        assert_eq!(model.criterion, RegressionCriterion::Mse);
1705        assert_eq!(model.random_state, None);
1706    }
1707}