Skip to main content

ferrolearn_tree/
extra_trees_ensemble.rs

1//! Extremely randomized trees ensemble classifiers and regressors.
2//!
3//! This module provides [`ExtraTreesClassifier`] and [`ExtraTreesRegressor`],
4//! which build ensembles of extremely randomized trees. Unlike
5//! [`RandomForestClassifier`](crate::RandomForestClassifier), ExtraTrees
6//! ensembles do **not** bootstrap by default: all trees see all samples, and
7//! randomness comes solely from the random split thresholds and random feature
8//! subsets at each node.
9//!
10//! # Examples
11//!
12//! ```
13//! use ferrolearn_tree::ExtraTreesClassifier;
14//! use ferrolearn_core::{Fit, Predict};
15//! use ndarray::{array, Array1, Array2};
16//!
17//! let x = Array2::from_shape_vec((8, 2), vec![
18//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,  4.0, 4.0,
19//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,  8.0, 9.0,
20//! ]).unwrap();
21//! let y = array![0, 0, 0, 0, 1, 1, 1, 1];
22//!
23//! let model = ExtraTreesClassifier::<f64>::new()
24//!     .with_n_estimators(10)
25//!     .with_random_state(42);
26//! let fitted = model.fit(&x, &y).unwrap();
27//! let preds = fitted.predict(&x).unwrap();
28//! ```
29
30use ferrolearn_core::error::FerroError;
31use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
32use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
33use ferrolearn_core::traits::{Fit, Predict};
34use ndarray::{Array1, Array2};
35use num_traits::{Float, FromPrimitive, ToPrimitive};
36use rand::SeedableRng;
37use rand::rngs::StdRng;
38use rayon::prelude::*;
39use serde::{Deserialize, Serialize};
40
41use crate::decision_tree::{
42    ClassificationCriterion, Node, TreeParams, compute_feature_importances, traverse,
43};
44use crate::extra_tree::{
45    build_extra_classification_tree_for_ensemble, build_extra_regression_tree_for_ensemble,
46};
47use crate::random_forest::MaxFeatures;
48
49/// Resolve the `MaxFeatures` strategy to a concrete number.
50fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
51    let result = match strategy {
52        MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
53        MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
54        MaxFeatures::All => n_features,
55        MaxFeatures::Fixed(n) => n.min(n_features),
56        MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
57    };
58    result.max(1).min(n_features)
59}
60
61/// Internal tree parameter struct helper.
62fn make_tree_params(
63    max_depth: Option<usize>,
64    min_samples_split: usize,
65    min_samples_leaf: usize,
66) -> TreeParams {
67    TreeParams {
68        max_depth,
69        min_samples_split,
70        min_samples_leaf,
71    }
72}
73
74// ---------------------------------------------------------------------------
75// ExtraTreesClassifier
76// ---------------------------------------------------------------------------
77
78/// Extremely randomized trees classifier (ensemble).
79///
80/// Builds an ensemble of [`ExtraTreeClassifier`](crate::ExtraTreeClassifier)
81/// base estimators, each using random split thresholds and random feature
82/// subsets at every node. Final predictions are made by majority vote.
83///
84/// Unlike [`RandomForestClassifier`](crate::RandomForestClassifier), bootstrap
85/// sampling is **disabled** by default. Randomness comes from the random
86/// thresholds and random feature subsets at each split.
87///
88/// # Type Parameters
89///
90/// - `F`: The floating-point type (`f32` or `f64`).
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct ExtraTreesClassifier<F> {
93    /// Number of trees in the ensemble.
94    pub n_estimators: usize,
95    /// Maximum depth of each tree. `None` means unlimited.
96    pub max_depth: Option<usize>,
97    /// Minimum number of samples required to split an internal node.
98    pub min_samples_split: usize,
99    /// Minimum number of samples required in a leaf node.
100    pub min_samples_leaf: usize,
101    /// Strategy for the number of features considered at each split.
102    pub max_features: MaxFeatures,
103    /// Whether to use bootstrap sampling. Default is `false`.
104    pub bootstrap: bool,
105    /// Splitting criterion.
106    pub criterion: ClassificationCriterion,
107    /// Random seed for reproducibility. `None` means non-deterministic.
108    pub random_state: Option<u64>,
109    /// Number of parallel jobs. `None` means use all available cores.
110    pub n_jobs: Option<usize>,
111    _marker: std::marker::PhantomData<F>,
112}
113
114impl<F: Float> ExtraTreesClassifier<F> {
115    /// Create a new `ExtraTreesClassifier` with default settings.
116    ///
117    /// Defaults: `n_estimators = 100`, `max_depth = None`,
118    /// `max_features = Sqrt`, `min_samples_split = 2`,
119    /// `min_samples_leaf = 1`, `bootstrap = false`,
120    /// `criterion = Gini`, `random_state = None`, `n_jobs = None`.
121    #[must_use]
122    pub fn new() -> Self {
123        Self {
124            n_estimators: 100,
125            max_depth: None,
126            min_samples_split: 2,
127            min_samples_leaf: 1,
128            max_features: MaxFeatures::Sqrt,
129            bootstrap: false,
130            criterion: ClassificationCriterion::Gini,
131            random_state: None,
132            n_jobs: None,
133            _marker: std::marker::PhantomData,
134        }
135    }
136
137    /// Set the number of trees.
138    #[must_use]
139    pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
140        self.n_estimators = n_estimators;
141        self
142    }
143
144    /// Set the maximum tree depth.
145    #[must_use]
146    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
147        self.max_depth = max_depth;
148        self
149    }
150
151    /// Set the minimum number of samples to split a node.
152    #[must_use]
153    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
154        self.min_samples_split = min_samples_split;
155        self
156    }
157
158    /// Set the minimum number of samples in a leaf.
159    #[must_use]
160    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
161        self.min_samples_leaf = min_samples_leaf;
162        self
163    }
164
165    /// Set the maximum features strategy.
166    #[must_use]
167    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
168        self.max_features = max_features;
169        self
170    }
171
172    /// Set whether to use bootstrap sampling.
173    #[must_use]
174    pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
175        self.bootstrap = bootstrap;
176        self
177    }
178
179    /// Set the splitting criterion.
180    #[must_use]
181    pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
182        self.criterion = criterion;
183        self
184    }
185
186    /// Set the random seed for reproducibility.
187    #[must_use]
188    pub fn with_random_state(mut self, seed: u64) -> Self {
189        self.random_state = Some(seed);
190        self
191    }
192
193    /// Set the number of parallel jobs.
194    #[must_use]
195    pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
196        self.n_jobs = Some(n_jobs);
197        self
198    }
199}
200
201impl<F: Float> Default for ExtraTreesClassifier<F> {
202    fn default() -> Self {
203        Self::new()
204    }
205}
206
207// ---------------------------------------------------------------------------
208// FittedExtraTreesClassifier
209// ---------------------------------------------------------------------------
210
211/// A fitted extremely randomized trees classifier (ensemble).
212///
213/// Stores the ensemble of fitted extra-trees and aggregates their
214/// predictions by majority vote.
215#[derive(Debug, Clone)]
216pub struct FittedExtraTreesClassifier<F> {
217    /// Individual tree node vectors.
218    trees: Vec<Vec<Node<F>>>,
219    /// Sorted unique class labels.
220    classes: Vec<usize>,
221    /// Number of features.
222    n_features: usize,
223    /// Per-feature importance scores (mean decrease in impurity, normalised).
224    feature_importances: Array1<F>,
225}
226
227impl<F: Float + Send + Sync + 'static> FittedExtraTreesClassifier<F> {
228    /// Returns a reference to the individual tree node vectors.
229    #[must_use]
230    pub fn trees(&self) -> &[Vec<Node<F>>] {
231        &self.trees
232    }
233
234    /// Returns the number of features the model was trained on.
235    #[must_use]
236    pub fn n_features(&self) -> usize {
237        self.n_features
238    }
239
240    /// Returns the number of trees in the ensemble.
241    #[must_use]
242    pub fn n_estimators(&self) -> usize {
243        self.trees.len()
244    }
245
246    /// Predict class probabilities for each sample by averaging tree predictions.
247    ///
248    /// Returns a 2-D array of shape `(n_samples, n_classes)`.
249    ///
250    /// # Errors
251    ///
252    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
253    /// not match the training data.
254    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
255        if x.ncols() != self.n_features {
256            return Err(FerroError::ShapeMismatch {
257                expected: vec![self.n_features],
258                actual: vec![x.ncols()],
259                context: "number of features must match fitted model".into(),
260            });
261        }
262        let n_samples = x.nrows();
263        let n_classes = self.classes.len();
264        let n_trees_f = F::from(self.trees.len()).unwrap();
265        let mut proba = Array2::zeros((n_samples, n_classes));
266
267        for i in 0..n_samples {
268            let row = x.row(i);
269            for tree_nodes in &self.trees {
270                let leaf_idx = traverse(tree_nodes, &row);
271                if let Node::Leaf {
272                    class_distribution: Some(ref dist),
273                    ..
274                } = tree_nodes[leaf_idx]
275                {
276                    for (j, &p) in dist.iter().enumerate() {
277                        proba[[i, j]] = proba[[i, j]] + p;
278                    }
279                }
280            }
281            // Average across trees.
282            for j in 0..n_classes {
283                proba[[i, j]] = proba[[i, j]] / n_trees_f;
284            }
285        }
286        Ok(proba)
287    }
288
289    /// Mean accuracy on the given test data and labels.
290    /// Equivalent to sklearn's `ClassifierMixin.score`.
291    ///
292    /// # Errors
293    ///
294    /// Returns [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()` or
295    /// the feature count does not match the training data.
296    pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
297        if x.nrows() != y.len() {
298            return Err(FerroError::ShapeMismatch {
299                expected: vec![x.nrows()],
300                actual: vec![y.len()],
301                context: "y length must match number of samples in X".into(),
302            });
303        }
304        let preds = self.predict(x)?;
305        Ok(crate::mean_accuracy(&preds, y))
306    }
307
308    /// Element-wise log of [`predict_proba`](Self::predict_proba). Mirrors
309    /// sklearn's `ClassifierMixin.predict_log_proba`.
310    ///
311    /// # Errors
312    ///
313    /// Forwards any error from [`predict_proba`](Self::predict_proba).
314    pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
315        let proba = self.predict_proba(x)?;
316        Ok(crate::log_proba(&proba))
317    }
318}
319
320impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ExtraTreesClassifier<F> {
321    type Fitted = FittedExtraTreesClassifier<F>;
322    type Error = FerroError;
323
324    /// Fit the ensemble by building `n_estimators` extra-trees in parallel.
325    ///
326    /// Each tree uses random split thresholds and random feature subsets at
327    /// every node. If `bootstrap` is `true`, each tree is trained on a
328    /// bootstrap sample; otherwise all samples are used.
329    ///
330    /// # Errors
331    ///
332    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
333    /// numbers of samples.
334    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
335    /// Returns [`FerroError::InvalidParameter`] if `n_estimators` is 0.
336    fn fit(
337        &self,
338        x: &Array2<F>,
339        y: &Array1<usize>,
340    ) -> Result<FittedExtraTreesClassifier<F>, FerroError> {
341        let (n_samples, n_features) = x.dim();
342
343        if n_samples != y.len() {
344            return Err(FerroError::ShapeMismatch {
345                expected: vec![n_samples],
346                actual: vec![y.len()],
347                context: "y length must match number of samples in X".into(),
348            });
349        }
350        if n_samples == 0 {
351            return Err(FerroError::InsufficientSamples {
352                required: 1,
353                actual: 0,
354                context: "ExtraTreesClassifier requires at least one sample".into(),
355            });
356        }
357        if self.n_estimators == 0 {
358            return Err(FerroError::InvalidParameter {
359                name: "n_estimators".into(),
360                reason: "must be at least 1".into(),
361            });
362        }
363
364        // Determine unique classes.
365        let mut classes: Vec<usize> = y.iter().copied().collect();
366        classes.sort_unstable();
367        classes.dedup();
368        let n_classes = classes.len();
369
370        let y_mapped: Vec<usize> = y
371            .iter()
372            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
373            .collect();
374
375        let max_features_n = resolve_max_features(self.max_features, n_features);
376        let params = make_tree_params(
377            self.max_depth,
378            self.min_samples_split,
379            self.min_samples_leaf,
380        );
381        let criterion = self.criterion;
382        let bootstrap = self.bootstrap;
383
384        // Generate per-tree seeds sequentially for determinism.
385        let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
386            let mut master_rng = StdRng::seed_from_u64(seed);
387            (0..self.n_estimators)
388                .map(|_| {
389                    use rand::RngCore;
390                    master_rng.next_u64()
391                })
392                .collect()
393        } else {
394            (0..self.n_estimators)
395                .map(|_| {
396                    use rand::RngCore;
397                    rand::rng().next_u64()
398                })
399                .collect()
400        };
401
402        // Optionally configure thread pool.
403        let trees: Vec<Vec<Node<F>>> = if let Some(n_jobs) = self.n_jobs {
404            let pool = rayon::ThreadPoolBuilder::new()
405                .num_threads(n_jobs)
406                .build()
407                .unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
408            pool.install(|| {
409                tree_seeds
410                    .par_iter()
411                    .map(|&seed| {
412                        build_single_classification_tree(
413                            x,
414                            &y_mapped,
415                            n_classes,
416                            n_samples,
417                            n_features,
418                            max_features_n,
419                            &params,
420                            criterion,
421                            bootstrap,
422                            seed,
423                        )
424                    })
425                    .collect()
426            })
427        } else {
428            tree_seeds
429                .par_iter()
430                .map(|&seed| {
431                    build_single_classification_tree(
432                        x,
433                        &y_mapped,
434                        n_classes,
435                        n_samples,
436                        n_features,
437                        max_features_n,
438                        &params,
439                        criterion,
440                        bootstrap,
441                        seed,
442                    )
443                })
444                .collect()
445        };
446
447        // Aggregate feature importances across trees.
448        let mut total_importances = Array1::<F>::zeros(n_features);
449        for tree_nodes in &trees {
450            let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
451            total_importances = total_importances + tree_imp;
452        }
453        let imp_sum: F = total_importances
454            .iter()
455            .copied()
456            .fold(F::zero(), |a, b| a + b);
457        if imp_sum > F::zero() {
458            total_importances.mapv_inplace(|v| v / imp_sum);
459        }
460
461        Ok(FittedExtraTreesClassifier {
462            trees,
463            classes,
464            n_features,
465            feature_importances: total_importances,
466        })
467    }
468}
469
470/// Build a single classification extra-tree (used by parallel dispatch).
471#[allow(clippy::too_many_arguments)]
472fn build_single_classification_tree<F: Float>(
473    x: &Array2<F>,
474    y_mapped: &[usize],
475    n_classes: usize,
476    n_samples: usize,
477    n_features: usize,
478    max_features_n: usize,
479    params: &TreeParams,
480    criterion: ClassificationCriterion,
481    bootstrap: bool,
482    seed: u64,
483) -> Vec<Node<F>> {
484    let mut rng = StdRng::seed_from_u64(seed);
485
486    let indices: Vec<usize> = if bootstrap {
487        use rand::RngCore;
488        (0..n_samples)
489            .map(|_| (rng.next_u64() as usize) % n_samples)
490            .collect()
491    } else {
492        (0..n_samples).collect()
493    };
494
495    build_extra_classification_tree_for_ensemble(
496        x,
497        y_mapped,
498        n_classes,
499        &indices,
500        None, // feature selection happens inside the tree builder
501        params,
502        criterion,
503        n_features,
504        max_features_n,
505        &mut rng,
506    )
507}
508
509impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreesClassifier<F> {
510    type Output = Array1<usize>;
511    type Error = FerroError;
512
513    /// Predict class labels by majority vote across all trees.
514    ///
515    /// # Errors
516    ///
517    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
518    /// not match the fitted model.
519    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
520        if x.ncols() != self.n_features {
521            return Err(FerroError::ShapeMismatch {
522                expected: vec![self.n_features],
523                actual: vec![x.ncols()],
524                context: "number of features must match fitted model".into(),
525            });
526        }
527
528        let n_samples = x.nrows();
529        let n_classes = self.classes.len();
530        let mut predictions = Array1::zeros(n_samples);
531
532        for i in 0..n_samples {
533            let row = x.row(i);
534            let mut votes = vec![0usize; n_classes];
535
536            for tree_nodes in &self.trees {
537                let leaf_idx = traverse(tree_nodes, &row);
538                if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
539                    let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
540                    if class_idx < n_classes {
541                        votes[class_idx] += 1;
542                    }
543                }
544            }
545
546            let winner = votes
547                .iter()
548                .enumerate()
549                .max_by_key(|&(_, &count)| count)
550                .map_or(0, |(idx, _)| idx);
551            predictions[i] = self.classes[winner];
552        }
553
554        Ok(predictions)
555    }
556}
557
558impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreesClassifier<F> {
559    fn feature_importances(&self) -> &Array1<F> {
560        &self.feature_importances
561    }
562}
563
564impl<F: Float + Send + Sync + 'static> HasClasses for FittedExtraTreesClassifier<F> {
565    fn classes(&self) -> &[usize] {
566        &self.classes
567    }
568
569    fn n_classes(&self) -> usize {
570        self.classes.len()
571    }
572}
573
574// Pipeline integration.
575impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
576    for ExtraTreesClassifier<F>
577{
578    fn fit_pipeline(
579        &self,
580        x: &Array2<F>,
581        y: &Array1<F>,
582    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
583        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
584        let fitted = self.fit(x, &y_usize)?;
585        Ok(Box::new(FittedExtraTreesClassifierPipelineAdapter(fitted)))
586    }
587}
588
589/// Pipeline adapter for `FittedExtraTreesClassifier<F>`.
590struct FittedExtraTreesClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
591    FittedExtraTreesClassifier<F>,
592);
593
594impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
595    for FittedExtraTreesClassifierPipelineAdapter<F>
596{
597    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
598        let preds = self.0.predict(x)?;
599        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
600    }
601}
602
603// ---------------------------------------------------------------------------
604// ExtraTreesRegressor
605// ---------------------------------------------------------------------------
606
607/// Extremely randomized trees regressor (ensemble).
608///
609/// Builds an ensemble of [`ExtraTreeRegressor`](crate::ExtraTreeRegressor)
610/// base estimators, each using random split thresholds and random feature
611/// subsets at every node. Final predictions are the mean across all trees.
612///
613/// Unlike [`RandomForestRegressor`](crate::RandomForestRegressor), bootstrap
614/// sampling is **disabled** by default.
615///
616/// # Type Parameters
617///
618/// - `F`: The floating-point type (`f32` or `f64`).
619#[derive(Debug, Clone, Serialize, Deserialize)]
620pub struct ExtraTreesRegressor<F> {
621    /// Number of trees in the ensemble.
622    pub n_estimators: usize,
623    /// Maximum depth of each tree. `None` means unlimited.
624    pub max_depth: Option<usize>,
625    /// Minimum number of samples required to split an internal node.
626    pub min_samples_split: usize,
627    /// Minimum number of samples required in a leaf node.
628    pub min_samples_leaf: usize,
629    /// Strategy for the number of features considered at each split.
630    pub max_features: MaxFeatures,
631    /// Whether to use bootstrap sampling. Default is `false`.
632    pub bootstrap: bool,
633    /// Random seed for reproducibility. `None` means non-deterministic.
634    pub random_state: Option<u64>,
635    /// Number of parallel jobs. `None` means use all available cores.
636    pub n_jobs: Option<usize>,
637    _marker: std::marker::PhantomData<F>,
638}
639
640impl<F: Float> ExtraTreesRegressor<F> {
641    /// Create a new `ExtraTreesRegressor` with default settings.
642    ///
643    /// Defaults: `n_estimators = 100`, `max_depth = None`,
644    /// `max_features = All`, `min_samples_split = 2`,
645    /// `min_samples_leaf = 1`, `bootstrap = false`,
646    /// `random_state = None`, `n_jobs = None`.
647    #[must_use]
648    pub fn new() -> Self {
649        Self {
650            n_estimators: 100,
651            max_depth: None,
652            min_samples_split: 2,
653            min_samples_leaf: 1,
654            max_features: MaxFeatures::All,
655            bootstrap: false,
656            random_state: None,
657            n_jobs: None,
658            _marker: std::marker::PhantomData,
659        }
660    }
661
662    /// Set the number of trees.
663    #[must_use]
664    pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
665        self.n_estimators = n_estimators;
666        self
667    }
668
669    /// Set the maximum tree depth.
670    #[must_use]
671    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
672        self.max_depth = max_depth;
673        self
674    }
675
676    /// Set the minimum number of samples to split a node.
677    #[must_use]
678    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
679        self.min_samples_split = min_samples_split;
680        self
681    }
682
683    /// Set the minimum number of samples in a leaf.
684    #[must_use]
685    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
686        self.min_samples_leaf = min_samples_leaf;
687        self
688    }
689
690    /// Set the maximum features strategy.
691    #[must_use]
692    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
693        self.max_features = max_features;
694        self
695    }
696
697    /// Set whether to use bootstrap sampling.
698    #[must_use]
699    pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
700        self.bootstrap = bootstrap;
701        self
702    }
703
704    /// Set the random seed for reproducibility.
705    #[must_use]
706    pub fn with_random_state(mut self, seed: u64) -> Self {
707        self.random_state = Some(seed);
708        self
709    }
710
711    /// Set the number of parallel jobs.
712    #[must_use]
713    pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
714        self.n_jobs = Some(n_jobs);
715        self
716    }
717}
718
719impl<F: Float> Default for ExtraTreesRegressor<F> {
720    fn default() -> Self {
721        Self::new()
722    }
723}
724
725// ---------------------------------------------------------------------------
726// FittedExtraTreesRegressor
727// ---------------------------------------------------------------------------
728
729/// A fitted extremely randomized trees regressor (ensemble).
730///
731/// Stores the ensemble of fitted extra-trees and aggregates their
732/// predictions by averaging.
733#[derive(Debug, Clone)]
734pub struct FittedExtraTreesRegressor<F> {
735    /// Individual tree node vectors.
736    trees: Vec<Vec<Node<F>>>,
737    /// Number of features.
738    n_features: usize,
739    /// Per-feature importance scores (mean decrease in impurity, normalised).
740    feature_importances: Array1<F>,
741}
742
743impl<F: Float + Send + Sync + 'static> FittedExtraTreesRegressor<F> {
744    /// Returns a reference to the individual tree node vectors.
745    #[must_use]
746    pub fn trees(&self) -> &[Vec<Node<F>>] {
747        &self.trees
748    }
749
750    /// Returns the number of features the model was trained on.
751    #[must_use]
752    pub fn n_features(&self) -> usize {
753        self.n_features
754    }
755
756    /// Returns the number of trees in the ensemble.
757    #[must_use]
758    pub fn n_estimators(&self) -> usize {
759        self.trees.len()
760    }
761
762    /// R² coefficient of determination on the given test data.
763    /// Equivalent to sklearn's `RegressorMixin.score`.
764    ///
765    /// # Errors
766    ///
767    /// Returns [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()` or
768    /// the feature count does not match the training data.
769    pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<F, FerroError> {
770        if x.nrows() != y.len() {
771            return Err(FerroError::ShapeMismatch {
772                expected: vec![x.nrows()],
773                actual: vec![y.len()],
774                context: "y length must match number of samples in X".into(),
775            });
776        }
777        let preds = self.predict(x)?;
778        Ok(crate::r2_score(&preds, y))
779    }
780}
781
782impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for ExtraTreesRegressor<F> {
783    type Fitted = FittedExtraTreesRegressor<F>;
784    type Error = FerroError;
785
786    /// Fit the ensemble by building `n_estimators` extra-trees in parallel.
787    ///
788    /// # Errors
789    ///
790    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
791    /// numbers of samples.
792    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
793    /// Returns [`FerroError::InvalidParameter`] if `n_estimators` is 0.
794    fn fit(
795        &self,
796        x: &Array2<F>,
797        y: &Array1<F>,
798    ) -> Result<FittedExtraTreesRegressor<F>, FerroError> {
799        let (n_samples, n_features) = x.dim();
800
801        if n_samples != y.len() {
802            return Err(FerroError::ShapeMismatch {
803                expected: vec![n_samples],
804                actual: vec![y.len()],
805                context: "y length must match number of samples in X".into(),
806            });
807        }
808        if n_samples == 0 {
809            return Err(FerroError::InsufficientSamples {
810                required: 1,
811                actual: 0,
812                context: "ExtraTreesRegressor requires at least one sample".into(),
813            });
814        }
815        if self.n_estimators == 0 {
816            return Err(FerroError::InvalidParameter {
817                name: "n_estimators".into(),
818                reason: "must be at least 1".into(),
819            });
820        }
821
822        let max_features_n = resolve_max_features(self.max_features, n_features);
823        let params = make_tree_params(
824            self.max_depth,
825            self.min_samples_split,
826            self.min_samples_leaf,
827        );
828        let bootstrap = self.bootstrap;
829
830        // Generate per-tree seeds sequentially.
831        let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
832            let mut master_rng = StdRng::seed_from_u64(seed);
833            (0..self.n_estimators)
834                .map(|_| {
835                    use rand::RngCore;
836                    master_rng.next_u64()
837                })
838                .collect()
839        } else {
840            (0..self.n_estimators)
841                .map(|_| {
842                    use rand::RngCore;
843                    rand::rng().next_u64()
844                })
845                .collect()
846        };
847
848        // Build trees in parallel.
849        let trees: Vec<Vec<Node<F>>> = if let Some(n_jobs) = self.n_jobs {
850            let pool = rayon::ThreadPoolBuilder::new()
851                .num_threads(n_jobs)
852                .build()
853                .unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
854            pool.install(|| {
855                tree_seeds
856                    .par_iter()
857                    .map(|&seed| {
858                        build_single_regression_tree(
859                            x,
860                            y,
861                            n_samples,
862                            n_features,
863                            max_features_n,
864                            &params,
865                            bootstrap,
866                            seed,
867                        )
868                    })
869                    .collect()
870            })
871        } else {
872            tree_seeds
873                .par_iter()
874                .map(|&seed| {
875                    build_single_regression_tree(
876                        x,
877                        y,
878                        n_samples,
879                        n_features,
880                        max_features_n,
881                        &params,
882                        bootstrap,
883                        seed,
884                    )
885                })
886                .collect()
887        };
888
889        // Aggregate feature importances.
890        let mut total_importances = Array1::<F>::zeros(n_features);
891        for tree_nodes in &trees {
892            let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
893            total_importances = total_importances + tree_imp;
894        }
895        let imp_sum: F = total_importances
896            .iter()
897            .copied()
898            .fold(F::zero(), |a, b| a + b);
899        if imp_sum > F::zero() {
900            total_importances.mapv_inplace(|v| v / imp_sum);
901        }
902
903        Ok(FittedExtraTreesRegressor {
904            trees,
905            n_features,
906            feature_importances: total_importances,
907        })
908    }
909}
910
911/// Build a single regression extra-tree (used by parallel dispatch).
912#[allow(clippy::too_many_arguments)]
913fn build_single_regression_tree<F: Float>(
914    x: &Array2<F>,
915    y: &Array1<F>,
916    n_samples: usize,
917    n_features: usize,
918    max_features_n: usize,
919    params: &TreeParams,
920    bootstrap: bool,
921    seed: u64,
922) -> Vec<Node<F>> {
923    let mut rng = StdRng::seed_from_u64(seed);
924
925    let indices: Vec<usize> = if bootstrap {
926        use rand::RngCore;
927        (0..n_samples)
928            .map(|_| (rng.next_u64() as usize) % n_samples)
929            .collect()
930    } else {
931        (0..n_samples).collect()
932    };
933
934    build_extra_regression_tree_for_ensemble(
935        x,
936        y,
937        &indices,
938        None, // feature selection happens inside the tree builder
939        params,
940        n_features,
941        max_features_n,
942        &mut rng,
943    )
944}
945
946impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreesRegressor<F> {
947    type Output = Array1<F>;
948    type Error = FerroError;
949
950    /// Predict target values by averaging across all trees.
951    ///
952    /// # Errors
953    ///
954    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
955    /// not match the fitted model.
956    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
957        if x.ncols() != self.n_features {
958            return Err(FerroError::ShapeMismatch {
959                expected: vec![self.n_features],
960                actual: vec![x.ncols()],
961                context: "number of features must match fitted model".into(),
962            });
963        }
964
965        let n_samples = x.nrows();
966        let n_trees_f = F::from(self.trees.len()).unwrap();
967        let mut predictions = Array1::zeros(n_samples);
968
969        for i in 0..n_samples {
970            let row = x.row(i);
971            let mut sum = F::zero();
972
973            for tree_nodes in &self.trees {
974                let leaf_idx = traverse(tree_nodes, &row);
975                if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
976                    sum = sum + value;
977                }
978            }
979
980            predictions[i] = sum / n_trees_f;
981        }
982
983        Ok(predictions)
984    }
985}
986
987impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreesRegressor<F> {
988    fn feature_importances(&self) -> &Array1<F> {
989        &self.feature_importances
990    }
991}
992
993// Pipeline integration.
994impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for ExtraTreesRegressor<F> {
995    fn fit_pipeline(
996        &self,
997        x: &Array2<F>,
998        y: &Array1<F>,
999    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
1000        let fitted = self.fit(x, y)?;
1001        Ok(Box::new(fitted))
1002    }
1003}
1004
1005impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedExtraTreesRegressor<F> {
1006    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1007        self.predict(x)
1008    }
1009}
1010
1011// ---------------------------------------------------------------------------
1012// Tests
1013// ---------------------------------------------------------------------------
1014
1015#[cfg(test)]
1016mod tests {
1017    use super::*;
1018    use approx::assert_relative_eq;
1019    use ndarray::array;
1020
1021    // -- ExtraTreesClassifier tests --
1022
1023    #[test]
1024    fn test_ensemble_classifier_simple() {
1025        let x = Array2::from_shape_vec(
1026            (8, 2),
1027            vec![
1028                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1029            ],
1030        )
1031        .unwrap();
1032        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1033
1034        let model = ExtraTreesClassifier::<f64>::new()
1035            .with_n_estimators(20)
1036            .with_random_state(42);
1037        let fitted = model.fit(&x, &y).unwrap();
1038        let preds = fitted.predict(&x).unwrap();
1039
1040        // Should classify all training points correctly.
1041        assert_eq!(preds, y);
1042    }
1043
1044    #[test]
1045    fn test_ensemble_classifier_no_bootstrap() {
1046        let x = Array2::from_shape_vec(
1047            (8, 2),
1048            vec![
1049                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1050            ],
1051        )
1052        .unwrap();
1053        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1054
1055        // Default: no bootstrap.
1056        let model = ExtraTreesClassifier::<f64>::new()
1057            .with_n_estimators(10)
1058            .with_random_state(42);
1059        assert!(!model.bootstrap);
1060        let fitted = model.fit(&x, &y).unwrap();
1061        let preds = fitted.predict(&x).unwrap();
1062        assert_eq!(preds, y);
1063    }
1064
1065    #[test]
1066    fn test_ensemble_classifier_with_bootstrap() {
1067        let x = Array2::from_shape_vec(
1068            (8, 2),
1069            vec![
1070                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1071            ],
1072        )
1073        .unwrap();
1074        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1075
1076        let model = ExtraTreesClassifier::<f64>::new()
1077            .with_n_estimators(20)
1078            .with_bootstrap(true)
1079            .with_random_state(42);
1080        assert!(model.bootstrap);
1081        let fitted = model.fit(&x, &y).unwrap();
1082        let preds = fitted.predict(&x).unwrap();
1083        assert_eq!(preds.len(), 8);
1084    }
1085
1086    #[test]
1087    fn test_ensemble_classifier_predict_proba() {
1088        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1089        let y = array![0, 0, 0, 1, 1, 1];
1090
1091        let model = ExtraTreesClassifier::<f64>::new()
1092            .with_n_estimators(10)
1093            .with_random_state(42);
1094        let fitted = model.fit(&x, &y).unwrap();
1095        let proba = fitted.predict_proba(&x).unwrap();
1096
1097        assert_eq!(proba.dim(), (6, 2));
1098        for i in 0..6 {
1099            let row_sum = proba.row(i).sum();
1100            assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
1101        }
1102    }
1103
1104    #[test]
1105    fn test_ensemble_classifier_feature_importances() {
1106        let x = Array2::from_shape_vec(
1107            (8, 2),
1108            vec![
1109                1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 5.0, 1.0, 6.0, 1.0, 7.0, 1.0, 8.0, 1.0,
1110            ],
1111        )
1112        .unwrap();
1113        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1114
1115        let model = ExtraTreesClassifier::<f64>::new()
1116            .with_n_estimators(20)
1117            .with_max_features(MaxFeatures::All)
1118            .with_random_state(42);
1119        let fitted = model.fit(&x, &y).unwrap();
1120        let importances = fitted.feature_importances();
1121
1122        assert_eq!(importances.len(), 2);
1123        let total: f64 = importances.sum();
1124        assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1125        // Feature 0 should dominate (feature 1 is constant).
1126        assert!(importances[0] > importances[1]);
1127    }
1128
1129    #[test]
1130    fn test_ensemble_classifier_n_estimators() {
1131        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1132        let y = array![0, 0, 1, 1];
1133
1134        let model = ExtraTreesClassifier::<f64>::new()
1135            .with_n_estimators(15)
1136            .with_random_state(42);
1137        let fitted = model.fit(&x, &y).unwrap();
1138        assert_eq!(fitted.n_estimators(), 15);
1139    }
1140
1141    #[test]
1142    fn test_ensemble_classifier_classes() {
1143        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1144        let y = array![0, 0, 0, 3, 3, 3]; // non-contiguous
1145
1146        let model = ExtraTreesClassifier::<f64>::new()
1147            .with_n_estimators(5)
1148            .with_random_state(42);
1149        let fitted = model.fit(&x, &y).unwrap();
1150        assert_eq!(fitted.classes(), &[0, 3]);
1151        assert_eq!(fitted.n_classes(), 2);
1152    }
1153
1154    #[test]
1155    fn test_ensemble_classifier_shape_mismatch() {
1156        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1157        let y = array![0, 0];
1158        let model = ExtraTreesClassifier::<f64>::new();
1159        assert!(model.fit(&x, &y).is_err());
1160    }
1161
1162    #[test]
1163    fn test_ensemble_classifier_empty_data() {
1164        let x = Array2::<f64>::zeros((0, 2));
1165        let y = Array1::<usize>::zeros(0);
1166        let model = ExtraTreesClassifier::<f64>::new();
1167        assert!(model.fit(&x, &y).is_err());
1168    }
1169
1170    #[test]
1171    fn test_ensemble_classifier_zero_estimators() {
1172        let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1173        let y = array![0, 1];
1174        let model = ExtraTreesClassifier::<f64>::new().with_n_estimators(0);
1175        assert!(model.fit(&x, &y).is_err());
1176    }
1177
1178    #[test]
1179    fn test_ensemble_classifier_deterministic() {
1180        let x = Array2::from_shape_vec(
1181            (8, 2),
1182            vec![
1183                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1184            ],
1185        )
1186        .unwrap();
1187        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1188
1189        let model1 = ExtraTreesClassifier::<f64>::new()
1190            .with_n_estimators(10)
1191            .with_random_state(123);
1192        let model2 = ExtraTreesClassifier::<f64>::new()
1193            .with_n_estimators(10)
1194            .with_random_state(123);
1195
1196        let preds1 = model1.fit(&x, &y).unwrap().predict(&x).unwrap();
1197        let preds2 = model2.fit(&x, &y).unwrap().predict(&x).unwrap();
1198        assert_eq!(preds1, preds2);
1199    }
1200
1201    #[test]
1202    fn test_ensemble_classifier_predict_shape_mismatch() {
1203        let x =
1204            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1205        let y = array![0, 0, 1, 1];
1206
1207        let model = ExtraTreesClassifier::<f64>::new()
1208            .with_n_estimators(5)
1209            .with_random_state(42);
1210        let fitted = model.fit(&x, &y).unwrap();
1211
1212        let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1213        assert!(fitted.predict(&x_wrong).is_err());
1214    }
1215
1216    // -- ExtraTreesRegressor tests --
1217
1218    #[test]
1219    fn test_ensemble_regressor_simple() {
1220        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1221        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1222
1223        let model = ExtraTreesRegressor::<f64>::new()
1224            .with_n_estimators(20)
1225            .with_random_state(42);
1226        let fitted = model.fit(&x, &y).unwrap();
1227        let preds = fitted.predict(&x).unwrap();
1228
1229        assert_eq!(preds.len(), 6);
1230        // Ensemble should approximate the training data well.
1231        for i in 0..6 {
1232            assert_relative_eq!(preds[i], y[i], epsilon = 1.0);
1233        }
1234    }
1235
1236    #[test]
1237    fn test_ensemble_regressor_constant_target() {
1238        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1239        let y = array![5.0, 5.0, 5.0, 5.0];
1240
1241        let model = ExtraTreesRegressor::<f64>::new()
1242            .with_n_estimators(10)
1243            .with_random_state(42);
1244        let fitted = model.fit(&x, &y).unwrap();
1245        let preds = fitted.predict(&x).unwrap();
1246
1247        for &p in &preds {
1248            assert_relative_eq!(p, 5.0, epsilon = 1e-10);
1249        }
1250    }
1251
1252    #[test]
1253    fn test_ensemble_regressor_no_bootstrap() {
1254        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1255        let y = array![1.0, 2.0, 3.0, 4.0];
1256
1257        let model = ExtraTreesRegressor::<f64>::new()
1258            .with_n_estimators(10)
1259            .with_random_state(42);
1260        assert!(!model.bootstrap);
1261        let fitted = model.fit(&x, &y).unwrap();
1262        let preds = fitted.predict(&x).unwrap();
1263        assert_eq!(preds.len(), 4);
1264    }
1265
1266    #[test]
1267    fn test_ensemble_regressor_with_bootstrap() {
1268        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1269        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1270
1271        let model = ExtraTreesRegressor::<f64>::new()
1272            .with_n_estimators(10)
1273            .with_bootstrap(true)
1274            .with_random_state(42);
1275        assert!(model.bootstrap);
1276        let fitted = model.fit(&x, &y).unwrap();
1277        let preds = fitted.predict(&x).unwrap();
1278        assert_eq!(preds.len(), 6);
1279    }
1280
1281    #[test]
1282    fn test_ensemble_regressor_feature_importances() {
1283        let x = Array2::from_shape_vec(
1284            (8, 2),
1285            vec![
1286                1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1287            ],
1288        )
1289        .unwrap();
1290        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1291
1292        let model = ExtraTreesRegressor::<f64>::new()
1293            .with_n_estimators(20)
1294            .with_max_features(MaxFeatures::All)
1295            .with_random_state(42);
1296        let fitted = model.fit(&x, &y).unwrap();
1297        let importances = fitted.feature_importances();
1298
1299        assert_eq!(importances.len(), 2);
1300        let total: f64 = importances.sum();
1301        assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1302        assert!(importances[0] > importances[1]);
1303    }
1304
1305    #[test]
1306    fn test_ensemble_regressor_n_estimators() {
1307        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1308        let y = array![1.0, 2.0, 3.0, 4.0];
1309
1310        let model = ExtraTreesRegressor::<f64>::new()
1311            .with_n_estimators(7)
1312            .with_random_state(42);
1313        let fitted = model.fit(&x, &y).unwrap();
1314        assert_eq!(fitted.n_estimators(), 7);
1315    }
1316
1317    #[test]
1318    fn test_ensemble_regressor_shape_mismatch() {
1319        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1320        let y = array![1.0, 2.0];
1321        let model = ExtraTreesRegressor::<f64>::new();
1322        assert!(model.fit(&x, &y).is_err());
1323    }
1324
1325    #[test]
1326    fn test_ensemble_regressor_empty_data() {
1327        let x = Array2::<f64>::zeros((0, 2));
1328        let y = Array1::<f64>::zeros(0);
1329        let model = ExtraTreesRegressor::<f64>::new();
1330        assert!(model.fit(&x, &y).is_err());
1331    }
1332
1333    #[test]
1334    fn test_ensemble_regressor_zero_estimators() {
1335        let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1336        let y = array![1.0, 2.0];
1337        let model = ExtraTreesRegressor::<f64>::new().with_n_estimators(0);
1338        assert!(model.fit(&x, &y).is_err());
1339    }
1340
1341    #[test]
1342    fn test_ensemble_regressor_deterministic() {
1343        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1344        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1345
1346        let model1 = ExtraTreesRegressor::<f64>::new()
1347            .with_n_estimators(10)
1348            .with_random_state(99);
1349        let model2 = ExtraTreesRegressor::<f64>::new()
1350            .with_n_estimators(10)
1351            .with_random_state(99);
1352
1353        let preds1 = model1.fit(&x, &y).unwrap().predict(&x).unwrap();
1354        let preds2 = model2.fit(&x, &y).unwrap().predict(&x).unwrap();
1355
1356        for i in 0..6 {
1357            assert_relative_eq!(preds1[i], preds2[i], epsilon = 1e-12);
1358        }
1359    }
1360
1361    #[test]
1362    fn test_ensemble_regressor_predict_shape_mismatch() {
1363        let x =
1364            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1365        let y = array![1.0, 2.0, 3.0, 4.0];
1366
1367        let model = ExtraTreesRegressor::<f64>::new()
1368            .with_n_estimators(5)
1369            .with_random_state(42);
1370        let fitted = model.fit(&x, &y).unwrap();
1371
1372        let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1373        assert!(fitted.predict(&x_wrong).is_err());
1374    }
1375
1376    // -- Builder tests --
1377
1378    #[test]
1379    fn test_ensemble_classifier_builder() {
1380        let model = ExtraTreesClassifier::<f64>::new()
1381            .with_n_estimators(50)
1382            .with_max_depth(Some(5))
1383            .with_min_samples_split(10)
1384            .with_min_samples_leaf(3)
1385            .with_max_features(MaxFeatures::Log2)
1386            .with_bootstrap(true)
1387            .with_criterion(ClassificationCriterion::Entropy)
1388            .with_random_state(42)
1389            .with_n_jobs(4);
1390
1391        assert_eq!(model.n_estimators, 50);
1392        assert_eq!(model.max_depth, Some(5));
1393        assert_eq!(model.min_samples_split, 10);
1394        assert_eq!(model.min_samples_leaf, 3);
1395        assert_eq!(model.max_features, MaxFeatures::Log2);
1396        assert!(model.bootstrap);
1397        assert_eq!(model.criterion, ClassificationCriterion::Entropy);
1398        assert_eq!(model.random_state, Some(42));
1399        assert_eq!(model.n_jobs, Some(4));
1400    }
1401
1402    #[test]
1403    fn test_ensemble_regressor_builder() {
1404        let model = ExtraTreesRegressor::<f64>::new()
1405            .with_n_estimators(25)
1406            .with_max_depth(Some(8))
1407            .with_min_samples_split(5)
1408            .with_min_samples_leaf(2)
1409            .with_max_features(MaxFeatures::Fraction(0.5))
1410            .with_bootstrap(true)
1411            .with_random_state(99)
1412            .with_n_jobs(2);
1413
1414        assert_eq!(model.n_estimators, 25);
1415        assert_eq!(model.max_depth, Some(8));
1416        assert_eq!(model.min_samples_split, 5);
1417        assert_eq!(model.min_samples_leaf, 2);
1418        assert_eq!(model.max_features, MaxFeatures::Fraction(0.5));
1419        assert!(model.bootstrap);
1420        assert_eq!(model.random_state, Some(99));
1421        assert_eq!(model.n_jobs, Some(2));
1422    }
1423
1424    #[test]
1425    fn test_ensemble_classifier_default() {
1426        let model = ExtraTreesClassifier::<f64>::default();
1427        assert_eq!(model.n_estimators, 100);
1428        assert_eq!(model.max_depth, None);
1429        assert_eq!(model.min_samples_split, 2);
1430        assert_eq!(model.min_samples_leaf, 1);
1431        assert_eq!(model.max_features, MaxFeatures::Sqrt);
1432        assert!(!model.bootstrap);
1433        assert_eq!(model.criterion, ClassificationCriterion::Gini);
1434        assert_eq!(model.random_state, None);
1435        assert_eq!(model.n_jobs, None);
1436    }
1437
1438    #[test]
1439    fn test_ensemble_regressor_default() {
1440        let model = ExtraTreesRegressor::<f64>::default();
1441        assert_eq!(model.n_estimators, 100);
1442        assert_eq!(model.max_depth, None);
1443        assert_eq!(model.min_samples_split, 2);
1444        assert_eq!(model.min_samples_leaf, 1);
1445        assert_eq!(model.max_features, MaxFeatures::All);
1446        assert!(!model.bootstrap);
1447        assert_eq!(model.random_state, None);
1448        assert_eq!(model.n_jobs, None);
1449    }
1450}