1use crate::ir::{
2 BinaryChildren, BinarySplit, IndexedLeaf, LeafIndexing, LeafPayload, MultiwayBranch,
3 MultiwaySplit, NodeStats, NodeTreeNode, ObliviousLevel, ObliviousSplit as IrObliviousSplit,
4 TrainingMetadata, TreeDefinition, criterion_name, feature_name, threshold_upper_bound,
5 tree_type_name,
6};
7use crate::sampling::sample_feature_subset;
8use crate::{Criterion, FeaturePreprocessing, Parallelism, capture_feature_preprocessing};
9use forestfire_data::TableAccess;
10use rand::rngs::StdRng;
11use rand::{Rng, SeedableRng};
12use rayon::prelude::*;
13use std::collections::{BTreeMap, BTreeSet};
14use std::error::Error;
15use std::fmt::{Display, Formatter};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum DecisionTreeAlgorithm {
19 Id3,
20 C45,
21 Cart,
22 Randomized,
23 Oblivious,
24}
25
26#[derive(Debug, Clone, Copy)]
27pub struct DecisionTreeOptions {
28 pub max_depth: usize,
29 pub min_samples_split: usize,
30 pub min_samples_leaf: usize,
31 pub max_features: Option<usize>,
32 pub random_seed: u64,
33}
34
35impl Default for DecisionTreeOptions {
36 fn default() -> Self {
37 Self {
38 max_depth: 8,
39 min_samples_split: 2,
40 min_samples_leaf: 1,
41 max_features: None,
42 random_seed: 0,
43 }
44 }
45}
46
47#[derive(Debug)]
48pub enum DecisionTreeError {
49 EmptyTarget,
50 InvalidTargetValue { row: usize, value: f64 },
51}
52
53impl Display for DecisionTreeError {
54 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
55 match self {
56 DecisionTreeError::EmptyTarget => write!(f, "Cannot train on an empty target vector."),
57 DecisionTreeError::InvalidTargetValue { row, value } => write!(
58 f,
59 "Classification targets must be finite values. Found {} at row {}.",
60 value, row
61 ),
62 }
63 }
64}
65
66impl Error for DecisionTreeError {}
67
68#[derive(Debug, Clone)]
69pub struct DecisionTreeClassifier {
70 algorithm: DecisionTreeAlgorithm,
71 criterion: Criterion,
72 class_labels: Vec<f64>,
73 structure: TreeStructure,
74 options: DecisionTreeOptions,
75 num_features: usize,
76 feature_preprocessing: Vec<FeaturePreprocessing>,
77 training_canaries: usize,
78}
79
80#[derive(Debug, Clone)]
81pub(crate) enum TreeStructure {
82 Standard {
83 nodes: Vec<TreeNode>,
84 root: usize,
85 },
86 Oblivious {
87 splits: Vec<ObliviousSplit>,
88 leaf_class_indices: Vec<usize>,
89 leaf_sample_counts: Vec<usize>,
90 leaf_class_counts: Vec<Vec<usize>>,
91 },
92}
93
94#[derive(Debug, Clone, Copy)]
95pub(crate) struct ObliviousSplit {
96 pub(crate) feature_index: usize,
97 pub(crate) threshold_bin: u16,
98 pub(crate) sample_count: usize,
99 pub(crate) impurity: f64,
100 pub(crate) gain: f64,
101}
102
103#[derive(Debug, Clone)]
104pub(crate) enum TreeNode {
105 Leaf {
106 class_index: usize,
107 sample_count: usize,
108 class_counts: Vec<usize>,
109 },
110 MultiwaySplit {
111 feature_index: usize,
112 fallback_class_index: usize,
113 branches: Vec<(u16, usize)>,
114 sample_count: usize,
115 impurity: f64,
116 gain: f64,
117 class_counts: Vec<usize>,
118 },
119 BinarySplit {
120 feature_index: usize,
121 threshold_bin: u16,
122 left_child: usize,
123 right_child: usize,
124 sample_count: usize,
125 impurity: f64,
126 gain: f64,
127 class_counts: Vec<usize>,
128 },
129}
130
131#[derive(Debug, Clone)]
132#[allow(dead_code)]
133enum SplitCandidate {
134 Multiway {
135 feature_index: usize,
136 score: f64,
137 branches: Vec<(u16, Vec<usize>)>,
138 },
139 Binary {
140 feature_index: usize,
141 score: f64,
142 threshold_bin: u16,
143 left_rows: Vec<usize>,
144 right_rows: Vec<usize>,
145 },
146}
147
148#[derive(Debug, Clone, Copy)]
149struct BinarySplitChoice {
150 feature_index: usize,
151 score: f64,
152 threshold_bin: u16,
153}
154
155#[derive(Debug, Clone)]
156struct MultiwaySplitChoice {
157 feature_index: usize,
158 score: f64,
159 branch_bins: Vec<u16>,
160}
161
162#[derive(Debug, Clone)]
163enum ClassificationFeatureHistogram {
164 Binary {
165 false_counts: Vec<usize>,
166 true_counts: Vec<usize>,
167 false_size: usize,
168 true_size: usize,
169 },
170 Numeric {
171 bin_class_counts: Vec<Vec<usize>>,
172 observed_bins: Vec<usize>,
173 },
174}
175
176pub fn train_id3(train_set: &dyn TableAccess) -> Result<DecisionTreeClassifier, DecisionTreeError> {
177 train_id3_with_criterion(train_set, Criterion::Entropy)
178}
179
180pub fn train_id3_with_criterion(
181 train_set: &dyn TableAccess,
182 criterion: Criterion,
183) -> Result<DecisionTreeClassifier, DecisionTreeError> {
184 train_id3_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
185}
186
187pub(crate) fn train_id3_with_criterion_and_parallelism(
188 train_set: &dyn TableAccess,
189 criterion: Criterion,
190 parallelism: Parallelism,
191) -> Result<DecisionTreeClassifier, DecisionTreeError> {
192 train_id3_with_criterion_parallelism_and_options(
193 train_set,
194 criterion,
195 parallelism,
196 DecisionTreeOptions::default(),
197 )
198}
199
200pub(crate) fn train_id3_with_criterion_parallelism_and_options(
201 train_set: &dyn TableAccess,
202 criterion: Criterion,
203 parallelism: Parallelism,
204 options: DecisionTreeOptions,
205) -> Result<DecisionTreeClassifier, DecisionTreeError> {
206 train_classifier(
207 train_set,
208 DecisionTreeAlgorithm::Id3,
209 criterion,
210 parallelism,
211 options,
212 )
213}
214
215pub fn train_c45(train_set: &dyn TableAccess) -> Result<DecisionTreeClassifier, DecisionTreeError> {
216 train_c45_with_criterion(train_set, Criterion::Entropy)
217}
218
219pub fn train_c45_with_criterion(
220 train_set: &dyn TableAccess,
221 criterion: Criterion,
222) -> Result<DecisionTreeClassifier, DecisionTreeError> {
223 train_c45_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
224}
225
226pub(crate) fn train_c45_with_criterion_and_parallelism(
227 train_set: &dyn TableAccess,
228 criterion: Criterion,
229 parallelism: Parallelism,
230) -> Result<DecisionTreeClassifier, DecisionTreeError> {
231 train_c45_with_criterion_parallelism_and_options(
232 train_set,
233 criterion,
234 parallelism,
235 DecisionTreeOptions::default(),
236 )
237}
238
239pub(crate) fn train_c45_with_criterion_parallelism_and_options(
240 train_set: &dyn TableAccess,
241 criterion: Criterion,
242 parallelism: Parallelism,
243 options: DecisionTreeOptions,
244) -> Result<DecisionTreeClassifier, DecisionTreeError> {
245 train_classifier(
246 train_set,
247 DecisionTreeAlgorithm::C45,
248 criterion,
249 parallelism,
250 options,
251 )
252}
253
254pub fn train_cart(
255 train_set: &dyn TableAccess,
256) -> Result<DecisionTreeClassifier, DecisionTreeError> {
257 train_cart_with_criterion(train_set, Criterion::Gini)
258}
259
260pub fn train_cart_with_criterion(
261 train_set: &dyn TableAccess,
262 criterion: Criterion,
263) -> Result<DecisionTreeClassifier, DecisionTreeError> {
264 train_cart_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
265}
266
267pub(crate) fn train_cart_with_criterion_and_parallelism(
268 train_set: &dyn TableAccess,
269 criterion: Criterion,
270 parallelism: Parallelism,
271) -> Result<DecisionTreeClassifier, DecisionTreeError> {
272 train_cart_with_criterion_parallelism_and_options(
273 train_set,
274 criterion,
275 parallelism,
276 DecisionTreeOptions::default(),
277 )
278}
279
280pub(crate) fn train_cart_with_criterion_parallelism_and_options(
281 train_set: &dyn TableAccess,
282 criterion: Criterion,
283 parallelism: Parallelism,
284 options: DecisionTreeOptions,
285) -> Result<DecisionTreeClassifier, DecisionTreeError> {
286 train_classifier(
287 train_set,
288 DecisionTreeAlgorithm::Cart,
289 criterion,
290 parallelism,
291 options,
292 )
293}
294
295pub fn train_oblivious(
296 train_set: &dyn TableAccess,
297) -> Result<DecisionTreeClassifier, DecisionTreeError> {
298 train_oblivious_with_criterion(train_set, Criterion::Gini)
299}
300
301pub fn train_oblivious_with_criterion(
302 train_set: &dyn TableAccess,
303 criterion: Criterion,
304) -> Result<DecisionTreeClassifier, DecisionTreeError> {
305 train_oblivious_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
306}
307
308pub(crate) fn train_oblivious_with_criterion_and_parallelism(
309 train_set: &dyn TableAccess,
310 criterion: Criterion,
311 parallelism: Parallelism,
312) -> Result<DecisionTreeClassifier, DecisionTreeError> {
313 train_oblivious_with_criterion_parallelism_and_options(
314 train_set,
315 criterion,
316 parallelism,
317 DecisionTreeOptions::default(),
318 )
319}
320
321pub(crate) fn train_oblivious_with_criterion_parallelism_and_options(
322 train_set: &dyn TableAccess,
323 criterion: Criterion,
324 parallelism: Parallelism,
325 options: DecisionTreeOptions,
326) -> Result<DecisionTreeClassifier, DecisionTreeError> {
327 train_classifier(
328 train_set,
329 DecisionTreeAlgorithm::Oblivious,
330 criterion,
331 parallelism,
332 options,
333 )
334}
335
336pub fn train_randomized(
337 train_set: &dyn TableAccess,
338) -> Result<DecisionTreeClassifier, DecisionTreeError> {
339 train_randomized_with_criterion(train_set, Criterion::Gini)
340}
341
342pub fn train_randomized_with_criterion(
343 train_set: &dyn TableAccess,
344 criterion: Criterion,
345) -> Result<DecisionTreeClassifier, DecisionTreeError> {
346 train_randomized_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
347}
348
349pub(crate) fn train_randomized_with_criterion_and_parallelism(
350 train_set: &dyn TableAccess,
351 criterion: Criterion,
352 parallelism: Parallelism,
353) -> Result<DecisionTreeClassifier, DecisionTreeError> {
354 train_randomized_with_criterion_parallelism_and_options(
355 train_set,
356 criterion,
357 parallelism,
358 DecisionTreeOptions::default(),
359 )
360}
361
362pub(crate) fn train_randomized_with_criterion_parallelism_and_options(
363 train_set: &dyn TableAccess,
364 criterion: Criterion,
365 parallelism: Parallelism,
366 options: DecisionTreeOptions,
367) -> Result<DecisionTreeClassifier, DecisionTreeError> {
368 train_classifier(
369 train_set,
370 DecisionTreeAlgorithm::Randomized,
371 criterion,
372 parallelism,
373 options,
374 )
375}
376
377fn train_classifier(
378 train_set: &dyn TableAccess,
379 algorithm: DecisionTreeAlgorithm,
380 criterion: Criterion,
381 parallelism: Parallelism,
382 options: DecisionTreeOptions,
383) -> Result<DecisionTreeClassifier, DecisionTreeError> {
384 if train_set.n_rows() == 0 {
385 return Err(DecisionTreeError::EmptyTarget);
386 }
387
388 let (class_labels, class_indices) = encode_class_labels(train_set)?;
389 let structure = match algorithm {
390 DecisionTreeAlgorithm::Oblivious => train_oblivious_structure(
391 train_set,
392 &class_indices,
393 &class_labels,
394 criterion,
395 parallelism,
396 options,
397 ),
398 DecisionTreeAlgorithm::Cart | DecisionTreeAlgorithm::Randomized => {
399 let mut nodes = Vec::new();
400 let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
401 let context = BuildContext {
402 table: train_set,
403 class_indices: &class_indices,
404 class_labels: &class_labels,
405 algorithm,
406 criterion,
407 parallelism,
408 options,
409 };
410 let root = build_binary_node_in_place(&context, &mut nodes, &mut all_rows, 0);
411 TreeStructure::Standard { nodes, root }
412 }
413 DecisionTreeAlgorithm::Id3 | DecisionTreeAlgorithm::C45 => {
414 let mut nodes = Vec::new();
415 let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
416 let context = BuildContext {
417 table: train_set,
418 class_indices: &class_indices,
419 class_labels: &class_labels,
420 algorithm,
421 criterion,
422 parallelism,
423 options,
424 };
425 let root = build_multiway_node_in_place(&context, &mut nodes, &mut all_rows, 0);
426 TreeStructure::Standard { nodes, root }
427 }
428 };
429
430 Ok(DecisionTreeClassifier {
431 algorithm,
432 criterion,
433 class_labels,
434 structure,
435 options,
436 num_features: train_set.n_features(),
437 feature_preprocessing: capture_feature_preprocessing(train_set),
438 training_canaries: train_set.canaries(),
439 })
440}
441
442impl DecisionTreeClassifier {
443 pub fn algorithm(&self) -> DecisionTreeAlgorithm {
444 self.algorithm
445 }
446
447 pub fn criterion(&self) -> Criterion {
448 self.criterion
449 }
450
451 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
452 (0..table.n_rows())
453 .map(|row_idx| self.predict_row(table, row_idx))
454 .collect()
455 }
456
457 pub fn predict_proba_table(&self, table: &dyn TableAccess) -> Vec<Vec<f64>> {
458 (0..table.n_rows())
459 .map(|row_idx| self.predict_proba_row(table, row_idx))
460 .collect()
461 }
462
463 fn predict_row(&self, table: &dyn TableAccess, row_idx: usize) -> f64 {
464 match &self.structure {
465 TreeStructure::Standard { nodes, root } => {
466 let mut node_index = *root;
467
468 loop {
469 match &nodes[node_index] {
470 TreeNode::Leaf { class_index, .. } => {
471 return self.class_labels[*class_index];
472 }
473 TreeNode::MultiwaySplit {
474 feature_index,
475 fallback_class_index,
476 branches,
477 ..
478 } => {
479 let bin = table.binned_value(*feature_index, row_idx);
480 if let Some((_, child_index)) =
481 branches.iter().find(|(branch_bin, _)| *branch_bin == bin)
482 {
483 node_index = *child_index;
484 } else {
485 return self.class_labels[*fallback_class_index];
486 }
487 }
488 TreeNode::BinarySplit {
489 feature_index,
490 threshold_bin,
491 left_child,
492 right_child,
493 ..
494 } => {
495 let bin = table.binned_value(*feature_index, row_idx);
496 node_index = if bin <= *threshold_bin {
497 *left_child
498 } else {
499 *right_child
500 };
501 }
502 }
503 }
504 }
505 TreeStructure::Oblivious {
506 splits,
507 leaf_class_indices,
508 ..
509 } => {
510 let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
511 let go_right =
512 table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
513 (leaf_index << 1) | usize::from(go_right)
514 });
515
516 self.class_labels[leaf_class_indices[leaf_index]]
517 }
518 }
519 }
520
521 fn predict_proba_row(&self, table: &dyn TableAccess, row_idx: usize) -> Vec<f64> {
522 match &self.structure {
523 TreeStructure::Standard { nodes, root } => {
524 let mut node_index = *root;
525
526 loop {
527 match &nodes[node_index] {
528 TreeNode::Leaf { class_counts, .. } => {
529 return normalized_class_probabilities(class_counts);
530 }
531 TreeNode::MultiwaySplit {
532 feature_index,
533 branches,
534 class_counts,
535 ..
536 } => {
537 let bin = table.binned_value(*feature_index, row_idx);
538 if let Some((_, child_index)) =
539 branches.iter().find(|(branch_bin, _)| *branch_bin == bin)
540 {
541 node_index = *child_index;
542 } else {
543 return normalized_class_probabilities(class_counts);
544 }
545 }
546 TreeNode::BinarySplit {
547 feature_index,
548 threshold_bin,
549 left_child,
550 right_child,
551 ..
552 } => {
553 let bin = table.binned_value(*feature_index, row_idx);
554 node_index = if bin <= *threshold_bin {
555 *left_child
556 } else {
557 *right_child
558 };
559 }
560 }
561 }
562 }
563 TreeStructure::Oblivious {
564 splits,
565 leaf_class_counts,
566 ..
567 } => {
568 let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
569 let go_right =
570 table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
571 (leaf_index << 1) | usize::from(go_right)
572 });
573
574 normalized_class_probabilities(&leaf_class_counts[leaf_index])
575 }
576 }
577 }
578
579 pub(crate) fn class_labels(&self) -> &[f64] {
580 &self.class_labels
581 }
582
583 pub(crate) fn structure(&self) -> &TreeStructure {
584 &self.structure
585 }
586
587 pub(crate) fn num_features(&self) -> usize {
588 self.num_features
589 }
590
591 pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
592 &self.feature_preprocessing
593 }
594
595 pub(crate) fn training_metadata(&self) -> TrainingMetadata {
596 TrainingMetadata {
597 algorithm: "dt".to_string(),
598 task: "classification".to_string(),
599 tree_type: tree_type_name(match self.algorithm {
600 DecisionTreeAlgorithm::Id3 => crate::TreeType::Id3,
601 DecisionTreeAlgorithm::C45 => crate::TreeType::C45,
602 DecisionTreeAlgorithm::Cart => crate::TreeType::Cart,
603 DecisionTreeAlgorithm::Randomized => crate::TreeType::Randomized,
604 DecisionTreeAlgorithm::Oblivious => crate::TreeType::Oblivious,
605 })
606 .to_string(),
607 criterion: criterion_name(self.criterion).to_string(),
608 canaries: self.training_canaries,
609 compute_oob: false,
610 max_depth: Some(self.options.max_depth),
611 min_samples_split: Some(self.options.min_samples_split),
612 min_samples_leaf: Some(self.options.min_samples_leaf),
613 n_trees: None,
614 max_features: self.options.max_features,
615 seed: None,
616 oob_score: None,
617 class_labels: Some(self.class_labels.clone()),
618 learning_rate: None,
619 bootstrap: None,
620 top_gradient_fraction: None,
621 other_gradient_fraction: None,
622 }
623 }
624
625 pub(crate) fn to_ir_tree(&self) -> TreeDefinition {
626 match &self.structure {
627 TreeStructure::Standard { nodes, root } => {
628 let depths = standard_node_depths(nodes, *root);
629 TreeDefinition::NodeTree {
630 tree_id: 0,
631 weight: 1.0,
632 root_node_id: *root,
633 nodes: nodes
634 .iter()
635 .enumerate()
636 .map(|(node_id, node)| match node {
637 TreeNode::Leaf {
638 class_index,
639 sample_count,
640 class_counts,
641 } => NodeTreeNode::Leaf {
642 node_id,
643 depth: depths[node_id],
644 leaf: self.class_leaf(*class_index),
645 stats: NodeStats {
646 sample_count: *sample_count,
647 impurity: None,
648 gain: None,
649 class_counts: Some(class_counts.clone()),
650 variance: None,
651 },
652 },
653 TreeNode::BinarySplit {
654 feature_index,
655 threshold_bin,
656 left_child,
657 right_child,
658 sample_count,
659 impurity,
660 gain,
661 class_counts,
662 } => NodeTreeNode::BinaryBranch {
663 node_id,
664 depth: depths[node_id],
665 split: binary_split_ir(
666 *feature_index,
667 *threshold_bin,
668 &self.feature_preprocessing,
669 ),
670 children: BinaryChildren {
671 left: *left_child,
672 right: *right_child,
673 },
674 stats: NodeStats {
675 sample_count: *sample_count,
676 impurity: Some(*impurity),
677 gain: Some(*gain),
678 class_counts: Some(class_counts.clone()),
679 variance: None,
680 },
681 },
682 TreeNode::MultiwaySplit {
683 feature_index,
684 fallback_class_index,
685 branches,
686 sample_count,
687 impurity,
688 gain,
689 class_counts,
690 } => NodeTreeNode::MultiwayBranch {
691 node_id,
692 depth: depths[node_id],
693 split: MultiwaySplit {
694 split_type: "binned_value_multiway".to_string(),
695 feature_index: *feature_index,
696 feature_name: feature_name(*feature_index),
697 comparison_dtype: "uint16".to_string(),
698 },
699 branches: branches
700 .iter()
701 .map(|(bin, child)| MultiwayBranch {
702 bin: *bin,
703 child: *child,
704 })
705 .collect(),
706 unmatched_leaf: self.class_leaf(*fallback_class_index),
707 stats: NodeStats {
708 sample_count: *sample_count,
709 impurity: Some(*impurity),
710 gain: Some(*gain),
711 class_counts: Some(class_counts.clone()),
712 variance: None,
713 },
714 },
715 })
716 .collect(),
717 }
718 }
719 TreeStructure::Oblivious {
720 splits,
721 leaf_class_indices,
722 leaf_sample_counts,
723 leaf_class_counts,
724 } => TreeDefinition::ObliviousLevels {
725 tree_id: 0,
726 weight: 1.0,
727 depth: splits.len(),
728 levels: splits
729 .iter()
730 .enumerate()
731 .map(|(level, split)| ObliviousLevel {
732 level,
733 split: oblivious_split_ir(
734 split.feature_index,
735 split.threshold_bin,
736 &self.feature_preprocessing,
737 ),
738 stats: NodeStats {
739 sample_count: split.sample_count,
740 impurity: Some(split.impurity),
741 gain: Some(split.gain),
742 class_counts: None,
743 variance: None,
744 },
745 })
746 .collect(),
747 leaf_indexing: LeafIndexing {
748 bit_order: "msb_first".to_string(),
749 index_formula: "sum(bit[level] << (depth - 1 - level))".to_string(),
750 },
751 leaves: leaf_class_indices
752 .iter()
753 .enumerate()
754 .map(|(leaf_index, class_index)| IndexedLeaf {
755 leaf_index,
756 leaf: self.class_leaf(*class_index),
757 stats: NodeStats {
758 sample_count: leaf_sample_counts[leaf_index],
759 impurity: None,
760 gain: None,
761 class_counts: Some(leaf_class_counts[leaf_index].clone()),
762 variance: None,
763 },
764 })
765 .collect(),
766 },
767 }
768 }
769
770 fn class_leaf(&self, class_index: usize) -> LeafPayload {
771 LeafPayload::ClassIndex {
772 class_index,
773 class_value: self.class_labels[class_index],
774 }
775 }
776
777 #[allow(clippy::too_many_arguments)]
778 pub(crate) fn from_ir_parts(
779 algorithm: DecisionTreeAlgorithm,
780 criterion: Criterion,
781 class_labels: Vec<f64>,
782 structure: TreeStructure,
783 options: DecisionTreeOptions,
784 num_features: usize,
785 feature_preprocessing: Vec<FeaturePreprocessing>,
786 training_canaries: usize,
787 ) -> Self {
788 Self {
789 algorithm,
790 criterion,
791 class_labels,
792 structure,
793 options,
794 num_features,
795 feature_preprocessing,
796 training_canaries,
797 }
798 }
799}
800
801fn build_binary_node_in_place(
802 context: &BuildContext<'_>,
803 nodes: &mut Vec<TreeNode>,
804 rows: &mut [usize],
805 depth: usize,
806) -> usize {
807 build_binary_node_in_place_with_hist(context, nodes, rows, depth, None)
808}
809
810fn build_binary_node_in_place_with_hist(
811 context: &BuildContext<'_>,
812 nodes: &mut Vec<TreeNode>,
813 rows: &mut [usize],
814 depth: usize,
815 histograms: Option<Vec<ClassificationFeatureHistogram>>,
816) -> usize {
817 let majority_class_index =
818 majority_class(rows, context.class_indices, context.class_labels.len());
819 let current_class_counts =
820 class_counts(rows, context.class_indices, context.class_labels.len());
821
822 if rows.is_empty()
823 || depth >= context.options.max_depth
824 || rows.len() < context.options.min_samples_split
825 || is_pure(rows, context.class_indices)
826 {
827 return push_leaf(
828 nodes,
829 majority_class_index,
830 rows.len(),
831 current_class_counts,
832 );
833 }
834
835 let scoring = SplitScoringContext {
836 table: context.table,
837 class_indices: context.class_indices,
838 num_classes: context.class_labels.len(),
839 criterion: context.criterion,
840 min_samples_leaf: context.options.min_samples_leaf,
841 };
842 let histograms = histograms.unwrap_or_else(|| {
843 build_classification_node_histograms(
844 context.table,
845 context.class_indices,
846 rows,
847 context.class_labels.len(),
848 )
849 });
850 let feature_indices = candidate_feature_indices(
851 context.table.binned_feature_count(),
852 context.options.max_features,
853 node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
854 );
855 let best_split = if context.parallelism.enabled() {
856 feature_indices
857 .into_par_iter()
858 .filter_map(|feature_index| {
859 score_binary_split_choice_from_hist(
860 &scoring,
861 &histograms[feature_index],
862 feature_index,
863 rows,
864 ¤t_class_counts,
865 context.algorithm,
866 )
867 })
868 .max_by(|left, right| left.score.total_cmp(&right.score))
869 } else {
870 feature_indices
871 .into_iter()
872 .filter_map(|feature_index| {
873 score_binary_split_choice_from_hist(
874 &scoring,
875 &histograms[feature_index],
876 feature_index,
877 rows,
878 ¤t_class_counts,
879 context.algorithm,
880 )
881 })
882 .max_by(|left, right| left.score.total_cmp(&right.score))
883 };
884
885 match best_split {
886 Some(best_split)
887 if context
888 .table
889 .is_canary_binned_feature(best_split.feature_index) =>
890 {
891 push_leaf(
892 nodes,
893 majority_class_index,
894 rows.len(),
895 current_class_counts,
896 )
897 }
898 Some(best_split) if best_split.score > 0.0 => {
899 let impurity =
900 classification_impurity(¤t_class_counts, rows.len(), context.criterion);
901 let left_count = partition_rows_for_binary_split(
902 context.table,
903 best_split.feature_index,
904 best_split.threshold_bin,
905 rows,
906 );
907 let (left_rows, right_rows) = rows.split_at_mut(left_count);
908 let (left_histograms, right_histograms) = if left_rows.len() <= right_rows.len() {
909 let left_histograms = build_classification_node_histograms(
910 context.table,
911 context.class_indices,
912 left_rows,
913 context.class_labels.len(),
914 );
915 let right_histograms =
916 subtract_classification_node_histograms(&histograms, &left_histograms);
917 (left_histograms, right_histograms)
918 } else {
919 let right_histograms = build_classification_node_histograms(
920 context.table,
921 context.class_indices,
922 right_rows,
923 context.class_labels.len(),
924 );
925 let left_histograms =
926 subtract_classification_node_histograms(&histograms, &right_histograms);
927 (left_histograms, right_histograms)
928 };
929 let left_child = build_binary_node_in_place_with_hist(
930 context,
931 nodes,
932 left_rows,
933 depth + 1,
934 Some(left_histograms),
935 );
936 let right_child = build_binary_node_in_place_with_hist(
937 context,
938 nodes,
939 right_rows,
940 depth + 1,
941 Some(right_histograms),
942 );
943
944 push_node(
945 nodes,
946 TreeNode::BinarySplit {
947 feature_index: best_split.feature_index,
948 threshold_bin: best_split.threshold_bin,
949 left_child,
950 right_child,
951 sample_count: rows.len(),
952 impurity,
953 gain: best_split.score,
954 class_counts: current_class_counts,
955 },
956 )
957 }
958 _ => push_leaf(
959 nodes,
960 majority_class_index,
961 rows.len(),
962 current_class_counts,
963 ),
964 }
965}
966
967fn build_multiway_node_in_place(
968 context: &BuildContext<'_>,
969 nodes: &mut Vec<TreeNode>,
970 rows: &mut [usize],
971 depth: usize,
972) -> usize {
973 let majority_class_index =
974 majority_class(rows, context.class_indices, context.class_labels.len());
975 let current_class_counts =
976 class_counts(rows, context.class_indices, context.class_labels.len());
977
978 if rows.is_empty()
979 || depth >= context.options.max_depth
980 || rows.len() < context.options.min_samples_split
981 || is_pure(rows, context.class_indices)
982 {
983 return push_leaf(
984 nodes,
985 majority_class_index,
986 rows.len(),
987 current_class_counts,
988 );
989 }
990
991 let metric = match context.algorithm {
992 DecisionTreeAlgorithm::Id3 => MultiwayMetric::InformationGain,
993 DecisionTreeAlgorithm::C45 => MultiwayMetric::GainRatio,
994 _ => unreachable!("multiway builder only supports id3/c45"),
995 };
996 let scoring = SplitScoringContext {
997 table: context.table,
998 class_indices: context.class_indices,
999 num_classes: context.class_labels.len(),
1000 criterion: context.criterion,
1001 min_samples_leaf: context.options.min_samples_leaf,
1002 };
1003 let feature_indices = candidate_feature_indices(
1004 context.table.binned_feature_count(),
1005 context.options.max_features,
1006 node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
1007 );
1008 let best_split = if context.parallelism.enabled() {
1009 feature_indices
1010 .into_par_iter()
1011 .filter_map(|feature_index| {
1012 score_multiway_split_choice(&scoring, feature_index, rows, metric)
1013 })
1014 .max_by(|left, right| left.score.total_cmp(&right.score))
1015 } else {
1016 feature_indices
1017 .into_iter()
1018 .filter_map(|feature_index| {
1019 score_multiway_split_choice(&scoring, feature_index, rows, metric)
1020 })
1021 .max_by(|left, right| left.score.total_cmp(&right.score))
1022 };
1023
1024 match best_split {
1025 Some(best_split)
1026 if context
1027 .table
1028 .is_canary_binned_feature(best_split.feature_index) =>
1029 {
1030 push_leaf(
1031 nodes,
1032 majority_class_index,
1033 rows.len(),
1034 current_class_counts,
1035 )
1036 }
1037 Some(best_split) if best_split.score > 0.0 => {
1038 let impurity =
1039 classification_impurity(¤t_class_counts, rows.len(), context.criterion);
1040 let branch_ranges = partition_rows_for_multiway_split(
1041 context.table,
1042 best_split.feature_index,
1043 &best_split.branch_bins,
1044 rows,
1045 );
1046 let mut branch_nodes = Vec::with_capacity(branch_ranges.len());
1047 for (bin, start, end) in branch_ranges {
1048 let child =
1049 build_multiway_node_in_place(context, nodes, &mut rows[start..end], depth + 1);
1050 branch_nodes.push((bin, child));
1051 }
1052
1053 push_node(
1054 nodes,
1055 TreeNode::MultiwaySplit {
1056 feature_index: best_split.feature_index,
1057 fallback_class_index: majority_class_index,
1058 branches: branch_nodes,
1059 sample_count: rows.len(),
1060 impurity,
1061 gain: best_split.score,
1062 class_counts: current_class_counts,
1063 },
1064 )
1065 }
1066 _ => push_leaf(
1067 nodes,
1068 majority_class_index,
1069 rows.len(),
1070 current_class_counts,
1071 ),
1072 }
1073}
1074
1075fn normalized_class_probabilities(class_counts: &[usize]) -> Vec<f64> {
1076 let total = class_counts.iter().sum::<usize>();
1077 if total == 0 {
1078 return vec![0.0; class_counts.len()];
1079 }
1080
1081 class_counts
1082 .iter()
1083 .map(|count| *count as f64 / total as f64)
1084 .collect()
1085}
1086
1087fn standard_node_depths(nodes: &[TreeNode], root: usize) -> Vec<usize> {
1088 let mut depths = vec![0; nodes.len()];
1089 populate_depths(nodes, root, 0, &mut depths);
1090 depths
1091}
1092
1093fn populate_depths(nodes: &[TreeNode], node_id: usize, depth: usize, depths: &mut [usize]) {
1094 depths[node_id] = depth;
1095 match &nodes[node_id] {
1096 TreeNode::Leaf { .. } => {}
1097 TreeNode::BinarySplit {
1098 left_child,
1099 right_child,
1100 ..
1101 } => {
1102 populate_depths(nodes, *left_child, depth + 1, depths);
1103 populate_depths(nodes, *right_child, depth + 1, depths);
1104 }
1105 TreeNode::MultiwaySplit { branches, .. } => {
1106 for (_, child) in branches {
1107 populate_depths(nodes, *child, depth + 1, depths);
1108 }
1109 }
1110 }
1111}
1112
1113fn binary_split_ir(
1114 feature_index: usize,
1115 threshold_bin: u16,
1116 preprocessing: &[FeaturePreprocessing],
1117) -> BinarySplit {
1118 match preprocessing.get(feature_index) {
1119 Some(FeaturePreprocessing::Binary) => BinarySplit::BooleanTest {
1120 feature_index,
1121 feature_name: feature_name(feature_index),
1122 false_child_semantics: "left".to_string(),
1123 true_child_semantics: "right".to_string(),
1124 },
1125 Some(FeaturePreprocessing::Numeric { .. }) | None => BinarySplit::NumericBinThreshold {
1126 feature_index,
1127 feature_name: feature_name(feature_index),
1128 operator: "<=".to_string(),
1129 threshold_bin,
1130 threshold_upper_bound: threshold_upper_bound(
1131 preprocessing,
1132 feature_index,
1133 threshold_bin,
1134 ),
1135 comparison_dtype: "uint16".to_string(),
1136 },
1137 }
1138}
1139
1140fn oblivious_split_ir(
1141 feature_index: usize,
1142 threshold_bin: u16,
1143 preprocessing: &[FeaturePreprocessing],
1144) -> IrObliviousSplit {
1145 match preprocessing.get(feature_index) {
1146 Some(FeaturePreprocessing::Binary) => IrObliviousSplit::BooleanTest {
1147 feature_index,
1148 feature_name: feature_name(feature_index),
1149 bit_when_false: 0,
1150 bit_when_true: 1,
1151 },
1152 Some(FeaturePreprocessing::Numeric { .. }) | None => {
1153 IrObliviousSplit::NumericBinThreshold {
1154 feature_index,
1155 feature_name: feature_name(feature_index),
1156 operator: "<=".to_string(),
1157 threshold_bin,
1158 threshold_upper_bound: threshold_upper_bound(
1159 preprocessing,
1160 feature_index,
1161 threshold_bin,
1162 ),
1163 comparison_dtype: "uint16".to_string(),
1164 bit_when_true: 0,
1165 bit_when_false: 1,
1166 }
1167 }
1168 }
1169}
1170
1171fn encode_class_labels(
1172 train_set: &dyn TableAccess,
1173) -> Result<(Vec<f64>, Vec<usize>), DecisionTreeError> {
1174 let targets: Vec<f64> = (0..train_set.n_rows())
1175 .map(|row_idx| {
1176 let value = train_set.target_value(row_idx);
1177 if value.is_finite() {
1178 Ok(value)
1179 } else {
1180 Err(DecisionTreeError::InvalidTargetValue {
1181 row: row_idx,
1182 value,
1183 })
1184 }
1185 })
1186 .collect::<Result<_, _>>()?;
1187
1188 let class_labels = targets
1189 .iter()
1190 .copied()
1191 .fold(Vec::<f64>::new(), |mut labels, value| {
1192 if labels
1193 .binary_search_by(|candidate| candidate.total_cmp(&value))
1194 .is_err()
1195 {
1196 labels.push(value);
1197 labels.sort_by(|left, right| left.total_cmp(right));
1198 }
1199 labels
1200 });
1201
1202 let class_indices = targets
1203 .iter()
1204 .map(|value| {
1205 class_labels
1206 .binary_search_by(|candidate| candidate.total_cmp(value))
1207 .expect("target value must exist in class label vocabulary")
1208 })
1209 .collect();
1210
1211 Ok((class_labels, class_indices))
1212}
1213
1214#[allow(dead_code)]
1215fn build_node(
1216 context: &BuildContext<'_>,
1217 nodes: &mut Vec<TreeNode>,
1218 rows: &[usize],
1219 depth: usize,
1220) -> usize {
1221 let majority_class_index =
1222 majority_class(rows, context.class_indices, context.class_labels.len());
1223 let current_class_counts =
1224 class_counts(rows, context.class_indices, context.class_labels.len());
1225
1226 if rows.is_empty()
1227 || depth >= context.options.max_depth
1228 || rows.len() < context.options.min_samples_split
1229 || is_pure(rows, context.class_indices)
1230 {
1231 return push_leaf(
1232 nodes,
1233 majority_class_index,
1234 rows.len(),
1235 current_class_counts,
1236 );
1237 }
1238
1239 let scoring = SplitScoringContext {
1240 table: context.table,
1241 class_indices: context.class_indices,
1242 num_classes: context.class_labels.len(),
1243 criterion: context.criterion,
1244 min_samples_leaf: context.options.min_samples_leaf,
1245 };
1246 let feature_indices = candidate_feature_indices(
1247 context.table.binned_feature_count(),
1248 context.options.max_features,
1249 node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
1250 );
1251 let best_split = if context.parallelism.enabled() {
1252 feature_indices
1253 .into_par_iter()
1254 .filter_map(|feature_index| {
1255 score_split(&scoring, feature_index, rows, context.algorithm)
1256 })
1257 .max_by(|left, right| split_score(left).total_cmp(&split_score(right)))
1258 } else {
1259 feature_indices
1260 .into_iter()
1261 .filter_map(|feature_index| {
1262 score_split(&scoring, feature_index, rows, context.algorithm)
1263 })
1264 .max_by(|left, right| split_score(left).total_cmp(&split_score(right)))
1265 };
1266
1267 match best_split {
1268 Some(best_split)
1269 if context
1270 .table
1271 .is_canary_binned_feature(split_feature_index(&best_split)) =>
1272 {
1273 push_leaf(
1274 nodes,
1275 majority_class_index,
1276 rows.len(),
1277 current_class_counts,
1278 )
1279 }
1280 Some(SplitCandidate::Multiway {
1281 feature_index,
1282 score,
1283 branches,
1284 }) if score > 0.0 => {
1285 let impurity =
1286 classification_impurity(¤t_class_counts, rows.len(), context.criterion);
1287 let branch_nodes = branches
1288 .into_iter()
1289 .map(|(bin, branch_rows)| {
1290 (bin, build_node(context, nodes, &branch_rows, depth + 1))
1291 })
1292 .collect();
1293
1294 push_node(
1295 nodes,
1296 TreeNode::MultiwaySplit {
1297 feature_index,
1298 fallback_class_index: majority_class_index,
1299 branches: branch_nodes,
1300 sample_count: rows.len(),
1301 impurity,
1302 gain: score,
1303 class_counts: current_class_counts,
1304 },
1305 )
1306 }
1307 Some(SplitCandidate::Binary {
1308 feature_index,
1309 score,
1310 threshold_bin,
1311 left_rows,
1312 right_rows,
1313 }) if score > 0.0 => {
1314 let impurity =
1315 classification_impurity(¤t_class_counts, rows.len(), context.criterion);
1316 let left_child = build_node(context, nodes, &left_rows, depth + 1);
1317 let right_child = build_node(context, nodes, &right_rows, depth + 1);
1318
1319 push_node(
1320 nodes,
1321 TreeNode::BinarySplit {
1322 feature_index,
1323 threshold_bin,
1324 left_child,
1325 right_child,
1326 sample_count: rows.len(),
1327 impurity,
1328 gain: score,
1329 class_counts: current_class_counts,
1330 },
1331 )
1332 }
1333 _ => push_leaf(
1334 nodes,
1335 majority_class_index,
1336 rows.len(),
1337 current_class_counts,
1338 ),
1339 }
1340}
1341
1342struct BuildContext<'a> {
1343 table: &'a dyn TableAccess,
1344 class_indices: &'a [usize],
1345 class_labels: &'a [f64],
1346 algorithm: DecisionTreeAlgorithm,
1347 criterion: Criterion,
1348 parallelism: Parallelism,
1349 options: DecisionTreeOptions,
1350}
1351
1352struct SplitScoringContext<'a> {
1353 table: &'a dyn TableAccess,
1354 class_indices: &'a [usize],
1355 num_classes: usize,
1356 criterion: Criterion,
1357 min_samples_leaf: usize,
1358}
1359
1360fn build_classification_node_histograms(
1361 table: &dyn TableAccess,
1362 class_indices: &[usize],
1363 rows: &[usize],
1364 num_classes: usize,
1365) -> Vec<ClassificationFeatureHistogram> {
1366 (0..table.binned_feature_count())
1367 .map(|feature_index| {
1368 if table.is_binary_binned_feature(feature_index) {
1369 let mut false_counts = vec![0usize; num_classes];
1370 let mut true_counts = vec![0usize; num_classes];
1371 let mut false_size = 0usize;
1372 let mut true_size = 0usize;
1373 for row_idx in rows {
1374 let class_index = class_indices[*row_idx];
1375 if !table
1376 .binned_boolean_value(feature_index, *row_idx)
1377 .expect("binary feature must expose boolean values")
1378 {
1379 false_counts[class_index] += 1;
1380 false_size += 1;
1381 } else {
1382 true_counts[class_index] += 1;
1383 true_size += 1;
1384 }
1385 }
1386 ClassificationFeatureHistogram::Binary {
1387 false_counts,
1388 true_counts,
1389 false_size,
1390 true_size,
1391 }
1392 } else {
1393 let bin_cap = table.numeric_bin_cap();
1394 let mut bin_class_counts = vec![vec![0usize; num_classes]; bin_cap];
1395 let mut observed_bins = vec![false; bin_cap];
1396 for row_idx in rows {
1397 let bin = table.binned_value(feature_index, *row_idx) as usize;
1398 bin_class_counts[bin][class_indices[*row_idx]] += 1;
1399 observed_bins[bin] = true;
1400 }
1401 ClassificationFeatureHistogram::Numeric {
1402 bin_class_counts,
1403 observed_bins: observed_bins
1404 .into_iter()
1405 .enumerate()
1406 .filter_map(|(bin, seen)| seen.then_some(bin))
1407 .collect(),
1408 }
1409 }
1410 })
1411 .collect()
1412}
1413
1414fn subtract_classification_node_histograms(
1415 parent: &[ClassificationFeatureHistogram],
1416 child: &[ClassificationFeatureHistogram],
1417) -> Vec<ClassificationFeatureHistogram> {
1418 parent
1419 .iter()
1420 .zip(child.iter())
1421 .map(
1422 |(parent_hist, child_hist)| match (parent_hist, child_hist) {
1423 (
1424 ClassificationFeatureHistogram::Binary {
1425 false_counts: parent_false_counts,
1426 true_counts: parent_true_counts,
1427 false_size: parent_false_size,
1428 true_size: parent_true_size,
1429 },
1430 ClassificationFeatureHistogram::Binary {
1431 false_counts: child_false_counts,
1432 true_counts: child_true_counts,
1433 false_size: child_false_size,
1434 true_size: child_true_size,
1435 },
1436 ) => ClassificationFeatureHistogram::Binary {
1437 false_counts: parent_false_counts
1438 .iter()
1439 .zip(child_false_counts.iter())
1440 .map(|(parent, child)| parent - child)
1441 .collect(),
1442 true_counts: parent_true_counts
1443 .iter()
1444 .zip(child_true_counts.iter())
1445 .map(|(parent, child)| parent - child)
1446 .collect(),
1447 false_size: parent_false_size - child_false_size,
1448 true_size: parent_true_size - child_true_size,
1449 },
1450 (
1451 ClassificationFeatureHistogram::Numeric {
1452 bin_class_counts: parent_bin_class_counts,
1453 ..
1454 },
1455 ClassificationFeatureHistogram::Numeric {
1456 bin_class_counts: child_bin_class_counts,
1457 ..
1458 },
1459 ) => {
1460 let bin_class_counts = parent_bin_class_counts
1461 .iter()
1462 .zip(child_bin_class_counts.iter())
1463 .map(|(parent_counts, child_counts)| {
1464 parent_counts
1465 .iter()
1466 .zip(child_counts.iter())
1467 .map(|(parent, child)| parent - child)
1468 .collect::<Vec<_>>()
1469 })
1470 .collect::<Vec<_>>();
1471 let observed_bins = bin_class_counts
1472 .iter()
1473 .enumerate()
1474 .filter_map(|(bin, counts)| {
1475 counts.iter().any(|count| *count > 0).then_some(bin)
1476 })
1477 .collect::<Vec<_>>();
1478 ClassificationFeatureHistogram::Numeric {
1479 bin_class_counts,
1480 observed_bins,
1481 }
1482 }
1483 _ => unreachable!("histogram shapes must match"),
1484 },
1485 )
1486 .collect()
1487}
1488
1489#[derive(Debug, Clone)]
1490struct ObliviousLeafState {
1491 start: usize,
1492 end: usize,
1493 class_index: usize,
1494 class_counts: Vec<usize>,
1495}
1496
1497impl ObliviousLeafState {
1498 fn len(&self) -> usize {
1499 self.end - self.start
1500 }
1501}
1502
1503fn train_oblivious_structure(
1504 table: &dyn TableAccess,
1505 class_indices: &[usize],
1506 class_labels: &[f64],
1507 criterion: Criterion,
1508 parallelism: Parallelism,
1509 options: DecisionTreeOptions,
1510) -> TreeStructure {
1511 let mut row_indices: Vec<usize> = (0..table.n_rows()).collect();
1512 let total_class_counts = class_counts(&row_indices, class_indices, class_labels.len());
1513 let total_impurity = classification_impurity(&total_class_counts, row_indices.len(), criterion);
1514 let mut leaves = vec![ObliviousLeafState {
1515 start: 0,
1516 end: row_indices.len(),
1517 class_index: majority_class(&row_indices, class_indices, class_labels.len()),
1518 class_counts: total_class_counts.clone(),
1519 }];
1520 let mut splits = Vec::new();
1521
1522 for depth in 0..options.max_depth {
1523 if leaves
1524 .iter()
1525 .all(|leaf| leaf.len() < options.min_samples_split)
1526 {
1527 break;
1528 }
1529 let feature_indices = candidate_feature_indices(
1530 table.binned_feature_count(),
1531 options.max_features,
1532 node_seed(options.random_seed, depth, &[], 0x0B11_A10Cu64),
1533 );
1534 let best_split = if parallelism.enabled() {
1535 feature_indices
1536 .into_par_iter()
1537 .filter_map(|feature_index| {
1538 score_oblivious_split(
1539 table,
1540 &row_indices,
1541 class_indices,
1542 feature_index,
1543 &leaves,
1544 class_labels.len(),
1545 criterion,
1546 options.min_samples_leaf,
1547 )
1548 })
1549 .max_by(|left, right| left.score.total_cmp(&right.score))
1550 } else {
1551 feature_indices
1552 .into_iter()
1553 .filter_map(|feature_index| {
1554 score_oblivious_split(
1555 table,
1556 &row_indices,
1557 class_indices,
1558 feature_index,
1559 &leaves,
1560 class_labels.len(),
1561 criterion,
1562 options.min_samples_leaf,
1563 )
1564 })
1565 .max_by(|left, right| left.score.total_cmp(&right.score))
1566 };
1567
1568 let Some(best_split) = best_split.filter(|candidate| candidate.score > 0.0) else {
1569 break;
1570 };
1571 if table.is_canary_binned_feature(best_split.feature_index) {
1572 break;
1573 }
1574
1575 leaves = split_oblivious_leaves_in_place(
1576 table,
1577 &mut row_indices,
1578 class_indices,
1579 class_labels.len(),
1580 leaves,
1581 best_split.feature_index,
1582 best_split.threshold_bin,
1583 );
1584 splits.push(ObliviousSplit {
1585 feature_index: best_split.feature_index,
1586 threshold_bin: best_split.threshold_bin,
1587 sample_count: table.n_rows(),
1588 impurity: total_impurity,
1589 gain: best_split.score,
1590 });
1591 }
1592
1593 TreeStructure::Oblivious {
1594 splits,
1595 leaf_class_indices: leaves.iter().map(|leaf| leaf.class_index).collect(),
1596 leaf_sample_counts: leaves.iter().map(ObliviousLeafState::len).collect(),
1597 leaf_class_counts: leaves
1598 .iter()
1599 .map(|leaf| leaf.class_counts.clone())
1600 .collect(),
1601 }
1602}
1603
1604#[derive(Debug, Clone, Copy)]
1605struct ObliviousSplitCandidate {
1606 feature_index: usize,
1607 threshold_bin: u16,
1608 score: f64,
1609}
1610
1611#[allow(clippy::too_many_arguments)]
1612fn score_oblivious_split(
1613 table: &dyn TableAccess,
1614 row_indices: &[usize],
1615 class_indices: &[usize],
1616 feature_index: usize,
1617 leaves: &[ObliviousLeafState],
1618 num_classes: usize,
1619 criterion: Criterion,
1620 min_samples_leaf: usize,
1621) -> Option<ObliviousSplitCandidate> {
1622 if table.is_binary_binned_feature(feature_index) {
1623 return score_binary_oblivious_split(
1624 table,
1625 row_indices,
1626 class_indices,
1627 feature_index,
1628 leaves,
1629 num_classes,
1630 criterion,
1631 min_samples_leaf,
1632 );
1633 }
1634 if let Some(candidate) = score_numeric_oblivious_split_fast(
1635 table,
1636 row_indices,
1637 class_indices,
1638 feature_index,
1639 leaves,
1640 num_classes,
1641 criterion,
1642 min_samples_leaf,
1643 ) {
1644 return Some(candidate);
1645 }
1646 let candidate_thresholds = leaves
1647 .iter()
1648 .flat_map(|leaf| {
1649 row_indices[leaf.start..leaf.end]
1650 .iter()
1651 .map(|row_idx| table.binned_value(feature_index, *row_idx))
1652 })
1653 .collect::<BTreeSet<_>>();
1654
1655 candidate_thresholds
1656 .into_iter()
1657 .filter_map(|threshold_bin| {
1658 let score = leaves.iter().fold(0.0, |score, leaf| {
1659 let leaf_rows = &row_indices[leaf.start..leaf.end];
1660 let (left_rows, right_rows): (Vec<usize>, Vec<usize>) =
1661 leaf_rows.iter().copied().partition(|row_idx| {
1662 table.binned_value(feature_index, *row_idx) <= threshold_bin
1663 });
1664
1665 if left_rows.len() < min_samples_leaf || right_rows.len() < min_samples_leaf {
1666 return score;
1667 }
1668
1669 let parent_counts = leaf.class_counts.clone();
1670 let left_counts = class_counts(&left_rows, class_indices, num_classes);
1671 let right_counts = class_counts(&right_rows, class_indices, num_classes);
1672
1673 let weighted_parent_impurity = leaf.len() as f64
1674 * classification_impurity(&parent_counts, leaf.len(), criterion);
1675 let weighted_children_impurity = left_rows.len() as f64
1676 * classification_impurity(&left_counts, left_rows.len(), criterion)
1677 + right_rows.len() as f64
1678 * classification_impurity(&right_counts, right_rows.len(), criterion);
1679
1680 score + (weighted_parent_impurity - weighted_children_impurity)
1681 });
1682
1683 (score > 0.0).then_some(ObliviousSplitCandidate {
1684 feature_index,
1685 threshold_bin,
1686 score,
1687 })
1688 })
1689 .max_by(|left, right| left.score.total_cmp(&right.score))
1690}
1691
1692fn split_oblivious_leaves_in_place(
1693 table: &dyn TableAccess,
1694 row_indices: &mut [usize],
1695 class_indices: &[usize],
1696 num_classes: usize,
1697 leaves: Vec<ObliviousLeafState>,
1698 feature_index: usize,
1699 threshold_bin: u16,
1700) -> Vec<ObliviousLeafState> {
1701 let mut next_leaves = Vec::with_capacity(leaves.len() * 2);
1702 for leaf in leaves {
1703 let left_count = partition_rows_for_binary_split(
1704 table,
1705 feature_index,
1706 threshold_bin,
1707 &mut row_indices[leaf.start..leaf.end],
1708 );
1709 let mid = leaf.start + left_count;
1710 let mut left_class_counts = vec![0usize; num_classes];
1711 let mut right_class_counts = vec![0usize; num_classes];
1712 for row_idx in &row_indices[leaf.start..mid] {
1713 left_class_counts[class_indices[*row_idx]] += 1;
1714 }
1715 for row_idx in &row_indices[mid..leaf.end] {
1716 right_class_counts[class_indices[*row_idx]] += 1;
1717 }
1718 let left_class_index = if left_count == 0 {
1719 leaf.class_index
1720 } else {
1721 majority_class_from_counts(&left_class_counts)
1722 };
1723 let right_class_index = if mid == leaf.end {
1724 leaf.class_index
1725 } else {
1726 majority_class_from_counts(&right_class_counts)
1727 };
1728 next_leaves.push(ObliviousLeafState {
1729 start: leaf.start,
1730 end: mid,
1731 class_index: left_class_index,
1732 class_counts: left_class_counts,
1733 });
1734 next_leaves.push(ObliviousLeafState {
1735 start: mid,
1736 end: leaf.end,
1737 class_index: right_class_index,
1738 class_counts: right_class_counts,
1739 });
1740 }
1741 next_leaves
1742}
1743
1744#[allow(dead_code)]
1745fn score_split(
1746 context: &SplitScoringContext<'_>,
1747 feature_index: usize,
1748 rows: &[usize],
1749 algorithm: DecisionTreeAlgorithm,
1750) -> Option<SplitCandidate> {
1751 match algorithm {
1752 DecisionTreeAlgorithm::Id3 => score_multiway_split(
1753 context,
1754 feature_index,
1755 rows,
1756 MultiwayMetric::InformationGain,
1757 ),
1758 DecisionTreeAlgorithm::C45 => {
1759 score_multiway_split(context, feature_index, rows, MultiwayMetric::GainRatio)
1760 }
1761 DecisionTreeAlgorithm::Cart => score_cart_split(context, feature_index, rows),
1762 DecisionTreeAlgorithm::Randomized => score_randomized_split(context, feature_index, rows),
1763 DecisionTreeAlgorithm::Oblivious => None,
1764 }
1765}
1766
1767#[allow(dead_code)]
1768fn score_multiway_split(
1769 context: &SplitScoringContext<'_>,
1770 feature_index: usize,
1771 rows: &[usize],
1772 metric: MultiwayMetric,
1773) -> Option<SplitCandidate> {
1774 let grouped_rows = if context.table.is_binary_binned_feature(feature_index) {
1775 let (false_rows, true_rows): (Vec<usize>, Vec<usize>) =
1776 rows.iter().copied().partition(|row_idx| {
1777 !context
1778 .table
1779 .binned_boolean_value(feature_index, *row_idx)
1780 .expect("binary feature must expose boolean values")
1781 });
1782 [(0u16, false_rows), (1u16, true_rows)]
1783 .into_iter()
1784 .filter(|(_bin, group_rows)| !group_rows.is_empty())
1785 .collect::<BTreeMap<_, _>>()
1786 } else {
1787 rows.iter()
1788 .fold(BTreeMap::<u16, Vec<usize>>::new(), |mut groups, row_idx| {
1789 groups
1790 .entry(context.table.binned_value(feature_index, *row_idx))
1791 .or_default()
1792 .push(*row_idx);
1793 groups
1794 })
1795 };
1796
1797 if grouped_rows.len() <= 1
1798 || grouped_rows
1799 .values()
1800 .any(|group| group.len() < context.min_samples_leaf)
1801 {
1802 return None;
1803 }
1804
1805 let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
1806 let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
1807 let weighted_child_impurity = grouped_rows
1808 .values()
1809 .map(|group_rows| {
1810 let counts = class_counts(group_rows, context.class_indices, context.num_classes);
1811 (group_rows.len() as f64 / rows.len() as f64)
1812 * classification_impurity(&counts, group_rows.len(), context.criterion)
1813 })
1814 .sum::<f64>();
1815 let information_gain = parent_impurity - weighted_child_impurity;
1816
1817 let score = match metric {
1818 MultiwayMetric::InformationGain => information_gain,
1819 MultiwayMetric::GainRatio => {
1820 let split_info = grouped_rows
1821 .values()
1822 .map(|group_rows| {
1823 let probability = group_rows.len() as f64 / rows.len() as f64;
1824 -probability * probability.log2()
1825 })
1826 .sum::<f64>();
1827
1828 if split_info == 0.0 {
1829 return None;
1830 }
1831
1832 information_gain / split_info
1833 }
1834 };
1835
1836 Some(SplitCandidate::Multiway {
1837 feature_index,
1838 score,
1839 branches: grouped_rows.into_iter().collect(),
1840 })
1841}
1842
1843fn score_multiway_split_choice(
1844 context: &SplitScoringContext<'_>,
1845 feature_index: usize,
1846 rows: &[usize],
1847 metric: MultiwayMetric,
1848) -> Option<MultiwaySplitChoice> {
1849 let grouped_counts = if context.table.is_binary_binned_feature(feature_index) {
1850 let mut false_counts = vec![0usize; context.num_classes];
1851 let mut true_counts = vec![0usize; context.num_classes];
1852 let mut false_size = 0usize;
1853 let mut true_size = 0usize;
1854 for row_idx in rows {
1855 let class_index = context.class_indices[*row_idx];
1856 if !context
1857 .table
1858 .binned_boolean_value(feature_index, *row_idx)
1859 .expect("binary feature must expose boolean values")
1860 {
1861 false_counts[class_index] += 1;
1862 false_size += 1;
1863 } else {
1864 true_counts[class_index] += 1;
1865 true_size += 1;
1866 }
1867 }
1868 [
1869 (0u16, (false_size, false_counts)),
1870 (1u16, (true_size, true_counts)),
1871 ]
1872 .into_iter()
1873 .filter(|(_, (size, _))| *size > 0)
1874 .collect::<Vec<_>>()
1875 } else {
1876 let mut grouped = BTreeMap::<u16, (usize, Vec<usize>)>::new();
1877 for row_idx in rows {
1878 let bin = context.table.binned_value(feature_index, *row_idx);
1879 let entry = grouped
1880 .entry(bin)
1881 .or_insert_with(|| (0usize, vec![0usize; context.num_classes]));
1882 entry.0 += 1;
1883 entry.1[context.class_indices[*row_idx]] += 1;
1884 }
1885 grouped.into_iter().collect::<Vec<_>>()
1886 };
1887
1888 if grouped_counts.len() <= 1
1889 || grouped_counts
1890 .iter()
1891 .any(|(_, (group_size, _))| *group_size < context.min_samples_leaf)
1892 {
1893 return None;
1894 }
1895
1896 let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
1897 let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
1898 let weighted_child_impurity = grouped_counts
1899 .iter()
1900 .map(|(_, (group_size, counts))| {
1901 (*group_size as f64 / rows.len() as f64)
1902 * classification_impurity(counts, *group_size, context.criterion)
1903 })
1904 .sum::<f64>();
1905 let information_gain = parent_impurity - weighted_child_impurity;
1906
1907 let score = match metric {
1908 MultiwayMetric::InformationGain => information_gain,
1909 MultiwayMetric::GainRatio => {
1910 let split_info = grouped_counts
1911 .iter()
1912 .map(|(_, (group_size, _))| {
1913 let probability = *group_size as f64 / rows.len() as f64;
1914 -probability * probability.log2()
1915 })
1916 .sum::<f64>();
1917 if split_info == 0.0 {
1918 return None;
1919 }
1920 information_gain / split_info
1921 }
1922 };
1923
1924 Some(MultiwaySplitChoice {
1925 feature_index,
1926 score,
1927 branch_bins: grouped_counts.into_iter().map(|(bin, _)| bin).collect(),
1928 })
1929}
1930
1931#[allow(dead_code)]
1932fn score_cart_split(
1933 context: &SplitScoringContext<'_>,
1934 feature_index: usize,
1935 rows: &[usize],
1936) -> Option<SplitCandidate> {
1937 if context.table.is_binary_binned_feature(feature_index) {
1938 return score_binary_cart_split(context, feature_index, rows);
1939 }
1940 if let Some(candidate) = score_numeric_cart_split_fast(context, feature_index, rows) {
1941 return Some(candidate);
1942 }
1943 let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
1944 let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
1945
1946 rows.iter()
1947 .map(|row_idx| context.table.binned_value(feature_index, *row_idx))
1948 .collect::<BTreeSet<_>>()
1949 .into_iter()
1950 .filter_map(|threshold_bin| {
1951 let (left_rows, right_rows): (Vec<usize>, Vec<usize>) =
1952 rows.iter().copied().partition(|row_idx| {
1953 context.table.binned_value(feature_index, *row_idx) <= threshold_bin
1954 });
1955
1956 if left_rows.len() < context.min_samples_leaf
1957 || right_rows.len() < context.min_samples_leaf
1958 {
1959 return None;
1960 }
1961
1962 let left_counts = class_counts(&left_rows, context.class_indices, context.num_classes);
1963 let right_counts =
1964 class_counts(&right_rows, context.class_indices, context.num_classes);
1965 let weighted_impurity = (left_rows.len() as f64 / rows.len() as f64)
1966 * classification_impurity(&left_counts, left_rows.len(), context.criterion)
1967 + (right_rows.len() as f64 / rows.len() as f64)
1968 * classification_impurity(&right_counts, right_rows.len(), context.criterion);
1969
1970 Some(SplitCandidate::Binary {
1971 feature_index,
1972 score: parent_impurity - weighted_impurity,
1973 threshold_bin,
1974 left_rows,
1975 right_rows,
1976 })
1977 })
1978 .max_by(|left, right| split_score(left).total_cmp(&split_score(right)))
1979}
1980
1981#[allow(dead_code)]
1982fn score_randomized_split(
1983 context: &SplitScoringContext<'_>,
1984 feature_index: usize,
1985 rows: &[usize],
1986) -> Option<SplitCandidate> {
1987 if context.table.is_binary_binned_feature(feature_index) {
1988 return score_binary_cart_split(context, feature_index, rows);
1989 }
1990 if let Some(candidate) = score_numeric_randomized_split_fast(context, feature_index, rows) {
1991 return Some(candidate);
1992 }
1993
1994 let candidate_thresholds = rows
1995 .iter()
1996 .map(|row_idx| context.table.binned_value(feature_index, *row_idx))
1997 .collect::<BTreeSet<_>>()
1998 .into_iter()
1999 .collect::<Vec<_>>();
2000 let threshold_bin =
2001 choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xC1A551F1u64)?;
2002
2003 let (left_rows, right_rows): (Vec<usize>, Vec<usize>) = rows
2004 .iter()
2005 .copied()
2006 .partition(|row_idx| context.table.binned_value(feature_index, *row_idx) <= threshold_bin);
2007
2008 if left_rows.len() < context.min_samples_leaf || right_rows.len() < context.min_samples_leaf {
2009 return None;
2010 }
2011
2012 let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
2013 let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
2014 let left_counts = class_counts(&left_rows, context.class_indices, context.num_classes);
2015 let right_counts = class_counts(&right_rows, context.class_indices, context.num_classes);
2016 let weighted_impurity = (left_rows.len() as f64 / rows.len() as f64)
2017 * classification_impurity(&left_counts, left_rows.len(), context.criterion)
2018 + (right_rows.len() as f64 / rows.len() as f64)
2019 * classification_impurity(&right_counts, right_rows.len(), context.criterion);
2020
2021 Some(SplitCandidate::Binary {
2022 feature_index,
2023 score: parent_impurity - weighted_impurity,
2024 threshold_bin,
2025 left_rows,
2026 right_rows,
2027 })
2028}
2029
2030#[allow(dead_code)]
2031fn score_binary_cart_split(
2032 context: &SplitScoringContext<'_>,
2033 feature_index: usize,
2034 rows: &[usize],
2035) -> Option<SplitCandidate> {
2036 let (left_rows, right_rows): (Vec<usize>, Vec<usize>) =
2037 rows.iter().copied().partition(|row_idx| {
2038 !context
2039 .table
2040 .binned_boolean_value(feature_index, *row_idx)
2041 .expect("binary feature must expose boolean values")
2042 });
2043
2044 if left_rows.len() < context.min_samples_leaf || right_rows.len() < context.min_samples_leaf {
2045 return None;
2046 }
2047
2048 let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
2049 let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
2050 let left_counts = class_counts(&left_rows, context.class_indices, context.num_classes);
2051 let right_counts = class_counts(&right_rows, context.class_indices, context.num_classes);
2052 let weighted_impurity = (left_rows.len() as f64 / rows.len() as f64)
2053 * classification_impurity(&left_counts, left_rows.len(), context.criterion)
2054 + (right_rows.len() as f64 / rows.len() as f64)
2055 * classification_impurity(&right_counts, right_rows.len(), context.criterion);
2056
2057 Some(SplitCandidate::Binary {
2058 feature_index,
2059 score: parent_impurity - weighted_impurity,
2060 threshold_bin: 0,
2061 left_rows,
2062 right_rows,
2063 })
2064}
2065
2066#[allow(clippy::too_many_arguments)]
2067fn score_binary_oblivious_split(
2068 table: &dyn TableAccess,
2069 row_indices: &[usize],
2070 class_indices: &[usize],
2071 feature_index: usize,
2072 leaves: &[ObliviousLeafState],
2073 num_classes: usize,
2074 criterion: Criterion,
2075 min_samples_leaf: usize,
2076) -> Option<ObliviousSplitCandidate> {
2077 let mut score = 0.0;
2078 let mut found_valid = false;
2079
2080 for leaf in leaves {
2081 let mut left_counts = vec![0usize; num_classes];
2082 let mut left_size = 0usize;
2083 for row_idx in &row_indices[leaf.start..leaf.end] {
2084 if !table
2085 .binned_boolean_value(feature_index, *row_idx)
2086 .expect("binary feature must expose boolean values")
2087 {
2088 left_counts[class_indices[*row_idx]] += 1;
2089 left_size += 1;
2090 }
2091 }
2092 let right_size = leaf.len() - left_size;
2093 if left_size < min_samples_leaf || right_size < min_samples_leaf {
2094 continue;
2095 }
2096 found_valid = true;
2097 let right_counts = leaf
2098 .class_counts
2099 .iter()
2100 .zip(left_counts.iter())
2101 .map(|(parent, left)| parent - left)
2102 .collect::<Vec<_>>();
2103 let weighted_parent_impurity =
2104 leaf.len() as f64 * classification_impurity(&leaf.class_counts, leaf.len(), criterion);
2105 let weighted_children_impurity = left_size as f64
2106 * classification_impurity(&left_counts, left_size, criterion)
2107 + right_size as f64 * classification_impurity(&right_counts, right_size, criterion);
2108 score += weighted_parent_impurity - weighted_children_impurity;
2109 }
2110
2111 (found_valid && score > 0.0).then_some(ObliviousSplitCandidate {
2112 feature_index,
2113 threshold_bin: 0,
2114 score,
2115 })
2116}
2117
2118#[allow(clippy::too_many_arguments)]
2119fn score_numeric_oblivious_split_fast(
2120 table: &dyn TableAccess,
2121 row_indices: &[usize],
2122 class_indices: &[usize],
2123 feature_index: usize,
2124 leaves: &[ObliviousLeafState],
2125 num_classes: usize,
2126 criterion: Criterion,
2127 min_samples_leaf: usize,
2128) -> Option<ObliviousSplitCandidate> {
2129 let bin_cap = table.numeric_bin_cap();
2130 if bin_cap == 0 {
2131 return None;
2132 }
2133
2134 let mut threshold_scores = vec![0.0; bin_cap];
2135 let mut observed_any = false;
2136
2137 for leaf in leaves {
2138 let mut bin_class_counts = vec![vec![0usize; num_classes]; bin_cap];
2139 let mut observed_bins = vec![false; bin_cap];
2140 for row_idx in &row_indices[leaf.start..leaf.end] {
2141 let bin = table.binned_value(feature_index, *row_idx) as usize;
2142 if bin >= bin_cap {
2143 return None;
2144 }
2145 bin_class_counts[bin][class_indices[*row_idx]] += 1;
2146 observed_bins[bin] = true;
2147 }
2148
2149 let observed_bins: Vec<usize> = observed_bins
2150 .into_iter()
2151 .enumerate()
2152 .filter_map(|(bin, seen)| seen.then_some(bin))
2153 .collect();
2154 if observed_bins.len() <= 1 {
2155 continue;
2156 }
2157 observed_any = true;
2158
2159 let parent_weighted_impurity =
2160 leaf.len() as f64 * classification_impurity(&leaf.class_counts, leaf.len(), criterion);
2161 let mut left_counts = vec![0usize; num_classes];
2162 let mut left_size = 0usize;
2163
2164 for &bin in &observed_bins {
2165 for class_index in 0..num_classes {
2166 left_counts[class_index] += bin_class_counts[bin][class_index];
2167 }
2168 left_size += bin_class_counts[bin].iter().sum::<usize>();
2169 let right_size = leaf.len() - left_size;
2170
2171 if left_size < min_samples_leaf || right_size < min_samples_leaf {
2172 continue;
2173 }
2174
2175 let right_counts = leaf
2176 .class_counts
2177 .iter()
2178 .zip(left_counts.iter())
2179 .map(|(parent, left)| parent - left)
2180 .collect::<Vec<_>>();
2181 let weighted_children_impurity = left_size as f64
2182 * classification_impurity(&left_counts, left_size, criterion)
2183 + right_size as f64 * classification_impurity(&right_counts, right_size, criterion);
2184 threshold_scores[bin] += parent_weighted_impurity - weighted_children_impurity;
2185 }
2186 }
2187
2188 if !observed_any {
2189 return None;
2190 }
2191
2192 threshold_scores
2193 .into_iter()
2194 .enumerate()
2195 .filter(|(_, score)| *score > 0.0)
2196 .max_by(|left, right| left.1.total_cmp(&right.1))
2197 .map(|(threshold_bin, score)| ObliviousSplitCandidate {
2198 feature_index,
2199 threshold_bin: threshold_bin as u16,
2200 score,
2201 })
2202}
2203
2204#[allow(dead_code)]
2205fn score_numeric_cart_split_fast(
2206 context: &SplitScoringContext<'_>,
2207 feature_index: usize,
2208 rows: &[usize],
2209) -> Option<SplitCandidate> {
2210 let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
2211 let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
2212 let bin_cap = context.table.numeric_bin_cap();
2213 if bin_cap == 0 {
2214 return None;
2215 }
2216
2217 let mut bin_class_counts = vec![vec![0usize; context.num_classes]; bin_cap];
2218 let mut observed_bins = vec![false; bin_cap];
2219 for row_idx in rows {
2220 let bin = context.table.binned_value(feature_index, *row_idx) as usize;
2221 if bin >= bin_cap {
2222 return None;
2223 }
2224 bin_class_counts[bin][context.class_indices[*row_idx]] += 1;
2225 observed_bins[bin] = true;
2226 }
2227
2228 let observed_bins: Vec<usize> = observed_bins
2229 .into_iter()
2230 .enumerate()
2231 .filter_map(|(bin, seen)| seen.then_some(bin))
2232 .collect();
2233 if observed_bins.len() <= 1 {
2234 return None;
2235 }
2236
2237 let mut left_counts = vec![0usize; context.num_classes];
2238 let mut left_size = 0usize;
2239 let mut best_threshold = None;
2240 let mut best_score = f64::NEG_INFINITY;
2241
2242 for &bin in &observed_bins {
2243 for class_index in 0..context.num_classes {
2244 left_counts[class_index] += bin_class_counts[bin][class_index];
2245 }
2246 left_size += bin_class_counts[bin].iter().sum::<usize>();
2247 let right_size = rows.len() - left_size;
2248
2249 if left_size < context.min_samples_leaf || right_size < context.min_samples_leaf {
2250 continue;
2251 }
2252
2253 let right_counts = parent_counts
2254 .iter()
2255 .zip(left_counts.iter())
2256 .map(|(parent, left)| parent - left)
2257 .collect::<Vec<_>>();
2258 let weighted_impurity = (left_size as f64 / rows.len() as f64)
2259 * classification_impurity(&left_counts, left_size, context.criterion)
2260 + (right_size as f64 / rows.len() as f64)
2261 * classification_impurity(&right_counts, right_size, context.criterion);
2262 let score = parent_impurity - weighted_impurity;
2263 if score > best_score {
2264 best_score = score;
2265 best_threshold = Some(bin as u16);
2266 }
2267 }
2268
2269 let threshold_bin = best_threshold?;
2270 let (left_rows, right_rows): (Vec<usize>, Vec<usize>) = rows
2271 .iter()
2272 .copied()
2273 .partition(|row_idx| context.table.binned_value(feature_index, *row_idx) <= threshold_bin);
2274
2275 Some(SplitCandidate::Binary {
2276 feature_index,
2277 score: best_score,
2278 threshold_bin,
2279 left_rows,
2280 right_rows,
2281 })
2282}
2283
2284#[allow(dead_code)]
2285fn score_numeric_randomized_split_fast(
2286 context: &SplitScoringContext<'_>,
2287 feature_index: usize,
2288 rows: &[usize],
2289) -> Option<SplitCandidate> {
2290 let bin_cap = context.table.numeric_bin_cap();
2291 if bin_cap == 0 {
2292 return None;
2293 }
2294 let mut observed_bins = vec![false; bin_cap];
2295 for row_idx in rows {
2296 let bin = context.table.binned_value(feature_index, *row_idx) as usize;
2297 if bin >= bin_cap {
2298 return None;
2299 }
2300 observed_bins[bin] = true;
2301 }
2302 let candidate_thresholds = observed_bins
2303 .into_iter()
2304 .enumerate()
2305 .filter_map(|(bin, seen)| seen.then_some(bin as u16))
2306 .collect::<Vec<_>>();
2307 let threshold_bin =
2308 choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xC1A551F1u64)?;
2309
2310 let (left_rows, right_rows): (Vec<usize>, Vec<usize>) = rows
2311 .iter()
2312 .copied()
2313 .partition(|row_idx| context.table.binned_value(feature_index, *row_idx) <= threshold_bin);
2314
2315 if left_rows.len() < context.min_samples_leaf || right_rows.len() < context.min_samples_leaf {
2316 return None;
2317 }
2318
2319 let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
2320 let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
2321 let left_counts = class_counts(&left_rows, context.class_indices, context.num_classes);
2322 let right_counts = class_counts(&right_rows, context.class_indices, context.num_classes);
2323 let weighted_impurity = (left_rows.len() as f64 / rows.len() as f64)
2324 * classification_impurity(&left_counts, left_rows.len(), context.criterion)
2325 + (right_rows.len() as f64 / rows.len() as f64)
2326 * classification_impurity(&right_counts, right_rows.len(), context.criterion);
2327
2328 Some(SplitCandidate::Binary {
2329 feature_index,
2330 score: parent_impurity - weighted_impurity,
2331 threshold_bin,
2332 left_rows,
2333 right_rows,
2334 })
2335}
2336
2337fn class_counts(rows: &[usize], class_indices: &[usize], num_classes: usize) -> Vec<usize> {
2338 rows.iter()
2339 .fold(vec![0usize; num_classes], |mut counts, row_idx| {
2340 counts[class_indices[*row_idx]] += 1;
2341 counts
2342 })
2343}
2344
2345fn majority_class(rows: &[usize], class_indices: &[usize], num_classes: usize) -> usize {
2346 majority_class_from_counts(&class_counts(rows, class_indices, num_classes))
2347}
2348
2349fn majority_class_from_counts(counts: &[usize]) -> usize {
2350 counts
2351 .iter()
2352 .copied()
2353 .enumerate()
2354 .max_by(|left, right| left.1.cmp(&right.1).then_with(|| right.0.cmp(&left.0)))
2355 .map(|(class_index, _count)| class_index)
2356 .unwrap_or(0)
2357}
2358
2359fn is_pure(rows: &[usize], class_indices: &[usize]) -> bool {
2360 rows.first().is_none_or(|first_row| {
2361 rows.iter()
2362 .all(|row_idx| class_indices[*row_idx] == class_indices[*first_row])
2363 })
2364}
2365
2366fn entropy(counts: &[usize], total: usize) -> f64 {
2367 counts
2368 .iter()
2369 .copied()
2370 .filter(|count| *count > 0)
2371 .map(|count| {
2372 let probability = count as f64 / total as f64;
2373 -probability * probability.log2()
2374 })
2375 .sum()
2376}
2377
2378fn gini(counts: &[usize], total: usize) -> f64 {
2379 1.0 - counts
2380 .iter()
2381 .copied()
2382 .map(|count| {
2383 let probability = count as f64 / total as f64;
2384 probability * probability
2385 })
2386 .sum::<f64>()
2387}
2388
2389fn classification_impurity(counts: &[usize], total: usize, criterion: Criterion) -> f64 {
2390 match criterion {
2391 Criterion::Entropy => entropy(counts, total),
2392 Criterion::Gini => gini(counts, total),
2393 _ => unreachable!("classification impurity only supports gini or entropy"),
2394 }
2395}
2396
2397#[allow(dead_code)]
2398fn split_score(candidate: &SplitCandidate) -> f64 {
2399 match candidate {
2400 SplitCandidate::Multiway { score, .. } | SplitCandidate::Binary { score, .. } => *score,
2401 }
2402}
2403
2404#[allow(dead_code)]
2405fn score_binary_split_choice(
2406 context: &SplitScoringContext<'_>,
2407 feature_index: usize,
2408 rows: &[usize],
2409 algorithm: DecisionTreeAlgorithm,
2410) -> Option<BinarySplitChoice> {
2411 match algorithm {
2412 DecisionTreeAlgorithm::Cart => {
2413 if context.table.is_binary_binned_feature(feature_index) {
2414 score_binary_cart_split_choice(context, feature_index, rows)
2415 } else {
2416 score_numeric_cart_split_choice_fast(context, feature_index, rows)
2417 }
2418 }
2419 DecisionTreeAlgorithm::Randomized => {
2420 if context.table.is_binary_binned_feature(feature_index) {
2421 score_binary_cart_split_choice(context, feature_index, rows)
2422 } else {
2423 score_numeric_randomized_split_choice_fast(context, feature_index, rows)
2424 }
2425 }
2426 _ => None,
2427 }
2428}
2429
2430fn score_binary_split_choice_from_hist(
2431 context: &SplitScoringContext<'_>,
2432 histogram: &ClassificationFeatureHistogram,
2433 feature_index: usize,
2434 rows: &[usize],
2435 parent_counts: &[usize],
2436 algorithm: DecisionTreeAlgorithm,
2437) -> Option<BinarySplitChoice> {
2438 match (algorithm, histogram) {
2439 (
2440 DecisionTreeAlgorithm::Cart,
2441 ClassificationFeatureHistogram::Binary {
2442 false_counts,
2443 true_counts,
2444 false_size,
2445 true_size,
2446 },
2447 ) => score_binary_cart_split_choice_from_counts(
2448 context,
2449 feature_index,
2450 parent_counts,
2451 false_counts,
2452 *false_size,
2453 true_counts,
2454 *true_size,
2455 ),
2456 (
2457 DecisionTreeAlgorithm::Cart,
2458 ClassificationFeatureHistogram::Numeric {
2459 bin_class_counts,
2460 observed_bins,
2461 },
2462 ) => score_numeric_cart_split_choice_from_hist(
2463 context,
2464 feature_index,
2465 parent_counts,
2466 rows.len(),
2467 bin_class_counts,
2468 observed_bins,
2469 ),
2470 (
2471 DecisionTreeAlgorithm::Randomized,
2472 ClassificationFeatureHistogram::Binary {
2473 false_counts,
2474 true_counts,
2475 false_size,
2476 true_size,
2477 },
2478 ) => score_binary_cart_split_choice_from_counts(
2479 context,
2480 feature_index,
2481 parent_counts,
2482 false_counts,
2483 *false_size,
2484 true_counts,
2485 *true_size,
2486 ),
2487 (
2488 DecisionTreeAlgorithm::Randomized,
2489 ClassificationFeatureHistogram::Numeric { observed_bins, .. },
2490 ) => score_numeric_randomized_split_choice_from_hist(
2491 context,
2492 feature_index,
2493 rows,
2494 parent_counts,
2495 observed_bins,
2496 histogram,
2497 ),
2498 _ => None,
2499 }
2500}
2501
2502fn score_binary_cart_split_choice_from_counts(
2503 context: &SplitScoringContext<'_>,
2504 feature_index: usize,
2505 parent_counts: &[usize],
2506 left_counts: &[usize],
2507 left_size: usize,
2508 right_counts: &[usize],
2509 right_size: usize,
2510) -> Option<BinarySplitChoice> {
2511 if left_size < context.min_samples_leaf || right_size < context.min_samples_leaf {
2512 return None;
2513 }
2514 let parent_impurity =
2515 classification_impurity(parent_counts, left_size + right_size, context.criterion);
2516 let weighted_impurity = (left_size as f64 / (left_size + right_size) as f64)
2517 * classification_impurity(left_counts, left_size, context.criterion)
2518 + (right_size as f64 / (left_size + right_size) as f64)
2519 * classification_impurity(right_counts, right_size, context.criterion);
2520 Some(BinarySplitChoice {
2521 feature_index,
2522 score: parent_impurity - weighted_impurity,
2523 threshold_bin: 0,
2524 })
2525}
2526
2527fn score_numeric_cart_split_choice_from_hist(
2528 context: &SplitScoringContext<'_>,
2529 feature_index: usize,
2530 parent_counts: &[usize],
2531 row_count: usize,
2532 bin_class_counts: &[Vec<usize>],
2533 observed_bins: &[usize],
2534) -> Option<BinarySplitChoice> {
2535 if observed_bins.len() <= 1 {
2536 return None;
2537 }
2538 let parent_impurity = classification_impurity(parent_counts, row_count, context.criterion);
2539 let mut left_counts = vec![0usize; context.num_classes];
2540 let mut left_size = 0usize;
2541 let mut best_threshold = None;
2542 let mut best_score = f64::NEG_INFINITY;
2543
2544 for &bin in observed_bins {
2545 for class_index in 0..context.num_classes {
2546 left_counts[class_index] += bin_class_counts[bin][class_index];
2547 }
2548 left_size += bin_class_counts[bin].iter().sum::<usize>();
2549 let right_size = row_count - left_size;
2550 if left_size < context.min_samples_leaf || right_size < context.min_samples_leaf {
2551 continue;
2552 }
2553 let right_counts = parent_counts
2554 .iter()
2555 .zip(left_counts.iter())
2556 .map(|(parent, left)| parent - left)
2557 .collect::<Vec<_>>();
2558 let weighted_impurity = (left_size as f64 / row_count as f64)
2559 * classification_impurity(&left_counts, left_size, context.criterion)
2560 + (right_size as f64 / row_count as f64)
2561 * classification_impurity(&right_counts, right_size, context.criterion);
2562 let score = parent_impurity - weighted_impurity;
2563 if score > best_score {
2564 best_score = score;
2565 best_threshold = Some(bin as u16);
2566 }
2567 }
2568
2569 best_threshold.map(|threshold_bin| BinarySplitChoice {
2570 feature_index,
2571 score: best_score,
2572 threshold_bin,
2573 })
2574}
2575
2576fn score_numeric_randomized_split_choice_from_hist(
2577 context: &SplitScoringContext<'_>,
2578 feature_index: usize,
2579 rows: &[usize],
2580 parent_counts: &[usize],
2581 observed_bins: &[usize],
2582 histogram: &ClassificationFeatureHistogram,
2583) -> Option<BinarySplitChoice> {
2584 let candidate_thresholds = observed_bins
2585 .iter()
2586 .copied()
2587 .map(|bin| bin as u16)
2588 .collect::<Vec<_>>();
2589 let threshold_bin =
2590 choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xC1A551F1u64)?;
2591 let ClassificationFeatureHistogram::Numeric {
2592 bin_class_counts, ..
2593 } = histogram
2594 else {
2595 unreachable!("randomized numeric histogram must be numeric");
2596 };
2597 let mut left_counts = vec![0usize; context.num_classes];
2598 let mut left_size = 0usize;
2599 for bin in 0..=threshold_bin as usize {
2600 if bin >= bin_class_counts.len() {
2601 break;
2602 }
2603 for class_index in 0..context.num_classes {
2604 left_counts[class_index] += bin_class_counts[bin][class_index];
2605 }
2606 left_size += bin_class_counts[bin].iter().sum::<usize>();
2607 }
2608 let row_count = rows.len();
2609 let right_size = row_count - left_size;
2610 if left_size < context.min_samples_leaf || right_size < context.min_samples_leaf {
2611 return None;
2612 }
2613 let right_counts = parent_counts
2614 .iter()
2615 .zip(left_counts.iter())
2616 .map(|(parent, left)| parent - left)
2617 .collect::<Vec<_>>();
2618 let parent_impurity = classification_impurity(parent_counts, row_count, context.criterion);
2619 let weighted_impurity = (left_size as f64 / row_count as f64)
2620 * classification_impurity(&left_counts, left_size, context.criterion)
2621 + (right_size as f64 / row_count as f64)
2622 * classification_impurity(&right_counts, right_size, context.criterion);
2623 Some(BinarySplitChoice {
2624 feature_index,
2625 score: parent_impurity - weighted_impurity,
2626 threshold_bin,
2627 })
2628}
2629
2630#[allow(dead_code)]
2631fn score_binary_cart_split_choice(
2632 context: &SplitScoringContext<'_>,
2633 feature_index: usize,
2634 rows: &[usize],
2635) -> Option<BinarySplitChoice> {
2636 let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
2637 let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
2638 let mut left_counts = vec![0usize; context.num_classes];
2639 let mut left_size = 0usize;
2640
2641 for row_idx in rows {
2642 if !context
2643 .table
2644 .binned_boolean_value(feature_index, *row_idx)
2645 .expect("binary feature must expose boolean values")
2646 {
2647 left_counts[context.class_indices[*row_idx]] += 1;
2648 left_size += 1;
2649 }
2650 }
2651
2652 let right_size = rows.len() - left_size;
2653 if left_size < context.min_samples_leaf || right_size < context.min_samples_leaf {
2654 return None;
2655 }
2656
2657 let right_counts = parent_counts
2658 .iter()
2659 .zip(left_counts.iter())
2660 .map(|(parent, left)| parent - left)
2661 .collect::<Vec<_>>();
2662 let weighted_impurity = (left_size as f64 / rows.len() as f64)
2663 * classification_impurity(&left_counts, left_size, context.criterion)
2664 + (right_size as f64 / rows.len() as f64)
2665 * classification_impurity(&right_counts, right_size, context.criterion);
2666
2667 Some(BinarySplitChoice {
2668 feature_index,
2669 score: parent_impurity - weighted_impurity,
2670 threshold_bin: 0,
2671 })
2672}
2673
2674#[allow(dead_code)]
2675fn score_numeric_cart_split_choice_fast(
2676 context: &SplitScoringContext<'_>,
2677 feature_index: usize,
2678 rows: &[usize],
2679) -> Option<BinarySplitChoice> {
2680 let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
2681 let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
2682 let bin_cap = context.table.numeric_bin_cap();
2683 if bin_cap == 0 {
2684 return None;
2685 }
2686
2687 let mut bin_class_counts = vec![vec![0usize; context.num_classes]; bin_cap];
2688 let mut observed_bins = vec![false; bin_cap];
2689 for row_idx in rows {
2690 let bin = context.table.binned_value(feature_index, *row_idx) as usize;
2691 if bin >= bin_cap {
2692 return None;
2693 }
2694 bin_class_counts[bin][context.class_indices[*row_idx]] += 1;
2695 observed_bins[bin] = true;
2696 }
2697
2698 let observed_bins: Vec<usize> = observed_bins
2699 .into_iter()
2700 .enumerate()
2701 .filter_map(|(bin, seen)| seen.then_some(bin))
2702 .collect();
2703 if observed_bins.len() <= 1 {
2704 return None;
2705 }
2706
2707 let mut left_counts = vec![0usize; context.num_classes];
2708 let mut left_size = 0usize;
2709 let mut best_threshold = None;
2710 let mut best_score = f64::NEG_INFINITY;
2711
2712 for &bin in &observed_bins {
2713 for class_index in 0..context.num_classes {
2714 left_counts[class_index] += bin_class_counts[bin][class_index];
2715 }
2716 left_size += bin_class_counts[bin].iter().sum::<usize>();
2717 let right_size = rows.len() - left_size;
2718
2719 if left_size < context.min_samples_leaf || right_size < context.min_samples_leaf {
2720 continue;
2721 }
2722
2723 let right_counts = parent_counts
2724 .iter()
2725 .zip(left_counts.iter())
2726 .map(|(parent, left)| parent - left)
2727 .collect::<Vec<_>>();
2728 let weighted_impurity = (left_size as f64 / rows.len() as f64)
2729 * classification_impurity(&left_counts, left_size, context.criterion)
2730 + (right_size as f64 / rows.len() as f64)
2731 * classification_impurity(&right_counts, right_size, context.criterion);
2732 let score = parent_impurity - weighted_impurity;
2733 if score > best_score {
2734 best_score = score;
2735 best_threshold = Some(bin as u16);
2736 }
2737 }
2738
2739 best_threshold.map(|threshold_bin| BinarySplitChoice {
2740 feature_index,
2741 score: best_score,
2742 threshold_bin,
2743 })
2744}
2745
2746#[allow(dead_code)]
2747fn score_numeric_randomized_split_choice_fast(
2748 context: &SplitScoringContext<'_>,
2749 feature_index: usize,
2750 rows: &[usize],
2751) -> Option<BinarySplitChoice> {
2752 let bin_cap = context.table.numeric_bin_cap();
2753 if bin_cap == 0 {
2754 return None;
2755 }
2756 let mut observed_bins = vec![false; bin_cap];
2757 for row_idx in rows {
2758 let bin = context.table.binned_value(feature_index, *row_idx) as usize;
2759 if bin >= bin_cap {
2760 return None;
2761 }
2762 observed_bins[bin] = true;
2763 }
2764 let candidate_thresholds = observed_bins
2765 .into_iter()
2766 .enumerate()
2767 .filter_map(|(bin, seen)| seen.then_some(bin as u16))
2768 .collect::<Vec<_>>();
2769 let threshold_bin =
2770 choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xC1A551F1u64)?;
2771
2772 let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
2773 let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
2774 let mut left_counts = vec![0usize; context.num_classes];
2775 let mut left_size = 0usize;
2776 for row_idx in rows {
2777 if context.table.binned_value(feature_index, *row_idx) <= threshold_bin {
2778 left_counts[context.class_indices[*row_idx]] += 1;
2779 left_size += 1;
2780 }
2781 }
2782 let right_size = rows.len() - left_size;
2783 if left_size < context.min_samples_leaf || right_size < context.min_samples_leaf {
2784 return None;
2785 }
2786 let right_counts = parent_counts
2787 .iter()
2788 .zip(left_counts.iter())
2789 .map(|(parent, left)| parent - left)
2790 .collect::<Vec<_>>();
2791 let weighted_impurity = (left_size as f64 / rows.len() as f64)
2792 * classification_impurity(&left_counts, left_size, context.criterion)
2793 + (right_size as f64 / rows.len() as f64)
2794 * classification_impurity(&right_counts, right_size, context.criterion);
2795
2796 Some(BinarySplitChoice {
2797 feature_index,
2798 score: parent_impurity - weighted_impurity,
2799 threshold_bin,
2800 })
2801}
2802
2803fn partition_rows_for_binary_split(
2804 table: &dyn TableAccess,
2805 feature_index: usize,
2806 threshold_bin: u16,
2807 rows: &mut [usize],
2808) -> usize {
2809 let mut left = 0usize;
2810 for index in 0..rows.len() {
2811 let go_left = if table.is_binary_binned_feature(feature_index) {
2812 !table
2813 .binned_boolean_value(feature_index, rows[index])
2814 .expect("binary feature must expose boolean values")
2815 } else {
2816 table.binned_value(feature_index, rows[index]) <= threshold_bin
2817 };
2818 if go_left {
2819 rows.swap(left, index);
2820 left += 1;
2821 }
2822 }
2823 left
2824}
2825
2826fn partition_rows_for_multiway_split(
2827 table: &dyn TableAccess,
2828 feature_index: usize,
2829 branch_bins: &[u16],
2830 rows: &mut [usize],
2831) -> Vec<(u16, usize, usize)> {
2832 let mut scratch = vec![0usize; rows.len()];
2833 let mut counts = vec![0usize; branch_bins.len()];
2834
2835 for row_idx in rows.iter().copied() {
2836 let bin = if table.is_binary_binned_feature(feature_index) {
2837 if table
2838 .binned_boolean_value(feature_index, row_idx)
2839 .expect("binary feature must expose boolean values")
2840 {
2841 1
2842 } else {
2843 0
2844 }
2845 } else {
2846 table.binned_value(feature_index, row_idx)
2847 };
2848 let branch_index = branch_bins
2849 .binary_search(&bin)
2850 .expect("branch bins must cover all observed bins");
2851 counts[branch_index] += 1;
2852 }
2853
2854 let mut offsets = Vec::with_capacity(branch_bins.len());
2855 let mut next = 0usize;
2856 for count in &counts {
2857 offsets.push(next);
2858 next += *count;
2859 }
2860 let mut write_positions = offsets.clone();
2861 for row_idx in rows.iter().copied() {
2862 let bin = if table.is_binary_binned_feature(feature_index) {
2863 if table
2864 .binned_boolean_value(feature_index, row_idx)
2865 .expect("binary feature must expose boolean values")
2866 {
2867 1
2868 } else {
2869 0
2870 }
2871 } else {
2872 table.binned_value(feature_index, row_idx)
2873 };
2874 let branch_index = branch_bins
2875 .binary_search(&bin)
2876 .expect("branch bins must cover all observed bins");
2877 let write_index = write_positions[branch_index];
2878 scratch[write_index] = row_idx;
2879 write_positions[branch_index] += 1;
2880 }
2881 rows.copy_from_slice(&scratch);
2882
2883 branch_bins
2884 .iter()
2885 .copied()
2886 .zip(offsets)
2887 .zip(counts)
2888 .map(|((bin, start), count)| (bin, start, start + count))
2889 .collect()
2890}
2891
2892fn choose_random_threshold(
2893 candidate_thresholds: &[u16],
2894 feature_index: usize,
2895 rows: &[usize],
2896 salt: u64,
2897) -> Option<u16> {
2898 if candidate_thresholds.is_empty() {
2899 return None;
2900 }
2901
2902 let mut seed = salt ^ ((feature_index as u64) << 32) ^ (rows.len() as u64);
2903 for row_idx in rows {
2904 seed = seed
2905 .wrapping_mul(6364136223846793005)
2906 .wrapping_add((*row_idx as u64) + 1);
2907 }
2908 let mut rng = StdRng::seed_from_u64(seed);
2909 let selected = rng.gen_range(0..candidate_thresholds.len());
2910 candidate_thresholds.get(selected).copied()
2911}
2912
2913fn candidate_feature_indices(
2914 feature_count: usize,
2915 max_features: Option<usize>,
2916 seed: u64,
2917) -> Vec<usize> {
2918 match max_features {
2919 Some(count) => sample_feature_subset(feature_count, count, seed),
2920 None => (0..feature_count).collect(),
2921 }
2922}
2923
2924fn node_seed(base_seed: u64, depth: usize, rows: &[usize], salt: u64) -> u64 {
2925 rows.iter().fold(
2926 base_seed
2927 ^ salt
2928 ^ (depth as u64)
2929 .wrapping_mul(0x9E37_79B9_7F4A_7C15)
2930 .rotate_left(11),
2931 |seed, row_index| {
2932 seed.wrapping_mul(0xA076_1D64_78BD_642F)
2933 ^ (*row_index as u64).wrapping_add(0xE703_7ED1_A0B4_28DB)
2934 },
2935 )
2936}
2937
2938#[allow(dead_code)]
2939fn split_feature_index(candidate: &SplitCandidate) -> usize {
2940 match candidate {
2941 SplitCandidate::Multiway { feature_index, .. }
2942 | SplitCandidate::Binary { feature_index, .. } => *feature_index,
2943 }
2944}
2945
2946fn push_leaf(
2947 nodes: &mut Vec<TreeNode>,
2948 class_index: usize,
2949 sample_count: usize,
2950 class_counts: Vec<usize>,
2951) -> usize {
2952 push_node(
2953 nodes,
2954 TreeNode::Leaf {
2955 class_index,
2956 sample_count,
2957 class_counts,
2958 },
2959 )
2960}
2961
2962fn push_node(nodes: &mut Vec<TreeNode>, node: TreeNode) -> usize {
2963 nodes.push(node);
2964 nodes.len() - 1
2965}
2966
2967#[derive(Debug, Clone, Copy)]
2968enum MultiwayMetric {
2969 InformationGain,
2970 GainRatio,
2971}
2972
2973#[cfg(test)]
2974mod tests {
2975 use super::*;
2976 use crate::{FeaturePreprocessing, Model, NumericBinBoundary};
2977 use forestfire_data::{DenseTable, NumericBins};
2978
2979 fn and_table() -> DenseTable {
2980 DenseTable::new(
2981 vec![
2982 vec![0.0, 0.0],
2983 vec![0.0, 1.0],
2984 vec![1.0, 0.0],
2985 vec![1.0, 1.0],
2986 vec![0.0, 0.0],
2987 vec![0.0, 1.0],
2988 vec![1.0, 0.0],
2989 vec![1.0, 1.0],
2990 ],
2991 vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
2992 )
2993 .unwrap()
2994 }
2995
2996 fn criterion_choice_table() -> DenseTable {
2997 DenseTable::with_options(
2998 vec![
2999 vec![0.0, 1.0],
3000 vec![4.0, 1.0],
3001 vec![4.0, 0.0],
3002 vec![0.0, 1.0],
3003 vec![5.0, 2.0],
3004 vec![2.0, 4.0],
3005 vec![1.0, 2.0],
3006 ],
3007 vec![0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0],
3008 0,
3009 NumericBins::Fixed(8),
3010 )
3011 .unwrap()
3012 }
3013
3014 fn canary_target_table() -> DenseTable {
3015 let x: Vec<Vec<f64>> = (0..8).map(|value| vec![value as f64]).collect();
3016 let probe =
3017 DenseTable::with_options(x.clone(), vec![0.0; 8], 1, NumericBins::Auto).unwrap();
3018 let canary_index = probe.n_features();
3019 let mut observed_bins = (0..probe.n_rows())
3020 .map(|row_idx| probe.binned_value(canary_index, row_idx))
3021 .collect::<Vec<_>>();
3022 observed_bins.sort_unstable();
3023 observed_bins.dedup();
3024 let threshold = observed_bins[observed_bins.len() / 2];
3025 let y = (0..probe.n_rows())
3026 .map(|row_idx| {
3027 if probe.binned_value(canary_index, row_idx) >= threshold {
3028 1.0
3029 } else {
3030 0.0
3031 }
3032 })
3033 .collect();
3034
3035 DenseTable::with_options(x, y, 1, NumericBins::Auto).unwrap()
3036 }
3037
3038 #[test]
3039 fn id3_fits_basic_boolean_pattern() {
3040 let table = and_table();
3041 let model = train_id3(&table).unwrap();
3042
3043 assert_eq!(model.algorithm(), DecisionTreeAlgorithm::Id3);
3044 assert_eq!(model.criterion(), Criterion::Entropy);
3045 assert_eq!(model.predict_table(&table), table_targets(&table));
3046 }
3047
3048 #[test]
3049 fn c45_fits_basic_boolean_pattern() {
3050 let table = and_table();
3051 let model = train_c45(&table).unwrap();
3052
3053 assert_eq!(model.algorithm(), DecisionTreeAlgorithm::C45);
3054 assert_eq!(model.criterion(), Criterion::Entropy);
3055 assert_eq!(model.predict_table(&table), table_targets(&table));
3056 }
3057
3058 #[test]
3059 fn cart_fits_basic_boolean_pattern() {
3060 let table = and_table();
3061 let model = train_cart(&table).unwrap();
3062
3063 assert_eq!(model.algorithm(), DecisionTreeAlgorithm::Cart);
3064 assert_eq!(model.criterion(), Criterion::Gini);
3065 assert_eq!(model.predict_table(&table), table_targets(&table));
3066 }
3067
3068 #[test]
3069 fn randomized_fits_basic_boolean_pattern() {
3070 let table = and_table();
3071 let model = train_randomized(&table).unwrap();
3072
3073 assert_eq!(model.algorithm(), DecisionTreeAlgorithm::Randomized);
3074 assert_eq!(model.criterion(), Criterion::Gini);
3075 assert_eq!(model.predict_table(&table), table_targets(&table));
3076 }
3077
3078 #[test]
3079 fn oblivious_fits_basic_boolean_pattern() {
3080 let table = and_table();
3081 let model = train_oblivious(&table).unwrap();
3082
3083 assert_eq!(model.algorithm(), DecisionTreeAlgorithm::Oblivious);
3084 assert_eq!(model.criterion(), Criterion::Gini);
3085 assert_eq!(model.predict_table(&table), table_targets(&table));
3086 }
3087
3088 #[test]
3089 fn cart_can_choose_between_gini_and_entropy() {
3090 let table = criterion_choice_table();
3091 let options = DecisionTreeOptions {
3092 max_depth: 1,
3093 ..DecisionTreeOptions::default()
3094 };
3095 let gini_model = train_classifier(
3096 &table,
3097 DecisionTreeAlgorithm::Cart,
3098 Criterion::Gini,
3099 Parallelism::sequential(),
3100 options,
3101 )
3102 .unwrap();
3103 let entropy_model = train_classifier(
3104 &table,
3105 DecisionTreeAlgorithm::Cart,
3106 Criterion::Entropy,
3107 Parallelism::sequential(),
3108 options,
3109 )
3110 .unwrap();
3111
3112 let root_feature = |model: &DecisionTreeClassifier| match &model.structure {
3113 TreeStructure::Standard { nodes, root } => match &nodes[*root] {
3114 TreeNode::BinarySplit { feature_index, .. } => *feature_index,
3115 node => panic!("expected binary root split, found {node:?}"),
3116 },
3117 TreeStructure::Oblivious { .. } => panic!("expected standard tree"),
3118 };
3119
3120 assert_eq!(gini_model.criterion(), Criterion::Gini);
3121 assert_eq!(entropy_model.criterion(), Criterion::Entropy);
3122 assert_eq!(root_feature(&gini_model), 0);
3123 assert_eq!(root_feature(&entropy_model), 1);
3124 }
3125
3126 #[test]
3127 fn rejects_non_finite_class_labels() {
3128 let table = DenseTable::new(vec![vec![0.0], vec![1.0]], vec![0.0, f64::NAN]).unwrap();
3129
3130 let err = train_id3(&table).unwrap_err();
3131 assert!(matches!(
3132 err,
3133 DecisionTreeError::InvalidTargetValue { row: 1, value } if value.is_nan()
3134 ));
3135 }
3136
3137 #[test]
3138 fn stops_standard_tree_growth_when_a_canary_wins() {
3139 let table = canary_target_table();
3140 for trainer in [train_id3, train_c45, train_cart] {
3141 let model = trainer(&table).unwrap();
3142 let preds = model.predict_table(&table);
3143
3144 assert!(preds.iter().all(|pred| *pred == preds[0]));
3145 assert_ne!(preds, table_targets(&table));
3146 }
3147 }
3148
3149 #[test]
3150 fn stops_oblivious_tree_growth_when_a_canary_wins() {
3151 let table = canary_target_table();
3152 let model = train_oblivious(&table).unwrap();
3153 let preds = model.predict_table(&table);
3154
3155 assert!(preds.iter().all(|pred| *pred == preds[0]));
3156 assert_ne!(preds, table_targets(&table));
3157 }
3158
3159 #[test]
3160 fn manually_built_classifier_models_serialize_for_each_tree_type() {
3161 let preprocessing = vec![
3162 FeaturePreprocessing::Binary,
3163 FeaturePreprocessing::Numeric {
3164 bin_boundaries: vec![
3165 NumericBinBoundary {
3166 bin: 0,
3167 upper_bound: 1.0,
3168 },
3169 NumericBinBoundary {
3170 bin: 127,
3171 upper_bound: 10.0,
3172 },
3173 ],
3174 },
3175 ];
3176 let options = DecisionTreeOptions::default();
3177 let class_labels = vec![10.0, 20.0];
3178
3179 let id3 = Model::DecisionTreeClassifier(DecisionTreeClassifier {
3180 algorithm: DecisionTreeAlgorithm::Id3,
3181 criterion: Criterion::Entropy,
3182 class_labels: class_labels.clone(),
3183 structure: TreeStructure::Standard {
3184 nodes: vec![
3185 TreeNode::Leaf {
3186 class_index: 0,
3187 sample_count: 3,
3188 class_counts: vec![3, 0],
3189 },
3190 TreeNode::Leaf {
3191 class_index: 1,
3192 sample_count: 2,
3193 class_counts: vec![0, 2],
3194 },
3195 TreeNode::MultiwaySplit {
3196 feature_index: 1,
3197 fallback_class_index: 0,
3198 branches: vec![(0, 0), (127, 1)],
3199 sample_count: 5,
3200 impurity: 0.48,
3201 gain: 0.24,
3202 class_counts: vec![3, 2],
3203 },
3204 ],
3205 root: 2,
3206 },
3207 options,
3208 num_features: 2,
3209 feature_preprocessing: preprocessing.clone(),
3210 training_canaries: 0,
3211 });
3212 let c45 = Model::DecisionTreeClassifier(DecisionTreeClassifier {
3213 algorithm: DecisionTreeAlgorithm::C45,
3214 criterion: Criterion::Entropy,
3215 class_labels: class_labels.clone(),
3216 structure: TreeStructure::Standard {
3217 nodes: vec![
3218 TreeNode::Leaf {
3219 class_index: 0,
3220 sample_count: 3,
3221 class_counts: vec![3, 0],
3222 },
3223 TreeNode::Leaf {
3224 class_index: 1,
3225 sample_count: 2,
3226 class_counts: vec![0, 2],
3227 },
3228 TreeNode::MultiwaySplit {
3229 feature_index: 1,
3230 fallback_class_index: 0,
3231 branches: vec![(0, 0), (127, 1)],
3232 sample_count: 5,
3233 impurity: 0.48,
3234 gain: 0.24,
3235 class_counts: vec![3, 2],
3236 },
3237 ],
3238 root: 2,
3239 },
3240 options,
3241 num_features: 2,
3242 feature_preprocessing: preprocessing.clone(),
3243 training_canaries: 0,
3244 });
3245 let cart = Model::DecisionTreeClassifier(DecisionTreeClassifier {
3246 algorithm: DecisionTreeAlgorithm::Cart,
3247 criterion: Criterion::Gini,
3248 class_labels: class_labels.clone(),
3249 structure: TreeStructure::Standard {
3250 nodes: vec![
3251 TreeNode::Leaf {
3252 class_index: 0,
3253 sample_count: 3,
3254 class_counts: vec![3, 0],
3255 },
3256 TreeNode::Leaf {
3257 class_index: 1,
3258 sample_count: 2,
3259 class_counts: vec![0, 2],
3260 },
3261 TreeNode::BinarySplit {
3262 feature_index: 0,
3263 threshold_bin: 0,
3264 left_child: 0,
3265 right_child: 1,
3266 sample_count: 5,
3267 impurity: 0.48,
3268 gain: 0.24,
3269 class_counts: vec![3, 2],
3270 },
3271 ],
3272 root: 2,
3273 },
3274 options,
3275 num_features: 2,
3276 feature_preprocessing: preprocessing.clone(),
3277 training_canaries: 0,
3278 });
3279 let randomized = Model::DecisionTreeClassifier(DecisionTreeClassifier {
3280 algorithm: DecisionTreeAlgorithm::Randomized,
3281 criterion: Criterion::Entropy,
3282 class_labels: class_labels.clone(),
3283 structure: TreeStructure::Standard {
3284 nodes: vec![
3285 TreeNode::Leaf {
3286 class_index: 0,
3287 sample_count: 3,
3288 class_counts: vec![3, 0],
3289 },
3290 TreeNode::Leaf {
3291 class_index: 1,
3292 sample_count: 2,
3293 class_counts: vec![0, 2],
3294 },
3295 TreeNode::BinarySplit {
3296 feature_index: 0,
3297 threshold_bin: 0,
3298 left_child: 0,
3299 right_child: 1,
3300 sample_count: 5,
3301 impurity: 0.48,
3302 gain: 0.2,
3303 class_counts: vec![3, 2],
3304 },
3305 ],
3306 root: 2,
3307 },
3308 options,
3309 num_features: 2,
3310 feature_preprocessing: preprocessing.clone(),
3311 training_canaries: 0,
3312 });
3313 let oblivious = Model::DecisionTreeClassifier(DecisionTreeClassifier {
3314 algorithm: DecisionTreeAlgorithm::Oblivious,
3315 criterion: Criterion::Gini,
3316 class_labels,
3317 structure: TreeStructure::Oblivious {
3318 splits: vec![ObliviousSplit {
3319 feature_index: 0,
3320 threshold_bin: 0,
3321 sample_count: 4,
3322 impurity: 0.5,
3323 gain: 0.25,
3324 }],
3325 leaf_class_indices: vec![0, 1],
3326 leaf_sample_counts: vec![2, 2],
3327 leaf_class_counts: vec![vec![2, 0], vec![0, 2]],
3328 },
3329 options,
3330 num_features: 2,
3331 feature_preprocessing: preprocessing,
3332 training_canaries: 0,
3333 });
3334
3335 for (tree_type, model) in [
3336 ("id3", id3),
3337 ("c45", c45),
3338 ("cart", cart),
3339 ("randomized", randomized),
3340 ("oblivious", oblivious),
3341 ] {
3342 let json = model.serialize().unwrap();
3343 assert!(json.contains(&format!("\"tree_type\":\"{tree_type}\"")));
3344 assert!(json.contains("\"task\":\"classification\""));
3345 }
3346 }
3347
3348 fn table_targets(table: &dyn TableAccess) -> Vec<f64> {
3349 (0..table.n_rows())
3350 .map(|row_idx| table.target_value(row_idx))
3351 .collect()
3352 }
3353}