Skip to main content

forestfire_core/
forest.rs

1//! Random forest implementation.
2//!
3//! The forest deliberately reuses the single-tree trainers instead of
4//! maintaining a separate tree-building codepath. That keeps semantics aligned:
5//! every constituent tree is still "just" a normal ForestFire tree trained on a
6//! sampled table view with per-node feature subsampling.
7
8use crate::bootstrap::BootstrapSampler;
9use crate::ir::TrainingMetadata;
10use crate::tree::shared::mix_seed;
11use crate::{
12    Criterion, FeaturePreprocessing, MaxFeatures, Model, Parallelism, PredictError, Task,
13    TrainError, TreeType, capture_feature_preprocessing, training,
14};
15use forestfire_data::TableAccess;
16use rayon::prelude::*;
17
18/// Bagged ensemble of decision trees.
19///
20/// The forest stores full semantic [`Model`] trees rather than a bespoke forest
21/// node format. That costs some memory, but it keeps IR conversion,
22/// introspection, and optimized lowering consistent with the single-tree path.
23#[derive(Debug, Clone)]
24pub struct RandomForest {
25    task: Task,
26    criterion: Criterion,
27    tree_type: TreeType,
28    trees: Vec<Model>,
29    compute_oob: bool,
30    oob_score: Option<f64>,
31    max_features: usize,
32    seed: Option<u64>,
33    num_features: usize,
34    feature_preprocessing: Vec<FeaturePreprocessing>,
35}
36
37struct TrainedTree {
38    model: Model,
39    oob_rows: Vec<usize>,
40}
41
42struct SampledTable<'a> {
43    base: &'a dyn TableAccess,
44    row_indices: Vec<usize>,
45}
46
47struct NoCanaryTable<'a> {
48    base: &'a dyn TableAccess,
49}
50
51impl RandomForest {
52    #[allow(clippy::too_many_arguments)]
53    pub fn new(
54        task: Task,
55        criterion: Criterion,
56        tree_type: TreeType,
57        trees: Vec<Model>,
58        compute_oob: bool,
59        oob_score: Option<f64>,
60        max_features: usize,
61        seed: Option<u64>,
62        num_features: usize,
63        feature_preprocessing: Vec<FeaturePreprocessing>,
64    ) -> Self {
65        Self {
66            task,
67            criterion,
68            tree_type,
69            trees,
70            compute_oob,
71            oob_score,
72            max_features,
73            seed,
74            num_features,
75            feature_preprocessing,
76        }
77    }
78
79    pub(crate) fn train(
80        train_set: &dyn TableAccess,
81        config: training::RandomForestConfig,
82        criterion: Criterion,
83        parallelism: Parallelism,
84    ) -> Result<Self, TrainError> {
85        let n_trees = config.n_trees;
86        if n_trees == 0 {
87            return Err(TrainError::InvalidTreeCount(n_trees));
88        }
89        if matches!(config.max_features, MaxFeatures::Count(0)) {
90            return Err(TrainError::InvalidMaxFeatures(0));
91        }
92
93        // Forests intentionally ignore canaries. For standalone trees they act as
94        // a regularization/stopping heuristic, but in a forest they would make
95        // bootstrap replicas overly conservative and reduce ensemble diversity.
96        let train_set = NoCanaryTable::new(train_set);
97        let sampler = BootstrapSampler::new(train_set.n_rows());
98        let feature_preprocessing = capture_feature_preprocessing(&train_set);
99        let max_features = config
100            .max_features
101            .resolve(config.task, train_set.binned_feature_count());
102        let base_seed = config.seed.unwrap_or(0x0005_EEDF_0E57_u64);
103        let tree_parallelism = Parallelism {
104            thread_count: parallelism.thread_count,
105        };
106        let per_tree_parallelism = Parallelism::sequential();
107        let train_tree = |tree_index: usize| -> Result<TrainedTree, TrainError> {
108            let tree_seed = mix_seed(base_seed, tree_index as u64);
109            let (sampled_rows, oob_rows) = sampler.sample_with_oob(tree_seed);
110            // Sampling is implemented as a `TableAccess` view so the existing
111            // tree trainers can stay oblivious to bootstrap mechanics.
112            let sampled_table = SampledTable::new(&train_set, sampled_rows);
113            let model = training::train_single_model_with_feature_subset(
114                &sampled_table,
115                training::SingleModelFeatureSubsetConfig {
116                    base: training::SingleModelConfig {
117                        task: config.task,
118                        tree_type: config.tree_type,
119                        criterion,
120                        parallelism: per_tree_parallelism,
121                        max_depth: config.max_depth,
122                        min_samples_split: config.min_samples_split,
123                        min_samples_leaf: config.min_samples_leaf,
124                        missing_value_strategies: config.missing_value_strategies.clone(),
125                        canary_filter: crate::CanaryFilter::default(),
126                    },
127                    max_features: Some(max_features),
128                    random_seed: tree_seed,
129                },
130            )?;
131            Ok(TrainedTree { model, oob_rows })
132        };
133        let trained_trees = if tree_parallelism.enabled() {
134            (0..n_trees)
135                .into_par_iter()
136                .map(train_tree)
137                .collect::<Result<Vec<_>, _>>()?
138        } else {
139            (0..n_trees)
140                .map(train_tree)
141                .collect::<Result<Vec<_>, _>>()?
142        };
143        let oob_score = if config.compute_oob {
144            compute_oob_score(config.task, &trained_trees, &train_set)
145        } else {
146            None
147        };
148        let trees = trained_trees.into_iter().map(|tree| tree.model).collect();
149
150        Ok(Self::new(
151            config.task,
152            criterion,
153            config.tree_type,
154            trees,
155            config.compute_oob,
156            oob_score,
157            max_features,
158            config.seed,
159            train_set.n_features(),
160            feature_preprocessing,
161        ))
162    }
163
164    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
165        match self.task {
166            Task::Regression => self.predict_regression_table(table),
167            Task::Classification => self.predict_classification_table(table),
168        }
169    }
170
171    pub fn predict_proba_table(
172        &self,
173        table: &dyn TableAccess,
174    ) -> Result<Vec<Vec<f64>>, PredictError> {
175        if self.task != Task::Classification {
176            return Err(PredictError::ProbabilityPredictionRequiresClassification);
177        }
178
179        // Forest classification aggregates full class distributions rather than
180        // voting on hard labels. This keeps `predict` and `predict_proba`
181        // consistent and matches the optimized runtime lowering.
182        let mut totals = self.trees[0].predict_proba_table(table)?;
183        for tree in &self.trees[1..] {
184            let probs = tree.predict_proba_table(table)?;
185            for (row_totals, row_probs) in totals.iter_mut().zip(probs.iter()) {
186                for (total, prob) in row_totals.iter_mut().zip(row_probs.iter()) {
187                    *total += *prob;
188                }
189            }
190        }
191
192        let tree_count = self.trees.len() as f64;
193        for row in &mut totals {
194            for value in row {
195                *value /= tree_count;
196            }
197        }
198
199        Ok(totals)
200    }
201
202    pub fn task(&self) -> Task {
203        self.task
204    }
205
206    pub fn criterion(&self) -> Criterion {
207        self.criterion
208    }
209
210    pub fn tree_type(&self) -> TreeType {
211        self.tree_type
212    }
213
214    pub fn trees(&self) -> &[Model] {
215        &self.trees
216    }
217
218    pub fn num_features(&self) -> usize {
219        self.num_features
220    }
221
222    pub fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
223        &self.feature_preprocessing
224    }
225
226    pub fn training_metadata(&self) -> TrainingMetadata {
227        let mut metadata = self.trees[0].training_metadata();
228        metadata.algorithm = "rf".to_string();
229        metadata.n_trees = Some(self.trees.len());
230        metadata.max_features = Some(self.max_features);
231        metadata.seed = self.seed;
232        metadata.compute_oob = self.compute_oob;
233        metadata.oob_score = self.oob_score;
234        metadata.learning_rate = None;
235        metadata.bootstrap = None;
236        metadata.top_gradient_fraction = None;
237        metadata.other_gradient_fraction = None;
238        metadata
239    }
240
241    pub fn class_labels(&self) -> Option<Vec<f64>> {
242        match self.task {
243            Task::Classification => self.trees[0].class_labels(),
244            Task::Regression => None,
245        }
246    }
247
248    pub fn oob_score(&self) -> Option<f64> {
249        self.oob_score
250    }
251
252    fn predict_regression_table(&self, table: &dyn TableAccess) -> Vec<f64> {
253        let mut totals = self.trees[0].predict_table(table);
254        for tree in &self.trees[1..] {
255            let preds = tree.predict_table(table);
256            for (total, pred) in totals.iter_mut().zip(preds.iter()) {
257                *total += *pred;
258            }
259        }
260
261        let tree_count = self.trees.len() as f64;
262        for value in &mut totals {
263            *value /= tree_count;
264        }
265
266        totals
267    }
268
269    fn predict_classification_table(&self, table: &dyn TableAccess) -> Vec<f64> {
270        let probabilities = self
271            .predict_proba_table(table)
272            .expect("classification forest supports probabilities");
273        let class_labels = self
274            .class_labels()
275            .expect("classification forest stores class labels");
276
277        probabilities
278            .into_iter()
279            .map(|row| {
280                let (best_index, _) = row
281                    .iter()
282                    .copied()
283                    .enumerate()
284                    .max_by(|(left_index, left), (right_index, right)| {
285                        left.total_cmp(right)
286                            .then_with(|| right_index.cmp(left_index))
287                    })
288                    .expect("classification probability row is non-empty");
289                class_labels[best_index]
290            })
291            .collect()
292    }
293}
294
295fn compute_oob_score(
296    task: Task,
297    trained_trees: &[TrainedTree],
298    train_set: &dyn TableAccess,
299) -> Option<f64> {
300    match task {
301        Task::Classification => compute_classification_oob_score(trained_trees, train_set),
302        Task::Regression => compute_regression_oob_score(trained_trees, train_set),
303    }
304}
305
306fn compute_classification_oob_score(
307    trained_trees: &[TrainedTree],
308    train_set: &dyn TableAccess,
309) -> Option<f64> {
310    let class_labels = trained_trees.first()?.model.class_labels()?;
311    let mut totals = vec![vec![0.0; class_labels.len()]; train_set.n_rows()];
312    let mut counts = vec![0usize; train_set.n_rows()];
313
314    for tree in trained_trees {
315        if tree.oob_rows.is_empty() {
316            continue;
317        }
318        let oob_table = SampledTable::new(train_set, tree.oob_rows.clone());
319        let probabilities = tree
320            .model
321            .predict_proba_table(&oob_table)
322            .expect("classification tree supports predict_proba");
323        for (&row_index, row_probs) in tree.oob_rows.iter().zip(probabilities.iter()) {
324            for (total, prob) in totals[row_index].iter_mut().zip(row_probs.iter()) {
325                *total += *prob;
326            }
327            counts[row_index] += 1;
328        }
329    }
330
331    let mut correct = 0usize;
332    let mut covered = 0usize;
333    for row_index in 0..train_set.n_rows() {
334        if counts[row_index] == 0 {
335            continue;
336        }
337        covered += 1;
338        let predicted = totals[row_index]
339            .iter()
340            .copied()
341            .enumerate()
342            .max_by(|(li, lv), (ri, rv)| lv.total_cmp(rv).then_with(|| ri.cmp(li)))
343            .map(|(index, _)| class_labels[index])
344            .expect("classification probability row is non-empty");
345        if predicted
346            .total_cmp(&train_set.target_value(row_index))
347            .is_eq()
348        {
349            correct += 1;
350        }
351    }
352
353    (covered > 0).then_some(correct as f64 / covered as f64)
354}
355
356fn compute_regression_oob_score(
357    trained_trees: &[TrainedTree],
358    train_set: &dyn TableAccess,
359) -> Option<f64> {
360    let mut totals = vec![0.0; train_set.n_rows()];
361    let mut counts = vec![0usize; train_set.n_rows()];
362
363    for tree in trained_trees {
364        if tree.oob_rows.is_empty() {
365            continue;
366        }
367        let oob_table = SampledTable::new(train_set, tree.oob_rows.clone());
368        let predictions = tree.model.predict_table(&oob_table);
369        for (&row_index, prediction) in tree.oob_rows.iter().zip(predictions.iter().copied()) {
370            totals[row_index] += prediction;
371            counts[row_index] += 1;
372        }
373    }
374
375    let covered_rows: Vec<usize> = counts
376        .iter()
377        .enumerate()
378        .filter_map(|(row_index, count)| (*count > 0).then_some(row_index))
379        .collect();
380    if covered_rows.is_empty() {
381        return None;
382    }
383
384    let mean_target = covered_rows
385        .iter()
386        .map(|row_index| train_set.target_value(*row_index))
387        .sum::<f64>()
388        / covered_rows.len() as f64;
389    let mut residual_sum = 0.0;
390    let mut total_sum = 0.0;
391    for row_index in covered_rows {
392        let actual = train_set.target_value(row_index);
393        let prediction = totals[row_index] / counts[row_index] as f64;
394        residual_sum += (actual - prediction).powi(2);
395        total_sum += (actual - mean_target).powi(2);
396    }
397    if total_sum == 0.0 {
398        return None;
399    }
400    Some(1.0 - residual_sum / total_sum)
401}
402
403impl<'a> SampledTable<'a> {
404    fn new(base: &'a dyn TableAccess, row_indices: Vec<usize>) -> Self {
405        Self { base, row_indices }
406    }
407
408    fn resolve_row(&self, row_index: usize) -> usize {
409        self.row_indices[row_index]
410    }
411}
412
413impl<'a> NoCanaryTable<'a> {
414    fn new(base: &'a dyn TableAccess) -> Self {
415        Self { base }
416    }
417}
418
419impl TableAccess for SampledTable<'_> {
420    fn n_rows(&self) -> usize {
421        self.row_indices.len()
422    }
423
424    fn n_features(&self) -> usize {
425        self.base.n_features()
426    }
427
428    fn canaries(&self) -> usize {
429        self.base.canaries()
430    }
431
432    fn numeric_bin_cap(&self) -> usize {
433        self.base.numeric_bin_cap()
434    }
435
436    fn binned_feature_count(&self) -> usize {
437        self.base.binned_feature_count()
438    }
439
440    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
441        self.base
442            .feature_value(feature_index, self.resolve_row(row_index))
443    }
444
445    fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
446        self.base
447            .is_missing(feature_index, self.resolve_row(row_index))
448    }
449
450    fn is_binary_feature(&self, index: usize) -> bool {
451        self.base.is_binary_feature(index)
452    }
453
454    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
455        self.base
456            .binned_value(feature_index, self.resolve_row(row_index))
457    }
458
459    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
460        self.base
461            .binned_boolean_value(feature_index, self.resolve_row(row_index))
462    }
463
464    fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
465        self.base.binned_column_kind(index)
466    }
467
468    fn is_binary_binned_feature(&self, index: usize) -> bool {
469        self.base.is_binary_binned_feature(index)
470    }
471
472    fn target_value(&self, row_index: usize) -> f64 {
473        self.base.target_value(self.resolve_row(row_index))
474    }
475}
476
477impl TableAccess for NoCanaryTable<'_> {
478    fn n_rows(&self) -> usize {
479        self.base.n_rows()
480    }
481
482    fn n_features(&self) -> usize {
483        self.base.n_features()
484    }
485
486    fn canaries(&self) -> usize {
487        0
488    }
489
490    fn numeric_bin_cap(&self) -> usize {
491        self.base.numeric_bin_cap()
492    }
493
494    fn binned_feature_count(&self) -> usize {
495        self.base.binned_feature_count() - self.base.canaries()
496    }
497
498    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
499        self.base.feature_value(feature_index, row_index)
500    }
501
502    fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
503        self.base.is_missing(feature_index, row_index)
504    }
505
506    fn is_binary_feature(&self, index: usize) -> bool {
507        self.base.is_binary_feature(index)
508    }
509
510    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
511        self.base.binned_value(feature_index, row_index)
512    }
513
514    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
515        self.base.binned_boolean_value(feature_index, row_index)
516    }
517
518    fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
519        self.base.binned_column_kind(index)
520    }
521
522    fn is_binary_binned_feature(&self, index: usize) -> bool {
523        self.base.is_binary_binned_feature(index)
524    }
525
526    fn target_value(&self, row_index: usize) -> f64 {
527        self.base.target_value(row_index)
528    }
529}