1use crate::bootstrap::BootstrapSampler;
9use crate::ir::TrainingMetadata;
10use crate::tree::shared::mix_seed;
11use crate::{
12 Criterion, FeaturePreprocessing, MaxFeatures, Model, Parallelism, PredictError, Task,
13 TrainError, TreeType, capture_feature_preprocessing, training,
14};
15use forestfire_data::TableAccess;
16use rayon::prelude::*;
17
18#[derive(Debug, Clone)]
24pub struct RandomForest {
25 task: Task,
26 criterion: Criterion,
27 tree_type: TreeType,
28 trees: Vec<Model>,
29 compute_oob: bool,
30 oob_score: Option<f64>,
31 max_features: usize,
32 seed: Option<u64>,
33 num_features: usize,
34 feature_preprocessing: Vec<FeaturePreprocessing>,
35}
36
37struct TrainedTree {
38 model: Model,
39 oob_rows: Vec<usize>,
40}
41
42struct SampledTable<'a> {
43 base: &'a dyn TableAccess,
44 row_indices: Vec<usize>,
45}
46
47struct NoCanaryTable<'a> {
48 base: &'a dyn TableAccess,
49}
50
51impl RandomForest {
52 #[allow(clippy::too_many_arguments)]
53 pub fn new(
54 task: Task,
55 criterion: Criterion,
56 tree_type: TreeType,
57 trees: Vec<Model>,
58 compute_oob: bool,
59 oob_score: Option<f64>,
60 max_features: usize,
61 seed: Option<u64>,
62 num_features: usize,
63 feature_preprocessing: Vec<FeaturePreprocessing>,
64 ) -> Self {
65 Self {
66 task,
67 criterion,
68 tree_type,
69 trees,
70 compute_oob,
71 oob_score,
72 max_features,
73 seed,
74 num_features,
75 feature_preprocessing,
76 }
77 }
78
79 pub(crate) fn train(
80 train_set: &dyn TableAccess,
81 config: training::RandomForestConfig,
82 criterion: Criterion,
83 parallelism: Parallelism,
84 ) -> Result<Self, TrainError> {
85 let n_trees = config.n_trees;
86 if n_trees == 0 {
87 return Err(TrainError::InvalidTreeCount(n_trees));
88 }
89 if matches!(config.max_features, MaxFeatures::Count(0)) {
90 return Err(TrainError::InvalidMaxFeatures(0));
91 }
92
93 let train_set = NoCanaryTable::new(train_set);
97 let sampler = BootstrapSampler::new(train_set.n_rows());
98 let feature_preprocessing = capture_feature_preprocessing(&train_set);
99 let max_features = config
100 .max_features
101 .resolve(config.task, train_set.binned_feature_count());
102 let base_seed = config.seed.unwrap_or(0x0005_EEDF_0E57_u64);
103 let tree_parallelism = Parallelism {
104 thread_count: parallelism.thread_count,
105 };
106 let per_tree_parallelism = Parallelism::sequential();
107 let train_tree = |tree_index: usize| -> Result<TrainedTree, TrainError> {
108 let tree_seed = mix_seed(base_seed, tree_index as u64);
109 let (sampled_rows, oob_rows) = sampler.sample_with_oob(tree_seed);
110 let sampled_table = SampledTable::new(&train_set, sampled_rows);
113 let model = training::train_single_model_with_feature_subset(
114 &sampled_table,
115 training::SingleModelFeatureSubsetConfig {
116 base: training::SingleModelConfig {
117 task: config.task,
118 tree_type: config.tree_type,
119 criterion,
120 parallelism: per_tree_parallelism,
121 max_depth: config.max_depth,
122 min_samples_split: config.min_samples_split,
123 min_samples_leaf: config.min_samples_leaf,
124 missing_value_strategies: config.missing_value_strategies.clone(),
125 canary_filter: crate::CanaryFilter::default(),
126 },
127 max_features: Some(max_features),
128 random_seed: tree_seed,
129 },
130 )?;
131 Ok(TrainedTree { model, oob_rows })
132 };
133 let trained_trees = if tree_parallelism.enabled() {
134 (0..n_trees)
135 .into_par_iter()
136 .map(train_tree)
137 .collect::<Result<Vec<_>, _>>()?
138 } else {
139 (0..n_trees)
140 .map(train_tree)
141 .collect::<Result<Vec<_>, _>>()?
142 };
143 let oob_score = if config.compute_oob {
144 compute_oob_score(config.task, &trained_trees, &train_set)
145 } else {
146 None
147 };
148 let trees = trained_trees.into_iter().map(|tree| tree.model).collect();
149
150 Ok(Self::new(
151 config.task,
152 criterion,
153 config.tree_type,
154 trees,
155 config.compute_oob,
156 oob_score,
157 max_features,
158 config.seed,
159 train_set.n_features(),
160 feature_preprocessing,
161 ))
162 }
163
164 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
165 match self.task {
166 Task::Regression => self.predict_regression_table(table),
167 Task::Classification => self.predict_classification_table(table),
168 }
169 }
170
171 pub fn predict_proba_table(
172 &self,
173 table: &dyn TableAccess,
174 ) -> Result<Vec<Vec<f64>>, PredictError> {
175 if self.task != Task::Classification {
176 return Err(PredictError::ProbabilityPredictionRequiresClassification);
177 }
178
179 let mut totals = self.trees[0].predict_proba_table(table)?;
183 for tree in &self.trees[1..] {
184 let probs = tree.predict_proba_table(table)?;
185 for (row_totals, row_probs) in totals.iter_mut().zip(probs.iter()) {
186 for (total, prob) in row_totals.iter_mut().zip(row_probs.iter()) {
187 *total += *prob;
188 }
189 }
190 }
191
192 let tree_count = self.trees.len() as f64;
193 for row in &mut totals {
194 for value in row {
195 *value /= tree_count;
196 }
197 }
198
199 Ok(totals)
200 }
201
202 pub fn task(&self) -> Task {
203 self.task
204 }
205
206 pub fn criterion(&self) -> Criterion {
207 self.criterion
208 }
209
210 pub fn tree_type(&self) -> TreeType {
211 self.tree_type
212 }
213
214 pub fn trees(&self) -> &[Model] {
215 &self.trees
216 }
217
218 pub fn num_features(&self) -> usize {
219 self.num_features
220 }
221
222 pub fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
223 &self.feature_preprocessing
224 }
225
226 pub fn training_metadata(&self) -> TrainingMetadata {
227 let mut metadata = self.trees[0].training_metadata();
228 metadata.algorithm = "rf".to_string();
229 metadata.n_trees = Some(self.trees.len());
230 metadata.max_features = Some(self.max_features);
231 metadata.seed = self.seed;
232 metadata.compute_oob = self.compute_oob;
233 metadata.oob_score = self.oob_score;
234 metadata.learning_rate = None;
235 metadata.bootstrap = None;
236 metadata.top_gradient_fraction = None;
237 metadata.other_gradient_fraction = None;
238 metadata
239 }
240
241 pub fn class_labels(&self) -> Option<Vec<f64>> {
242 match self.task {
243 Task::Classification => self.trees[0].class_labels(),
244 Task::Regression => None,
245 }
246 }
247
248 pub fn oob_score(&self) -> Option<f64> {
249 self.oob_score
250 }
251
252 fn predict_regression_table(&self, table: &dyn TableAccess) -> Vec<f64> {
253 let mut totals = self.trees[0].predict_table(table);
254 for tree in &self.trees[1..] {
255 let preds = tree.predict_table(table);
256 for (total, pred) in totals.iter_mut().zip(preds.iter()) {
257 *total += *pred;
258 }
259 }
260
261 let tree_count = self.trees.len() as f64;
262 for value in &mut totals {
263 *value /= tree_count;
264 }
265
266 totals
267 }
268
269 fn predict_classification_table(&self, table: &dyn TableAccess) -> Vec<f64> {
270 let probabilities = self
271 .predict_proba_table(table)
272 .expect("classification forest supports probabilities");
273 let class_labels = self
274 .class_labels()
275 .expect("classification forest stores class labels");
276
277 probabilities
278 .into_iter()
279 .map(|row| {
280 let (best_index, _) = row
281 .iter()
282 .copied()
283 .enumerate()
284 .max_by(|(left_index, left), (right_index, right)| {
285 left.total_cmp(right)
286 .then_with(|| right_index.cmp(left_index))
287 })
288 .expect("classification probability row is non-empty");
289 class_labels[best_index]
290 })
291 .collect()
292 }
293}
294
295fn compute_oob_score(
296 task: Task,
297 trained_trees: &[TrainedTree],
298 train_set: &dyn TableAccess,
299) -> Option<f64> {
300 match task {
301 Task::Classification => compute_classification_oob_score(trained_trees, train_set),
302 Task::Regression => compute_regression_oob_score(trained_trees, train_set),
303 }
304}
305
306fn compute_classification_oob_score(
307 trained_trees: &[TrainedTree],
308 train_set: &dyn TableAccess,
309) -> Option<f64> {
310 let class_labels = trained_trees.first()?.model.class_labels()?;
311 let mut totals = vec![vec![0.0; class_labels.len()]; train_set.n_rows()];
312 let mut counts = vec![0usize; train_set.n_rows()];
313
314 for tree in trained_trees {
315 if tree.oob_rows.is_empty() {
316 continue;
317 }
318 let oob_table = SampledTable::new(train_set, tree.oob_rows.clone());
319 let probabilities = tree
320 .model
321 .predict_proba_table(&oob_table)
322 .expect("classification tree supports predict_proba");
323 for (&row_index, row_probs) in tree.oob_rows.iter().zip(probabilities.iter()) {
324 for (total, prob) in totals[row_index].iter_mut().zip(row_probs.iter()) {
325 *total += *prob;
326 }
327 counts[row_index] += 1;
328 }
329 }
330
331 let mut correct = 0usize;
332 let mut covered = 0usize;
333 for row_index in 0..train_set.n_rows() {
334 if counts[row_index] == 0 {
335 continue;
336 }
337 covered += 1;
338 let predicted = totals[row_index]
339 .iter()
340 .copied()
341 .enumerate()
342 .max_by(|(li, lv), (ri, rv)| lv.total_cmp(rv).then_with(|| ri.cmp(li)))
343 .map(|(index, _)| class_labels[index])
344 .expect("classification probability row is non-empty");
345 if predicted
346 .total_cmp(&train_set.target_value(row_index))
347 .is_eq()
348 {
349 correct += 1;
350 }
351 }
352
353 (covered > 0).then_some(correct as f64 / covered as f64)
354}
355
356fn compute_regression_oob_score(
357 trained_trees: &[TrainedTree],
358 train_set: &dyn TableAccess,
359) -> Option<f64> {
360 let mut totals = vec![0.0; train_set.n_rows()];
361 let mut counts = vec![0usize; train_set.n_rows()];
362
363 for tree in trained_trees {
364 if tree.oob_rows.is_empty() {
365 continue;
366 }
367 let oob_table = SampledTable::new(train_set, tree.oob_rows.clone());
368 let predictions = tree.model.predict_table(&oob_table);
369 for (&row_index, prediction) in tree.oob_rows.iter().zip(predictions.iter().copied()) {
370 totals[row_index] += prediction;
371 counts[row_index] += 1;
372 }
373 }
374
375 let covered_rows: Vec<usize> = counts
376 .iter()
377 .enumerate()
378 .filter_map(|(row_index, count)| (*count > 0).then_some(row_index))
379 .collect();
380 if covered_rows.is_empty() {
381 return None;
382 }
383
384 let mean_target = covered_rows
385 .iter()
386 .map(|row_index| train_set.target_value(*row_index))
387 .sum::<f64>()
388 / covered_rows.len() as f64;
389 let mut residual_sum = 0.0;
390 let mut total_sum = 0.0;
391 for row_index in covered_rows {
392 let actual = train_set.target_value(row_index);
393 let prediction = totals[row_index] / counts[row_index] as f64;
394 residual_sum += (actual - prediction).powi(2);
395 total_sum += (actual - mean_target).powi(2);
396 }
397 if total_sum == 0.0 {
398 return None;
399 }
400 Some(1.0 - residual_sum / total_sum)
401}
402
403impl<'a> SampledTable<'a> {
404 fn new(base: &'a dyn TableAccess, row_indices: Vec<usize>) -> Self {
405 Self { base, row_indices }
406 }
407
408 fn resolve_row(&self, row_index: usize) -> usize {
409 self.row_indices[row_index]
410 }
411}
412
413impl<'a> NoCanaryTable<'a> {
414 fn new(base: &'a dyn TableAccess) -> Self {
415 Self { base }
416 }
417}
418
419impl TableAccess for SampledTable<'_> {
420 fn n_rows(&self) -> usize {
421 self.row_indices.len()
422 }
423
424 fn n_features(&self) -> usize {
425 self.base.n_features()
426 }
427
428 fn canaries(&self) -> usize {
429 self.base.canaries()
430 }
431
432 fn numeric_bin_cap(&self) -> usize {
433 self.base.numeric_bin_cap()
434 }
435
436 fn binned_feature_count(&self) -> usize {
437 self.base.binned_feature_count()
438 }
439
440 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
441 self.base
442 .feature_value(feature_index, self.resolve_row(row_index))
443 }
444
445 fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
446 self.base
447 .is_missing(feature_index, self.resolve_row(row_index))
448 }
449
450 fn is_binary_feature(&self, index: usize) -> bool {
451 self.base.is_binary_feature(index)
452 }
453
454 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
455 self.base
456 .binned_value(feature_index, self.resolve_row(row_index))
457 }
458
459 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
460 self.base
461 .binned_boolean_value(feature_index, self.resolve_row(row_index))
462 }
463
464 fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
465 self.base.binned_column_kind(index)
466 }
467
468 fn is_binary_binned_feature(&self, index: usize) -> bool {
469 self.base.is_binary_binned_feature(index)
470 }
471
472 fn target_value(&self, row_index: usize) -> f64 {
473 self.base.target_value(self.resolve_row(row_index))
474 }
475}
476
477impl TableAccess for NoCanaryTable<'_> {
478 fn n_rows(&self) -> usize {
479 self.base.n_rows()
480 }
481
482 fn n_features(&self) -> usize {
483 self.base.n_features()
484 }
485
486 fn canaries(&self) -> usize {
487 0
488 }
489
490 fn numeric_bin_cap(&self) -> usize {
491 self.base.numeric_bin_cap()
492 }
493
494 fn binned_feature_count(&self) -> usize {
495 self.base.binned_feature_count() - self.base.canaries()
496 }
497
498 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
499 self.base.feature_value(feature_index, row_index)
500 }
501
502 fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
503 self.base.is_missing(feature_index, row_index)
504 }
505
506 fn is_binary_feature(&self, index: usize) -> bool {
507 self.base.is_binary_feature(index)
508 }
509
510 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
511 self.base.binned_value(feature_index, row_index)
512 }
513
514 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
515 self.base.binned_boolean_value(feature_index, row_index)
516 }
517
518 fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
519 self.base.binned_column_kind(index)
520 }
521
522 fn is_binary_binned_feature(&self, index: usize) -> bool {
523 self.base.is_binary_binned_feature(index)
524 }
525
526 fn target_value(&self, row_index: usize) -> f64 {
527 self.base.target_value(row_index)
528 }
529}