1use crate::bootstrap::BootstrapSampler;
2use crate::ir::TrainingMetadata;
3use crate::tree::second_order::{
4 SecondOrderRegressionTreeError, SecondOrderRegressionTreeOptions,
5 train_cart_regressor_from_gradients_and_hessians_with_status,
6 train_oblivious_regressor_from_gradients_and_hessians_with_status,
7 train_randomized_regressor_from_gradients_and_hessians_with_status,
8};
9use crate::{
10 Criterion, FeaturePreprocessing, Model, Parallelism, PredictError, Task, TrainConfig, TreeType,
11 capture_feature_preprocessing,
12};
13use forestfire_data::TableAccess;
14use rand::SeedableRng;
15use rand::rngs::StdRng;
16use rand::seq::SliceRandom;
17
18#[derive(Debug, Clone)]
19pub struct GradientBoostedTrees {
20 task: Task,
21 tree_type: TreeType,
22 trees: Vec<Model>,
23 tree_weights: Vec<f64>,
24 base_score: f64,
25 learning_rate: f64,
26 bootstrap: bool,
27 top_gradient_fraction: f64,
28 other_gradient_fraction: f64,
29 max_features: usize,
30 seed: Option<u64>,
31 num_features: usize,
32 feature_preprocessing: Vec<FeaturePreprocessing>,
33 class_labels: Option<Vec<f64>>,
34 training_canaries: usize,
35}
36
37#[derive(Debug)]
38pub enum BoostingError {
39 InvalidTargetValue { row: usize, value: f64 },
40 UnsupportedClassificationClassCount(usize),
41 InvalidLearningRate(f64),
42 InvalidTopGradientFraction(f64),
43 InvalidOtherGradientFraction(f64),
44 SecondOrderTree(SecondOrderRegressionTreeError),
45}
46
47impl std::fmt::Display for BoostingError {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 match self {
50 BoostingError::InvalidTargetValue { row, value } => write!(
51 f,
52 "Boosting targets must be finite values. Found {} at row {}.",
53 value, row
54 ),
55 BoostingError::UnsupportedClassificationClassCount(count) => write!(
56 f,
57 "Gradient boosting currently supports binary classification only. Found {} classes.",
58 count
59 ),
60 BoostingError::InvalidLearningRate(value) => write!(
61 f,
62 "learning_rate must be finite and greater than 0. Found {}.",
63 value
64 ),
65 BoostingError::InvalidTopGradientFraction(value) => write!(
66 f,
67 "top_gradient_fraction must be in the interval (0, 1]. Found {}.",
68 value
69 ),
70 BoostingError::InvalidOtherGradientFraction(value) => write!(
71 f,
72 "other_gradient_fraction must be in the interval [0, 1), and top_gradient_fraction + other_gradient_fraction must be at most 1. Found {}.",
73 value
74 ),
75 BoostingError::SecondOrderTree(err) => err.fmt(f),
76 }
77 }
78}
79
80impl std::error::Error for BoostingError {}
81
82struct SampledTable<'a> {
83 base: &'a dyn TableAccess,
84 row_indices: Vec<usize>,
85}
86
87impl GradientBoostedTrees {
88 #[allow(clippy::too_many_arguments)]
89 pub fn new(
90 task: Task,
91 tree_type: TreeType,
92 trees: Vec<Model>,
93 tree_weights: Vec<f64>,
94 base_score: f64,
95 learning_rate: f64,
96 bootstrap: bool,
97 top_gradient_fraction: f64,
98 other_gradient_fraction: f64,
99 max_features: usize,
100 seed: Option<u64>,
101 num_features: usize,
102 feature_preprocessing: Vec<FeaturePreprocessing>,
103 class_labels: Option<Vec<f64>>,
104 training_canaries: usize,
105 ) -> Self {
106 Self {
107 task,
108 tree_type,
109 trees,
110 tree_weights,
111 base_score,
112 learning_rate,
113 bootstrap,
114 top_gradient_fraction,
115 other_gradient_fraction,
116 max_features,
117 seed,
118 num_features,
119 feature_preprocessing,
120 class_labels,
121 training_canaries,
122 }
123 }
124
125 pub(crate) fn train(
126 train_set: &dyn TableAccess,
127 config: TrainConfig,
128 parallelism: Parallelism,
129 ) -> Result<Self, BoostingError> {
130 let n_trees = config.n_trees.unwrap_or(100);
131 let learning_rate = config.learning_rate.unwrap_or(0.1);
132 let bootstrap = config.bootstrap;
133 let top_gradient_fraction = config.top_gradient_fraction.unwrap_or(0.2);
134 let other_gradient_fraction = config.other_gradient_fraction.unwrap_or(0.1);
135 validate_boosting_parameters(
136 train_set,
137 learning_rate,
138 top_gradient_fraction,
139 other_gradient_fraction,
140 )?;
141
142 let max_features = config
143 .max_features
144 .resolve(config.task, train_set.binned_feature_count());
145 let base_seed = config.seed.unwrap_or(0xB005_7EED_u64);
146 let tree_options = crate::RegressionTreeOptions {
147 max_depth: config.max_depth.unwrap_or(8),
148 min_samples_split: config.min_samples_split.unwrap_or(2),
149 min_samples_leaf: config.min_samples_leaf.unwrap_or(1),
150 max_features: Some(max_features),
151 random_seed: 0,
152 };
153 let tree_options = SecondOrderRegressionTreeOptions {
154 tree_options,
155 l2_regularization: 1.0,
156 min_sum_hessian_in_leaf: 1e-3,
157 min_gain_to_split: 0.0,
158 };
159 let feature_preprocessing = capture_feature_preprocessing(train_set);
160 let sampler = BootstrapSampler::new(train_set.n_rows());
161
162 let (mut raw_predictions, class_labels, base_score) = match config.task {
163 Task::Regression => {
164 let targets = finite_targets(train_set)?;
165 let base_score = targets.iter().sum::<f64>() / targets.len() as f64;
166 (vec![base_score; train_set.n_rows()], None, base_score)
167 }
168 Task::Classification => {
169 let (labels, encoded_targets) = binary_classification_targets(train_set)?;
170 let positive_rate = (encoded_targets.iter().sum::<f64>()
171 / encoded_targets.len() as f64)
172 .clamp(1e-6, 1.0 - 1e-6);
173 let base_score = (positive_rate / (1.0 - positive_rate)).ln();
174 (
175 vec![base_score; train_set.n_rows()],
176 Some(labels),
177 base_score,
178 )
179 }
180 };
181
182 let mut trees = Vec::with_capacity(n_trees);
183 let mut tree_weights = Vec::with_capacity(n_trees);
184 let regression_targets = if config.task == Task::Regression {
185 Some(finite_targets(train_set)?)
186 } else {
187 None
188 };
189 let classification_targets = if config.task == Task::Classification {
190 Some(binary_classification_targets(train_set)?.1)
191 } else {
192 None
193 };
194
195 for tree_index in 0..n_trees {
196 let stage_seed = mix_seed(base_seed, tree_index as u64);
197 let (gradients, hessians) = match config.task {
198 Task::Regression => squared_error_gradients_and_hessians(
199 raw_predictions.as_slice(),
200 regression_targets
201 .as_ref()
202 .expect("regression targets exist for regression boosting"),
203 ),
204 Task::Classification => logistic_gradients_and_hessians(
205 raw_predictions.as_slice(),
206 classification_targets
207 .as_ref()
208 .expect("classification targets exist for classification boosting"),
209 ),
210 };
211
212 let base_rows = if bootstrap {
213 sampler.sample(stage_seed)
214 } else {
215 (0..train_set.n_rows()).collect()
216 };
217 let sampled_rows = gradient_focus_sample(
218 &base_rows,
219 &gradients,
220 &hessians,
221 top_gradient_fraction,
222 other_gradient_fraction,
223 mix_seed(stage_seed, 0x6011_5A11),
224 );
225 let sampled_table = SampledTable::new(train_set, sampled_rows.row_indices);
226 let mut stage_tree_options = tree_options;
227 stage_tree_options.tree_options.random_seed = stage_seed;
228 let stage_result = match config.tree_type {
229 TreeType::Cart => train_cart_regressor_from_gradients_and_hessians_with_status(
230 &sampled_table,
231 &sampled_rows.gradients,
232 &sampled_rows.hessians,
233 parallelism,
234 stage_tree_options,
235 ),
236 TreeType::Randomized => {
237 train_randomized_regressor_from_gradients_and_hessians_with_status(
238 &sampled_table,
239 &sampled_rows.gradients,
240 &sampled_rows.hessians,
241 parallelism,
242 stage_tree_options,
243 )
244 }
245 TreeType::Oblivious => {
246 train_oblivious_regressor_from_gradients_and_hessians_with_status(
247 &sampled_table,
248 &sampled_rows.gradients,
249 &sampled_rows.hessians,
250 parallelism,
251 stage_tree_options,
252 )
253 }
254 _ => unreachable!("boosting tree type validated by training dispatch"),
255 }
256 .map_err(BoostingError::SecondOrderTree)?;
257
258 if stage_result.root_canary_selected {
259 break;
260 }
261
262 let stage_tree = stage_result.model;
263 let stage_model = Model::DecisionTreeRegressor(stage_tree);
264 let stage_predictions = stage_model.predict_table(train_set);
265 for (raw_prediction, stage_prediction) in raw_predictions
266 .iter_mut()
267 .zip(stage_predictions.iter().copied())
268 {
269 *raw_prediction += learning_rate * stage_prediction;
270 }
271 tree_weights.push(learning_rate);
272 trees.push(stage_model);
273 }
274
275 Ok(Self::new(
276 config.task,
277 config.tree_type,
278 trees,
279 tree_weights,
280 base_score,
281 learning_rate,
282 bootstrap,
283 top_gradient_fraction,
284 other_gradient_fraction,
285 max_features,
286 config.seed,
287 train_set.n_features(),
288 feature_preprocessing,
289 class_labels,
290 train_set.canaries(),
291 ))
292 }
293
294 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
295 match self.task {
296 Task::Regression => self.predict_regression_table(table),
297 Task::Classification => self.predict_classification_table(table),
298 }
299 }
300
301 pub fn predict_proba_table(
302 &self,
303 table: &dyn TableAccess,
304 ) -> Result<Vec<Vec<f64>>, PredictError> {
305 if self.task != Task::Classification {
306 return Err(PredictError::ProbabilityPredictionRequiresClassification);
307 }
308 Ok(self
309 .raw_scores(table)
310 .into_iter()
311 .map(|score| {
312 let positive = sigmoid(score);
313 vec![1.0 - positive, positive]
314 })
315 .collect())
316 }
317
318 pub fn task(&self) -> Task {
319 self.task
320 }
321
322 pub fn criterion(&self) -> Criterion {
323 Criterion::SecondOrder
324 }
325
326 pub fn tree_type(&self) -> TreeType {
327 self.tree_type
328 }
329
330 pub fn trees(&self) -> &[Model] {
331 &self.trees
332 }
333
334 pub fn tree_weights(&self) -> &[f64] {
335 &self.tree_weights
336 }
337
338 pub fn base_score(&self) -> f64 {
339 self.base_score
340 }
341
342 pub fn num_features(&self) -> usize {
343 self.num_features
344 }
345
346 pub fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
347 &self.feature_preprocessing
348 }
349
350 pub fn class_labels(&self) -> Option<Vec<f64>> {
351 self.class_labels.clone()
352 }
353
354 pub fn training_metadata(&self) -> TrainingMetadata {
355 TrainingMetadata {
356 algorithm: "gbm".to_string(),
357 task: match self.task {
358 Task::Regression => "regression".to_string(),
359 Task::Classification => "classification".to_string(),
360 },
361 tree_type: match self.tree_type {
362 TreeType::Cart => "cart".to_string(),
363 TreeType::Randomized => "randomized".to_string(),
364 TreeType::Oblivious => "oblivious".to_string(),
365 _ => unreachable!("boosting only supports cart/randomized/oblivious"),
366 },
367 criterion: "second_order".to_string(),
368 canaries: self.training_canaries,
369 compute_oob: false,
370 max_depth: self.trees.first().and_then(Model::max_depth),
371 min_samples_split: self.trees.first().and_then(Model::min_samples_split),
372 min_samples_leaf: self.trees.first().and_then(Model::min_samples_leaf),
373 n_trees: Some(self.trees.len()),
374 max_features: Some(self.max_features),
375 seed: self.seed,
376 oob_score: None,
377 class_labels: self.class_labels.clone(),
378 learning_rate: Some(self.learning_rate),
379 bootstrap: Some(self.bootstrap),
380 top_gradient_fraction: Some(self.top_gradient_fraction),
381 other_gradient_fraction: Some(self.other_gradient_fraction),
382 }
383 }
384
385 fn raw_scores(&self, table: &dyn TableAccess) -> Vec<f64> {
386 let mut scores = vec![self.base_score; table.n_rows()];
387 for (tree, weight) in self.trees.iter().zip(self.tree_weights.iter().copied()) {
388 let predictions = tree.predict_table(table);
389 for (score, prediction) in scores.iter_mut().zip(predictions.iter().copied()) {
390 *score += weight * prediction;
391 }
392 }
393 scores
394 }
395
396 fn predict_regression_table(&self, table: &dyn TableAccess) -> Vec<f64> {
397 self.raw_scores(table)
398 }
399
400 fn predict_classification_table(&self, table: &dyn TableAccess) -> Vec<f64> {
401 let class_labels = self
402 .class_labels
403 .as_ref()
404 .expect("classification boosting stores class labels");
405 self.raw_scores(table)
406 .into_iter()
407 .map(|score| {
408 if sigmoid(score) >= 0.5 {
409 class_labels[1]
410 } else {
411 class_labels[0]
412 }
413 })
414 .collect()
415 }
416}
417
418struct GradientFocusedSample {
419 row_indices: Vec<usize>,
420 gradients: Vec<f64>,
421 hessians: Vec<f64>,
422}
423
424impl<'a> SampledTable<'a> {
425 fn new(base: &'a dyn TableAccess, row_indices: Vec<usize>) -> Self {
426 Self { base, row_indices }
427 }
428
429 fn resolve_row(&self, row_index: usize) -> usize {
430 self.row_indices[row_index]
431 }
432}
433
434impl TableAccess for SampledTable<'_> {
435 fn n_rows(&self) -> usize {
436 self.row_indices.len()
437 }
438
439 fn n_features(&self) -> usize {
440 self.base.n_features()
441 }
442
443 fn canaries(&self) -> usize {
444 self.base.canaries()
445 }
446
447 fn numeric_bin_cap(&self) -> usize {
448 self.base.numeric_bin_cap()
449 }
450
451 fn binned_feature_count(&self) -> usize {
452 self.base.binned_feature_count()
453 }
454
455 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
456 self.base
457 .feature_value(feature_index, self.resolve_row(row_index))
458 }
459
460 fn is_binary_feature(&self, index: usize) -> bool {
461 self.base.is_binary_feature(index)
462 }
463
464 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
465 self.base
466 .binned_value(feature_index, self.resolve_row(row_index))
467 }
468
469 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
470 self.base
471 .binned_boolean_value(feature_index, self.resolve_row(row_index))
472 }
473
474 fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
475 self.base.binned_column_kind(index)
476 }
477
478 fn is_binary_binned_feature(&self, index: usize) -> bool {
479 self.base.is_binary_binned_feature(index)
480 }
481
482 fn target_value(&self, row_index: usize) -> f64 {
483 self.base.target_value(self.resolve_row(row_index))
484 }
485}
486
487fn validate_boosting_parameters(
488 train_set: &dyn TableAccess,
489 learning_rate: f64,
490 top_gradient_fraction: f64,
491 other_gradient_fraction: f64,
492) -> Result<(), BoostingError> {
493 if train_set.n_rows() == 0 {
494 return Err(BoostingError::InvalidLearningRate(learning_rate));
495 }
496 if !learning_rate.is_finite() || learning_rate <= 0.0 {
497 return Err(BoostingError::InvalidLearningRate(learning_rate));
498 }
499 if !top_gradient_fraction.is_finite()
500 || top_gradient_fraction <= 0.0
501 || top_gradient_fraction > 1.0
502 {
503 return Err(BoostingError::InvalidTopGradientFraction(
504 top_gradient_fraction,
505 ));
506 }
507 if !other_gradient_fraction.is_finite()
508 || !(0.0..1.0).contains(&other_gradient_fraction)
509 || top_gradient_fraction + other_gradient_fraction > 1.0
510 {
511 return Err(BoostingError::InvalidOtherGradientFraction(
512 other_gradient_fraction,
513 ));
514 }
515 Ok(())
516}
517
518fn finite_targets(train_set: &dyn TableAccess) -> Result<Vec<f64>, BoostingError> {
519 (0..train_set.n_rows())
520 .map(|row_index| {
521 let value = train_set.target_value(row_index);
522 if value.is_finite() {
523 Ok(value)
524 } else {
525 Err(BoostingError::InvalidTargetValue {
526 row: row_index,
527 value,
528 })
529 }
530 })
531 .collect()
532}
533
534fn binary_classification_targets(
535 train_set: &dyn TableAccess,
536) -> Result<(Vec<f64>, Vec<f64>), BoostingError> {
537 let mut labels = finite_targets(train_set)?;
538 labels.sort_by(|left, right| left.total_cmp(right));
539 labels.dedup_by(|left, right| left.total_cmp(right).is_eq());
540 if labels.len() != 2 {
541 return Err(BoostingError::UnsupportedClassificationClassCount(
542 labels.len(),
543 ));
544 }
545
546 let negative = labels[0];
547 let encoded = (0..train_set.n_rows())
548 .map(|row_index| {
549 if train_set
550 .target_value(row_index)
551 .total_cmp(&negative)
552 .is_eq()
553 {
554 0.0
555 } else {
556 1.0
557 }
558 })
559 .collect();
560 Ok((labels, encoded))
561}
562
563fn squared_error_gradients_and_hessians(
564 raw_predictions: &[f64],
565 targets: &[f64],
566) -> (Vec<f64>, Vec<f64>) {
567 (
568 raw_predictions
569 .iter()
570 .zip(targets.iter())
571 .map(|(prediction, target)| prediction - target)
572 .collect(),
573 vec![1.0; targets.len()],
574 )
575}
576
577fn logistic_gradients_and_hessians(
578 raw_predictions: &[f64],
579 targets: &[f64],
580) -> (Vec<f64>, Vec<f64>) {
581 let mut gradients = Vec::with_capacity(targets.len());
582 let mut hessians = Vec::with_capacity(targets.len());
583 for (raw_prediction, target) in raw_predictions.iter().zip(targets.iter()) {
584 let probability = sigmoid(*raw_prediction);
585 gradients.push(probability - target);
586 hessians.push((probability * (1.0 - probability)).max(1e-12));
587 }
588 (gradients, hessians)
589}
590
591fn sigmoid(value: f64) -> f64 {
592 if value >= 0.0 {
593 let exp = (-value).exp();
594 1.0 / (1.0 + exp)
595 } else {
596 let exp = value.exp();
597 exp / (1.0 + exp)
598 }
599}
600
601fn gradient_focus_sample(
602 base_rows: &[usize],
603 gradients: &[f64],
604 hessians: &[f64],
605 top_gradient_fraction: f64,
606 other_gradient_fraction: f64,
607 seed: u64,
608) -> GradientFocusedSample {
609 let mut ranked = base_rows
610 .iter()
611 .copied()
612 .map(|row_index| (row_index, gradients[row_index].abs()))
613 .collect::<Vec<_>>();
614 ranked.sort_by(|(left_row, left_abs), (right_row, right_abs)| {
615 right_abs
616 .total_cmp(left_abs)
617 .then_with(|| left_row.cmp(right_row))
618 });
619
620 let top_count = ((ranked.len() as f64) * top_gradient_fraction)
621 .ceil()
622 .clamp(1.0, ranked.len() as f64) as usize;
623 let mut row_indices = Vec::with_capacity(ranked.len());
624 let mut sampled_gradients = Vec::with_capacity(ranked.len());
625 let mut sampled_hessians = Vec::with_capacity(ranked.len());
626
627 for (row_index, _) in ranked.iter().take(top_count) {
628 row_indices.push(*row_index);
629 sampled_gradients.push(gradients[*row_index]);
630 sampled_hessians.push(hessians[*row_index]);
631 }
632
633 if top_count < ranked.len() && other_gradient_fraction > 0.0 {
634 let remaining = ranked[top_count..]
635 .iter()
636 .map(|(row_index, _)| *row_index)
637 .collect::<Vec<_>>();
638 let other_count = ((remaining.len() as f64) * other_gradient_fraction)
639 .ceil()
640 .min(remaining.len() as f64) as usize;
641 if other_count > 0 {
642 let mut remaining = remaining;
643 let mut rng = StdRng::seed_from_u64(seed);
644 remaining.shuffle(&mut rng);
645 let gradient_scale = (1.0 - top_gradient_fraction) / other_gradient_fraction;
646 for row_index in remaining.into_iter().take(other_count) {
647 row_indices.push(row_index);
648 sampled_gradients.push(gradients[row_index] * gradient_scale);
649 sampled_hessians.push(hessians[row_index] * gradient_scale);
650 }
651 }
652 }
653
654 GradientFocusedSample {
655 row_indices,
656 gradients: sampled_gradients,
657 hessians: sampled_hessians,
658 }
659}
660
661fn mix_seed(base_seed: u64, value: u64) -> u64 {
662 base_seed ^ value.wrapping_mul(0x9E37_79B9_7F4A_7C15).rotate_left(17)
663}
664
665#[cfg(test)]
666mod tests {
667 use super::*;
668 use crate::{MaxFeatures, TrainAlgorithm, TrainConfig};
669 use forestfire_data::{BinnedColumnKind, TableAccess};
670 use forestfire_data::{DenseTable, NumericBins};
671
672 #[test]
673 fn regression_boosting_fits_simple_signal() {
674 let table = DenseTable::with_options(
675 vec![
676 vec![0.0],
677 vec![0.0],
678 vec![1.0],
679 vec![1.0],
680 vec![2.0],
681 vec![2.0],
682 ],
683 vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
684 0,
685 NumericBins::fixed(8).unwrap(),
686 )
687 .unwrap();
688
689 let model = GradientBoostedTrees::train(
690 &table,
691 TrainConfig {
692 algorithm: TrainAlgorithm::Gbm,
693 task: Task::Regression,
694 tree_type: TreeType::Cart,
695 criterion: Criterion::SecondOrder,
696 n_trees: Some(20),
697 learning_rate: Some(0.2),
698 max_depth: Some(2),
699 ..TrainConfig::default()
700 },
701 Parallelism::sequential(),
702 )
703 .unwrap();
704
705 let predictions = model.predict_table(&table);
706 assert!(predictions[0] < predictions[2]);
707 assert!(predictions[2] < predictions[4]);
708 }
709
710 #[test]
711 fn classification_boosting_produces_binary_probabilities() {
712 let table = DenseTable::with_options(
713 vec![vec![0.0], vec![0.1], vec![0.9], vec![1.0]],
714 vec![0.0, 0.0, 1.0, 1.0],
715 0,
716 NumericBins::fixed(8).unwrap(),
717 )
718 .unwrap();
719
720 let model = GradientBoostedTrees::train(
721 &table,
722 TrainConfig {
723 algorithm: TrainAlgorithm::Gbm,
724 task: Task::Classification,
725 tree_type: TreeType::Cart,
726 criterion: Criterion::SecondOrder,
727 n_trees: Some(25),
728 learning_rate: Some(0.2),
729 max_depth: Some(2),
730 ..TrainConfig::default()
731 },
732 Parallelism::sequential(),
733 )
734 .unwrap();
735
736 let probabilities = model.predict_proba_table(&table).unwrap();
737 assert_eq!(probabilities.len(), 4);
738 assert!(probabilities[0][1] < 0.5);
739 assert!(probabilities[3][1] > 0.5);
740 }
741
742 #[test]
743 fn classification_boosting_rejects_multiclass_targets() {
744 let table =
745 DenseTable::new(vec![vec![0.0], vec![1.0], vec![2.0]], vec![0.0, 1.0, 2.0]).unwrap();
746
747 let error = GradientBoostedTrees::train(
748 &table,
749 TrainConfig {
750 algorithm: TrainAlgorithm::Gbm,
751 task: Task::Classification,
752 tree_type: TreeType::Cart,
753 criterion: Criterion::SecondOrder,
754 ..TrainConfig::default()
755 },
756 Parallelism::sequential(),
757 )
758 .unwrap_err();
759
760 assert!(matches!(
761 error,
762 BoostingError::UnsupportedClassificationClassCount(3)
763 ));
764 }
765
766 struct RootCanaryTable;
767
768 impl TableAccess for RootCanaryTable {
769 fn n_rows(&self) -> usize {
770 4
771 }
772
773 fn n_features(&self) -> usize {
774 1
775 }
776
777 fn canaries(&self) -> usize {
778 1
779 }
780
781 fn numeric_bin_cap(&self) -> usize {
782 2
783 }
784
785 fn binned_feature_count(&self) -> usize {
786 2
787 }
788
789 fn feature_value(&self, _feature_index: usize, _row_index: usize) -> f64 {
790 0.0
791 }
792
793 fn is_binary_feature(&self, _index: usize) -> bool {
794 true
795 }
796
797 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
798 match feature_index {
799 0 => 0,
800 1 => u16::from(row_index >= 2),
801 _ => unreachable!(),
802 }
803 }
804
805 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
806 Some(match feature_index {
807 0 => false,
808 1 => row_index >= 2,
809 _ => unreachable!(),
810 })
811 }
812
813 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
814 match index {
815 0 => BinnedColumnKind::Real { source_index: 0 },
816 1 => BinnedColumnKind::Canary {
817 source_index: 0,
818 copy_index: 0,
819 },
820 _ => unreachable!(),
821 }
822 }
823
824 fn is_binary_binned_feature(&self, _index: usize) -> bool {
825 true
826 }
827
828 fn target_value(&self, row_index: usize) -> f64 {
829 [0.0, 0.0, 1.0, 1.0][row_index]
830 }
831 }
832
833 #[test]
834 fn boosting_stops_when_root_split_is_a_canary() {
835 let table = RootCanaryTable;
836
837 let model = GradientBoostedTrees::train(
838 &table,
839 TrainConfig {
840 algorithm: TrainAlgorithm::Gbm,
841 task: Task::Regression,
842 tree_type: TreeType::Cart,
843 criterion: Criterion::SecondOrder,
844 n_trees: Some(10),
845 max_features: MaxFeatures::All,
846 learning_rate: Some(0.1),
847 top_gradient_fraction: Some(1.0),
848 other_gradient_fraction: Some(0.0),
849 ..TrainConfig::default()
850 },
851 Parallelism::sequential(),
852 )
853 .unwrap();
854
855 assert_eq!(model.trees().len(), 0);
856 assert_eq!(model.training_metadata().n_trees, Some(0));
857 assert!(
858 model
859 .predict_table(&table)
860 .iter()
861 .all(|value| value.is_finite())
862 );
863 }
864}