Skip to main content

forestfire_core/
boosting.rs

1use crate::bootstrap::BootstrapSampler;
2use crate::ir::TrainingMetadata;
3use crate::tree::second_order::{
4    SecondOrderRegressionTreeError, SecondOrderRegressionTreeOptions,
5    train_cart_regressor_from_gradients_and_hessians_with_status,
6    train_oblivious_regressor_from_gradients_and_hessians_with_status,
7    train_randomized_regressor_from_gradients_and_hessians_with_status,
8};
9use crate::{
10    Criterion, FeaturePreprocessing, Model, Parallelism, PredictError, Task, TrainConfig, TreeType,
11    capture_feature_preprocessing,
12};
13use forestfire_data::TableAccess;
14use rand::SeedableRng;
15use rand::rngs::StdRng;
16use rand::seq::SliceRandom;
17
18#[derive(Debug, Clone)]
19pub struct GradientBoostedTrees {
20    task: Task,
21    tree_type: TreeType,
22    trees: Vec<Model>,
23    tree_weights: Vec<f64>,
24    base_score: f64,
25    learning_rate: f64,
26    bootstrap: bool,
27    top_gradient_fraction: f64,
28    other_gradient_fraction: f64,
29    max_features: usize,
30    seed: Option<u64>,
31    num_features: usize,
32    feature_preprocessing: Vec<FeaturePreprocessing>,
33    class_labels: Option<Vec<f64>>,
34    training_canaries: usize,
35}
36
37#[derive(Debug)]
38pub enum BoostingError {
39    InvalidTargetValue { row: usize, value: f64 },
40    UnsupportedClassificationClassCount(usize),
41    InvalidLearningRate(f64),
42    InvalidTopGradientFraction(f64),
43    InvalidOtherGradientFraction(f64),
44    SecondOrderTree(SecondOrderRegressionTreeError),
45}
46
47impl std::fmt::Display for BoostingError {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        match self {
50            BoostingError::InvalidTargetValue { row, value } => write!(
51                f,
52                "Boosting targets must be finite values. Found {} at row {}.",
53                value, row
54            ),
55            BoostingError::UnsupportedClassificationClassCount(count) => write!(
56                f,
57                "Gradient boosting currently supports binary classification only. Found {} classes.",
58                count
59            ),
60            BoostingError::InvalidLearningRate(value) => write!(
61                f,
62                "learning_rate must be finite and greater than 0. Found {}.",
63                value
64            ),
65            BoostingError::InvalidTopGradientFraction(value) => write!(
66                f,
67                "top_gradient_fraction must be in the interval (0, 1]. Found {}.",
68                value
69            ),
70            BoostingError::InvalidOtherGradientFraction(value) => write!(
71                f,
72                "other_gradient_fraction must be in the interval [0, 1), and top_gradient_fraction + other_gradient_fraction must be at most 1. Found {}.",
73                value
74            ),
75            BoostingError::SecondOrderTree(err) => err.fmt(f),
76        }
77    }
78}
79
80impl std::error::Error for BoostingError {}
81
82struct SampledTable<'a> {
83    base: &'a dyn TableAccess,
84    row_indices: Vec<usize>,
85}
86
87impl GradientBoostedTrees {
88    #[allow(clippy::too_many_arguments)]
89    pub fn new(
90        task: Task,
91        tree_type: TreeType,
92        trees: Vec<Model>,
93        tree_weights: Vec<f64>,
94        base_score: f64,
95        learning_rate: f64,
96        bootstrap: bool,
97        top_gradient_fraction: f64,
98        other_gradient_fraction: f64,
99        max_features: usize,
100        seed: Option<u64>,
101        num_features: usize,
102        feature_preprocessing: Vec<FeaturePreprocessing>,
103        class_labels: Option<Vec<f64>>,
104        training_canaries: usize,
105    ) -> Self {
106        Self {
107            task,
108            tree_type,
109            trees,
110            tree_weights,
111            base_score,
112            learning_rate,
113            bootstrap,
114            top_gradient_fraction,
115            other_gradient_fraction,
116            max_features,
117            seed,
118            num_features,
119            feature_preprocessing,
120            class_labels,
121            training_canaries,
122        }
123    }
124
125    pub(crate) fn train(
126        train_set: &dyn TableAccess,
127        config: TrainConfig,
128        parallelism: Parallelism,
129    ) -> Result<Self, BoostingError> {
130        let n_trees = config.n_trees.unwrap_or(100);
131        let learning_rate = config.learning_rate.unwrap_or(0.1);
132        let bootstrap = config.bootstrap;
133        let top_gradient_fraction = config.top_gradient_fraction.unwrap_or(0.2);
134        let other_gradient_fraction = config.other_gradient_fraction.unwrap_or(0.1);
135        validate_boosting_parameters(
136            train_set,
137            learning_rate,
138            top_gradient_fraction,
139            other_gradient_fraction,
140        )?;
141
142        let max_features = config
143            .max_features
144            .resolve(config.task, train_set.binned_feature_count());
145        let base_seed = config.seed.unwrap_or(0xB005_7EED_u64);
146        let tree_options = crate::RegressionTreeOptions {
147            max_depth: config.max_depth.unwrap_or(8),
148            min_samples_split: config.min_samples_split.unwrap_or(2),
149            min_samples_leaf: config.min_samples_leaf.unwrap_or(1),
150            max_features: Some(max_features),
151            random_seed: 0,
152        };
153        let tree_options = SecondOrderRegressionTreeOptions {
154            tree_options,
155            l2_regularization: 1.0,
156            min_sum_hessian_in_leaf: 1e-3,
157            min_gain_to_split: 0.0,
158        };
159        let feature_preprocessing = capture_feature_preprocessing(train_set);
160        let sampler = BootstrapSampler::new(train_set.n_rows());
161
162        let (mut raw_predictions, class_labels, base_score) = match config.task {
163            Task::Regression => {
164                let targets = finite_targets(train_set)?;
165                let base_score = targets.iter().sum::<f64>() / targets.len() as f64;
166                (vec![base_score; train_set.n_rows()], None, base_score)
167            }
168            Task::Classification => {
169                let (labels, encoded_targets) = binary_classification_targets(train_set)?;
170                let positive_rate = (encoded_targets.iter().sum::<f64>()
171                    / encoded_targets.len() as f64)
172                    .clamp(1e-6, 1.0 - 1e-6);
173                let base_score = (positive_rate / (1.0 - positive_rate)).ln();
174                (
175                    vec![base_score; train_set.n_rows()],
176                    Some(labels),
177                    base_score,
178                )
179            }
180        };
181
182        let mut trees = Vec::with_capacity(n_trees);
183        let mut tree_weights = Vec::with_capacity(n_trees);
184        let regression_targets = if config.task == Task::Regression {
185            Some(finite_targets(train_set)?)
186        } else {
187            None
188        };
189        let classification_targets = if config.task == Task::Classification {
190            Some(binary_classification_targets(train_set)?.1)
191        } else {
192            None
193        };
194
195        for tree_index in 0..n_trees {
196            let stage_seed = mix_seed(base_seed, tree_index as u64);
197            let (gradients, hessians) = match config.task {
198                Task::Regression => squared_error_gradients_and_hessians(
199                    raw_predictions.as_slice(),
200                    regression_targets
201                        .as_ref()
202                        .expect("regression targets exist for regression boosting"),
203                ),
204                Task::Classification => logistic_gradients_and_hessians(
205                    raw_predictions.as_slice(),
206                    classification_targets
207                        .as_ref()
208                        .expect("classification targets exist for classification boosting"),
209                ),
210            };
211
212            let base_rows = if bootstrap {
213                sampler.sample(stage_seed)
214            } else {
215                (0..train_set.n_rows()).collect()
216            };
217            let sampled_rows = gradient_focus_sample(
218                &base_rows,
219                &gradients,
220                &hessians,
221                top_gradient_fraction,
222                other_gradient_fraction,
223                mix_seed(stage_seed, 0x6011_5A11),
224            );
225            let sampled_table = SampledTable::new(train_set, sampled_rows.row_indices);
226            let mut stage_tree_options = tree_options;
227            stage_tree_options.tree_options.random_seed = stage_seed;
228            let stage_result = match config.tree_type {
229                TreeType::Cart => train_cart_regressor_from_gradients_and_hessians_with_status(
230                    &sampled_table,
231                    &sampled_rows.gradients,
232                    &sampled_rows.hessians,
233                    parallelism,
234                    stage_tree_options,
235                ),
236                TreeType::Randomized => {
237                    train_randomized_regressor_from_gradients_and_hessians_with_status(
238                        &sampled_table,
239                        &sampled_rows.gradients,
240                        &sampled_rows.hessians,
241                        parallelism,
242                        stage_tree_options,
243                    )
244                }
245                TreeType::Oblivious => {
246                    train_oblivious_regressor_from_gradients_and_hessians_with_status(
247                        &sampled_table,
248                        &sampled_rows.gradients,
249                        &sampled_rows.hessians,
250                        parallelism,
251                        stage_tree_options,
252                    )
253                }
254                _ => unreachable!("boosting tree type validated by training dispatch"),
255            }
256            .map_err(BoostingError::SecondOrderTree)?;
257
258            if stage_result.root_canary_selected {
259                break;
260            }
261
262            let stage_tree = stage_result.model;
263            let stage_model = Model::DecisionTreeRegressor(stage_tree);
264            let stage_predictions = stage_model.predict_table(train_set);
265            for (raw_prediction, stage_prediction) in raw_predictions
266                .iter_mut()
267                .zip(stage_predictions.iter().copied())
268            {
269                *raw_prediction += learning_rate * stage_prediction;
270            }
271            tree_weights.push(learning_rate);
272            trees.push(stage_model);
273        }
274
275        Ok(Self::new(
276            config.task,
277            config.tree_type,
278            trees,
279            tree_weights,
280            base_score,
281            learning_rate,
282            bootstrap,
283            top_gradient_fraction,
284            other_gradient_fraction,
285            max_features,
286            config.seed,
287            train_set.n_features(),
288            feature_preprocessing,
289            class_labels,
290            train_set.canaries(),
291        ))
292    }
293
294    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
295        match self.task {
296            Task::Regression => self.predict_regression_table(table),
297            Task::Classification => self.predict_classification_table(table),
298        }
299    }
300
301    pub fn predict_proba_table(
302        &self,
303        table: &dyn TableAccess,
304    ) -> Result<Vec<Vec<f64>>, PredictError> {
305        if self.task != Task::Classification {
306            return Err(PredictError::ProbabilityPredictionRequiresClassification);
307        }
308        Ok(self
309            .raw_scores(table)
310            .into_iter()
311            .map(|score| {
312                let positive = sigmoid(score);
313                vec![1.0 - positive, positive]
314            })
315            .collect())
316    }
317
318    pub fn task(&self) -> Task {
319        self.task
320    }
321
322    pub fn criterion(&self) -> Criterion {
323        Criterion::SecondOrder
324    }
325
326    pub fn tree_type(&self) -> TreeType {
327        self.tree_type
328    }
329
330    pub fn trees(&self) -> &[Model] {
331        &self.trees
332    }
333
334    pub fn tree_weights(&self) -> &[f64] {
335        &self.tree_weights
336    }
337
338    pub fn base_score(&self) -> f64 {
339        self.base_score
340    }
341
342    pub fn num_features(&self) -> usize {
343        self.num_features
344    }
345
346    pub fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
347        &self.feature_preprocessing
348    }
349
350    pub fn class_labels(&self) -> Option<Vec<f64>> {
351        self.class_labels.clone()
352    }
353
354    pub fn training_metadata(&self) -> TrainingMetadata {
355        TrainingMetadata {
356            algorithm: "gbm".to_string(),
357            task: match self.task {
358                Task::Regression => "regression".to_string(),
359                Task::Classification => "classification".to_string(),
360            },
361            tree_type: match self.tree_type {
362                TreeType::Cart => "cart".to_string(),
363                TreeType::Randomized => "randomized".to_string(),
364                TreeType::Oblivious => "oblivious".to_string(),
365                _ => unreachable!("boosting only supports cart/randomized/oblivious"),
366            },
367            criterion: "second_order".to_string(),
368            canaries: self.training_canaries,
369            compute_oob: false,
370            max_depth: self.trees.first().and_then(Model::max_depth),
371            min_samples_split: self.trees.first().and_then(Model::min_samples_split),
372            min_samples_leaf: self.trees.first().and_then(Model::min_samples_leaf),
373            n_trees: Some(self.trees.len()),
374            max_features: Some(self.max_features),
375            seed: self.seed,
376            oob_score: None,
377            class_labels: self.class_labels.clone(),
378            learning_rate: Some(self.learning_rate),
379            bootstrap: Some(self.bootstrap),
380            top_gradient_fraction: Some(self.top_gradient_fraction),
381            other_gradient_fraction: Some(self.other_gradient_fraction),
382        }
383    }
384
385    fn raw_scores(&self, table: &dyn TableAccess) -> Vec<f64> {
386        let mut scores = vec![self.base_score; table.n_rows()];
387        for (tree, weight) in self.trees.iter().zip(self.tree_weights.iter().copied()) {
388            let predictions = tree.predict_table(table);
389            for (score, prediction) in scores.iter_mut().zip(predictions.iter().copied()) {
390                *score += weight * prediction;
391            }
392        }
393        scores
394    }
395
396    fn predict_regression_table(&self, table: &dyn TableAccess) -> Vec<f64> {
397        self.raw_scores(table)
398    }
399
400    fn predict_classification_table(&self, table: &dyn TableAccess) -> Vec<f64> {
401        let class_labels = self
402            .class_labels
403            .as_ref()
404            .expect("classification boosting stores class labels");
405        self.raw_scores(table)
406            .into_iter()
407            .map(|score| {
408                if sigmoid(score) >= 0.5 {
409                    class_labels[1]
410                } else {
411                    class_labels[0]
412                }
413            })
414            .collect()
415    }
416}
417
418struct GradientFocusedSample {
419    row_indices: Vec<usize>,
420    gradients: Vec<f64>,
421    hessians: Vec<f64>,
422}
423
424impl<'a> SampledTable<'a> {
425    fn new(base: &'a dyn TableAccess, row_indices: Vec<usize>) -> Self {
426        Self { base, row_indices }
427    }
428
429    fn resolve_row(&self, row_index: usize) -> usize {
430        self.row_indices[row_index]
431    }
432}
433
434impl TableAccess for SampledTable<'_> {
435    fn n_rows(&self) -> usize {
436        self.row_indices.len()
437    }
438
439    fn n_features(&self) -> usize {
440        self.base.n_features()
441    }
442
443    fn canaries(&self) -> usize {
444        self.base.canaries()
445    }
446
447    fn numeric_bin_cap(&self) -> usize {
448        self.base.numeric_bin_cap()
449    }
450
451    fn binned_feature_count(&self) -> usize {
452        self.base.binned_feature_count()
453    }
454
455    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
456        self.base
457            .feature_value(feature_index, self.resolve_row(row_index))
458    }
459
460    fn is_binary_feature(&self, index: usize) -> bool {
461        self.base.is_binary_feature(index)
462    }
463
464    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
465        self.base
466            .binned_value(feature_index, self.resolve_row(row_index))
467    }
468
469    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
470        self.base
471            .binned_boolean_value(feature_index, self.resolve_row(row_index))
472    }
473
474    fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
475        self.base.binned_column_kind(index)
476    }
477
478    fn is_binary_binned_feature(&self, index: usize) -> bool {
479        self.base.is_binary_binned_feature(index)
480    }
481
482    fn target_value(&self, row_index: usize) -> f64 {
483        self.base.target_value(self.resolve_row(row_index))
484    }
485}
486
487fn validate_boosting_parameters(
488    train_set: &dyn TableAccess,
489    learning_rate: f64,
490    top_gradient_fraction: f64,
491    other_gradient_fraction: f64,
492) -> Result<(), BoostingError> {
493    if train_set.n_rows() == 0 {
494        return Err(BoostingError::InvalidLearningRate(learning_rate));
495    }
496    if !learning_rate.is_finite() || learning_rate <= 0.0 {
497        return Err(BoostingError::InvalidLearningRate(learning_rate));
498    }
499    if !top_gradient_fraction.is_finite()
500        || top_gradient_fraction <= 0.0
501        || top_gradient_fraction > 1.0
502    {
503        return Err(BoostingError::InvalidTopGradientFraction(
504            top_gradient_fraction,
505        ));
506    }
507    if !other_gradient_fraction.is_finite()
508        || !(0.0..1.0).contains(&other_gradient_fraction)
509        || top_gradient_fraction + other_gradient_fraction > 1.0
510    {
511        return Err(BoostingError::InvalidOtherGradientFraction(
512            other_gradient_fraction,
513        ));
514    }
515    Ok(())
516}
517
518fn finite_targets(train_set: &dyn TableAccess) -> Result<Vec<f64>, BoostingError> {
519    (0..train_set.n_rows())
520        .map(|row_index| {
521            let value = train_set.target_value(row_index);
522            if value.is_finite() {
523                Ok(value)
524            } else {
525                Err(BoostingError::InvalidTargetValue {
526                    row: row_index,
527                    value,
528                })
529            }
530        })
531        .collect()
532}
533
534fn binary_classification_targets(
535    train_set: &dyn TableAccess,
536) -> Result<(Vec<f64>, Vec<f64>), BoostingError> {
537    let mut labels = finite_targets(train_set)?;
538    labels.sort_by(|left, right| left.total_cmp(right));
539    labels.dedup_by(|left, right| left.total_cmp(right).is_eq());
540    if labels.len() != 2 {
541        return Err(BoostingError::UnsupportedClassificationClassCount(
542            labels.len(),
543        ));
544    }
545
546    let negative = labels[0];
547    let encoded = (0..train_set.n_rows())
548        .map(|row_index| {
549            if train_set
550                .target_value(row_index)
551                .total_cmp(&negative)
552                .is_eq()
553            {
554                0.0
555            } else {
556                1.0
557            }
558        })
559        .collect();
560    Ok((labels, encoded))
561}
562
563fn squared_error_gradients_and_hessians(
564    raw_predictions: &[f64],
565    targets: &[f64],
566) -> (Vec<f64>, Vec<f64>) {
567    (
568        raw_predictions
569            .iter()
570            .zip(targets.iter())
571            .map(|(prediction, target)| prediction - target)
572            .collect(),
573        vec![1.0; targets.len()],
574    )
575}
576
577fn logistic_gradients_and_hessians(
578    raw_predictions: &[f64],
579    targets: &[f64],
580) -> (Vec<f64>, Vec<f64>) {
581    let mut gradients = Vec::with_capacity(targets.len());
582    let mut hessians = Vec::with_capacity(targets.len());
583    for (raw_prediction, target) in raw_predictions.iter().zip(targets.iter()) {
584        let probability = sigmoid(*raw_prediction);
585        gradients.push(probability - target);
586        hessians.push((probability * (1.0 - probability)).max(1e-12));
587    }
588    (gradients, hessians)
589}
590
591fn sigmoid(value: f64) -> f64 {
592    if value >= 0.0 {
593        let exp = (-value).exp();
594        1.0 / (1.0 + exp)
595    } else {
596        let exp = value.exp();
597        exp / (1.0 + exp)
598    }
599}
600
601fn gradient_focus_sample(
602    base_rows: &[usize],
603    gradients: &[f64],
604    hessians: &[f64],
605    top_gradient_fraction: f64,
606    other_gradient_fraction: f64,
607    seed: u64,
608) -> GradientFocusedSample {
609    let mut ranked = base_rows
610        .iter()
611        .copied()
612        .map(|row_index| (row_index, gradients[row_index].abs()))
613        .collect::<Vec<_>>();
614    ranked.sort_by(|(left_row, left_abs), (right_row, right_abs)| {
615        right_abs
616            .total_cmp(left_abs)
617            .then_with(|| left_row.cmp(right_row))
618    });
619
620    let top_count = ((ranked.len() as f64) * top_gradient_fraction)
621        .ceil()
622        .clamp(1.0, ranked.len() as f64) as usize;
623    let mut row_indices = Vec::with_capacity(ranked.len());
624    let mut sampled_gradients = Vec::with_capacity(ranked.len());
625    let mut sampled_hessians = Vec::with_capacity(ranked.len());
626
627    for (row_index, _) in ranked.iter().take(top_count) {
628        row_indices.push(*row_index);
629        sampled_gradients.push(gradients[*row_index]);
630        sampled_hessians.push(hessians[*row_index]);
631    }
632
633    if top_count < ranked.len() && other_gradient_fraction > 0.0 {
634        let remaining = ranked[top_count..]
635            .iter()
636            .map(|(row_index, _)| *row_index)
637            .collect::<Vec<_>>();
638        let other_count = ((remaining.len() as f64) * other_gradient_fraction)
639            .ceil()
640            .min(remaining.len() as f64) as usize;
641        if other_count > 0 {
642            let mut remaining = remaining;
643            let mut rng = StdRng::seed_from_u64(seed);
644            remaining.shuffle(&mut rng);
645            let gradient_scale = (1.0 - top_gradient_fraction) / other_gradient_fraction;
646            for row_index in remaining.into_iter().take(other_count) {
647                row_indices.push(row_index);
648                sampled_gradients.push(gradients[row_index] * gradient_scale);
649                sampled_hessians.push(hessians[row_index] * gradient_scale);
650            }
651        }
652    }
653
654    GradientFocusedSample {
655        row_indices,
656        gradients: sampled_gradients,
657        hessians: sampled_hessians,
658    }
659}
660
661fn mix_seed(base_seed: u64, value: u64) -> u64 {
662    base_seed ^ value.wrapping_mul(0x9E37_79B9_7F4A_7C15).rotate_left(17)
663}
664
665#[cfg(test)]
666mod tests {
667    use super::*;
668    use crate::{MaxFeatures, TrainAlgorithm, TrainConfig};
669    use forestfire_data::{BinnedColumnKind, TableAccess};
670    use forestfire_data::{DenseTable, NumericBins};
671
672    #[test]
673    fn regression_boosting_fits_simple_signal() {
674        let table = DenseTable::with_options(
675            vec![
676                vec![0.0],
677                vec![0.0],
678                vec![1.0],
679                vec![1.0],
680                vec![2.0],
681                vec![2.0],
682            ],
683            vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
684            0,
685            NumericBins::fixed(8).unwrap(),
686        )
687        .unwrap();
688
689        let model = GradientBoostedTrees::train(
690            &table,
691            TrainConfig {
692                algorithm: TrainAlgorithm::Gbm,
693                task: Task::Regression,
694                tree_type: TreeType::Cart,
695                criterion: Criterion::SecondOrder,
696                n_trees: Some(20),
697                learning_rate: Some(0.2),
698                max_depth: Some(2),
699                ..TrainConfig::default()
700            },
701            Parallelism::sequential(),
702        )
703        .unwrap();
704
705        let predictions = model.predict_table(&table);
706        assert!(predictions[0] < predictions[2]);
707        assert!(predictions[2] < predictions[4]);
708    }
709
710    #[test]
711    fn classification_boosting_produces_binary_probabilities() {
712        let table = DenseTable::with_options(
713            vec![vec![0.0], vec![0.1], vec![0.9], vec![1.0]],
714            vec![0.0, 0.0, 1.0, 1.0],
715            0,
716            NumericBins::fixed(8).unwrap(),
717        )
718        .unwrap();
719
720        let model = GradientBoostedTrees::train(
721            &table,
722            TrainConfig {
723                algorithm: TrainAlgorithm::Gbm,
724                task: Task::Classification,
725                tree_type: TreeType::Cart,
726                criterion: Criterion::SecondOrder,
727                n_trees: Some(25),
728                learning_rate: Some(0.2),
729                max_depth: Some(2),
730                ..TrainConfig::default()
731            },
732            Parallelism::sequential(),
733        )
734        .unwrap();
735
736        let probabilities = model.predict_proba_table(&table).unwrap();
737        assert_eq!(probabilities.len(), 4);
738        assert!(probabilities[0][1] < 0.5);
739        assert!(probabilities[3][1] > 0.5);
740    }
741
742    #[test]
743    fn classification_boosting_rejects_multiclass_targets() {
744        let table =
745            DenseTable::new(vec![vec![0.0], vec![1.0], vec![2.0]], vec![0.0, 1.0, 2.0]).unwrap();
746
747        let error = GradientBoostedTrees::train(
748            &table,
749            TrainConfig {
750                algorithm: TrainAlgorithm::Gbm,
751                task: Task::Classification,
752                tree_type: TreeType::Cart,
753                criterion: Criterion::SecondOrder,
754                ..TrainConfig::default()
755            },
756            Parallelism::sequential(),
757        )
758        .unwrap_err();
759
760        assert!(matches!(
761            error,
762            BoostingError::UnsupportedClassificationClassCount(3)
763        ));
764    }
765
766    struct RootCanaryTable;
767
768    impl TableAccess for RootCanaryTable {
769        fn n_rows(&self) -> usize {
770            4
771        }
772
773        fn n_features(&self) -> usize {
774            1
775        }
776
777        fn canaries(&self) -> usize {
778            1
779        }
780
781        fn numeric_bin_cap(&self) -> usize {
782            2
783        }
784
785        fn binned_feature_count(&self) -> usize {
786            2
787        }
788
789        fn feature_value(&self, _feature_index: usize, _row_index: usize) -> f64 {
790            0.0
791        }
792
793        fn is_binary_feature(&self, _index: usize) -> bool {
794            true
795        }
796
797        fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
798            match feature_index {
799                0 => 0,
800                1 => u16::from(row_index >= 2),
801                _ => unreachable!(),
802            }
803        }
804
805        fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
806            Some(match feature_index {
807                0 => false,
808                1 => row_index >= 2,
809                _ => unreachable!(),
810            })
811        }
812
813        fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
814            match index {
815                0 => BinnedColumnKind::Real { source_index: 0 },
816                1 => BinnedColumnKind::Canary {
817                    source_index: 0,
818                    copy_index: 0,
819                },
820                _ => unreachable!(),
821            }
822        }
823
824        fn is_binary_binned_feature(&self, _index: usize) -> bool {
825            true
826        }
827
828        fn target_value(&self, row_index: usize) -> f64 {
829            [0.0, 0.0, 1.0, 1.0][row_index]
830        }
831    }
832
833    #[test]
834    fn boosting_stops_when_root_split_is_a_canary() {
835        let table = RootCanaryTable;
836
837        let model = GradientBoostedTrees::train(
838            &table,
839            TrainConfig {
840                algorithm: TrainAlgorithm::Gbm,
841                task: Task::Regression,
842                tree_type: TreeType::Cart,
843                criterion: Criterion::SecondOrder,
844                n_trees: Some(10),
845                max_features: MaxFeatures::All,
846                learning_rate: Some(0.1),
847                top_gradient_fraction: Some(1.0),
848                other_gradient_fraction: Some(0.0),
849                ..TrainConfig::default()
850            },
851            Parallelism::sequential(),
852        )
853        .unwrap();
854
855        assert_eq!(model.trees().len(), 0);
856        assert_eq!(model.training_metadata().n_trees, Some(0));
857        assert!(
858            model
859                .predict_table(&table)
860                .iter()
861                .all(|value| value.is_finite())
862        );
863    }
864}