1use crate::ir::{
9 BinaryChildren, BinarySplit, IndexedLeaf, LeafIndexing, LeafPayload, NodeStats, NodeTreeNode,
10 ObliviousLevel, ObliviousSplit as IrObliviousSplit, TrainingMetadata, TreeDefinition,
11 criterion_name, feature_name, threshold_upper_bound,
12};
13use crate::tree::shared::{
14 FeatureHistogram, HistogramBin, MissingBranchDirection, build_feature_histograms,
15 candidate_feature_indices, choose_random_threshold, node_seed, partition_rows_for_binary_split,
16 subtract_feature_histograms,
17};
18use crate::{
19 Criterion, FeaturePreprocessing, MissingValueStrategy, Parallelism,
20 capture_feature_preprocessing,
21};
22use forestfire_data::TableAccess;
23use rayon::prelude::*;
24use std::collections::BTreeSet;
25use std::error::Error;
26use std::fmt::{Display, Formatter};
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum RegressionTreeAlgorithm {
30 Cart,
31 Randomized,
32 Oblivious,
33}
34
35#[derive(Debug, Clone)]
37pub struct RegressionTreeOptions {
38 pub max_depth: usize,
39 pub min_samples_split: usize,
40 pub min_samples_leaf: usize,
41 pub max_features: Option<usize>,
42 pub random_seed: u64,
43 pub missing_value_strategies: Vec<MissingValueStrategy>,
44}
45
46impl Default for RegressionTreeOptions {
47 fn default() -> Self {
48 Self {
49 max_depth: 8,
50 min_samples_split: 2,
51 min_samples_leaf: 1,
52 max_features: None,
53 random_seed: 0,
54 missing_value_strategies: Vec::new(),
55 }
56 }
57}
58
59impl RegressionTreeOptions {
60 fn missing_value_strategy(&self, feature_index: usize) -> MissingValueStrategy {
61 self.missing_value_strategies
62 .get(feature_index)
63 .copied()
64 .unwrap_or(MissingValueStrategy::Heuristic)
65 }
66}
67
68#[derive(Debug)]
69pub enum RegressionTreeError {
70 EmptyTarget,
71 InvalidTargetValue { row: usize, value: f64 },
72}
73
74impl Display for RegressionTreeError {
75 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
76 match self {
77 RegressionTreeError::EmptyTarget => {
78 write!(f, "Cannot train on an empty target vector.")
79 }
80 RegressionTreeError::InvalidTargetValue { row, value } => write!(
81 f,
82 "Regression targets must be finite values. Found {} at row {}.",
83 value, row
84 ),
85 }
86 }
87}
88
89impl Error for RegressionTreeError {}
90
91#[derive(Debug, Clone)]
93pub struct DecisionTreeRegressor {
94 algorithm: RegressionTreeAlgorithm,
95 criterion: Criterion,
96 structure: RegressionTreeStructure,
97 options: RegressionTreeOptions,
98 num_features: usize,
99 feature_preprocessing: Vec<FeaturePreprocessing>,
100 training_canaries: usize,
101}
102
103#[derive(Debug, Clone)]
104pub(crate) enum RegressionTreeStructure {
105 Standard {
106 nodes: Vec<RegressionNode>,
107 root: usize,
108 },
109 Oblivious {
110 splits: Vec<ObliviousSplit>,
111 leaf_values: Vec<f64>,
112 leaf_sample_counts: Vec<usize>,
113 leaf_variances: Vec<Option<f64>>,
114 },
115}
116
117#[derive(Debug, Clone)]
118pub(crate) enum RegressionNode {
119 Leaf {
120 value: f64,
121 sample_count: usize,
122 variance: Option<f64>,
123 },
124 BinarySplit {
125 feature_index: usize,
126 threshold_bin: u16,
127 missing_direction: MissingBranchDirection,
128 missing_value: f64,
129 left_child: usize,
130 right_child: usize,
131 sample_count: usize,
132 impurity: f64,
133 gain: f64,
134 variance: Option<f64>,
135 },
136}
137
138#[derive(Debug, Clone, Copy)]
139pub(crate) struct ObliviousSplit {
140 pub(crate) feature_index: usize,
141 pub(crate) threshold_bin: u16,
142 pub(crate) sample_count: usize,
143 pub(crate) impurity: f64,
144 pub(crate) gain: f64,
145}
146
147#[derive(Debug, Clone)]
148struct RegressionSplitCandidate {
149 feature_index: usize,
150 threshold_bin: u16,
151 score: f64,
152 missing_direction: MissingBranchDirection,
153}
154
155#[derive(Debug, Clone)]
156struct ObliviousLeafState {
157 start: usize,
158 end: usize,
159 value: f64,
160 variance: Option<f64>,
161 sum: f64,
162 sum_sq: f64,
163}
164
165impl ObliviousLeafState {
166 fn len(&self) -> usize {
167 self.end - self.start
168 }
169}
170
171#[derive(Debug, Clone, Copy)]
172struct ObliviousSplitCandidate {
173 feature_index: usize,
174 threshold_bin: u16,
175 score: f64,
176}
177
178#[derive(Debug, Clone, Copy)]
179struct BinarySplitChoice {
180 feature_index: usize,
181 threshold_bin: u16,
182 score: f64,
183 missing_direction: MissingBranchDirection,
184}
185
186#[derive(Debug, Clone)]
187struct RegressionHistogramBin {
188 count: usize,
189 sum: f64,
190 sum_sq: f64,
191}
192
193impl HistogramBin for RegressionHistogramBin {
194 fn subtract(parent: &Self, child: &Self) -> Self {
195 Self {
196 count: parent.count - child.count,
197 sum: parent.sum - child.sum,
198 sum_sq: parent.sum_sq - child.sum_sq,
199 }
200 }
201
202 fn is_observed(&self) -> bool {
203 self.count > 0
204 }
205}
206
207type RegressionFeatureHistogram = FeatureHistogram<RegressionHistogramBin>;
208
209pub fn train_cart_regressor(
210 train_set: &dyn TableAccess,
211) -> Result<DecisionTreeRegressor, RegressionTreeError> {
212 train_cart_regressor_with_criterion(train_set, Criterion::Mean)
213}
214
215pub fn train_cart_regressor_with_criterion(
216 train_set: &dyn TableAccess,
217 criterion: Criterion,
218) -> Result<DecisionTreeRegressor, RegressionTreeError> {
219 train_cart_regressor_with_criterion_and_parallelism(
220 train_set,
221 criterion,
222 Parallelism::sequential(),
223 )
224}
225
226pub(crate) fn train_cart_regressor_with_criterion_and_parallelism(
227 train_set: &dyn TableAccess,
228 criterion: Criterion,
229 parallelism: Parallelism,
230) -> Result<DecisionTreeRegressor, RegressionTreeError> {
231 train_cart_regressor_with_criterion_parallelism_and_options(
232 train_set,
233 criterion,
234 parallelism,
235 RegressionTreeOptions::default(),
236 )
237}
238
239pub(crate) fn train_cart_regressor_with_criterion_parallelism_and_options(
240 train_set: &dyn TableAccess,
241 criterion: Criterion,
242 parallelism: Parallelism,
243 options: RegressionTreeOptions,
244) -> Result<DecisionTreeRegressor, RegressionTreeError> {
245 train_regressor(
246 train_set,
247 RegressionTreeAlgorithm::Cart,
248 criterion,
249 parallelism,
250 options,
251 )
252}
253
254pub fn train_oblivious_regressor(
255 train_set: &dyn TableAccess,
256) -> Result<DecisionTreeRegressor, RegressionTreeError> {
257 train_oblivious_regressor_with_criterion(train_set, Criterion::Mean)
258}
259
260pub fn train_oblivious_regressor_with_criterion(
261 train_set: &dyn TableAccess,
262 criterion: Criterion,
263) -> Result<DecisionTreeRegressor, RegressionTreeError> {
264 train_oblivious_regressor_with_criterion_and_parallelism(
265 train_set,
266 criterion,
267 Parallelism::sequential(),
268 )
269}
270
271pub(crate) fn train_oblivious_regressor_with_criterion_and_parallelism(
272 train_set: &dyn TableAccess,
273 criterion: Criterion,
274 parallelism: Parallelism,
275) -> Result<DecisionTreeRegressor, RegressionTreeError> {
276 train_oblivious_regressor_with_criterion_parallelism_and_options(
277 train_set,
278 criterion,
279 parallelism,
280 RegressionTreeOptions::default(),
281 )
282}
283
284pub(crate) fn train_oblivious_regressor_with_criterion_parallelism_and_options(
285 train_set: &dyn TableAccess,
286 criterion: Criterion,
287 parallelism: Parallelism,
288 options: RegressionTreeOptions,
289) -> Result<DecisionTreeRegressor, RegressionTreeError> {
290 train_regressor(
291 train_set,
292 RegressionTreeAlgorithm::Oblivious,
293 criterion,
294 parallelism,
295 options,
296 )
297}
298
299pub fn train_randomized_regressor(
300 train_set: &dyn TableAccess,
301) -> Result<DecisionTreeRegressor, RegressionTreeError> {
302 train_randomized_regressor_with_criterion(train_set, Criterion::Mean)
303}
304
305pub fn train_randomized_regressor_with_criterion(
306 train_set: &dyn TableAccess,
307 criterion: Criterion,
308) -> Result<DecisionTreeRegressor, RegressionTreeError> {
309 train_randomized_regressor_with_criterion_and_parallelism(
310 train_set,
311 criterion,
312 Parallelism::sequential(),
313 )
314}
315
316pub(crate) fn train_randomized_regressor_with_criterion_and_parallelism(
317 train_set: &dyn TableAccess,
318 criterion: Criterion,
319 parallelism: Parallelism,
320) -> Result<DecisionTreeRegressor, RegressionTreeError> {
321 train_randomized_regressor_with_criterion_parallelism_and_options(
322 train_set,
323 criterion,
324 parallelism,
325 RegressionTreeOptions::default(),
326 )
327}
328
329pub(crate) fn train_randomized_regressor_with_criterion_parallelism_and_options(
330 train_set: &dyn TableAccess,
331 criterion: Criterion,
332 parallelism: Parallelism,
333 options: RegressionTreeOptions,
334) -> Result<DecisionTreeRegressor, RegressionTreeError> {
335 train_regressor(
336 train_set,
337 RegressionTreeAlgorithm::Randomized,
338 criterion,
339 parallelism,
340 options,
341 )
342}
343
344fn train_regressor(
345 train_set: &dyn TableAccess,
346 algorithm: RegressionTreeAlgorithm,
347 criterion: Criterion,
348 parallelism: Parallelism,
349 options: RegressionTreeOptions,
350) -> Result<DecisionTreeRegressor, RegressionTreeError> {
351 if train_set.n_rows() == 0 {
352 return Err(RegressionTreeError::EmptyTarget);
353 }
354
355 let targets = finite_targets(train_set)?;
356 let structure = match algorithm {
357 RegressionTreeAlgorithm::Cart => {
358 let mut nodes = Vec::new();
359 let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
360 let context = BuildContext {
361 table: train_set,
362 targets: &targets,
363 criterion,
364 parallelism,
365 options: options.clone(),
366 algorithm,
367 };
368 let root = build_binary_node_in_place(&context, &mut nodes, &mut all_rows, 0);
372 RegressionTreeStructure::Standard { nodes, root }
373 }
374 RegressionTreeAlgorithm::Randomized => {
375 let mut nodes = Vec::new();
376 let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
377 let context = BuildContext {
378 table: train_set,
379 targets: &targets,
380 criterion,
381 parallelism,
382 options: options.clone(),
383 algorithm,
384 };
385 let root = build_binary_node_in_place(&context, &mut nodes, &mut all_rows, 0);
386 RegressionTreeStructure::Standard { nodes, root }
387 }
388 RegressionTreeAlgorithm::Oblivious => {
389 train_oblivious_structure(train_set, &targets, criterion, parallelism, options.clone())
392 }
393 };
394
395 Ok(DecisionTreeRegressor {
396 algorithm,
397 criterion,
398 structure,
399 options,
400 num_features: train_set.n_features(),
401 feature_preprocessing: capture_feature_preprocessing(train_set),
402 training_canaries: train_set.canaries(),
403 })
404}
405
406impl DecisionTreeRegressor {
407 pub fn algorithm(&self) -> RegressionTreeAlgorithm {
409 self.algorithm
410 }
411
412 pub fn criterion(&self) -> Criterion {
414 self.criterion
415 }
416
417 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
419 (0..table.n_rows())
420 .map(|row_idx| self.predict_row(table, row_idx))
421 .collect()
422 }
423
424 fn predict_row(&self, table: &dyn TableAccess, row_idx: usize) -> f64 {
425 match &self.structure {
426 RegressionTreeStructure::Standard { nodes, root } => {
427 let mut node_index = *root;
428
429 loop {
430 match &nodes[node_index] {
431 RegressionNode::Leaf { value, .. } => return *value,
432 RegressionNode::BinarySplit {
433 feature_index,
434 threshold_bin,
435 missing_direction,
436 missing_value,
437 left_child,
438 right_child,
439 ..
440 } => {
441 if table.is_missing(*feature_index, row_idx) {
442 match missing_direction {
443 MissingBranchDirection::Left => {
444 node_index = *left_child;
445 }
446 MissingBranchDirection::Right => {
447 node_index = *right_child;
448 }
449 MissingBranchDirection::Node => return *missing_value,
450 }
451 continue;
452 }
453 let bin = table.binned_value(*feature_index, row_idx);
454 node_index = if bin <= *threshold_bin {
455 *left_child
456 } else {
457 *right_child
458 };
459 }
460 }
461 }
462 }
463 RegressionTreeStructure::Oblivious {
464 splits,
465 leaf_values,
466 ..
467 } => {
468 let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
469 let go_right =
470 table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
471 (leaf_index << 1) | usize::from(go_right)
472 });
473
474 leaf_values[leaf_index]
475 }
476 }
477 }
478
479 pub(crate) fn num_features(&self) -> usize {
480 self.num_features
481 }
482
483 pub(crate) fn structure(&self) -> &RegressionTreeStructure {
484 &self.structure
485 }
486
487 pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
488 &self.feature_preprocessing
489 }
490
491 pub(crate) fn training_metadata(&self) -> TrainingMetadata {
492 TrainingMetadata {
493 algorithm: "dt".to_string(),
494 task: "regression".to_string(),
495 tree_type: match self.algorithm {
496 RegressionTreeAlgorithm::Cart => "cart".to_string(),
497 RegressionTreeAlgorithm::Randomized => "randomized".to_string(),
498 RegressionTreeAlgorithm::Oblivious => "oblivious".to_string(),
499 },
500 criterion: criterion_name(self.criterion).to_string(),
501 canaries: self.training_canaries,
502 compute_oob: false,
503 max_depth: Some(self.options.max_depth),
504 min_samples_split: Some(self.options.min_samples_split),
505 min_samples_leaf: Some(self.options.min_samples_leaf),
506 n_trees: None,
507 max_features: self.options.max_features,
508 seed: None,
509 oob_score: None,
510 class_labels: None,
511 learning_rate: None,
512 bootstrap: None,
513 top_gradient_fraction: None,
514 other_gradient_fraction: None,
515 }
516 }
517
518 pub(crate) fn to_ir_tree(&self) -> TreeDefinition {
519 match &self.structure {
520 RegressionTreeStructure::Standard { nodes, root } => {
521 let depths = standard_node_depths(nodes, *root);
522 TreeDefinition::NodeTree {
523 tree_id: 0,
524 weight: 1.0,
525 root_node_id: *root,
526 nodes: nodes
527 .iter()
528 .enumerate()
529 .map(|(node_id, node)| match node {
530 RegressionNode::Leaf {
531 value,
532 sample_count,
533 variance,
534 } => NodeTreeNode::Leaf {
535 node_id,
536 depth: depths[node_id],
537 leaf: LeafPayload::RegressionValue { value: *value },
538 stats: NodeStats {
539 sample_count: *sample_count,
540 impurity: None,
541 gain: None,
542 class_counts: None,
543 variance: *variance,
544 },
545 },
546 RegressionNode::BinarySplit {
547 feature_index,
548 threshold_bin,
549 missing_direction,
550 missing_value: _,
551 left_child,
552 right_child,
553 sample_count,
554 impurity,
555 gain,
556 variance,
557 } => NodeTreeNode::BinaryBranch {
558 node_id,
559 depth: depths[node_id],
560 split: binary_split_ir(
561 *feature_index,
562 *threshold_bin,
563 *missing_direction,
564 &self.feature_preprocessing,
565 ),
566 children: BinaryChildren {
567 left: *left_child,
568 right: *right_child,
569 },
570 stats: NodeStats {
571 sample_count: *sample_count,
572 impurity: Some(*impurity),
573 gain: Some(*gain),
574 class_counts: None,
575 variance: *variance,
576 },
577 },
578 })
579 .collect(),
580 }
581 }
582 RegressionTreeStructure::Oblivious {
583 splits,
584 leaf_values,
585 leaf_sample_counts,
586 leaf_variances,
587 } => TreeDefinition::ObliviousLevels {
588 tree_id: 0,
589 weight: 1.0,
590 depth: splits.len(),
591 levels: splits
592 .iter()
593 .enumerate()
594 .map(|(level, split)| ObliviousLevel {
595 level,
596 split: oblivious_split_ir(
597 split.feature_index,
598 split.threshold_bin,
599 &self.feature_preprocessing,
600 ),
601 stats: NodeStats {
602 sample_count: split.sample_count,
603 impurity: Some(split.impurity),
604 gain: Some(split.gain),
605 class_counts: None,
606 variance: None,
607 },
608 })
609 .collect(),
610 leaf_indexing: LeafIndexing {
611 bit_order: "msb_first".to_string(),
612 index_formula: "sum(bit[level] << (depth - 1 - level))".to_string(),
613 },
614 leaves: leaf_values
615 .iter()
616 .enumerate()
617 .map(|(leaf_index, value)| IndexedLeaf {
618 leaf_index,
619 leaf: LeafPayload::RegressionValue { value: *value },
620 stats: NodeStats {
621 sample_count: leaf_sample_counts[leaf_index],
622 impurity: None,
623 gain: None,
624 class_counts: None,
625 variance: leaf_variances[leaf_index],
626 },
627 })
628 .collect(),
629 },
630 }
631 }
632
633 pub(crate) fn from_ir_parts(
634 algorithm: RegressionTreeAlgorithm,
635 criterion: Criterion,
636 structure: RegressionTreeStructure,
637 options: RegressionTreeOptions,
638 num_features: usize,
639 feature_preprocessing: Vec<FeaturePreprocessing>,
640 training_canaries: usize,
641 ) -> Self {
642 Self {
643 algorithm,
644 criterion,
645 structure,
646 options: options.clone(),
647 num_features,
648 feature_preprocessing,
649 training_canaries,
650 }
651 }
652}
653
654fn standard_node_depths(nodes: &[RegressionNode], root: usize) -> Vec<usize> {
655 let mut depths = vec![0; nodes.len()];
656 populate_depths(nodes, root, 0, &mut depths);
657 depths
658}
659
660fn populate_depths(nodes: &[RegressionNode], node_id: usize, depth: usize, depths: &mut [usize]) {
661 depths[node_id] = depth;
662 match &nodes[node_id] {
663 RegressionNode::Leaf { .. } => {}
664 RegressionNode::BinarySplit {
665 left_child,
666 right_child,
667 ..
668 } => {
669 populate_depths(nodes, *left_child, depth + 1, depths);
670 populate_depths(nodes, *right_child, depth + 1, depths);
671 }
672 }
673}
674
675fn binary_split_ir(
676 feature_index: usize,
677 threshold_bin: u16,
678 _missing_direction: MissingBranchDirection,
679 preprocessing: &[FeaturePreprocessing],
680) -> BinarySplit {
681 match preprocessing.get(feature_index) {
682 Some(FeaturePreprocessing::Binary) => BinarySplit::BooleanTest {
683 feature_index,
684 feature_name: feature_name(feature_index),
685 false_child_semantics: "left".to_string(),
686 true_child_semantics: "right".to_string(),
687 },
688 Some(FeaturePreprocessing::Numeric { .. }) | None => BinarySplit::NumericBinThreshold {
689 feature_index,
690 feature_name: feature_name(feature_index),
691 operator: "<=".to_string(),
692 threshold_bin,
693 threshold_upper_bound: threshold_upper_bound(
694 preprocessing,
695 feature_index,
696 threshold_bin,
697 ),
698 comparison_dtype: "uint16".to_string(),
699 },
700 }
701}
702
703fn oblivious_split_ir(
704 feature_index: usize,
705 threshold_bin: u16,
706 preprocessing: &[FeaturePreprocessing],
707) -> IrObliviousSplit {
708 match preprocessing.get(feature_index) {
709 Some(FeaturePreprocessing::Binary) => IrObliviousSplit::BooleanTest {
710 feature_index,
711 feature_name: feature_name(feature_index),
712 bit_when_false: 0,
713 bit_when_true: 1,
714 },
715 Some(FeaturePreprocessing::Numeric { .. }) | None => {
716 IrObliviousSplit::NumericBinThreshold {
717 feature_index,
718 feature_name: feature_name(feature_index),
719 operator: "<=".to_string(),
720 threshold_bin,
721 threshold_upper_bound: threshold_upper_bound(
722 preprocessing,
723 feature_index,
724 threshold_bin,
725 ),
726 comparison_dtype: "uint16".to_string(),
727 bit_when_true: 0,
728 bit_when_false: 1,
729 }
730 }
731 }
732}
733
734struct BuildContext<'a> {
735 table: &'a dyn TableAccess,
736 targets: &'a [f64],
737 criterion: Criterion,
738 parallelism: Parallelism,
739 options: RegressionTreeOptions,
740 algorithm: RegressionTreeAlgorithm,
741}
742
743fn build_regression_node_histograms(
744 table: &dyn TableAccess,
745 targets: &[f64],
746 rows: &[usize],
747) -> Vec<RegressionFeatureHistogram> {
748 build_feature_histograms(
749 table,
750 rows,
751 |_| RegressionHistogramBin {
752 count: 0,
753 sum: 0.0,
754 sum_sq: 0.0,
755 },
756 |_feature_index, payload, row_idx| {
757 let value = targets[row_idx];
758 payload.count += 1;
759 payload.sum += value;
760 payload.sum_sq += value * value;
761 },
762 )
763}
764
765fn subtract_regression_node_histograms(
766 parent: &[RegressionFeatureHistogram],
767 child: &[RegressionFeatureHistogram],
768) -> Vec<RegressionFeatureHistogram> {
769 subtract_feature_histograms(parent, child)
770}
771
772fn finite_targets(train_set: &dyn TableAccess) -> Result<Vec<f64>, RegressionTreeError> {
773 (0..train_set.n_rows())
774 .map(|row_idx| {
775 let value = train_set.target_value(row_idx);
776 if value.is_finite() {
777 Ok(value)
778 } else {
779 Err(RegressionTreeError::InvalidTargetValue {
780 row: row_idx,
781 value,
782 })
783 }
784 })
785 .collect()
786}
787
788fn build_binary_node_in_place(
789 context: &BuildContext<'_>,
790 nodes: &mut Vec<RegressionNode>,
791 rows: &mut [usize],
792 depth: usize,
793) -> usize {
794 build_binary_node_in_place_with_hist(context, nodes, rows, depth, None)
795}
796
797fn build_binary_node_in_place_with_hist(
798 context: &BuildContext<'_>,
799 nodes: &mut Vec<RegressionNode>,
800 rows: &mut [usize],
801 depth: usize,
802 histograms: Option<Vec<RegressionFeatureHistogram>>,
803) -> usize {
804 let leaf_value = regression_value(rows, context.targets, context.criterion);
805 let leaf_variance = variance(rows, context.targets);
806
807 if rows.is_empty()
808 || depth >= context.options.max_depth
809 || rows.len() < context.options.min_samples_split
810 || has_constant_target(rows, context.targets)
811 {
812 return push_leaf(nodes, leaf_value, rows.len(), leaf_variance);
813 }
814
815 let histograms = if matches!(context.criterion, Criterion::Mean) {
816 Some(histograms.unwrap_or_else(|| {
817 build_regression_node_histograms(context.table, context.targets, rows)
818 }))
819 } else {
820 None
821 };
822 let feature_indices = candidate_feature_indices(
823 context.table.binned_feature_count(),
824 context.options.max_features,
825 node_seed(context.options.random_seed, depth, rows, 0xA11C_E5E1u64),
826 );
827 let best_split = if context.parallelism.enabled() {
828 feature_indices
829 .into_par_iter()
830 .filter_map(|feature_index| {
831 if let Some(histograms) = histograms.as_ref() {
832 score_binary_split_choice_from_hist(
833 context,
834 &histograms[feature_index],
835 feature_index,
836 rows,
837 )
838 } else {
839 score_binary_split_choice(context, feature_index, rows)
840 }
841 })
842 .max_by(|left, right| left.score.total_cmp(&right.score))
843 } else {
844 feature_indices
845 .into_iter()
846 .filter_map(|feature_index| {
847 if let Some(histograms) = histograms.as_ref() {
848 score_binary_split_choice_from_hist(
849 context,
850 &histograms[feature_index],
851 feature_index,
852 rows,
853 )
854 } else {
855 score_binary_split_choice(context, feature_index, rows)
856 }
857 })
858 .max_by(|left, right| left.score.total_cmp(&right.score))
859 };
860
861 match best_split {
862 Some(best_split)
863 if context
864 .table
865 .is_canary_binned_feature(best_split.feature_index) =>
866 {
867 push_leaf(nodes, leaf_value, rows.len(), leaf_variance)
868 }
869 Some(best_split) if best_split.score > 0.0 => {
870 let impurity = regression_loss(rows, context.targets, context.criterion);
871 let left_count = partition_rows_for_binary_split(
872 context.table,
873 best_split.feature_index,
874 best_split.threshold_bin,
875 best_split.missing_direction,
876 rows,
877 );
878 let (left_rows, right_rows) = rows.split_at_mut(left_count);
879 let (left_child, right_child) = if let Some(histograms) = histograms {
880 if left_rows.len() <= right_rows.len() {
881 let left_histograms =
882 build_regression_node_histograms(context.table, context.targets, left_rows);
883 let right_histograms =
884 subtract_regression_node_histograms(&histograms, &left_histograms);
885 (
886 build_binary_node_in_place_with_hist(
887 context,
888 nodes,
889 left_rows,
890 depth + 1,
891 Some(left_histograms),
892 ),
893 build_binary_node_in_place_with_hist(
894 context,
895 nodes,
896 right_rows,
897 depth + 1,
898 Some(right_histograms),
899 ),
900 )
901 } else {
902 let right_histograms = build_regression_node_histograms(
903 context.table,
904 context.targets,
905 right_rows,
906 );
907 let left_histograms =
908 subtract_regression_node_histograms(&histograms, &right_histograms);
909 (
910 build_binary_node_in_place_with_hist(
911 context,
912 nodes,
913 left_rows,
914 depth + 1,
915 Some(left_histograms),
916 ),
917 build_binary_node_in_place_with_hist(
918 context,
919 nodes,
920 right_rows,
921 depth + 1,
922 Some(right_histograms),
923 ),
924 )
925 }
926 } else {
927 (
928 build_binary_node_in_place(context, nodes, left_rows, depth + 1),
929 build_binary_node_in_place(context, nodes, right_rows, depth + 1),
930 )
931 };
932
933 push_node(
934 nodes,
935 RegressionNode::BinarySplit {
936 feature_index: best_split.feature_index,
937 threshold_bin: best_split.threshold_bin,
938 missing_direction: best_split.missing_direction,
939 missing_value: leaf_value,
940 left_child,
941 right_child,
942 sample_count: rows.len(),
943 impurity,
944 gain: best_split.score,
945 variance: leaf_variance,
946 },
947 )
948 }
949 _ => push_leaf(nodes, leaf_value, rows.len(), leaf_variance),
950 }
951}
952
953fn train_oblivious_structure(
954 table: &dyn TableAccess,
955 targets: &[f64],
956 criterion: Criterion,
957 parallelism: Parallelism,
958 options: RegressionTreeOptions,
959) -> RegressionTreeStructure {
960 let mut row_indices: Vec<usize> = (0..table.n_rows()).collect();
961 let (root_sum, root_sum_sq) = sum_stats(&row_indices, targets);
962 let mut leaves = vec![ObliviousLeafState {
963 start: 0,
964 end: row_indices.len(),
965 value: regression_value_from_stats(&row_indices, targets, criterion, root_sum),
966 variance: variance_from_stats(row_indices.len(), root_sum, root_sum_sq),
967 sum: root_sum,
968 sum_sq: root_sum_sq,
969 }];
970 let mut splits = Vec::new();
971
972 for depth in 0..options.max_depth {
973 if leaves
974 .iter()
975 .all(|leaf| leaf.len() < options.min_samples_split)
976 {
977 break;
978 }
979 let feature_indices = candidate_feature_indices(
980 table.binned_feature_count(),
981 options.max_features,
982 node_seed(options.random_seed, depth, &[], 0x0B11_A10Cu64),
983 );
984 let best_split = if parallelism.enabled() {
985 feature_indices
986 .into_par_iter()
987 .filter_map(|feature_index| {
988 score_oblivious_split(
989 table,
990 &row_indices,
991 targets,
992 feature_index,
993 &leaves,
994 criterion,
995 options.min_samples_leaf,
996 )
997 })
998 .max_by(|left, right| left.score.total_cmp(&right.score))
999 } else {
1000 feature_indices
1001 .into_iter()
1002 .filter_map(|feature_index| {
1003 score_oblivious_split(
1004 table,
1005 &row_indices,
1006 targets,
1007 feature_index,
1008 &leaves,
1009 criterion,
1010 options.min_samples_leaf,
1011 )
1012 })
1013 .max_by(|left, right| left.score.total_cmp(&right.score))
1014 };
1015
1016 let Some(best_split) = best_split.filter(|candidate| candidate.score > 0.0) else {
1017 break;
1018 };
1019 if table.is_canary_binned_feature(best_split.feature_index) {
1020 break;
1021 }
1022
1023 leaves = split_oblivious_leaves_in_place(
1024 table,
1025 &mut row_indices,
1026 targets,
1027 leaves,
1028 best_split.feature_index,
1029 best_split.threshold_bin,
1030 criterion,
1031 );
1032 splits.push(ObliviousSplit {
1033 feature_index: best_split.feature_index,
1034 threshold_bin: best_split.threshold_bin,
1035 sample_count: table.n_rows(),
1036 impurity: leaves
1037 .iter()
1038 .map(|leaf| leaf_regression_loss(leaf, &row_indices, targets, criterion))
1039 .sum(),
1040 gain: best_split.score,
1041 });
1042 }
1043
1044 RegressionTreeStructure::Oblivious {
1045 splits,
1046 leaf_values: leaves.iter().map(|leaf| leaf.value).collect(),
1047 leaf_sample_counts: leaves.iter().map(ObliviousLeafState::len).collect(),
1048 leaf_variances: leaves.iter().map(|leaf| leaf.variance).collect(),
1049 }
1050}
1051
1052#[allow(clippy::too_many_arguments)]
1053fn score_split(
1054 table: &dyn TableAccess,
1055 targets: &[f64],
1056 feature_index: usize,
1057 rows: &[usize],
1058 criterion: Criterion,
1059 min_samples_leaf: usize,
1060 algorithm: RegressionTreeAlgorithm,
1061 strategy: MissingValueStrategy,
1062) -> Option<RegressionSplitCandidate> {
1063 if table.is_binary_binned_feature(feature_index) {
1064 return score_binary_split(
1065 table,
1066 targets,
1067 feature_index,
1068 rows,
1069 criterion,
1070 min_samples_leaf,
1071 strategy,
1072 );
1073 }
1074 let has_missing = feature_has_missing(table, feature_index, rows);
1075 if matches!(criterion, Criterion::Mean) && !has_missing {
1076 if matches!(algorithm, RegressionTreeAlgorithm::Randomized) {
1077 if let Some(candidate) = score_randomized_split_mean_fast(
1078 table,
1079 targets,
1080 feature_index,
1081 rows,
1082 min_samples_leaf,
1083 ) {
1084 return Some(candidate);
1085 }
1086 } else if let Some(candidate) =
1087 score_numeric_split_mean_fast(table, targets, feature_index, rows, min_samples_leaf)
1088 {
1089 return Some(candidate);
1090 }
1091 }
1092 if matches!(algorithm, RegressionTreeAlgorithm::Randomized) {
1093 return score_randomized_split(
1094 table,
1095 targets,
1096 feature_index,
1097 rows,
1098 criterion,
1099 min_samples_leaf,
1100 strategy,
1101 );
1102 }
1103 if has_missing && matches!(strategy, MissingValueStrategy::Heuristic) {
1104 return score_split_heuristic_missing_assignment(
1105 table,
1106 targets,
1107 feature_index,
1108 rows,
1109 criterion,
1110 min_samples_leaf,
1111 );
1112 }
1113 let parent_loss = regression_loss(rows, targets, criterion);
1114
1115 rows.iter()
1116 .copied()
1117 .filter(|row_idx| !table.is_missing(feature_index, *row_idx))
1118 .map(|row_idx| table.binned_value(feature_index, row_idx))
1119 .collect::<BTreeSet<_>>()
1120 .into_iter()
1121 .filter_map(|threshold_bin| {
1122 evaluate_regression_missing_assignment(
1123 table,
1124 targets,
1125 feature_index,
1126 rows,
1127 criterion,
1128 min_samples_leaf,
1129 threshold_bin,
1130 parent_loss,
1131 )
1132 })
1133 .max_by(|left, right| left.score.total_cmp(&right.score))
1134}
1135
1136fn score_randomized_split(
1137 table: &dyn TableAccess,
1138 targets: &[f64],
1139 feature_index: usize,
1140 rows: &[usize],
1141 criterion: Criterion,
1142 min_samples_leaf: usize,
1143 _strategy: MissingValueStrategy,
1144) -> Option<RegressionSplitCandidate> {
1145 let candidate_thresholds = rows
1146 .iter()
1147 .copied()
1148 .filter(|row_idx| !table.is_missing(feature_index, *row_idx))
1149 .map(|row_idx| table.binned_value(feature_index, row_idx))
1150 .collect::<BTreeSet<_>>()
1151 .into_iter()
1152 .collect::<Vec<_>>();
1153 let threshold_bin =
1154 choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xA11CE551u64)?;
1155
1156 let parent_loss = regression_loss(rows, targets, criterion);
1157 evaluate_regression_missing_assignment(
1158 table,
1159 targets,
1160 feature_index,
1161 rows,
1162 criterion,
1163 min_samples_leaf,
1164 threshold_bin,
1165 parent_loss,
1166 )
1167}
1168
1169fn score_oblivious_split(
1170 table: &dyn TableAccess,
1171 row_indices: &[usize],
1172 targets: &[f64],
1173 feature_index: usize,
1174 leaves: &[ObliviousLeafState],
1175 criterion: Criterion,
1176 min_samples_leaf: usize,
1177) -> Option<ObliviousSplitCandidate> {
1178 if table.is_binary_binned_feature(feature_index) {
1179 if matches!(criterion, Criterion::Mean)
1180 && let Some(candidate) = score_binary_oblivious_split_mean_fast(
1181 table,
1182 row_indices,
1183 targets,
1184 feature_index,
1185 leaves,
1186 min_samples_leaf,
1187 )
1188 {
1189 return Some(candidate);
1190 }
1191 return score_binary_oblivious_split(
1192 table,
1193 row_indices,
1194 targets,
1195 feature_index,
1196 leaves,
1197 criterion,
1198 min_samples_leaf,
1199 );
1200 }
1201 if matches!(criterion, Criterion::Mean)
1202 && let Some(candidate) = score_numeric_oblivious_split_mean_fast(
1203 table,
1204 row_indices,
1205 targets,
1206 feature_index,
1207 leaves,
1208 min_samples_leaf,
1209 )
1210 {
1211 return Some(candidate);
1212 }
1213 let candidate_thresholds = leaves
1214 .iter()
1215 .flat_map(|leaf| {
1216 row_indices[leaf.start..leaf.end]
1217 .iter()
1218 .map(|row_idx| table.binned_value(feature_index, *row_idx))
1219 })
1220 .collect::<BTreeSet<_>>();
1221
1222 candidate_thresholds
1223 .into_iter()
1224 .filter_map(|threshold_bin| {
1225 let score = leaves.iter().fold(0.0, |score, leaf| {
1226 let leaf_rows = &row_indices[leaf.start..leaf.end];
1227 let (left_rows, right_rows): (Vec<usize>, Vec<usize>) =
1228 leaf_rows.iter().copied().partition(|row_idx| {
1229 table.binned_value(feature_index, *row_idx) <= threshold_bin
1230 });
1231
1232 if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1233 return score;
1234 }
1235
1236 score + regression_loss(leaf_rows, targets, criterion)
1237 - (regression_loss(&left_rows, targets, criterion)
1238 + regression_loss(&right_rows, targets, criterion))
1239 });
1240
1241 (score > 0.0).then_some(ObliviousSplitCandidate {
1242 feature_index,
1243 threshold_bin,
1244 score,
1245 })
1246 })
1247 .max_by(|left, right| left.score.total_cmp(&right.score))
1248}
1249
1250fn split_oblivious_leaves_in_place(
1251 table: &dyn TableAccess,
1252 row_indices: &mut [usize],
1253 targets: &[f64],
1254 leaves: Vec<ObliviousLeafState>,
1255 feature_index: usize,
1256 threshold_bin: u16,
1257 criterion: Criterion,
1258) -> Vec<ObliviousLeafState> {
1259 let mut next_leaves = Vec::with_capacity(leaves.len() * 2);
1260 for leaf in leaves {
1261 let fallback_value = leaf.value;
1262 let left_count = partition_rows_for_binary_split(
1263 table,
1264 feature_index,
1265 threshold_bin,
1266 MissingBranchDirection::Right,
1267 &mut row_indices[leaf.start..leaf.end],
1268 );
1269 let mid = leaf.start + left_count;
1270 let left_rows = &row_indices[leaf.start..mid];
1271 let right_rows = &row_indices[mid..leaf.end];
1272 let (left_sum, left_sum_sq) = sum_stats(left_rows, targets);
1273 let (right_sum, right_sum_sq) = sum_stats(right_rows, targets);
1274 next_leaves.push(ObliviousLeafState {
1275 start: leaf.start,
1276 end: mid,
1277 value: if left_rows.is_empty() {
1278 fallback_value
1279 } else {
1280 regression_value_from_stats(left_rows, targets, criterion, left_sum)
1281 },
1282 variance: variance_from_stats(left_rows.len(), left_sum, left_sum_sq),
1283 sum: left_sum,
1284 sum_sq: left_sum_sq,
1285 });
1286 next_leaves.push(ObliviousLeafState {
1287 start: mid,
1288 end: leaf.end,
1289 value: if right_rows.is_empty() {
1290 fallback_value
1291 } else {
1292 regression_value_from_stats(right_rows, targets, criterion, right_sum)
1293 },
1294 variance: variance_from_stats(right_rows.len(), right_sum, right_sum_sq),
1295 sum: right_sum,
1296 sum_sq: right_sum_sq,
1297 });
1298 }
1299 next_leaves
1300}
1301
1302fn variance(rows: &[usize], targets: &[f64]) -> Option<f64> {
1303 let (sum, sum_sq) = sum_stats(rows, targets);
1304 variance_from_stats(rows.len(), sum, sum_sq)
1305}
1306
1307fn mean(rows: &[usize], targets: &[f64]) -> f64 {
1308 if rows.is_empty() {
1309 0.0
1310 } else {
1311 rows.iter().map(|row_idx| targets[*row_idx]).sum::<f64>() / rows.len() as f64
1312 }
1313}
1314
1315fn median(rows: &[usize], targets: &[f64]) -> f64 {
1316 if rows.is_empty() {
1317 return 0.0;
1318 }
1319 let mut values: Vec<f64> = rows.iter().map(|row_idx| targets[*row_idx]).collect();
1320 values.sort_by(|left, right| left.total_cmp(right));
1321
1322 let mid = values.len() / 2;
1323 if values.len().is_multiple_of(2) {
1324 (values[mid - 1] + values[mid]) / 2.0
1325 } else {
1326 values[mid]
1327 }
1328}
1329
1330fn sum_squared_error(rows: &[usize], targets: &[f64]) -> f64 {
1331 let mean = mean(rows, targets);
1332 rows.iter()
1333 .map(|row_idx| {
1334 let diff = targets[*row_idx] - mean;
1335 diff * diff
1336 })
1337 .sum()
1338}
1339
1340fn sum_absolute_error(rows: &[usize], targets: &[f64]) -> f64 {
1341 let median = median(rows, targets);
1342 rows.iter()
1343 .map(|row_idx| (targets[*row_idx] - median).abs())
1344 .sum()
1345}
1346
1347fn regression_value(rows: &[usize], targets: &[f64], criterion: Criterion) -> f64 {
1348 let (sum, _sum_sq) = sum_stats(rows, targets);
1349 regression_value_from_stats(rows, targets, criterion, sum)
1350}
1351
1352fn regression_value_from_stats(
1353 rows: &[usize],
1354 targets: &[f64],
1355 criterion: Criterion,
1356 sum: f64,
1357) -> f64 {
1358 match criterion {
1359 Criterion::Mean => {
1360 if rows.is_empty() {
1361 0.0
1362 } else {
1363 sum / rows.len() as f64
1364 }
1365 }
1366 Criterion::Median => median(rows, targets),
1367 _ => unreachable!("regression criterion only supports mean or median"),
1368 }
1369}
1370
1371fn regression_loss(rows: &[usize], targets: &[f64], criterion: Criterion) -> f64 {
1372 match criterion {
1373 Criterion::Mean => sum_squared_error(rows, targets),
1374 Criterion::Median => sum_absolute_error(rows, targets),
1375 _ => unreachable!("regression criterion only supports mean or median"),
1376 }
1377}
1378
1379fn score_binary_split(
1380 table: &dyn TableAccess,
1381 targets: &[f64],
1382 feature_index: usize,
1383 rows: &[usize],
1384 criterion: Criterion,
1385 min_samples_leaf: usize,
1386 strategy: MissingValueStrategy,
1387) -> Option<RegressionSplitCandidate> {
1388 if matches!(strategy, MissingValueStrategy::Heuristic) {
1389 return score_binary_split_heuristic(
1390 table,
1391 targets,
1392 feature_index,
1393 rows,
1394 criterion,
1395 min_samples_leaf,
1396 );
1397 }
1398 let parent_loss = regression_loss(rows, targets, criterion);
1399 evaluate_regression_missing_assignment(
1400 table,
1401 targets,
1402 feature_index,
1403 rows,
1404 criterion,
1405 min_samples_leaf,
1406 0,
1407 parent_loss,
1408 )
1409}
1410
1411fn score_binary_split_heuristic(
1412 table: &dyn TableAccess,
1413 targets: &[f64],
1414 feature_index: usize,
1415 rows: &[usize],
1416 criterion: Criterion,
1417 min_samples_leaf: usize,
1418) -> Option<RegressionSplitCandidate> {
1419 let observed_rows = rows
1420 .iter()
1421 .copied()
1422 .filter(|row_idx| !table.is_missing(feature_index, *row_idx))
1423 .collect::<Vec<_>>();
1424 if observed_rows.is_empty() {
1425 return None;
1426 }
1427 let parent_loss = regression_loss(&observed_rows, targets, criterion);
1428 let mut left_rows = Vec::new();
1429 let mut right_rows = Vec::new();
1430 for row_idx in observed_rows.iter().copied() {
1431 if !table
1432 .binned_boolean_value(feature_index, row_idx)
1433 .expect("observed binary feature must expose boolean values")
1434 {
1435 left_rows.push(row_idx);
1436 } else {
1437 right_rows.push(row_idx);
1438 }
1439 }
1440 if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1441 return None;
1442 }
1443 evaluate_regression_missing_assignment(
1444 table,
1445 targets,
1446 feature_index,
1447 rows,
1448 criterion,
1449 min_samples_leaf,
1450 0,
1451 parent_loss,
1452 )
1453}
1454
1455fn score_split_heuristic_missing_assignment(
1456 table: &dyn TableAccess,
1457 targets: &[f64],
1458 feature_index: usize,
1459 rows: &[usize],
1460 criterion: Criterion,
1461 min_samples_leaf: usize,
1462) -> Option<RegressionSplitCandidate> {
1463 let observed_rows = rows
1464 .iter()
1465 .copied()
1466 .filter(|row_idx| !table.is_missing(feature_index, *row_idx))
1467 .collect::<Vec<_>>();
1468 if observed_rows.is_empty() {
1469 return None;
1470 }
1471 let parent_loss = regression_loss(&observed_rows, targets, criterion);
1472 let threshold_bin = observed_rows
1473 .iter()
1474 .copied()
1475 .map(|row_idx| table.binned_value(feature_index, row_idx))
1476 .collect::<BTreeSet<_>>()
1477 .into_iter()
1478 .filter_map(|threshold_bin| {
1479 evaluate_regression_observed_split(
1480 table,
1481 targets,
1482 feature_index,
1483 &observed_rows,
1484 criterion,
1485 min_samples_leaf,
1486 threshold_bin,
1487 parent_loss,
1488 )
1489 .map(|score| (threshold_bin, score))
1490 })
1491 .max_by(|left, right| left.1.total_cmp(&right.1))
1492 .map(|(threshold_bin, _)| threshold_bin)?;
1493 evaluate_regression_missing_assignment(
1494 table,
1495 targets,
1496 feature_index,
1497 rows,
1498 criterion,
1499 min_samples_leaf,
1500 threshold_bin,
1501 parent_loss,
1502 )
1503}
1504
1505#[allow(clippy::too_many_arguments)]
1506fn evaluate_regression_observed_split(
1507 table: &dyn TableAccess,
1508 targets: &[f64],
1509 feature_index: usize,
1510 observed_rows: &[usize],
1511 criterion: Criterion,
1512 min_samples_leaf: usize,
1513 threshold_bin: u16,
1514 parent_loss: f64,
1515) -> Option<f64> {
1516 let mut left_rows = Vec::new();
1517 let mut right_rows = Vec::new();
1518 for row_idx in observed_rows.iter().copied() {
1519 if table.binned_value(feature_index, row_idx) <= threshold_bin {
1520 left_rows.push(row_idx);
1521 } else {
1522 right_rows.push(row_idx);
1523 }
1524 }
1525 if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1526 return None;
1527 }
1528 Some(
1529 parent_loss
1530 - (regression_loss(&left_rows, targets, criterion)
1531 + regression_loss(&right_rows, targets, criterion)),
1532 )
1533}
1534
1535fn score_binary_split_choice(
1536 context: &BuildContext<'_>,
1537 feature_index: usize,
1538 rows: &[usize],
1539) -> Option<BinarySplitChoice> {
1540 if matches!(context.criterion, Criterion::Mean) {
1541 if context.table.is_binary_binned_feature(feature_index) {
1542 if feature_has_missing(context.table, feature_index, rows) {
1543 return score_split(
1544 context.table,
1545 context.targets,
1546 feature_index,
1547 rows,
1548 context.criterion,
1549 context.options.min_samples_leaf,
1550 context.algorithm,
1551 context.options.missing_value_strategy(feature_index),
1552 )
1553 .map(|candidate| BinarySplitChoice {
1554 feature_index: candidate.feature_index,
1555 threshold_bin: candidate.threshold_bin,
1556 score: candidate.score,
1557 missing_direction: candidate.missing_direction,
1558 });
1559 }
1560 return score_binary_split_choice_mean(context, feature_index, rows);
1561 }
1562 if feature_has_missing(context.table, feature_index, rows) {
1563 return score_split(
1564 context.table,
1565 context.targets,
1566 feature_index,
1567 rows,
1568 context.criterion,
1569 context.options.min_samples_leaf,
1570 context.algorithm,
1571 context.options.missing_value_strategy(feature_index),
1572 )
1573 .map(|candidate| BinarySplitChoice {
1574 feature_index: candidate.feature_index,
1575 threshold_bin: candidate.threshold_bin,
1576 score: candidate.score,
1577 missing_direction: candidate.missing_direction,
1578 });
1579 }
1580 return match context.algorithm {
1581 RegressionTreeAlgorithm::Cart => {
1582 score_numeric_split_choice_mean_fast(context, feature_index, rows)
1583 }
1584 RegressionTreeAlgorithm::Randomized => {
1585 score_randomized_split_choice_mean_fast(context, feature_index, rows)
1586 }
1587 RegressionTreeAlgorithm::Oblivious => None,
1588 };
1589 }
1590
1591 score_split(
1592 context.table,
1593 context.targets,
1594 feature_index,
1595 rows,
1596 context.criterion,
1597 context.options.min_samples_leaf,
1598 context.algorithm,
1599 context.options.missing_value_strategy(feature_index),
1600 )
1601 .map(|candidate| BinarySplitChoice {
1602 feature_index: candidate.feature_index,
1603 threshold_bin: candidate.threshold_bin,
1604 score: candidate.score,
1605 missing_direction: candidate.missing_direction,
1606 })
1607}
1608
1609fn score_binary_split_choice_from_hist(
1610 context: &BuildContext<'_>,
1611 histogram: &RegressionFeatureHistogram,
1612 feature_index: usize,
1613 rows: &[usize],
1614) -> Option<BinarySplitChoice> {
1615 if !matches!(context.criterion, Criterion::Mean) {
1616 return score_binary_split_choice(context, feature_index, rows);
1617 }
1618
1619 match histogram {
1620 RegressionFeatureHistogram::Binary {
1621 false_bin,
1622 true_bin,
1623 missing_bin,
1624 } if missing_bin.count == 0 => score_binary_split_choice_mean_from_stats(
1625 context,
1626 feature_index,
1627 false_bin.count,
1628 false_bin.sum,
1629 false_bin.sum_sq,
1630 true_bin.count,
1631 true_bin.sum,
1632 true_bin.sum_sq,
1633 ),
1634 RegressionFeatureHistogram::Binary { .. } => {
1635 score_binary_split_choice(context, feature_index, rows)
1636 }
1637 RegressionFeatureHistogram::Numeric {
1638 bins,
1639 observed_bins,
1640 } if bins
1641 .get(context.table.numeric_bin_cap())
1642 .is_none_or(|missing_bin| missing_bin.count == 0) =>
1643 {
1644 match context.algorithm {
1645 RegressionTreeAlgorithm::Cart => score_numeric_split_choice_mean_from_hist(
1646 context,
1647 feature_index,
1648 rows.len(),
1649 bins,
1650 observed_bins,
1651 ),
1652 RegressionTreeAlgorithm::Randomized => {
1653 score_randomized_split_choice_mean_from_hist(
1654 context,
1655 feature_index,
1656 rows,
1657 bins,
1658 observed_bins,
1659 )
1660 }
1661 RegressionTreeAlgorithm::Oblivious => None,
1662 }
1663 }
1664 RegressionFeatureHistogram::Numeric { .. } => {
1665 score_binary_split_choice(context, feature_index, rows)
1666 }
1667 }
1668}
1669
1670#[allow(clippy::too_many_arguments)]
1671fn score_binary_split_choice_mean_from_stats(
1672 context: &BuildContext<'_>,
1673 feature_index: usize,
1674 left_count: usize,
1675 left_sum: f64,
1676 left_sum_sq: f64,
1677 right_count: usize,
1678 right_sum: f64,
1679 right_sum_sq: f64,
1680) -> Option<BinarySplitChoice> {
1681 if left_count < context.options.min_samples_leaf
1682 || right_count < context.options.min_samples_leaf
1683 {
1684 return None;
1685 }
1686 let total_count = left_count + right_count;
1687 let total_sum = left_sum + right_sum;
1688 let total_sum_sq = left_sum_sq + right_sum_sq;
1689 let parent_loss = total_sum_sq - (total_sum * total_sum) / total_count as f64;
1690 let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1691 let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1692 Some(BinarySplitChoice {
1693 feature_index,
1694 threshold_bin: 0,
1695 score: parent_loss - (left_loss + right_loss),
1696 missing_direction: MissingBranchDirection::Node,
1697 })
1698}
1699
1700fn score_numeric_split_choice_mean_from_hist(
1701 context: &BuildContext<'_>,
1702 feature_index: usize,
1703 row_count: usize,
1704 bins: &[RegressionHistogramBin],
1705 observed_bins: &[usize],
1706) -> Option<BinarySplitChoice> {
1707 if observed_bins.len() <= 1 {
1708 return None;
1709 }
1710 let total_sum = bins.iter().map(|bin| bin.sum).sum::<f64>();
1711 let total_sum_sq = bins.iter().map(|bin| bin.sum_sq).sum::<f64>();
1712 let parent_loss = total_sum_sq - (total_sum * total_sum) / row_count as f64;
1713 let mut left_count = 0usize;
1714 let mut left_sum = 0.0;
1715 let mut left_sum_sq = 0.0;
1716 let mut best_threshold = None;
1717 let mut best_score = f64::NEG_INFINITY;
1718
1719 for &bin in observed_bins {
1720 left_count += bins[bin].count;
1721 left_sum += bins[bin].sum;
1722 left_sum_sq += bins[bin].sum_sq;
1723 let right_count = row_count - left_count;
1724 if left_count < context.options.min_samples_leaf
1725 || right_count < context.options.min_samples_leaf
1726 {
1727 continue;
1728 }
1729 let right_sum = total_sum - left_sum;
1730 let right_sum_sq = total_sum_sq - left_sum_sq;
1731 let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1732 let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1733 let score = parent_loss - (left_loss + right_loss);
1734 if score > best_score {
1735 best_score = score;
1736 best_threshold = Some(bin as u16);
1737 }
1738 }
1739
1740 best_threshold.map(|threshold_bin| BinarySplitChoice {
1741 feature_index,
1742 threshold_bin,
1743 score: best_score,
1744 missing_direction: MissingBranchDirection::Node,
1745 })
1746}
1747
1748fn score_randomized_split_choice_mean_from_hist(
1749 context: &BuildContext<'_>,
1750 feature_index: usize,
1751 rows: &[usize],
1752 bins: &[RegressionHistogramBin],
1753 observed_bins: &[usize],
1754) -> Option<BinarySplitChoice> {
1755 let candidate_thresholds = observed_bins
1756 .iter()
1757 .copied()
1758 .map(|bin| bin as u16)
1759 .collect::<Vec<_>>();
1760 let threshold_bin =
1761 choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xA11CE551u64)?;
1762 let total_sum = bins.iter().map(|bin| bin.sum).sum::<f64>();
1763 let total_sum_sq = bins.iter().map(|bin| bin.sum_sq).sum::<f64>();
1764 let mut left_count = 0usize;
1765 let mut left_sum = 0.0;
1766 let mut left_sum_sq = 0.0;
1767 for bin in 0..=threshold_bin as usize {
1768 if bin >= bins.len() {
1769 break;
1770 }
1771 left_count += bins[bin].count;
1772 left_sum += bins[bin].sum;
1773 left_sum_sq += bins[bin].sum_sq;
1774 }
1775 let right_count = rows.len() - left_count;
1776 if left_count < context.options.min_samples_leaf
1777 || right_count < context.options.min_samples_leaf
1778 {
1779 return None;
1780 }
1781 let parent_loss = total_sum_sq - (total_sum * total_sum) / rows.len() as f64;
1782 let right_sum = total_sum - left_sum;
1783 let right_sum_sq = total_sum_sq - left_sum_sq;
1784 let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1785 let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1786 Some(BinarySplitChoice {
1787 feature_index,
1788 threshold_bin,
1789 score: parent_loss - (left_loss + right_loss),
1790 missing_direction: MissingBranchDirection::Node,
1791 })
1792}
1793
1794fn score_binary_split_choice_mean(
1795 context: &BuildContext<'_>,
1796 feature_index: usize,
1797 rows: &[usize],
1798) -> Option<BinarySplitChoice> {
1799 let mut left_count = 0usize;
1800 let mut left_sum = 0.0;
1801 let mut left_sum_sq = 0.0;
1802 let mut total_sum = 0.0;
1803 let mut total_sum_sq = 0.0;
1804
1805 for row_idx in rows {
1806 let target = context.targets[*row_idx];
1807 total_sum += target;
1808 total_sum_sq += target * target;
1809 if !context
1810 .table
1811 .binned_boolean_value(feature_index, *row_idx)
1812 .expect("binary feature must expose boolean values")
1813 {
1814 left_count += 1;
1815 left_sum += target;
1816 left_sum_sq += target * target;
1817 }
1818 }
1819
1820 let right_count = rows.len() - left_count;
1821 if left_count < context.options.min_samples_leaf
1822 || right_count < context.options.min_samples_leaf
1823 {
1824 return None;
1825 }
1826
1827 let parent_loss = total_sum_sq - (total_sum * total_sum) / rows.len() as f64;
1828 let right_sum = total_sum - left_sum;
1829 let right_sum_sq = total_sum_sq - left_sum_sq;
1830 let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1831 let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1832
1833 Some(BinarySplitChoice {
1834 feature_index,
1835 threshold_bin: 0,
1836 score: parent_loss - (left_loss + right_loss),
1837 missing_direction: MissingBranchDirection::Node,
1838 })
1839}
1840
1841fn score_numeric_split_mean_fast(
1842 table: &dyn TableAccess,
1843 targets: &[f64],
1844 feature_index: usize,
1845 rows: &[usize],
1846 min_samples_leaf: usize,
1847) -> Option<RegressionSplitCandidate> {
1848 let bin_cap = table.numeric_bin_cap();
1849 if bin_cap == 0 {
1850 return None;
1851 }
1852
1853 let mut bin_count = vec![0usize; bin_cap];
1854 let mut bin_sum = vec![0.0; bin_cap];
1855 let mut bin_sum_sq = vec![0.0; bin_cap];
1856 let mut observed_bins = vec![false; bin_cap];
1857 let mut total_sum = 0.0;
1858 let mut total_sum_sq = 0.0;
1859
1860 for row_idx in rows {
1861 let bin = table.binned_value(feature_index, *row_idx) as usize;
1862 if bin >= bin_cap {
1863 return None;
1864 }
1865 let target = targets[*row_idx];
1866 bin_count[bin] += 1;
1867 bin_sum[bin] += target;
1868 bin_sum_sq[bin] += target * target;
1869 observed_bins[bin] = true;
1870 total_sum += target;
1871 total_sum_sq += target * target;
1872 }
1873
1874 let observed_bins: Vec<usize> = observed_bins
1875 .into_iter()
1876 .enumerate()
1877 .filter_map(|(bin, seen)| seen.then_some(bin))
1878 .collect();
1879 if observed_bins.len() <= 1 {
1880 return None;
1881 }
1882
1883 let parent_loss = total_sum_sq - (total_sum * total_sum) / rows.len() as f64;
1884 let mut left_count = 0usize;
1885 let mut left_sum = 0.0;
1886 let mut left_sum_sq = 0.0;
1887 let mut best_threshold = None;
1888 let mut best_score = f64::NEG_INFINITY;
1889
1890 for &bin in &observed_bins {
1891 left_count += bin_count[bin];
1892 left_sum += bin_sum[bin];
1893 left_sum_sq += bin_sum_sq[bin];
1894 let right_count = rows.len() - left_count;
1895
1896 if left_count < min_samples_leaf || right_count < min_samples_leaf {
1897 continue;
1898 }
1899
1900 let right_sum = total_sum - left_sum;
1901 let right_sum_sq = total_sum_sq - left_sum_sq;
1902 let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1903 let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1904 let score = parent_loss - (left_loss + right_loss);
1905 if score > best_score {
1906 best_score = score;
1907 best_threshold = Some(bin as u16);
1908 }
1909 }
1910
1911 let threshold_bin = best_threshold?;
1912 Some(RegressionSplitCandidate {
1913 feature_index,
1914 threshold_bin,
1915 score: best_score,
1916 missing_direction: MissingBranchDirection::Node,
1917 })
1918}
1919
1920fn score_numeric_split_choice_mean_fast(
1921 context: &BuildContext<'_>,
1922 feature_index: usize,
1923 rows: &[usize],
1924) -> Option<BinarySplitChoice> {
1925 score_numeric_split_mean_fast(
1926 context.table,
1927 context.targets,
1928 feature_index,
1929 rows,
1930 context.options.min_samples_leaf,
1931 )
1932 .map(|candidate| BinarySplitChoice {
1933 feature_index: candidate.feature_index,
1934 threshold_bin: candidate.threshold_bin,
1935 score: candidate.score,
1936 missing_direction: MissingBranchDirection::Node,
1937 })
1938}
1939
1940fn score_randomized_split_mean_fast(
1941 table: &dyn TableAccess,
1942 targets: &[f64],
1943 feature_index: usize,
1944 rows: &[usize],
1945 min_samples_leaf: usize,
1946) -> Option<RegressionSplitCandidate> {
1947 let bin_cap = table.numeric_bin_cap();
1948 if bin_cap == 0 {
1949 return None;
1950 }
1951 let mut observed_bins = vec![false; bin_cap];
1952 for row_idx in rows {
1953 let bin = table.binned_value(feature_index, *row_idx) as usize;
1954 if bin >= bin_cap {
1955 return None;
1956 }
1957 observed_bins[bin] = true;
1958 }
1959 let candidate_thresholds = observed_bins
1960 .into_iter()
1961 .enumerate()
1962 .filter_map(|(bin, seen)| seen.then_some(bin as u16))
1963 .collect::<Vec<_>>();
1964 let threshold_bin =
1965 choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xA11CE551u64)?;
1966
1967 let mut left_count = 0usize;
1968 let mut left_sum = 0.0;
1969 let mut left_sum_sq = 0.0;
1970 let mut total_sum = 0.0;
1971 let mut total_sum_sq = 0.0;
1972 for row_idx in rows {
1973 let target = targets[*row_idx];
1974 total_sum += target;
1975 total_sum_sq += target * target;
1976 if table.binned_value(feature_index, *row_idx) <= threshold_bin {
1977 left_count += 1;
1978 left_sum += target;
1979 left_sum_sq += target * target;
1980 }
1981 }
1982 let right_count = rows.len() - left_count;
1983 if left_count < min_samples_leaf || right_count < min_samples_leaf {
1984 return None;
1985 }
1986 let parent_loss = total_sum_sq - (total_sum * total_sum) / rows.len() as f64;
1987 let right_sum = total_sum - left_sum;
1988 let right_sum_sq = total_sum_sq - left_sum_sq;
1989 let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1990 let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1991 let score = parent_loss - (left_loss + right_loss);
1992
1993 Some(RegressionSplitCandidate {
1994 feature_index,
1995 threshold_bin,
1996 score,
1997 missing_direction: MissingBranchDirection::Node,
1998 })
1999}
2000
2001fn score_randomized_split_choice_mean_fast(
2002 context: &BuildContext<'_>,
2003 feature_index: usize,
2004 rows: &[usize],
2005) -> Option<BinarySplitChoice> {
2006 score_randomized_split_mean_fast(
2007 context.table,
2008 context.targets,
2009 feature_index,
2010 rows,
2011 context.options.min_samples_leaf,
2012 )
2013 .map(|candidate| BinarySplitChoice {
2014 feature_index: candidate.feature_index,
2015 threshold_bin: candidate.threshold_bin,
2016 score: candidate.score,
2017 missing_direction: MissingBranchDirection::Node,
2018 })
2019}
2020
2021fn feature_has_missing(table: &dyn TableAccess, feature_index: usize, rows: &[usize]) -> bool {
2022 rows.iter()
2023 .any(|row_idx| table.is_missing(feature_index, *row_idx))
2024}
2025
2026#[allow(clippy::too_many_arguments)]
2027fn evaluate_regression_missing_assignment(
2028 table: &dyn TableAccess,
2029 targets: &[f64],
2030 feature_index: usize,
2031 rows: &[usize],
2032 criterion: Criterion,
2033 min_samples_leaf: usize,
2034 threshold_bin: u16,
2035 parent_loss: f64,
2036) -> Option<RegressionSplitCandidate> {
2037 let mut left_rows = Vec::new();
2038 let mut right_rows = Vec::new();
2039 let mut missing_rows = Vec::new();
2040
2041 for row_idx in rows.iter().copied() {
2042 if table.is_missing(feature_index, row_idx) {
2043 missing_rows.push(row_idx);
2044 } else if table.is_binary_binned_feature(feature_index) {
2045 if !table
2046 .binned_boolean_value(feature_index, row_idx)
2047 .expect("observed binary feature must expose boolean values")
2048 {
2049 left_rows.push(row_idx);
2050 } else {
2051 right_rows.push(row_idx);
2052 }
2053 } else if table.binned_value(feature_index, row_idx) <= threshold_bin {
2054 left_rows.push(row_idx);
2055 } else {
2056 right_rows.push(row_idx);
2057 }
2058 }
2059
2060 let evaluate = |direction: MissingBranchDirection| {
2061 let mut candidate_left = left_rows.clone();
2062 let mut candidate_right = right_rows.clone();
2063 match direction {
2064 MissingBranchDirection::Left => candidate_left.extend(missing_rows.iter().copied()),
2065 MissingBranchDirection::Right => candidate_right.extend(missing_rows.iter().copied()),
2066 MissingBranchDirection::Node => {}
2067 }
2068 if candidate_left.len() < min_samples_leaf || candidate_right.len() < min_samples_leaf {
2069 return None;
2070 }
2071
2072 let score = parent_loss
2073 - (regression_loss(&candidate_left, targets, criterion)
2074 + regression_loss(&candidate_right, targets, criterion));
2075 Some(RegressionSplitCandidate {
2076 feature_index,
2077 threshold_bin,
2078 score,
2079 missing_direction: direction,
2080 })
2081 };
2082
2083 if missing_rows.is_empty() {
2084 evaluate(MissingBranchDirection::Node)
2085 } else {
2086 [MissingBranchDirection::Left, MissingBranchDirection::Right]
2087 .into_iter()
2088 .filter_map(evaluate)
2089 .max_by(|left, right| left.score.total_cmp(&right.score))
2090 }
2091}
2092
2093fn score_numeric_oblivious_split_mean_fast(
2094 table: &dyn TableAccess,
2095 row_indices: &[usize],
2096 targets: &[f64],
2097 feature_index: usize,
2098 leaves: &[ObliviousLeafState],
2099 min_samples_leaf: usize,
2100) -> Option<ObliviousSplitCandidate> {
2101 let bin_cap = table.numeric_bin_cap();
2102 if bin_cap == 0 {
2103 return None;
2104 }
2105 let mut threshold_scores = vec![0.0; bin_cap];
2106 let mut observed_any = false;
2107
2108 for leaf in leaves {
2109 let mut bin_count = vec![0usize; bin_cap];
2110 let mut bin_sum = vec![0.0; bin_cap];
2111 let mut bin_sum_sq = vec![0.0; bin_cap];
2112 let mut observed_bins = vec![false; bin_cap];
2113
2114 for row_idx in &row_indices[leaf.start..leaf.end] {
2115 let bin = table.binned_value(feature_index, *row_idx) as usize;
2116 if bin >= bin_cap {
2117 return None;
2118 }
2119 let target = targets[*row_idx];
2120 bin_count[bin] += 1;
2121 bin_sum[bin] += target;
2122 bin_sum_sq[bin] += target * target;
2123 observed_bins[bin] = true;
2124 }
2125
2126 let observed_bins: Vec<usize> = observed_bins
2127 .into_iter()
2128 .enumerate()
2129 .filter_map(|(bin, seen)| seen.then_some(bin))
2130 .collect();
2131 if observed_bins.len() <= 1 {
2132 continue;
2133 }
2134 observed_any = true;
2135
2136 let parent_loss = leaf.sum_sq - (leaf.sum * leaf.sum) / leaf.len() as f64;
2137 let mut left_count = 0usize;
2138 let mut left_sum = 0.0;
2139 let mut left_sum_sq = 0.0;
2140
2141 for &bin in &observed_bins {
2142 left_count += bin_count[bin];
2143 left_sum += bin_sum[bin];
2144 left_sum_sq += bin_sum_sq[bin];
2145 let right_count = leaf.len() - left_count;
2146
2147 if left_count < min_samples_leaf || right_count < min_samples_leaf {
2148 continue;
2149 }
2150
2151 let right_sum = leaf.sum - left_sum;
2152 let right_sum_sq = leaf.sum_sq - left_sum_sq;
2153 let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
2154 let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
2155 threshold_scores[bin] += parent_loss - (left_loss + right_loss);
2156 }
2157 }
2158
2159 if !observed_any {
2160 return None;
2161 }
2162
2163 threshold_scores
2164 .into_iter()
2165 .enumerate()
2166 .filter(|(_, score)| *score > 0.0)
2167 .max_by(|left, right| left.1.total_cmp(&right.1))
2168 .map(|(threshold_bin, score)| ObliviousSplitCandidate {
2169 feature_index,
2170 threshold_bin: threshold_bin as u16,
2171 score,
2172 })
2173}
2174
2175fn score_binary_oblivious_split(
2176 table: &dyn TableAccess,
2177 row_indices: &[usize],
2178 targets: &[f64],
2179 feature_index: usize,
2180 leaves: &[ObliviousLeafState],
2181 criterion: Criterion,
2182 min_samples_leaf: usize,
2183) -> Option<ObliviousSplitCandidate> {
2184 let score = leaves.iter().fold(0.0, |score, leaf| {
2185 let leaf_rows = &row_indices[leaf.start..leaf.end];
2186 let (left_rows, right_rows): (Vec<usize>, Vec<usize>) =
2187 leaf_rows.iter().copied().partition(|row_idx| {
2188 !table
2189 .binned_boolean_value(feature_index, *row_idx)
2190 .expect("binary feature must expose boolean values")
2191 });
2192
2193 if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
2194 return score;
2195 }
2196
2197 score + regression_loss(leaf_rows, targets, criterion)
2198 - (regression_loss(&left_rows, targets, criterion)
2199 + regression_loss(&right_rows, targets, criterion))
2200 });
2201
2202 (score > 0.0).then_some(ObliviousSplitCandidate {
2203 feature_index,
2204 threshold_bin: 0,
2205 score,
2206 })
2207}
2208
2209fn score_binary_oblivious_split_mean_fast(
2210 table: &dyn TableAccess,
2211 row_indices: &[usize],
2212 targets: &[f64],
2213 feature_index: usize,
2214 leaves: &[ObliviousLeafState],
2215 min_samples_leaf: usize,
2216) -> Option<ObliviousSplitCandidate> {
2217 let mut score = 0.0;
2218 let mut found_valid = false;
2219
2220 for leaf in leaves {
2221 let mut left_count = 0usize;
2222 let mut left_sum = 0.0;
2223 let mut left_sum_sq = 0.0;
2224
2225 for row_idx in &row_indices[leaf.start..leaf.end] {
2226 if !table
2227 .binned_boolean_value(feature_index, *row_idx)
2228 .expect("binary feature must expose boolean values")
2229 {
2230 let target = targets[*row_idx];
2231 left_count += 1;
2232 left_sum += target;
2233 left_sum_sq += target * target;
2234 }
2235 }
2236
2237 let right_count = leaf.len() - left_count;
2238 if left_count < min_samples_leaf || right_count < min_samples_leaf {
2239 continue;
2240 }
2241
2242 found_valid = true;
2243 let parent_loss = leaf.sum_sq - (leaf.sum * leaf.sum) / leaf.len() as f64;
2244 let right_sum = leaf.sum - left_sum;
2245 let right_sum_sq = leaf.sum_sq - left_sum_sq;
2246 let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
2247 let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
2248 score += parent_loss - (left_loss + right_loss);
2249 }
2250
2251 (found_valid && score > 0.0).then_some(ObliviousSplitCandidate {
2252 feature_index,
2253 threshold_bin: 0,
2254 score,
2255 })
2256}
2257
2258fn sum_stats(rows: &[usize], targets: &[f64]) -> (f64, f64) {
2259 rows.iter().fold((0.0, 0.0), |(sum, sum_sq), row_idx| {
2260 let value = targets[*row_idx];
2261 (sum + value, sum_sq + value * value)
2262 })
2263}
2264
2265fn variance_from_stats(count: usize, sum: f64, sum_sq: f64) -> Option<f64> {
2266 if count == 0 {
2267 None
2268 } else {
2269 Some((sum_sq / count as f64) - (sum / count as f64).powi(2))
2270 }
2271}
2272
2273fn leaf_regression_loss(
2274 leaf: &ObliviousLeafState,
2275 row_indices: &[usize],
2276 targets: &[f64],
2277 criterion: Criterion,
2278) -> f64 {
2279 match criterion {
2280 Criterion::Mean => leaf.sum_sq - (leaf.sum * leaf.sum) / leaf.len() as f64,
2281 Criterion::Median => {
2282 regression_loss(&row_indices[leaf.start..leaf.end], targets, criterion)
2283 }
2284 _ => unreachable!("regression criterion only supports mean or median"),
2285 }
2286}
2287
2288fn has_constant_target(rows: &[usize], targets: &[f64]) -> bool {
2289 rows.first().is_none_or(|first_row| {
2290 rows.iter()
2291 .all(|row_idx| targets[*row_idx] == targets[*first_row])
2292 })
2293}
2294
2295fn push_leaf(
2296 nodes: &mut Vec<RegressionNode>,
2297 value: f64,
2298 sample_count: usize,
2299 variance: Option<f64>,
2300) -> usize {
2301 push_node(
2302 nodes,
2303 RegressionNode::Leaf {
2304 value,
2305 sample_count,
2306 variance,
2307 },
2308 )
2309}
2310
2311fn push_node(nodes: &mut Vec<RegressionNode>, node: RegressionNode) -> usize {
2312 nodes.push(node);
2313 nodes.len() - 1
2314}
2315
2316#[cfg(test)]
2317mod tests {
2318 use super::*;
2319 use crate::{FeaturePreprocessing, Model, NumericBinBoundary};
2320 use forestfire_data::{DenseTable, NumericBins};
2321
2322 fn quadratic_table() -> DenseTable {
2323 DenseTable::with_options(
2324 vec![
2325 vec![0.0],
2326 vec![1.0],
2327 vec![2.0],
2328 vec![3.0],
2329 vec![4.0],
2330 vec![5.0],
2331 vec![6.0],
2332 vec![7.0],
2333 ],
2334 vec![0.0, 1.0, 4.0, 9.0, 16.0, 25.0, 36.0, 49.0],
2335 0,
2336 NumericBins::Fixed(128),
2337 )
2338 .unwrap()
2339 }
2340
2341 fn canary_target_table() -> DenseTable {
2342 let x: Vec<Vec<f64>> = (0..8).map(|value| vec![value as f64]).collect();
2343 let probe =
2344 DenseTable::with_options(x.clone(), vec![0.0; 8], 1, NumericBins::Auto).unwrap();
2345 let canary_index = probe.n_features();
2346 let y = (0..probe.n_rows())
2347 .map(|row_idx| probe.binned_value(canary_index, row_idx) as f64)
2348 .collect();
2349
2350 DenseTable::with_options(x, y, 1, NumericBins::Auto).unwrap()
2351 }
2352
2353 fn randomized_permutation_table() -> DenseTable {
2354 DenseTable::with_options(
2355 vec![
2356 vec![0.0, 0.0, 0.0],
2357 vec![0.0, 0.0, 1.0],
2358 vec![0.0, 1.0, 0.0],
2359 vec![0.0, 1.0, 1.0],
2360 vec![1.0, 0.0, 0.0],
2361 vec![1.0, 0.0, 1.0],
2362 vec![1.0, 1.0, 0.0],
2363 vec![1.0, 1.0, 1.0],
2364 vec![0.0, 0.0, 2.0],
2365 vec![0.0, 1.0, 2.0],
2366 vec![1.0, 0.0, 2.0],
2367 vec![1.0, 1.0, 2.0],
2368 ],
2369 vec![0.0, 1.0, 2.5, 3.5, 4.0, 5.0, 6.5, 7.5, 2.0, 4.5, 6.0, 8.5],
2370 0,
2371 NumericBins::Fixed(8),
2372 )
2373 .unwrap()
2374 }
2375
2376 #[test]
2377 fn cart_regressor_fits_basic_numeric_pattern() {
2378 let table = quadratic_table();
2379 let model = train_cart_regressor(&table).unwrap();
2380 let preds = model.predict_table(&table);
2381
2382 assert_eq!(model.algorithm(), RegressionTreeAlgorithm::Cart);
2383 assert_eq!(model.criterion(), Criterion::Mean);
2384 assert_eq!(preds, table_targets(&table));
2385 }
2386
2387 #[test]
2388 fn randomized_regressor_fits_basic_numeric_pattern() {
2389 let table = quadratic_table();
2390 let model = train_randomized_regressor(&table).unwrap();
2391 let preds = model.predict_table(&table);
2392 let targets = table_targets(&table);
2393 let baseline_mean = targets.iter().sum::<f64>() / targets.len() as f64;
2394 let baseline_sse = targets
2395 .iter()
2396 .map(|target| {
2397 let diff = target - baseline_mean;
2398 diff * diff
2399 })
2400 .sum::<f64>();
2401 let model_sse = preds
2402 .iter()
2403 .zip(targets.iter())
2404 .map(|(pred, target)| {
2405 let diff = pred - target;
2406 diff * diff
2407 })
2408 .sum::<f64>();
2409
2410 assert_eq!(model.algorithm(), RegressionTreeAlgorithm::Randomized);
2411 assert_eq!(model.criterion(), Criterion::Mean);
2412 assert!(model_sse < baseline_sse);
2413 }
2414
2415 #[test]
2416 fn randomized_regressor_is_repeatable_for_fixed_seed_and_varies_across_seeds() {
2417 let table = randomized_permutation_table();
2418 let make_options = |random_seed| RegressionTreeOptions {
2419 max_depth: 4,
2420 max_features: Some(2),
2421 random_seed,
2422 ..RegressionTreeOptions::default()
2423 };
2424
2425 let base_model = train_randomized_regressor_with_criterion_parallelism_and_options(
2426 &table,
2427 Criterion::Mean,
2428 Parallelism::sequential(),
2429 make_options(91),
2430 )
2431 .unwrap();
2432 let repeated_model = train_randomized_regressor_with_criterion_parallelism_and_options(
2433 &table,
2434 Criterion::Mean,
2435 Parallelism::sequential(),
2436 make_options(91),
2437 )
2438 .unwrap();
2439 let unique_serializations = (0..32u64)
2440 .map(|seed| {
2441 Model::DecisionTreeRegressor(
2442 train_randomized_regressor_with_criterion_parallelism_and_options(
2443 &table,
2444 Criterion::Mean,
2445 Parallelism::sequential(),
2446 make_options(seed),
2447 )
2448 .unwrap(),
2449 )
2450 .serialize()
2451 .unwrap()
2452 })
2453 .collect::<std::collections::BTreeSet<_>>();
2454
2455 assert_eq!(
2456 Model::DecisionTreeRegressor(base_model.clone())
2457 .serialize()
2458 .unwrap(),
2459 Model::DecisionTreeRegressor(repeated_model)
2460 .serialize()
2461 .unwrap()
2462 );
2463 assert!(unique_serializations.len() >= 4);
2464 }
2465
2466 #[test]
2467 fn oblivious_regressor_fits_basic_numeric_pattern() {
2468 let table = quadratic_table();
2469 let model = train_oblivious_regressor(&table).unwrap();
2470 let preds = model.predict_table(&table);
2471
2472 assert_eq!(model.algorithm(), RegressionTreeAlgorithm::Oblivious);
2473 assert_eq!(model.criterion(), Criterion::Mean);
2474 assert_eq!(preds, table_targets(&table));
2475 }
2476
2477 #[test]
2478 fn regressors_can_choose_between_mean_and_median() {
2479 let table = DenseTable::with_canaries(
2480 vec![vec![0.0], vec![0.0], vec![0.0]],
2481 vec![0.0, 0.0, 100.0],
2482 0,
2483 )
2484 .unwrap();
2485
2486 let mean_model = train_cart_regressor_with_criterion(&table, Criterion::Mean).unwrap();
2487 let median_model = train_cart_regressor_with_criterion(&table, Criterion::Median).unwrap();
2488
2489 assert_eq!(mean_model.criterion(), Criterion::Mean);
2490 assert_eq!(median_model.criterion(), Criterion::Median);
2491 assert_eq!(
2492 mean_model.predict_table(&table),
2493 vec![100.0 / 3.0, 100.0 / 3.0, 100.0 / 3.0]
2494 );
2495 assert_eq!(median_model.predict_table(&table), vec![0.0, 0.0, 0.0]);
2496 }
2497
2498 #[test]
2499 fn rejects_non_finite_targets() {
2500 let table = DenseTable::new(vec![vec![0.0], vec![1.0]], vec![0.0, f64::NAN]).unwrap();
2501
2502 let err = train_cart_regressor(&table).unwrap_err();
2503 assert!(matches!(
2504 err,
2505 RegressionTreeError::InvalidTargetValue { row: 1, value } if value.is_nan()
2506 ));
2507 }
2508
2509 #[test]
2510 fn stops_cart_regressor_growth_when_a_canary_wins() {
2511 let table = canary_target_table();
2512 let model = train_cart_regressor(&table).unwrap();
2513 let preds = model.predict_table(&table);
2514
2515 assert!(preds.iter().all(|pred| *pred == preds[0]));
2516 assert_ne!(preds, table_targets(&table));
2517 }
2518
2519 #[test]
2520 fn stops_oblivious_regressor_growth_when_a_canary_wins() {
2521 let table = canary_target_table();
2522 let model = train_oblivious_regressor(&table).unwrap();
2523 let preds = model.predict_table(&table);
2524
2525 assert!(preds.iter().all(|pred| *pred == preds[0]));
2526 assert_ne!(preds, table_targets(&table));
2527 }
2528
2529 #[test]
2530 fn manually_built_regressor_models_serialize_for_each_tree_type() {
2531 let preprocessing = vec![
2532 FeaturePreprocessing::Binary,
2533 FeaturePreprocessing::Numeric {
2534 bin_boundaries: vec![
2535 NumericBinBoundary {
2536 bin: 0,
2537 upper_bound: 1.0,
2538 },
2539 NumericBinBoundary {
2540 bin: 127,
2541 upper_bound: 10.0,
2542 },
2543 ],
2544 missing_bin: 128,
2545 },
2546 ];
2547 let options = RegressionTreeOptions::default();
2548
2549 let cart = Model::DecisionTreeRegressor(DecisionTreeRegressor {
2550 algorithm: RegressionTreeAlgorithm::Cart,
2551 criterion: Criterion::Mean,
2552 structure: RegressionTreeStructure::Standard {
2553 nodes: vec![
2554 RegressionNode::Leaf {
2555 value: -1.0,
2556 sample_count: 2,
2557 variance: Some(0.25),
2558 },
2559 RegressionNode::Leaf {
2560 value: 2.5,
2561 sample_count: 3,
2562 variance: Some(1.0),
2563 },
2564 RegressionNode::BinarySplit {
2565 feature_index: 0,
2566 threshold_bin: 0,
2567 missing_direction: crate::tree::shared::MissingBranchDirection::Node,
2568 missing_value: -1.0,
2569 left_child: 0,
2570 right_child: 1,
2571 sample_count: 5,
2572 impurity: 3.5,
2573 gain: 1.25,
2574 variance: Some(0.7),
2575 },
2576 ],
2577 root: 2,
2578 },
2579 options: options.clone(),
2580 num_features: 2,
2581 feature_preprocessing: preprocessing.clone(),
2582 training_canaries: 0,
2583 });
2584 let randomized = Model::DecisionTreeRegressor(DecisionTreeRegressor {
2585 algorithm: RegressionTreeAlgorithm::Randomized,
2586 criterion: Criterion::Median,
2587 structure: RegressionTreeStructure::Standard {
2588 nodes: vec![
2589 RegressionNode::Leaf {
2590 value: -1.0,
2591 sample_count: 2,
2592 variance: Some(0.25),
2593 },
2594 RegressionNode::Leaf {
2595 value: 2.5,
2596 sample_count: 3,
2597 variance: Some(1.0),
2598 },
2599 RegressionNode::BinarySplit {
2600 feature_index: 0,
2601 threshold_bin: 0,
2602 missing_direction: crate::tree::shared::MissingBranchDirection::Node,
2603 missing_value: -1.0,
2604 left_child: 0,
2605 right_child: 1,
2606 sample_count: 5,
2607 impurity: 3.5,
2608 gain: 0.8,
2609 variance: Some(0.7),
2610 },
2611 ],
2612 root: 2,
2613 },
2614 options: options.clone(),
2615 num_features: 2,
2616 feature_preprocessing: preprocessing.clone(),
2617 training_canaries: 0,
2618 });
2619 let oblivious = Model::DecisionTreeRegressor(DecisionTreeRegressor {
2620 algorithm: RegressionTreeAlgorithm::Oblivious,
2621 criterion: Criterion::Median,
2622 structure: RegressionTreeStructure::Oblivious {
2623 splits: vec![ObliviousSplit {
2624 feature_index: 1,
2625 threshold_bin: 127,
2626 sample_count: 4,
2627 impurity: 2.0,
2628 gain: 0.5,
2629 }],
2630 leaf_values: vec![0.0, 10.0],
2631 leaf_sample_counts: vec![2, 2],
2632 leaf_variances: vec![Some(0.0), Some(1.0)],
2633 },
2634 options,
2635 num_features: 2,
2636 feature_preprocessing: preprocessing,
2637 training_canaries: 0,
2638 });
2639
2640 for (tree_type, model) in [
2641 ("cart", cart),
2642 ("randomized", randomized),
2643 ("oblivious", oblivious),
2644 ] {
2645 let json = model.serialize().unwrap();
2646 assert!(json.contains(&format!("\"tree_type\":\"{tree_type}\"")));
2647 assert!(json.contains("\"task\":\"regression\""));
2648 }
2649 }
2650
2651 #[test]
2652 fn cart_regressor_assigns_training_missing_values_to_best_child() {
2653 let table = DenseTable::with_canaries(
2654 vec![
2655 vec![0.0],
2656 vec![0.0],
2657 vec![1.0],
2658 vec![1.0],
2659 vec![f64::NAN],
2660 vec![f64::NAN],
2661 ],
2662 vec![0.0, 0.0, 10.0, 10.0, 0.0, 0.0],
2663 0,
2664 )
2665 .unwrap();
2666
2667 let model = train_cart_regressor(&table).unwrap();
2668
2669 let wrapped = Model::DecisionTreeRegressor(model.clone());
2670 assert_eq!(
2671 wrapped.predict_rows(vec![vec![f64::NAN]]).unwrap(),
2672 vec![0.0]
2673 );
2674 }
2675
2676 #[test]
2677 fn cart_regressor_defaults_unseen_missing_to_node_mean() {
2678 let table = DenseTable::with_canaries(
2679 vec![vec![0.0], vec![0.0], vec![1.0]],
2680 vec![0.0, 0.0, 9.0],
2681 0,
2682 )
2683 .unwrap();
2684
2685 let model = train_cart_regressor(&table).unwrap();
2686 let wrapped = Model::DecisionTreeRegressor(model.clone());
2687 let prediction = wrapped.predict_rows(vec![vec![f64::NAN]]).unwrap()[0];
2688
2689 assert!((prediction - 3.0).abs() < 1e-9);
2690 }
2691
2692 fn table_targets(table: &dyn TableAccess) -> Vec<f64> {
2693 (0..table.n_rows())
2694 .map(|row_idx| table.target_value(row_idx))
2695 .collect()
2696 }
2697}