Skip to main content

forestfire_core/
boosting.rs

1//! Gradient boosting implementation.
2//!
3//! The boosting path is intentionally "LightGBM-like" rather than a direct
4//! clone. It uses second-order trees, shrinkage, and gradient-focused sampling,
5//! but it keeps ForestFire's canary mechanism active so a stage can stop before
6//! growing a tree whose best root split is indistinguishable from noise.
7
8use crate::bootstrap::BootstrapSampler;
9use crate::ir::TrainingMetadata;
10use crate::tree::second_order::{
11    SecondOrderRegressionTreeError, SecondOrderRegressionTreeOptions,
12    train_cart_regressor_from_gradients_and_hessians_with_status,
13    train_oblivious_regressor_from_gradients_and_hessians_with_status,
14    train_randomized_regressor_from_gradients_and_hessians_with_status,
15};
16use crate::tree::shared::mix_seed;
17use crate::{
18    Criterion, FeaturePreprocessing, Model, Parallelism, PredictError, Task, TrainConfig, TreeType,
19    capture_feature_preprocessing,
20};
21use forestfire_data::TableAccess;
22use rand::SeedableRng;
23use rand::rngs::StdRng;
24use rand::seq::SliceRandom;
25
26/// Stage-wise gradient-boosted tree ensemble.
27///
28/// The ensemble keeps explicit tree weights and a base score so the semantic
29/// model can reconstruct raw margins exactly for both prediction and IR export.
30#[derive(Debug, Clone)]
31pub struct GradientBoostedTrees {
32    task: Task,
33    tree_type: TreeType,
34    trees: Vec<Model>,
35    tree_weights: Vec<f64>,
36    base_score: f64,
37    learning_rate: f64,
38    bootstrap: bool,
39    top_gradient_fraction: f64,
40    other_gradient_fraction: f64,
41    max_features: usize,
42    seed: Option<u64>,
43    num_features: usize,
44    feature_preprocessing: Vec<FeaturePreprocessing>,
45    class_labels: Option<Vec<f64>>,
46    training_canaries: usize,
47}
48
49#[derive(Debug)]
50pub enum BoostingError {
51    InvalidTargetValue { row: usize, value: f64 },
52    UnsupportedClassificationClassCount(usize),
53    InvalidLearningRate(f64),
54    InvalidTopGradientFraction(f64),
55    InvalidOtherGradientFraction(f64),
56    SecondOrderTree(SecondOrderRegressionTreeError),
57}
58
59impl std::fmt::Display for BoostingError {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            BoostingError::InvalidTargetValue { row, value } => write!(
63                f,
64                "Boosting targets must be finite values. Found {} at row {}.",
65                value, row
66            ),
67            BoostingError::UnsupportedClassificationClassCount(count) => write!(
68                f,
69                "Gradient boosting currently supports binary classification only. Found {} classes.",
70                count
71            ),
72            BoostingError::InvalidLearningRate(value) => write!(
73                f,
74                "learning_rate must be finite and greater than 0. Found {}.",
75                value
76            ),
77            BoostingError::InvalidTopGradientFraction(value) => write!(
78                f,
79                "top_gradient_fraction must be in the interval (0, 1]. Found {}.",
80                value
81            ),
82            BoostingError::InvalidOtherGradientFraction(value) => write!(
83                f,
84                "other_gradient_fraction must be in the interval [0, 1), and top_gradient_fraction + other_gradient_fraction must be at most 1. Found {}.",
85                value
86            ),
87            BoostingError::SecondOrderTree(err) => err.fmt(f),
88        }
89    }
90}
91
92impl std::error::Error for BoostingError {}
93
94struct SampledTable<'a> {
95    base: &'a dyn TableAccess,
96    row_indices: Vec<usize>,
97}
98
99impl GradientBoostedTrees {
100    #[allow(clippy::too_many_arguments)]
101    pub fn new(
102        task: Task,
103        tree_type: TreeType,
104        trees: Vec<Model>,
105        tree_weights: Vec<f64>,
106        base_score: f64,
107        learning_rate: f64,
108        bootstrap: bool,
109        top_gradient_fraction: f64,
110        other_gradient_fraction: f64,
111        max_features: usize,
112        seed: Option<u64>,
113        num_features: usize,
114        feature_preprocessing: Vec<FeaturePreprocessing>,
115        class_labels: Option<Vec<f64>>,
116        training_canaries: usize,
117    ) -> Self {
118        Self {
119            task,
120            tree_type,
121            trees,
122            tree_weights,
123            base_score,
124            learning_rate,
125            bootstrap,
126            top_gradient_fraction,
127            other_gradient_fraction,
128            max_features,
129            seed,
130            num_features,
131            feature_preprocessing,
132            class_labels,
133            training_canaries,
134        }
135    }
136
137    pub(crate) fn train(
138        train_set: &dyn TableAccess,
139        config: TrainConfig,
140        parallelism: Parallelism,
141    ) -> Result<Self, BoostingError> {
142        let n_trees = config.n_trees.unwrap_or(100);
143        let learning_rate = config.learning_rate.unwrap_or(0.1);
144        let bootstrap = config.bootstrap;
145        let top_gradient_fraction = config.top_gradient_fraction.unwrap_or(0.2);
146        let other_gradient_fraction = config.other_gradient_fraction.unwrap_or(0.1);
147        validate_boosting_parameters(
148            train_set,
149            learning_rate,
150            top_gradient_fraction,
151            other_gradient_fraction,
152        )?;
153
154        let max_features = config
155            .max_features
156            .resolve(config.task, train_set.binned_feature_count());
157        let base_seed = config.seed.unwrap_or(0xB005_7EED_u64);
158        let tree_options = crate::RegressionTreeOptions {
159            max_depth: config.max_depth.unwrap_or(8),
160            min_samples_split: config.min_samples_split.unwrap_or(2),
161            min_samples_leaf: config.min_samples_leaf.unwrap_or(1),
162            max_features: Some(max_features),
163            random_seed: 0,
164        };
165        let tree_options = SecondOrderRegressionTreeOptions {
166            tree_options,
167            l2_regularization: 1.0,
168            min_sum_hessian_in_leaf: 1e-3,
169            min_gain_to_split: 0.0,
170        };
171        let feature_preprocessing = capture_feature_preprocessing(train_set);
172        let sampler = BootstrapSampler::new(train_set.n_rows());
173
174        let (mut raw_predictions, class_labels, base_score) = match config.task {
175            Task::Regression => {
176                let targets = finite_targets(train_set)?;
177                let base_score = targets.iter().sum::<f64>() / targets.len() as f64;
178                (vec![base_score; train_set.n_rows()], None, base_score)
179            }
180            Task::Classification => {
181                let (labels, encoded_targets) = binary_classification_targets(train_set)?;
182                let positive_rate = (encoded_targets.iter().sum::<f64>()
183                    / encoded_targets.len() as f64)
184                    .clamp(1e-6, 1.0 - 1e-6);
185                let base_score = (positive_rate / (1.0 - positive_rate)).ln();
186                (
187                    vec![base_score; train_set.n_rows()],
188                    Some(labels),
189                    base_score,
190                )
191            }
192        };
193
194        let mut trees = Vec::with_capacity(n_trees);
195        let mut tree_weights = Vec::with_capacity(n_trees);
196        let regression_targets = if config.task == Task::Regression {
197            Some(finite_targets(train_set)?)
198        } else {
199            None
200        };
201        let classification_targets = if config.task == Task::Classification {
202            Some(binary_classification_targets(train_set)?.1)
203        } else {
204            None
205        };
206
207        for tree_index in 0..n_trees {
208            let stage_seed = mix_seed(base_seed, tree_index as u64);
209            // Gradients/hessians are recomputed from the current ensemble margin
210            // at every stage, which keeps the tree learner focused on residual
211            // structure that earlier trees did not explain.
212            let (gradients, hessians) = match config.task {
213                Task::Regression => squared_error_gradients_and_hessians(
214                    raw_predictions.as_slice(),
215                    regression_targets
216                        .as_ref()
217                        .expect("regression targets exist for regression boosting"),
218                ),
219                Task::Classification => logistic_gradients_and_hessians(
220                    raw_predictions.as_slice(),
221                    classification_targets
222                        .as_ref()
223                        .expect("classification targets exist for classification boosting"),
224                ),
225            };
226
227            let base_rows = if bootstrap {
228                sampler.sample(stage_seed)
229            } else {
230                (0..train_set.n_rows()).collect()
231            };
232            // GOSS-style sampling keeps the largest gradients deterministically
233            // and samples part of the remainder. This biases work toward the rows
234            // where the current ensemble is most wrong.
235            let sampled_rows = gradient_focus_sample(
236                &base_rows,
237                &gradients,
238                &hessians,
239                top_gradient_fraction,
240                other_gradient_fraction,
241                mix_seed(stage_seed, 0x6011_5A11),
242            );
243            let sampled_table = SampledTable::new(train_set, sampled_rows.row_indices);
244            let mut stage_tree_options = tree_options;
245            stage_tree_options.tree_options.random_seed = stage_seed;
246            let stage_result = match config.tree_type {
247                TreeType::Cart => train_cart_regressor_from_gradients_and_hessians_with_status(
248                    &sampled_table,
249                    &sampled_rows.gradients,
250                    &sampled_rows.hessians,
251                    parallelism,
252                    stage_tree_options,
253                ),
254                TreeType::Randomized => {
255                    train_randomized_regressor_from_gradients_and_hessians_with_status(
256                        &sampled_table,
257                        &sampled_rows.gradients,
258                        &sampled_rows.hessians,
259                        parallelism,
260                        stage_tree_options,
261                    )
262                }
263                TreeType::Oblivious => {
264                    train_oblivious_regressor_from_gradients_and_hessians_with_status(
265                        &sampled_table,
266                        &sampled_rows.gradients,
267                        &sampled_rows.hessians,
268                        parallelism,
269                        stage_tree_options,
270                    )
271                }
272                _ => unreachable!("boosting tree type validated by training dispatch"),
273            }
274            .map_err(BoostingError::SecondOrderTree)?;
275
276            // A canary root win means the stage could not find a real feature
277            // stronger than shuffled noise, so boosting stops early.
278            if stage_result.root_canary_selected {
279                break;
280            }
281
282            let stage_tree = stage_result.model;
283            let stage_model = Model::DecisionTreeRegressor(stage_tree);
284            let stage_predictions = stage_model.predict_table(train_set);
285            for (raw_prediction, stage_prediction) in raw_predictions
286                .iter_mut()
287                .zip(stage_predictions.iter().copied())
288            {
289                // Trees are fit on raw margins; shrinkage is applied only when
290                // updating the ensemble prediction.
291                *raw_prediction += learning_rate * stage_prediction;
292            }
293            tree_weights.push(learning_rate);
294            trees.push(stage_model);
295        }
296
297        Ok(Self::new(
298            config.task,
299            config.tree_type,
300            trees,
301            tree_weights,
302            base_score,
303            learning_rate,
304            bootstrap,
305            top_gradient_fraction,
306            other_gradient_fraction,
307            max_features,
308            config.seed,
309            train_set.n_features(),
310            feature_preprocessing,
311            class_labels,
312            train_set.canaries(),
313        ))
314    }
315
316    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
317        match self.task {
318            Task::Regression => self.predict_regression_table(table),
319            Task::Classification => self.predict_classification_table(table),
320        }
321    }
322
323    pub fn predict_proba_table(
324        &self,
325        table: &dyn TableAccess,
326    ) -> Result<Vec<Vec<f64>>, PredictError> {
327        if self.task != Task::Classification {
328            return Err(PredictError::ProbabilityPredictionRequiresClassification);
329        }
330        Ok(self
331            .raw_scores(table)
332            .into_iter()
333            .map(|score| {
334                let positive = sigmoid(score);
335                vec![1.0 - positive, positive]
336            })
337            .collect())
338    }
339
340    pub fn task(&self) -> Task {
341        self.task
342    }
343
344    pub fn criterion(&self) -> Criterion {
345        Criterion::SecondOrder
346    }
347
348    pub fn tree_type(&self) -> TreeType {
349        self.tree_type
350    }
351
352    pub fn trees(&self) -> &[Model] {
353        &self.trees
354    }
355
356    pub fn tree_weights(&self) -> &[f64] {
357        &self.tree_weights
358    }
359
360    pub fn base_score(&self) -> f64 {
361        self.base_score
362    }
363
364    pub fn num_features(&self) -> usize {
365        self.num_features
366    }
367
368    pub fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
369        &self.feature_preprocessing
370    }
371
372    pub fn class_labels(&self) -> Option<Vec<f64>> {
373        self.class_labels.clone()
374    }
375
376    pub fn training_metadata(&self) -> TrainingMetadata {
377        TrainingMetadata {
378            algorithm: "gbm".to_string(),
379            task: match self.task {
380                Task::Regression => "regression".to_string(),
381                Task::Classification => "classification".to_string(),
382            },
383            tree_type: match self.tree_type {
384                TreeType::Cart => "cart".to_string(),
385                TreeType::Randomized => "randomized".to_string(),
386                TreeType::Oblivious => "oblivious".to_string(),
387                _ => unreachable!("boosting only supports cart/randomized/oblivious"),
388            },
389            criterion: "second_order".to_string(),
390            canaries: self.training_canaries,
391            compute_oob: false,
392            max_depth: self.trees.first().and_then(Model::max_depth),
393            min_samples_split: self.trees.first().and_then(Model::min_samples_split),
394            min_samples_leaf: self.trees.first().and_then(Model::min_samples_leaf),
395            n_trees: Some(self.trees.len()),
396            max_features: Some(self.max_features),
397            seed: self.seed,
398            oob_score: None,
399            class_labels: self.class_labels.clone(),
400            learning_rate: Some(self.learning_rate),
401            bootstrap: Some(self.bootstrap),
402            top_gradient_fraction: Some(self.top_gradient_fraction),
403            other_gradient_fraction: Some(self.other_gradient_fraction),
404        }
405    }
406
407    fn raw_scores(&self, table: &dyn TableAccess) -> Vec<f64> {
408        let mut scores = vec![self.base_score; table.n_rows()];
409        for (tree, weight) in self.trees.iter().zip(self.tree_weights.iter().copied()) {
410            let predictions = tree.predict_table(table);
411            for (score, prediction) in scores.iter_mut().zip(predictions.iter().copied()) {
412                *score += weight * prediction;
413            }
414        }
415        scores
416    }
417
418    fn predict_regression_table(&self, table: &dyn TableAccess) -> Vec<f64> {
419        self.raw_scores(table)
420    }
421
422    fn predict_classification_table(&self, table: &dyn TableAccess) -> Vec<f64> {
423        let class_labels = self
424            .class_labels
425            .as_ref()
426            .expect("classification boosting stores class labels");
427        self.raw_scores(table)
428            .into_iter()
429            .map(|score| {
430                if sigmoid(score) >= 0.5 {
431                    class_labels[1]
432                } else {
433                    class_labels[0]
434                }
435            })
436            .collect()
437    }
438}
439
440struct GradientFocusedSample {
441    row_indices: Vec<usize>,
442    gradients: Vec<f64>,
443    hessians: Vec<f64>,
444}
445
446impl<'a> SampledTable<'a> {
447    fn new(base: &'a dyn TableAccess, row_indices: Vec<usize>) -> Self {
448        Self { base, row_indices }
449    }
450
451    fn resolve_row(&self, row_index: usize) -> usize {
452        self.row_indices[row_index]
453    }
454}
455
456impl TableAccess for SampledTable<'_> {
457    fn n_rows(&self) -> usize {
458        self.row_indices.len()
459    }
460
461    fn n_features(&self) -> usize {
462        self.base.n_features()
463    }
464
465    fn canaries(&self) -> usize {
466        self.base.canaries()
467    }
468
469    fn numeric_bin_cap(&self) -> usize {
470        self.base.numeric_bin_cap()
471    }
472
473    fn binned_feature_count(&self) -> usize {
474        self.base.binned_feature_count()
475    }
476
477    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
478        self.base
479            .feature_value(feature_index, self.resolve_row(row_index))
480    }
481
482    fn is_binary_feature(&self, index: usize) -> bool {
483        self.base.is_binary_feature(index)
484    }
485
486    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
487        self.base
488            .binned_value(feature_index, self.resolve_row(row_index))
489    }
490
491    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
492        self.base
493            .binned_boolean_value(feature_index, self.resolve_row(row_index))
494    }
495
496    fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
497        self.base.binned_column_kind(index)
498    }
499
500    fn is_binary_binned_feature(&self, index: usize) -> bool {
501        self.base.is_binary_binned_feature(index)
502    }
503
504    fn target_value(&self, row_index: usize) -> f64 {
505        self.base.target_value(self.resolve_row(row_index))
506    }
507}
508
509fn validate_boosting_parameters(
510    train_set: &dyn TableAccess,
511    learning_rate: f64,
512    top_gradient_fraction: f64,
513    other_gradient_fraction: f64,
514) -> Result<(), BoostingError> {
515    if train_set.n_rows() == 0 {
516        return Err(BoostingError::InvalidLearningRate(learning_rate));
517    }
518    if !learning_rate.is_finite() || learning_rate <= 0.0 {
519        return Err(BoostingError::InvalidLearningRate(learning_rate));
520    }
521    if !top_gradient_fraction.is_finite()
522        || top_gradient_fraction <= 0.0
523        || top_gradient_fraction > 1.0
524    {
525        return Err(BoostingError::InvalidTopGradientFraction(
526            top_gradient_fraction,
527        ));
528    }
529    if !other_gradient_fraction.is_finite()
530        || !(0.0..1.0).contains(&other_gradient_fraction)
531        || top_gradient_fraction + other_gradient_fraction > 1.0
532    {
533        return Err(BoostingError::InvalidOtherGradientFraction(
534            other_gradient_fraction,
535        ));
536    }
537    Ok(())
538}
539
540fn finite_targets(train_set: &dyn TableAccess) -> Result<Vec<f64>, BoostingError> {
541    (0..train_set.n_rows())
542        .map(|row_index| {
543            let value = train_set.target_value(row_index);
544            if value.is_finite() {
545                Ok(value)
546            } else {
547                Err(BoostingError::InvalidTargetValue {
548                    row: row_index,
549                    value,
550                })
551            }
552        })
553        .collect()
554}
555
556fn binary_classification_targets(
557    train_set: &dyn TableAccess,
558) -> Result<(Vec<f64>, Vec<f64>), BoostingError> {
559    let mut labels = finite_targets(train_set)?;
560    labels.sort_by(|left, right| left.total_cmp(right));
561    labels.dedup_by(|left, right| left.total_cmp(right).is_eq());
562    if labels.len() != 2 {
563        return Err(BoostingError::UnsupportedClassificationClassCount(
564            labels.len(),
565        ));
566    }
567
568    let negative = labels[0];
569    let encoded = (0..train_set.n_rows())
570        .map(|row_index| {
571            if train_set
572                .target_value(row_index)
573                .total_cmp(&negative)
574                .is_eq()
575            {
576                0.0
577            } else {
578                1.0
579            }
580        })
581        .collect();
582    Ok((labels, encoded))
583}
584
585fn squared_error_gradients_and_hessians(
586    raw_predictions: &[f64],
587    targets: &[f64],
588) -> (Vec<f64>, Vec<f64>) {
589    (
590        raw_predictions
591            .iter()
592            .zip(targets.iter())
593            .map(|(prediction, target)| prediction - target)
594            .collect(),
595        vec![1.0; targets.len()],
596    )
597}
598
599fn logistic_gradients_and_hessians(
600    raw_predictions: &[f64],
601    targets: &[f64],
602) -> (Vec<f64>, Vec<f64>) {
603    let mut gradients = Vec::with_capacity(targets.len());
604    let mut hessians = Vec::with_capacity(targets.len());
605    for (raw_prediction, target) in raw_predictions.iter().zip(targets.iter()) {
606        let probability = sigmoid(*raw_prediction);
607        gradients.push(probability - target);
608        hessians.push((probability * (1.0 - probability)).max(1e-12));
609    }
610    (gradients, hessians)
611}
612
613fn sigmoid(value: f64) -> f64 {
614    if value >= 0.0 {
615        let exp = (-value).exp();
616        1.0 / (1.0 + exp)
617    } else {
618        let exp = value.exp();
619        exp / (1.0 + exp)
620    }
621}
622
623fn gradient_focus_sample(
624    base_rows: &[usize],
625    gradients: &[f64],
626    hessians: &[f64],
627    top_gradient_fraction: f64,
628    other_gradient_fraction: f64,
629    seed: u64,
630) -> GradientFocusedSample {
631    let mut ranked = base_rows
632        .iter()
633        .copied()
634        .map(|row_index| (row_index, gradients[row_index].abs()))
635        .collect::<Vec<_>>();
636    ranked.sort_by(|(left_row, left_abs), (right_row, right_abs)| {
637        right_abs
638            .total_cmp(left_abs)
639            .then_with(|| left_row.cmp(right_row))
640    });
641
642    let top_count = ((ranked.len() as f64) * top_gradient_fraction)
643        .ceil()
644        .clamp(1.0, ranked.len() as f64) as usize;
645    let mut row_indices = Vec::with_capacity(ranked.len());
646    let mut sampled_gradients = Vec::with_capacity(ranked.len());
647    let mut sampled_hessians = Vec::with_capacity(ranked.len());
648
649    for (row_index, _) in ranked.iter().take(top_count) {
650        row_indices.push(*row_index);
651        sampled_gradients.push(gradients[*row_index]);
652        sampled_hessians.push(hessians[*row_index]);
653    }
654
655    if top_count < ranked.len() && other_gradient_fraction > 0.0 {
656        let remaining = ranked[top_count..]
657            .iter()
658            .map(|(row_index, _)| *row_index)
659            .collect::<Vec<_>>();
660        let other_count = ((remaining.len() as f64) * other_gradient_fraction)
661            .ceil()
662            .min(remaining.len() as f64) as usize;
663        if other_count > 0 {
664            let mut remaining = remaining;
665            let mut rng = StdRng::seed_from_u64(seed);
666            remaining.shuffle(&mut rng);
667            let gradient_scale = (1.0 - top_gradient_fraction) / other_gradient_fraction;
668            for row_index in remaining.into_iter().take(other_count) {
669                row_indices.push(row_index);
670                sampled_gradients.push(gradients[row_index] * gradient_scale);
671                sampled_hessians.push(hessians[row_index] * gradient_scale);
672            }
673        }
674    }
675
676    GradientFocusedSample {
677        row_indices,
678        gradients: sampled_gradients,
679        hessians: sampled_hessians,
680    }
681}
682
683#[cfg(test)]
684mod tests {
685    use super::*;
686    use crate::{MaxFeatures, TrainAlgorithm, TrainConfig};
687    use forestfire_data::{BinnedColumnKind, TableAccess};
688    use forestfire_data::{DenseTable, NumericBins};
689
690    #[test]
691    fn regression_boosting_fits_simple_signal() {
692        let table = DenseTable::with_options(
693            vec![
694                vec![0.0],
695                vec![0.0],
696                vec![1.0],
697                vec![1.0],
698                vec![2.0],
699                vec![2.0],
700            ],
701            vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
702            0,
703            NumericBins::fixed(8).unwrap(),
704        )
705        .unwrap();
706
707        let model = GradientBoostedTrees::train(
708            &table,
709            TrainConfig {
710                algorithm: TrainAlgorithm::Gbm,
711                task: Task::Regression,
712                tree_type: TreeType::Cart,
713                criterion: Criterion::SecondOrder,
714                n_trees: Some(20),
715                learning_rate: Some(0.2),
716                max_depth: Some(2),
717                ..TrainConfig::default()
718            },
719            Parallelism::sequential(),
720        )
721        .unwrap();
722
723        let predictions = model.predict_table(&table);
724        assert!(predictions[0] < predictions[2]);
725        assert!(predictions[2] < predictions[4]);
726    }
727
728    #[test]
729    fn classification_boosting_produces_binary_probabilities() {
730        let table = DenseTable::with_options(
731            vec![vec![0.0], vec![0.1], vec![0.9], vec![1.0]],
732            vec![0.0, 0.0, 1.0, 1.0],
733            0,
734            NumericBins::fixed(8).unwrap(),
735        )
736        .unwrap();
737
738        let model = GradientBoostedTrees::train(
739            &table,
740            TrainConfig {
741                algorithm: TrainAlgorithm::Gbm,
742                task: Task::Classification,
743                tree_type: TreeType::Cart,
744                criterion: Criterion::SecondOrder,
745                n_trees: Some(25),
746                learning_rate: Some(0.2),
747                max_depth: Some(2),
748                ..TrainConfig::default()
749            },
750            Parallelism::sequential(),
751        )
752        .unwrap();
753
754        let probabilities = model.predict_proba_table(&table).unwrap();
755        assert_eq!(probabilities.len(), 4);
756        assert!(probabilities[0][1] < 0.5);
757        assert!(probabilities[3][1] > 0.5);
758    }
759
760    #[test]
761    fn classification_boosting_rejects_multiclass_targets() {
762        let table =
763            DenseTable::new(vec![vec![0.0], vec![1.0], vec![2.0]], vec![0.0, 1.0, 2.0]).unwrap();
764
765        let error = GradientBoostedTrees::train(
766            &table,
767            TrainConfig {
768                algorithm: TrainAlgorithm::Gbm,
769                task: Task::Classification,
770                tree_type: TreeType::Cart,
771                criterion: Criterion::SecondOrder,
772                ..TrainConfig::default()
773            },
774            Parallelism::sequential(),
775        )
776        .unwrap_err();
777
778        assert!(matches!(
779            error,
780            BoostingError::UnsupportedClassificationClassCount(3)
781        ));
782    }
783
784    struct RootCanaryTable;
785
786    impl TableAccess for RootCanaryTable {
787        fn n_rows(&self) -> usize {
788            4
789        }
790
791        fn n_features(&self) -> usize {
792            1
793        }
794
795        fn canaries(&self) -> usize {
796            1
797        }
798
799        fn numeric_bin_cap(&self) -> usize {
800            2
801        }
802
803        fn binned_feature_count(&self) -> usize {
804            2
805        }
806
807        fn feature_value(&self, _feature_index: usize, _row_index: usize) -> f64 {
808            0.0
809        }
810
811        fn is_binary_feature(&self, _index: usize) -> bool {
812            true
813        }
814
815        fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
816            match feature_index {
817                0 => 0,
818                1 => u16::from(row_index >= 2),
819                _ => unreachable!(),
820            }
821        }
822
823        fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
824            Some(match feature_index {
825                0 => false,
826                1 => row_index >= 2,
827                _ => unreachable!(),
828            })
829        }
830
831        fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
832            match index {
833                0 => BinnedColumnKind::Real { source_index: 0 },
834                1 => BinnedColumnKind::Canary {
835                    source_index: 0,
836                    copy_index: 0,
837                },
838                _ => unreachable!(),
839            }
840        }
841
842        fn is_binary_binned_feature(&self, _index: usize) -> bool {
843            true
844        }
845
846        fn target_value(&self, row_index: usize) -> f64 {
847            [0.0, 0.0, 1.0, 1.0][row_index]
848        }
849    }
850
851    #[test]
852    fn boosting_stops_when_root_split_is_a_canary() {
853        let table = RootCanaryTable;
854
855        let model = GradientBoostedTrees::train(
856            &table,
857            TrainConfig {
858                algorithm: TrainAlgorithm::Gbm,
859                task: Task::Regression,
860                tree_type: TreeType::Cart,
861                criterion: Criterion::SecondOrder,
862                n_trees: Some(10),
863                max_features: MaxFeatures::All,
864                learning_rate: Some(0.1),
865                top_gradient_fraction: Some(1.0),
866                other_gradient_fraction: Some(0.0),
867                ..TrainConfig::default()
868            },
869            Parallelism::sequential(),
870        )
871        .unwrap();
872
873        assert_eq!(model.trees().len(), 0);
874        assert_eq!(model.training_metadata().n_trees, Some(0));
875        assert!(
876            model
877                .predict_table(&table)
878                .iter()
879                .all(|value| value.is_finite())
880        );
881    }
882}