Skip to main content

ferrolearn_tree/
extra_trees_ensemble.rs

1//! Extremely randomized trees ensemble classifiers and regressors.
2//!
3//! This module provides [`ExtraTreesClassifier`] and [`ExtraTreesRegressor`],
4//! which build ensembles of extremely randomized trees. Unlike
5//! [`RandomForestClassifier`](crate::RandomForestClassifier), ExtraTrees
6//! ensembles do **not** bootstrap by default: all trees see all samples, and
7//! randomness comes solely from the random split thresholds and random feature
8//! subsets at each node.
9//!
10//! # Examples
11//!
12//! ```
13//! use ferrolearn_tree::ExtraTreesClassifier;
14//! use ferrolearn_core::{Fit, Predict};
15//! use ndarray::{array, Array1, Array2};
16//!
17//! let x = Array2::from_shape_vec((8, 2), vec![
18//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,  4.0, 4.0,
19//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,  8.0, 9.0,
20//! ]).unwrap();
21//! let y = array![0, 0, 0, 0, 1, 1, 1, 1];
22//!
23//! let model = ExtraTreesClassifier::<f64>::new()
24//!     .with_n_estimators(10)
25//!     .with_random_state(42);
26//! let fitted = model.fit(&x, &y).unwrap();
27//! let preds = fitted.predict(&x).unwrap();
28//! ```
29
30use ferrolearn_core::error::FerroError;
31use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
32use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
33use ferrolearn_core::traits::{Fit, Predict};
34use ndarray::{Array1, Array2};
35use num_traits::{Float, FromPrimitive, ToPrimitive};
36use rand::SeedableRng;
37use rand::rngs::StdRng;
38use rayon::prelude::*;
39use serde::{Deserialize, Serialize};
40
41use crate::decision_tree::{
42    ClassificationCriterion, Node, TreeParams, compute_feature_importances, traverse,
43};
44use crate::extra_tree::{
45    build_extra_classification_tree_for_ensemble, build_extra_regression_tree_for_ensemble,
46};
47use crate::random_forest::MaxFeatures;
48
49/// Resolve the `MaxFeatures` strategy to a concrete number.
50fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
51    let result = match strategy {
52        MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
53        MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
54        MaxFeatures::All => n_features,
55        MaxFeatures::Fixed(n) => n.min(n_features),
56        MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
57    };
58    result.max(1).min(n_features)
59}
60
61/// Internal tree parameter struct helper.
62fn make_tree_params(
63    max_depth: Option<usize>,
64    min_samples_split: usize,
65    min_samples_leaf: usize,
66) -> TreeParams {
67    TreeParams {
68        max_depth,
69        min_samples_split,
70        min_samples_leaf,
71    }
72}
73
74// ---------------------------------------------------------------------------
75// ExtraTreesClassifier
76// ---------------------------------------------------------------------------
77
78/// Extremely randomized trees classifier (ensemble).
79///
80/// Builds an ensemble of [`ExtraTreeClassifier`](crate::ExtraTreeClassifier)
81/// base estimators, each using random split thresholds and random feature
82/// subsets at every node. Final predictions are made by majority vote.
83///
84/// Unlike [`RandomForestClassifier`](crate::RandomForestClassifier), bootstrap
85/// sampling is **disabled** by default. Randomness comes from the random
86/// thresholds and random feature subsets at each split.
87///
88/// # Type Parameters
89///
90/// - `F`: The floating-point type (`f32` or `f64`).
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct ExtraTreesClassifier<F> {
93    /// Number of trees in the ensemble.
94    pub n_estimators: usize,
95    /// Maximum depth of each tree. `None` means unlimited.
96    pub max_depth: Option<usize>,
97    /// Minimum number of samples required to split an internal node.
98    pub min_samples_split: usize,
99    /// Minimum number of samples required in a leaf node.
100    pub min_samples_leaf: usize,
101    /// Strategy for the number of features considered at each split.
102    pub max_features: MaxFeatures,
103    /// Whether to use bootstrap sampling. Default is `false`.
104    pub bootstrap: bool,
105    /// Splitting criterion.
106    pub criterion: ClassificationCriterion,
107    /// Random seed for reproducibility. `None` means non-deterministic.
108    pub random_state: Option<u64>,
109    /// Number of parallel jobs. `None` means use all available cores.
110    pub n_jobs: Option<usize>,
111    _marker: std::marker::PhantomData<F>,
112}
113
114impl<F: Float> ExtraTreesClassifier<F> {
115    /// Create a new `ExtraTreesClassifier` with default settings.
116    ///
117    /// Defaults: `n_estimators = 100`, `max_depth = None`,
118    /// `max_features = Sqrt`, `min_samples_split = 2`,
119    /// `min_samples_leaf = 1`, `bootstrap = false`,
120    /// `criterion = Gini`, `random_state = None`, `n_jobs = None`.
121    #[must_use]
122    pub fn new() -> Self {
123        Self {
124            n_estimators: 100,
125            max_depth: None,
126            min_samples_split: 2,
127            min_samples_leaf: 1,
128            max_features: MaxFeatures::Sqrt,
129            bootstrap: false,
130            criterion: ClassificationCriterion::Gini,
131            random_state: None,
132            n_jobs: None,
133            _marker: std::marker::PhantomData,
134        }
135    }
136
137    /// Set the number of trees.
138    #[must_use]
139    pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
140        self.n_estimators = n_estimators;
141        self
142    }
143
144    /// Set the maximum tree depth.
145    #[must_use]
146    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
147        self.max_depth = max_depth;
148        self
149    }
150
151    /// Set the minimum number of samples to split a node.
152    #[must_use]
153    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
154        self.min_samples_split = min_samples_split;
155        self
156    }
157
158    /// Set the minimum number of samples in a leaf.
159    #[must_use]
160    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
161        self.min_samples_leaf = min_samples_leaf;
162        self
163    }
164
165    /// Set the maximum features strategy.
166    #[must_use]
167    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
168        self.max_features = max_features;
169        self
170    }
171
172    /// Set whether to use bootstrap sampling.
173    #[must_use]
174    pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
175        self.bootstrap = bootstrap;
176        self
177    }
178
179    /// Set the splitting criterion.
180    #[must_use]
181    pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
182        self.criterion = criterion;
183        self
184    }
185
186    /// Set the random seed for reproducibility.
187    #[must_use]
188    pub fn with_random_state(mut self, seed: u64) -> Self {
189        self.random_state = Some(seed);
190        self
191    }
192
193    /// Set the number of parallel jobs.
194    #[must_use]
195    pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
196        self.n_jobs = Some(n_jobs);
197        self
198    }
199}
200
201impl<F: Float> Default for ExtraTreesClassifier<F> {
202    fn default() -> Self {
203        Self::new()
204    }
205}
206
207// ---------------------------------------------------------------------------
208// FittedExtraTreesClassifier
209// ---------------------------------------------------------------------------
210
211/// A fitted extremely randomized trees classifier (ensemble).
212///
213/// Stores the ensemble of fitted extra-trees and aggregates their
214/// predictions by majority vote.
215#[derive(Debug, Clone)]
216pub struct FittedExtraTreesClassifier<F> {
217    /// Individual tree node vectors.
218    trees: Vec<Vec<Node<F>>>,
219    /// Sorted unique class labels.
220    classes: Vec<usize>,
221    /// Number of features.
222    n_features: usize,
223    /// Per-feature importance scores (mean decrease in impurity, normalised).
224    feature_importances: Array1<F>,
225}
226
227impl<F: Float + Send + Sync + 'static> FittedExtraTreesClassifier<F> {
228    /// Returns a reference to the individual tree node vectors.
229    #[must_use]
230    pub fn trees(&self) -> &[Vec<Node<F>>] {
231        &self.trees
232    }
233
234    /// Returns the number of features the model was trained on.
235    #[must_use]
236    pub fn n_features(&self) -> usize {
237        self.n_features
238    }
239
240    /// Returns the number of trees in the ensemble.
241    #[must_use]
242    pub fn n_estimators(&self) -> usize {
243        self.trees.len()
244    }
245
246    /// Predict class probabilities for each sample by averaging tree predictions.
247    ///
248    /// Returns a 2-D array of shape `(n_samples, n_classes)`.
249    ///
250    /// # Errors
251    ///
252    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
253    /// not match the training data.
254    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
255        if x.ncols() != self.n_features {
256            return Err(FerroError::ShapeMismatch {
257                expected: vec![self.n_features],
258                actual: vec![x.ncols()],
259                context: "number of features must match fitted model".into(),
260            });
261        }
262        let n_samples = x.nrows();
263        let n_classes = self.classes.len();
264        let n_trees_f = F::from(self.trees.len()).unwrap();
265        let mut proba = Array2::zeros((n_samples, n_classes));
266
267        for i in 0..n_samples {
268            let row = x.row(i);
269            for tree_nodes in &self.trees {
270                let leaf_idx = traverse(tree_nodes, &row);
271                if let Node::Leaf {
272                    class_distribution: Some(ref dist),
273                    ..
274                } = tree_nodes[leaf_idx]
275                {
276                    for (j, &p) in dist.iter().enumerate() {
277                        proba[[i, j]] = proba[[i, j]] + p;
278                    }
279                }
280            }
281            // Average across trees.
282            for j in 0..n_classes {
283                proba[[i, j]] = proba[[i, j]] / n_trees_f;
284            }
285        }
286        Ok(proba)
287    }
288}
289
290impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ExtraTreesClassifier<F> {
291    type Fitted = FittedExtraTreesClassifier<F>;
292    type Error = FerroError;
293
294    /// Fit the ensemble by building `n_estimators` extra-trees in parallel.
295    ///
296    /// Each tree uses random split thresholds and random feature subsets at
297    /// every node. If `bootstrap` is `true`, each tree is trained on a
298    /// bootstrap sample; otherwise all samples are used.
299    ///
300    /// # Errors
301    ///
302    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
303    /// numbers of samples.
304    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
305    /// Returns [`FerroError::InvalidParameter`] if `n_estimators` is 0.
306    fn fit(
307        &self,
308        x: &Array2<F>,
309        y: &Array1<usize>,
310    ) -> Result<FittedExtraTreesClassifier<F>, FerroError> {
311        let (n_samples, n_features) = x.dim();
312
313        if n_samples != y.len() {
314            return Err(FerroError::ShapeMismatch {
315                expected: vec![n_samples],
316                actual: vec![y.len()],
317                context: "y length must match number of samples in X".into(),
318            });
319        }
320        if n_samples == 0 {
321            return Err(FerroError::InsufficientSamples {
322                required: 1,
323                actual: 0,
324                context: "ExtraTreesClassifier requires at least one sample".into(),
325            });
326        }
327        if self.n_estimators == 0 {
328            return Err(FerroError::InvalidParameter {
329                name: "n_estimators".into(),
330                reason: "must be at least 1".into(),
331            });
332        }
333
334        // Determine unique classes.
335        let mut classes: Vec<usize> = y.iter().copied().collect();
336        classes.sort_unstable();
337        classes.dedup();
338        let n_classes = classes.len();
339
340        let y_mapped: Vec<usize> = y
341            .iter()
342            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
343            .collect();
344
345        let max_features_n = resolve_max_features(self.max_features, n_features);
346        let params = make_tree_params(
347            self.max_depth,
348            self.min_samples_split,
349            self.min_samples_leaf,
350        );
351        let criterion = self.criterion;
352        let bootstrap = self.bootstrap;
353
354        // Generate per-tree seeds sequentially for determinism.
355        let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
356            let mut master_rng = StdRng::seed_from_u64(seed);
357            (0..self.n_estimators)
358                .map(|_| {
359                    use rand::RngCore;
360                    master_rng.next_u64()
361                })
362                .collect()
363        } else {
364            (0..self.n_estimators)
365                .map(|_| {
366                    use rand::RngCore;
367                    rand::rng().next_u64()
368                })
369                .collect()
370        };
371
372        // Optionally configure thread pool.
373        let trees: Vec<Vec<Node<F>>> = if let Some(n_jobs) = self.n_jobs {
374            let pool = rayon::ThreadPoolBuilder::new()
375                .num_threads(n_jobs)
376                .build()
377                .unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
378            pool.install(|| {
379                tree_seeds
380                    .par_iter()
381                    .map(|&seed| {
382                        build_single_classification_tree(
383                            x,
384                            &y_mapped,
385                            n_classes,
386                            n_samples,
387                            n_features,
388                            max_features_n,
389                            &params,
390                            criterion,
391                            bootstrap,
392                            seed,
393                        )
394                    })
395                    .collect()
396            })
397        } else {
398            tree_seeds
399                .par_iter()
400                .map(|&seed| {
401                    build_single_classification_tree(
402                        x,
403                        &y_mapped,
404                        n_classes,
405                        n_samples,
406                        n_features,
407                        max_features_n,
408                        &params,
409                        criterion,
410                        bootstrap,
411                        seed,
412                    )
413                })
414                .collect()
415        };
416
417        // Aggregate feature importances across trees.
418        let mut total_importances = Array1::<F>::zeros(n_features);
419        for tree_nodes in &trees {
420            let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
421            total_importances = total_importances + tree_imp;
422        }
423        let imp_sum: F = total_importances
424            .iter()
425            .copied()
426            .fold(F::zero(), |a, b| a + b);
427        if imp_sum > F::zero() {
428            total_importances.mapv_inplace(|v| v / imp_sum);
429        }
430
431        Ok(FittedExtraTreesClassifier {
432            trees,
433            classes,
434            n_features,
435            feature_importances: total_importances,
436        })
437    }
438}
439
440/// Build a single classification extra-tree (used by parallel dispatch).
441#[allow(clippy::too_many_arguments)]
442fn build_single_classification_tree<F: Float>(
443    x: &Array2<F>,
444    y_mapped: &[usize],
445    n_classes: usize,
446    n_samples: usize,
447    n_features: usize,
448    max_features_n: usize,
449    params: &TreeParams,
450    criterion: ClassificationCriterion,
451    bootstrap: bool,
452    seed: u64,
453) -> Vec<Node<F>> {
454    let mut rng = StdRng::seed_from_u64(seed);
455
456    let indices: Vec<usize> = if bootstrap {
457        use rand::RngCore;
458        (0..n_samples)
459            .map(|_| (rng.next_u64() as usize) % n_samples)
460            .collect()
461    } else {
462        (0..n_samples).collect()
463    };
464
465    build_extra_classification_tree_for_ensemble(
466        x,
467        y_mapped,
468        n_classes,
469        &indices,
470        None, // feature selection happens inside the tree builder
471        params,
472        criterion,
473        n_features,
474        max_features_n,
475        &mut rng,
476    )
477}
478
479impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreesClassifier<F> {
480    type Output = Array1<usize>;
481    type Error = FerroError;
482
483    /// Predict class labels by majority vote across all trees.
484    ///
485    /// # Errors
486    ///
487    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
488    /// not match the fitted model.
489    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
490        if x.ncols() != self.n_features {
491            return Err(FerroError::ShapeMismatch {
492                expected: vec![self.n_features],
493                actual: vec![x.ncols()],
494                context: "number of features must match fitted model".into(),
495            });
496        }
497
498        let n_samples = x.nrows();
499        let n_classes = self.classes.len();
500        let mut predictions = Array1::zeros(n_samples);
501
502        for i in 0..n_samples {
503            let row = x.row(i);
504            let mut votes = vec![0usize; n_classes];
505
506            for tree_nodes in &self.trees {
507                let leaf_idx = traverse(tree_nodes, &row);
508                if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
509                    let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
510                    if class_idx < n_classes {
511                        votes[class_idx] += 1;
512                    }
513                }
514            }
515
516            let winner = votes
517                .iter()
518                .enumerate()
519                .max_by_key(|&(_, &count)| count)
520                .map_or(0, |(idx, _)| idx);
521            predictions[i] = self.classes[winner];
522        }
523
524        Ok(predictions)
525    }
526}
527
528impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreesClassifier<F> {
529    fn feature_importances(&self) -> &Array1<F> {
530        &self.feature_importances
531    }
532}
533
534impl<F: Float + Send + Sync + 'static> HasClasses for FittedExtraTreesClassifier<F> {
535    fn classes(&self) -> &[usize] {
536        &self.classes
537    }
538
539    fn n_classes(&self) -> usize {
540        self.classes.len()
541    }
542}
543
544// Pipeline integration.
545impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
546    for ExtraTreesClassifier<F>
547{
548    fn fit_pipeline(
549        &self,
550        x: &Array2<F>,
551        y: &Array1<F>,
552    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
553        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
554        let fitted = self.fit(x, &y_usize)?;
555        Ok(Box::new(FittedExtraTreesClassifierPipelineAdapter(fitted)))
556    }
557}
558
559/// Pipeline adapter for `FittedExtraTreesClassifier<F>`.
560struct FittedExtraTreesClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
561    FittedExtraTreesClassifier<F>,
562);
563
564impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
565    for FittedExtraTreesClassifierPipelineAdapter<F>
566{
567    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
568        let preds = self.0.predict(x)?;
569        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
570    }
571}
572
573// ---------------------------------------------------------------------------
574// ExtraTreesRegressor
575// ---------------------------------------------------------------------------
576
577/// Extremely randomized trees regressor (ensemble).
578///
579/// Builds an ensemble of [`ExtraTreeRegressor`](crate::ExtraTreeRegressor)
580/// base estimators, each using random split thresholds and random feature
581/// subsets at every node. Final predictions are the mean across all trees.
582///
583/// Unlike [`RandomForestRegressor`](crate::RandomForestRegressor), bootstrap
584/// sampling is **disabled** by default.
585///
586/// # Type Parameters
587///
588/// - `F`: The floating-point type (`f32` or `f64`).
589#[derive(Debug, Clone, Serialize, Deserialize)]
590pub struct ExtraTreesRegressor<F> {
591    /// Number of trees in the ensemble.
592    pub n_estimators: usize,
593    /// Maximum depth of each tree. `None` means unlimited.
594    pub max_depth: Option<usize>,
595    /// Minimum number of samples required to split an internal node.
596    pub min_samples_split: usize,
597    /// Minimum number of samples required in a leaf node.
598    pub min_samples_leaf: usize,
599    /// Strategy for the number of features considered at each split.
600    pub max_features: MaxFeatures,
601    /// Whether to use bootstrap sampling. Default is `false`.
602    pub bootstrap: bool,
603    /// Random seed for reproducibility. `None` means non-deterministic.
604    pub random_state: Option<u64>,
605    /// Number of parallel jobs. `None` means use all available cores.
606    pub n_jobs: Option<usize>,
607    _marker: std::marker::PhantomData<F>,
608}
609
610impl<F: Float> ExtraTreesRegressor<F> {
611    /// Create a new `ExtraTreesRegressor` with default settings.
612    ///
613    /// Defaults: `n_estimators = 100`, `max_depth = None`,
614    /// `max_features = All`, `min_samples_split = 2`,
615    /// `min_samples_leaf = 1`, `bootstrap = false`,
616    /// `random_state = None`, `n_jobs = None`.
617    #[must_use]
618    pub fn new() -> Self {
619        Self {
620            n_estimators: 100,
621            max_depth: None,
622            min_samples_split: 2,
623            min_samples_leaf: 1,
624            max_features: MaxFeatures::All,
625            bootstrap: false,
626            random_state: None,
627            n_jobs: None,
628            _marker: std::marker::PhantomData,
629        }
630    }
631
632    /// Set the number of trees.
633    #[must_use]
634    pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
635        self.n_estimators = n_estimators;
636        self
637    }
638
639    /// Set the maximum tree depth.
640    #[must_use]
641    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
642        self.max_depth = max_depth;
643        self
644    }
645
646    /// Set the minimum number of samples to split a node.
647    #[must_use]
648    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
649        self.min_samples_split = min_samples_split;
650        self
651    }
652
653    /// Set the minimum number of samples in a leaf.
654    #[must_use]
655    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
656        self.min_samples_leaf = min_samples_leaf;
657        self
658    }
659
660    /// Set the maximum features strategy.
661    #[must_use]
662    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
663        self.max_features = max_features;
664        self
665    }
666
667    /// Set whether to use bootstrap sampling.
668    #[must_use]
669    pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
670        self.bootstrap = bootstrap;
671        self
672    }
673
674    /// Set the random seed for reproducibility.
675    #[must_use]
676    pub fn with_random_state(mut self, seed: u64) -> Self {
677        self.random_state = Some(seed);
678        self
679    }
680
681    /// Set the number of parallel jobs.
682    #[must_use]
683    pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
684        self.n_jobs = Some(n_jobs);
685        self
686    }
687}
688
689impl<F: Float> Default for ExtraTreesRegressor<F> {
690    fn default() -> Self {
691        Self::new()
692    }
693}
694
695// ---------------------------------------------------------------------------
696// FittedExtraTreesRegressor
697// ---------------------------------------------------------------------------
698
699/// A fitted extremely randomized trees regressor (ensemble).
700///
701/// Stores the ensemble of fitted extra-trees and aggregates their
702/// predictions by averaging.
703#[derive(Debug, Clone)]
704pub struct FittedExtraTreesRegressor<F> {
705    /// Individual tree node vectors.
706    trees: Vec<Vec<Node<F>>>,
707    /// Number of features.
708    n_features: usize,
709    /// Per-feature importance scores (mean decrease in impurity, normalised).
710    feature_importances: Array1<F>,
711}
712
713impl<F: Float + Send + Sync + 'static> FittedExtraTreesRegressor<F> {
714    /// Returns a reference to the individual tree node vectors.
715    #[must_use]
716    pub fn trees(&self) -> &[Vec<Node<F>>] {
717        &self.trees
718    }
719
720    /// Returns the number of features the model was trained on.
721    #[must_use]
722    pub fn n_features(&self) -> usize {
723        self.n_features
724    }
725
726    /// Returns the number of trees in the ensemble.
727    #[must_use]
728    pub fn n_estimators(&self) -> usize {
729        self.trees.len()
730    }
731}
732
733impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for ExtraTreesRegressor<F> {
734    type Fitted = FittedExtraTreesRegressor<F>;
735    type Error = FerroError;
736
737    /// Fit the ensemble by building `n_estimators` extra-trees in parallel.
738    ///
739    /// # Errors
740    ///
741    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
742    /// numbers of samples.
743    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
744    /// Returns [`FerroError::InvalidParameter`] if `n_estimators` is 0.
745    fn fit(
746        &self,
747        x: &Array2<F>,
748        y: &Array1<F>,
749    ) -> Result<FittedExtraTreesRegressor<F>, FerroError> {
750        let (n_samples, n_features) = x.dim();
751
752        if n_samples != y.len() {
753            return Err(FerroError::ShapeMismatch {
754                expected: vec![n_samples],
755                actual: vec![y.len()],
756                context: "y length must match number of samples in X".into(),
757            });
758        }
759        if n_samples == 0 {
760            return Err(FerroError::InsufficientSamples {
761                required: 1,
762                actual: 0,
763                context: "ExtraTreesRegressor requires at least one sample".into(),
764            });
765        }
766        if self.n_estimators == 0 {
767            return Err(FerroError::InvalidParameter {
768                name: "n_estimators".into(),
769                reason: "must be at least 1".into(),
770            });
771        }
772
773        let max_features_n = resolve_max_features(self.max_features, n_features);
774        let params = make_tree_params(
775            self.max_depth,
776            self.min_samples_split,
777            self.min_samples_leaf,
778        );
779        let bootstrap = self.bootstrap;
780
781        // Generate per-tree seeds sequentially.
782        let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
783            let mut master_rng = StdRng::seed_from_u64(seed);
784            (0..self.n_estimators)
785                .map(|_| {
786                    use rand::RngCore;
787                    master_rng.next_u64()
788                })
789                .collect()
790        } else {
791            (0..self.n_estimators)
792                .map(|_| {
793                    use rand::RngCore;
794                    rand::rng().next_u64()
795                })
796                .collect()
797        };
798
799        // Build trees in parallel.
800        let trees: Vec<Vec<Node<F>>> = if let Some(n_jobs) = self.n_jobs {
801            let pool = rayon::ThreadPoolBuilder::new()
802                .num_threads(n_jobs)
803                .build()
804                .unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
805            pool.install(|| {
806                tree_seeds
807                    .par_iter()
808                    .map(|&seed| {
809                        build_single_regression_tree(
810                            x,
811                            y,
812                            n_samples,
813                            n_features,
814                            max_features_n,
815                            &params,
816                            bootstrap,
817                            seed,
818                        )
819                    })
820                    .collect()
821            })
822        } else {
823            tree_seeds
824                .par_iter()
825                .map(|&seed| {
826                    build_single_regression_tree(
827                        x,
828                        y,
829                        n_samples,
830                        n_features,
831                        max_features_n,
832                        &params,
833                        bootstrap,
834                        seed,
835                    )
836                })
837                .collect()
838        };
839
840        // Aggregate feature importances.
841        let mut total_importances = Array1::<F>::zeros(n_features);
842        for tree_nodes in &trees {
843            let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
844            total_importances = total_importances + tree_imp;
845        }
846        let imp_sum: F = total_importances
847            .iter()
848            .copied()
849            .fold(F::zero(), |a, b| a + b);
850        if imp_sum > F::zero() {
851            total_importances.mapv_inplace(|v| v / imp_sum);
852        }
853
854        Ok(FittedExtraTreesRegressor {
855            trees,
856            n_features,
857            feature_importances: total_importances,
858        })
859    }
860}
861
862/// Build a single regression extra-tree (used by parallel dispatch).
863#[allow(clippy::too_many_arguments)]
864fn build_single_regression_tree<F: Float>(
865    x: &Array2<F>,
866    y: &Array1<F>,
867    n_samples: usize,
868    n_features: usize,
869    max_features_n: usize,
870    params: &TreeParams,
871    bootstrap: bool,
872    seed: u64,
873) -> Vec<Node<F>> {
874    let mut rng = StdRng::seed_from_u64(seed);
875
876    let indices: Vec<usize> = if bootstrap {
877        use rand::RngCore;
878        (0..n_samples)
879            .map(|_| (rng.next_u64() as usize) % n_samples)
880            .collect()
881    } else {
882        (0..n_samples).collect()
883    };
884
885    build_extra_regression_tree_for_ensemble(
886        x,
887        y,
888        &indices,
889        None, // feature selection happens inside the tree builder
890        params,
891        n_features,
892        max_features_n,
893        &mut rng,
894    )
895}
896
897impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreesRegressor<F> {
898    type Output = Array1<F>;
899    type Error = FerroError;
900
901    /// Predict target values by averaging across all trees.
902    ///
903    /// # Errors
904    ///
905    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
906    /// not match the fitted model.
907    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
908        if x.ncols() != self.n_features {
909            return Err(FerroError::ShapeMismatch {
910                expected: vec![self.n_features],
911                actual: vec![x.ncols()],
912                context: "number of features must match fitted model".into(),
913            });
914        }
915
916        let n_samples = x.nrows();
917        let n_trees_f = F::from(self.trees.len()).unwrap();
918        let mut predictions = Array1::zeros(n_samples);
919
920        for i in 0..n_samples {
921            let row = x.row(i);
922            let mut sum = F::zero();
923
924            for tree_nodes in &self.trees {
925                let leaf_idx = traverse(tree_nodes, &row);
926                if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
927                    sum = sum + value;
928                }
929            }
930
931            predictions[i] = sum / n_trees_f;
932        }
933
934        Ok(predictions)
935    }
936}
937
938impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreesRegressor<F> {
939    fn feature_importances(&self) -> &Array1<F> {
940        &self.feature_importances
941    }
942}
943
944// Pipeline integration.
945impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for ExtraTreesRegressor<F> {
946    fn fit_pipeline(
947        &self,
948        x: &Array2<F>,
949        y: &Array1<F>,
950    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
951        let fitted = self.fit(x, y)?;
952        Ok(Box::new(fitted))
953    }
954}
955
956impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedExtraTreesRegressor<F> {
957    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
958        self.predict(x)
959    }
960}
961
962// ---------------------------------------------------------------------------
963// Tests
964// ---------------------------------------------------------------------------
965
966#[cfg(test)]
967mod tests {
968    use super::*;
969    use approx::assert_relative_eq;
970    use ndarray::array;
971
972    // -- ExtraTreesClassifier tests --
973
974    #[test]
975    fn test_ensemble_classifier_simple() {
976        let x = Array2::from_shape_vec(
977            (8, 2),
978            vec![
979                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
980            ],
981        )
982        .unwrap();
983        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
984
985        let model = ExtraTreesClassifier::<f64>::new()
986            .with_n_estimators(20)
987            .with_random_state(42);
988        let fitted = model.fit(&x, &y).unwrap();
989        let preds = fitted.predict(&x).unwrap();
990
991        // Should classify all training points correctly.
992        assert_eq!(preds, y);
993    }
994
995    #[test]
996    fn test_ensemble_classifier_no_bootstrap() {
997        let x = Array2::from_shape_vec(
998            (8, 2),
999            vec![
1000                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1001            ],
1002        )
1003        .unwrap();
1004        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1005
1006        // Default: no bootstrap.
1007        let model = ExtraTreesClassifier::<f64>::new()
1008            .with_n_estimators(10)
1009            .with_random_state(42);
1010        assert!(!model.bootstrap);
1011        let fitted = model.fit(&x, &y).unwrap();
1012        let preds = fitted.predict(&x).unwrap();
1013        assert_eq!(preds, y);
1014    }
1015
1016    #[test]
1017    fn test_ensemble_classifier_with_bootstrap() {
1018        let x = Array2::from_shape_vec(
1019            (8, 2),
1020            vec![
1021                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1022            ],
1023        )
1024        .unwrap();
1025        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1026
1027        let model = ExtraTreesClassifier::<f64>::new()
1028            .with_n_estimators(20)
1029            .with_bootstrap(true)
1030            .with_random_state(42);
1031        assert!(model.bootstrap);
1032        let fitted = model.fit(&x, &y).unwrap();
1033        let preds = fitted.predict(&x).unwrap();
1034        assert_eq!(preds.len(), 8);
1035    }
1036
1037    #[test]
1038    fn test_ensemble_classifier_predict_proba() {
1039        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1040        let y = array![0, 0, 0, 1, 1, 1];
1041
1042        let model = ExtraTreesClassifier::<f64>::new()
1043            .with_n_estimators(10)
1044            .with_random_state(42);
1045        let fitted = model.fit(&x, &y).unwrap();
1046        let proba = fitted.predict_proba(&x).unwrap();
1047
1048        assert_eq!(proba.dim(), (6, 2));
1049        for i in 0..6 {
1050            let row_sum = proba.row(i).sum();
1051            assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
1052        }
1053    }
1054
1055    #[test]
1056    fn test_ensemble_classifier_feature_importances() {
1057        let x = Array2::from_shape_vec(
1058            (8, 2),
1059            vec![
1060                1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 5.0, 1.0, 6.0, 1.0, 7.0, 1.0, 8.0, 1.0,
1061            ],
1062        )
1063        .unwrap();
1064        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1065
1066        let model = ExtraTreesClassifier::<f64>::new()
1067            .with_n_estimators(20)
1068            .with_max_features(MaxFeatures::All)
1069            .with_random_state(42);
1070        let fitted = model.fit(&x, &y).unwrap();
1071        let importances = fitted.feature_importances();
1072
1073        assert_eq!(importances.len(), 2);
1074        let total: f64 = importances.sum();
1075        assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1076        // Feature 0 should dominate (feature 1 is constant).
1077        assert!(importances[0] > importances[1]);
1078    }
1079
1080    #[test]
1081    fn test_ensemble_classifier_n_estimators() {
1082        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1083        let y = array![0, 0, 1, 1];
1084
1085        let model = ExtraTreesClassifier::<f64>::new()
1086            .with_n_estimators(15)
1087            .with_random_state(42);
1088        let fitted = model.fit(&x, &y).unwrap();
1089        assert_eq!(fitted.n_estimators(), 15);
1090    }
1091
1092    #[test]
1093    fn test_ensemble_classifier_classes() {
1094        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1095        let y = array![0, 0, 0, 3, 3, 3]; // non-contiguous
1096
1097        let model = ExtraTreesClassifier::<f64>::new()
1098            .with_n_estimators(5)
1099            .with_random_state(42);
1100        let fitted = model.fit(&x, &y).unwrap();
1101        assert_eq!(fitted.classes(), &[0, 3]);
1102        assert_eq!(fitted.n_classes(), 2);
1103    }
1104
1105    #[test]
1106    fn test_ensemble_classifier_shape_mismatch() {
1107        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1108        let y = array![0, 0];
1109        let model = ExtraTreesClassifier::<f64>::new();
1110        assert!(model.fit(&x, &y).is_err());
1111    }
1112
1113    #[test]
1114    fn test_ensemble_classifier_empty_data() {
1115        let x = Array2::<f64>::zeros((0, 2));
1116        let y = Array1::<usize>::zeros(0);
1117        let model = ExtraTreesClassifier::<f64>::new();
1118        assert!(model.fit(&x, &y).is_err());
1119    }
1120
1121    #[test]
1122    fn test_ensemble_classifier_zero_estimators() {
1123        let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1124        let y = array![0, 1];
1125        let model = ExtraTreesClassifier::<f64>::new().with_n_estimators(0);
1126        assert!(model.fit(&x, &y).is_err());
1127    }
1128
1129    #[test]
1130    fn test_ensemble_classifier_deterministic() {
1131        let x = Array2::from_shape_vec(
1132            (8, 2),
1133            vec![
1134                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
1135            ],
1136        )
1137        .unwrap();
1138        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1139
1140        let model1 = ExtraTreesClassifier::<f64>::new()
1141            .with_n_estimators(10)
1142            .with_random_state(123);
1143        let model2 = ExtraTreesClassifier::<f64>::new()
1144            .with_n_estimators(10)
1145            .with_random_state(123);
1146
1147        let preds1 = model1.fit(&x, &y).unwrap().predict(&x).unwrap();
1148        let preds2 = model2.fit(&x, &y).unwrap().predict(&x).unwrap();
1149        assert_eq!(preds1, preds2);
1150    }
1151
1152    #[test]
1153    fn test_ensemble_classifier_predict_shape_mismatch() {
1154        let x =
1155            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1156        let y = array![0, 0, 1, 1];
1157
1158        let model = ExtraTreesClassifier::<f64>::new()
1159            .with_n_estimators(5)
1160            .with_random_state(42);
1161        let fitted = model.fit(&x, &y).unwrap();
1162
1163        let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1164        assert!(fitted.predict(&x_wrong).is_err());
1165    }
1166
1167    // -- ExtraTreesRegressor tests --
1168
1169    #[test]
1170    fn test_ensemble_regressor_simple() {
1171        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1172        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1173
1174        let model = ExtraTreesRegressor::<f64>::new()
1175            .with_n_estimators(20)
1176            .with_random_state(42);
1177        let fitted = model.fit(&x, &y).unwrap();
1178        let preds = fitted.predict(&x).unwrap();
1179
1180        assert_eq!(preds.len(), 6);
1181        // Ensemble should approximate the training data well.
1182        for i in 0..6 {
1183            assert_relative_eq!(preds[i], y[i], epsilon = 1.0);
1184        }
1185    }
1186
1187    #[test]
1188    fn test_ensemble_regressor_constant_target() {
1189        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1190        let y = array![5.0, 5.0, 5.0, 5.0];
1191
1192        let model = ExtraTreesRegressor::<f64>::new()
1193            .with_n_estimators(10)
1194            .with_random_state(42);
1195        let fitted = model.fit(&x, &y).unwrap();
1196        let preds = fitted.predict(&x).unwrap();
1197
1198        for &p in &preds {
1199            assert_relative_eq!(p, 5.0, epsilon = 1e-10);
1200        }
1201    }
1202
1203    #[test]
1204    fn test_ensemble_regressor_no_bootstrap() {
1205        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1206        let y = array![1.0, 2.0, 3.0, 4.0];
1207
1208        let model = ExtraTreesRegressor::<f64>::new()
1209            .with_n_estimators(10)
1210            .with_random_state(42);
1211        assert!(!model.bootstrap);
1212        let fitted = model.fit(&x, &y).unwrap();
1213        let preds = fitted.predict(&x).unwrap();
1214        assert_eq!(preds.len(), 4);
1215    }
1216
1217    #[test]
1218    fn test_ensemble_regressor_with_bootstrap() {
1219        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1220        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1221
1222        let model = ExtraTreesRegressor::<f64>::new()
1223            .with_n_estimators(10)
1224            .with_bootstrap(true)
1225            .with_random_state(42);
1226        assert!(model.bootstrap);
1227        let fitted = model.fit(&x, &y).unwrap();
1228        let preds = fitted.predict(&x).unwrap();
1229        assert_eq!(preds.len(), 6);
1230    }
1231
1232    #[test]
1233    fn test_ensemble_regressor_feature_importances() {
1234        let x = Array2::from_shape_vec(
1235            (8, 2),
1236            vec![
1237                1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1238            ],
1239        )
1240        .unwrap();
1241        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1242
1243        let model = ExtraTreesRegressor::<f64>::new()
1244            .with_n_estimators(20)
1245            .with_max_features(MaxFeatures::All)
1246            .with_random_state(42);
1247        let fitted = model.fit(&x, &y).unwrap();
1248        let importances = fitted.feature_importances();
1249
1250        assert_eq!(importances.len(), 2);
1251        let total: f64 = importances.sum();
1252        assert_relative_eq!(total, 1.0, epsilon = 1e-10);
1253        assert!(importances[0] > importances[1]);
1254    }
1255
1256    #[test]
1257    fn test_ensemble_regressor_n_estimators() {
1258        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1259        let y = array![1.0, 2.0, 3.0, 4.0];
1260
1261        let model = ExtraTreesRegressor::<f64>::new()
1262            .with_n_estimators(7)
1263            .with_random_state(42);
1264        let fitted = model.fit(&x, &y).unwrap();
1265        assert_eq!(fitted.n_estimators(), 7);
1266    }
1267
1268    #[test]
1269    fn test_ensemble_regressor_shape_mismatch() {
1270        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1271        let y = array![1.0, 2.0];
1272        let model = ExtraTreesRegressor::<f64>::new();
1273        assert!(model.fit(&x, &y).is_err());
1274    }
1275
1276    #[test]
1277    fn test_ensemble_regressor_empty_data() {
1278        let x = Array2::<f64>::zeros((0, 2));
1279        let y = Array1::<f64>::zeros(0);
1280        let model = ExtraTreesRegressor::<f64>::new();
1281        assert!(model.fit(&x, &y).is_err());
1282    }
1283
1284    #[test]
1285    fn test_ensemble_regressor_zero_estimators() {
1286        let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1287        let y = array![1.0, 2.0];
1288        let model = ExtraTreesRegressor::<f64>::new().with_n_estimators(0);
1289        assert!(model.fit(&x, &y).is_err());
1290    }
1291
1292    #[test]
1293    fn test_ensemble_regressor_deterministic() {
1294        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1295        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1296
1297        let model1 = ExtraTreesRegressor::<f64>::new()
1298            .with_n_estimators(10)
1299            .with_random_state(99);
1300        let model2 = ExtraTreesRegressor::<f64>::new()
1301            .with_n_estimators(10)
1302            .with_random_state(99);
1303
1304        let preds1 = model1.fit(&x, &y).unwrap().predict(&x).unwrap();
1305        let preds2 = model2.fit(&x, &y).unwrap().predict(&x).unwrap();
1306
1307        for i in 0..6 {
1308            assert_relative_eq!(preds1[i], preds2[i], epsilon = 1e-12);
1309        }
1310    }
1311
1312    #[test]
1313    fn test_ensemble_regressor_predict_shape_mismatch() {
1314        let x =
1315            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1316        let y = array![1.0, 2.0, 3.0, 4.0];
1317
1318        let model = ExtraTreesRegressor::<f64>::new()
1319            .with_n_estimators(5)
1320            .with_random_state(42);
1321        let fitted = model.fit(&x, &y).unwrap();
1322
1323        let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1324        assert!(fitted.predict(&x_wrong).is_err());
1325    }
1326
1327    // -- Builder tests --
1328
1329    #[test]
1330    fn test_ensemble_classifier_builder() {
1331        let model = ExtraTreesClassifier::<f64>::new()
1332            .with_n_estimators(50)
1333            .with_max_depth(Some(5))
1334            .with_min_samples_split(10)
1335            .with_min_samples_leaf(3)
1336            .with_max_features(MaxFeatures::Log2)
1337            .with_bootstrap(true)
1338            .with_criterion(ClassificationCriterion::Entropy)
1339            .with_random_state(42)
1340            .with_n_jobs(4);
1341
1342        assert_eq!(model.n_estimators, 50);
1343        assert_eq!(model.max_depth, Some(5));
1344        assert_eq!(model.min_samples_split, 10);
1345        assert_eq!(model.min_samples_leaf, 3);
1346        assert_eq!(model.max_features, MaxFeatures::Log2);
1347        assert!(model.bootstrap);
1348        assert_eq!(model.criterion, ClassificationCriterion::Entropy);
1349        assert_eq!(model.random_state, Some(42));
1350        assert_eq!(model.n_jobs, Some(4));
1351    }
1352
1353    #[test]
1354    fn test_ensemble_regressor_builder() {
1355        let model = ExtraTreesRegressor::<f64>::new()
1356            .with_n_estimators(25)
1357            .with_max_depth(Some(8))
1358            .with_min_samples_split(5)
1359            .with_min_samples_leaf(2)
1360            .with_max_features(MaxFeatures::Fraction(0.5))
1361            .with_bootstrap(true)
1362            .with_random_state(99)
1363            .with_n_jobs(2);
1364
1365        assert_eq!(model.n_estimators, 25);
1366        assert_eq!(model.max_depth, Some(8));
1367        assert_eq!(model.min_samples_split, 5);
1368        assert_eq!(model.min_samples_leaf, 2);
1369        assert_eq!(model.max_features, MaxFeatures::Fraction(0.5));
1370        assert!(model.bootstrap);
1371        assert_eq!(model.random_state, Some(99));
1372        assert_eq!(model.n_jobs, Some(2));
1373    }
1374
1375    #[test]
1376    fn test_ensemble_classifier_default() {
1377        let model = ExtraTreesClassifier::<f64>::default();
1378        assert_eq!(model.n_estimators, 100);
1379        assert_eq!(model.max_depth, None);
1380        assert_eq!(model.min_samples_split, 2);
1381        assert_eq!(model.min_samples_leaf, 1);
1382        assert_eq!(model.max_features, MaxFeatures::Sqrt);
1383        assert!(!model.bootstrap);
1384        assert_eq!(model.criterion, ClassificationCriterion::Gini);
1385        assert_eq!(model.random_state, None);
1386        assert_eq!(model.n_jobs, None);
1387    }
1388
1389    #[test]
1390    fn test_ensemble_regressor_default() {
1391        let model = ExtraTreesRegressor::<f64>::default();
1392        assert_eq!(model.n_estimators, 100);
1393        assert_eq!(model.max_depth, None);
1394        assert_eq!(model.min_samples_split, 2);
1395        assert_eq!(model.min_samples_leaf, 1);
1396        assert_eq!(model.max_features, MaxFeatures::All);
1397        assert!(!model.bootstrap);
1398        assert_eq!(model.random_state, None);
1399        assert_eq!(model.n_jobs, None);
1400    }
1401}