1use crate::bootstrap::BootstrapSampler;
9use crate::ir::TrainingMetadata;
10use crate::tree::shared::mix_seed;
11use crate::{
12 Criterion, FeaturePreprocessing, MaxFeatures, Model, Parallelism, PredictError, Task,
13 TrainError, TreeType, capture_feature_preprocessing, training,
14};
15use forestfire_data::TableAccess;
16use rayon::prelude::*;
17
18#[derive(Debug, Clone)]
24pub struct RandomForest {
25 task: Task,
26 criterion: Criterion,
27 tree_type: TreeType,
28 trees: Vec<Model>,
29 compute_oob: bool,
30 oob_score: Option<f64>,
31 max_features: usize,
32 seed: Option<u64>,
33 num_features: usize,
34 feature_preprocessing: Vec<FeaturePreprocessing>,
35}
36
37struct TrainedTree {
38 model: Model,
39 oob_rows: Vec<usize>,
40}
41
42struct SampledTable<'a> {
43 base: &'a dyn TableAccess,
44 row_indices: Vec<usize>,
45}
46
47struct NoCanaryTable<'a> {
48 base: &'a dyn TableAccess,
49}
50
51impl RandomForest {
52 #[allow(clippy::too_many_arguments)]
53 pub fn new(
54 task: Task,
55 criterion: Criterion,
56 tree_type: TreeType,
57 trees: Vec<Model>,
58 compute_oob: bool,
59 oob_score: Option<f64>,
60 max_features: usize,
61 seed: Option<u64>,
62 num_features: usize,
63 feature_preprocessing: Vec<FeaturePreprocessing>,
64 ) -> Self {
65 Self {
66 task,
67 criterion,
68 tree_type,
69 trees,
70 compute_oob,
71 oob_score,
72 max_features,
73 seed,
74 num_features,
75 feature_preprocessing,
76 }
77 }
78
79 pub(crate) fn train(
80 train_set: &dyn TableAccess,
81 config: training::RandomForestConfig,
82 criterion: Criterion,
83 parallelism: Parallelism,
84 ) -> Result<Self, TrainError> {
85 let n_trees = config.n_trees;
86 if n_trees == 0 {
87 return Err(TrainError::InvalidTreeCount(n_trees));
88 }
89 if matches!(config.max_features, MaxFeatures::Count(0)) {
90 return Err(TrainError::InvalidMaxFeatures(0));
91 }
92
93 let train_set = NoCanaryTable::new(train_set);
97 let sampler = BootstrapSampler::new(train_set.n_rows());
98 let feature_preprocessing = capture_feature_preprocessing(&train_set);
99 let max_features = config
100 .max_features
101 .resolve(config.task, train_set.binned_feature_count());
102 let base_seed = config.seed.unwrap_or(0x0005_EEDF_0E57_u64);
103 let tree_parallelism = Parallelism {
104 thread_count: parallelism.thread_count,
105 };
106 let per_tree_parallelism = Parallelism::sequential();
107 let train_tree = |tree_index: usize| -> Result<TrainedTree, TrainError> {
108 let tree_seed = mix_seed(base_seed, tree_index as u64);
109 let (sampled_rows, oob_rows) = sampler.sample_with_oob(tree_seed);
110 let sampled_table = SampledTable::new(&train_set, sampled_rows);
113 let model = training::train_single_model_with_feature_subset(
114 &sampled_table,
115 training::SingleModelFeatureSubsetConfig {
116 base: training::SingleModelConfig {
117 task: config.task,
118 tree_type: config.tree_type,
119 criterion,
120 parallelism: per_tree_parallelism,
121 max_depth: config.max_depth,
122 min_samples_split: config.min_samples_split,
123 min_samples_leaf: config.min_samples_leaf,
124 missing_value_strategies: config.missing_value_strategies.clone(),
125 },
126 max_features: Some(max_features),
127 random_seed: tree_seed,
128 },
129 )?;
130 Ok(TrainedTree { model, oob_rows })
131 };
132 let trained_trees = if tree_parallelism.enabled() {
133 (0..n_trees)
134 .into_par_iter()
135 .map(train_tree)
136 .collect::<Result<Vec<_>, _>>()?
137 } else {
138 (0..n_trees)
139 .map(train_tree)
140 .collect::<Result<Vec<_>, _>>()?
141 };
142 let oob_score = if config.compute_oob {
143 compute_oob_score(config.task, &trained_trees, &train_set)
144 } else {
145 None
146 };
147 let trees = trained_trees.into_iter().map(|tree| tree.model).collect();
148
149 Ok(Self::new(
150 config.task,
151 criterion,
152 config.tree_type,
153 trees,
154 config.compute_oob,
155 oob_score,
156 max_features,
157 config.seed,
158 train_set.n_features(),
159 feature_preprocessing,
160 ))
161 }
162
163 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
164 match self.task {
165 Task::Regression => self.predict_regression_table(table),
166 Task::Classification => self.predict_classification_table(table),
167 }
168 }
169
170 pub fn predict_proba_table(
171 &self,
172 table: &dyn TableAccess,
173 ) -> Result<Vec<Vec<f64>>, PredictError> {
174 if self.task != Task::Classification {
175 return Err(PredictError::ProbabilityPredictionRequiresClassification);
176 }
177
178 let mut totals = self.trees[0].predict_proba_table(table)?;
182 for tree in &self.trees[1..] {
183 let probs = tree.predict_proba_table(table)?;
184 for (row_totals, row_probs) in totals.iter_mut().zip(probs.iter()) {
185 for (total, prob) in row_totals.iter_mut().zip(row_probs.iter()) {
186 *total += *prob;
187 }
188 }
189 }
190
191 let tree_count = self.trees.len() as f64;
192 for row in &mut totals {
193 for value in row {
194 *value /= tree_count;
195 }
196 }
197
198 Ok(totals)
199 }
200
201 pub fn task(&self) -> Task {
202 self.task
203 }
204
205 pub fn criterion(&self) -> Criterion {
206 self.criterion
207 }
208
209 pub fn tree_type(&self) -> TreeType {
210 self.tree_type
211 }
212
213 pub fn trees(&self) -> &[Model] {
214 &self.trees
215 }
216
217 pub fn num_features(&self) -> usize {
218 self.num_features
219 }
220
221 pub fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
222 &self.feature_preprocessing
223 }
224
225 pub fn training_metadata(&self) -> TrainingMetadata {
226 let mut metadata = self.trees[0].training_metadata();
227 metadata.algorithm = "rf".to_string();
228 metadata.n_trees = Some(self.trees.len());
229 metadata.max_features = Some(self.max_features);
230 metadata.seed = self.seed;
231 metadata.compute_oob = self.compute_oob;
232 metadata.oob_score = self.oob_score;
233 metadata.learning_rate = None;
234 metadata.bootstrap = None;
235 metadata.top_gradient_fraction = None;
236 metadata.other_gradient_fraction = None;
237 metadata
238 }
239
240 pub fn class_labels(&self) -> Option<Vec<f64>> {
241 match self.task {
242 Task::Classification => self.trees[0].class_labels(),
243 Task::Regression => None,
244 }
245 }
246
247 pub fn oob_score(&self) -> Option<f64> {
248 self.oob_score
249 }
250
251 fn predict_regression_table(&self, table: &dyn TableAccess) -> Vec<f64> {
252 let mut totals = self.trees[0].predict_table(table);
253 for tree in &self.trees[1..] {
254 let preds = tree.predict_table(table);
255 for (total, pred) in totals.iter_mut().zip(preds.iter()) {
256 *total += *pred;
257 }
258 }
259
260 let tree_count = self.trees.len() as f64;
261 for value in &mut totals {
262 *value /= tree_count;
263 }
264
265 totals
266 }
267
268 fn predict_classification_table(&self, table: &dyn TableAccess) -> Vec<f64> {
269 let probabilities = self
270 .predict_proba_table(table)
271 .expect("classification forest supports probabilities");
272 let class_labels = self
273 .class_labels()
274 .expect("classification forest stores class labels");
275
276 probabilities
277 .into_iter()
278 .map(|row| {
279 let (best_index, _) = row
280 .iter()
281 .copied()
282 .enumerate()
283 .max_by(|(left_index, left), (right_index, right)| {
284 left.total_cmp(right)
285 .then_with(|| right_index.cmp(left_index))
286 })
287 .expect("classification probability row is non-empty");
288 class_labels[best_index]
289 })
290 .collect()
291 }
292}
293
294fn compute_oob_score(
295 task: Task,
296 trained_trees: &[TrainedTree],
297 train_set: &dyn TableAccess,
298) -> Option<f64> {
299 match task {
300 Task::Classification => compute_classification_oob_score(trained_trees, train_set),
301 Task::Regression => compute_regression_oob_score(trained_trees, train_set),
302 }
303}
304
305fn compute_classification_oob_score(
306 trained_trees: &[TrainedTree],
307 train_set: &dyn TableAccess,
308) -> Option<f64> {
309 let class_labels = trained_trees.first()?.model.class_labels()?;
310 let mut totals = vec![vec![0.0; class_labels.len()]; train_set.n_rows()];
311 let mut counts = vec![0usize; train_set.n_rows()];
312
313 for tree in trained_trees {
314 if tree.oob_rows.is_empty() {
315 continue;
316 }
317 let oob_table = SampledTable::new(train_set, tree.oob_rows.clone());
318 let probabilities = tree
319 .model
320 .predict_proba_table(&oob_table)
321 .expect("classification tree supports predict_proba");
322 for (&row_index, row_probs) in tree.oob_rows.iter().zip(probabilities.iter()) {
323 for (total, prob) in totals[row_index].iter_mut().zip(row_probs.iter()) {
324 *total += *prob;
325 }
326 counts[row_index] += 1;
327 }
328 }
329
330 let mut correct = 0usize;
331 let mut covered = 0usize;
332 for row_index in 0..train_set.n_rows() {
333 if counts[row_index] == 0 {
334 continue;
335 }
336 covered += 1;
337 let predicted = totals[row_index]
338 .iter()
339 .copied()
340 .enumerate()
341 .max_by(|(li, lv), (ri, rv)| lv.total_cmp(rv).then_with(|| ri.cmp(li)))
342 .map(|(index, _)| class_labels[index])
343 .expect("classification probability row is non-empty");
344 if predicted
345 .total_cmp(&train_set.target_value(row_index))
346 .is_eq()
347 {
348 correct += 1;
349 }
350 }
351
352 (covered > 0).then_some(correct as f64 / covered as f64)
353}
354
355fn compute_regression_oob_score(
356 trained_trees: &[TrainedTree],
357 train_set: &dyn TableAccess,
358) -> Option<f64> {
359 let mut totals = vec![0.0; train_set.n_rows()];
360 let mut counts = vec![0usize; train_set.n_rows()];
361
362 for tree in trained_trees {
363 if tree.oob_rows.is_empty() {
364 continue;
365 }
366 let oob_table = SampledTable::new(train_set, tree.oob_rows.clone());
367 let predictions = tree.model.predict_table(&oob_table);
368 for (&row_index, prediction) in tree.oob_rows.iter().zip(predictions.iter().copied()) {
369 totals[row_index] += prediction;
370 counts[row_index] += 1;
371 }
372 }
373
374 let covered_rows: Vec<usize> = counts
375 .iter()
376 .enumerate()
377 .filter_map(|(row_index, count)| (*count > 0).then_some(row_index))
378 .collect();
379 if covered_rows.is_empty() {
380 return None;
381 }
382
383 let mean_target = covered_rows
384 .iter()
385 .map(|row_index| train_set.target_value(*row_index))
386 .sum::<f64>()
387 / covered_rows.len() as f64;
388 let mut residual_sum = 0.0;
389 let mut total_sum = 0.0;
390 for row_index in covered_rows {
391 let actual = train_set.target_value(row_index);
392 let prediction = totals[row_index] / counts[row_index] as f64;
393 residual_sum += (actual - prediction).powi(2);
394 total_sum += (actual - mean_target).powi(2);
395 }
396 if total_sum == 0.0 {
397 return None;
398 }
399 Some(1.0 - residual_sum / total_sum)
400}
401
402impl<'a> SampledTable<'a> {
403 fn new(base: &'a dyn TableAccess, row_indices: Vec<usize>) -> Self {
404 Self { base, row_indices }
405 }
406
407 fn resolve_row(&self, row_index: usize) -> usize {
408 self.row_indices[row_index]
409 }
410}
411
412impl<'a> NoCanaryTable<'a> {
413 fn new(base: &'a dyn TableAccess) -> Self {
414 Self { base }
415 }
416}
417
418impl TableAccess for SampledTable<'_> {
419 fn n_rows(&self) -> usize {
420 self.row_indices.len()
421 }
422
423 fn n_features(&self) -> usize {
424 self.base.n_features()
425 }
426
427 fn canaries(&self) -> usize {
428 self.base.canaries()
429 }
430
431 fn numeric_bin_cap(&self) -> usize {
432 self.base.numeric_bin_cap()
433 }
434
435 fn binned_feature_count(&self) -> usize {
436 self.base.binned_feature_count()
437 }
438
439 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
440 self.base
441 .feature_value(feature_index, self.resolve_row(row_index))
442 }
443
444 fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
445 self.base
446 .is_missing(feature_index, self.resolve_row(row_index))
447 }
448
449 fn is_binary_feature(&self, index: usize) -> bool {
450 self.base.is_binary_feature(index)
451 }
452
453 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
454 self.base
455 .binned_value(feature_index, self.resolve_row(row_index))
456 }
457
458 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
459 self.base
460 .binned_boolean_value(feature_index, self.resolve_row(row_index))
461 }
462
463 fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
464 self.base.binned_column_kind(index)
465 }
466
467 fn is_binary_binned_feature(&self, index: usize) -> bool {
468 self.base.is_binary_binned_feature(index)
469 }
470
471 fn target_value(&self, row_index: usize) -> f64 {
472 self.base.target_value(self.resolve_row(row_index))
473 }
474}
475
476impl TableAccess for NoCanaryTable<'_> {
477 fn n_rows(&self) -> usize {
478 self.base.n_rows()
479 }
480
481 fn n_features(&self) -> usize {
482 self.base.n_features()
483 }
484
485 fn canaries(&self) -> usize {
486 0
487 }
488
489 fn numeric_bin_cap(&self) -> usize {
490 self.base.numeric_bin_cap()
491 }
492
493 fn binned_feature_count(&self) -> usize {
494 self.base.binned_feature_count() - self.base.canaries()
495 }
496
497 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
498 self.base.feature_value(feature_index, row_index)
499 }
500
501 fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
502 self.base.is_missing(feature_index, row_index)
503 }
504
505 fn is_binary_feature(&self, index: usize) -> bool {
506 self.base.is_binary_feature(index)
507 }
508
509 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
510 self.base.binned_value(feature_index, row_index)
511 }
512
513 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
514 self.base.binned_boolean_value(feature_index, row_index)
515 }
516
517 fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
518 self.base.binned_column_kind(index)
519 }
520
521 fn is_binary_binned_feature(&self, index: usize) -> bool {
522 self.base.is_binary_binned_feature(index)
523 }
524
525 fn target_value(&self, row_index: usize) -> f64 {
526 self.base.target_value(row_index)
527 }
528}