1use crate::bootstrap::BootstrapSampler;
9use crate::ir::TrainingMetadata;
10use crate::tree::second_order::{
11 SecondOrderRegressionTreeError, SecondOrderRegressionTreeOptions,
12 train_cart_regressor_from_gradients_and_hessians_with_status,
13 train_oblivious_regressor_from_gradients_and_hessians_with_status,
14 train_randomized_regressor_from_gradients_and_hessians_with_status,
15};
16use crate::tree::shared::mix_seed;
17use crate::{
18 Criterion, FeaturePreprocessing, Model, Parallelism, PredictError, Task, TrainConfig, TreeType,
19 capture_feature_preprocessing,
20};
21use forestfire_data::TableAccess;
22use rand::SeedableRng;
23use rand::rngs::StdRng;
24use rand::seq::SliceRandom;
25
26#[derive(Debug, Clone)]
31pub struct GradientBoostedTrees {
32 task: Task,
33 tree_type: TreeType,
34 trees: Vec<Model>,
35 tree_weights: Vec<f64>,
36 base_score: f64,
37 learning_rate: f64,
38 bootstrap: bool,
39 top_gradient_fraction: f64,
40 other_gradient_fraction: f64,
41 max_features: usize,
42 seed: Option<u64>,
43 num_features: usize,
44 feature_preprocessing: Vec<FeaturePreprocessing>,
45 class_labels: Option<Vec<f64>>,
46 training_canaries: usize,
47}
48
49#[derive(Debug)]
50pub enum BoostingError {
51 InvalidTargetValue { row: usize, value: f64 },
52 UnsupportedClassificationClassCount(usize),
53 InvalidLearningRate(f64),
54 InvalidTopGradientFraction(f64),
55 InvalidOtherGradientFraction(f64),
56 SecondOrderTree(SecondOrderRegressionTreeError),
57}
58
59impl std::fmt::Display for BoostingError {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 BoostingError::InvalidTargetValue { row, value } => write!(
63 f,
64 "Boosting targets must be finite values. Found {} at row {}.",
65 value, row
66 ),
67 BoostingError::UnsupportedClassificationClassCount(count) => write!(
68 f,
69 "Gradient boosting currently supports binary classification only. Found {} classes.",
70 count
71 ),
72 BoostingError::InvalidLearningRate(value) => write!(
73 f,
74 "learning_rate must be finite and greater than 0. Found {}.",
75 value
76 ),
77 BoostingError::InvalidTopGradientFraction(value) => write!(
78 f,
79 "top_gradient_fraction must be in the interval (0, 1]. Found {}.",
80 value
81 ),
82 BoostingError::InvalidOtherGradientFraction(value) => write!(
83 f,
84 "other_gradient_fraction must be in the interval [0, 1), and top_gradient_fraction + other_gradient_fraction must be at most 1. Found {}.",
85 value
86 ),
87 BoostingError::SecondOrderTree(err) => err.fmt(f),
88 }
89 }
90}
91
92impl std::error::Error for BoostingError {}
93
94struct SampledTable<'a> {
95 base: &'a dyn TableAccess,
96 row_indices: Vec<usize>,
97}
98
99impl GradientBoostedTrees {
100 #[allow(clippy::too_many_arguments)]
101 pub fn new(
102 task: Task,
103 tree_type: TreeType,
104 trees: Vec<Model>,
105 tree_weights: Vec<f64>,
106 base_score: f64,
107 learning_rate: f64,
108 bootstrap: bool,
109 top_gradient_fraction: f64,
110 other_gradient_fraction: f64,
111 max_features: usize,
112 seed: Option<u64>,
113 num_features: usize,
114 feature_preprocessing: Vec<FeaturePreprocessing>,
115 class_labels: Option<Vec<f64>>,
116 training_canaries: usize,
117 ) -> Self {
118 Self {
119 task,
120 tree_type,
121 trees,
122 tree_weights,
123 base_score,
124 learning_rate,
125 bootstrap,
126 top_gradient_fraction,
127 other_gradient_fraction,
128 max_features,
129 seed,
130 num_features,
131 feature_preprocessing,
132 class_labels,
133 training_canaries,
134 }
135 }
136
137 pub(crate) fn train(
138 train_set: &dyn TableAccess,
139 config: TrainConfig,
140 parallelism: Parallelism,
141 ) -> Result<Self, BoostingError> {
142 let n_trees = config.n_trees.unwrap_or(100);
143 let learning_rate = config.learning_rate.unwrap_or(0.1);
144 let bootstrap = config.bootstrap;
145 let top_gradient_fraction = config.top_gradient_fraction.unwrap_or(0.2);
146 let other_gradient_fraction = config.other_gradient_fraction.unwrap_or(0.1);
147 validate_boosting_parameters(
148 train_set,
149 learning_rate,
150 top_gradient_fraction,
151 other_gradient_fraction,
152 )?;
153
154 let max_features = config
155 .max_features
156 .resolve(config.task, train_set.binned_feature_count());
157 let base_seed = config.seed.unwrap_or(0xB005_7EED_u64);
158 let tree_options = crate::RegressionTreeOptions {
159 max_depth: config.max_depth.unwrap_or(8),
160 min_samples_split: config.min_samples_split.unwrap_or(2),
161 min_samples_leaf: config.min_samples_leaf.unwrap_or(1),
162 max_features: Some(max_features),
163 random_seed: 0,
164 };
165 let tree_options = SecondOrderRegressionTreeOptions {
166 tree_options,
167 l2_regularization: 1.0,
168 min_sum_hessian_in_leaf: 1e-3,
169 min_gain_to_split: 0.0,
170 };
171 let feature_preprocessing = capture_feature_preprocessing(train_set);
172 let sampler = BootstrapSampler::new(train_set.n_rows());
173
174 let (mut raw_predictions, class_labels, base_score) = match config.task {
175 Task::Regression => {
176 let targets = finite_targets(train_set)?;
177 let base_score = targets.iter().sum::<f64>() / targets.len() as f64;
178 (vec![base_score; train_set.n_rows()], None, base_score)
179 }
180 Task::Classification => {
181 let (labels, encoded_targets) = binary_classification_targets(train_set)?;
182 let positive_rate = (encoded_targets.iter().sum::<f64>()
183 / encoded_targets.len() as f64)
184 .clamp(1e-6, 1.0 - 1e-6);
185 let base_score = (positive_rate / (1.0 - positive_rate)).ln();
186 (
187 vec![base_score; train_set.n_rows()],
188 Some(labels),
189 base_score,
190 )
191 }
192 };
193
194 let mut trees = Vec::with_capacity(n_trees);
195 let mut tree_weights = Vec::with_capacity(n_trees);
196 let regression_targets = if config.task == Task::Regression {
197 Some(finite_targets(train_set)?)
198 } else {
199 None
200 };
201 let classification_targets = if config.task == Task::Classification {
202 Some(binary_classification_targets(train_set)?.1)
203 } else {
204 None
205 };
206
207 for tree_index in 0..n_trees {
208 let stage_seed = mix_seed(base_seed, tree_index as u64);
209 let (gradients, hessians) = match config.task {
213 Task::Regression => squared_error_gradients_and_hessians(
214 raw_predictions.as_slice(),
215 regression_targets
216 .as_ref()
217 .expect("regression targets exist for regression boosting"),
218 ),
219 Task::Classification => logistic_gradients_and_hessians(
220 raw_predictions.as_slice(),
221 classification_targets
222 .as_ref()
223 .expect("classification targets exist for classification boosting"),
224 ),
225 };
226
227 let base_rows = if bootstrap {
228 sampler.sample(stage_seed)
229 } else {
230 (0..train_set.n_rows()).collect()
231 };
232 let sampled_rows = gradient_focus_sample(
236 &base_rows,
237 &gradients,
238 &hessians,
239 top_gradient_fraction,
240 other_gradient_fraction,
241 mix_seed(stage_seed, 0x6011_5A11),
242 );
243 let sampled_table = SampledTable::new(train_set, sampled_rows.row_indices);
244 let mut stage_tree_options = tree_options;
245 stage_tree_options.tree_options.random_seed = stage_seed;
246 let stage_result = match config.tree_type {
247 TreeType::Cart => train_cart_regressor_from_gradients_and_hessians_with_status(
248 &sampled_table,
249 &sampled_rows.gradients,
250 &sampled_rows.hessians,
251 parallelism,
252 stage_tree_options,
253 ),
254 TreeType::Randomized => {
255 train_randomized_regressor_from_gradients_and_hessians_with_status(
256 &sampled_table,
257 &sampled_rows.gradients,
258 &sampled_rows.hessians,
259 parallelism,
260 stage_tree_options,
261 )
262 }
263 TreeType::Oblivious => {
264 train_oblivious_regressor_from_gradients_and_hessians_with_status(
265 &sampled_table,
266 &sampled_rows.gradients,
267 &sampled_rows.hessians,
268 parallelism,
269 stage_tree_options,
270 )
271 }
272 _ => unreachable!("boosting tree type validated by training dispatch"),
273 }
274 .map_err(BoostingError::SecondOrderTree)?;
275
276 if stage_result.root_canary_selected {
279 break;
280 }
281
282 let stage_tree = stage_result.model;
283 let stage_model = Model::DecisionTreeRegressor(stage_tree);
284 let stage_predictions = stage_model.predict_table(train_set);
285 for (raw_prediction, stage_prediction) in raw_predictions
286 .iter_mut()
287 .zip(stage_predictions.iter().copied())
288 {
289 *raw_prediction += learning_rate * stage_prediction;
292 }
293 tree_weights.push(learning_rate);
294 trees.push(stage_model);
295 }
296
297 Ok(Self::new(
298 config.task,
299 config.tree_type,
300 trees,
301 tree_weights,
302 base_score,
303 learning_rate,
304 bootstrap,
305 top_gradient_fraction,
306 other_gradient_fraction,
307 max_features,
308 config.seed,
309 train_set.n_features(),
310 feature_preprocessing,
311 class_labels,
312 train_set.canaries(),
313 ))
314 }
315
316 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
317 match self.task {
318 Task::Regression => self.predict_regression_table(table),
319 Task::Classification => self.predict_classification_table(table),
320 }
321 }
322
323 pub fn predict_proba_table(
324 &self,
325 table: &dyn TableAccess,
326 ) -> Result<Vec<Vec<f64>>, PredictError> {
327 if self.task != Task::Classification {
328 return Err(PredictError::ProbabilityPredictionRequiresClassification);
329 }
330 Ok(self
331 .raw_scores(table)
332 .into_iter()
333 .map(|score| {
334 let positive = sigmoid(score);
335 vec![1.0 - positive, positive]
336 })
337 .collect())
338 }
339
340 pub fn task(&self) -> Task {
341 self.task
342 }
343
344 pub fn criterion(&self) -> Criterion {
345 Criterion::SecondOrder
346 }
347
348 pub fn tree_type(&self) -> TreeType {
349 self.tree_type
350 }
351
352 pub fn trees(&self) -> &[Model] {
353 &self.trees
354 }
355
356 pub fn tree_weights(&self) -> &[f64] {
357 &self.tree_weights
358 }
359
360 pub fn base_score(&self) -> f64 {
361 self.base_score
362 }
363
364 pub fn num_features(&self) -> usize {
365 self.num_features
366 }
367
368 pub fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
369 &self.feature_preprocessing
370 }
371
372 pub fn class_labels(&self) -> Option<Vec<f64>> {
373 self.class_labels.clone()
374 }
375
376 pub fn training_metadata(&self) -> TrainingMetadata {
377 TrainingMetadata {
378 algorithm: "gbm".to_string(),
379 task: match self.task {
380 Task::Regression => "regression".to_string(),
381 Task::Classification => "classification".to_string(),
382 },
383 tree_type: match self.tree_type {
384 TreeType::Cart => "cart".to_string(),
385 TreeType::Randomized => "randomized".to_string(),
386 TreeType::Oblivious => "oblivious".to_string(),
387 _ => unreachable!("boosting only supports cart/randomized/oblivious"),
388 },
389 criterion: "second_order".to_string(),
390 canaries: self.training_canaries,
391 compute_oob: false,
392 max_depth: self.trees.first().and_then(Model::max_depth),
393 min_samples_split: self.trees.first().and_then(Model::min_samples_split),
394 min_samples_leaf: self.trees.first().and_then(Model::min_samples_leaf),
395 n_trees: Some(self.trees.len()),
396 max_features: Some(self.max_features),
397 seed: self.seed,
398 oob_score: None,
399 class_labels: self.class_labels.clone(),
400 learning_rate: Some(self.learning_rate),
401 bootstrap: Some(self.bootstrap),
402 top_gradient_fraction: Some(self.top_gradient_fraction),
403 other_gradient_fraction: Some(self.other_gradient_fraction),
404 }
405 }
406
407 fn raw_scores(&self, table: &dyn TableAccess) -> Vec<f64> {
408 let mut scores = vec![self.base_score; table.n_rows()];
409 for (tree, weight) in self.trees.iter().zip(self.tree_weights.iter().copied()) {
410 let predictions = tree.predict_table(table);
411 for (score, prediction) in scores.iter_mut().zip(predictions.iter().copied()) {
412 *score += weight * prediction;
413 }
414 }
415 scores
416 }
417
418 fn predict_regression_table(&self, table: &dyn TableAccess) -> Vec<f64> {
419 self.raw_scores(table)
420 }
421
422 fn predict_classification_table(&self, table: &dyn TableAccess) -> Vec<f64> {
423 let class_labels = self
424 .class_labels
425 .as_ref()
426 .expect("classification boosting stores class labels");
427 self.raw_scores(table)
428 .into_iter()
429 .map(|score| {
430 if sigmoid(score) >= 0.5 {
431 class_labels[1]
432 } else {
433 class_labels[0]
434 }
435 })
436 .collect()
437 }
438}
439
440struct GradientFocusedSample {
441 row_indices: Vec<usize>,
442 gradients: Vec<f64>,
443 hessians: Vec<f64>,
444}
445
446impl<'a> SampledTable<'a> {
447 fn new(base: &'a dyn TableAccess, row_indices: Vec<usize>) -> Self {
448 Self { base, row_indices }
449 }
450
451 fn resolve_row(&self, row_index: usize) -> usize {
452 self.row_indices[row_index]
453 }
454}
455
456impl TableAccess for SampledTable<'_> {
457 fn n_rows(&self) -> usize {
458 self.row_indices.len()
459 }
460
461 fn n_features(&self) -> usize {
462 self.base.n_features()
463 }
464
465 fn canaries(&self) -> usize {
466 self.base.canaries()
467 }
468
469 fn numeric_bin_cap(&self) -> usize {
470 self.base.numeric_bin_cap()
471 }
472
473 fn binned_feature_count(&self) -> usize {
474 self.base.binned_feature_count()
475 }
476
477 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
478 self.base
479 .feature_value(feature_index, self.resolve_row(row_index))
480 }
481
482 fn is_binary_feature(&self, index: usize) -> bool {
483 self.base.is_binary_feature(index)
484 }
485
486 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
487 self.base
488 .binned_value(feature_index, self.resolve_row(row_index))
489 }
490
491 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
492 self.base
493 .binned_boolean_value(feature_index, self.resolve_row(row_index))
494 }
495
496 fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
497 self.base.binned_column_kind(index)
498 }
499
500 fn is_binary_binned_feature(&self, index: usize) -> bool {
501 self.base.is_binary_binned_feature(index)
502 }
503
504 fn target_value(&self, row_index: usize) -> f64 {
505 self.base.target_value(self.resolve_row(row_index))
506 }
507}
508
509fn validate_boosting_parameters(
510 train_set: &dyn TableAccess,
511 learning_rate: f64,
512 top_gradient_fraction: f64,
513 other_gradient_fraction: f64,
514) -> Result<(), BoostingError> {
515 if train_set.n_rows() == 0 {
516 return Err(BoostingError::InvalidLearningRate(learning_rate));
517 }
518 if !learning_rate.is_finite() || learning_rate <= 0.0 {
519 return Err(BoostingError::InvalidLearningRate(learning_rate));
520 }
521 if !top_gradient_fraction.is_finite()
522 || top_gradient_fraction <= 0.0
523 || top_gradient_fraction > 1.0
524 {
525 return Err(BoostingError::InvalidTopGradientFraction(
526 top_gradient_fraction,
527 ));
528 }
529 if !other_gradient_fraction.is_finite()
530 || !(0.0..1.0).contains(&other_gradient_fraction)
531 || top_gradient_fraction + other_gradient_fraction > 1.0
532 {
533 return Err(BoostingError::InvalidOtherGradientFraction(
534 other_gradient_fraction,
535 ));
536 }
537 Ok(())
538}
539
540fn finite_targets(train_set: &dyn TableAccess) -> Result<Vec<f64>, BoostingError> {
541 (0..train_set.n_rows())
542 .map(|row_index| {
543 let value = train_set.target_value(row_index);
544 if value.is_finite() {
545 Ok(value)
546 } else {
547 Err(BoostingError::InvalidTargetValue {
548 row: row_index,
549 value,
550 })
551 }
552 })
553 .collect()
554}
555
556fn binary_classification_targets(
557 train_set: &dyn TableAccess,
558) -> Result<(Vec<f64>, Vec<f64>), BoostingError> {
559 let mut labels = finite_targets(train_set)?;
560 labels.sort_by(|left, right| left.total_cmp(right));
561 labels.dedup_by(|left, right| left.total_cmp(right).is_eq());
562 if labels.len() != 2 {
563 return Err(BoostingError::UnsupportedClassificationClassCount(
564 labels.len(),
565 ));
566 }
567
568 let negative = labels[0];
569 let encoded = (0..train_set.n_rows())
570 .map(|row_index| {
571 if train_set
572 .target_value(row_index)
573 .total_cmp(&negative)
574 .is_eq()
575 {
576 0.0
577 } else {
578 1.0
579 }
580 })
581 .collect();
582 Ok((labels, encoded))
583}
584
585fn squared_error_gradients_and_hessians(
586 raw_predictions: &[f64],
587 targets: &[f64],
588) -> (Vec<f64>, Vec<f64>) {
589 (
590 raw_predictions
591 .iter()
592 .zip(targets.iter())
593 .map(|(prediction, target)| prediction - target)
594 .collect(),
595 vec![1.0; targets.len()],
596 )
597}
598
599fn logistic_gradients_and_hessians(
600 raw_predictions: &[f64],
601 targets: &[f64],
602) -> (Vec<f64>, Vec<f64>) {
603 let mut gradients = Vec::with_capacity(targets.len());
604 let mut hessians = Vec::with_capacity(targets.len());
605 for (raw_prediction, target) in raw_predictions.iter().zip(targets.iter()) {
606 let probability = sigmoid(*raw_prediction);
607 gradients.push(probability - target);
608 hessians.push((probability * (1.0 - probability)).max(1e-12));
609 }
610 (gradients, hessians)
611}
612
613fn sigmoid(value: f64) -> f64 {
614 if value >= 0.0 {
615 let exp = (-value).exp();
616 1.0 / (1.0 + exp)
617 } else {
618 let exp = value.exp();
619 exp / (1.0 + exp)
620 }
621}
622
623fn gradient_focus_sample(
624 base_rows: &[usize],
625 gradients: &[f64],
626 hessians: &[f64],
627 top_gradient_fraction: f64,
628 other_gradient_fraction: f64,
629 seed: u64,
630) -> GradientFocusedSample {
631 let mut ranked = base_rows
632 .iter()
633 .copied()
634 .map(|row_index| (row_index, gradients[row_index].abs()))
635 .collect::<Vec<_>>();
636 ranked.sort_by(|(left_row, left_abs), (right_row, right_abs)| {
637 right_abs
638 .total_cmp(left_abs)
639 .then_with(|| left_row.cmp(right_row))
640 });
641
642 let top_count = ((ranked.len() as f64) * top_gradient_fraction)
643 .ceil()
644 .clamp(1.0, ranked.len() as f64) as usize;
645 let mut row_indices = Vec::with_capacity(ranked.len());
646 let mut sampled_gradients = Vec::with_capacity(ranked.len());
647 let mut sampled_hessians = Vec::with_capacity(ranked.len());
648
649 for (row_index, _) in ranked.iter().take(top_count) {
650 row_indices.push(*row_index);
651 sampled_gradients.push(gradients[*row_index]);
652 sampled_hessians.push(hessians[*row_index]);
653 }
654
655 if top_count < ranked.len() && other_gradient_fraction > 0.0 {
656 let remaining = ranked[top_count..]
657 .iter()
658 .map(|(row_index, _)| *row_index)
659 .collect::<Vec<_>>();
660 let other_count = ((remaining.len() as f64) * other_gradient_fraction)
661 .ceil()
662 .min(remaining.len() as f64) as usize;
663 if other_count > 0 {
664 let mut remaining = remaining;
665 let mut rng = StdRng::seed_from_u64(seed);
666 remaining.shuffle(&mut rng);
667 let gradient_scale = (1.0 - top_gradient_fraction) / other_gradient_fraction;
668 for row_index in remaining.into_iter().take(other_count) {
669 row_indices.push(row_index);
670 sampled_gradients.push(gradients[row_index] * gradient_scale);
671 sampled_hessians.push(hessians[row_index] * gradient_scale);
672 }
673 }
674 }
675
676 GradientFocusedSample {
677 row_indices,
678 gradients: sampled_gradients,
679 hessians: sampled_hessians,
680 }
681}
682
683#[cfg(test)]
684mod tests {
685 use super::*;
686 use crate::{MaxFeatures, TrainAlgorithm, TrainConfig};
687 use forestfire_data::{BinnedColumnKind, TableAccess};
688 use forestfire_data::{DenseTable, NumericBins};
689
690 #[test]
691 fn regression_boosting_fits_simple_signal() {
692 let table = DenseTable::with_options(
693 vec![
694 vec![0.0],
695 vec![0.0],
696 vec![1.0],
697 vec![1.0],
698 vec![2.0],
699 vec![2.0],
700 ],
701 vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
702 0,
703 NumericBins::fixed(8).unwrap(),
704 )
705 .unwrap();
706
707 let model = GradientBoostedTrees::train(
708 &table,
709 TrainConfig {
710 algorithm: TrainAlgorithm::Gbm,
711 task: Task::Regression,
712 tree_type: TreeType::Cart,
713 criterion: Criterion::SecondOrder,
714 n_trees: Some(20),
715 learning_rate: Some(0.2),
716 max_depth: Some(2),
717 ..TrainConfig::default()
718 },
719 Parallelism::sequential(),
720 )
721 .unwrap();
722
723 let predictions = model.predict_table(&table);
724 assert!(predictions[0] < predictions[2]);
725 assert!(predictions[2] < predictions[4]);
726 }
727
728 #[test]
729 fn classification_boosting_produces_binary_probabilities() {
730 let table = DenseTable::with_options(
731 vec![vec![0.0], vec![0.1], vec![0.9], vec![1.0]],
732 vec![0.0, 0.0, 1.0, 1.0],
733 0,
734 NumericBins::fixed(8).unwrap(),
735 )
736 .unwrap();
737
738 let model = GradientBoostedTrees::train(
739 &table,
740 TrainConfig {
741 algorithm: TrainAlgorithm::Gbm,
742 task: Task::Classification,
743 tree_type: TreeType::Cart,
744 criterion: Criterion::SecondOrder,
745 n_trees: Some(25),
746 learning_rate: Some(0.2),
747 max_depth: Some(2),
748 ..TrainConfig::default()
749 },
750 Parallelism::sequential(),
751 )
752 .unwrap();
753
754 let probabilities = model.predict_proba_table(&table).unwrap();
755 assert_eq!(probabilities.len(), 4);
756 assert!(probabilities[0][1] < 0.5);
757 assert!(probabilities[3][1] > 0.5);
758 }
759
760 #[test]
761 fn classification_boosting_rejects_multiclass_targets() {
762 let table =
763 DenseTable::new(vec![vec![0.0], vec![1.0], vec![2.0]], vec![0.0, 1.0, 2.0]).unwrap();
764
765 let error = GradientBoostedTrees::train(
766 &table,
767 TrainConfig {
768 algorithm: TrainAlgorithm::Gbm,
769 task: Task::Classification,
770 tree_type: TreeType::Cart,
771 criterion: Criterion::SecondOrder,
772 ..TrainConfig::default()
773 },
774 Parallelism::sequential(),
775 )
776 .unwrap_err();
777
778 assert!(matches!(
779 error,
780 BoostingError::UnsupportedClassificationClassCount(3)
781 ));
782 }
783
784 struct RootCanaryTable;
785
786 impl TableAccess for RootCanaryTable {
787 fn n_rows(&self) -> usize {
788 4
789 }
790
791 fn n_features(&self) -> usize {
792 1
793 }
794
795 fn canaries(&self) -> usize {
796 1
797 }
798
799 fn numeric_bin_cap(&self) -> usize {
800 2
801 }
802
803 fn binned_feature_count(&self) -> usize {
804 2
805 }
806
807 fn feature_value(&self, _feature_index: usize, _row_index: usize) -> f64 {
808 0.0
809 }
810
811 fn is_binary_feature(&self, _index: usize) -> bool {
812 true
813 }
814
815 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
816 match feature_index {
817 0 => 0,
818 1 => u16::from(row_index >= 2),
819 _ => unreachable!(),
820 }
821 }
822
823 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
824 Some(match feature_index {
825 0 => false,
826 1 => row_index >= 2,
827 _ => unreachable!(),
828 })
829 }
830
831 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
832 match index {
833 0 => BinnedColumnKind::Real { source_index: 0 },
834 1 => BinnedColumnKind::Canary {
835 source_index: 0,
836 copy_index: 0,
837 },
838 _ => unreachable!(),
839 }
840 }
841
842 fn is_binary_binned_feature(&self, _index: usize) -> bool {
843 true
844 }
845
846 fn target_value(&self, row_index: usize) -> f64 {
847 [0.0, 0.0, 1.0, 1.0][row_index]
848 }
849 }
850
851 #[test]
852 fn boosting_stops_when_root_split_is_a_canary() {
853 let table = RootCanaryTable;
854
855 let model = GradientBoostedTrees::train(
856 &table,
857 TrainConfig {
858 algorithm: TrainAlgorithm::Gbm,
859 task: Task::Regression,
860 tree_type: TreeType::Cart,
861 criterion: Criterion::SecondOrder,
862 n_trees: Some(10),
863 max_features: MaxFeatures::All,
864 learning_rate: Some(0.1),
865 top_gradient_fraction: Some(1.0),
866 other_gradient_fraction: Some(0.0),
867 ..TrainConfig::default()
868 },
869 Parallelism::sequential(),
870 )
871 .unwrap();
872
873 assert_eq!(model.trees().len(), 0);
874 assert_eq!(model.training_metadata().n_trees, Some(0));
875 assert!(
876 model
877 .predict_table(&table)
878 .iter()
879 .all(|value| value.is_finite())
880 );
881 }
882}