Skip to main content

forestfire_core/tree/
regressor.rs

1//! First-order regression tree learners.
2//!
3//! This module is the regression analogue of `classifier`, but the split logic
4//! differs in one important way: regression quality depends on leaf value
5//! statistics rather than class counts. The implementation therefore leans on
6//! cached count/sum/sum-of-squares histograms in the mean-criterion hot path.
7
8use crate::ir::{
9    BinaryChildren, BinarySplit, IndexedLeaf, LeafIndexing, LeafPayload, NodeStats, NodeTreeNode,
10    ObliviousLevel, ObliviousSplit as IrObliviousSplit, TrainingMetadata, TreeDefinition,
11    criterion_name, feature_name, threshold_upper_bound,
12};
13use crate::tree::shared::{
14    FeatureHistogram, HistogramBin, MissingBranchDirection, build_feature_histograms,
15    candidate_feature_indices, choose_random_threshold, node_seed, partition_rows_for_binary_split,
16    subtract_feature_histograms,
17};
18use crate::{
19    Criterion, FeaturePreprocessing, MissingValueStrategy, Parallelism,
20    capture_feature_preprocessing,
21};
22use forestfire_data::TableAccess;
23use rayon::prelude::*;
24use std::collections::BTreeSet;
25use std::error::Error;
26use std::fmt::{Display, Formatter};
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum RegressionTreeAlgorithm {
30    Cart,
31    Randomized,
32    Oblivious,
33}
34
35/// Shared training controls for regression tree learners.
36#[derive(Debug, Clone)]
37pub struct RegressionTreeOptions {
38    pub max_depth: usize,
39    pub min_samples_split: usize,
40    pub min_samples_leaf: usize,
41    pub max_features: Option<usize>,
42    pub random_seed: u64,
43    pub missing_value_strategies: Vec<MissingValueStrategy>,
44}
45
46impl Default for RegressionTreeOptions {
47    fn default() -> Self {
48        Self {
49            max_depth: 8,
50            min_samples_split: 2,
51            min_samples_leaf: 1,
52            max_features: None,
53            random_seed: 0,
54            missing_value_strategies: Vec::new(),
55        }
56    }
57}
58
59impl RegressionTreeOptions {
60    fn missing_value_strategy(&self, feature_index: usize) -> MissingValueStrategy {
61        self.missing_value_strategies
62            .get(feature_index)
63            .copied()
64            .unwrap_or(MissingValueStrategy::Heuristic)
65    }
66}
67
68#[derive(Debug)]
69pub enum RegressionTreeError {
70    EmptyTarget,
71    InvalidTargetValue { row: usize, value: f64 },
72}
73
74impl Display for RegressionTreeError {
75    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
76        match self {
77            RegressionTreeError::EmptyTarget => {
78                write!(f, "Cannot train on an empty target vector.")
79            }
80            RegressionTreeError::InvalidTargetValue { row, value } => write!(
81                f,
82                "Regression targets must be finite values. Found {} at row {}.",
83                value, row
84            ),
85        }
86    }
87}
88
89impl Error for RegressionTreeError {}
90
91/// Concrete trained regression tree.
92#[derive(Debug, Clone)]
93pub struct DecisionTreeRegressor {
94    algorithm: RegressionTreeAlgorithm,
95    criterion: Criterion,
96    structure: RegressionTreeStructure,
97    options: RegressionTreeOptions,
98    num_features: usize,
99    feature_preprocessing: Vec<FeaturePreprocessing>,
100    training_canaries: usize,
101}
102
103#[derive(Debug, Clone)]
104pub(crate) enum RegressionTreeStructure {
105    Standard {
106        nodes: Vec<RegressionNode>,
107        root: usize,
108    },
109    Oblivious {
110        splits: Vec<ObliviousSplit>,
111        leaf_values: Vec<f64>,
112        leaf_sample_counts: Vec<usize>,
113        leaf_variances: Vec<Option<f64>>,
114    },
115}
116
117#[derive(Debug, Clone)]
118pub(crate) enum RegressionNode {
119    Leaf {
120        value: f64,
121        sample_count: usize,
122        variance: Option<f64>,
123    },
124    BinarySplit {
125        feature_index: usize,
126        threshold_bin: u16,
127        missing_direction: MissingBranchDirection,
128        missing_value: f64,
129        left_child: usize,
130        right_child: usize,
131        sample_count: usize,
132        impurity: f64,
133        gain: f64,
134        variance: Option<f64>,
135    },
136}
137
138#[derive(Debug, Clone, Copy)]
139pub(crate) struct ObliviousSplit {
140    pub(crate) feature_index: usize,
141    pub(crate) threshold_bin: u16,
142    pub(crate) sample_count: usize,
143    pub(crate) impurity: f64,
144    pub(crate) gain: f64,
145}
146
147#[derive(Debug, Clone)]
148struct RegressionSplitCandidate {
149    feature_index: usize,
150    threshold_bin: u16,
151    score: f64,
152    missing_direction: MissingBranchDirection,
153}
154
155#[derive(Debug, Clone)]
156struct ObliviousLeafState {
157    start: usize,
158    end: usize,
159    value: f64,
160    variance: Option<f64>,
161    sum: f64,
162    sum_sq: f64,
163}
164
165impl ObliviousLeafState {
166    fn len(&self) -> usize {
167        self.end - self.start
168    }
169}
170
171#[derive(Debug, Clone, Copy)]
172struct ObliviousSplitCandidate {
173    feature_index: usize,
174    threshold_bin: u16,
175    score: f64,
176}
177
178#[derive(Debug, Clone, Copy)]
179struct BinarySplitChoice {
180    feature_index: usize,
181    threshold_bin: u16,
182    score: f64,
183    missing_direction: MissingBranchDirection,
184}
185
186#[derive(Debug, Clone)]
187struct RegressionHistogramBin {
188    count: usize,
189    sum: f64,
190    sum_sq: f64,
191}
192
193impl HistogramBin for RegressionHistogramBin {
194    fn subtract(parent: &Self, child: &Self) -> Self {
195        Self {
196            count: parent.count - child.count,
197            sum: parent.sum - child.sum,
198            sum_sq: parent.sum_sq - child.sum_sq,
199        }
200    }
201
202    fn is_observed(&self) -> bool {
203        self.count > 0
204    }
205}
206
207type RegressionFeatureHistogram = FeatureHistogram<RegressionHistogramBin>;
208
209pub fn train_cart_regressor(
210    train_set: &dyn TableAccess,
211) -> Result<DecisionTreeRegressor, RegressionTreeError> {
212    train_cart_regressor_with_criterion(train_set, Criterion::Mean)
213}
214
215pub fn train_cart_regressor_with_criterion(
216    train_set: &dyn TableAccess,
217    criterion: Criterion,
218) -> Result<DecisionTreeRegressor, RegressionTreeError> {
219    train_cart_regressor_with_criterion_and_parallelism(
220        train_set,
221        criterion,
222        Parallelism::sequential(),
223    )
224}
225
226pub(crate) fn train_cart_regressor_with_criterion_and_parallelism(
227    train_set: &dyn TableAccess,
228    criterion: Criterion,
229    parallelism: Parallelism,
230) -> Result<DecisionTreeRegressor, RegressionTreeError> {
231    train_cart_regressor_with_criterion_parallelism_and_options(
232        train_set,
233        criterion,
234        parallelism,
235        RegressionTreeOptions::default(),
236    )
237}
238
239pub(crate) fn train_cart_regressor_with_criterion_parallelism_and_options(
240    train_set: &dyn TableAccess,
241    criterion: Criterion,
242    parallelism: Parallelism,
243    options: RegressionTreeOptions,
244) -> Result<DecisionTreeRegressor, RegressionTreeError> {
245    train_regressor(
246        train_set,
247        RegressionTreeAlgorithm::Cart,
248        criterion,
249        parallelism,
250        options,
251    )
252}
253
254pub fn train_oblivious_regressor(
255    train_set: &dyn TableAccess,
256) -> Result<DecisionTreeRegressor, RegressionTreeError> {
257    train_oblivious_regressor_with_criterion(train_set, Criterion::Mean)
258}
259
260pub fn train_oblivious_regressor_with_criterion(
261    train_set: &dyn TableAccess,
262    criterion: Criterion,
263) -> Result<DecisionTreeRegressor, RegressionTreeError> {
264    train_oblivious_regressor_with_criterion_and_parallelism(
265        train_set,
266        criterion,
267        Parallelism::sequential(),
268    )
269}
270
271pub(crate) fn train_oblivious_regressor_with_criterion_and_parallelism(
272    train_set: &dyn TableAccess,
273    criterion: Criterion,
274    parallelism: Parallelism,
275) -> Result<DecisionTreeRegressor, RegressionTreeError> {
276    train_oblivious_regressor_with_criterion_parallelism_and_options(
277        train_set,
278        criterion,
279        parallelism,
280        RegressionTreeOptions::default(),
281    )
282}
283
284pub(crate) fn train_oblivious_regressor_with_criterion_parallelism_and_options(
285    train_set: &dyn TableAccess,
286    criterion: Criterion,
287    parallelism: Parallelism,
288    options: RegressionTreeOptions,
289) -> Result<DecisionTreeRegressor, RegressionTreeError> {
290    train_regressor(
291        train_set,
292        RegressionTreeAlgorithm::Oblivious,
293        criterion,
294        parallelism,
295        options,
296    )
297}
298
299pub fn train_randomized_regressor(
300    train_set: &dyn TableAccess,
301) -> Result<DecisionTreeRegressor, RegressionTreeError> {
302    train_randomized_regressor_with_criterion(train_set, Criterion::Mean)
303}
304
305pub fn train_randomized_regressor_with_criterion(
306    train_set: &dyn TableAccess,
307    criterion: Criterion,
308) -> Result<DecisionTreeRegressor, RegressionTreeError> {
309    train_randomized_regressor_with_criterion_and_parallelism(
310        train_set,
311        criterion,
312        Parallelism::sequential(),
313    )
314}
315
316pub(crate) fn train_randomized_regressor_with_criterion_and_parallelism(
317    train_set: &dyn TableAccess,
318    criterion: Criterion,
319    parallelism: Parallelism,
320) -> Result<DecisionTreeRegressor, RegressionTreeError> {
321    train_randomized_regressor_with_criterion_parallelism_and_options(
322        train_set,
323        criterion,
324        parallelism,
325        RegressionTreeOptions::default(),
326    )
327}
328
329pub(crate) fn train_randomized_regressor_with_criterion_parallelism_and_options(
330    train_set: &dyn TableAccess,
331    criterion: Criterion,
332    parallelism: Parallelism,
333    options: RegressionTreeOptions,
334) -> Result<DecisionTreeRegressor, RegressionTreeError> {
335    train_regressor(
336        train_set,
337        RegressionTreeAlgorithm::Randomized,
338        criterion,
339        parallelism,
340        options,
341    )
342}
343
344fn train_regressor(
345    train_set: &dyn TableAccess,
346    algorithm: RegressionTreeAlgorithm,
347    criterion: Criterion,
348    parallelism: Parallelism,
349    options: RegressionTreeOptions,
350) -> Result<DecisionTreeRegressor, RegressionTreeError> {
351    if train_set.n_rows() == 0 {
352        return Err(RegressionTreeError::EmptyTarget);
353    }
354
355    let targets = finite_targets(train_set)?;
356    let structure = match algorithm {
357        RegressionTreeAlgorithm::Cart => {
358            let mut nodes = Vec::new();
359            let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
360            let context = BuildContext {
361                table: train_set,
362                targets: &targets,
363                criterion,
364                parallelism,
365                options: options.clone(),
366                algorithm,
367            };
368            // CART and randomized regression reuse a single mutable row-index
369            // buffer so child partitions are formed in place instead of by
370            // allocating fresh row vectors for every split.
371            let root = build_binary_node_in_place(&context, &mut nodes, &mut all_rows, 0);
372            RegressionTreeStructure::Standard { nodes, root }
373        }
374        RegressionTreeAlgorithm::Randomized => {
375            let mut nodes = Vec::new();
376            let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
377            let context = BuildContext {
378                table: train_set,
379                targets: &targets,
380                criterion,
381                parallelism,
382                options: options.clone(),
383                algorithm,
384            };
385            let root = build_binary_node_in_place(&context, &mut nodes, &mut all_rows, 0);
386            RegressionTreeStructure::Standard { nodes, root }
387        }
388        RegressionTreeAlgorithm::Oblivious => {
389            // Oblivious trees are built level by level because every node at a
390            // given depth must share the same split.
391            train_oblivious_structure(train_set, &targets, criterion, parallelism, options.clone())
392        }
393    };
394
395    Ok(DecisionTreeRegressor {
396        algorithm,
397        criterion,
398        structure,
399        options,
400        num_features: train_set.n_features(),
401        feature_preprocessing: capture_feature_preprocessing(train_set),
402        training_canaries: train_set.canaries(),
403    })
404}
405
406impl DecisionTreeRegressor {
407    /// Which learner family produced this tree.
408    pub fn algorithm(&self) -> RegressionTreeAlgorithm {
409        self.algorithm
410    }
411
412    /// Split criterion used during training.
413    pub fn criterion(&self) -> Criterion {
414        self.criterion
415    }
416
417    /// Predict one numeric value per row from a preprocessed table.
418    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
419        (0..table.n_rows())
420            .map(|row_idx| self.predict_row(table, row_idx))
421            .collect()
422    }
423
424    fn predict_row(&self, table: &dyn TableAccess, row_idx: usize) -> f64 {
425        match &self.structure {
426            RegressionTreeStructure::Standard { nodes, root } => {
427                let mut node_index = *root;
428
429                loop {
430                    match &nodes[node_index] {
431                        RegressionNode::Leaf { value, .. } => return *value,
432                        RegressionNode::BinarySplit {
433                            feature_index,
434                            threshold_bin,
435                            missing_direction,
436                            missing_value,
437                            left_child,
438                            right_child,
439                            ..
440                        } => {
441                            if table.is_missing(*feature_index, row_idx) {
442                                match missing_direction {
443                                    MissingBranchDirection::Left => {
444                                        node_index = *left_child;
445                                    }
446                                    MissingBranchDirection::Right => {
447                                        node_index = *right_child;
448                                    }
449                                    MissingBranchDirection::Node => return *missing_value,
450                                }
451                                continue;
452                            }
453                            let bin = table.binned_value(*feature_index, row_idx);
454                            node_index = if bin <= *threshold_bin {
455                                *left_child
456                            } else {
457                                *right_child
458                            };
459                        }
460                    }
461                }
462            }
463            RegressionTreeStructure::Oblivious {
464                splits,
465                leaf_values,
466                ..
467            } => {
468                let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
469                    let go_right =
470                        table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
471                    (leaf_index << 1) | usize::from(go_right)
472                });
473
474                leaf_values[leaf_index]
475            }
476        }
477    }
478
479    pub(crate) fn num_features(&self) -> usize {
480        self.num_features
481    }
482
483    pub(crate) fn structure(&self) -> &RegressionTreeStructure {
484        &self.structure
485    }
486
487    pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
488        &self.feature_preprocessing
489    }
490
491    pub(crate) fn training_metadata(&self) -> TrainingMetadata {
492        TrainingMetadata {
493            algorithm: "dt".to_string(),
494            task: "regression".to_string(),
495            tree_type: match self.algorithm {
496                RegressionTreeAlgorithm::Cart => "cart".to_string(),
497                RegressionTreeAlgorithm::Randomized => "randomized".to_string(),
498                RegressionTreeAlgorithm::Oblivious => "oblivious".to_string(),
499            },
500            criterion: criterion_name(self.criterion).to_string(),
501            canaries: self.training_canaries,
502            compute_oob: false,
503            max_depth: Some(self.options.max_depth),
504            min_samples_split: Some(self.options.min_samples_split),
505            min_samples_leaf: Some(self.options.min_samples_leaf),
506            n_trees: None,
507            max_features: self.options.max_features,
508            seed: None,
509            oob_score: None,
510            class_labels: None,
511            learning_rate: None,
512            bootstrap: None,
513            top_gradient_fraction: None,
514            other_gradient_fraction: None,
515        }
516    }
517
518    pub(crate) fn to_ir_tree(&self) -> TreeDefinition {
519        match &self.structure {
520            RegressionTreeStructure::Standard { nodes, root } => {
521                let depths = standard_node_depths(nodes, *root);
522                TreeDefinition::NodeTree {
523                    tree_id: 0,
524                    weight: 1.0,
525                    root_node_id: *root,
526                    nodes: nodes
527                        .iter()
528                        .enumerate()
529                        .map(|(node_id, node)| match node {
530                            RegressionNode::Leaf {
531                                value,
532                                sample_count,
533                                variance,
534                            } => NodeTreeNode::Leaf {
535                                node_id,
536                                depth: depths[node_id],
537                                leaf: LeafPayload::RegressionValue { value: *value },
538                                stats: NodeStats {
539                                    sample_count: *sample_count,
540                                    impurity: None,
541                                    gain: None,
542                                    class_counts: None,
543                                    variance: *variance,
544                                },
545                            },
546                            RegressionNode::BinarySplit {
547                                feature_index,
548                                threshold_bin,
549                                missing_direction,
550                                missing_value: _,
551                                left_child,
552                                right_child,
553                                sample_count,
554                                impurity,
555                                gain,
556                                variance,
557                            } => NodeTreeNode::BinaryBranch {
558                                node_id,
559                                depth: depths[node_id],
560                                split: binary_split_ir(
561                                    *feature_index,
562                                    *threshold_bin,
563                                    *missing_direction,
564                                    &self.feature_preprocessing,
565                                ),
566                                children: BinaryChildren {
567                                    left: *left_child,
568                                    right: *right_child,
569                                },
570                                stats: NodeStats {
571                                    sample_count: *sample_count,
572                                    impurity: Some(*impurity),
573                                    gain: Some(*gain),
574                                    class_counts: None,
575                                    variance: *variance,
576                                },
577                            },
578                        })
579                        .collect(),
580                }
581            }
582            RegressionTreeStructure::Oblivious {
583                splits,
584                leaf_values,
585                leaf_sample_counts,
586                leaf_variances,
587            } => TreeDefinition::ObliviousLevels {
588                tree_id: 0,
589                weight: 1.0,
590                depth: splits.len(),
591                levels: splits
592                    .iter()
593                    .enumerate()
594                    .map(|(level, split)| ObliviousLevel {
595                        level,
596                        split: oblivious_split_ir(
597                            split.feature_index,
598                            split.threshold_bin,
599                            &self.feature_preprocessing,
600                        ),
601                        stats: NodeStats {
602                            sample_count: split.sample_count,
603                            impurity: Some(split.impurity),
604                            gain: Some(split.gain),
605                            class_counts: None,
606                            variance: None,
607                        },
608                    })
609                    .collect(),
610                leaf_indexing: LeafIndexing {
611                    bit_order: "msb_first".to_string(),
612                    index_formula: "sum(bit[level] << (depth - 1 - level))".to_string(),
613                },
614                leaves: leaf_values
615                    .iter()
616                    .enumerate()
617                    .map(|(leaf_index, value)| IndexedLeaf {
618                        leaf_index,
619                        leaf: LeafPayload::RegressionValue { value: *value },
620                        stats: NodeStats {
621                            sample_count: leaf_sample_counts[leaf_index],
622                            impurity: None,
623                            gain: None,
624                            class_counts: None,
625                            variance: leaf_variances[leaf_index],
626                        },
627                    })
628                    .collect(),
629            },
630        }
631    }
632
633    pub(crate) fn from_ir_parts(
634        algorithm: RegressionTreeAlgorithm,
635        criterion: Criterion,
636        structure: RegressionTreeStructure,
637        options: RegressionTreeOptions,
638        num_features: usize,
639        feature_preprocessing: Vec<FeaturePreprocessing>,
640        training_canaries: usize,
641    ) -> Self {
642        Self {
643            algorithm,
644            criterion,
645            structure,
646            options: options.clone(),
647            num_features,
648            feature_preprocessing,
649            training_canaries,
650        }
651    }
652}
653
654fn standard_node_depths(nodes: &[RegressionNode], root: usize) -> Vec<usize> {
655    let mut depths = vec![0; nodes.len()];
656    populate_depths(nodes, root, 0, &mut depths);
657    depths
658}
659
660fn populate_depths(nodes: &[RegressionNode], node_id: usize, depth: usize, depths: &mut [usize]) {
661    depths[node_id] = depth;
662    match &nodes[node_id] {
663        RegressionNode::Leaf { .. } => {}
664        RegressionNode::BinarySplit {
665            left_child,
666            right_child,
667            ..
668        } => {
669            populate_depths(nodes, *left_child, depth + 1, depths);
670            populate_depths(nodes, *right_child, depth + 1, depths);
671        }
672    }
673}
674
675fn binary_split_ir(
676    feature_index: usize,
677    threshold_bin: u16,
678    _missing_direction: MissingBranchDirection,
679    preprocessing: &[FeaturePreprocessing],
680) -> BinarySplit {
681    match preprocessing.get(feature_index) {
682        Some(FeaturePreprocessing::Binary) => BinarySplit::BooleanTest {
683            feature_index,
684            feature_name: feature_name(feature_index),
685            false_child_semantics: "left".to_string(),
686            true_child_semantics: "right".to_string(),
687        },
688        Some(FeaturePreprocessing::Numeric { .. }) | None => BinarySplit::NumericBinThreshold {
689            feature_index,
690            feature_name: feature_name(feature_index),
691            operator: "<=".to_string(),
692            threshold_bin,
693            threshold_upper_bound: threshold_upper_bound(
694                preprocessing,
695                feature_index,
696                threshold_bin,
697            ),
698            comparison_dtype: "uint16".to_string(),
699        },
700    }
701}
702
703fn oblivious_split_ir(
704    feature_index: usize,
705    threshold_bin: u16,
706    preprocessing: &[FeaturePreprocessing],
707) -> IrObliviousSplit {
708    match preprocessing.get(feature_index) {
709        Some(FeaturePreprocessing::Binary) => IrObliviousSplit::BooleanTest {
710            feature_index,
711            feature_name: feature_name(feature_index),
712            bit_when_false: 0,
713            bit_when_true: 1,
714        },
715        Some(FeaturePreprocessing::Numeric { .. }) | None => {
716            IrObliviousSplit::NumericBinThreshold {
717                feature_index,
718                feature_name: feature_name(feature_index),
719                operator: "<=".to_string(),
720                threshold_bin,
721                threshold_upper_bound: threshold_upper_bound(
722                    preprocessing,
723                    feature_index,
724                    threshold_bin,
725                ),
726                comparison_dtype: "uint16".to_string(),
727                bit_when_true: 0,
728                bit_when_false: 1,
729            }
730        }
731    }
732}
733
734struct BuildContext<'a> {
735    table: &'a dyn TableAccess,
736    targets: &'a [f64],
737    criterion: Criterion,
738    parallelism: Parallelism,
739    options: RegressionTreeOptions,
740    algorithm: RegressionTreeAlgorithm,
741}
742
743fn build_regression_node_histograms(
744    table: &dyn TableAccess,
745    targets: &[f64],
746    rows: &[usize],
747) -> Vec<RegressionFeatureHistogram> {
748    build_feature_histograms(
749        table,
750        rows,
751        |_| RegressionHistogramBin {
752            count: 0,
753            sum: 0.0,
754            sum_sq: 0.0,
755        },
756        |_feature_index, payload, row_idx| {
757            let value = targets[row_idx];
758            payload.count += 1;
759            payload.sum += value;
760            payload.sum_sq += value * value;
761        },
762    )
763}
764
765fn subtract_regression_node_histograms(
766    parent: &[RegressionFeatureHistogram],
767    child: &[RegressionFeatureHistogram],
768) -> Vec<RegressionFeatureHistogram> {
769    subtract_feature_histograms(parent, child)
770}
771
772fn finite_targets(train_set: &dyn TableAccess) -> Result<Vec<f64>, RegressionTreeError> {
773    (0..train_set.n_rows())
774        .map(|row_idx| {
775            let value = train_set.target_value(row_idx);
776            if value.is_finite() {
777                Ok(value)
778            } else {
779                Err(RegressionTreeError::InvalidTargetValue {
780                    row: row_idx,
781                    value,
782                })
783            }
784        })
785        .collect()
786}
787
788fn build_binary_node_in_place(
789    context: &BuildContext<'_>,
790    nodes: &mut Vec<RegressionNode>,
791    rows: &mut [usize],
792    depth: usize,
793) -> usize {
794    build_binary_node_in_place_with_hist(context, nodes, rows, depth, None)
795}
796
797fn build_binary_node_in_place_with_hist(
798    context: &BuildContext<'_>,
799    nodes: &mut Vec<RegressionNode>,
800    rows: &mut [usize],
801    depth: usize,
802    histograms: Option<Vec<RegressionFeatureHistogram>>,
803) -> usize {
804    let leaf_value = regression_value(rows, context.targets, context.criterion);
805    let leaf_variance = variance(rows, context.targets);
806
807    if rows.is_empty()
808        || depth >= context.options.max_depth
809        || rows.len() < context.options.min_samples_split
810        || has_constant_target(rows, context.targets)
811    {
812        return push_leaf(nodes, leaf_value, rows.len(), leaf_variance);
813    }
814
815    let histograms = if matches!(context.criterion, Criterion::Mean) {
816        Some(histograms.unwrap_or_else(|| {
817            build_regression_node_histograms(context.table, context.targets, rows)
818        }))
819    } else {
820        None
821    };
822    let feature_indices = candidate_feature_indices(
823        context.table.binned_feature_count(),
824        context.options.max_features,
825        node_seed(context.options.random_seed, depth, rows, 0xA11C_E5E1u64),
826    );
827    let best_split = if context.parallelism.enabled() {
828        feature_indices
829            .into_par_iter()
830            .filter_map(|feature_index| {
831                if let Some(histograms) = histograms.as_ref() {
832                    score_binary_split_choice_from_hist(
833                        context,
834                        &histograms[feature_index],
835                        feature_index,
836                        rows,
837                    )
838                } else {
839                    score_binary_split_choice(context, feature_index, rows)
840                }
841            })
842            .max_by(|left, right| left.score.total_cmp(&right.score))
843    } else {
844        feature_indices
845            .into_iter()
846            .filter_map(|feature_index| {
847                if let Some(histograms) = histograms.as_ref() {
848                    score_binary_split_choice_from_hist(
849                        context,
850                        &histograms[feature_index],
851                        feature_index,
852                        rows,
853                    )
854                } else {
855                    score_binary_split_choice(context, feature_index, rows)
856                }
857            })
858            .max_by(|left, right| left.score.total_cmp(&right.score))
859    };
860
861    match best_split {
862        Some(best_split)
863            if context
864                .table
865                .is_canary_binned_feature(best_split.feature_index) =>
866        {
867            push_leaf(nodes, leaf_value, rows.len(), leaf_variance)
868        }
869        Some(best_split) if best_split.score > 0.0 => {
870            let impurity = regression_loss(rows, context.targets, context.criterion);
871            let left_count = partition_rows_for_binary_split(
872                context.table,
873                best_split.feature_index,
874                best_split.threshold_bin,
875                best_split.missing_direction,
876                rows,
877            );
878            let (left_rows, right_rows) = rows.split_at_mut(left_count);
879            let (left_child, right_child) = if let Some(histograms) = histograms {
880                if left_rows.len() <= right_rows.len() {
881                    let left_histograms =
882                        build_regression_node_histograms(context.table, context.targets, left_rows);
883                    let right_histograms =
884                        subtract_regression_node_histograms(&histograms, &left_histograms);
885                    (
886                        build_binary_node_in_place_with_hist(
887                            context,
888                            nodes,
889                            left_rows,
890                            depth + 1,
891                            Some(left_histograms),
892                        ),
893                        build_binary_node_in_place_with_hist(
894                            context,
895                            nodes,
896                            right_rows,
897                            depth + 1,
898                            Some(right_histograms),
899                        ),
900                    )
901                } else {
902                    let right_histograms = build_regression_node_histograms(
903                        context.table,
904                        context.targets,
905                        right_rows,
906                    );
907                    let left_histograms =
908                        subtract_regression_node_histograms(&histograms, &right_histograms);
909                    (
910                        build_binary_node_in_place_with_hist(
911                            context,
912                            nodes,
913                            left_rows,
914                            depth + 1,
915                            Some(left_histograms),
916                        ),
917                        build_binary_node_in_place_with_hist(
918                            context,
919                            nodes,
920                            right_rows,
921                            depth + 1,
922                            Some(right_histograms),
923                        ),
924                    )
925                }
926            } else {
927                (
928                    build_binary_node_in_place(context, nodes, left_rows, depth + 1),
929                    build_binary_node_in_place(context, nodes, right_rows, depth + 1),
930                )
931            };
932
933            push_node(
934                nodes,
935                RegressionNode::BinarySplit {
936                    feature_index: best_split.feature_index,
937                    threshold_bin: best_split.threshold_bin,
938                    missing_direction: best_split.missing_direction,
939                    missing_value: leaf_value,
940                    left_child,
941                    right_child,
942                    sample_count: rows.len(),
943                    impurity,
944                    gain: best_split.score,
945                    variance: leaf_variance,
946                },
947            )
948        }
949        _ => push_leaf(nodes, leaf_value, rows.len(), leaf_variance),
950    }
951}
952
953fn train_oblivious_structure(
954    table: &dyn TableAccess,
955    targets: &[f64],
956    criterion: Criterion,
957    parallelism: Parallelism,
958    options: RegressionTreeOptions,
959) -> RegressionTreeStructure {
960    let mut row_indices: Vec<usize> = (0..table.n_rows()).collect();
961    let (root_sum, root_sum_sq) = sum_stats(&row_indices, targets);
962    let mut leaves = vec![ObliviousLeafState {
963        start: 0,
964        end: row_indices.len(),
965        value: regression_value_from_stats(&row_indices, targets, criterion, root_sum),
966        variance: variance_from_stats(row_indices.len(), root_sum, root_sum_sq),
967        sum: root_sum,
968        sum_sq: root_sum_sq,
969    }];
970    let mut splits = Vec::new();
971
972    for depth in 0..options.max_depth {
973        if leaves
974            .iter()
975            .all(|leaf| leaf.len() < options.min_samples_split)
976        {
977            break;
978        }
979        let feature_indices = candidate_feature_indices(
980            table.binned_feature_count(),
981            options.max_features,
982            node_seed(options.random_seed, depth, &[], 0x0B11_A10Cu64),
983        );
984        let best_split = if parallelism.enabled() {
985            feature_indices
986                .into_par_iter()
987                .filter_map(|feature_index| {
988                    score_oblivious_split(
989                        table,
990                        &row_indices,
991                        targets,
992                        feature_index,
993                        &leaves,
994                        criterion,
995                        options.min_samples_leaf,
996                    )
997                })
998                .max_by(|left, right| left.score.total_cmp(&right.score))
999        } else {
1000            feature_indices
1001                .into_iter()
1002                .filter_map(|feature_index| {
1003                    score_oblivious_split(
1004                        table,
1005                        &row_indices,
1006                        targets,
1007                        feature_index,
1008                        &leaves,
1009                        criterion,
1010                        options.min_samples_leaf,
1011                    )
1012                })
1013                .max_by(|left, right| left.score.total_cmp(&right.score))
1014        };
1015
1016        let Some(best_split) = best_split.filter(|candidate| candidate.score > 0.0) else {
1017            break;
1018        };
1019        if table.is_canary_binned_feature(best_split.feature_index) {
1020            break;
1021        }
1022
1023        leaves = split_oblivious_leaves_in_place(
1024            table,
1025            &mut row_indices,
1026            targets,
1027            leaves,
1028            best_split.feature_index,
1029            best_split.threshold_bin,
1030            criterion,
1031        );
1032        splits.push(ObliviousSplit {
1033            feature_index: best_split.feature_index,
1034            threshold_bin: best_split.threshold_bin,
1035            sample_count: table.n_rows(),
1036            impurity: leaves
1037                .iter()
1038                .map(|leaf| leaf_regression_loss(leaf, &row_indices, targets, criterion))
1039                .sum(),
1040            gain: best_split.score,
1041        });
1042    }
1043
1044    RegressionTreeStructure::Oblivious {
1045        splits,
1046        leaf_values: leaves.iter().map(|leaf| leaf.value).collect(),
1047        leaf_sample_counts: leaves.iter().map(ObliviousLeafState::len).collect(),
1048        leaf_variances: leaves.iter().map(|leaf| leaf.variance).collect(),
1049    }
1050}
1051
1052#[allow(clippy::too_many_arguments)]
1053fn score_split(
1054    table: &dyn TableAccess,
1055    targets: &[f64],
1056    feature_index: usize,
1057    rows: &[usize],
1058    criterion: Criterion,
1059    min_samples_leaf: usize,
1060    algorithm: RegressionTreeAlgorithm,
1061    strategy: MissingValueStrategy,
1062) -> Option<RegressionSplitCandidate> {
1063    if table.is_binary_binned_feature(feature_index) {
1064        return score_binary_split(
1065            table,
1066            targets,
1067            feature_index,
1068            rows,
1069            criterion,
1070            min_samples_leaf,
1071            strategy,
1072        );
1073    }
1074    let has_missing = feature_has_missing(table, feature_index, rows);
1075    if matches!(criterion, Criterion::Mean) && !has_missing {
1076        if matches!(algorithm, RegressionTreeAlgorithm::Randomized) {
1077            if let Some(candidate) = score_randomized_split_mean_fast(
1078                table,
1079                targets,
1080                feature_index,
1081                rows,
1082                min_samples_leaf,
1083            ) {
1084                return Some(candidate);
1085            }
1086        } else if let Some(candidate) =
1087            score_numeric_split_mean_fast(table, targets, feature_index, rows, min_samples_leaf)
1088        {
1089            return Some(candidate);
1090        }
1091    }
1092    if matches!(algorithm, RegressionTreeAlgorithm::Randomized) {
1093        return score_randomized_split(
1094            table,
1095            targets,
1096            feature_index,
1097            rows,
1098            criterion,
1099            min_samples_leaf,
1100            strategy,
1101        );
1102    }
1103    if has_missing && matches!(strategy, MissingValueStrategy::Heuristic) {
1104        return score_split_heuristic_missing_assignment(
1105            table,
1106            targets,
1107            feature_index,
1108            rows,
1109            criterion,
1110            min_samples_leaf,
1111        );
1112    }
1113    let parent_loss = regression_loss(rows, targets, criterion);
1114
1115    rows.iter()
1116        .copied()
1117        .filter(|row_idx| !table.is_missing(feature_index, *row_idx))
1118        .map(|row_idx| table.binned_value(feature_index, row_idx))
1119        .collect::<BTreeSet<_>>()
1120        .into_iter()
1121        .filter_map(|threshold_bin| {
1122            evaluate_regression_missing_assignment(
1123                table,
1124                targets,
1125                feature_index,
1126                rows,
1127                criterion,
1128                min_samples_leaf,
1129                threshold_bin,
1130                parent_loss,
1131            )
1132        })
1133        .max_by(|left, right| left.score.total_cmp(&right.score))
1134}
1135
1136fn score_randomized_split(
1137    table: &dyn TableAccess,
1138    targets: &[f64],
1139    feature_index: usize,
1140    rows: &[usize],
1141    criterion: Criterion,
1142    min_samples_leaf: usize,
1143    _strategy: MissingValueStrategy,
1144) -> Option<RegressionSplitCandidate> {
1145    let candidate_thresholds = rows
1146        .iter()
1147        .copied()
1148        .filter(|row_idx| !table.is_missing(feature_index, *row_idx))
1149        .map(|row_idx| table.binned_value(feature_index, row_idx))
1150        .collect::<BTreeSet<_>>()
1151        .into_iter()
1152        .collect::<Vec<_>>();
1153    let threshold_bin =
1154        choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xA11CE551u64)?;
1155
1156    let parent_loss = regression_loss(rows, targets, criterion);
1157    evaluate_regression_missing_assignment(
1158        table,
1159        targets,
1160        feature_index,
1161        rows,
1162        criterion,
1163        min_samples_leaf,
1164        threshold_bin,
1165        parent_loss,
1166    )
1167}
1168
1169fn score_oblivious_split(
1170    table: &dyn TableAccess,
1171    row_indices: &[usize],
1172    targets: &[f64],
1173    feature_index: usize,
1174    leaves: &[ObliviousLeafState],
1175    criterion: Criterion,
1176    min_samples_leaf: usize,
1177) -> Option<ObliviousSplitCandidate> {
1178    if table.is_binary_binned_feature(feature_index) {
1179        if matches!(criterion, Criterion::Mean)
1180            && let Some(candidate) = score_binary_oblivious_split_mean_fast(
1181                table,
1182                row_indices,
1183                targets,
1184                feature_index,
1185                leaves,
1186                min_samples_leaf,
1187            )
1188        {
1189            return Some(candidate);
1190        }
1191        return score_binary_oblivious_split(
1192            table,
1193            row_indices,
1194            targets,
1195            feature_index,
1196            leaves,
1197            criterion,
1198            min_samples_leaf,
1199        );
1200    }
1201    if matches!(criterion, Criterion::Mean)
1202        && let Some(candidate) = score_numeric_oblivious_split_mean_fast(
1203            table,
1204            row_indices,
1205            targets,
1206            feature_index,
1207            leaves,
1208            min_samples_leaf,
1209        )
1210    {
1211        return Some(candidate);
1212    }
1213    let candidate_thresholds = leaves
1214        .iter()
1215        .flat_map(|leaf| {
1216            row_indices[leaf.start..leaf.end]
1217                .iter()
1218                .map(|row_idx| table.binned_value(feature_index, *row_idx))
1219        })
1220        .collect::<BTreeSet<_>>();
1221
1222    candidate_thresholds
1223        .into_iter()
1224        .filter_map(|threshold_bin| {
1225            let score = leaves.iter().fold(0.0, |score, leaf| {
1226                let leaf_rows = &row_indices[leaf.start..leaf.end];
1227                let (left_rows, right_rows): (Vec<usize>, Vec<usize>) =
1228                    leaf_rows.iter().copied().partition(|row_idx| {
1229                        table.binned_value(feature_index, *row_idx) <= threshold_bin
1230                    });
1231
1232                if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1233                    return score;
1234                }
1235
1236                score + regression_loss(leaf_rows, targets, criterion)
1237                    - (regression_loss(&left_rows, targets, criterion)
1238                        + regression_loss(&right_rows, targets, criterion))
1239            });
1240
1241            (score > 0.0).then_some(ObliviousSplitCandidate {
1242                feature_index,
1243                threshold_bin,
1244                score,
1245            })
1246        })
1247        .max_by(|left, right| left.score.total_cmp(&right.score))
1248}
1249
1250fn split_oblivious_leaves_in_place(
1251    table: &dyn TableAccess,
1252    row_indices: &mut [usize],
1253    targets: &[f64],
1254    leaves: Vec<ObliviousLeafState>,
1255    feature_index: usize,
1256    threshold_bin: u16,
1257    criterion: Criterion,
1258) -> Vec<ObliviousLeafState> {
1259    let mut next_leaves = Vec::with_capacity(leaves.len() * 2);
1260    for leaf in leaves {
1261        let fallback_value = leaf.value;
1262        let left_count = partition_rows_for_binary_split(
1263            table,
1264            feature_index,
1265            threshold_bin,
1266            MissingBranchDirection::Right,
1267            &mut row_indices[leaf.start..leaf.end],
1268        );
1269        let mid = leaf.start + left_count;
1270        let left_rows = &row_indices[leaf.start..mid];
1271        let right_rows = &row_indices[mid..leaf.end];
1272        let (left_sum, left_sum_sq) = sum_stats(left_rows, targets);
1273        let (right_sum, right_sum_sq) = sum_stats(right_rows, targets);
1274        next_leaves.push(ObliviousLeafState {
1275            start: leaf.start,
1276            end: mid,
1277            value: if left_rows.is_empty() {
1278                fallback_value
1279            } else {
1280                regression_value_from_stats(left_rows, targets, criterion, left_sum)
1281            },
1282            variance: variance_from_stats(left_rows.len(), left_sum, left_sum_sq),
1283            sum: left_sum,
1284            sum_sq: left_sum_sq,
1285        });
1286        next_leaves.push(ObliviousLeafState {
1287            start: mid,
1288            end: leaf.end,
1289            value: if right_rows.is_empty() {
1290                fallback_value
1291            } else {
1292                regression_value_from_stats(right_rows, targets, criterion, right_sum)
1293            },
1294            variance: variance_from_stats(right_rows.len(), right_sum, right_sum_sq),
1295            sum: right_sum,
1296            sum_sq: right_sum_sq,
1297        });
1298    }
1299    next_leaves
1300}
1301
1302fn variance(rows: &[usize], targets: &[f64]) -> Option<f64> {
1303    let (sum, sum_sq) = sum_stats(rows, targets);
1304    variance_from_stats(rows.len(), sum, sum_sq)
1305}
1306
1307fn mean(rows: &[usize], targets: &[f64]) -> f64 {
1308    if rows.is_empty() {
1309        0.0
1310    } else {
1311        rows.iter().map(|row_idx| targets[*row_idx]).sum::<f64>() / rows.len() as f64
1312    }
1313}
1314
1315fn median(rows: &[usize], targets: &[f64]) -> f64 {
1316    if rows.is_empty() {
1317        return 0.0;
1318    }
1319    let mut values: Vec<f64> = rows.iter().map(|row_idx| targets[*row_idx]).collect();
1320    values.sort_by(|left, right| left.total_cmp(right));
1321
1322    let mid = values.len() / 2;
1323    if values.len().is_multiple_of(2) {
1324        (values[mid - 1] + values[mid]) / 2.0
1325    } else {
1326        values[mid]
1327    }
1328}
1329
1330fn sum_squared_error(rows: &[usize], targets: &[f64]) -> f64 {
1331    let mean = mean(rows, targets);
1332    rows.iter()
1333        .map(|row_idx| {
1334            let diff = targets[*row_idx] - mean;
1335            diff * diff
1336        })
1337        .sum()
1338}
1339
1340fn sum_absolute_error(rows: &[usize], targets: &[f64]) -> f64 {
1341    let median = median(rows, targets);
1342    rows.iter()
1343        .map(|row_idx| (targets[*row_idx] - median).abs())
1344        .sum()
1345}
1346
1347fn regression_value(rows: &[usize], targets: &[f64], criterion: Criterion) -> f64 {
1348    let (sum, _sum_sq) = sum_stats(rows, targets);
1349    regression_value_from_stats(rows, targets, criterion, sum)
1350}
1351
1352fn regression_value_from_stats(
1353    rows: &[usize],
1354    targets: &[f64],
1355    criterion: Criterion,
1356    sum: f64,
1357) -> f64 {
1358    match criterion {
1359        Criterion::Mean => {
1360            if rows.is_empty() {
1361                0.0
1362            } else {
1363                sum / rows.len() as f64
1364            }
1365        }
1366        Criterion::Median => median(rows, targets),
1367        _ => unreachable!("regression criterion only supports mean or median"),
1368    }
1369}
1370
1371fn regression_loss(rows: &[usize], targets: &[f64], criterion: Criterion) -> f64 {
1372    match criterion {
1373        Criterion::Mean => sum_squared_error(rows, targets),
1374        Criterion::Median => sum_absolute_error(rows, targets),
1375        _ => unreachable!("regression criterion only supports mean or median"),
1376    }
1377}
1378
1379fn score_binary_split(
1380    table: &dyn TableAccess,
1381    targets: &[f64],
1382    feature_index: usize,
1383    rows: &[usize],
1384    criterion: Criterion,
1385    min_samples_leaf: usize,
1386    strategy: MissingValueStrategy,
1387) -> Option<RegressionSplitCandidate> {
1388    if matches!(strategy, MissingValueStrategy::Heuristic) {
1389        return score_binary_split_heuristic(
1390            table,
1391            targets,
1392            feature_index,
1393            rows,
1394            criterion,
1395            min_samples_leaf,
1396        );
1397    }
1398    let parent_loss = regression_loss(rows, targets, criterion);
1399    evaluate_regression_missing_assignment(
1400        table,
1401        targets,
1402        feature_index,
1403        rows,
1404        criterion,
1405        min_samples_leaf,
1406        0,
1407        parent_loss,
1408    )
1409}
1410
1411fn score_binary_split_heuristic(
1412    table: &dyn TableAccess,
1413    targets: &[f64],
1414    feature_index: usize,
1415    rows: &[usize],
1416    criterion: Criterion,
1417    min_samples_leaf: usize,
1418) -> Option<RegressionSplitCandidate> {
1419    let observed_rows = rows
1420        .iter()
1421        .copied()
1422        .filter(|row_idx| !table.is_missing(feature_index, *row_idx))
1423        .collect::<Vec<_>>();
1424    if observed_rows.is_empty() {
1425        return None;
1426    }
1427    let parent_loss = regression_loss(&observed_rows, targets, criterion);
1428    let mut left_rows = Vec::new();
1429    let mut right_rows = Vec::new();
1430    for row_idx in observed_rows.iter().copied() {
1431        if !table
1432            .binned_boolean_value(feature_index, row_idx)
1433            .expect("observed binary feature must expose boolean values")
1434        {
1435            left_rows.push(row_idx);
1436        } else {
1437            right_rows.push(row_idx);
1438        }
1439    }
1440    if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1441        return None;
1442    }
1443    evaluate_regression_missing_assignment(
1444        table,
1445        targets,
1446        feature_index,
1447        rows,
1448        criterion,
1449        min_samples_leaf,
1450        0,
1451        parent_loss,
1452    )
1453}
1454
1455fn score_split_heuristic_missing_assignment(
1456    table: &dyn TableAccess,
1457    targets: &[f64],
1458    feature_index: usize,
1459    rows: &[usize],
1460    criterion: Criterion,
1461    min_samples_leaf: usize,
1462) -> Option<RegressionSplitCandidate> {
1463    let observed_rows = rows
1464        .iter()
1465        .copied()
1466        .filter(|row_idx| !table.is_missing(feature_index, *row_idx))
1467        .collect::<Vec<_>>();
1468    if observed_rows.is_empty() {
1469        return None;
1470    }
1471    let parent_loss = regression_loss(&observed_rows, targets, criterion);
1472    let threshold_bin = observed_rows
1473        .iter()
1474        .copied()
1475        .map(|row_idx| table.binned_value(feature_index, row_idx))
1476        .collect::<BTreeSet<_>>()
1477        .into_iter()
1478        .filter_map(|threshold_bin| {
1479            evaluate_regression_observed_split(
1480                table,
1481                targets,
1482                feature_index,
1483                &observed_rows,
1484                criterion,
1485                min_samples_leaf,
1486                threshold_bin,
1487                parent_loss,
1488            )
1489            .map(|score| (threshold_bin, score))
1490        })
1491        .max_by(|left, right| left.1.total_cmp(&right.1))
1492        .map(|(threshold_bin, _)| threshold_bin)?;
1493    evaluate_regression_missing_assignment(
1494        table,
1495        targets,
1496        feature_index,
1497        rows,
1498        criterion,
1499        min_samples_leaf,
1500        threshold_bin,
1501        parent_loss,
1502    )
1503}
1504
1505#[allow(clippy::too_many_arguments)]
1506fn evaluate_regression_observed_split(
1507    table: &dyn TableAccess,
1508    targets: &[f64],
1509    feature_index: usize,
1510    observed_rows: &[usize],
1511    criterion: Criterion,
1512    min_samples_leaf: usize,
1513    threshold_bin: u16,
1514    parent_loss: f64,
1515) -> Option<f64> {
1516    let mut left_rows = Vec::new();
1517    let mut right_rows = Vec::new();
1518    for row_idx in observed_rows.iter().copied() {
1519        if table.binned_value(feature_index, row_idx) <= threshold_bin {
1520            left_rows.push(row_idx);
1521        } else {
1522            right_rows.push(row_idx);
1523        }
1524    }
1525    if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1526        return None;
1527    }
1528    Some(
1529        parent_loss
1530            - (regression_loss(&left_rows, targets, criterion)
1531                + regression_loss(&right_rows, targets, criterion)),
1532    )
1533}
1534
1535fn score_binary_split_choice(
1536    context: &BuildContext<'_>,
1537    feature_index: usize,
1538    rows: &[usize],
1539) -> Option<BinarySplitChoice> {
1540    if matches!(context.criterion, Criterion::Mean) {
1541        if context.table.is_binary_binned_feature(feature_index) {
1542            if feature_has_missing(context.table, feature_index, rows) {
1543                return score_split(
1544                    context.table,
1545                    context.targets,
1546                    feature_index,
1547                    rows,
1548                    context.criterion,
1549                    context.options.min_samples_leaf,
1550                    context.algorithm,
1551                    context.options.missing_value_strategy(feature_index),
1552                )
1553                .map(|candidate| BinarySplitChoice {
1554                    feature_index: candidate.feature_index,
1555                    threshold_bin: candidate.threshold_bin,
1556                    score: candidate.score,
1557                    missing_direction: candidate.missing_direction,
1558                });
1559            }
1560            return score_binary_split_choice_mean(context, feature_index, rows);
1561        }
1562        if feature_has_missing(context.table, feature_index, rows) {
1563            return score_split(
1564                context.table,
1565                context.targets,
1566                feature_index,
1567                rows,
1568                context.criterion,
1569                context.options.min_samples_leaf,
1570                context.algorithm,
1571                context.options.missing_value_strategy(feature_index),
1572            )
1573            .map(|candidate| BinarySplitChoice {
1574                feature_index: candidate.feature_index,
1575                threshold_bin: candidate.threshold_bin,
1576                score: candidate.score,
1577                missing_direction: candidate.missing_direction,
1578            });
1579        }
1580        return match context.algorithm {
1581            RegressionTreeAlgorithm::Cart => {
1582                score_numeric_split_choice_mean_fast(context, feature_index, rows)
1583            }
1584            RegressionTreeAlgorithm::Randomized => {
1585                score_randomized_split_choice_mean_fast(context, feature_index, rows)
1586            }
1587            RegressionTreeAlgorithm::Oblivious => None,
1588        };
1589    }
1590
1591    score_split(
1592        context.table,
1593        context.targets,
1594        feature_index,
1595        rows,
1596        context.criterion,
1597        context.options.min_samples_leaf,
1598        context.algorithm,
1599        context.options.missing_value_strategy(feature_index),
1600    )
1601    .map(|candidate| BinarySplitChoice {
1602        feature_index: candidate.feature_index,
1603        threshold_bin: candidate.threshold_bin,
1604        score: candidate.score,
1605        missing_direction: candidate.missing_direction,
1606    })
1607}
1608
1609fn score_binary_split_choice_from_hist(
1610    context: &BuildContext<'_>,
1611    histogram: &RegressionFeatureHistogram,
1612    feature_index: usize,
1613    rows: &[usize],
1614) -> Option<BinarySplitChoice> {
1615    if !matches!(context.criterion, Criterion::Mean) {
1616        return score_binary_split_choice(context, feature_index, rows);
1617    }
1618
1619    match histogram {
1620        RegressionFeatureHistogram::Binary {
1621            false_bin,
1622            true_bin,
1623            missing_bin,
1624        } if missing_bin.count == 0 => score_binary_split_choice_mean_from_stats(
1625            context,
1626            feature_index,
1627            false_bin.count,
1628            false_bin.sum,
1629            false_bin.sum_sq,
1630            true_bin.count,
1631            true_bin.sum,
1632            true_bin.sum_sq,
1633        ),
1634        RegressionFeatureHistogram::Binary { .. } => {
1635            score_binary_split_choice(context, feature_index, rows)
1636        }
1637        RegressionFeatureHistogram::Numeric {
1638            bins,
1639            observed_bins,
1640        } if bins
1641            .get(context.table.numeric_bin_cap())
1642            .is_none_or(|missing_bin| missing_bin.count == 0) =>
1643        {
1644            match context.algorithm {
1645                RegressionTreeAlgorithm::Cart => score_numeric_split_choice_mean_from_hist(
1646                    context,
1647                    feature_index,
1648                    rows.len(),
1649                    bins,
1650                    observed_bins,
1651                ),
1652                RegressionTreeAlgorithm::Randomized => {
1653                    score_randomized_split_choice_mean_from_hist(
1654                        context,
1655                        feature_index,
1656                        rows,
1657                        bins,
1658                        observed_bins,
1659                    )
1660                }
1661                RegressionTreeAlgorithm::Oblivious => None,
1662            }
1663        }
1664        RegressionFeatureHistogram::Numeric { .. } => {
1665            score_binary_split_choice(context, feature_index, rows)
1666        }
1667    }
1668}
1669
1670#[allow(clippy::too_many_arguments)]
1671fn score_binary_split_choice_mean_from_stats(
1672    context: &BuildContext<'_>,
1673    feature_index: usize,
1674    left_count: usize,
1675    left_sum: f64,
1676    left_sum_sq: f64,
1677    right_count: usize,
1678    right_sum: f64,
1679    right_sum_sq: f64,
1680) -> Option<BinarySplitChoice> {
1681    if left_count < context.options.min_samples_leaf
1682        || right_count < context.options.min_samples_leaf
1683    {
1684        return None;
1685    }
1686    let total_count = left_count + right_count;
1687    let total_sum = left_sum + right_sum;
1688    let total_sum_sq = left_sum_sq + right_sum_sq;
1689    let parent_loss = total_sum_sq - (total_sum * total_sum) / total_count as f64;
1690    let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1691    let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1692    Some(BinarySplitChoice {
1693        feature_index,
1694        threshold_bin: 0,
1695        score: parent_loss - (left_loss + right_loss),
1696        missing_direction: MissingBranchDirection::Node,
1697    })
1698}
1699
1700fn score_numeric_split_choice_mean_from_hist(
1701    context: &BuildContext<'_>,
1702    feature_index: usize,
1703    row_count: usize,
1704    bins: &[RegressionHistogramBin],
1705    observed_bins: &[usize],
1706) -> Option<BinarySplitChoice> {
1707    if observed_bins.len() <= 1 {
1708        return None;
1709    }
1710    let total_sum = bins.iter().map(|bin| bin.sum).sum::<f64>();
1711    let total_sum_sq = bins.iter().map(|bin| bin.sum_sq).sum::<f64>();
1712    let parent_loss = total_sum_sq - (total_sum * total_sum) / row_count as f64;
1713    let mut left_count = 0usize;
1714    let mut left_sum = 0.0;
1715    let mut left_sum_sq = 0.0;
1716    let mut best_threshold = None;
1717    let mut best_score = f64::NEG_INFINITY;
1718
1719    for &bin in observed_bins {
1720        left_count += bins[bin].count;
1721        left_sum += bins[bin].sum;
1722        left_sum_sq += bins[bin].sum_sq;
1723        let right_count = row_count - left_count;
1724        if left_count < context.options.min_samples_leaf
1725            || right_count < context.options.min_samples_leaf
1726        {
1727            continue;
1728        }
1729        let right_sum = total_sum - left_sum;
1730        let right_sum_sq = total_sum_sq - left_sum_sq;
1731        let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1732        let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1733        let score = parent_loss - (left_loss + right_loss);
1734        if score > best_score {
1735            best_score = score;
1736            best_threshold = Some(bin as u16);
1737        }
1738    }
1739
1740    best_threshold.map(|threshold_bin| BinarySplitChoice {
1741        feature_index,
1742        threshold_bin,
1743        score: best_score,
1744        missing_direction: MissingBranchDirection::Node,
1745    })
1746}
1747
1748fn score_randomized_split_choice_mean_from_hist(
1749    context: &BuildContext<'_>,
1750    feature_index: usize,
1751    rows: &[usize],
1752    bins: &[RegressionHistogramBin],
1753    observed_bins: &[usize],
1754) -> Option<BinarySplitChoice> {
1755    let candidate_thresholds = observed_bins
1756        .iter()
1757        .copied()
1758        .map(|bin| bin as u16)
1759        .collect::<Vec<_>>();
1760    let threshold_bin =
1761        choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xA11CE551u64)?;
1762    let total_sum = bins.iter().map(|bin| bin.sum).sum::<f64>();
1763    let total_sum_sq = bins.iter().map(|bin| bin.sum_sq).sum::<f64>();
1764    let mut left_count = 0usize;
1765    let mut left_sum = 0.0;
1766    let mut left_sum_sq = 0.0;
1767    for bin in 0..=threshold_bin as usize {
1768        if bin >= bins.len() {
1769            break;
1770        }
1771        left_count += bins[bin].count;
1772        left_sum += bins[bin].sum;
1773        left_sum_sq += bins[bin].sum_sq;
1774    }
1775    let right_count = rows.len() - left_count;
1776    if left_count < context.options.min_samples_leaf
1777        || right_count < context.options.min_samples_leaf
1778    {
1779        return None;
1780    }
1781    let parent_loss = total_sum_sq - (total_sum * total_sum) / rows.len() as f64;
1782    let right_sum = total_sum - left_sum;
1783    let right_sum_sq = total_sum_sq - left_sum_sq;
1784    let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1785    let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1786    Some(BinarySplitChoice {
1787        feature_index,
1788        threshold_bin,
1789        score: parent_loss - (left_loss + right_loss),
1790        missing_direction: MissingBranchDirection::Node,
1791    })
1792}
1793
1794fn score_binary_split_choice_mean(
1795    context: &BuildContext<'_>,
1796    feature_index: usize,
1797    rows: &[usize],
1798) -> Option<BinarySplitChoice> {
1799    let mut left_count = 0usize;
1800    let mut left_sum = 0.0;
1801    let mut left_sum_sq = 0.0;
1802    let mut total_sum = 0.0;
1803    let mut total_sum_sq = 0.0;
1804
1805    for row_idx in rows {
1806        let target = context.targets[*row_idx];
1807        total_sum += target;
1808        total_sum_sq += target * target;
1809        if !context
1810            .table
1811            .binned_boolean_value(feature_index, *row_idx)
1812            .expect("binary feature must expose boolean values")
1813        {
1814            left_count += 1;
1815            left_sum += target;
1816            left_sum_sq += target * target;
1817        }
1818    }
1819
1820    let right_count = rows.len() - left_count;
1821    if left_count < context.options.min_samples_leaf
1822        || right_count < context.options.min_samples_leaf
1823    {
1824        return None;
1825    }
1826
1827    let parent_loss = total_sum_sq - (total_sum * total_sum) / rows.len() as f64;
1828    let right_sum = total_sum - left_sum;
1829    let right_sum_sq = total_sum_sq - left_sum_sq;
1830    let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1831    let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1832
1833    Some(BinarySplitChoice {
1834        feature_index,
1835        threshold_bin: 0,
1836        score: parent_loss - (left_loss + right_loss),
1837        missing_direction: MissingBranchDirection::Node,
1838    })
1839}
1840
1841fn score_numeric_split_mean_fast(
1842    table: &dyn TableAccess,
1843    targets: &[f64],
1844    feature_index: usize,
1845    rows: &[usize],
1846    min_samples_leaf: usize,
1847) -> Option<RegressionSplitCandidate> {
1848    let bin_cap = table.numeric_bin_cap();
1849    if bin_cap == 0 {
1850        return None;
1851    }
1852
1853    let mut bin_count = vec![0usize; bin_cap];
1854    let mut bin_sum = vec![0.0; bin_cap];
1855    let mut bin_sum_sq = vec![0.0; bin_cap];
1856    let mut observed_bins = vec![false; bin_cap];
1857    let mut total_sum = 0.0;
1858    let mut total_sum_sq = 0.0;
1859
1860    for row_idx in rows {
1861        let bin = table.binned_value(feature_index, *row_idx) as usize;
1862        if bin >= bin_cap {
1863            return None;
1864        }
1865        let target = targets[*row_idx];
1866        bin_count[bin] += 1;
1867        bin_sum[bin] += target;
1868        bin_sum_sq[bin] += target * target;
1869        observed_bins[bin] = true;
1870        total_sum += target;
1871        total_sum_sq += target * target;
1872    }
1873
1874    let observed_bins: Vec<usize> = observed_bins
1875        .into_iter()
1876        .enumerate()
1877        .filter_map(|(bin, seen)| seen.then_some(bin))
1878        .collect();
1879    if observed_bins.len() <= 1 {
1880        return None;
1881    }
1882
1883    let parent_loss = total_sum_sq - (total_sum * total_sum) / rows.len() as f64;
1884    let mut left_count = 0usize;
1885    let mut left_sum = 0.0;
1886    let mut left_sum_sq = 0.0;
1887    let mut best_threshold = None;
1888    let mut best_score = f64::NEG_INFINITY;
1889
1890    for &bin in &observed_bins {
1891        left_count += bin_count[bin];
1892        left_sum += bin_sum[bin];
1893        left_sum_sq += bin_sum_sq[bin];
1894        let right_count = rows.len() - left_count;
1895
1896        if left_count < min_samples_leaf || right_count < min_samples_leaf {
1897            continue;
1898        }
1899
1900        let right_sum = total_sum - left_sum;
1901        let right_sum_sq = total_sum_sq - left_sum_sq;
1902        let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1903        let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1904        let score = parent_loss - (left_loss + right_loss);
1905        if score > best_score {
1906            best_score = score;
1907            best_threshold = Some(bin as u16);
1908        }
1909    }
1910
1911    let threshold_bin = best_threshold?;
1912    Some(RegressionSplitCandidate {
1913        feature_index,
1914        threshold_bin,
1915        score: best_score,
1916        missing_direction: MissingBranchDirection::Node,
1917    })
1918}
1919
1920fn score_numeric_split_choice_mean_fast(
1921    context: &BuildContext<'_>,
1922    feature_index: usize,
1923    rows: &[usize],
1924) -> Option<BinarySplitChoice> {
1925    score_numeric_split_mean_fast(
1926        context.table,
1927        context.targets,
1928        feature_index,
1929        rows,
1930        context.options.min_samples_leaf,
1931    )
1932    .map(|candidate| BinarySplitChoice {
1933        feature_index: candidate.feature_index,
1934        threshold_bin: candidate.threshold_bin,
1935        score: candidate.score,
1936        missing_direction: MissingBranchDirection::Node,
1937    })
1938}
1939
1940fn score_randomized_split_mean_fast(
1941    table: &dyn TableAccess,
1942    targets: &[f64],
1943    feature_index: usize,
1944    rows: &[usize],
1945    min_samples_leaf: usize,
1946) -> Option<RegressionSplitCandidate> {
1947    let bin_cap = table.numeric_bin_cap();
1948    if bin_cap == 0 {
1949        return None;
1950    }
1951    let mut observed_bins = vec![false; bin_cap];
1952    for row_idx in rows {
1953        let bin = table.binned_value(feature_index, *row_idx) as usize;
1954        if bin >= bin_cap {
1955            return None;
1956        }
1957        observed_bins[bin] = true;
1958    }
1959    let candidate_thresholds = observed_bins
1960        .into_iter()
1961        .enumerate()
1962        .filter_map(|(bin, seen)| seen.then_some(bin as u16))
1963        .collect::<Vec<_>>();
1964    let threshold_bin =
1965        choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xA11CE551u64)?;
1966
1967    let mut left_count = 0usize;
1968    let mut left_sum = 0.0;
1969    let mut left_sum_sq = 0.0;
1970    let mut total_sum = 0.0;
1971    let mut total_sum_sq = 0.0;
1972    for row_idx in rows {
1973        let target = targets[*row_idx];
1974        total_sum += target;
1975        total_sum_sq += target * target;
1976        if table.binned_value(feature_index, *row_idx) <= threshold_bin {
1977            left_count += 1;
1978            left_sum += target;
1979            left_sum_sq += target * target;
1980        }
1981    }
1982    let right_count = rows.len() - left_count;
1983    if left_count < min_samples_leaf || right_count < min_samples_leaf {
1984        return None;
1985    }
1986    let parent_loss = total_sum_sq - (total_sum * total_sum) / rows.len() as f64;
1987    let right_sum = total_sum - left_sum;
1988    let right_sum_sq = total_sum_sq - left_sum_sq;
1989    let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
1990    let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
1991    let score = parent_loss - (left_loss + right_loss);
1992
1993    Some(RegressionSplitCandidate {
1994        feature_index,
1995        threshold_bin,
1996        score,
1997        missing_direction: MissingBranchDirection::Node,
1998    })
1999}
2000
2001fn score_randomized_split_choice_mean_fast(
2002    context: &BuildContext<'_>,
2003    feature_index: usize,
2004    rows: &[usize],
2005) -> Option<BinarySplitChoice> {
2006    score_randomized_split_mean_fast(
2007        context.table,
2008        context.targets,
2009        feature_index,
2010        rows,
2011        context.options.min_samples_leaf,
2012    )
2013    .map(|candidate| BinarySplitChoice {
2014        feature_index: candidate.feature_index,
2015        threshold_bin: candidate.threshold_bin,
2016        score: candidate.score,
2017        missing_direction: MissingBranchDirection::Node,
2018    })
2019}
2020
2021fn feature_has_missing(table: &dyn TableAccess, feature_index: usize, rows: &[usize]) -> bool {
2022    rows.iter()
2023        .any(|row_idx| table.is_missing(feature_index, *row_idx))
2024}
2025
2026#[allow(clippy::too_many_arguments)]
2027fn evaluate_regression_missing_assignment(
2028    table: &dyn TableAccess,
2029    targets: &[f64],
2030    feature_index: usize,
2031    rows: &[usize],
2032    criterion: Criterion,
2033    min_samples_leaf: usize,
2034    threshold_bin: u16,
2035    parent_loss: f64,
2036) -> Option<RegressionSplitCandidate> {
2037    let mut left_rows = Vec::new();
2038    let mut right_rows = Vec::new();
2039    let mut missing_rows = Vec::new();
2040
2041    for row_idx in rows.iter().copied() {
2042        if table.is_missing(feature_index, row_idx) {
2043            missing_rows.push(row_idx);
2044        } else if table.is_binary_binned_feature(feature_index) {
2045            if !table
2046                .binned_boolean_value(feature_index, row_idx)
2047                .expect("observed binary feature must expose boolean values")
2048            {
2049                left_rows.push(row_idx);
2050            } else {
2051                right_rows.push(row_idx);
2052            }
2053        } else if table.binned_value(feature_index, row_idx) <= threshold_bin {
2054            left_rows.push(row_idx);
2055        } else {
2056            right_rows.push(row_idx);
2057        }
2058    }
2059
2060    let evaluate = |direction: MissingBranchDirection| {
2061        let mut candidate_left = left_rows.clone();
2062        let mut candidate_right = right_rows.clone();
2063        match direction {
2064            MissingBranchDirection::Left => candidate_left.extend(missing_rows.iter().copied()),
2065            MissingBranchDirection::Right => candidate_right.extend(missing_rows.iter().copied()),
2066            MissingBranchDirection::Node => {}
2067        }
2068        if candidate_left.len() < min_samples_leaf || candidate_right.len() < min_samples_leaf {
2069            return None;
2070        }
2071
2072        let score = parent_loss
2073            - (regression_loss(&candidate_left, targets, criterion)
2074                + regression_loss(&candidate_right, targets, criterion));
2075        Some(RegressionSplitCandidate {
2076            feature_index,
2077            threshold_bin,
2078            score,
2079            missing_direction: direction,
2080        })
2081    };
2082
2083    if missing_rows.is_empty() {
2084        evaluate(MissingBranchDirection::Node)
2085    } else {
2086        [MissingBranchDirection::Left, MissingBranchDirection::Right]
2087            .into_iter()
2088            .filter_map(evaluate)
2089            .max_by(|left, right| left.score.total_cmp(&right.score))
2090    }
2091}
2092
2093fn score_numeric_oblivious_split_mean_fast(
2094    table: &dyn TableAccess,
2095    row_indices: &[usize],
2096    targets: &[f64],
2097    feature_index: usize,
2098    leaves: &[ObliviousLeafState],
2099    min_samples_leaf: usize,
2100) -> Option<ObliviousSplitCandidate> {
2101    let bin_cap = table.numeric_bin_cap();
2102    if bin_cap == 0 {
2103        return None;
2104    }
2105    let mut threshold_scores = vec![0.0; bin_cap];
2106    let mut observed_any = false;
2107
2108    for leaf in leaves {
2109        let mut bin_count = vec![0usize; bin_cap];
2110        let mut bin_sum = vec![0.0; bin_cap];
2111        let mut bin_sum_sq = vec![0.0; bin_cap];
2112        let mut observed_bins = vec![false; bin_cap];
2113
2114        for row_idx in &row_indices[leaf.start..leaf.end] {
2115            let bin = table.binned_value(feature_index, *row_idx) as usize;
2116            if bin >= bin_cap {
2117                return None;
2118            }
2119            let target = targets[*row_idx];
2120            bin_count[bin] += 1;
2121            bin_sum[bin] += target;
2122            bin_sum_sq[bin] += target * target;
2123            observed_bins[bin] = true;
2124        }
2125
2126        let observed_bins: Vec<usize> = observed_bins
2127            .into_iter()
2128            .enumerate()
2129            .filter_map(|(bin, seen)| seen.then_some(bin))
2130            .collect();
2131        if observed_bins.len() <= 1 {
2132            continue;
2133        }
2134        observed_any = true;
2135
2136        let parent_loss = leaf.sum_sq - (leaf.sum * leaf.sum) / leaf.len() as f64;
2137        let mut left_count = 0usize;
2138        let mut left_sum = 0.0;
2139        let mut left_sum_sq = 0.0;
2140
2141        for &bin in &observed_bins {
2142            left_count += bin_count[bin];
2143            left_sum += bin_sum[bin];
2144            left_sum_sq += bin_sum_sq[bin];
2145            let right_count = leaf.len() - left_count;
2146
2147            if left_count < min_samples_leaf || right_count < min_samples_leaf {
2148                continue;
2149            }
2150
2151            let right_sum = leaf.sum - left_sum;
2152            let right_sum_sq = leaf.sum_sq - left_sum_sq;
2153            let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
2154            let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
2155            threshold_scores[bin] += parent_loss - (left_loss + right_loss);
2156        }
2157    }
2158
2159    if !observed_any {
2160        return None;
2161    }
2162
2163    threshold_scores
2164        .into_iter()
2165        .enumerate()
2166        .filter(|(_, score)| *score > 0.0)
2167        .max_by(|left, right| left.1.total_cmp(&right.1))
2168        .map(|(threshold_bin, score)| ObliviousSplitCandidate {
2169            feature_index,
2170            threshold_bin: threshold_bin as u16,
2171            score,
2172        })
2173}
2174
2175fn score_binary_oblivious_split(
2176    table: &dyn TableAccess,
2177    row_indices: &[usize],
2178    targets: &[f64],
2179    feature_index: usize,
2180    leaves: &[ObliviousLeafState],
2181    criterion: Criterion,
2182    min_samples_leaf: usize,
2183) -> Option<ObliviousSplitCandidate> {
2184    let score = leaves.iter().fold(0.0, |score, leaf| {
2185        let leaf_rows = &row_indices[leaf.start..leaf.end];
2186        let (left_rows, right_rows): (Vec<usize>, Vec<usize>) =
2187            leaf_rows.iter().copied().partition(|row_idx| {
2188                !table
2189                    .binned_boolean_value(feature_index, *row_idx)
2190                    .expect("binary feature must expose boolean values")
2191            });
2192
2193        if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
2194            return score;
2195        }
2196
2197        score + regression_loss(leaf_rows, targets, criterion)
2198            - (regression_loss(&left_rows, targets, criterion)
2199                + regression_loss(&right_rows, targets, criterion))
2200    });
2201
2202    (score > 0.0).then_some(ObliviousSplitCandidate {
2203        feature_index,
2204        threshold_bin: 0,
2205        score,
2206    })
2207}
2208
2209fn score_binary_oblivious_split_mean_fast(
2210    table: &dyn TableAccess,
2211    row_indices: &[usize],
2212    targets: &[f64],
2213    feature_index: usize,
2214    leaves: &[ObliviousLeafState],
2215    min_samples_leaf: usize,
2216) -> Option<ObliviousSplitCandidate> {
2217    let mut score = 0.0;
2218    let mut found_valid = false;
2219
2220    for leaf in leaves {
2221        let mut left_count = 0usize;
2222        let mut left_sum = 0.0;
2223        let mut left_sum_sq = 0.0;
2224
2225        for row_idx in &row_indices[leaf.start..leaf.end] {
2226            if !table
2227                .binned_boolean_value(feature_index, *row_idx)
2228                .expect("binary feature must expose boolean values")
2229            {
2230                let target = targets[*row_idx];
2231                left_count += 1;
2232                left_sum += target;
2233                left_sum_sq += target * target;
2234            }
2235        }
2236
2237        let right_count = leaf.len() - left_count;
2238        if left_count < min_samples_leaf || right_count < min_samples_leaf {
2239            continue;
2240        }
2241
2242        found_valid = true;
2243        let parent_loss = leaf.sum_sq - (leaf.sum * leaf.sum) / leaf.len() as f64;
2244        let right_sum = leaf.sum - left_sum;
2245        let right_sum_sq = leaf.sum_sq - left_sum_sq;
2246        let left_loss = left_sum_sq - (left_sum * left_sum) / left_count as f64;
2247        let right_loss = right_sum_sq - (right_sum * right_sum) / right_count as f64;
2248        score += parent_loss - (left_loss + right_loss);
2249    }
2250
2251    (found_valid && score > 0.0).then_some(ObliviousSplitCandidate {
2252        feature_index,
2253        threshold_bin: 0,
2254        score,
2255    })
2256}
2257
2258fn sum_stats(rows: &[usize], targets: &[f64]) -> (f64, f64) {
2259    rows.iter().fold((0.0, 0.0), |(sum, sum_sq), row_idx| {
2260        let value = targets[*row_idx];
2261        (sum + value, sum_sq + value * value)
2262    })
2263}
2264
2265fn variance_from_stats(count: usize, sum: f64, sum_sq: f64) -> Option<f64> {
2266    if count == 0 {
2267        None
2268    } else {
2269        Some((sum_sq / count as f64) - (sum / count as f64).powi(2))
2270    }
2271}
2272
2273fn leaf_regression_loss(
2274    leaf: &ObliviousLeafState,
2275    row_indices: &[usize],
2276    targets: &[f64],
2277    criterion: Criterion,
2278) -> f64 {
2279    match criterion {
2280        Criterion::Mean => leaf.sum_sq - (leaf.sum * leaf.sum) / leaf.len() as f64,
2281        Criterion::Median => {
2282            regression_loss(&row_indices[leaf.start..leaf.end], targets, criterion)
2283        }
2284        _ => unreachable!("regression criterion only supports mean or median"),
2285    }
2286}
2287
2288fn has_constant_target(rows: &[usize], targets: &[f64]) -> bool {
2289    rows.first().is_none_or(|first_row| {
2290        rows.iter()
2291            .all(|row_idx| targets[*row_idx] == targets[*first_row])
2292    })
2293}
2294
2295fn push_leaf(
2296    nodes: &mut Vec<RegressionNode>,
2297    value: f64,
2298    sample_count: usize,
2299    variance: Option<f64>,
2300) -> usize {
2301    push_node(
2302        nodes,
2303        RegressionNode::Leaf {
2304            value,
2305            sample_count,
2306            variance,
2307        },
2308    )
2309}
2310
2311fn push_node(nodes: &mut Vec<RegressionNode>, node: RegressionNode) -> usize {
2312    nodes.push(node);
2313    nodes.len() - 1
2314}
2315
2316#[cfg(test)]
2317mod tests {
2318    use super::*;
2319    use crate::{FeaturePreprocessing, Model, NumericBinBoundary};
2320    use forestfire_data::{DenseTable, NumericBins};
2321
2322    fn quadratic_table() -> DenseTable {
2323        DenseTable::with_options(
2324            vec![
2325                vec![0.0],
2326                vec![1.0],
2327                vec![2.0],
2328                vec![3.0],
2329                vec![4.0],
2330                vec![5.0],
2331                vec![6.0],
2332                vec![7.0],
2333            ],
2334            vec![0.0, 1.0, 4.0, 9.0, 16.0, 25.0, 36.0, 49.0],
2335            0,
2336            NumericBins::Fixed(128),
2337        )
2338        .unwrap()
2339    }
2340
2341    fn canary_target_table() -> DenseTable {
2342        let x: Vec<Vec<f64>> = (0..8).map(|value| vec![value as f64]).collect();
2343        let probe =
2344            DenseTable::with_options(x.clone(), vec![0.0; 8], 1, NumericBins::Auto).unwrap();
2345        let canary_index = probe.n_features();
2346        let y = (0..probe.n_rows())
2347            .map(|row_idx| probe.binned_value(canary_index, row_idx) as f64)
2348            .collect();
2349
2350        DenseTable::with_options(x, y, 1, NumericBins::Auto).unwrap()
2351    }
2352
2353    fn randomized_permutation_table() -> DenseTable {
2354        DenseTable::with_options(
2355            vec![
2356                vec![0.0, 0.0, 0.0],
2357                vec![0.0, 0.0, 1.0],
2358                vec![0.0, 1.0, 0.0],
2359                vec![0.0, 1.0, 1.0],
2360                vec![1.0, 0.0, 0.0],
2361                vec![1.0, 0.0, 1.0],
2362                vec![1.0, 1.0, 0.0],
2363                vec![1.0, 1.0, 1.0],
2364                vec![0.0, 0.0, 2.0],
2365                vec![0.0, 1.0, 2.0],
2366                vec![1.0, 0.0, 2.0],
2367                vec![1.0, 1.0, 2.0],
2368            ],
2369            vec![0.0, 1.0, 2.5, 3.5, 4.0, 5.0, 6.5, 7.5, 2.0, 4.5, 6.0, 8.5],
2370            0,
2371            NumericBins::Fixed(8),
2372        )
2373        .unwrap()
2374    }
2375
2376    #[test]
2377    fn cart_regressor_fits_basic_numeric_pattern() {
2378        let table = quadratic_table();
2379        let model = train_cart_regressor(&table).unwrap();
2380        let preds = model.predict_table(&table);
2381
2382        assert_eq!(model.algorithm(), RegressionTreeAlgorithm::Cart);
2383        assert_eq!(model.criterion(), Criterion::Mean);
2384        assert_eq!(preds, table_targets(&table));
2385    }
2386
2387    #[test]
2388    fn randomized_regressor_fits_basic_numeric_pattern() {
2389        let table = quadratic_table();
2390        let model = train_randomized_regressor(&table).unwrap();
2391        let preds = model.predict_table(&table);
2392        let targets = table_targets(&table);
2393        let baseline_mean = targets.iter().sum::<f64>() / targets.len() as f64;
2394        let baseline_sse = targets
2395            .iter()
2396            .map(|target| {
2397                let diff = target - baseline_mean;
2398                diff * diff
2399            })
2400            .sum::<f64>();
2401        let model_sse = preds
2402            .iter()
2403            .zip(targets.iter())
2404            .map(|(pred, target)| {
2405                let diff = pred - target;
2406                diff * diff
2407            })
2408            .sum::<f64>();
2409
2410        assert_eq!(model.algorithm(), RegressionTreeAlgorithm::Randomized);
2411        assert_eq!(model.criterion(), Criterion::Mean);
2412        assert!(model_sse < baseline_sse);
2413    }
2414
2415    #[test]
2416    fn randomized_regressor_is_repeatable_for_fixed_seed_and_varies_across_seeds() {
2417        let table = randomized_permutation_table();
2418        let make_options = |random_seed| RegressionTreeOptions {
2419            max_depth: 4,
2420            max_features: Some(2),
2421            random_seed,
2422            ..RegressionTreeOptions::default()
2423        };
2424
2425        let base_model = train_randomized_regressor_with_criterion_parallelism_and_options(
2426            &table,
2427            Criterion::Mean,
2428            Parallelism::sequential(),
2429            make_options(91),
2430        )
2431        .unwrap();
2432        let repeated_model = train_randomized_regressor_with_criterion_parallelism_and_options(
2433            &table,
2434            Criterion::Mean,
2435            Parallelism::sequential(),
2436            make_options(91),
2437        )
2438        .unwrap();
2439        let unique_serializations = (0..32u64)
2440            .map(|seed| {
2441                Model::DecisionTreeRegressor(
2442                    train_randomized_regressor_with_criterion_parallelism_and_options(
2443                        &table,
2444                        Criterion::Mean,
2445                        Parallelism::sequential(),
2446                        make_options(seed),
2447                    )
2448                    .unwrap(),
2449                )
2450                .serialize()
2451                .unwrap()
2452            })
2453            .collect::<std::collections::BTreeSet<_>>();
2454
2455        assert_eq!(
2456            Model::DecisionTreeRegressor(base_model.clone())
2457                .serialize()
2458                .unwrap(),
2459            Model::DecisionTreeRegressor(repeated_model)
2460                .serialize()
2461                .unwrap()
2462        );
2463        assert!(unique_serializations.len() >= 4);
2464    }
2465
2466    #[test]
2467    fn oblivious_regressor_fits_basic_numeric_pattern() {
2468        let table = quadratic_table();
2469        let model = train_oblivious_regressor(&table).unwrap();
2470        let preds = model.predict_table(&table);
2471
2472        assert_eq!(model.algorithm(), RegressionTreeAlgorithm::Oblivious);
2473        assert_eq!(model.criterion(), Criterion::Mean);
2474        assert_eq!(preds, table_targets(&table));
2475    }
2476
2477    #[test]
2478    fn regressors_can_choose_between_mean_and_median() {
2479        let table = DenseTable::with_canaries(
2480            vec![vec![0.0], vec![0.0], vec![0.0]],
2481            vec![0.0, 0.0, 100.0],
2482            0,
2483        )
2484        .unwrap();
2485
2486        let mean_model = train_cart_regressor_with_criterion(&table, Criterion::Mean).unwrap();
2487        let median_model = train_cart_regressor_with_criterion(&table, Criterion::Median).unwrap();
2488
2489        assert_eq!(mean_model.criterion(), Criterion::Mean);
2490        assert_eq!(median_model.criterion(), Criterion::Median);
2491        assert_eq!(
2492            mean_model.predict_table(&table),
2493            vec![100.0 / 3.0, 100.0 / 3.0, 100.0 / 3.0]
2494        );
2495        assert_eq!(median_model.predict_table(&table), vec![0.0, 0.0, 0.0]);
2496    }
2497
2498    #[test]
2499    fn rejects_non_finite_targets() {
2500        let table = DenseTable::new(vec![vec![0.0], vec![1.0]], vec![0.0, f64::NAN]).unwrap();
2501
2502        let err = train_cart_regressor(&table).unwrap_err();
2503        assert!(matches!(
2504            err,
2505            RegressionTreeError::InvalidTargetValue { row: 1, value } if value.is_nan()
2506        ));
2507    }
2508
2509    #[test]
2510    fn stops_cart_regressor_growth_when_a_canary_wins() {
2511        let table = canary_target_table();
2512        let model = train_cart_regressor(&table).unwrap();
2513        let preds = model.predict_table(&table);
2514
2515        assert!(preds.iter().all(|pred| *pred == preds[0]));
2516        assert_ne!(preds, table_targets(&table));
2517    }
2518
2519    #[test]
2520    fn stops_oblivious_regressor_growth_when_a_canary_wins() {
2521        let table = canary_target_table();
2522        let model = train_oblivious_regressor(&table).unwrap();
2523        let preds = model.predict_table(&table);
2524
2525        assert!(preds.iter().all(|pred| *pred == preds[0]));
2526        assert_ne!(preds, table_targets(&table));
2527    }
2528
2529    #[test]
2530    fn manually_built_regressor_models_serialize_for_each_tree_type() {
2531        let preprocessing = vec![
2532            FeaturePreprocessing::Binary,
2533            FeaturePreprocessing::Numeric {
2534                bin_boundaries: vec![
2535                    NumericBinBoundary {
2536                        bin: 0,
2537                        upper_bound: 1.0,
2538                    },
2539                    NumericBinBoundary {
2540                        bin: 127,
2541                        upper_bound: 10.0,
2542                    },
2543                ],
2544                missing_bin: 128,
2545            },
2546        ];
2547        let options = RegressionTreeOptions::default();
2548
2549        let cart = Model::DecisionTreeRegressor(DecisionTreeRegressor {
2550            algorithm: RegressionTreeAlgorithm::Cart,
2551            criterion: Criterion::Mean,
2552            structure: RegressionTreeStructure::Standard {
2553                nodes: vec![
2554                    RegressionNode::Leaf {
2555                        value: -1.0,
2556                        sample_count: 2,
2557                        variance: Some(0.25),
2558                    },
2559                    RegressionNode::Leaf {
2560                        value: 2.5,
2561                        sample_count: 3,
2562                        variance: Some(1.0),
2563                    },
2564                    RegressionNode::BinarySplit {
2565                        feature_index: 0,
2566                        threshold_bin: 0,
2567                        missing_direction: crate::tree::shared::MissingBranchDirection::Node,
2568                        missing_value: -1.0,
2569                        left_child: 0,
2570                        right_child: 1,
2571                        sample_count: 5,
2572                        impurity: 3.5,
2573                        gain: 1.25,
2574                        variance: Some(0.7),
2575                    },
2576                ],
2577                root: 2,
2578            },
2579            options: options.clone(),
2580            num_features: 2,
2581            feature_preprocessing: preprocessing.clone(),
2582            training_canaries: 0,
2583        });
2584        let randomized = Model::DecisionTreeRegressor(DecisionTreeRegressor {
2585            algorithm: RegressionTreeAlgorithm::Randomized,
2586            criterion: Criterion::Median,
2587            structure: RegressionTreeStructure::Standard {
2588                nodes: vec![
2589                    RegressionNode::Leaf {
2590                        value: -1.0,
2591                        sample_count: 2,
2592                        variance: Some(0.25),
2593                    },
2594                    RegressionNode::Leaf {
2595                        value: 2.5,
2596                        sample_count: 3,
2597                        variance: Some(1.0),
2598                    },
2599                    RegressionNode::BinarySplit {
2600                        feature_index: 0,
2601                        threshold_bin: 0,
2602                        missing_direction: crate::tree::shared::MissingBranchDirection::Node,
2603                        missing_value: -1.0,
2604                        left_child: 0,
2605                        right_child: 1,
2606                        sample_count: 5,
2607                        impurity: 3.5,
2608                        gain: 0.8,
2609                        variance: Some(0.7),
2610                    },
2611                ],
2612                root: 2,
2613            },
2614            options: options.clone(),
2615            num_features: 2,
2616            feature_preprocessing: preprocessing.clone(),
2617            training_canaries: 0,
2618        });
2619        let oblivious = Model::DecisionTreeRegressor(DecisionTreeRegressor {
2620            algorithm: RegressionTreeAlgorithm::Oblivious,
2621            criterion: Criterion::Median,
2622            structure: RegressionTreeStructure::Oblivious {
2623                splits: vec![ObliviousSplit {
2624                    feature_index: 1,
2625                    threshold_bin: 127,
2626                    sample_count: 4,
2627                    impurity: 2.0,
2628                    gain: 0.5,
2629                }],
2630                leaf_values: vec![0.0, 10.0],
2631                leaf_sample_counts: vec![2, 2],
2632                leaf_variances: vec![Some(0.0), Some(1.0)],
2633            },
2634            options,
2635            num_features: 2,
2636            feature_preprocessing: preprocessing,
2637            training_canaries: 0,
2638        });
2639
2640        for (tree_type, model) in [
2641            ("cart", cart),
2642            ("randomized", randomized),
2643            ("oblivious", oblivious),
2644        ] {
2645            let json = model.serialize().unwrap();
2646            assert!(json.contains(&format!("\"tree_type\":\"{tree_type}\"")));
2647            assert!(json.contains("\"task\":\"regression\""));
2648        }
2649    }
2650
2651    #[test]
2652    fn cart_regressor_assigns_training_missing_values_to_best_child() {
2653        let table = DenseTable::with_canaries(
2654            vec![
2655                vec![0.0],
2656                vec![0.0],
2657                vec![1.0],
2658                vec![1.0],
2659                vec![f64::NAN],
2660                vec![f64::NAN],
2661            ],
2662            vec![0.0, 0.0, 10.0, 10.0, 0.0, 0.0],
2663            0,
2664        )
2665        .unwrap();
2666
2667        let model = train_cart_regressor(&table).unwrap();
2668
2669        let wrapped = Model::DecisionTreeRegressor(model.clone());
2670        assert_eq!(
2671            wrapped.predict_rows(vec![vec![f64::NAN]]).unwrap(),
2672            vec![0.0]
2673        );
2674    }
2675
2676    #[test]
2677    fn cart_regressor_defaults_unseen_missing_to_node_mean() {
2678        let table = DenseTable::with_canaries(
2679            vec![vec![0.0], vec![0.0], vec![1.0]],
2680            vec![0.0, 0.0, 9.0],
2681            0,
2682        )
2683        .unwrap();
2684
2685        let model = train_cart_regressor(&table).unwrap();
2686        let wrapped = Model::DecisionTreeRegressor(model.clone());
2687        let prediction = wrapped.predict_rows(vec![vec![f64::NAN]]).unwrap()[0];
2688
2689        assert!((prediction - 3.0).abs() < 1e-9);
2690    }
2691
2692    fn table_targets(table: &dyn TableAccess) -> Vec<f64> {
2693        (0..table.n_rows())
2694            .map(|row_idx| table.target_value(row_idx))
2695            .collect()
2696    }
2697}