Skip to main content

ferrolearn_tree/
random_forest.rs

1//! Random forest classifiers and regressors.
2//!
3//! This module provides [`RandomForestClassifier`] and [`RandomForestRegressor`],
4//! which build ensembles of decision trees using bootstrap sampling and random
5//! feature subsets (bagging). Trees are built in parallel via `rayon`.
6//!
7//! # Examples
8//!
9//! ```
10//! use ferrolearn_tree::RandomForestClassifier;
11//! use ferrolearn_core::{Fit, Predict};
12//! use ndarray::{array, Array1, Array2};
13//!
14//! let x = Array2::from_shape_vec((8, 2), vec![
15//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,  4.0, 4.0,
16//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,  8.0, 9.0,
17//! ]).unwrap();
18//! let y = array![0, 0, 0, 0, 1, 1, 1, 1];
19//!
20//! let model = RandomForestClassifier::<f64>::new()
21//!     .with_n_estimators(10)
22//!     .with_random_state(42);
23//! let fitted = model.fit(&x, &y).unwrap();
24//! let preds = fitted.predict(&x).unwrap();
25//! ```
26
27use ferrolearn_core::error::FerroError;
28use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
29use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
30use ferrolearn_core::traits::{Fit, Predict};
31use ndarray::{Array1, Array2};
32use num_traits::{Float, FromPrimitive, ToPrimitive};
33use rand::SeedableRng;
34use rand::rngs::StdRng;
35use rayon::prelude::*;
36use serde::{Deserialize, Serialize};
37
38use crate::decision_tree::{
39    self, ClassificationCriterion, Node, build_classification_tree_per_split_features,
40    build_regression_tree_per_split_features, compute_feature_importances,
41};
42
43// ---------------------------------------------------------------------------
44// MaxFeatures
45// ---------------------------------------------------------------------------
46
47/// Strategy for selecting the number of features considered at each split.
48#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
49pub enum MaxFeatures {
50    /// Use the square root of the total number of features (default for classifiers).
51    Sqrt,
52    /// Use the log2 of the total number of features.
53    Log2,
54    /// Use all features (default for regressors).
55    All,
56    /// Use a specific number of features.
57    Fixed(usize),
58    /// Use a fraction of the total number of features.
59    Fraction(f64),
60}
61
62/// Resolve the `MaxFeatures` strategy to a concrete number.
63fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
64    let result = match strategy {
65        MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
66        MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
67        MaxFeatures::All => n_features,
68        MaxFeatures::Fixed(n) => n.min(n_features),
69        MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
70    };
71    result.max(1).min(n_features)
72}
73
74/// Internal tree parameter struct reused from decision_tree.
75///
76/// Re-created here to avoid leaking internal details; the crate-internal
77/// struct is the same shape.
78fn make_tree_params(
79    max_depth: Option<usize>,
80    min_samples_split: usize,
81    min_samples_leaf: usize,
82) -> decision_tree::TreeParams {
83    decision_tree::TreeParams {
84        max_depth,
85        min_samples_split,
86        min_samples_leaf,
87    }
88}
89
90// ---------------------------------------------------------------------------
91// RandomForestClassifier
92// ---------------------------------------------------------------------------
93
94/// Random forest classifier.
95///
96/// Builds an ensemble of decision tree classifiers, each trained on a
97/// bootstrap sample with a random subset of features considered at each split.
98/// Final predictions are made by majority vote.
99///
100/// # Type Parameters
101///
102/// - `F`: The floating-point type (`f32` or `f64`).
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct RandomForestClassifier<F> {
105    /// Number of trees in the forest.
106    pub n_estimators: usize,
107    /// Maximum depth of each tree. `None` means unlimited.
108    pub max_depth: Option<usize>,
109    /// Strategy for the number of features considered at each split.
110    pub max_features: MaxFeatures,
111    /// Minimum number of samples required to split an internal node.
112    pub min_samples_split: usize,
113    /// Minimum number of samples required in a leaf node.
114    pub min_samples_leaf: usize,
115    /// Random seed for reproducibility. `None` means non-deterministic.
116    pub random_state: Option<u64>,
117    /// Splitting criterion.
118    pub criterion: ClassificationCriterion,
119    _marker: std::marker::PhantomData<F>,
120}
121
122impl<F: Float> RandomForestClassifier<F> {
123    /// Create a new `RandomForestClassifier` with default settings.
124    ///
125    /// Defaults: `n_estimators = 100`, `max_depth = None`,
126    /// `max_features = Sqrt`, `min_samples_split = 2`,
127    /// `min_samples_leaf = 1`, `random_state = None`,
128    /// `criterion = Gini`.
129    #[must_use]
130    pub fn new() -> Self {
131        Self {
132            n_estimators: 100,
133            max_depth: None,
134            max_features: MaxFeatures::Sqrt,
135            min_samples_split: 2,
136            min_samples_leaf: 1,
137            random_state: None,
138            criterion: ClassificationCriterion::Gini,
139            _marker: std::marker::PhantomData,
140        }
141    }
142
143    /// Set the number of trees.
144    #[must_use]
145    pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
146        self.n_estimators = n_estimators;
147        self
148    }
149
150    /// Set the maximum tree depth.
151    #[must_use]
152    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
153        self.max_depth = max_depth;
154        self
155    }
156
157    /// Set the maximum features strategy.
158    #[must_use]
159    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
160        self.max_features = max_features;
161        self
162    }
163
164    /// Set the minimum number of samples to split a node.
165    #[must_use]
166    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
167        self.min_samples_split = min_samples_split;
168        self
169    }
170
171    /// Set the minimum number of samples in a leaf.
172    #[must_use]
173    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
174        self.min_samples_leaf = min_samples_leaf;
175        self
176    }
177
178    /// Set the random seed for reproducibility.
179    #[must_use]
180    pub fn with_random_state(mut self, seed: u64) -> Self {
181        self.random_state = Some(seed);
182        self
183    }
184
185    /// Set the splitting criterion.
186    #[must_use]
187    pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
188        self.criterion = criterion;
189        self
190    }
191}
192
193impl<F: Float> Default for RandomForestClassifier<F> {
194    fn default() -> Self {
195        Self::new()
196    }
197}
198
199// ---------------------------------------------------------------------------
200// FittedRandomForestClassifier
201// ---------------------------------------------------------------------------
202
203/// A fitted random forest classifier.
204///
205/// Stores the ensemble of fitted decision trees and aggregates their
206/// predictions by majority vote.
207#[derive(Debug, Clone)]
208pub struct FittedRandomForestClassifier<F> {
209    /// Individual tree node vectors.
210    trees: Vec<Vec<Node<F>>>,
211    /// Sorted unique class labels.
212    classes: Vec<usize>,
213    /// Number of features.
214    n_features: usize,
215    /// Per-feature importance scores (mean decrease in impurity, normalised).
216    feature_importances: Array1<F>,
217}
218
219impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for RandomForestClassifier<F> {
220    type Fitted = FittedRandomForestClassifier<F>;
221    type Error = FerroError;
222
223    /// Fit the random forest by building `n_estimators` decision trees in parallel.
224    ///
225    /// Each tree is trained on a bootstrap sample of the data, considering only
226    /// a random subset of features at each split.
227    ///
228    /// # Errors
229    ///
230    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
231    /// numbers of samples.
232    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
233    /// Returns [`FerroError::InvalidParameter`] if `n_estimators` is 0.
234    fn fit(
235        &self,
236        x: &Array2<F>,
237        y: &Array1<usize>,
238    ) -> Result<FittedRandomForestClassifier<F>, FerroError> {
239        let (n_samples, n_features) = x.dim();
240
241        if n_samples != y.len() {
242            return Err(FerroError::ShapeMismatch {
243                expected: vec![n_samples],
244                actual: vec![y.len()],
245                context: "y length must match number of samples in X".into(),
246            });
247        }
248        if n_samples == 0 {
249            return Err(FerroError::InsufficientSamples {
250                required: 1,
251                actual: 0,
252                context: "RandomForestClassifier requires at least one sample".into(),
253            });
254        }
255        if self.n_estimators == 0 {
256            return Err(FerroError::InvalidParameter {
257                name: "n_estimators".into(),
258                reason: "must be at least 1".into(),
259            });
260        }
261
262        // Determine unique classes.
263        let mut classes: Vec<usize> = y.iter().copied().collect();
264        classes.sort_unstable();
265        classes.dedup();
266        let n_classes = classes.len();
267
268        let y_mapped: Vec<usize> = y
269            .iter()
270            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
271            .collect();
272
273        let max_features_n = resolve_max_features(self.max_features, n_features);
274        let params = make_tree_params(
275            self.max_depth,
276            self.min_samples_split,
277            self.min_samples_leaf,
278        );
279        let criterion = self.criterion;
280
281        // Generate per-tree seeds sequentially for determinism, then dispatch in parallel.
282        let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
283            let mut master_rng = StdRng::seed_from_u64(seed);
284            (0..self.n_estimators)
285                .map(|_| {
286                    use rand::RngCore;
287                    master_rng.next_u64()
288                })
289                .collect()
290        } else {
291            (0..self.n_estimators)
292                .map(|_| {
293                    use rand::RngCore;
294                    rand::rng().next_u64()
295                })
296                .collect()
297        };
298
299        // Build trees in parallel.
300        //
301        // Each tree gets:
302        //   - a bootstrap sample of rows (Breiman bagging),
303        //   - per-split random feature sampling of size `max_features_n`
304        //     drawn afresh at every split node (Breiman 2001 RF, sklearn
305        //     parity). The previous implementation pre-sampled a single
306        //     `max_features_n` feature subset per tree which severely
307        //     limited each tree's capacity at large p.
308        let trees: Vec<Vec<Node<F>>> = tree_seeds
309            .par_iter()
310            .map(|&seed| {
311                let mut bootstrap_rng = StdRng::seed_from_u64(seed);
312
313                let bootstrap_indices: Vec<usize> = (0..n_samples)
314                    .map(|_| {
315                        use rand::RngCore;
316                        (bootstrap_rng.next_u64() as usize) % n_samples
317                    })
318                    .collect();
319
320                // Use a separate, derived seed for the per-split feature
321                // RNG so that bootstrap sampling and feature sampling are
322                // statistically independent.
323                use rand::RngCore;
324                let split_seed = bootstrap_rng.next_u64();
325
326                build_classification_tree_per_split_features(
327                    x,
328                    &y_mapped,
329                    n_classes,
330                    &bootstrap_indices,
331                    max_features_n,
332                    &params,
333                    criterion,
334                    split_seed,
335                )
336            })
337            .collect();
338
339        // Aggregate feature importances across trees.
340        let mut total_importances = Array1::<F>::zeros(n_features);
341        for tree_nodes in &trees {
342            let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
343            total_importances = total_importances + tree_imp;
344        }
345        let imp_sum: F = total_importances
346            .iter()
347            .copied()
348            .fold(F::zero(), |a, b| a + b);
349        if imp_sum > F::zero() {
350            total_importances.mapv_inplace(|v| v / imp_sum);
351        }
352
353        Ok(FittedRandomForestClassifier {
354            trees,
355            classes,
356            n_features,
357            feature_importances: total_importances,
358        })
359    }
360}
361
362impl<F: Float + Send + Sync + 'static> FittedRandomForestClassifier<F> {
363    /// Returns a reference to the individual tree node vectors.
364    #[must_use]
365    pub fn trees(&self) -> &[Vec<Node<F>>] {
366        &self.trees
367    }
368
369    /// Returns the number of features the model was trained on.
370    #[must_use]
371    pub fn n_features(&self) -> usize {
372        self.n_features
373    }
374
375    /// Mean accuracy on the given test data and labels.
376    /// Equivalent to sklearn's `ClassifierMixin.score`.
377    ///
378    /// # Errors
379    ///
380    /// Returns [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()` or
381    /// the feature count does not match the training data.
382    pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
383        if x.nrows() != y.len() {
384            return Err(FerroError::ShapeMismatch {
385                expected: vec![x.nrows()],
386                actual: vec![y.len()],
387                context: "y length must match number of samples in X".into(),
388            });
389        }
390        let preds = self.predict(x)?;
391        Ok(crate::mean_accuracy(&preds, y))
392    }
393
394    /// Predict class probabilities for each sample by averaging per-tree
395    /// class distributions across the forest. Equivalent to sklearn's
396    /// `RandomForestClassifier.predict_proba`.
397    ///
398    /// Returns an `(n_samples, n_classes)` array. Each row sums to 1.
399    /// When a leaf does not carry a class distribution, it contributes a
400    /// one-hot vote at the leaf's predicted class.
401    ///
402    /// # Errors
403    ///
404    /// Returns [`FerroError::ShapeMismatch`] if the number of features
405    /// does not match the training data.
406    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
407        if x.ncols() != self.n_features {
408            return Err(FerroError::ShapeMismatch {
409                expected: vec![self.n_features],
410                actual: vec![x.ncols()],
411                context: "number of features must match fitted model".into(),
412            });
413        }
414        let n_samples = x.nrows();
415        let n_classes = self.classes.len();
416        let n_trees_f = F::from(self.trees.len()).unwrap();
417        let mut proba = Array2::<F>::zeros((n_samples, n_classes));
418
419        for i in 0..n_samples {
420            let row = x.row(i);
421            for tree_nodes in &self.trees {
422                let leaf_idx = decision_tree::traverse(tree_nodes, &row);
423                match &tree_nodes[leaf_idx] {
424                    Node::Leaf {
425                        class_distribution: Some(dist),
426                        ..
427                    } => {
428                        for (j, &p) in dist.iter().enumerate().take(n_classes) {
429                            proba[[i, j]] = proba[[i, j]] + p;
430                        }
431                    }
432                    Node::Leaf { value, .. } => {
433                        let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
434                        if class_idx < n_classes {
435                            proba[[i, class_idx]] = proba[[i, class_idx]] + F::one();
436                        }
437                    }
438                    _ => {}
439                }
440            }
441            for j in 0..n_classes {
442                proba[[i, j]] = proba[[i, j]] / n_trees_f;
443            }
444        }
445        Ok(proba)
446    }
447
448    /// Element-wise log of [`predict_proba`](Self::predict_proba). Mirrors
449    /// sklearn's `ClassifierMixin.predict_log_proba`.
450    ///
451    /// # Errors
452    ///
453    /// Forwards any error from [`predict_proba`](Self::predict_proba).
454    pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
455        let proba = self.predict_proba(x)?;
456        Ok(crate::log_proba(&proba))
457    }
458}
459
460impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedRandomForestClassifier<F> {
461    type Output = Array1<usize>;
462    type Error = FerroError;
463
464    /// Predict class labels by majority vote across all trees.
465    ///
466    /// # Errors
467    ///
468    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
469    /// not match the fitted model.
470    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
471        if x.ncols() != self.n_features {
472            return Err(FerroError::ShapeMismatch {
473                expected: vec![self.n_features],
474                actual: vec![x.ncols()],
475                context: "number of features must match fitted model".into(),
476            });
477        }
478
479        let n_samples = x.nrows();
480        let n_classes = self.classes.len();
481        let mut predictions = Array1::zeros(n_samples);
482
483        for i in 0..n_samples {
484            let row = x.row(i);
485            let mut votes = vec![0usize; n_classes];
486
487            for tree_nodes in &self.trees {
488                let leaf_idx = decision_tree::traverse(tree_nodes, &row);
489                if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
490                    let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
491                    if class_idx < n_classes {
492                        votes[class_idx] += 1;
493                    }
494                }
495            }
496
497            let winner = votes
498                .iter()
499                .enumerate()
500                .max_by_key(|&(_, &count)| count)
501                .map_or(0, |(idx, _)| idx);
502            predictions[i] = self.classes[winner];
503        }
504
505        Ok(predictions)
506    }
507}
508
509impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
510    for FittedRandomForestClassifier<F>
511{
512    fn feature_importances(&self) -> &Array1<F> {
513        &self.feature_importances
514    }
515}
516
517impl<F: Float + Send + Sync + 'static> HasClasses for FittedRandomForestClassifier<F> {
518    fn classes(&self) -> &[usize] {
519        &self.classes
520    }
521
522    fn n_classes(&self) -> usize {
523        self.classes.len()
524    }
525}
526
527// Pipeline integration.
528impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
529    for RandomForestClassifier<F>
530{
531    fn fit_pipeline(
532        &self,
533        x: &Array2<F>,
534        y: &Array1<F>,
535    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
536        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
537        let fitted = self.fit(x, &y_usize)?;
538        Ok(Box::new(FittedForestClassifierPipelineAdapter(fitted)))
539    }
540}
541
542/// Pipeline adapter for `FittedRandomForestClassifier<F>`.
543struct FittedForestClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
544    FittedRandomForestClassifier<F>,
545);
546
547impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
548    for FittedForestClassifierPipelineAdapter<F>
549{
550    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
551        let preds = self.0.predict(x)?;
552        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
553    }
554}
555
556// ---------------------------------------------------------------------------
557// RandomForestRegressor
558// ---------------------------------------------------------------------------
559
560/// Random forest regressor.
561///
562/// Builds an ensemble of decision tree regressors, each trained on a
563/// bootstrap sample with a random subset of features considered at each split.
564/// Final predictions are the mean across all trees.
565///
566/// # Type Parameters
567///
568/// - `F`: The floating-point type (`f32` or `f64`).
569#[derive(Debug, Clone, Serialize, Deserialize)]
570pub struct RandomForestRegressor<F> {
571    /// Number of trees in the forest.
572    pub n_estimators: usize,
573    /// Maximum depth of each tree. `None` means unlimited.
574    pub max_depth: Option<usize>,
575    /// Strategy for the number of features considered at each split.
576    pub max_features: MaxFeatures,
577    /// Minimum number of samples required to split an internal node.
578    pub min_samples_split: usize,
579    /// Minimum number of samples required in a leaf node.
580    pub min_samples_leaf: usize,
581    /// Random seed for reproducibility. `None` means non-deterministic.
582    pub random_state: Option<u64>,
583    _marker: std::marker::PhantomData<F>,
584}
585
586impl<F: Float> RandomForestRegressor<F> {
587    /// Create a new `RandomForestRegressor` with default settings.
588    ///
589    /// Defaults: `n_estimators = 100`, `max_depth = None`,
590    /// `max_features = All`, `min_samples_split = 2`,
591    /// `min_samples_leaf = 1`, `random_state = None`.
592    #[must_use]
593    pub fn new() -> Self {
594        Self {
595            n_estimators: 100,
596            max_depth: None,
597            max_features: MaxFeatures::All,
598            min_samples_split: 2,
599            min_samples_leaf: 1,
600            random_state: None,
601            _marker: std::marker::PhantomData,
602        }
603    }
604
605    /// Set the number of trees.
606    #[must_use]
607    pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
608        self.n_estimators = n_estimators;
609        self
610    }
611
612    /// Set the maximum tree depth.
613    #[must_use]
614    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
615        self.max_depth = max_depth;
616        self
617    }
618
619    /// Set the maximum features strategy.
620    #[must_use]
621    pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
622        self.max_features = max_features;
623        self
624    }
625
626    /// Set the minimum number of samples to split a node.
627    #[must_use]
628    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
629        self.min_samples_split = min_samples_split;
630        self
631    }
632
633    /// Set the minimum number of samples in a leaf.
634    #[must_use]
635    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
636        self.min_samples_leaf = min_samples_leaf;
637        self
638    }
639
640    /// Set the random seed for reproducibility.
641    #[must_use]
642    pub fn with_random_state(mut self, seed: u64) -> Self {
643        self.random_state = Some(seed);
644        self
645    }
646}
647
648impl<F: Float> Default for RandomForestRegressor<F> {
649    fn default() -> Self {
650        Self::new()
651    }
652}
653
654// ---------------------------------------------------------------------------
655// FittedRandomForestRegressor
656// ---------------------------------------------------------------------------
657
658/// A fitted random forest regressor.
659///
660/// Stores the ensemble of fitted decision trees and aggregates their
661/// predictions by averaging.
662#[derive(Debug, Clone)]
663pub struct FittedRandomForestRegressor<F> {
664    /// Individual tree node vectors.
665    trees: Vec<Vec<Node<F>>>,
666    /// Number of features.
667    n_features: usize,
668    /// Per-feature importance scores (mean decrease in impurity, normalised).
669    feature_importances: Array1<F>,
670}
671
672impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for RandomForestRegressor<F> {
673    type Fitted = FittedRandomForestRegressor<F>;
674    type Error = FerroError;
675
676    /// Fit the random forest regressor.
677    ///
678    /// # Errors
679    ///
680    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
681    /// numbers of samples.
682    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
683    /// Returns [`FerroError::InvalidParameter`] if `n_estimators` is 0.
684    fn fit(
685        &self,
686        x: &Array2<F>,
687        y: &Array1<F>,
688    ) -> Result<FittedRandomForestRegressor<F>, FerroError> {
689        let (n_samples, n_features) = x.dim();
690
691        if n_samples != y.len() {
692            return Err(FerroError::ShapeMismatch {
693                expected: vec![n_samples],
694                actual: vec![y.len()],
695                context: "y length must match number of samples in X".into(),
696            });
697        }
698        if n_samples == 0 {
699            return Err(FerroError::InsufficientSamples {
700                required: 1,
701                actual: 0,
702                context: "RandomForestRegressor requires at least one sample".into(),
703            });
704        }
705        if self.n_estimators == 0 {
706            return Err(FerroError::InvalidParameter {
707                name: "n_estimators".into(),
708                reason: "must be at least 1".into(),
709            });
710        }
711
712        let max_features_n = resolve_max_features(self.max_features, n_features);
713        let params = make_tree_params(
714            self.max_depth,
715            self.min_samples_split,
716            self.min_samples_leaf,
717        );
718
719        // Generate per-tree seeds sequentially.
720        let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
721            let mut master_rng = StdRng::seed_from_u64(seed);
722            (0..self.n_estimators)
723                .map(|_| {
724                    use rand::RngCore;
725                    master_rng.next_u64()
726                })
727                .collect()
728        } else {
729            (0..self.n_estimators)
730                .map(|_| {
731                    use rand::RngCore;
732                    rand::rng().next_u64()
733                })
734                .collect()
735        };
736
737        // Build trees in parallel — per-split feature sampling (Breiman RF,
738        // sklearn parity); see RandomForestClassifier::fit for full notes.
739        let trees: Vec<Vec<Node<F>>> = tree_seeds
740            .par_iter()
741            .map(|&seed| {
742                let mut bootstrap_rng = StdRng::seed_from_u64(seed);
743
744                let bootstrap_indices: Vec<usize> = (0..n_samples)
745                    .map(|_| {
746                        use rand::RngCore;
747                        (bootstrap_rng.next_u64() as usize) % n_samples
748                    })
749                    .collect();
750
751                use rand::RngCore;
752                let split_seed = bootstrap_rng.next_u64();
753
754                build_regression_tree_per_split_features(
755                    x,
756                    y,
757                    &bootstrap_indices,
758                    max_features_n,
759                    &params,
760                    split_seed,
761                )
762            })
763            .collect();
764
765        // Aggregate feature importances.
766        let mut total_importances = Array1::<F>::zeros(n_features);
767        for tree_nodes in &trees {
768            let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
769            total_importances = total_importances + tree_imp;
770        }
771        let imp_sum: F = total_importances
772            .iter()
773            .copied()
774            .fold(F::zero(), |a, b| a + b);
775        if imp_sum > F::zero() {
776            total_importances.mapv_inplace(|v| v / imp_sum);
777        }
778
779        Ok(FittedRandomForestRegressor {
780            trees,
781            n_features,
782            feature_importances: total_importances,
783        })
784    }
785}
786
787impl<F: Float + Send + Sync + 'static> FittedRandomForestRegressor<F> {
788    /// Returns a reference to the individual tree node vectors.
789    #[must_use]
790    pub fn trees(&self) -> &[Vec<Node<F>>] {
791        &self.trees
792    }
793
794    /// Returns the number of features the model was trained on.
795    #[must_use]
796    pub fn n_features(&self) -> usize {
797        self.n_features
798    }
799
800    /// R² coefficient of determination on the given test data.
801    /// Equivalent to sklearn's `RegressorMixin.score`.
802    ///
803    /// # Errors
804    ///
805    /// Returns [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()` or
806    /// the feature count does not match the training data.
807    pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<F, FerroError> {
808        if x.nrows() != y.len() {
809            return Err(FerroError::ShapeMismatch {
810                expected: vec![x.nrows()],
811                actual: vec![y.len()],
812                context: "y length must match number of samples in X".into(),
813            });
814        }
815        let preds = self.predict(x)?;
816        Ok(crate::r2_score(&preds, y))
817    }
818}
819
820impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedRandomForestRegressor<F> {
821    type Output = Array1<F>;
822    type Error = FerroError;
823
824    /// Predict target values by averaging across all trees.
825    ///
826    /// # Errors
827    ///
828    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
829    /// not match the fitted model.
830    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
831        if x.ncols() != self.n_features {
832            return Err(FerroError::ShapeMismatch {
833                expected: vec![self.n_features],
834                actual: vec![x.ncols()],
835                context: "number of features must match fitted model".into(),
836            });
837        }
838
839        let n_samples = x.nrows();
840        let n_trees_f = F::from(self.trees.len()).unwrap();
841        let mut predictions = Array1::zeros(n_samples);
842
843        for i in 0..n_samples {
844            let row = x.row(i);
845            let mut sum = F::zero();
846
847            for tree_nodes in &self.trees {
848                let leaf_idx = decision_tree::traverse(tree_nodes, &row);
849                if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
850                    sum = sum + value;
851                }
852            }
853
854            predictions[i] = sum / n_trees_f;
855        }
856
857        Ok(predictions)
858    }
859}
860
861impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedRandomForestRegressor<F> {
862    fn feature_importances(&self) -> &Array1<F> {
863        &self.feature_importances
864    }
865}
866
867// Pipeline integration.
868impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for RandomForestRegressor<F> {
869    fn fit_pipeline(
870        &self,
871        x: &Array2<F>,
872        y: &Array1<F>,
873    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
874        let fitted = self.fit(x, y)?;
875        Ok(Box::new(fitted))
876    }
877}
878
879impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F>
880    for FittedRandomForestRegressor<F>
881{
882    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
883        self.predict(x)
884    }
885}
886
887// ---------------------------------------------------------------------------
888// Tests
889// ---------------------------------------------------------------------------
890
891#[cfg(test)]
892mod tests {
893    use super::*;
894    use approx::assert_relative_eq;
895    use ndarray::array;
896
897    // -- Classifier tests --
898
899    #[test]
900    fn test_forest_classifier_simple() {
901        let x = Array2::from_shape_vec(
902            (8, 2),
903            vec![
904                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
905            ],
906        )
907        .unwrap();
908        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
909
910        let model = RandomForestClassifier::<f64>::new()
911            .with_n_estimators(20)
912            .with_random_state(42);
913        let fitted = model.fit(&x, &y).unwrap();
914        let preds = fitted.predict(&x).unwrap();
915
916        assert_eq!(preds.len(), 8);
917        for i in 0..4 {
918            assert_eq!(preds[i], 0);
919        }
920        for i in 4..8 {
921            assert_eq!(preds[i], 1);
922        }
923    }
924
925    #[test]
926    fn test_forest_classifier_reproducibility() {
927        let x = Array2::from_shape_vec(
928            (8, 2),
929            vec![
930                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
931            ],
932        )
933        .unwrap();
934        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
935
936        let model = RandomForestClassifier::<f64>::new()
937            .with_n_estimators(10)
938            .with_random_state(123);
939
940        let fitted1 = model.fit(&x, &y).unwrap();
941        let fitted2 = model.fit(&x, &y).unwrap();
942
943        let preds1 = fitted1.predict(&x).unwrap();
944        let preds2 = fitted2.predict(&x).unwrap();
945
946        assert_eq!(preds1, preds2);
947    }
948
949    #[test]
950    fn test_forest_classifier_feature_importances() {
951        let x = Array2::from_shape_vec(
952            (10, 3),
953            vec![
954                1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
955                0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
956            ],
957        )
958        .unwrap();
959        let y = array![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
960
961        let model = RandomForestClassifier::<f64>::new()
962            .with_n_estimators(20)
963            .with_max_features(MaxFeatures::All)
964            .with_random_state(42);
965        let fitted = model.fit(&x, &y).unwrap();
966        let importances = fitted.feature_importances();
967
968        assert_eq!(importances.len(), 3);
969        assert!(importances[0] > importances[1]);
970        assert!(importances[0] > importances[2]);
971    }
972
973    #[test]
974    fn test_forest_classifier_has_classes() {
975        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
976        let y = array![0, 1, 2, 0, 1, 2];
977
978        let model = RandomForestClassifier::<f64>::new()
979            .with_n_estimators(5)
980            .with_random_state(0);
981        let fitted = model.fit(&x, &y).unwrap();
982
983        assert_eq!(fitted.classes(), &[0, 1, 2]);
984        assert_eq!(fitted.n_classes(), 3);
985    }
986
987    #[test]
988    fn test_forest_classifier_shape_mismatch_fit() {
989        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
990        let y = array![0, 1];
991
992        let model = RandomForestClassifier::<f64>::new().with_n_estimators(5);
993        assert!(model.fit(&x, &y).is_err());
994    }
995
996    #[test]
997    fn test_forest_classifier_shape_mismatch_predict() {
998        let x =
999            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1000        let y = array![0, 0, 1, 1];
1001
1002        let model = RandomForestClassifier::<f64>::new()
1003            .with_n_estimators(5)
1004            .with_random_state(0);
1005        let fitted = model.fit(&x, &y).unwrap();
1006
1007        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1008        assert!(fitted.predict(&x_bad).is_err());
1009    }
1010
1011    #[test]
1012    fn test_forest_classifier_empty_data() {
1013        let x = Array2::<f64>::zeros((0, 2));
1014        let y = Array1::<usize>::zeros(0);
1015
1016        let model = RandomForestClassifier::<f64>::new().with_n_estimators(5);
1017        assert!(model.fit(&x, &y).is_err());
1018    }
1019
1020    #[test]
1021    fn test_forest_classifier_zero_estimators() {
1022        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1023        let y = array![0, 0, 1, 1];
1024
1025        let model = RandomForestClassifier::<f64>::new().with_n_estimators(0);
1026        assert!(model.fit(&x, &y).is_err());
1027    }
1028
1029    #[test]
1030    fn test_forest_classifier_single_tree() {
1031        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1032        let y = array![0, 0, 0, 1, 1, 1];
1033
1034        let model = RandomForestClassifier::<f64>::new()
1035            .with_n_estimators(1)
1036            .with_max_features(MaxFeatures::All)
1037            .with_random_state(42);
1038        let fitted = model.fit(&x, &y).unwrap();
1039        let preds = fitted.predict(&x).unwrap();
1040
1041        assert_eq!(preds.len(), 6);
1042    }
1043
1044    #[test]
1045    fn test_forest_classifier_pipeline_integration() {
1046        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1047        let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
1048
1049        let model = RandomForestClassifier::<f64>::new()
1050            .with_n_estimators(5)
1051            .with_random_state(42);
1052        let fitted = model.fit_pipeline(&x, &y).unwrap();
1053        let preds = fitted.predict_pipeline(&x).unwrap();
1054        assert_eq!(preds.len(), 6);
1055    }
1056
1057    #[test]
1058    fn test_forest_classifier_max_depth() {
1059        let x =
1060            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1061        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1062
1063        let model = RandomForestClassifier::<f64>::new()
1064            .with_n_estimators(10)
1065            .with_max_depth(Some(1))
1066            .with_max_features(MaxFeatures::All)
1067            .with_random_state(42);
1068        let fitted = model.fit(&x, &y).unwrap();
1069        let preds = fitted.predict(&x).unwrap();
1070
1071        assert_eq!(preds.len(), 8);
1072    }
1073
1074    // -- Regressor tests --
1075
1076    #[test]
1077    fn test_forest_regressor_simple() {
1078        let x =
1079            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1080        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1081
1082        let model = RandomForestRegressor::<f64>::new()
1083            .with_n_estimators(50)
1084            .with_random_state(42);
1085        let fitted = model.fit(&x, &y).unwrap();
1086        let preds = fitted.predict(&x).unwrap();
1087
1088        assert_eq!(preds.len(), 8);
1089        for i in 0..4 {
1090            assert!(preds[i] < 3.0, "Expected ~1.0, got {}", preds[i]);
1091        }
1092        for i in 4..8 {
1093            assert!(preds[i] > 3.0, "Expected ~5.0, got {}", preds[i]);
1094        }
1095    }
1096
1097    #[test]
1098    fn test_forest_regressor_reproducibility() {
1099        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1100        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1101
1102        let model = RandomForestRegressor::<f64>::new()
1103            .with_n_estimators(10)
1104            .with_random_state(99);
1105
1106        let fitted1 = model.fit(&x, &y).unwrap();
1107        let fitted2 = model.fit(&x, &y).unwrap();
1108
1109        let preds1 = fitted1.predict(&x).unwrap();
1110        let preds2 = fitted2.predict(&x).unwrap();
1111
1112        for (p1, p2) in preds1.iter().zip(preds2.iter()) {
1113            assert_relative_eq!(*p1, *p2, epsilon = 1e-10);
1114        }
1115    }
1116
1117    #[test]
1118    fn test_forest_regressor_feature_importances() {
1119        let x = Array2::from_shape_vec(
1120            (8, 2),
1121            vec![
1122                1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1123            ],
1124        )
1125        .unwrap();
1126        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1127
1128        let model = RandomForestRegressor::<f64>::new()
1129            .with_n_estimators(20)
1130            .with_max_features(MaxFeatures::All)
1131            .with_random_state(42);
1132        let fitted = model.fit(&x, &y).unwrap();
1133        let importances = fitted.feature_importances();
1134
1135        assert_eq!(importances.len(), 2);
1136        assert!(importances[0] > importances[1]);
1137    }
1138
1139    #[test]
1140    fn test_forest_regressor_shape_mismatch_fit() {
1141        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1142        let y = array![1.0, 2.0];
1143
1144        let model = RandomForestRegressor::<f64>::new().with_n_estimators(5);
1145        assert!(model.fit(&x, &y).is_err());
1146    }
1147
1148    #[test]
1149    fn test_forest_regressor_shape_mismatch_predict() {
1150        let x =
1151            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1152        let y = array![1.0, 2.0, 3.0, 4.0];
1153
1154        let model = RandomForestRegressor::<f64>::new()
1155            .with_n_estimators(5)
1156            .with_random_state(0);
1157        let fitted = model.fit(&x, &y).unwrap();
1158
1159        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1160        assert!(fitted.predict(&x_bad).is_err());
1161    }
1162
1163    #[test]
1164    fn test_forest_regressor_empty_data() {
1165        let x = Array2::<f64>::zeros((0, 2));
1166        let y = Array1::<f64>::zeros(0);
1167
1168        let model = RandomForestRegressor::<f64>::new().with_n_estimators(5);
1169        assert!(model.fit(&x, &y).is_err());
1170    }
1171
1172    #[test]
1173    fn test_forest_regressor_zero_estimators() {
1174        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1175        let y = array![1.0, 2.0, 3.0, 4.0];
1176
1177        let model = RandomForestRegressor::<f64>::new().with_n_estimators(0);
1178        assert!(model.fit(&x, &y).is_err());
1179    }
1180
1181    #[test]
1182    fn test_forest_regressor_pipeline_integration() {
1183        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1184        let y = array![1.0, 2.0, 3.0, 4.0];
1185
1186        let model = RandomForestRegressor::<f64>::new()
1187            .with_n_estimators(5)
1188            .with_random_state(42);
1189        let fitted = model.fit_pipeline(&x, &y).unwrap();
1190        let preds = fitted.predict_pipeline(&x).unwrap();
1191        assert_eq!(preds.len(), 4);
1192    }
1193
1194    #[test]
1195    fn test_forest_regressor_max_features_strategies() {
1196        let x = Array2::from_shape_vec(
1197            (8, 4),
1198            vec![
1199                1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0, 3.0, 4.0, 5.0, 6.0, 4.0, 5.0, 6.0, 7.0,
1200                5.0, 6.0, 7.0, 8.0, 6.0, 7.0, 8.0, 9.0, 7.0, 8.0, 9.0, 10.0, 8.0, 9.0, 10.0, 11.0,
1201            ],
1202        )
1203        .unwrap();
1204        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1205
1206        for strategy in &[
1207            MaxFeatures::Sqrt,
1208            MaxFeatures::Log2,
1209            MaxFeatures::All,
1210            MaxFeatures::Fixed(2),
1211            MaxFeatures::Fraction(0.5),
1212        ] {
1213            let model = RandomForestRegressor::<f64>::new()
1214                .with_n_estimators(5)
1215                .with_max_features(*strategy)
1216                .with_random_state(42);
1217            let fitted = model.fit(&x, &y).unwrap();
1218            let preds = fitted.predict(&x).unwrap();
1219            assert_eq!(preds.len(), 8);
1220        }
1221    }
1222
1223    // -- MaxFeatures resolution tests --
1224
1225    #[test]
1226    fn test_resolve_max_features_sqrt() {
1227        assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 9), 3);
1228        assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 10), 4);
1229        assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 1), 1);
1230    }
1231
1232    #[test]
1233    fn test_resolve_max_features_log2() {
1234        assert_eq!(resolve_max_features(MaxFeatures::Log2, 8), 3);
1235        assert_eq!(resolve_max_features(MaxFeatures::Log2, 1), 1);
1236    }
1237
1238    #[test]
1239    fn test_resolve_max_features_all() {
1240        assert_eq!(resolve_max_features(MaxFeatures::All, 10), 10);
1241        assert_eq!(resolve_max_features(MaxFeatures::All, 1), 1);
1242    }
1243
1244    #[test]
1245    fn test_resolve_max_features_fixed() {
1246        assert_eq!(resolve_max_features(MaxFeatures::Fixed(3), 10), 3);
1247        assert_eq!(resolve_max_features(MaxFeatures::Fixed(20), 10), 10);
1248    }
1249
1250    #[test]
1251    fn test_resolve_max_features_fraction() {
1252        assert_eq!(resolve_max_features(MaxFeatures::Fraction(0.5), 10), 5);
1253        assert_eq!(resolve_max_features(MaxFeatures::Fraction(0.1), 10), 1);
1254    }
1255
1256    #[test]
1257    fn test_forest_classifier_f32_support() {
1258        let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1259        let y = array![0, 0, 0, 1, 1, 1];
1260
1261        let model = RandomForestClassifier::<f32>::new()
1262            .with_n_estimators(5)
1263            .with_random_state(42);
1264        let fitted = model.fit(&x, &y).unwrap();
1265        let preds = fitted.predict(&x).unwrap();
1266        assert_eq!(preds.len(), 6);
1267    }
1268
1269    #[test]
1270    fn test_forest_regressor_f32_support() {
1271        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1272        let y = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
1273
1274        let model = RandomForestRegressor::<f32>::new()
1275            .with_n_estimators(5)
1276            .with_random_state(42);
1277        let fitted = model.fit(&x, &y).unwrap();
1278        let preds = fitted.predict(&x).unwrap();
1279        assert_eq!(preds.len(), 4);
1280    }
1281}