1use crate::bootstrap::BootstrapSampler;
2use crate::ir::TrainingMetadata;
3use crate::{
4 Criterion, FeaturePreprocessing, MaxFeatures, Model, Parallelism, PredictError, Task,
5 TrainConfig, TrainError, TreeType, capture_feature_preprocessing, training,
6};
7use forestfire_data::TableAccess;
8use rayon::prelude::*;
9
10#[derive(Debug, Clone)]
11pub struct RandomForest {
12 task: Task,
13 criterion: Criterion,
14 tree_type: TreeType,
15 trees: Vec<Model>,
16 compute_oob: bool,
17 oob_score: Option<f64>,
18 max_features: usize,
19 seed: Option<u64>,
20 num_features: usize,
21 feature_preprocessing: Vec<FeaturePreprocessing>,
22}
23
24struct TrainedTree {
25 model: Model,
26 oob_rows: Vec<usize>,
27}
28
29struct SampledTable<'a> {
30 base: &'a dyn TableAccess,
31 row_indices: Vec<usize>,
32}
33
34struct NoCanaryTable<'a> {
35 base: &'a dyn TableAccess,
36}
37
38impl RandomForest {
39 #[allow(clippy::too_many_arguments)]
40 pub fn new(
41 task: Task,
42 criterion: Criterion,
43 tree_type: TreeType,
44 trees: Vec<Model>,
45 compute_oob: bool,
46 oob_score: Option<f64>,
47 max_features: usize,
48 seed: Option<u64>,
49 num_features: usize,
50 feature_preprocessing: Vec<FeaturePreprocessing>,
51 ) -> Self {
52 Self {
53 task,
54 criterion,
55 tree_type,
56 trees,
57 compute_oob,
58 oob_score,
59 max_features,
60 seed,
61 num_features,
62 feature_preprocessing,
63 }
64 }
65
66 pub(crate) fn train(
67 train_set: &dyn TableAccess,
68 config: TrainConfig,
69 criterion: Criterion,
70 parallelism: Parallelism,
71 ) -> Result<Self, TrainError> {
72 let n_trees = config.n_trees.unwrap_or(1000);
73 if n_trees == 0 {
74 return Err(TrainError::InvalidTreeCount(n_trees));
75 }
76 if matches!(config.max_features, MaxFeatures::Count(0)) {
77 return Err(TrainError::InvalidMaxFeatures(0));
78 }
79
80 let train_set = NoCanaryTable::new(train_set);
81 let sampler = BootstrapSampler::new(train_set.n_rows());
82 let feature_preprocessing = capture_feature_preprocessing(&train_set);
83 let max_features = config
84 .max_features
85 .resolve(config.task, train_set.binned_feature_count());
86 let base_seed = config.seed.unwrap_or(0x0005_EEDF_0E57_u64);
87 let tree_parallelism = Parallelism {
88 thread_count: parallelism.thread_count,
89 };
90 let per_tree_parallelism = Parallelism::sequential();
91 let train_tree = |tree_index: usize| -> Result<TrainedTree, TrainError> {
92 let tree_seed = mix_seed(base_seed, tree_index as u64);
93 let (sampled_rows, oob_rows) = sampler.sample_with_oob(tree_seed);
94 let sampled_table = SampledTable::new(&train_set, sampled_rows);
95 let model = training::train_single_model_with_feature_subset(
96 &sampled_table,
97 training::SingleModelFeatureSubsetConfig {
98 base: training::SingleModelConfig {
99 task: config.task,
100 tree_type: config.tree_type,
101 criterion,
102 parallelism: per_tree_parallelism,
103 max_depth: config.max_depth.unwrap_or(8),
104 min_samples_split: config.min_samples_split.unwrap_or(2),
105 min_samples_leaf: config.min_samples_leaf.unwrap_or(1),
106 },
107 max_features: Some(max_features),
108 random_seed: tree_seed,
109 },
110 )?;
111 Ok(TrainedTree { model, oob_rows })
112 };
113 let trained_trees = if tree_parallelism.enabled() {
114 (0..n_trees)
115 .into_par_iter()
116 .map(train_tree)
117 .collect::<Result<Vec<_>, _>>()?
118 } else {
119 (0..n_trees)
120 .map(train_tree)
121 .collect::<Result<Vec<_>, _>>()?
122 };
123 let oob_score = if config.compute_oob {
124 compute_oob_score(config.task, &trained_trees, &train_set)
125 } else {
126 None
127 };
128 let trees = trained_trees.into_iter().map(|tree| tree.model).collect();
129
130 Ok(Self::new(
131 config.task,
132 criterion,
133 config.tree_type,
134 trees,
135 config.compute_oob,
136 oob_score,
137 max_features,
138 config.seed,
139 train_set.n_features(),
140 feature_preprocessing,
141 ))
142 }
143
144 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
145 match self.task {
146 Task::Regression => self.predict_regression_table(table),
147 Task::Classification => self.predict_classification_table(table),
148 }
149 }
150
151 pub fn predict_proba_table(
152 &self,
153 table: &dyn TableAccess,
154 ) -> Result<Vec<Vec<f64>>, PredictError> {
155 if self.task != Task::Classification {
156 return Err(PredictError::ProbabilityPredictionRequiresClassification);
157 }
158
159 let mut totals = self.trees[0].predict_proba_table(table)?;
160 for tree in &self.trees[1..] {
161 let probs = tree.predict_proba_table(table)?;
162 for (row_totals, row_probs) in totals.iter_mut().zip(probs.iter()) {
163 for (total, prob) in row_totals.iter_mut().zip(row_probs.iter()) {
164 *total += *prob;
165 }
166 }
167 }
168
169 let tree_count = self.trees.len() as f64;
170 for row in &mut totals {
171 for value in row {
172 *value /= tree_count;
173 }
174 }
175
176 Ok(totals)
177 }
178
179 pub fn task(&self) -> Task {
180 self.task
181 }
182
183 pub fn criterion(&self) -> Criterion {
184 self.criterion
185 }
186
187 pub fn tree_type(&self) -> TreeType {
188 self.tree_type
189 }
190
191 pub fn trees(&self) -> &[Model] {
192 &self.trees
193 }
194
195 pub fn num_features(&self) -> usize {
196 self.num_features
197 }
198
199 pub fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
200 &self.feature_preprocessing
201 }
202
203 pub fn training_metadata(&self) -> TrainingMetadata {
204 let mut metadata = self.trees[0].training_metadata();
205 metadata.algorithm = "rf".to_string();
206 metadata.n_trees = Some(self.trees.len());
207 metadata.max_features = Some(self.max_features);
208 metadata.seed = self.seed;
209 metadata.compute_oob = self.compute_oob;
210 metadata.oob_score = self.oob_score;
211 metadata.learning_rate = None;
212 metadata.bootstrap = None;
213 metadata.top_gradient_fraction = None;
214 metadata.other_gradient_fraction = None;
215 metadata
216 }
217
218 pub fn class_labels(&self) -> Option<Vec<f64>> {
219 match self.task {
220 Task::Classification => self.trees[0].class_labels(),
221 Task::Regression => None,
222 }
223 }
224
225 pub fn oob_score(&self) -> Option<f64> {
226 self.oob_score
227 }
228
229 fn predict_regression_table(&self, table: &dyn TableAccess) -> Vec<f64> {
230 let mut totals = self.trees[0].predict_table(table);
231 for tree in &self.trees[1..] {
232 let preds = tree.predict_table(table);
233 for (total, pred) in totals.iter_mut().zip(preds.iter()) {
234 *total += *pred;
235 }
236 }
237
238 let tree_count = self.trees.len() as f64;
239 for value in &mut totals {
240 *value /= tree_count;
241 }
242
243 totals
244 }
245
246 fn predict_classification_table(&self, table: &dyn TableAccess) -> Vec<f64> {
247 let probabilities = self
248 .predict_proba_table(table)
249 .expect("classification forest supports probabilities");
250 let class_labels = self
251 .class_labels()
252 .expect("classification forest stores class labels");
253
254 probabilities
255 .into_iter()
256 .map(|row| {
257 let (best_index, _) = row
258 .iter()
259 .copied()
260 .enumerate()
261 .max_by(|(left_index, left), (right_index, right)| {
262 left.total_cmp(right)
263 .then_with(|| right_index.cmp(left_index))
264 })
265 .expect("classification probability row is non-empty");
266 class_labels[best_index]
267 })
268 .collect()
269 }
270}
271
272fn mix_seed(base_seed: u64, value: u64) -> u64 {
273 base_seed ^ value.wrapping_mul(0x9E37_79B9_7F4A_7C15).rotate_left(17)
274}
275
276fn compute_oob_score(
277 task: Task,
278 trained_trees: &[TrainedTree],
279 train_set: &dyn TableAccess,
280) -> Option<f64> {
281 match task {
282 Task::Classification => compute_classification_oob_score(trained_trees, train_set),
283 Task::Regression => compute_regression_oob_score(trained_trees, train_set),
284 }
285}
286
287fn compute_classification_oob_score(
288 trained_trees: &[TrainedTree],
289 train_set: &dyn TableAccess,
290) -> Option<f64> {
291 let class_labels = trained_trees.first()?.model.class_labels()?;
292 let mut totals = vec![vec![0.0; class_labels.len()]; train_set.n_rows()];
293 let mut counts = vec![0usize; train_set.n_rows()];
294
295 for tree in trained_trees {
296 if tree.oob_rows.is_empty() {
297 continue;
298 }
299 let oob_table = SampledTable::new(train_set, tree.oob_rows.clone());
300 let probabilities = tree
301 .model
302 .predict_proba_table(&oob_table)
303 .expect("classification tree supports predict_proba");
304 for (&row_index, row_probs) in tree.oob_rows.iter().zip(probabilities.iter()) {
305 for (total, prob) in totals[row_index].iter_mut().zip(row_probs.iter()) {
306 *total += *prob;
307 }
308 counts[row_index] += 1;
309 }
310 }
311
312 let mut correct = 0usize;
313 let mut covered = 0usize;
314 for row_index in 0..train_set.n_rows() {
315 if counts[row_index] == 0 {
316 continue;
317 }
318 covered += 1;
319 let predicted = totals[row_index]
320 .iter()
321 .copied()
322 .enumerate()
323 .max_by(|(li, lv), (ri, rv)| lv.total_cmp(rv).then_with(|| ri.cmp(li)))
324 .map(|(index, _)| class_labels[index])
325 .expect("classification probability row is non-empty");
326 if predicted
327 .total_cmp(&train_set.target_value(row_index))
328 .is_eq()
329 {
330 correct += 1;
331 }
332 }
333
334 (covered > 0).then_some(correct as f64 / covered as f64)
335}
336
337fn compute_regression_oob_score(
338 trained_trees: &[TrainedTree],
339 train_set: &dyn TableAccess,
340) -> Option<f64> {
341 let mut totals = vec![0.0; train_set.n_rows()];
342 let mut counts = vec![0usize; train_set.n_rows()];
343
344 for tree in trained_trees {
345 if tree.oob_rows.is_empty() {
346 continue;
347 }
348 let oob_table = SampledTable::new(train_set, tree.oob_rows.clone());
349 let predictions = tree.model.predict_table(&oob_table);
350 for (&row_index, prediction) in tree.oob_rows.iter().zip(predictions.iter().copied()) {
351 totals[row_index] += prediction;
352 counts[row_index] += 1;
353 }
354 }
355
356 let covered_rows: Vec<usize> = counts
357 .iter()
358 .enumerate()
359 .filter_map(|(row_index, count)| (*count > 0).then_some(row_index))
360 .collect();
361 if covered_rows.is_empty() {
362 return None;
363 }
364
365 let mean_target = covered_rows
366 .iter()
367 .map(|row_index| train_set.target_value(*row_index))
368 .sum::<f64>()
369 / covered_rows.len() as f64;
370 let mut residual_sum = 0.0;
371 let mut total_sum = 0.0;
372 for row_index in covered_rows {
373 let actual = train_set.target_value(row_index);
374 let prediction = totals[row_index] / counts[row_index] as f64;
375 residual_sum += (actual - prediction).powi(2);
376 total_sum += (actual - mean_target).powi(2);
377 }
378 if total_sum == 0.0 {
379 return None;
380 }
381 Some(1.0 - residual_sum / total_sum)
382}
383
384impl<'a> SampledTable<'a> {
385 fn new(base: &'a dyn TableAccess, row_indices: Vec<usize>) -> Self {
386 Self { base, row_indices }
387 }
388
389 fn resolve_row(&self, row_index: usize) -> usize {
390 self.row_indices[row_index]
391 }
392}
393
394impl<'a> NoCanaryTable<'a> {
395 fn new(base: &'a dyn TableAccess) -> Self {
396 Self { base }
397 }
398}
399
400impl TableAccess for SampledTable<'_> {
401 fn n_rows(&self) -> usize {
402 self.row_indices.len()
403 }
404
405 fn n_features(&self) -> usize {
406 self.base.n_features()
407 }
408
409 fn canaries(&self) -> usize {
410 self.base.canaries()
411 }
412
413 fn numeric_bin_cap(&self) -> usize {
414 self.base.numeric_bin_cap()
415 }
416
417 fn binned_feature_count(&self) -> usize {
418 self.base.binned_feature_count()
419 }
420
421 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
422 self.base
423 .feature_value(feature_index, self.resolve_row(row_index))
424 }
425
426 fn is_binary_feature(&self, index: usize) -> bool {
427 self.base.is_binary_feature(index)
428 }
429
430 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
431 self.base
432 .binned_value(feature_index, self.resolve_row(row_index))
433 }
434
435 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
436 self.base
437 .binned_boolean_value(feature_index, self.resolve_row(row_index))
438 }
439
440 fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
441 self.base.binned_column_kind(index)
442 }
443
444 fn is_binary_binned_feature(&self, index: usize) -> bool {
445 self.base.is_binary_binned_feature(index)
446 }
447
448 fn target_value(&self, row_index: usize) -> f64 {
449 self.base.target_value(self.resolve_row(row_index))
450 }
451}
452
453impl TableAccess for NoCanaryTable<'_> {
454 fn n_rows(&self) -> usize {
455 self.base.n_rows()
456 }
457
458 fn n_features(&self) -> usize {
459 self.base.n_features()
460 }
461
462 fn canaries(&self) -> usize {
463 0
464 }
465
466 fn numeric_bin_cap(&self) -> usize {
467 self.base.numeric_bin_cap()
468 }
469
470 fn binned_feature_count(&self) -> usize {
471 self.base.binned_feature_count() - self.base.canaries()
472 }
473
474 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
475 self.base.feature_value(feature_index, row_index)
476 }
477
478 fn is_binary_feature(&self, index: usize) -> bool {
479 self.base.is_binary_feature(index)
480 }
481
482 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
483 self.base.binned_value(feature_index, row_index)
484 }
485
486 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
487 self.base.binned_boolean_value(feature_index, row_index)
488 }
489
490 fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
491 self.base.binned_column_kind(index)
492 }
493
494 fn is_binary_binned_feature(&self, index: usize) -> bool {
495 self.base.is_binary_binned_feature(index)
496 }
497
498 fn target_value(&self, row_index: usize) -> f64 {
499 self.base.target_value(row_index)
500 }
501}