Skip to main content

ferrolearn_tree/
random_forest.rs

1//! Random forest classifiers and regressors.
2//!
3//! This module provides [`RandomForestClassifier`] and [`RandomForestRegressor`],
4//! which build ensembles of decision trees using bootstrap sampling and random
5//! feature subsets (bagging). Trees are built in parallel via `rayon`.
6//!
7//! # Examples
8//!
9//! ```
10//! use ferrolearn_tree::RandomForestClassifier;
11//! use ferrolearn_core::{Fit, Predict};
12//! use ndarray::{array, Array1, Array2};
13//!
14//! let x = Array2::from_shape_vec((8, 2), vec![
15//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,  4.0, 4.0,
16//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,  8.0, 9.0,
17//! ]).unwrap();
18//! let y = array![0, 0, 0, 0, 1, 1, 1, 1];
19//!
20//! let model = RandomForestClassifier::<f64>::new()
21//!     .with_n_estimators(10)
22//!     .with_random_state(42);
23//! let fitted = model.fit(&x, &y).unwrap();
24//! let preds = fitted.predict(&x).unwrap();
25//! ```
26
27use ferrolearn_core::error::FerroError;
28use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
29use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
30use ferrolearn_core::traits::{Fit, Predict};
31use ndarray::{Array1, Array2};
32use num_traits::{Float, FromPrimitive, ToPrimitive};
33use rand::SeedableRng;
34use rand::rngs::StdRng;
35use rand::seq::index::sample as rand_sample_indices;
36use rayon::prelude::*;
37use serde::{Deserialize, Serialize};
38
39use crate::decision_tree::{
40    self, ClassificationCriterion, Node, build_classification_tree_with_feature_subset,
41    build_regression_tree_with_feature_subset, compute_feature_importances,
42};
43
44// ---------------------------------------------------------------------------
45// MaxFeatures
46// ---------------------------------------------------------------------------
47
48/// Strategy for selecting the number of features considered at each split.
49#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
50pub enum MaxFeatures {
51    /// Use the square root of the total number of features (default for classifiers).
52    Sqrt,
53    /// Use the log2 of the total number of features.
54    Log2,
55    /// Use all features (default for regressors).
56    All,
57    /// Use a specific number of features.
58    Fixed(usize),
59    /// Use a fraction of the total number of features.
60    Fraction(f64),
61}
62
63/// Resolve the `MaxFeatures` strategy to a concrete number.
64fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
65    let result = match strategy {
66        MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
67        MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
68        MaxFeatures::All => n_features,
69        MaxFeatures::Fixed(n) => n.min(n_features),
70        MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
71    };
72    result.max(1).min(n_features)
73}
74
75/// Internal tree parameter struct reused from decision_tree.
76///
77/// Re-created here to avoid leaking internal details; the crate-internal
78/// struct is the same shape.
79fn make_tree_params(
80    max_depth: Option<usize>,
81    min_samples_split: usize,
82    min_samples_leaf: usize,
83) -> decision_tree::TreeParams {
84    decision_tree::TreeParams {
85        max_depth,
86        min_samples_split,
87        min_samples_leaf,
88    }
89}
90
91// ---------------------------------------------------------------------------
92// RandomForestClassifier
93// ---------------------------------------------------------------------------
94
95/// Random forest classifier.
96///
97/// Builds an ensemble of decision tree classifiers, each trained on a
98/// bootstrap sample with a random subset of features considered at each split.
99/// Final predictions are made by majority vote.
100///
101/// # Type Parameters
102///
103/// - `F`: The floating-point type (`f32` or `f64`).
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct RandomForestClassifier<F> {
106    /// Number of trees in the forest.
107    pub n_estimators: usize,
108    /// Maximum depth of each tree. `None` means unlimited.
109    pub max_depth: Option<usize>,
110    /// Strategy for the number of features considered at each split.
111    pub max_features: MaxFeatures,
112    /// Minimum number of samples required to split an internal node.
113    pub min_samples_split: usize,
114    /// Minimum number of samples required in a leaf node.
115    pub min_samples_leaf: usize,
116    /// Random seed for reproducibility. `None` means non-deterministic.
117    pub random_state: Option<u64>,
118    /// Splitting criterion.
119    pub criterion: ClassificationCriterion,
120    _marker: std::marker::PhantomData<F>,
121}
122
123impl<F: Float> RandomForestClassifier<F> {
124    /// Create a new `RandomForestClassifier` with default settings.
125    ///
126    /// Defaults: `n_estimators = 100`, `max_depth = None`,
127    /// `max_features = Sqrt`, `min_samples_split = 2`,
128    /// `min_samples_leaf = 1`, `random_state = None`,
129    /// `criterion = Gini`.
130    #[must_use]
131    pub fn new() -> Self {
132        Self {
133            n_estimators: 100,
134            max_depth: None,
135            max_features: MaxFeatures::Sqrt,
136            min_samples_split: 2,
137            min_samples_leaf: 1,
138            random_state: None,
139            criterion: ClassificationCriterion::Gini,
140            _marker: std::marker::PhantomData,
141        }
142    }
143
144    /// Set the number of trees.
145    #[must_use]
146    pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
147        self.n_estimators = n_estimators;
148        self
149    }
150
151    /// Set the maximum tree depth.
152    #[must_use]
153    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
154        self.max_depth = max_depth;
155        self
156    }
157
158    /// Set the maximum features strategy.
159    #[must_use]
160    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
161        self.max_features = max_features;
162        self
163    }
164
165    /// Set the minimum number of samples to split a node.
166    #[must_use]
167    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
168        self.min_samples_split = min_samples_split;
169        self
170    }
171
172    /// Set the minimum number of samples in a leaf.
173    #[must_use]
174    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
175        self.min_samples_leaf = min_samples_leaf;
176        self
177    }
178
179    /// Set the random seed for reproducibility.
180    #[must_use]
181    pub fn with_random_state(mut self, seed: u64) -> Self {
182        self.random_state = Some(seed);
183        self
184    }
185
186    /// Set the splitting criterion.
187    #[must_use]
188    pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
189        self.criterion = criterion;
190        self
191    }
192}
193
194impl<F: Float> Default for RandomForestClassifier<F> {
195    fn default() -> Self {
196        Self::new()
197    }
198}
199
200// ---------------------------------------------------------------------------
201// FittedRandomForestClassifier
202// ---------------------------------------------------------------------------
203
204/// A fitted random forest classifier.
205///
206/// Stores the ensemble of fitted decision trees and aggregates their
207/// predictions by majority vote.
208#[derive(Debug, Clone)]
209pub struct FittedRandomForestClassifier<F> {
210    /// Individual tree node vectors.
211    trees: Vec<Vec<Node<F>>>,
212    /// Sorted unique class labels.
213    classes: Vec<usize>,
214    /// Number of features.
215    n_features: usize,
216    /// Per-feature importance scores (mean decrease in impurity, normalised).
217    feature_importances: Array1<F>,
218}
219
220impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for RandomForestClassifier<F> {
221    type Fitted = FittedRandomForestClassifier<F>;
222    type Error = FerroError;
223
224    /// Fit the random forest by building `n_estimators` decision trees in parallel.
225    ///
226    /// Each tree is trained on a bootstrap sample of the data, considering only
227    /// a random subset of features at each split.
228    ///
229    /// # Errors
230    ///
231    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
232    /// numbers of samples.
233    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
234    /// Returns [`FerroError::InvalidParameter`] if `n_estimators` is 0.
235    fn fit(
236        &self,
237        x: &Array2<F>,
238        y: &Array1<usize>,
239    ) -> Result<FittedRandomForestClassifier<F>, FerroError> {
240        let (n_samples, n_features) = x.dim();
241
242        if n_samples != y.len() {
243            return Err(FerroError::ShapeMismatch {
244                expected: vec![n_samples],
245                actual: vec![y.len()],
246                context: "y length must match number of samples in X".into(),
247            });
248        }
249        if n_samples == 0 {
250            return Err(FerroError::InsufficientSamples {
251                required: 1,
252                actual: 0,
253                context: "RandomForestClassifier requires at least one sample".into(),
254            });
255        }
256        if self.n_estimators == 0 {
257            return Err(FerroError::InvalidParameter {
258                name: "n_estimators".into(),
259                reason: "must be at least 1".into(),
260            });
261        }
262
263        // Determine unique classes.
264        let mut classes: Vec<usize> = y.iter().copied().collect();
265        classes.sort_unstable();
266        classes.dedup();
267        let n_classes = classes.len();
268
269        let y_mapped: Vec<usize> = y
270            .iter()
271            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
272            .collect();
273
274        let max_features_n = resolve_max_features(self.max_features, n_features);
275        let params = make_tree_params(
276            self.max_depth,
277            self.min_samples_split,
278            self.min_samples_leaf,
279        );
280        let criterion = self.criterion;
281
282        // Generate per-tree seeds sequentially for determinism, then dispatch in parallel.
283        let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
284            let mut master_rng = StdRng::seed_from_u64(seed);
285            (0..self.n_estimators)
286                .map(|_| {
287                    use rand::RngCore;
288                    master_rng.next_u64()
289                })
290                .collect()
291        } else {
292            (0..self.n_estimators)
293                .map(|_| {
294                    use rand::RngCore;
295                    rand::rng().next_u64()
296                })
297                .collect()
298        };
299
300        // Build trees in parallel.
301        let trees: Vec<Vec<Node<F>>> = tree_seeds
302            .par_iter()
303            .map(|&seed| {
304                let mut rng = StdRng::seed_from_u64(seed);
305
306                // Bootstrap sample (with replacement).
307                let bootstrap_indices: Vec<usize> = (0..n_samples)
308                    .map(|_| {
309                        use rand::RngCore;
310                        (rng.next_u64() as usize) % n_samples
311                    })
312                    .collect();
313
314                // Random feature subset.
315                let feature_indices: Vec<usize> =
316                    rand_sample_indices(&mut rng, n_features, max_features_n).into_vec();
317
318                build_classification_tree_with_feature_subset(
319                    x,
320                    &y_mapped,
321                    n_classes,
322                    &bootstrap_indices,
323                    &feature_indices,
324                    &params,
325                    criterion,
326                )
327            })
328            .collect();
329
330        // Aggregate feature importances across trees.
331        let mut total_importances = Array1::<F>::zeros(n_features);
332        for tree_nodes in &trees {
333            let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
334            total_importances = total_importances + tree_imp;
335        }
336        let imp_sum: F = total_importances
337            .iter()
338            .copied()
339            .fold(F::zero(), |a, b| a + b);
340        if imp_sum > F::zero() {
341            total_importances.mapv_inplace(|v| v / imp_sum);
342        }
343
344        Ok(FittedRandomForestClassifier {
345            trees,
346            classes,
347            n_features,
348            feature_importances: total_importances,
349        })
350    }
351}
352
353impl<F: Float + Send + Sync + 'static> FittedRandomForestClassifier<F> {
354    /// Returns a reference to the individual tree node vectors.
355    #[must_use]
356    pub fn trees(&self) -> &[Vec<Node<F>>] {
357        &self.trees
358    }
359
360    /// Returns the number of features the model was trained on.
361    #[must_use]
362    pub fn n_features(&self) -> usize {
363        self.n_features
364    }
365}
366
367impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedRandomForestClassifier<F> {
368    type Output = Array1<usize>;
369    type Error = FerroError;
370
371    /// Predict class labels by majority vote across all trees.
372    ///
373    /// # Errors
374    ///
375    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
376    /// not match the fitted model.
377    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
378        if x.ncols() != self.n_features {
379            return Err(FerroError::ShapeMismatch {
380                expected: vec![self.n_features],
381                actual: vec![x.ncols()],
382                context: "number of features must match fitted model".into(),
383            });
384        }
385
386        let n_samples = x.nrows();
387        let n_classes = self.classes.len();
388        let mut predictions = Array1::zeros(n_samples);
389
390        for i in 0..n_samples {
391            let row = x.row(i);
392            let mut votes = vec![0usize; n_classes];
393
394            for tree_nodes in &self.trees {
395                let leaf_idx = decision_tree::traverse(tree_nodes, &row);
396                if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
397                    let class_idx = value.to_f64().map(|f| f.round() as usize).unwrap_or(0);
398                    if class_idx < n_classes {
399                        votes[class_idx] += 1;
400                    }
401                }
402            }
403
404            let winner = votes
405                .iter()
406                .enumerate()
407                .max_by_key(|&(_, &count)| count)
408                .map(|(idx, _)| idx)
409                .unwrap_or(0);
410            predictions[i] = self.classes[winner];
411        }
412
413        Ok(predictions)
414    }
415}
416
417impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
418    for FittedRandomForestClassifier<F>
419{
420    fn feature_importances(&self) -> &Array1<F> {
421        &self.feature_importances
422    }
423}
424
425impl<F: Float + Send + Sync + 'static> HasClasses for FittedRandomForestClassifier<F> {
426    fn classes(&self) -> &[usize] {
427        &self.classes
428    }
429
430    fn n_classes(&self) -> usize {
431        self.classes.len()
432    }
433}
434
435// Pipeline integration.
436impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
437    for RandomForestClassifier<F>
438{
439    fn fit_pipeline(
440        &self,
441        x: &Array2<F>,
442        y: &Array1<F>,
443    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
444        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
445        let fitted = self.fit(x, &y_usize)?;
446        Ok(Box::new(FittedForestClassifierPipelineAdapter(fitted)))
447    }
448}
449
450/// Pipeline adapter for `FittedRandomForestClassifier<F>`.
451struct FittedForestClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
452    FittedRandomForestClassifier<F>,
453);
454
455impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
456    for FittedForestClassifierPipelineAdapter<F>
457{
458    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
459        let preds = self.0.predict(x)?;
460        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
461    }
462}
463
464// ---------------------------------------------------------------------------
465// RandomForestRegressor
466// ---------------------------------------------------------------------------
467
468/// Random forest regressor.
469///
470/// Builds an ensemble of decision tree regressors, each trained on a
471/// bootstrap sample with a random subset of features considered at each split.
472/// Final predictions are the mean across all trees.
473///
474/// # Type Parameters
475///
476/// - `F`: The floating-point type (`f32` or `f64`).
477#[derive(Debug, Clone, Serialize, Deserialize)]
478pub struct RandomForestRegressor<F> {
479    /// Number of trees in the forest.
480    pub n_estimators: usize,
481    /// Maximum depth of each tree. `None` means unlimited.
482    pub max_depth: Option<usize>,
483    /// Strategy for the number of features considered at each split.
484    pub max_features: MaxFeatures,
485    /// Minimum number of samples required to split an internal node.
486    pub min_samples_split: usize,
487    /// Minimum number of samples required in a leaf node.
488    pub min_samples_leaf: usize,
489    /// Random seed for reproducibility. `None` means non-deterministic.
490    pub random_state: Option<u64>,
491    _marker: std::marker::PhantomData<F>,
492}
493
494impl<F: Float> RandomForestRegressor<F> {
495    /// Create a new `RandomForestRegressor` with default settings.
496    ///
497    /// Defaults: `n_estimators = 100`, `max_depth = None`,
498    /// `max_features = All`, `min_samples_split = 2`,
499    /// `min_samples_leaf = 1`, `random_state = None`.
500    #[must_use]
501    pub fn new() -> Self {
502        Self {
503            n_estimators: 100,
504            max_depth: None,
505            max_features: MaxFeatures::All,
506            min_samples_split: 2,
507            min_samples_leaf: 1,
508            random_state: None,
509            _marker: std::marker::PhantomData,
510        }
511    }
512
513    /// Set the number of trees.
514    #[must_use]
515    pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
516        self.n_estimators = n_estimators;
517        self
518    }
519
520    /// Set the maximum tree depth.
521    #[must_use]
522    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
523        self.max_depth = max_depth;
524        self
525    }
526
527    /// Set the maximum features strategy.
528    #[must_use]
529    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
530        self.max_features = max_features;
531        self
532    }
533
534    /// Set the minimum number of samples to split a node.
535    #[must_use]
536    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
537        self.min_samples_split = min_samples_split;
538        self
539    }
540
541    /// Set the minimum number of samples in a leaf.
542    #[must_use]
543    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
544        self.min_samples_leaf = min_samples_leaf;
545        self
546    }
547
548    /// Set the random seed for reproducibility.
549    #[must_use]
550    pub fn with_random_state(mut self, seed: u64) -> Self {
551        self.random_state = Some(seed);
552        self
553    }
554}
555
556impl<F: Float> Default for RandomForestRegressor<F> {
557    fn default() -> Self {
558        Self::new()
559    }
560}
561
562// ---------------------------------------------------------------------------
563// FittedRandomForestRegressor
564// ---------------------------------------------------------------------------
565
566/// A fitted random forest regressor.
567///
568/// Stores the ensemble of fitted decision trees and aggregates their
569/// predictions by averaging.
570#[derive(Debug, Clone)]
571pub struct FittedRandomForestRegressor<F> {
572    /// Individual tree node vectors.
573    trees: Vec<Vec<Node<F>>>,
574    /// Number of features.
575    n_features: usize,
576    /// Per-feature importance scores (mean decrease in impurity, normalised).
577    feature_importances: Array1<F>,
578}
579
580impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for RandomForestRegressor<F> {
581    type Fitted = FittedRandomForestRegressor<F>;
582    type Error = FerroError;
583
584    /// Fit the random forest regressor.
585    ///
586    /// # Errors
587    ///
588    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
589    /// numbers of samples.
590    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
591    /// Returns [`FerroError::InvalidParameter`] if `n_estimators` is 0.
592    fn fit(
593        &self,
594        x: &Array2<F>,
595        y: &Array1<F>,
596    ) -> Result<FittedRandomForestRegressor<F>, FerroError> {
597        let (n_samples, n_features) = x.dim();
598
599        if n_samples != y.len() {
600            return Err(FerroError::ShapeMismatch {
601                expected: vec![n_samples],
602                actual: vec![y.len()],
603                context: "y length must match number of samples in X".into(),
604            });
605        }
606        if n_samples == 0 {
607            return Err(FerroError::InsufficientSamples {
608                required: 1,
609                actual: 0,
610                context: "RandomForestRegressor requires at least one sample".into(),
611            });
612        }
613        if self.n_estimators == 0 {
614            return Err(FerroError::InvalidParameter {
615                name: "n_estimators".into(),
616                reason: "must be at least 1".into(),
617            });
618        }
619
620        let max_features_n = resolve_max_features(self.max_features, n_features);
621        let params = make_tree_params(
622            self.max_depth,
623            self.min_samples_split,
624            self.min_samples_leaf,
625        );
626
627        // Generate per-tree seeds sequentially.
628        let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
629            let mut master_rng = StdRng::seed_from_u64(seed);
630            (0..self.n_estimators)
631                .map(|_| {
632                    use rand::RngCore;
633                    master_rng.next_u64()
634                })
635                .collect()
636        } else {
637            (0..self.n_estimators)
638                .map(|_| {
639                    use rand::RngCore;
640                    rand::rng().next_u64()
641                })
642                .collect()
643        };
644
645        // Build trees in parallel.
646        let trees: Vec<Vec<Node<F>>> = tree_seeds
647            .par_iter()
648            .map(|&seed| {
649                let mut rng = StdRng::seed_from_u64(seed);
650
651                let bootstrap_indices: Vec<usize> = (0..n_samples)
652                    .map(|_| {
653                        use rand::RngCore;
654                        (rng.next_u64() as usize) % n_samples
655                    })
656                    .collect();
657
658                let feature_indices: Vec<usize> =
659                    rand_sample_indices(&mut rng, n_features, max_features_n).into_vec();
660
661                build_regression_tree_with_feature_subset(
662                    x,
663                    y,
664                    &bootstrap_indices,
665                    &feature_indices,
666                    &params,
667                )
668            })
669            .collect();
670
671        // Aggregate feature importances.
672        let mut total_importances = Array1::<F>::zeros(n_features);
673        for tree_nodes in &trees {
674            let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
675            total_importances = total_importances + tree_imp;
676        }
677        let imp_sum: F = total_importances
678            .iter()
679            .copied()
680            .fold(F::zero(), |a, b| a + b);
681        if imp_sum > F::zero() {
682            total_importances.mapv_inplace(|v| v / imp_sum);
683        }
684
685        Ok(FittedRandomForestRegressor {
686            trees,
687            n_features,
688            feature_importances: total_importances,
689        })
690    }
691}
692
693impl<F: Float + Send + Sync + 'static> FittedRandomForestRegressor<F> {
694    /// Returns a reference to the individual tree node vectors.
695    #[must_use]
696    pub fn trees(&self) -> &[Vec<Node<F>>] {
697        &self.trees
698    }
699
700    /// Returns the number of features the model was trained on.
701    #[must_use]
702    pub fn n_features(&self) -> usize {
703        self.n_features
704    }
705}
706
707impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedRandomForestRegressor<F> {
708    type Output = Array1<F>;
709    type Error = FerroError;
710
711    /// Predict target values by averaging across all trees.
712    ///
713    /// # Errors
714    ///
715    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
716    /// not match the fitted model.
717    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
718        if x.ncols() != self.n_features {
719            return Err(FerroError::ShapeMismatch {
720                expected: vec![self.n_features],
721                actual: vec![x.ncols()],
722                context: "number of features must match fitted model".into(),
723            });
724        }
725
726        let n_samples = x.nrows();
727        let n_trees_f = F::from(self.trees.len()).unwrap();
728        let mut predictions = Array1::zeros(n_samples);
729
730        for i in 0..n_samples {
731            let row = x.row(i);
732            let mut sum = F::zero();
733
734            for tree_nodes in &self.trees {
735                let leaf_idx = decision_tree::traverse(tree_nodes, &row);
736                if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
737                    sum = sum + value;
738                }
739            }
740
741            predictions[i] = sum / n_trees_f;
742        }
743
744        Ok(predictions)
745    }
746}
747
748impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedRandomForestRegressor<F> {
749    fn feature_importances(&self) -> &Array1<F> {
750        &self.feature_importances
751    }
752}
753
754// Pipeline integration.
755impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for RandomForestRegressor<F> {
756    fn fit_pipeline(
757        &self,
758        x: &Array2<F>,
759        y: &Array1<F>,
760    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
761        let fitted = self.fit(x, y)?;
762        Ok(Box::new(fitted))
763    }
764}
765
766impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F>
767    for FittedRandomForestRegressor<F>
768{
769    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
770        self.predict(x)
771    }
772}
773
774// ---------------------------------------------------------------------------
775// Tests
776// ---------------------------------------------------------------------------
777
778#[cfg(test)]
779mod tests {
780    use super::*;
781    use approx::assert_relative_eq;
782    use ndarray::array;
783
784    // -- Classifier tests --
785
786    #[test]
787    fn test_forest_classifier_simple() {
788        let x = Array2::from_shape_vec(
789            (8, 2),
790            vec![
791                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
792            ],
793        )
794        .unwrap();
795        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
796
797        let model = RandomForestClassifier::<f64>::new()
798            .with_n_estimators(20)
799            .with_random_state(42);
800        let fitted = model.fit(&x, &y).unwrap();
801        let preds = fitted.predict(&x).unwrap();
802
803        assert_eq!(preds.len(), 8);
804        for i in 0..4 {
805            assert_eq!(preds[i], 0);
806        }
807        for i in 4..8 {
808            assert_eq!(preds[i], 1);
809        }
810    }
811
812    #[test]
813    fn test_forest_classifier_reproducibility() {
814        let x = Array2::from_shape_vec(
815            (8, 2),
816            vec![
817                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
818            ],
819        )
820        .unwrap();
821        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
822
823        let model = RandomForestClassifier::<f64>::new()
824            .with_n_estimators(10)
825            .with_random_state(123);
826
827        let fitted1 = model.fit(&x, &y).unwrap();
828        let fitted2 = model.fit(&x, &y).unwrap();
829
830        let preds1 = fitted1.predict(&x).unwrap();
831        let preds2 = fitted2.predict(&x).unwrap();
832
833        assert_eq!(preds1, preds2);
834    }
835
836    #[test]
837    fn test_forest_classifier_feature_importances() {
838        let x = Array2::from_shape_vec(
839            (10, 3),
840            vec![
841                1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
842                0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
843            ],
844        )
845        .unwrap();
846        let y = array![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
847
848        let model = RandomForestClassifier::<f64>::new()
849            .with_n_estimators(20)
850            .with_max_features(MaxFeatures::All)
851            .with_random_state(42);
852        let fitted = model.fit(&x, &y).unwrap();
853        let importances = fitted.feature_importances();
854
855        assert_eq!(importances.len(), 3);
856        assert!(importances[0] > importances[1]);
857        assert!(importances[0] > importances[2]);
858    }
859
860    #[test]
861    fn test_forest_classifier_has_classes() {
862        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
863        let y = array![0, 1, 2, 0, 1, 2];
864
865        let model = RandomForestClassifier::<f64>::new()
866            .with_n_estimators(5)
867            .with_random_state(0);
868        let fitted = model.fit(&x, &y).unwrap();
869
870        assert_eq!(fitted.classes(), &[0, 1, 2]);
871        assert_eq!(fitted.n_classes(), 3);
872    }
873
874    #[test]
875    fn test_forest_classifier_shape_mismatch_fit() {
876        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
877        let y = array![0, 1];
878
879        let model = RandomForestClassifier::<f64>::new().with_n_estimators(5);
880        assert!(model.fit(&x, &y).is_err());
881    }
882
883    #[test]
884    fn test_forest_classifier_shape_mismatch_predict() {
885        let x =
886            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
887        let y = array![0, 0, 1, 1];
888
889        let model = RandomForestClassifier::<f64>::new()
890            .with_n_estimators(5)
891            .with_random_state(0);
892        let fitted = model.fit(&x, &y).unwrap();
893
894        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
895        assert!(fitted.predict(&x_bad).is_err());
896    }
897
898    #[test]
899    fn test_forest_classifier_empty_data() {
900        let x = Array2::<f64>::zeros((0, 2));
901        let y = Array1::<usize>::zeros(0);
902
903        let model = RandomForestClassifier::<f64>::new().with_n_estimators(5);
904        assert!(model.fit(&x, &y).is_err());
905    }
906
907    #[test]
908    fn test_forest_classifier_zero_estimators() {
909        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
910        let y = array![0, 0, 1, 1];
911
912        let model = RandomForestClassifier::<f64>::new().with_n_estimators(0);
913        assert!(model.fit(&x, &y).is_err());
914    }
915
916    #[test]
917    fn test_forest_classifier_single_tree() {
918        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
919        let y = array![0, 0, 0, 1, 1, 1];
920
921        let model = RandomForestClassifier::<f64>::new()
922            .with_n_estimators(1)
923            .with_max_features(MaxFeatures::All)
924            .with_random_state(42);
925        let fitted = model.fit(&x, &y).unwrap();
926        let preds = fitted.predict(&x).unwrap();
927
928        assert_eq!(preds.len(), 6);
929    }
930
931    #[test]
932    fn test_forest_classifier_pipeline_integration() {
933        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
934        let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
935
936        let model = RandomForestClassifier::<f64>::new()
937            .with_n_estimators(5)
938            .with_random_state(42);
939        let fitted = model.fit_pipeline(&x, &y).unwrap();
940        let preds = fitted.predict_pipeline(&x).unwrap();
941        assert_eq!(preds.len(), 6);
942    }
943
944    #[test]
945    fn test_forest_classifier_max_depth() {
946        let x =
947            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
948        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
949
950        let model = RandomForestClassifier::<f64>::new()
951            .with_n_estimators(10)
952            .with_max_depth(Some(1))
953            .with_max_features(MaxFeatures::All)
954            .with_random_state(42);
955        let fitted = model.fit(&x, &y).unwrap();
956        let preds = fitted.predict(&x).unwrap();
957
958        assert_eq!(preds.len(), 8);
959    }
960
961    // -- Regressor tests --
962
963    #[test]
964    fn test_forest_regressor_simple() {
965        let x =
966            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
967        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
968
969        let model = RandomForestRegressor::<f64>::new()
970            .with_n_estimators(50)
971            .with_random_state(42);
972        let fitted = model.fit(&x, &y).unwrap();
973        let preds = fitted.predict(&x).unwrap();
974
975        assert_eq!(preds.len(), 8);
976        for i in 0..4 {
977            assert!(preds[i] < 3.0, "Expected ~1.0, got {}", preds[i]);
978        }
979        for i in 4..8 {
980            assert!(preds[i] > 3.0, "Expected ~5.0, got {}", preds[i]);
981        }
982    }
983
984    #[test]
985    fn test_forest_regressor_reproducibility() {
986        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
987        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
988
989        let model = RandomForestRegressor::<f64>::new()
990            .with_n_estimators(10)
991            .with_random_state(99);
992
993        let fitted1 = model.fit(&x, &y).unwrap();
994        let fitted2 = model.fit(&x, &y).unwrap();
995
996        let preds1 = fitted1.predict(&x).unwrap();
997        let preds2 = fitted2.predict(&x).unwrap();
998
999        for (p1, p2) in preds1.iter().zip(preds2.iter()) {
1000            assert_relative_eq!(*p1, *p2, epsilon = 1e-10);
1001        }
1002    }
1003
1004    #[test]
1005    fn test_forest_regressor_feature_importances() {
1006        let x = Array2::from_shape_vec(
1007            (8, 2),
1008            vec![
1009                1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1010            ],
1011        )
1012        .unwrap();
1013        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1014
1015        let model = RandomForestRegressor::<f64>::new()
1016            .with_n_estimators(20)
1017            .with_max_features(MaxFeatures::All)
1018            .with_random_state(42);
1019        let fitted = model.fit(&x, &y).unwrap();
1020        let importances = fitted.feature_importances();
1021
1022        assert_eq!(importances.len(), 2);
1023        assert!(importances[0] > importances[1]);
1024    }
1025
1026    #[test]
1027    fn test_forest_regressor_shape_mismatch_fit() {
1028        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1029        let y = array![1.0, 2.0];
1030
1031        let model = RandomForestRegressor::<f64>::new().with_n_estimators(5);
1032        assert!(model.fit(&x, &y).is_err());
1033    }
1034
1035    #[test]
1036    fn test_forest_regressor_shape_mismatch_predict() {
1037        let x =
1038            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1039        let y = array![1.0, 2.0, 3.0, 4.0];
1040
1041        let model = RandomForestRegressor::<f64>::new()
1042            .with_n_estimators(5)
1043            .with_random_state(0);
1044        let fitted = model.fit(&x, &y).unwrap();
1045
1046        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1047        assert!(fitted.predict(&x_bad).is_err());
1048    }
1049
1050    #[test]
1051    fn test_forest_regressor_empty_data() {
1052        let x = Array2::<f64>::zeros((0, 2));
1053        let y = Array1::<f64>::zeros(0);
1054
1055        let model = RandomForestRegressor::<f64>::new().with_n_estimators(5);
1056        assert!(model.fit(&x, &y).is_err());
1057    }
1058
1059    #[test]
1060    fn test_forest_regressor_zero_estimators() {
1061        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1062        let y = array![1.0, 2.0, 3.0, 4.0];
1063
1064        let model = RandomForestRegressor::<f64>::new().with_n_estimators(0);
1065        assert!(model.fit(&x, &y).is_err());
1066    }
1067
1068    #[test]
1069    fn test_forest_regressor_pipeline_integration() {
1070        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1071        let y = array![1.0, 2.0, 3.0, 4.0];
1072
1073        let model = RandomForestRegressor::<f64>::new()
1074            .with_n_estimators(5)
1075            .with_random_state(42);
1076        let fitted = model.fit_pipeline(&x, &y).unwrap();
1077        let preds = fitted.predict_pipeline(&x).unwrap();
1078        assert_eq!(preds.len(), 4);
1079    }
1080
1081    #[test]
1082    fn test_forest_regressor_max_features_strategies() {
1083        let x = Array2::from_shape_vec(
1084            (8, 4),
1085            vec![
1086                1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0, 3.0, 4.0, 5.0, 6.0, 4.0, 5.0, 6.0, 7.0,
1087                5.0, 6.0, 7.0, 8.0, 6.0, 7.0, 8.0, 9.0, 7.0, 8.0, 9.0, 10.0, 8.0, 9.0, 10.0, 11.0,
1088            ],
1089        )
1090        .unwrap();
1091        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1092
1093        for strategy in &[
1094            MaxFeatures::Sqrt,
1095            MaxFeatures::Log2,
1096            MaxFeatures::All,
1097            MaxFeatures::Fixed(2),
1098            MaxFeatures::Fraction(0.5),
1099        ] {
1100            let model = RandomForestRegressor::<f64>::new()
1101                .with_n_estimators(5)
1102                .with_max_features(*strategy)
1103                .with_random_state(42);
1104            let fitted = model.fit(&x, &y).unwrap();
1105            let preds = fitted.predict(&x).unwrap();
1106            assert_eq!(preds.len(), 8);
1107        }
1108    }
1109
1110    // -- MaxFeatures resolution tests --
1111
1112    #[test]
1113    fn test_resolve_max_features_sqrt() {
1114        assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 9), 3);
1115        assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 10), 4);
1116        assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 1), 1);
1117    }
1118
1119    #[test]
1120    fn test_resolve_max_features_log2() {
1121        assert_eq!(resolve_max_features(MaxFeatures::Log2, 8), 3);
1122        assert_eq!(resolve_max_features(MaxFeatures::Log2, 1), 1);
1123    }
1124
1125    #[test]
1126    fn test_resolve_max_features_all() {
1127        assert_eq!(resolve_max_features(MaxFeatures::All, 10), 10);
1128        assert_eq!(resolve_max_features(MaxFeatures::All, 1), 1);
1129    }
1130
1131    #[test]
1132    fn test_resolve_max_features_fixed() {
1133        assert_eq!(resolve_max_features(MaxFeatures::Fixed(3), 10), 3);
1134        assert_eq!(resolve_max_features(MaxFeatures::Fixed(20), 10), 10);
1135    }
1136
1137    #[test]
1138    fn test_resolve_max_features_fraction() {
1139        assert_eq!(resolve_max_features(MaxFeatures::Fraction(0.5), 10), 5);
1140        assert_eq!(resolve_max_features(MaxFeatures::Fraction(0.1), 10), 1);
1141    }
1142
1143    #[test]
1144    fn test_forest_classifier_f32_support() {
1145        let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1146        let y = array![0, 0, 0, 1, 1, 1];
1147
1148        let model = RandomForestClassifier::<f32>::new()
1149            .with_n_estimators(5)
1150            .with_random_state(42);
1151        let fitted = model.fit(&x, &y).unwrap();
1152        let preds = fitted.predict(&x).unwrap();
1153        assert_eq!(preds.len(), 6);
1154    }
1155
1156    #[test]
1157    fn test_forest_regressor_f32_support() {
1158        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1159        let y = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
1160
1161        let model = RandomForestRegressor::<f32>::new()
1162            .with_n_estimators(5)
1163            .with_random_state(42);
1164        let fitted = model.fit(&x, &y).unwrap();
1165        let preds = fitted.predict(&x).unwrap();
1166        assert_eq!(preds.len(), 4);
1167    }
1168}