1use crate::bootstrap::BootstrapSampler;
9use crate::ir::TrainingMetadata;
10use crate::tree::second_order::{
11 SecondOrderRegressionTreeError, SecondOrderRegressionTreeOptions,
12 train_cart_regressor_from_gradients_and_hessians_with_status,
13 train_oblivious_regressor_from_gradients_and_hessians_with_status,
14 train_randomized_regressor_from_gradients_and_hessians_with_status,
15};
16use crate::tree::shared::mix_seed;
17use crate::{
18 Criterion, FeaturePreprocessing, Model, Parallelism, PredictError, Task, TrainConfig, TreeType,
19 capture_feature_preprocessing,
20};
21use forestfire_data::TableAccess;
22use rand::SeedableRng;
23use rand::rngs::StdRng;
24use rand::seq::SliceRandom;
25
26#[derive(Debug, Clone)]
31pub struct GradientBoostedTrees {
32 task: Task,
33 tree_type: TreeType,
34 trees: Vec<Model>,
35 tree_weights: Vec<f64>,
36 base_score: f64,
37 learning_rate: f64,
38 bootstrap: bool,
39 top_gradient_fraction: f64,
40 other_gradient_fraction: f64,
41 max_features: usize,
42 seed: Option<u64>,
43 num_features: usize,
44 feature_preprocessing: Vec<FeaturePreprocessing>,
45 class_labels: Option<Vec<f64>>,
46 training_canaries: usize,
47}
48
49#[derive(Debug)]
50pub enum BoostingError {
51 InvalidTargetValue { row: usize, value: f64 },
52 UnsupportedClassificationClassCount(usize),
53 InvalidLearningRate(f64),
54 InvalidTopGradientFraction(f64),
55 InvalidOtherGradientFraction(f64),
56 SecondOrderTree(SecondOrderRegressionTreeError),
57}
58
59impl std::fmt::Display for BoostingError {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 BoostingError::InvalidTargetValue { row, value } => write!(
63 f,
64 "Boosting targets must be finite values. Found {} at row {}.",
65 value, row
66 ),
67 BoostingError::UnsupportedClassificationClassCount(count) => write!(
68 f,
69 "Gradient boosting currently supports binary classification only. Found {} classes.",
70 count
71 ),
72 BoostingError::InvalidLearningRate(value) => write!(
73 f,
74 "learning_rate must be finite and greater than 0. Found {}.",
75 value
76 ),
77 BoostingError::InvalidTopGradientFraction(value) => write!(
78 f,
79 "top_gradient_fraction must be in the interval (0, 1]. Found {}.",
80 value
81 ),
82 BoostingError::InvalidOtherGradientFraction(value) => write!(
83 f,
84 "other_gradient_fraction must be in the interval [0, 1), and top_gradient_fraction + other_gradient_fraction must be at most 1. Found {}.",
85 value
86 ),
87 BoostingError::SecondOrderTree(err) => err.fmt(f),
88 }
89 }
90}
91
92impl std::error::Error for BoostingError {}
93
94struct SampledTable<'a> {
95 base: &'a dyn TableAccess,
96 row_indices: Vec<usize>,
97}
98
99impl GradientBoostedTrees {
100 #[allow(clippy::too_many_arguments)]
101 pub fn new(
102 task: Task,
103 tree_type: TreeType,
104 trees: Vec<Model>,
105 tree_weights: Vec<f64>,
106 base_score: f64,
107 learning_rate: f64,
108 bootstrap: bool,
109 top_gradient_fraction: f64,
110 other_gradient_fraction: f64,
111 max_features: usize,
112 seed: Option<u64>,
113 num_features: usize,
114 feature_preprocessing: Vec<FeaturePreprocessing>,
115 class_labels: Option<Vec<f64>>,
116 training_canaries: usize,
117 ) -> Self {
118 Self {
119 task,
120 tree_type,
121 trees,
122 tree_weights,
123 base_score,
124 learning_rate,
125 bootstrap,
126 top_gradient_fraction,
127 other_gradient_fraction,
128 max_features,
129 seed,
130 num_features,
131 feature_preprocessing,
132 class_labels,
133 training_canaries,
134 }
135 }
136
137 #[allow(dead_code)]
138 pub(crate) fn train(
139 train_set: &dyn TableAccess,
140 config: TrainConfig,
141 parallelism: Parallelism,
142 ) -> Result<Self, BoostingError> {
143 let missing_value_strategies = config
144 .missing_value_strategy
145 .resolve_for_feature_count(train_set.binned_feature_count())
146 .unwrap_or_else(|err| {
147 panic!("unexpected training error while resolving missing strategy: {err}")
148 });
149 Self::train_with_missing_value_strategies(
150 train_set,
151 config,
152 parallelism,
153 missing_value_strategies,
154 )
155 }
156
157 pub(crate) fn train_with_missing_value_strategies(
158 train_set: &dyn TableAccess,
159 config: TrainConfig,
160 parallelism: Parallelism,
161 missing_value_strategies: Vec<crate::MissingValueStrategy>,
162 ) -> Result<Self, BoostingError> {
163 let n_trees = config.n_trees.unwrap_or(100);
164 let learning_rate = config.learning_rate.unwrap_or(0.1);
165 let bootstrap = config.bootstrap;
166 let top_gradient_fraction = config.top_gradient_fraction.unwrap_or(0.2);
167 let other_gradient_fraction = config.other_gradient_fraction.unwrap_or(0.1);
168 validate_boosting_parameters(
169 train_set,
170 learning_rate,
171 top_gradient_fraction,
172 other_gradient_fraction,
173 )?;
174
175 let max_features = config
176 .max_features
177 .resolve(config.task, train_set.n_features());
178 let base_seed = config.seed.unwrap_or(0xB005_7EED_u64);
179 let tree_options = crate::RegressionTreeOptions {
180 max_depth: config.max_depth.unwrap_or(8),
181 min_samples_split: config.min_samples_split.unwrap_or(2),
182 min_samples_leaf: config.min_samples_leaf.unwrap_or(1),
183 max_features: Some(max_features),
184 random_seed: 0,
185 missing_value_strategies,
186 canary_filter: config.canary_filter,
187 };
188 let tree_options = SecondOrderRegressionTreeOptions {
189 tree_options,
190 l2_regularization: 1.0,
191 min_sum_hessian_in_leaf: 1e-3,
192 min_gain_to_split: 0.0,
193 };
194 let feature_preprocessing = capture_feature_preprocessing(train_set);
195 let sampler = BootstrapSampler::new(train_set.n_rows());
196
197 let (mut raw_predictions, class_labels, base_score) = match config.task {
198 Task::Regression => {
199 let targets = finite_targets(train_set)?;
200 let base_score = targets.iter().sum::<f64>() / targets.len() as f64;
201 (vec![base_score; train_set.n_rows()], None, base_score)
202 }
203 Task::Classification => {
204 let (labels, encoded_targets) = binary_classification_targets(train_set)?;
205 let positive_rate = (encoded_targets.iter().sum::<f64>()
206 / encoded_targets.len() as f64)
207 .clamp(1e-6, 1.0 - 1e-6);
208 let base_score = (positive_rate / (1.0 - positive_rate)).ln();
209 (
210 vec![base_score; train_set.n_rows()],
211 Some(labels),
212 base_score,
213 )
214 }
215 };
216
217 let mut trees = Vec::with_capacity(n_trees);
218 let mut tree_weights = Vec::with_capacity(n_trees);
219 let regression_targets = if config.task == Task::Regression {
220 Some(finite_targets(train_set)?)
221 } else {
222 None
223 };
224 let classification_targets = if config.task == Task::Classification {
225 Some(binary_classification_targets(train_set)?.1)
226 } else {
227 None
228 };
229
230 for tree_index in 0..n_trees {
231 let stage_seed = mix_seed(base_seed, tree_index as u64);
232 let (gradients, hessians) = match config.task {
236 Task::Regression => squared_error_gradients_and_hessians(
237 raw_predictions.as_slice(),
238 regression_targets
239 .as_ref()
240 .expect("regression targets exist for regression boosting"),
241 ),
242 Task::Classification => logistic_gradients_and_hessians(
243 raw_predictions.as_slice(),
244 classification_targets
245 .as_ref()
246 .expect("classification targets exist for classification boosting"),
247 ),
248 };
249
250 let base_rows = if bootstrap {
251 sampler.sample(stage_seed)
252 } else {
253 (0..train_set.n_rows()).collect()
254 };
255 let sampled_rows = gradient_focus_sample(
259 &base_rows,
260 &gradients,
261 &hessians,
262 top_gradient_fraction,
263 other_gradient_fraction,
264 mix_seed(stage_seed, 0x6011_5A11),
265 );
266 let sampled_table = SampledTable::new(train_set, sampled_rows.row_indices);
267 let mut stage_tree_options = tree_options.clone();
268 stage_tree_options.tree_options.random_seed = stage_seed;
269 let stage_result = match config.tree_type {
270 TreeType::Cart => train_cart_regressor_from_gradients_and_hessians_with_status(
271 &sampled_table,
272 &sampled_rows.gradients,
273 &sampled_rows.hessians,
274 parallelism,
275 stage_tree_options,
276 ),
277 TreeType::Randomized => {
278 train_randomized_regressor_from_gradients_and_hessians_with_status(
279 &sampled_table,
280 &sampled_rows.gradients,
281 &sampled_rows.hessians,
282 parallelism,
283 stage_tree_options,
284 )
285 }
286 TreeType::Oblivious => {
287 train_oblivious_regressor_from_gradients_and_hessians_with_status(
288 &sampled_table,
289 &sampled_rows.gradients,
290 &sampled_rows.hessians,
291 parallelism,
292 stage_tree_options,
293 )
294 }
295 _ => unreachable!("boosting tree type validated by training dispatch"),
296 }
297 .map_err(BoostingError::SecondOrderTree)?;
298
299 if stage_result.root_canary_selected {
302 break;
303 }
304
305 let stage_tree = stage_result.model;
306 let stage_model = Model::DecisionTreeRegressor(stage_tree);
307 let stage_predictions = stage_model.predict_table(train_set);
308 for (raw_prediction, stage_prediction) in raw_predictions
309 .iter_mut()
310 .zip(stage_predictions.iter().copied())
311 {
312 *raw_prediction += learning_rate * stage_prediction;
315 }
316 tree_weights.push(learning_rate);
317 trees.push(stage_model);
318 }
319
320 Ok(Self::new(
321 config.task,
322 config.tree_type,
323 trees,
324 tree_weights,
325 base_score,
326 learning_rate,
327 bootstrap,
328 top_gradient_fraction,
329 other_gradient_fraction,
330 max_features,
331 config.seed,
332 train_set.n_features(),
333 feature_preprocessing,
334 class_labels,
335 train_set.canaries(),
336 ))
337 }
338
339 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
340 match self.task {
341 Task::Regression => self.predict_regression_table(table),
342 Task::Classification => self.predict_classification_table(table),
343 }
344 }
345
346 pub fn predict_proba_table(
347 &self,
348 table: &dyn TableAccess,
349 ) -> Result<Vec<Vec<f64>>, PredictError> {
350 if self.task != Task::Classification {
351 return Err(PredictError::ProbabilityPredictionRequiresClassification);
352 }
353 Ok(self
354 .raw_scores(table)
355 .into_iter()
356 .map(|score| {
357 let positive = sigmoid(score);
358 vec![1.0 - positive, positive]
359 })
360 .collect())
361 }
362
363 pub fn task(&self) -> Task {
364 self.task
365 }
366
367 pub fn criterion(&self) -> Criterion {
368 Criterion::SecondOrder
369 }
370
371 pub fn tree_type(&self) -> TreeType {
372 self.tree_type
373 }
374
375 pub fn trees(&self) -> &[Model] {
376 &self.trees
377 }
378
379 pub fn tree_weights(&self) -> &[f64] {
380 &self.tree_weights
381 }
382
383 pub fn base_score(&self) -> f64 {
384 self.base_score
385 }
386
387 pub fn num_features(&self) -> usize {
388 self.num_features
389 }
390
391 pub fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
392 &self.feature_preprocessing
393 }
394
395 pub fn class_labels(&self) -> Option<Vec<f64>> {
396 self.class_labels.clone()
397 }
398
399 pub fn training_metadata(&self) -> TrainingMetadata {
400 TrainingMetadata {
401 algorithm: "gbm".to_string(),
402 task: match self.task {
403 Task::Regression => "regression".to_string(),
404 Task::Classification => "classification".to_string(),
405 },
406 tree_type: match self.tree_type {
407 TreeType::Cart => "cart".to_string(),
408 TreeType::Randomized => "randomized".to_string(),
409 TreeType::Oblivious => "oblivious".to_string(),
410 _ => unreachable!("boosting only supports cart/randomized/oblivious"),
411 },
412 criterion: "second_order".to_string(),
413 canaries: self.training_canaries,
414 compute_oob: false,
415 max_depth: self.trees.first().and_then(Model::max_depth),
416 min_samples_split: self.trees.first().and_then(Model::min_samples_split),
417 min_samples_leaf: self.trees.first().and_then(Model::min_samples_leaf),
418 n_trees: Some(self.trees.len()),
419 max_features: Some(self.max_features),
420 seed: self.seed,
421 oob_score: None,
422 class_labels: self.class_labels.clone(),
423 learning_rate: Some(self.learning_rate),
424 bootstrap: Some(self.bootstrap),
425 top_gradient_fraction: Some(self.top_gradient_fraction),
426 other_gradient_fraction: Some(self.other_gradient_fraction),
427 }
428 }
429
430 fn raw_scores(&self, table: &dyn TableAccess) -> Vec<f64> {
431 let mut scores = vec![self.base_score; table.n_rows()];
432 for (tree, weight) in self.trees.iter().zip(self.tree_weights.iter().copied()) {
433 let predictions = tree.predict_table(table);
434 for (score, prediction) in scores.iter_mut().zip(predictions.iter().copied()) {
435 *score += weight * prediction;
436 }
437 }
438 scores
439 }
440
441 fn predict_regression_table(&self, table: &dyn TableAccess) -> Vec<f64> {
442 self.raw_scores(table)
443 }
444
445 fn predict_classification_table(&self, table: &dyn TableAccess) -> Vec<f64> {
446 let class_labels = self
447 .class_labels
448 .as_ref()
449 .expect("classification boosting stores class labels");
450 self.raw_scores(table)
451 .into_iter()
452 .map(|score| {
453 if sigmoid(score) >= 0.5 {
454 class_labels[1]
455 } else {
456 class_labels[0]
457 }
458 })
459 .collect()
460 }
461}
462
463struct GradientFocusedSample {
464 row_indices: Vec<usize>,
465 gradients: Vec<f64>,
466 hessians: Vec<f64>,
467}
468
469impl<'a> SampledTable<'a> {
470 fn new(base: &'a dyn TableAccess, row_indices: Vec<usize>) -> Self {
471 Self { base, row_indices }
472 }
473
474 fn resolve_row(&self, row_index: usize) -> usize {
475 self.row_indices[row_index]
476 }
477}
478
479impl TableAccess for SampledTable<'_> {
480 fn n_rows(&self) -> usize {
481 self.row_indices.len()
482 }
483
484 fn n_features(&self) -> usize {
485 self.base.n_features()
486 }
487
488 fn canaries(&self) -> usize {
489 self.base.canaries()
490 }
491
492 fn numeric_bin_cap(&self) -> usize {
493 self.base.numeric_bin_cap()
494 }
495
496 fn binned_feature_count(&self) -> usize {
497 self.base.binned_feature_count()
498 }
499
500 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
501 self.base
502 .feature_value(feature_index, self.resolve_row(row_index))
503 }
504
505 fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
506 self.base
507 .is_missing(feature_index, self.resolve_row(row_index))
508 }
509
510 fn is_binary_feature(&self, index: usize) -> bool {
511 self.base.is_binary_feature(index)
512 }
513
514 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
515 self.base
516 .binned_value(feature_index, self.resolve_row(row_index))
517 }
518
519 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
520 self.base
521 .binned_boolean_value(feature_index, self.resolve_row(row_index))
522 }
523
524 fn binned_column_kind(&self, index: usize) -> forestfire_data::BinnedColumnKind {
525 self.base.binned_column_kind(index)
526 }
527
528 fn is_binary_binned_feature(&self, index: usize) -> bool {
529 self.base.is_binary_binned_feature(index)
530 }
531
532 fn target_value(&self, row_index: usize) -> f64 {
533 self.base.target_value(self.resolve_row(row_index))
534 }
535}
536
537fn validate_boosting_parameters(
538 train_set: &dyn TableAccess,
539 learning_rate: f64,
540 top_gradient_fraction: f64,
541 other_gradient_fraction: f64,
542) -> Result<(), BoostingError> {
543 if train_set.n_rows() == 0 {
544 return Err(BoostingError::InvalidLearningRate(learning_rate));
545 }
546 if !learning_rate.is_finite() || learning_rate <= 0.0 {
547 return Err(BoostingError::InvalidLearningRate(learning_rate));
548 }
549 if !top_gradient_fraction.is_finite()
550 || top_gradient_fraction <= 0.0
551 || top_gradient_fraction > 1.0
552 {
553 return Err(BoostingError::InvalidTopGradientFraction(
554 top_gradient_fraction,
555 ));
556 }
557 if !other_gradient_fraction.is_finite()
558 || !(0.0..1.0).contains(&other_gradient_fraction)
559 || top_gradient_fraction + other_gradient_fraction > 1.0
560 {
561 return Err(BoostingError::InvalidOtherGradientFraction(
562 other_gradient_fraction,
563 ));
564 }
565 Ok(())
566}
567
568fn finite_targets(train_set: &dyn TableAccess) -> Result<Vec<f64>, BoostingError> {
569 (0..train_set.n_rows())
570 .map(|row_index| {
571 let value = train_set.target_value(row_index);
572 if value.is_finite() {
573 Ok(value)
574 } else {
575 Err(BoostingError::InvalidTargetValue {
576 row: row_index,
577 value,
578 })
579 }
580 })
581 .collect()
582}
583
584fn binary_classification_targets(
585 train_set: &dyn TableAccess,
586) -> Result<(Vec<f64>, Vec<f64>), BoostingError> {
587 let mut labels = finite_targets(train_set)?;
588 labels.sort_by(|left, right| left.total_cmp(right));
589 labels.dedup_by(|left, right| left.total_cmp(right).is_eq());
590 if labels.len() != 2 {
591 return Err(BoostingError::UnsupportedClassificationClassCount(
592 labels.len(),
593 ));
594 }
595
596 let negative = labels[0];
597 let encoded = (0..train_set.n_rows())
598 .map(|row_index| {
599 if train_set
600 .target_value(row_index)
601 .total_cmp(&negative)
602 .is_eq()
603 {
604 0.0
605 } else {
606 1.0
607 }
608 })
609 .collect();
610 Ok((labels, encoded))
611}
612
613fn squared_error_gradients_and_hessians(
614 raw_predictions: &[f64],
615 targets: &[f64],
616) -> (Vec<f64>, Vec<f64>) {
617 (
618 raw_predictions
619 .iter()
620 .zip(targets.iter())
621 .map(|(prediction, target)| prediction - target)
622 .collect(),
623 vec![1.0; targets.len()],
624 )
625}
626
627fn logistic_gradients_and_hessians(
628 raw_predictions: &[f64],
629 targets: &[f64],
630) -> (Vec<f64>, Vec<f64>) {
631 let mut gradients = Vec::with_capacity(targets.len());
632 let mut hessians = Vec::with_capacity(targets.len());
633 for (raw_prediction, target) in raw_predictions.iter().zip(targets.iter()) {
634 let probability = sigmoid(*raw_prediction);
635 gradients.push(probability - target);
636 hessians.push((probability * (1.0 - probability)).max(1e-12));
637 }
638 (gradients, hessians)
639}
640
641fn sigmoid(value: f64) -> f64 {
642 if value >= 0.0 {
643 let exp = (-value).exp();
644 1.0 / (1.0 + exp)
645 } else {
646 let exp = value.exp();
647 exp / (1.0 + exp)
648 }
649}
650
651fn gradient_focus_sample(
652 base_rows: &[usize],
653 gradients: &[f64],
654 hessians: &[f64],
655 top_gradient_fraction: f64,
656 other_gradient_fraction: f64,
657 seed: u64,
658) -> GradientFocusedSample {
659 let mut ranked = base_rows
660 .iter()
661 .copied()
662 .map(|row_index| (row_index, gradients[row_index].abs()))
663 .collect::<Vec<_>>();
664 ranked.sort_by(|(left_row, left_abs), (right_row, right_abs)| {
665 right_abs
666 .total_cmp(left_abs)
667 .then_with(|| left_row.cmp(right_row))
668 });
669
670 let top_count = ((ranked.len() as f64) * top_gradient_fraction)
671 .ceil()
672 .clamp(1.0, ranked.len() as f64) as usize;
673 let mut row_indices = Vec::with_capacity(ranked.len());
674 let mut sampled_gradients = Vec::with_capacity(ranked.len());
675 let mut sampled_hessians = Vec::with_capacity(ranked.len());
676
677 for (row_index, _) in ranked.iter().take(top_count) {
678 row_indices.push(*row_index);
679 sampled_gradients.push(gradients[*row_index]);
680 sampled_hessians.push(hessians[*row_index]);
681 }
682
683 if top_count < ranked.len() && other_gradient_fraction > 0.0 {
684 let remaining = ranked[top_count..]
685 .iter()
686 .map(|(row_index, _)| *row_index)
687 .collect::<Vec<_>>();
688 let other_count = ((remaining.len() as f64) * other_gradient_fraction)
689 .ceil()
690 .min(remaining.len() as f64) as usize;
691 if other_count > 0 {
692 let mut remaining = remaining;
693 let mut rng = StdRng::seed_from_u64(seed);
694 remaining.shuffle(&mut rng);
695 let gradient_scale = (1.0 - top_gradient_fraction) / other_gradient_fraction;
696 for row_index in remaining.into_iter().take(other_count) {
697 row_indices.push(row_index);
698 sampled_gradients.push(gradients[row_index] * gradient_scale);
699 sampled_hessians.push(hessians[row_index] * gradient_scale);
700 }
701 }
702 }
703
704 GradientFocusedSample {
705 row_indices,
706 gradients: sampled_gradients,
707 hessians: sampled_hessians,
708 }
709}
710
711#[cfg(test)]
712mod tests {
713 use super::*;
714 use crate::{CanaryFilter, MaxFeatures, TrainAlgorithm, TrainConfig};
715 use forestfire_data::{BinnedColumnKind, TableAccess};
716 use forestfire_data::{DenseTable, NumericBins};
717
718 #[test]
719 fn regression_boosting_fits_simple_signal() {
720 let table = DenseTable::with_options(
721 vec![
722 vec![0.0],
723 vec![0.0],
724 vec![1.0],
725 vec![1.0],
726 vec![2.0],
727 vec![2.0],
728 ],
729 vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
730 0,
731 NumericBins::fixed(8).unwrap(),
732 )
733 .unwrap();
734
735 let model = GradientBoostedTrees::train(
736 &table,
737 TrainConfig {
738 algorithm: TrainAlgorithm::Gbm,
739 task: Task::Regression,
740 tree_type: TreeType::Cart,
741 criterion: Criterion::SecondOrder,
742 n_trees: Some(20),
743 learning_rate: Some(0.2),
744 max_depth: Some(2),
745 ..TrainConfig::default()
746 },
747 Parallelism::sequential(),
748 )
749 .unwrap();
750
751 let predictions = model.predict_table(&table);
752 assert!(predictions[0] < predictions[2]);
753 assert!(predictions[2] < predictions[4]);
754 }
755
756 #[test]
757 fn classification_boosting_produces_binary_probabilities() {
758 let table = DenseTable::with_options(
759 vec![vec![0.0], vec![0.1], vec![0.9], vec![1.0]],
760 vec![0.0, 0.0, 1.0, 1.0],
761 0,
762 NumericBins::fixed(8).unwrap(),
763 )
764 .unwrap();
765
766 let model = GradientBoostedTrees::train(
767 &table,
768 TrainConfig {
769 algorithm: TrainAlgorithm::Gbm,
770 task: Task::Classification,
771 tree_type: TreeType::Cart,
772 criterion: Criterion::SecondOrder,
773 n_trees: Some(25),
774 learning_rate: Some(0.2),
775 max_depth: Some(2),
776 ..TrainConfig::default()
777 },
778 Parallelism::sequential(),
779 )
780 .unwrap();
781
782 let probabilities = model.predict_proba_table(&table).unwrap();
783 assert_eq!(probabilities.len(), 4);
784 assert!(probabilities[0][1] < 0.5);
785 assert!(probabilities[3][1] > 0.5);
786 }
787
788 #[test]
789 fn classification_boosting_rejects_multiclass_targets() {
790 let table =
791 DenseTable::new(vec![vec![0.0], vec![1.0], vec![2.0]], vec![0.0, 1.0, 2.0]).unwrap();
792
793 let error = GradientBoostedTrees::train(
794 &table,
795 TrainConfig {
796 algorithm: TrainAlgorithm::Gbm,
797 task: Task::Classification,
798 tree_type: TreeType::Cart,
799 criterion: Criterion::SecondOrder,
800 ..TrainConfig::default()
801 },
802 Parallelism::sequential(),
803 )
804 .unwrap_err();
805
806 assert!(matches!(
807 error,
808 BoostingError::UnsupportedClassificationClassCount(3)
809 ));
810 }
811
812 struct RootCanaryTable;
813
814 struct FilteredRootCanaryTable;
815
816 impl TableAccess for RootCanaryTable {
817 fn n_rows(&self) -> usize {
818 4
819 }
820
821 fn n_features(&self) -> usize {
822 1
823 }
824
825 fn canaries(&self) -> usize {
826 1
827 }
828
829 fn numeric_bin_cap(&self) -> usize {
830 2
831 }
832
833 fn binned_feature_count(&self) -> usize {
834 2
835 }
836
837 fn feature_value(&self, _feature_index: usize, _row_index: usize) -> f64 {
838 0.0
839 }
840
841 fn is_missing(&self, _feature_index: usize, _row_index: usize) -> bool {
842 false
843 }
844
845 fn is_binary_feature(&self, _index: usize) -> bool {
846 true
847 }
848
849 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
850 match feature_index {
851 0 => 0,
852 1 => u16::from(row_index >= 2),
853 _ => unreachable!(),
854 }
855 }
856
857 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
858 Some(match feature_index {
859 0 => false,
860 1 => row_index >= 2,
861 _ => unreachable!(),
862 })
863 }
864
865 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
866 match index {
867 0 => BinnedColumnKind::Real { source_index: 0 },
868 1 => BinnedColumnKind::Canary {
869 source_index: 0,
870 copy_index: 0,
871 },
872 _ => unreachable!(),
873 }
874 }
875
876 fn is_binary_binned_feature(&self, _index: usize) -> bool {
877 true
878 }
879
880 fn target_value(&self, row_index: usize) -> f64 {
881 [0.0, 0.0, 1.0, 1.0][row_index]
882 }
883 }
884
885 impl TableAccess for FilteredRootCanaryTable {
886 fn n_rows(&self) -> usize {
887 4
888 }
889
890 fn n_features(&self) -> usize {
891 1
892 }
893
894 fn canaries(&self) -> usize {
895 1
896 }
897
898 fn numeric_bin_cap(&self) -> usize {
899 2
900 }
901
902 fn binned_feature_count(&self) -> usize {
903 2
904 }
905
906 fn feature_value(&self, _feature_index: usize, _row_index: usize) -> f64 {
907 0.0
908 }
909
910 fn is_missing(&self, _feature_index: usize, _row_index: usize) -> bool {
911 false
912 }
913
914 fn is_binary_feature(&self, _index: usize) -> bool {
915 true
916 }
917
918 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
919 u16::from(
920 self.binned_boolean_value(feature_index, row_index)
921 .expect("all features are observed"),
922 )
923 }
924
925 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
926 Some(match feature_index {
927 0 => row_index == 3,
928 1 => row_index >= 2,
929 _ => unreachable!(),
930 })
931 }
932
933 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
934 match index {
935 0 => BinnedColumnKind::Real { source_index: 0 },
936 1 => BinnedColumnKind::Canary {
937 source_index: 0,
938 copy_index: 0,
939 },
940 _ => unreachable!(),
941 }
942 }
943
944 fn is_binary_binned_feature(&self, _index: usize) -> bool {
945 true
946 }
947
948 fn target_value(&self, row_index: usize) -> f64 {
949 [0.0, 0.0, 1.0, 1.0][row_index]
950 }
951 }
952
953 #[test]
954 fn boosting_stops_when_root_split_is_a_canary() {
955 let table = RootCanaryTable;
956
957 let model = GradientBoostedTrees::train(
958 &table,
959 TrainConfig {
960 algorithm: TrainAlgorithm::Gbm,
961 task: Task::Regression,
962 tree_type: TreeType::Cart,
963 criterion: Criterion::SecondOrder,
964 n_trees: Some(10),
965 max_features: MaxFeatures::All,
966 learning_rate: Some(0.1),
967 top_gradient_fraction: Some(1.0),
968 other_gradient_fraction: Some(0.0),
969 ..TrainConfig::default()
970 },
971 Parallelism::sequential(),
972 )
973 .unwrap();
974
975 assert_eq!(model.trees().len(), 0);
976 assert_eq!(model.training_metadata().n_trees, Some(0));
977 assert!(
978 model
979 .predict_table(&table)
980 .iter()
981 .all(|value| value.is_finite())
982 );
983 }
984
985 #[test]
986 fn boosting_can_use_a_real_root_split_inside_top_n_canary_filter() {
987 let table = FilteredRootCanaryTable;
988
989 let model = GradientBoostedTrees::train(
990 &table,
991 TrainConfig {
992 algorithm: TrainAlgorithm::Gbm,
993 task: Task::Regression,
994 tree_type: TreeType::Cart,
995 criterion: Criterion::SecondOrder,
996 n_trees: Some(10),
997 max_features: MaxFeatures::All,
998 learning_rate: Some(0.1),
999 top_gradient_fraction: Some(1.0),
1000 other_gradient_fraction: Some(0.0),
1001 canary_filter: CanaryFilter::TopN(2),
1002 ..TrainConfig::default()
1003 },
1004 Parallelism::sequential(),
1005 )
1006 .unwrap();
1007
1008 assert!(!model.trees().is_empty());
1009 }
1010}