Skip to main content

ferrolearn_tree/
random_forest.rs

1//! Random forest classifiers and regressors.
2//!
3//! This module provides [`RandomForestClassifier`] and [`RandomForestRegressor`],
4//! which build ensembles of decision trees using bootstrap sampling and random
5//! feature subsets (bagging). Trees are built in parallel via `rayon`.
6//!
7//! # Examples
8//!
9//! ```
10//! use ferrolearn_tree::RandomForestClassifier;
11//! use ferrolearn_core::{Fit, Predict};
12//! use ndarray::{array, Array1, Array2};
13//!
14//! let x = Array2::from_shape_vec((8, 2), vec![
15//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,  4.0, 4.0,
16//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,  8.0, 9.0,
17//! ]).unwrap();
18//! let y = array![0, 0, 0, 0, 1, 1, 1, 1];
19//!
20//! let model = RandomForestClassifier::<f64>::new()
21//!     .with_n_estimators(10)
22//!     .with_random_state(42);
23//! let fitted = model.fit(&x, &y).unwrap();
24//! let preds = fitted.predict(&x).unwrap();
25//! ```
26
27use ferrolearn_core::error::FerroError;
28use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
29use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
30use ferrolearn_core::traits::{Fit, Predict};
31use ndarray::{Array1, Array2};
32use num_traits::{Float, FromPrimitive, ToPrimitive};
33use rand::SeedableRng;
34use rand::rngs::StdRng;
35use rand::seq::index::sample as rand_sample_indices;
36use rayon::prelude::*;
37use serde::{Deserialize, Serialize};
38
39use crate::decision_tree::{
40    self, ClassificationCriterion, Node, build_classification_tree_with_feature_subset,
41    build_regression_tree_with_feature_subset, compute_feature_importances,
42};
43
44// ---------------------------------------------------------------------------
45// MaxFeatures
46// ---------------------------------------------------------------------------
47
48/// Strategy for selecting the number of features considered at each split.
49#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
50pub enum MaxFeatures {
51    /// Use the square root of the total number of features (default for classifiers).
52    Sqrt,
53    /// Use the log2 of the total number of features.
54    Log2,
55    /// Use all features (default for regressors).
56    All,
57    /// Use a specific number of features.
58    Fixed(usize),
59    /// Use a fraction of the total number of features.
60    Fraction(f64),
61}
62
63/// Resolve the `MaxFeatures` strategy to a concrete number.
64fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
65    let result = match strategy {
66        MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
67        MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
68        MaxFeatures::All => n_features,
69        MaxFeatures::Fixed(n) => n.min(n_features),
70        MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
71    };
72    result.max(1).min(n_features)
73}
74
75/// Internal tree parameter struct reused from decision_tree.
76///
77/// Re-created here to avoid leaking internal details; the crate-internal
78/// struct is the same shape.
79fn make_tree_params(
80    max_depth: Option<usize>,
81    min_samples_split: usize,
82    min_samples_leaf: usize,
83) -> decision_tree::TreeParams {
84    decision_tree::TreeParams {
85        max_depth,
86        min_samples_split,
87        min_samples_leaf,
88    }
89}
90
91// ---------------------------------------------------------------------------
92// RandomForestClassifier
93// ---------------------------------------------------------------------------
94
95/// Random forest classifier.
96///
97/// Builds an ensemble of decision tree classifiers, each trained on a
98/// bootstrap sample with a random subset of features considered at each split.
99/// Final predictions are made by majority vote.
100///
101/// # Type Parameters
102///
103/// - `F`: The floating-point type (`f32` or `f64`).
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct RandomForestClassifier<F> {
106    /// Number of trees in the forest.
107    pub n_estimators: usize,
108    /// Maximum depth of each tree. `None` means unlimited.
109    pub max_depth: Option<usize>,
110    /// Strategy for the number of features considered at each split.
111    pub max_features: MaxFeatures,
112    /// Minimum number of samples required to split an internal node.
113    pub min_samples_split: usize,
114    /// Minimum number of samples required in a leaf node.
115    pub min_samples_leaf: usize,
116    /// Random seed for reproducibility. `None` means non-deterministic.
117    pub random_state: Option<u64>,
118    /// Splitting criterion.
119    pub criterion: ClassificationCriterion,
120    _marker: std::marker::PhantomData<F>,
121}
122
123impl<F: Float> RandomForestClassifier<F> {
124    /// Create a new `RandomForestClassifier` with default settings.
125    ///
126    /// Defaults: `n_estimators = 100`, `max_depth = None`,
127    /// `max_features = Sqrt`, `min_samples_split = 2`,
128    /// `min_samples_leaf = 1`, `random_state = None`,
129    /// `criterion = Gini`.
130    #[must_use]
131    pub fn new() -> Self {
132        Self {
133            n_estimators: 100,
134            max_depth: None,
135            max_features: MaxFeatures::Sqrt,
136            min_samples_split: 2,
137            min_samples_leaf: 1,
138            random_state: None,
139            criterion: ClassificationCriterion::Gini,
140            _marker: std::marker::PhantomData,
141        }
142    }
143
144    /// Set the number of trees.
145    #[must_use]
146    pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
147        self.n_estimators = n_estimators;
148        self
149    }
150
151    /// Set the maximum tree depth.
152    #[must_use]
153    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
154        self.max_depth = max_depth;
155        self
156    }
157
158    /// Set the maximum features strategy.
159    #[must_use]
160    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
161        self.max_features = max_features;
162        self
163    }
164
165    /// Set the minimum number of samples to split a node.
166    #[must_use]
167    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
168        self.min_samples_split = min_samples_split;
169        self
170    }
171
172    /// Set the minimum number of samples in a leaf.
173    #[must_use]
174    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
175        self.min_samples_leaf = min_samples_leaf;
176        self
177    }
178
179    /// Set the random seed for reproducibility.
180    #[must_use]
181    pub fn with_random_state(mut self, seed: u64) -> Self {
182        self.random_state = Some(seed);
183        self
184    }
185
186    /// Set the splitting criterion.
187    #[must_use]
188    pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
189        self.criterion = criterion;
190        self
191    }
192}
193
194impl<F: Float> Default for RandomForestClassifier<F> {
195    fn default() -> Self {
196        Self::new()
197    }
198}
199
200// ---------------------------------------------------------------------------
201// FittedRandomForestClassifier
202// ---------------------------------------------------------------------------
203
204/// A fitted random forest classifier.
205///
206/// Stores the ensemble of fitted decision trees and aggregates their
207/// predictions by majority vote.
208#[derive(Debug, Clone)]
209pub struct FittedRandomForestClassifier<F> {
210    /// Individual tree node vectors.
211    trees: Vec<Vec<Node<F>>>,
212    /// Sorted unique class labels.
213    classes: Vec<usize>,
214    /// Number of features.
215    n_features: usize,
216    /// Per-feature importance scores (mean decrease in impurity, normalised).
217    feature_importances: Array1<F>,
218}
219
220impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for RandomForestClassifier<F> {
221    type Fitted = FittedRandomForestClassifier<F>;
222    type Error = FerroError;
223
224    /// Fit the random forest by building `n_estimators` decision trees in parallel.
225    ///
226    /// Each tree is trained on a bootstrap sample of the data, considering only
227    /// a random subset of features at each split.
228    ///
229    /// # Errors
230    ///
231    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
232    /// numbers of samples.
233    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
234    /// Returns [`FerroError::InvalidParameter`] if `n_estimators` is 0.
235    fn fit(
236        &self,
237        x: &Array2<F>,
238        y: &Array1<usize>,
239    ) -> Result<FittedRandomForestClassifier<F>, FerroError> {
240        let (n_samples, n_features) = x.dim();
241
242        if n_samples != y.len() {
243            return Err(FerroError::ShapeMismatch {
244                expected: vec![n_samples],
245                actual: vec![y.len()],
246                context: "y length must match number of samples in X".into(),
247            });
248        }
249        if n_samples == 0 {
250            return Err(FerroError::InsufficientSamples {
251                required: 1,
252                actual: 0,
253                context: "RandomForestClassifier requires at least one sample".into(),
254            });
255        }
256        if self.n_estimators == 0 {
257            return Err(FerroError::InvalidParameter {
258                name: "n_estimators".into(),
259                reason: "must be at least 1".into(),
260            });
261        }
262
263        // Determine unique classes.
264        let mut classes: Vec<usize> = y.iter().copied().collect();
265        classes.sort_unstable();
266        classes.dedup();
267        let n_classes = classes.len();
268
269        let y_mapped: Vec<usize> = y
270            .iter()
271            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
272            .collect();
273
274        let max_features_n = resolve_max_features(self.max_features, n_features);
275        let params = make_tree_params(
276            self.max_depth,
277            self.min_samples_split,
278            self.min_samples_leaf,
279        );
280        let criterion = self.criterion;
281
282        // Generate per-tree seeds sequentially for determinism, then dispatch in parallel.
283        let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
284            let mut master_rng = StdRng::seed_from_u64(seed);
285            (0..self.n_estimators)
286                .map(|_| {
287                    use rand::RngCore;
288                    master_rng.next_u64()
289                })
290                .collect()
291        } else {
292            (0..self.n_estimators)
293                .map(|_| {
294                    use rand::RngCore;
295                    rand::rng().next_u64()
296                })
297                .collect()
298        };
299
300        // Build trees in parallel.
301        let trees: Vec<Vec<Node<F>>> = tree_seeds
302            .par_iter()
303            .map(|&seed| {
304                let mut rng = StdRng::seed_from_u64(seed);
305
306                // Bootstrap sample (with replacement).
307                let bootstrap_indices: Vec<usize> = (0..n_samples)
308                    .map(|_| {
309                        use rand::RngCore;
310                        (rng.next_u64() as usize) % n_samples
311                    })
312                    .collect();
313
314                // Random feature subset.
315                let feature_indices: Vec<usize> =
316                    rand_sample_indices(&mut rng, n_features, max_features_n).into_vec();
317
318                build_classification_tree_with_feature_subset(
319                    x,
320                    &y_mapped,
321                    n_classes,
322                    &bootstrap_indices,
323                    &feature_indices,
324                    &params,
325                    criterion,
326                )
327            })
328            .collect();
329
330        // Aggregate feature importances across trees.
331        let mut total_importances = Array1::<F>::zeros(n_features);
332        for tree_nodes in &trees {
333            let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
334            total_importances = total_importances + tree_imp;
335        }
336        let imp_sum: F = total_importances
337            .iter()
338            .copied()
339            .fold(F::zero(), |a, b| a + b);
340        if imp_sum > F::zero() {
341            total_importances.mapv_inplace(|v| v / imp_sum);
342        }
343
344        Ok(FittedRandomForestClassifier {
345            trees,
346            classes,
347            n_features,
348            feature_importances: total_importances,
349        })
350    }
351}
352
353impl<F: Float + Send + Sync + 'static> FittedRandomForestClassifier<F> {
354    /// Returns a reference to the individual tree node vectors.
355    #[must_use]
356    pub fn trees(&self) -> &[Vec<Node<F>>] {
357        &self.trees
358    }
359
360    /// Returns the number of features the model was trained on.
361    #[must_use]
362    pub fn n_features(&self) -> usize {
363        self.n_features
364    }
365}
366
367impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedRandomForestClassifier<F> {
368    type Output = Array1<usize>;
369    type Error = FerroError;
370
371    /// Predict class labels by majority vote across all trees.
372    ///
373    /// # Errors
374    ///
375    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
376    /// not match the fitted model.
377    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
378        if x.ncols() != self.n_features {
379            return Err(FerroError::ShapeMismatch {
380                expected: vec![self.n_features],
381                actual: vec![x.ncols()],
382                context: "number of features must match fitted model".into(),
383            });
384        }
385
386        let n_samples = x.nrows();
387        let n_classes = self.classes.len();
388        let mut predictions = Array1::zeros(n_samples);
389
390        for i in 0..n_samples {
391            let row = x.row(i);
392            let mut votes = vec![0usize; n_classes];
393
394            for tree_nodes in &self.trees {
395                let leaf_idx = decision_tree::traverse(tree_nodes, &row);
396                if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
397                    let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
398                    if class_idx < n_classes {
399                        votes[class_idx] += 1;
400                    }
401                }
402            }
403
404            let winner = votes
405                .iter()
406                .enumerate()
407                .max_by_key(|&(_, &count)| count)
408                .map_or(0, |(idx, _)| idx);
409            predictions[i] = self.classes[winner];
410        }
411
412        Ok(predictions)
413    }
414}
415
416impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
417    for FittedRandomForestClassifier<F>
418{
419    fn feature_importances(&self) -> &Array1<F> {
420        &self.feature_importances
421    }
422}
423
424impl<F: Float + Send + Sync + 'static> HasClasses for FittedRandomForestClassifier<F> {
425    fn classes(&self) -> &[usize] {
426        &self.classes
427    }
428
429    fn n_classes(&self) -> usize {
430        self.classes.len()
431    }
432}
433
434// Pipeline integration.
435impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
436    for RandomForestClassifier<F>
437{
438    fn fit_pipeline(
439        &self,
440        x: &Array2<F>,
441        y: &Array1<F>,
442    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
443        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
444        let fitted = self.fit(x, &y_usize)?;
445        Ok(Box::new(FittedForestClassifierPipelineAdapter(fitted)))
446    }
447}
448
449/// Pipeline adapter for `FittedRandomForestClassifier<F>`.
450struct FittedForestClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
451    FittedRandomForestClassifier<F>,
452);
453
454impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
455    for FittedForestClassifierPipelineAdapter<F>
456{
457    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
458        let preds = self.0.predict(x)?;
459        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
460    }
461}
462
463// ---------------------------------------------------------------------------
464// RandomForestRegressor
465// ---------------------------------------------------------------------------
466
467/// Random forest regressor.
468///
469/// Builds an ensemble of decision tree regressors, each trained on a
470/// bootstrap sample with a random subset of features considered at each split.
471/// Final predictions are the mean across all trees.
472///
473/// # Type Parameters
474///
475/// - `F`: The floating-point type (`f32` or `f64`).
476#[derive(Debug, Clone, Serialize, Deserialize)]
477pub struct RandomForestRegressor<F> {
478    /// Number of trees in the forest.
479    pub n_estimators: usize,
480    /// Maximum depth of each tree. `None` means unlimited.
481    pub max_depth: Option<usize>,
482    /// Strategy for the number of features considered at each split.
483    pub max_features: MaxFeatures,
484    /// Minimum number of samples required to split an internal node.
485    pub min_samples_split: usize,
486    /// Minimum number of samples required in a leaf node.
487    pub min_samples_leaf: usize,
488    /// Random seed for reproducibility. `None` means non-deterministic.
489    pub random_state: Option<u64>,
490    _marker: std::marker::PhantomData<F>,
491}
492
493impl<F: Float> RandomForestRegressor<F> {
494    /// Create a new `RandomForestRegressor` with default settings.
495    ///
496    /// Defaults: `n_estimators = 100`, `max_depth = None`,
497    /// `max_features = All`, `min_samples_split = 2`,
498    /// `min_samples_leaf = 1`, `random_state = None`.
499    #[must_use]
500    pub fn new() -> Self {
501        Self {
502            n_estimators: 100,
503            max_depth: None,
504            max_features: MaxFeatures::All,
505            min_samples_split: 2,
506            min_samples_leaf: 1,
507            random_state: None,
508            _marker: std::marker::PhantomData,
509        }
510    }
511
512    /// Set the number of trees.
513    #[must_use]
514    pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
515        self.n_estimators = n_estimators;
516        self
517    }
518
519    /// Set the maximum tree depth.
520    #[must_use]
521    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
522        self.max_depth = max_depth;
523        self
524    }
525
526    /// Set the maximum features strategy.
527    #[must_use]
528    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
529        self.max_features = max_features;
530        self
531    }
532
533    /// Set the minimum number of samples to split a node.
534    #[must_use]
535    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
536        self.min_samples_split = min_samples_split;
537        self
538    }
539
540    /// Set the minimum number of samples in a leaf.
541    #[must_use]
542    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
543        self.min_samples_leaf = min_samples_leaf;
544        self
545    }
546
547    /// Set the random seed for reproducibility.
548    #[must_use]
549    pub fn with_random_state(mut self, seed: u64) -> Self {
550        self.random_state = Some(seed);
551        self
552    }
553}
554
555impl<F: Float> Default for RandomForestRegressor<F> {
556    fn default() -> Self {
557        Self::new()
558    }
559}
560
561// ---------------------------------------------------------------------------
562// FittedRandomForestRegressor
563// ---------------------------------------------------------------------------
564
565/// A fitted random forest regressor.
566///
567/// Stores the ensemble of fitted decision trees and aggregates their
568/// predictions by averaging.
569#[derive(Debug, Clone)]
570pub struct FittedRandomForestRegressor<F> {
571    /// Individual tree node vectors.
572    trees: Vec<Vec<Node<F>>>,
573    /// Number of features.
574    n_features: usize,
575    /// Per-feature importance scores (mean decrease in impurity, normalised).
576    feature_importances: Array1<F>,
577}
578
579impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for RandomForestRegressor<F> {
580    type Fitted = FittedRandomForestRegressor<F>;
581    type Error = FerroError;
582
583    /// Fit the random forest regressor.
584    ///
585    /// # Errors
586    ///
587    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
588    /// numbers of samples.
589    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
590    /// Returns [`FerroError::InvalidParameter`] if `n_estimators` is 0.
591    fn fit(
592        &self,
593        x: &Array2<F>,
594        y: &Array1<F>,
595    ) -> Result<FittedRandomForestRegressor<F>, FerroError> {
596        let (n_samples, n_features) = x.dim();
597
598        if n_samples != y.len() {
599            return Err(FerroError::ShapeMismatch {
600                expected: vec![n_samples],
601                actual: vec![y.len()],
602                context: "y length must match number of samples in X".into(),
603            });
604        }
605        if n_samples == 0 {
606            return Err(FerroError::InsufficientSamples {
607                required: 1,
608                actual: 0,
609                context: "RandomForestRegressor requires at least one sample".into(),
610            });
611        }
612        if self.n_estimators == 0 {
613            return Err(FerroError::InvalidParameter {
614                name: "n_estimators".into(),
615                reason: "must be at least 1".into(),
616            });
617        }
618
619        let max_features_n = resolve_max_features(self.max_features, n_features);
620        let params = make_tree_params(
621            self.max_depth,
622            self.min_samples_split,
623            self.min_samples_leaf,
624        );
625
626        // Generate per-tree seeds sequentially.
627        let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
628            let mut master_rng = StdRng::seed_from_u64(seed);
629            (0..self.n_estimators)
630                .map(|_| {
631                    use rand::RngCore;
632                    master_rng.next_u64()
633                })
634                .collect()
635        } else {
636            (0..self.n_estimators)
637                .map(|_| {
638                    use rand::RngCore;
639                    rand::rng().next_u64()
640                })
641                .collect()
642        };
643
644        // Build trees in parallel.
645        let trees: Vec<Vec<Node<F>>> = tree_seeds
646            .par_iter()
647            .map(|&seed| {
648                let mut rng = StdRng::seed_from_u64(seed);
649
650                let bootstrap_indices: Vec<usize> = (0..n_samples)
651                    .map(|_| {
652                        use rand::RngCore;
653                        (rng.next_u64() as usize) % n_samples
654                    })
655                    .collect();
656
657                let feature_indices: Vec<usize> =
658                    rand_sample_indices(&mut rng, n_features, max_features_n).into_vec();
659
660                build_regression_tree_with_feature_subset(
661                    x,
662                    y,
663                    &bootstrap_indices,
664                    &feature_indices,
665                    &params,
666                )
667            })
668            .collect();
669
670        // Aggregate feature importances.
671        let mut total_importances = Array1::<F>::zeros(n_features);
672        for tree_nodes in &trees {
673            let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
674            total_importances = total_importances + tree_imp;
675        }
676        let imp_sum: F = total_importances
677            .iter()
678            .copied()
679            .fold(F::zero(), |a, b| a + b);
680        if imp_sum > F::zero() {
681            total_importances.mapv_inplace(|v| v / imp_sum);
682        }
683
684        Ok(FittedRandomForestRegressor {
685            trees,
686            n_features,
687            feature_importances: total_importances,
688        })
689    }
690}
691
692impl<F: Float + Send + Sync + 'static> FittedRandomForestRegressor<F> {
693    /// Returns a reference to the individual tree node vectors.
694    #[must_use]
695    pub fn trees(&self) -> &[Vec<Node<F>>] {
696        &self.trees
697    }
698
699    /// Returns the number of features the model was trained on.
700    #[must_use]
701    pub fn n_features(&self) -> usize {
702        self.n_features
703    }
704}
705
706impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedRandomForestRegressor<F> {
707    type Output = Array1<F>;
708    type Error = FerroError;
709
710    /// Predict target values by averaging across all trees.
711    ///
712    /// # Errors
713    ///
714    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
715    /// not match the fitted model.
716    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
717        if x.ncols() != self.n_features {
718            return Err(FerroError::ShapeMismatch {
719                expected: vec![self.n_features],
720                actual: vec![x.ncols()],
721                context: "number of features must match fitted model".into(),
722            });
723        }
724
725        let n_samples = x.nrows();
726        let n_trees_f = F::from(self.trees.len()).unwrap();
727        let mut predictions = Array1::zeros(n_samples);
728
729        for i in 0..n_samples {
730            let row = x.row(i);
731            let mut sum = F::zero();
732
733            for tree_nodes in &self.trees {
734                let leaf_idx = decision_tree::traverse(tree_nodes, &row);
735                if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
736                    sum = sum + value;
737                }
738            }
739
740            predictions[i] = sum / n_trees_f;
741        }
742
743        Ok(predictions)
744    }
745}
746
747impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedRandomForestRegressor<F> {
748    fn feature_importances(&self) -> &Array1<F> {
749        &self.feature_importances
750    }
751}
752
753// Pipeline integration.
754impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for RandomForestRegressor<F> {
755    fn fit_pipeline(
756        &self,
757        x: &Array2<F>,
758        y: &Array1<F>,
759    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
760        let fitted = self.fit(x, y)?;
761        Ok(Box::new(fitted))
762    }
763}
764
765impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F>
766    for FittedRandomForestRegressor<F>
767{
768    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
769        self.predict(x)
770    }
771}
772
773// ---------------------------------------------------------------------------
774// Tests
775// ---------------------------------------------------------------------------
776
777#[cfg(test)]
778mod tests {
779    use super::*;
780    use approx::assert_relative_eq;
781    use ndarray::array;
782
783    // -- Classifier tests --
784
785    #[test]
786    fn test_forest_classifier_simple() {
787        let x = Array2::from_shape_vec(
788            (8, 2),
789            vec![
790                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
791            ],
792        )
793        .unwrap();
794        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
795
796        let model = RandomForestClassifier::<f64>::new()
797            .with_n_estimators(20)
798            .with_random_state(42);
799        let fitted = model.fit(&x, &y).unwrap();
800        let preds = fitted.predict(&x).unwrap();
801
802        assert_eq!(preds.len(), 8);
803        for i in 0..4 {
804            assert_eq!(preds[i], 0);
805        }
806        for i in 4..8 {
807            assert_eq!(preds[i], 1);
808        }
809    }
810
811    #[test]
812    fn test_forest_classifier_reproducibility() {
813        let x = Array2::from_shape_vec(
814            (8, 2),
815            vec![
816                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
817            ],
818        )
819        .unwrap();
820        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
821
822        let model = RandomForestClassifier::<f64>::new()
823            .with_n_estimators(10)
824            .with_random_state(123);
825
826        let fitted1 = model.fit(&x, &y).unwrap();
827        let fitted2 = model.fit(&x, &y).unwrap();
828
829        let preds1 = fitted1.predict(&x).unwrap();
830        let preds2 = fitted2.predict(&x).unwrap();
831
832        assert_eq!(preds1, preds2);
833    }
834
835    #[test]
836    fn test_forest_classifier_feature_importances() {
837        let x = Array2::from_shape_vec(
838            (10, 3),
839            vec![
840                1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
841                0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
842            ],
843        )
844        .unwrap();
845        let y = array![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
846
847        let model = RandomForestClassifier::<f64>::new()
848            .with_n_estimators(20)
849            .with_max_features(MaxFeatures::All)
850            .with_random_state(42);
851        let fitted = model.fit(&x, &y).unwrap();
852        let importances = fitted.feature_importances();
853
854        assert_eq!(importances.len(), 3);
855        assert!(importances[0] > importances[1]);
856        assert!(importances[0] > importances[2]);
857    }
858
859    #[test]
860    fn test_forest_classifier_has_classes() {
861        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
862        let y = array![0, 1, 2, 0, 1, 2];
863
864        let model = RandomForestClassifier::<f64>::new()
865            .with_n_estimators(5)
866            .with_random_state(0);
867        let fitted = model.fit(&x, &y).unwrap();
868
869        assert_eq!(fitted.classes(), &[0, 1, 2]);
870        assert_eq!(fitted.n_classes(), 3);
871    }
872
873    #[test]
874    fn test_forest_classifier_shape_mismatch_fit() {
875        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
876        let y = array![0, 1];
877
878        let model = RandomForestClassifier::<f64>::new().with_n_estimators(5);
879        assert!(model.fit(&x, &y).is_err());
880    }
881
882    #[test]
883    fn test_forest_classifier_shape_mismatch_predict() {
884        let x =
885            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
886        let y = array![0, 0, 1, 1];
887
888        let model = RandomForestClassifier::<f64>::new()
889            .with_n_estimators(5)
890            .with_random_state(0);
891        let fitted = model.fit(&x, &y).unwrap();
892
893        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
894        assert!(fitted.predict(&x_bad).is_err());
895    }
896
897    #[test]
898    fn test_forest_classifier_empty_data() {
899        let x = Array2::<f64>::zeros((0, 2));
900        let y = Array1::<usize>::zeros(0);
901
902        let model = RandomForestClassifier::<f64>::new().with_n_estimators(5);
903        assert!(model.fit(&x, &y).is_err());
904    }
905
906    #[test]
907    fn test_forest_classifier_zero_estimators() {
908        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
909        let y = array![0, 0, 1, 1];
910
911        let model = RandomForestClassifier::<f64>::new().with_n_estimators(0);
912        assert!(model.fit(&x, &y).is_err());
913    }
914
915    #[test]
916    fn test_forest_classifier_single_tree() {
917        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
918        let y = array![0, 0, 0, 1, 1, 1];
919
920        let model = RandomForestClassifier::<f64>::new()
921            .with_n_estimators(1)
922            .with_max_features(MaxFeatures::All)
923            .with_random_state(42);
924        let fitted = model.fit(&x, &y).unwrap();
925        let preds = fitted.predict(&x).unwrap();
926
927        assert_eq!(preds.len(), 6);
928    }
929
930    #[test]
931    fn test_forest_classifier_pipeline_integration() {
932        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
933        let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
934
935        let model = RandomForestClassifier::<f64>::new()
936            .with_n_estimators(5)
937            .with_random_state(42);
938        let fitted = model.fit_pipeline(&x, &y).unwrap();
939        let preds = fitted.predict_pipeline(&x).unwrap();
940        assert_eq!(preds.len(), 6);
941    }
942
943    #[test]
944    fn test_forest_classifier_max_depth() {
945        let x =
946            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
947        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
948
949        let model = RandomForestClassifier::<f64>::new()
950            .with_n_estimators(10)
951            .with_max_depth(Some(1))
952            .with_max_features(MaxFeatures::All)
953            .with_random_state(42);
954        let fitted = model.fit(&x, &y).unwrap();
955        let preds = fitted.predict(&x).unwrap();
956
957        assert_eq!(preds.len(), 8);
958    }
959
960    // -- Regressor tests --
961
962    #[test]
963    fn test_forest_regressor_simple() {
964        let x =
965            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
966        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
967
968        let model = RandomForestRegressor::<f64>::new()
969            .with_n_estimators(50)
970            .with_random_state(42);
971        let fitted = model.fit(&x, &y).unwrap();
972        let preds = fitted.predict(&x).unwrap();
973
974        assert_eq!(preds.len(), 8);
975        for i in 0..4 {
976            assert!(preds[i] < 3.0, "Expected ~1.0, got {}", preds[i]);
977        }
978        for i in 4..8 {
979            assert!(preds[i] > 3.0, "Expected ~5.0, got {}", preds[i]);
980        }
981    }
982
983    #[test]
984    fn test_forest_regressor_reproducibility() {
985        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
986        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
987
988        let model = RandomForestRegressor::<f64>::new()
989            .with_n_estimators(10)
990            .with_random_state(99);
991
992        let fitted1 = model.fit(&x, &y).unwrap();
993        let fitted2 = model.fit(&x, &y).unwrap();
994
995        let preds1 = fitted1.predict(&x).unwrap();
996        let preds2 = fitted2.predict(&x).unwrap();
997
998        for (p1, p2) in preds1.iter().zip(preds2.iter()) {
999            assert_relative_eq!(*p1, *p2, epsilon = 1e-10);
1000        }
1001    }
1002
1003    #[test]
1004    fn test_forest_regressor_feature_importances() {
1005        let x = Array2::from_shape_vec(
1006            (8, 2),
1007            vec![
1008                1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1009            ],
1010        )
1011        .unwrap();
1012        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1013
1014        let model = RandomForestRegressor::<f64>::new()
1015            .with_n_estimators(20)
1016            .with_max_features(MaxFeatures::All)
1017            .with_random_state(42);
1018        let fitted = model.fit(&x, &y).unwrap();
1019        let importances = fitted.feature_importances();
1020
1021        assert_eq!(importances.len(), 2);
1022        assert!(importances[0] > importances[1]);
1023    }
1024
1025    #[test]
1026    fn test_forest_regressor_shape_mismatch_fit() {
1027        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1028        let y = array![1.0, 2.0];
1029
1030        let model = RandomForestRegressor::<f64>::new().with_n_estimators(5);
1031        assert!(model.fit(&x, &y).is_err());
1032    }
1033
1034    #[test]
1035    fn test_forest_regressor_shape_mismatch_predict() {
1036        let x =
1037            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1038        let y = array![1.0, 2.0, 3.0, 4.0];
1039
1040        let model = RandomForestRegressor::<f64>::new()
1041            .with_n_estimators(5)
1042            .with_random_state(0);
1043        let fitted = model.fit(&x, &y).unwrap();
1044
1045        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1046        assert!(fitted.predict(&x_bad).is_err());
1047    }
1048
1049    #[test]
1050    fn test_forest_regressor_empty_data() {
1051        let x = Array2::<f64>::zeros((0, 2));
1052        let y = Array1::<f64>::zeros(0);
1053
1054        let model = RandomForestRegressor::<f64>::new().with_n_estimators(5);
1055        assert!(model.fit(&x, &y).is_err());
1056    }
1057
1058    #[test]
1059    fn test_forest_regressor_zero_estimators() {
1060        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1061        let y = array![1.0, 2.0, 3.0, 4.0];
1062
1063        let model = RandomForestRegressor::<f64>::new().with_n_estimators(0);
1064        assert!(model.fit(&x, &y).is_err());
1065    }
1066
1067    #[test]
1068    fn test_forest_regressor_pipeline_integration() {
1069        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1070        let y = array![1.0, 2.0, 3.0, 4.0];
1071
1072        let model = RandomForestRegressor::<f64>::new()
1073            .with_n_estimators(5)
1074            .with_random_state(42);
1075        let fitted = model.fit_pipeline(&x, &y).unwrap();
1076        let preds = fitted.predict_pipeline(&x).unwrap();
1077        assert_eq!(preds.len(), 4);
1078    }
1079
1080    #[test]
1081    fn test_forest_regressor_max_features_strategies() {
1082        let x = Array2::from_shape_vec(
1083            (8, 4),
1084            vec![
1085                1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0, 3.0, 4.0, 5.0, 6.0, 4.0, 5.0, 6.0, 7.0,
1086                5.0, 6.0, 7.0, 8.0, 6.0, 7.0, 8.0, 9.0, 7.0, 8.0, 9.0, 10.0, 8.0, 9.0, 10.0, 11.0,
1087            ],
1088        )
1089        .unwrap();
1090        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1091
1092        for strategy in &[
1093            MaxFeatures::Sqrt,
1094            MaxFeatures::Log2,
1095            MaxFeatures::All,
1096            MaxFeatures::Fixed(2),
1097            MaxFeatures::Fraction(0.5),
1098        ] {
1099            let model = RandomForestRegressor::<f64>::new()
1100                .with_n_estimators(5)
1101                .with_max_features(*strategy)
1102                .with_random_state(42);
1103            let fitted = model.fit(&x, &y).unwrap();
1104            let preds = fitted.predict(&x).unwrap();
1105            assert_eq!(preds.len(), 8);
1106        }
1107    }
1108
1109    // -- MaxFeatures resolution tests --
1110
1111    #[test]
1112    fn test_resolve_max_features_sqrt() {
1113        assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 9), 3);
1114        assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 10), 4);
1115        assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 1), 1);
1116    }
1117
1118    #[test]
1119    fn test_resolve_max_features_log2() {
1120        assert_eq!(resolve_max_features(MaxFeatures::Log2, 8), 3);
1121        assert_eq!(resolve_max_features(MaxFeatures::Log2, 1), 1);
1122    }
1123
1124    #[test]
1125    fn test_resolve_max_features_all() {
1126        assert_eq!(resolve_max_features(MaxFeatures::All, 10), 10);
1127        assert_eq!(resolve_max_features(MaxFeatures::All, 1), 1);
1128    }
1129
1130    #[test]
1131    fn test_resolve_max_features_fixed() {
1132        assert_eq!(resolve_max_features(MaxFeatures::Fixed(3), 10), 3);
1133        assert_eq!(resolve_max_features(MaxFeatures::Fixed(20), 10), 10);
1134    }
1135
1136    #[test]
1137    fn test_resolve_max_features_fraction() {
1138        assert_eq!(resolve_max_features(MaxFeatures::Fraction(0.5), 10), 5);
1139        assert_eq!(resolve_max_features(MaxFeatures::Fraction(0.1), 10), 1);
1140    }
1141
1142    #[test]
1143    fn test_forest_classifier_f32_support() {
1144        let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1145        let y = array![0, 0, 0, 1, 1, 1];
1146
1147        let model = RandomForestClassifier::<f32>::new()
1148            .with_n_estimators(5)
1149            .with_random_state(42);
1150        let fitted = model.fit(&x, &y).unwrap();
1151        let preds = fitted.predict(&x).unwrap();
1152        assert_eq!(preds.len(), 6);
1153    }
1154
1155    #[test]
1156    fn test_forest_regressor_f32_support() {
1157        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1158        let y = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
1159
1160        let model = RandomForestRegressor::<f32>::new()
1161            .with_n_estimators(5)
1162            .with_random_state(42);
1163        let fitted = model.fit(&x, &y).unwrap();
1164        let preds = fitted.predict(&x).unwrap();
1165        assert_eq!(preds.len(), 4);
1166    }
1167}