1use crate::ir::{
17 BinaryChildren, BinarySplit, IndexedLeaf, LeafIndexing, LeafPayload, MultiwayBranch,
18 MultiwaySplit, NodeStats, NodeTreeNode, ObliviousLevel, ObliviousSplit as IrObliviousSplit,
19 TrainingMetadata, TreeDefinition, criterion_name, feature_name, threshold_upper_bound,
20 tree_type_name,
21};
22use crate::tree::shared::{
23 MissingBranchDirection, candidate_feature_indices, choose_random_threshold, node_seed,
24 partition_rows_for_binary_split,
25};
26use crate::{
27 Criterion, FeaturePreprocessing, MissingValueStrategy, Parallelism,
28 capture_feature_preprocessing,
29};
30use forestfire_data::TableAccess;
31use rayon::prelude::*;
32use std::collections::BTreeMap;
33use std::error::Error;
34use std::fmt::{Display, Formatter};
35
36mod histogram;
37mod ir_support;
38mod oblivious;
39mod partitioning;
40mod split_scoring;
41
42use histogram::{
43 ClassificationFeatureHistogram, build_classification_node_histograms,
44 subtract_classification_node_histograms,
45};
46use ir_support::{
47 binary_split_ir, normalized_class_probabilities, oblivious_split_ir, standard_node_depths,
48};
49use oblivious::train_oblivious_structure;
50use partitioning::partition_rows_for_multiway_split;
51use split_scoring::{
52 MultiwayMetric, SplitScoringContext, score_binary_split_choice_from_hist,
53 score_multiway_split_choice,
54};
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum DecisionTreeAlgorithm {
58 Id3,
59 C45,
60 Cart,
61 Randomized,
62 Oblivious,
63}
64
65#[derive(Debug, Clone)]
71pub struct DecisionTreeOptions {
72 pub max_depth: usize,
73 pub min_samples_split: usize,
74 pub min_samples_leaf: usize,
75 pub max_features: Option<usize>,
76 pub random_seed: u64,
77 pub missing_value_strategies: Vec<MissingValueStrategy>,
78}
79
80impl Default for DecisionTreeOptions {
81 fn default() -> Self {
82 Self {
83 max_depth: 8,
84 min_samples_split: 2,
85 min_samples_leaf: 1,
86 max_features: None,
87 random_seed: 0,
88 missing_value_strategies: Vec::new(),
89 }
90 }
91}
92
93#[derive(Debug)]
94pub enum DecisionTreeError {
95 EmptyTarget,
96 InvalidTargetValue { row: usize, value: f64 },
97}
98
99impl Display for DecisionTreeError {
100 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
101 match self {
102 DecisionTreeError::EmptyTarget => write!(f, "Cannot train on an empty target vector."),
103 DecisionTreeError::InvalidTargetValue { row, value } => write!(
104 f,
105 "Classification targets must be finite values. Found {} at row {}.",
106 value, row
107 ),
108 }
109 }
110}
111
112impl Error for DecisionTreeError {}
113
114#[derive(Debug, Clone)]
116pub struct DecisionTreeClassifier {
117 algorithm: DecisionTreeAlgorithm,
118 criterion: Criterion,
119 class_labels: Vec<f64>,
120 structure: TreeStructure,
121 options: DecisionTreeOptions,
122 num_features: usize,
123 feature_preprocessing: Vec<FeaturePreprocessing>,
124 training_canaries: usize,
125}
126
127#[derive(Debug, Clone)]
128pub(crate) enum TreeStructure {
129 Standard {
130 nodes: Vec<TreeNode>,
131 root: usize,
132 },
133 Oblivious {
134 splits: Vec<ObliviousSplit>,
135 leaf_class_indices: Vec<usize>,
136 leaf_sample_counts: Vec<usize>,
137 leaf_class_counts: Vec<Vec<usize>>,
138 },
139}
140
141#[derive(Debug, Clone)]
142pub(crate) struct ObliviousSplit {
143 pub(crate) feature_index: usize,
144 pub(crate) threshold_bin: u16,
145 #[allow(dead_code)]
146 pub(crate) missing_directions: Vec<MissingBranchDirection>,
147 pub(crate) sample_count: usize,
148 pub(crate) impurity: f64,
149 pub(crate) gain: f64,
150}
151
152#[derive(Debug, Clone)]
153pub(crate) enum TreeNode {
154 Leaf {
155 class_index: usize,
156 sample_count: usize,
157 class_counts: Vec<usize>,
158 },
159 MultiwaySplit {
160 feature_index: usize,
161 fallback_class_index: usize,
162 branches: Vec<(u16, usize)>,
163 missing_child: Option<usize>,
164 sample_count: usize,
165 impurity: f64,
166 gain: f64,
167 class_counts: Vec<usize>,
168 },
169 BinarySplit {
170 feature_index: usize,
171 threshold_bin: u16,
172 missing_direction: MissingBranchDirection,
173 left_child: usize,
174 right_child: usize,
175 sample_count: usize,
176 impurity: f64,
177 gain: f64,
178 class_counts: Vec<usize>,
179 },
180}
181
182#[derive(Debug, Clone, Copy)]
183struct BinarySplitChoice {
184 feature_index: usize,
185 score: f64,
186 threshold_bin: u16,
187 missing_direction: MissingBranchDirection,
188}
189
190#[derive(Debug, Clone)]
191struct MultiwaySplitChoice {
192 feature_index: usize,
193 score: f64,
194 branch_bins: Vec<u16>,
195 missing_branch_bin: Option<u16>,
196}
197
198pub fn train_id3(train_set: &dyn TableAccess) -> Result<DecisionTreeClassifier, DecisionTreeError> {
199 train_id3_with_criterion(train_set, Criterion::Entropy)
200}
201
202pub fn train_id3_with_criterion(
203 train_set: &dyn TableAccess,
204 criterion: Criterion,
205) -> Result<DecisionTreeClassifier, DecisionTreeError> {
206 train_id3_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
207}
208
209pub(crate) fn train_id3_with_criterion_and_parallelism(
210 train_set: &dyn TableAccess,
211 criterion: Criterion,
212 parallelism: Parallelism,
213) -> Result<DecisionTreeClassifier, DecisionTreeError> {
214 train_id3_with_criterion_parallelism_and_options(
215 train_set,
216 criterion,
217 parallelism,
218 DecisionTreeOptions::default(),
219 )
220}
221
222pub(crate) fn train_id3_with_criterion_parallelism_and_options(
223 train_set: &dyn TableAccess,
224 criterion: Criterion,
225 parallelism: Parallelism,
226 options: DecisionTreeOptions,
227) -> Result<DecisionTreeClassifier, DecisionTreeError> {
228 train_classifier(
229 train_set,
230 DecisionTreeAlgorithm::Id3,
231 criterion,
232 parallelism,
233 options,
234 )
235}
236
237pub fn train_c45(train_set: &dyn TableAccess) -> Result<DecisionTreeClassifier, DecisionTreeError> {
238 train_c45_with_criterion(train_set, Criterion::Entropy)
239}
240
241pub fn train_c45_with_criterion(
242 train_set: &dyn TableAccess,
243 criterion: Criterion,
244) -> Result<DecisionTreeClassifier, DecisionTreeError> {
245 train_c45_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
246}
247
248pub(crate) fn train_c45_with_criterion_and_parallelism(
249 train_set: &dyn TableAccess,
250 criterion: Criterion,
251 parallelism: Parallelism,
252) -> Result<DecisionTreeClassifier, DecisionTreeError> {
253 train_c45_with_criterion_parallelism_and_options(
254 train_set,
255 criterion,
256 parallelism,
257 DecisionTreeOptions::default(),
258 )
259}
260
261pub(crate) fn train_c45_with_criterion_parallelism_and_options(
262 train_set: &dyn TableAccess,
263 criterion: Criterion,
264 parallelism: Parallelism,
265 options: DecisionTreeOptions,
266) -> Result<DecisionTreeClassifier, DecisionTreeError> {
267 train_classifier(
268 train_set,
269 DecisionTreeAlgorithm::C45,
270 criterion,
271 parallelism,
272 options,
273 )
274}
275
276pub fn train_cart(
277 train_set: &dyn TableAccess,
278) -> Result<DecisionTreeClassifier, DecisionTreeError> {
279 train_cart_with_criterion(train_set, Criterion::Gini)
280}
281
282pub fn train_cart_with_criterion(
283 train_set: &dyn TableAccess,
284 criterion: Criterion,
285) -> Result<DecisionTreeClassifier, DecisionTreeError> {
286 train_cart_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
287}
288
289pub(crate) fn train_cart_with_criterion_and_parallelism(
290 train_set: &dyn TableAccess,
291 criterion: Criterion,
292 parallelism: Parallelism,
293) -> Result<DecisionTreeClassifier, DecisionTreeError> {
294 train_cart_with_criterion_parallelism_and_options(
295 train_set,
296 criterion,
297 parallelism,
298 DecisionTreeOptions::default(),
299 )
300}
301
302pub(crate) fn train_cart_with_criterion_parallelism_and_options(
303 train_set: &dyn TableAccess,
304 criterion: Criterion,
305 parallelism: Parallelism,
306 options: DecisionTreeOptions,
307) -> Result<DecisionTreeClassifier, DecisionTreeError> {
308 train_classifier(
309 train_set,
310 DecisionTreeAlgorithm::Cart,
311 criterion,
312 parallelism,
313 options,
314 )
315}
316
317pub fn train_oblivious(
318 train_set: &dyn TableAccess,
319) -> Result<DecisionTreeClassifier, DecisionTreeError> {
320 train_oblivious_with_criterion(train_set, Criterion::Gini)
321}
322
323pub fn train_oblivious_with_criterion(
324 train_set: &dyn TableAccess,
325 criterion: Criterion,
326) -> Result<DecisionTreeClassifier, DecisionTreeError> {
327 train_oblivious_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
328}
329
330pub(crate) fn train_oblivious_with_criterion_and_parallelism(
331 train_set: &dyn TableAccess,
332 criterion: Criterion,
333 parallelism: Parallelism,
334) -> Result<DecisionTreeClassifier, DecisionTreeError> {
335 train_oblivious_with_criterion_parallelism_and_options(
336 train_set,
337 criterion,
338 parallelism,
339 DecisionTreeOptions::default(),
340 )
341}
342
343pub(crate) fn train_oblivious_with_criterion_parallelism_and_options(
344 train_set: &dyn TableAccess,
345 criterion: Criterion,
346 parallelism: Parallelism,
347 options: DecisionTreeOptions,
348) -> Result<DecisionTreeClassifier, DecisionTreeError> {
349 train_classifier(
350 train_set,
351 DecisionTreeAlgorithm::Oblivious,
352 criterion,
353 parallelism,
354 options,
355 )
356}
357
358pub fn train_randomized(
359 train_set: &dyn TableAccess,
360) -> Result<DecisionTreeClassifier, DecisionTreeError> {
361 train_randomized_with_criterion(train_set, Criterion::Gini)
362}
363
364pub fn train_randomized_with_criterion(
365 train_set: &dyn TableAccess,
366 criterion: Criterion,
367) -> Result<DecisionTreeClassifier, DecisionTreeError> {
368 train_randomized_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
369}
370
371pub(crate) fn train_randomized_with_criterion_and_parallelism(
372 train_set: &dyn TableAccess,
373 criterion: Criterion,
374 parallelism: Parallelism,
375) -> Result<DecisionTreeClassifier, DecisionTreeError> {
376 train_randomized_with_criterion_parallelism_and_options(
377 train_set,
378 criterion,
379 parallelism,
380 DecisionTreeOptions::default(),
381 )
382}
383
384pub(crate) fn train_randomized_with_criterion_parallelism_and_options(
385 train_set: &dyn TableAccess,
386 criterion: Criterion,
387 parallelism: Parallelism,
388 options: DecisionTreeOptions,
389) -> Result<DecisionTreeClassifier, DecisionTreeError> {
390 train_classifier(
391 train_set,
392 DecisionTreeAlgorithm::Randomized,
393 criterion,
394 parallelism,
395 options,
396 )
397}
398
399fn train_classifier(
400 train_set: &dyn TableAccess,
401 algorithm: DecisionTreeAlgorithm,
402 criterion: Criterion,
403 parallelism: Parallelism,
404 options: DecisionTreeOptions,
405) -> Result<DecisionTreeClassifier, DecisionTreeError> {
406 if train_set.n_rows() == 0 {
407 return Err(DecisionTreeError::EmptyTarget);
408 }
409
410 let (class_labels, class_indices) = encode_class_labels(train_set)?;
411 let structure = match algorithm {
412 DecisionTreeAlgorithm::Oblivious => train_oblivious_structure(
413 train_set,
414 &class_indices,
415 &class_labels,
416 criterion,
417 parallelism,
418 options.clone(),
419 ),
420 DecisionTreeAlgorithm::Cart | DecisionTreeAlgorithm::Randomized => {
421 let mut nodes = Vec::new();
422 let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
423 let context = BuildContext {
424 table: train_set,
425 class_indices: &class_indices,
426 class_labels: &class_labels,
427 algorithm,
428 criterion,
429 parallelism,
430 options: options.clone(),
431 };
432 let root = build_binary_node_in_place(&context, &mut nodes, &mut all_rows, 0);
433 TreeStructure::Standard { nodes, root }
434 }
435 DecisionTreeAlgorithm::Id3 | DecisionTreeAlgorithm::C45 => {
436 let mut nodes = Vec::new();
437 let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
438 let context = BuildContext {
439 table: train_set,
440 class_indices: &class_indices,
441 class_labels: &class_labels,
442 algorithm,
443 criterion,
444 parallelism,
445 options: options.clone(),
446 };
447 let root = build_multiway_node_in_place(&context, &mut nodes, &mut all_rows, 0);
448 TreeStructure::Standard { nodes, root }
449 }
450 };
451
452 Ok(DecisionTreeClassifier {
453 algorithm,
454 criterion,
455 class_labels,
456 structure,
457 options,
458 num_features: train_set.n_features(),
459 feature_preprocessing: capture_feature_preprocessing(train_set),
460 training_canaries: train_set.canaries(),
461 })
462}
463
464impl DecisionTreeClassifier {
465 pub fn algorithm(&self) -> DecisionTreeAlgorithm {
466 self.algorithm
467 }
468
469 pub fn criterion(&self) -> Criterion {
470 self.criterion
471 }
472
473 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
474 (0..table.n_rows())
475 .map(|row_idx| self.predict_row(table, row_idx))
476 .collect()
477 }
478
479 pub fn predict_proba_table(&self, table: &dyn TableAccess) -> Vec<Vec<f64>> {
480 (0..table.n_rows())
481 .map(|row_idx| self.predict_proba_row(table, row_idx))
482 .collect()
483 }
484
485 fn predict_row(&self, table: &dyn TableAccess, row_idx: usize) -> f64 {
486 match &self.structure {
487 TreeStructure::Standard { nodes, root } => {
488 let mut node_index = *root;
489 loop {
490 match &nodes[node_index] {
491 TreeNode::Leaf { class_index, .. } => {
492 return self.class_labels[*class_index];
493 }
494 TreeNode::MultiwaySplit {
495 feature_index,
496 fallback_class_index,
497 branches,
498 missing_child,
499 ..
500 } => {
501 if table.is_missing(*feature_index, row_idx) {
502 if let Some(child_index) = missing_child {
503 node_index = *child_index;
504 } else {
505 return self.class_labels[*fallback_class_index];
506 }
507 continue;
508 }
509 let bin = table.binned_value(*feature_index, row_idx);
510 if let Some((_, child_index)) =
511 branches.iter().find(|(branch_bin, _)| *branch_bin == bin)
512 {
513 node_index = *child_index;
514 } else {
515 return self.class_labels[*fallback_class_index];
516 }
517 }
518 TreeNode::BinarySplit {
519 feature_index,
520 threshold_bin,
521 missing_direction,
522 left_child,
523 right_child,
524 class_counts,
525 ..
526 } => {
527 if table.is_missing(*feature_index, row_idx) {
528 match missing_direction {
529 MissingBranchDirection::Left => {
530 node_index = *left_child;
531 }
532 MissingBranchDirection::Right => {
533 node_index = *right_child;
534 }
535 MissingBranchDirection::Node => {
536 return self.class_labels
537 [majority_class_from_counts(class_counts)];
538 }
539 }
540 continue;
541 }
542 let bin = table.binned_value(*feature_index, row_idx);
543 node_index = if bin <= *threshold_bin {
544 *left_child
545 } else {
546 *right_child
547 };
548 }
549 }
550 }
551 }
552 TreeStructure::Oblivious {
553 splits,
554 leaf_class_indices,
555 ..
556 } => {
557 let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
558 let go_right =
559 table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
560 (leaf_index << 1) | usize::from(go_right)
561 });
562
563 self.class_labels[leaf_class_indices[leaf_index]]
564 }
565 }
566 }
567
568 fn predict_proba_row(&self, table: &dyn TableAccess, row_idx: usize) -> Vec<f64> {
569 match &self.structure {
570 TreeStructure::Standard { nodes, root } => {
571 let mut node_index = *root;
572 loop {
573 match &nodes[node_index] {
574 TreeNode::Leaf { class_counts, .. } => {
575 return normalized_class_probabilities(class_counts);
576 }
577 TreeNode::MultiwaySplit {
578 feature_index,
579 branches,
580 missing_child,
581 class_counts,
582 ..
583 } => {
584 if table.is_missing(*feature_index, row_idx) {
585 if let Some(child_index) = missing_child {
586 node_index = *child_index;
587 } else {
588 return normalized_class_probabilities(class_counts);
589 }
590 continue;
591 }
592 let bin = table.binned_value(*feature_index, row_idx);
593 if let Some((_, child_index)) =
594 branches.iter().find(|(branch_bin, _)| *branch_bin == bin)
595 {
596 node_index = *child_index;
597 } else {
598 return normalized_class_probabilities(class_counts);
599 }
600 }
601 TreeNode::BinarySplit {
602 feature_index,
603 threshold_bin,
604 missing_direction,
605 left_child,
606 right_child,
607 class_counts,
608 ..
609 } => {
610 if table.is_missing(*feature_index, row_idx) {
611 match missing_direction {
612 MissingBranchDirection::Left => {
613 node_index = *left_child;
614 }
615 MissingBranchDirection::Right => {
616 node_index = *right_child;
617 }
618 MissingBranchDirection::Node => {
619 return normalized_class_probabilities(class_counts);
620 }
621 }
622 continue;
623 }
624 let bin = table.binned_value(*feature_index, row_idx);
625 node_index = if bin <= *threshold_bin {
626 *left_child
627 } else {
628 *right_child
629 };
630 }
631 }
632 }
633 }
634 TreeStructure::Oblivious {
635 splits,
636 leaf_class_counts,
637 ..
638 } => {
639 let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
640 let go_right =
641 table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
642 (leaf_index << 1) | usize::from(go_right)
643 });
644
645 normalized_class_probabilities(&leaf_class_counts[leaf_index])
646 }
647 }
648 }
649
650 pub(crate) fn class_labels(&self) -> &[f64] {
651 &self.class_labels
652 }
653
654 pub(crate) fn structure(&self) -> &TreeStructure {
655 &self.structure
656 }
657
658 pub(crate) fn num_features(&self) -> usize {
659 self.num_features
660 }
661
662 pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
663 &self.feature_preprocessing
664 }
665
666 pub(crate) fn training_metadata(&self) -> TrainingMetadata {
667 TrainingMetadata {
668 algorithm: "dt".to_string(),
669 task: "classification".to_string(),
670 tree_type: tree_type_name(match self.algorithm {
671 DecisionTreeAlgorithm::Id3 => crate::TreeType::Id3,
672 DecisionTreeAlgorithm::C45 => crate::TreeType::C45,
673 DecisionTreeAlgorithm::Cart => crate::TreeType::Cart,
674 DecisionTreeAlgorithm::Randomized => crate::TreeType::Randomized,
675 DecisionTreeAlgorithm::Oblivious => crate::TreeType::Oblivious,
676 })
677 .to_string(),
678 criterion: criterion_name(self.criterion).to_string(),
679 canaries: self.training_canaries,
680 compute_oob: false,
681 max_depth: Some(self.options.max_depth),
682 min_samples_split: Some(self.options.min_samples_split),
683 min_samples_leaf: Some(self.options.min_samples_leaf),
684 n_trees: None,
685 max_features: self.options.max_features,
686 seed: None,
687 oob_score: None,
688 class_labels: Some(self.class_labels.clone()),
689 learning_rate: None,
690 bootstrap: None,
691 top_gradient_fraction: None,
692 other_gradient_fraction: None,
693 }
694 }
695
696 pub(crate) fn to_ir_tree(&self) -> TreeDefinition {
697 match &self.structure {
698 TreeStructure::Standard { nodes, root } => {
699 let depths = standard_node_depths(nodes, *root);
700 TreeDefinition::NodeTree {
701 tree_id: 0,
702 weight: 1.0,
703 root_node_id: *root,
704 nodes: nodes
705 .iter()
706 .enumerate()
707 .map(|(node_id, node)| match node {
708 TreeNode::Leaf {
709 class_index,
710 sample_count,
711 class_counts,
712 } => NodeTreeNode::Leaf {
713 node_id,
714 depth: depths[node_id],
715 leaf: self.class_leaf(*class_index),
716 stats: NodeStats {
717 sample_count: *sample_count,
718 impurity: None,
719 gain: None,
720 class_counts: Some(class_counts.clone()),
721 variance: None,
722 },
723 },
724 TreeNode::BinarySplit {
725 feature_index,
726 threshold_bin,
727 missing_direction,
728 left_child,
729 right_child,
730 sample_count,
731 impurity,
732 gain,
733 class_counts,
734 } => NodeTreeNode::BinaryBranch {
735 node_id,
736 depth: depths[node_id],
737 split: binary_split_ir(
738 *feature_index,
739 *threshold_bin,
740 *missing_direction,
741 &self.feature_preprocessing,
742 ),
743 children: BinaryChildren {
744 left: *left_child,
745 right: *right_child,
746 },
747 stats: NodeStats {
748 sample_count: *sample_count,
749 impurity: Some(*impurity),
750 gain: Some(*gain),
751 class_counts: Some(class_counts.clone()),
752 variance: None,
753 },
754 },
755 TreeNode::MultiwaySplit {
756 feature_index,
757 fallback_class_index,
758 branches,
759 missing_child: _,
760 sample_count,
761 impurity,
762 gain,
763 class_counts,
764 } => NodeTreeNode::MultiwayBranch {
765 node_id,
766 depth: depths[node_id],
767 split: MultiwaySplit {
768 split_type: "binned_value_multiway".to_string(),
769 feature_index: *feature_index,
770 feature_name: feature_name(*feature_index),
771 comparison_dtype: "uint16".to_string(),
772 },
773 branches: branches
774 .iter()
775 .map(|(bin, child)| MultiwayBranch {
776 bin: *bin,
777 child: *child,
778 })
779 .collect(),
780 unmatched_leaf: self.class_leaf(*fallback_class_index),
781 stats: NodeStats {
782 sample_count: *sample_count,
783 impurity: Some(*impurity),
784 gain: Some(*gain),
785 class_counts: Some(class_counts.clone()),
786 variance: None,
787 },
788 },
789 })
790 .collect(),
791 }
792 }
793 TreeStructure::Oblivious {
794 splits,
795 leaf_class_indices,
796 leaf_sample_counts,
797 leaf_class_counts,
798 } => TreeDefinition::ObliviousLevels {
799 tree_id: 0,
800 weight: 1.0,
801 depth: splits.len(),
802 levels: splits
803 .iter()
804 .enumerate()
805 .map(|(level, split)| ObliviousLevel {
806 level,
807 split: oblivious_split_ir(
808 split.feature_index,
809 split.threshold_bin,
810 &self.feature_preprocessing,
811 ),
812 stats: NodeStats {
813 sample_count: split.sample_count,
814 impurity: Some(split.impurity),
815 gain: Some(split.gain),
816 class_counts: None,
817 variance: None,
818 },
819 })
820 .collect(),
821 leaf_indexing: LeafIndexing {
822 bit_order: "msb_first".to_string(),
823 index_formula: "sum(bit[level] << (depth - 1 - level))".to_string(),
824 },
825 leaves: leaf_class_indices
826 .iter()
827 .enumerate()
828 .map(|(leaf_index, class_index)| IndexedLeaf {
829 leaf_index,
830 leaf: self.class_leaf(*class_index),
831 stats: NodeStats {
832 sample_count: leaf_sample_counts[leaf_index],
833 impurity: None,
834 gain: None,
835 class_counts: Some(leaf_class_counts[leaf_index].clone()),
836 variance: None,
837 },
838 })
839 .collect(),
840 },
841 }
842 }
843
844 fn class_leaf(&self, class_index: usize) -> LeafPayload {
845 LeafPayload::ClassIndex {
846 class_index,
847 class_value: self.class_labels[class_index],
848 }
849 }
850
851 #[allow(clippy::too_many_arguments)]
852 pub(crate) fn from_ir_parts(
853 algorithm: DecisionTreeAlgorithm,
854 criterion: Criterion,
855 class_labels: Vec<f64>,
856 structure: TreeStructure,
857 options: DecisionTreeOptions,
858 num_features: usize,
859 feature_preprocessing: Vec<FeaturePreprocessing>,
860 training_canaries: usize,
861 ) -> Self {
862 Self {
863 algorithm,
864 criterion,
865 class_labels,
866 structure,
867 options,
868 num_features,
869 feature_preprocessing,
870 training_canaries,
871 }
872 }
873}
874
875fn build_binary_node_in_place(
876 context: &BuildContext<'_>,
877 nodes: &mut Vec<TreeNode>,
878 rows: &mut [usize],
879 depth: usize,
880) -> usize {
881 build_binary_node_in_place_with_hist(context, nodes, rows, depth, None)
882}
883
884fn build_binary_node_in_place_with_hist(
885 context: &BuildContext<'_>,
886 nodes: &mut Vec<TreeNode>,
887 rows: &mut [usize],
888 depth: usize,
889 histograms: Option<Vec<ClassificationFeatureHistogram>>,
890) -> usize {
891 let majority_class_index =
892 majority_class(rows, context.class_indices, context.class_labels.len());
893 let current_class_counts =
894 class_counts(rows, context.class_indices, context.class_labels.len());
895
896 if rows.is_empty()
897 || depth >= context.options.max_depth
898 || rows.len() < context.options.min_samples_split
899 || is_pure(rows, context.class_indices)
900 {
901 return push_leaf(
902 nodes,
903 majority_class_index,
904 rows.len(),
905 current_class_counts,
906 );
907 }
908
909 let scoring = SplitScoringContext {
910 table: context.table,
911 class_indices: context.class_indices,
912 num_classes: context.class_labels.len(),
913 criterion: context.criterion,
914 min_samples_leaf: context.options.min_samples_leaf,
915 missing_value_strategies: &context.options.missing_value_strategies,
916 };
917 let histograms = histograms.unwrap_or_else(|| {
918 build_classification_node_histograms(
919 context.table,
920 context.class_indices,
921 rows,
922 context.class_labels.len(),
923 )
924 });
925 let feature_indices = candidate_feature_indices(
926 context.table.binned_feature_count(),
927 context.options.max_features,
928 node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
929 );
930 let best_split = if context.parallelism.enabled() {
931 feature_indices
932 .into_par_iter()
933 .filter_map(|feature_index| {
934 score_binary_split_choice_from_hist(
935 &scoring,
936 &histograms[feature_index],
937 feature_index,
938 rows,
939 ¤t_class_counts,
940 context.algorithm,
941 )
942 })
943 .max_by(|left, right| left.score.total_cmp(&right.score))
944 } else {
945 feature_indices
946 .into_iter()
947 .filter_map(|feature_index| {
948 score_binary_split_choice_from_hist(
949 &scoring,
950 &histograms[feature_index],
951 feature_index,
952 rows,
953 ¤t_class_counts,
954 context.algorithm,
955 )
956 })
957 .max_by(|left, right| left.score.total_cmp(&right.score))
958 };
959
960 match best_split {
961 Some(best_split)
962 if context
963 .table
964 .is_canary_binned_feature(best_split.feature_index) =>
965 {
966 push_leaf(
967 nodes,
968 majority_class_index,
969 rows.len(),
970 current_class_counts,
971 )
972 }
973 Some(best_split) if best_split.score > 0.0 => {
974 let impurity =
975 classification_impurity(¤t_class_counts, rows.len(), context.criterion);
976 let left_count = partition_rows_for_binary_split(
977 context.table,
978 best_split.feature_index,
979 best_split.threshold_bin,
980 best_split.missing_direction,
981 rows,
982 );
983 let (left_rows, right_rows) = rows.split_at_mut(left_count);
984 let (left_histograms, right_histograms) = if left_rows.len() <= right_rows.len() {
985 let left_histograms = build_classification_node_histograms(
986 context.table,
987 context.class_indices,
988 left_rows,
989 context.class_labels.len(),
990 );
991 let right_histograms =
992 subtract_classification_node_histograms(&histograms, &left_histograms);
993 (left_histograms, right_histograms)
994 } else {
995 let right_histograms = build_classification_node_histograms(
996 context.table,
997 context.class_indices,
998 right_rows,
999 context.class_labels.len(),
1000 );
1001 let left_histograms =
1002 subtract_classification_node_histograms(&histograms, &right_histograms);
1003 (left_histograms, right_histograms)
1004 };
1005 let left_child = build_binary_node_in_place_with_hist(
1006 context,
1007 nodes,
1008 left_rows,
1009 depth + 1,
1010 Some(left_histograms),
1011 );
1012 let right_child = build_binary_node_in_place_with_hist(
1013 context,
1014 nodes,
1015 right_rows,
1016 depth + 1,
1017 Some(right_histograms),
1018 );
1019
1020 push_node(
1021 nodes,
1022 TreeNode::BinarySplit {
1023 feature_index: best_split.feature_index,
1024 threshold_bin: best_split.threshold_bin,
1025 missing_direction: best_split.missing_direction,
1026 left_child,
1027 right_child,
1028 sample_count: rows.len(),
1029 impurity,
1030 gain: best_split.score,
1031 class_counts: current_class_counts,
1032 },
1033 )
1034 }
1035 _ => push_leaf(
1036 nodes,
1037 majority_class_index,
1038 rows.len(),
1039 current_class_counts,
1040 ),
1041 }
1042}
1043
1044fn build_multiway_node_in_place(
1045 context: &BuildContext<'_>,
1046 nodes: &mut Vec<TreeNode>,
1047 rows: &mut [usize],
1048 depth: usize,
1049) -> usize {
1050 let majority_class_index =
1051 majority_class(rows, context.class_indices, context.class_labels.len());
1052 let current_class_counts =
1053 class_counts(rows, context.class_indices, context.class_labels.len());
1054
1055 if rows.is_empty()
1056 || depth >= context.options.max_depth
1057 || rows.len() < context.options.min_samples_split
1058 || is_pure(rows, context.class_indices)
1059 {
1060 return push_leaf(
1061 nodes,
1062 majority_class_index,
1063 rows.len(),
1064 current_class_counts,
1065 );
1066 }
1067
1068 let metric = match context.algorithm {
1069 DecisionTreeAlgorithm::Id3 => MultiwayMetric::InformationGain,
1070 DecisionTreeAlgorithm::C45 => MultiwayMetric::GainRatio,
1071 _ => unreachable!("multiway builder only supports id3/c45"),
1072 };
1073 let scoring = SplitScoringContext {
1074 table: context.table,
1075 class_indices: context.class_indices,
1076 num_classes: context.class_labels.len(),
1077 criterion: context.criterion,
1078 min_samples_leaf: context.options.min_samples_leaf,
1079 missing_value_strategies: &context.options.missing_value_strategies,
1080 };
1081 let feature_indices = candidate_feature_indices(
1082 context.table.binned_feature_count(),
1083 context.options.max_features,
1084 node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
1085 );
1086 let best_split = if context.parallelism.enabled() {
1087 feature_indices
1088 .into_par_iter()
1089 .filter_map(|feature_index| {
1090 score_multiway_split_choice(&scoring, feature_index, rows, metric)
1091 })
1092 .max_by(|left, right| left.score.total_cmp(&right.score))
1093 } else {
1094 feature_indices
1095 .into_iter()
1096 .filter_map(|feature_index| {
1097 score_multiway_split_choice(&scoring, feature_index, rows, metric)
1098 })
1099 .max_by(|left, right| left.score.total_cmp(&right.score))
1100 };
1101
1102 match best_split {
1103 Some(best_split)
1104 if context
1105 .table
1106 .is_canary_binned_feature(best_split.feature_index) =>
1107 {
1108 push_leaf(
1109 nodes,
1110 majority_class_index,
1111 rows.len(),
1112 current_class_counts,
1113 )
1114 }
1115 Some(best_split) if best_split.score > 0.0 => {
1116 let impurity =
1117 classification_impurity(¤t_class_counts, rows.len(), context.criterion);
1118 let branch_ranges = partition_rows_for_multiway_split(
1119 context.table,
1120 best_split.feature_index,
1121 &best_split.branch_bins,
1122 best_split.missing_branch_bin,
1123 rows,
1124 );
1125 let mut branch_nodes = Vec::with_capacity(branch_ranges.len());
1126 let mut missing_child = None;
1127 for (bin, start, end) in branch_ranges {
1128 let child =
1129 build_multiway_node_in_place(context, nodes, &mut rows[start..end], depth + 1);
1130 if best_split.missing_branch_bin == Some(bin) {
1131 missing_child = Some(child);
1132 }
1133 branch_nodes.push((bin, child));
1134 }
1135
1136 push_node(
1137 nodes,
1138 TreeNode::MultiwaySplit {
1139 feature_index: best_split.feature_index,
1140 fallback_class_index: majority_class_index,
1141 branches: branch_nodes,
1142 missing_child,
1143 sample_count: rows.len(),
1144 impurity,
1145 gain: best_split.score,
1146 class_counts: current_class_counts,
1147 },
1148 )
1149 }
1150 _ => push_leaf(
1151 nodes,
1152 majority_class_index,
1153 rows.len(),
1154 current_class_counts,
1155 ),
1156 }
1157}
1158
1159struct BuildContext<'a> {
1160 table: &'a dyn TableAccess,
1161 class_indices: &'a [usize],
1162 class_labels: &'a [f64],
1163 algorithm: DecisionTreeAlgorithm,
1164 criterion: Criterion,
1165 parallelism: Parallelism,
1166 options: DecisionTreeOptions,
1167}
1168
1169fn encode_class_labels(
1170 train_set: &dyn TableAccess,
1171) -> Result<(Vec<f64>, Vec<usize>), DecisionTreeError> {
1172 let targets: Vec<f64> = (0..train_set.n_rows())
1173 .map(|row_idx| {
1174 let value = train_set.target_value(row_idx);
1175 if value.is_finite() {
1176 Ok(value)
1177 } else {
1178 Err(DecisionTreeError::InvalidTargetValue {
1179 row: row_idx,
1180 value,
1181 })
1182 }
1183 })
1184 .collect::<Result<_, _>>()?;
1185
1186 let class_labels = targets
1187 .iter()
1188 .copied()
1189 .fold(Vec::<f64>::new(), |mut labels, value| {
1190 if labels
1191 .binary_search_by(|candidate| candidate.total_cmp(&value))
1192 .is_err()
1193 {
1194 labels.push(value);
1195 labels.sort_by(|left, right| left.total_cmp(right));
1196 }
1197 labels
1198 });
1199
1200 let class_indices = targets
1201 .iter()
1202 .map(|value| {
1203 class_labels
1204 .binary_search_by(|candidate| candidate.total_cmp(value))
1205 .expect("target value must exist in class label vocabulary")
1206 })
1207 .collect();
1208
1209 Ok((class_labels, class_indices))
1210}
1211
1212fn class_counts(rows: &[usize], class_indices: &[usize], num_classes: usize) -> Vec<usize> {
1213 rows.iter()
1214 .fold(vec![0usize; num_classes], |mut counts, row_idx| {
1215 counts[class_indices[*row_idx]] += 1;
1216 counts
1217 })
1218}
1219
1220fn majority_class(rows: &[usize], class_indices: &[usize], num_classes: usize) -> usize {
1221 majority_class_from_counts(&class_counts(rows, class_indices, num_classes))
1222}
1223
1224fn majority_class_from_counts(counts: &[usize]) -> usize {
1225 counts
1226 .iter()
1227 .copied()
1228 .enumerate()
1229 .max_by(|left, right| left.1.cmp(&right.1).then_with(|| right.0.cmp(&left.0)))
1230 .map(|(class_index, _count)| class_index)
1231 .unwrap_or(0)
1232}
1233
1234fn is_pure(rows: &[usize], class_indices: &[usize]) -> bool {
1235 rows.first().is_none_or(|first_row| {
1236 rows.iter()
1237 .all(|row_idx| class_indices[*row_idx] == class_indices[*first_row])
1238 })
1239}
1240
1241fn entropy(counts: &[usize], total: usize) -> f64 {
1242 counts
1243 .iter()
1244 .copied()
1245 .filter(|count| *count > 0)
1246 .map(|count| {
1247 let probability = count as f64 / total as f64;
1248 -probability * probability.log2()
1249 })
1250 .sum()
1251}
1252
1253fn gini(counts: &[usize], total: usize) -> f64 {
1254 1.0 - counts
1255 .iter()
1256 .copied()
1257 .map(|count| {
1258 let probability = count as f64 / total as f64;
1259 probability * probability
1260 })
1261 .sum::<f64>()
1262}
1263
1264fn classification_impurity(counts: &[usize], total: usize, criterion: Criterion) -> f64 {
1265 match criterion {
1266 Criterion::Entropy => entropy(counts, total),
1267 Criterion::Gini => gini(counts, total),
1268 _ => unreachable!("classification impurity only supports gini or entropy"),
1269 }
1270}
1271
1272fn push_leaf(
1273 nodes: &mut Vec<TreeNode>,
1274 class_index: usize,
1275 sample_count: usize,
1276 class_counts: Vec<usize>,
1277) -> usize {
1278 push_node(
1279 nodes,
1280 TreeNode::Leaf {
1281 class_index,
1282 sample_count,
1283 class_counts,
1284 },
1285 )
1286}
1287
1288fn push_node(nodes: &mut Vec<TreeNode>, node: TreeNode) -> usize {
1289 nodes.push(node);
1290 nodes.len() - 1
1291}
1292
1293#[cfg(test)]
1294mod tests;