Skip to main content

forestfire_core/
forest.rs

1use crate::bootstrap::BootstrapSampler;
2use crate::ir::TrainingMetadata;
3use crate::{
4    Criterion, FeaturePreprocessing, MaxFeatures, Model, Parallelism, PredictError, Task,
5    TrainConfig, TrainError, TreeType, capture_feature_preprocessing, training,
6};
7use forestfire_data::TableAccess;
8use rayon::prelude::*;
9
10#[derive(Debug, Clone)]
11pub struct RandomForest {
12    task: Task,
13    criterion: Criterion,
14    tree_type: TreeType,
15    trees: Vec<Model>,
16    compute_oob: bool,
17    oob_score: Option<f64>,
18    max_features: usize,
19    seed: Option<u64>,
20    num_features: usize,
21    feature_preprocessing: Vec<FeaturePreprocessing>,
22}
23
24struct TrainedTree {
25    model: Model,
26    oob_rows: Vec<usize>,
27}
28
29struct SampledTable<'a> {
30    base: &'a dyn TableAccess,
31    row_indices: Vec<usize>,
32}
33
34struct NoCanaryTable<'a> {
35    base: &'a dyn TableAccess,
36}
37
38impl RandomForest {
39    #[allow(clippy::too_many_arguments)]
40    pub fn new(
41        task: Task,
42        criterion: Criterion,
43        tree_type: TreeType,
44        trees: Vec<Model>,
45        compute_oob: bool,
46        oob_score: Option<f64>,
47        max_features: usize,
48        seed: Option<u64>,
49        num_features: usize,
50        feature_preprocessing: Vec<FeaturePreprocessing>,
51    ) -> Self {
52        Self {
53            task,
54            criterion,
55            tree_type,
56            trees,
57            compute_oob,
58            oob_score,
59            max_features,
60            seed,
61            num_features,
62            feature_preprocessing,
63        }
64    }
65
66    pub(crate) fn train(
67        train_set: &dyn TableAccess,
68        config: TrainConfig,
69        criterion: Criterion,
70        parallelism: Parallelism,
71    ) -> Result<Self, TrainError> {
72        let n_trees = config.n_trees.unwrap_or(1000);
73        if n_trees == 0 {
74            return Err(TrainError::InvalidTreeCount(n_trees));
75        }
76        if matches!(config.max_features, MaxFeatures::Count(0)) {
77            return Err(TrainError::InvalidMaxFeatures(0));
78        }
79
80        let train_set = NoCanaryTable::new(train_set);
81        let sampler = BootstrapSampler::new(train_set.n_rows());
82        let feature_preprocessing = capture_feature_preprocessing(&train_set);
83        let max_features = config
84            .max_features
85            .resolve(config.task, train_set.binned_feature_count());
86        let base_seed = config.seed.unwrap_or(0x0005_EEDF_0E57_u64);
87        let tree_parallelism = Parallelism {
88            thread_count: parallelism.thread_count,
89        };
90        let per_tree_parallelism = Parallelism::sequential();
91        let train_tree = |tree_index: usize| -> Result<TrainedTree, TrainError> {
92            let tree_seed = mix_seed(base_seed, tree_index as u64);
93            let (sampled_rows, oob_rows) = sampler.sample_with_oob(tree_seed);
94            let sampled_table = SampledTable::new(&train_set, sampled_rows);
95            let model = training::train_single_model_with_feature_subset(
96                &sampled_table,
97                training::SingleModelFeatureSubsetConfig {
98                    base: training::SingleModelConfig {
99                        task: config.task,
100                        tree_type: config.tree_type,
101                        criterion,
102                        parallelism: per_tree_parallelism,
103                        max_depth: config.max_depth.unwrap_or(8),
104                        min_samples_split: config.min_samples_split.unwrap_or(2),
105                        min_samples_leaf: config.min_samples_leaf.unwrap_or(1),
106                    },
107                    max_features: Some(max_features),
108                    random_seed: tree_seed,
109                },
110            )?;
111            Ok(TrainedTree { model, oob_rows })
112        };
113        let trained_trees = if tree_parallelism.enabled() {
114            (0..n_trees)
115                .into_par_iter()
116                .map(train_tree)
117                .collect::<Result<Vec<_>, _>>()?
118        } else {
119            (0..n_trees)
120                .map(train_tree)
121                .collect::<Result<Vec<_>, _>>()?
122        };
123        let oob_score = if config.compute_oob {
124            compute_oob_score(config.task, &trained_trees, &train_set)
125        } else {
126            None
127        };
128        let trees = trained_trees.into_iter().map(|tree| tree.model).collect();
129
130        Ok(Self::new(
131            config.task,
132            criterion,
133            config.tree_type,
134            trees,
135            config.compute_oob,
136            oob_score,
137            max_features,
138            config.seed,
139            train_set.n_features(),
140            feature_preprocessing,
141        ))
142    }
143
144    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
145        match self.task {
146            Task::Regression => self.predict_regression_table(table),
147            Task::Classification => self.predict_classification_table(table),
148        }
149    }
150
151    pub fn predict_proba_table(
152        &self,
153        table: &dyn TableAccess,
154    ) -> Result<Vec<Vec<f64>>, PredictError> {
155        if self.task != Task::Classification {
156            return Err(PredictError::ProbabilityPredictionRequiresClassification);
157        }
158
159        let mut totals = self.trees[0].predict_proba_table(table)?;
160        for tree in &self.trees[1..] {
161            let probs = tree.predict_proba_table(table)?;
162            for (row_totals, row_probs) in totals.iter_mut().zip(probs.iter()) {
163                for (total, prob) in row_totals.iter_mut().zip(row_probs.iter()) {
164                    *total += *prob;
165                }
166            }
167        }
168
169        let tree_count = self.trees.len() as f64;
170        for row in &mut totals {
171            for value in row {
172                *value /= tree_count;
173            }
174        }
175
176        Ok(totals)
177    }
178
179    pub fn task(&self) -> Task {
180        self.task
181    }
182
183    pub fn criterion(&self) -> Criterion {
184        self.criterion
185    }
186
187    pub fn tree_type(&self) -> TreeType {
188        self.tree_type
189    }
190
191    pub fn trees(&self) -> &[Model] {
192        &self.trees
193    }
194
195    pub fn num_features(&self) -> usize {
196        self.num_features
197    }
198
199    pub fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
200        &self.feature_preprocessing
201    }
202
203    pub fn training_metadata(&self) -> TrainingMetadata {
204        let mut metadata = self.trees[0].training_metadata();
205        metadata.algorithm = "rf".to_string();
206        metadata.n_trees = Some(self.trees.len());
207        metadata.max_features = Some(self.max_features);
208        metadata.seed = self.seed;
209        metadata.compute_oob = self.compute_oob;
210        metadata.oob_score = self.oob_score;
211        metadata.learning_rate = None;
212        metadata.bootstrap = None;
213        metadata.top_gradient_fraction = None;
214        metadata.other_gradient_fraction = None;
215        metadata
216    }
217
218    pub fn class_labels(&self) -> Option<Vec<f64>> {
219        match self.task {
220            Task::Classification => self.trees[0].class_labels(),
221            Task::Regression => None,
222        }
223    }
224
225    pub fn oob_score(&self) -> Option<f64> {
226        self.oob_score
227    }
228
229    fn predict_regression_table(&self, table: &dyn TableAccess) -> Vec<f64> {
230        let mut totals = self.trees[0].predict_table(table);
231        for tree in &self.trees[1..] {
232            let preds = tree.predict_table(table);
233            for (total, pred) in totals.iter_mut().zip(preds.iter()) {
234                *total += *pred;
235            }
236        }
237
238        let tree_count = self.trees.len() as f64;
239        for value in &mut totals {
240            *value /= tree_count;
241        }
242
243        totals
244    }
245
246    fn predict_classification_table(&self, table: &dyn TableAccess) -> Vec<f64> {
247        let probabilities = self
248            .predict_proba_table(table)
249            .expect("classification forest supports probabilities");
250        let class_labels = self
251            .class_labels()
252            .expect("classification forest stores class labels");
253
254        probabilities
255            .into_iter()
256            .map(|row| {
257                let (best_index, _) = row
258                    .iter()
259                    .copied()
260                    .enumerate()
261                    .max_by(|(left_index, left), (right_index, right)| {
262                        left.total_cmp(right)
263                            .then_with(|| right_index.cmp(left_index))
264                    })
265                    .expect("classification probability row is non-empty");
266                class_labels[best_index]
267            })
268            .collect()
269    }
270}
271
272fn mix_seed(base_seed: u64, value: u64) -> u64 {
273    base_seed ^ value.wrapping_mul(0x9E37_79B9_7F4A_7C15).rotate_left(17)
274}
275
276fn compute_oob_score(
277    task: Task,
278    trained_trees: &[TrainedTree],
279    train_set: &dyn TableAccess,
280) -> Option<f64> {
281    match task {
282        Task::Classification => compute_classification_oob_score(trained_trees, train_set),
283        Task::Regression => compute_regression_oob_score(trained_trees, train_set),
284    }
285}
286
287fn compute_classification_oob_score(
288    trained_trees: &[TrainedTree],
289    train_set: &dyn TableAccess,
290) -> Option<f64> {
291    let class_labels = trained_trees.first()?.model.class_labels()?;
292    let mut totals = vec![vec![0.0; class_labels.len()]; train_set.n_rows()];
293    let mut counts = vec![0usize; train_set.n_rows()];
294
295    for tree in trained_trees {
296        if tree.oob_rows.is_empty() {
297            continue;
298        }
299        let oob_table = SampledTable::new(train_set, tree.oob_rows.clone());
300        let probabilities = tree
301            .model
302            .predict_proba_table(&oob_table)
303            .expect("classification tree supports predict_proba");
304        for (&row_index, row_probs) in tree.oob_rows.iter().zip(probabilities.iter()) {
305            for (total, prob) in totals[row_index].iter_mut().zip(row_probs.iter()) {
306                *total += *prob;
307            }
308            counts[row_index] += 1;
309        }
310    }
311
312    let mut correct = 0usize;
313    let mut covered = 0usize;
314    for row_index in 0..train_set.n_rows() {
315        if counts[row_index] == 0 {
316            continue;
317        }
318        covered += 1;
319        let predicted = totals[row_index]
320            .iter()
321            .copied()
322            .enumerate()
323            .max_by(|(li, lv), (ri, rv)| lv.total_cmp(rv).then_with(|| ri.cmp(li)))
324            .map(|(index, _)| class_labels[index])
325            .expect("classification probability row is non-empty");
326        if predicted
327            .total_cmp(&train_set.target_value(row_index))
328            .is_eq()
329        {
330            correct += 1;
331        }
332    }
333
334    (covered > 0).then_some(correct as f64 / covered as f64)
335}
336
337fn compute_regression_oob_score(
338    trained_trees: &[TrainedTree],
339    train_set: &dyn TableAccess,
340) -> Option<f64> {
341    let mut totals = vec![0.0; train_set.n_rows()];
342    let mut counts = vec![0usize; train_set.n_rows()];
343
344    for tree in trained_trees {
345        if tree.oob_rows.is_empty() {
346            continue;
347        }
348        let oob_table = SampledTable::new(train_set, tree.oob_rows.clone());
349        let predictions = tree.model.predict_table(&oob_table);
350        for (&row_index, prediction) in tree.oob_rows.iter().zip(predictions.iter().copied()) {
351            totals[row_index] += prediction;
352            counts[row_index] += 1;
353        }
354    }
355
356    let covered_rows: Vec<usize> = counts
357        .iter()
358        .enumerate()
359        .filter_map(|(row_index, count)| (*count > 0).then_some(row_index))
360        .collect();
361    if covered_rows.is_empty() {
362        return None;
363    }
364
365    let mean_target = covered_rows
366        .iter()
367        .map(|row_index| train_set.target_value(*row_index))
368        .sum::<f64>()
369        / covered_rows.len() as f64;
370    let mut residual_sum = 0.0;
371    let mut total_sum = 0.0;
372    for row_index in covered_rows {
373        let actual = train_set.target_value(row_index);
374        let prediction = totals[row_index] / counts[row_index] as f64;
375        residual_sum += (actual - prediction).powi(2);
376        total_sum += (actual - mean_target).powi(2);
377    }
378    if total_sum == 0.0 {
379        return None;
380    }
381    Some(1.0 - residual_sum / total_sum)
382}
383
384impl<'a> SampledTable<'a> {
385    fn new(base: &'a dyn TableAccess, row_indices: Vec<usize>) -> Self {
386        Self { base, row_indices }
387    }
388
389    fn resolve_row(&self, row_index: usize) -> usize {
390        self.row_indices[row_index]
391    }
392}
393
394impl<'a> NoCanaryTable<'a> {
395    fn new(base: &'a dyn TableAccess) -> Self {
396        Self { base }
397    }
398}
399
400impl TableAccess for SampledTable<'_> {
401    fn n_rows(&self) -> usize {
402        self.row_indices.len()
403    }
404
405    fn n_features(&self) -> usize {
406        self.base.n_features()
407    }
408
409    fn canaries(&self) -> usize {
410        self.base.canaries()
411    }
412
413    fn numeric_bin_cap(&self) -> usize {
414        self.base.numeric_bin_cap()
415    }
416
417    fn binned_feature_count(&self) -> usize {
418        self.base.binned_feature_count()
419    }
420
421    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
422        self.base
423            .feature_value(feature_index, self.resolve_row(row_index))
424    }
425
426    fn is_binary_feature(&self, index: usize) -> bool {
427        self.base.is_binary_feature(index)
428    }
429
430    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
431        self.base
432            .binned_value(feature_index, self.resolve_row(row_index))
433    }
434
435    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
436        self.base
437            .binned_boolean_value(feature_index, self.resolve_row(row_index))
438    }
439
440    fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
441        self.base.binned_column_kind(index)
442    }
443
444    fn is_binary_binned_feature(&self, index: usize) -> bool {
445        self.base.is_binary_binned_feature(index)
446    }
447
448    fn target_value(&self, row_index: usize) -> f64 {
449        self.base.target_value(self.resolve_row(row_index))
450    }
451}
452
453impl TableAccess for NoCanaryTable<'_> {
454    fn n_rows(&self) -> usize {
455        self.base.n_rows()
456    }
457
458    fn n_features(&self) -> usize {
459        self.base.n_features()
460    }
461
462    fn canaries(&self) -> usize {
463        0
464    }
465
466    fn numeric_bin_cap(&self) -> usize {
467        self.base.numeric_bin_cap()
468    }
469
470    fn binned_feature_count(&self) -> usize {
471        self.base.binned_feature_count() - self.base.canaries()
472    }
473
474    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
475        self.base.feature_value(feature_index, row_index)
476    }
477
478    fn is_binary_feature(&self, index: usize) -> bool {
479        self.base.is_binary_feature(index)
480    }
481
482    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
483        self.base.binned_value(feature_index, row_index)
484    }
485
486    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
487        self.base.binned_boolean_value(feature_index, row_index)
488    }
489
490    fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
491        self.base.binned_column_kind(index)
492    }
493
494    fn is_binary_binned_feature(&self, index: usize) -> bool {
495        self.base.is_binary_binned_feature(index)
496    }
497
498    fn target_value(&self, row_index: usize) -> f64 {
499        self.base.target_value(row_index)
500    }
501}