Skip to main content

forestfire_core/
forest.rs

1//! Random forest implementation.
2//!
3//! The forest deliberately reuses the single-tree trainers instead of
4//! maintaining a separate tree-building codepath. That keeps semantics aligned:
5//! every constituent tree is still "just" a normal ForestFire tree trained on a
6//! sampled table view with per-node feature subsampling.
7
8use crate::bootstrap::BootstrapSampler;
9use crate::ir::TrainingMetadata;
10use crate::tree::shared::mix_seed;
11use crate::{
12    Criterion, FeaturePreprocessing, MaxFeatures, Model, Parallelism, PredictError, Task,
13    TrainConfig, TrainError, TreeType, capture_feature_preprocessing, training,
14};
15use forestfire_data::TableAccess;
16use rayon::prelude::*;
17
18/// Bagged ensemble of decision trees.
19///
20/// The forest stores full semantic [`Model`] trees rather than a bespoke forest
21/// node format. That costs some memory, but it keeps IR conversion,
22/// introspection, and optimized lowering consistent with the single-tree path.
23#[derive(Debug, Clone)]
24pub struct RandomForest {
25    task: Task,
26    criterion: Criterion,
27    tree_type: TreeType,
28    trees: Vec<Model>,
29    compute_oob: bool,
30    oob_score: Option<f64>,
31    max_features: usize,
32    seed: Option<u64>,
33    num_features: usize,
34    feature_preprocessing: Vec<FeaturePreprocessing>,
35}
36
37struct TrainedTree {
38    model: Model,
39    oob_rows: Vec<usize>,
40}
41
42struct SampledTable<'a> {
43    base: &'a dyn TableAccess,
44    row_indices: Vec<usize>,
45}
46
47struct NoCanaryTable<'a> {
48    base: &'a dyn TableAccess,
49}
50
51impl RandomForest {
52    #[allow(clippy::too_many_arguments)]
53    pub fn new(
54        task: Task,
55        criterion: Criterion,
56        tree_type: TreeType,
57        trees: Vec<Model>,
58        compute_oob: bool,
59        oob_score: Option<f64>,
60        max_features: usize,
61        seed: Option<u64>,
62        num_features: usize,
63        feature_preprocessing: Vec<FeaturePreprocessing>,
64    ) -> Self {
65        Self {
66            task,
67            criterion,
68            tree_type,
69            trees,
70            compute_oob,
71            oob_score,
72            max_features,
73            seed,
74            num_features,
75            feature_preprocessing,
76        }
77    }
78
79    pub(crate) fn train(
80        train_set: &dyn TableAccess,
81        config: TrainConfig,
82        criterion: Criterion,
83        parallelism: Parallelism,
84    ) -> Result<Self, TrainError> {
85        let n_trees = config.n_trees.unwrap_or(1000);
86        if n_trees == 0 {
87            return Err(TrainError::InvalidTreeCount(n_trees));
88        }
89        if matches!(config.max_features, MaxFeatures::Count(0)) {
90            return Err(TrainError::InvalidMaxFeatures(0));
91        }
92
93        // Forests intentionally ignore canaries. For standalone trees they act as
94        // a regularization/stopping heuristic, but in a forest they would make
95        // bootstrap replicas overly conservative and reduce ensemble diversity.
96        let train_set = NoCanaryTable::new(train_set);
97        let sampler = BootstrapSampler::new(train_set.n_rows());
98        let feature_preprocessing = capture_feature_preprocessing(&train_set);
99        let max_features = config
100            .max_features
101            .resolve(config.task, train_set.binned_feature_count());
102        let base_seed = config.seed.unwrap_or(0x0005_EEDF_0E57_u64);
103        let tree_parallelism = Parallelism {
104            thread_count: parallelism.thread_count,
105        };
106        let per_tree_parallelism = Parallelism::sequential();
107        let train_tree = |tree_index: usize| -> Result<TrainedTree, TrainError> {
108            let tree_seed = mix_seed(base_seed, tree_index as u64);
109            let (sampled_rows, oob_rows) = sampler.sample_with_oob(tree_seed);
110            // Sampling is implemented as a `TableAccess` view so the existing
111            // tree trainers can stay oblivious to bootstrap mechanics.
112            let sampled_table = SampledTable::new(&train_set, sampled_rows);
113            let model = training::train_single_model_with_feature_subset(
114                &sampled_table,
115                training::SingleModelFeatureSubsetConfig {
116                    base: training::SingleModelConfig {
117                        task: config.task,
118                        tree_type: config.tree_type,
119                        criterion,
120                        parallelism: per_tree_parallelism,
121                        max_depth: config.max_depth.unwrap_or(8),
122                        min_samples_split: config.min_samples_split.unwrap_or(2),
123                        min_samples_leaf: config.min_samples_leaf.unwrap_or(1),
124                    },
125                    max_features: Some(max_features),
126                    random_seed: tree_seed,
127                },
128            )?;
129            Ok(TrainedTree { model, oob_rows })
130        };
131        let trained_trees = if tree_parallelism.enabled() {
132            (0..n_trees)
133                .into_par_iter()
134                .map(train_tree)
135                .collect::<Result<Vec<_>, _>>()?
136        } else {
137            (0..n_trees)
138                .map(train_tree)
139                .collect::<Result<Vec<_>, _>>()?
140        };
141        let oob_score = if config.compute_oob {
142            compute_oob_score(config.task, &trained_trees, &train_set)
143        } else {
144            None
145        };
146        let trees = trained_trees.into_iter().map(|tree| tree.model).collect();
147
148        Ok(Self::new(
149            config.task,
150            criterion,
151            config.tree_type,
152            trees,
153            config.compute_oob,
154            oob_score,
155            max_features,
156            config.seed,
157            train_set.n_features(),
158            feature_preprocessing,
159        ))
160    }
161
162    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
163        match self.task {
164            Task::Regression => self.predict_regression_table(table),
165            Task::Classification => self.predict_classification_table(table),
166        }
167    }
168
169    pub fn predict_proba_table(
170        &self,
171        table: &dyn TableAccess,
172    ) -> Result<Vec<Vec<f64>>, PredictError> {
173        if self.task != Task::Classification {
174            return Err(PredictError::ProbabilityPredictionRequiresClassification);
175        }
176
177        // Forest classification aggregates full class distributions rather than
178        // voting on hard labels. This keeps `predict` and `predict_proba`
179        // consistent and matches the optimized runtime lowering.
180        let mut totals = self.trees[0].predict_proba_table(table)?;
181        for tree in &self.trees[1..] {
182            let probs = tree.predict_proba_table(table)?;
183            for (row_totals, row_probs) in totals.iter_mut().zip(probs.iter()) {
184                for (total, prob) in row_totals.iter_mut().zip(row_probs.iter()) {
185                    *total += *prob;
186                }
187            }
188        }
189
190        let tree_count = self.trees.len() as f64;
191        for row in &mut totals {
192            for value in row {
193                *value /= tree_count;
194            }
195        }
196
197        Ok(totals)
198    }
199
200    pub fn task(&self) -> Task {
201        self.task
202    }
203
204    pub fn criterion(&self) -> Criterion {
205        self.criterion
206    }
207
208    pub fn tree_type(&self) -> TreeType {
209        self.tree_type
210    }
211
212    pub fn trees(&self) -> &[Model] {
213        &self.trees
214    }
215
216    pub fn num_features(&self) -> usize {
217        self.num_features
218    }
219
220    pub fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
221        &self.feature_preprocessing
222    }
223
224    pub fn training_metadata(&self) -> TrainingMetadata {
225        let mut metadata = self.trees[0].training_metadata();
226        metadata.algorithm = "rf".to_string();
227        metadata.n_trees = Some(self.trees.len());
228        metadata.max_features = Some(self.max_features);
229        metadata.seed = self.seed;
230        metadata.compute_oob = self.compute_oob;
231        metadata.oob_score = self.oob_score;
232        metadata.learning_rate = None;
233        metadata.bootstrap = None;
234        metadata.top_gradient_fraction = None;
235        metadata.other_gradient_fraction = None;
236        metadata
237    }
238
239    pub fn class_labels(&self) -> Option<Vec<f64>> {
240        match self.task {
241            Task::Classification => self.trees[0].class_labels(),
242            Task::Regression => None,
243        }
244    }
245
246    pub fn oob_score(&self) -> Option<f64> {
247        self.oob_score
248    }
249
250    fn predict_regression_table(&self, table: &dyn TableAccess) -> Vec<f64> {
251        let mut totals = self.trees[0].predict_table(table);
252        for tree in &self.trees[1..] {
253            let preds = tree.predict_table(table);
254            for (total, pred) in totals.iter_mut().zip(preds.iter()) {
255                *total += *pred;
256            }
257        }
258
259        let tree_count = self.trees.len() as f64;
260        for value in &mut totals {
261            *value /= tree_count;
262        }
263
264        totals
265    }
266
267    fn predict_classification_table(&self, table: &dyn TableAccess) -> Vec<f64> {
268        let probabilities = self
269            .predict_proba_table(table)
270            .expect("classification forest supports probabilities");
271        let class_labels = self
272            .class_labels()
273            .expect("classification forest stores class labels");
274
275        probabilities
276            .into_iter()
277            .map(|row| {
278                let (best_index, _) = row
279                    .iter()
280                    .copied()
281                    .enumerate()
282                    .max_by(|(left_index, left), (right_index, right)| {
283                        left.total_cmp(right)
284                            .then_with(|| right_index.cmp(left_index))
285                    })
286                    .expect("classification probability row is non-empty");
287                class_labels[best_index]
288            })
289            .collect()
290    }
291}
292
293fn compute_oob_score(
294    task: Task,
295    trained_trees: &[TrainedTree],
296    train_set: &dyn TableAccess,
297) -> Option<f64> {
298    match task {
299        Task::Classification => compute_classification_oob_score(trained_trees, train_set),
300        Task::Regression => compute_regression_oob_score(trained_trees, train_set),
301    }
302}
303
304fn compute_classification_oob_score(
305    trained_trees: &[TrainedTree],
306    train_set: &dyn TableAccess,
307) -> Option<f64> {
308    let class_labels = trained_trees.first()?.model.class_labels()?;
309    let mut totals = vec![vec![0.0; class_labels.len()]; train_set.n_rows()];
310    let mut counts = vec![0usize; train_set.n_rows()];
311
312    for tree in trained_trees {
313        if tree.oob_rows.is_empty() {
314            continue;
315        }
316        let oob_table = SampledTable::new(train_set, tree.oob_rows.clone());
317        let probabilities = tree
318            .model
319            .predict_proba_table(&oob_table)
320            .expect("classification tree supports predict_proba");
321        for (&row_index, row_probs) in tree.oob_rows.iter().zip(probabilities.iter()) {
322            for (total, prob) in totals[row_index].iter_mut().zip(row_probs.iter()) {
323                *total += *prob;
324            }
325            counts[row_index] += 1;
326        }
327    }
328
329    let mut correct = 0usize;
330    let mut covered = 0usize;
331    for row_index in 0..train_set.n_rows() {
332        if counts[row_index] == 0 {
333            continue;
334        }
335        covered += 1;
336        let predicted = totals[row_index]
337            .iter()
338            .copied()
339            .enumerate()
340            .max_by(|(li, lv), (ri, rv)| lv.total_cmp(rv).then_with(|| ri.cmp(li)))
341            .map(|(index, _)| class_labels[index])
342            .expect("classification probability row is non-empty");
343        if predicted
344            .total_cmp(&train_set.target_value(row_index))
345            .is_eq()
346        {
347            correct += 1;
348        }
349    }
350
351    (covered > 0).then_some(correct as f64 / covered as f64)
352}
353
354fn compute_regression_oob_score(
355    trained_trees: &[TrainedTree],
356    train_set: &dyn TableAccess,
357) -> Option<f64> {
358    let mut totals = vec![0.0; train_set.n_rows()];
359    let mut counts = vec![0usize; train_set.n_rows()];
360
361    for tree in trained_trees {
362        if tree.oob_rows.is_empty() {
363            continue;
364        }
365        let oob_table = SampledTable::new(train_set, tree.oob_rows.clone());
366        let predictions = tree.model.predict_table(&oob_table);
367        for (&row_index, prediction) in tree.oob_rows.iter().zip(predictions.iter().copied()) {
368            totals[row_index] += prediction;
369            counts[row_index] += 1;
370        }
371    }
372
373    let covered_rows: Vec<usize> = counts
374        .iter()
375        .enumerate()
376        .filter_map(|(row_index, count)| (*count > 0).then_some(row_index))
377        .collect();
378    if covered_rows.is_empty() {
379        return None;
380    }
381
382    let mean_target = covered_rows
383        .iter()
384        .map(|row_index| train_set.target_value(*row_index))
385        .sum::<f64>()
386        / covered_rows.len() as f64;
387    let mut residual_sum = 0.0;
388    let mut total_sum = 0.0;
389    for row_index in covered_rows {
390        let actual = train_set.target_value(row_index);
391        let prediction = totals[row_index] / counts[row_index] as f64;
392        residual_sum += (actual - prediction).powi(2);
393        total_sum += (actual - mean_target).powi(2);
394    }
395    if total_sum == 0.0 {
396        return None;
397    }
398    Some(1.0 - residual_sum / total_sum)
399}
400
401impl<'a> SampledTable<'a> {
402    fn new(base: &'a dyn TableAccess, row_indices: Vec<usize>) -> Self {
403        Self { base, row_indices }
404    }
405
406    fn resolve_row(&self, row_index: usize) -> usize {
407        self.row_indices[row_index]
408    }
409}
410
411impl<'a> NoCanaryTable<'a> {
412    fn new(base: &'a dyn TableAccess) -> Self {
413        Self { base }
414    }
415}
416
417impl TableAccess for SampledTable<'_> {
418    fn n_rows(&self) -> usize {
419        self.row_indices.len()
420    }
421
422    fn n_features(&self) -> usize {
423        self.base.n_features()
424    }
425
426    fn canaries(&self) -> usize {
427        self.base.canaries()
428    }
429
430    fn numeric_bin_cap(&self) -> usize {
431        self.base.numeric_bin_cap()
432    }
433
434    fn binned_feature_count(&self) -> usize {
435        self.base.binned_feature_count()
436    }
437
438    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
439        self.base
440            .feature_value(feature_index, self.resolve_row(row_index))
441    }
442
443    fn is_binary_feature(&self, index: usize) -> bool {
444        self.base.is_binary_feature(index)
445    }
446
447    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
448        self.base
449            .binned_value(feature_index, self.resolve_row(row_index))
450    }
451
452    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
453        self.base
454            .binned_boolean_value(feature_index, self.resolve_row(row_index))
455    }
456
457    fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
458        self.base.binned_column_kind(index)
459    }
460
461    fn is_binary_binned_feature(&self, index: usize) -> bool {
462        self.base.is_binary_binned_feature(index)
463    }
464
465    fn target_value(&self, row_index: usize) -> f64 {
466        self.base.target_value(self.resolve_row(row_index))
467    }
468}
469
470impl TableAccess for NoCanaryTable<'_> {
471    fn n_rows(&self) -> usize {
472        self.base.n_rows()
473    }
474
475    fn n_features(&self) -> usize {
476        self.base.n_features()
477    }
478
479    fn canaries(&self) -> usize {
480        0
481    }
482
483    fn numeric_bin_cap(&self) -> usize {
484        self.base.numeric_bin_cap()
485    }
486
487    fn binned_feature_count(&self) -> usize {
488        self.base.binned_feature_count() - self.base.canaries()
489    }
490
491    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
492        self.base.feature_value(feature_index, row_index)
493    }
494
495    fn is_binary_feature(&self, index: usize) -> bool {
496        self.base.is_binary_feature(index)
497    }
498
499    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
500        self.base.binned_value(feature_index, row_index)
501    }
502
503    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
504        self.base.binned_boolean_value(feature_index, row_index)
505    }
506
507    fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
508        self.base.binned_column_kind(index)
509    }
510
511    fn is_binary_binned_feature(&self, index: usize) -> bool {
512        self.base.is_binary_binned_feature(index)
513    }
514
515    fn target_value(&self, row_index: usize) -> f64 {
516        self.base.target_value(row_index)
517    }
518}