1use crate::bootstrap::BootstrapSampler;
9use crate::ir::TrainingMetadata;
10use crate::tree::shared::mix_seed;
11use crate::{
12 Criterion, FeaturePreprocessing, MaxFeatures, Model, Parallelism, PredictError, Task,
13 TrainConfig, TrainError, TreeType, capture_feature_preprocessing, training,
14};
15use forestfire_data::TableAccess;
16use rayon::prelude::*;
17
18#[derive(Debug, Clone)]
24pub struct RandomForest {
25 task: Task,
26 criterion: Criterion,
27 tree_type: TreeType,
28 trees: Vec<Model>,
29 compute_oob: bool,
30 oob_score: Option<f64>,
31 max_features: usize,
32 seed: Option<u64>,
33 num_features: usize,
34 feature_preprocessing: Vec<FeaturePreprocessing>,
35}
36
37struct TrainedTree {
38 model: Model,
39 oob_rows: Vec<usize>,
40}
41
42struct SampledTable<'a> {
43 base: &'a dyn TableAccess,
44 row_indices: Vec<usize>,
45}
46
47struct NoCanaryTable<'a> {
48 base: &'a dyn TableAccess,
49}
50
51impl RandomForest {
52 #[allow(clippy::too_many_arguments)]
53 pub fn new(
54 task: Task,
55 criterion: Criterion,
56 tree_type: TreeType,
57 trees: Vec<Model>,
58 compute_oob: bool,
59 oob_score: Option<f64>,
60 max_features: usize,
61 seed: Option<u64>,
62 num_features: usize,
63 feature_preprocessing: Vec<FeaturePreprocessing>,
64 ) -> Self {
65 Self {
66 task,
67 criterion,
68 tree_type,
69 trees,
70 compute_oob,
71 oob_score,
72 max_features,
73 seed,
74 num_features,
75 feature_preprocessing,
76 }
77 }
78
79 pub(crate) fn train(
80 train_set: &dyn TableAccess,
81 config: TrainConfig,
82 criterion: Criterion,
83 parallelism: Parallelism,
84 ) -> Result<Self, TrainError> {
85 let n_trees = config.n_trees.unwrap_or(1000);
86 if n_trees == 0 {
87 return Err(TrainError::InvalidTreeCount(n_trees));
88 }
89 if matches!(config.max_features, MaxFeatures::Count(0)) {
90 return Err(TrainError::InvalidMaxFeatures(0));
91 }
92
93 let train_set = NoCanaryTable::new(train_set);
97 let sampler = BootstrapSampler::new(train_set.n_rows());
98 let feature_preprocessing = capture_feature_preprocessing(&train_set);
99 let max_features = config
100 .max_features
101 .resolve(config.task, train_set.binned_feature_count());
102 let base_seed = config.seed.unwrap_or(0x0005_EEDF_0E57_u64);
103 let tree_parallelism = Parallelism {
104 thread_count: parallelism.thread_count,
105 };
106 let per_tree_parallelism = Parallelism::sequential();
107 let train_tree = |tree_index: usize| -> Result<TrainedTree, TrainError> {
108 let tree_seed = mix_seed(base_seed, tree_index as u64);
109 let (sampled_rows, oob_rows) = sampler.sample_with_oob(tree_seed);
110 let sampled_table = SampledTable::new(&train_set, sampled_rows);
113 let model = training::train_single_model_with_feature_subset(
114 &sampled_table,
115 training::SingleModelFeatureSubsetConfig {
116 base: training::SingleModelConfig {
117 task: config.task,
118 tree_type: config.tree_type,
119 criterion,
120 parallelism: per_tree_parallelism,
121 max_depth: config.max_depth.unwrap_or(8),
122 min_samples_split: config.min_samples_split.unwrap_or(2),
123 min_samples_leaf: config.min_samples_leaf.unwrap_or(1),
124 },
125 max_features: Some(max_features),
126 random_seed: tree_seed,
127 },
128 )?;
129 Ok(TrainedTree { model, oob_rows })
130 };
131 let trained_trees = if tree_parallelism.enabled() {
132 (0..n_trees)
133 .into_par_iter()
134 .map(train_tree)
135 .collect::<Result<Vec<_>, _>>()?
136 } else {
137 (0..n_trees)
138 .map(train_tree)
139 .collect::<Result<Vec<_>, _>>()?
140 };
141 let oob_score = if config.compute_oob {
142 compute_oob_score(config.task, &trained_trees, &train_set)
143 } else {
144 None
145 };
146 let trees = trained_trees.into_iter().map(|tree| tree.model).collect();
147
148 Ok(Self::new(
149 config.task,
150 criterion,
151 config.tree_type,
152 trees,
153 config.compute_oob,
154 oob_score,
155 max_features,
156 config.seed,
157 train_set.n_features(),
158 feature_preprocessing,
159 ))
160 }
161
162 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
163 match self.task {
164 Task::Regression => self.predict_regression_table(table),
165 Task::Classification => self.predict_classification_table(table),
166 }
167 }
168
169 pub fn predict_proba_table(
170 &self,
171 table: &dyn TableAccess,
172 ) -> Result<Vec<Vec<f64>>, PredictError> {
173 if self.task != Task::Classification {
174 return Err(PredictError::ProbabilityPredictionRequiresClassification);
175 }
176
177 let mut totals = self.trees[0].predict_proba_table(table)?;
181 for tree in &self.trees[1..] {
182 let probs = tree.predict_proba_table(table)?;
183 for (row_totals, row_probs) in totals.iter_mut().zip(probs.iter()) {
184 for (total, prob) in row_totals.iter_mut().zip(row_probs.iter()) {
185 *total += *prob;
186 }
187 }
188 }
189
190 let tree_count = self.trees.len() as f64;
191 for row in &mut totals {
192 for value in row {
193 *value /= tree_count;
194 }
195 }
196
197 Ok(totals)
198 }
199
200 pub fn task(&self) -> Task {
201 self.task
202 }
203
204 pub fn criterion(&self) -> Criterion {
205 self.criterion
206 }
207
208 pub fn tree_type(&self) -> TreeType {
209 self.tree_type
210 }
211
212 pub fn trees(&self) -> &[Model] {
213 &self.trees
214 }
215
216 pub fn num_features(&self) -> usize {
217 self.num_features
218 }
219
220 pub fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
221 &self.feature_preprocessing
222 }
223
224 pub fn training_metadata(&self) -> TrainingMetadata {
225 let mut metadata = self.trees[0].training_metadata();
226 metadata.algorithm = "rf".to_string();
227 metadata.n_trees = Some(self.trees.len());
228 metadata.max_features = Some(self.max_features);
229 metadata.seed = self.seed;
230 metadata.compute_oob = self.compute_oob;
231 metadata.oob_score = self.oob_score;
232 metadata.learning_rate = None;
233 metadata.bootstrap = None;
234 metadata.top_gradient_fraction = None;
235 metadata.other_gradient_fraction = None;
236 metadata
237 }
238
239 pub fn class_labels(&self) -> Option<Vec<f64>> {
240 match self.task {
241 Task::Classification => self.trees[0].class_labels(),
242 Task::Regression => None,
243 }
244 }
245
246 pub fn oob_score(&self) -> Option<f64> {
247 self.oob_score
248 }
249
250 fn predict_regression_table(&self, table: &dyn TableAccess) -> Vec<f64> {
251 let mut totals = self.trees[0].predict_table(table);
252 for tree in &self.trees[1..] {
253 let preds = tree.predict_table(table);
254 for (total, pred) in totals.iter_mut().zip(preds.iter()) {
255 *total += *pred;
256 }
257 }
258
259 let tree_count = self.trees.len() as f64;
260 for value in &mut totals {
261 *value /= tree_count;
262 }
263
264 totals
265 }
266
267 fn predict_classification_table(&self, table: &dyn TableAccess) -> Vec<f64> {
268 let probabilities = self
269 .predict_proba_table(table)
270 .expect("classification forest supports probabilities");
271 let class_labels = self
272 .class_labels()
273 .expect("classification forest stores class labels");
274
275 probabilities
276 .into_iter()
277 .map(|row| {
278 let (best_index, _) = row
279 .iter()
280 .copied()
281 .enumerate()
282 .max_by(|(left_index, left), (right_index, right)| {
283 left.total_cmp(right)
284 .then_with(|| right_index.cmp(left_index))
285 })
286 .expect("classification probability row is non-empty");
287 class_labels[best_index]
288 })
289 .collect()
290 }
291}
292
293fn compute_oob_score(
294 task: Task,
295 trained_trees: &[TrainedTree],
296 train_set: &dyn TableAccess,
297) -> Option<f64> {
298 match task {
299 Task::Classification => compute_classification_oob_score(trained_trees, train_set),
300 Task::Regression => compute_regression_oob_score(trained_trees, train_set),
301 }
302}
303
304fn compute_classification_oob_score(
305 trained_trees: &[TrainedTree],
306 train_set: &dyn TableAccess,
307) -> Option<f64> {
308 let class_labels = trained_trees.first()?.model.class_labels()?;
309 let mut totals = vec![vec![0.0; class_labels.len()]; train_set.n_rows()];
310 let mut counts = vec![0usize; train_set.n_rows()];
311
312 for tree in trained_trees {
313 if tree.oob_rows.is_empty() {
314 continue;
315 }
316 let oob_table = SampledTable::new(train_set, tree.oob_rows.clone());
317 let probabilities = tree
318 .model
319 .predict_proba_table(&oob_table)
320 .expect("classification tree supports predict_proba");
321 for (&row_index, row_probs) in tree.oob_rows.iter().zip(probabilities.iter()) {
322 for (total, prob) in totals[row_index].iter_mut().zip(row_probs.iter()) {
323 *total += *prob;
324 }
325 counts[row_index] += 1;
326 }
327 }
328
329 let mut correct = 0usize;
330 let mut covered = 0usize;
331 for row_index in 0..train_set.n_rows() {
332 if counts[row_index] == 0 {
333 continue;
334 }
335 covered += 1;
336 let predicted = totals[row_index]
337 .iter()
338 .copied()
339 .enumerate()
340 .max_by(|(li, lv), (ri, rv)| lv.total_cmp(rv).then_with(|| ri.cmp(li)))
341 .map(|(index, _)| class_labels[index])
342 .expect("classification probability row is non-empty");
343 if predicted
344 .total_cmp(&train_set.target_value(row_index))
345 .is_eq()
346 {
347 correct += 1;
348 }
349 }
350
351 (covered > 0).then_some(correct as f64 / covered as f64)
352}
353
354fn compute_regression_oob_score(
355 trained_trees: &[TrainedTree],
356 train_set: &dyn TableAccess,
357) -> Option<f64> {
358 let mut totals = vec![0.0; train_set.n_rows()];
359 let mut counts = vec![0usize; train_set.n_rows()];
360
361 for tree in trained_trees {
362 if tree.oob_rows.is_empty() {
363 continue;
364 }
365 let oob_table = SampledTable::new(train_set, tree.oob_rows.clone());
366 let predictions = tree.model.predict_table(&oob_table);
367 for (&row_index, prediction) in tree.oob_rows.iter().zip(predictions.iter().copied()) {
368 totals[row_index] += prediction;
369 counts[row_index] += 1;
370 }
371 }
372
373 let covered_rows: Vec<usize> = counts
374 .iter()
375 .enumerate()
376 .filter_map(|(row_index, count)| (*count > 0).then_some(row_index))
377 .collect();
378 if covered_rows.is_empty() {
379 return None;
380 }
381
382 let mean_target = covered_rows
383 .iter()
384 .map(|row_index| train_set.target_value(*row_index))
385 .sum::<f64>()
386 / covered_rows.len() as f64;
387 let mut residual_sum = 0.0;
388 let mut total_sum = 0.0;
389 for row_index in covered_rows {
390 let actual = train_set.target_value(row_index);
391 let prediction = totals[row_index] / counts[row_index] as f64;
392 residual_sum += (actual - prediction).powi(2);
393 total_sum += (actual - mean_target).powi(2);
394 }
395 if total_sum == 0.0 {
396 return None;
397 }
398 Some(1.0 - residual_sum / total_sum)
399}
400
401impl<'a> SampledTable<'a> {
402 fn new(base: &'a dyn TableAccess, row_indices: Vec<usize>) -> Self {
403 Self { base, row_indices }
404 }
405
406 fn resolve_row(&self, row_index: usize) -> usize {
407 self.row_indices[row_index]
408 }
409}
410
411impl<'a> NoCanaryTable<'a> {
412 fn new(base: &'a dyn TableAccess) -> Self {
413 Self { base }
414 }
415}
416
417impl TableAccess for SampledTable<'_> {
418 fn n_rows(&self) -> usize {
419 self.row_indices.len()
420 }
421
422 fn n_features(&self) -> usize {
423 self.base.n_features()
424 }
425
426 fn canaries(&self) -> usize {
427 self.base.canaries()
428 }
429
430 fn numeric_bin_cap(&self) -> usize {
431 self.base.numeric_bin_cap()
432 }
433
434 fn binned_feature_count(&self) -> usize {
435 self.base.binned_feature_count()
436 }
437
438 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
439 self.base
440 .feature_value(feature_index, self.resolve_row(row_index))
441 }
442
443 fn is_binary_feature(&self, index: usize) -> bool {
444 self.base.is_binary_feature(index)
445 }
446
447 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
448 self.base
449 .binned_value(feature_index, self.resolve_row(row_index))
450 }
451
452 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
453 self.base
454 .binned_boolean_value(feature_index, self.resolve_row(row_index))
455 }
456
457 fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
458 self.base.binned_column_kind(index)
459 }
460
461 fn is_binary_binned_feature(&self, index: usize) -> bool {
462 self.base.is_binary_binned_feature(index)
463 }
464
465 fn target_value(&self, row_index: usize) -> f64 {
466 self.base.target_value(self.resolve_row(row_index))
467 }
468}
469
470impl TableAccess for NoCanaryTable<'_> {
471 fn n_rows(&self) -> usize {
472 self.base.n_rows()
473 }
474
475 fn n_features(&self) -> usize {
476 self.base.n_features()
477 }
478
479 fn canaries(&self) -> usize {
480 0
481 }
482
483 fn numeric_bin_cap(&self) -> usize {
484 self.base.numeric_bin_cap()
485 }
486
487 fn binned_feature_count(&self) -> usize {
488 self.base.binned_feature_count() - self.base.canaries()
489 }
490
491 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
492 self.base.feature_value(feature_index, row_index)
493 }
494
495 fn is_binary_feature(&self, index: usize) -> bool {
496 self.base.is_binary_feature(index)
497 }
498
499 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
500 self.base.binned_value(feature_index, row_index)
501 }
502
503 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
504 self.base.binned_boolean_value(feature_index, row_index)
505 }
506
507 fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
508 self.base.binned_column_kind(index)
509 }
510
511 fn is_binary_binned_feature(&self, index: usize) -> bool {
512 self.base.is_binary_binned_feature(index)
513 }
514
515 fn target_value(&self, row_index: usize) -> f64 {
516 self.base.target_value(row_index)
517 }
518}