1use crate::bootstrap::BootstrapSampler;
9use crate::ir::TrainingMetadata;
10use crate::tree::second_order::{
11 SecondOrderRegressionTreeError, SecondOrderRegressionTreeOptions,
12 train_cart_regressor_from_gradients_and_hessians_with_status,
13 train_oblivious_regressor_from_gradients_and_hessians_with_status,
14 train_randomized_regressor_from_gradients_and_hessians_with_status,
15};
16use crate::tree::shared::mix_seed;
17use crate::{
18 Criterion, FeaturePreprocessing, Model, Parallelism, PredictError, Task, TrainConfig, TreeType,
19 capture_feature_preprocessing,
20};
21use forestfire_data::TableAccess;
22use rand::SeedableRng;
23use rand::rngs::StdRng;
24use rand::seq::SliceRandom;
25
26#[derive(Debug, Clone)]
31pub struct GradientBoostedTrees {
32 task: Task,
33 tree_type: TreeType,
34 trees: Vec<Model>,
35 tree_weights: Vec<f64>,
36 base_score: f64,
37 learning_rate: f64,
38 bootstrap: bool,
39 top_gradient_fraction: f64,
40 other_gradient_fraction: f64,
41 max_features: usize,
42 seed: Option<u64>,
43 num_features: usize,
44 feature_preprocessing: Vec<FeaturePreprocessing>,
45 class_labels: Option<Vec<f64>>,
46 training_canaries: usize,
47}
48
49#[derive(Debug)]
50pub enum BoostingError {
51 InvalidTargetValue { row: usize, value: f64 },
52 UnsupportedClassificationClassCount(usize),
53 InvalidLearningRate(f64),
54 InvalidTopGradientFraction(f64),
55 InvalidOtherGradientFraction(f64),
56 SecondOrderTree(SecondOrderRegressionTreeError),
57}
58
59impl std::fmt::Display for BoostingError {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 BoostingError::InvalidTargetValue { row, value } => write!(
63 f,
64 "Boosting targets must be finite values. Found {} at row {}.",
65 value, row
66 ),
67 BoostingError::UnsupportedClassificationClassCount(count) => write!(
68 f,
69 "Gradient boosting currently supports binary classification only. Found {} classes.",
70 count
71 ),
72 BoostingError::InvalidLearningRate(value) => write!(
73 f,
74 "learning_rate must be finite and greater than 0. Found {}.",
75 value
76 ),
77 BoostingError::InvalidTopGradientFraction(value) => write!(
78 f,
79 "top_gradient_fraction must be in the interval (0, 1]. Found {}.",
80 value
81 ),
82 BoostingError::InvalidOtherGradientFraction(value) => write!(
83 f,
84 "other_gradient_fraction must be in the interval [0, 1), and top_gradient_fraction + other_gradient_fraction must be at most 1. Found {}.",
85 value
86 ),
87 BoostingError::SecondOrderTree(err) => err.fmt(f),
88 }
89 }
90}
91
92impl std::error::Error for BoostingError {}
93
94struct SampledTable<'a> {
95 base: &'a dyn TableAccess,
96 row_indices: Vec<usize>,
97}
98
99impl GradientBoostedTrees {
100 #[allow(clippy::too_many_arguments)]
101 pub fn new(
102 task: Task,
103 tree_type: TreeType,
104 trees: Vec<Model>,
105 tree_weights: Vec<f64>,
106 base_score: f64,
107 learning_rate: f64,
108 bootstrap: bool,
109 top_gradient_fraction: f64,
110 other_gradient_fraction: f64,
111 max_features: usize,
112 seed: Option<u64>,
113 num_features: usize,
114 feature_preprocessing: Vec<FeaturePreprocessing>,
115 class_labels: Option<Vec<f64>>,
116 training_canaries: usize,
117 ) -> Self {
118 Self {
119 task,
120 tree_type,
121 trees,
122 tree_weights,
123 base_score,
124 learning_rate,
125 bootstrap,
126 top_gradient_fraction,
127 other_gradient_fraction,
128 max_features,
129 seed,
130 num_features,
131 feature_preprocessing,
132 class_labels,
133 training_canaries,
134 }
135 }
136
137 #[allow(dead_code)]
138 pub(crate) fn train(
139 train_set: &dyn TableAccess,
140 config: TrainConfig,
141 parallelism: Parallelism,
142 ) -> Result<Self, BoostingError> {
143 let missing_value_strategies = config
144 .missing_value_strategy
145 .resolve_for_feature_count(train_set.binned_feature_count())
146 .unwrap_or_else(|err| {
147 panic!("unexpected training error while resolving missing strategy: {err}")
148 });
149 Self::train_with_missing_value_strategies(
150 train_set,
151 config,
152 parallelism,
153 missing_value_strategies,
154 )
155 }
156
157 pub(crate) fn train_with_missing_value_strategies(
158 train_set: &dyn TableAccess,
159 config: TrainConfig,
160 parallelism: Parallelism,
161 missing_value_strategies: Vec<crate::MissingValueStrategy>,
162 ) -> Result<Self, BoostingError> {
163 let n_trees = config.n_trees.unwrap_or(100);
164 let learning_rate = config.learning_rate.unwrap_or(0.1);
165 let bootstrap = config.bootstrap;
166 let top_gradient_fraction = config.top_gradient_fraction.unwrap_or(0.2);
167 let other_gradient_fraction = config.other_gradient_fraction.unwrap_or(0.1);
168 validate_boosting_parameters(
169 train_set,
170 learning_rate,
171 top_gradient_fraction,
172 other_gradient_fraction,
173 )?;
174
175 let max_features = config
176 .max_features
177 .resolve(config.task, train_set.binned_feature_count());
178 let base_seed = config.seed.unwrap_or(0xB005_7EED_u64);
179 let tree_options = crate::RegressionTreeOptions {
180 max_depth: config.max_depth.unwrap_or(8),
181 min_samples_split: config.min_samples_split.unwrap_or(2),
182 min_samples_leaf: config.min_samples_leaf.unwrap_or(1),
183 max_features: Some(max_features),
184 random_seed: 0,
185 missing_value_strategies,
186 };
187 let tree_options = SecondOrderRegressionTreeOptions {
188 tree_options,
189 l2_regularization: 1.0,
190 min_sum_hessian_in_leaf: 1e-3,
191 min_gain_to_split: 0.0,
192 };
193 let feature_preprocessing = capture_feature_preprocessing(train_set);
194 let sampler = BootstrapSampler::new(train_set.n_rows());
195
196 let (mut raw_predictions, class_labels, base_score) = match config.task {
197 Task::Regression => {
198 let targets = finite_targets(train_set)?;
199 let base_score = targets.iter().sum::<f64>() / targets.len() as f64;
200 (vec![base_score; train_set.n_rows()], None, base_score)
201 }
202 Task::Classification => {
203 let (labels, encoded_targets) = binary_classification_targets(train_set)?;
204 let positive_rate = (encoded_targets.iter().sum::<f64>()
205 / encoded_targets.len() as f64)
206 .clamp(1e-6, 1.0 - 1e-6);
207 let base_score = (positive_rate / (1.0 - positive_rate)).ln();
208 (
209 vec![base_score; train_set.n_rows()],
210 Some(labels),
211 base_score,
212 )
213 }
214 };
215
216 let mut trees = Vec::with_capacity(n_trees);
217 let mut tree_weights = Vec::with_capacity(n_trees);
218 let regression_targets = if config.task == Task::Regression {
219 Some(finite_targets(train_set)?)
220 } else {
221 None
222 };
223 let classification_targets = if config.task == Task::Classification {
224 Some(binary_classification_targets(train_set)?.1)
225 } else {
226 None
227 };
228
229 for tree_index in 0..n_trees {
230 let stage_seed = mix_seed(base_seed, tree_index as u64);
231 let (gradients, hessians) = match config.task {
235 Task::Regression => squared_error_gradients_and_hessians(
236 raw_predictions.as_slice(),
237 regression_targets
238 .as_ref()
239 .expect("regression targets exist for regression boosting"),
240 ),
241 Task::Classification => logistic_gradients_and_hessians(
242 raw_predictions.as_slice(),
243 classification_targets
244 .as_ref()
245 .expect("classification targets exist for classification boosting"),
246 ),
247 };
248
249 let base_rows = if bootstrap {
250 sampler.sample(stage_seed)
251 } else {
252 (0..train_set.n_rows()).collect()
253 };
254 let sampled_rows = gradient_focus_sample(
258 &base_rows,
259 &gradients,
260 &hessians,
261 top_gradient_fraction,
262 other_gradient_fraction,
263 mix_seed(stage_seed, 0x6011_5A11),
264 );
265 let sampled_table = SampledTable::new(train_set, sampled_rows.row_indices);
266 let mut stage_tree_options = tree_options.clone();
267 stage_tree_options.tree_options.random_seed = stage_seed;
268 let stage_result = match config.tree_type {
269 TreeType::Cart => train_cart_regressor_from_gradients_and_hessians_with_status(
270 &sampled_table,
271 &sampled_rows.gradients,
272 &sampled_rows.hessians,
273 parallelism,
274 stage_tree_options,
275 ),
276 TreeType::Randomized => {
277 train_randomized_regressor_from_gradients_and_hessians_with_status(
278 &sampled_table,
279 &sampled_rows.gradients,
280 &sampled_rows.hessians,
281 parallelism,
282 stage_tree_options,
283 )
284 }
285 TreeType::Oblivious => {
286 train_oblivious_regressor_from_gradients_and_hessians_with_status(
287 &sampled_table,
288 &sampled_rows.gradients,
289 &sampled_rows.hessians,
290 parallelism,
291 stage_tree_options,
292 )
293 }
294 _ => unreachable!("boosting tree type validated by training dispatch"),
295 }
296 .map_err(BoostingError::SecondOrderTree)?;
297
298 if stage_result.root_canary_selected {
301 break;
302 }
303
304 let stage_tree = stage_result.model;
305 let stage_model = Model::DecisionTreeRegressor(stage_tree);
306 let stage_predictions = stage_model.predict_table(train_set);
307 for (raw_prediction, stage_prediction) in raw_predictions
308 .iter_mut()
309 .zip(stage_predictions.iter().copied())
310 {
311 *raw_prediction += learning_rate * stage_prediction;
314 }
315 tree_weights.push(learning_rate);
316 trees.push(stage_model);
317 }
318
319 Ok(Self::new(
320 config.task,
321 config.tree_type,
322 trees,
323 tree_weights,
324 base_score,
325 learning_rate,
326 bootstrap,
327 top_gradient_fraction,
328 other_gradient_fraction,
329 max_features,
330 config.seed,
331 train_set.n_features(),
332 feature_preprocessing,
333 class_labels,
334 train_set.canaries(),
335 ))
336 }
337
338 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
339 match self.task {
340 Task::Regression => self.predict_regression_table(table),
341 Task::Classification => self.predict_classification_table(table),
342 }
343 }
344
345 pub fn predict_proba_table(
346 &self,
347 table: &dyn TableAccess,
348 ) -> Result<Vec<Vec<f64>>, PredictError> {
349 if self.task != Task::Classification {
350 return Err(PredictError::ProbabilityPredictionRequiresClassification);
351 }
352 Ok(self
353 .raw_scores(table)
354 .into_iter()
355 .map(|score| {
356 let positive = sigmoid(score);
357 vec![1.0 - positive, positive]
358 })
359 .collect())
360 }
361
362 pub fn task(&self) -> Task {
363 self.task
364 }
365
366 pub fn criterion(&self) -> Criterion {
367 Criterion::SecondOrder
368 }
369
370 pub fn tree_type(&self) -> TreeType {
371 self.tree_type
372 }
373
374 pub fn trees(&self) -> &[Model] {
375 &self.trees
376 }
377
378 pub fn tree_weights(&self) -> &[f64] {
379 &self.tree_weights
380 }
381
382 pub fn base_score(&self) -> f64 {
383 self.base_score
384 }
385
386 pub fn num_features(&self) -> usize {
387 self.num_features
388 }
389
390 pub fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
391 &self.feature_preprocessing
392 }
393
394 pub fn class_labels(&self) -> Option<Vec<f64>> {
395 self.class_labels.clone()
396 }
397
398 pub fn training_metadata(&self) -> TrainingMetadata {
399 TrainingMetadata {
400 algorithm: "gbm".to_string(),
401 task: match self.task {
402 Task::Regression => "regression".to_string(),
403 Task::Classification => "classification".to_string(),
404 },
405 tree_type: match self.tree_type {
406 TreeType::Cart => "cart".to_string(),
407 TreeType::Randomized => "randomized".to_string(),
408 TreeType::Oblivious => "oblivious".to_string(),
409 _ => unreachable!("boosting only supports cart/randomized/oblivious"),
410 },
411 criterion: "second_order".to_string(),
412 canaries: self.training_canaries,
413 compute_oob: false,
414 max_depth: self.trees.first().and_then(Model::max_depth),
415 min_samples_split: self.trees.first().and_then(Model::min_samples_split),
416 min_samples_leaf: self.trees.first().and_then(Model::min_samples_leaf),
417 n_trees: Some(self.trees.len()),
418 max_features: Some(self.max_features),
419 seed: self.seed,
420 oob_score: None,
421 class_labels: self.class_labels.clone(),
422 learning_rate: Some(self.learning_rate),
423 bootstrap: Some(self.bootstrap),
424 top_gradient_fraction: Some(self.top_gradient_fraction),
425 other_gradient_fraction: Some(self.other_gradient_fraction),
426 }
427 }
428
429 fn raw_scores(&self, table: &dyn TableAccess) -> Vec<f64> {
430 let mut scores = vec![self.base_score; table.n_rows()];
431 for (tree, weight) in self.trees.iter().zip(self.tree_weights.iter().copied()) {
432 let predictions = tree.predict_table(table);
433 for (score, prediction) in scores.iter_mut().zip(predictions.iter().copied()) {
434 *score += weight * prediction;
435 }
436 }
437 scores
438 }
439
440 fn predict_regression_table(&self, table: &dyn TableAccess) -> Vec<f64> {
441 self.raw_scores(table)
442 }
443
444 fn predict_classification_table(&self, table: &dyn TableAccess) -> Vec<f64> {
445 let class_labels = self
446 .class_labels
447 .as_ref()
448 .expect("classification boosting stores class labels");
449 self.raw_scores(table)
450 .into_iter()
451 .map(|score| {
452 if sigmoid(score) >= 0.5 {
453 class_labels[1]
454 } else {
455 class_labels[0]
456 }
457 })
458 .collect()
459 }
460}
461
462struct GradientFocusedSample {
463 row_indices: Vec<usize>,
464 gradients: Vec<f64>,
465 hessians: Vec<f64>,
466}
467
468impl<'a> SampledTable<'a> {
469 fn new(base: &'a dyn TableAccess, row_indices: Vec<usize>) -> Self {
470 Self { base, row_indices }
471 }
472
473 fn resolve_row(&self, row_index: usize) -> usize {
474 self.row_indices[row_index]
475 }
476}
477
478impl TableAccess for SampledTable<'_> {
479 fn n_rows(&self) -> usize {
480 self.row_indices.len()
481 }
482
483 fn n_features(&self) -> usize {
484 self.base.n_features()
485 }
486
487 fn canaries(&self) -> usize {
488 self.base.canaries()
489 }
490
491 fn numeric_bin_cap(&self) -> usize {
492 self.base.numeric_bin_cap()
493 }
494
495 fn binned_feature_count(&self) -> usize {
496 self.base.binned_feature_count()
497 }
498
499 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
500 self.base
501 .feature_value(feature_index, self.resolve_row(row_index))
502 }
503
504 fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
505 self.base
506 .is_missing(feature_index, self.resolve_row(row_index))
507 }
508
509 fn is_binary_feature(&self, index: usize) -> bool {
510 self.base.is_binary_feature(index)
511 }
512
513 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
514 self.base
515 .binned_value(feature_index, self.resolve_row(row_index))
516 }
517
518 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
519 self.base
520 .binned_boolean_value(feature_index, self.resolve_row(row_index))
521 }
522
523 fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
524 self.base.binned_column_kind(index)
525 }
526
527 fn is_binary_binned_feature(&self, index: usize) -> bool {
528 self.base.is_binary_binned_feature(index)
529 }
530
531 fn target_value(&self, row_index: usize) -> f64 {
532 self.base.target_value(self.resolve_row(row_index))
533 }
534}
535
536fn validate_boosting_parameters(
537 train_set: &dyn TableAccess,
538 learning_rate: f64,
539 top_gradient_fraction: f64,
540 other_gradient_fraction: f64,
541) -> Result<(), BoostingError> {
542 if train_set.n_rows() == 0 {
543 return Err(BoostingError::InvalidLearningRate(learning_rate));
544 }
545 if !learning_rate.is_finite() || learning_rate <= 0.0 {
546 return Err(BoostingError::InvalidLearningRate(learning_rate));
547 }
548 if !top_gradient_fraction.is_finite()
549 || top_gradient_fraction <= 0.0
550 || top_gradient_fraction > 1.0
551 {
552 return Err(BoostingError::InvalidTopGradientFraction(
553 top_gradient_fraction,
554 ));
555 }
556 if !other_gradient_fraction.is_finite()
557 || !(0.0..1.0).contains(&other_gradient_fraction)
558 || top_gradient_fraction + other_gradient_fraction > 1.0
559 {
560 return Err(BoostingError::InvalidOtherGradientFraction(
561 other_gradient_fraction,
562 ));
563 }
564 Ok(())
565}
566
567fn finite_targets(train_set: &dyn TableAccess) -> Result<Vec<f64>, BoostingError> {
568 (0..train_set.n_rows())
569 .map(|row_index| {
570 let value = train_set.target_value(row_index);
571 if value.is_finite() {
572 Ok(value)
573 } else {
574 Err(BoostingError::InvalidTargetValue {
575 row: row_index,
576 value,
577 })
578 }
579 })
580 .collect()
581}
582
583fn binary_classification_targets(
584 train_set: &dyn TableAccess,
585) -> Result<(Vec<f64>, Vec<f64>), BoostingError> {
586 let mut labels = finite_targets(train_set)?;
587 labels.sort_by(|left, right| left.total_cmp(right));
588 labels.dedup_by(|left, right| left.total_cmp(right).is_eq());
589 if labels.len() != 2 {
590 return Err(BoostingError::UnsupportedClassificationClassCount(
591 labels.len(),
592 ));
593 }
594
595 let negative = labels[0];
596 let encoded = (0..train_set.n_rows())
597 .map(|row_index| {
598 if train_set
599 .target_value(row_index)
600 .total_cmp(&negative)
601 .is_eq()
602 {
603 0.0
604 } else {
605 1.0
606 }
607 })
608 .collect();
609 Ok((labels, encoded))
610}
611
612fn squared_error_gradients_and_hessians(
613 raw_predictions: &[f64],
614 targets: &[f64],
615) -> (Vec<f64>, Vec<f64>) {
616 (
617 raw_predictions
618 .iter()
619 .zip(targets.iter())
620 .map(|(prediction, target)| prediction - target)
621 .collect(),
622 vec![1.0; targets.len()],
623 )
624}
625
626fn logistic_gradients_and_hessians(
627 raw_predictions: &[f64],
628 targets: &[f64],
629) -> (Vec<f64>, Vec<f64>) {
630 let mut gradients = Vec::with_capacity(targets.len());
631 let mut hessians = Vec::with_capacity(targets.len());
632 for (raw_prediction, target) in raw_predictions.iter().zip(targets.iter()) {
633 let probability = sigmoid(*raw_prediction);
634 gradients.push(probability - target);
635 hessians.push((probability * (1.0 - probability)).max(1e-12));
636 }
637 (gradients, hessians)
638}
639
640fn sigmoid(value: f64) -> f64 {
641 if value >= 0.0 {
642 let exp = (-value).exp();
643 1.0 / (1.0 + exp)
644 } else {
645 let exp = value.exp();
646 exp / (1.0 + exp)
647 }
648}
649
650fn gradient_focus_sample(
651 base_rows: &[usize],
652 gradients: &[f64],
653 hessians: &[f64],
654 top_gradient_fraction: f64,
655 other_gradient_fraction: f64,
656 seed: u64,
657) -> GradientFocusedSample {
658 let mut ranked = base_rows
659 .iter()
660 .copied()
661 .map(|row_index| (row_index, gradients[row_index].abs()))
662 .collect::<Vec<_>>();
663 ranked.sort_by(|(left_row, left_abs), (right_row, right_abs)| {
664 right_abs
665 .total_cmp(left_abs)
666 .then_with(|| left_row.cmp(right_row))
667 });
668
669 let top_count = ((ranked.len() as f64) * top_gradient_fraction)
670 .ceil()
671 .clamp(1.0, ranked.len() as f64) as usize;
672 let mut row_indices = Vec::with_capacity(ranked.len());
673 let mut sampled_gradients = Vec::with_capacity(ranked.len());
674 let mut sampled_hessians = Vec::with_capacity(ranked.len());
675
676 for (row_index, _) in ranked.iter().take(top_count) {
677 row_indices.push(*row_index);
678 sampled_gradients.push(gradients[*row_index]);
679 sampled_hessians.push(hessians[*row_index]);
680 }
681
682 if top_count < ranked.len() && other_gradient_fraction > 0.0 {
683 let remaining = ranked[top_count..]
684 .iter()
685 .map(|(row_index, _)| *row_index)
686 .collect::<Vec<_>>();
687 let other_count = ((remaining.len() as f64) * other_gradient_fraction)
688 .ceil()
689 .min(remaining.len() as f64) as usize;
690 if other_count > 0 {
691 let mut remaining = remaining;
692 let mut rng = StdRng::seed_from_u64(seed);
693 remaining.shuffle(&mut rng);
694 let gradient_scale = (1.0 - top_gradient_fraction) / other_gradient_fraction;
695 for row_index in remaining.into_iter().take(other_count) {
696 row_indices.push(row_index);
697 sampled_gradients.push(gradients[row_index] * gradient_scale);
698 sampled_hessians.push(hessians[row_index] * gradient_scale);
699 }
700 }
701 }
702
703 GradientFocusedSample {
704 row_indices,
705 gradients: sampled_gradients,
706 hessians: sampled_hessians,
707 }
708}
709
710#[cfg(test)]
711mod tests {
712 use super::*;
713 use crate::{MaxFeatures, TrainAlgorithm, TrainConfig};
714 use forestfire_data::{BinnedColumnKind, TableAccess};
715 use forestfire_data::{DenseTable, NumericBins};
716
717 #[test]
718 fn regression_boosting_fits_simple_signal() {
719 let table = DenseTable::with_options(
720 vec![
721 vec![0.0],
722 vec![0.0],
723 vec![1.0],
724 vec![1.0],
725 vec![2.0],
726 vec![2.0],
727 ],
728 vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
729 0,
730 NumericBins::fixed(8).unwrap(),
731 )
732 .unwrap();
733
734 let model = GradientBoostedTrees::train(
735 &table,
736 TrainConfig {
737 algorithm: TrainAlgorithm::Gbm,
738 task: Task::Regression,
739 tree_type: TreeType::Cart,
740 criterion: Criterion::SecondOrder,
741 n_trees: Some(20),
742 learning_rate: Some(0.2),
743 max_depth: Some(2),
744 ..TrainConfig::default()
745 },
746 Parallelism::sequential(),
747 )
748 .unwrap();
749
750 let predictions = model.predict_table(&table);
751 assert!(predictions[0] < predictions[2]);
752 assert!(predictions[2] < predictions[4]);
753 }
754
755 #[test]
756 fn classification_boosting_produces_binary_probabilities() {
757 let table = DenseTable::with_options(
758 vec![vec![0.0], vec![0.1], vec![0.9], vec![1.0]],
759 vec![0.0, 0.0, 1.0, 1.0],
760 0,
761 NumericBins::fixed(8).unwrap(),
762 )
763 .unwrap();
764
765 let model = GradientBoostedTrees::train(
766 &table,
767 TrainConfig {
768 algorithm: TrainAlgorithm::Gbm,
769 task: Task::Classification,
770 tree_type: TreeType::Cart,
771 criterion: Criterion::SecondOrder,
772 n_trees: Some(25),
773 learning_rate: Some(0.2),
774 max_depth: Some(2),
775 ..TrainConfig::default()
776 },
777 Parallelism::sequential(),
778 )
779 .unwrap();
780
781 let probabilities = model.predict_proba_table(&table).unwrap();
782 assert_eq!(probabilities.len(), 4);
783 assert!(probabilities[0][1] < 0.5);
784 assert!(probabilities[3][1] > 0.5);
785 }
786
787 #[test]
788 fn classification_boosting_rejects_multiclass_targets() {
789 let table =
790 DenseTable::new(vec![vec![0.0], vec![1.0], vec![2.0]], vec![0.0, 1.0, 2.0]).unwrap();
791
792 let error = GradientBoostedTrees::train(
793 &table,
794 TrainConfig {
795 algorithm: TrainAlgorithm::Gbm,
796 task: Task::Classification,
797 tree_type: TreeType::Cart,
798 criterion: Criterion::SecondOrder,
799 ..TrainConfig::default()
800 },
801 Parallelism::sequential(),
802 )
803 .unwrap_err();
804
805 assert!(matches!(
806 error,
807 BoostingError::UnsupportedClassificationClassCount(3)
808 ));
809 }
810
811 struct RootCanaryTable;
812
813 impl TableAccess for RootCanaryTable {
814 fn n_rows(&self) -> usize {
815 4
816 }
817
818 fn n_features(&self) -> usize {
819 1
820 }
821
822 fn canaries(&self) -> usize {
823 1
824 }
825
826 fn numeric_bin_cap(&self) -> usize {
827 2
828 }
829
830 fn binned_feature_count(&self) -> usize {
831 2
832 }
833
834 fn feature_value(&self, _feature_index: usize, _row_index: usize) -> f64 {
835 0.0
836 }
837
838 fn is_missing(&self, _feature_index: usize, _row_index: usize) -> bool {
839 false
840 }
841
842 fn is_binary_feature(&self, _index: usize) -> bool {
843 true
844 }
845
846 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
847 match feature_index {
848 0 => 0,
849 1 => u16::from(row_index >= 2),
850 _ => unreachable!(),
851 }
852 }
853
854 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
855 Some(match feature_index {
856 0 => false,
857 1 => row_index >= 2,
858 _ => unreachable!(),
859 })
860 }
861
862 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
863 match index {
864 0 => BinnedColumnKind::Real { source_index: 0 },
865 1 => BinnedColumnKind::Canary {
866 source_index: 0,
867 copy_index: 0,
868 },
869 _ => unreachable!(),
870 }
871 }
872
873 fn is_binary_binned_feature(&self, _index: usize) -> bool {
874 true
875 }
876
877 fn target_value(&self, row_index: usize) -> f64 {
878 [0.0, 0.0, 1.0, 1.0][row_index]
879 }
880 }
881
882 #[test]
883 fn boosting_stops_when_root_split_is_a_canary() {
884 let table = RootCanaryTable;
885
886 let model = GradientBoostedTrees::train(
887 &table,
888 TrainConfig {
889 algorithm: TrainAlgorithm::Gbm,
890 task: Task::Regression,
891 tree_type: TreeType::Cart,
892 criterion: Criterion::SecondOrder,
893 n_trees: Some(10),
894 max_features: MaxFeatures::All,
895 learning_rate: Some(0.1),
896 top_gradient_fraction: Some(1.0),
897 other_gradient_fraction: Some(0.0),
898 ..TrainConfig::default()
899 },
900 Parallelism::sequential(),
901 )
902 .unwrap();
903
904 assert_eq!(model.trees().len(), 0);
905 assert_eq!(model.training_metadata().n_trees, Some(0));
906 assert!(
907 model
908 .predict_table(&table)
909 .iter()
910 .all(|value| value.is_finite())
911 );
912 }
913}