Skip to main content

forestfire_core/
boosting.rs

1//! Gradient boosting implementation.
2//!
3//! The boosting path is intentionally "LightGBM-like" rather than a direct
4//! clone. It uses second-order trees, shrinkage, and gradient-focused sampling,
5//! but it keeps ForestFire's canary mechanism active so a stage can stop before
6//! growing a tree whose best root split is indistinguishable from noise.
7
8use crate::bootstrap::BootstrapSampler;
9use crate::ir::TrainingMetadata;
10use crate::tree::second_order::{
11    SecondOrderRegressionTreeError, SecondOrderRegressionTreeOptions,
12    train_cart_regressor_from_gradients_and_hessians_with_status,
13    train_oblivious_regressor_from_gradients_and_hessians_with_status,
14    train_randomized_regressor_from_gradients_and_hessians_with_status,
15};
16use crate::tree::shared::mix_seed;
17use crate::{
18    Criterion, FeaturePreprocessing, Model, Parallelism, PredictError, Task, TrainConfig, TreeType,
19    capture_feature_preprocessing,
20};
21use forestfire_data::TableAccess;
22use rand::SeedableRng;
23use rand::rngs::StdRng;
24use rand::seq::SliceRandom;
25
26/// Stage-wise gradient-boosted tree ensemble.
27///
28/// The ensemble keeps explicit tree weights and a base score so the semantic
29/// model can reconstruct raw margins exactly for both prediction and IR export.
30#[derive(Debug, Clone)]
31pub struct GradientBoostedTrees {
32    task: Task,
33    tree_type: TreeType,
34    trees: Vec<Model>,
35    tree_weights: Vec<f64>,
36    base_score: f64,
37    learning_rate: f64,
38    bootstrap: bool,
39    top_gradient_fraction: f64,
40    other_gradient_fraction: f64,
41    max_features: usize,
42    seed: Option<u64>,
43    num_features: usize,
44    feature_preprocessing: Vec<FeaturePreprocessing>,
45    class_labels: Option<Vec<f64>>,
46    training_canaries: usize,
47}
48
49#[derive(Debug)]
50pub enum BoostingError {
51    InvalidTargetValue { row: usize, value: f64 },
52    UnsupportedClassificationClassCount(usize),
53    InvalidLearningRate(f64),
54    InvalidTopGradientFraction(f64),
55    InvalidOtherGradientFraction(f64),
56    SecondOrderTree(SecondOrderRegressionTreeError),
57}
58
59impl std::fmt::Display for BoostingError {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            BoostingError::InvalidTargetValue { row, value } => write!(
63                f,
64                "Boosting targets must be finite values. Found {} at row {}.",
65                value, row
66            ),
67            BoostingError::UnsupportedClassificationClassCount(count) => write!(
68                f,
69                "Gradient boosting currently supports binary classification only. Found {} classes.",
70                count
71            ),
72            BoostingError::InvalidLearningRate(value) => write!(
73                f,
74                "learning_rate must be finite and greater than 0. Found {}.",
75                value
76            ),
77            BoostingError::InvalidTopGradientFraction(value) => write!(
78                f,
79                "top_gradient_fraction must be in the interval (0, 1]. Found {}.",
80                value
81            ),
82            BoostingError::InvalidOtherGradientFraction(value) => write!(
83                f,
84                "other_gradient_fraction must be in the interval [0, 1), and top_gradient_fraction + other_gradient_fraction must be at most 1. Found {}.",
85                value
86            ),
87            BoostingError::SecondOrderTree(err) => err.fmt(f),
88        }
89    }
90}
91
92impl std::error::Error for BoostingError {}
93
94struct SampledTable<'a> {
95    base: &'a dyn TableAccess,
96    row_indices: Vec<usize>,
97}
98
99impl GradientBoostedTrees {
100    #[allow(clippy::too_many_arguments)]
101    pub fn new(
102        task: Task,
103        tree_type: TreeType,
104        trees: Vec<Model>,
105        tree_weights: Vec<f64>,
106        base_score: f64,
107        learning_rate: f64,
108        bootstrap: bool,
109        top_gradient_fraction: f64,
110        other_gradient_fraction: f64,
111        max_features: usize,
112        seed: Option<u64>,
113        num_features: usize,
114        feature_preprocessing: Vec<FeaturePreprocessing>,
115        class_labels: Option<Vec<f64>>,
116        training_canaries: usize,
117    ) -> Self {
118        Self {
119            task,
120            tree_type,
121            trees,
122            tree_weights,
123            base_score,
124            learning_rate,
125            bootstrap,
126            top_gradient_fraction,
127            other_gradient_fraction,
128            max_features,
129            seed,
130            num_features,
131            feature_preprocessing,
132            class_labels,
133            training_canaries,
134        }
135    }
136
137    #[allow(dead_code)]
138    pub(crate) fn train(
139        train_set: &dyn TableAccess,
140        config: TrainConfig,
141        parallelism: Parallelism,
142    ) -> Result<Self, BoostingError> {
143        let missing_value_strategies = config
144            .missing_value_strategy
145            .resolve_for_feature_count(train_set.binned_feature_count())
146            .unwrap_or_else(|err| {
147                panic!("unexpected training error while resolving missing strategy: {err}")
148            });
149        Self::train_with_missing_value_strategies(
150            train_set,
151            config,
152            parallelism,
153            missing_value_strategies,
154        )
155    }
156
157    pub(crate) fn train_with_missing_value_strategies(
158        train_set: &dyn TableAccess,
159        config: TrainConfig,
160        parallelism: Parallelism,
161        missing_value_strategies: Vec<crate::MissingValueStrategy>,
162    ) -> Result<Self, BoostingError> {
163        let n_trees = config.n_trees.unwrap_or(100);
164        let learning_rate = config.learning_rate.unwrap_or(0.1);
165        let bootstrap = config.bootstrap;
166        let top_gradient_fraction = config.top_gradient_fraction.unwrap_or(0.2);
167        let other_gradient_fraction = config.other_gradient_fraction.unwrap_or(0.1);
168        validate_boosting_parameters(
169            train_set,
170            learning_rate,
171            top_gradient_fraction,
172            other_gradient_fraction,
173        )?;
174
175        let max_features = config
176            .max_features
177            .resolve(config.task, train_set.n_features());
178        let base_seed = config.seed.unwrap_or(0xB005_7EED_u64);
179        let tree_options = crate::RegressionTreeOptions {
180            max_depth: config.max_depth.unwrap_or(8),
181            min_samples_split: config.min_samples_split.unwrap_or(2),
182            min_samples_leaf: config.min_samples_leaf.unwrap_or(1),
183            max_features: Some(max_features),
184            random_seed: 0,
185            missing_value_strategies,
186            canary_filter: config.canary_filter,
187        };
188        let tree_options = SecondOrderRegressionTreeOptions {
189            tree_options,
190            l2_regularization: 1.0,
191            min_sum_hessian_in_leaf: 1e-3,
192            min_gain_to_split: 0.0,
193        };
194        let feature_preprocessing = capture_feature_preprocessing(train_set);
195        let sampler = BootstrapSampler::new(train_set.n_rows());
196
197        let (mut raw_predictions, class_labels, base_score) = match config.task {
198            Task::Regression => {
199                let targets = finite_targets(train_set)?;
200                let base_score = targets.iter().sum::<f64>() / targets.len() as f64;
201                (vec![base_score; train_set.n_rows()], None, base_score)
202            }
203            Task::Classification => {
204                let (labels, encoded_targets) = binary_classification_targets(train_set)?;
205                let positive_rate = (encoded_targets.iter().sum::<f64>()
206                    / encoded_targets.len() as f64)
207                    .clamp(1e-6, 1.0 - 1e-6);
208                let base_score = (positive_rate / (1.0 - positive_rate)).ln();
209                (
210                    vec![base_score; train_set.n_rows()],
211                    Some(labels),
212                    base_score,
213                )
214            }
215        };
216
217        let mut trees = Vec::with_capacity(n_trees);
218        let mut tree_weights = Vec::with_capacity(n_trees);
219        let regression_targets = if config.task == Task::Regression {
220            Some(finite_targets(train_set)?)
221        } else {
222            None
223        };
224        let classification_targets = if config.task == Task::Classification {
225            Some(binary_classification_targets(train_set)?.1)
226        } else {
227            None
228        };
229
230        for tree_index in 0..n_trees {
231            let stage_seed = mix_seed(base_seed, tree_index as u64);
232            // Gradients/hessians are recomputed from the current ensemble margin
233            // at every stage, which keeps the tree learner focused on residual
234            // structure that earlier trees did not explain.
235            let (gradients, hessians) = match config.task {
236                Task::Regression => squared_error_gradients_and_hessians(
237                    raw_predictions.as_slice(),
238                    regression_targets
239                        .as_ref()
240                        .expect("regression targets exist for regression boosting"),
241                ),
242                Task::Classification => logistic_gradients_and_hessians(
243                    raw_predictions.as_slice(),
244                    classification_targets
245                        .as_ref()
246                        .expect("classification targets exist for classification boosting"),
247                ),
248            };
249
250            let base_rows = if bootstrap {
251                sampler.sample(stage_seed)
252            } else {
253                (0..train_set.n_rows()).collect()
254            };
255            // GOSS-style sampling keeps the largest gradients deterministically
256            // and samples part of the remainder. This biases work toward the rows
257            // where the current ensemble is most wrong.
258            let sampled_rows = gradient_focus_sample(
259                &base_rows,
260                &gradients,
261                &hessians,
262                top_gradient_fraction,
263                other_gradient_fraction,
264                mix_seed(stage_seed, 0x6011_5A11),
265            );
266            let sampled_table = SampledTable::new(train_set, sampled_rows.row_indices);
267            let mut stage_tree_options = tree_options.clone();
268            stage_tree_options.tree_options.random_seed = stage_seed;
269            let stage_result = match config.tree_type {
270                TreeType::Cart => train_cart_regressor_from_gradients_and_hessians_with_status(
271                    &sampled_table,
272                    &sampled_rows.gradients,
273                    &sampled_rows.hessians,
274                    parallelism,
275                    stage_tree_options,
276                ),
277                TreeType::Randomized => {
278                    train_randomized_regressor_from_gradients_and_hessians_with_status(
279                        &sampled_table,
280                        &sampled_rows.gradients,
281                        &sampled_rows.hessians,
282                        parallelism,
283                        stage_tree_options,
284                    )
285                }
286                TreeType::Oblivious => {
287                    train_oblivious_regressor_from_gradients_and_hessians_with_status(
288                        &sampled_table,
289                        &sampled_rows.gradients,
290                        &sampled_rows.hessians,
291                        parallelism,
292                        stage_tree_options,
293                    )
294                }
295                _ => unreachable!("boosting tree type validated by training dispatch"),
296            }
297            .map_err(BoostingError::SecondOrderTree)?;
298
299            // A canary root win means the stage could not find a real feature
300            // stronger than shuffled noise, so boosting stops early.
301            if stage_result.root_canary_selected {
302                break;
303            }
304
305            let stage_tree = stage_result.model;
306            let stage_model = Model::DecisionTreeRegressor(stage_tree);
307            let stage_predictions = stage_model.predict_table(train_set);
308            for (raw_prediction, stage_prediction) in raw_predictions
309                .iter_mut()
310                .zip(stage_predictions.iter().copied())
311            {
312                // Trees are fit on raw margins; shrinkage is applied only when
313                // updating the ensemble prediction.
314                *raw_prediction += learning_rate * stage_prediction;
315            }
316            tree_weights.push(learning_rate);
317            trees.push(stage_model);
318        }
319
320        Ok(Self::new(
321            config.task,
322            config.tree_type,
323            trees,
324            tree_weights,
325            base_score,
326            learning_rate,
327            bootstrap,
328            top_gradient_fraction,
329            other_gradient_fraction,
330            max_features,
331            config.seed,
332            train_set.n_features(),
333            feature_preprocessing,
334            class_labels,
335            train_set.canaries(),
336        ))
337    }
338
339    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
340        match self.task {
341            Task::Regression => self.predict_regression_table(table),
342            Task::Classification => self.predict_classification_table(table),
343        }
344    }
345
346    pub fn predict_proba_table(
347        &self,
348        table: &dyn TableAccess,
349    ) -> Result<Vec<Vec<f64>>, PredictError> {
350        if self.task != Task::Classification {
351            return Err(PredictError::ProbabilityPredictionRequiresClassification);
352        }
353        Ok(self
354            .raw_scores(table)
355            .into_iter()
356            .map(|score| {
357                let positive = sigmoid(score);
358                vec![1.0 - positive, positive]
359            })
360            .collect())
361    }
362
363    pub fn task(&self) -> Task {
364        self.task
365    }
366
367    pub fn criterion(&self) -> Criterion {
368        Criterion::SecondOrder
369    }
370
371    pub fn tree_type(&self) -> TreeType {
372        self.tree_type
373    }
374
375    pub fn trees(&self) -> &[Model] {
376        &self.trees
377    }
378
379    pub fn tree_weights(&self) -> &[f64] {
380        &self.tree_weights
381    }
382
383    pub fn base_score(&self) -> f64 {
384        self.base_score
385    }
386
387    pub fn num_features(&self) -> usize {
388        self.num_features
389    }
390
391    pub fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
392        &self.feature_preprocessing
393    }
394
395    pub fn class_labels(&self) -> Option<Vec<f64>> {
396        self.class_labels.clone()
397    }
398
399    pub fn training_metadata(&self) -> TrainingMetadata {
400        TrainingMetadata {
401            algorithm: "gbm".to_string(),
402            task: match self.task {
403                Task::Regression => "regression".to_string(),
404                Task::Classification => "classification".to_string(),
405            },
406            tree_type: match self.tree_type {
407                TreeType::Cart => "cart".to_string(),
408                TreeType::Randomized => "randomized".to_string(),
409                TreeType::Oblivious => "oblivious".to_string(),
410                _ => unreachable!("boosting only supports cart/randomized/oblivious"),
411            },
412            criterion: "second_order".to_string(),
413            canaries: self.training_canaries,
414            compute_oob: false,
415            max_depth: self.trees.first().and_then(Model::max_depth),
416            min_samples_split: self.trees.first().and_then(Model::min_samples_split),
417            min_samples_leaf: self.trees.first().and_then(Model::min_samples_leaf),
418            n_trees: Some(self.trees.len()),
419            max_features: Some(self.max_features),
420            seed: self.seed,
421            oob_score: None,
422            class_labels: self.class_labels.clone(),
423            learning_rate: Some(self.learning_rate),
424            bootstrap: Some(self.bootstrap),
425            top_gradient_fraction: Some(self.top_gradient_fraction),
426            other_gradient_fraction: Some(self.other_gradient_fraction),
427        }
428    }
429
430    fn raw_scores(&self, table: &dyn TableAccess) -> Vec<f64> {
431        let mut scores = vec![self.base_score; table.n_rows()];
432        for (tree, weight) in self.trees.iter().zip(self.tree_weights.iter().copied()) {
433            let predictions = tree.predict_table(table);
434            for (score, prediction) in scores.iter_mut().zip(predictions.iter().copied()) {
435                *score += weight * prediction;
436            }
437        }
438        scores
439    }
440
441    fn predict_regression_table(&self, table: &dyn TableAccess) -> Vec<f64> {
442        self.raw_scores(table)
443    }
444
445    fn predict_classification_table(&self, table: &dyn TableAccess) -> Vec<f64> {
446        let class_labels = self
447            .class_labels
448            .as_ref()
449            .expect("classification boosting stores class labels");
450        self.raw_scores(table)
451            .into_iter()
452            .map(|score| {
453                if sigmoid(score) >= 0.5 {
454                    class_labels[1]
455                } else {
456                    class_labels[0]
457                }
458            })
459            .collect()
460    }
461}
462
463struct GradientFocusedSample {
464    row_indices: Vec<usize>,
465    gradients: Vec<f64>,
466    hessians: Vec<f64>,
467}
468
469impl<'a> SampledTable<'a> {
470    fn new(base: &'a dyn TableAccess, row_indices: Vec<usize>) -> Self {
471        Self { base, row_indices }
472    }
473
474    fn resolve_row(&self, row_index: usize) -> usize {
475        self.row_indices[row_index]
476    }
477}
478
479impl TableAccess for SampledTable<'_> {
480    fn n_rows(&self) -> usize {
481        self.row_indices.len()
482    }
483
484    fn n_features(&self) -> usize {
485        self.base.n_features()
486    }
487
488    fn canaries(&self) -> usize {
489        self.base.canaries()
490    }
491
492    fn numeric_bin_cap(&self) -> usize {
493        self.base.numeric_bin_cap()
494    }
495
496    fn binned_feature_count(&self) -> usize {
497        self.base.binned_feature_count()
498    }
499
500    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
501        self.base
502            .feature_value(feature_index, self.resolve_row(row_index))
503    }
504
505    fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
506        self.base
507            .is_missing(feature_index, self.resolve_row(row_index))
508    }
509
510    fn is_binary_feature(&self, index: usize) -> bool {
511        self.base.is_binary_feature(index)
512    }
513
514    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
515        self.base
516            .binned_value(feature_index, self.resolve_row(row_index))
517    }
518
519    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
520        self.base
521            .binned_boolean_value(feature_index, self.resolve_row(row_index))
522    }
523
524    fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
525        self.base.binned_column_kind(index)
526    }
527
528    fn is_binary_binned_feature(&self, index: usize) -> bool {
529        self.base.is_binary_binned_feature(index)
530    }
531
532    fn target_value(&self, row_index: usize) -> f64 {
533        self.base.target_value(self.resolve_row(row_index))
534    }
535}
536
537fn validate_boosting_parameters(
538    train_set: &dyn TableAccess,
539    learning_rate: f64,
540    top_gradient_fraction: f64,
541    other_gradient_fraction: f64,
542) -> Result<(), BoostingError> {
543    if train_set.n_rows() == 0 {
544        return Err(BoostingError::InvalidLearningRate(learning_rate));
545    }
546    if !learning_rate.is_finite() || learning_rate <= 0.0 {
547        return Err(BoostingError::InvalidLearningRate(learning_rate));
548    }
549    if !top_gradient_fraction.is_finite()
550        || top_gradient_fraction <= 0.0
551        || top_gradient_fraction > 1.0
552    {
553        return Err(BoostingError::InvalidTopGradientFraction(
554            top_gradient_fraction,
555        ));
556    }
557    if !other_gradient_fraction.is_finite()
558        || !(0.0..1.0).contains(&other_gradient_fraction)
559        || top_gradient_fraction + other_gradient_fraction > 1.0
560    {
561        return Err(BoostingError::InvalidOtherGradientFraction(
562            other_gradient_fraction,
563        ));
564    }
565    Ok(())
566}
567
568fn finite_targets(train_set: &dyn TableAccess) -> Result<Vec<f64>, BoostingError> {
569    (0..train_set.n_rows())
570        .map(|row_index| {
571            let value = train_set.target_value(row_index);
572            if value.is_finite() {
573                Ok(value)
574            } else {
575                Err(BoostingError::InvalidTargetValue {
576                    row: row_index,
577                    value,
578                })
579            }
580        })
581        .collect()
582}
583
584fn binary_classification_targets(
585    train_set: &dyn TableAccess,
586) -> Result<(Vec<f64>, Vec<f64>), BoostingError> {
587    let mut labels = finite_targets(train_set)?;
588    labels.sort_by(|left, right| left.total_cmp(right));
589    labels.dedup_by(|left, right| left.total_cmp(right).is_eq());
590    if labels.len() != 2 {
591        return Err(BoostingError::UnsupportedClassificationClassCount(
592            labels.len(),
593        ));
594    }
595
596    let negative = labels[0];
597    let encoded = (0..train_set.n_rows())
598        .map(|row_index| {
599            if train_set
600                .target_value(row_index)
601                .total_cmp(&negative)
602                .is_eq()
603            {
604                0.0
605            } else {
606                1.0
607            }
608        })
609        .collect();
610    Ok((labels, encoded))
611}
612
613fn squared_error_gradients_and_hessians(
614    raw_predictions: &[f64],
615    targets: &[f64],
616) -> (Vec<f64>, Vec<f64>) {
617    (
618        raw_predictions
619            .iter()
620            .zip(targets.iter())
621            .map(|(prediction, target)| prediction - target)
622            .collect(),
623        vec![1.0; targets.len()],
624    )
625}
626
627fn logistic_gradients_and_hessians(
628    raw_predictions: &[f64],
629    targets: &[f64],
630) -> (Vec<f64>, Vec<f64>) {
631    let mut gradients = Vec::with_capacity(targets.len());
632    let mut hessians = Vec::with_capacity(targets.len());
633    for (raw_prediction, target) in raw_predictions.iter().zip(targets.iter()) {
634        let probability = sigmoid(*raw_prediction);
635        gradients.push(probability - target);
636        hessians.push((probability * (1.0 - probability)).max(1e-12));
637    }
638    (gradients, hessians)
639}
640
641fn sigmoid(value: f64) -> f64 {
642    if value >= 0.0 {
643        let exp = (-value).exp();
644        1.0 / (1.0 + exp)
645    } else {
646        let exp = value.exp();
647        exp / (1.0 + exp)
648    }
649}
650
651fn gradient_focus_sample(
652    base_rows: &[usize],
653    gradients: &[f64],
654    hessians: &[f64],
655    top_gradient_fraction: f64,
656    other_gradient_fraction: f64,
657    seed: u64,
658) -> GradientFocusedSample {
659    let mut ranked = base_rows
660        .iter()
661        .copied()
662        .map(|row_index| (row_index, gradients[row_index].abs()))
663        .collect::<Vec<_>>();
664    ranked.sort_by(|(left_row, left_abs), (right_row, right_abs)| {
665        right_abs
666            .total_cmp(left_abs)
667            .then_with(|| left_row.cmp(right_row))
668    });
669
670    let top_count = ((ranked.len() as f64) * top_gradient_fraction)
671        .ceil()
672        .clamp(1.0, ranked.len() as f64) as usize;
673    let mut row_indices = Vec::with_capacity(ranked.len());
674    let mut sampled_gradients = Vec::with_capacity(ranked.len());
675    let mut sampled_hessians = Vec::with_capacity(ranked.len());
676
677    for (row_index, _) in ranked.iter().take(top_count) {
678        row_indices.push(*row_index);
679        sampled_gradients.push(gradients[*row_index]);
680        sampled_hessians.push(hessians[*row_index]);
681    }
682
683    if top_count < ranked.len() && other_gradient_fraction > 0.0 {
684        let remaining = ranked[top_count..]
685            .iter()
686            .map(|(row_index, _)| *row_index)
687            .collect::<Vec<_>>();
688        let other_count = ((remaining.len() as f64) * other_gradient_fraction)
689            .ceil()
690            .min(remaining.len() as f64) as usize;
691        if other_count > 0 {
692            let mut remaining = remaining;
693            let mut rng = StdRng::seed_from_u64(seed);
694            remaining.shuffle(&mut rng);
695            let gradient_scale = (1.0 - top_gradient_fraction) / other_gradient_fraction;
696            for row_index in remaining.into_iter().take(other_count) {
697                row_indices.push(row_index);
698                sampled_gradients.push(gradients[row_index] * gradient_scale);
699                sampled_hessians.push(hessians[row_index] * gradient_scale);
700            }
701        }
702    }
703
704    GradientFocusedSample {
705        row_indices,
706        gradients: sampled_gradients,
707        hessians: sampled_hessians,
708    }
709}
710
711#[cfg(test)]
712mod tests {
713    use super::*;
714    use crate::{CanaryFilter, MaxFeatures, TrainAlgorithm, TrainConfig};
715    use forestfire_data::{BinnedColumnKind, TableAccess};
716    use forestfire_data::{DenseTable, NumericBins};
717
718    #[test]
719    fn regression_boosting_fits_simple_signal() {
720        let table = DenseTable::with_options(
721            vec![
722                vec![0.0],
723                vec![0.0],
724                vec![1.0],
725                vec![1.0],
726                vec![2.0],
727                vec![2.0],
728            ],
729            vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
730            0,
731            NumericBins::fixed(8).unwrap(),
732        )
733        .unwrap();
734
735        let model = GradientBoostedTrees::train(
736            &table,
737            TrainConfig {
738                algorithm: TrainAlgorithm::Gbm,
739                task: Task::Regression,
740                tree_type: TreeType::Cart,
741                criterion: Criterion::SecondOrder,
742                n_trees: Some(20),
743                learning_rate: Some(0.2),
744                max_depth: Some(2),
745                ..TrainConfig::default()
746            },
747            Parallelism::sequential(),
748        )
749        .unwrap();
750
751        let predictions = model.predict_table(&table);
752        assert!(predictions[0] < predictions[2]);
753        assert!(predictions[2] < predictions[4]);
754    }
755
756    #[test]
757    fn classification_boosting_produces_binary_probabilities() {
758        let table = DenseTable::with_options(
759            vec![vec![0.0], vec![0.1], vec![0.9], vec![1.0]],
760            vec![0.0, 0.0, 1.0, 1.0],
761            0,
762            NumericBins::fixed(8).unwrap(),
763        )
764        .unwrap();
765
766        let model = GradientBoostedTrees::train(
767            &table,
768            TrainConfig {
769                algorithm: TrainAlgorithm::Gbm,
770                task: Task::Classification,
771                tree_type: TreeType::Cart,
772                criterion: Criterion::SecondOrder,
773                n_trees: Some(25),
774                learning_rate: Some(0.2),
775                max_depth: Some(2),
776                ..TrainConfig::default()
777            },
778            Parallelism::sequential(),
779        )
780        .unwrap();
781
782        let probabilities = model.predict_proba_table(&table).unwrap();
783        assert_eq!(probabilities.len(), 4);
784        assert!(probabilities[0][1] < 0.5);
785        assert!(probabilities[3][1] > 0.5);
786    }
787
788    #[test]
789    fn classification_boosting_rejects_multiclass_targets() {
790        let table =
791            DenseTable::new(vec![vec![0.0], vec![1.0], vec![2.0]], vec![0.0, 1.0, 2.0]).unwrap();
792
793        let error = GradientBoostedTrees::train(
794            &table,
795            TrainConfig {
796                algorithm: TrainAlgorithm::Gbm,
797                task: Task::Classification,
798                tree_type: TreeType::Cart,
799                criterion: Criterion::SecondOrder,
800                ..TrainConfig::default()
801            },
802            Parallelism::sequential(),
803        )
804        .unwrap_err();
805
806        assert!(matches!(
807            error,
808            BoostingError::UnsupportedClassificationClassCount(3)
809        ));
810    }
811
812    struct RootCanaryTable;
813
814    struct FilteredRootCanaryTable;
815
816    impl TableAccess for RootCanaryTable {
817        fn n_rows(&self) -> usize {
818            4
819        }
820
821        fn n_features(&self) -> usize {
822            1
823        }
824
825        fn canaries(&self) -> usize {
826            1
827        }
828
829        fn numeric_bin_cap(&self) -> usize {
830            2
831        }
832
833        fn binned_feature_count(&self) -> usize {
834            2
835        }
836
837        fn feature_value(&self, _feature_index: usize, _row_index: usize) -> f64 {
838            0.0
839        }
840
841        fn is_missing(&self, _feature_index: usize, _row_index: usize) -> bool {
842            false
843        }
844
845        fn is_binary_feature(&self, _index: usize) -> bool {
846            true
847        }
848
849        fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
850            match feature_index {
851                0 => 0,
852                1 => u16::from(row_index >= 2),
853                _ => unreachable!(),
854            }
855        }
856
857        fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
858            Some(match feature_index {
859                0 => false,
860                1 => row_index >= 2,
861                _ => unreachable!(),
862            })
863        }
864
865        fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
866            match index {
867                0 => BinnedColumnKind::Real { source_index: 0 },
868                1 => BinnedColumnKind::Canary {
869                    source_index: 0,
870                    copy_index: 0,
871                },
872                _ => unreachable!(),
873            }
874        }
875
876        fn is_binary_binned_feature(&self, _index: usize) -> bool {
877            true
878        }
879
880        fn target_value(&self, row_index: usize) -> f64 {
881            [0.0, 0.0, 1.0, 1.0][row_index]
882        }
883    }
884
885    impl TableAccess for FilteredRootCanaryTable {
886        fn n_rows(&self) -> usize {
887            4
888        }
889
890        fn n_features(&self) -> usize {
891            1
892        }
893
894        fn canaries(&self) -> usize {
895            1
896        }
897
898        fn numeric_bin_cap(&self) -> usize {
899            2
900        }
901
902        fn binned_feature_count(&self) -> usize {
903            2
904        }
905
906        fn feature_value(&self, _feature_index: usize, _row_index: usize) -> f64 {
907            0.0
908        }
909
910        fn is_missing(&self, _feature_index: usize, _row_index: usize) -> bool {
911            false
912        }
913
914        fn is_binary_feature(&self, _index: usize) -> bool {
915            true
916        }
917
918        fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
919            u16::from(
920                self.binned_boolean_value(feature_index, row_index)
921                    .expect("all features are observed"),
922            )
923        }
924
925        fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
926            Some(match feature_index {
927                0 => row_index == 3,
928                1 => row_index >= 2,
929                _ => unreachable!(),
930            })
931        }
932
933        fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
934            match index {
935                0 => BinnedColumnKind::Real { source_index: 0 },
936                1 => BinnedColumnKind::Canary {
937                    source_index: 0,
938                    copy_index: 0,
939                },
940                _ => unreachable!(),
941            }
942        }
943
944        fn is_binary_binned_feature(&self, _index: usize) -> bool {
945            true
946        }
947
948        fn target_value(&self, row_index: usize) -> f64 {
949            [0.0, 0.0, 1.0, 1.0][row_index]
950        }
951    }
952
953    #[test]
954    fn boosting_stops_when_root_split_is_a_canary() {
955        let table = RootCanaryTable;
956
957        let model = GradientBoostedTrees::train(
958            &table,
959            TrainConfig {
960                algorithm: TrainAlgorithm::Gbm,
961                task: Task::Regression,
962                tree_type: TreeType::Cart,
963                criterion: Criterion::SecondOrder,
964                n_trees: Some(10),
965                max_features: MaxFeatures::All,
966                learning_rate: Some(0.1),
967                top_gradient_fraction: Some(1.0),
968                other_gradient_fraction: Some(0.0),
969                ..TrainConfig::default()
970            },
971            Parallelism::sequential(),
972        )
973        .unwrap();
974
975        assert_eq!(model.trees().len(), 0);
976        assert_eq!(model.training_metadata().n_trees, Some(0));
977        assert!(
978            model
979                .predict_table(&table)
980                .iter()
981                .all(|value| value.is_finite())
982        );
983    }
984
985    #[test]
986    fn boosting_can_use_a_real_root_split_inside_top_n_canary_filter() {
987        let table = FilteredRootCanaryTable;
988
989        let model = GradientBoostedTrees::train(
990            &table,
991            TrainConfig {
992                algorithm: TrainAlgorithm::Gbm,
993                task: Task::Regression,
994                tree_type: TreeType::Cart,
995                criterion: Criterion::SecondOrder,
996                n_trees: Some(10),
997                max_features: MaxFeatures::All,
998                learning_rate: Some(0.1),
999                top_gradient_fraction: Some(1.0),
1000                other_gradient_fraction: Some(0.0),
1001                canary_filter: CanaryFilter::TopN(2),
1002                ..TrainConfig::default()
1003            },
1004            Parallelism::sequential(),
1005        )
1006        .unwrap();
1007
1008        assert!(!model.trees().is_empty());
1009    }
1010}