Skip to main content

forestfire_core/tree/
regressor.rs

1//! First-order regression tree learners.
2//!
3//! This module is the regression analogue of `classifier`, but the split logic
4//! differs in one important way: regression quality depends on leaf value
5//! statistics rather than class counts. The implementation therefore leans on
6//! cached count/sum/sum-of-squares histograms in the mean-criterion hot path.
7
8use crate::ir::{
9    BinaryChildren, BinarySplit, IndexedLeaf, LeafIndexing, LeafPayload, NodeStats, NodeTreeNode,
10    ObliviousLevel, ObliviousSplit as IrObliviousSplit, TrainingMetadata, TreeDefinition,
11    criterion_name, feature_name, threshold_upper_bound,
12};
13use crate::tree::shared::{
14    FeatureHistogram, HistogramBin, build_feature_histograms, candidate_feature_indices,
15    choose_random_threshold, node_seed, partition_rows_for_binary_split,
16    subtract_feature_histograms,
17};
18use crate::{Criterion, FeaturePreprocessing, Parallelism, capture_feature_preprocessing};
19use forestfire_data::TableAccess;
20use rayon::prelude::*;
21use std::collections::BTreeSet;
22use std::error::Error;
23use std::fmt::{Display, Formatter};
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum RegressionTreeAlgorithm {
27    Cart,
28    Randomized,
29    Oblivious,
30}
31
32/// Shared training controls for regression tree learners.
33#[derive(Debug, Clone, Copy)]
34pub struct RegressionTreeOptions {
35    pub max_depth: usize,
36    pub min_samples_split: usize,
37    pub min_samples_leaf: usize,
38    pub max_features: Option<usize>,
39    pub random_seed: u64,
40}
41
42impl Default for RegressionTreeOptions {
43    fn default() -> Self {
44        Self {
45            max_depth: 8,
46            min_samples_split: 2,
47            min_samples_leaf: 1,
48            max_features: None,
49            random_seed: 0,
50        }
51    }
52}
53
54#[derive(Debug)]
55pub enum RegressionTreeError {
56    EmptyTarget,
57    InvalidTargetValue { row: usize, value: f64 },
58}
59
60impl Display for RegressionTreeError {
61    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
62        match self {
63            RegressionTreeError::EmptyTarget => {
64                write!(f, "Cannot train on an empty target vector.")
65            }
66            RegressionTreeError::InvalidTargetValue { row, value } => write!(
67                f,
68                "Regression targets must be finite values. Found {} at row {}.",
69                value, row
70            ),
71        }
72    }
73}
74
75impl Error for RegressionTreeError {}
76
77/// Concrete trained regression tree.
78#[derive(Debug, Clone)]
79pub struct DecisionTreeRegressor {
80    algorithm: RegressionTreeAlgorithm,
81    criterion: Criterion,
82    structure: RegressionTreeStructure,
83    options: RegressionTreeOptions,
84    num_features: usize,
85    feature_preprocessing: Vec<FeaturePreprocessing>,
86    training_canaries: usize,
87}
88
89#[derive(Debug, Clone)]
90pub(crate) enum RegressionTreeStructure {
91    Standard {
92        nodes: Vec<RegressionNode>,
93        root: usize,
94    },
95    Oblivious {
96        splits: Vec<ObliviousSplit>,
97        leaf_values: Vec<f64>,
98        leaf_sample_counts: Vec<usize>,
99        leaf_variances: Vec<Option<f64>>,
100    },
101}
102
103#[derive(Debug, Clone)]
104pub(crate) enum RegressionNode {
105    Leaf {
106        value: f64,
107        sample_count: usize,
108        variance: Option<f64>,
109    },
110    BinarySplit {
111        feature_index: usize,
112        threshold_bin: u16,
113        left_child: usize,
114        right_child: usize,
115        sample_count: usize,
116        impurity: f64,
117        gain: f64,
118        variance: Option<f64>,
119    },
120}
121
122#[derive(Debug, Clone, Copy)]
123pub(crate) struct ObliviousSplit {
124    pub(crate) feature_index: usize,
125    pub(crate) threshold_bin: u16,
126    pub(crate) sample_count: usize,
127    pub(crate) impurity: f64,
128    pub(crate) gain: f64,
129}
130
131#[derive(Debug, Clone)]
132struct RegressionSplitCandidate {
133    feature_index: usize,
134    threshold_bin: u16,
135    score: f64,
136}
137
138#[derive(Debug, Clone)]
139struct ObliviousLeafState {
140    start: usize,
141    end: usize,
142    value: f64,
143    variance: Option<f64>,
144    sum: f64,
145    sum_sq: f64,
146}
147
148impl ObliviousLeafState {
149    fn len(&self) -> usize {
150        self.end - self.start
151    }
152}
153
154#[derive(Debug, Clone, Copy)]
155struct ObliviousSplitCandidate {
156    feature_index: usize,
157    threshold_bin: u16,
158    score: f64,
159}
160
161#[derive(Debug, Clone, Copy)]
162struct BinarySplitChoice {
163    feature_index: usize,
164    threshold_bin: u16,
165    score: f64,
166}
167
168#[derive(Debug, Clone)]
169struct RegressionHistogramBin {
170    count: usize,
171    sum: f64,
172    sum_sq: f64,
173}
174
175impl HistogramBin for RegressionHistogramBin {
176    fn subtract(parent: &Self, child: &Self) -> Self {
177        Self {
178            count: parent.count - child.count,
179            sum: parent.sum - child.sum,
180            sum_sq: parent.sum_sq - child.sum_sq,
181        }
182    }
183
184    fn is_observed(&self) -> bool {
185        self.count > 0
186    }
187}
188
189type RegressionFeatureHistogram = FeatureHistogram<RegressionHistogramBin>;
190
191pub fn train_cart_regressor(
192    train_set: &dyn TableAccess,
193) -> Result<DecisionTreeRegressor, RegressionTreeError> {
194    train_cart_regressor_with_criterion(train_set, Criterion::Mean)
195}
196
197pub fn train_cart_regressor_with_criterion(
198    train_set: &dyn TableAccess,
199    criterion: Criterion,
200) -> Result<DecisionTreeRegressor, RegressionTreeError> {
201    train_cart_regressor_with_criterion_and_parallelism(
202        train_set,
203        criterion,
204        Parallelism::sequential(),
205    )
206}
207
208pub(crate) fn train_cart_regressor_with_criterion_and_parallelism(
209    train_set: &dyn TableAccess,
210    criterion: Criterion,
211    parallelism: Parallelism,
212) -> Result<DecisionTreeRegressor, RegressionTreeError> {
213    train_cart_regressor_with_criterion_parallelism_and_options(
214        train_set,
215        criterion,
216        parallelism,
217        RegressionTreeOptions::default(),
218    )
219}
220
221pub(crate) fn train_cart_regressor_with_criterion_parallelism_and_options(
222    train_set: &dyn TableAccess,
223    criterion: Criterion,
224    parallelism: Parallelism,
225    options: RegressionTreeOptions,
226) -> Result<DecisionTreeRegressor, RegressionTreeError> {
227    train_regressor(
228        train_set,
229        RegressionTreeAlgorithm::Cart,
230        criterion,
231        parallelism,
232        options,
233    )
234}
235
236pub fn train_oblivious_regressor(
237    train_set: &dyn TableAccess,
238) -> Result<DecisionTreeRegressor, RegressionTreeError> {
239    train_oblivious_regressor_with_criterion(train_set, Criterion::Mean)
240}
241
242pub fn train_oblivious_regressor_with_criterion(
243    train_set: &dyn TableAccess,
244    criterion: Criterion,
245) -> Result<DecisionTreeRegressor, RegressionTreeError> {
246    train_oblivious_regressor_with_criterion_and_parallelism(
247        train_set,
248        criterion,
249        Parallelism::sequential(),
250    )
251}
252
253pub(crate) fn train_oblivious_regressor_with_criterion_and_parallelism(
254    train_set: &dyn TableAccess,
255    criterion: Criterion,
256    parallelism: Parallelism,
257) -> Result<DecisionTreeRegressor, RegressionTreeError> {
258    train_oblivious_regressor_with_criterion_parallelism_and_options(
259        train_set,
260        criterion,
261        parallelism,
262        RegressionTreeOptions::default(),
263    )
264}
265
266pub(crate) fn train_oblivious_regressor_with_criterion_parallelism_and_options(
267    train_set: &dyn TableAccess,
268    criterion: Criterion,
269    parallelism: Parallelism,
270    options: RegressionTreeOptions,
271) -> Result<DecisionTreeRegressor, RegressionTreeError> {
272    train_regressor(
273        train_set,
274        RegressionTreeAlgorithm::Oblivious,
275        criterion,
276        parallelism,
277        options,
278    )
279}
280
281pub fn train_randomized_regressor(
282    train_set: &dyn TableAccess,
283) -> Result<DecisionTreeRegressor, RegressionTreeError> {
284    train_randomized_regressor_with_criterion(train_set, Criterion::Mean)
285}
286
287pub fn train_randomized_regressor_with_criterion(
288    train_set: &dyn TableAccess,
289    criterion: Criterion,
290) -> Result<DecisionTreeRegressor, RegressionTreeError> {
291    train_randomized_regressor_with_criterion_and_parallelism(
292        train_set,
293        criterion,
294        Parallelism::sequential(),
295    )
296}
297
298pub(crate) fn train_randomized_regressor_with_criterion_and_parallelism(
299    train_set: &dyn TableAccess,
300    criterion: Criterion,
301    parallelism: Parallelism,
302) -> Result<DecisionTreeRegressor, RegressionTreeError> {
303    train_randomized_regressor_with_criterion_parallelism_and_options(
304        train_set,
305        criterion,
306        parallelism,
307        RegressionTreeOptions::default(),
308    )
309}
310
311pub(crate) fn train_randomized_regressor_with_criterion_parallelism_and_options(
312    train_set: &dyn TableAccess,
313    criterion: Criterion,
314    parallelism: Parallelism,
315    options: RegressionTreeOptions,
316) -> Result<DecisionTreeRegressor, RegressionTreeError> {
317    train_regressor(
318        train_set,
319        RegressionTreeAlgorithm::Randomized,
320        criterion,
321        parallelism,
322        options,
323    )
324}
325
326fn train_regressor(
327    train_set: &dyn TableAccess,
328    algorithm: RegressionTreeAlgorithm,
329    criterion: Criterion,
330    parallelism: Parallelism,
331    options: RegressionTreeOptions,
332) -> Result<DecisionTreeRegressor, RegressionTreeError> {
333    if train_set.n_rows() == 0 {
334        return Err(RegressionTreeError::EmptyTarget);
335    }
336
337    let targets = finite_targets(train_set)?;
338    let structure = match algorithm {
339        RegressionTreeAlgorithm::Cart => {
340            let mut nodes = Vec::new();
341            let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
342            let context = BuildContext {
343                table: train_set,
344                targets: &targets,
345                criterion,
346                parallelism,
347                options,
348                algorithm,
349            };
350            // CART and randomized regression reuse a single mutable row-index
351            // buffer so child partitions are formed in place instead of by
352            // allocating fresh row vectors for every split.
353            let root = build_binary_node_in_place(&context, &mut nodes, &mut all_rows, 0);
354            RegressionTreeStructure::Standard { nodes, root }
355        }
356        RegressionTreeAlgorithm::Randomized => {
357            let mut nodes = Vec::new();
358            let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
359            let context = BuildContext {
360                table: train_set,
361                targets: &targets,
362                criterion,
363                parallelism,
364                options,
365                algorithm,
366            };
367            let root = build_binary_node_in_place(&context, &mut nodes, &mut all_rows, 0);
368            RegressionTreeStructure::Standard { nodes, root }
369        }
370        RegressionTreeAlgorithm::Oblivious => {
371            // Oblivious trees are built level by level because every node at a
372            // given depth must share the same split.
373            train_oblivious_structure(train_set, &targets, criterion, parallelism, options)
374        }
375    };
376
377    Ok(DecisionTreeRegressor {
378        algorithm,
379        criterion,
380        structure,
381        options,
382        num_features: train_set.n_features(),
383        feature_preprocessing: capture_feature_preprocessing(train_set),
384        training_canaries: train_set.canaries(),
385    })
386}
387
388impl DecisionTreeRegressor {
389    /// Which learner family produced this tree.
390    pub fn algorithm(&self) -> RegressionTreeAlgorithm {
391        self.algorithm
392    }
393
394    /// Split criterion used during training.
395    pub fn criterion(&self) -> Criterion {
396        self.criterion
397    }
398
399    /// Predict one numeric value per row from a preprocessed table.
400    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
401        (0..table.n_rows())
402            .map(|row_idx| self.predict_row(table, row_idx))
403            .collect()
404    }
405
406    fn predict_row(&self, table: &dyn TableAccess, row_idx: usize) -> f64 {
407        match &self.structure {
408            RegressionTreeStructure::Standard { nodes, root } => {
409                let mut node_index = *root;
410
411                loop {
412                    match &nodes[node_index] {
413                        RegressionNode::Leaf { value, .. } => return *value,
414                        RegressionNode::BinarySplit {
415                            feature_index,
416                            threshold_bin,
417                            left_child,
418                            right_child,
419                            ..
420                        } => {
421                            let bin = table.binned_value(*feature_index, row_idx);
422                            node_index = if bin <= *threshold_bin {
423                                *left_child
424                            } else {
425                                *right_child
426                            };
427                        }
428                    }
429                }
430            }
431            RegressionTreeStructure::Oblivious {
432                splits,
433                leaf_values,
434                ..
435            } => {
436                let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
437                    let go_right =
438                        table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
439                    (leaf_index << 1) | usize::from(go_right)
440                });
441
442                leaf_values[leaf_index]
443            }
444        }
445    }
446
447    pub(crate) fn num_features(&self) -> usize {
448        self.num_features
449    }
450
451    pub(crate) fn structure(&self) -> &RegressionTreeStructure {
452        &self.structure
453    }
454
455    pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
456        &self.feature_preprocessing
457    }
458
459    pub(crate) fn training_metadata(&self) -> TrainingMetadata {
460        TrainingMetadata {
461            algorithm: "dt".to_string(),
462            task: "regression".to_string(),
463            tree_type: match self.algorithm {
464                RegressionTreeAlgorithm::Cart => "cart".to_string(),
465                RegressionTreeAlgorithm::Randomized => "randomized".to_string(),
466                RegressionTreeAlgorithm::Oblivious => "oblivious".to_string(),
467            },
468            criterion: criterion_name(self.criterion).to_string(),
469            canaries: self.training_canaries,
470            compute_oob: false,
471            max_depth: Some(self.options.max_depth),
472            min_samples_split: Some(self.options.min_samples_split),
473            min_samples_leaf: Some(self.options.min_samples_leaf),
474            n_trees: None,
475            max_features: self.options.max_features,
476            seed: None,
477            oob_score: None,
478            class_labels: None,
479            learning_rate: None,
480            bootstrap: None,
481            top_gradient_fraction: None,
482            other_gradient_fraction: None,
483        }
484    }
485
486    pub(crate) fn to_ir_tree(&self) -> TreeDefinition {
487        match &self.structure {
488            RegressionTreeStructure::Standard { nodes, root } => {
489                let depths = standard_node_depths(nodes, *root);
490                TreeDefinition::NodeTree {
491                    tree_id: 0,
492                    weight: 1.0,
493                    root_node_id: *root,
494                    nodes: nodes
495                        .iter()
496                        .enumerate()
497                        .map(|(node_id, node)| match node {
498                            RegressionNode::Leaf {
499                                value,
500                                sample_count,
501                                variance,
502                            } => NodeTreeNode::Leaf {
503                                node_id,
504                                depth: depths[node_id],
505                                leaf: LeafPayload::RegressionValue { value: *value },
506                                stats: NodeStats {
507                                    sample_count: *sample_count,
508                                    impurity: None,
509                                    gain: None,
510                                    class_counts: None,
511                                    variance: *variance,
512                                },
513                            },
514                            RegressionNode::BinarySplit {
515                                feature_index,
516                                threshold_bin,
517                                left_child,
518                                right_child,
519                                sample_count,
520                                impurity,
521                                gain,
522                                variance,
523                            } => NodeTreeNode::BinaryBranch {
524                                node_id,
525                                depth: depths[node_id],
526                                split: binary_split_ir(
527                                    *feature_index,
528                                    *threshold_bin,
529                                    &self.feature_preprocessing,
530                                ),
531                                children: BinaryChildren {
532                                    left: *left_child,
533                                    right: *right_child,
534                                },
535                                stats: NodeStats {
536                                    sample_count: *sample_count,
537                                    impurity: Some(*impurity),
538                                    gain: Some(*gain),
539                                    class_counts: None,
540                                    variance: *variance,
541                                },
542                            },
543                        })
544                        .collect(),
545                }
546            }
547            RegressionTreeStructure::Oblivious {
548                splits,
549                leaf_values,
550                leaf_sample_counts,
551                leaf_variances,
552            } => TreeDefinition::ObliviousLevels {
553                tree_id: 0,
554                weight: 1.0,
555                depth: splits.len(),
556                levels: splits
557                    .iter()
558                    .enumerate()
559                    .map(|(level, split)| ObliviousLevel {
560                        level,
561                        split: oblivious_split_ir(
562                            split.feature_index,
563                            split.threshold_bin,
564                            &self.feature_preprocessing,
565                        ),
566                        stats: NodeStats {
567                            sample_count: split.sample_count,
568                            impurity: Some(split.impurity),
569                            gain: Some(split.gain),
570                            class_counts: None,
571                            variance: None,
572                        },
573                    })
574                    .collect(),
575                leaf_indexing: LeafIndexing {
576                    bit_order: "msb_first".to_string(),
577                    index_formula: "sum(bit[level] << (depth - 1 - level))".to_string(),
578                },
579                leaves: leaf_values
580                    .iter()
581                    .enumerate()
582                    .map(|(leaf_index, value)| IndexedLeaf {
583                        leaf_index,
584                        leaf: LeafPayload::RegressionValue { value: *value },
585                        stats: NodeStats {
586                            sample_count: leaf_sample_counts[leaf_index],
587                            impurity: None,
588                            gain: None,
589                            class_counts: None,
590                            variance: leaf_variances[leaf_index],
591                        },
592                    })
593                    .collect(),
594            },
595        }
596    }
597
598    pub(crate) fn from_ir_parts(
599        algorithm: RegressionTreeAlgorithm,
600        criterion: Criterion,
601        structure: RegressionTreeStructure,
602        options: RegressionTreeOptions,
603        num_features: usize,
604        feature_preprocessing: Vec<FeaturePreprocessing>,
605        training_canaries: usize,
606    ) -> Self {
607        Self {
608            algorithm,
609            criterion,
610            structure,
611            options,
612            num_features,
613            feature_preprocessing,
614            training_canaries,
615        }
616    }
617}
618
619fn standard_node_depths(nodes: &[RegressionNode], root: usize) -> Vec<usize> {
620    let mut depths = vec![0; nodes.len()];
621    populate_depths(nodes, root, 0, &mut depths);
622    depths
623}
624
625fn populate_depths(nodes: &[RegressionNode], node_id: usize, depth: usize, depths: &mut [usize]) {
626    depths[node_id] = depth;
627    match &nodes[node_id] {
628        RegressionNode::Leaf { .. } => {}
629        RegressionNode::BinarySplit {
630            left_child,
631            right_child,
632            ..
633        } => {
634            populate_depths(nodes, *left_child, depth + 1, depths);
635            populate_depths(nodes, *right_child, depth + 1, depths);
636        }
637    }
638}
639
640fn binary_split_ir(
641    feature_index: usize,
642    threshold_bin: u16,
643    preprocessing: &[FeaturePreprocessing],
644) -> BinarySplit {
645    match preprocessing.get(feature_index) {
646        Some(FeaturePreprocessing::Binary) => BinarySplit::BooleanTest {
647            feature_index,
648            feature_name: feature_name(feature_index),
649            false_child_semantics: "left".to_string(),
650            true_child_semantics: "right".to_string(),
651        },
652        Some(FeaturePreprocessing::Numeric { .. }) | None => BinarySplit::NumericBinThreshold {
653            feature_index,
654            feature_name: feature_name(feature_index),
655            operator: "<=".to_string(),
656            threshold_bin,
657            threshold_upper_bound: threshold_upper_bound(
658                preprocessing,
659                feature_index,
660                threshold_bin,
661            ),
662            comparison_dtype: "uint16".to_string(),
663        },
664    }
665}
666
667fn oblivious_split_ir(
668    feature_index: usize,
669    threshold_bin: u16,
670    preprocessing: &[FeaturePreprocessing],
671) -> IrObliviousSplit {
672    match preprocessing.get(feature_index) {
673        Some(FeaturePreprocessing::Binary) => IrObliviousSplit::BooleanTest {
674            feature_index,
675            feature_name: feature_name(feature_index),
676            bit_when_false: 0,
677            bit_when_true: 1,
678        },
679        Some(FeaturePreprocessing::Numeric { .. }) | None => {
680            IrObliviousSplit::NumericBinThreshold {
681                feature_index,
682                feature_name: feature_name(feature_index),
683                operator: "<=".to_string(),
684                threshold_bin,
685                threshold_upper_bound: threshold_upper_bound(
686                    preprocessing,
687                    feature_index,
688                    threshold_bin,
689                ),
690                comparison_dtype: "uint16".to_string(),
691                bit_when_true: 0,
692                bit_when_false: 1,
693            }
694        }
695    }
696}
697
698struct BuildContext<'a> {
699    table: &'a dyn TableAccess,
700    targets: &'a [f64],
701    criterion: Criterion,
702    parallelism: Parallelism,
703    options: RegressionTreeOptions,
704    algorithm: RegressionTreeAlgorithm,
705}
706
707fn build_regression_node_histograms(
708    table: &dyn TableAccess,
709    targets: &[f64],
710    rows: &[usize],
711) -> Vec<RegressionFeatureHistogram> {
712    build_feature_histograms(
713        table,
714        rows,
715        |_| RegressionHistogramBin {
716            count: 0,
717            sum: 0.0,
718            sum_sq: 0.0,
719        },
720        |_feature_index, payload, row_idx| {
721            let value = targets[row_idx];
722            payload.count += 1;
723            payload.sum += value;
724            payload.sum_sq += value * value;
725        },
726    )
727}
728
729fn subtract_regression_node_histograms(
730    parent: &[RegressionFeatureHistogram],
731    child: &[RegressionFeatureHistogram],
732) -> Vec<RegressionFeatureHistogram> {
733    subtract_feature_histograms(parent, child)
734}
735
736fn finite_targets(train_set: &dyn TableAccess) -> Result<Vec<f64>, RegressionTreeError> {
737    (0..train_set.n_rows())
738        .map(|row_idx| {
739            let value = train_set.target_value(row_idx);
740            if value.is_finite() {
741                Ok(value)
742            } else {
743                Err(RegressionTreeError::InvalidTargetValue {
744                    row: row_idx,
745                    value,
746                })
747            }
748        })
749        .collect()
750}
751
752fn build_binary_node_in_place(
753    context: &BuildContext<'_>,
754    nodes: &mut Vec<RegressionNode>,
755    rows: &mut [usize],
756    depth: usize,
757) -> usize {
758    build_binary_node_in_place_with_hist(context, nodes, rows, depth, None)
759}
760
761fn build_binary_node_in_place_with_hist(
762    context: &BuildContext<'_>,
763    nodes: &mut Vec<RegressionNode>,
764    rows: &mut [usize],
765    depth: usize,
766    histograms: Option<Vec<RegressionFeatureHistogram>>,
767) -> usize {
768    let leaf_value = regression_value(rows, context.targets, context.criterion);
769    let leaf_variance = variance(rows, context.targets);
770
771    if rows.is_empty()
772        || depth >= context.options.max_depth
773        || rows.len() < context.options.min_samples_split
774        || has_constant_target(rows, context.targets)
775    {
776        return push_leaf(nodes, leaf_value, rows.len(), leaf_variance);
777    }
778
779    let histograms = if matches!(context.criterion, Criterion::Mean) {
780        Some(histograms.unwrap_or_else(|| {
781            build_regression_node_histograms(context.table, context.targets, rows)
782        }))
783    } else {
784        None
785    };
786    let feature_indices = candidate_feature_indices(
787        context.table.binned_feature_count(),
788        context.options.max_features,
789        node_seed(context.options.random_seed, depth, rows, 0xA11C_E5E1u64),
790    );
791    let best_split = if context.parallelism.enabled() {
792        feature_indices
793            .into_par_iter()
794            .filter_map(|feature_index| {
795                if let Some(histograms) = histograms.as_ref() {
796                    score_binary_split_choice_from_hist(
797                        context,
798                        &histograms[feature_index],
799                        feature_index,
800                        rows,
801                    )
802                } else {
803                    score_binary_split_choice(context, feature_index, rows)
804                }
805            })
806            .max_by(|left, right| left.score.total_cmp(&right.score))
807    } else {
808        feature_indices
809            .into_iter()
810            .filter_map(|feature_index| {
811                if let Some(histograms) = histograms.as_ref() {
812                    score_binary_split_choice_from_hist(
813                        context,
814                        &histograms[feature_index],
815                        feature_index,
816                        rows,
817                    )
818                } else {
819                    score_binary_split_choice(context, feature_index, rows)
820                }
821            })
822            .max_by(|left, right| left.score.total_cmp(&right.score))
823    };
824
825    match best_split {
826        Some(best_split)
827            if context
828                .table
829                .is_canary_binned_feature(best_split.feature_index) =>
830        {
831            push_leaf(nodes, leaf_value, rows.len(), leaf_variance)
832        }
833        Some(best_split) if best_split.score > 0.0 => {
834            let impurity = regression_loss(rows, context.targets, context.criterion);
835            let left_count = partition_rows_for_binary_split(
836                context.table,
837                best_split.feature_index,
838                best_split.threshold_bin,
839                rows,
840            );
841            let (left_rows, right_rows) = rows.split_at_mut(left_count);
842            let (left_child, right_child) = if let Some(histograms) = histograms {
843                if left_rows.len() <= right_rows.len() {
844                    let left_histograms =
845                        build_regression_node_histograms(context.table, context.targets, left_rows);
846                    let right_histograms =
847                        subtract_regression_node_histograms(&histograms, &left_histograms);
848                    (
849                        build_binary_node_in_place_with_hist(
850                            context,
851                            nodes,
852                            left_rows,
853                            depth + 1,
854                            Some(left_histograms),
855                        ),
856                        build_binary_node_in_place_with_hist(
857                            context,
858                            nodes,
859                            right_rows,
860                            depth + 1,
861                            Some(right_histograms),
862                        ),
863                    )
864                } else {
865                    let right_histograms = build_regression_node_histograms(
866                        context.table,
867                        context.targets,
868                        right_rows,
869                    );
870                    let left_histograms =
871                        subtract_regression_node_histograms(&histograms, &right_histograms);
872                    (
873                        build_binary_node_in_place_with_hist(
874                            context,
875                            nodes,
876                            left_rows,
877                            depth + 1,
878                            Some(left_histograms),
879                        ),
880                        build_binary_node_in_place_with_hist(
881                            context,
882                            nodes,
883                            right_rows,
884                            depth + 1,
885                            Some(right_histograms),
886                        ),
887                    )
888                }
889            } else {
890                (
891                    build_binary_node_in_place(context, nodes, left_rows, depth + 1),
892                    build_binary_node_in_place(context, nodes, right_rows, depth + 1),
893                )
894            };
895
896            push_node(
897                nodes,
898                RegressionNode::BinarySplit {
899                    feature_index: best_split.feature_index,
900                    threshold_bin: best_split.threshold_bin,
901                    left_child,
902                    right_child,
903                    sample_count: rows.len(),
904                    impurity,
905                    gain: best_split.score,
906                    variance: leaf_variance,
907                },
908            )
909        }
910        _ => push_leaf(nodes, leaf_value, rows.len(), leaf_variance),
911    }
912}
913
914fn train_oblivious_structure(
915    table: &dyn TableAccess,
916    targets: &[f64],
917    criterion: Criterion,
918    parallelism: Parallelism,
919    options: RegressionTreeOptions,
920) -> RegressionTreeStructure {
921    let mut row_indices: Vec<usize> = (0..table.n_rows()).collect();
922    let (root_sum, root_sum_sq) = sum_stats(&row_indices, targets);
923    let mut leaves = vec![ObliviousLeafState {
924        start: 0,
925        end: row_indices.len(),
926        value: regression_value_from_stats(&row_indices, targets, criterion, root_sum),
927        variance: variance_from_stats(row_indices.len(), root_sum, root_sum_sq),
928        sum: root_sum,
929        sum_sq: root_sum_sq,
930    }];
931    let mut splits = Vec::new();
932
933    for depth in 0..options.max_depth {
934        if leaves
935            .iter()
936            .all(|leaf| leaf.len() < options.min_samples_split)
937        {
938            break;
939        }
940        let feature_indices = candidate_feature_indices(
941            table.binned_feature_count(),
942            options.max_features,
943            node_seed(options.random_seed, depth, &[], 0x0B11_A10Cu64),
944        );
945        let best_split = if parallelism.enabled() {
946            feature_indices
947                .into_par_iter()
948                .filter_map(|feature_index| {
949                    score_oblivious_split(
950                        table,
951                        &row_indices,
952                        targets,
953                        feature_index,
954                        &leaves,
955                        criterion,
956                        options.min_samples_leaf,
957                    )
958                })
959                .max_by(|left, right| left.score.total_cmp(&right.score))
960        } else {
961            feature_indices
962                .into_iter()
963                .filter_map(|feature_index| {
964                    score_oblivious_split(
965                        table,
966                        &row_indices,
967                        targets,
968                        feature_index,
969                        &leaves,
970                        criterion,
971                        options.min_samples_leaf,
972                    )
973                })
974                .max_by(|left, right| left.score.total_cmp(&right.score))
975        };
976
977        let Some(best_split) = best_split.filter(|candidate| candidate.score > 0.0) else {
978            break;
979        };
980        if table.is_canary_binned_feature(best_split.feature_index) {
981            break;
982        }
983
984        leaves = split_oblivious_leaves_in_place(
985            table,
986            &mut row_indices,
987            targets,
988            leaves,
989            best_split.feature_index,
990            best_split.threshold_bin,
991            criterion,
992        );
993        splits.push(ObliviousSplit {
994            feature_index: best_split.feature_index,
995            threshold_bin: best_split.threshold_bin,
996            sample_count: table.n_rows(),
997            impurity: leaves
998                .iter()
999                .map(|leaf| leaf_regression_loss(leaf, &row_indices, targets, criterion))
1000                .sum(),
1001            gain: best_split.score,
1002        });
1003    }
1004
1005    RegressionTreeStructure::Oblivious {
1006        splits,
1007        leaf_values: leaves.iter().map(|leaf| leaf.value).collect(),
1008        leaf_sample_counts: leaves.iter().map(ObliviousLeafState::len).collect(),
1009        leaf_variances: leaves.iter().map(|leaf| leaf.variance).collect(),
1010    }
1011}
1012
1013fn score_split(
1014    table: &dyn TableAccess,
1015    targets: &[f64],
1016    feature_index: usize,
1017    rows: &[usize],
1018    criterion: Criterion,
1019    min_samples_leaf: usize,
1020    algorithm: RegressionTreeAlgorithm,
1021) -> Option<RegressionSplitCandidate> {
1022    if table.is_binary_binned_feature(feature_index) {
1023        return score_binary_split(
1024            table,
1025            targets,
1026            feature_index,
1027            rows,
1028            criterion,
1029            min_samples_leaf,
1030        );
1031    }
1032    if matches!(criterion, Criterion::Mean) {
1033        if matches!(algorithm, RegressionTreeAlgorithm::Randomized) {
1034            if let Some(candidate) = score_randomized_split_mean_fast(
1035                table,
1036                targets,
1037                feature_index,
1038                rows,
1039                min_samples_leaf,
1040            ) {
1041                return Some(candidate);
1042            }
1043        } else if let Some(candidate) =
1044            score_numeric_split_mean_fast(table, targets, feature_index, rows, min_samples_leaf)
1045        {
1046            return Some(candidate);
1047        }
1048    }
1049    if matches!(algorithm, RegressionTreeAlgorithm::Randomized) {
1050        return score_randomized_split(
1051            table,
1052            targets,
1053            feature_index,
1054            rows,
1055            criterion,
1056            min_samples_leaf,
1057        );
1058    }
1059    let parent_loss = regression_loss(rows, targets, criterion);
1060
1061    rows.iter()
1062        .map(|row_idx| table.binned_value(feature_index, *row_idx))
1063        .collect::<BTreeSet<_>>()
1064        .into_iter()
1065        .filter_map(|threshold_bin| {
1066            let (left_rows, right_rows): (Vec<usize>, Vec<usize>) = rows
1067                .iter()
1068                .copied()
1069                .partition(|row_idx| table.binned_value(feature_index, *row_idx) <= threshold_bin);
1070
1071            if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1072                return None;
1073            }
1074
1075            let score = parent_loss
1076                - (regression_loss(&left_rows, targets, criterion)
1077                    + regression_loss(&right_rows, targets, criterion));
1078
1079            Some(RegressionSplitCandidate {
1080                feature_index,
1081                threshold_bin,
1082                score,
1083            })
1084        })
1085        .max_by(|left, right| left.score.total_cmp(&right.score))
1086}
1087
1088fn score_randomized_split(
1089    table: &dyn TableAccess,
1090    targets: &[f64],
1091    feature_index: usize,
1092    rows: &[usize],
1093    criterion: Criterion,
1094    min_samples_leaf: usize,
1095) -> Option<RegressionSplitCandidate> {
1096    let candidate_thresholds = rows
1097        .iter()
1098        .map(|row_idx| table.binned_value(feature_index, *row_idx))
1099        .collect::<BTreeSet<_>>()
1100        .into_iter()
1101        .collect::<Vec<_>>();
1102    let threshold_bin =
1103        choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xA11CE551u64)?;
1104
1105    let (left_rows, right_rows): (Vec<usize>, Vec<usize>) = rows
1106        .iter()
1107        .copied()
1108        .partition(|row_idx| table.binned_value(feature_index, *row_idx) <= threshold_bin);
1109
1110    if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1111        return None;
1112    }
1113
1114    let parent_loss = regression_loss(rows, targets, criterion);
1115    let score = parent_loss
1116        - (regression_loss(&left_rows, targets, criterion)
1117            + regression_loss(&right_rows, targets, criterion));
1118
1119    Some(RegressionSplitCandidate {
1120        feature_index,
1121        threshold_bin,
1122        score,
1123    })
1124}
1125
1126fn score_oblivious_split(
1127    table: &dyn TableAccess,
1128    row_indices: &[usize],
1129    targets: &[f64],
1130    feature_index: usize,
1131    leaves: &[ObliviousLeafState],
1132    criterion: Criterion,
1133    min_samples_leaf: usize,
1134) -> Option<ObliviousSplitCandidate> {
1135    if table.is_binary_binned_feature(feature_index) {
1136        if matches!(criterion, Criterion::Mean)
1137            && let Some(candidate) = score_binary_oblivious_split_mean_fast(
1138                table,
1139                row_indices,
1140                targets,
1141                feature_index,
1142                leaves,
1143                min_samples_leaf,
1144            )
1145        {
1146            return Some(candidate);
1147        }
1148        return score_binary_oblivious_split(
1149            table,
1150            row_indices,
1151            targets,
1152            feature_index,
1153            leaves,
1154            criterion,
1155            min_samples_leaf,
1156        );
1157    }
1158    if matches!(criterion, Criterion::Mean)
1159        && let Some(candidate) = score_numeric_oblivious_split_mean_fast(
1160            table,
1161            row_indices,
1162            targets,
1163            feature_index,
1164            leaves,
1165            min_samples_leaf,
1166        )
1167    {
1168        return Some(candidate);
1169    }
1170    let candidate_thresholds = leaves
1171        .iter()
1172        .flat_map(|leaf| {
1173            row_indices[leaf.start..leaf.end]
1174                .iter()
1175                .map(|row_idx| table.binned_value(feature_index, *row_idx))
1176        })
1177        .collect::<BTreeSet<_>>();
1178
1179    candidate_thresholds
1180        .into_iter()
1181        .filter_map(|threshold_bin| {
1182            let score = leaves.iter().fold(0.0, |score, leaf| {
1183                let leaf_rows = &row_indices[leaf.start..leaf.end];
1184                let (left_rows, right_rows): (Vec<usize>, Vec<usize>) =
1185                    leaf_rows.iter().copied().partition(|row_idx| {
1186                        table.binned_value(feature_index, *row_idx) <= threshold_bin
1187                    });
1188
1189                if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1190                    return score;
1191                }
1192
1193                score + regression_loss(leaf_rows, targets, criterion)
1194                    - (regression_loss(&left_rows, targets, criterion)
1195                        + regression_loss(&right_rows, targets, criterion))
1196            });
1197
1198            (score > 0.0).then_some(ObliviousSplitCandidate {
1199                feature_index,
1200                threshold_bin,
1201                score,
1202            })
1203        })
1204        .max_by(|left, right| left.score.total_cmp(&right.score))
1205}
1206
1207fn split_oblivious_leaves_in_place(
1208    table: &dyn TableAccess,
1209    row_indices: &mut [usize],
1210    targets: &[f64],
1211    leaves: Vec<ObliviousLeafState>,
1212    feature_index: usize,
1213    threshold_bin: u16,
1214    criterion: Criterion,
1215) -> Vec<ObliviousLeafState> {
1216    let mut next_leaves = Vec::with_capacity(leaves.len() * 2);
1217    for leaf in leaves {
1218        let fallback_value = leaf.value;
1219        let left_count = partition_rows_for_binary_split(
1220            table,
1221            feature_index,
1222            threshold_bin,
1223            &mut row_indices[leaf.start..leaf.end],
1224        );
1225        let mid = leaf.start + left_count;
1226        let left_rows = &row_indices[leaf.start..mid];
1227        let right_rows = &row_indices[mid..leaf.end];
1228        let (left_sum, left_sum_sq) = sum_stats(left_rows, targets);
1229        let (right_sum, right_sum_sq) = sum_stats(right_rows, targets);
1230        next_leaves.push(ObliviousLeafState {
1231            start: leaf.start,
1232            end: mid,
1233            value: if left_rows.is_empty() {
1234                fallback_value
1235            } else {
1236                regression_value_from_stats(left_rows, targets, criterion, left_sum)
1237            },
1238            variance: variance_from_stats(left_rows.len(), left_sum, left_sum_sq),
1239            sum: left_sum,
1240            sum_sq: left_sum_sq,
1241        });
1242        next_leaves.push(ObliviousLeafState {
1243            start: mid,
1244            end: leaf.end,
1245            value: if right_rows.is_empty() {
1246                fallback_value
1247            } else {
1248                regression_value_from_stats(right_rows, targets, criterion, right_sum)
1249            },
1250            variance: variance_from_stats(right_rows.len(), right_sum, right_sum_sq),
1251            sum: right_sum,
1252            sum_sq: right_sum_sq,
1253        });
1254    }
1255    next_leaves
1256}
1257
1258fn variance(rows: &[usize], targets: &[f64]) -> Option<f64> {
1259    let (sum, sum_sq) = sum_stats(rows, targets);
1260    variance_from_stats(rows.len(), sum, sum_sq)
1261}
1262
1263fn mean(rows: &[usize], targets: &[f64]) -> f64 {
1264    if rows.is_empty() {
1265        0.0
1266    } else {
1267        rows.iter().map(|row_idx| targets[*row_idx]).sum::<f64>() / rows.len() as f64
1268    }
1269}
1270
1271fn median(rows: &[usize], targets: &[f64]) -> f64 {
1272    if rows.is_empty() {
1273        return 0.0;
1274    }
1275    let mut values: Vec<f64> = rows.iter().map(|row_idx| targets[*row_idx]).collect();
1276    values.sort_by(|left, right| left.total_cmp(right));
1277
1278    let mid = values.len() / 2;
1279    if values.len().is_multiple_of(2) {
1280        (values[mid - 1] + values[mid]) / 2.0
1281    } else {
1282        values[mid]
1283    }
1284}
1285
1286fn sum_squared_error(rows: &[usize], targets: &[f64]) -> f64 {
1287    let mean = mean(rows, targets);
1288    rows.iter()
1289        .map(|row_idx| {
1290            let diff = targets[*row_idx] - mean;
1291            diff * diff
1292        })
1293        .sum()
1294}
1295
1296fn sum_absolute_error(rows: &[usize], targets: &[f64]) -> f64 {
1297    let median = median(rows, targets);
1298    rows.iter()
1299        .map(|row_idx| (targets[*row_idx] - median).abs())
1300        .sum()
1301}
1302
1303fn regression_value(rows: &[usize], targets: &[f64], criterion: Criterion) -> f64 {
1304    let (sum, _sum_sq) = sum_stats(rows, targets);
1305    regression_value_from_stats(rows, targets, criterion, sum)
1306}
1307
1308fn regression_value_from_stats(
1309    rows: &[usize],
1310    targets: &[f64],
1311    criterion: Criterion,
1312    sum: f64,
1313) -> f64 {
1314    match criterion {
1315        Criterion::Mean => {
1316            if rows.is_empty() {
1317                0.0
1318            } else {
1319                sum / rows.len() as f64
1320            }
1321        }
1322        Criterion::Median => median(rows, targets),
1323        _ => unreachable!("regression criterion only supports mean or median"),
1324    }
1325}
1326
1327fn regression_loss(rows: &[usize], targets: &[f64], criterion: Criterion) -> f64 {
1328    match criterion {
1329        Criterion::Mean => sum_squared_error(rows, targets),
1330        Criterion::Median => sum_absolute_error(rows, targets),
1331        _ => unreachable!("regression criterion only supports mean or median"),
1332    }
1333}
1334
1335fn score_binary_split(
1336    table: &dyn TableAccess,
1337    targets: &[f64],
1338    feature_index: usize,
1339    rows: &[usize],
1340    criterion: Criterion,
1341    min_samples_leaf: usize,
1342) -> Option<RegressionSplitCandidate> {
1343    let (left_rows, right_rows): (Vec<usize>, Vec<usize>) =
1344        rows.iter().copied().partition(|row_idx| {
1345            !table
1346                .binned_boolean_value(feature_index, *row_idx)
1347                .expect("binary feature must expose boolean values")
1348        });
1349
1350    if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1351        return None;
1352    }
1353
1354    let parent_loss = regression_loss(rows, targets, criterion);
1355    let score = parent_loss
1356        - (regression_loss(&left_rows, targets, criterion)
1357            + regression_loss(&right_rows, targets, criterion));
1358
1359    Some(RegressionSplitCandidate {
1360        feature_index,
1361        threshold_bin: 0,
1362        score,
1363    })
1364}
1365
1366fn score_binary_split_choice(
1367    context: &BuildContext<'_>,
1368    feature_index: usize,
1369    rows: &[usize],
1370) -> Option<BinarySplitChoice> {
1371    if matches!(context.criterion, Criterion::Mean) {
1372        if context.table.is_binary_binned_feature(feature_index) {
1373            return score_binary_split_choice_mean(context, feature_index, rows);
1374        }
1375        return match context.algorithm {
1376            RegressionTreeAlgorithm::Cart => {
1377                score_numeric_split_choice_mean_fast(context, feature_index, rows)
1378            }
1379            RegressionTreeAlgorithm::Randomized => {
1380                score_randomized_split_choice_mean_fast(context, feature_index, rows)
1381            }
1382            RegressionTreeAlgorithm::Oblivious => None,
1383        };
1384    }
1385
1386    score_split(
1387        context.table,
1388        context.targets,
1389        feature_index,
1390        rows,
1391        context.criterion,
1392        context.options.min_samples_leaf,
1393        context.algorithm,
1394    )
1395    .map(|candidate| BinarySplitChoice {
1396        feature_index: candidate.feature_index,
1397        threshold_bin: candidate.threshold_bin,
1398        score: candidate.score,
1399    })
1400}
1401
1402fn score_binary_split_choice_from_hist(
1403    context: &BuildContext<'_>,
1404    histogram: &RegressionFeatureHistogram,
1405    feature_index: usize,
1406    rows: &[usize],
1407) -> Option<BinarySplitChoice> {
1408    if !matches!(context.criterion, Criterion::Mean) {
1409        return score_binary_split_choice(context, feature_index, rows);
1410    }
1411
1412    match histogram {
1413        RegressionFeatureHistogram::Binary {
1414            false_bin,
1415            true_bin,
1416        } => score_binary_split_choice_mean_from_stats(
1417            context,
1418            feature_index,
1419            false_bin.count,
1420            false_bin.sum,
1421            false_bin.sum_sq,
1422            true_bin.count,
1423            true_bin.sum,
1424            true_bin.sum_sq,
1425        ),
1426        RegressionFeatureHistogram::Numeric {
1427            bins,
1428            observed_bins,
1429        } => match context.algorithm {
1430            RegressionTreeAlgorithm::Cart => score_numeric_split_choice_mean_from_hist(
1431                context,
1432                feature_index,
1433                rows.len(),
1434                bins,
1435                observed_bins,
1436            ),
1437            RegressionTreeAlgorithm::Randomized => score_randomized_split_choice_mean_from_hist(
1438                context,
1439                feature_index,
1440                rows,
1441                bins,
1442                observed_bins,
1443            ),
1444            RegressionTreeAlgorithm::Oblivious => None,
1445        },
1446    }
1447}
1448
1449#[allow(clippy::too_many_arguments)]
1450fn score_binary_split_choice_mean_from_stats(
1451    context: &BuildContext<'_>,
1452    feature_index: usize,
1453    left_count: usize,
1454    left_sum: f64,
1455    left_sum_sq: f64,
1456    right_count: usize,
1457    right_sum: f64,
1458    right_sum_sq: f64,
1459) -> Option<BinarySplitChoice> {
1460    if left_count < context.options.min_samples_leaf
1461        || right_count < context.options.min_samples_leaf
1462    {
1463        return None;
1464    }
1465    let total_count = left_count + right_count;
1466    let total_sum = left_sum + right_sum;
1467    let total_sum_sq = left_sum_sq + right_sum_sq;
1468    let parent_loss = total_sum_sq - (total_sum * total_sum) / total_count as f64;
1469    let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1470    let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1471    Some(BinarySplitChoice {
1472        feature_index,
1473        threshold_bin: 0,
1474        score: parent_loss - (left_loss + right_loss),
1475    })
1476}
1477
1478fn score_numeric_split_choice_mean_from_hist(
1479    context: &BuildContext<'_>,
1480    feature_index: usize,
1481    row_count: usize,
1482    bins: &[RegressionHistogramBin],
1483    observed_bins: &[usize],
1484) -> Option<BinarySplitChoice> {
1485    if observed_bins.len() <= 1 {
1486        return None;
1487    }
1488    let total_sum = bins.iter().map(|bin| bin.sum).sum::<f64>();
1489    let total_sum_sq = bins.iter().map(|bin| bin.sum_sq).sum::<f64>();
1490    let parent_loss = total_sum_sq - (total_sum * total_sum) / row_count as f64;
1491    let mut left_count = 0usize;
1492    let mut left_sum = 0.0;
1493    let mut left_sum_sq = 0.0;
1494    let mut best_threshold = None;
1495    let mut best_score = f64::NEG_INFINITY;
1496
1497    for &bin in observed_bins {
1498        left_count += bins[bin].count;
1499        left_sum += bins[bin].sum;
1500        left_sum_sq += bins[bin].sum_sq;
1501        let right_count = row_count - left_count;
1502        if left_count < context.options.min_samples_leaf
1503            || right_count < context.options.min_samples_leaf
1504        {
1505            continue;
1506        }
1507        let right_sum = total_sum - left_sum;
1508        let right_sum_sq = total_sum_sq - left_sum_sq;
1509        let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1510        let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1511        let score = parent_loss - (left_loss + right_loss);
1512        if score > best_score {
1513            best_score = score;
1514            best_threshold = Some(bin as u16);
1515        }
1516    }
1517
1518    best_threshold.map(|threshold_bin| BinarySplitChoice {
1519        feature_index,
1520        threshold_bin,
1521        score: best_score,
1522    })
1523}
1524
1525fn score_randomized_split_choice_mean_from_hist(
1526    context: &BuildContext<'_>,
1527    feature_index: usize,
1528    rows: &[usize],
1529    bins: &[RegressionHistogramBin],
1530    observed_bins: &[usize],
1531) -> Option<BinarySplitChoice> {
1532    let candidate_thresholds = observed_bins
1533        .iter()
1534        .copied()
1535        .map(|bin| bin as u16)
1536        .collect::<Vec<_>>();
1537    let threshold_bin =
1538        choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xA11CE551u64)?;
1539    let total_sum = bins.iter().map(|bin| bin.sum).sum::<f64>();
1540    let total_sum_sq = bins.iter().map(|bin| bin.sum_sq).sum::<f64>();
1541    let mut left_count = 0usize;
1542    let mut left_sum = 0.0;
1543    let mut left_sum_sq = 0.0;
1544    for bin in 0..=threshold_bin as usize {
1545        if bin >= bins.len() {
1546            break;
1547        }
1548        left_count += bins[bin].count;
1549        left_sum += bins[bin].sum;
1550        left_sum_sq += bins[bin].sum_sq;
1551    }
1552    let right_count = rows.len() - left_count;
1553    if left_count < context.options.min_samples_leaf
1554        || right_count < context.options.min_samples_leaf
1555    {
1556        return None;
1557    }
1558    let parent_loss = total_sum_sq - (total_sum * total_sum) / rows.len() as f64;
1559    let right_sum = total_sum - left_sum;
1560    let right_sum_sq = total_sum_sq - left_sum_sq;
1561    let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1562    let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1563    Some(BinarySplitChoice {
1564        feature_index,
1565        threshold_bin,
1566        score: parent_loss - (left_loss + right_loss),
1567    })
1568}
1569
1570fn score_binary_split_choice_mean(
1571    context: &BuildContext<'_>,
1572    feature_index: usize,
1573    rows: &[usize],
1574) -> Option<BinarySplitChoice> {
1575    let mut left_count = 0usize;
1576    let mut left_sum = 0.0;
1577    let mut left_sum_sq = 0.0;
1578    let mut total_sum = 0.0;
1579    let mut total_sum_sq = 0.0;
1580
1581    for row_idx in rows {
1582        let target = context.targets[*row_idx];
1583        total_sum += target;
1584        total_sum_sq += target * target;
1585        if !context
1586            .table
1587            .binned_boolean_value(feature_index, *row_idx)
1588            .expect("binary feature must expose boolean values")
1589        {
1590            left_count += 1;
1591            left_sum += target;
1592            left_sum_sq += target * target;
1593        }
1594    }
1595
1596    let right_count = rows.len() - left_count;
1597    if left_count < context.options.min_samples_leaf
1598        || right_count < context.options.min_samples_leaf
1599    {
1600        return None;
1601    }
1602
1603    let parent_loss = total_sum_sq - (total_sum * total_sum) / rows.len() as f64;
1604    let right_sum = total_sum - left_sum;
1605    let right_sum_sq = total_sum_sq - left_sum_sq;
1606    let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1607    let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1608
1609    Some(BinarySplitChoice {
1610        feature_index,
1611        threshold_bin: 0,
1612        score: parent_loss - (left_loss + right_loss),
1613    })
1614}
1615
1616fn score_numeric_split_mean_fast(
1617    table: &dyn TableAccess,
1618    targets: &[f64],
1619    feature_index: usize,
1620    rows: &[usize],
1621    min_samples_leaf: usize,
1622) -> Option<RegressionSplitCandidate> {
1623    let bin_cap = table.numeric_bin_cap();
1624    if bin_cap == 0 {
1625        return None;
1626    }
1627
1628    let mut bin_count = vec![0usize; bin_cap];
1629    let mut bin_sum = vec![0.0; bin_cap];
1630    let mut bin_sum_sq = vec![0.0; bin_cap];
1631    let mut observed_bins = vec![false; bin_cap];
1632    let mut total_sum = 0.0;
1633    let mut total_sum_sq = 0.0;
1634
1635    for row_idx in rows {
1636        let bin = table.binned_value(feature_index, *row_idx) as usize;
1637        if bin >= bin_cap {
1638            return None;
1639        }
1640        let target = targets[*row_idx];
1641        bin_count[bin] += 1;
1642        bin_sum[bin] += target;
1643        bin_sum_sq[bin] += target * target;
1644        observed_bins[bin] = true;
1645        total_sum += target;
1646        total_sum_sq += target * target;
1647    }
1648
1649    let observed_bins: Vec<usize> = observed_bins
1650        .into_iter()
1651        .enumerate()
1652        .filter_map(|(bin, seen)| seen.then_some(bin))
1653        .collect();
1654    if observed_bins.len() <= 1 {
1655        return None;
1656    }
1657
1658    let parent_loss = total_sum_sq - (total_sum * total_sum) / rows.len() as f64;
1659    let mut left_count = 0usize;
1660    let mut left_sum = 0.0;
1661    let mut left_sum_sq = 0.0;
1662    let mut best_threshold = None;
1663    let mut best_score = f64::NEG_INFINITY;
1664
1665    for &bin in &observed_bins {
1666        left_count += bin_count[bin];
1667        left_sum += bin_sum[bin];
1668        left_sum_sq += bin_sum_sq[bin];
1669        let right_count = rows.len() - left_count;
1670
1671        if left_count < min_samples_leaf || right_count < min_samples_leaf {
1672            continue;
1673        }
1674
1675        let right_sum = total_sum - left_sum;
1676        let right_sum_sq = total_sum_sq - left_sum_sq;
1677        let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1678        let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1679        let score = parent_loss - (left_loss + right_loss);
1680        if score > best_score {
1681            best_score = score;
1682            best_threshold = Some(bin as u16);
1683        }
1684    }
1685
1686    let threshold_bin = best_threshold?;
1687    Some(RegressionSplitCandidate {
1688        feature_index,
1689        threshold_bin,
1690        score: best_score,
1691    })
1692}
1693
1694fn score_numeric_split_choice_mean_fast(
1695    context: &BuildContext<'_>,
1696    feature_index: usize,
1697    rows: &[usize],
1698) -> Option<BinarySplitChoice> {
1699    score_numeric_split_mean_fast(
1700        context.table,
1701        context.targets,
1702        feature_index,
1703        rows,
1704        context.options.min_samples_leaf,
1705    )
1706    .map(|candidate| BinarySplitChoice {
1707        feature_index: candidate.feature_index,
1708        threshold_bin: candidate.threshold_bin,
1709        score: candidate.score,
1710    })
1711}
1712
1713fn score_randomized_split_mean_fast(
1714    table: &dyn TableAccess,
1715    targets: &[f64],
1716    feature_index: usize,
1717    rows: &[usize],
1718    min_samples_leaf: usize,
1719) -> Option<RegressionSplitCandidate> {
1720    let bin_cap = table.numeric_bin_cap();
1721    if bin_cap == 0 {
1722        return None;
1723    }
1724    let mut observed_bins = vec![false; bin_cap];
1725    for row_idx in rows {
1726        let bin = table.binned_value(feature_index, *row_idx) as usize;
1727        if bin >= bin_cap {
1728            return None;
1729        }
1730        observed_bins[bin] = true;
1731    }
1732    let candidate_thresholds = observed_bins
1733        .into_iter()
1734        .enumerate()
1735        .filter_map(|(bin, seen)| seen.then_some(bin as u16))
1736        .collect::<Vec<_>>();
1737    let threshold_bin =
1738        choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xA11CE551u64)?;
1739
1740    let mut left_count = 0usize;
1741    let mut left_sum = 0.0;
1742    let mut left_sum_sq = 0.0;
1743    let mut total_sum = 0.0;
1744    let mut total_sum_sq = 0.0;
1745    for row_idx in rows {
1746        let target = targets[*row_idx];
1747        total_sum += target;
1748        total_sum_sq += target * target;
1749        if table.binned_value(feature_index, *row_idx) <= threshold_bin {
1750            left_count += 1;
1751            left_sum += target;
1752            left_sum_sq += target * target;
1753        }
1754    }
1755    let right_count = rows.len() - left_count;
1756    if left_count < min_samples_leaf || right_count < min_samples_leaf {
1757        return None;
1758    }
1759    let parent_loss = total_sum_sq - (total_sum * total_sum) / rows.len() as f64;
1760    let right_sum = total_sum - left_sum;
1761    let right_sum_sq = total_sum_sq - left_sum_sq;
1762    let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1763    let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1764    let score = parent_loss - (left_loss + right_loss);
1765
1766    Some(RegressionSplitCandidate {
1767        feature_index,
1768        threshold_bin,
1769        score,
1770    })
1771}
1772
1773fn score_randomized_split_choice_mean_fast(
1774    context: &BuildContext<'_>,
1775    feature_index: usize,
1776    rows: &[usize],
1777) -> Option<BinarySplitChoice> {
1778    score_randomized_split_mean_fast(
1779        context.table,
1780        context.targets,
1781        feature_index,
1782        rows,
1783        context.options.min_samples_leaf,
1784    )
1785    .map(|candidate| BinarySplitChoice {
1786        feature_index: candidate.feature_index,
1787        threshold_bin: candidate.threshold_bin,
1788        score: candidate.score,
1789    })
1790}
1791
1792fn score_numeric_oblivious_split_mean_fast(
1793    table: &dyn TableAccess,
1794    row_indices: &[usize],
1795    targets: &[f64],
1796    feature_index: usize,
1797    leaves: &[ObliviousLeafState],
1798    min_samples_leaf: usize,
1799) -> Option<ObliviousSplitCandidate> {
1800    let bin_cap = table.numeric_bin_cap();
1801    if bin_cap == 0 {
1802        return None;
1803    }
1804    let mut threshold_scores = vec![0.0; bin_cap];
1805    let mut observed_any = false;
1806
1807    for leaf in leaves {
1808        let mut bin_count = vec![0usize; bin_cap];
1809        let mut bin_sum = vec![0.0; bin_cap];
1810        let mut bin_sum_sq = vec![0.0; bin_cap];
1811        let mut observed_bins = vec![false; bin_cap];
1812
1813        for row_idx in &row_indices[leaf.start..leaf.end] {
1814            let bin = table.binned_value(feature_index, *row_idx) as usize;
1815            if bin >= bin_cap {
1816                return None;
1817            }
1818            let target = targets[*row_idx];
1819            bin_count[bin] += 1;
1820            bin_sum[bin] += target;
1821            bin_sum_sq[bin] += target * target;
1822            observed_bins[bin] = true;
1823        }
1824
1825        let observed_bins: Vec<usize> = observed_bins
1826            .into_iter()
1827            .enumerate()
1828            .filter_map(|(bin, seen)| seen.then_some(bin))
1829            .collect();
1830        if observed_bins.len() <= 1 {
1831            continue;
1832        }
1833        observed_any = true;
1834
1835        let parent_loss = leaf.sum_sq - (leaf.sum * leaf.sum) / leaf.len() as f64;
1836        let mut left_count = 0usize;
1837        let mut left_sum = 0.0;
1838        let mut left_sum_sq = 0.0;
1839
1840        for &bin in &observed_bins {
1841            left_count += bin_count[bin];
1842            left_sum += bin_sum[bin];
1843            left_sum_sq += bin_sum_sq[bin];
1844            let right_count = leaf.len() - left_count;
1845
1846            if left_count < min_samples_leaf || right_count < min_samples_leaf {
1847                continue;
1848            }
1849
1850            let right_sum = leaf.sum - left_sum;
1851            let right_sum_sq = leaf.sum_sq - left_sum_sq;
1852            let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1853            let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1854            threshold_scores[bin] += parent_loss - (left_loss + right_loss);
1855        }
1856    }
1857
1858    if !observed_any {
1859        return None;
1860    }
1861
1862    threshold_scores
1863        .into_iter()
1864        .enumerate()
1865        .filter(|(_, score)| *score > 0.0)
1866        .max_by(|left, right| left.1.total_cmp(&right.1))
1867        .map(|(threshold_bin, score)| ObliviousSplitCandidate {
1868            feature_index,
1869            threshold_bin: threshold_bin as u16,
1870            score,
1871        })
1872}
1873
1874fn score_binary_oblivious_split(
1875    table: &dyn TableAccess,
1876    row_indices: &[usize],
1877    targets: &[f64],
1878    feature_index: usize,
1879    leaves: &[ObliviousLeafState],
1880    criterion: Criterion,
1881    min_samples_leaf: usize,
1882) -> Option<ObliviousSplitCandidate> {
1883    let score = leaves.iter().fold(0.0, |score, leaf| {
1884        let leaf_rows = &row_indices[leaf.start..leaf.end];
1885        let (left_rows, right_rows): (Vec<usize>, Vec<usize>) =
1886            leaf_rows.iter().copied().partition(|row_idx| {
1887                !table
1888                    .binned_boolean_value(feature_index, *row_idx)
1889                    .expect("binary feature must expose boolean values")
1890            });
1891
1892        if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1893            return score;
1894        }
1895
1896        score + regression_loss(leaf_rows, targets, criterion)
1897            - (regression_loss(&left_rows, targets, criterion)
1898                + regression_loss(&right_rows, targets, criterion))
1899    });
1900
1901    (score > 0.0).then_some(ObliviousSplitCandidate {
1902        feature_index,
1903        threshold_bin: 0,
1904        score,
1905    })
1906}
1907
1908fn score_binary_oblivious_split_mean_fast(
1909    table: &dyn TableAccess,
1910    row_indices: &[usize],
1911    targets: &[f64],
1912    feature_index: usize,
1913    leaves: &[ObliviousLeafState],
1914    min_samples_leaf: usize,
1915) -> Option<ObliviousSplitCandidate> {
1916    let mut score = 0.0;
1917    let mut found_valid = false;
1918
1919    for leaf in leaves {
1920        let mut left_count = 0usize;
1921        let mut left_sum = 0.0;
1922        let mut left_sum_sq = 0.0;
1923
1924        for row_idx in &row_indices[leaf.start..leaf.end] {
1925            if !table
1926                .binned_boolean_value(feature_index, *row_idx)
1927                .expect("binary feature must expose boolean values")
1928            {
1929                let target = targets[*row_idx];
1930                left_count += 1;
1931                left_sum += target;
1932                left_sum_sq += target * target;
1933            }
1934        }
1935
1936        let right_count = leaf.len() - left_count;
1937        if left_count < min_samples_leaf || right_count < min_samples_leaf {
1938            continue;
1939        }
1940
1941        found_valid = true;
1942        let parent_loss = leaf.sum_sq - (leaf.sum * leaf.sum) / leaf.len() as f64;
1943        let right_sum = leaf.sum - left_sum;
1944        let right_sum_sq = leaf.sum_sq - left_sum_sq;
1945        let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1946        let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1947        score += parent_loss - (left_loss + right_loss);
1948    }
1949
1950    (found_valid && score > 0.0).then_some(ObliviousSplitCandidate {
1951        feature_index,
1952        threshold_bin: 0,
1953        score,
1954    })
1955}
1956
1957fn sum_stats(rows: &[usize], targets: &[f64]) -> (f64, f64) {
1958    rows.iter().fold((0.0, 0.0), |(sum, sum_sq), row_idx| {
1959        let value = targets[*row_idx];
1960        (sum + value, sum_sq + value * value)
1961    })
1962}
1963
1964fn variance_from_stats(count: usize, sum: f64, sum_sq: f64) -> Option<f64> {
1965    if count == 0 {
1966        None
1967    } else {
1968        Some((sum_sq / count as f64) - (sum / count as f64).powi(2))
1969    }
1970}
1971
1972fn leaf_regression_loss(
1973    leaf: &ObliviousLeafState,
1974    row_indices: &[usize],
1975    targets: &[f64],
1976    criterion: Criterion,
1977) -> f64 {
1978    match criterion {
1979        Criterion::Mean => leaf.sum_sq - (leaf.sum * leaf.sum) / leaf.len() as f64,
1980        Criterion::Median => {
1981            regression_loss(&row_indices[leaf.start..leaf.end], targets, criterion)
1982        }
1983        _ => unreachable!("regression criterion only supports mean or median"),
1984    }
1985}
1986
1987fn has_constant_target(rows: &[usize], targets: &[f64]) -> bool {
1988    rows.first().is_none_or(|first_row| {
1989        rows.iter()
1990            .all(|row_idx| targets[*row_idx] == targets[*first_row])
1991    })
1992}
1993
1994fn push_leaf(
1995    nodes: &mut Vec<RegressionNode>,
1996    value: f64,
1997    sample_count: usize,
1998    variance: Option<f64>,
1999) -> usize {
2000    push_node(
2001        nodes,
2002        RegressionNode::Leaf {
2003            value,
2004            sample_count,
2005            variance,
2006        },
2007    )
2008}
2009
2010fn push_node(nodes: &mut Vec<RegressionNode>, node: RegressionNode) -> usize {
2011    nodes.push(node);
2012    nodes.len() - 1
2013}
2014
2015#[cfg(test)]
2016mod tests {
2017    use super::*;
2018    use crate::{FeaturePreprocessing, Model, NumericBinBoundary};
2019    use forestfire_data::{DenseTable, NumericBins};
2020
2021    fn quadratic_table() -> DenseTable {
2022        DenseTable::with_options(
2023            vec![
2024                vec![0.0],
2025                vec![1.0],
2026                vec![2.0],
2027                vec![3.0],
2028                vec![4.0],
2029                vec![5.0],
2030                vec![6.0],
2031                vec![7.0],
2032            ],
2033            vec![0.0, 1.0, 4.0, 9.0, 16.0, 25.0, 36.0, 49.0],
2034            0,
2035            NumericBins::Fixed(128),
2036        )
2037        .unwrap()
2038    }
2039
2040    fn canary_target_table() -> DenseTable {
2041        let x: Vec<Vec<f64>> = (0..8).map(|value| vec![value as f64]).collect();
2042        let probe =
2043            DenseTable::with_options(x.clone(), vec![0.0; 8], 1, NumericBins::Auto).unwrap();
2044        let canary_index = probe.n_features();
2045        let y = (0..probe.n_rows())
2046            .map(|row_idx| probe.binned_value(canary_index, row_idx) as f64)
2047            .collect();
2048
2049        DenseTable::with_options(x, y, 1, NumericBins::Auto).unwrap()
2050    }
2051
2052    fn randomized_permutation_table() -> DenseTable {
2053        DenseTable::with_options(
2054            vec![
2055                vec![0.0, 0.0, 0.0],
2056                vec![0.0, 0.0, 1.0],
2057                vec![0.0, 1.0, 0.0],
2058                vec![0.0, 1.0, 1.0],
2059                vec![1.0, 0.0, 0.0],
2060                vec![1.0, 0.0, 1.0],
2061                vec![1.0, 1.0, 0.0],
2062                vec![1.0, 1.0, 1.0],
2063                vec![0.0, 0.0, 2.0],
2064                vec![0.0, 1.0, 2.0],
2065                vec![1.0, 0.0, 2.0],
2066                vec![1.0, 1.0, 2.0],
2067            ],
2068            vec![0.0, 1.0, 2.5, 3.5, 4.0, 5.0, 6.5, 7.5, 2.0, 4.5, 6.0, 8.5],
2069            0,
2070            NumericBins::Fixed(8),
2071        )
2072        .unwrap()
2073    }
2074
2075    #[test]
2076    fn cart_regressor_fits_basic_numeric_pattern() {
2077        let table = quadratic_table();
2078        let model = train_cart_regressor(&table).unwrap();
2079        let preds = model.predict_table(&table);
2080
2081        assert_eq!(model.algorithm(), RegressionTreeAlgorithm::Cart);
2082        assert_eq!(model.criterion(), Criterion::Mean);
2083        assert_eq!(preds, table_targets(&table));
2084    }
2085
2086    #[test]
2087    fn randomized_regressor_fits_basic_numeric_pattern() {
2088        let table = quadratic_table();
2089        let model = train_randomized_regressor(&table).unwrap();
2090        let preds = model.predict_table(&table);
2091        let targets = table_targets(&table);
2092        let baseline_mean = targets.iter().sum::<f64>() / targets.len() as f64;
2093        let baseline_sse = targets
2094            .iter()
2095            .map(|target| {
2096                let diff = target - baseline_mean;
2097                diff * diff
2098            })
2099            .sum::<f64>();
2100        let model_sse = preds
2101            .iter()
2102            .zip(targets.iter())
2103            .map(|(pred, target)| {
2104                let diff = pred - target;
2105                diff * diff
2106            })
2107            .sum::<f64>();
2108
2109        assert_eq!(model.algorithm(), RegressionTreeAlgorithm::Randomized);
2110        assert_eq!(model.criterion(), Criterion::Mean);
2111        assert!(model_sse < baseline_sse);
2112    }
2113
2114    #[test]
2115    fn randomized_regressor_is_repeatable_for_fixed_seed_and_varies_across_seeds() {
2116        let table = randomized_permutation_table();
2117        let make_options = |random_seed| RegressionTreeOptions {
2118            max_depth: 4,
2119            max_features: Some(2),
2120            random_seed,
2121            ..RegressionTreeOptions::default()
2122        };
2123
2124        let base_model = train_randomized_regressor_with_criterion_parallelism_and_options(
2125            &table,
2126            Criterion::Mean,
2127            Parallelism::sequential(),
2128            make_options(91),
2129        )
2130        .unwrap();
2131        let repeated_model = train_randomized_regressor_with_criterion_parallelism_and_options(
2132            &table,
2133            Criterion::Mean,
2134            Parallelism::sequential(),
2135            make_options(91),
2136        )
2137        .unwrap();
2138        let unique_serializations = (0..32u64)
2139            .map(|seed| {
2140                Model::DecisionTreeRegressor(
2141                    train_randomized_regressor_with_criterion_parallelism_and_options(
2142                        &table,
2143                        Criterion::Mean,
2144                        Parallelism::sequential(),
2145                        make_options(seed),
2146                    )
2147                    .unwrap(),
2148                )
2149                .serialize()
2150                .unwrap()
2151            })
2152            .collect::<std::collections::BTreeSet<_>>();
2153
2154        assert_eq!(
2155            Model::DecisionTreeRegressor(base_model.clone())
2156                .serialize()
2157                .unwrap(),
2158            Model::DecisionTreeRegressor(repeated_model)
2159                .serialize()
2160                .unwrap()
2161        );
2162        assert!(unique_serializations.len() >= 4);
2163    }
2164
2165    #[test]
2166    fn oblivious_regressor_fits_basic_numeric_pattern() {
2167        let table = quadratic_table();
2168        let model = train_oblivious_regressor(&table).unwrap();
2169        let preds = model.predict_table(&table);
2170
2171        assert_eq!(model.algorithm(), RegressionTreeAlgorithm::Oblivious);
2172        assert_eq!(model.criterion(), Criterion::Mean);
2173        assert_eq!(preds, table_targets(&table));
2174    }
2175
2176    #[test]
2177    fn regressors_can_choose_between_mean_and_median() {
2178        let table = DenseTable::with_canaries(
2179            vec![vec![0.0], vec![0.0], vec![0.0]],
2180            vec![0.0, 0.0, 100.0],
2181            0,
2182        )
2183        .unwrap();
2184
2185        let mean_model = train_cart_regressor_with_criterion(&table, Criterion::Mean).unwrap();
2186        let median_model = train_cart_regressor_with_criterion(&table, Criterion::Median).unwrap();
2187
2188        assert_eq!(mean_model.criterion(), Criterion::Mean);
2189        assert_eq!(median_model.criterion(), Criterion::Median);
2190        assert_eq!(
2191            mean_model.predict_table(&table),
2192            vec![100.0 / 3.0, 100.0 / 3.0, 100.0 / 3.0]
2193        );
2194        assert_eq!(median_model.predict_table(&table), vec![0.0, 0.0, 0.0]);
2195    }
2196
2197    #[test]
2198    fn rejects_non_finite_targets() {
2199        let table = DenseTable::new(vec![vec![0.0], vec![1.0]], vec![0.0, f64::NAN]).unwrap();
2200
2201        let err = train_cart_regressor(&table).unwrap_err();
2202        assert!(matches!(
2203            err,
2204            RegressionTreeError::InvalidTargetValue { row: 1, value } if value.is_nan()
2205        ));
2206    }
2207
2208    #[test]
2209    fn stops_cart_regressor_growth_when_a_canary_wins() {
2210        let table = canary_target_table();
2211        let model = train_cart_regressor(&table).unwrap();
2212        let preds = model.predict_table(&table);
2213
2214        assert!(preds.iter().all(|pred| *pred == preds[0]));
2215        assert_ne!(preds, table_targets(&table));
2216    }
2217
2218    #[test]
2219    fn stops_oblivious_regressor_growth_when_a_canary_wins() {
2220        let table = canary_target_table();
2221        let model = train_oblivious_regressor(&table).unwrap();
2222        let preds = model.predict_table(&table);
2223
2224        assert!(preds.iter().all(|pred| *pred == preds[0]));
2225        assert_ne!(preds, table_targets(&table));
2226    }
2227
2228    #[test]
2229    fn manually_built_regressor_models_serialize_for_each_tree_type() {
2230        let preprocessing = vec![
2231            FeaturePreprocessing::Binary,
2232            FeaturePreprocessing::Numeric {
2233                bin_boundaries: vec![
2234                    NumericBinBoundary {
2235                        bin: 0,
2236                        upper_bound: 1.0,
2237                    },
2238                    NumericBinBoundary {
2239                        bin: 127,
2240                        upper_bound: 10.0,
2241                    },
2242                ],
2243            },
2244        ];
2245        let options = RegressionTreeOptions::default();
2246
2247        let cart = Model::DecisionTreeRegressor(DecisionTreeRegressor {
2248            algorithm: RegressionTreeAlgorithm::Cart,
2249            criterion: Criterion::Mean,
2250            structure: RegressionTreeStructure::Standard {
2251                nodes: vec![
2252                    RegressionNode::Leaf {
2253                        value: -1.0,
2254                        sample_count: 2,
2255                        variance: Some(0.25),
2256                    },
2257                    RegressionNode::Leaf {
2258                        value: 2.5,
2259                        sample_count: 3,
2260                        variance: Some(1.0),
2261                    },
2262                    RegressionNode::BinarySplit {
2263                        feature_index: 0,
2264                        threshold_bin: 0,
2265                        left_child: 0,
2266                        right_child: 1,
2267                        sample_count: 5,
2268                        impurity: 3.5,
2269                        gain: 1.25,
2270                        variance: Some(0.7),
2271                    },
2272                ],
2273                root: 2,
2274            },
2275            options,
2276            num_features: 2,
2277            feature_preprocessing: preprocessing.clone(),
2278            training_canaries: 0,
2279        });
2280        let randomized = Model::DecisionTreeRegressor(DecisionTreeRegressor {
2281            algorithm: RegressionTreeAlgorithm::Randomized,
2282            criterion: Criterion::Median,
2283            structure: RegressionTreeStructure::Standard {
2284                nodes: vec![
2285                    RegressionNode::Leaf {
2286                        value: -1.0,
2287                        sample_count: 2,
2288                        variance: Some(0.25),
2289                    },
2290                    RegressionNode::Leaf {
2291                        value: 2.5,
2292                        sample_count: 3,
2293                        variance: Some(1.0),
2294                    },
2295                    RegressionNode::BinarySplit {
2296                        feature_index: 0,
2297                        threshold_bin: 0,
2298                        left_child: 0,
2299                        right_child: 1,
2300                        sample_count: 5,
2301                        impurity: 3.5,
2302                        gain: 0.8,
2303                        variance: Some(0.7),
2304                    },
2305                ],
2306                root: 2,
2307            },
2308            options,
2309            num_features: 2,
2310            feature_preprocessing: preprocessing.clone(),
2311            training_canaries: 0,
2312        });
2313        let oblivious = Model::DecisionTreeRegressor(DecisionTreeRegressor {
2314            algorithm: RegressionTreeAlgorithm::Oblivious,
2315            criterion: Criterion::Median,
2316            structure: RegressionTreeStructure::Oblivious {
2317                splits: vec![ObliviousSplit {
2318                    feature_index: 1,
2319                    threshold_bin: 127,
2320                    sample_count: 4,
2321                    impurity: 2.0,
2322                    gain: 0.5,
2323                }],
2324                leaf_values: vec![0.0, 10.0],
2325                leaf_sample_counts: vec![2, 2],
2326                leaf_variances: vec![Some(0.0), Some(1.0)],
2327            },
2328            options,
2329            num_features: 2,
2330            feature_preprocessing: preprocessing,
2331            training_canaries: 0,
2332        });
2333
2334        for (tree_type, model) in [
2335            ("cart", cart),
2336            ("randomized", randomized),
2337            ("oblivious", oblivious),
2338        ] {
2339            let json = model.serialize().unwrap();
2340            assert!(json.contains(&format!("\"tree_type\":\"{tree_type}\"")));
2341            assert!(json.contains("\"task\":\"regression\""));
2342        }
2343    }
2344
2345    fn table_targets(table: &dyn TableAccess) -> Vec<f64> {
2346        (0..table.n_rows())
2347            .map(|row_idx| table.target_value(row_idx))
2348            .collect()
2349    }
2350}