Skip to main content

forestfire_core/tree/
classifier.rs

1use crate::ir::{
2    BinaryChildren, BinarySplit, IndexedLeaf, LeafIndexing, LeafPayload, MultiwayBranch,
3    MultiwaySplit, NodeStats, NodeTreeNode, ObliviousLevel, ObliviousSplit as IrObliviousSplit,
4    TrainingMetadata, TreeDefinition, criterion_name, feature_name, threshold_upper_bound,
5    tree_type_name,
6};
7use crate::sampling::sample_feature_subset;
8use crate::{Criterion, FeaturePreprocessing, Parallelism, capture_feature_preprocessing};
9use forestfire_data::TableAccess;
10use rand::rngs::StdRng;
11use rand::{Rng, SeedableRng};
12use rayon::prelude::*;
13use std::collections::{BTreeMap, BTreeSet};
14use std::error::Error;
15use std::fmt::{Display, Formatter};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum DecisionTreeAlgorithm {
19    Id3,
20    C45,
21    Cart,
22    Randomized,
23    Oblivious,
24}
25
26#[derive(Debug, Clone, Copy)]
27pub struct DecisionTreeOptions {
28    pub max_depth: usize,
29    pub min_samples_split: usize,
30    pub min_samples_leaf: usize,
31    pub max_features: Option<usize>,
32    pub random_seed: u64,
33}
34
35impl Default for DecisionTreeOptions {
36    fn default() -> Self {
37        Self {
38            max_depth: 8,
39            min_samples_split: 2,
40            min_samples_leaf: 1,
41            max_features: None,
42            random_seed: 0,
43        }
44    }
45}
46
47#[derive(Debug)]
48pub enum DecisionTreeError {
49    EmptyTarget,
50    InvalidTargetValue { row: usize, value: f64 },
51}
52
53impl Display for DecisionTreeError {
54    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
55        match self {
56            DecisionTreeError::EmptyTarget => write!(f, "Cannot train on an empty target vector."),
57            DecisionTreeError::InvalidTargetValue { row, value } => write!(
58                f,
59                "Classification targets must be finite values. Found {} at row {}.",
60                value, row
61            ),
62        }
63    }
64}
65
66impl Error for DecisionTreeError {}
67
68#[derive(Debug, Clone)]
69pub struct DecisionTreeClassifier {
70    algorithm: DecisionTreeAlgorithm,
71    criterion: Criterion,
72    class_labels: Vec<f64>,
73    structure: TreeStructure,
74    options: DecisionTreeOptions,
75    num_features: usize,
76    feature_preprocessing: Vec<FeaturePreprocessing>,
77    training_canaries: usize,
78}
79
80#[derive(Debug, Clone)]
81pub(crate) enum TreeStructure {
82    Standard {
83        nodes: Vec<TreeNode>,
84        root: usize,
85    },
86    Oblivious {
87        splits: Vec<ObliviousSplit>,
88        leaf_class_indices: Vec<usize>,
89        leaf_sample_counts: Vec<usize>,
90        leaf_class_counts: Vec<Vec<usize>>,
91    },
92}
93
94#[derive(Debug, Clone, Copy)]
95pub(crate) struct ObliviousSplit {
96    pub(crate) feature_index: usize,
97    pub(crate) threshold_bin: u16,
98    pub(crate) sample_count: usize,
99    pub(crate) impurity: f64,
100    pub(crate) gain: f64,
101}
102
103#[derive(Debug, Clone)]
104pub(crate) enum TreeNode {
105    Leaf {
106        class_index: usize,
107        sample_count: usize,
108        class_counts: Vec<usize>,
109    },
110    MultiwaySplit {
111        feature_index: usize,
112        fallback_class_index: usize,
113        branches: Vec<(u16, usize)>,
114        sample_count: usize,
115        impurity: f64,
116        gain: f64,
117        class_counts: Vec<usize>,
118    },
119    BinarySplit {
120        feature_index: usize,
121        threshold_bin: u16,
122        left_child: usize,
123        right_child: usize,
124        sample_count: usize,
125        impurity: f64,
126        gain: f64,
127        class_counts: Vec<usize>,
128    },
129}
130
131#[derive(Debug, Clone)]
132#[allow(dead_code)]
133enum SplitCandidate {
134    Multiway {
135        feature_index: usize,
136        score: f64,
137        branches: Vec<(u16, Vec<usize>)>,
138    },
139    Binary {
140        feature_index: usize,
141        score: f64,
142        threshold_bin: u16,
143        left_rows: Vec<usize>,
144        right_rows: Vec<usize>,
145    },
146}
147
148#[derive(Debug, Clone, Copy)]
149struct BinarySplitChoice {
150    feature_index: usize,
151    score: f64,
152    threshold_bin: u16,
153}
154
155#[derive(Debug, Clone)]
156struct MultiwaySplitChoice {
157    feature_index: usize,
158    score: f64,
159    branch_bins: Vec<u16>,
160}
161
162#[derive(Debug, Clone)]
163enum ClassificationFeatureHistogram {
164    Binary {
165        false_counts: Vec<usize>,
166        true_counts: Vec<usize>,
167        false_size: usize,
168        true_size: usize,
169    },
170    Numeric {
171        bin_class_counts: Vec<Vec<usize>>,
172        observed_bins: Vec<usize>,
173    },
174}
175
176pub fn train_id3(train_set: &dyn TableAccess) -> Result<DecisionTreeClassifier, DecisionTreeError> {
177    train_id3_with_criterion(train_set, Criterion::Entropy)
178}
179
180pub fn train_id3_with_criterion(
181    train_set: &dyn TableAccess,
182    criterion: Criterion,
183) -> Result<DecisionTreeClassifier, DecisionTreeError> {
184    train_id3_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
185}
186
187pub(crate) fn train_id3_with_criterion_and_parallelism(
188    train_set: &dyn TableAccess,
189    criterion: Criterion,
190    parallelism: Parallelism,
191) -> Result<DecisionTreeClassifier, DecisionTreeError> {
192    train_id3_with_criterion_parallelism_and_options(
193        train_set,
194        criterion,
195        parallelism,
196        DecisionTreeOptions::default(),
197    )
198}
199
200pub(crate) fn train_id3_with_criterion_parallelism_and_options(
201    train_set: &dyn TableAccess,
202    criterion: Criterion,
203    parallelism: Parallelism,
204    options: DecisionTreeOptions,
205) -> Result<DecisionTreeClassifier, DecisionTreeError> {
206    train_classifier(
207        train_set,
208        DecisionTreeAlgorithm::Id3,
209        criterion,
210        parallelism,
211        options,
212    )
213}
214
215pub fn train_c45(train_set: &dyn TableAccess) -> Result<DecisionTreeClassifier, DecisionTreeError> {
216    train_c45_with_criterion(train_set, Criterion::Entropy)
217}
218
219pub fn train_c45_with_criterion(
220    train_set: &dyn TableAccess,
221    criterion: Criterion,
222) -> Result<DecisionTreeClassifier, DecisionTreeError> {
223    train_c45_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
224}
225
226pub(crate) fn train_c45_with_criterion_and_parallelism(
227    train_set: &dyn TableAccess,
228    criterion: Criterion,
229    parallelism: Parallelism,
230) -> Result<DecisionTreeClassifier, DecisionTreeError> {
231    train_c45_with_criterion_parallelism_and_options(
232        train_set,
233        criterion,
234        parallelism,
235        DecisionTreeOptions::default(),
236    )
237}
238
239pub(crate) fn train_c45_with_criterion_parallelism_and_options(
240    train_set: &dyn TableAccess,
241    criterion: Criterion,
242    parallelism: Parallelism,
243    options: DecisionTreeOptions,
244) -> Result<DecisionTreeClassifier, DecisionTreeError> {
245    train_classifier(
246        train_set,
247        DecisionTreeAlgorithm::C45,
248        criterion,
249        parallelism,
250        options,
251    )
252}
253
254pub fn train_cart(
255    train_set: &dyn TableAccess,
256) -> Result<DecisionTreeClassifier, DecisionTreeError> {
257    train_cart_with_criterion(train_set, Criterion::Gini)
258}
259
260pub fn train_cart_with_criterion(
261    train_set: &dyn TableAccess,
262    criterion: Criterion,
263) -> Result<DecisionTreeClassifier, DecisionTreeError> {
264    train_cart_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
265}
266
267pub(crate) fn train_cart_with_criterion_and_parallelism(
268    train_set: &dyn TableAccess,
269    criterion: Criterion,
270    parallelism: Parallelism,
271) -> Result<DecisionTreeClassifier, DecisionTreeError> {
272    train_cart_with_criterion_parallelism_and_options(
273        train_set,
274        criterion,
275        parallelism,
276        DecisionTreeOptions::default(),
277    )
278}
279
280pub(crate) fn train_cart_with_criterion_parallelism_and_options(
281    train_set: &dyn TableAccess,
282    criterion: Criterion,
283    parallelism: Parallelism,
284    options: DecisionTreeOptions,
285) -> Result<DecisionTreeClassifier, DecisionTreeError> {
286    train_classifier(
287        train_set,
288        DecisionTreeAlgorithm::Cart,
289        criterion,
290        parallelism,
291        options,
292    )
293}
294
295pub fn train_oblivious(
296    train_set: &dyn TableAccess,
297) -> Result<DecisionTreeClassifier, DecisionTreeError> {
298    train_oblivious_with_criterion(train_set, Criterion::Gini)
299}
300
301pub fn train_oblivious_with_criterion(
302    train_set: &dyn TableAccess,
303    criterion: Criterion,
304) -> Result<DecisionTreeClassifier, DecisionTreeError> {
305    train_oblivious_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
306}
307
308pub(crate) fn train_oblivious_with_criterion_and_parallelism(
309    train_set: &dyn TableAccess,
310    criterion: Criterion,
311    parallelism: Parallelism,
312) -> Result<DecisionTreeClassifier, DecisionTreeError> {
313    train_oblivious_with_criterion_parallelism_and_options(
314        train_set,
315        criterion,
316        parallelism,
317        DecisionTreeOptions::default(),
318    )
319}
320
321pub(crate) fn train_oblivious_with_criterion_parallelism_and_options(
322    train_set: &dyn TableAccess,
323    criterion: Criterion,
324    parallelism: Parallelism,
325    options: DecisionTreeOptions,
326) -> Result<DecisionTreeClassifier, DecisionTreeError> {
327    train_classifier(
328        train_set,
329        DecisionTreeAlgorithm::Oblivious,
330        criterion,
331        parallelism,
332        options,
333    )
334}
335
336pub fn train_randomized(
337    train_set: &dyn TableAccess,
338) -> Result<DecisionTreeClassifier, DecisionTreeError> {
339    train_randomized_with_criterion(train_set, Criterion::Gini)
340}
341
342pub fn train_randomized_with_criterion(
343    train_set: &dyn TableAccess,
344    criterion: Criterion,
345) -> Result<DecisionTreeClassifier, DecisionTreeError> {
346    train_randomized_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
347}
348
349pub(crate) fn train_randomized_with_criterion_and_parallelism(
350    train_set: &dyn TableAccess,
351    criterion: Criterion,
352    parallelism: Parallelism,
353) -> Result<DecisionTreeClassifier, DecisionTreeError> {
354    train_randomized_with_criterion_parallelism_and_options(
355        train_set,
356        criterion,
357        parallelism,
358        DecisionTreeOptions::default(),
359    )
360}
361
362pub(crate) fn train_randomized_with_criterion_parallelism_and_options(
363    train_set: &dyn TableAccess,
364    criterion: Criterion,
365    parallelism: Parallelism,
366    options: DecisionTreeOptions,
367) -> Result<DecisionTreeClassifier, DecisionTreeError> {
368    train_classifier(
369        train_set,
370        DecisionTreeAlgorithm::Randomized,
371        criterion,
372        parallelism,
373        options,
374    )
375}
376
377fn train_classifier(
378    train_set: &dyn TableAccess,
379    algorithm: DecisionTreeAlgorithm,
380    criterion: Criterion,
381    parallelism: Parallelism,
382    options: DecisionTreeOptions,
383) -> Result<DecisionTreeClassifier, DecisionTreeError> {
384    if train_set.n_rows() == 0 {
385        return Err(DecisionTreeError::EmptyTarget);
386    }
387
388    let (class_labels, class_indices) = encode_class_labels(train_set)?;
389    let structure = match algorithm {
390        DecisionTreeAlgorithm::Oblivious => train_oblivious_structure(
391            train_set,
392            &class_indices,
393            &class_labels,
394            criterion,
395            parallelism,
396            options,
397        ),
398        DecisionTreeAlgorithm::Cart | DecisionTreeAlgorithm::Randomized => {
399            let mut nodes = Vec::new();
400            let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
401            let context = BuildContext {
402                table: train_set,
403                class_indices: &class_indices,
404                class_labels: &class_labels,
405                algorithm,
406                criterion,
407                parallelism,
408                options,
409            };
410            let root = build_binary_node_in_place(&context, &mut nodes, &mut all_rows, 0);
411            TreeStructure::Standard { nodes, root }
412        }
413        DecisionTreeAlgorithm::Id3 | DecisionTreeAlgorithm::C45 => {
414            let mut nodes = Vec::new();
415            let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
416            let context = BuildContext {
417                table: train_set,
418                class_indices: &class_indices,
419                class_labels: &class_labels,
420                algorithm,
421                criterion,
422                parallelism,
423                options,
424            };
425            let root = build_multiway_node_in_place(&context, &mut nodes, &mut all_rows, 0);
426            TreeStructure::Standard { nodes, root }
427        }
428    };
429
430    Ok(DecisionTreeClassifier {
431        algorithm,
432        criterion,
433        class_labels,
434        structure,
435        options,
436        num_features: train_set.n_features(),
437        feature_preprocessing: capture_feature_preprocessing(train_set),
438        training_canaries: train_set.canaries(),
439    })
440}
441
442impl DecisionTreeClassifier {
443    pub fn algorithm(&self) -> DecisionTreeAlgorithm {
444        self.algorithm
445    }
446
447    pub fn criterion(&self) -> Criterion {
448        self.criterion
449    }
450
451    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
452        (0..table.n_rows())
453            .map(|row_idx| self.predict_row(table, row_idx))
454            .collect()
455    }
456
457    pub fn predict_proba_table(&self, table: &dyn TableAccess) -> Vec<Vec<f64>> {
458        (0..table.n_rows())
459            .map(|row_idx| self.predict_proba_row(table, row_idx))
460            .collect()
461    }
462
463    fn predict_row(&self, table: &dyn TableAccess, row_idx: usize) -> f64 {
464        match &self.structure {
465            TreeStructure::Standard { nodes, root } => {
466                let mut node_index = *root;
467
468                loop {
469                    match &nodes[node_index] {
470                        TreeNode::Leaf { class_index, .. } => {
471                            return self.class_labels[*class_index];
472                        }
473                        TreeNode::MultiwaySplit {
474                            feature_index,
475                            fallback_class_index,
476                            branches,
477                            ..
478                        } => {
479                            let bin = table.binned_value(*feature_index, row_idx);
480                            if let Some((_, child_index)) =
481                                branches.iter().find(|(branch_bin, _)| *branch_bin == bin)
482                            {
483                                node_index = *child_index;
484                            } else {
485                                return self.class_labels[*fallback_class_index];
486                            }
487                        }
488                        TreeNode::BinarySplit {
489                            feature_index,
490                            threshold_bin,
491                            left_child,
492                            right_child,
493                            ..
494                        } => {
495                            let bin = table.binned_value(*feature_index, row_idx);
496                            node_index = if bin <= *threshold_bin {
497                                *left_child
498                            } else {
499                                *right_child
500                            };
501                        }
502                    }
503                }
504            }
505            TreeStructure::Oblivious {
506                splits,
507                leaf_class_indices,
508                ..
509            } => {
510                let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
511                    let go_right =
512                        table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
513                    (leaf_index << 1) | usize::from(go_right)
514                });
515
516                self.class_labels[leaf_class_indices[leaf_index]]
517            }
518        }
519    }
520
521    fn predict_proba_row(&self, table: &dyn TableAccess, row_idx: usize) -> Vec<f64> {
522        match &self.structure {
523            TreeStructure::Standard { nodes, root } => {
524                let mut node_index = *root;
525
526                loop {
527                    match &nodes[node_index] {
528                        TreeNode::Leaf { class_counts, .. } => {
529                            return normalized_class_probabilities(class_counts);
530                        }
531                        TreeNode::MultiwaySplit {
532                            feature_index,
533                            branches,
534                            class_counts,
535                            ..
536                        } => {
537                            let bin = table.binned_value(*feature_index, row_idx);
538                            if let Some((_, child_index)) =
539                                branches.iter().find(|(branch_bin, _)| *branch_bin == bin)
540                            {
541                                node_index = *child_index;
542                            } else {
543                                return normalized_class_probabilities(class_counts);
544                            }
545                        }
546                        TreeNode::BinarySplit {
547                            feature_index,
548                            threshold_bin,
549                            left_child,
550                            right_child,
551                            ..
552                        } => {
553                            let bin = table.binned_value(*feature_index, row_idx);
554                            node_index = if bin <= *threshold_bin {
555                                *left_child
556                            } else {
557                                *right_child
558                            };
559                        }
560                    }
561                }
562            }
563            TreeStructure::Oblivious {
564                splits,
565                leaf_class_counts,
566                ..
567            } => {
568                let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
569                    let go_right =
570                        table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
571                    (leaf_index << 1) | usize::from(go_right)
572                });
573
574                normalized_class_probabilities(&leaf_class_counts[leaf_index])
575            }
576        }
577    }
578
579    pub(crate) fn class_labels(&self) -> &[f64] {
580        &self.class_labels
581    }
582
583    pub(crate) fn structure(&self) -> &TreeStructure {
584        &self.structure
585    }
586
587    pub(crate) fn num_features(&self) -> usize {
588        self.num_features
589    }
590
591    pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
592        &self.feature_preprocessing
593    }
594
595    pub(crate) fn training_metadata(&self) -> TrainingMetadata {
596        TrainingMetadata {
597            algorithm: "dt".to_string(),
598            task: "classification".to_string(),
599            tree_type: tree_type_name(match self.algorithm {
600                DecisionTreeAlgorithm::Id3 => crate::TreeType::Id3,
601                DecisionTreeAlgorithm::C45 => crate::TreeType::C45,
602                DecisionTreeAlgorithm::Cart => crate::TreeType::Cart,
603                DecisionTreeAlgorithm::Randomized => crate::TreeType::Randomized,
604                DecisionTreeAlgorithm::Oblivious => crate::TreeType::Oblivious,
605            })
606            .to_string(),
607            criterion: criterion_name(self.criterion).to_string(),
608            canaries: self.training_canaries,
609            compute_oob: false,
610            max_depth: Some(self.options.max_depth),
611            min_samples_split: Some(self.options.min_samples_split),
612            min_samples_leaf: Some(self.options.min_samples_leaf),
613            n_trees: None,
614            max_features: self.options.max_features,
615            seed: None,
616            oob_score: None,
617            class_labels: Some(self.class_labels.clone()),
618            learning_rate: None,
619            bootstrap: None,
620            top_gradient_fraction: None,
621            other_gradient_fraction: None,
622        }
623    }
624
625    pub(crate) fn to_ir_tree(&self) -> TreeDefinition {
626        match &self.structure {
627            TreeStructure::Standard { nodes, root } => {
628                let depths = standard_node_depths(nodes, *root);
629                TreeDefinition::NodeTree {
630                    tree_id: 0,
631                    weight: 1.0,
632                    root_node_id: *root,
633                    nodes: nodes
634                        .iter()
635                        .enumerate()
636                        .map(|(node_id, node)| match node {
637                            TreeNode::Leaf {
638                                class_index,
639                                sample_count,
640                                class_counts,
641                            } => NodeTreeNode::Leaf {
642                                node_id,
643                                depth: depths[node_id],
644                                leaf: self.class_leaf(*class_index),
645                                stats: NodeStats {
646                                    sample_count: *sample_count,
647                                    impurity: None,
648                                    gain: None,
649                                    class_counts: Some(class_counts.clone()),
650                                    variance: None,
651                                },
652                            },
653                            TreeNode::BinarySplit {
654                                feature_index,
655                                threshold_bin,
656                                left_child,
657                                right_child,
658                                sample_count,
659                                impurity,
660                                gain,
661                                class_counts,
662                            } => NodeTreeNode::BinaryBranch {
663                                node_id,
664                                depth: depths[node_id],
665                                split: binary_split_ir(
666                                    *feature_index,
667                                    *threshold_bin,
668                                    &self.feature_preprocessing,
669                                ),
670                                children: BinaryChildren {
671                                    left: *left_child,
672                                    right: *right_child,
673                                },
674                                stats: NodeStats {
675                                    sample_count: *sample_count,
676                                    impurity: Some(*impurity),
677                                    gain: Some(*gain),
678                                    class_counts: Some(class_counts.clone()),
679                                    variance: None,
680                                },
681                            },
682                            TreeNode::MultiwaySplit {
683                                feature_index,
684                                fallback_class_index,
685                                branches,
686                                sample_count,
687                                impurity,
688                                gain,
689                                class_counts,
690                            } => NodeTreeNode::MultiwayBranch {
691                                node_id,
692                                depth: depths[node_id],
693                                split: MultiwaySplit {
694                                    split_type: "binned_value_multiway".to_string(),
695                                    feature_index: *feature_index,
696                                    feature_name: feature_name(*feature_index),
697                                    comparison_dtype: "uint16".to_string(),
698                                },
699                                branches: branches
700                                    .iter()
701                                    .map(|(bin, child)| MultiwayBranch {
702                                        bin: *bin,
703                                        child: *child,
704                                    })
705                                    .collect(),
706                                unmatched_leaf: self.class_leaf(*fallback_class_index),
707                                stats: NodeStats {
708                                    sample_count: *sample_count,
709                                    impurity: Some(*impurity),
710                                    gain: Some(*gain),
711                                    class_counts: Some(class_counts.clone()),
712                                    variance: None,
713                                },
714                            },
715                        })
716                        .collect(),
717                }
718            }
719            TreeStructure::Oblivious {
720                splits,
721                leaf_class_indices,
722                leaf_sample_counts,
723                leaf_class_counts,
724            } => TreeDefinition::ObliviousLevels {
725                tree_id: 0,
726                weight: 1.0,
727                depth: splits.len(),
728                levels: splits
729                    .iter()
730                    .enumerate()
731                    .map(|(level, split)| ObliviousLevel {
732                        level,
733                        split: oblivious_split_ir(
734                            split.feature_index,
735                            split.threshold_bin,
736                            &self.feature_preprocessing,
737                        ),
738                        stats: NodeStats {
739                            sample_count: split.sample_count,
740                            impurity: Some(split.impurity),
741                            gain: Some(split.gain),
742                            class_counts: None,
743                            variance: None,
744                        },
745                    })
746                    .collect(),
747                leaf_indexing: LeafIndexing {
748                    bit_order: "msb_first".to_string(),
749                    index_formula: "sum(bit[level] << (depth - 1 - level))".to_string(),
750                },
751                leaves: leaf_class_indices
752                    .iter()
753                    .enumerate()
754                    .map(|(leaf_index, class_index)| IndexedLeaf {
755                        leaf_index,
756                        leaf: self.class_leaf(*class_index),
757                        stats: NodeStats {
758                            sample_count: leaf_sample_counts[leaf_index],
759                            impurity: None,
760                            gain: None,
761                            class_counts: Some(leaf_class_counts[leaf_index].clone()),
762                            variance: None,
763                        },
764                    })
765                    .collect(),
766            },
767        }
768    }
769
770    fn class_leaf(&self, class_index: usize) -> LeafPayload {
771        LeafPayload::ClassIndex {
772            class_index,
773            class_value: self.class_labels[class_index],
774        }
775    }
776
777    #[allow(clippy::too_many_arguments)]
778    pub(crate) fn from_ir_parts(
779        algorithm: DecisionTreeAlgorithm,
780        criterion: Criterion,
781        class_labels: Vec<f64>,
782        structure: TreeStructure,
783        options: DecisionTreeOptions,
784        num_features: usize,
785        feature_preprocessing: Vec<FeaturePreprocessing>,
786        training_canaries: usize,
787    ) -> Self {
788        Self {
789            algorithm,
790            criterion,
791            class_labels,
792            structure,
793            options,
794            num_features,
795            feature_preprocessing,
796            training_canaries,
797        }
798    }
799}
800
801fn build_binary_node_in_place(
802    context: &BuildContext<'_>,
803    nodes: &mut Vec<TreeNode>,
804    rows: &mut [usize],
805    depth: usize,
806) -> usize {
807    build_binary_node_in_place_with_hist(context, nodes, rows, depth, None)
808}
809
810fn build_binary_node_in_place_with_hist(
811    context: &BuildContext<'_>,
812    nodes: &mut Vec<TreeNode>,
813    rows: &mut [usize],
814    depth: usize,
815    histograms: Option<Vec<ClassificationFeatureHistogram>>,
816) -> usize {
817    let majority_class_index =
818        majority_class(rows, context.class_indices, context.class_labels.len());
819    let current_class_counts =
820        class_counts(rows, context.class_indices, context.class_labels.len());
821
822    if rows.is_empty()
823        || depth >= context.options.max_depth
824        || rows.len() < context.options.min_samples_split
825        || is_pure(rows, context.class_indices)
826    {
827        return push_leaf(
828            nodes,
829            majority_class_index,
830            rows.len(),
831            current_class_counts,
832        );
833    }
834
835    let scoring = SplitScoringContext {
836        table: context.table,
837        class_indices: context.class_indices,
838        num_classes: context.class_labels.len(),
839        criterion: context.criterion,
840        min_samples_leaf: context.options.min_samples_leaf,
841    };
842    let histograms = histograms.unwrap_or_else(|| {
843        build_classification_node_histograms(
844            context.table,
845            context.class_indices,
846            rows,
847            context.class_labels.len(),
848        )
849    });
850    let feature_indices = candidate_feature_indices(
851        context.table.binned_feature_count(),
852        context.options.max_features,
853        node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
854    );
855    let best_split = if context.parallelism.enabled() {
856        feature_indices
857            .into_par_iter()
858            .filter_map(|feature_index| {
859                score_binary_split_choice_from_hist(
860                    &scoring,
861                    &histograms[feature_index],
862                    feature_index,
863                    rows,
864                    &current_class_counts,
865                    context.algorithm,
866                )
867            })
868            .max_by(|left, right| left.score.total_cmp(&right.score))
869    } else {
870        feature_indices
871            .into_iter()
872            .filter_map(|feature_index| {
873                score_binary_split_choice_from_hist(
874                    &scoring,
875                    &histograms[feature_index],
876                    feature_index,
877                    rows,
878                    &current_class_counts,
879                    context.algorithm,
880                )
881            })
882            .max_by(|left, right| left.score.total_cmp(&right.score))
883    };
884
885    match best_split {
886        Some(best_split)
887            if context
888                .table
889                .is_canary_binned_feature(best_split.feature_index) =>
890        {
891            push_leaf(
892                nodes,
893                majority_class_index,
894                rows.len(),
895                current_class_counts,
896            )
897        }
898        Some(best_split) if best_split.score > 0.0 => {
899            let impurity =
900                classification_impurity(&current_class_counts, rows.len(), context.criterion);
901            let left_count = partition_rows_for_binary_split(
902                context.table,
903                best_split.feature_index,
904                best_split.threshold_bin,
905                rows,
906            );
907            let (left_rows, right_rows) = rows.split_at_mut(left_count);
908            let (left_histograms, right_histograms) = if left_rows.len() <= right_rows.len() {
909                let left_histograms = build_classification_node_histograms(
910                    context.table,
911                    context.class_indices,
912                    left_rows,
913                    context.class_labels.len(),
914                );
915                let right_histograms =
916                    subtract_classification_node_histograms(&histograms, &left_histograms);
917                (left_histograms, right_histograms)
918            } else {
919                let right_histograms = build_classification_node_histograms(
920                    context.table,
921                    context.class_indices,
922                    right_rows,
923                    context.class_labels.len(),
924                );
925                let left_histograms =
926                    subtract_classification_node_histograms(&histograms, &right_histograms);
927                (left_histograms, right_histograms)
928            };
929            let left_child = build_binary_node_in_place_with_hist(
930                context,
931                nodes,
932                left_rows,
933                depth + 1,
934                Some(left_histograms),
935            );
936            let right_child = build_binary_node_in_place_with_hist(
937                context,
938                nodes,
939                right_rows,
940                depth + 1,
941                Some(right_histograms),
942            );
943
944            push_node(
945                nodes,
946                TreeNode::BinarySplit {
947                    feature_index: best_split.feature_index,
948                    threshold_bin: best_split.threshold_bin,
949                    left_child,
950                    right_child,
951                    sample_count: rows.len(),
952                    impurity,
953                    gain: best_split.score,
954                    class_counts: current_class_counts,
955                },
956            )
957        }
958        _ => push_leaf(
959            nodes,
960            majority_class_index,
961            rows.len(),
962            current_class_counts,
963        ),
964    }
965}
966
967fn build_multiway_node_in_place(
968    context: &BuildContext<'_>,
969    nodes: &mut Vec<TreeNode>,
970    rows: &mut [usize],
971    depth: usize,
972) -> usize {
973    let majority_class_index =
974        majority_class(rows, context.class_indices, context.class_labels.len());
975    let current_class_counts =
976        class_counts(rows, context.class_indices, context.class_labels.len());
977
978    if rows.is_empty()
979        || depth >= context.options.max_depth
980        || rows.len() < context.options.min_samples_split
981        || is_pure(rows, context.class_indices)
982    {
983        return push_leaf(
984            nodes,
985            majority_class_index,
986            rows.len(),
987            current_class_counts,
988        );
989    }
990
991    let metric = match context.algorithm {
992        DecisionTreeAlgorithm::Id3 => MultiwayMetric::InformationGain,
993        DecisionTreeAlgorithm::C45 => MultiwayMetric::GainRatio,
994        _ => unreachable!("multiway builder only supports id3/c45"),
995    };
996    let scoring = SplitScoringContext {
997        table: context.table,
998        class_indices: context.class_indices,
999        num_classes: context.class_labels.len(),
1000        criterion: context.criterion,
1001        min_samples_leaf: context.options.min_samples_leaf,
1002    };
1003    let feature_indices = candidate_feature_indices(
1004        context.table.binned_feature_count(),
1005        context.options.max_features,
1006        node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
1007    );
1008    let best_split = if context.parallelism.enabled() {
1009        feature_indices
1010            .into_par_iter()
1011            .filter_map(|feature_index| {
1012                score_multiway_split_choice(&scoring, feature_index, rows, metric)
1013            })
1014            .max_by(|left, right| left.score.total_cmp(&right.score))
1015    } else {
1016        feature_indices
1017            .into_iter()
1018            .filter_map(|feature_index| {
1019                score_multiway_split_choice(&scoring, feature_index, rows, metric)
1020            })
1021            .max_by(|left, right| left.score.total_cmp(&right.score))
1022    };
1023
1024    match best_split {
1025        Some(best_split)
1026            if context
1027                .table
1028                .is_canary_binned_feature(best_split.feature_index) =>
1029        {
1030            push_leaf(
1031                nodes,
1032                majority_class_index,
1033                rows.len(),
1034                current_class_counts,
1035            )
1036        }
1037        Some(best_split) if best_split.score > 0.0 => {
1038            let impurity =
1039                classification_impurity(&current_class_counts, rows.len(), context.criterion);
1040            let branch_ranges = partition_rows_for_multiway_split(
1041                context.table,
1042                best_split.feature_index,
1043                &best_split.branch_bins,
1044                rows,
1045            );
1046            let mut branch_nodes = Vec::with_capacity(branch_ranges.len());
1047            for (bin, start, end) in branch_ranges {
1048                let child =
1049                    build_multiway_node_in_place(context, nodes, &mut rows[start..end], depth + 1);
1050                branch_nodes.push((bin, child));
1051            }
1052
1053            push_node(
1054                nodes,
1055                TreeNode::MultiwaySplit {
1056                    feature_index: best_split.feature_index,
1057                    fallback_class_index: majority_class_index,
1058                    branches: branch_nodes,
1059                    sample_count: rows.len(),
1060                    impurity,
1061                    gain: best_split.score,
1062                    class_counts: current_class_counts,
1063                },
1064            )
1065        }
1066        _ => push_leaf(
1067            nodes,
1068            majority_class_index,
1069            rows.len(),
1070            current_class_counts,
1071        ),
1072    }
1073}
1074
1075fn normalized_class_probabilities(class_counts: &[usize]) -> Vec<f64> {
1076    let total = class_counts.iter().sum::<usize>();
1077    if total == 0 {
1078        return vec![0.0; class_counts.len()];
1079    }
1080
1081    class_counts
1082        .iter()
1083        .map(|count| *count as f64 / total as f64)
1084        .collect()
1085}
1086
1087fn standard_node_depths(nodes: &[TreeNode], root: usize) -> Vec<usize> {
1088    let mut depths = vec![0; nodes.len()];
1089    populate_depths(nodes, root, 0, &mut depths);
1090    depths
1091}
1092
1093fn populate_depths(nodes: &[TreeNode], node_id: usize, depth: usize, depths: &mut [usize]) {
1094    depths[node_id] = depth;
1095    match &nodes[node_id] {
1096        TreeNode::Leaf { .. } => {}
1097        TreeNode::BinarySplit {
1098            left_child,
1099            right_child,
1100            ..
1101        } => {
1102            populate_depths(nodes, *left_child, depth + 1, depths);
1103            populate_depths(nodes, *right_child, depth + 1, depths);
1104        }
1105        TreeNode::MultiwaySplit { branches, .. } => {
1106            for (_, child) in branches {
1107                populate_depths(nodes, *child, depth + 1, depths);
1108            }
1109        }
1110    }
1111}
1112
1113fn binary_split_ir(
1114    feature_index: usize,
1115    threshold_bin: u16,
1116    preprocessing: &[FeaturePreprocessing],
1117) -> BinarySplit {
1118    match preprocessing.get(feature_index) {
1119        Some(FeaturePreprocessing::Binary) => BinarySplit::BooleanTest {
1120            feature_index,
1121            feature_name: feature_name(feature_index),
1122            false_child_semantics: "left".to_string(),
1123            true_child_semantics: "right".to_string(),
1124        },
1125        Some(FeaturePreprocessing::Numeric { .. }) | None => BinarySplit::NumericBinThreshold {
1126            feature_index,
1127            feature_name: feature_name(feature_index),
1128            operator: "<=".to_string(),
1129            threshold_bin,
1130            threshold_upper_bound: threshold_upper_bound(
1131                preprocessing,
1132                feature_index,
1133                threshold_bin,
1134            ),
1135            comparison_dtype: "uint16".to_string(),
1136        },
1137    }
1138}
1139
1140fn oblivious_split_ir(
1141    feature_index: usize,
1142    threshold_bin: u16,
1143    preprocessing: &[FeaturePreprocessing],
1144) -> IrObliviousSplit {
1145    match preprocessing.get(feature_index) {
1146        Some(FeaturePreprocessing::Binary) => IrObliviousSplit::BooleanTest {
1147            feature_index,
1148            feature_name: feature_name(feature_index),
1149            bit_when_false: 0,
1150            bit_when_true: 1,
1151        },
1152        Some(FeaturePreprocessing::Numeric { .. }) | None => {
1153            IrObliviousSplit::NumericBinThreshold {
1154                feature_index,
1155                feature_name: feature_name(feature_index),
1156                operator: "<=".to_string(),
1157                threshold_bin,
1158                threshold_upper_bound: threshold_upper_bound(
1159                    preprocessing,
1160                    feature_index,
1161                    threshold_bin,
1162                ),
1163                comparison_dtype: "uint16".to_string(),
1164                bit_when_true: 0,
1165                bit_when_false: 1,
1166            }
1167        }
1168    }
1169}
1170
1171fn encode_class_labels(
1172    train_set: &dyn TableAccess,
1173) -> Result<(Vec<f64>, Vec<usize>), DecisionTreeError> {
1174    let targets: Vec<f64> = (0..train_set.n_rows())
1175        .map(|row_idx| {
1176            let value = train_set.target_value(row_idx);
1177            if value.is_finite() {
1178                Ok(value)
1179            } else {
1180                Err(DecisionTreeError::InvalidTargetValue {
1181                    row: row_idx,
1182                    value,
1183                })
1184            }
1185        })
1186        .collect::<Result<_, _>>()?;
1187
1188    let class_labels = targets
1189        .iter()
1190        .copied()
1191        .fold(Vec::<f64>::new(), |mut labels, value| {
1192            if labels
1193                .binary_search_by(|candidate| candidate.total_cmp(&value))
1194                .is_err()
1195            {
1196                labels.push(value);
1197                labels.sort_by(|left, right| left.total_cmp(right));
1198            }
1199            labels
1200        });
1201
1202    let class_indices = targets
1203        .iter()
1204        .map(|value| {
1205            class_labels
1206                .binary_search_by(|candidate| candidate.total_cmp(value))
1207                .expect("target value must exist in class label vocabulary")
1208        })
1209        .collect();
1210
1211    Ok((class_labels, class_indices))
1212}
1213
1214#[allow(dead_code)]
1215fn build_node(
1216    context: &BuildContext<'_>,
1217    nodes: &mut Vec<TreeNode>,
1218    rows: &[usize],
1219    depth: usize,
1220) -> usize {
1221    let majority_class_index =
1222        majority_class(rows, context.class_indices, context.class_labels.len());
1223    let current_class_counts =
1224        class_counts(rows, context.class_indices, context.class_labels.len());
1225
1226    if rows.is_empty()
1227        || depth >= context.options.max_depth
1228        || rows.len() < context.options.min_samples_split
1229        || is_pure(rows, context.class_indices)
1230    {
1231        return push_leaf(
1232            nodes,
1233            majority_class_index,
1234            rows.len(),
1235            current_class_counts,
1236        );
1237    }
1238
1239    let scoring = SplitScoringContext {
1240        table: context.table,
1241        class_indices: context.class_indices,
1242        num_classes: context.class_labels.len(),
1243        criterion: context.criterion,
1244        min_samples_leaf: context.options.min_samples_leaf,
1245    };
1246    let feature_indices = candidate_feature_indices(
1247        context.table.binned_feature_count(),
1248        context.options.max_features,
1249        node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
1250    );
1251    let best_split = if context.parallelism.enabled() {
1252        feature_indices
1253            .into_par_iter()
1254            .filter_map(|feature_index| {
1255                score_split(&scoring, feature_index, rows, context.algorithm)
1256            })
1257            .max_by(|left, right| split_score(left).total_cmp(&split_score(right)))
1258    } else {
1259        feature_indices
1260            .into_iter()
1261            .filter_map(|feature_index| {
1262                score_split(&scoring, feature_index, rows, context.algorithm)
1263            })
1264            .max_by(|left, right| split_score(left).total_cmp(&split_score(right)))
1265    };
1266
1267    match best_split {
1268        Some(best_split)
1269            if context
1270                .table
1271                .is_canary_binned_feature(split_feature_index(&best_split)) =>
1272        {
1273            push_leaf(
1274                nodes,
1275                majority_class_index,
1276                rows.len(),
1277                current_class_counts,
1278            )
1279        }
1280        Some(SplitCandidate::Multiway {
1281            feature_index,
1282            score,
1283            branches,
1284        }) if score > 0.0 => {
1285            let impurity =
1286                classification_impurity(&current_class_counts, rows.len(), context.criterion);
1287            let branch_nodes = branches
1288                .into_iter()
1289                .map(|(bin, branch_rows)| {
1290                    (bin, build_node(context, nodes, &branch_rows, depth + 1))
1291                })
1292                .collect();
1293
1294            push_node(
1295                nodes,
1296                TreeNode::MultiwaySplit {
1297                    feature_index,
1298                    fallback_class_index: majority_class_index,
1299                    branches: branch_nodes,
1300                    sample_count: rows.len(),
1301                    impurity,
1302                    gain: score,
1303                    class_counts: current_class_counts,
1304                },
1305            )
1306        }
1307        Some(SplitCandidate::Binary {
1308            feature_index,
1309            score,
1310            threshold_bin,
1311            left_rows,
1312            right_rows,
1313        }) if score > 0.0 => {
1314            let impurity =
1315                classification_impurity(&current_class_counts, rows.len(), context.criterion);
1316            let left_child = build_node(context, nodes, &left_rows, depth + 1);
1317            let right_child = build_node(context, nodes, &right_rows, depth + 1);
1318
1319            push_node(
1320                nodes,
1321                TreeNode::BinarySplit {
1322                    feature_index,
1323                    threshold_bin,
1324                    left_child,
1325                    right_child,
1326                    sample_count: rows.len(),
1327                    impurity,
1328                    gain: score,
1329                    class_counts: current_class_counts,
1330                },
1331            )
1332        }
1333        _ => push_leaf(
1334            nodes,
1335            majority_class_index,
1336            rows.len(),
1337            current_class_counts,
1338        ),
1339    }
1340}
1341
1342struct BuildContext<'a> {
1343    table: &'a dyn TableAccess,
1344    class_indices: &'a [usize],
1345    class_labels: &'a [f64],
1346    algorithm: DecisionTreeAlgorithm,
1347    criterion: Criterion,
1348    parallelism: Parallelism,
1349    options: DecisionTreeOptions,
1350}
1351
1352struct SplitScoringContext<'a> {
1353    table: &'a dyn TableAccess,
1354    class_indices: &'a [usize],
1355    num_classes: usize,
1356    criterion: Criterion,
1357    min_samples_leaf: usize,
1358}
1359
1360fn build_classification_node_histograms(
1361    table: &dyn TableAccess,
1362    class_indices: &[usize],
1363    rows: &[usize],
1364    num_classes: usize,
1365) -> Vec<ClassificationFeatureHistogram> {
1366    (0..table.binned_feature_count())
1367        .map(|feature_index| {
1368            if table.is_binary_binned_feature(feature_index) {
1369                let mut false_counts = vec![0usize; num_classes];
1370                let mut true_counts = vec![0usize; num_classes];
1371                let mut false_size = 0usize;
1372                let mut true_size = 0usize;
1373                for row_idx in rows {
1374                    let class_index = class_indices[*row_idx];
1375                    if !table
1376                        .binned_boolean_value(feature_index, *row_idx)
1377                        .expect("binary feature must expose boolean values")
1378                    {
1379                        false_counts[class_index] += 1;
1380                        false_size += 1;
1381                    } else {
1382                        true_counts[class_index] += 1;
1383                        true_size += 1;
1384                    }
1385                }
1386                ClassificationFeatureHistogram::Binary {
1387                    false_counts,
1388                    true_counts,
1389                    false_size,
1390                    true_size,
1391                }
1392            } else {
1393                let bin_cap = table.numeric_bin_cap();
1394                let mut bin_class_counts = vec![vec![0usize; num_classes]; bin_cap];
1395                let mut observed_bins = vec![false; bin_cap];
1396                for row_idx in rows {
1397                    let bin = table.binned_value(feature_index, *row_idx) as usize;
1398                    bin_class_counts[bin][class_indices[*row_idx]] += 1;
1399                    observed_bins[bin] = true;
1400                }
1401                ClassificationFeatureHistogram::Numeric {
1402                    bin_class_counts,
1403                    observed_bins: observed_bins
1404                        .into_iter()
1405                        .enumerate()
1406                        .filter_map(|(bin, seen)| seen.then_some(bin))
1407                        .collect(),
1408                }
1409            }
1410        })
1411        .collect()
1412}
1413
1414fn subtract_classification_node_histograms(
1415    parent: &[ClassificationFeatureHistogram],
1416    child: &[ClassificationFeatureHistogram],
1417) -> Vec<ClassificationFeatureHistogram> {
1418    parent
1419        .iter()
1420        .zip(child.iter())
1421        .map(
1422            |(parent_hist, child_hist)| match (parent_hist, child_hist) {
1423                (
1424                    ClassificationFeatureHistogram::Binary {
1425                        false_counts: parent_false_counts,
1426                        true_counts: parent_true_counts,
1427                        false_size: parent_false_size,
1428                        true_size: parent_true_size,
1429                    },
1430                    ClassificationFeatureHistogram::Binary {
1431                        false_counts: child_false_counts,
1432                        true_counts: child_true_counts,
1433                        false_size: child_false_size,
1434                        true_size: child_true_size,
1435                    },
1436                ) => ClassificationFeatureHistogram::Binary {
1437                    false_counts: parent_false_counts
1438                        .iter()
1439                        .zip(child_false_counts.iter())
1440                        .map(|(parent, child)| parent - child)
1441                        .collect(),
1442                    true_counts: parent_true_counts
1443                        .iter()
1444                        .zip(child_true_counts.iter())
1445                        .map(|(parent, child)| parent - child)
1446                        .collect(),
1447                    false_size: parent_false_size - child_false_size,
1448                    true_size: parent_true_size - child_true_size,
1449                },
1450                (
1451                    ClassificationFeatureHistogram::Numeric {
1452                        bin_class_counts: parent_bin_class_counts,
1453                        ..
1454                    },
1455                    ClassificationFeatureHistogram::Numeric {
1456                        bin_class_counts: child_bin_class_counts,
1457                        ..
1458                    },
1459                ) => {
1460                    let bin_class_counts = parent_bin_class_counts
1461                        .iter()
1462                        .zip(child_bin_class_counts.iter())
1463                        .map(|(parent_counts, child_counts)| {
1464                            parent_counts
1465                                .iter()
1466                                .zip(child_counts.iter())
1467                                .map(|(parent, child)| parent - child)
1468                                .collect::<Vec<_>>()
1469                        })
1470                        .collect::<Vec<_>>();
1471                    let observed_bins = bin_class_counts
1472                        .iter()
1473                        .enumerate()
1474                        .filter_map(|(bin, counts)| {
1475                            counts.iter().any(|count| *count > 0).then_some(bin)
1476                        })
1477                        .collect::<Vec<_>>();
1478                    ClassificationFeatureHistogram::Numeric {
1479                        bin_class_counts,
1480                        observed_bins,
1481                    }
1482                }
1483                _ => unreachable!("histogram shapes must match"),
1484            },
1485        )
1486        .collect()
1487}
1488
1489#[derive(Debug, Clone)]
1490struct ObliviousLeafState {
1491    start: usize,
1492    end: usize,
1493    class_index: usize,
1494    class_counts: Vec<usize>,
1495}
1496
1497impl ObliviousLeafState {
1498    fn len(&self) -> usize {
1499        self.end - self.start
1500    }
1501}
1502
1503fn train_oblivious_structure(
1504    table: &dyn TableAccess,
1505    class_indices: &[usize],
1506    class_labels: &[f64],
1507    criterion: Criterion,
1508    parallelism: Parallelism,
1509    options: DecisionTreeOptions,
1510) -> TreeStructure {
1511    let mut row_indices: Vec<usize> = (0..table.n_rows()).collect();
1512    let total_class_counts = class_counts(&row_indices, class_indices, class_labels.len());
1513    let total_impurity = classification_impurity(&total_class_counts, row_indices.len(), criterion);
1514    let mut leaves = vec![ObliviousLeafState {
1515        start: 0,
1516        end: row_indices.len(),
1517        class_index: majority_class(&row_indices, class_indices, class_labels.len()),
1518        class_counts: total_class_counts.clone(),
1519    }];
1520    let mut splits = Vec::new();
1521
1522    for depth in 0..options.max_depth {
1523        if leaves
1524            .iter()
1525            .all(|leaf| leaf.len() < options.min_samples_split)
1526        {
1527            break;
1528        }
1529        let feature_indices = candidate_feature_indices(
1530            table.binned_feature_count(),
1531            options.max_features,
1532            node_seed(options.random_seed, depth, &[], 0x0B11_A10Cu64),
1533        );
1534        let best_split = if parallelism.enabled() {
1535            feature_indices
1536                .into_par_iter()
1537                .filter_map(|feature_index| {
1538                    score_oblivious_split(
1539                        table,
1540                        &row_indices,
1541                        class_indices,
1542                        feature_index,
1543                        &leaves,
1544                        class_labels.len(),
1545                        criterion,
1546                        options.min_samples_leaf,
1547                    )
1548                })
1549                .max_by(|left, right| left.score.total_cmp(&right.score))
1550        } else {
1551            feature_indices
1552                .into_iter()
1553                .filter_map(|feature_index| {
1554                    score_oblivious_split(
1555                        table,
1556                        &row_indices,
1557                        class_indices,
1558                        feature_index,
1559                        &leaves,
1560                        class_labels.len(),
1561                        criterion,
1562                        options.min_samples_leaf,
1563                    )
1564                })
1565                .max_by(|left, right| left.score.total_cmp(&right.score))
1566        };
1567
1568        let Some(best_split) = best_split.filter(|candidate| candidate.score > 0.0) else {
1569            break;
1570        };
1571        if table.is_canary_binned_feature(best_split.feature_index) {
1572            break;
1573        }
1574
1575        leaves = split_oblivious_leaves_in_place(
1576            table,
1577            &mut row_indices,
1578            class_indices,
1579            class_labels.len(),
1580            leaves,
1581            best_split.feature_index,
1582            best_split.threshold_bin,
1583        );
1584        splits.push(ObliviousSplit {
1585            feature_index: best_split.feature_index,
1586            threshold_bin: best_split.threshold_bin,
1587            sample_count: table.n_rows(),
1588            impurity: total_impurity,
1589            gain: best_split.score,
1590        });
1591    }
1592
1593    TreeStructure::Oblivious {
1594        splits,
1595        leaf_class_indices: leaves.iter().map(|leaf| leaf.class_index).collect(),
1596        leaf_sample_counts: leaves.iter().map(ObliviousLeafState::len).collect(),
1597        leaf_class_counts: leaves
1598            .iter()
1599            .map(|leaf| leaf.class_counts.clone())
1600            .collect(),
1601    }
1602}
1603
1604#[derive(Debug, Clone, Copy)]
1605struct ObliviousSplitCandidate {
1606    feature_index: usize,
1607    threshold_bin: u16,
1608    score: f64,
1609}
1610
1611#[allow(clippy::too_many_arguments)]
1612fn score_oblivious_split(
1613    table: &dyn TableAccess,
1614    row_indices: &[usize],
1615    class_indices: &[usize],
1616    feature_index: usize,
1617    leaves: &[ObliviousLeafState],
1618    num_classes: usize,
1619    criterion: Criterion,
1620    min_samples_leaf: usize,
1621) -> Option<ObliviousSplitCandidate> {
1622    if table.is_binary_binned_feature(feature_index) {
1623        return score_binary_oblivious_split(
1624            table,
1625            row_indices,
1626            class_indices,
1627            feature_index,
1628            leaves,
1629            num_classes,
1630            criterion,
1631            min_samples_leaf,
1632        );
1633    }
1634    if let Some(candidate) = score_numeric_oblivious_split_fast(
1635        table,
1636        row_indices,
1637        class_indices,
1638        feature_index,
1639        leaves,
1640        num_classes,
1641        criterion,
1642        min_samples_leaf,
1643    ) {
1644        return Some(candidate);
1645    }
1646    let candidate_thresholds = leaves
1647        .iter()
1648        .flat_map(|leaf| {
1649            row_indices[leaf.start..leaf.end]
1650                .iter()
1651                .map(|row_idx| table.binned_value(feature_index, *row_idx))
1652        })
1653        .collect::<BTreeSet<_>>();
1654
1655    candidate_thresholds
1656        .into_iter()
1657        .filter_map(|threshold_bin| {
1658            let score = leaves.iter().fold(0.0, |score, leaf| {
1659                let leaf_rows = &row_indices[leaf.start..leaf.end];
1660                let (left_rows, right_rows): (Vec<usize>, Vec<usize>) =
1661                    leaf_rows.iter().copied().partition(|row_idx| {
1662                        table.binned_value(feature_index, *row_idx) <= threshold_bin
1663                    });
1664
1665                if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1666                    return score;
1667                }
1668
1669                let parent_counts = leaf.class_counts.clone();
1670                let left_counts = class_counts(&left_rows, class_indices, num_classes);
1671                let right_counts = class_counts(&right_rows, class_indices, num_classes);
1672
1673                let weighted_parent_impurity = leaf.len() as f64
1674                    * classification_impurity(&parent_counts, leaf.len(), criterion);
1675                let weighted_children_impurity = left_rows.len() as f64
1676                    * classification_impurity(&left_counts, left_rows.len(), criterion)
1677                    + right_rows.len() as f64
1678                        * classification_impurity(&right_counts, right_rows.len(), criterion);
1679
1680                score + (weighted_parent_impurity - weighted_children_impurity)
1681            });
1682
1683            (score > 0.0).then_some(ObliviousSplitCandidate {
1684                feature_index,
1685                threshold_bin,
1686                score,
1687            })
1688        })
1689        .max_by(|left, right| left.score.total_cmp(&right.score))
1690}
1691
1692fn split_oblivious_leaves_in_place(
1693    table: &dyn TableAccess,
1694    row_indices: &mut [usize],
1695    class_indices: &[usize],
1696    num_classes: usize,
1697    leaves: Vec<ObliviousLeafState>,
1698    feature_index: usize,
1699    threshold_bin: u16,
1700) -> Vec<ObliviousLeafState> {
1701    let mut next_leaves = Vec::with_capacity(leaves.len() * 2);
1702    for leaf in leaves {
1703        let left_count = partition_rows_for_binary_split(
1704            table,
1705            feature_index,
1706            threshold_bin,
1707            &mut row_indices[leaf.start..leaf.end],
1708        );
1709        let mid = leaf.start + left_count;
1710        let mut left_class_counts = vec![0usize; num_classes];
1711        let mut right_class_counts = vec![0usize; num_classes];
1712        for row_idx in &row_indices[leaf.start..mid] {
1713            left_class_counts[class_indices[*row_idx]] += 1;
1714        }
1715        for row_idx in &row_indices[mid..leaf.end] {
1716            right_class_counts[class_indices[*row_idx]] += 1;
1717        }
1718        let left_class_index = if left_count == 0 {
1719            leaf.class_index
1720        } else {
1721            majority_class_from_counts(&left_class_counts)
1722        };
1723        let right_class_index = if mid == leaf.end {
1724            leaf.class_index
1725        } else {
1726            majority_class_from_counts(&right_class_counts)
1727        };
1728        next_leaves.push(ObliviousLeafState {
1729            start: leaf.start,
1730            end: mid,
1731            class_index: left_class_index,
1732            class_counts: left_class_counts,
1733        });
1734        next_leaves.push(ObliviousLeafState {
1735            start: mid,
1736            end: leaf.end,
1737            class_index: right_class_index,
1738            class_counts: right_class_counts,
1739        });
1740    }
1741    next_leaves
1742}
1743
1744#[allow(dead_code)]
1745fn score_split(
1746    context: &SplitScoringContext<'_>,
1747    feature_index: usize,
1748    rows: &[usize],
1749    algorithm: DecisionTreeAlgorithm,
1750) -> Option<SplitCandidate> {
1751    match algorithm {
1752        DecisionTreeAlgorithm::Id3 => score_multiway_split(
1753            context,
1754            feature_index,
1755            rows,
1756            MultiwayMetric::InformationGain,
1757        ),
1758        DecisionTreeAlgorithm::C45 => {
1759            score_multiway_split(context, feature_index, rows, MultiwayMetric::GainRatio)
1760        }
1761        DecisionTreeAlgorithm::Cart => score_cart_split(context, feature_index, rows),
1762        DecisionTreeAlgorithm::Randomized => score_randomized_split(context, feature_index, rows),
1763        DecisionTreeAlgorithm::Oblivious => None,
1764    }
1765}
1766
1767#[allow(dead_code)]
1768fn score_multiway_split(
1769    context: &SplitScoringContext<'_>,
1770    feature_index: usize,
1771    rows: &[usize],
1772    metric: MultiwayMetric,
1773) -> Option<SplitCandidate> {
1774    let grouped_rows = if context.table.is_binary_binned_feature(feature_index) {
1775        let (false_rows, true_rows): (Vec<usize>, Vec<usize>) =
1776            rows.iter().copied().partition(|row_idx| {
1777                !context
1778                    .table
1779                    .binned_boolean_value(feature_index, *row_idx)
1780                    .expect("binary feature must expose boolean values")
1781            });
1782        [(0u16, false_rows), (1u16, true_rows)]
1783            .into_iter()
1784            .filter(|(_bin, group_rows)| !group_rows.is_empty())
1785            .collect::<BTreeMap<_, _>>()
1786    } else {
1787        rows.iter()
1788            .fold(BTreeMap::<u16, Vec<usize>>::new(), |mut groups, row_idx| {
1789                groups
1790                    .entry(context.table.binned_value(feature_index, *row_idx))
1791                    .or_default()
1792                    .push(*row_idx);
1793                groups
1794            })
1795    };
1796
1797    if grouped_rows.len() <= 1
1798        || grouped_rows
1799            .values()
1800            .any(|group| group.len() < context.min_samples_leaf)
1801    {
1802        return None;
1803    }
1804
1805    let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
1806    let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
1807    let weighted_child_impurity = grouped_rows
1808        .values()
1809        .map(|group_rows| {
1810            let counts = class_counts(group_rows, context.class_indices, context.num_classes);
1811            (group_rows.len() as f64 / rows.len() as f64)
1812                * classification_impurity(&counts, group_rows.len(), context.criterion)
1813        })
1814        .sum::<f64>();
1815    let information_gain = parent_impurity - weighted_child_impurity;
1816
1817    let score = match metric {
1818        MultiwayMetric::InformationGain => information_gain,
1819        MultiwayMetric::GainRatio => {
1820            let split_info = grouped_rows
1821                .values()
1822                .map(|group_rows| {
1823                    let probability = group_rows.len() as f64 / rows.len() as f64;
1824                    -probability * probability.log2()
1825                })
1826                .sum::<f64>();
1827
1828            if split_info == 0.0 {
1829                return None;
1830            }
1831
1832            information_gain / split_info
1833        }
1834    };
1835
1836    Some(SplitCandidate::Multiway {
1837        feature_index,
1838        score,
1839        branches: grouped_rows.into_iter().collect(),
1840    })
1841}
1842
1843fn score_multiway_split_choice(
1844    context: &SplitScoringContext<'_>,
1845    feature_index: usize,
1846    rows: &[usize],
1847    metric: MultiwayMetric,
1848) -> Option<MultiwaySplitChoice> {
1849    let grouped_counts = if context.table.is_binary_binned_feature(feature_index) {
1850        let mut false_counts = vec![0usize; context.num_classes];
1851        let mut true_counts = vec![0usize; context.num_classes];
1852        let mut false_size = 0usize;
1853        let mut true_size = 0usize;
1854        for row_idx in rows {
1855            let class_index = context.class_indices[*row_idx];
1856            if !context
1857                .table
1858                .binned_boolean_value(feature_index, *row_idx)
1859                .expect("binary feature must expose boolean values")
1860            {
1861                false_counts[class_index] += 1;
1862                false_size += 1;
1863            } else {
1864                true_counts[class_index] += 1;
1865                true_size += 1;
1866            }
1867        }
1868        [
1869            (0u16, (false_size, false_counts)),
1870            (1u16, (true_size, true_counts)),
1871        ]
1872        .into_iter()
1873        .filter(|(_, (size, _))| *size > 0)
1874        .collect::<Vec<_>>()
1875    } else {
1876        let mut grouped = BTreeMap::<u16, (usize, Vec<usize>)>::new();
1877        for row_idx in rows {
1878            let bin = context.table.binned_value(feature_index, *row_idx);
1879            let entry = grouped
1880                .entry(bin)
1881                .or_insert_with(|| (0usize, vec![0usize; context.num_classes]));
1882            entry.0 += 1;
1883            entry.1[context.class_indices[*row_idx]] += 1;
1884        }
1885        grouped.into_iter().collect::<Vec<_>>()
1886    };
1887
1888    if grouped_counts.len() <= 1
1889        || grouped_counts
1890            .iter()
1891            .any(|(_, (group_size, _))| *group_size < context.min_samples_leaf)
1892    {
1893        return None;
1894    }
1895
1896    let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
1897    let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
1898    let weighted_child_impurity = grouped_counts
1899        .iter()
1900        .map(|(_, (group_size, counts))| {
1901            (*group_size as f64 / rows.len() as f64)
1902                * classification_impurity(counts, *group_size, context.criterion)
1903        })
1904        .sum::<f64>();
1905    let information_gain = parent_impurity - weighted_child_impurity;
1906
1907    let score = match metric {
1908        MultiwayMetric::InformationGain => information_gain,
1909        MultiwayMetric::GainRatio => {
1910            let split_info = grouped_counts
1911                .iter()
1912                .map(|(_, (group_size, _))| {
1913                    let probability = *group_size as f64 / rows.len() as f64;
1914                    -probability * probability.log2()
1915                })
1916                .sum::<f64>();
1917            if split_info == 0.0 {
1918                return None;
1919            }
1920            information_gain / split_info
1921        }
1922    };
1923
1924    Some(MultiwaySplitChoice {
1925        feature_index,
1926        score,
1927        branch_bins: grouped_counts.into_iter().map(|(bin, _)| bin).collect(),
1928    })
1929}
1930
1931#[allow(dead_code)]
1932fn score_cart_split(
1933    context: &SplitScoringContext<'_>,
1934    feature_index: usize,
1935    rows: &[usize],
1936) -> Option<SplitCandidate> {
1937    if context.table.is_binary_binned_feature(feature_index) {
1938        return score_binary_cart_split(context, feature_index, rows);
1939    }
1940    if let Some(candidate) = score_numeric_cart_split_fast(context, feature_index, rows) {
1941        return Some(candidate);
1942    }
1943    let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
1944    let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
1945
1946    rows.iter()
1947        .map(|row_idx| context.table.binned_value(feature_index, *row_idx))
1948        .collect::<BTreeSet<_>>()
1949        .into_iter()
1950        .filter_map(|threshold_bin| {
1951            let (left_rows, right_rows): (Vec<usize>, Vec<usize>) =
1952                rows.iter().copied().partition(|row_idx| {
1953                    context.table.binned_value(feature_index, *row_idx) <= threshold_bin
1954                });
1955
1956            if left_rows.len() < context.min_samples_leaf
1957                || right_rows.len() < context.min_samples_leaf
1958            {
1959                return None;
1960            }
1961
1962            let left_counts = class_counts(&left_rows, context.class_indices, context.num_classes);
1963            let right_counts =
1964                class_counts(&right_rows, context.class_indices, context.num_classes);
1965            let weighted_impurity = (left_rows.len() as f64 / rows.len() as f64)
1966                * classification_impurity(&left_counts, left_rows.len(), context.criterion)
1967                + (right_rows.len() as f64 / rows.len() as f64)
1968                    * classification_impurity(&right_counts, right_rows.len(), context.criterion);
1969
1970            Some(SplitCandidate::Binary {
1971                feature_index,
1972                score: parent_impurity - weighted_impurity,
1973                threshold_bin,
1974                left_rows,
1975                right_rows,
1976            })
1977        })
1978        .max_by(|left, right| split_score(left).total_cmp(&split_score(right)))
1979}
1980
1981#[allow(dead_code)]
1982fn score_randomized_split(
1983    context: &SplitScoringContext<'_>,
1984    feature_index: usize,
1985    rows: &[usize],
1986) -> Option<SplitCandidate> {
1987    if context.table.is_binary_binned_feature(feature_index) {
1988        return score_binary_cart_split(context, feature_index, rows);
1989    }
1990    if let Some(candidate) = score_numeric_randomized_split_fast(context, feature_index, rows) {
1991        return Some(candidate);
1992    }
1993
1994    let candidate_thresholds = rows
1995        .iter()
1996        .map(|row_idx| context.table.binned_value(feature_index, *row_idx))
1997        .collect::<BTreeSet<_>>()
1998        .into_iter()
1999        .collect::<Vec<_>>();
2000    let threshold_bin =
2001        choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xC1A551F1u64)?;
2002
2003    let (left_rows, right_rows): (Vec<usize>, Vec<usize>) = rows
2004        .iter()
2005        .copied()
2006        .partition(|row_idx| context.table.binned_value(feature_index, *row_idx) <= threshold_bin);
2007
2008    if left_rows.len() < context.min_samples_leaf || right_rows.len() < context.min_samples_leaf {
2009        return None;
2010    }
2011
2012    let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
2013    let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
2014    let left_counts = class_counts(&left_rows, context.class_indices, context.num_classes);
2015    let right_counts = class_counts(&right_rows, context.class_indices, context.num_classes);
2016    let weighted_impurity = (left_rows.len() as f64 / rows.len() as f64)
2017        * classification_impurity(&left_counts, left_rows.len(), context.criterion)
2018        + (right_rows.len() as f64 / rows.len() as f64)
2019            * classification_impurity(&right_counts, right_rows.len(), context.criterion);
2020
2021    Some(SplitCandidate::Binary {
2022        feature_index,
2023        score: parent_impurity - weighted_impurity,
2024        threshold_bin,
2025        left_rows,
2026        right_rows,
2027    })
2028}
2029
2030#[allow(dead_code)]
2031fn score_binary_cart_split(
2032    context: &SplitScoringContext<'_>,
2033    feature_index: usize,
2034    rows: &[usize],
2035) -> Option<SplitCandidate> {
2036    let (left_rows, right_rows): (Vec<usize>, Vec<usize>) =
2037        rows.iter().copied().partition(|row_idx| {
2038            !context
2039                .table
2040                .binned_boolean_value(feature_index, *row_idx)
2041                .expect("binary feature must expose boolean values")
2042        });
2043
2044    if left_rows.len() < context.min_samples_leaf || right_rows.len() < context.min_samples_leaf {
2045        return None;
2046    }
2047
2048    let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
2049    let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
2050    let left_counts = class_counts(&left_rows, context.class_indices, context.num_classes);
2051    let right_counts = class_counts(&right_rows, context.class_indices, context.num_classes);
2052    let weighted_impurity = (left_rows.len() as f64 / rows.len() as f64)
2053        * classification_impurity(&left_counts, left_rows.len(), context.criterion)
2054        + (right_rows.len() as f64 / rows.len() as f64)
2055            * classification_impurity(&right_counts, right_rows.len(), context.criterion);
2056
2057    Some(SplitCandidate::Binary {
2058        feature_index,
2059        score: parent_impurity - weighted_impurity,
2060        threshold_bin: 0,
2061        left_rows,
2062        right_rows,
2063    })
2064}
2065
2066#[allow(clippy::too_many_arguments)]
2067fn score_binary_oblivious_split(
2068    table: &dyn TableAccess,
2069    row_indices: &[usize],
2070    class_indices: &[usize],
2071    feature_index: usize,
2072    leaves: &[ObliviousLeafState],
2073    num_classes: usize,
2074    criterion: Criterion,
2075    min_samples_leaf: usize,
2076) -> Option<ObliviousSplitCandidate> {
2077    let mut score = 0.0;
2078    let mut found_valid = false;
2079
2080    for leaf in leaves {
2081        let mut left_counts = vec![0usize; num_classes];
2082        let mut left_size = 0usize;
2083        for row_idx in &row_indices[leaf.start..leaf.end] {
2084            if !table
2085                .binned_boolean_value(feature_index, *row_idx)
2086                .expect("binary feature must expose boolean values")
2087            {
2088                left_counts[class_indices[*row_idx]] += 1;
2089                left_size += 1;
2090            }
2091        }
2092        let right_size = leaf.len() - left_size;
2093        if left_size < min_samples_leaf || right_size < min_samples_leaf {
2094            continue;
2095        }
2096        found_valid = true;
2097        let right_counts = leaf
2098            .class_counts
2099            .iter()
2100            .zip(left_counts.iter())
2101            .map(|(parent, left)| parent - left)
2102            .collect::<Vec<_>>();
2103        let weighted_parent_impurity =
2104            leaf.len() as f64 * classification_impurity(&leaf.class_counts, leaf.len(), criterion);
2105        let weighted_children_impurity = left_size as f64
2106            * classification_impurity(&left_counts, left_size, criterion)
2107            + right_size as f64 * classification_impurity(&right_counts, right_size, criterion);
2108        score += weighted_parent_impurity - weighted_children_impurity;
2109    }
2110
2111    (found_valid && score > 0.0).then_some(ObliviousSplitCandidate {
2112        feature_index,
2113        threshold_bin: 0,
2114        score,
2115    })
2116}
2117
2118#[allow(clippy::too_many_arguments)]
2119fn score_numeric_oblivious_split_fast(
2120    table: &dyn TableAccess,
2121    row_indices: &[usize],
2122    class_indices: &[usize],
2123    feature_index: usize,
2124    leaves: &[ObliviousLeafState],
2125    num_classes: usize,
2126    criterion: Criterion,
2127    min_samples_leaf: usize,
2128) -> Option<ObliviousSplitCandidate> {
2129    let bin_cap = table.numeric_bin_cap();
2130    if bin_cap == 0 {
2131        return None;
2132    }
2133
2134    let mut threshold_scores = vec![0.0; bin_cap];
2135    let mut observed_any = false;
2136
2137    for leaf in leaves {
2138        let mut bin_class_counts = vec![vec![0usize; num_classes]; bin_cap];
2139        let mut observed_bins = vec![false; bin_cap];
2140        for row_idx in &row_indices[leaf.start..leaf.end] {
2141            let bin = table.binned_value(feature_index, *row_idx) as usize;
2142            if bin >= bin_cap {
2143                return None;
2144            }
2145            bin_class_counts[bin][class_indices[*row_idx]] += 1;
2146            observed_bins[bin] = true;
2147        }
2148
2149        let observed_bins: Vec<usize> = observed_bins
2150            .into_iter()
2151            .enumerate()
2152            .filter_map(|(bin, seen)| seen.then_some(bin))
2153            .collect();
2154        if observed_bins.len() <= 1 {
2155            continue;
2156        }
2157        observed_any = true;
2158
2159        let parent_weighted_impurity =
2160            leaf.len() as f64 * classification_impurity(&leaf.class_counts, leaf.len(), criterion);
2161        let mut left_counts = vec![0usize; num_classes];
2162        let mut left_size = 0usize;
2163
2164        for &bin in &observed_bins {
2165            for class_index in 0..num_classes {
2166                left_counts[class_index] += bin_class_counts[bin][class_index];
2167            }
2168            left_size += bin_class_counts[bin].iter().sum::<usize>();
2169            let right_size = leaf.len() - left_size;
2170
2171            if left_size < min_samples_leaf || right_size < min_samples_leaf {
2172                continue;
2173            }
2174
2175            let right_counts = leaf
2176                .class_counts
2177                .iter()
2178                .zip(left_counts.iter())
2179                .map(|(parent, left)| parent - left)
2180                .collect::<Vec<_>>();
2181            let weighted_children_impurity = left_size as f64
2182                * classification_impurity(&left_counts, left_size, criterion)
2183                + right_size as f64 * classification_impurity(&right_counts, right_size, criterion);
2184            threshold_scores[bin] += parent_weighted_impurity - weighted_children_impurity;
2185        }
2186    }
2187
2188    if !observed_any {
2189        return None;
2190    }
2191
2192    threshold_scores
2193        .into_iter()
2194        .enumerate()
2195        .filter(|(_, score)| *score > 0.0)
2196        .max_by(|left, right| left.1.total_cmp(&right.1))
2197        .map(|(threshold_bin, score)| ObliviousSplitCandidate {
2198            feature_index,
2199            threshold_bin: threshold_bin as u16,
2200            score,
2201        })
2202}
2203
2204#[allow(dead_code)]
2205fn score_numeric_cart_split_fast(
2206    context: &SplitScoringContext<'_>,
2207    feature_index: usize,
2208    rows: &[usize],
2209) -> Option<SplitCandidate> {
2210    let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
2211    let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
2212    let bin_cap = context.table.numeric_bin_cap();
2213    if bin_cap == 0 {
2214        return None;
2215    }
2216
2217    let mut bin_class_counts = vec![vec![0usize; context.num_classes]; bin_cap];
2218    let mut observed_bins = vec![false; bin_cap];
2219    for row_idx in rows {
2220        let bin = context.table.binned_value(feature_index, *row_idx) as usize;
2221        if bin >= bin_cap {
2222            return None;
2223        }
2224        bin_class_counts[bin][context.class_indices[*row_idx]] += 1;
2225        observed_bins[bin] = true;
2226    }
2227
2228    let observed_bins: Vec<usize> = observed_bins
2229        .into_iter()
2230        .enumerate()
2231        .filter_map(|(bin, seen)| seen.then_some(bin))
2232        .collect();
2233    if observed_bins.len() <= 1 {
2234        return None;
2235    }
2236
2237    let mut left_counts = vec![0usize; context.num_classes];
2238    let mut left_size = 0usize;
2239    let mut best_threshold = None;
2240    let mut best_score = f64::NEG_INFINITY;
2241
2242    for &bin in &observed_bins {
2243        for class_index in 0..context.num_classes {
2244            left_counts[class_index] += bin_class_counts[bin][class_index];
2245        }
2246        left_size += bin_class_counts[bin].iter().sum::<usize>();
2247        let right_size = rows.len() - left_size;
2248
2249        if left_size < context.min_samples_leaf || right_size < context.min_samples_leaf {
2250            continue;
2251        }
2252
2253        let right_counts = parent_counts
2254            .iter()
2255            .zip(left_counts.iter())
2256            .map(|(parent, left)| parent - left)
2257            .collect::<Vec<_>>();
2258        let weighted_impurity = (left_size as f64 / rows.len() as f64)
2259            * classification_impurity(&left_counts, left_size, context.criterion)
2260            + (right_size as f64 / rows.len() as f64)
2261                * classification_impurity(&right_counts, right_size, context.criterion);
2262        let score = parent_impurity - weighted_impurity;
2263        if score > best_score {
2264            best_score = score;
2265            best_threshold = Some(bin as u16);
2266        }
2267    }
2268
2269    let threshold_bin = best_threshold?;
2270    let (left_rows, right_rows): (Vec<usize>, Vec<usize>) = rows
2271        .iter()
2272        .copied()
2273        .partition(|row_idx| context.table.binned_value(feature_index, *row_idx) <= threshold_bin);
2274
2275    Some(SplitCandidate::Binary {
2276        feature_index,
2277        score: best_score,
2278        threshold_bin,
2279        left_rows,
2280        right_rows,
2281    })
2282}
2283
2284#[allow(dead_code)]
2285fn score_numeric_randomized_split_fast(
2286    context: &SplitScoringContext<'_>,
2287    feature_index: usize,
2288    rows: &[usize],
2289) -> Option<SplitCandidate> {
2290    let bin_cap = context.table.numeric_bin_cap();
2291    if bin_cap == 0 {
2292        return None;
2293    }
2294    let mut observed_bins = vec![false; bin_cap];
2295    for row_idx in rows {
2296        let bin = context.table.binned_value(feature_index, *row_idx) as usize;
2297        if bin >= bin_cap {
2298            return None;
2299        }
2300        observed_bins[bin] = true;
2301    }
2302    let candidate_thresholds = observed_bins
2303        .into_iter()
2304        .enumerate()
2305        .filter_map(|(bin, seen)| seen.then_some(bin as u16))
2306        .collect::<Vec<_>>();
2307    let threshold_bin =
2308        choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xC1A551F1u64)?;
2309
2310    let (left_rows, right_rows): (Vec<usize>, Vec<usize>) = rows
2311        .iter()
2312        .copied()
2313        .partition(|row_idx| context.table.binned_value(feature_index, *row_idx) <= threshold_bin);
2314
2315    if left_rows.len() < context.min_samples_leaf || right_rows.len() < context.min_samples_leaf {
2316        return None;
2317    }
2318
2319    let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
2320    let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
2321    let left_counts = class_counts(&left_rows, context.class_indices, context.num_classes);
2322    let right_counts = class_counts(&right_rows, context.class_indices, context.num_classes);
2323    let weighted_impurity = (left_rows.len() as f64 / rows.len() as f64)
2324        * classification_impurity(&left_counts, left_rows.len(), context.criterion)
2325        + (right_rows.len() as f64 / rows.len() as f64)
2326            * classification_impurity(&right_counts, right_rows.len(), context.criterion);
2327
2328    Some(SplitCandidate::Binary {
2329        feature_index,
2330        score: parent_impurity - weighted_impurity,
2331        threshold_bin,
2332        left_rows,
2333        right_rows,
2334    })
2335}
2336
2337fn class_counts(rows: &[usize], class_indices: &[usize], num_classes: usize) -> Vec<usize> {
2338    rows.iter()
2339        .fold(vec![0usize; num_classes], |mut counts, row_idx| {
2340            counts[class_indices[*row_idx]] += 1;
2341            counts
2342        })
2343}
2344
2345fn majority_class(rows: &[usize], class_indices: &[usize], num_classes: usize) -> usize {
2346    majority_class_from_counts(&class_counts(rows, class_indices, num_classes))
2347}
2348
2349fn majority_class_from_counts(counts: &[usize]) -> usize {
2350    counts
2351        .iter()
2352        .copied()
2353        .enumerate()
2354        .max_by(|left, right| left.1.cmp(&right.1).then_with(|| right.0.cmp(&left.0)))
2355        .map(|(class_index, _count)| class_index)
2356        .unwrap_or(0)
2357}
2358
2359fn is_pure(rows: &[usize], class_indices: &[usize]) -> bool {
2360    rows.first().is_none_or(|first_row| {
2361        rows.iter()
2362            .all(|row_idx| class_indices[*row_idx] == class_indices[*first_row])
2363    })
2364}
2365
2366fn entropy(counts: &[usize], total: usize) -> f64 {
2367    counts
2368        .iter()
2369        .copied()
2370        .filter(|count| *count > 0)
2371        .map(|count| {
2372            let probability = count as f64 / total as f64;
2373            -probability * probability.log2()
2374        })
2375        .sum()
2376}
2377
2378fn gini(counts: &[usize], total: usize) -> f64 {
2379    1.0 - counts
2380        .iter()
2381        .copied()
2382        .map(|count| {
2383            let probability = count as f64 / total as f64;
2384            probability * probability
2385        })
2386        .sum::<f64>()
2387}
2388
2389fn classification_impurity(counts: &[usize], total: usize, criterion: Criterion) -> f64 {
2390    match criterion {
2391        Criterion::Entropy => entropy(counts, total),
2392        Criterion::Gini => gini(counts, total),
2393        _ => unreachable!("classification impurity only supports gini or entropy"),
2394    }
2395}
2396
2397#[allow(dead_code)]
2398fn split_score(candidate: &SplitCandidate) -> f64 {
2399    match candidate {
2400        SplitCandidate::Multiway { score, .. } | SplitCandidate::Binary { score, .. } => *score,
2401    }
2402}
2403
2404#[allow(dead_code)]
2405fn score_binary_split_choice(
2406    context: &SplitScoringContext<'_>,
2407    feature_index: usize,
2408    rows: &[usize],
2409    algorithm: DecisionTreeAlgorithm,
2410) -> Option<BinarySplitChoice> {
2411    match algorithm {
2412        DecisionTreeAlgorithm::Cart => {
2413            if context.table.is_binary_binned_feature(feature_index) {
2414                score_binary_cart_split_choice(context, feature_index, rows)
2415            } else {
2416                score_numeric_cart_split_choice_fast(context, feature_index, rows)
2417            }
2418        }
2419        DecisionTreeAlgorithm::Randomized => {
2420            if context.table.is_binary_binned_feature(feature_index) {
2421                score_binary_cart_split_choice(context, feature_index, rows)
2422            } else {
2423                score_numeric_randomized_split_choice_fast(context, feature_index, rows)
2424            }
2425        }
2426        _ => None,
2427    }
2428}
2429
2430fn score_binary_split_choice_from_hist(
2431    context: &SplitScoringContext<'_>,
2432    histogram: &ClassificationFeatureHistogram,
2433    feature_index: usize,
2434    rows: &[usize],
2435    parent_counts: &[usize],
2436    algorithm: DecisionTreeAlgorithm,
2437) -> Option<BinarySplitChoice> {
2438    match (algorithm, histogram) {
2439        (
2440            DecisionTreeAlgorithm::Cart,
2441            ClassificationFeatureHistogram::Binary {
2442                false_counts,
2443                true_counts,
2444                false_size,
2445                true_size,
2446            },
2447        ) => score_binary_cart_split_choice_from_counts(
2448            context,
2449            feature_index,
2450            parent_counts,
2451            false_counts,
2452            *false_size,
2453            true_counts,
2454            *true_size,
2455        ),
2456        (
2457            DecisionTreeAlgorithm::Cart,
2458            ClassificationFeatureHistogram::Numeric {
2459                bin_class_counts,
2460                observed_bins,
2461            },
2462        ) => score_numeric_cart_split_choice_from_hist(
2463            context,
2464            feature_index,
2465            parent_counts,
2466            rows.len(),
2467            bin_class_counts,
2468            observed_bins,
2469        ),
2470        (
2471            DecisionTreeAlgorithm::Randomized,
2472            ClassificationFeatureHistogram::Binary {
2473                false_counts,
2474                true_counts,
2475                false_size,
2476                true_size,
2477            },
2478        ) => score_binary_cart_split_choice_from_counts(
2479            context,
2480            feature_index,
2481            parent_counts,
2482            false_counts,
2483            *false_size,
2484            true_counts,
2485            *true_size,
2486        ),
2487        (
2488            DecisionTreeAlgorithm::Randomized,
2489            ClassificationFeatureHistogram::Numeric { observed_bins, .. },
2490        ) => score_numeric_randomized_split_choice_from_hist(
2491            context,
2492            feature_index,
2493            rows,
2494            parent_counts,
2495            observed_bins,
2496            histogram,
2497        ),
2498        _ => None,
2499    }
2500}
2501
2502fn score_binary_cart_split_choice_from_counts(
2503    context: &SplitScoringContext<'_>,
2504    feature_index: usize,
2505    parent_counts: &[usize],
2506    left_counts: &[usize],
2507    left_size: usize,
2508    right_counts: &[usize],
2509    right_size: usize,
2510) -> Option<BinarySplitChoice> {
2511    if left_size < context.min_samples_leaf || right_size < context.min_samples_leaf {
2512        return None;
2513    }
2514    let parent_impurity =
2515        classification_impurity(parent_counts, left_size + right_size, context.criterion);
2516    let weighted_impurity = (left_size as f64 / (left_size + right_size) as f64)
2517        * classification_impurity(left_counts, left_size, context.criterion)
2518        + (right_size as f64 / (left_size + right_size) as f64)
2519            * classification_impurity(right_counts, right_size, context.criterion);
2520    Some(BinarySplitChoice {
2521        feature_index,
2522        score: parent_impurity - weighted_impurity,
2523        threshold_bin: 0,
2524    })
2525}
2526
2527fn score_numeric_cart_split_choice_from_hist(
2528    context: &SplitScoringContext<'_>,
2529    feature_index: usize,
2530    parent_counts: &[usize],
2531    row_count: usize,
2532    bin_class_counts: &[Vec<usize>],
2533    observed_bins: &[usize],
2534) -> Option<BinarySplitChoice> {
2535    if observed_bins.len() <= 1 {
2536        return None;
2537    }
2538    let parent_impurity = classification_impurity(parent_counts, row_count, context.criterion);
2539    let mut left_counts = vec![0usize; context.num_classes];
2540    let mut left_size = 0usize;
2541    let mut best_threshold = None;
2542    let mut best_score = f64::NEG_INFINITY;
2543
2544    for &bin in observed_bins {
2545        for class_index in 0..context.num_classes {
2546            left_counts[class_index] += bin_class_counts[bin][class_index];
2547        }
2548        left_size += bin_class_counts[bin].iter().sum::<usize>();
2549        let right_size = row_count - left_size;
2550        if left_size < context.min_samples_leaf || right_size < context.min_samples_leaf {
2551            continue;
2552        }
2553        let right_counts = parent_counts
2554            .iter()
2555            .zip(left_counts.iter())
2556            .map(|(parent, left)| parent - left)
2557            .collect::<Vec<_>>();
2558        let weighted_impurity = (left_size as f64 / row_count as f64)
2559            * classification_impurity(&left_counts, left_size, context.criterion)
2560            + (right_size as f64 / row_count as f64)
2561                * classification_impurity(&right_counts, right_size, context.criterion);
2562        let score = parent_impurity - weighted_impurity;
2563        if score > best_score {
2564            best_score = score;
2565            best_threshold = Some(bin as u16);
2566        }
2567    }
2568
2569    best_threshold.map(|threshold_bin| BinarySplitChoice {
2570        feature_index,
2571        score: best_score,
2572        threshold_bin,
2573    })
2574}
2575
2576fn score_numeric_randomized_split_choice_from_hist(
2577    context: &SplitScoringContext<'_>,
2578    feature_index: usize,
2579    rows: &[usize],
2580    parent_counts: &[usize],
2581    observed_bins: &[usize],
2582    histogram: &ClassificationFeatureHistogram,
2583) -> Option<BinarySplitChoice> {
2584    let candidate_thresholds = observed_bins
2585        .iter()
2586        .copied()
2587        .map(|bin| bin as u16)
2588        .collect::<Vec<_>>();
2589    let threshold_bin =
2590        choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xC1A551F1u64)?;
2591    let ClassificationFeatureHistogram::Numeric {
2592        bin_class_counts, ..
2593    } = histogram
2594    else {
2595        unreachable!("randomized numeric histogram must be numeric");
2596    };
2597    let mut left_counts = vec![0usize; context.num_classes];
2598    let mut left_size = 0usize;
2599    for bin in 0..=threshold_bin as usize {
2600        if bin >= bin_class_counts.len() {
2601            break;
2602        }
2603        for class_index in 0..context.num_classes {
2604            left_counts[class_index] += bin_class_counts[bin][class_index];
2605        }
2606        left_size += bin_class_counts[bin].iter().sum::<usize>();
2607    }
2608    let row_count = rows.len();
2609    let right_size = row_count - left_size;
2610    if left_size < context.min_samples_leaf || right_size < context.min_samples_leaf {
2611        return None;
2612    }
2613    let right_counts = parent_counts
2614        .iter()
2615        .zip(left_counts.iter())
2616        .map(|(parent, left)| parent - left)
2617        .collect::<Vec<_>>();
2618    let parent_impurity = classification_impurity(parent_counts, row_count, context.criterion);
2619    let weighted_impurity = (left_size as f64 / row_count as f64)
2620        * classification_impurity(&left_counts, left_size, context.criterion)
2621        + (right_size as f64 / row_count as f64)
2622            * classification_impurity(&right_counts, right_size, context.criterion);
2623    Some(BinarySplitChoice {
2624        feature_index,
2625        score: parent_impurity - weighted_impurity,
2626        threshold_bin,
2627    })
2628}
2629
2630#[allow(dead_code)]
2631fn score_binary_cart_split_choice(
2632    context: &SplitScoringContext<'_>,
2633    feature_index: usize,
2634    rows: &[usize],
2635) -> Option<BinarySplitChoice> {
2636    let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
2637    let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
2638    let mut left_counts = vec![0usize; context.num_classes];
2639    let mut left_size = 0usize;
2640
2641    for row_idx in rows {
2642        if !context
2643            .table
2644            .binned_boolean_value(feature_index, *row_idx)
2645            .expect("binary feature must expose boolean values")
2646        {
2647            left_counts[context.class_indices[*row_idx]] += 1;
2648            left_size += 1;
2649        }
2650    }
2651
2652    let right_size = rows.len() - left_size;
2653    if left_size < context.min_samples_leaf || right_size < context.min_samples_leaf {
2654        return None;
2655    }
2656
2657    let right_counts = parent_counts
2658        .iter()
2659        .zip(left_counts.iter())
2660        .map(|(parent, left)| parent - left)
2661        .collect::<Vec<_>>();
2662    let weighted_impurity = (left_size as f64 / rows.len() as f64)
2663        * classification_impurity(&left_counts, left_size, context.criterion)
2664        + (right_size as f64 / rows.len() as f64)
2665            * classification_impurity(&right_counts, right_size, context.criterion);
2666
2667    Some(BinarySplitChoice {
2668        feature_index,
2669        score: parent_impurity - weighted_impurity,
2670        threshold_bin: 0,
2671    })
2672}
2673
2674#[allow(dead_code)]
2675fn score_numeric_cart_split_choice_fast(
2676    context: &SplitScoringContext<'_>,
2677    feature_index: usize,
2678    rows: &[usize],
2679) -> Option<BinarySplitChoice> {
2680    let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
2681    let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
2682    let bin_cap = context.table.numeric_bin_cap();
2683    if bin_cap == 0 {
2684        return None;
2685    }
2686
2687    let mut bin_class_counts = vec![vec![0usize; context.num_classes]; bin_cap];
2688    let mut observed_bins = vec![false; bin_cap];
2689    for row_idx in rows {
2690        let bin = context.table.binned_value(feature_index, *row_idx) as usize;
2691        if bin >= bin_cap {
2692            return None;
2693        }
2694        bin_class_counts[bin][context.class_indices[*row_idx]] += 1;
2695        observed_bins[bin] = true;
2696    }
2697
2698    let observed_bins: Vec<usize> = observed_bins
2699        .into_iter()
2700        .enumerate()
2701        .filter_map(|(bin, seen)| seen.then_some(bin))
2702        .collect();
2703    if observed_bins.len() <= 1 {
2704        return None;
2705    }
2706
2707    let mut left_counts = vec![0usize; context.num_classes];
2708    let mut left_size = 0usize;
2709    let mut best_threshold = None;
2710    let mut best_score = f64::NEG_INFINITY;
2711
2712    for &bin in &observed_bins {
2713        for class_index in 0..context.num_classes {
2714            left_counts[class_index] += bin_class_counts[bin][class_index];
2715        }
2716        left_size += bin_class_counts[bin].iter().sum::<usize>();
2717        let right_size = rows.len() - left_size;
2718
2719        if left_size < context.min_samples_leaf || right_size < context.min_samples_leaf {
2720            continue;
2721        }
2722
2723        let right_counts = parent_counts
2724            .iter()
2725            .zip(left_counts.iter())
2726            .map(|(parent, left)| parent - left)
2727            .collect::<Vec<_>>();
2728        let weighted_impurity = (left_size as f64 / rows.len() as f64)
2729            * classification_impurity(&left_counts, left_size, context.criterion)
2730            + (right_size as f64 / rows.len() as f64)
2731                * classification_impurity(&right_counts, right_size, context.criterion);
2732        let score = parent_impurity - weighted_impurity;
2733        if score > best_score {
2734            best_score = score;
2735            best_threshold = Some(bin as u16);
2736        }
2737    }
2738
2739    best_threshold.map(|threshold_bin| BinarySplitChoice {
2740        feature_index,
2741        score: best_score,
2742        threshold_bin,
2743    })
2744}
2745
2746#[allow(dead_code)]
2747fn score_numeric_randomized_split_choice_fast(
2748    context: &SplitScoringContext<'_>,
2749    feature_index: usize,
2750    rows: &[usize],
2751) -> Option<BinarySplitChoice> {
2752    let bin_cap = context.table.numeric_bin_cap();
2753    if bin_cap == 0 {
2754        return None;
2755    }
2756    let mut observed_bins = vec![false; bin_cap];
2757    for row_idx in rows {
2758        let bin = context.table.binned_value(feature_index, *row_idx) as usize;
2759        if bin >= bin_cap {
2760            return None;
2761        }
2762        observed_bins[bin] = true;
2763    }
2764    let candidate_thresholds = observed_bins
2765        .into_iter()
2766        .enumerate()
2767        .filter_map(|(bin, seen)| seen.then_some(bin as u16))
2768        .collect::<Vec<_>>();
2769    let threshold_bin =
2770        choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xC1A551F1u64)?;
2771
2772    let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
2773    let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
2774    let mut left_counts = vec![0usize; context.num_classes];
2775    let mut left_size = 0usize;
2776    for row_idx in rows {
2777        if context.table.binned_value(feature_index, *row_idx) <= threshold_bin {
2778            left_counts[context.class_indices[*row_idx]] += 1;
2779            left_size += 1;
2780        }
2781    }
2782    let right_size = rows.len() - left_size;
2783    if left_size < context.min_samples_leaf || right_size < context.min_samples_leaf {
2784        return None;
2785    }
2786    let right_counts = parent_counts
2787        .iter()
2788        .zip(left_counts.iter())
2789        .map(|(parent, left)| parent - left)
2790        .collect::<Vec<_>>();
2791    let weighted_impurity = (left_size as f64 / rows.len() as f64)
2792        * classification_impurity(&left_counts, left_size, context.criterion)
2793        + (right_size as f64 / rows.len() as f64)
2794            * classification_impurity(&right_counts, right_size, context.criterion);
2795
2796    Some(BinarySplitChoice {
2797        feature_index,
2798        score: parent_impurity - weighted_impurity,
2799        threshold_bin,
2800    })
2801}
2802
2803fn partition_rows_for_binary_split(
2804    table: &dyn TableAccess,
2805    feature_index: usize,
2806    threshold_bin: u16,
2807    rows: &mut [usize],
2808) -> usize {
2809    let mut left = 0usize;
2810    for index in 0..rows.len() {
2811        let go_left = if table.is_binary_binned_feature(feature_index) {
2812            !table
2813                .binned_boolean_value(feature_index, rows[index])
2814                .expect("binary feature must expose boolean values")
2815        } else {
2816            table.binned_value(feature_index, rows[index]) <= threshold_bin
2817        };
2818        if go_left {
2819            rows.swap(left, index);
2820            left += 1;
2821        }
2822    }
2823    left
2824}
2825
2826fn partition_rows_for_multiway_split(
2827    table: &dyn TableAccess,
2828    feature_index: usize,
2829    branch_bins: &[u16],
2830    rows: &mut [usize],
2831) -> Vec<(u16, usize, usize)> {
2832    let mut scratch = vec![0usize; rows.len()];
2833    let mut counts = vec![0usize; branch_bins.len()];
2834
2835    for row_idx in rows.iter().copied() {
2836        let bin = if table.is_binary_binned_feature(feature_index) {
2837            if table
2838                .binned_boolean_value(feature_index, row_idx)
2839                .expect("binary feature must expose boolean values")
2840            {
2841                1
2842            } else {
2843                0
2844            }
2845        } else {
2846            table.binned_value(feature_index, row_idx)
2847        };
2848        let branch_index = branch_bins
2849            .binary_search(&bin)
2850            .expect("branch bins must cover all observed bins");
2851        counts[branch_index] += 1;
2852    }
2853
2854    let mut offsets = Vec::with_capacity(branch_bins.len());
2855    let mut next = 0usize;
2856    for count in &counts {
2857        offsets.push(next);
2858        next += *count;
2859    }
2860    let mut write_positions = offsets.clone();
2861    for row_idx in rows.iter().copied() {
2862        let bin = if table.is_binary_binned_feature(feature_index) {
2863            if table
2864                .binned_boolean_value(feature_index, row_idx)
2865                .expect("binary feature must expose boolean values")
2866            {
2867                1
2868            } else {
2869                0
2870            }
2871        } else {
2872            table.binned_value(feature_index, row_idx)
2873        };
2874        let branch_index = branch_bins
2875            .binary_search(&bin)
2876            .expect("branch bins must cover all observed bins");
2877        let write_index = write_positions[branch_index];
2878        scratch[write_index] = row_idx;
2879        write_positions[branch_index] += 1;
2880    }
2881    rows.copy_from_slice(&scratch);
2882
2883    branch_bins
2884        .iter()
2885        .copied()
2886        .zip(offsets)
2887        .zip(counts)
2888        .map(|((bin, start), count)| (bin, start, start + count))
2889        .collect()
2890}
2891
2892fn choose_random_threshold(
2893    candidate_thresholds: &[u16],
2894    feature_index: usize,
2895    rows: &[usize],
2896    salt: u64,
2897) -> Option<u16> {
2898    if candidate_thresholds.is_empty() {
2899        return None;
2900    }
2901
2902    let mut seed = salt ^ ((feature_index as u64) << 32) ^ (rows.len() as u64);
2903    for row_idx in rows {
2904        seed = seed
2905            .wrapping_mul(6364136223846793005)
2906            .wrapping_add((*row_idx as u64) + 1);
2907    }
2908    let mut rng = StdRng::seed_from_u64(seed);
2909    let selected = rng.gen_range(0..candidate_thresholds.len());
2910    candidate_thresholds.get(selected).copied()
2911}
2912
2913fn candidate_feature_indices(
2914    feature_count: usize,
2915    max_features: Option<usize>,
2916    seed: u64,
2917) -> Vec<usize> {
2918    match max_features {
2919        Some(count) => sample_feature_subset(feature_count, count, seed),
2920        None => (0..feature_count).collect(),
2921    }
2922}
2923
2924fn node_seed(base_seed: u64, depth: usize, rows: &[usize], salt: u64) -> u64 {
2925    rows.iter().fold(
2926        base_seed
2927            ^ salt
2928            ^ (depth as u64)
2929                .wrapping_mul(0x9E37_79B9_7F4A_7C15)
2930                .rotate_left(11),
2931        |seed, row_index| {
2932            seed.wrapping_mul(0xA076_1D64_78BD_642F)
2933                ^ (*row_index as u64).wrapping_add(0xE703_7ED1_A0B4_28DB)
2934        },
2935    )
2936}
2937
2938#[allow(dead_code)]
2939fn split_feature_index(candidate: &SplitCandidate) -> usize {
2940    match candidate {
2941        SplitCandidate::Multiway { feature_index, .. }
2942        | SplitCandidate::Binary { feature_index, .. } => *feature_index,
2943    }
2944}
2945
2946fn push_leaf(
2947    nodes: &mut Vec<TreeNode>,
2948    class_index: usize,
2949    sample_count: usize,
2950    class_counts: Vec<usize>,
2951) -> usize {
2952    push_node(
2953        nodes,
2954        TreeNode::Leaf {
2955            class_index,
2956            sample_count,
2957            class_counts,
2958        },
2959    )
2960}
2961
2962fn push_node(nodes: &mut Vec<TreeNode>, node: TreeNode) -> usize {
2963    nodes.push(node);
2964    nodes.len() - 1
2965}
2966
2967#[derive(Debug, Clone, Copy)]
2968enum MultiwayMetric {
2969    InformationGain,
2970    GainRatio,
2971}
2972
2973#[cfg(test)]
2974mod tests {
2975    use super::*;
2976    use crate::{FeaturePreprocessing, Model, NumericBinBoundary};
2977    use forestfire_data::{DenseTable, NumericBins};
2978
2979    fn and_table() -> DenseTable {
2980        DenseTable::new(
2981            vec![
2982                vec![0.0, 0.0],
2983                vec![0.0, 1.0],
2984                vec![1.0, 0.0],
2985                vec![1.0, 1.0],
2986                vec![0.0, 0.0],
2987                vec![0.0, 1.0],
2988                vec![1.0, 0.0],
2989                vec![1.0, 1.0],
2990            ],
2991            vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
2992        )
2993        .unwrap()
2994    }
2995
2996    fn criterion_choice_table() -> DenseTable {
2997        DenseTable::with_options(
2998            vec![
2999                vec![0.0, 1.0],
3000                vec![4.0, 1.0],
3001                vec![4.0, 0.0],
3002                vec![0.0, 1.0],
3003                vec![5.0, 2.0],
3004                vec![2.0, 4.0],
3005                vec![1.0, 2.0],
3006            ],
3007            vec![0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0],
3008            0,
3009            NumericBins::Fixed(8),
3010        )
3011        .unwrap()
3012    }
3013
3014    fn canary_target_table() -> DenseTable {
3015        let x: Vec<Vec<f64>> = (0..8).map(|value| vec![value as f64]).collect();
3016        let probe =
3017            DenseTable::with_options(x.clone(), vec![0.0; 8], 1, NumericBins::Auto).unwrap();
3018        let canary_index = probe.n_features();
3019        let mut observed_bins = (0..probe.n_rows())
3020            .map(|row_idx| probe.binned_value(canary_index, row_idx))
3021            .collect::<Vec<_>>();
3022        observed_bins.sort_unstable();
3023        observed_bins.dedup();
3024        let threshold = observed_bins[observed_bins.len() / 2];
3025        let y = (0..probe.n_rows())
3026            .map(|row_idx| {
3027                if probe.binned_value(canary_index, row_idx) >= threshold {
3028                    1.0
3029                } else {
3030                    0.0
3031                }
3032            })
3033            .collect();
3034
3035        DenseTable::with_options(x, y, 1, NumericBins::Auto).unwrap()
3036    }
3037
3038    #[test]
3039    fn id3_fits_basic_boolean_pattern() {
3040        let table = and_table();
3041        let model = train_id3(&table).unwrap();
3042
3043        assert_eq!(model.algorithm(), DecisionTreeAlgorithm::Id3);
3044        assert_eq!(model.criterion(), Criterion::Entropy);
3045        assert_eq!(model.predict_table(&table), table_targets(&table));
3046    }
3047
3048    #[test]
3049    fn c45_fits_basic_boolean_pattern() {
3050        let table = and_table();
3051        let model = train_c45(&table).unwrap();
3052
3053        assert_eq!(model.algorithm(), DecisionTreeAlgorithm::C45);
3054        assert_eq!(model.criterion(), Criterion::Entropy);
3055        assert_eq!(model.predict_table(&table), table_targets(&table));
3056    }
3057
3058    #[test]
3059    fn cart_fits_basic_boolean_pattern() {
3060        let table = and_table();
3061        let model = train_cart(&table).unwrap();
3062
3063        assert_eq!(model.algorithm(), DecisionTreeAlgorithm::Cart);
3064        assert_eq!(model.criterion(), Criterion::Gini);
3065        assert_eq!(model.predict_table(&table), table_targets(&table));
3066    }
3067
3068    #[test]
3069    fn randomized_fits_basic_boolean_pattern() {
3070        let table = and_table();
3071        let model = train_randomized(&table).unwrap();
3072
3073        assert_eq!(model.algorithm(), DecisionTreeAlgorithm::Randomized);
3074        assert_eq!(model.criterion(), Criterion::Gini);
3075        assert_eq!(model.predict_table(&table), table_targets(&table));
3076    }
3077
3078    #[test]
3079    fn oblivious_fits_basic_boolean_pattern() {
3080        let table = and_table();
3081        let model = train_oblivious(&table).unwrap();
3082
3083        assert_eq!(model.algorithm(), DecisionTreeAlgorithm::Oblivious);
3084        assert_eq!(model.criterion(), Criterion::Gini);
3085        assert_eq!(model.predict_table(&table), table_targets(&table));
3086    }
3087
3088    #[test]
3089    fn cart_can_choose_between_gini_and_entropy() {
3090        let table = criterion_choice_table();
3091        let options = DecisionTreeOptions {
3092            max_depth: 1,
3093            ..DecisionTreeOptions::default()
3094        };
3095        let gini_model = train_classifier(
3096            &table,
3097            DecisionTreeAlgorithm::Cart,
3098            Criterion::Gini,
3099            Parallelism::sequential(),
3100            options,
3101        )
3102        .unwrap();
3103        let entropy_model = train_classifier(
3104            &table,
3105            DecisionTreeAlgorithm::Cart,
3106            Criterion::Entropy,
3107            Parallelism::sequential(),
3108            options,
3109        )
3110        .unwrap();
3111
3112        let root_feature = |model: &DecisionTreeClassifier| match &model.structure {
3113            TreeStructure::Standard { nodes, root } => match &nodes[*root] {
3114                TreeNode::BinarySplit { feature_index, .. } => *feature_index,
3115                node => panic!("expected binary root split, found {node:?}"),
3116            },
3117            TreeStructure::Oblivious { .. } => panic!("expected standard tree"),
3118        };
3119
3120        assert_eq!(gini_model.criterion(), Criterion::Gini);
3121        assert_eq!(entropy_model.criterion(), Criterion::Entropy);
3122        assert_eq!(root_feature(&gini_model), 0);
3123        assert_eq!(root_feature(&entropy_model), 1);
3124    }
3125
3126    #[test]
3127    fn rejects_non_finite_class_labels() {
3128        let table = DenseTable::new(vec![vec![0.0], vec![1.0]], vec![0.0, f64::NAN]).unwrap();
3129
3130        let err = train_id3(&table).unwrap_err();
3131        assert!(matches!(
3132            err,
3133            DecisionTreeError::InvalidTargetValue { row: 1, value } if value.is_nan()
3134        ));
3135    }
3136
3137    #[test]
3138    fn stops_standard_tree_growth_when_a_canary_wins() {
3139        let table = canary_target_table();
3140        for trainer in [train_id3, train_c45, train_cart] {
3141            let model = trainer(&table).unwrap();
3142            let preds = model.predict_table(&table);
3143
3144            assert!(preds.iter().all(|pred| *pred == preds[0]));
3145            assert_ne!(preds, table_targets(&table));
3146        }
3147    }
3148
3149    #[test]
3150    fn stops_oblivious_tree_growth_when_a_canary_wins() {
3151        let table = canary_target_table();
3152        let model = train_oblivious(&table).unwrap();
3153        let preds = model.predict_table(&table);
3154
3155        assert!(preds.iter().all(|pred| *pred == preds[0]));
3156        assert_ne!(preds, table_targets(&table));
3157    }
3158
3159    #[test]
3160    fn manually_built_classifier_models_serialize_for_each_tree_type() {
3161        let preprocessing = vec![
3162            FeaturePreprocessing::Binary,
3163            FeaturePreprocessing::Numeric {
3164                bin_boundaries: vec![
3165                    NumericBinBoundary {
3166                        bin: 0,
3167                        upper_bound: 1.0,
3168                    },
3169                    NumericBinBoundary {
3170                        bin: 127,
3171                        upper_bound: 10.0,
3172                    },
3173                ],
3174            },
3175        ];
3176        let options = DecisionTreeOptions::default();
3177        let class_labels = vec![10.0, 20.0];
3178
3179        let id3 = Model::DecisionTreeClassifier(DecisionTreeClassifier {
3180            algorithm: DecisionTreeAlgorithm::Id3,
3181            criterion: Criterion::Entropy,
3182            class_labels: class_labels.clone(),
3183            structure: TreeStructure::Standard {
3184                nodes: vec![
3185                    TreeNode::Leaf {
3186                        class_index: 0,
3187                        sample_count: 3,
3188                        class_counts: vec![3, 0],
3189                    },
3190                    TreeNode::Leaf {
3191                        class_index: 1,
3192                        sample_count: 2,
3193                        class_counts: vec![0, 2],
3194                    },
3195                    TreeNode::MultiwaySplit {
3196                        feature_index: 1,
3197                        fallback_class_index: 0,
3198                        branches: vec![(0, 0), (127, 1)],
3199                        sample_count: 5,
3200                        impurity: 0.48,
3201                        gain: 0.24,
3202                        class_counts: vec![3, 2],
3203                    },
3204                ],
3205                root: 2,
3206            },
3207            options,
3208            num_features: 2,
3209            feature_preprocessing: preprocessing.clone(),
3210            training_canaries: 0,
3211        });
3212        let c45 = Model::DecisionTreeClassifier(DecisionTreeClassifier {
3213            algorithm: DecisionTreeAlgorithm::C45,
3214            criterion: Criterion::Entropy,
3215            class_labels: class_labels.clone(),
3216            structure: TreeStructure::Standard {
3217                nodes: vec![
3218                    TreeNode::Leaf {
3219                        class_index: 0,
3220                        sample_count: 3,
3221                        class_counts: vec![3, 0],
3222                    },
3223                    TreeNode::Leaf {
3224                        class_index: 1,
3225                        sample_count: 2,
3226                        class_counts: vec![0, 2],
3227                    },
3228                    TreeNode::MultiwaySplit {
3229                        feature_index: 1,
3230                        fallback_class_index: 0,
3231                        branches: vec![(0, 0), (127, 1)],
3232                        sample_count: 5,
3233                        impurity: 0.48,
3234                        gain: 0.24,
3235                        class_counts: vec![3, 2],
3236                    },
3237                ],
3238                root: 2,
3239            },
3240            options,
3241            num_features: 2,
3242            feature_preprocessing: preprocessing.clone(),
3243            training_canaries: 0,
3244        });
3245        let cart = Model::DecisionTreeClassifier(DecisionTreeClassifier {
3246            algorithm: DecisionTreeAlgorithm::Cart,
3247            criterion: Criterion::Gini,
3248            class_labels: class_labels.clone(),
3249            structure: TreeStructure::Standard {
3250                nodes: vec![
3251                    TreeNode::Leaf {
3252                        class_index: 0,
3253                        sample_count: 3,
3254                        class_counts: vec![3, 0],
3255                    },
3256                    TreeNode::Leaf {
3257                        class_index: 1,
3258                        sample_count: 2,
3259                        class_counts: vec![0, 2],
3260                    },
3261                    TreeNode::BinarySplit {
3262                        feature_index: 0,
3263                        threshold_bin: 0,
3264                        left_child: 0,
3265                        right_child: 1,
3266                        sample_count: 5,
3267                        impurity: 0.48,
3268                        gain: 0.24,
3269                        class_counts: vec![3, 2],
3270                    },
3271                ],
3272                root: 2,
3273            },
3274            options,
3275            num_features: 2,
3276            feature_preprocessing: preprocessing.clone(),
3277            training_canaries: 0,
3278        });
3279        let randomized = Model::DecisionTreeClassifier(DecisionTreeClassifier {
3280            algorithm: DecisionTreeAlgorithm::Randomized,
3281            criterion: Criterion::Entropy,
3282            class_labels: class_labels.clone(),
3283            structure: TreeStructure::Standard {
3284                nodes: vec![
3285                    TreeNode::Leaf {
3286                        class_index: 0,
3287                        sample_count: 3,
3288                        class_counts: vec![3, 0],
3289                    },
3290                    TreeNode::Leaf {
3291                        class_index: 1,
3292                        sample_count: 2,
3293                        class_counts: vec![0, 2],
3294                    },
3295                    TreeNode::BinarySplit {
3296                        feature_index: 0,
3297                        threshold_bin: 0,
3298                        left_child: 0,
3299                        right_child: 1,
3300                        sample_count: 5,
3301                        impurity: 0.48,
3302                        gain: 0.2,
3303                        class_counts: vec![3, 2],
3304                    },
3305                ],
3306                root: 2,
3307            },
3308            options,
3309            num_features: 2,
3310            feature_preprocessing: preprocessing.clone(),
3311            training_canaries: 0,
3312        });
3313        let oblivious = Model::DecisionTreeClassifier(DecisionTreeClassifier {
3314            algorithm: DecisionTreeAlgorithm::Oblivious,
3315            criterion: Criterion::Gini,
3316            class_labels,
3317            structure: TreeStructure::Oblivious {
3318                splits: vec![ObliviousSplit {
3319                    feature_index: 0,
3320                    threshold_bin: 0,
3321                    sample_count: 4,
3322                    impurity: 0.5,
3323                    gain: 0.25,
3324                }],
3325                leaf_class_indices: vec![0, 1],
3326                leaf_sample_counts: vec![2, 2],
3327                leaf_class_counts: vec![vec![2, 0], vec![0, 2]],
3328            },
3329            options,
3330            num_features: 2,
3331            feature_preprocessing: preprocessing,
3332            training_canaries: 0,
3333        });
3334
3335        for (tree_type, model) in [
3336            ("id3", id3),
3337            ("c45", c45),
3338            ("cart", cart),
3339            ("randomized", randomized),
3340            ("oblivious", oblivious),
3341        ] {
3342            let json = model.serialize().unwrap();
3343            assert!(json.contains(&format!("\"tree_type\":\"{tree_type}\"")));
3344            assert!(json.contains("\"task\":\"classification\""));
3345        }
3346    }
3347
3348    fn table_targets(table: &dyn TableAccess) -> Vec<f64> {
3349        (0..table.n_rows())
3350            .map(|row_idx| table.target_value(row_idx))
3351            .collect()
3352    }
3353}