Skip to main content

ferrolearn_tree/
extra_tree.rs

1//! Extremely randomized tree classifiers and regressors.
2//!
3//! This module provides [`ExtraTreeClassifier`] and [`ExtraTreeRegressor`],
4//! which are variants of decision trees where split thresholds are chosen
5//! randomly rather than via exhaustive search. For each candidate feature,
6//! a random threshold is drawn uniformly between the feature's minimum and
7//! maximum values in the current node.
8//!
9//! # Examples
10//!
11//! ```
12//! use ferrolearn_tree::ExtraTreeClassifier;
13//! use ferrolearn_core::{Fit, Predict};
14//! use ndarray::{array, Array1, Array2};
15//!
16//! let x = Array2::from_shape_vec((6, 2), vec![
17//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,
18//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,
19//! ]).unwrap();
20//! let y = array![0, 0, 0, 1, 1, 1];
21//!
22//! let model = ExtraTreeClassifier::<f64>::new()
23//!     .with_random_state(42);
24//! let fitted = model.fit(&x, &y).unwrap();
25//! let preds = fitted.predict(&x).unwrap();
26//! ```
27
28use ferrolearn_core::error::FerroError;
29use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
30use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
31use ferrolearn_core::traits::{Fit, Predict};
32use ndarray::{Array1, Array2};
33use num_traits::{Float, FromPrimitive, ToPrimitive};
34use rand::SeedableRng;
35use rand::rngs::StdRng;
36use rand::seq::index::sample as rand_sample_indices;
37use serde::{Deserialize, Serialize};
38
39use crate::decision_tree::{
40    ClassificationCriterion, Node, RegressionCriterion, TreeParams, compute_feature_importances,
41    traverse,
42};
43use crate::random_forest::MaxFeatures;
44
45// ---------------------------------------------------------------------------
46// Internal data structs for extra-tree building
47// ---------------------------------------------------------------------------
48
49/// Data references for classification extra-tree building.
50struct ClassificationData<'a, F> {
51    x: &'a Array2<F>,
52    y: &'a [usize],
53    n_classes: usize,
54    feature_indices: Option<&'a [usize]>,
55    criterion: ClassificationCriterion,
56}
57
58/// Data references for regression extra-tree building.
59struct RegressionData<'a, F> {
60    x: &'a Array2<F>,
61    y: &'a Array1<F>,
62    feature_indices: Option<&'a [usize]>,
63}
64
65// ---------------------------------------------------------------------------
66// ExtraTreeClassifier
67// ---------------------------------------------------------------------------
68
69/// Extremely randomized tree classifier.
70///
71/// Like a [`DecisionTreeClassifier`](crate::DecisionTreeClassifier), but split
72/// thresholds are chosen randomly rather than via exhaustive search. For each
73/// candidate feature, a random threshold is drawn uniformly between the
74/// feature's minimum and maximum values in the current node.
75///
76/// # Type Parameters
77///
78/// - `F`: The floating-point type (`f32` or `f64`).
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ExtraTreeClassifier<F> {
81    /// Maximum depth of the tree. `None` means unlimited.
82    pub max_depth: Option<usize>,
83    /// Minimum number of samples required to split an internal node.
84    pub min_samples_split: usize,
85    /// Minimum number of samples required in a leaf node.
86    pub min_samples_leaf: usize,
87    /// Strategy for the number of features considered at each split.
88    pub max_features: MaxFeatures,
89    /// Splitting criterion.
90    pub criterion: ClassificationCriterion,
91    /// Random seed for reproducibility. `None` means non-deterministic.
92    pub random_state: Option<u64>,
93    _marker: std::marker::PhantomData<F>,
94}
95
96impl<F: Float> ExtraTreeClassifier<F> {
97    /// Create a new `ExtraTreeClassifier` with default settings.
98    ///
99    /// Defaults: `max_depth = None`, `min_samples_split = 2`,
100    /// `min_samples_leaf = 1`, `max_features = Sqrt`,
101    /// `criterion = Gini`, `random_state = None`.
102    #[must_use]
103    pub fn new() -> Self {
104        Self {
105            max_depth: None,
106            min_samples_split: 2,
107            min_samples_leaf: 1,
108            max_features: MaxFeatures::Sqrt,
109            criterion: ClassificationCriterion::Gini,
110            random_state: None,
111            _marker: std::marker::PhantomData,
112        }
113    }
114
115    /// Set the maximum tree depth.
116    #[must_use]
117    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
118        self.max_depth = max_depth;
119        self
120    }
121
122    /// Set the minimum number of samples required to split a node.
123    #[must_use]
124    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
125        self.min_samples_split = min_samples_split;
126        self
127    }
128
129    /// Set the minimum number of samples required in a leaf node.
130    #[must_use]
131    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
132        self.min_samples_leaf = min_samples_leaf;
133        self
134    }
135
136    /// Set the maximum features strategy.
137    #[must_use]
138    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
139        self.max_features = max_features;
140        self
141    }
142
143    /// Set the splitting criterion.
144    #[must_use]
145    pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
146        self.criterion = criterion;
147        self
148    }
149
150    /// Set the random seed for reproducibility.
151    #[must_use]
152    pub fn with_random_state(mut self, seed: u64) -> Self {
153        self.random_state = Some(seed);
154        self
155    }
156}
157
158impl<F: Float> Default for ExtraTreeClassifier<F> {
159    fn default() -> Self {
160        Self::new()
161    }
162}
163
164// ---------------------------------------------------------------------------
165// FittedExtraTreeClassifier
166// ---------------------------------------------------------------------------
167
168/// A fitted extremely randomized tree classifier.
169///
170/// Stores the learned tree as a flat `Vec<Node<F>>` for cache-friendly traversal.
171/// Implements [`Predict`] for generating class predictions and
172/// [`HasFeatureImportances`] for inspecting per-feature importance scores.
173#[derive(Debug, Clone)]
174pub struct FittedExtraTreeClassifier<F> {
175    /// Flat node storage; index 0 is the root.
176    nodes: Vec<Node<F>>,
177    /// Sorted unique class labels observed during training.
178    classes: Vec<usize>,
179    /// Number of features the model was trained on.
180    n_features: usize,
181    /// Per-feature importance scores (normalised to sum to 1).
182    feature_importances: Array1<F>,
183}
184
185impl<F: Float + Send + Sync + 'static> FittedExtraTreeClassifier<F> {
186    /// Returns a reference to the flat node storage of the tree.
187    #[must_use]
188    pub fn nodes(&self) -> &[Node<F>] {
189        &self.nodes
190    }
191
192    /// Returns the number of features the model was trained on.
193    #[must_use]
194    pub fn n_features(&self) -> usize {
195        self.n_features
196    }
197
198    /// Predict class probabilities for each sample.
199    ///
200    /// Returns a 2-D array of shape `(n_samples, n_classes)`.
201    ///
202    /// # Errors
203    ///
204    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
205    /// not match the training data.
206    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
207        if x.ncols() != self.n_features {
208            return Err(FerroError::ShapeMismatch {
209                expected: vec![self.n_features],
210                actual: vec![x.ncols()],
211                context: "number of features must match fitted model".into(),
212            });
213        }
214        let n_samples = x.nrows();
215        let n_classes = self.classes.len();
216        let mut proba = Array2::zeros((n_samples, n_classes));
217        for i in 0..n_samples {
218            let row = x.row(i);
219            let leaf = traverse(&self.nodes, &row);
220            if let Node::Leaf {
221                class_distribution: Some(ref dist),
222                ..
223            } = self.nodes[leaf]
224            {
225                for (j, &p) in dist.iter().enumerate() {
226                    proba[[i, j]] = p;
227                }
228            }
229        }
230        Ok(proba)
231    }
232}
233
234impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ExtraTreeClassifier<F> {
235    type Fitted = FittedExtraTreeClassifier<F>;
236    type Error = FerroError;
237
238    /// Fit the extra-tree classifier on the training data.
239    ///
240    /// # Errors
241    ///
242    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
243    /// numbers of samples.
244    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
245    /// Returns [`FerroError::InvalidParameter`] if hyperparameters are invalid.
246    fn fit(
247        &self,
248        x: &Array2<F>,
249        y: &Array1<usize>,
250    ) -> Result<FittedExtraTreeClassifier<F>, FerroError> {
251        let (n_samples, n_features) = x.dim();
252
253        if n_samples != y.len() {
254            return Err(FerroError::ShapeMismatch {
255                expected: vec![n_samples],
256                actual: vec![y.len()],
257                context: "y length must match number of samples in X".into(),
258            });
259        }
260        if n_samples == 0 {
261            return Err(FerroError::InsufficientSamples {
262                required: 1,
263                actual: 0,
264                context: "ExtraTreeClassifier requires at least one sample".into(),
265            });
266        }
267        if self.min_samples_split < 2 {
268            return Err(FerroError::InvalidParameter {
269                name: "min_samples_split".into(),
270                reason: "must be at least 2".into(),
271            });
272        }
273        if self.min_samples_leaf < 1 {
274            return Err(FerroError::InvalidParameter {
275                name: "min_samples_leaf".into(),
276                reason: "must be at least 1".into(),
277            });
278        }
279
280        // Determine unique classes.
281        let mut classes: Vec<usize> = y.iter().copied().collect();
282        classes.sort_unstable();
283        classes.dedup();
284        let n_classes = classes.len();
285
286        // Map class labels to indices 0..n_classes.
287        let y_mapped: Vec<usize> = y
288            .iter()
289            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
290            .collect();
291
292        let indices: Vec<usize> = (0..n_samples).collect();
293
294        let max_features_n = resolve_max_features(self.max_features, n_features);
295
296        let mut rng = if let Some(seed) = self.random_state {
297            StdRng::seed_from_u64(seed)
298        } else {
299            StdRng::from_os_rng()
300        };
301
302        let data = ClassificationData {
303            x,
304            y: &y_mapped,
305            n_classes,
306            feature_indices: None,
307            criterion: self.criterion,
308        };
309        let params = TreeParams {
310            max_depth: self.max_depth,
311            min_samples_split: self.min_samples_split,
312            min_samples_leaf: self.min_samples_leaf,
313        };
314
315        let mut nodes: Vec<Node<F>> = Vec::new();
316        build_extra_classification_tree(
317            &data,
318            &indices,
319            &mut nodes,
320            0,
321            &params,
322            n_features,
323            max_features_n,
324            &mut rng,
325        );
326
327        let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
328
329        Ok(FittedExtraTreeClassifier {
330            nodes,
331            classes,
332            n_features,
333            feature_importances,
334        })
335    }
336}
337
338impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreeClassifier<F> {
339    type Output = Array1<usize>;
340    type Error = FerroError;
341
342    /// Predict class labels for the given feature matrix.
343    ///
344    /// # Errors
345    ///
346    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
347    /// not match the fitted model.
348    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
349        if x.ncols() != self.n_features {
350            return Err(FerroError::ShapeMismatch {
351                expected: vec![self.n_features],
352                actual: vec![x.ncols()],
353                context: "number of features must match fitted model".into(),
354            });
355        }
356        let n_samples = x.nrows();
357        let mut predictions = Array1::zeros(n_samples);
358        for i in 0..n_samples {
359            let row = x.row(i);
360            let leaf = traverse(&self.nodes, &row);
361            if let Node::Leaf { value, .. } = self.nodes[leaf] {
362                predictions[i] = float_to_usize(value);
363            }
364        }
365        Ok(predictions)
366    }
367}
368
369impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreeClassifier<F> {
370    fn feature_importances(&self) -> &Array1<F> {
371        &self.feature_importances
372    }
373}
374
375impl<F: Float + Send + Sync + 'static> HasClasses for FittedExtraTreeClassifier<F> {
376    fn classes(&self) -> &[usize] {
377        &self.classes
378    }
379
380    fn n_classes(&self) -> usize {
381        self.classes.len()
382    }
383}
384
385// Pipeline integration.
386impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
387    for ExtraTreeClassifier<F>
388{
389    fn fit_pipeline(
390        &self,
391        x: &Array2<F>,
392        y: &Array1<F>,
393    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
394        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
395        let fitted = self.fit(x, &y_usize)?;
396        Ok(Box::new(FittedExtraTreeClassifierPipelineAdapter(fitted)))
397    }
398}
399
400/// Pipeline adapter for `FittedExtraTreeClassifier<F>`.
401struct FittedExtraTreeClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
402    FittedExtraTreeClassifier<F>,
403);
404
405impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
406    for FittedExtraTreeClassifierPipelineAdapter<F>
407{
408    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
409        let preds = self.0.predict(x)?;
410        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
411    }
412}
413
414// ---------------------------------------------------------------------------
415// ExtraTreeRegressor
416// ---------------------------------------------------------------------------
417
418/// Extremely randomized tree regressor.
419///
420/// Like a [`DecisionTreeRegressor`](crate::DecisionTreeRegressor), but split
421/// thresholds are chosen randomly rather than via exhaustive search. For each
422/// candidate feature, a random threshold is drawn uniformly between the
423/// feature's minimum and maximum values in the current node.
424///
425/// # Type Parameters
426///
427/// - `F`: The floating-point type (`f32` or `f64`).
428#[derive(Debug, Clone, Serialize, Deserialize)]
429pub struct ExtraTreeRegressor<F> {
430    /// Maximum depth of the tree. `None` means unlimited.
431    pub max_depth: Option<usize>,
432    /// Minimum number of samples required to split an internal node.
433    pub min_samples_split: usize,
434    /// Minimum number of samples required in a leaf node.
435    pub min_samples_leaf: usize,
436    /// Strategy for the number of features considered at each split.
437    pub max_features: MaxFeatures,
438    /// Splitting criterion.
439    pub criterion: RegressionCriterion,
440    /// Random seed for reproducibility. `None` means non-deterministic.
441    pub random_state: Option<u64>,
442    _marker: std::marker::PhantomData<F>,
443}
444
445impl<F: Float> ExtraTreeRegressor<F> {
446    /// Create a new `ExtraTreeRegressor` with default settings.
447    ///
448    /// Defaults: `max_depth = None`, `min_samples_split = 2`,
449    /// `min_samples_leaf = 1`, `max_features = All`,
450    /// `criterion = MSE`, `random_state = None`.
451    #[must_use]
452    pub fn new() -> Self {
453        Self {
454            max_depth: None,
455            min_samples_split: 2,
456            min_samples_leaf: 1,
457            max_features: MaxFeatures::All,
458            criterion: RegressionCriterion::Mse,
459            random_state: None,
460            _marker: std::marker::PhantomData,
461        }
462    }
463
464    /// Set the maximum tree depth.
465    #[must_use]
466    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
467        self.max_depth = max_depth;
468        self
469    }
470
471    /// Set the minimum number of samples required to split a node.
472    #[must_use]
473    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
474        self.min_samples_split = min_samples_split;
475        self
476    }
477
478    /// Set the minimum number of samples required in a leaf node.
479    #[must_use]
480    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
481        self.min_samples_leaf = min_samples_leaf;
482        self
483    }
484
485    /// Set the maximum features strategy.
486    #[must_use]
487    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
488        self.max_features = max_features;
489        self
490    }
491
492    /// Set the splitting criterion.
493    #[must_use]
494    pub fn with_criterion(mut self, criterion: RegressionCriterion) -> Self {
495        self.criterion = criterion;
496        self
497    }
498
499    /// Set the random seed for reproducibility.
500    #[must_use]
501    pub fn with_random_state(mut self, seed: u64) -> Self {
502        self.random_state = Some(seed);
503        self
504    }
505}
506
507impl<F: Float> Default for ExtraTreeRegressor<F> {
508    fn default() -> Self {
509        Self::new()
510    }
511}
512
513// ---------------------------------------------------------------------------
514// FittedExtraTreeRegressor
515// ---------------------------------------------------------------------------
516
517/// A fitted extremely randomized tree regressor.
518///
519/// Stores the learned tree as a flat `Vec<Node<F>>` for cache-friendly traversal.
520#[derive(Debug, Clone)]
521pub struct FittedExtraTreeRegressor<F> {
522    /// Flat node storage; index 0 is the root.
523    nodes: Vec<Node<F>>,
524    /// Number of features the model was trained on.
525    n_features: usize,
526    /// Per-feature importance scores (normalised to sum to 1).
527    feature_importances: Array1<F>,
528}
529
530impl<F: Float + Send + Sync + 'static> FittedExtraTreeRegressor<F> {
531    /// Returns a reference to the flat node storage of the tree.
532    #[must_use]
533    pub fn nodes(&self) -> &[Node<F>] {
534        &self.nodes
535    }
536
537    /// Returns the number of features the model was trained on.
538    #[must_use]
539    pub fn n_features(&self) -> usize {
540        self.n_features
541    }
542}
543
544impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for ExtraTreeRegressor<F> {
545    type Fitted = FittedExtraTreeRegressor<F>;
546    type Error = FerroError;
547
548    /// Fit the extra-tree regressor on the training data.
549    ///
550    /// # Errors
551    ///
552    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
553    /// numbers of samples.
554    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
555    /// Returns [`FerroError::InvalidParameter`] if hyperparameters are invalid.
556    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedExtraTreeRegressor<F>, FerroError> {
557        let (n_samples, n_features) = x.dim();
558
559        if n_samples != y.len() {
560            return Err(FerroError::ShapeMismatch {
561                expected: vec![n_samples],
562                actual: vec![y.len()],
563                context: "y length must match number of samples in X".into(),
564            });
565        }
566        if n_samples == 0 {
567            return Err(FerroError::InsufficientSamples {
568                required: 1,
569                actual: 0,
570                context: "ExtraTreeRegressor requires at least one sample".into(),
571            });
572        }
573        if self.min_samples_split < 2 {
574            return Err(FerroError::InvalidParameter {
575                name: "min_samples_split".into(),
576                reason: "must be at least 2".into(),
577            });
578        }
579        if self.min_samples_leaf < 1 {
580            return Err(FerroError::InvalidParameter {
581                name: "min_samples_leaf".into(),
582                reason: "must be at least 1".into(),
583            });
584        }
585
586        let indices: Vec<usize> = (0..n_samples).collect();
587        let max_features_n = resolve_max_features(self.max_features, n_features);
588
589        let mut rng = if let Some(seed) = self.random_state {
590            StdRng::seed_from_u64(seed)
591        } else {
592            StdRng::from_os_rng()
593        };
594
595        let data = RegressionData {
596            x,
597            y,
598            feature_indices: None,
599        };
600        let params = TreeParams {
601            max_depth: self.max_depth,
602            min_samples_split: self.min_samples_split,
603            min_samples_leaf: self.min_samples_leaf,
604        };
605
606        let mut nodes: Vec<Node<F>> = Vec::new();
607        build_extra_regression_tree(
608            &data,
609            &indices,
610            &mut nodes,
611            0,
612            &params,
613            n_features,
614            max_features_n,
615            &mut rng,
616        );
617
618        let feature_importances = compute_feature_importances(&nodes, n_features, n_samples);
619
620        Ok(FittedExtraTreeRegressor {
621            nodes,
622            n_features,
623            feature_importances,
624        })
625    }
626}
627
628impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreeRegressor<F> {
629    type Output = Array1<F>;
630    type Error = FerroError;
631
632    /// Predict target values for the given feature matrix.
633    ///
634    /// # Errors
635    ///
636    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
637    /// not match the fitted model.
638    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
639        if x.ncols() != self.n_features {
640            return Err(FerroError::ShapeMismatch {
641                expected: vec![self.n_features],
642                actual: vec![x.ncols()],
643                context: "number of features must match fitted model".into(),
644            });
645        }
646        let n_samples = x.nrows();
647        let mut predictions = Array1::zeros(n_samples);
648        for i in 0..n_samples {
649            let row = x.row(i);
650            let leaf = traverse(&self.nodes, &row);
651            if let Node::Leaf { value, .. } = self.nodes[leaf] {
652                predictions[i] = value;
653            }
654        }
655        Ok(predictions)
656    }
657}
658
659impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreeRegressor<F> {
660    fn feature_importances(&self) -> &Array1<F> {
661        &self.feature_importances
662    }
663}
664
665// Pipeline integration.
666impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for ExtraTreeRegressor<F> {
667    fn fit_pipeline(
668        &self,
669        x: &Array2<F>,
670        y: &Array1<F>,
671    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
672        let fitted = self.fit(x, y)?;
673        Ok(Box::new(fitted))
674    }
675}
676
677impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedExtraTreeRegressor<F> {
678    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
679        self.predict(x)
680    }
681}
682
683// ---------------------------------------------------------------------------
684// Internal: helpers
685// ---------------------------------------------------------------------------
686
687/// Resolve the `MaxFeatures` strategy to a concrete number.
688fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
689    let result = match strategy {
690        MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
691        MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
692        MaxFeatures::All => n_features,
693        MaxFeatures::Fixed(n) => n.min(n_features),
694        MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
695    };
696    result.max(1).min(n_features)
697}
698
699/// Convert a `Float` value to `usize` (for class labels stored as floats).
700fn float_to_usize<F: Float>(v: F) -> usize {
701    v.to_f64().map(|f| f.round() as usize).unwrap_or(0)
702}
703
704/// Generate a uniform random float in `[min_val, max_val]`.
705fn random_threshold<F: Float>(rng: &mut StdRng, min_val: F, max_val: F) -> F {
706    use rand::RngCore;
707    // Generate a random f64 in [0, 1) and scale to [min_val, max_val].
708    let u = (rng.next_u64() as f64) / (u64::MAX as f64);
709    let range = max_val - min_val;
710    min_val + F::from(u).unwrap() * range
711}
712
713/// Compute the Gini impurity for a set of class counts.
714fn gini_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
715    if total == 0 {
716        return F::zero();
717    }
718    let total_f = F::from(total).unwrap();
719    let mut impurity = F::one();
720    for &count in class_counts {
721        let p = F::from(count).unwrap() / total_f;
722        impurity = impurity - p * p;
723    }
724    impurity
725}
726
727/// Compute the Shannon entropy for a set of class counts.
728fn entropy_impurity<F: Float>(class_counts: &[usize], total: usize) -> F {
729    if total == 0 {
730        return F::zero();
731    }
732    let total_f = F::from(total).unwrap();
733    let mut ent = F::zero();
734    for &count in class_counts {
735        if count > 0 {
736            let p = F::from(count).unwrap() / total_f;
737            ent = ent - p * p.ln();
738        }
739    }
740    ent
741}
742
743/// Compute impurity for a given classification criterion.
744fn compute_impurity<F: Float>(
745    class_counts: &[usize],
746    total: usize,
747    criterion: ClassificationCriterion,
748) -> F {
749    match criterion {
750        ClassificationCriterion::Gini => gini_impurity(class_counts, total),
751        ClassificationCriterion::Entropy => entropy_impurity(class_counts, total),
752    }
753}
754
755/// Create a classification leaf node and return its index.
756fn make_classification_leaf<F: Float>(
757    nodes: &mut Vec<Node<F>>,
758    class_counts: &[usize],
759    n_classes: usize,
760    n_samples: usize,
761) -> usize {
762    let majority_class = class_counts
763        .iter()
764        .enumerate()
765        .max_by_key(|&(_, &count)| count)
766        .map(|(i, _)| i)
767        .unwrap_or(0);
768
769    let total_f = if n_samples > 0 {
770        F::from(n_samples).unwrap()
771    } else {
772        F::one()
773    };
774    let distribution: Vec<F> = (0..n_classes)
775        .map(|c| F::from(class_counts[c]).unwrap() / total_f)
776        .collect();
777
778    let idx = nodes.len();
779    nodes.push(Node::Leaf {
780        value: F::from(majority_class).unwrap(),
781        class_distribution: Some(distribution),
782        n_samples,
783    });
784    idx
785}
786
787/// Compute the mean of target values for the given indices.
788fn mean_value<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
789    if indices.is_empty() {
790        return F::zero();
791    }
792    let sum: F = indices.iter().map(|&i| y[i]).fold(F::zero(), |a, b| a + b);
793    sum / F::from(indices.len()).unwrap()
794}
795
796// ---------------------------------------------------------------------------
797// Extra-tree classification building
798// ---------------------------------------------------------------------------
799
800/// Build an extra-tree classification tree recursively with random thresholds.
801///
802/// At each node, a random subset of features is considered, and for each feature
803/// a random threshold is drawn uniformly between the feature's min and max in the
804/// current node.
805#[allow(clippy::too_many_arguments)]
806fn build_extra_classification_tree<F: Float>(
807    data: &ClassificationData<'_, F>,
808    indices: &[usize],
809    nodes: &mut Vec<Node<F>>,
810    depth: usize,
811    params: &TreeParams,
812    n_features: usize,
813    max_features_n: usize,
814    rng: &mut StdRng,
815) -> usize {
816    let n = indices.len();
817
818    let mut class_counts = vec![0usize; data.n_classes];
819    for &i in indices {
820        class_counts[data.y[i]] += 1;
821    }
822
823    let should_stop = n < params.min_samples_split
824        || params.max_depth.is_some_and(|d| depth >= d)
825        || class_counts.iter().filter(|&&c| c > 0).count() <= 1;
826
827    if should_stop {
828        return make_classification_leaf(nodes, &class_counts, data.n_classes, n);
829    }
830
831    let best = find_random_classification_split(
832        data,
833        indices,
834        params.min_samples_leaf,
835        n_features,
836        max_features_n,
837        rng,
838    );
839
840    if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
841        let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
842            .iter()
843            .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
844
845        // Ensure both children have at least min_samples_leaf.
846        if left_indices.len() < params.min_samples_leaf
847            || right_indices.len() < params.min_samples_leaf
848        {
849            return make_classification_leaf(nodes, &class_counts, data.n_classes, n);
850        }
851
852        let node_idx = nodes.len();
853        nodes.push(Node::Leaf {
854            value: F::zero(),
855            class_distribution: None,
856            n_samples: 0,
857        }); // placeholder
858
859        let left_idx = build_extra_classification_tree(
860            data,
861            &left_indices,
862            nodes,
863            depth + 1,
864            params,
865            n_features,
866            max_features_n,
867            rng,
868        );
869        let right_idx = build_extra_classification_tree(
870            data,
871            &right_indices,
872            nodes,
873            depth + 1,
874            params,
875            n_features,
876            max_features_n,
877            rng,
878        );
879
880        nodes[node_idx] = Node::Split {
881            feature: best_feature,
882            threshold: best_threshold,
883            left: left_idx,
884            right: right_idx,
885            impurity_decrease: best_impurity_decrease,
886            n_samples: n,
887        };
888
889        node_idx
890    } else {
891        make_classification_leaf(nodes, &class_counts, data.n_classes, n)
892    }
893}
894
895/// Find the best random split for a classification node.
896///
897/// For each candidate feature (from a random subset), pick a random threshold
898/// between min and max of that feature in the current node. Return the split
899/// with the largest impurity decrease, or `None` if no valid split exists.
900#[allow(clippy::too_many_arguments)]
901fn find_random_classification_split<F: Float>(
902    data: &ClassificationData<'_, F>,
903    indices: &[usize],
904    min_samples_leaf: usize,
905    n_features: usize,
906    max_features_n: usize,
907    rng: &mut StdRng,
908) -> Option<(usize, F, F)> {
909    let n = indices.len();
910    let n_f = F::from(n).unwrap();
911
912    let mut parent_counts = vec![0usize; data.n_classes];
913    for &i in indices {
914        parent_counts[data.y[i]] += 1;
915    }
916    let parent_impurity = compute_impurity::<F>(&parent_counts, n, data.criterion);
917
918    let mut best_score = F::neg_infinity();
919    let mut best_feature = 0;
920    let mut best_threshold = F::zero();
921
922    // Select random feature subset.
923    let feature_subset: Vec<usize> = if let Some(feat_indices) = data.feature_indices {
924        // If a feature subset is provided externally, sample from it.
925        let k = max_features_n.min(feat_indices.len());
926        rand_sample_indices(rng, feat_indices.len(), k)
927            .into_vec()
928            .into_iter()
929            .map(|i| feat_indices[i])
930            .collect()
931    } else {
932        let k = max_features_n.min(n_features);
933        rand_sample_indices(rng, n_features, k).into_vec()
934    };
935
936    for feat in feature_subset {
937        // Find min and max of this feature in the current node.
938        let mut feat_min = F::infinity();
939        let mut feat_max = F::neg_infinity();
940        for &i in indices {
941            let val = data.x[[i, feat]];
942            if val < feat_min {
943                feat_min = val;
944            }
945            if val > feat_max {
946                feat_max = val;
947            }
948        }
949
950        // Skip constant features.
951        if feat_min >= feat_max {
952            continue;
953        }
954
955        // Draw a random threshold uniformly in (min, max).
956        let threshold = random_threshold(rng, feat_min, feat_max);
957
958        // Count left and right.
959        let mut left_counts = vec![0usize; data.n_classes];
960        let mut right_counts = vec![0usize; data.n_classes];
961        let mut left_n = 0usize;
962
963        for &i in indices {
964            let cls = data.y[i];
965            if data.x[[i, feat]] <= threshold {
966                left_counts[cls] += 1;
967                left_n += 1;
968            } else {
969                right_counts[cls] += 1;
970            }
971        }
972
973        let right_n = n - left_n;
974        if left_n < min_samples_leaf || right_n < min_samples_leaf {
975            continue;
976        }
977
978        let left_impurity = compute_impurity::<F>(&left_counts, left_n, data.criterion);
979        let right_impurity = compute_impurity::<F>(&right_counts, right_n, data.criterion);
980        let left_weight = F::from(left_n).unwrap() / n_f;
981        let right_weight = F::from(right_n).unwrap() / n_f;
982        let weighted_child_impurity = left_weight * left_impurity + right_weight * right_impurity;
983        let impurity_decrease = parent_impurity - weighted_child_impurity;
984
985        if impurity_decrease > best_score {
986            best_score = impurity_decrease;
987            best_feature = feat;
988            best_threshold = threshold;
989        }
990    }
991
992    if best_score > F::zero() {
993        Some((best_feature, best_threshold, best_score * n_f))
994    } else {
995        None
996    }
997}
998
999// ---------------------------------------------------------------------------
1000// Extra-tree regression building
1001// ---------------------------------------------------------------------------
1002
1003/// Build an extra-tree regression tree recursively with random thresholds.
1004#[allow(clippy::too_many_arguments)]
1005fn build_extra_regression_tree<F: Float>(
1006    data: &RegressionData<'_, F>,
1007    indices: &[usize],
1008    nodes: &mut Vec<Node<F>>,
1009    depth: usize,
1010    params: &TreeParams,
1011    n_features: usize,
1012    max_features_n: usize,
1013    rng: &mut StdRng,
1014) -> usize {
1015    let n = indices.len();
1016    let mean = mean_value(data.y, indices);
1017
1018    let should_stop = n < params.min_samples_split || params.max_depth.is_some_and(|d| depth >= d);
1019
1020    if should_stop {
1021        let idx = nodes.len();
1022        nodes.push(Node::Leaf {
1023            value: mean,
1024            class_distribution: None,
1025            n_samples: n,
1026        });
1027        return idx;
1028    }
1029
1030    // Check if variance is essentially zero.
1031    let parent_sum_sq: F = indices
1032        .iter()
1033        .map(|&i| {
1034            let diff = data.y[i] - mean;
1035            diff * diff
1036        })
1037        .fold(F::zero(), |a, b| a + b);
1038    let parent_mse = parent_sum_sq / F::from(n).unwrap();
1039
1040    if parent_mse <= F::epsilon() {
1041        let idx = nodes.len();
1042        nodes.push(Node::Leaf {
1043            value: mean,
1044            class_distribution: None,
1045            n_samples: n,
1046        });
1047        return idx;
1048    }
1049
1050    let best = find_random_regression_split(
1051        data,
1052        indices,
1053        params.min_samples_leaf,
1054        n_features,
1055        max_features_n,
1056        rng,
1057    );
1058
1059    if let Some((best_feature, best_threshold, best_impurity_decrease)) = best {
1060        let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
1061            .iter()
1062            .partition(|&&i| data.x[[i, best_feature]] <= best_threshold);
1063
1064        // Ensure both children have at least min_samples_leaf.
1065        if left_indices.len() < params.min_samples_leaf
1066            || right_indices.len() < params.min_samples_leaf
1067        {
1068            let idx = nodes.len();
1069            nodes.push(Node::Leaf {
1070                value: mean,
1071                class_distribution: None,
1072                n_samples: n,
1073            });
1074            return idx;
1075        }
1076
1077        let node_idx = nodes.len();
1078        nodes.push(Node::Leaf {
1079            value: F::zero(),
1080            class_distribution: None,
1081            n_samples: 0,
1082        }); // placeholder
1083
1084        let left_idx = build_extra_regression_tree(
1085            data,
1086            &left_indices,
1087            nodes,
1088            depth + 1,
1089            params,
1090            n_features,
1091            max_features_n,
1092            rng,
1093        );
1094        let right_idx = build_extra_regression_tree(
1095            data,
1096            &right_indices,
1097            nodes,
1098            depth + 1,
1099            params,
1100            n_features,
1101            max_features_n,
1102            rng,
1103        );
1104
1105        nodes[node_idx] = Node::Split {
1106            feature: best_feature,
1107            threshold: best_threshold,
1108            left: left_idx,
1109            right: right_idx,
1110            impurity_decrease: best_impurity_decrease,
1111            n_samples: n,
1112        };
1113
1114        node_idx
1115    } else {
1116        let idx = nodes.len();
1117        nodes.push(Node::Leaf {
1118            value: mean,
1119            class_distribution: None,
1120            n_samples: n,
1121        });
1122        idx
1123    }
1124}
1125
1126/// Find the best random split for a regression node.
1127///
1128/// For each candidate feature (from a random subset), pick a random threshold
1129/// between min and max of that feature in the current node. Return the split
1130/// with the largest MSE decrease, or `None` if no valid split exists.
1131#[allow(clippy::too_many_arguments)]
1132fn find_random_regression_split<F: Float>(
1133    data: &RegressionData<'_, F>,
1134    indices: &[usize],
1135    min_samples_leaf: usize,
1136    n_features: usize,
1137    max_features_n: usize,
1138    rng: &mut StdRng,
1139) -> Option<(usize, F, F)> {
1140    let n = indices.len();
1141    let n_f = F::from(n).unwrap();
1142
1143    let parent_sum: F = indices
1144        .iter()
1145        .map(|&i| data.y[i])
1146        .fold(F::zero(), |a, b| a + b);
1147    let parent_sum_sq: F = indices
1148        .iter()
1149        .map(|&i| data.y[i] * data.y[i])
1150        .fold(F::zero(), |a, b| a + b);
1151    let parent_mse = parent_sum_sq / n_f - (parent_sum / n_f) * (parent_sum / n_f);
1152
1153    let mut best_score = F::neg_infinity();
1154    let mut best_feature = 0;
1155    let mut best_threshold = F::zero();
1156
1157    // Select random feature subset.
1158    let feature_subset: Vec<usize> = if let Some(feat_indices) = data.feature_indices {
1159        let k = max_features_n.min(feat_indices.len());
1160        rand_sample_indices(rng, feat_indices.len(), k)
1161            .into_vec()
1162            .into_iter()
1163            .map(|i| feat_indices[i])
1164            .collect()
1165    } else {
1166        let k = max_features_n.min(n_features);
1167        rand_sample_indices(rng, n_features, k).into_vec()
1168    };
1169
1170    for feat in feature_subset {
1171        // Find min and max of this feature in the current node.
1172        let mut feat_min = F::infinity();
1173        let mut feat_max = F::neg_infinity();
1174        for &i in indices {
1175            let val = data.x[[i, feat]];
1176            if val < feat_min {
1177                feat_min = val;
1178            }
1179            if val > feat_max {
1180                feat_max = val;
1181            }
1182        }
1183
1184        // Skip constant features.
1185        if feat_min >= feat_max {
1186            continue;
1187        }
1188
1189        // Draw a random threshold uniformly in (min, max).
1190        let threshold = random_threshold(rng, feat_min, feat_max);
1191
1192        // Compute left/right statistics.
1193        let mut left_sum = F::zero();
1194        let mut left_sum_sq = F::zero();
1195        let mut left_n: usize = 0;
1196
1197        for &i in indices {
1198            if data.x[[i, feat]] <= threshold {
1199                let val = data.y[i];
1200                left_sum = left_sum + val;
1201                left_sum_sq = left_sum_sq + val * val;
1202                left_n += 1;
1203            }
1204        }
1205
1206        let right_n = n - left_n;
1207        if left_n < min_samples_leaf || right_n < min_samples_leaf {
1208            continue;
1209        }
1210
1211        let left_n_f = F::from(left_n).unwrap();
1212        let right_n_f = F::from(right_n).unwrap();
1213
1214        let left_mean = left_sum / left_n_f;
1215        let left_mse = left_sum_sq / left_n_f - left_mean * left_mean;
1216
1217        let right_sum = parent_sum - left_sum;
1218        let right_sum_sq = parent_sum_sq - left_sum_sq;
1219        let right_mean = right_sum / right_n_f;
1220        let right_mse = right_sum_sq / right_n_f - right_mean * right_mean;
1221
1222        let weighted_child_mse = (left_n_f * left_mse + right_n_f * right_mse) / n_f;
1223        let mse_decrease = parent_mse - weighted_child_mse;
1224
1225        if mse_decrease > best_score {
1226            best_score = mse_decrease;
1227            best_feature = feat;
1228            best_threshold = threshold;
1229        }
1230    }
1231
1232    if best_score > F::zero() {
1233        Some((best_feature, best_threshold, best_score * n_f))
1234    } else {
1235        None
1236    }
1237}
1238
1239// ---------------------------------------------------------------------------
1240// Crate-internal functions for ensemble usage
1241// ---------------------------------------------------------------------------
1242
1243/// Build a classification extra-tree with a subset of features for ensemble use.
1244///
1245/// Used internally by `ExtraTreesClassifier` to build individual trees.
1246#[allow(clippy::too_many_arguments)]
1247pub(crate) fn build_extra_classification_tree_for_ensemble<F: Float>(
1248    x: &Array2<F>,
1249    y: &[usize],
1250    n_classes: usize,
1251    indices: &[usize],
1252    feature_indices: Option<&[usize]>,
1253    params: &TreeParams,
1254    criterion: ClassificationCriterion,
1255    n_features: usize,
1256    max_features_n: usize,
1257    rng: &mut StdRng,
1258) -> Vec<Node<F>> {
1259    let data = ClassificationData {
1260        x,
1261        y,
1262        n_classes,
1263        feature_indices,
1264        criterion,
1265    };
1266    let mut nodes = Vec::new();
1267    build_extra_classification_tree(
1268        &data,
1269        indices,
1270        &mut nodes,
1271        0,
1272        params,
1273        n_features,
1274        max_features_n,
1275        rng,
1276    );
1277    nodes
1278}
1279
1280/// Build a regression extra-tree with a subset of features for ensemble use.
1281///
1282/// Used internally by `ExtraTreesRegressor` to build individual trees.
1283#[allow(clippy::too_many_arguments)]
1284pub(crate) fn build_extra_regression_tree_for_ensemble<F: Float>(
1285    x: &Array2<F>,
1286    y: &Array1<F>,
1287    indices: &[usize],
1288    feature_indices: Option<&[usize]>,
1289    params: &TreeParams,
1290    n_features: usize,
1291    max_features_n: usize,
1292    rng: &mut StdRng,
1293) -> Vec<Node<F>> {
1294    let data = RegressionData {
1295        x,
1296        y,
1297        feature_indices,
1298    };
1299    let mut nodes = Vec::new();
1300    build_extra_regression_tree(
1301        &data,
1302        indices,
1303        &mut nodes,
1304        0,
1305        params,
1306        n_features,
1307        max_features_n,
1308        rng,
1309    );
1310    nodes
1311}
1312
1313// ---------------------------------------------------------------------------
1314// Tests
1315// ---------------------------------------------------------------------------
1316
1317#[cfg(test)]
1318mod tests {
1319    use super::*;
1320    use approx::assert_relative_eq;
1321    use ndarray::array;
1322
1323    // -- Classifier tests --
1324
1325    #[test]
1326    fn test_extra_classifier_simple_binary() {
1327        let x = Array2::from_shape_vec(
1328            (6, 2),
1329            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
1330        )
1331        .unwrap();
1332        let y = array![0, 0, 0, 1, 1, 1];
1333
1334        let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1335        let fitted = model.fit(&x, &y).unwrap();
1336        let preds = fitted.predict(&x).unwrap();
1337
1338        assert_eq!(preds.len(), 6);
1339        // ExtraTrees should separate linearly separable data.
1340        for i in 0..3 {
1341            assert_eq!(preds[i], 0, "sample {i} should be class 0");
1342        }
1343        for i in 3..6 {
1344            assert_eq!(preds[i], 1, "sample {i} should be class 1");
1345        }
1346    }
1347
1348    #[test]
1349    fn test_extra_classifier_single_class() {
1350        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1351        let y = array![0, 0, 0];
1352
1353        let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1354        let fitted = model.fit(&x, &y).unwrap();
1355        let preds = fitted.predict(&x).unwrap();
1356
1357        assert_eq!(preds, array![0, 0, 0]);
1358    }
1359
1360    #[test]
1361    fn test_extra_classifier_max_depth_1() {
1362        let x =
1363            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1364        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1365
1366        let model = ExtraTreeClassifier::<f64>::new()
1367            .with_max_depth(Some(1))
1368            .with_max_features(MaxFeatures::All)
1369            .with_random_state(42);
1370        let fitted = model.fit(&x, &y).unwrap();
1371        let preds = fitted.predict(&x).unwrap();
1372
1373        // With depth 1 and a single feature, it should still separate the classes.
1374        // The tree has exactly one split node and two leaves.
1375        assert_eq!(fitted.nodes().len(), 3);
1376    }
1377
1378    #[test]
1379    fn test_extra_classifier_predict_proba() {
1380        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1381        let y = array![0, 0, 0, 1, 1, 1];
1382
1383        let model = ExtraTreeClassifier::<f64>::new()
1384            .with_max_features(MaxFeatures::All)
1385            .with_random_state(42);
1386        let fitted = model.fit(&x, &y).unwrap();
1387        let proba = fitted.predict_proba(&x).unwrap();
1388
1389        assert_eq!(proba.dim(), (6, 2));
1390        // Each row sums to 1.
1391        for i in 0..6 {
1392            let row_sum = proba.row(i).sum();
1393            assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
1394        }
1395    }
1396
1397    #[test]
1398    fn test_extra_classifier_feature_importances() {
1399        let x = Array2::from_shape_vec(
1400            (8, 2),
1401            vec![
1402                1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 5.0, 1.0, 6.0, 1.0, 7.0, 1.0, 8.0, 1.0,
1403            ],
1404        )
1405        .unwrap();
1406        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1407
1408        let model = ExtraTreeClassifier::<f64>::new()
1409            .with_max_features(MaxFeatures::All)
1410            .with_random_state(42);
1411        let fitted = model.fit(&x, &y).unwrap();
1412        let importances = fitted.feature_importances();
1413
1414        assert_eq!(importances.len(), 2);
1415        // The sum of importances should be 1 (normalised).
1416        let total: f64 = importances.sum();
1417        assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1418        // Feature 0 should have higher importance (feature 1 is constant).
1419        assert!(importances[0] > importances[1]);
1420    }
1421
1422    #[test]
1423    fn test_extra_classifier_shape_mismatch() {
1424        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1425        let y = array![0, 0]; // wrong length
1426
1427        let model = ExtraTreeClassifier::<f64>::new();
1428        assert!(model.fit(&x, &y).is_err());
1429    }
1430
1431    #[test]
1432    fn test_extra_classifier_empty_data() {
1433        let x = Array2::<f64>::zeros((0, 2));
1434        let y = Array1::<usize>::zeros(0);
1435
1436        let model = ExtraTreeClassifier::<f64>::new();
1437        assert!(model.fit(&x, &y).is_err());
1438    }
1439
1440    #[test]
1441    fn test_extra_classifier_invalid_min_samples_split() {
1442        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1443        let y = array![0, 0, 1];
1444
1445        let model = ExtraTreeClassifier::<f64>::new().with_min_samples_split(1);
1446        assert!(model.fit(&x, &y).is_err());
1447    }
1448
1449    #[test]
1450    fn test_extra_classifier_classes() {
1451        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1452        let y = array![0, 0, 0, 2, 2, 2]; // non-contiguous classes
1453
1454        let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1455        let fitted = model.fit(&x, &y).unwrap();
1456
1457        assert_eq!(fitted.classes(), &[0, 2]);
1458        assert_eq!(fitted.n_classes(), 2);
1459    }
1460
1461    #[test]
1462    fn test_extra_classifier_predict_shape_mismatch() {
1463        let x =
1464            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1465        let y = array![0, 0, 1, 1];
1466
1467        let model = ExtraTreeClassifier::<f64>::new().with_random_state(42);
1468        let fitted = model.fit(&x, &y).unwrap();
1469
1470        let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1471        assert!(fitted.predict(&x_wrong).is_err());
1472    }
1473
1474    #[test]
1475    fn test_extra_classifier_f32() {
1476        let x = Array2::from_shape_vec(
1477            (6, 2),
1478            vec![
1479                1.0f32, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0,
1480            ],
1481        )
1482        .unwrap();
1483        let y = array![0, 0, 0, 1, 1, 1];
1484
1485        let model = ExtraTreeClassifier::<f32>::new().with_random_state(42);
1486        let fitted = model.fit(&x, &y).unwrap();
1487        let preds = fitted.predict(&x).unwrap();
1488        assert_eq!(preds.len(), 6);
1489    }
1490
1491    #[test]
1492    fn test_extra_classifier_deterministic() {
1493        let x = Array2::from_shape_vec(
1494            (8, 2),
1495            vec![
1496                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1497            ],
1498        )
1499        .unwrap();
1500        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1501
1502        let model1 = ExtraTreeClassifier::<f64>::new().with_random_state(123);
1503        let model2 = ExtraTreeClassifier::<f64>::new().with_random_state(123);
1504
1505        let fitted1 = model1.fit(&x, &y).unwrap();
1506        let fitted2 = model2.fit(&x, &y).unwrap();
1507
1508        let preds1 = fitted1.predict(&x).unwrap();
1509        let preds2 = fitted2.predict(&x).unwrap();
1510
1511        assert_eq!(preds1, preds2);
1512    }
1513
1514    // -- Regressor tests --
1515
1516    #[test]
1517    fn test_extra_regressor_simple() {
1518        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1519        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1520
1521        let model = ExtraTreeRegressor::<f64>::new()
1522            .with_max_features(MaxFeatures::All)
1523            .with_random_state(42);
1524        let fitted = model.fit(&x, &y).unwrap();
1525        let preds = fitted.predict(&x).unwrap();
1526
1527        // A deep extra-tree should roughly memorize the training data.
1528        assert_eq!(preds.len(), 6);
1529        for i in 0..6 {
1530            assert_relative_eq!(preds[i], y[i], epsilon = 1.0);
1531        }
1532    }
1533
1534    #[test]
1535    fn test_extra_regressor_constant_target() {
1536        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1537        let y = array![5.0, 5.0, 5.0, 5.0];
1538
1539        let model = ExtraTreeRegressor::<f64>::new().with_random_state(42);
1540        let fitted = model.fit(&x, &y).unwrap();
1541        let preds = fitted.predict(&x).unwrap();
1542
1543        for &p in preds.iter() {
1544            assert_relative_eq!(p, 5.0, epsilon = 1e-10);
1545        }
1546    }
1547
1548    #[test]
1549    fn test_extra_regressor_feature_importances() {
1550        let x = Array2::from_shape_vec(
1551            (8, 2),
1552            vec![
1553                1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1554            ],
1555        )
1556        .unwrap();
1557        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1558
1559        let model = ExtraTreeRegressor::<f64>::new()
1560            .with_max_features(MaxFeatures::All)
1561            .with_random_state(42);
1562        let fitted = model.fit(&x, &y).unwrap();
1563        let importances = fitted.feature_importances();
1564
1565        assert_eq!(importances.len(), 2);
1566        let total: f64 = importances.sum();
1567        assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1568        // Feature 0 drives the target; feature 1 is constant.
1569        assert!(importances[0] > importances[1]);
1570    }
1571
1572    #[test]
1573    fn test_extra_regressor_shape_mismatch() {
1574        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1575        let y = array![1.0, 2.0]; // wrong length
1576
1577        let model = ExtraTreeRegressor::<f64>::new();
1578        assert!(model.fit(&x, &y).is_err());
1579    }
1580
1581    #[test]
1582    fn test_extra_regressor_empty_data() {
1583        let x = Array2::<f64>::zeros((0, 2));
1584        let y = Array1::<f64>::zeros(0);
1585
1586        let model = ExtraTreeRegressor::<f64>::new();
1587        assert!(model.fit(&x, &y).is_err());
1588    }
1589
1590    #[test]
1591    fn test_extra_regressor_predict_shape_mismatch() {
1592        let x =
1593            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1594        let y = array![1.0, 2.0, 3.0, 4.0];
1595
1596        let model = ExtraTreeRegressor::<f64>::new().with_random_state(42);
1597        let fitted = model.fit(&x, &y).unwrap();
1598
1599        let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1600        assert!(fitted.predict(&x_wrong).is_err());
1601    }
1602
1603    #[test]
1604    fn test_extra_regressor_max_depth() {
1605        let x =
1606            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1607        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1608
1609        let model = ExtraTreeRegressor::<f64>::new()
1610            .with_max_depth(Some(1))
1611            .with_max_features(MaxFeatures::All)
1612            .with_random_state(42);
1613        let fitted = model.fit(&x, &y).unwrap();
1614
1615        // With depth 1, the tree should have exactly 3 nodes: one split + two leaves.
1616        assert_eq!(fitted.nodes().len(), 3);
1617    }
1618
1619    #[test]
1620    fn test_extra_regressor_deterministic() {
1621        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1622        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1623
1624        let model1 = ExtraTreeRegressor::<f64>::new().with_random_state(99);
1625        let model2 = ExtraTreeRegressor::<f64>::new().with_random_state(99);
1626
1627        let fitted1 = model1.fit(&x, &y).unwrap();
1628        let fitted2 = model2.fit(&x, &y).unwrap();
1629
1630        let preds1 = fitted1.predict(&x).unwrap();
1631        let preds2 = fitted2.predict(&x).unwrap();
1632
1633        for i in 0..6 {
1634            assert_relative_eq!(preds1[i], preds2[i], epsilon = 1e-12);
1635        }
1636    }
1637
1638    #[test]
1639    fn test_extra_regressor_f32() {
1640        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1641        let y = array![1.0f32, 2.0, 3.0, 4.0];
1642
1643        let model = ExtraTreeRegressor::<f32>::new().with_random_state(42);
1644        let fitted = model.fit(&x, &y).unwrap();
1645        let preds = fitted.predict(&x).unwrap();
1646        assert_eq!(preds.len(), 4);
1647    }
1648
1649    // -- Builder tests --
1650
1651    #[test]
1652    fn test_classifier_builder_methods() {
1653        let model = ExtraTreeClassifier::<f64>::new()
1654            .with_max_depth(Some(5))
1655            .with_min_samples_split(10)
1656            .with_min_samples_leaf(3)
1657            .with_max_features(MaxFeatures::Log2)
1658            .with_criterion(ClassificationCriterion::Entropy)
1659            .with_random_state(42);
1660
1661        assert_eq!(model.max_depth, Some(5));
1662        assert_eq!(model.min_samples_split, 10);
1663        assert_eq!(model.min_samples_leaf, 3);
1664        assert_eq!(model.max_features, MaxFeatures::Log2);
1665        assert_eq!(model.criterion, ClassificationCriterion::Entropy);
1666        assert_eq!(model.random_state, Some(42));
1667    }
1668
1669    #[test]
1670    fn test_regressor_builder_methods() {
1671        let model = ExtraTreeRegressor::<f64>::new()
1672            .with_max_depth(Some(10))
1673            .with_min_samples_split(5)
1674            .with_min_samples_leaf(2)
1675            .with_max_features(MaxFeatures::Fixed(3))
1676            .with_criterion(RegressionCriterion::Mse)
1677            .with_random_state(99);
1678
1679        assert_eq!(model.max_depth, Some(10));
1680        assert_eq!(model.min_samples_split, 5);
1681        assert_eq!(model.min_samples_leaf, 2);
1682        assert_eq!(model.max_features, MaxFeatures::Fixed(3));
1683        assert_eq!(model.criterion, RegressionCriterion::Mse);
1684        assert_eq!(model.random_state, Some(99));
1685    }
1686
1687    #[test]
1688    fn test_classifier_default() {
1689        let model = ExtraTreeClassifier::<f64>::default();
1690        assert_eq!(model.max_depth, None);
1691        assert_eq!(model.min_samples_split, 2);
1692        assert_eq!(model.min_samples_leaf, 1);
1693        assert_eq!(model.max_features, MaxFeatures::Sqrt);
1694        assert_eq!(model.criterion, ClassificationCriterion::Gini);
1695        assert_eq!(model.random_state, None);
1696    }
1697
1698    #[test]
1699    fn test_regressor_default() {
1700        let model = ExtraTreeRegressor::<f64>::default();
1701        assert_eq!(model.max_depth, None);
1702        assert_eq!(model.min_samples_split, 2);
1703        assert_eq!(model.min_samples_leaf, 1);
1704        assert_eq!(model.max_features, MaxFeatures::All);
1705        assert_eq!(model.criterion, RegressionCriterion::Mse);
1706        assert_eq!(model.random_state, None);
1707    }
1708}