1use crate::ir::{
9 BinaryChildren, BinarySplit, IndexedLeaf, LeafIndexing, LeafPayload, NodeStats, NodeTreeNode,
10 ObliviousLevel, ObliviousSplit as IrObliviousSplit, TrainingMetadata, TreeDefinition,
11 criterion_name, feature_name, threshold_upper_bound,
12};
13use crate::tree::shared::{
14 FeatureHistogram, HistogramBin, build_feature_histograms, candidate_feature_indices,
15 choose_random_threshold, node_seed, partition_rows_for_binary_split,
16 subtract_feature_histograms,
17};
18use crate::{Criterion, FeaturePreprocessing, Parallelism, capture_feature_preprocessing};
19use forestfire_data::TableAccess;
20use rayon::prelude::*;
21use std::collections::BTreeSet;
22use std::error::Error;
23use std::fmt::{Display, Formatter};
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum RegressionTreeAlgorithm {
27 Cart,
28 Randomized,
29 Oblivious,
30}
31
32#[derive(Debug, Clone, Copy)]
34pub struct RegressionTreeOptions {
35 pub max_depth: usize,
36 pub min_samples_split: usize,
37 pub min_samples_leaf: usize,
38 pub max_features: Option<usize>,
39 pub random_seed: u64,
40}
41
42impl Default for RegressionTreeOptions {
43 fn default() -> Self {
44 Self {
45 max_depth: 8,
46 min_samples_split: 2,
47 min_samples_leaf: 1,
48 max_features: None,
49 random_seed: 0,
50 }
51 }
52}
53
54#[derive(Debug)]
55pub enum RegressionTreeError {
56 EmptyTarget,
57 InvalidTargetValue { row: usize, value: f64 },
58}
59
60impl Display for RegressionTreeError {
61 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
62 match self {
63 RegressionTreeError::EmptyTarget => {
64 write!(f, "Cannot train on an empty target vector.")
65 }
66 RegressionTreeError::InvalidTargetValue { row, value } => write!(
67 f,
68 "Regression targets must be finite values. Found {} at row {}.",
69 value, row
70 ),
71 }
72 }
73}
74
75impl Error for RegressionTreeError {}
76
77#[derive(Debug, Clone)]
79pub struct DecisionTreeRegressor {
80 algorithm: RegressionTreeAlgorithm,
81 criterion: Criterion,
82 structure: RegressionTreeStructure,
83 options: RegressionTreeOptions,
84 num_features: usize,
85 feature_preprocessing: Vec<FeaturePreprocessing>,
86 training_canaries: usize,
87}
88
89#[derive(Debug, Clone)]
90pub(crate) enum RegressionTreeStructure {
91 Standard {
92 nodes: Vec<RegressionNode>,
93 root: usize,
94 },
95 Oblivious {
96 splits: Vec<ObliviousSplit>,
97 leaf_values: Vec<f64>,
98 leaf_sample_counts: Vec<usize>,
99 leaf_variances: Vec<Option<f64>>,
100 },
101}
102
103#[derive(Debug, Clone)]
104pub(crate) enum RegressionNode {
105 Leaf {
106 value: f64,
107 sample_count: usize,
108 variance: Option<f64>,
109 },
110 BinarySplit {
111 feature_index: usize,
112 threshold_bin: u16,
113 left_child: usize,
114 right_child: usize,
115 sample_count: usize,
116 impurity: f64,
117 gain: f64,
118 variance: Option<f64>,
119 },
120}
121
122#[derive(Debug, Clone, Copy)]
123pub(crate) struct ObliviousSplit {
124 pub(crate) feature_index: usize,
125 pub(crate) threshold_bin: u16,
126 pub(crate) sample_count: usize,
127 pub(crate) impurity: f64,
128 pub(crate) gain: f64,
129}
130
131#[derive(Debug, Clone)]
132struct RegressionSplitCandidate {
133 feature_index: usize,
134 threshold_bin: u16,
135 score: f64,
136}
137
138#[derive(Debug, Clone)]
139struct ObliviousLeafState {
140 start: usize,
141 end: usize,
142 value: f64,
143 variance: Option<f64>,
144 sum: f64,
145 sum_sq: f64,
146}
147
148impl ObliviousLeafState {
149 fn len(&self) -> usize {
150 self.end - self.start
151 }
152}
153
154#[derive(Debug, Clone, Copy)]
155struct ObliviousSplitCandidate {
156 feature_index: usize,
157 threshold_bin: u16,
158 score: f64,
159}
160
161#[derive(Debug, Clone, Copy)]
162struct BinarySplitChoice {
163 feature_index: usize,
164 threshold_bin: u16,
165 score: f64,
166}
167
168#[derive(Debug, Clone)]
169struct RegressionHistogramBin {
170 count: usize,
171 sum: f64,
172 sum_sq: f64,
173}
174
175impl HistogramBin for RegressionHistogramBin {
176 fn subtract(parent: &Self, child: &Self) -> Self {
177 Self {
178 count: parent.count - child.count,
179 sum: parent.sum - child.sum,
180 sum_sq: parent.sum_sq - child.sum_sq,
181 }
182 }
183
184 fn is_observed(&self) -> bool {
185 self.count > 0
186 }
187}
188
189type RegressionFeatureHistogram = FeatureHistogram<RegressionHistogramBin>;
190
191pub fn train_cart_regressor(
192 train_set: &dyn TableAccess,
193) -> Result<DecisionTreeRegressor, RegressionTreeError> {
194 train_cart_regressor_with_criterion(train_set, Criterion::Mean)
195}
196
197pub fn train_cart_regressor_with_criterion(
198 train_set: &dyn TableAccess,
199 criterion: Criterion,
200) -> Result<DecisionTreeRegressor, RegressionTreeError> {
201 train_cart_regressor_with_criterion_and_parallelism(
202 train_set,
203 criterion,
204 Parallelism::sequential(),
205 )
206}
207
208pub(crate) fn train_cart_regressor_with_criterion_and_parallelism(
209 train_set: &dyn TableAccess,
210 criterion: Criterion,
211 parallelism: Parallelism,
212) -> Result<DecisionTreeRegressor, RegressionTreeError> {
213 train_cart_regressor_with_criterion_parallelism_and_options(
214 train_set,
215 criterion,
216 parallelism,
217 RegressionTreeOptions::default(),
218 )
219}
220
221pub(crate) fn train_cart_regressor_with_criterion_parallelism_and_options(
222 train_set: &dyn TableAccess,
223 criterion: Criterion,
224 parallelism: Parallelism,
225 options: RegressionTreeOptions,
226) -> Result<DecisionTreeRegressor, RegressionTreeError> {
227 train_regressor(
228 train_set,
229 RegressionTreeAlgorithm::Cart,
230 criterion,
231 parallelism,
232 options,
233 )
234}
235
236pub fn train_oblivious_regressor(
237 train_set: &dyn TableAccess,
238) -> Result<DecisionTreeRegressor, RegressionTreeError> {
239 train_oblivious_regressor_with_criterion(train_set, Criterion::Mean)
240}
241
242pub fn train_oblivious_regressor_with_criterion(
243 train_set: &dyn TableAccess,
244 criterion: Criterion,
245) -> Result<DecisionTreeRegressor, RegressionTreeError> {
246 train_oblivious_regressor_with_criterion_and_parallelism(
247 train_set,
248 criterion,
249 Parallelism::sequential(),
250 )
251}
252
253pub(crate) fn train_oblivious_regressor_with_criterion_and_parallelism(
254 train_set: &dyn TableAccess,
255 criterion: Criterion,
256 parallelism: Parallelism,
257) -> Result<DecisionTreeRegressor, RegressionTreeError> {
258 train_oblivious_regressor_with_criterion_parallelism_and_options(
259 train_set,
260 criterion,
261 parallelism,
262 RegressionTreeOptions::default(),
263 )
264}
265
266pub(crate) fn train_oblivious_regressor_with_criterion_parallelism_and_options(
267 train_set: &dyn TableAccess,
268 criterion: Criterion,
269 parallelism: Parallelism,
270 options: RegressionTreeOptions,
271) -> Result<DecisionTreeRegressor, RegressionTreeError> {
272 train_regressor(
273 train_set,
274 RegressionTreeAlgorithm::Oblivious,
275 criterion,
276 parallelism,
277 options,
278 )
279}
280
281pub fn train_randomized_regressor(
282 train_set: &dyn TableAccess,
283) -> Result<DecisionTreeRegressor, RegressionTreeError> {
284 train_randomized_regressor_with_criterion(train_set, Criterion::Mean)
285}
286
287pub fn train_randomized_regressor_with_criterion(
288 train_set: &dyn TableAccess,
289 criterion: Criterion,
290) -> Result<DecisionTreeRegressor, RegressionTreeError> {
291 train_randomized_regressor_with_criterion_and_parallelism(
292 train_set,
293 criterion,
294 Parallelism::sequential(),
295 )
296}
297
298pub(crate) fn train_randomized_regressor_with_criterion_and_parallelism(
299 train_set: &dyn TableAccess,
300 criterion: Criterion,
301 parallelism: Parallelism,
302) -> Result<DecisionTreeRegressor, RegressionTreeError> {
303 train_randomized_regressor_with_criterion_parallelism_and_options(
304 train_set,
305 criterion,
306 parallelism,
307 RegressionTreeOptions::default(),
308 )
309}
310
311pub(crate) fn train_randomized_regressor_with_criterion_parallelism_and_options(
312 train_set: &dyn TableAccess,
313 criterion: Criterion,
314 parallelism: Parallelism,
315 options: RegressionTreeOptions,
316) -> Result<DecisionTreeRegressor, RegressionTreeError> {
317 train_regressor(
318 train_set,
319 RegressionTreeAlgorithm::Randomized,
320 criterion,
321 parallelism,
322 options,
323 )
324}
325
326fn train_regressor(
327 train_set: &dyn TableAccess,
328 algorithm: RegressionTreeAlgorithm,
329 criterion: Criterion,
330 parallelism: Parallelism,
331 options: RegressionTreeOptions,
332) -> Result<DecisionTreeRegressor, RegressionTreeError> {
333 if train_set.n_rows() == 0 {
334 return Err(RegressionTreeError::EmptyTarget);
335 }
336
337 let targets = finite_targets(train_set)?;
338 let structure = match algorithm {
339 RegressionTreeAlgorithm::Cart => {
340 let mut nodes = Vec::new();
341 let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
342 let context = BuildContext {
343 table: train_set,
344 targets: &targets,
345 criterion,
346 parallelism,
347 options,
348 algorithm,
349 };
350 let root = build_binary_node_in_place(&context, &mut nodes, &mut all_rows, 0);
354 RegressionTreeStructure::Standard { nodes, root }
355 }
356 RegressionTreeAlgorithm::Randomized => {
357 let mut nodes = Vec::new();
358 let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
359 let context = BuildContext {
360 table: train_set,
361 targets: &targets,
362 criterion,
363 parallelism,
364 options,
365 algorithm,
366 };
367 let root = build_binary_node_in_place(&context, &mut nodes, &mut all_rows, 0);
368 RegressionTreeStructure::Standard { nodes, root }
369 }
370 RegressionTreeAlgorithm::Oblivious => {
371 train_oblivious_structure(train_set, &targets, criterion, parallelism, options)
374 }
375 };
376
377 Ok(DecisionTreeRegressor {
378 algorithm,
379 criterion,
380 structure,
381 options,
382 num_features: train_set.n_features(),
383 feature_preprocessing: capture_feature_preprocessing(train_set),
384 training_canaries: train_set.canaries(),
385 })
386}
387
388impl DecisionTreeRegressor {
389 pub fn algorithm(&self) -> RegressionTreeAlgorithm {
391 self.algorithm
392 }
393
394 pub fn criterion(&self) -> Criterion {
396 self.criterion
397 }
398
399 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
401 (0..table.n_rows())
402 .map(|row_idx| self.predict_row(table, row_idx))
403 .collect()
404 }
405
406 fn predict_row(&self, table: &dyn TableAccess, row_idx: usize) -> f64 {
407 match &self.structure {
408 RegressionTreeStructure::Standard { nodes, root } => {
409 let mut node_index = *root;
410
411 loop {
412 match &nodes[node_index] {
413 RegressionNode::Leaf { value, .. } => return *value,
414 RegressionNode::BinarySplit {
415 feature_index,
416 threshold_bin,
417 left_child,
418 right_child,
419 ..
420 } => {
421 let bin = table.binned_value(*feature_index, row_idx);
422 node_index = if bin <= *threshold_bin {
423 *left_child
424 } else {
425 *right_child
426 };
427 }
428 }
429 }
430 }
431 RegressionTreeStructure::Oblivious {
432 splits,
433 leaf_values,
434 ..
435 } => {
436 let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
437 let go_right =
438 table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
439 (leaf_index << 1) | usize::from(go_right)
440 });
441
442 leaf_values[leaf_index]
443 }
444 }
445 }
446
447 pub(crate) fn num_features(&self) -> usize {
448 self.num_features
449 }
450
451 pub(crate) fn structure(&self) -> &RegressionTreeStructure {
452 &self.structure
453 }
454
455 pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
456 &self.feature_preprocessing
457 }
458
459 pub(crate) fn training_metadata(&self) -> TrainingMetadata {
460 TrainingMetadata {
461 algorithm: "dt".to_string(),
462 task: "regression".to_string(),
463 tree_type: match self.algorithm {
464 RegressionTreeAlgorithm::Cart => "cart".to_string(),
465 RegressionTreeAlgorithm::Randomized => "randomized".to_string(),
466 RegressionTreeAlgorithm::Oblivious => "oblivious".to_string(),
467 },
468 criterion: criterion_name(self.criterion).to_string(),
469 canaries: self.training_canaries,
470 compute_oob: false,
471 max_depth: Some(self.options.max_depth),
472 min_samples_split: Some(self.options.min_samples_split),
473 min_samples_leaf: Some(self.options.min_samples_leaf),
474 n_trees: None,
475 max_features: self.options.max_features,
476 seed: None,
477 oob_score: None,
478 class_labels: None,
479 learning_rate: None,
480 bootstrap: None,
481 top_gradient_fraction: None,
482 other_gradient_fraction: None,
483 }
484 }
485
486 pub(crate) fn to_ir_tree(&self) -> TreeDefinition {
487 match &self.structure {
488 RegressionTreeStructure::Standard { nodes, root } => {
489 let depths = standard_node_depths(nodes, *root);
490 TreeDefinition::NodeTree {
491 tree_id: 0,
492 weight: 1.0,
493 root_node_id: *root,
494 nodes: nodes
495 .iter()
496 .enumerate()
497 .map(|(node_id, node)| match node {
498 RegressionNode::Leaf {
499 value,
500 sample_count,
501 variance,
502 } => NodeTreeNode::Leaf {
503 node_id,
504 depth: depths[node_id],
505 leaf: LeafPayload::RegressionValue { value: *value },
506 stats: NodeStats {
507 sample_count: *sample_count,
508 impurity: None,
509 gain: None,
510 class_counts: None,
511 variance: *variance,
512 },
513 },
514 RegressionNode::BinarySplit {
515 feature_index,
516 threshold_bin,
517 left_child,
518 right_child,
519 sample_count,
520 impurity,
521 gain,
522 variance,
523 } => NodeTreeNode::BinaryBranch {
524 node_id,
525 depth: depths[node_id],
526 split: binary_split_ir(
527 *feature_index,
528 *threshold_bin,
529 &self.feature_preprocessing,
530 ),
531 children: BinaryChildren {
532 left: *left_child,
533 right: *right_child,
534 },
535 stats: NodeStats {
536 sample_count: *sample_count,
537 impurity: Some(*impurity),
538 gain: Some(*gain),
539 class_counts: None,
540 variance: *variance,
541 },
542 },
543 })
544 .collect(),
545 }
546 }
547 RegressionTreeStructure::Oblivious {
548 splits,
549 leaf_values,
550 leaf_sample_counts,
551 leaf_variances,
552 } => TreeDefinition::ObliviousLevels {
553 tree_id: 0,
554 weight: 1.0,
555 depth: splits.len(),
556 levels: splits
557 .iter()
558 .enumerate()
559 .map(|(level, split)| ObliviousLevel {
560 level,
561 split: oblivious_split_ir(
562 split.feature_index,
563 split.threshold_bin,
564 &self.feature_preprocessing,
565 ),
566 stats: NodeStats {
567 sample_count: split.sample_count,
568 impurity: Some(split.impurity),
569 gain: Some(split.gain),
570 class_counts: None,
571 variance: None,
572 },
573 })
574 .collect(),
575 leaf_indexing: LeafIndexing {
576 bit_order: "msb_first".to_string(),
577 index_formula: "sum(bit[level] << (depth - 1 - level))".to_string(),
578 },
579 leaves: leaf_values
580 .iter()
581 .enumerate()
582 .map(|(leaf_index, value)| IndexedLeaf {
583 leaf_index,
584 leaf: LeafPayload::RegressionValue { value: *value },
585 stats: NodeStats {
586 sample_count: leaf_sample_counts[leaf_index],
587 impurity: None,
588 gain: None,
589 class_counts: None,
590 variance: leaf_variances[leaf_index],
591 },
592 })
593 .collect(),
594 },
595 }
596 }
597
598 pub(crate) fn from_ir_parts(
599 algorithm: RegressionTreeAlgorithm,
600 criterion: Criterion,
601 structure: RegressionTreeStructure,
602 options: RegressionTreeOptions,
603 num_features: usize,
604 feature_preprocessing: Vec<FeaturePreprocessing>,
605 training_canaries: usize,
606 ) -> Self {
607 Self {
608 algorithm,
609 criterion,
610 structure,
611 options,
612 num_features,
613 feature_preprocessing,
614 training_canaries,
615 }
616 }
617}
618
619fn standard_node_depths(nodes: &[RegressionNode], root: usize) -> Vec<usize> {
620 let mut depths = vec![0; nodes.len()];
621 populate_depths(nodes, root, 0, &mut depths);
622 depths
623}
624
625fn populate_depths(nodes: &[RegressionNode], node_id: usize, depth: usize, depths: &mut [usize]) {
626 depths[node_id] = depth;
627 match &nodes[node_id] {
628 RegressionNode::Leaf { .. } => {}
629 RegressionNode::BinarySplit {
630 left_child,
631 right_child,
632 ..
633 } => {
634 populate_depths(nodes, *left_child, depth + 1, depths);
635 populate_depths(nodes, *right_child, depth + 1, depths);
636 }
637 }
638}
639
640fn binary_split_ir(
641 feature_index: usize,
642 threshold_bin: u16,
643 preprocessing: &[FeaturePreprocessing],
644) -> BinarySplit {
645 match preprocessing.get(feature_index) {
646 Some(FeaturePreprocessing::Binary) => BinarySplit::BooleanTest {
647 feature_index,
648 feature_name: feature_name(feature_index),
649 false_child_semantics: "left".to_string(),
650 true_child_semantics: "right".to_string(),
651 },
652 Some(FeaturePreprocessing::Numeric { .. }) | None => BinarySplit::NumericBinThreshold {
653 feature_index,
654 feature_name: feature_name(feature_index),
655 operator: "<=".to_string(),
656 threshold_bin,
657 threshold_upper_bound: threshold_upper_bound(
658 preprocessing,
659 feature_index,
660 threshold_bin,
661 ),
662 comparison_dtype: "uint16".to_string(),
663 },
664 }
665}
666
667fn oblivious_split_ir(
668 feature_index: usize,
669 threshold_bin: u16,
670 preprocessing: &[FeaturePreprocessing],
671) -> IrObliviousSplit {
672 match preprocessing.get(feature_index) {
673 Some(FeaturePreprocessing::Binary) => IrObliviousSplit::BooleanTest {
674 feature_index,
675 feature_name: feature_name(feature_index),
676 bit_when_false: 0,
677 bit_when_true: 1,
678 },
679 Some(FeaturePreprocessing::Numeric { .. }) | None => {
680 IrObliviousSplit::NumericBinThreshold {
681 feature_index,
682 feature_name: feature_name(feature_index),
683 operator: "<=".to_string(),
684 threshold_bin,
685 threshold_upper_bound: threshold_upper_bound(
686 preprocessing,
687 feature_index,
688 threshold_bin,
689 ),
690 comparison_dtype: "uint16".to_string(),
691 bit_when_true: 0,
692 bit_when_false: 1,
693 }
694 }
695 }
696}
697
698struct BuildContext<'a> {
699 table: &'a dyn TableAccess,
700 targets: &'a [f64],
701 criterion: Criterion,
702 parallelism: Parallelism,
703 options: RegressionTreeOptions,
704 algorithm: RegressionTreeAlgorithm,
705}
706
707fn build_regression_node_histograms(
708 table: &dyn TableAccess,
709 targets: &[f64],
710 rows: &[usize],
711) -> Vec<RegressionFeatureHistogram> {
712 build_feature_histograms(
713 table,
714 rows,
715 |_| RegressionHistogramBin {
716 count: 0,
717 sum: 0.0,
718 sum_sq: 0.0,
719 },
720 |_feature_index, payload, row_idx| {
721 let value = targets[row_idx];
722 payload.count += 1;
723 payload.sum += value;
724 payload.sum_sq += value * value;
725 },
726 )
727}
728
729fn subtract_regression_node_histograms(
730 parent: &[RegressionFeatureHistogram],
731 child: &[RegressionFeatureHistogram],
732) -> Vec<RegressionFeatureHistogram> {
733 subtract_feature_histograms(parent, child)
734}
735
736fn finite_targets(train_set: &dyn TableAccess) -> Result<Vec<f64>, RegressionTreeError> {
737 (0..train_set.n_rows())
738 .map(|row_idx| {
739 let value = train_set.target_value(row_idx);
740 if value.is_finite() {
741 Ok(value)
742 } else {
743 Err(RegressionTreeError::InvalidTargetValue {
744 row: row_idx,
745 value,
746 })
747 }
748 })
749 .collect()
750}
751
752fn build_binary_node_in_place(
753 context: &BuildContext<'_>,
754 nodes: &mut Vec<RegressionNode>,
755 rows: &mut [usize],
756 depth: usize,
757) -> usize {
758 build_binary_node_in_place_with_hist(context, nodes, rows, depth, None)
759}
760
761fn build_binary_node_in_place_with_hist(
762 context: &BuildContext<'_>,
763 nodes: &mut Vec<RegressionNode>,
764 rows: &mut [usize],
765 depth: usize,
766 histograms: Option<Vec<RegressionFeatureHistogram>>,
767) -> usize {
768 let leaf_value = regression_value(rows, context.targets, context.criterion);
769 let leaf_variance = variance(rows, context.targets);
770
771 if rows.is_empty()
772 || depth >= context.options.max_depth
773 || rows.len() < context.options.min_samples_split
774 || has_constant_target(rows, context.targets)
775 {
776 return push_leaf(nodes, leaf_value, rows.len(), leaf_variance);
777 }
778
779 let histograms = if matches!(context.criterion, Criterion::Mean) {
780 Some(histograms.unwrap_or_else(|| {
781 build_regression_node_histograms(context.table, context.targets, rows)
782 }))
783 } else {
784 None
785 };
786 let feature_indices = candidate_feature_indices(
787 context.table.binned_feature_count(),
788 context.options.max_features,
789 node_seed(context.options.random_seed, depth, rows, 0xA11C_E5E1u64),
790 );
791 let best_split = if context.parallelism.enabled() {
792 feature_indices
793 .into_par_iter()
794 .filter_map(|feature_index| {
795 if let Some(histograms) = histograms.as_ref() {
796 score_binary_split_choice_from_hist(
797 context,
798 &histograms[feature_index],
799 feature_index,
800 rows,
801 )
802 } else {
803 score_binary_split_choice(context, feature_index, rows)
804 }
805 })
806 .max_by(|left, right| left.score.total_cmp(&right.score))
807 } else {
808 feature_indices
809 .into_iter()
810 .filter_map(|feature_index| {
811 if let Some(histograms) = histograms.as_ref() {
812 score_binary_split_choice_from_hist(
813 context,
814 &histograms[feature_index],
815 feature_index,
816 rows,
817 )
818 } else {
819 score_binary_split_choice(context, feature_index, rows)
820 }
821 })
822 .max_by(|left, right| left.score.total_cmp(&right.score))
823 };
824
825 match best_split {
826 Some(best_split)
827 if context
828 .table
829 .is_canary_binned_feature(best_split.feature_index) =>
830 {
831 push_leaf(nodes, leaf_value, rows.len(), leaf_variance)
832 }
833 Some(best_split) if best_split.score > 0.0 => {
834 let impurity = regression_loss(rows, context.targets, context.criterion);
835 let left_count = partition_rows_for_binary_split(
836 context.table,
837 best_split.feature_index,
838 best_split.threshold_bin,
839 rows,
840 );
841 let (left_rows, right_rows) = rows.split_at_mut(left_count);
842 let (left_child, right_child) = if let Some(histograms) = histograms {
843 if left_rows.len() <= right_rows.len() {
844 let left_histograms =
845 build_regression_node_histograms(context.table, context.targets, left_rows);
846 let right_histograms =
847 subtract_regression_node_histograms(&histograms, &left_histograms);
848 (
849 build_binary_node_in_place_with_hist(
850 context,
851 nodes,
852 left_rows,
853 depth + 1,
854 Some(left_histograms),
855 ),
856 build_binary_node_in_place_with_hist(
857 context,
858 nodes,
859 right_rows,
860 depth + 1,
861 Some(right_histograms),
862 ),
863 )
864 } else {
865 let right_histograms = build_regression_node_histograms(
866 context.table,
867 context.targets,
868 right_rows,
869 );
870 let left_histograms =
871 subtract_regression_node_histograms(&histograms, &right_histograms);
872 (
873 build_binary_node_in_place_with_hist(
874 context,
875 nodes,
876 left_rows,
877 depth + 1,
878 Some(left_histograms),
879 ),
880 build_binary_node_in_place_with_hist(
881 context,
882 nodes,
883 right_rows,
884 depth + 1,
885 Some(right_histograms),
886 ),
887 )
888 }
889 } else {
890 (
891 build_binary_node_in_place(context, nodes, left_rows, depth + 1),
892 build_binary_node_in_place(context, nodes, right_rows, depth + 1),
893 )
894 };
895
896 push_node(
897 nodes,
898 RegressionNode::BinarySplit {
899 feature_index: best_split.feature_index,
900 threshold_bin: best_split.threshold_bin,
901 left_child,
902 right_child,
903 sample_count: rows.len(),
904 impurity,
905 gain: best_split.score,
906 variance: leaf_variance,
907 },
908 )
909 }
910 _ => push_leaf(nodes, leaf_value, rows.len(), leaf_variance),
911 }
912}
913
914fn train_oblivious_structure(
915 table: &dyn TableAccess,
916 targets: &[f64],
917 criterion: Criterion,
918 parallelism: Parallelism,
919 options: RegressionTreeOptions,
920) -> RegressionTreeStructure {
921 let mut row_indices: Vec<usize> = (0..table.n_rows()).collect();
922 let (root_sum, root_sum_sq) = sum_stats(&row_indices, targets);
923 let mut leaves = vec![ObliviousLeafState {
924 start: 0,
925 end: row_indices.len(),
926 value: regression_value_from_stats(&row_indices, targets, criterion, root_sum),
927 variance: variance_from_stats(row_indices.len(), root_sum, root_sum_sq),
928 sum: root_sum,
929 sum_sq: root_sum_sq,
930 }];
931 let mut splits = Vec::new();
932
933 for depth in 0..options.max_depth {
934 if leaves
935 .iter()
936 .all(|leaf| leaf.len() < options.min_samples_split)
937 {
938 break;
939 }
940 let feature_indices = candidate_feature_indices(
941 table.binned_feature_count(),
942 options.max_features,
943 node_seed(options.random_seed, depth, &[], 0x0B11_A10Cu64),
944 );
945 let best_split = if parallelism.enabled() {
946 feature_indices
947 .into_par_iter()
948 .filter_map(|feature_index| {
949 score_oblivious_split(
950 table,
951 &row_indices,
952 targets,
953 feature_index,
954 &leaves,
955 criterion,
956 options.min_samples_leaf,
957 )
958 })
959 .max_by(|left, right| left.score.total_cmp(&right.score))
960 } else {
961 feature_indices
962 .into_iter()
963 .filter_map(|feature_index| {
964 score_oblivious_split(
965 table,
966 &row_indices,
967 targets,
968 feature_index,
969 &leaves,
970 criterion,
971 options.min_samples_leaf,
972 )
973 })
974 .max_by(|left, right| left.score.total_cmp(&right.score))
975 };
976
977 let Some(best_split) = best_split.filter(|candidate| candidate.score > 0.0) else {
978 break;
979 };
980 if table.is_canary_binned_feature(best_split.feature_index) {
981 break;
982 }
983
984 leaves = split_oblivious_leaves_in_place(
985 table,
986 &mut row_indices,
987 targets,
988 leaves,
989 best_split.feature_index,
990 best_split.threshold_bin,
991 criterion,
992 );
993 splits.push(ObliviousSplit {
994 feature_index: best_split.feature_index,
995 threshold_bin: best_split.threshold_bin,
996 sample_count: table.n_rows(),
997 impurity: leaves
998 .iter()
999 .map(|leaf| leaf_regression_loss(leaf, &row_indices, targets, criterion))
1000 .sum(),
1001 gain: best_split.score,
1002 });
1003 }
1004
1005 RegressionTreeStructure::Oblivious {
1006 splits,
1007 leaf_values: leaves.iter().map(|leaf| leaf.value).collect(),
1008 leaf_sample_counts: leaves.iter().map(ObliviousLeafState::len).collect(),
1009 leaf_variances: leaves.iter().map(|leaf| leaf.variance).collect(),
1010 }
1011}
1012
1013fn score_split(
1014 table: &dyn TableAccess,
1015 targets: &[f64],
1016 feature_index: usize,
1017 rows: &[usize],
1018 criterion: Criterion,
1019 min_samples_leaf: usize,
1020 algorithm: RegressionTreeAlgorithm,
1021) -> Option<RegressionSplitCandidate> {
1022 if table.is_binary_binned_feature(feature_index) {
1023 return score_binary_split(
1024 table,
1025 targets,
1026 feature_index,
1027 rows,
1028 criterion,
1029 min_samples_leaf,
1030 );
1031 }
1032 if matches!(criterion, Criterion::Mean) {
1033 if matches!(algorithm, RegressionTreeAlgorithm::Randomized) {
1034 if let Some(candidate) = score_randomized_split_mean_fast(
1035 table,
1036 targets,
1037 feature_index,
1038 rows,
1039 min_samples_leaf,
1040 ) {
1041 return Some(candidate);
1042 }
1043 } else if let Some(candidate) =
1044 score_numeric_split_mean_fast(table, targets, feature_index, rows, min_samples_leaf)
1045 {
1046 return Some(candidate);
1047 }
1048 }
1049 if matches!(algorithm, RegressionTreeAlgorithm::Randomized) {
1050 return score_randomized_split(
1051 table,
1052 targets,
1053 feature_index,
1054 rows,
1055 criterion,
1056 min_samples_leaf,
1057 );
1058 }
1059 let parent_loss = regression_loss(rows, targets, criterion);
1060
1061 rows.iter()
1062 .map(|row_idx| table.binned_value(feature_index, *row_idx))
1063 .collect::<BTreeSet<_>>()
1064 .into_iter()
1065 .filter_map(|threshold_bin| {
1066 let (left_rows, right_rows): (Vec<usize>, Vec<usize>) = rows
1067 .iter()
1068 .copied()
1069 .partition(|row_idx| table.binned_value(feature_index, *row_idx) <= threshold_bin);
1070
1071 if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1072 return None;
1073 }
1074
1075 let score = parent_loss
1076 - (regression_loss(&left_rows, targets, criterion)
1077 + regression_loss(&right_rows, targets, criterion));
1078
1079 Some(RegressionSplitCandidate {
1080 feature_index,
1081 threshold_bin,
1082 score,
1083 })
1084 })
1085 .max_by(|left, right| left.score.total_cmp(&right.score))
1086}
1087
1088fn score_randomized_split(
1089 table: &dyn TableAccess,
1090 targets: &[f64],
1091 feature_index: usize,
1092 rows: &[usize],
1093 criterion: Criterion,
1094 min_samples_leaf: usize,
1095) -> Option<RegressionSplitCandidate> {
1096 let candidate_thresholds = rows
1097 .iter()
1098 .map(|row_idx| table.binned_value(feature_index, *row_idx))
1099 .collect::<BTreeSet<_>>()
1100 .into_iter()
1101 .collect::<Vec<_>>();
1102 let threshold_bin =
1103 choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xA11CE551u64)?;
1104
1105 let (left_rows, right_rows): (Vec<usize>, Vec<usize>) = rows
1106 .iter()
1107 .copied()
1108 .partition(|row_idx| table.binned_value(feature_index, *row_idx) <= threshold_bin);
1109
1110 if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1111 return None;
1112 }
1113
1114 let parent_loss = regression_loss(rows, targets, criterion);
1115 let score = parent_loss
1116 - (regression_loss(&left_rows, targets, criterion)
1117 + regression_loss(&right_rows, targets, criterion));
1118
1119 Some(RegressionSplitCandidate {
1120 feature_index,
1121 threshold_bin,
1122 score,
1123 })
1124}
1125
1126fn score_oblivious_split(
1127 table: &dyn TableAccess,
1128 row_indices: &[usize],
1129 targets: &[f64],
1130 feature_index: usize,
1131 leaves: &[ObliviousLeafState],
1132 criterion: Criterion,
1133 min_samples_leaf: usize,
1134) -> Option<ObliviousSplitCandidate> {
1135 if table.is_binary_binned_feature(feature_index) {
1136 if matches!(criterion, Criterion::Mean)
1137 && let Some(candidate) = score_binary_oblivious_split_mean_fast(
1138 table,
1139 row_indices,
1140 targets,
1141 feature_index,
1142 leaves,
1143 min_samples_leaf,
1144 )
1145 {
1146 return Some(candidate);
1147 }
1148 return score_binary_oblivious_split(
1149 table,
1150 row_indices,
1151 targets,
1152 feature_index,
1153 leaves,
1154 criterion,
1155 min_samples_leaf,
1156 );
1157 }
1158 if matches!(criterion, Criterion::Mean)
1159 && let Some(candidate) = score_numeric_oblivious_split_mean_fast(
1160 table,
1161 row_indices,
1162 targets,
1163 feature_index,
1164 leaves,
1165 min_samples_leaf,
1166 )
1167 {
1168 return Some(candidate);
1169 }
1170 let candidate_thresholds = leaves
1171 .iter()
1172 .flat_map(|leaf| {
1173 row_indices[leaf.start..leaf.end]
1174 .iter()
1175 .map(|row_idx| table.binned_value(feature_index, *row_idx))
1176 })
1177 .collect::<BTreeSet<_>>();
1178
1179 candidate_thresholds
1180 .into_iter()
1181 .filter_map(|threshold_bin| {
1182 let score = leaves.iter().fold(0.0, |score, leaf| {
1183 let leaf_rows = &row_indices[leaf.start..leaf.end];
1184 let (left_rows, right_rows): (Vec<usize>, Vec<usize>) =
1185 leaf_rows.iter().copied().partition(|row_idx| {
1186 table.binned_value(feature_index, *row_idx) <= threshold_bin
1187 });
1188
1189 if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1190 return score;
1191 }
1192
1193 score + regression_loss(leaf_rows, targets, criterion)
1194 - (regression_loss(&left_rows, targets, criterion)
1195 + regression_loss(&right_rows, targets, criterion))
1196 });
1197
1198 (score > 0.0).then_some(ObliviousSplitCandidate {
1199 feature_index,
1200 threshold_bin,
1201 score,
1202 })
1203 })
1204 .max_by(|left, right| left.score.total_cmp(&right.score))
1205}
1206
1207fn split_oblivious_leaves_in_place(
1208 table: &dyn TableAccess,
1209 row_indices: &mut [usize],
1210 targets: &[f64],
1211 leaves: Vec<ObliviousLeafState>,
1212 feature_index: usize,
1213 threshold_bin: u16,
1214 criterion: Criterion,
1215) -> Vec<ObliviousLeafState> {
1216 let mut next_leaves = Vec::with_capacity(leaves.len() * 2);
1217 for leaf in leaves {
1218 let fallback_value = leaf.value;
1219 let left_count = partition_rows_for_binary_split(
1220 table,
1221 feature_index,
1222 threshold_bin,
1223 &mut row_indices[leaf.start..leaf.end],
1224 );
1225 let mid = leaf.start + left_count;
1226 let left_rows = &row_indices[leaf.start..mid];
1227 let right_rows = &row_indices[mid..leaf.end];
1228 let (left_sum, left_sum_sq) = sum_stats(left_rows, targets);
1229 let (right_sum, right_sum_sq) = sum_stats(right_rows, targets);
1230 next_leaves.push(ObliviousLeafState {
1231 start: leaf.start,
1232 end: mid,
1233 value: if left_rows.is_empty() {
1234 fallback_value
1235 } else {
1236 regression_value_from_stats(left_rows, targets, criterion, left_sum)
1237 },
1238 variance: variance_from_stats(left_rows.len(), left_sum, left_sum_sq),
1239 sum: left_sum,
1240 sum_sq: left_sum_sq,
1241 });
1242 next_leaves.push(ObliviousLeafState {
1243 start: mid,
1244 end: leaf.end,
1245 value: if right_rows.is_empty() {
1246 fallback_value
1247 } else {
1248 regression_value_from_stats(right_rows, targets, criterion, right_sum)
1249 },
1250 variance: variance_from_stats(right_rows.len(), right_sum, right_sum_sq),
1251 sum: right_sum,
1252 sum_sq: right_sum_sq,
1253 });
1254 }
1255 next_leaves
1256}
1257
1258fn variance(rows: &[usize], targets: &[f64]) -> Option<f64> {
1259 let (sum, sum_sq) = sum_stats(rows, targets);
1260 variance_from_stats(rows.len(), sum, sum_sq)
1261}
1262
1263fn mean(rows: &[usize], targets: &[f64]) -> f64 {
1264 if rows.is_empty() {
1265 0.0
1266 } else {
1267 rows.iter().map(|row_idx| targets[*row_idx]).sum::<f64>() / rows.len() as f64
1268 }
1269}
1270
1271fn median(rows: &[usize], targets: &[f64]) -> f64 {
1272 if rows.is_empty() {
1273 return 0.0;
1274 }
1275 let mut values: Vec<f64> = rows.iter().map(|row_idx| targets[*row_idx]).collect();
1276 values.sort_by(|left, right| left.total_cmp(right));
1277
1278 let mid = values.len() / 2;
1279 if values.len().is_multiple_of(2) {
1280 (values[mid - 1] + values[mid]) / 2.0
1281 } else {
1282 values[mid]
1283 }
1284}
1285
1286fn sum_squared_error(rows: &[usize], targets: &[f64]) -> f64 {
1287 let mean = mean(rows, targets);
1288 rows.iter()
1289 .map(|row_idx| {
1290 let diff = targets[*row_idx] - mean;
1291 diff * diff
1292 })
1293 .sum()
1294}
1295
1296fn sum_absolute_error(rows: &[usize], targets: &[f64]) -> f64 {
1297 let median = median(rows, targets);
1298 rows.iter()
1299 .map(|row_idx| (targets[*row_idx] - median).abs())
1300 .sum()
1301}
1302
1303fn regression_value(rows: &[usize], targets: &[f64], criterion: Criterion) -> f64 {
1304 let (sum, _sum_sq) = sum_stats(rows, targets);
1305 regression_value_from_stats(rows, targets, criterion, sum)
1306}
1307
1308fn regression_value_from_stats(
1309 rows: &[usize],
1310 targets: &[f64],
1311 criterion: Criterion,
1312 sum: f64,
1313) -> f64 {
1314 match criterion {
1315 Criterion::Mean => {
1316 if rows.is_empty() {
1317 0.0
1318 } else {
1319 sum / rows.len() as f64
1320 }
1321 }
1322 Criterion::Median => median(rows, targets),
1323 _ => unreachable!("regression criterion only supports mean or median"),
1324 }
1325}
1326
1327fn regression_loss(rows: &[usize], targets: &[f64], criterion: Criterion) -> f64 {
1328 match criterion {
1329 Criterion::Mean => sum_squared_error(rows, targets),
1330 Criterion::Median => sum_absolute_error(rows, targets),
1331 _ => unreachable!("regression criterion only supports mean or median"),
1332 }
1333}
1334
1335fn score_binary_split(
1336 table: &dyn TableAccess,
1337 targets: &[f64],
1338 feature_index: usize,
1339 rows: &[usize],
1340 criterion: Criterion,
1341 min_samples_leaf: usize,
1342) -> Option<RegressionSplitCandidate> {
1343 let (left_rows, right_rows): (Vec<usize>, Vec<usize>) =
1344 rows.iter().copied().partition(|row_idx| {
1345 !table
1346 .binned_boolean_value(feature_index, *row_idx)
1347 .expect("binary feature must expose boolean values")
1348 });
1349
1350 if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1351 return None;
1352 }
1353
1354 let parent_loss = regression_loss(rows, targets, criterion);
1355 let score = parent_loss
1356 - (regression_loss(&left_rows, targets, criterion)
1357 + regression_loss(&right_rows, targets, criterion));
1358
1359 Some(RegressionSplitCandidate {
1360 feature_index,
1361 threshold_bin: 0,
1362 score,
1363 })
1364}
1365
1366fn score_binary_split_choice(
1367 context: &BuildContext<'_>,
1368 feature_index: usize,
1369 rows: &[usize],
1370) -> Option<BinarySplitChoice> {
1371 if matches!(context.criterion, Criterion::Mean) {
1372 if context.table.is_binary_binned_feature(feature_index) {
1373 return score_binary_split_choice_mean(context, feature_index, rows);
1374 }
1375 return match context.algorithm {
1376 RegressionTreeAlgorithm::Cart => {
1377 score_numeric_split_choice_mean_fast(context, feature_index, rows)
1378 }
1379 RegressionTreeAlgorithm::Randomized => {
1380 score_randomized_split_choice_mean_fast(context, feature_index, rows)
1381 }
1382 RegressionTreeAlgorithm::Oblivious => None,
1383 };
1384 }
1385
1386 score_split(
1387 context.table,
1388 context.targets,
1389 feature_index,
1390 rows,
1391 context.criterion,
1392 context.options.min_samples_leaf,
1393 context.algorithm,
1394 )
1395 .map(|candidate| BinarySplitChoice {
1396 feature_index: candidate.feature_index,
1397 threshold_bin: candidate.threshold_bin,
1398 score: candidate.score,
1399 })
1400}
1401
1402fn score_binary_split_choice_from_hist(
1403 context: &BuildContext<'_>,
1404 histogram: &RegressionFeatureHistogram,
1405 feature_index: usize,
1406 rows: &[usize],
1407) -> Option<BinarySplitChoice> {
1408 if !matches!(context.criterion, Criterion::Mean) {
1409 return score_binary_split_choice(context, feature_index, rows);
1410 }
1411
1412 match histogram {
1413 RegressionFeatureHistogram::Binary {
1414 false_bin,
1415 true_bin,
1416 } => score_binary_split_choice_mean_from_stats(
1417 context,
1418 feature_index,
1419 false_bin.count,
1420 false_bin.sum,
1421 false_bin.sum_sq,
1422 true_bin.count,
1423 true_bin.sum,
1424 true_bin.sum_sq,
1425 ),
1426 RegressionFeatureHistogram::Numeric {
1427 bins,
1428 observed_bins,
1429 } => match context.algorithm {
1430 RegressionTreeAlgorithm::Cart => score_numeric_split_choice_mean_from_hist(
1431 context,
1432 feature_index,
1433 rows.len(),
1434 bins,
1435 observed_bins,
1436 ),
1437 RegressionTreeAlgorithm::Randomized => score_randomized_split_choice_mean_from_hist(
1438 context,
1439 feature_index,
1440 rows,
1441 bins,
1442 observed_bins,
1443 ),
1444 RegressionTreeAlgorithm::Oblivious => None,
1445 },
1446 }
1447}
1448
1449#[allow(clippy::too_many_arguments)]
1450fn score_binary_split_choice_mean_from_stats(
1451 context: &BuildContext<'_>,
1452 feature_index: usize,
1453 left_count: usize,
1454 left_sum: f64,
1455 left_sum_sq: f64,
1456 right_count: usize,
1457 right_sum: f64,
1458 right_sum_sq: f64,
1459) -> Option<BinarySplitChoice> {
1460 if left_count < context.options.min_samples_leaf
1461 || right_count < context.options.min_samples_leaf
1462 {
1463 return None;
1464 }
1465 let total_count = left_count + right_count;
1466 let total_sum = left_sum + right_sum;
1467 let total_sum_sq = left_sum_sq + right_sum_sq;
1468 let parent_loss = total_sum_sq - (total_sum * total_sum) / total_count as f64;
1469 let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1470 let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1471 Some(BinarySplitChoice {
1472 feature_index,
1473 threshold_bin: 0,
1474 score: parent_loss - (left_loss + right_loss),
1475 })
1476}
1477
1478fn score_numeric_split_choice_mean_from_hist(
1479 context: &BuildContext<'_>,
1480 feature_index: usize,
1481 row_count: usize,
1482 bins: &[RegressionHistogramBin],
1483 observed_bins: &[usize],
1484) -> Option<BinarySplitChoice> {
1485 if observed_bins.len() <= 1 {
1486 return None;
1487 }
1488 let total_sum = bins.iter().map(|bin| bin.sum).sum::<f64>();
1489 let total_sum_sq = bins.iter().map(|bin| bin.sum_sq).sum::<f64>();
1490 let parent_loss = total_sum_sq - (total_sum * total_sum) / row_count as f64;
1491 let mut left_count = 0usize;
1492 let mut left_sum = 0.0;
1493 let mut left_sum_sq = 0.0;
1494 let mut best_threshold = None;
1495 let mut best_score = f64::NEG_INFINITY;
1496
1497 for &bin in observed_bins {
1498 left_count += bins[bin].count;
1499 left_sum += bins[bin].sum;
1500 left_sum_sq += bins[bin].sum_sq;
1501 let right_count = row_count - left_count;
1502 if left_count < context.options.min_samples_leaf
1503 || right_count < context.options.min_samples_leaf
1504 {
1505 continue;
1506 }
1507 let right_sum = total_sum - left_sum;
1508 let right_sum_sq = total_sum_sq - left_sum_sq;
1509 let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1510 let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1511 let score = parent_loss - (left_loss + right_loss);
1512 if score > best_score {
1513 best_score = score;
1514 best_threshold = Some(bin as u16);
1515 }
1516 }
1517
1518 best_threshold.map(|threshold_bin| BinarySplitChoice {
1519 feature_index,
1520 threshold_bin,
1521 score: best_score,
1522 })
1523}
1524
1525fn score_randomized_split_choice_mean_from_hist(
1526 context: &BuildContext<'_>,
1527 feature_index: usize,
1528 rows: &[usize],
1529 bins: &[RegressionHistogramBin],
1530 observed_bins: &[usize],
1531) -> Option<BinarySplitChoice> {
1532 let candidate_thresholds = observed_bins
1533 .iter()
1534 .copied()
1535 .map(|bin| bin as u16)
1536 .collect::<Vec<_>>();
1537 let threshold_bin =
1538 choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xA11CE551u64)?;
1539 let total_sum = bins.iter().map(|bin| bin.sum).sum::<f64>();
1540 let total_sum_sq = bins.iter().map(|bin| bin.sum_sq).sum::<f64>();
1541 let mut left_count = 0usize;
1542 let mut left_sum = 0.0;
1543 let mut left_sum_sq = 0.0;
1544 for bin in 0..=threshold_bin as usize {
1545 if bin >= bins.len() {
1546 break;
1547 }
1548 left_count += bins[bin].count;
1549 left_sum += bins[bin].sum;
1550 left_sum_sq += bins[bin].sum_sq;
1551 }
1552 let right_count = rows.len() - left_count;
1553 if left_count < context.options.min_samples_leaf
1554 || right_count < context.options.min_samples_leaf
1555 {
1556 return None;
1557 }
1558 let parent_loss = total_sum_sq - (total_sum * total_sum) / rows.len() as f64;
1559 let right_sum = total_sum - left_sum;
1560 let right_sum_sq = total_sum_sq - left_sum_sq;
1561 let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1562 let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1563 Some(BinarySplitChoice {
1564 feature_index,
1565 threshold_bin,
1566 score: parent_loss - (left_loss + right_loss),
1567 })
1568}
1569
1570fn score_binary_split_choice_mean(
1571 context: &BuildContext<'_>,
1572 feature_index: usize,
1573 rows: &[usize],
1574) -> Option<BinarySplitChoice> {
1575 let mut left_count = 0usize;
1576 let mut left_sum = 0.0;
1577 let mut left_sum_sq = 0.0;
1578 let mut total_sum = 0.0;
1579 let mut total_sum_sq = 0.0;
1580
1581 for row_idx in rows {
1582 let target = context.targets[*row_idx];
1583 total_sum += target;
1584 total_sum_sq += target * target;
1585 if !context
1586 .table
1587 .binned_boolean_value(feature_index, *row_idx)
1588 .expect("binary feature must expose boolean values")
1589 {
1590 left_count += 1;
1591 left_sum += target;
1592 left_sum_sq += target * target;
1593 }
1594 }
1595
1596 let right_count = rows.len() - left_count;
1597 if left_count < context.options.min_samples_leaf
1598 || right_count < context.options.min_samples_leaf
1599 {
1600 return None;
1601 }
1602
1603 let parent_loss = total_sum_sq - (total_sum * total_sum) / rows.len() as f64;
1604 let right_sum = total_sum - left_sum;
1605 let right_sum_sq = total_sum_sq - left_sum_sq;
1606 let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1607 let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1608
1609 Some(BinarySplitChoice {
1610 feature_index,
1611 threshold_bin: 0,
1612 score: parent_loss - (left_loss + right_loss),
1613 })
1614}
1615
1616fn score_numeric_split_mean_fast(
1617 table: &dyn TableAccess,
1618 targets: &[f64],
1619 feature_index: usize,
1620 rows: &[usize],
1621 min_samples_leaf: usize,
1622) -> Option<RegressionSplitCandidate> {
1623 let bin_cap = table.numeric_bin_cap();
1624 if bin_cap == 0 {
1625 return None;
1626 }
1627
1628 let mut bin_count = vec![0usize; bin_cap];
1629 let mut bin_sum = vec![0.0; bin_cap];
1630 let mut bin_sum_sq = vec![0.0; bin_cap];
1631 let mut observed_bins = vec![false; bin_cap];
1632 let mut total_sum = 0.0;
1633 let mut total_sum_sq = 0.0;
1634
1635 for row_idx in rows {
1636 let bin = table.binned_value(feature_index, *row_idx) as usize;
1637 if bin >= bin_cap {
1638 return None;
1639 }
1640 let target = targets[*row_idx];
1641 bin_count[bin] += 1;
1642 bin_sum[bin] += target;
1643 bin_sum_sq[bin] += target * target;
1644 observed_bins[bin] = true;
1645 total_sum += target;
1646 total_sum_sq += target * target;
1647 }
1648
1649 let observed_bins: Vec<usize> = observed_bins
1650 .into_iter()
1651 .enumerate()
1652 .filter_map(|(bin, seen)| seen.then_some(bin))
1653 .collect();
1654 if observed_bins.len() <= 1 {
1655 return None;
1656 }
1657
1658 let parent_loss = total_sum_sq - (total_sum * total_sum) / rows.len() as f64;
1659 let mut left_count = 0usize;
1660 let mut left_sum = 0.0;
1661 let mut left_sum_sq = 0.0;
1662 let mut best_threshold = None;
1663 let mut best_score = f64::NEG_INFINITY;
1664
1665 for &bin in &observed_bins {
1666 left_count += bin_count[bin];
1667 left_sum += bin_sum[bin];
1668 left_sum_sq += bin_sum_sq[bin];
1669 let right_count = rows.len() - left_count;
1670
1671 if left_count < min_samples_leaf || right_count < min_samples_leaf {
1672 continue;
1673 }
1674
1675 let right_sum = total_sum - left_sum;
1676 let right_sum_sq = total_sum_sq - left_sum_sq;
1677 let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1678 let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1679 let score = parent_loss - (left_loss + right_loss);
1680 if score > best_score {
1681 best_score = score;
1682 best_threshold = Some(bin as u16);
1683 }
1684 }
1685
1686 let threshold_bin = best_threshold?;
1687 Some(RegressionSplitCandidate {
1688 feature_index,
1689 threshold_bin,
1690 score: best_score,
1691 })
1692}
1693
1694fn score_numeric_split_choice_mean_fast(
1695 context: &BuildContext<'_>,
1696 feature_index: usize,
1697 rows: &[usize],
1698) -> Option<BinarySplitChoice> {
1699 score_numeric_split_mean_fast(
1700 context.table,
1701 context.targets,
1702 feature_index,
1703 rows,
1704 context.options.min_samples_leaf,
1705 )
1706 .map(|candidate| BinarySplitChoice {
1707 feature_index: candidate.feature_index,
1708 threshold_bin: candidate.threshold_bin,
1709 score: candidate.score,
1710 })
1711}
1712
1713fn score_randomized_split_mean_fast(
1714 table: &dyn TableAccess,
1715 targets: &[f64],
1716 feature_index: usize,
1717 rows: &[usize],
1718 min_samples_leaf: usize,
1719) -> Option<RegressionSplitCandidate> {
1720 let bin_cap = table.numeric_bin_cap();
1721 if bin_cap == 0 {
1722 return None;
1723 }
1724 let mut observed_bins = vec![false; bin_cap];
1725 for row_idx in rows {
1726 let bin = table.binned_value(feature_index, *row_idx) as usize;
1727 if bin >= bin_cap {
1728 return None;
1729 }
1730 observed_bins[bin] = true;
1731 }
1732 let candidate_thresholds = observed_bins
1733 .into_iter()
1734 .enumerate()
1735 .filter_map(|(bin, seen)| seen.then_some(bin as u16))
1736 .collect::<Vec<_>>();
1737 let threshold_bin =
1738 choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xA11CE551u64)?;
1739
1740 let mut left_count = 0usize;
1741 let mut left_sum = 0.0;
1742 let mut left_sum_sq = 0.0;
1743 let mut total_sum = 0.0;
1744 let mut total_sum_sq = 0.0;
1745 for row_idx in rows {
1746 let target = targets[*row_idx];
1747 total_sum += target;
1748 total_sum_sq += target * target;
1749 if table.binned_value(feature_index, *row_idx) <= threshold_bin {
1750 left_count += 1;
1751 left_sum += target;
1752 left_sum_sq += target * target;
1753 }
1754 }
1755 let right_count = rows.len() - left_count;
1756 if left_count < min_samples_leaf || right_count < min_samples_leaf {
1757 return None;
1758 }
1759 let parent_loss = total_sum_sq - (total_sum * total_sum) / rows.len() as f64;
1760 let right_sum = total_sum - left_sum;
1761 let right_sum_sq = total_sum_sq - left_sum_sq;
1762 let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1763 let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1764 let score = parent_loss - (left_loss + right_loss);
1765
1766 Some(RegressionSplitCandidate {
1767 feature_index,
1768 threshold_bin,
1769 score,
1770 })
1771}
1772
1773fn score_randomized_split_choice_mean_fast(
1774 context: &BuildContext<'_>,
1775 feature_index: usize,
1776 rows: &[usize],
1777) -> Option<BinarySplitChoice> {
1778 score_randomized_split_mean_fast(
1779 context.table,
1780 context.targets,
1781 feature_index,
1782 rows,
1783 context.options.min_samples_leaf,
1784 )
1785 .map(|candidate| BinarySplitChoice {
1786 feature_index: candidate.feature_index,
1787 threshold_bin: candidate.threshold_bin,
1788 score: candidate.score,
1789 })
1790}
1791
1792fn score_numeric_oblivious_split_mean_fast(
1793 table: &dyn TableAccess,
1794 row_indices: &[usize],
1795 targets: &[f64],
1796 feature_index: usize,
1797 leaves: &[ObliviousLeafState],
1798 min_samples_leaf: usize,
1799) -> Option<ObliviousSplitCandidate> {
1800 let bin_cap = table.numeric_bin_cap();
1801 if bin_cap == 0 {
1802 return None;
1803 }
1804 let mut threshold_scores = vec![0.0; bin_cap];
1805 let mut observed_any = false;
1806
1807 for leaf in leaves {
1808 let mut bin_count = vec![0usize; bin_cap];
1809 let mut bin_sum = vec![0.0; bin_cap];
1810 let mut bin_sum_sq = vec![0.0; bin_cap];
1811 let mut observed_bins = vec![false; bin_cap];
1812
1813 for row_idx in &row_indices[leaf.start..leaf.end] {
1814 let bin = table.binned_value(feature_index, *row_idx) as usize;
1815 if bin >= bin_cap {
1816 return None;
1817 }
1818 let target = targets[*row_idx];
1819 bin_count[bin] += 1;
1820 bin_sum[bin] += target;
1821 bin_sum_sq[bin] += target * target;
1822 observed_bins[bin] = true;
1823 }
1824
1825 let observed_bins: Vec<usize> = observed_bins
1826 .into_iter()
1827 .enumerate()
1828 .filter_map(|(bin, seen)| seen.then_some(bin))
1829 .collect();
1830 if observed_bins.len() <= 1 {
1831 continue;
1832 }
1833 observed_any = true;
1834
1835 let parent_loss = leaf.sum_sq - (leaf.sum * leaf.sum) / leaf.len() as f64;
1836 let mut left_count = 0usize;
1837 let mut left_sum = 0.0;
1838 let mut left_sum_sq = 0.0;
1839
1840 for &bin in &observed_bins {
1841 left_count += bin_count[bin];
1842 left_sum += bin_sum[bin];
1843 left_sum_sq += bin_sum_sq[bin];
1844 let right_count = leaf.len() - left_count;
1845
1846 if left_count < min_samples_leaf || right_count < min_samples_leaf {
1847 continue;
1848 }
1849
1850 let right_sum = leaf.sum - left_sum;
1851 let right_sum_sq = leaf.sum_sq - left_sum_sq;
1852 let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1853 let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1854 threshold_scores[bin] += parent_loss - (left_loss + right_loss);
1855 }
1856 }
1857
1858 if !observed_any {
1859 return None;
1860 }
1861
1862 threshold_scores
1863 .into_iter()
1864 .enumerate()
1865 .filter(|(_, score)| *score > 0.0)
1866 .max_by(|left, right| left.1.total_cmp(&right.1))
1867 .map(|(threshold_bin, score)| ObliviousSplitCandidate {
1868 feature_index,
1869 threshold_bin: threshold_bin as u16,
1870 score,
1871 })
1872}
1873
1874fn score_binary_oblivious_split(
1875 table: &dyn TableAccess,
1876 row_indices: &[usize],
1877 targets: &[f64],
1878 feature_index: usize,
1879 leaves: &[ObliviousLeafState],
1880 criterion: Criterion,
1881 min_samples_leaf: usize,
1882) -> Option<ObliviousSplitCandidate> {
1883 let score = leaves.iter().fold(0.0, |score, leaf| {
1884 let leaf_rows = &row_indices[leaf.start..leaf.end];
1885 let (left_rows, right_rows): (Vec<usize>, Vec<usize>) =
1886 leaf_rows.iter().copied().partition(|row_idx| {
1887 !table
1888 .binned_boolean_value(feature_index, *row_idx)
1889 .expect("binary feature must expose boolean values")
1890 });
1891
1892 if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1893 return score;
1894 }
1895
1896 score + regression_loss(leaf_rows, targets, criterion)
1897 - (regression_loss(&left_rows, targets, criterion)
1898 + regression_loss(&right_rows, targets, criterion))
1899 });
1900
1901 (score > 0.0).then_some(ObliviousSplitCandidate {
1902 feature_index,
1903 threshold_bin: 0,
1904 score,
1905 })
1906}
1907
1908fn score_binary_oblivious_split_mean_fast(
1909 table: &dyn TableAccess,
1910 row_indices: &[usize],
1911 targets: &[f64],
1912 feature_index: usize,
1913 leaves: &[ObliviousLeafState],
1914 min_samples_leaf: usize,
1915) -> Option<ObliviousSplitCandidate> {
1916 let mut score = 0.0;
1917 let mut found_valid = false;
1918
1919 for leaf in leaves {
1920 let mut left_count = 0usize;
1921 let mut left_sum = 0.0;
1922 let mut left_sum_sq = 0.0;
1923
1924 for row_idx in &row_indices[leaf.start..leaf.end] {
1925 if !table
1926 .binned_boolean_value(feature_index, *row_idx)
1927 .expect("binary feature must expose boolean values")
1928 {
1929 let target = targets[*row_idx];
1930 left_count += 1;
1931 left_sum += target;
1932 left_sum_sq += target * target;
1933 }
1934 }
1935
1936 let right_count = leaf.len() - left_count;
1937 if left_count < min_samples_leaf || right_count < min_samples_leaf {
1938 continue;
1939 }
1940
1941 found_valid = true;
1942 let parent_loss = leaf.sum_sq - (leaf.sum * leaf.sum) / leaf.len() as f64;
1943 let right_sum = leaf.sum - left_sum;
1944 let right_sum_sq = leaf.sum_sq - left_sum_sq;
1945 let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1946 let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1947 score += parent_loss - (left_loss + right_loss);
1948 }
1949
1950 (found_valid && score > 0.0).then_some(ObliviousSplitCandidate {
1951 feature_index,
1952 threshold_bin: 0,
1953 score,
1954 })
1955}
1956
1957fn sum_stats(rows: &[usize], targets: &[f64]) -> (f64, f64) {
1958 rows.iter().fold((0.0, 0.0), |(sum, sum_sq), row_idx| {
1959 let value = targets[*row_idx];
1960 (sum + value, sum_sq + value * value)
1961 })
1962}
1963
1964fn variance_from_stats(count: usize, sum: f64, sum_sq: f64) -> Option<f64> {
1965 if count == 0 {
1966 None
1967 } else {
1968 Some((sum_sq / count as f64) - (sum / count as f64).powi(2))
1969 }
1970}
1971
1972fn leaf_regression_loss(
1973 leaf: &ObliviousLeafState,
1974 row_indices: &[usize],
1975 targets: &[f64],
1976 criterion: Criterion,
1977) -> f64 {
1978 match criterion {
1979 Criterion::Mean => leaf.sum_sq - (leaf.sum * leaf.sum) / leaf.len() as f64,
1980 Criterion::Median => {
1981 regression_loss(&row_indices[leaf.start..leaf.end], targets, criterion)
1982 }
1983 _ => unreachable!("regression criterion only supports mean or median"),
1984 }
1985}
1986
1987fn has_constant_target(rows: &[usize], targets: &[f64]) -> bool {
1988 rows.first().is_none_or(|first_row| {
1989 rows.iter()
1990 .all(|row_idx| targets[*row_idx] == targets[*first_row])
1991 })
1992}
1993
1994fn push_leaf(
1995 nodes: &mut Vec<RegressionNode>,
1996 value: f64,
1997 sample_count: usize,
1998 variance: Option<f64>,
1999) -> usize {
2000 push_node(
2001 nodes,
2002 RegressionNode::Leaf {
2003 value,
2004 sample_count,
2005 variance,
2006 },
2007 )
2008}
2009
2010fn push_node(nodes: &mut Vec<RegressionNode>, node: RegressionNode) -> usize {
2011 nodes.push(node);
2012 nodes.len() - 1
2013}
2014
2015#[cfg(test)]
2016mod tests {
2017 use super::*;
2018 use crate::{FeaturePreprocessing, Model, NumericBinBoundary};
2019 use forestfire_data::{DenseTable, NumericBins};
2020
2021 fn quadratic_table() -> DenseTable {
2022 DenseTable::with_options(
2023 vec![
2024 vec![0.0],
2025 vec![1.0],
2026 vec![2.0],
2027 vec![3.0],
2028 vec![4.0],
2029 vec![5.0],
2030 vec![6.0],
2031 vec![7.0],
2032 ],
2033 vec![0.0, 1.0, 4.0, 9.0, 16.0, 25.0, 36.0, 49.0],
2034 0,
2035 NumericBins::Fixed(128),
2036 )
2037 .unwrap()
2038 }
2039
2040 fn canary_target_table() -> DenseTable {
2041 let x: Vec<Vec<f64>> = (0..8).map(|value| vec![value as f64]).collect();
2042 let probe =
2043 DenseTable::with_options(x.clone(), vec![0.0; 8], 1, NumericBins::Auto).unwrap();
2044 let canary_index = probe.n_features();
2045 let y = (0..probe.n_rows())
2046 .map(|row_idx| probe.binned_value(canary_index, row_idx) as f64)
2047 .collect();
2048
2049 DenseTable::with_options(x, y, 1, NumericBins::Auto).unwrap()
2050 }
2051
2052 fn randomized_permutation_table() -> DenseTable {
2053 DenseTable::with_options(
2054 vec![
2055 vec![0.0, 0.0, 0.0],
2056 vec![0.0, 0.0, 1.0],
2057 vec![0.0, 1.0, 0.0],
2058 vec![0.0, 1.0, 1.0],
2059 vec![1.0, 0.0, 0.0],
2060 vec![1.0, 0.0, 1.0],
2061 vec![1.0, 1.0, 0.0],
2062 vec![1.0, 1.0, 1.0],
2063 vec![0.0, 0.0, 2.0],
2064 vec![0.0, 1.0, 2.0],
2065 vec![1.0, 0.0, 2.0],
2066 vec![1.0, 1.0, 2.0],
2067 ],
2068 vec![0.0, 1.0, 2.5, 3.5, 4.0, 5.0, 6.5, 7.5, 2.0, 4.5, 6.0, 8.5],
2069 0,
2070 NumericBins::Fixed(8),
2071 )
2072 .unwrap()
2073 }
2074
2075 #[test]
2076 fn cart_regressor_fits_basic_numeric_pattern() {
2077 let table = quadratic_table();
2078 let model = train_cart_regressor(&table).unwrap();
2079 let preds = model.predict_table(&table);
2080
2081 assert_eq!(model.algorithm(), RegressionTreeAlgorithm::Cart);
2082 assert_eq!(model.criterion(), Criterion::Mean);
2083 assert_eq!(preds, table_targets(&table));
2084 }
2085
2086 #[test]
2087 fn randomized_regressor_fits_basic_numeric_pattern() {
2088 let table = quadratic_table();
2089 let model = train_randomized_regressor(&table).unwrap();
2090 let preds = model.predict_table(&table);
2091 let targets = table_targets(&table);
2092 let baseline_mean = targets.iter().sum::<f64>() / targets.len() as f64;
2093 let baseline_sse = targets
2094 .iter()
2095 .map(|target| {
2096 let diff = target - baseline_mean;
2097 diff * diff
2098 })
2099 .sum::<f64>();
2100 let model_sse = preds
2101 .iter()
2102 .zip(targets.iter())
2103 .map(|(pred, target)| {
2104 let diff = pred - target;
2105 diff * diff
2106 })
2107 .sum::<f64>();
2108
2109 assert_eq!(model.algorithm(), RegressionTreeAlgorithm::Randomized);
2110 assert_eq!(model.criterion(), Criterion::Mean);
2111 assert!(model_sse < baseline_sse);
2112 }
2113
2114 #[test]
2115 fn randomized_regressor_is_repeatable_for_fixed_seed_and_varies_across_seeds() {
2116 let table = randomized_permutation_table();
2117 let make_options = |random_seed| RegressionTreeOptions {
2118 max_depth: 4,
2119 max_features: Some(2),
2120 random_seed,
2121 ..RegressionTreeOptions::default()
2122 };
2123
2124 let base_model = train_randomized_regressor_with_criterion_parallelism_and_options(
2125 &table,
2126 Criterion::Mean,
2127 Parallelism::sequential(),
2128 make_options(91),
2129 )
2130 .unwrap();
2131 let repeated_model = train_randomized_regressor_with_criterion_parallelism_and_options(
2132 &table,
2133 Criterion::Mean,
2134 Parallelism::sequential(),
2135 make_options(91),
2136 )
2137 .unwrap();
2138 let unique_serializations = (0..32u64)
2139 .map(|seed| {
2140 Model::DecisionTreeRegressor(
2141 train_randomized_regressor_with_criterion_parallelism_and_options(
2142 &table,
2143 Criterion::Mean,
2144 Parallelism::sequential(),
2145 make_options(seed),
2146 )
2147 .unwrap(),
2148 )
2149 .serialize()
2150 .unwrap()
2151 })
2152 .collect::<std::collections::BTreeSet<_>>();
2153
2154 assert_eq!(
2155 Model::DecisionTreeRegressor(base_model.clone())
2156 .serialize()
2157 .unwrap(),
2158 Model::DecisionTreeRegressor(repeated_model)
2159 .serialize()
2160 .unwrap()
2161 );
2162 assert!(unique_serializations.len() >= 4);
2163 }
2164
2165 #[test]
2166 fn oblivious_regressor_fits_basic_numeric_pattern() {
2167 let table = quadratic_table();
2168 let model = train_oblivious_regressor(&table).unwrap();
2169 let preds = model.predict_table(&table);
2170
2171 assert_eq!(model.algorithm(), RegressionTreeAlgorithm::Oblivious);
2172 assert_eq!(model.criterion(), Criterion::Mean);
2173 assert_eq!(preds, table_targets(&table));
2174 }
2175
2176 #[test]
2177 fn regressors_can_choose_between_mean_and_median() {
2178 let table = DenseTable::with_canaries(
2179 vec![vec![0.0], vec![0.0], vec![0.0]],
2180 vec![0.0, 0.0, 100.0],
2181 0,
2182 )
2183 .unwrap();
2184
2185 let mean_model = train_cart_regressor_with_criterion(&table, Criterion::Mean).unwrap();
2186 let median_model = train_cart_regressor_with_criterion(&table, Criterion::Median).unwrap();
2187
2188 assert_eq!(mean_model.criterion(), Criterion::Mean);
2189 assert_eq!(median_model.criterion(), Criterion::Median);
2190 assert_eq!(
2191 mean_model.predict_table(&table),
2192 vec![100.0 / 3.0, 100.0 / 3.0, 100.0 / 3.0]
2193 );
2194 assert_eq!(median_model.predict_table(&table), vec![0.0, 0.0, 0.0]);
2195 }
2196
2197 #[test]
2198 fn rejects_non_finite_targets() {
2199 let table = DenseTable::new(vec![vec![0.0], vec![1.0]], vec![0.0, f64::NAN]).unwrap();
2200
2201 let err = train_cart_regressor(&table).unwrap_err();
2202 assert!(matches!(
2203 err,
2204 RegressionTreeError::InvalidTargetValue { row: 1, value } if value.is_nan()
2205 ));
2206 }
2207
2208 #[test]
2209 fn stops_cart_regressor_growth_when_a_canary_wins() {
2210 let table = canary_target_table();
2211 let model = train_cart_regressor(&table).unwrap();
2212 let preds = model.predict_table(&table);
2213
2214 assert!(preds.iter().all(|pred| *pred == preds[0]));
2215 assert_ne!(preds, table_targets(&table));
2216 }
2217
2218 #[test]
2219 fn stops_oblivious_regressor_growth_when_a_canary_wins() {
2220 let table = canary_target_table();
2221 let model = train_oblivious_regressor(&table).unwrap();
2222 let preds = model.predict_table(&table);
2223
2224 assert!(preds.iter().all(|pred| *pred == preds[0]));
2225 assert_ne!(preds, table_targets(&table));
2226 }
2227
2228 #[test]
2229 fn manually_built_regressor_models_serialize_for_each_tree_type() {
2230 let preprocessing = vec![
2231 FeaturePreprocessing::Binary,
2232 FeaturePreprocessing::Numeric {
2233 bin_boundaries: vec![
2234 NumericBinBoundary {
2235 bin: 0,
2236 upper_bound: 1.0,
2237 },
2238 NumericBinBoundary {
2239 bin: 127,
2240 upper_bound: 10.0,
2241 },
2242 ],
2243 },
2244 ];
2245 let options = RegressionTreeOptions::default();
2246
2247 let cart = Model::DecisionTreeRegressor(DecisionTreeRegressor {
2248 algorithm: RegressionTreeAlgorithm::Cart,
2249 criterion: Criterion::Mean,
2250 structure: RegressionTreeStructure::Standard {
2251 nodes: vec![
2252 RegressionNode::Leaf {
2253 value: -1.0,
2254 sample_count: 2,
2255 variance: Some(0.25),
2256 },
2257 RegressionNode::Leaf {
2258 value: 2.5,
2259 sample_count: 3,
2260 variance: Some(1.0),
2261 },
2262 RegressionNode::BinarySplit {
2263 feature_index: 0,
2264 threshold_bin: 0,
2265 left_child: 0,
2266 right_child: 1,
2267 sample_count: 5,
2268 impurity: 3.5,
2269 gain: 1.25,
2270 variance: Some(0.7),
2271 },
2272 ],
2273 root: 2,
2274 },
2275 options,
2276 num_features: 2,
2277 feature_preprocessing: preprocessing.clone(),
2278 training_canaries: 0,
2279 });
2280 let randomized = Model::DecisionTreeRegressor(DecisionTreeRegressor {
2281 algorithm: RegressionTreeAlgorithm::Randomized,
2282 criterion: Criterion::Median,
2283 structure: RegressionTreeStructure::Standard {
2284 nodes: vec![
2285 RegressionNode::Leaf {
2286 value: -1.0,
2287 sample_count: 2,
2288 variance: Some(0.25),
2289 },
2290 RegressionNode::Leaf {
2291 value: 2.5,
2292 sample_count: 3,
2293 variance: Some(1.0),
2294 },
2295 RegressionNode::BinarySplit {
2296 feature_index: 0,
2297 threshold_bin: 0,
2298 left_child: 0,
2299 right_child: 1,
2300 sample_count: 5,
2301 impurity: 3.5,
2302 gain: 0.8,
2303 variance: Some(0.7),
2304 },
2305 ],
2306 root: 2,
2307 },
2308 options,
2309 num_features: 2,
2310 feature_preprocessing: preprocessing.clone(),
2311 training_canaries: 0,
2312 });
2313 let oblivious = Model::DecisionTreeRegressor(DecisionTreeRegressor {
2314 algorithm: RegressionTreeAlgorithm::Oblivious,
2315 criterion: Criterion::Median,
2316 structure: RegressionTreeStructure::Oblivious {
2317 splits: vec![ObliviousSplit {
2318 feature_index: 1,
2319 threshold_bin: 127,
2320 sample_count: 4,
2321 impurity: 2.0,
2322 gain: 0.5,
2323 }],
2324 leaf_values: vec![0.0, 10.0],
2325 leaf_sample_counts: vec![2, 2],
2326 leaf_variances: vec![Some(0.0), Some(1.0)],
2327 },
2328 options,
2329 num_features: 2,
2330 feature_preprocessing: preprocessing,
2331 training_canaries: 0,
2332 });
2333
2334 for (tree_type, model) in [
2335 ("cart", cart),
2336 ("randomized", randomized),
2337 ("oblivious", oblivious),
2338 ] {
2339 let json = model.serialize().unwrap();
2340 assert!(json.contains(&format!("\"tree_type\":\"{tree_type}\"")));
2341 assert!(json.contains("\"task\":\"regression\""));
2342 }
2343 }
2344
2345 fn table_targets(table: &dyn TableAccess) -> Vec<f64> {
2346 (0..table.n_rows())
2347 .map(|row_idx| table.target_value(row_idx))
2348 .collect()
2349 }
2350}