Skip to main content

anofox_ml_ensemble/
random_forest_classifier.rs

1use anofox_ml_core::{Fit, Float, Predict, PredictProba, Result, RustMlError};
2use anofox_ml_trees::{
3    ClassWeight, DecisionTreeClassifier, FittedDecisionTreeClassifier, SplitCriterion,
4};
5use ndarray::{Array1, Array2};
6use rand::rngs::StdRng;
7use rand::{Rng, SeedableRng};
8use rayon::prelude::*;
9
10/// Random forest classifier parameters (unfitted state).
11///
12/// Trains an ensemble of decision tree classifiers, each on a bootstrap sample
13/// of the data with an optional random subset of features.
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct RandomForestClassifier {
16    /// Number of trees in the forest.
17    pub n_estimators: usize,
18    /// Maximum depth of each tree.
19    pub max_depth: Option<usize>,
20    /// Minimum samples required to split a node.
21    pub min_samples_split: usize,
22    /// Minimum samples required in a leaf node.
23    pub min_samples_leaf: usize,
24    /// Number of features to consider per tree. If `None`, all features are used.
25    pub max_features: Option<usize>,
26    /// Split criterion for each tree. Default: Gini.
27    pub criterion: SplitCriterion,
28    /// Whether to use bootstrap sampling. Default: true.
29    pub bootstrap: bool,
30    /// Fraction of samples to draw for each tree (when bootstrap=true).
31    /// If `None`, draws n_samples (with replacement). Value in (0, 1].
32    pub max_samples: Option<f64>,
33    /// Class weight strategy passed to each tree.
34    pub class_weight: Option<ClassWeight>,
35    /// Whether to compute out-of-bag score after fitting.
36    pub oob_score: bool,
37    /// Random seed for reproducibility.
38    pub seed: u64,
39}
40
41impl RandomForestClassifier {
42    /// Create a new `RandomForestClassifier` with the given number of trees and default parameters.
43    pub fn new(n_estimators: usize) -> Self {
44        Self {
45            n_estimators,
46            max_depth: None,
47            min_samples_split: 2,
48            min_samples_leaf: 1,
49            max_features: None,
50            criterion: SplitCriterion::Gini,
51            bootstrap: true,
52            max_samples: None,
53            class_weight: None,
54            oob_score: false,
55            seed: 0,
56        }
57    }
58
59    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
60        self.max_depth = max_depth;
61        self
62    }
63    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
64        self.min_samples_split = min_samples_split;
65        self
66    }
67    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
68        self.min_samples_leaf = min_samples_leaf;
69        self
70    }
71    pub fn with_max_features(mut self, max_features: Option<usize>) -> Self {
72        self.max_features = max_features;
73        self
74    }
75    pub fn with_criterion(mut self, criterion: SplitCriterion) -> Self {
76        self.criterion = criterion;
77        self
78    }
79    pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
80        self.bootstrap = bootstrap;
81        self
82    }
83    pub fn with_max_samples(mut self, max_samples: Option<f64>) -> Self {
84        self.max_samples = max_samples;
85        self
86    }
87    pub fn with_class_weight(mut self, class_weight: Option<ClassWeight>) -> Self {
88        self.class_weight = class_weight;
89        self
90    }
91    pub fn with_oob_score(mut self, oob_score: bool) -> Self {
92        self.oob_score = oob_score;
93        self
94    }
95    pub fn with_seed(mut self, seed: u64) -> Self {
96        self.seed = seed;
97        self
98    }
99}
100
101impl Default for RandomForestClassifier {
102    fn default() -> Self {
103        Self::new(100)
104    }
105}
106
107/// A single tree in the forest together with its selected feature indices.
108#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
109#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
110struct ForestTree<F: Float> {
111    tree: FittedDecisionTreeClassifier<F>,
112    /// Indices of the features this tree was trained on (relative to the
113    /// original feature matrix). When `max_features` is `None` this contains
114    /// `0..n_features`.
115    feature_indices: Vec<usize>,
116}
117
118/// Fitted random forest classifier.
119#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
120#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
121pub struct FittedRandomForestClassifier<F: Float> {
122    trees: Vec<ForestTree<F>>,
123    n_features: usize,
124    /// OOB accuracy score (only set when oob_score=true).
125    oob_score_value: Option<f64>,
126}
127
128impl<F: Float> Fit<F> for RandomForestClassifier {
129    type Fitted = FittedRandomForestClassifier<F>;
130
131    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
132        if x.nrows() != y.len() {
133            return Err(RustMlError::ShapeMismatch(format!(
134                "X has {} rows but y has {} elements",
135                x.nrows(),
136                y.len()
137            )));
138        }
139        if x.is_empty() {
140            return Err(RustMlError::EmptyInput("training data is empty".into()));
141        }
142        if self.n_estimators == 0 {
143            return Err(RustMlError::InvalidParameter(
144                "n_estimators must be > 0".into(),
145            ));
146        }
147
148        let n_samples = x.nrows();
149        let n_features = x.ncols();
150
151        if let Some(k) = self.max_features {
152            if k == 0 || k > n_features {
153                return Err(RustMlError::InvalidParameter(format!(
154                    "max_features={k} is invalid for data with {n_features} features"
155                )));
156            }
157        }
158
159        let mut rng = StdRng::seed_from_u64(self.seed);
160
161        // Compute bootstrap sample size
162        let draw_size = if let Some(frac) = self.max_samples {
163            if frac <= 0.0 || frac > 1.0 {
164                return Err(RustMlError::InvalidParameter(
165                    "max_samples must be in (0, 1]".into(),
166                ));
167            }
168            (n_samples as f64 * frac).ceil() as usize
169        } else {
170            n_samples
171        };
172
173        let tree_params = DecisionTreeClassifier {
174            max_depth: self.max_depth,
175            min_samples_split: self.min_samples_split,
176            min_samples_leaf: self.min_samples_leaf,
177            criterion: self.criterion,
178            max_features: None,
179            sample_weight: None,
180            class_weight: self.class_weight.clone(),
181        };
182
183        // Pre-generate bootstrap/sample and feature indices for determinism
184        let sample_plans: Vec<(Vec<usize>, Vec<usize>)> = (0..self.n_estimators)
185            .map(|_| {
186                let row_indices: Vec<usize> = if self.bootstrap {
187                    (0..draw_size)
188                        .map(|_| rng.gen_range(0..n_samples))
189                        .collect()
190                } else {
191                    (0..n_samples).collect()
192                };
193                let feature_indices = select_features(n_features, self.max_features, &mut rng);
194                (row_indices, feature_indices)
195            })
196            .collect();
197
198        // Keep a copy of row indices for OOB scoring
199        let all_row_indices: Vec<Vec<usize>> = if self.oob_score {
200            sample_plans.iter().map(|(ri, _)| ri.clone()).collect()
201        } else {
202            Vec::new()
203        };
204
205        // Train trees in parallel
206        let trees: Result<Vec<ForestTree<F>>> = sample_plans
207            .into_par_iter()
208            .map(|(row_indices, feature_indices)| {
209                let x_sub = build_sub_matrix(x, &row_indices, &feature_indices);
210                let y_sub = Array1::from_vec(row_indices.iter().map(|&i| y[i]).collect::<Vec<F>>());
211                let fitted_tree: FittedDecisionTreeClassifier<F> =
212                    tree_params.fit(&x_sub, &y_sub)?;
213                Ok(ForestTree {
214                    tree: fitted_tree,
215                    feature_indices,
216                })
217            })
218            .collect();
219        let trees = trees?;
220
221        // Compute OOB score if requested
222        let oob_score_value = if self.oob_score && self.bootstrap {
223            compute_oob_score_classification(&trees, x, y, n_samples, &all_row_indices)
224        } else {
225            None
226        };
227
228        Ok(FittedRandomForestClassifier {
229            trees,
230            n_features,
231            oob_score_value,
232        })
233    }
234}
235
236impl<F: Float> Predict<F> for FittedRandomForestClassifier<F> {
237    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>> {
238        if x.ncols() != self.n_features {
239            return Err(RustMlError::ShapeMismatch(format!(
240                "expected {} features, got {}",
241                self.n_features,
242                x.ncols()
243            )));
244        }
245
246        let n_samples = x.nrows();
247        let n_trees = self.trees.len();
248
249        // Collect all tree predictions in parallel
250        let all_preds: Result<Vec<Array1<F>>> = self
251            .trees
252            .par_iter()
253            .map(|forest_tree| {
254                let sub_x = build_sub_matrix_cols(x, &forest_tree.feature_indices);
255                forest_tree.tree.predict(&sub_x)
256            })
257            .collect();
258        let all_preds = all_preds?;
259
260        // Aggregate votes per sample
261        let mut predictions = Vec::with_capacity(n_samples);
262        let mut votes = Vec::with_capacity(n_trees);
263        for i in 0..n_samples {
264            votes.clear();
265            for tree_pred in &all_preds {
266                votes.push(tree_pred[i]);
267            }
268            predictions.push(majority_vote(&votes));
269        }
270
271        Ok(Array1::from_vec(predictions))
272    }
273}
274
275impl<F: Float> FittedRandomForestClassifier<F> {
276    /// Feature importances averaged across all trees and normalized to sum to 1.
277    ///
278    /// Each tree's importances are computed in its own (possibly reduced)
279    /// feature space, then mapped back to the original feature indices and
280    /// averaged.
281    pub fn feature_importances(&self) -> Array1<F> {
282        let mut importances = vec![F::zero(); self.n_features];
283        let n_trees = F::from_usize(self.trees.len()).unwrap();
284
285        for forest_tree in &self.trees {
286            let tree_importances = forest_tree.tree.feature_importances();
287            for (local_idx, &original_idx) in forest_tree.feature_indices.iter().enumerate() {
288                importances[original_idx] += tree_importances[local_idx] / n_trees;
289            }
290        }
291
292        // Normalize so importances sum to 1
293        let sum: F = importances.iter().copied().fold(F::zero(), |a, b| a + b);
294        if sum > F::zero() {
295            Array1::from_vec(importances.into_iter().map(|v| v / sum).collect())
296        } else {
297            Array1::zeros(self.n_features)
298        }
299    }
300
301    /// Number of trees in the forest.
302    pub fn n_estimators(&self) -> usize {
303        self.trees.len()
304    }
305
306    /// Predict class probabilities for each sample.
307    ///
308    /// Returns an `Array2<F>` of shape `(n_samples, n_classes)` where each row
309    /// sums to 1.0. Probabilities are averaged across all trees in the forest.
310    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>> {
311        if x.ncols() != self.n_features {
312            return Err(RustMlError::ShapeMismatch(format!(
313                "expected {} features, got {}",
314                self.n_features,
315                x.ncols()
316            )));
317        }
318
319        // Collect probabilities from each tree in parallel
320        let all_proba: Result<Vec<Array2<F>>> = self
321            .trees
322            .par_iter()
323            .map(|forest_tree| {
324                let sub_x = build_sub_matrix_cols(x, &forest_tree.feature_indices);
325                forest_tree.tree.predict_proba(&sub_x)
326            })
327            .collect();
328        let all_proba = all_proba?;
329
330        // Determine global class set (union of all trees' classes)
331        let mut global_classes: Vec<F> = Vec::new();
332        let eps = F::from_f64(1e-9).unwrap();
333        for forest_tree in &self.trees {
334            for c in forest_tree.tree.classes() {
335                if !global_classes.iter().any(|&gc| (gc - c).abs() < eps) {
336                    global_classes.push(c);
337                }
338            }
339        }
340        global_classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
341
342        let n_samples = x.nrows();
343        let n_classes = global_classes.len();
344        let n_trees_f = F::from_usize(self.trees.len()).unwrap();
345        let mut avg_proba = Array2::<F>::zeros((n_samples, n_classes));
346
347        // Map each tree's probabilities to the global class indices and average
348        for (tree_idx, forest_tree) in self.trees.iter().enumerate() {
349            let tree_classes = forest_tree.tree.classes();
350            let tree_proba = &all_proba[tree_idx];
351
352            for (local_ci, &tc) in tree_classes.iter().enumerate() {
353                if let Some(global_ci) = global_classes.iter().position(|&gc| (gc - tc).abs() < eps)
354                {
355                    for i in 0..n_samples {
356                        avg_proba[[i, global_ci]] += tree_proba[[i, local_ci]] / n_trees_f;
357                    }
358                }
359            }
360        }
361
362        Ok(avg_proba)
363    }
364
365    /// OOB accuracy score (only available when `oob_score=true` and `bootstrap=true`).
366    pub fn oob_score(&self) -> Option<f64> {
367        self.oob_score_value
368    }
369
370    /// Compute classification accuracy on the given data.
371    pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<f64> {
372        let preds = self.predict(x)?;
373        let n = y.len();
374        let correct = preds
375            .iter()
376            .zip(y.iter())
377            .filter(|(&p, &t)| (p - t).abs() < F::from_f64(1e-9).unwrap())
378            .count();
379        Ok(correct as f64 / n as f64)
380    }
381
382    /// Returns the unique sorted class labels across all trees.
383    pub fn classes(&self) -> Vec<F> {
384        let eps = F::from_f64(1e-9).unwrap();
385        let mut classes: Vec<F> = Vec::new();
386        for forest_tree in &self.trees {
387            for c in forest_tree.tree.classes() {
388                if !classes.iter().any(|&gc| (gc - c).abs() < eps) {
389                    classes.push(c);
390                }
391            }
392        }
393        classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
394        classes
395    }
396}
397
398/// Select `k` distinct feature indices from `0..n_features` without replacement.
399/// If `max_features` is `None`, returns all feature indices.
400fn select_features(n_features: usize, max_features: Option<usize>, rng: &mut StdRng) -> Vec<usize> {
401    match max_features {
402        None => (0..n_features).collect(),
403        Some(k) => {
404            // Fisher-Yates partial shuffle
405            let mut indices: Vec<usize> = (0..n_features).collect();
406            for i in 0..k {
407                let j = rng.gen_range(i..n_features);
408                indices.swap(i, j);
409            }
410            indices.truncate(k);
411            indices.sort_unstable();
412            indices
413        }
414    }
415}
416
417/// Build a sub-matrix selecting specific rows and columns from `x`.
418fn build_sub_matrix<F: Float>(
419    x: &Array2<F>,
420    row_indices: &[usize],
421    col_indices: &[usize],
422) -> Array2<F> {
423    // Select rows first (produces C-contiguous), then columns.
424    let n_rows = row_indices.len();
425    let n_cols = col_indices.len();
426    let mut data = Vec::with_capacity(n_rows * n_cols);
427    for &ri in row_indices {
428        for &ci in col_indices {
429            data.push(x[[ri, ci]]);
430        }
431    }
432    Array2::from_shape_vec((n_rows, n_cols), data).expect("shape matches data length")
433}
434
435/// Build a sub-matrix selecting all rows but only specific columns from `x`.
436/// Produces a guaranteed C-contiguous (standard layout) array so that
437/// `row.as_slice()` works in downstream predict calls.
438fn build_sub_matrix_cols<F: Float>(x: &Array2<F>, col_indices: &[usize]) -> Array2<F> {
439    let n_rows = x.nrows();
440    let n_cols = col_indices.len();
441    let mut data = Vec::with_capacity(n_rows * n_cols);
442    for i in 0..n_rows {
443        for &ci in col_indices {
444            data.push(x[[i, ci]]);
445        }
446    }
447    Array2::from_shape_vec((n_rows, n_cols), data).expect("shape matches data length")
448}
449
450/// Compute OOB classification accuracy.
451/// For each sample, only trees that did NOT include it in their bootstrap are used.
452fn compute_oob_score_classification<F: Float>(
453    trees: &[ForestTree<F>],
454    x: &Array2<F>,
455    y: &Array1<F>,
456    n_samples: usize,
457    bootstrap_indices: &[Vec<usize>],
458) -> Option<f64> {
459    use std::collections::HashSet;
460
461    let mut correct = 0usize;
462    let mut evaluated = 0usize;
463
464    for i in 0..n_samples {
465        let mut votes: Vec<F> = Vec::new();
466        for (t_idx, forest_tree) in trees.iter().enumerate() {
467            // Check if sample i was NOT in the bootstrap for tree t_idx
468            let in_bag: HashSet<usize> = bootstrap_indices[t_idx].iter().copied().collect();
469            if in_bag.contains(&i) {
470                continue;
471            }
472            // Get prediction from this tree for sample i
473            let sub_x = build_sub_matrix_cols_single(x, i, &forest_tree.feature_indices);
474            if let Ok(pred) = forest_tree.tree.predict(&sub_x) {
475                votes.push(pred[0]);
476            }
477        }
478        if !votes.is_empty() {
479            let oob_pred = majority_vote(&votes);
480            if (oob_pred - y[i]).abs() < F::from_f64(1e-9).unwrap() {
481                correct += 1;
482            }
483            evaluated += 1;
484        }
485    }
486
487    if evaluated > 0 {
488        Some(correct as f64 / evaluated as f64)
489    } else {
490        None
491    }
492}
493
494/// Build a single-row sub-matrix selecting specific columns for one sample.
495fn build_sub_matrix_cols_single<F: Float>(
496    x: &Array2<F>,
497    row: usize,
498    col_indices: &[usize],
499) -> Array2<F> {
500    let n_cols = col_indices.len();
501    let data: Vec<F> = col_indices.iter().map(|&ci| x[[row, ci]]).collect();
502    Array2::from_shape_vec((1, n_cols), data).expect("shape matches data length")
503}
504
505/// Return the class that appears most frequently in `votes`.
506/// Uses HashMap with f64 bit representation for O(1) lookup per vote.
507#[inline]
508fn majority_vote<F: Float>(votes: &[F]) -> F {
509    use std::collections::HashMap;
510    let mut counts: HashMap<u64, (F, usize)> = HashMap::new();
511    for &v in votes {
512        let key = v.to_f64().unwrap().to_bits();
513        counts.entry(key).and_modify(|e| e.1 += 1).or_insert((v, 1));
514    }
515    counts
516        .into_values()
517        .max_by_key(|&(_, count)| count)
518        .unwrap()
519        .0
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525    use approx::assert_abs_diff_eq;
526    use ndarray::array;
527
528    #[test]
529    fn test_basic_classification() {
530        let x = array![
531            [1.0, 0.0],
532            [2.0, 0.0],
533            [3.0, 0.0],
534            [10.0, 1.0],
535            [11.0, 1.0],
536            [12.0, 1.0]
537        ];
538        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
539
540        let rf = RandomForestClassifier {
541            n_estimators: 20,
542            max_depth: Some(3),
543            seed: 42,
544            ..Default::default()
545        };
546        let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
547
548        let preds = fitted.predict(&x).unwrap();
549        for (p, t) in preds.iter().zip(y.iter()) {
550            assert_abs_diff_eq!(*p, *t, epsilon = 1e-10);
551        }
552    }
553
554    #[test]
555    fn test_reproducibility() {
556        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
557        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
558
559        let rf = RandomForestClassifier {
560            n_estimators: 10,
561            seed: 123,
562            ..Default::default()
563        };
564
565        let fitted1: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
566        let fitted2: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
567
568        let preds1 = fitted1.predict(&x).unwrap();
569        let preds2 = fitted2.predict(&x).unwrap();
570
571        for (a, b) in preds1.iter().zip(preds2.iter()) {
572            assert_abs_diff_eq!(*a, *b, epsilon = 1e-15);
573        }
574    }
575
576    #[test]
577    fn test_max_features() {
578        let x = array![
579            [1.0, 100.0, 0.5],
580            [2.0, 200.0, 0.6],
581            [3.0, 300.0, 0.7],
582            [10.0, 400.0, 0.8],
583            [11.0, 500.0, 0.9],
584            [12.0, 600.0, 1.0]
585        ];
586        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
587
588        let rf = RandomForestClassifier {
589            n_estimators: 30,
590            max_features: Some(2),
591            seed: 99,
592            ..Default::default()
593        };
594        let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
595
596        // Training accuracy should be high
597        let preds = fitted.predict(&x).unwrap();
598        for (p, t) in preds.iter().zip(y.iter()) {
599            assert_abs_diff_eq!(*p, *t, epsilon = 1e-10);
600        }
601    }
602
603    #[test]
604    fn test_feature_importances_sum_to_one() {
605        let x = array![
606            [1.0, 100.0],
607            [2.0, 200.0],
608            [3.0, 300.0],
609            [4.0, 400.0],
610            [5.0, 500.0],
611            [6.0, 600.0]
612        ];
613        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
614
615        let rf = RandomForestClassifier {
616            n_estimators: 20,
617            seed: 7,
618            ..Default::default()
619        };
620        let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
621
622        let importances = fitted.feature_importances();
623        let sum: f64 = importances.iter().sum();
624        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
625    }
626
627    #[test]
628    fn test_n_estimators() {
629        let x = array![[1.0], [2.0], [3.0], [4.0]];
630        let y = array![0.0, 0.0, 1.0, 1.0];
631
632        let rf = RandomForestClassifier {
633            n_estimators: 7,
634            seed: 0,
635            ..Default::default()
636        };
637        let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
638        assert_eq!(fitted.n_estimators(), 7);
639    }
640
641    #[test]
642    fn test_shape_mismatch_error() {
643        let x = array![[1.0], [2.0]];
644        let y = array![0.0, 1.0, 2.0];
645
646        let rf = RandomForestClassifier::default();
647        let result: std::result::Result<FittedRandomForestClassifier<f64>, _> = rf.fit(&x, &y);
648        assert!(result.is_err());
649    }
650
651    #[test]
652    fn test_predict_wrong_features_error() {
653        let x = array![[1.0, 2.0], [3.0, 4.0]];
654        let y = array![0.0, 1.0];
655
656        let rf = RandomForestClassifier {
657            n_estimators: 5,
658            seed: 0,
659            ..Default::default()
660        };
661        let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
662
663        let x_bad = array![[1.0], [2.0]];
664        let result = fitted.predict(&x_bad);
665        assert!(result.is_err());
666    }
667
668    #[test]
669    fn test_invalid_max_features() {
670        let x = array![[1.0, 2.0], [3.0, 4.0]];
671        let y = array![0.0, 1.0];
672
673        let rf = RandomForestClassifier {
674            n_estimators: 5,
675            max_features: Some(5),
676            seed: 0,
677            ..Default::default()
678        };
679        let result: std::result::Result<FittedRandomForestClassifier<f64>, _> = rf.fit(&x, &y);
680        assert!(result.is_err());
681    }
682
683    #[test]
684    fn test_feature_importances_non_negative() {
685        let x = array![
686            [1.0, 100.0, 0.5],
687            [2.0, 200.0, 0.6],
688            [3.0, 300.0, 0.7],
689            [10.0, 400.0, 0.8],
690            [11.0, 500.0, 0.9],
691            [12.0, 600.0, 1.0]
692        ];
693        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
694
695        let rf = RandomForestClassifier {
696            n_estimators: 20,
697            seed: 7,
698            ..Default::default()
699        };
700        let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
701
702        let importances = fitted.feature_importances();
703        for &imp in importances.iter() {
704            assert!(
705                imp >= 0.0,
706                "feature importance must be non-negative, got {imp}"
707            );
708        }
709    }
710
711    #[test]
712    fn test_n_estimators_one() {
713        let x = array![
714            [1.0, 0.0],
715            [2.0, 0.0],
716            [3.0, 0.0],
717            [10.0, 1.0],
718            [11.0, 1.0],
719            [12.0, 1.0]
720        ];
721        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
722
723        let rf = RandomForestClassifier {
724            n_estimators: 1,
725            max_depth: Some(3),
726            seed: 42,
727            ..Default::default()
728        };
729        let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
730        assert_eq!(fitted.n_estimators(), 1);
731
732        // A single tree should still produce valid predictions.
733        let preds = fitted.predict(&x).unwrap();
734        assert_eq!(preds.len(), y.len());
735    }
736
737    #[test]
738    fn test_predictions_are_valid_labels() {
739        let x = array![
740            [1.0, 0.0],
741            [2.0, 0.0],
742            [3.0, 0.0],
743            [10.0, 1.0],
744            [11.0, 1.0],
745            [12.0, 1.0],
746            [20.0, 2.0],
747            [21.0, 2.0],
748            [22.0, 2.0]
749        ];
750        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0];
751
752        let rf = RandomForestClassifier {
753            n_estimators: 30,
754            max_depth: Some(5),
755            seed: 42,
756            ..Default::default()
757        };
758        let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
759
760        let preds = fitted.predict(&x).unwrap();
761        let valid_labels: std::collections::HashSet<u64> = y.iter().map(|v| v.to_bits()).collect();
762        for &p in preds.iter() {
763            assert!(
764                valid_labels.contains(&p.to_bits()),
765                "prediction {p} is not a valid training label"
766            );
767        }
768    }
769
770    #[test]
771    fn test_empty_input_error() {
772        let x: Array2<f64> = Array2::zeros((0, 2));
773        let y: Array1<f64> = Array1::zeros(0);
774
775        let rf = RandomForestClassifier::default();
776        let result: std::result::Result<FittedRandomForestClassifier<f64>, _> = rf.fit(&x, &y);
777        assert!(result.is_err());
778    }
779
780    #[test]
781    fn test_zero_estimators_error() {
782        let x = array![[1.0, 2.0], [3.0, 4.0]];
783        let y = array![0.0, 1.0];
784
785        let rf = RandomForestClassifier {
786            n_estimators: 0,
787            seed: 0,
788            ..Default::default()
789        };
790        let result: std::result::Result<FittedRandomForestClassifier<f64>, _> = rf.fit(&x, &y);
791        assert!(result.is_err());
792    }
793
794    mod prop_tests {
795        use super::*;
796        use proptest::prelude::*;
797        use std::collections::HashSet;
798
799        /// Generate deterministic training data for classification.
800        fn make_classification_data(
801            n_samples: usize,
802            n_features: usize,
803            n_classes: usize,
804            seed: u64,
805        ) -> (Array2<f64>, Array1<f64>) {
806            use std::collections::hash_map::DefaultHasher;
807            use std::hash::{Hash, Hasher};
808
809            let mut x_data = Vec::with_capacity(n_samples * n_features);
810            let mut y_data = Vec::with_capacity(n_samples);
811
812            for i in 0..n_samples {
813                for j in 0..n_features {
814                    let mut h = DefaultHasher::new();
815                    seed.hash(&mut h);
816                    (i as u64).hash(&mut h);
817                    (j as u64).hash(&mut h);
818                    let bits = h.finish();
819                    let v = (bits as f64 / u64::MAX as f64) * 20.0 - 10.0;
820                    x_data.push(v);
821                }
822                let mut h = DefaultHasher::new();
823                seed.hash(&mut h);
824                (i as u64).hash(&mut h);
825                0xDEAD_BEEFu64.hash(&mut h);
826                let label = (h.finish() % n_classes as u64) as f64;
827                y_data.push(label);
828            }
829
830            let x = Array2::from_shape_vec((n_samples, n_features), x_data).unwrap();
831            let y = Array1::from_vec(y_data);
832            (x, y)
833        }
834
835        proptest! {
836            #[test]
837            fn predictions_are_valid_labels(
838                n_samples in 6..30usize,
839                n_features in 1..5usize,
840                n_classes in 2..5usize,
841                seed in 0u64..1000,
842            ) {
843                let (x, y) = make_classification_data(n_samples, n_features, n_classes, seed);
844
845                let train_labels: HashSet<u64> = y.iter()
846                    .map(|&v| v.to_bits())
847                    .collect();
848
849                let rf = RandomForestClassifier {
850                    n_estimators: 10,
851                    max_depth: Some(5),
852                    seed: seed as u64,
853                    ..Default::default()
854                };
855                let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
856                let preds = fitted.predict(&x).unwrap();
857
858                for (i, &p) in preds.iter().enumerate() {
859                    prop_assert!(
860                        train_labels.contains(&p.to_bits()),
861                        "prediction {} at index {} is not a valid training label",
862                        p, i
863                    );
864                }
865            }
866
867            #[test]
868            fn feature_importances_sum_to_one(
869                n_samples in 6..30usize,
870                n_features in 1..5usize,
871                seed in 0u64..1000,
872            ) {
873                let n_classes = 3;
874                let (x, y) = make_classification_data(n_samples, n_features, n_classes, seed);
875
876                let rf = RandomForestClassifier {
877                    n_estimators: 10,
878                    max_depth: Some(5),
879                    seed: seed as u64,
880                    ..Default::default()
881                };
882                let fitted: FittedRandomForestClassifier<f64> = rf.fit(&x, &y).unwrap();
883                let importances = fitted.feature_importances();
884                let sum: f64 = importances.iter().sum();
885
886                // Valid outcomes: (1) importances are a valid probability
887                // distribution (sum to 1), or (2) every tree is a pure leaf
888                // with no splits (e.g. degenerate tiny input produces uniform
889                // labels) and all importances are zero.
890                prop_assert!(
891                    (sum - 1.0).abs() < 1e-10 || sum == 0.0,
892                    "feature importances sum to {} (expected ~1.0 or 0.0 for no-split case), n_samples={}, n_features={}, seed={}",
893                    sum, n_samples, n_features, seed
894                );
895                // Importances must always be non-negative.
896                for (i, &imp) in importances.iter().enumerate() {
897                    prop_assert!(
898                        imp >= 0.0,
899                        "importance[{}] = {} is negative",
900                        i, imp
901                    );
902                }
903            }
904        }
905    }
906}
907
908impl<F: Float> PredictProba<F> for FittedRandomForestClassifier<F> {
909    fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>> {
910        Self::predict_proba(self, x)
911    }
912}