1use crate::ir::{
17 BinaryChildren, BinarySplit, IndexedLeaf, LeafIndexing, LeafPayload, MultiwayBranch,
18 MultiwaySplit, NodeStats, NodeTreeNode, ObliviousLevel, ObliviousSplit as IrObliviousSplit,
19 TrainingMetadata, TreeDefinition, criterion_name, feature_name, threshold_upper_bound,
20 tree_type_name,
21};
22use crate::tree::shared::{
23 candidate_feature_indices, choose_random_threshold, node_seed, partition_rows_for_binary_split,
24};
25use crate::{Criterion, FeaturePreprocessing, Parallelism, capture_feature_preprocessing};
26use forestfire_data::TableAccess;
27use rayon::prelude::*;
28use std::collections::BTreeMap;
29use std::error::Error;
30use std::fmt::{Display, Formatter};
31
32mod histogram;
33mod ir_support;
34mod oblivious;
35mod partitioning;
36mod split_scoring;
37
38use histogram::{
39 ClassificationFeatureHistogram, build_classification_node_histograms,
40 subtract_classification_node_histograms,
41};
42use ir_support::{
43 binary_split_ir, normalized_class_probabilities, oblivious_split_ir, standard_node_depths,
44};
45use oblivious::train_oblivious_structure;
46use partitioning::partition_rows_for_multiway_split;
47use split_scoring::{
48 MultiwayMetric, SplitScoringContext, score_binary_split_choice_from_hist,
49 score_multiway_split_choice,
50};
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum DecisionTreeAlgorithm {
54 Id3,
55 C45,
56 Cart,
57 Randomized,
58 Oblivious,
59}
60
61#[derive(Debug, Clone, Copy)]
67pub struct DecisionTreeOptions {
68 pub max_depth: usize,
69 pub min_samples_split: usize,
70 pub min_samples_leaf: usize,
71 pub max_features: Option<usize>,
72 pub random_seed: u64,
73}
74
75impl Default for DecisionTreeOptions {
76 fn default() -> Self {
77 Self {
78 max_depth: 8,
79 min_samples_split: 2,
80 min_samples_leaf: 1,
81 max_features: None,
82 random_seed: 0,
83 }
84 }
85}
86
87#[derive(Debug)]
88pub enum DecisionTreeError {
89 EmptyTarget,
90 InvalidTargetValue { row: usize, value: f64 },
91}
92
93impl Display for DecisionTreeError {
94 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
95 match self {
96 DecisionTreeError::EmptyTarget => write!(f, "Cannot train on an empty target vector."),
97 DecisionTreeError::InvalidTargetValue { row, value } => write!(
98 f,
99 "Classification targets must be finite values. Found {} at row {}.",
100 value, row
101 ),
102 }
103 }
104}
105
106impl Error for DecisionTreeError {}
107
108#[derive(Debug, Clone)]
110pub struct DecisionTreeClassifier {
111 algorithm: DecisionTreeAlgorithm,
112 criterion: Criterion,
113 class_labels: Vec<f64>,
114 structure: TreeStructure,
115 options: DecisionTreeOptions,
116 num_features: usize,
117 feature_preprocessing: Vec<FeaturePreprocessing>,
118 training_canaries: usize,
119}
120
121#[derive(Debug, Clone)]
122pub(crate) enum TreeStructure {
123 Standard {
124 nodes: Vec<TreeNode>,
125 root: usize,
126 },
127 Oblivious {
128 splits: Vec<ObliviousSplit>,
129 leaf_class_indices: Vec<usize>,
130 leaf_sample_counts: Vec<usize>,
131 leaf_class_counts: Vec<Vec<usize>>,
132 },
133}
134
135#[derive(Debug, Clone, Copy)]
136pub(crate) struct ObliviousSplit {
137 pub(crate) feature_index: usize,
138 pub(crate) threshold_bin: u16,
139 pub(crate) sample_count: usize,
140 pub(crate) impurity: f64,
141 pub(crate) gain: f64,
142}
143
144#[derive(Debug, Clone)]
145pub(crate) enum TreeNode {
146 Leaf {
147 class_index: usize,
148 sample_count: usize,
149 class_counts: Vec<usize>,
150 },
151 MultiwaySplit {
152 feature_index: usize,
153 fallback_class_index: usize,
154 branches: Vec<(u16, usize)>,
155 sample_count: usize,
156 impurity: f64,
157 gain: f64,
158 class_counts: Vec<usize>,
159 },
160 BinarySplit {
161 feature_index: usize,
162 threshold_bin: u16,
163 left_child: usize,
164 right_child: usize,
165 sample_count: usize,
166 impurity: f64,
167 gain: f64,
168 class_counts: Vec<usize>,
169 },
170}
171
172#[derive(Debug, Clone, Copy)]
173struct BinarySplitChoice {
174 feature_index: usize,
175 score: f64,
176 threshold_bin: u16,
177}
178
179#[derive(Debug, Clone)]
180struct MultiwaySplitChoice {
181 feature_index: usize,
182 score: f64,
183 branch_bins: Vec<u16>,
184}
185
186pub fn train_id3(train_set: &dyn TableAccess) -> Result<DecisionTreeClassifier, DecisionTreeError> {
187 train_id3_with_criterion(train_set, Criterion::Entropy)
188}
189
190pub fn train_id3_with_criterion(
191 train_set: &dyn TableAccess,
192 criterion: Criterion,
193) -> Result<DecisionTreeClassifier, DecisionTreeError> {
194 train_id3_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
195}
196
197pub(crate) fn train_id3_with_criterion_and_parallelism(
198 train_set: &dyn TableAccess,
199 criterion: Criterion,
200 parallelism: Parallelism,
201) -> Result<DecisionTreeClassifier, DecisionTreeError> {
202 train_id3_with_criterion_parallelism_and_options(
203 train_set,
204 criterion,
205 parallelism,
206 DecisionTreeOptions::default(),
207 )
208}
209
210pub(crate) fn train_id3_with_criterion_parallelism_and_options(
211 train_set: &dyn TableAccess,
212 criterion: Criterion,
213 parallelism: Parallelism,
214 options: DecisionTreeOptions,
215) -> Result<DecisionTreeClassifier, DecisionTreeError> {
216 train_classifier(
217 train_set,
218 DecisionTreeAlgorithm::Id3,
219 criterion,
220 parallelism,
221 options,
222 )
223}
224
225pub fn train_c45(train_set: &dyn TableAccess) -> Result<DecisionTreeClassifier, DecisionTreeError> {
226 train_c45_with_criterion(train_set, Criterion::Entropy)
227}
228
229pub fn train_c45_with_criterion(
230 train_set: &dyn TableAccess,
231 criterion: Criterion,
232) -> Result<DecisionTreeClassifier, DecisionTreeError> {
233 train_c45_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
234}
235
236pub(crate) fn train_c45_with_criterion_and_parallelism(
237 train_set: &dyn TableAccess,
238 criterion: Criterion,
239 parallelism: Parallelism,
240) -> Result<DecisionTreeClassifier, DecisionTreeError> {
241 train_c45_with_criterion_parallelism_and_options(
242 train_set,
243 criterion,
244 parallelism,
245 DecisionTreeOptions::default(),
246 )
247}
248
249pub(crate) fn train_c45_with_criterion_parallelism_and_options(
250 train_set: &dyn TableAccess,
251 criterion: Criterion,
252 parallelism: Parallelism,
253 options: DecisionTreeOptions,
254) -> Result<DecisionTreeClassifier, DecisionTreeError> {
255 train_classifier(
256 train_set,
257 DecisionTreeAlgorithm::C45,
258 criterion,
259 parallelism,
260 options,
261 )
262}
263
264pub fn train_cart(
265 train_set: &dyn TableAccess,
266) -> Result<DecisionTreeClassifier, DecisionTreeError> {
267 train_cart_with_criterion(train_set, Criterion::Gini)
268}
269
270pub fn train_cart_with_criterion(
271 train_set: &dyn TableAccess,
272 criterion: Criterion,
273) -> Result<DecisionTreeClassifier, DecisionTreeError> {
274 train_cart_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
275}
276
277pub(crate) fn train_cart_with_criterion_and_parallelism(
278 train_set: &dyn TableAccess,
279 criterion: Criterion,
280 parallelism: Parallelism,
281) -> Result<DecisionTreeClassifier, DecisionTreeError> {
282 train_cart_with_criterion_parallelism_and_options(
283 train_set,
284 criterion,
285 parallelism,
286 DecisionTreeOptions::default(),
287 )
288}
289
290pub(crate) fn train_cart_with_criterion_parallelism_and_options(
291 train_set: &dyn TableAccess,
292 criterion: Criterion,
293 parallelism: Parallelism,
294 options: DecisionTreeOptions,
295) -> Result<DecisionTreeClassifier, DecisionTreeError> {
296 train_classifier(
297 train_set,
298 DecisionTreeAlgorithm::Cart,
299 criterion,
300 parallelism,
301 options,
302 )
303}
304
305pub fn train_oblivious(
306 train_set: &dyn TableAccess,
307) -> Result<DecisionTreeClassifier, DecisionTreeError> {
308 train_oblivious_with_criterion(train_set, Criterion::Gini)
309}
310
311pub fn train_oblivious_with_criterion(
312 train_set: &dyn TableAccess,
313 criterion: Criterion,
314) -> Result<DecisionTreeClassifier, DecisionTreeError> {
315 train_oblivious_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
316}
317
318pub(crate) fn train_oblivious_with_criterion_and_parallelism(
319 train_set: &dyn TableAccess,
320 criterion: Criterion,
321 parallelism: Parallelism,
322) -> Result<DecisionTreeClassifier, DecisionTreeError> {
323 train_oblivious_with_criterion_parallelism_and_options(
324 train_set,
325 criterion,
326 parallelism,
327 DecisionTreeOptions::default(),
328 )
329}
330
331pub(crate) fn train_oblivious_with_criterion_parallelism_and_options(
332 train_set: &dyn TableAccess,
333 criterion: Criterion,
334 parallelism: Parallelism,
335 options: DecisionTreeOptions,
336) -> Result<DecisionTreeClassifier, DecisionTreeError> {
337 train_classifier(
338 train_set,
339 DecisionTreeAlgorithm::Oblivious,
340 criterion,
341 parallelism,
342 options,
343 )
344}
345
346pub fn train_randomized(
347 train_set: &dyn TableAccess,
348) -> Result<DecisionTreeClassifier, DecisionTreeError> {
349 train_randomized_with_criterion(train_set, Criterion::Gini)
350}
351
352pub fn train_randomized_with_criterion(
353 train_set: &dyn TableAccess,
354 criterion: Criterion,
355) -> Result<DecisionTreeClassifier, DecisionTreeError> {
356 train_randomized_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
357}
358
359pub(crate) fn train_randomized_with_criterion_and_parallelism(
360 train_set: &dyn TableAccess,
361 criterion: Criterion,
362 parallelism: Parallelism,
363) -> Result<DecisionTreeClassifier, DecisionTreeError> {
364 train_randomized_with_criterion_parallelism_and_options(
365 train_set,
366 criterion,
367 parallelism,
368 DecisionTreeOptions::default(),
369 )
370}
371
372pub(crate) fn train_randomized_with_criterion_parallelism_and_options(
373 train_set: &dyn TableAccess,
374 criterion: Criterion,
375 parallelism: Parallelism,
376 options: DecisionTreeOptions,
377) -> Result<DecisionTreeClassifier, DecisionTreeError> {
378 train_classifier(
379 train_set,
380 DecisionTreeAlgorithm::Randomized,
381 criterion,
382 parallelism,
383 options,
384 )
385}
386
387fn train_classifier(
388 train_set: &dyn TableAccess,
389 algorithm: DecisionTreeAlgorithm,
390 criterion: Criterion,
391 parallelism: Parallelism,
392 options: DecisionTreeOptions,
393) -> Result<DecisionTreeClassifier, DecisionTreeError> {
394 if train_set.n_rows() == 0 {
395 return Err(DecisionTreeError::EmptyTarget);
396 }
397
398 let (class_labels, class_indices) = encode_class_labels(train_set)?;
399 let structure = match algorithm {
400 DecisionTreeAlgorithm::Oblivious => train_oblivious_structure(
401 train_set,
402 &class_indices,
403 &class_labels,
404 criterion,
405 parallelism,
406 options,
407 ),
408 DecisionTreeAlgorithm::Cart | DecisionTreeAlgorithm::Randomized => {
409 let mut nodes = Vec::new();
410 let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
411 let context = BuildContext {
412 table: train_set,
413 class_indices: &class_indices,
414 class_labels: &class_labels,
415 algorithm,
416 criterion,
417 parallelism,
418 options,
419 };
420 let root = build_binary_node_in_place(&context, &mut nodes, &mut all_rows, 0);
421 TreeStructure::Standard { nodes, root }
422 }
423 DecisionTreeAlgorithm::Id3 | DecisionTreeAlgorithm::C45 => {
424 let mut nodes = Vec::new();
425 let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
426 let context = BuildContext {
427 table: train_set,
428 class_indices: &class_indices,
429 class_labels: &class_labels,
430 algorithm,
431 criterion,
432 parallelism,
433 options,
434 };
435 let root = build_multiway_node_in_place(&context, &mut nodes, &mut all_rows, 0);
436 TreeStructure::Standard { nodes, root }
437 }
438 };
439
440 Ok(DecisionTreeClassifier {
441 algorithm,
442 criterion,
443 class_labels,
444 structure,
445 options,
446 num_features: train_set.n_features(),
447 feature_preprocessing: capture_feature_preprocessing(train_set),
448 training_canaries: train_set.canaries(),
449 })
450}
451
452impl DecisionTreeClassifier {
453 pub fn algorithm(&self) -> DecisionTreeAlgorithm {
454 self.algorithm
455 }
456
457 pub fn criterion(&self) -> Criterion {
458 self.criterion
459 }
460
461 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
462 (0..table.n_rows())
463 .map(|row_idx| self.predict_row(table, row_idx))
464 .collect()
465 }
466
467 pub fn predict_proba_table(&self, table: &dyn TableAccess) -> Vec<Vec<f64>> {
468 (0..table.n_rows())
469 .map(|row_idx| self.predict_proba_row(table, row_idx))
470 .collect()
471 }
472
473 fn predict_row(&self, table: &dyn TableAccess, row_idx: usize) -> f64 {
474 match &self.structure {
475 TreeStructure::Standard { nodes, root } => {
476 let mut node_index = *root;
477 loop {
478 match &nodes[node_index] {
479 TreeNode::Leaf { class_index, .. } => {
480 return self.class_labels[*class_index];
481 }
482 TreeNode::MultiwaySplit {
483 feature_index,
484 fallback_class_index,
485 branches,
486 ..
487 } => {
488 let bin = table.binned_value(*feature_index, row_idx);
489 if let Some((_, child_index)) =
490 branches.iter().find(|(branch_bin, _)| *branch_bin == bin)
491 {
492 node_index = *child_index;
493 } else {
494 return self.class_labels[*fallback_class_index];
495 }
496 }
497 TreeNode::BinarySplit {
498 feature_index,
499 threshold_bin,
500 left_child,
501 right_child,
502 ..
503 } => {
504 let bin = table.binned_value(*feature_index, row_idx);
505 node_index = if bin <= *threshold_bin {
506 *left_child
507 } else {
508 *right_child
509 };
510 }
511 }
512 }
513 }
514 TreeStructure::Oblivious {
515 splits,
516 leaf_class_indices,
517 ..
518 } => {
519 let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
520 let go_right =
521 table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
522 (leaf_index << 1) | usize::from(go_right)
523 });
524
525 self.class_labels[leaf_class_indices[leaf_index]]
526 }
527 }
528 }
529
530 fn predict_proba_row(&self, table: &dyn TableAccess, row_idx: usize) -> Vec<f64> {
531 match &self.structure {
532 TreeStructure::Standard { nodes, root } => {
533 let mut node_index = *root;
534 loop {
535 match &nodes[node_index] {
536 TreeNode::Leaf { class_counts, .. } => {
537 return normalized_class_probabilities(class_counts);
538 }
539 TreeNode::MultiwaySplit {
540 feature_index,
541 branches,
542 class_counts,
543 ..
544 } => {
545 let bin = table.binned_value(*feature_index, row_idx);
546 if let Some((_, child_index)) =
547 branches.iter().find(|(branch_bin, _)| *branch_bin == bin)
548 {
549 node_index = *child_index;
550 } else {
551 return normalized_class_probabilities(class_counts);
552 }
553 }
554 TreeNode::BinarySplit {
555 feature_index,
556 threshold_bin,
557 left_child,
558 right_child,
559 ..
560 } => {
561 let bin = table.binned_value(*feature_index, row_idx);
562 node_index = if bin <= *threshold_bin {
563 *left_child
564 } else {
565 *right_child
566 };
567 }
568 }
569 }
570 }
571 TreeStructure::Oblivious {
572 splits,
573 leaf_class_counts,
574 ..
575 } => {
576 let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
577 let go_right =
578 table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
579 (leaf_index << 1) | usize::from(go_right)
580 });
581
582 normalized_class_probabilities(&leaf_class_counts[leaf_index])
583 }
584 }
585 }
586
587 pub(crate) fn class_labels(&self) -> &[f64] {
588 &self.class_labels
589 }
590
591 pub(crate) fn structure(&self) -> &TreeStructure {
592 &self.structure
593 }
594
595 pub(crate) fn num_features(&self) -> usize {
596 self.num_features
597 }
598
599 pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
600 &self.feature_preprocessing
601 }
602
603 pub(crate) fn training_metadata(&self) -> TrainingMetadata {
604 TrainingMetadata {
605 algorithm: "dt".to_string(),
606 task: "classification".to_string(),
607 tree_type: tree_type_name(match self.algorithm {
608 DecisionTreeAlgorithm::Id3 => crate::TreeType::Id3,
609 DecisionTreeAlgorithm::C45 => crate::TreeType::C45,
610 DecisionTreeAlgorithm::Cart => crate::TreeType::Cart,
611 DecisionTreeAlgorithm::Randomized => crate::TreeType::Randomized,
612 DecisionTreeAlgorithm::Oblivious => crate::TreeType::Oblivious,
613 })
614 .to_string(),
615 criterion: criterion_name(self.criterion).to_string(),
616 canaries: self.training_canaries,
617 compute_oob: false,
618 max_depth: Some(self.options.max_depth),
619 min_samples_split: Some(self.options.min_samples_split),
620 min_samples_leaf: Some(self.options.min_samples_leaf),
621 n_trees: None,
622 max_features: self.options.max_features,
623 seed: None,
624 oob_score: None,
625 class_labels: Some(self.class_labels.clone()),
626 learning_rate: None,
627 bootstrap: None,
628 top_gradient_fraction: None,
629 other_gradient_fraction: None,
630 }
631 }
632
633 pub(crate) fn to_ir_tree(&self) -> TreeDefinition {
634 match &self.structure {
635 TreeStructure::Standard { nodes, root } => {
636 let depths = standard_node_depths(nodes, *root);
637 TreeDefinition::NodeTree {
638 tree_id: 0,
639 weight: 1.0,
640 root_node_id: *root,
641 nodes: nodes
642 .iter()
643 .enumerate()
644 .map(|(node_id, node)| match node {
645 TreeNode::Leaf {
646 class_index,
647 sample_count,
648 class_counts,
649 } => NodeTreeNode::Leaf {
650 node_id,
651 depth: depths[node_id],
652 leaf: self.class_leaf(*class_index),
653 stats: NodeStats {
654 sample_count: *sample_count,
655 impurity: None,
656 gain: None,
657 class_counts: Some(class_counts.clone()),
658 variance: None,
659 },
660 },
661 TreeNode::BinarySplit {
662 feature_index,
663 threshold_bin,
664 left_child,
665 right_child,
666 sample_count,
667 impurity,
668 gain,
669 class_counts,
670 } => NodeTreeNode::BinaryBranch {
671 node_id,
672 depth: depths[node_id],
673 split: binary_split_ir(
674 *feature_index,
675 *threshold_bin,
676 &self.feature_preprocessing,
677 ),
678 children: BinaryChildren {
679 left: *left_child,
680 right: *right_child,
681 },
682 stats: NodeStats {
683 sample_count: *sample_count,
684 impurity: Some(*impurity),
685 gain: Some(*gain),
686 class_counts: Some(class_counts.clone()),
687 variance: None,
688 },
689 },
690 TreeNode::MultiwaySplit {
691 feature_index,
692 fallback_class_index,
693 branches,
694 sample_count,
695 impurity,
696 gain,
697 class_counts,
698 } => NodeTreeNode::MultiwayBranch {
699 node_id,
700 depth: depths[node_id],
701 split: MultiwaySplit {
702 split_type: "binned_value_multiway".to_string(),
703 feature_index: *feature_index,
704 feature_name: feature_name(*feature_index),
705 comparison_dtype: "uint16".to_string(),
706 },
707 branches: branches
708 .iter()
709 .map(|(bin, child)| MultiwayBranch {
710 bin: *bin,
711 child: *child,
712 })
713 .collect(),
714 unmatched_leaf: self.class_leaf(*fallback_class_index),
715 stats: NodeStats {
716 sample_count: *sample_count,
717 impurity: Some(*impurity),
718 gain: Some(*gain),
719 class_counts: Some(class_counts.clone()),
720 variance: None,
721 },
722 },
723 })
724 .collect(),
725 }
726 }
727 TreeStructure::Oblivious {
728 splits,
729 leaf_class_indices,
730 leaf_sample_counts,
731 leaf_class_counts,
732 } => TreeDefinition::ObliviousLevels {
733 tree_id: 0,
734 weight: 1.0,
735 depth: splits.len(),
736 levels: splits
737 .iter()
738 .enumerate()
739 .map(|(level, split)| ObliviousLevel {
740 level,
741 split: oblivious_split_ir(
742 split.feature_index,
743 split.threshold_bin,
744 &self.feature_preprocessing,
745 ),
746 stats: NodeStats {
747 sample_count: split.sample_count,
748 impurity: Some(split.impurity),
749 gain: Some(split.gain),
750 class_counts: None,
751 variance: None,
752 },
753 })
754 .collect(),
755 leaf_indexing: LeafIndexing {
756 bit_order: "msb_first".to_string(),
757 index_formula: "sum(bit[level] << (depth - 1 - level))".to_string(),
758 },
759 leaves: leaf_class_indices
760 .iter()
761 .enumerate()
762 .map(|(leaf_index, class_index)| IndexedLeaf {
763 leaf_index,
764 leaf: self.class_leaf(*class_index),
765 stats: NodeStats {
766 sample_count: leaf_sample_counts[leaf_index],
767 impurity: None,
768 gain: None,
769 class_counts: Some(leaf_class_counts[leaf_index].clone()),
770 variance: None,
771 },
772 })
773 .collect(),
774 },
775 }
776 }
777
778 fn class_leaf(&self, class_index: usize) -> LeafPayload {
779 LeafPayload::ClassIndex {
780 class_index,
781 class_value: self.class_labels[class_index],
782 }
783 }
784
785 #[allow(clippy::too_many_arguments)]
786 pub(crate) fn from_ir_parts(
787 algorithm: DecisionTreeAlgorithm,
788 criterion: Criterion,
789 class_labels: Vec<f64>,
790 structure: TreeStructure,
791 options: DecisionTreeOptions,
792 num_features: usize,
793 feature_preprocessing: Vec<FeaturePreprocessing>,
794 training_canaries: usize,
795 ) -> Self {
796 Self {
797 algorithm,
798 criterion,
799 class_labels,
800 structure,
801 options,
802 num_features,
803 feature_preprocessing,
804 training_canaries,
805 }
806 }
807}
808
809fn build_binary_node_in_place(
810 context: &BuildContext<'_>,
811 nodes: &mut Vec<TreeNode>,
812 rows: &mut [usize],
813 depth: usize,
814) -> usize {
815 build_binary_node_in_place_with_hist(context, nodes, rows, depth, None)
816}
817
818fn build_binary_node_in_place_with_hist(
819 context: &BuildContext<'_>,
820 nodes: &mut Vec<TreeNode>,
821 rows: &mut [usize],
822 depth: usize,
823 histograms: Option<Vec<ClassificationFeatureHistogram>>,
824) -> usize {
825 let majority_class_index =
826 majority_class(rows, context.class_indices, context.class_labels.len());
827 let current_class_counts =
828 class_counts(rows, context.class_indices, context.class_labels.len());
829
830 if rows.is_empty()
831 || depth >= context.options.max_depth
832 || rows.len() < context.options.min_samples_split
833 || is_pure(rows, context.class_indices)
834 {
835 return push_leaf(
836 nodes,
837 majority_class_index,
838 rows.len(),
839 current_class_counts,
840 );
841 }
842
843 let scoring = SplitScoringContext {
844 table: context.table,
845 class_indices: context.class_indices,
846 num_classes: context.class_labels.len(),
847 criterion: context.criterion,
848 min_samples_leaf: context.options.min_samples_leaf,
849 };
850 let histograms = histograms.unwrap_or_else(|| {
851 build_classification_node_histograms(
852 context.table,
853 context.class_indices,
854 rows,
855 context.class_labels.len(),
856 )
857 });
858 let feature_indices = candidate_feature_indices(
859 context.table.binned_feature_count(),
860 context.options.max_features,
861 node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
862 );
863 let best_split = if context.parallelism.enabled() {
864 feature_indices
865 .into_par_iter()
866 .filter_map(|feature_index| {
867 score_binary_split_choice_from_hist(
868 &scoring,
869 &histograms[feature_index],
870 feature_index,
871 rows,
872 ¤t_class_counts,
873 context.algorithm,
874 )
875 })
876 .max_by(|left, right| left.score.total_cmp(&right.score))
877 } else {
878 feature_indices
879 .into_iter()
880 .filter_map(|feature_index| {
881 score_binary_split_choice_from_hist(
882 &scoring,
883 &histograms[feature_index],
884 feature_index,
885 rows,
886 ¤t_class_counts,
887 context.algorithm,
888 )
889 })
890 .max_by(|left, right| left.score.total_cmp(&right.score))
891 };
892
893 match best_split {
894 Some(best_split)
895 if context
896 .table
897 .is_canary_binned_feature(best_split.feature_index) =>
898 {
899 push_leaf(
900 nodes,
901 majority_class_index,
902 rows.len(),
903 current_class_counts,
904 )
905 }
906 Some(best_split) if best_split.score > 0.0 => {
907 let impurity =
908 classification_impurity(¤t_class_counts, rows.len(), context.criterion);
909 let left_count = partition_rows_for_binary_split(
910 context.table,
911 best_split.feature_index,
912 best_split.threshold_bin,
913 rows,
914 );
915 let (left_rows, right_rows) = rows.split_at_mut(left_count);
916 let (left_histograms, right_histograms) = if left_rows.len() <= right_rows.len() {
917 let left_histograms = build_classification_node_histograms(
918 context.table,
919 context.class_indices,
920 left_rows,
921 context.class_labels.len(),
922 );
923 let right_histograms =
924 subtract_classification_node_histograms(&histograms, &left_histograms);
925 (left_histograms, right_histograms)
926 } else {
927 let right_histograms = build_classification_node_histograms(
928 context.table,
929 context.class_indices,
930 right_rows,
931 context.class_labels.len(),
932 );
933 let left_histograms =
934 subtract_classification_node_histograms(&histograms, &right_histograms);
935 (left_histograms, right_histograms)
936 };
937 let left_child = build_binary_node_in_place_with_hist(
938 context,
939 nodes,
940 left_rows,
941 depth + 1,
942 Some(left_histograms),
943 );
944 let right_child = build_binary_node_in_place_with_hist(
945 context,
946 nodes,
947 right_rows,
948 depth + 1,
949 Some(right_histograms),
950 );
951
952 push_node(
953 nodes,
954 TreeNode::BinarySplit {
955 feature_index: best_split.feature_index,
956 threshold_bin: best_split.threshold_bin,
957 left_child,
958 right_child,
959 sample_count: rows.len(),
960 impurity,
961 gain: best_split.score,
962 class_counts: current_class_counts,
963 },
964 )
965 }
966 _ => push_leaf(
967 nodes,
968 majority_class_index,
969 rows.len(),
970 current_class_counts,
971 ),
972 }
973}
974
975fn build_multiway_node_in_place(
976 context: &BuildContext<'_>,
977 nodes: &mut Vec<TreeNode>,
978 rows: &mut [usize],
979 depth: usize,
980) -> usize {
981 let majority_class_index =
982 majority_class(rows, context.class_indices, context.class_labels.len());
983 let current_class_counts =
984 class_counts(rows, context.class_indices, context.class_labels.len());
985
986 if rows.is_empty()
987 || depth >= context.options.max_depth
988 || rows.len() < context.options.min_samples_split
989 || is_pure(rows, context.class_indices)
990 {
991 return push_leaf(
992 nodes,
993 majority_class_index,
994 rows.len(),
995 current_class_counts,
996 );
997 }
998
999 let metric = match context.algorithm {
1000 DecisionTreeAlgorithm::Id3 => MultiwayMetric::InformationGain,
1001 DecisionTreeAlgorithm::C45 => MultiwayMetric::GainRatio,
1002 _ => unreachable!("multiway builder only supports id3/c45"),
1003 };
1004 let scoring = SplitScoringContext {
1005 table: context.table,
1006 class_indices: context.class_indices,
1007 num_classes: context.class_labels.len(),
1008 criterion: context.criterion,
1009 min_samples_leaf: context.options.min_samples_leaf,
1010 };
1011 let feature_indices = candidate_feature_indices(
1012 context.table.binned_feature_count(),
1013 context.options.max_features,
1014 node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
1015 );
1016 let best_split = if context.parallelism.enabled() {
1017 feature_indices
1018 .into_par_iter()
1019 .filter_map(|feature_index| {
1020 score_multiway_split_choice(&scoring, feature_index, rows, metric)
1021 })
1022 .max_by(|left, right| left.score.total_cmp(&right.score))
1023 } else {
1024 feature_indices
1025 .into_iter()
1026 .filter_map(|feature_index| {
1027 score_multiway_split_choice(&scoring, feature_index, rows, metric)
1028 })
1029 .max_by(|left, right| left.score.total_cmp(&right.score))
1030 };
1031
1032 match best_split {
1033 Some(best_split)
1034 if context
1035 .table
1036 .is_canary_binned_feature(best_split.feature_index) =>
1037 {
1038 push_leaf(
1039 nodes,
1040 majority_class_index,
1041 rows.len(),
1042 current_class_counts,
1043 )
1044 }
1045 Some(best_split) if best_split.score > 0.0 => {
1046 let impurity =
1047 classification_impurity(¤t_class_counts, rows.len(), context.criterion);
1048 let branch_ranges = partition_rows_for_multiway_split(
1049 context.table,
1050 best_split.feature_index,
1051 &best_split.branch_bins,
1052 rows,
1053 );
1054 let mut branch_nodes = Vec::with_capacity(branch_ranges.len());
1055 for (bin, start, end) in branch_ranges {
1056 let child =
1057 build_multiway_node_in_place(context, nodes, &mut rows[start..end], depth + 1);
1058 branch_nodes.push((bin, child));
1059 }
1060
1061 push_node(
1062 nodes,
1063 TreeNode::MultiwaySplit {
1064 feature_index: best_split.feature_index,
1065 fallback_class_index: majority_class_index,
1066 branches: branch_nodes,
1067 sample_count: rows.len(),
1068 impurity,
1069 gain: best_split.score,
1070 class_counts: current_class_counts,
1071 },
1072 )
1073 }
1074 _ => push_leaf(
1075 nodes,
1076 majority_class_index,
1077 rows.len(),
1078 current_class_counts,
1079 ),
1080 }
1081}
1082
1083struct BuildContext<'a> {
1084 table: &'a dyn TableAccess,
1085 class_indices: &'a [usize],
1086 class_labels: &'a [f64],
1087 algorithm: DecisionTreeAlgorithm,
1088 criterion: Criterion,
1089 parallelism: Parallelism,
1090 options: DecisionTreeOptions,
1091}
1092
1093fn encode_class_labels(
1094 train_set: &dyn TableAccess,
1095) -> Result<(Vec<f64>, Vec<usize>), DecisionTreeError> {
1096 let targets: Vec<f64> = (0..train_set.n_rows())
1097 .map(|row_idx| {
1098 let value = train_set.target_value(row_idx);
1099 if value.is_finite() {
1100 Ok(value)
1101 } else {
1102 Err(DecisionTreeError::InvalidTargetValue {
1103 row: row_idx,
1104 value,
1105 })
1106 }
1107 })
1108 .collect::<Result<_, _>>()?;
1109
1110 let class_labels = targets
1111 .iter()
1112 .copied()
1113 .fold(Vec::<f64>::new(), |mut labels, value| {
1114 if labels
1115 .binary_search_by(|candidate| candidate.total_cmp(&value))
1116 .is_err()
1117 {
1118 labels.push(value);
1119 labels.sort_by(|left, right| left.total_cmp(right));
1120 }
1121 labels
1122 });
1123
1124 let class_indices = targets
1125 .iter()
1126 .map(|value| {
1127 class_labels
1128 .binary_search_by(|candidate| candidate.total_cmp(value))
1129 .expect("target value must exist in class label vocabulary")
1130 })
1131 .collect();
1132
1133 Ok((class_labels, class_indices))
1134}
1135
1136fn class_counts(rows: &[usize], class_indices: &[usize], num_classes: usize) -> Vec<usize> {
1137 rows.iter()
1138 .fold(vec![0usize; num_classes], |mut counts, row_idx| {
1139 counts[class_indices[*row_idx]] += 1;
1140 counts
1141 })
1142}
1143
1144fn majority_class(rows: &[usize], class_indices: &[usize], num_classes: usize) -> usize {
1145 majority_class_from_counts(&class_counts(rows, class_indices, num_classes))
1146}
1147
1148fn majority_class_from_counts(counts: &[usize]) -> usize {
1149 counts
1150 .iter()
1151 .copied()
1152 .enumerate()
1153 .max_by(|left, right| left.1.cmp(&right.1).then_with(|| right.0.cmp(&left.0)))
1154 .map(|(class_index, _count)| class_index)
1155 .unwrap_or(0)
1156}
1157
1158fn is_pure(rows: &[usize], class_indices: &[usize]) -> bool {
1159 rows.first().is_none_or(|first_row| {
1160 rows.iter()
1161 .all(|row_idx| class_indices[*row_idx] == class_indices[*first_row])
1162 })
1163}
1164
1165fn entropy(counts: &[usize], total: usize) -> f64 {
1166 counts
1167 .iter()
1168 .copied()
1169 .filter(|count| *count > 0)
1170 .map(|count| {
1171 let probability = count as f64 / total as f64;
1172 -probability * probability.log2()
1173 })
1174 .sum()
1175}
1176
1177fn gini(counts: &[usize], total: usize) -> f64 {
1178 1.0 - counts
1179 .iter()
1180 .copied()
1181 .map(|count| {
1182 let probability = count as f64 / total as f64;
1183 probability * probability
1184 })
1185 .sum::<f64>()
1186}
1187
1188fn classification_impurity(counts: &[usize], total: usize, criterion: Criterion) -> f64 {
1189 match criterion {
1190 Criterion::Entropy => entropy(counts, total),
1191 Criterion::Gini => gini(counts, total),
1192 _ => unreachable!("classification impurity only supports gini or entropy"),
1193 }
1194}
1195
1196fn push_leaf(
1197 nodes: &mut Vec<TreeNode>,
1198 class_index: usize,
1199 sample_count: usize,
1200 class_counts: Vec<usize>,
1201) -> usize {
1202 push_node(
1203 nodes,
1204 TreeNode::Leaf {
1205 class_index,
1206 sample_count,
1207 class_counts,
1208 },
1209 )
1210}
1211
1212fn push_node(nodes: &mut Vec<TreeNode>, node: TreeNode) -> usize {
1213 nodes.push(node);
1214 nodes.len() - 1
1215}
1216
1217#[cfg(test)]
1218mod tests;