Skip to main content

forestfire_core/
forest.rs

1//! Random forest implementation.
2//!
3//! The forest deliberately reuses the single-tree trainers instead of
4//! maintaining a separate tree-building codepath. That keeps semantics aligned:
5//! every constituent tree is still "just" a normal ForestFire tree trained on a
6//! sampled table view with per-node feature subsampling.
7
8use crate::bootstrap::BootstrapSampler;
9use crate::ir::TrainingMetadata;
10use crate::tree::shared::mix_seed;
11use crate::{
12    Criterion, FeaturePreprocessing, MaxFeatures, Model, Parallelism, PredictError, Task,
13    TrainError, TreeType, capture_feature_preprocessing, training,
14};
15use forestfire_data::TableAccess;
16use rayon::prelude::*;
17
18/// Bagged ensemble of decision trees.
19///
20/// The forest stores full semantic [`Model`] trees rather than a bespoke forest
21/// node format. That costs some memory, but it keeps IR conversion,
22/// introspection, and optimized lowering consistent with the single-tree path.
23#[derive(Debug, Clone)]
24pub struct RandomForest {
25    task: Task,
26    criterion: Criterion,
27    tree_type: TreeType,
28    trees: Vec<Model>,
29    compute_oob: bool,
30    oob_score: Option<f64>,
31    max_features: usize,
32    seed: Option<u64>,
33    num_features: usize,
34    feature_preprocessing: Vec<FeaturePreprocessing>,
35}
36
37struct TrainedTree {
38    model: Model,
39    oob_rows: Vec<usize>,
40}
41
42struct SampledTable<'a> {
43    base: &'a dyn TableAccess,
44    row_indices: Vec<usize>,
45}
46
47struct NoCanaryTable<'a> {
48    base: &'a dyn TableAccess,
49}
50
51impl RandomForest {
52    #[allow(clippy::too_many_arguments)]
53    pub fn new(
54        task: Task,
55        criterion: Criterion,
56        tree_type: TreeType,
57        trees: Vec<Model>,
58        compute_oob: bool,
59        oob_score: Option<f64>,
60        max_features: usize,
61        seed: Option<u64>,
62        num_features: usize,
63        feature_preprocessing: Vec<FeaturePreprocessing>,
64    ) -> Self {
65        Self {
66            task,
67            criterion,
68            tree_type,
69            trees,
70            compute_oob,
71            oob_score,
72            max_features,
73            seed,
74            num_features,
75            feature_preprocessing,
76        }
77    }
78
79    pub(crate) fn train(
80        train_set: &dyn TableAccess,
81        config: training::RandomForestConfig,
82        criterion: Criterion,
83        parallelism: Parallelism,
84    ) -> Result<Self, TrainError> {
85        let n_trees = config.n_trees;
86        if n_trees == 0 {
87            return Err(TrainError::InvalidTreeCount(n_trees));
88        }
89        if matches!(config.max_features, MaxFeatures::Count(0)) {
90            return Err(TrainError::InvalidMaxFeatures(0));
91        }
92
93        // Forests intentionally ignore canaries. For standalone trees they act as
94        // a regularization/stopping heuristic, but in a forest they would make
95        // bootstrap replicas overly conservative and reduce ensemble diversity.
96        let train_set = NoCanaryTable::new(train_set);
97        let sampler = BootstrapSampler::new(train_set.n_rows());
98        let feature_preprocessing = capture_feature_preprocessing(&train_set);
99        let max_features = config
100            .max_features
101            .resolve(config.task, train_set.binned_feature_count());
102        let base_seed = config.seed.unwrap_or(0x0005_EEDF_0E57_u64);
103        let tree_parallelism = Parallelism {
104            thread_count: parallelism.thread_count,
105        };
106        let per_tree_parallelism = Parallelism::sequential();
107        let train_tree = |tree_index: usize| -> Result<TrainedTree, TrainError> {
108            let tree_seed = mix_seed(base_seed, tree_index as u64);
109            let (sampled_rows, oob_rows) = sampler.sample_with_oob(tree_seed);
110            // Sampling is implemented as a `TableAccess` view so the existing
111            // tree trainers can stay oblivious to bootstrap mechanics.
112            let sampled_table = SampledTable::new(&train_set, sampled_rows);
113            let model = training::train_single_model_with_feature_subset(
114                &sampled_table,
115                training::SingleModelFeatureSubsetConfig {
116                    base: training::SingleModelConfig {
117                        task: config.task,
118                        tree_type: config.tree_type,
119                        criterion,
120                        parallelism: per_tree_parallelism,
121                        max_depth: config.max_depth,
122                        min_samples_split: config.min_samples_split,
123                        min_samples_leaf: config.min_samples_leaf,
124                        missing_value_strategies: config.missing_value_strategies.clone(),
125                    },
126                    max_features: Some(max_features),
127                    random_seed: tree_seed,
128                },
129            )?;
130            Ok(TrainedTree { model, oob_rows })
131        };
132        let trained_trees = if tree_parallelism.enabled() {
133            (0..n_trees)
134                .into_par_iter()
135                .map(train_tree)
136                .collect::<Result<Vec<_>, _>>()?
137        } else {
138            (0..n_trees)
139                .map(train_tree)
140                .collect::<Result<Vec<_>, _>>()?
141        };
142        let oob_score = if config.compute_oob {
143            compute_oob_score(config.task, &trained_trees, &train_set)
144        } else {
145            None
146        };
147        let trees = trained_trees.into_iter().map(|tree| tree.model).collect();
148
149        Ok(Self::new(
150            config.task,
151            criterion,
152            config.tree_type,
153            trees,
154            config.compute_oob,
155            oob_score,
156            max_features,
157            config.seed,
158            train_set.n_features(),
159            feature_preprocessing,
160        ))
161    }
162
163    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
164        match self.task {
165            Task::Regression => self.predict_regression_table(table),
166            Task::Classification => self.predict_classification_table(table),
167        }
168    }
169
170    pub fn predict_proba_table(
171        &self,
172        table: &dyn TableAccess,
173    ) -> Result<Vec<Vec<f64>>, PredictError> {
174        if self.task != Task::Classification {
175            return Err(PredictError::ProbabilityPredictionRequiresClassification);
176        }
177
178        // Forest classification aggregates full class distributions rather than
179        // voting on hard labels. This keeps `predict` and `predict_proba`
180        // consistent and matches the optimized runtime lowering.
181        let mut totals = self.trees[0].predict_proba_table(table)?;
182        for tree in &self.trees[1..] {
183            let probs = tree.predict_proba_table(table)?;
184            for (row_totals, row_probs) in totals.iter_mut().zip(probs.iter()) {
185                for (total, prob) in row_totals.iter_mut().zip(row_probs.iter()) {
186                    *total += *prob;
187                }
188            }
189        }
190
191        let tree_count = self.trees.len() as f64;
192        for row in &mut totals {
193            for value in row {
194                *value /= tree_count;
195            }
196        }
197
198        Ok(totals)
199    }
200
201    pub fn task(&self) -> Task {
202        self.task
203    }
204
205    pub fn criterion(&self) -> Criterion {
206        self.criterion
207    }
208
209    pub fn tree_type(&self) -> TreeType {
210        self.tree_type
211    }
212
213    pub fn trees(&self) -> &[Model] {
214        &self.trees
215    }
216
217    pub fn num_features(&self) -> usize {
218        self.num_features
219    }
220
221    pub fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
222        &self.feature_preprocessing
223    }
224
225    pub fn training_metadata(&self) -> TrainingMetadata {
226        let mut metadata = self.trees[0].training_metadata();
227        metadata.algorithm = "rf".to_string();
228        metadata.n_trees = Some(self.trees.len());
229        metadata.max_features = Some(self.max_features);
230        metadata.seed = self.seed;
231        metadata.compute_oob = self.compute_oob;
232        metadata.oob_score = self.oob_score;
233        metadata.learning_rate = None;
234        metadata.bootstrap = None;
235        metadata.top_gradient_fraction = None;
236        metadata.other_gradient_fraction = None;
237        metadata
238    }
239
240    pub fn class_labels(&self) -> Option<Vec<f64>> {
241        match self.task {
242            Task::Classification => self.trees[0].class_labels(),
243            Task::Regression => None,
244        }
245    }
246
247    pub fn oob_score(&self) -> Option<f64> {
248        self.oob_score
249    }
250
251    fn predict_regression_table(&self, table: &dyn TableAccess) -> Vec<f64> {
252        let mut totals = self.trees[0].predict_table(table);
253        for tree in &self.trees[1..] {
254            let preds = tree.predict_table(table);
255            for (total, pred) in totals.iter_mut().zip(preds.iter()) {
256                *total += *pred;
257            }
258        }
259
260        let tree_count = self.trees.len() as f64;
261        for value in &mut totals {
262            *value /= tree_count;
263        }
264
265        totals
266    }
267
268    fn predict_classification_table(&self, table: &dyn TableAccess) -> Vec<f64> {
269        let probabilities = self
270            .predict_proba_table(table)
271            .expect("classification forest supports probabilities");
272        let class_labels = self
273            .class_labels()
274            .expect("classification forest stores class labels");
275
276        probabilities
277            .into_iter()
278            .map(|row| {
279                let (best_index, _) = row
280                    .iter()
281                    .copied()
282                    .enumerate()
283                    .max_by(|(left_index, left), (right_index, right)| {
284                        left.total_cmp(right)
285                            .then_with(|| right_index.cmp(left_index))
286                    })
287                    .expect("classification probability row is non-empty");
288                class_labels[best_index]
289            })
290            .collect()
291    }
292}
293
294fn compute_oob_score(
295    task: Task,
296    trained_trees: &[TrainedTree],
297    train_set: &dyn TableAccess,
298) -> Option<f64> {
299    match task {
300        Task::Classification => compute_classification_oob_score(trained_trees, train_set),
301        Task::Regression => compute_regression_oob_score(trained_trees, train_set),
302    }
303}
304
305fn compute_classification_oob_score(
306    trained_trees: &[TrainedTree],
307    train_set: &dyn TableAccess,
308) -> Option<f64> {
309    let class_labels = trained_trees.first()?.model.class_labels()?;
310    let mut totals = vec![vec![0.0; class_labels.len()]; train_set.n_rows()];
311    let mut counts = vec![0usize; train_set.n_rows()];
312
313    for tree in trained_trees {
314        if tree.oob_rows.is_empty() {
315            continue;
316        }
317        let oob_table = SampledTable::new(train_set, tree.oob_rows.clone());
318        let probabilities = tree
319            .model
320            .predict_proba_table(&oob_table)
321            .expect("classification tree supports predict_proba");
322        for (&row_index, row_probs) in tree.oob_rows.iter().zip(probabilities.iter()) {
323            for (total, prob) in totals[row_index].iter_mut().zip(row_probs.iter()) {
324                *total += *prob;
325            }
326            counts[row_index] += 1;
327        }
328    }
329
330    let mut correct = 0usize;
331    let mut covered = 0usize;
332    for row_index in 0..train_set.n_rows() {
333        if counts[row_index] == 0 {
334            continue;
335        }
336        covered += 1;
337        let predicted = totals[row_index]
338            .iter()
339            .copied()
340            .enumerate()
341            .max_by(|(li, lv), (ri, rv)| lv.total_cmp(rv).then_with(|| ri.cmp(li)))
342            .map(|(index, _)| class_labels[index])
343            .expect("classification probability row is non-empty");
344        if predicted
345            .total_cmp(&train_set.target_value(row_index))
346            .is_eq()
347        {
348            correct += 1;
349        }
350    }
351
352    (covered > 0).then_some(correct as f64 / covered as f64)
353}
354
355fn compute_regression_oob_score(
356    trained_trees: &[TrainedTree],
357    train_set: &dyn TableAccess,
358) -> Option<f64> {
359    let mut totals = vec![0.0; train_set.n_rows()];
360    let mut counts = vec![0usize; train_set.n_rows()];
361
362    for tree in trained_trees {
363        if tree.oob_rows.is_empty() {
364            continue;
365        }
366        let oob_table = SampledTable::new(train_set, tree.oob_rows.clone());
367        let predictions = tree.model.predict_table(&oob_table);
368        for (&row_index, prediction) in tree.oob_rows.iter().zip(predictions.iter().copied()) {
369            totals[row_index] += prediction;
370            counts[row_index] += 1;
371        }
372    }
373
374    let covered_rows: Vec<usize> = counts
375        .iter()
376        .enumerate()
377        .filter_map(|(row_index, count)| (*count > 0).then_some(row_index))
378        .collect();
379    if covered_rows.is_empty() {
380        return None;
381    }
382
383    let mean_target = covered_rows
384        .iter()
385        .map(|row_index| train_set.target_value(*row_index))
386        .sum::<f64>()
387        / covered_rows.len() as f64;
388    let mut residual_sum = 0.0;
389    let mut total_sum = 0.0;
390    for row_index in covered_rows {
391        let actual = train_set.target_value(row_index);
392        let prediction = totals[row_index] / counts[row_index] as f64;
393        residual_sum += (actual - prediction).powi(2);
394        total_sum += (actual - mean_target).powi(2);
395    }
396    if total_sum == 0.0 {
397        return None;
398    }
399    Some(1.0 - residual_sum / total_sum)
400}
401
402impl<'a> SampledTable<'a> {
403    fn new(base: &'a dyn TableAccess, row_indices: Vec<usize>) -> Self {
404        Self { base, row_indices }
405    }
406
407    fn resolve_row(&self, row_index: usize) -> usize {
408        self.row_indices[row_index]
409    }
410}
411
412impl<'a> NoCanaryTable<'a> {
413    fn new(base: &'a dyn TableAccess) -> Self {
414        Self { base }
415    }
416}
417
418impl TableAccess for SampledTable<'_> {
419    fn n_rows(&self) -> usize {
420        self.row_indices.len()
421    }
422
423    fn n_features(&self) -> usize {
424        self.base.n_features()
425    }
426
427    fn canaries(&self) -> usize {
428        self.base.canaries()
429    }
430
431    fn numeric_bin_cap(&self) -> usize {
432        self.base.numeric_bin_cap()
433    }
434
435    fn binned_feature_count(&self) -> usize {
436        self.base.binned_feature_count()
437    }
438
439    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
440        self.base
441            .feature_value(feature_index, self.resolve_row(row_index))
442    }
443
444    fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
445        self.base
446            .is_missing(feature_index, self.resolve_row(row_index))
447    }
448
449    fn is_binary_feature(&self, index: usize) -> bool {
450        self.base.is_binary_feature(index)
451    }
452
453    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
454        self.base
455            .binned_value(feature_index, self.resolve_row(row_index))
456    }
457
458    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
459        self.base
460            .binned_boolean_value(feature_index, self.resolve_row(row_index))
461    }
462
463    fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
464        self.base.binned_column_kind(index)
465    }
466
467    fn is_binary_binned_feature(&self, index: usize) -> bool {
468        self.base.is_binary_binned_feature(index)
469    }
470
471    fn target_value(&self, row_index: usize) -> f64 {
472        self.base.target_value(self.resolve_row(row_index))
473    }
474}
475
476impl TableAccess for NoCanaryTable<'_> {
477    fn n_rows(&self) -> usize {
478        self.base.n_rows()
479    }
480
481    fn n_features(&self) -> usize {
482        self.base.n_features()
483    }
484
485    fn canaries(&self) -> usize {
486        0
487    }
488
489    fn numeric_bin_cap(&self) -> usize {
490        self.base.numeric_bin_cap()
491    }
492
493    fn binned_feature_count(&self) -> usize {
494        self.base.binned_feature_count() - self.base.canaries()
495    }
496
497    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
498        self.base.feature_value(feature_index, row_index)
499    }
500
501    fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
502        self.base.is_missing(feature_index, row_index)
503    }
504
505    fn is_binary_feature(&self, index: usize) -> bool {
506        self.base.is_binary_feature(index)
507    }
508
509    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
510        self.base.binned_value(feature_index, row_index)
511    }
512
513    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
514        self.base.binned_boolean_value(feature_index, row_index)
515    }
516
517    fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
518        self.base.binned_column_kind(index)
519    }
520
521    fn is_binary_binned_feature(&self, index: usize) -> bool {
522        self.base.is_binary_binned_feature(index)
523    }
524
525    fn target_value(&self, row_index: usize) -> f64 {
526        self.base.target_value(row_index)
527    }
528}