Skip to main content

forestfire_core/
boosting.rs

1//! Gradient boosting implementation.
2//!
3//! The boosting path is intentionally "LightGBM-like" rather than a direct
4//! clone. It uses second-order trees, shrinkage, and gradient-focused sampling,
5//! but it keeps ForestFire's canary mechanism active so a stage can stop before
6//! growing a tree whose best root split is indistinguishable from noise.
7
8use crate::bootstrap::BootstrapSampler;
9use crate::ir::TrainingMetadata;
10use crate::tree::second_order::{
11    SecondOrderRegressionTreeError, SecondOrderRegressionTreeOptions,
12    train_cart_regressor_from_gradients_and_hessians_with_status,
13    train_oblivious_regressor_from_gradients_and_hessians_with_status,
14    train_randomized_regressor_from_gradients_and_hessians_with_status,
15};
16use crate::tree::shared::mix_seed;
17use crate::{
18    Criterion, FeaturePreprocessing, Model, Parallelism, PredictError, Task, TrainConfig, TreeType,
19    capture_feature_preprocessing,
20};
21use forestfire_data::TableAccess;
22use rand::SeedableRng;
23use rand::rngs::StdRng;
24use rand::seq::SliceRandom;
25
26/// Stage-wise gradient-boosted tree ensemble.
27///
28/// The ensemble keeps explicit tree weights and a base score so the semantic
29/// model can reconstruct raw margins exactly for both prediction and IR export.
30#[derive(Debug, Clone)]
31pub struct GradientBoostedTrees {
32    task: Task,
33    tree_type: TreeType,
34    trees: Vec<Model>,
35    tree_weights: Vec<f64>,
36    base_score: f64,
37    learning_rate: f64,
38    bootstrap: bool,
39    top_gradient_fraction: f64,
40    other_gradient_fraction: f64,
41    max_features: usize,
42    seed: Option<u64>,
43    num_features: usize,
44    feature_preprocessing: Vec<FeaturePreprocessing>,
45    class_labels: Option<Vec<f64>>,
46    training_canaries: usize,
47}
48
49#[derive(Debug)]
50pub enum BoostingError {
51    InvalidTargetValue { row: usize, value: f64 },
52    UnsupportedClassificationClassCount(usize),
53    InvalidLearningRate(f64),
54    InvalidTopGradientFraction(f64),
55    InvalidOtherGradientFraction(f64),
56    SecondOrderTree(SecondOrderRegressionTreeError),
57}
58
59impl std::fmt::Display for BoostingError {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            BoostingError::InvalidTargetValue { row, value } => write!(
63                f,
64                "Boosting targets must be finite values. Found {} at row {}.",
65                value, row
66            ),
67            BoostingError::UnsupportedClassificationClassCount(count) => write!(
68                f,
69                "Gradient boosting currently supports binary classification only. Found {} classes.",
70                count
71            ),
72            BoostingError::InvalidLearningRate(value) => write!(
73                f,
74                "learning_rate must be finite and greater than 0. Found {}.",
75                value
76            ),
77            BoostingError::InvalidTopGradientFraction(value) => write!(
78                f,
79                "top_gradient_fraction must be in the interval (0, 1]. Found {}.",
80                value
81            ),
82            BoostingError::InvalidOtherGradientFraction(value) => write!(
83                f,
84                "other_gradient_fraction must be in the interval [0, 1), and top_gradient_fraction + other_gradient_fraction must be at most 1. Found {}.",
85                value
86            ),
87            BoostingError::SecondOrderTree(err) => err.fmt(f),
88        }
89    }
90}
91
92impl std::error::Error for BoostingError {}
93
94struct SampledTable<'a> {
95    base: &'a dyn TableAccess,
96    row_indices: Vec<usize>,
97}
98
99impl GradientBoostedTrees {
100    #[allow(clippy::too_many_arguments)]
101    pub fn new(
102        task: Task,
103        tree_type: TreeType,
104        trees: Vec<Model>,
105        tree_weights: Vec<f64>,
106        base_score: f64,
107        learning_rate: f64,
108        bootstrap: bool,
109        top_gradient_fraction: f64,
110        other_gradient_fraction: f64,
111        max_features: usize,
112        seed: Option<u64>,
113        num_features: usize,
114        feature_preprocessing: Vec<FeaturePreprocessing>,
115        class_labels: Option<Vec<f64>>,
116        training_canaries: usize,
117    ) -> Self {
118        Self {
119            task,
120            tree_type,
121            trees,
122            tree_weights,
123            base_score,
124            learning_rate,
125            bootstrap,
126            top_gradient_fraction,
127            other_gradient_fraction,
128            max_features,
129            seed,
130            num_features,
131            feature_preprocessing,
132            class_labels,
133            training_canaries,
134        }
135    }
136
137    #[allow(dead_code)]
138    pub(crate) fn train(
139        train_set: &dyn TableAccess,
140        config: TrainConfig,
141        parallelism: Parallelism,
142    ) -> Result<Self, BoostingError> {
143        let missing_value_strategies = config
144            .missing_value_strategy
145            .resolve_for_feature_count(train_set.binned_feature_count())
146            .unwrap_or_else(|err| {
147                panic!("unexpected training error while resolving missing strategy: {err}")
148            });
149        Self::train_with_missing_value_strategies(
150            train_set,
151            config,
152            parallelism,
153            missing_value_strategies,
154        )
155    }
156
157    pub(crate) fn train_with_missing_value_strategies(
158        train_set: &dyn TableAccess,
159        config: TrainConfig,
160        parallelism: Parallelism,
161        missing_value_strategies: Vec<crate::MissingValueStrategy>,
162    ) -> Result<Self, BoostingError> {
163        let n_trees = config.n_trees.unwrap_or(100);
164        let learning_rate = config.learning_rate.unwrap_or(0.1);
165        let bootstrap = config.bootstrap;
166        let top_gradient_fraction = config.top_gradient_fraction.unwrap_or(0.2);
167        let other_gradient_fraction = config.other_gradient_fraction.unwrap_or(0.1);
168        validate_boosting_parameters(
169            train_set,
170            learning_rate,
171            top_gradient_fraction,
172            other_gradient_fraction,
173        )?;
174
175        let max_features = config
176            .max_features
177            .resolve(config.task, train_set.binned_feature_count());
178        let base_seed = config.seed.unwrap_or(0xB005_7EED_u64);
179        let tree_options = crate::RegressionTreeOptions {
180            max_depth: config.max_depth.unwrap_or(8),
181            min_samples_split: config.min_samples_split.unwrap_or(2),
182            min_samples_leaf: config.min_samples_leaf.unwrap_or(1),
183            max_features: Some(max_features),
184            random_seed: 0,
185            missing_value_strategies,
186        };
187        let tree_options = SecondOrderRegressionTreeOptions {
188            tree_options,
189            l2_regularization: 1.0,
190            min_sum_hessian_in_leaf: 1e-3,
191            min_gain_to_split: 0.0,
192        };
193        let feature_preprocessing = capture_feature_preprocessing(train_set);
194        let sampler = BootstrapSampler::new(train_set.n_rows());
195
196        let (mut raw_predictions, class_labels, base_score) = match config.task {
197            Task::Regression => {
198                let targets = finite_targets(train_set)?;
199                let base_score = targets.iter().sum::<f64>() / targets.len() as f64;
200                (vec![base_score; train_set.n_rows()], None, base_score)
201            }
202            Task::Classification => {
203                let (labels, encoded_targets) = binary_classification_targets(train_set)?;
204                let positive_rate = (encoded_targets.iter().sum::<f64>()
205                    / encoded_targets.len() as f64)
206                    .clamp(1e-6, 1.0 - 1e-6);
207                let base_score = (positive_rate / (1.0 - positive_rate)).ln();
208                (
209                    vec![base_score; train_set.n_rows()],
210                    Some(labels),
211                    base_score,
212                )
213            }
214        };
215
216        let mut trees = Vec::with_capacity(n_trees);
217        let mut tree_weights = Vec::with_capacity(n_trees);
218        let regression_targets = if config.task == Task::Regression {
219            Some(finite_targets(train_set)?)
220        } else {
221            None
222        };
223        let classification_targets = if config.task == Task::Classification {
224            Some(binary_classification_targets(train_set)?.1)
225        } else {
226            None
227        };
228
229        for tree_index in 0..n_trees {
230            let stage_seed = mix_seed(base_seed, tree_index as u64);
231            // Gradients/hessians are recomputed from the current ensemble margin
232            // at every stage, which keeps the tree learner focused on residual
233            // structure that earlier trees did not explain.
234            let (gradients, hessians) = match config.task {
235                Task::Regression => squared_error_gradients_and_hessians(
236                    raw_predictions.as_slice(),
237                    regression_targets
238                        .as_ref()
239                        .expect("regression targets exist for regression boosting"),
240                ),
241                Task::Classification => logistic_gradients_and_hessians(
242                    raw_predictions.as_slice(),
243                    classification_targets
244                        .as_ref()
245                        .expect("classification targets exist for classification boosting"),
246                ),
247            };
248
249            let base_rows = if bootstrap {
250                sampler.sample(stage_seed)
251            } else {
252                (0..train_set.n_rows()).collect()
253            };
254            // GOSS-style sampling keeps the largest gradients deterministically
255            // and samples part of the remainder. This biases work toward the rows
256            // where the current ensemble is most wrong.
257            let sampled_rows = gradient_focus_sample(
258                &base_rows,
259                &gradients,
260                &hessians,
261                top_gradient_fraction,
262                other_gradient_fraction,
263                mix_seed(stage_seed, 0x6011_5A11),
264            );
265            let sampled_table = SampledTable::new(train_set, sampled_rows.row_indices);
266            let mut stage_tree_options = tree_options.clone();
267            stage_tree_options.tree_options.random_seed = stage_seed;
268            let stage_result = match config.tree_type {
269                TreeType::Cart => train_cart_regressor_from_gradients_and_hessians_with_status(
270                    &sampled_table,
271                    &sampled_rows.gradients,
272                    &sampled_rows.hessians,
273                    parallelism,
274                    stage_tree_options,
275                ),
276                TreeType::Randomized => {
277                    train_randomized_regressor_from_gradients_and_hessians_with_status(
278                        &sampled_table,
279                        &sampled_rows.gradients,
280                        &sampled_rows.hessians,
281                        parallelism,
282                        stage_tree_options,
283                    )
284                }
285                TreeType::Oblivious => {
286                    train_oblivious_regressor_from_gradients_and_hessians_with_status(
287                        &sampled_table,
288                        &sampled_rows.gradients,
289                        &sampled_rows.hessians,
290                        parallelism,
291                        stage_tree_options,
292                    )
293                }
294                _ => unreachable!("boosting tree type validated by training dispatch"),
295            }
296            .map_err(BoostingError::SecondOrderTree)?;
297
298            // A canary root win means the stage could not find a real feature
299            // stronger than shuffled noise, so boosting stops early.
300            if stage_result.root_canary_selected {
301                break;
302            }
303
304            let stage_tree = stage_result.model;
305            let stage_model = Model::DecisionTreeRegressor(stage_tree);
306            let stage_predictions = stage_model.predict_table(train_set);
307            for (raw_prediction, stage_prediction) in raw_predictions
308                .iter_mut()
309                .zip(stage_predictions.iter().copied())
310            {
311                // Trees are fit on raw margins; shrinkage is applied only when
312                // updating the ensemble prediction.
313                *raw_prediction += learning_rate * stage_prediction;
314            }
315            tree_weights.push(learning_rate);
316            trees.push(stage_model);
317        }
318
319        Ok(Self::new(
320            config.task,
321            config.tree_type,
322            trees,
323            tree_weights,
324            base_score,
325            learning_rate,
326            bootstrap,
327            top_gradient_fraction,
328            other_gradient_fraction,
329            max_features,
330            config.seed,
331            train_set.n_features(),
332            feature_preprocessing,
333            class_labels,
334            train_set.canaries(),
335        ))
336    }
337
338    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
339        match self.task {
340            Task::Regression => self.predict_regression_table(table),
341            Task::Classification => self.predict_classification_table(table),
342        }
343    }
344
345    pub fn predict_proba_table(
346        &self,
347        table: &dyn TableAccess,
348    ) -> Result<Vec<Vec<f64>>, PredictError> {
349        if self.task != Task::Classification {
350            return Err(PredictError::ProbabilityPredictionRequiresClassification);
351        }
352        Ok(self
353            .raw_scores(table)
354            .into_iter()
355            .map(|score| {
356                let positive = sigmoid(score);
357                vec![1.0 - positive, positive]
358            })
359            .collect())
360    }
361
362    pub fn task(&self) -> Task {
363        self.task
364    }
365
366    pub fn criterion(&self) -> Criterion {
367        Criterion::SecondOrder
368    }
369
370    pub fn tree_type(&self) -> TreeType {
371        self.tree_type
372    }
373
374    pub fn trees(&self) -> &[Model] {
375        &self.trees
376    }
377
378    pub fn tree_weights(&self) -> &[f64] {
379        &self.tree_weights
380    }
381
382    pub fn base_score(&self) -> f64 {
383        self.base_score
384    }
385
386    pub fn num_features(&self) -> usize {
387        self.num_features
388    }
389
390    pub fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
391        &self.feature_preprocessing
392    }
393
394    pub fn class_labels(&self) -> Option<Vec<f64>> {
395        self.class_labels.clone()
396    }
397
398    pub fn training_metadata(&self) -> TrainingMetadata {
399        TrainingMetadata {
400            algorithm: "gbm".to_string(),
401            task: match self.task {
402                Task::Regression => "regression".to_string(),
403                Task::Classification => "classification".to_string(),
404            },
405            tree_type: match self.tree_type {
406                TreeType::Cart => "cart".to_string(),
407                TreeType::Randomized => "randomized".to_string(),
408                TreeType::Oblivious => "oblivious".to_string(),
409                _ => unreachable!("boosting only supports cart/randomized/oblivious"),
410            },
411            criterion: "second_order".to_string(),
412            canaries: self.training_canaries,
413            compute_oob: false,
414            max_depth: self.trees.first().and_then(Model::max_depth),
415            min_samples_split: self.trees.first().and_then(Model::min_samples_split),
416            min_samples_leaf: self.trees.first().and_then(Model::min_samples_leaf),
417            n_trees: Some(self.trees.len()),
418            max_features: Some(self.max_features),
419            seed: self.seed,
420            oob_score: None,
421            class_labels: self.class_labels.clone(),
422            learning_rate: Some(self.learning_rate),
423            bootstrap: Some(self.bootstrap),
424            top_gradient_fraction: Some(self.top_gradient_fraction),
425            other_gradient_fraction: Some(self.other_gradient_fraction),
426        }
427    }
428
429    fn raw_scores(&self, table: &dyn TableAccess) -> Vec<f64> {
430        let mut scores = vec![self.base_score; table.n_rows()];
431        for (tree, weight) in self.trees.iter().zip(self.tree_weights.iter().copied()) {
432            let predictions = tree.predict_table(table);
433            for (score, prediction) in scores.iter_mut().zip(predictions.iter().copied()) {
434                *score += weight * prediction;
435            }
436        }
437        scores
438    }
439
440    fn predict_regression_table(&self, table: &dyn TableAccess) -> Vec<f64> {
441        self.raw_scores(table)
442    }
443
444    fn predict_classification_table(&self, table: &dyn TableAccess) -> Vec<f64> {
445        let class_labels = self
446            .class_labels
447            .as_ref()
448            .expect("classification boosting stores class labels");
449        self.raw_scores(table)
450            .into_iter()
451            .map(|score| {
452                if sigmoid(score) >= 0.5 {
453                    class_labels[1]
454                } else {
455                    class_labels[0]
456                }
457            })
458            .collect()
459    }
460}
461
462struct GradientFocusedSample {
463    row_indices: Vec<usize>,
464    gradients: Vec<f64>,
465    hessians: Vec<f64>,
466}
467
468impl<'a> SampledTable<'a> {
469    fn new(base: &'a dyn TableAccess, row_indices: Vec<usize>) -> Self {
470        Self { base, row_indices }
471    }
472
473    fn resolve_row(&self, row_index: usize) -> usize {
474        self.row_indices[row_index]
475    }
476}
477
478impl TableAccess for SampledTable<'_> {
479    fn n_rows(&self) -> usize {
480        self.row_indices.len()
481    }
482
483    fn n_features(&self) -> usize {
484        self.base.n_features()
485    }
486
487    fn canaries(&self) -> usize {
488        self.base.canaries()
489    }
490
491    fn numeric_bin_cap(&self) -> usize {
492        self.base.numeric_bin_cap()
493    }
494
495    fn binned_feature_count(&self) -> usize {
496        self.base.binned_feature_count()
497    }
498
499    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
500        self.base
501            .feature_value(feature_index, self.resolve_row(row_index))
502    }
503
504    fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
505        self.base
506            .is_missing(feature_index, self.resolve_row(row_index))
507    }
508
509    fn is_binary_feature(&self, index: usize) -> bool {
510        self.base.is_binary_feature(index)
511    }
512
513    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
514        self.base
515            .binned_value(feature_index, self.resolve_row(row_index))
516    }
517
518    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
519        self.base
520            .binned_boolean_value(feature_index, self.resolve_row(row_index))
521    }
522
523    fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
524        self.base.binned_column_kind(index)
525    }
526
527    fn is_binary_binned_feature(&self, index: usize) -> bool {
528        self.base.is_binary_binned_feature(index)
529    }
530
531    fn target_value(&self, row_index: usize) -> f64 {
532        self.base.target_value(self.resolve_row(row_index))
533    }
534}
535
536fn validate_boosting_parameters(
537    train_set: &dyn TableAccess,
538    learning_rate: f64,
539    top_gradient_fraction: f64,
540    other_gradient_fraction: f64,
541) -> Result<(), BoostingError> {
542    if train_set.n_rows() == 0 {
543        return Err(BoostingError::InvalidLearningRate(learning_rate));
544    }
545    if !learning_rate.is_finite() || learning_rate <= 0.0 {
546        return Err(BoostingError::InvalidLearningRate(learning_rate));
547    }
548    if !top_gradient_fraction.is_finite()
549        || top_gradient_fraction <= 0.0
550        || top_gradient_fraction > 1.0
551    {
552        return Err(BoostingError::InvalidTopGradientFraction(
553            top_gradient_fraction,
554        ));
555    }
556    if !other_gradient_fraction.is_finite()
557        || !(0.0..1.0).contains(&other_gradient_fraction)
558        || top_gradient_fraction + other_gradient_fraction > 1.0
559    {
560        return Err(BoostingError::InvalidOtherGradientFraction(
561            other_gradient_fraction,
562        ));
563    }
564    Ok(())
565}
566
567fn finite_targets(train_set: &dyn TableAccess) -> Result<Vec<f64>, BoostingError> {
568    (0..train_set.n_rows())
569        .map(|row_index| {
570            let value = train_set.target_value(row_index);
571            if value.is_finite() {
572                Ok(value)
573            } else {
574                Err(BoostingError::InvalidTargetValue {
575                    row: row_index,
576                    value,
577                })
578            }
579        })
580        .collect()
581}
582
583fn binary_classification_targets(
584    train_set: &dyn TableAccess,
585) -> Result<(Vec<f64>, Vec<f64>), BoostingError> {
586    let mut labels = finite_targets(train_set)?;
587    labels.sort_by(|left, right| left.total_cmp(right));
588    labels.dedup_by(|left, right| left.total_cmp(right).is_eq());
589    if labels.len() != 2 {
590        return Err(BoostingError::UnsupportedClassificationClassCount(
591            labels.len(),
592        ));
593    }
594
595    let negative = labels[0];
596    let encoded = (0..train_set.n_rows())
597        .map(|row_index| {
598            if train_set
599                .target_value(row_index)
600                .total_cmp(&negative)
601                .is_eq()
602            {
603                0.0
604            } else {
605                1.0
606            }
607        })
608        .collect();
609    Ok((labels, encoded))
610}
611
612fn squared_error_gradients_and_hessians(
613    raw_predictions: &[f64],
614    targets: &[f64],
615) -> (Vec<f64>, Vec<f64>) {
616    (
617        raw_predictions
618            .iter()
619            .zip(targets.iter())
620            .map(|(prediction, target)| prediction - target)
621            .collect(),
622        vec![1.0; targets.len()],
623    )
624}
625
626fn logistic_gradients_and_hessians(
627    raw_predictions: &[f64],
628    targets: &[f64],
629) -> (Vec<f64>, Vec<f64>) {
630    let mut gradients = Vec::with_capacity(targets.len());
631    let mut hessians = Vec::with_capacity(targets.len());
632    for (raw_prediction, target) in raw_predictions.iter().zip(targets.iter()) {
633        let probability = sigmoid(*raw_prediction);
634        gradients.push(probability - target);
635        hessians.push((probability * (1.0 - probability)).max(1e-12));
636    }
637    (gradients, hessians)
638}
639
640fn sigmoid(value: f64) -> f64 {
641    if value >= 0.0 {
642        let exp = (-value).exp();
643        1.0 / (1.0 + exp)
644    } else {
645        let exp = value.exp();
646        exp / (1.0 + exp)
647    }
648}
649
650fn gradient_focus_sample(
651    base_rows: &[usize],
652    gradients: &[f64],
653    hessians: &[f64],
654    top_gradient_fraction: f64,
655    other_gradient_fraction: f64,
656    seed: u64,
657) -> GradientFocusedSample {
658    let mut ranked = base_rows
659        .iter()
660        .copied()
661        .map(|row_index| (row_index, gradients[row_index].abs()))
662        .collect::<Vec<_>>();
663    ranked.sort_by(|(left_row, left_abs), (right_row, right_abs)| {
664        right_abs
665            .total_cmp(left_abs)
666            .then_with(|| left_row.cmp(right_row))
667    });
668
669    let top_count = ((ranked.len() as f64) * top_gradient_fraction)
670        .ceil()
671        .clamp(1.0, ranked.len() as f64) as usize;
672    let mut row_indices = Vec::with_capacity(ranked.len());
673    let mut sampled_gradients = Vec::with_capacity(ranked.len());
674    let mut sampled_hessians = Vec::with_capacity(ranked.len());
675
676    for (row_index, _) in ranked.iter().take(top_count) {
677        row_indices.push(*row_index);
678        sampled_gradients.push(gradients[*row_index]);
679        sampled_hessians.push(hessians[*row_index]);
680    }
681
682    if top_count < ranked.len() && other_gradient_fraction > 0.0 {
683        let remaining = ranked[top_count..]
684            .iter()
685            .map(|(row_index, _)| *row_index)
686            .collect::<Vec<_>>();
687        let other_count = ((remaining.len() as f64) * other_gradient_fraction)
688            .ceil()
689            .min(remaining.len() as f64) as usize;
690        if other_count > 0 {
691            let mut remaining = remaining;
692            let mut rng = StdRng::seed_from_u64(seed);
693            remaining.shuffle(&mut rng);
694            let gradient_scale = (1.0 - top_gradient_fraction) / other_gradient_fraction;
695            for row_index in remaining.into_iter().take(other_count) {
696                row_indices.push(row_index);
697                sampled_gradients.push(gradients[row_index] * gradient_scale);
698                sampled_hessians.push(hessians[row_index] * gradient_scale);
699            }
700        }
701    }
702
703    GradientFocusedSample {
704        row_indices,
705        gradients: sampled_gradients,
706        hessians: sampled_hessians,
707    }
708}
709
710#[cfg(test)]
711mod tests {
712    use super::*;
713    use crate::{MaxFeatures, TrainAlgorithm, TrainConfig};
714    use forestfire_data::{BinnedColumnKind, TableAccess};
715    use forestfire_data::{DenseTable, NumericBins};
716
717    #[test]
718    fn regression_boosting_fits_simple_signal() {
719        let table = DenseTable::with_options(
720            vec![
721                vec![0.0],
722                vec![0.0],
723                vec![1.0],
724                vec![1.0],
725                vec![2.0],
726                vec![2.0],
727            ],
728            vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
729            0,
730            NumericBins::fixed(8).unwrap(),
731        )
732        .unwrap();
733
734        let model = GradientBoostedTrees::train(
735            &table,
736            TrainConfig {
737                algorithm: TrainAlgorithm::Gbm,
738                task: Task::Regression,
739                tree_type: TreeType::Cart,
740                criterion: Criterion::SecondOrder,
741                n_trees: Some(20),
742                learning_rate: Some(0.2),
743                max_depth: Some(2),
744                ..TrainConfig::default()
745            },
746            Parallelism::sequential(),
747        )
748        .unwrap();
749
750        let predictions = model.predict_table(&table);
751        assert!(predictions[0] < predictions[2]);
752        assert!(predictions[2] < predictions[4]);
753    }
754
755    #[test]
756    fn classification_boosting_produces_binary_probabilities() {
757        let table = DenseTable::with_options(
758            vec![vec![0.0], vec![0.1], vec![0.9], vec![1.0]],
759            vec![0.0, 0.0, 1.0, 1.0],
760            0,
761            NumericBins::fixed(8).unwrap(),
762        )
763        .unwrap();
764
765        let model = GradientBoostedTrees::train(
766            &table,
767            TrainConfig {
768                algorithm: TrainAlgorithm::Gbm,
769                task: Task::Classification,
770                tree_type: TreeType::Cart,
771                criterion: Criterion::SecondOrder,
772                n_trees: Some(25),
773                learning_rate: Some(0.2),
774                max_depth: Some(2),
775                ..TrainConfig::default()
776            },
777            Parallelism::sequential(),
778        )
779        .unwrap();
780
781        let probabilities = model.predict_proba_table(&table).unwrap();
782        assert_eq!(probabilities.len(), 4);
783        assert!(probabilities[0][1] < 0.5);
784        assert!(probabilities[3][1] > 0.5);
785    }
786
787    #[test]
788    fn classification_boosting_rejects_multiclass_targets() {
789        let table =
790            DenseTable::new(vec![vec![0.0], vec![1.0], vec![2.0]], vec![0.0, 1.0, 2.0]).unwrap();
791
792        let error = GradientBoostedTrees::train(
793            &table,
794            TrainConfig {
795                algorithm: TrainAlgorithm::Gbm,
796                task: Task::Classification,
797                tree_type: TreeType::Cart,
798                criterion: Criterion::SecondOrder,
799                ..TrainConfig::default()
800            },
801            Parallelism::sequential(),
802        )
803        .unwrap_err();
804
805        assert!(matches!(
806            error,
807            BoostingError::UnsupportedClassificationClassCount(3)
808        ));
809    }
810
811    struct RootCanaryTable;
812
813    impl TableAccess for RootCanaryTable {
814        fn n_rows(&self) -> usize {
815            4
816        }
817
818        fn n_features(&self) -> usize {
819            1
820        }
821
822        fn canaries(&self) -> usize {
823            1
824        }
825
826        fn numeric_bin_cap(&self) -> usize {
827            2
828        }
829
830        fn binned_feature_count(&self) -> usize {
831            2
832        }
833
834        fn feature_value(&self, _feature_index: usize, _row_index: usize) -> f64 {
835            0.0
836        }
837
838        fn is_missing(&self, _feature_index: usize, _row_index: usize) -> bool {
839            false
840        }
841
842        fn is_binary_feature(&self, _index: usize) -> bool {
843            true
844        }
845
846        fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
847            match feature_index {
848                0 => 0,
849                1 => u16::from(row_index >= 2),
850                _ => unreachable!(),
851            }
852        }
853
854        fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
855            Some(match feature_index {
856                0 => false,
857                1 => row_index >= 2,
858                _ => unreachable!(),
859            })
860        }
861
862        fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
863            match index {
864                0 => BinnedColumnKind::Real { source_index: 0 },
865                1 => BinnedColumnKind::Canary {
866                    source_index: 0,
867                    copy_index: 0,
868                },
869                _ => unreachable!(),
870            }
871        }
872
873        fn is_binary_binned_feature(&self, _index: usize) -> bool {
874            true
875        }
876
877        fn target_value(&self, row_index: usize) -> f64 {
878            [0.0, 0.0, 1.0, 1.0][row_index]
879        }
880    }
881
882    #[test]
883    fn boosting_stops_when_root_split_is_a_canary() {
884        let table = RootCanaryTable;
885
886        let model = GradientBoostedTrees::train(
887            &table,
888            TrainConfig {
889                algorithm: TrainAlgorithm::Gbm,
890                task: Task::Regression,
891                tree_type: TreeType::Cart,
892                criterion: Criterion::SecondOrder,
893                n_trees: Some(10),
894                max_features: MaxFeatures::All,
895                learning_rate: Some(0.1),
896                top_gradient_fraction: Some(1.0),
897                other_gradient_fraction: Some(0.0),
898                ..TrainConfig::default()
899            },
900            Parallelism::sequential(),
901        )
902        .unwrap();
903
904        assert_eq!(model.trees().len(), 0);
905        assert_eq!(model.training_metadata().n_trees, Some(0));
906        assert!(
907            model
908                .predict_table(&table)
909                .iter()
910                .all(|value| value.is_finite())
911        );
912    }
913}