Skip to main content

ferrolearn_tree/
extra_trees_ensemble.rs

1//! Extremely randomized trees ensemble classifiers and regressors.
2//!
3//! This module provides [`ExtraTreesClassifier`] and [`ExtraTreesRegressor`],
4//! which build ensembles of extremely randomized trees. Unlike
5//! [`RandomForestClassifier`](crate::RandomForestClassifier), ExtraTrees
6//! ensembles do **not** bootstrap by default: all trees see all samples, and
7//! randomness comes solely from the random split thresholds and random feature
8//! subsets at each node.
9//!
10//! # Examples
11//!
12//! ```
13//! use ferrolearn_tree::ExtraTreesClassifier;
14//! use ferrolearn_core::{Fit, Predict};
15//! use ndarray::{array, Array1, Array2};
16//!
17//! let x = Array2::from_shape_vec((8, 2), vec![
18//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,  4.0, 4.0,
19//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,  8.0, 9.0,
20//! ]).unwrap();
21//! let y = array![0, 0, 0, 0, 1, 1, 1, 1];
22//!
23//! let model = ExtraTreesClassifier::<f64>::new()
24//!     .with_n_estimators(10)
25//!     .with_random_state(42);
26//! let fitted = model.fit(&x, &y).unwrap();
27//! let preds = fitted.predict(&x).unwrap();
28//! ```
29
30use ferrolearn_core::error::FerroError;
31use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
32use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
33use ferrolearn_core::traits::{Fit, Predict};
34use ndarray::{Array1, Array2};
35use num_traits::{Float, FromPrimitive, ToPrimitive};
36use rand::SeedableRng;
37use rand::rngs::StdRng;
38use rayon::prelude::*;
39use serde::{Deserialize, Serialize};
40
41use crate::decision_tree::{
42    ClassificationCriterion, Node, TreeParams, compute_feature_importances, traverse,
43};
44use crate::extra_tree::{
45    build_extra_classification_tree_for_ensemble, build_extra_regression_tree_for_ensemble,
46};
47use crate::random_forest::MaxFeatures;
48
49/// Resolve the `MaxFeatures` strategy to a concrete number.
50fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
51    let result = match strategy {
52        MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
53        MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
54        MaxFeatures::All => n_features,
55        MaxFeatures::Fixed(n) => n.min(n_features),
56        MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
57    };
58    result.max(1).min(n_features)
59}
60
61/// Internal tree parameter struct helper.
62fn make_tree_params(
63    max_depth: Option<usize>,
64    min_samples_split: usize,
65    min_samples_leaf: usize,
66) -> TreeParams {
67    TreeParams {
68        max_depth,
69        min_samples_split,
70        min_samples_leaf,
71    }
72}
73
74// ---------------------------------------------------------------------------
75// ExtraTreesClassifier
76// ---------------------------------------------------------------------------
77
78/// Extremely randomized trees classifier (ensemble).
79///
80/// Builds an ensemble of [`ExtraTreeClassifier`](crate::ExtraTreeClassifier)
81/// base estimators, each using random split thresholds and random feature
82/// subsets at every node. Final predictions are made by majority vote.
83///
84/// Unlike [`RandomForestClassifier`](crate::RandomForestClassifier), bootstrap
85/// sampling is **disabled** by default. Randomness comes from the random
86/// thresholds and random feature subsets at each split.
87///
88/// # Type Parameters
89///
90/// - `F`: The floating-point type (`f32` or `f64`).
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct ExtraTreesClassifier<F> {
93    /// Number of trees in the ensemble.
94    pub n_estimators: usize,
95    /// Maximum depth of each tree. `None` means unlimited.
96    pub max_depth: Option<usize>,
97    /// Minimum number of samples required to split an internal node.
98    pub min_samples_split: usize,
99    /// Minimum number of samples required in a leaf node.
100    pub min_samples_leaf: usize,
101    /// Strategy for the number of features considered at each split.
102    pub max_features: MaxFeatures,
103    /// Whether to use bootstrap sampling. Default is `false`.
104    pub bootstrap: bool,
105    /// Splitting criterion.
106    pub criterion: ClassificationCriterion,
107    /// Random seed for reproducibility. `None` means non-deterministic.
108    pub random_state: Option<u64>,
109    /// Number of parallel jobs. `None` means use all available cores.
110    pub n_jobs: Option<usize>,
111    _marker: std::marker::PhantomData<F>,
112}
113
114impl<F: Float> ExtraTreesClassifier<F> {
115    /// Create a new `ExtraTreesClassifier` with default settings.
116    ///
117    /// Defaults: `n_estimators = 100`, `max_depth = None`,
118    /// `max_features = Sqrt`, `min_samples_split = 2`,
119    /// `min_samples_leaf = 1`, `bootstrap = false`,
120    /// `criterion = Gini`, `random_state = None`, `n_jobs = None`.
121    #[must_use]
122    pub fn new() -> Self {
123        Self {
124            n_estimators: 100,
125            max_depth: None,
126            min_samples_split: 2,
127            min_samples_leaf: 1,
128            max_features: MaxFeatures::Sqrt,
129            bootstrap: false,
130            criterion: ClassificationCriterion::Gini,
131            random_state: None,
132            n_jobs: None,
133            _marker: std::marker::PhantomData,
134        }
135    }
136
137    /// Set the number of trees.
138    #[must_use]
139    pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
140        self.n_estimators = n_estimators;
141        self
142    }
143
144    /// Set the maximum tree depth.
145    #[must_use]
146    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
147        self.max_depth = max_depth;
148        self
149    }
150
151    /// Set the minimum number of samples to split a node.
152    #[must_use]
153    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
154        self.min_samples_split = min_samples_split;
155        self
156    }
157
158    /// Set the minimum number of samples in a leaf.
159    #[must_use]
160    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
161        self.min_samples_leaf = min_samples_leaf;
162        self
163    }
164
165    /// Set the maximum features strategy.
166    #[must_use]
167    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
168        self.max_features = max_features;
169        self
170    }
171
172    /// Set whether to use bootstrap sampling.
173    #[must_use]
174    pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
175        self.bootstrap = bootstrap;
176        self
177    }
178
179    /// Set the splitting criterion.
180    #[must_use]
181    pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
182        self.criterion = criterion;
183        self
184    }
185
186    /// Set the random seed for reproducibility.
187    #[must_use]
188    pub fn with_random_state(mut self, seed: u64) -> Self {
189        self.random_state = Some(seed);
190        self
191    }
192
193    /// Set the number of parallel jobs.
194    #[must_use]
195    pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
196        self.n_jobs = Some(n_jobs);
197        self
198    }
199}
200
201impl<F: Float> Default for ExtraTreesClassifier<F> {
202    fn default() -> Self {
203        Self::new()
204    }
205}
206
207// ---------------------------------------------------------------------------
208// FittedExtraTreesClassifier
209// ---------------------------------------------------------------------------
210
211/// A fitted extremely randomized trees classifier (ensemble).
212///
213/// Stores the ensemble of fitted extra-trees and aggregates their
214/// predictions by majority vote.
215#[derive(Debug, Clone)]
216pub struct FittedExtraTreesClassifier<F> {
217    /// Individual tree node vectors.
218    trees: Vec<Vec<Node<F>>>,
219    /// Sorted unique class labels.
220    classes: Vec<usize>,
221    /// Number of features.
222    n_features: usize,
223    /// Per-feature importance scores (mean decrease in impurity, normalised).
224    feature_importances: Array1<F>,
225}
226
227impl<F: Float + Send + Sync + 'static> FittedExtraTreesClassifier<F> {
228    /// Returns a reference to the individual tree node vectors.
229    #[must_use]
230    pub fn trees(&self) -> &[Vec<Node<F>>] {
231        &self.trees
232    }
233
234    /// Returns the number of features the model was trained on.
235    #[must_use]
236    pub fn n_features(&self) -> usize {
237        self.n_features
238    }
239
240    /// Returns the number of trees in the ensemble.
241    #[must_use]
242    pub fn n_estimators(&self) -> usize {
243        self.trees.len()
244    }
245
246    /// Predict class probabilities for each sample by averaging tree predictions.
247    ///
248    /// Returns a 2-D array of shape `(n_samples, n_classes)`.
249    ///
250    /// # Errors
251    ///
252    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
253    /// not match the training data.
254    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
255        if x.ncols() != self.n_features {
256            return Err(FerroError::ShapeMismatch {
257                expected: vec![self.n_features],
258                actual: vec![x.ncols()],
259                context: "number of features must match fitted model".into(),
260            });
261        }
262        let n_samples = x.nrows();
263        let n_classes = self.classes.len();
264        let n_trees_f = F::from(self.trees.len()).unwrap();
265        let mut proba = Array2::zeros((n_samples, n_classes));
266
267        for i in 0..n_samples {
268            let row = x.row(i);
269            for tree_nodes in &self.trees {
270                let leaf_idx = traverse(tree_nodes, &row);
271                if let Node::Leaf {
272                    class_distribution: Some(ref dist),
273                    ..
274                } = tree_nodes[leaf_idx]
275                {
276                    for (j, &p) in dist.iter().enumerate() {
277                        proba[[i, j]] = proba[[i, j]] + p;
278                    }
279                }
280            }
281            // Average across trees.
282            for j in 0..n_classes {
283                proba[[i, j]] = proba[[i, j]] / n_trees_f;
284            }
285        }
286        Ok(proba)
287    }
288}
289
290impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ExtraTreesClassifier<F> {
291    type Fitted = FittedExtraTreesClassifier<F>;
292    type Error = FerroError;
293
294    /// Fit the ensemble by building `n_estimators` extra-trees in parallel.
295    ///
296    /// Each tree uses random split thresholds and random feature subsets at
297    /// every node. If `bootstrap` is `true`, each tree is trained on a
298    /// bootstrap sample; otherwise all samples are used.
299    ///
300    /// # Errors
301    ///
302    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
303    /// numbers of samples.
304    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
305    /// Returns [`FerroError::InvalidParameter`] if `n_estimators` is 0.
306    fn fit(
307        &self,
308        x: &Array2<F>,
309        y: &Array1<usize>,
310    ) -> Result<FittedExtraTreesClassifier<F>, FerroError> {
311        let (n_samples, n_features) = x.dim();
312
313        if n_samples != y.len() {
314            return Err(FerroError::ShapeMismatch {
315                expected: vec![n_samples],
316                actual: vec![y.len()],
317                context: "y length must match number of samples in X".into(),
318            });
319        }
320        if n_samples == 0 {
321            return Err(FerroError::InsufficientSamples {
322                required: 1,
323                actual: 0,
324                context: "ExtraTreesClassifier requires at least one sample".into(),
325            });
326        }
327        if self.n_estimators == 0 {
328            return Err(FerroError::InvalidParameter {
329                name: "n_estimators".into(),
330                reason: "must be at least 1".into(),
331            });
332        }
333
334        // Determine unique classes.
335        let mut classes: Vec<usize> = y.iter().copied().collect();
336        classes.sort_unstable();
337        classes.dedup();
338        let n_classes = classes.len();
339
340        let y_mapped: Vec<usize> = y
341            .iter()
342            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
343            .collect();
344
345        let max_features_n = resolve_max_features(self.max_features, n_features);
346        let params = make_tree_params(
347            self.max_depth,
348            self.min_samples_split,
349            self.min_samples_leaf,
350        );
351        let criterion = self.criterion;
352        let bootstrap = self.bootstrap;
353
354        // Generate per-tree seeds sequentially for determinism.
355        let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
356            let mut master_rng = StdRng::seed_from_u64(seed);
357            (0..self.n_estimators)
358                .map(|_| {
359                    use rand::RngCore;
360                    master_rng.next_u64()
361                })
362                .collect()
363        } else {
364            (0..self.n_estimators)
365                .map(|_| {
366                    use rand::RngCore;
367                    rand::rng().next_u64()
368                })
369                .collect()
370        };
371
372        // Optionally configure thread pool.
373        let trees: Vec<Vec<Node<F>>> = if let Some(n_jobs) = self.n_jobs {
374            let pool = rayon::ThreadPoolBuilder::new()
375                .num_threads(n_jobs)
376                .build()
377                .unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
378            pool.install(|| {
379                tree_seeds
380                    .par_iter()
381                    .map(|&seed| {
382                        build_single_classification_tree(
383                            x,
384                            &y_mapped,
385                            n_classes,
386                            n_samples,
387                            n_features,
388                            max_features_n,
389                            &params,
390                            criterion,
391                            bootstrap,
392                            seed,
393                        )
394                    })
395                    .collect()
396            })
397        } else {
398            tree_seeds
399                .par_iter()
400                .map(|&seed| {
401                    build_single_classification_tree(
402                        x,
403                        &y_mapped,
404                        n_classes,
405                        n_samples,
406                        n_features,
407                        max_features_n,
408                        &params,
409                        criterion,
410                        bootstrap,
411                        seed,
412                    )
413                })
414                .collect()
415        };
416
417        // Aggregate feature importances across trees.
418        let mut total_importances = Array1::<F>::zeros(n_features);
419        for tree_nodes in &trees {
420            let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
421            total_importances = total_importances + tree_imp;
422        }
423        let imp_sum: F = total_importances
424            .iter()
425            .copied()
426            .fold(F::zero(), |a, b| a + b);
427        if imp_sum > F::zero() {
428            total_importances.mapv_inplace(|v| v / imp_sum);
429        }
430
431        Ok(FittedExtraTreesClassifier {
432            trees,
433            classes,
434            n_features,
435            feature_importances: total_importances,
436        })
437    }
438}
439
440/// Build a single classification extra-tree (used by parallel dispatch).
441#[allow(clippy::too_many_arguments)]
442fn build_single_classification_tree<F: Float>(
443    x: &Array2<F>,
444    y_mapped: &[usize],
445    n_classes: usize,
446    n_samples: usize,
447    n_features: usize,
448    max_features_n: usize,
449    params: &TreeParams,
450    criterion: ClassificationCriterion,
451    bootstrap: bool,
452    seed: u64,
453) -> Vec<Node<F>> {
454    let mut rng = StdRng::seed_from_u64(seed);
455
456    let indices: Vec<usize> = if bootstrap {
457        use rand::RngCore;
458        (0..n_samples)
459            .map(|_| (rng.next_u64() as usize) % n_samples)
460            .collect()
461    } else {
462        (0..n_samples).collect()
463    };
464
465    build_extra_classification_tree_for_ensemble(
466        x,
467        y_mapped,
468        n_classes,
469        &indices,
470        None, // feature selection happens inside the tree builder
471        params,
472        criterion,
473        n_features,
474        max_features_n,
475        &mut rng,
476    )
477}
478
479impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreesClassifier<F> {
480    type Output = Array1<usize>;
481    type Error = FerroError;
482
483    /// Predict class labels by majority vote across all trees.
484    ///
485    /// # Errors
486    ///
487    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
488    /// not match the fitted model.
489    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
490        if x.ncols() != self.n_features {
491            return Err(FerroError::ShapeMismatch {
492                expected: vec![self.n_features],
493                actual: vec![x.ncols()],
494                context: "number of features must match fitted model".into(),
495            });
496        }
497
498        let n_samples = x.nrows();
499        let n_classes = self.classes.len();
500        let mut predictions = Array1::zeros(n_samples);
501
502        for i in 0..n_samples {
503            let row = x.row(i);
504            let mut votes = vec![0usize; n_classes];
505
506            for tree_nodes in &self.trees {
507                let leaf_idx = traverse(tree_nodes, &row);
508                if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
509                    let class_idx = value.to_f64().map(|f| f.round() as usize).unwrap_or(0);
510                    if class_idx < n_classes {
511                        votes[class_idx] += 1;
512                    }
513                }
514            }
515
516            let winner = votes
517                .iter()
518                .enumerate()
519                .max_by_key(|&(_, &count)| count)
520                .map(|(idx, _)| idx)
521                .unwrap_or(0);
522            predictions[i] = self.classes[winner];
523        }
524
525        Ok(predictions)
526    }
527}
528
529impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreesClassifier<F> {
530    fn feature_importances(&self) -> &Array1<F> {
531        &self.feature_importances
532    }
533}
534
535impl<F: Float + Send + Sync + 'static> HasClasses for FittedExtraTreesClassifier<F> {
536    fn classes(&self) -> &[usize] {
537        &self.classes
538    }
539
540    fn n_classes(&self) -> usize {
541        self.classes.len()
542    }
543}
544
545// Pipeline integration.
546impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
547    for ExtraTreesClassifier<F>
548{
549    fn fit_pipeline(
550        &self,
551        x: &Array2<F>,
552        y: &Array1<F>,
553    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
554        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
555        let fitted = self.fit(x, &y_usize)?;
556        Ok(Box::new(FittedExtraTreesClassifierPipelineAdapter(fitted)))
557    }
558}
559
560/// Pipeline adapter for `FittedExtraTreesClassifier<F>`.
561struct FittedExtraTreesClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
562    FittedExtraTreesClassifier<F>,
563);
564
565impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
566    for FittedExtraTreesClassifierPipelineAdapter<F>
567{
568    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
569        let preds = self.0.predict(x)?;
570        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
571    }
572}
573
574// ---------------------------------------------------------------------------
575// ExtraTreesRegressor
576// ---------------------------------------------------------------------------
577
578/// Extremely randomized trees regressor (ensemble).
579///
580/// Builds an ensemble of [`ExtraTreeRegressor`](crate::ExtraTreeRegressor)
581/// base estimators, each using random split thresholds and random feature
582/// subsets at every node. Final predictions are the mean across all trees.
583///
584/// Unlike [`RandomForestRegressor`](crate::RandomForestRegressor), bootstrap
585/// sampling is **disabled** by default.
586///
587/// # Type Parameters
588///
589/// - `F`: The floating-point type (`f32` or `f64`).
590#[derive(Debug, Clone, Serialize, Deserialize)]
591pub struct ExtraTreesRegressor<F> {
592    /// Number of trees in the ensemble.
593    pub n_estimators: usize,
594    /// Maximum depth of each tree. `None` means unlimited.
595    pub max_depth: Option<usize>,
596    /// Minimum number of samples required to split an internal node.
597    pub min_samples_split: usize,
598    /// Minimum number of samples required in a leaf node.
599    pub min_samples_leaf: usize,
600    /// Strategy for the number of features considered at each split.
601    pub max_features: MaxFeatures,
602    /// Whether to use bootstrap sampling. Default is `false`.
603    pub bootstrap: bool,
604    /// Random seed for reproducibility. `None` means non-deterministic.
605    pub random_state: Option<u64>,
606    /// Number of parallel jobs. `None` means use all available cores.
607    pub n_jobs: Option<usize>,
608    _marker: std::marker::PhantomData<F>,
609}
610
611impl<F: Float> ExtraTreesRegressor<F> {
612    /// Create a new `ExtraTreesRegressor` with default settings.
613    ///
614    /// Defaults: `n_estimators = 100`, `max_depth = None`,
615    /// `max_features = All`, `min_samples_split = 2`,
616    /// `min_samples_leaf = 1`, `bootstrap = false`,
617    /// `random_state = None`, `n_jobs = None`.
618    #[must_use]
619    pub fn new() -> Self {
620        Self {
621            n_estimators: 100,
622            max_depth: None,
623            min_samples_split: 2,
624            min_samples_leaf: 1,
625            max_features: MaxFeatures::All,
626            bootstrap: false,
627            random_state: None,
628            n_jobs: None,
629            _marker: std::marker::PhantomData,
630        }
631    }
632
633    /// Set the number of trees.
634    #[must_use]
635    pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
636        self.n_estimators = n_estimators;
637        self
638    }
639
640    /// Set the maximum tree depth.
641    #[must_use]
642    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
643        self.max_depth = max_depth;
644        self
645    }
646
647    /// Set the minimum number of samples to split a node.
648    #[must_use]
649    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
650        self.min_samples_split = min_samples_split;
651        self
652    }
653
654    /// Set the minimum number of samples in a leaf.
655    #[must_use]
656    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
657        self.min_samples_leaf = min_samples_leaf;
658        self
659    }
660
661    /// Set the maximum features strategy.
662    #[must_use]
663    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
664        self.max_features = max_features;
665        self
666    }
667
668    /// Set whether to use bootstrap sampling.
669    #[must_use]
670    pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
671        self.bootstrap = bootstrap;
672        self
673    }
674
675    /// Set the random seed for reproducibility.
676    #[must_use]
677    pub fn with_random_state(mut self, seed: u64) -> Self {
678        self.random_state = Some(seed);
679        self
680    }
681
682    /// Set the number of parallel jobs.
683    #[must_use]
684    pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
685        self.n_jobs = Some(n_jobs);
686        self
687    }
688}
689
690impl<F: Float> Default for ExtraTreesRegressor<F> {
691    fn default() -> Self {
692        Self::new()
693    }
694}
695
696// ---------------------------------------------------------------------------
697// FittedExtraTreesRegressor
698// ---------------------------------------------------------------------------
699
700/// A fitted extremely randomized trees regressor (ensemble).
701///
702/// Stores the ensemble of fitted extra-trees and aggregates their
703/// predictions by averaging.
704#[derive(Debug, Clone)]
705pub struct FittedExtraTreesRegressor<F> {
706    /// Individual tree node vectors.
707    trees: Vec<Vec<Node<F>>>,
708    /// Number of features.
709    n_features: usize,
710    /// Per-feature importance scores (mean decrease in impurity, normalised).
711    feature_importances: Array1<F>,
712}
713
714impl<F: Float + Send + Sync + 'static> FittedExtraTreesRegressor<F> {
715    /// Returns a reference to the individual tree node vectors.
716    #[must_use]
717    pub fn trees(&self) -> &[Vec<Node<F>>] {
718        &self.trees
719    }
720
721    /// Returns the number of features the model was trained on.
722    #[must_use]
723    pub fn n_features(&self) -> usize {
724        self.n_features
725    }
726
727    /// Returns the number of trees in the ensemble.
728    #[must_use]
729    pub fn n_estimators(&self) -> usize {
730        self.trees.len()
731    }
732}
733
734impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for ExtraTreesRegressor<F> {
735    type Fitted = FittedExtraTreesRegressor<F>;
736    type Error = FerroError;
737
738    /// Fit the ensemble by building `n_estimators` extra-trees in parallel.
739    ///
740    /// # Errors
741    ///
742    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
743    /// numbers of samples.
744    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
745    /// Returns [`FerroError::InvalidParameter`] if `n_estimators` is 0.
746    fn fit(
747        &self,
748        x: &Array2<F>,
749        y: &Array1<F>,
750    ) -> Result<FittedExtraTreesRegressor<F>, FerroError> {
751        let (n_samples, n_features) = x.dim();
752
753        if n_samples != y.len() {
754            return Err(FerroError::ShapeMismatch {
755                expected: vec![n_samples],
756                actual: vec![y.len()],
757                context: "y length must match number of samples in X".into(),
758            });
759        }
760        if n_samples == 0 {
761            return Err(FerroError::InsufficientSamples {
762                required: 1,
763                actual: 0,
764                context: "ExtraTreesRegressor requires at least one sample".into(),
765            });
766        }
767        if self.n_estimators == 0 {
768            return Err(FerroError::InvalidParameter {
769                name: "n_estimators".into(),
770                reason: "must be at least 1".into(),
771            });
772        }
773
774        let max_features_n = resolve_max_features(self.max_features, n_features);
775        let params = make_tree_params(
776            self.max_depth,
777            self.min_samples_split,
778            self.min_samples_leaf,
779        );
780        let bootstrap = self.bootstrap;
781
782        // Generate per-tree seeds sequentially.
783        let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
784            let mut master_rng = StdRng::seed_from_u64(seed);
785            (0..self.n_estimators)
786                .map(|_| {
787                    use rand::RngCore;
788                    master_rng.next_u64()
789                })
790                .collect()
791        } else {
792            (0..self.n_estimators)
793                .map(|_| {
794                    use rand::RngCore;
795                    rand::rng().next_u64()
796                })
797                .collect()
798        };
799
800        // Build trees in parallel.
801        let trees: Vec<Vec<Node<F>>> = if let Some(n_jobs) = self.n_jobs {
802            let pool = rayon::ThreadPoolBuilder::new()
803                .num_threads(n_jobs)
804                .build()
805                .unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
806            pool.install(|| {
807                tree_seeds
808                    .par_iter()
809                    .map(|&seed| {
810                        build_single_regression_tree(
811                            x,
812                            y,
813                            n_samples,
814                            n_features,
815                            max_features_n,
816                            &params,
817                            bootstrap,
818                            seed,
819                        )
820                    })
821                    .collect()
822            })
823        } else {
824            tree_seeds
825                .par_iter()
826                .map(|&seed| {
827                    build_single_regression_tree(
828                        x,
829                        y,
830                        n_samples,
831                        n_features,
832                        max_features_n,
833                        &params,
834                        bootstrap,
835                        seed,
836                    )
837                })
838                .collect()
839        };
840
841        // Aggregate feature importances.
842        let mut total_importances = Array1::<F>::zeros(n_features);
843        for tree_nodes in &trees {
844            let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
845            total_importances = total_importances + tree_imp;
846        }
847        let imp_sum: F = total_importances
848            .iter()
849            .copied()
850            .fold(F::zero(), |a, b| a + b);
851        if imp_sum > F::zero() {
852            total_importances.mapv_inplace(|v| v / imp_sum);
853        }
854
855        Ok(FittedExtraTreesRegressor {
856            trees,
857            n_features,
858            feature_importances: total_importances,
859        })
860    }
861}
862
863/// Build a single regression extra-tree (used by parallel dispatch).
864#[allow(clippy::too_many_arguments)]
865fn build_single_regression_tree<F: Float>(
866    x: &Array2<F>,
867    y: &Array1<F>,
868    n_samples: usize,
869    n_features: usize,
870    max_features_n: usize,
871    params: &TreeParams,
872    bootstrap: bool,
873    seed: u64,
874) -> Vec<Node<F>> {
875    let mut rng = StdRng::seed_from_u64(seed);
876
877    let indices: Vec<usize> = if bootstrap {
878        use rand::RngCore;
879        (0..n_samples)
880            .map(|_| (rng.next_u64() as usize) % n_samples)
881            .collect()
882    } else {
883        (0..n_samples).collect()
884    };
885
886    build_extra_regression_tree_for_ensemble(
887        x,
888        y,
889        &indices,
890        None, // feature selection happens inside the tree builder
891        params,
892        n_features,
893        max_features_n,
894        &mut rng,
895    )
896}
897
898impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreesRegressor<F> {
899    type Output = Array1<F>;
900    type Error = FerroError;
901
902    /// Predict target values by averaging across all trees.
903    ///
904    /// # Errors
905    ///
906    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
907    /// not match the fitted model.
908    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
909        if x.ncols() != self.n_features {
910            return Err(FerroError::ShapeMismatch {
911                expected: vec![self.n_features],
912                actual: vec![x.ncols()],
913                context: "number of features must match fitted model".into(),
914            });
915        }
916
917        let n_samples = x.nrows();
918        let n_trees_f = F::from(self.trees.len()).unwrap();
919        let mut predictions = Array1::zeros(n_samples);
920
921        for i in 0..n_samples {
922            let row = x.row(i);
923            let mut sum = F::zero();
924
925            for tree_nodes in &self.trees {
926                let leaf_idx = traverse(tree_nodes, &row);
927                if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
928                    sum = sum + value;
929                }
930            }
931
932            predictions[i] = sum / n_trees_f;
933        }
934
935        Ok(predictions)
936    }
937}
938
939impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreesRegressor<F> {
940    fn feature_importances(&self) -> &Array1<F> {
941        &self.feature_importances
942    }
943}
944
945// Pipeline integration.
946impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for ExtraTreesRegressor<F> {
947    fn fit_pipeline(
948        &self,
949        x: &Array2<F>,
950        y: &Array1<F>,
951    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
952        let fitted = self.fit(x, y)?;
953        Ok(Box::new(fitted))
954    }
955}
956
957impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedExtraTreesRegressor<F> {
958    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
959        self.predict(x)
960    }
961}
962
963// ---------------------------------------------------------------------------
964// Tests
965// ---------------------------------------------------------------------------
966
967#[cfg(test)]
968mod tests {
969    use super::*;
970    use approx::assert_relative_eq;
971    use ndarray::array;
972
973    // -- ExtraTreesClassifier tests --
974
975    #[test]
976    fn test_ensemble_classifier_simple() {
977        let x = Array2::from_shape_vec(
978            (8, 2),
979            vec![
980                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
981            ],
982        )
983        .unwrap();
984        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
985
986        let model = ExtraTreesClassifier::<f64>::new()
987            .with_n_estimators(20)
988            .with_random_state(42);
989        let fitted = model.fit(&x, &y).unwrap();
990        let preds = fitted.predict(&x).unwrap();
991
992        // Should classify all training points correctly.
993        assert_eq!(preds, y);
994    }
995
996    #[test]
997    fn test_ensemble_classifier_no_bootstrap() {
998        let x = Array2::from_shape_vec(
999            (8, 2),
1000            vec![
1001                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1002            ],
1003        )
1004        .unwrap();
1005        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1006
1007        // Default: no bootstrap.
1008        let model = ExtraTreesClassifier::<f64>::new()
1009            .with_n_estimators(10)
1010            .with_random_state(42);
1011        assert!(!model.bootstrap);
1012        let fitted = model.fit(&x, &y).unwrap();
1013        let preds = fitted.predict(&x).unwrap();
1014        assert_eq!(preds, y);
1015    }
1016
1017    #[test]
1018    fn test_ensemble_classifier_with_bootstrap() {
1019        let x = Array2::from_shape_vec(
1020            (8, 2),
1021            vec![
1022                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1023            ],
1024        )
1025        .unwrap();
1026        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1027
1028        let model = ExtraTreesClassifier::<f64>::new()
1029            .with_n_estimators(20)
1030            .with_bootstrap(true)
1031            .with_random_state(42);
1032        assert!(model.bootstrap);
1033        let fitted = model.fit(&x, &y).unwrap();
1034        let preds = fitted.predict(&x).unwrap();
1035        assert_eq!(preds.len(), 8);
1036    }
1037
1038    #[test]
1039    fn test_ensemble_classifier_predict_proba() {
1040        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1041        let y = array![0, 0, 0, 1, 1, 1];
1042
1043        let model = ExtraTreesClassifier::<f64>::new()
1044            .with_n_estimators(10)
1045            .with_random_state(42);
1046        let fitted = model.fit(&x, &y).unwrap();
1047        let proba = fitted.predict_proba(&x).unwrap();
1048
1049        assert_eq!(proba.dim(), (6, 2));
1050        for i in 0..6 {
1051            let row_sum = proba.row(i).sum();
1052            assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
1053        }
1054    }
1055
1056    #[test]
1057    fn test_ensemble_classifier_feature_importances() {
1058        let x = Array2::from_shape_vec(
1059            (8, 2),
1060            vec![
1061                1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 5.0, 1.0, 6.0, 1.0, 7.0, 1.0, 8.0, 1.0,
1062            ],
1063        )
1064        .unwrap();
1065        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1066
1067        let model = ExtraTreesClassifier::<f64>::new()
1068            .with_n_estimators(20)
1069            .with_max_features(MaxFeatures::All)
1070            .with_random_state(42);
1071        let fitted = model.fit(&x, &y).unwrap();
1072        let importances = fitted.feature_importances();
1073
1074        assert_eq!(importances.len(), 2);
1075        let total: f64 = importances.sum();
1076        assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1077        // Feature 0 should dominate (feature 1 is constant).
1078        assert!(importances[0] > importances[1]);
1079    }
1080
1081    #[test]
1082    fn test_ensemble_classifier_n_estimators() {
1083        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1084        let y = array![0, 0, 1, 1];
1085
1086        let model = ExtraTreesClassifier::<f64>::new()
1087            .with_n_estimators(15)
1088            .with_random_state(42);
1089        let fitted = model.fit(&x, &y).unwrap();
1090        assert_eq!(fitted.n_estimators(), 15);
1091    }
1092
1093    #[test]
1094    fn test_ensemble_classifier_classes() {
1095        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1096        let y = array![0, 0, 0, 3, 3, 3]; // non-contiguous
1097
1098        let model = ExtraTreesClassifier::<f64>::new()
1099            .with_n_estimators(5)
1100            .with_random_state(42);
1101        let fitted = model.fit(&x, &y).unwrap();
1102        assert_eq!(fitted.classes(), &[0, 3]);
1103        assert_eq!(fitted.n_classes(), 2);
1104    }
1105
1106    #[test]
1107    fn test_ensemble_classifier_shape_mismatch() {
1108        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1109        let y = array![0, 0];
1110        let model = ExtraTreesClassifier::<f64>::new();
1111        assert!(model.fit(&x, &y).is_err());
1112    }
1113
1114    #[test]
1115    fn test_ensemble_classifier_empty_data() {
1116        let x = Array2::<f64>::zeros((0, 2));
1117        let y = Array1::<usize>::zeros(0);
1118        let model = ExtraTreesClassifier::<f64>::new();
1119        assert!(model.fit(&x, &y).is_err());
1120    }
1121
1122    #[test]
1123    fn test_ensemble_classifier_zero_estimators() {
1124        let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1125        let y = array![0, 1];
1126        let model = ExtraTreesClassifier::<f64>::new().with_n_estimators(0);
1127        assert!(model.fit(&x, &y).is_err());
1128    }
1129
1130    #[test]
1131    fn test_ensemble_classifier_deterministic() {
1132        let x = Array2::from_shape_vec(
1133            (8, 2),
1134            vec![
1135                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1136            ],
1137        )
1138        .unwrap();
1139        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1140
1141        let model1 = ExtraTreesClassifier::<f64>::new()
1142            .with_n_estimators(10)
1143            .with_random_state(123);
1144        let model2 = ExtraTreesClassifier::<f64>::new()
1145            .with_n_estimators(10)
1146            .with_random_state(123);
1147
1148        let preds1 = model1.fit(&x, &y).unwrap().predict(&x).unwrap();
1149        let preds2 = model2.fit(&x, &y).unwrap().predict(&x).unwrap();
1150        assert_eq!(preds1, preds2);
1151    }
1152
1153    #[test]
1154    fn test_ensemble_classifier_predict_shape_mismatch() {
1155        let x =
1156            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1157        let y = array![0, 0, 1, 1];
1158
1159        let model = ExtraTreesClassifier::<f64>::new()
1160            .with_n_estimators(5)
1161            .with_random_state(42);
1162        let fitted = model.fit(&x, &y).unwrap();
1163
1164        let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1165        assert!(fitted.predict(&x_wrong).is_err());
1166    }
1167
1168    // -- ExtraTreesRegressor tests --
1169
1170    #[test]
1171    fn test_ensemble_regressor_simple() {
1172        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1173        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1174
1175        let model = ExtraTreesRegressor::<f64>::new()
1176            .with_n_estimators(20)
1177            .with_random_state(42);
1178        let fitted = model.fit(&x, &y).unwrap();
1179        let preds = fitted.predict(&x).unwrap();
1180
1181        assert_eq!(preds.len(), 6);
1182        // Ensemble should approximate the training data well.
1183        for i in 0..6 {
1184            assert_relative_eq!(preds[i], y[i], epsilon = 1.0);
1185        }
1186    }
1187
1188    #[test]
1189    fn test_ensemble_regressor_constant_target() {
1190        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1191        let y = array![5.0, 5.0, 5.0, 5.0];
1192
1193        let model = ExtraTreesRegressor::<f64>::new()
1194            .with_n_estimators(10)
1195            .with_random_state(42);
1196        let fitted = model.fit(&x, &y).unwrap();
1197        let preds = fitted.predict(&x).unwrap();
1198
1199        for &p in preds.iter() {
1200            assert_relative_eq!(p, 5.0, epsilon = 1e-10);
1201        }
1202    }
1203
1204    #[test]
1205    fn test_ensemble_regressor_no_bootstrap() {
1206        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1207        let y = array![1.0, 2.0, 3.0, 4.0];
1208
1209        let model = ExtraTreesRegressor::<f64>::new()
1210            .with_n_estimators(10)
1211            .with_random_state(42);
1212        assert!(!model.bootstrap);
1213        let fitted = model.fit(&x, &y).unwrap();
1214        let preds = fitted.predict(&x).unwrap();
1215        assert_eq!(preds.len(), 4);
1216    }
1217
1218    #[test]
1219    fn test_ensemble_regressor_with_bootstrap() {
1220        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1221        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1222
1223        let model = ExtraTreesRegressor::<f64>::new()
1224            .with_n_estimators(10)
1225            .with_bootstrap(true)
1226            .with_random_state(42);
1227        assert!(model.bootstrap);
1228        let fitted = model.fit(&x, &y).unwrap();
1229        let preds = fitted.predict(&x).unwrap();
1230        assert_eq!(preds.len(), 6);
1231    }
1232
1233    #[test]
1234    fn test_ensemble_regressor_feature_importances() {
1235        let x = Array2::from_shape_vec(
1236            (8, 2),
1237            vec![
1238                1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1239            ],
1240        )
1241        .unwrap();
1242        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1243
1244        let model = ExtraTreesRegressor::<f64>::new()
1245            .with_n_estimators(20)
1246            .with_max_features(MaxFeatures::All)
1247            .with_random_state(42);
1248        let fitted = model.fit(&x, &y).unwrap();
1249        let importances = fitted.feature_importances();
1250
1251        assert_eq!(importances.len(), 2);
1252        let total: f64 = importances.sum();
1253        assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1254        assert!(importances[0] > importances[1]);
1255    }
1256
1257    #[test]
1258    fn test_ensemble_regressor_n_estimators() {
1259        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1260        let y = array![1.0, 2.0, 3.0, 4.0];
1261
1262        let model = ExtraTreesRegressor::<f64>::new()
1263            .with_n_estimators(7)
1264            .with_random_state(42);
1265        let fitted = model.fit(&x, &y).unwrap();
1266        assert_eq!(fitted.n_estimators(), 7);
1267    }
1268
1269    #[test]
1270    fn test_ensemble_regressor_shape_mismatch() {
1271        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1272        let y = array![1.0, 2.0];
1273        let model = ExtraTreesRegressor::<f64>::new();
1274        assert!(model.fit(&x, &y).is_err());
1275    }
1276
1277    #[test]
1278    fn test_ensemble_regressor_empty_data() {
1279        let x = Array2::<f64>::zeros((0, 2));
1280        let y = Array1::<f64>::zeros(0);
1281        let model = ExtraTreesRegressor::<f64>::new();
1282        assert!(model.fit(&x, &y).is_err());
1283    }
1284
1285    #[test]
1286    fn test_ensemble_regressor_zero_estimators() {
1287        let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1288        let y = array![1.0, 2.0];
1289        let model = ExtraTreesRegressor::<f64>::new().with_n_estimators(0);
1290        assert!(model.fit(&x, &y).is_err());
1291    }
1292
1293    #[test]
1294    fn test_ensemble_regressor_deterministic() {
1295        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1296        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1297
1298        let model1 = ExtraTreesRegressor::<f64>::new()
1299            .with_n_estimators(10)
1300            .with_random_state(99);
1301        let model2 = ExtraTreesRegressor::<f64>::new()
1302            .with_n_estimators(10)
1303            .with_random_state(99);
1304
1305        let preds1 = model1.fit(&x, &y).unwrap().predict(&x).unwrap();
1306        let preds2 = model2.fit(&x, &y).unwrap().predict(&x).unwrap();
1307
1308        for i in 0..6 {
1309            assert_relative_eq!(preds1[i], preds2[i], epsilon = 1e-12);
1310        }
1311    }
1312
1313    #[test]
1314    fn test_ensemble_regressor_predict_shape_mismatch() {
1315        let x =
1316            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1317        let y = array![1.0, 2.0, 3.0, 4.0];
1318
1319        let model = ExtraTreesRegressor::<f64>::new()
1320            .with_n_estimators(5)
1321            .with_random_state(42);
1322        let fitted = model.fit(&x, &y).unwrap();
1323
1324        let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1325        assert!(fitted.predict(&x_wrong).is_err());
1326    }
1327
1328    // -- Builder tests --
1329
1330    #[test]
1331    fn test_ensemble_classifier_builder() {
1332        let model = ExtraTreesClassifier::<f64>::new()
1333            .with_n_estimators(50)
1334            .with_max_depth(Some(5))
1335            .with_min_samples_split(10)
1336            .with_min_samples_leaf(3)
1337            .with_max_features(MaxFeatures::Log2)
1338            .with_bootstrap(true)
1339            .with_criterion(ClassificationCriterion::Entropy)
1340            .with_random_state(42)
1341            .with_n_jobs(4);
1342
1343        assert_eq!(model.n_estimators, 50);
1344        assert_eq!(model.max_depth, Some(5));
1345        assert_eq!(model.min_samples_split, 10);
1346        assert_eq!(model.min_samples_leaf, 3);
1347        assert_eq!(model.max_features, MaxFeatures::Log2);
1348        assert!(model.bootstrap);
1349        assert_eq!(model.criterion, ClassificationCriterion::Entropy);
1350        assert_eq!(model.random_state, Some(42));
1351        assert_eq!(model.n_jobs, Some(4));
1352    }
1353
1354    #[test]
1355    fn test_ensemble_regressor_builder() {
1356        let model = ExtraTreesRegressor::<f64>::new()
1357            .with_n_estimators(25)
1358            .with_max_depth(Some(8))
1359            .with_min_samples_split(5)
1360            .with_min_samples_leaf(2)
1361            .with_max_features(MaxFeatures::Fraction(0.5))
1362            .with_bootstrap(true)
1363            .with_random_state(99)
1364            .with_n_jobs(2);
1365
1366        assert_eq!(model.n_estimators, 25);
1367        assert_eq!(model.max_depth, Some(8));
1368        assert_eq!(model.min_samples_split, 5);
1369        assert_eq!(model.min_samples_leaf, 2);
1370        assert_eq!(model.max_features, MaxFeatures::Fraction(0.5));
1371        assert!(model.bootstrap);
1372        assert_eq!(model.random_state, Some(99));
1373        assert_eq!(model.n_jobs, Some(2));
1374    }
1375
1376    #[test]
1377    fn test_ensemble_classifier_default() {
1378        let model = ExtraTreesClassifier::<f64>::default();
1379        assert_eq!(model.n_estimators, 100);
1380        assert_eq!(model.max_depth, None);
1381        assert_eq!(model.min_samples_split, 2);
1382        assert_eq!(model.min_samples_leaf, 1);
1383        assert_eq!(model.max_features, MaxFeatures::Sqrt);
1384        assert!(!model.bootstrap);
1385        assert_eq!(model.criterion, ClassificationCriterion::Gini);
1386        assert_eq!(model.random_state, None);
1387        assert_eq!(model.n_jobs, None);
1388    }
1389
1390    #[test]
1391    fn test_ensemble_regressor_default() {
1392        let model = ExtraTreesRegressor::<f64>::default();
1393        assert_eq!(model.n_estimators, 100);
1394        assert_eq!(model.max_depth, None);
1395        assert_eq!(model.min_samples_split, 2);
1396        assert_eq!(model.min_samples_leaf, 1);
1397        assert_eq!(model.max_features, MaxFeatures::All);
1398        assert!(!model.bootstrap);
1399        assert_eq!(model.random_state, None);
1400        assert_eq!(model.n_jobs, None);
1401    }
1402}