1use crate::ir::{
17 BinaryChildren, BinarySplit, IndexedLeaf, LeafIndexing, LeafPayload, MultiwayBranch,
18 MultiwaySplit, NodeStats, NodeTreeNode, ObliviousLevel, ObliviousSplit as IrObliviousSplit,
19 TrainingMetadata, TreeDefinition, criterion_name, feature_name, threshold_upper_bound,
20 tree_type_name,
21};
22use crate::tree::shared::{
23 MissingBranchDirection, candidate_feature_indices, choose_random_threshold, node_seed,
24 partition_rows_for_binary_split, select_best_non_canary_candidate,
25};
26use crate::{
27 CanaryFilter, Criterion, FeaturePreprocessing, MissingValueStrategy, Parallelism,
28 capture_feature_preprocessing,
29};
30use forestfire_data::TableAccess;
31use rayon::prelude::*;
32use std::collections::BTreeMap;
33use std::error::Error;
34use std::fmt::{Display, Formatter};
35
36mod histogram;
37mod ir_support;
38mod oblivious;
39mod partitioning;
40mod split_scoring;
41
42use histogram::{
43 ClassificationFeatureHistogram, build_classification_node_histograms,
44 subtract_classification_node_histograms,
45};
46use ir_support::{
47 binary_split_ir, normalized_class_probabilities, oblivious_split_ir, standard_node_depths,
48};
49use oblivious::train_oblivious_structure;
50use partitioning::partition_rows_for_multiway_split;
51use split_scoring::{
52 MultiwayMetric, SplitScoringContext, score_binary_split_choice_from_hist,
53 score_multiway_split_choice,
54};
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum DecisionTreeAlgorithm {
58 Id3,
59 C45,
60 Cart,
61 Randomized,
62 Oblivious,
63}
64
65#[derive(Debug, Clone)]
71pub struct DecisionTreeOptions {
72 pub max_depth: usize,
73 pub min_samples_split: usize,
74 pub min_samples_leaf: usize,
75 pub max_features: Option<usize>,
76 pub random_seed: u64,
77 pub missing_value_strategies: Vec<MissingValueStrategy>,
78 pub canary_filter: CanaryFilter,
79}
80
81impl Default for DecisionTreeOptions {
82 fn default() -> Self {
83 Self {
84 max_depth: 8,
85 min_samples_split: 2,
86 min_samples_leaf: 1,
87 max_features: None,
88 random_seed: 0,
89 missing_value_strategies: Vec::new(),
90 canary_filter: CanaryFilter::default(),
91 }
92 }
93}
94
95#[derive(Debug)]
96pub enum DecisionTreeError {
97 EmptyTarget,
98 InvalidTargetValue { row: usize, value: f64 },
99}
100
101impl Display for DecisionTreeError {
102 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
103 match self {
104 DecisionTreeError::EmptyTarget => write!(f, "Cannot train on an empty target vector."),
105 DecisionTreeError::InvalidTargetValue { row, value } => write!(
106 f,
107 "Classification targets must be finite values. Found {} at row {}.",
108 value, row
109 ),
110 }
111 }
112}
113
114impl Error for DecisionTreeError {}
115
116#[derive(Debug, Clone)]
118pub struct DecisionTreeClassifier {
119 algorithm: DecisionTreeAlgorithm,
120 criterion: Criterion,
121 class_labels: Vec<f64>,
122 structure: TreeStructure,
123 options: DecisionTreeOptions,
124 num_features: usize,
125 feature_preprocessing: Vec<FeaturePreprocessing>,
126 training_canaries: usize,
127}
128
129#[derive(Debug, Clone)]
130pub(crate) enum TreeStructure {
131 Standard {
132 nodes: Vec<TreeNode>,
133 root: usize,
134 },
135 Oblivious {
136 splits: Vec<ObliviousSplit>,
137 leaf_class_indices: Vec<usize>,
138 leaf_sample_counts: Vec<usize>,
139 leaf_class_counts: Vec<Vec<usize>>,
140 },
141}
142
143#[derive(Debug, Clone)]
144pub(crate) struct ObliviousSplit {
145 pub(crate) feature_index: usize,
146 pub(crate) threshold_bin: u16,
147 #[allow(dead_code)]
148 pub(crate) missing_directions: Vec<MissingBranchDirection>,
149 pub(crate) sample_count: usize,
150 pub(crate) impurity: f64,
151 pub(crate) gain: f64,
152}
153
154#[derive(Debug, Clone)]
155pub(crate) enum TreeNode {
156 Leaf {
157 class_index: usize,
158 sample_count: usize,
159 class_counts: Vec<usize>,
160 },
161 MultiwaySplit {
162 feature_index: usize,
163 fallback_class_index: usize,
164 branches: Vec<(u16, usize)>,
165 missing_child: Option<usize>,
166 sample_count: usize,
167 impurity: f64,
168 gain: f64,
169 class_counts: Vec<usize>,
170 },
171 BinarySplit {
172 feature_index: usize,
173 threshold_bin: u16,
174 missing_direction: MissingBranchDirection,
175 left_child: usize,
176 right_child: usize,
177 sample_count: usize,
178 impurity: f64,
179 gain: f64,
180 class_counts: Vec<usize>,
181 },
182}
183
184#[derive(Debug, Clone, Copy)]
185struct BinarySplitChoice {
186 feature_index: usize,
187 score: f64,
188 threshold_bin: u16,
189 missing_direction: MissingBranchDirection,
190}
191
192#[derive(Debug, Clone)]
193struct MultiwaySplitChoice {
194 feature_index: usize,
195 score: f64,
196 branch_bins: Vec<u16>,
197 missing_branch_bin: Option<u16>,
198}
199
200pub fn train_id3(train_set: &dyn TableAccess) -> Result<DecisionTreeClassifier, DecisionTreeError> {
201 train_id3_with_criterion(train_set, Criterion::Entropy)
202}
203
204pub fn train_id3_with_criterion(
205 train_set: &dyn TableAccess,
206 criterion: Criterion,
207) -> Result<DecisionTreeClassifier, DecisionTreeError> {
208 train_id3_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
209}
210
211pub(crate) fn train_id3_with_criterion_and_parallelism(
212 train_set: &dyn TableAccess,
213 criterion: Criterion,
214 parallelism: Parallelism,
215) -> Result<DecisionTreeClassifier, DecisionTreeError> {
216 train_id3_with_criterion_parallelism_and_options(
217 train_set,
218 criterion,
219 parallelism,
220 DecisionTreeOptions::default(),
221 )
222}
223
224pub(crate) fn train_id3_with_criterion_parallelism_and_options(
225 train_set: &dyn TableAccess,
226 criterion: Criterion,
227 parallelism: Parallelism,
228 options: DecisionTreeOptions,
229) -> Result<DecisionTreeClassifier, DecisionTreeError> {
230 train_classifier(
231 train_set,
232 DecisionTreeAlgorithm::Id3,
233 criterion,
234 parallelism,
235 options,
236 )
237}
238
239pub fn train_c45(train_set: &dyn TableAccess) -> Result<DecisionTreeClassifier, DecisionTreeError> {
240 train_c45_with_criterion(train_set, Criterion::Entropy)
241}
242
243pub fn train_c45_with_criterion(
244 train_set: &dyn TableAccess,
245 criterion: Criterion,
246) -> Result<DecisionTreeClassifier, DecisionTreeError> {
247 train_c45_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
248}
249
250pub(crate) fn train_c45_with_criterion_and_parallelism(
251 train_set: &dyn TableAccess,
252 criterion: Criterion,
253 parallelism: Parallelism,
254) -> Result<DecisionTreeClassifier, DecisionTreeError> {
255 train_c45_with_criterion_parallelism_and_options(
256 train_set,
257 criterion,
258 parallelism,
259 DecisionTreeOptions::default(),
260 )
261}
262
263pub(crate) fn train_c45_with_criterion_parallelism_and_options(
264 train_set: &dyn TableAccess,
265 criterion: Criterion,
266 parallelism: Parallelism,
267 options: DecisionTreeOptions,
268) -> Result<DecisionTreeClassifier, DecisionTreeError> {
269 train_classifier(
270 train_set,
271 DecisionTreeAlgorithm::C45,
272 criterion,
273 parallelism,
274 options,
275 )
276}
277
278pub fn train_cart(
279 train_set: &dyn TableAccess,
280) -> Result<DecisionTreeClassifier, DecisionTreeError> {
281 train_cart_with_criterion(train_set, Criterion::Gini)
282}
283
284pub fn train_cart_with_criterion(
285 train_set: &dyn TableAccess,
286 criterion: Criterion,
287) -> Result<DecisionTreeClassifier, DecisionTreeError> {
288 train_cart_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
289}
290
291pub(crate) fn train_cart_with_criterion_and_parallelism(
292 train_set: &dyn TableAccess,
293 criterion: Criterion,
294 parallelism: Parallelism,
295) -> Result<DecisionTreeClassifier, DecisionTreeError> {
296 train_cart_with_criterion_parallelism_and_options(
297 train_set,
298 criterion,
299 parallelism,
300 DecisionTreeOptions::default(),
301 )
302}
303
304pub(crate) fn train_cart_with_criterion_parallelism_and_options(
305 train_set: &dyn TableAccess,
306 criterion: Criterion,
307 parallelism: Parallelism,
308 options: DecisionTreeOptions,
309) -> Result<DecisionTreeClassifier, DecisionTreeError> {
310 train_classifier(
311 train_set,
312 DecisionTreeAlgorithm::Cart,
313 criterion,
314 parallelism,
315 options,
316 )
317}
318
319pub fn train_oblivious(
320 train_set: &dyn TableAccess,
321) -> Result<DecisionTreeClassifier, DecisionTreeError> {
322 train_oblivious_with_criterion(train_set, Criterion::Gini)
323}
324
325pub fn train_oblivious_with_criterion(
326 train_set: &dyn TableAccess,
327 criterion: Criterion,
328) -> Result<DecisionTreeClassifier, DecisionTreeError> {
329 train_oblivious_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
330}
331
332pub(crate) fn train_oblivious_with_criterion_and_parallelism(
333 train_set: &dyn TableAccess,
334 criterion: Criterion,
335 parallelism: Parallelism,
336) -> Result<DecisionTreeClassifier, DecisionTreeError> {
337 train_oblivious_with_criterion_parallelism_and_options(
338 train_set,
339 criterion,
340 parallelism,
341 DecisionTreeOptions::default(),
342 )
343}
344
345pub(crate) fn train_oblivious_with_criterion_parallelism_and_options(
346 train_set: &dyn TableAccess,
347 criterion: Criterion,
348 parallelism: Parallelism,
349 options: DecisionTreeOptions,
350) -> Result<DecisionTreeClassifier, DecisionTreeError> {
351 train_classifier(
352 train_set,
353 DecisionTreeAlgorithm::Oblivious,
354 criterion,
355 parallelism,
356 options,
357 )
358}
359
360pub fn train_randomized(
361 train_set: &dyn TableAccess,
362) -> Result<DecisionTreeClassifier, DecisionTreeError> {
363 train_randomized_with_criterion(train_set, Criterion::Gini)
364}
365
366pub fn train_randomized_with_criterion(
367 train_set: &dyn TableAccess,
368 criterion: Criterion,
369) -> Result<DecisionTreeClassifier, DecisionTreeError> {
370 train_randomized_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
371}
372
373pub(crate) fn train_randomized_with_criterion_and_parallelism(
374 train_set: &dyn TableAccess,
375 criterion: Criterion,
376 parallelism: Parallelism,
377) -> Result<DecisionTreeClassifier, DecisionTreeError> {
378 train_randomized_with_criterion_parallelism_and_options(
379 train_set,
380 criterion,
381 parallelism,
382 DecisionTreeOptions::default(),
383 )
384}
385
386pub(crate) fn train_randomized_with_criterion_parallelism_and_options(
387 train_set: &dyn TableAccess,
388 criterion: Criterion,
389 parallelism: Parallelism,
390 options: DecisionTreeOptions,
391) -> Result<DecisionTreeClassifier, DecisionTreeError> {
392 train_classifier(
393 train_set,
394 DecisionTreeAlgorithm::Randomized,
395 criterion,
396 parallelism,
397 options,
398 )
399}
400
401fn train_classifier(
402 train_set: &dyn TableAccess,
403 algorithm: DecisionTreeAlgorithm,
404 criterion: Criterion,
405 parallelism: Parallelism,
406 options: DecisionTreeOptions,
407) -> Result<DecisionTreeClassifier, DecisionTreeError> {
408 if train_set.n_rows() == 0 {
409 return Err(DecisionTreeError::EmptyTarget);
410 }
411
412 let (class_labels, class_indices) = encode_class_labels(train_set)?;
413 let structure = match algorithm {
414 DecisionTreeAlgorithm::Oblivious => train_oblivious_structure(
415 train_set,
416 &class_indices,
417 &class_labels,
418 criterion,
419 parallelism,
420 options.clone(),
421 ),
422 DecisionTreeAlgorithm::Cart | DecisionTreeAlgorithm::Randomized => {
423 let mut nodes = Vec::new();
424 let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
425 let context = BuildContext {
426 table: train_set,
427 class_indices: &class_indices,
428 class_labels: &class_labels,
429 algorithm,
430 criterion,
431 parallelism,
432 options: options.clone(),
433 };
434 let root = build_binary_node_in_place(&context, &mut nodes, &mut all_rows, 0);
435 TreeStructure::Standard { nodes, root }
436 }
437 DecisionTreeAlgorithm::Id3 | DecisionTreeAlgorithm::C45 => {
438 let mut nodes = Vec::new();
439 let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
440 let context = BuildContext {
441 table: train_set,
442 class_indices: &class_indices,
443 class_labels: &class_labels,
444 algorithm,
445 criterion,
446 parallelism,
447 options: options.clone(),
448 };
449 let root = build_multiway_node_in_place(&context, &mut nodes, &mut all_rows, 0);
450 TreeStructure::Standard { nodes, root }
451 }
452 };
453
454 Ok(DecisionTreeClassifier {
455 algorithm,
456 criterion,
457 class_labels,
458 structure,
459 options,
460 num_features: train_set.n_features(),
461 feature_preprocessing: capture_feature_preprocessing(train_set),
462 training_canaries: train_set.canaries(),
463 })
464}
465
466impl DecisionTreeClassifier {
467 pub fn algorithm(&self) -> DecisionTreeAlgorithm {
468 self.algorithm
469 }
470
471 pub fn criterion(&self) -> Criterion {
472 self.criterion
473 }
474
475 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
476 (0..table.n_rows())
477 .map(|row_idx| self.predict_row(table, row_idx))
478 .collect()
479 }
480
481 pub fn predict_proba_table(&self, table: &dyn TableAccess) -> Vec<Vec<f64>> {
482 (0..table.n_rows())
483 .map(|row_idx| self.predict_proba_row(table, row_idx))
484 .collect()
485 }
486
487 fn predict_row(&self, table: &dyn TableAccess, row_idx: usize) -> f64 {
488 match &self.structure {
489 TreeStructure::Standard { nodes, root } => {
490 let mut node_index = *root;
491 loop {
492 match &nodes[node_index] {
493 TreeNode::Leaf { class_index, .. } => {
494 return self.class_labels[*class_index];
495 }
496 TreeNode::MultiwaySplit {
497 feature_index,
498 fallback_class_index,
499 branches,
500 missing_child,
501 ..
502 } => {
503 if table.is_missing(*feature_index, row_idx) {
504 if let Some(child_index) = missing_child {
505 node_index = *child_index;
506 } else {
507 return self.class_labels[*fallback_class_index];
508 }
509 continue;
510 }
511 let bin = table.binned_value(*feature_index, row_idx);
512 if let Some((_, child_index)) =
513 branches.iter().find(|(branch_bin, _)| *branch_bin == bin)
514 {
515 node_index = *child_index;
516 } else {
517 return self.class_labels[*fallback_class_index];
518 }
519 }
520 TreeNode::BinarySplit {
521 feature_index,
522 threshold_bin,
523 missing_direction,
524 left_child,
525 right_child,
526 class_counts,
527 ..
528 } => {
529 if table.is_missing(*feature_index, row_idx) {
530 match missing_direction {
531 MissingBranchDirection::Left => {
532 node_index = *left_child;
533 }
534 MissingBranchDirection::Right => {
535 node_index = *right_child;
536 }
537 MissingBranchDirection::Node => {
538 return self.class_labels
539 [majority_class_from_counts(class_counts)];
540 }
541 }
542 continue;
543 }
544 let bin = table.binned_value(*feature_index, row_idx);
545 node_index = if bin <= *threshold_bin {
546 *left_child
547 } else {
548 *right_child
549 };
550 }
551 }
552 }
553 }
554 TreeStructure::Oblivious {
555 splits,
556 leaf_class_indices,
557 ..
558 } => {
559 let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
560 let go_right =
561 table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
562 (leaf_index << 1) | usize::from(go_right)
563 });
564
565 self.class_labels[leaf_class_indices[leaf_index]]
566 }
567 }
568 }
569
570 fn predict_proba_row(&self, table: &dyn TableAccess, row_idx: usize) -> Vec<f64> {
571 match &self.structure {
572 TreeStructure::Standard { nodes, root } => {
573 let mut node_index = *root;
574 loop {
575 match &nodes[node_index] {
576 TreeNode::Leaf { class_counts, .. } => {
577 return normalized_class_probabilities(class_counts);
578 }
579 TreeNode::MultiwaySplit {
580 feature_index,
581 branches,
582 missing_child,
583 class_counts,
584 ..
585 } => {
586 if table.is_missing(*feature_index, row_idx) {
587 if let Some(child_index) = missing_child {
588 node_index = *child_index;
589 } else {
590 return normalized_class_probabilities(class_counts);
591 }
592 continue;
593 }
594 let bin = table.binned_value(*feature_index, row_idx);
595 if let Some((_, child_index)) =
596 branches.iter().find(|(branch_bin, _)| *branch_bin == bin)
597 {
598 node_index = *child_index;
599 } else {
600 return normalized_class_probabilities(class_counts);
601 }
602 }
603 TreeNode::BinarySplit {
604 feature_index,
605 threshold_bin,
606 missing_direction,
607 left_child,
608 right_child,
609 class_counts,
610 ..
611 } => {
612 if table.is_missing(*feature_index, row_idx) {
613 match missing_direction {
614 MissingBranchDirection::Left => {
615 node_index = *left_child;
616 }
617 MissingBranchDirection::Right => {
618 node_index = *right_child;
619 }
620 MissingBranchDirection::Node => {
621 return normalized_class_probabilities(class_counts);
622 }
623 }
624 continue;
625 }
626 let bin = table.binned_value(*feature_index, row_idx);
627 node_index = if bin <= *threshold_bin {
628 *left_child
629 } else {
630 *right_child
631 };
632 }
633 }
634 }
635 }
636 TreeStructure::Oblivious {
637 splits,
638 leaf_class_counts,
639 ..
640 } => {
641 let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
642 let go_right =
643 table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
644 (leaf_index << 1) | usize::from(go_right)
645 });
646
647 normalized_class_probabilities(&leaf_class_counts[leaf_index])
648 }
649 }
650 }
651
652 pub(crate) fn class_labels(&self) -> &[f64] {
653 &self.class_labels
654 }
655
656 pub(crate) fn structure(&self) -> &TreeStructure {
657 &self.structure
658 }
659
660 pub(crate) fn num_features(&self) -> usize {
661 self.num_features
662 }
663
664 pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
665 &self.feature_preprocessing
666 }
667
668 pub(crate) fn training_metadata(&self) -> TrainingMetadata {
669 TrainingMetadata {
670 algorithm: "dt".to_string(),
671 task: "classification".to_string(),
672 tree_type: tree_type_name(match self.algorithm {
673 DecisionTreeAlgorithm::Id3 => crate::TreeType::Id3,
674 DecisionTreeAlgorithm::C45 => crate::TreeType::C45,
675 DecisionTreeAlgorithm::Cart => crate::TreeType::Cart,
676 DecisionTreeAlgorithm::Randomized => crate::TreeType::Randomized,
677 DecisionTreeAlgorithm::Oblivious => crate::TreeType::Oblivious,
678 })
679 .to_string(),
680 criterion: criterion_name(self.criterion).to_string(),
681 canaries: self.training_canaries,
682 compute_oob: false,
683 max_depth: Some(self.options.max_depth),
684 min_samples_split: Some(self.options.min_samples_split),
685 min_samples_leaf: Some(self.options.min_samples_leaf),
686 n_trees: None,
687 max_features: self.options.max_features,
688 seed: None,
689 oob_score: None,
690 class_labels: Some(self.class_labels.clone()),
691 learning_rate: None,
692 bootstrap: None,
693 top_gradient_fraction: None,
694 other_gradient_fraction: None,
695 }
696 }
697
698 pub(crate) fn to_ir_tree(&self) -> TreeDefinition {
699 match &self.structure {
700 TreeStructure::Standard { nodes, root } => {
701 let depths = standard_node_depths(nodes, *root);
702 TreeDefinition::NodeTree {
703 tree_id: 0,
704 weight: 1.0,
705 root_node_id: *root,
706 nodes: nodes
707 .iter()
708 .enumerate()
709 .map(|(node_id, node)| match node {
710 TreeNode::Leaf {
711 class_index,
712 sample_count,
713 class_counts,
714 } => NodeTreeNode::Leaf {
715 node_id,
716 depth: depths[node_id],
717 leaf: self.class_leaf(*class_index),
718 stats: NodeStats {
719 sample_count: *sample_count,
720 impurity: None,
721 gain: None,
722 class_counts: Some(class_counts.clone()),
723 variance: None,
724 },
725 },
726 TreeNode::BinarySplit {
727 feature_index,
728 threshold_bin,
729 missing_direction,
730 left_child,
731 right_child,
732 sample_count,
733 impurity,
734 gain,
735 class_counts,
736 } => NodeTreeNode::BinaryBranch {
737 node_id,
738 depth: depths[node_id],
739 split: binary_split_ir(
740 *feature_index,
741 *threshold_bin,
742 *missing_direction,
743 &self.feature_preprocessing,
744 ),
745 children: BinaryChildren {
746 left: *left_child,
747 right: *right_child,
748 },
749 stats: NodeStats {
750 sample_count: *sample_count,
751 impurity: Some(*impurity),
752 gain: Some(*gain),
753 class_counts: Some(class_counts.clone()),
754 variance: None,
755 },
756 },
757 TreeNode::MultiwaySplit {
758 feature_index,
759 fallback_class_index,
760 branches,
761 missing_child: _,
762 sample_count,
763 impurity,
764 gain,
765 class_counts,
766 } => NodeTreeNode::MultiwayBranch {
767 node_id,
768 depth: depths[node_id],
769 split: MultiwaySplit {
770 split_type: "binned_value_multiway".to_string(),
771 feature_index: *feature_index,
772 feature_name: feature_name(*feature_index),
773 comparison_dtype: "uint16".to_string(),
774 },
775 branches: branches
776 .iter()
777 .map(|(bin, child)| MultiwayBranch {
778 bin: *bin,
779 child: *child,
780 })
781 .collect(),
782 unmatched_leaf: self.class_leaf(*fallback_class_index),
783 stats: NodeStats {
784 sample_count: *sample_count,
785 impurity: Some(*impurity),
786 gain: Some(*gain),
787 class_counts: Some(class_counts.clone()),
788 variance: None,
789 },
790 },
791 })
792 .collect(),
793 }
794 }
795 TreeStructure::Oblivious {
796 splits,
797 leaf_class_indices,
798 leaf_sample_counts,
799 leaf_class_counts,
800 } => TreeDefinition::ObliviousLevels {
801 tree_id: 0,
802 weight: 1.0,
803 depth: splits.len(),
804 levels: splits
805 .iter()
806 .enumerate()
807 .map(|(level, split)| ObliviousLevel {
808 level,
809 split: oblivious_split_ir(
810 split.feature_index,
811 split.threshold_bin,
812 &self.feature_preprocessing,
813 ),
814 stats: NodeStats {
815 sample_count: split.sample_count,
816 impurity: Some(split.impurity),
817 gain: Some(split.gain),
818 class_counts: None,
819 variance: None,
820 },
821 })
822 .collect(),
823 leaf_indexing: LeafIndexing {
824 bit_order: "msb_first".to_string(),
825 index_formula: "sum(bit[level] << (depth - 1 - level))".to_string(),
826 },
827 leaves: leaf_class_indices
828 .iter()
829 .enumerate()
830 .map(|(leaf_index, class_index)| IndexedLeaf {
831 leaf_index,
832 leaf: self.class_leaf(*class_index),
833 stats: NodeStats {
834 sample_count: leaf_sample_counts[leaf_index],
835 impurity: None,
836 gain: None,
837 class_counts: Some(leaf_class_counts[leaf_index].clone()),
838 variance: None,
839 },
840 })
841 .collect(),
842 },
843 }
844 }
845
846 fn class_leaf(&self, class_index: usize) -> LeafPayload {
847 LeafPayload::ClassIndex {
848 class_index,
849 class_value: self.class_labels[class_index],
850 }
851 }
852
853 #[allow(clippy::too_many_arguments)]
854 pub(crate) fn from_ir_parts(
855 algorithm: DecisionTreeAlgorithm,
856 criterion: Criterion,
857 class_labels: Vec<f64>,
858 structure: TreeStructure,
859 options: DecisionTreeOptions,
860 num_features: usize,
861 feature_preprocessing: Vec<FeaturePreprocessing>,
862 training_canaries: usize,
863 ) -> Self {
864 Self {
865 algorithm,
866 criterion,
867 class_labels,
868 structure,
869 options,
870 num_features,
871 feature_preprocessing,
872 training_canaries,
873 }
874 }
875}
876
877fn build_binary_node_in_place(
878 context: &BuildContext<'_>,
879 nodes: &mut Vec<TreeNode>,
880 rows: &mut [usize],
881 depth: usize,
882) -> usize {
883 build_binary_node_in_place_with_hist(context, nodes, rows, depth, None)
884}
885
886fn build_binary_node_in_place_with_hist(
887 context: &BuildContext<'_>,
888 nodes: &mut Vec<TreeNode>,
889 rows: &mut [usize],
890 depth: usize,
891 histograms: Option<Vec<ClassificationFeatureHistogram>>,
892) -> usize {
893 let majority_class_index =
894 majority_class(rows, context.class_indices, context.class_labels.len());
895 let current_class_counts =
896 class_counts(rows, context.class_indices, context.class_labels.len());
897
898 if rows.is_empty()
899 || depth >= context.options.max_depth
900 || rows.len() < context.options.min_samples_split
901 || is_pure(rows, context.class_indices)
902 {
903 return push_leaf(
904 nodes,
905 majority_class_index,
906 rows.len(),
907 current_class_counts,
908 );
909 }
910
911 let scoring = SplitScoringContext {
912 table: context.table,
913 class_indices: context.class_indices,
914 num_classes: context.class_labels.len(),
915 criterion: context.criterion,
916 min_samples_leaf: context.options.min_samples_leaf,
917 missing_value_strategies: &context.options.missing_value_strategies,
918 };
919 let histograms = histograms.unwrap_or_else(|| {
920 build_classification_node_histograms(
921 context.table,
922 context.class_indices,
923 rows,
924 context.class_labels.len(),
925 )
926 });
927 let feature_indices = candidate_feature_indices(
928 context.table,
929 context.options.max_features,
930 node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
931 );
932 let split_candidates = if context.parallelism.enabled() {
933 feature_indices
934 .into_par_iter()
935 .filter_map(|feature_index| {
936 score_binary_split_choice_from_hist(
937 &scoring,
938 &histograms[feature_index],
939 feature_index,
940 rows,
941 ¤t_class_counts,
942 context.algorithm,
943 )
944 })
945 .collect::<Vec<_>>()
946 } else {
947 feature_indices
948 .into_iter()
949 .filter_map(|feature_index| {
950 score_binary_split_choice_from_hist(
951 &scoring,
952 &histograms[feature_index],
953 feature_index,
954 rows,
955 ¤t_class_counts,
956 context.algorithm,
957 )
958 })
959 .collect::<Vec<_>>()
960 };
961 let best_split = select_best_non_canary_candidate(
962 context.table,
963 split_candidates,
964 context.options.canary_filter,
965 |candidate| candidate.score,
966 |candidate| candidate.feature_index,
967 )
968 .selected;
969
970 match best_split {
971 Some(best_split) if best_split.score > 0.0 => {
972 let impurity =
973 classification_impurity(¤t_class_counts, rows.len(), context.criterion);
974 let left_count = partition_rows_for_binary_split(
975 context.table,
976 best_split.feature_index,
977 best_split.threshold_bin,
978 best_split.missing_direction,
979 rows,
980 );
981 let (left_rows, right_rows) = rows.split_at_mut(left_count);
982 let (left_histograms, right_histograms) = if left_rows.len() <= right_rows.len() {
983 let left_histograms = build_classification_node_histograms(
984 context.table,
985 context.class_indices,
986 left_rows,
987 context.class_labels.len(),
988 );
989 let right_histograms =
990 subtract_classification_node_histograms(&histograms, &left_histograms);
991 (left_histograms, right_histograms)
992 } else {
993 let right_histograms = build_classification_node_histograms(
994 context.table,
995 context.class_indices,
996 right_rows,
997 context.class_labels.len(),
998 );
999 let left_histograms =
1000 subtract_classification_node_histograms(&histograms, &right_histograms);
1001 (left_histograms, right_histograms)
1002 };
1003 let left_child = build_binary_node_in_place_with_hist(
1004 context,
1005 nodes,
1006 left_rows,
1007 depth + 1,
1008 Some(left_histograms),
1009 );
1010 let right_child = build_binary_node_in_place_with_hist(
1011 context,
1012 nodes,
1013 right_rows,
1014 depth + 1,
1015 Some(right_histograms),
1016 );
1017
1018 push_node(
1019 nodes,
1020 TreeNode::BinarySplit {
1021 feature_index: best_split.feature_index,
1022 threshold_bin: best_split.threshold_bin,
1023 missing_direction: best_split.missing_direction,
1024 left_child,
1025 right_child,
1026 sample_count: rows.len(),
1027 impurity,
1028 gain: best_split.score,
1029 class_counts: current_class_counts,
1030 },
1031 )
1032 }
1033 _ => push_leaf(
1034 nodes,
1035 majority_class_index,
1036 rows.len(),
1037 current_class_counts,
1038 ),
1039 }
1040}
1041
1042fn build_multiway_node_in_place(
1043 context: &BuildContext<'_>,
1044 nodes: &mut Vec<TreeNode>,
1045 rows: &mut [usize],
1046 depth: usize,
1047) -> usize {
1048 let majority_class_index =
1049 majority_class(rows, context.class_indices, context.class_labels.len());
1050 let current_class_counts =
1051 class_counts(rows, context.class_indices, context.class_labels.len());
1052
1053 if rows.is_empty()
1054 || depth >= context.options.max_depth
1055 || rows.len() < context.options.min_samples_split
1056 || is_pure(rows, context.class_indices)
1057 {
1058 return push_leaf(
1059 nodes,
1060 majority_class_index,
1061 rows.len(),
1062 current_class_counts,
1063 );
1064 }
1065
1066 let metric = match context.algorithm {
1067 DecisionTreeAlgorithm::Id3 => MultiwayMetric::InformationGain,
1068 DecisionTreeAlgorithm::C45 => MultiwayMetric::GainRatio,
1069 _ => unreachable!("multiway builder only supports id3/c45"),
1070 };
1071 let scoring = SplitScoringContext {
1072 table: context.table,
1073 class_indices: context.class_indices,
1074 num_classes: context.class_labels.len(),
1075 criterion: context.criterion,
1076 min_samples_leaf: context.options.min_samples_leaf,
1077 missing_value_strategies: &context.options.missing_value_strategies,
1078 };
1079 let feature_indices = candidate_feature_indices(
1080 context.table,
1081 context.options.max_features,
1082 node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
1083 );
1084 let split_candidates = if context.parallelism.enabled() {
1085 feature_indices
1086 .into_par_iter()
1087 .filter_map(|feature_index| {
1088 score_multiway_split_choice(&scoring, feature_index, rows, metric)
1089 })
1090 .collect::<Vec<_>>()
1091 } else {
1092 feature_indices
1093 .into_iter()
1094 .filter_map(|feature_index| {
1095 score_multiway_split_choice(&scoring, feature_index, rows, metric)
1096 })
1097 .collect::<Vec<_>>()
1098 };
1099 let best_split = select_best_non_canary_candidate(
1100 context.table,
1101 split_candidates,
1102 context.options.canary_filter,
1103 |candidate| candidate.score,
1104 |candidate| candidate.feature_index,
1105 )
1106 .selected;
1107
1108 match best_split {
1109 Some(best_split) if best_split.score > 0.0 => {
1110 let impurity =
1111 classification_impurity(¤t_class_counts, rows.len(), context.criterion);
1112 let branch_ranges = partition_rows_for_multiway_split(
1113 context.table,
1114 best_split.feature_index,
1115 &best_split.branch_bins,
1116 best_split.missing_branch_bin,
1117 rows,
1118 );
1119 let mut branch_nodes = Vec::with_capacity(branch_ranges.len());
1120 let mut missing_child = None;
1121 for (bin, start, end) in branch_ranges {
1122 let child =
1123 build_multiway_node_in_place(context, nodes, &mut rows[start..end], depth + 1);
1124 if best_split.missing_branch_bin == Some(bin) {
1125 missing_child = Some(child);
1126 }
1127 branch_nodes.push((bin, child));
1128 }
1129
1130 push_node(
1131 nodes,
1132 TreeNode::MultiwaySplit {
1133 feature_index: best_split.feature_index,
1134 fallback_class_index: majority_class_index,
1135 branches: branch_nodes,
1136 missing_child,
1137 sample_count: rows.len(),
1138 impurity,
1139 gain: best_split.score,
1140 class_counts: current_class_counts,
1141 },
1142 )
1143 }
1144 _ => push_leaf(
1145 nodes,
1146 majority_class_index,
1147 rows.len(),
1148 current_class_counts,
1149 ),
1150 }
1151}
1152
1153struct BuildContext<'a> {
1154 table: &'a dyn TableAccess,
1155 class_indices: &'a [usize],
1156 class_labels: &'a [f64],
1157 algorithm: DecisionTreeAlgorithm,
1158 criterion: Criterion,
1159 parallelism: Parallelism,
1160 options: DecisionTreeOptions,
1161}
1162
1163fn encode_class_labels(
1164 train_set: &dyn TableAccess,
1165) -> Result<(Vec<f64>, Vec<usize>), DecisionTreeError> {
1166 let targets: Vec<f64> = (0..train_set.n_rows())
1167 .map(|row_idx| {
1168 let value = train_set.target_value(row_idx);
1169 if value.is_finite() {
1170 Ok(value)
1171 } else {
1172 Err(DecisionTreeError::InvalidTargetValue {
1173 row: row_idx,
1174 value,
1175 })
1176 }
1177 })
1178 .collect::<Result<_, _>>()?;
1179
1180 let class_labels = targets
1181 .iter()
1182 .copied()
1183 .fold(Vec::<f64>::new(), |mut labels, value| {
1184 if labels
1185 .binary_search_by(|candidate| candidate.total_cmp(&value))
1186 .is_err()
1187 {
1188 labels.push(value);
1189 labels.sort_by(|left, right| left.total_cmp(right));
1190 }
1191 labels
1192 });
1193
1194 let class_indices = targets
1195 .iter()
1196 .map(|value| {
1197 class_labels
1198 .binary_search_by(|candidate| candidate.total_cmp(value))
1199 .expect("target value must exist in class label vocabulary")
1200 })
1201 .collect();
1202
1203 Ok((class_labels, class_indices))
1204}
1205
1206fn class_counts(rows: &[usize], class_indices: &[usize], num_classes: usize) -> Vec<usize> {
1207 rows.iter()
1208 .fold(vec![0usize; num_classes], |mut counts, row_idx| {
1209 counts[class_indices[*row_idx]] += 1;
1210 counts
1211 })
1212}
1213
1214fn majority_class(rows: &[usize], class_indices: &[usize], num_classes: usize) -> usize {
1215 majority_class_from_counts(&class_counts(rows, class_indices, num_classes))
1216}
1217
1218fn majority_class_from_counts(counts: &[usize]) -> usize {
1219 counts
1220 .iter()
1221 .copied()
1222 .enumerate()
1223 .max_by(|left, right| left.1.cmp(&right.1).then_with(|| right.0.cmp(&left.0)))
1224 .map(|(class_index, _count)| class_index)
1225 .unwrap_or(0)
1226}
1227
1228fn is_pure(rows: &[usize], class_indices: &[usize]) -> bool {
1229 rows.first().is_none_or(|first_row| {
1230 rows.iter()
1231 .all(|row_idx| class_indices[*row_idx] == class_indices[*first_row])
1232 })
1233}
1234
1235fn entropy(counts: &[usize], total: usize) -> f64 {
1236 counts
1237 .iter()
1238 .copied()
1239 .filter(|count| *count > 0)
1240 .map(|count| {
1241 let probability = count as f64 / total as f64;
1242 -probability * probability.log2()
1243 })
1244 .sum()
1245}
1246
1247fn gini(counts: &[usize], total: usize) -> f64 {
1248 1.0 - counts
1249 .iter()
1250 .copied()
1251 .map(|count| {
1252 let probability = count as f64 / total as f64;
1253 probability * probability
1254 })
1255 .sum::<f64>()
1256}
1257
1258fn classification_impurity(counts: &[usize], total: usize, criterion: Criterion) -> f64 {
1259 match criterion {
1260 Criterion::Entropy => entropy(counts, total),
1261 Criterion::Gini => gini(counts, total),
1262 _ => unreachable!("classification impurity only supports gini or entropy"),
1263 }
1264}
1265
1266fn push_leaf(
1267 nodes: &mut Vec<TreeNode>,
1268 class_index: usize,
1269 sample_count: usize,
1270 class_counts: Vec<usize>,
1271) -> usize {
1272 push_node(
1273 nodes,
1274 TreeNode::Leaf {
1275 class_index,
1276 sample_count,
1277 class_counts,
1278 },
1279 )
1280}
1281
1282fn push_node(nodes: &mut Vec<TreeNode>, node: TreeNode) -> usize {
1283 nodes.push(node);
1284 nodes.len() - 1
1285}
1286
1287#[cfg(test)]
1288mod tests;