Skip to main content

forestfire_core/tree/
classifier.rs

1//! Classification tree learners.
2//!
3//! The module intentionally supports multiple tree families because they express
4//! different tradeoffs:
5//!
6//! - `id3` / `c45` keep multiway splits for categorical-like binned features.
7//! - `cart` is the standard binary threshold learner.
8//! - `randomized` keeps the CART structure but cheapens split search.
9//! - `oblivious` uses one split per depth, which is attractive for some runtime
10//!   layouts and boosting-style ensembles.
11//!
12//! The hot numeric paths are written around binned histograms and in-place row
13//! partitioning. That is why many helpers operate on row-index buffers instead of
14//! allocating fresh row vectors at every recursive step.
15
16use crate::ir::{
17    BinaryChildren, BinarySplit, IndexedLeaf, LeafIndexing, LeafPayload, MultiwayBranch,
18    MultiwaySplit, NodeStats, NodeTreeNode, ObliviousLevel, ObliviousSplit as IrObliviousSplit,
19    TrainingMetadata, TreeDefinition, criterion_name, feature_name, threshold_upper_bound,
20    tree_type_name,
21};
22use crate::tree::shared::{
23    MissingBranchDirection, candidate_feature_indices, choose_random_threshold, node_seed,
24    partition_rows_for_binary_split,
25};
26use crate::{
27    Criterion, FeaturePreprocessing, MissingValueStrategy, Parallelism,
28    capture_feature_preprocessing,
29};
30use forestfire_data::TableAccess;
31use rayon::prelude::*;
32use std::collections::BTreeMap;
33use std::error::Error;
34use std::fmt::{Display, Formatter};
35
36mod histogram;
37mod ir_support;
38mod oblivious;
39mod partitioning;
40mod split_scoring;
41
42use histogram::{
43    ClassificationFeatureHistogram, build_classification_node_histograms,
44    subtract_classification_node_histograms,
45};
46use ir_support::{
47    binary_split_ir, normalized_class_probabilities, oblivious_split_ir, standard_node_depths,
48};
49use oblivious::train_oblivious_structure;
50use partitioning::partition_rows_for_multiway_split;
51use split_scoring::{
52    MultiwayMetric, SplitScoringContext, score_binary_split_choice_from_hist,
53    score_multiway_split_choice,
54};
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum DecisionTreeAlgorithm {
58    Id3,
59    C45,
60    Cart,
61    Randomized,
62    Oblivious,
63}
64
65/// Shared training controls for classification tree learners.
66///
67/// The defaults are intentionally modest rather than "grow until pure", because
68/// ForestFire wants trees to be a stable building block for ensembles and
69/// interpretable standalone models.
70#[derive(Debug, Clone)]
71pub struct DecisionTreeOptions {
72    pub max_depth: usize,
73    pub min_samples_split: usize,
74    pub min_samples_leaf: usize,
75    pub max_features: Option<usize>,
76    pub random_seed: u64,
77    pub missing_value_strategies: Vec<MissingValueStrategy>,
78}
79
80impl Default for DecisionTreeOptions {
81    fn default() -> Self {
82        Self {
83            max_depth: 8,
84            min_samples_split: 2,
85            min_samples_leaf: 1,
86            max_features: None,
87            random_seed: 0,
88            missing_value_strategies: Vec::new(),
89        }
90    }
91}
92
93#[derive(Debug)]
94pub enum DecisionTreeError {
95    EmptyTarget,
96    InvalidTargetValue { row: usize, value: f64 },
97}
98
99impl Display for DecisionTreeError {
100    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
101        match self {
102            DecisionTreeError::EmptyTarget => write!(f, "Cannot train on an empty target vector."),
103            DecisionTreeError::InvalidTargetValue { row, value } => write!(
104                f,
105                "Classification targets must be finite values. Found {} at row {}.",
106                value, row
107            ),
108        }
109    }
110}
111
112impl Error for DecisionTreeError {}
113
114/// Concrete trained classification tree.
115#[derive(Debug, Clone)]
116pub struct DecisionTreeClassifier {
117    algorithm: DecisionTreeAlgorithm,
118    criterion: Criterion,
119    class_labels: Vec<f64>,
120    structure: TreeStructure,
121    options: DecisionTreeOptions,
122    num_features: usize,
123    feature_preprocessing: Vec<FeaturePreprocessing>,
124    training_canaries: usize,
125}
126
127#[derive(Debug, Clone)]
128pub(crate) enum TreeStructure {
129    Standard {
130        nodes: Vec<TreeNode>,
131        root: usize,
132    },
133    Oblivious {
134        splits: Vec<ObliviousSplit>,
135        leaf_class_indices: Vec<usize>,
136        leaf_sample_counts: Vec<usize>,
137        leaf_class_counts: Vec<Vec<usize>>,
138    },
139}
140
141#[derive(Debug, Clone)]
142pub(crate) struct ObliviousSplit {
143    pub(crate) feature_index: usize,
144    pub(crate) threshold_bin: u16,
145    #[allow(dead_code)]
146    pub(crate) missing_directions: Vec<MissingBranchDirection>,
147    pub(crate) sample_count: usize,
148    pub(crate) impurity: f64,
149    pub(crate) gain: f64,
150}
151
152#[derive(Debug, Clone)]
153pub(crate) enum TreeNode {
154    Leaf {
155        class_index: usize,
156        sample_count: usize,
157        class_counts: Vec<usize>,
158    },
159    MultiwaySplit {
160        feature_index: usize,
161        fallback_class_index: usize,
162        branches: Vec<(u16, usize)>,
163        missing_child: Option<usize>,
164        sample_count: usize,
165        impurity: f64,
166        gain: f64,
167        class_counts: Vec<usize>,
168    },
169    BinarySplit {
170        feature_index: usize,
171        threshold_bin: u16,
172        missing_direction: MissingBranchDirection,
173        left_child: usize,
174        right_child: usize,
175        sample_count: usize,
176        impurity: f64,
177        gain: f64,
178        class_counts: Vec<usize>,
179    },
180}
181
182#[derive(Debug, Clone, Copy)]
183struct BinarySplitChoice {
184    feature_index: usize,
185    score: f64,
186    threshold_bin: u16,
187    missing_direction: MissingBranchDirection,
188}
189
190#[derive(Debug, Clone)]
191struct MultiwaySplitChoice {
192    feature_index: usize,
193    score: f64,
194    branch_bins: Vec<u16>,
195    missing_branch_bin: Option<u16>,
196}
197
198pub fn train_id3(train_set: &dyn TableAccess) -> Result<DecisionTreeClassifier, DecisionTreeError> {
199    train_id3_with_criterion(train_set, Criterion::Entropy)
200}
201
202pub fn train_id3_with_criterion(
203    train_set: &dyn TableAccess,
204    criterion: Criterion,
205) -> Result<DecisionTreeClassifier, DecisionTreeError> {
206    train_id3_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
207}
208
209pub(crate) fn train_id3_with_criterion_and_parallelism(
210    train_set: &dyn TableAccess,
211    criterion: Criterion,
212    parallelism: Parallelism,
213) -> Result<DecisionTreeClassifier, DecisionTreeError> {
214    train_id3_with_criterion_parallelism_and_options(
215        train_set,
216        criterion,
217        parallelism,
218        DecisionTreeOptions::default(),
219    )
220}
221
222pub(crate) fn train_id3_with_criterion_parallelism_and_options(
223    train_set: &dyn TableAccess,
224    criterion: Criterion,
225    parallelism: Parallelism,
226    options: DecisionTreeOptions,
227) -> Result<DecisionTreeClassifier, DecisionTreeError> {
228    train_classifier(
229        train_set,
230        DecisionTreeAlgorithm::Id3,
231        criterion,
232        parallelism,
233        options,
234    )
235}
236
237pub fn train_c45(train_set: &dyn TableAccess) -> Result<DecisionTreeClassifier, DecisionTreeError> {
238    train_c45_with_criterion(train_set, Criterion::Entropy)
239}
240
241pub fn train_c45_with_criterion(
242    train_set: &dyn TableAccess,
243    criterion: Criterion,
244) -> Result<DecisionTreeClassifier, DecisionTreeError> {
245    train_c45_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
246}
247
248pub(crate) fn train_c45_with_criterion_and_parallelism(
249    train_set: &dyn TableAccess,
250    criterion: Criterion,
251    parallelism: Parallelism,
252) -> Result<DecisionTreeClassifier, DecisionTreeError> {
253    train_c45_with_criterion_parallelism_and_options(
254        train_set,
255        criterion,
256        parallelism,
257        DecisionTreeOptions::default(),
258    )
259}
260
261pub(crate) fn train_c45_with_criterion_parallelism_and_options(
262    train_set: &dyn TableAccess,
263    criterion: Criterion,
264    parallelism: Parallelism,
265    options: DecisionTreeOptions,
266) -> Result<DecisionTreeClassifier, DecisionTreeError> {
267    train_classifier(
268        train_set,
269        DecisionTreeAlgorithm::C45,
270        criterion,
271        parallelism,
272        options,
273    )
274}
275
276pub fn train_cart(
277    train_set: &dyn TableAccess,
278) -> Result<DecisionTreeClassifier, DecisionTreeError> {
279    train_cart_with_criterion(train_set, Criterion::Gini)
280}
281
282pub fn train_cart_with_criterion(
283    train_set: &dyn TableAccess,
284    criterion: Criterion,
285) -> Result<DecisionTreeClassifier, DecisionTreeError> {
286    train_cart_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
287}
288
289pub(crate) fn train_cart_with_criterion_and_parallelism(
290    train_set: &dyn TableAccess,
291    criterion: Criterion,
292    parallelism: Parallelism,
293) -> Result<DecisionTreeClassifier, DecisionTreeError> {
294    train_cart_with_criterion_parallelism_and_options(
295        train_set,
296        criterion,
297        parallelism,
298        DecisionTreeOptions::default(),
299    )
300}
301
302pub(crate) fn train_cart_with_criterion_parallelism_and_options(
303    train_set: &dyn TableAccess,
304    criterion: Criterion,
305    parallelism: Parallelism,
306    options: DecisionTreeOptions,
307) -> Result<DecisionTreeClassifier, DecisionTreeError> {
308    train_classifier(
309        train_set,
310        DecisionTreeAlgorithm::Cart,
311        criterion,
312        parallelism,
313        options,
314    )
315}
316
317pub fn train_oblivious(
318    train_set: &dyn TableAccess,
319) -> Result<DecisionTreeClassifier, DecisionTreeError> {
320    train_oblivious_with_criterion(train_set, Criterion::Gini)
321}
322
323pub fn train_oblivious_with_criterion(
324    train_set: &dyn TableAccess,
325    criterion: Criterion,
326) -> Result<DecisionTreeClassifier, DecisionTreeError> {
327    train_oblivious_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
328}
329
330pub(crate) fn train_oblivious_with_criterion_and_parallelism(
331    train_set: &dyn TableAccess,
332    criterion: Criterion,
333    parallelism: Parallelism,
334) -> Result<DecisionTreeClassifier, DecisionTreeError> {
335    train_oblivious_with_criterion_parallelism_and_options(
336        train_set,
337        criterion,
338        parallelism,
339        DecisionTreeOptions::default(),
340    )
341}
342
343pub(crate) fn train_oblivious_with_criterion_parallelism_and_options(
344    train_set: &dyn TableAccess,
345    criterion: Criterion,
346    parallelism: Parallelism,
347    options: DecisionTreeOptions,
348) -> Result<DecisionTreeClassifier, DecisionTreeError> {
349    train_classifier(
350        train_set,
351        DecisionTreeAlgorithm::Oblivious,
352        criterion,
353        parallelism,
354        options,
355    )
356}
357
358pub fn train_randomized(
359    train_set: &dyn TableAccess,
360) -> Result<DecisionTreeClassifier, DecisionTreeError> {
361    train_randomized_with_criterion(train_set, Criterion::Gini)
362}
363
364pub fn train_randomized_with_criterion(
365    train_set: &dyn TableAccess,
366    criterion: Criterion,
367) -> Result<DecisionTreeClassifier, DecisionTreeError> {
368    train_randomized_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
369}
370
371pub(crate) fn train_randomized_with_criterion_and_parallelism(
372    train_set: &dyn TableAccess,
373    criterion: Criterion,
374    parallelism: Parallelism,
375) -> Result<DecisionTreeClassifier, DecisionTreeError> {
376    train_randomized_with_criterion_parallelism_and_options(
377        train_set,
378        criterion,
379        parallelism,
380        DecisionTreeOptions::default(),
381    )
382}
383
384pub(crate) fn train_randomized_with_criterion_parallelism_and_options(
385    train_set: &dyn TableAccess,
386    criterion: Criterion,
387    parallelism: Parallelism,
388    options: DecisionTreeOptions,
389) -> Result<DecisionTreeClassifier, DecisionTreeError> {
390    train_classifier(
391        train_set,
392        DecisionTreeAlgorithm::Randomized,
393        criterion,
394        parallelism,
395        options,
396    )
397}
398
399fn train_classifier(
400    train_set: &dyn TableAccess,
401    algorithm: DecisionTreeAlgorithm,
402    criterion: Criterion,
403    parallelism: Parallelism,
404    options: DecisionTreeOptions,
405) -> Result<DecisionTreeClassifier, DecisionTreeError> {
406    if train_set.n_rows() == 0 {
407        return Err(DecisionTreeError::EmptyTarget);
408    }
409
410    let (class_labels, class_indices) = encode_class_labels(train_set)?;
411    let structure = match algorithm {
412        DecisionTreeAlgorithm::Oblivious => train_oblivious_structure(
413            train_set,
414            &class_indices,
415            &class_labels,
416            criterion,
417            parallelism,
418            options.clone(),
419        ),
420        DecisionTreeAlgorithm::Cart | DecisionTreeAlgorithm::Randomized => {
421            let mut nodes = Vec::new();
422            let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
423            let context = BuildContext {
424                table: train_set,
425                class_indices: &class_indices,
426                class_labels: &class_labels,
427                algorithm,
428                criterion,
429                parallelism,
430                options: options.clone(),
431            };
432            let root = build_binary_node_in_place(&context, &mut nodes, &mut all_rows, 0);
433            TreeStructure::Standard { nodes, root }
434        }
435        DecisionTreeAlgorithm::Id3 | DecisionTreeAlgorithm::C45 => {
436            let mut nodes = Vec::new();
437            let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
438            let context = BuildContext {
439                table: train_set,
440                class_indices: &class_indices,
441                class_labels: &class_labels,
442                algorithm,
443                criterion,
444                parallelism,
445                options: options.clone(),
446            };
447            let root = build_multiway_node_in_place(&context, &mut nodes, &mut all_rows, 0);
448            TreeStructure::Standard { nodes, root }
449        }
450    };
451
452    Ok(DecisionTreeClassifier {
453        algorithm,
454        criterion,
455        class_labels,
456        structure,
457        options,
458        num_features: train_set.n_features(),
459        feature_preprocessing: capture_feature_preprocessing(train_set),
460        training_canaries: train_set.canaries(),
461    })
462}
463
464impl DecisionTreeClassifier {
465    pub fn algorithm(&self) -> DecisionTreeAlgorithm {
466        self.algorithm
467    }
468
469    pub fn criterion(&self) -> Criterion {
470        self.criterion
471    }
472
473    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
474        (0..table.n_rows())
475            .map(|row_idx| self.predict_row(table, row_idx))
476            .collect()
477    }
478
479    pub fn predict_proba_table(&self, table: &dyn TableAccess) -> Vec<Vec<f64>> {
480        (0..table.n_rows())
481            .map(|row_idx| self.predict_proba_row(table, row_idx))
482            .collect()
483    }
484
485    fn predict_row(&self, table: &dyn TableAccess, row_idx: usize) -> f64 {
486        match &self.structure {
487            TreeStructure::Standard { nodes, root } => {
488                let mut node_index = *root;
489                loop {
490                    match &nodes[node_index] {
491                        TreeNode::Leaf { class_index, .. } => {
492                            return self.class_labels[*class_index];
493                        }
494                        TreeNode::MultiwaySplit {
495                            feature_index,
496                            fallback_class_index,
497                            branches,
498                            missing_child,
499                            ..
500                        } => {
501                            if table.is_missing(*feature_index, row_idx) {
502                                if let Some(child_index) = missing_child {
503                                    node_index = *child_index;
504                                } else {
505                                    return self.class_labels[*fallback_class_index];
506                                }
507                                continue;
508                            }
509                            let bin = table.binned_value(*feature_index, row_idx);
510                            if let Some((_, child_index)) =
511                                branches.iter().find(|(branch_bin, _)| *branch_bin == bin)
512                            {
513                                node_index = *child_index;
514                            } else {
515                                return self.class_labels[*fallback_class_index];
516                            }
517                        }
518                        TreeNode::BinarySplit {
519                            feature_index,
520                            threshold_bin,
521                            missing_direction,
522                            left_child,
523                            right_child,
524                            class_counts,
525                            ..
526                        } => {
527                            if table.is_missing(*feature_index, row_idx) {
528                                match missing_direction {
529                                    MissingBranchDirection::Left => {
530                                        node_index = *left_child;
531                                    }
532                                    MissingBranchDirection::Right => {
533                                        node_index = *right_child;
534                                    }
535                                    MissingBranchDirection::Node => {
536                                        return self.class_labels
537                                            [majority_class_from_counts(class_counts)];
538                                    }
539                                }
540                                continue;
541                            }
542                            let bin = table.binned_value(*feature_index, row_idx);
543                            node_index = if bin <= *threshold_bin {
544                                *left_child
545                            } else {
546                                *right_child
547                            };
548                        }
549                    }
550                }
551            }
552            TreeStructure::Oblivious {
553                splits,
554                leaf_class_indices,
555                ..
556            } => {
557                let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
558                    let go_right =
559                        table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
560                    (leaf_index << 1) | usize::from(go_right)
561                });
562
563                self.class_labels[leaf_class_indices[leaf_index]]
564            }
565        }
566    }
567
568    fn predict_proba_row(&self, table: &dyn TableAccess, row_idx: usize) -> Vec<f64> {
569        match &self.structure {
570            TreeStructure::Standard { nodes, root } => {
571                let mut node_index = *root;
572                loop {
573                    match &nodes[node_index] {
574                        TreeNode::Leaf { class_counts, .. } => {
575                            return normalized_class_probabilities(class_counts);
576                        }
577                        TreeNode::MultiwaySplit {
578                            feature_index,
579                            branches,
580                            missing_child,
581                            class_counts,
582                            ..
583                        } => {
584                            if table.is_missing(*feature_index, row_idx) {
585                                if let Some(child_index) = missing_child {
586                                    node_index = *child_index;
587                                } else {
588                                    return normalized_class_probabilities(class_counts);
589                                }
590                                continue;
591                            }
592                            let bin = table.binned_value(*feature_index, row_idx);
593                            if let Some((_, child_index)) =
594                                branches.iter().find(|(branch_bin, _)| *branch_bin == bin)
595                            {
596                                node_index = *child_index;
597                            } else {
598                                return normalized_class_probabilities(class_counts);
599                            }
600                        }
601                        TreeNode::BinarySplit {
602                            feature_index,
603                            threshold_bin,
604                            missing_direction,
605                            left_child,
606                            right_child,
607                            class_counts,
608                            ..
609                        } => {
610                            if table.is_missing(*feature_index, row_idx) {
611                                match missing_direction {
612                                    MissingBranchDirection::Left => {
613                                        node_index = *left_child;
614                                    }
615                                    MissingBranchDirection::Right => {
616                                        node_index = *right_child;
617                                    }
618                                    MissingBranchDirection::Node => {
619                                        return normalized_class_probabilities(class_counts);
620                                    }
621                                }
622                                continue;
623                            }
624                            let bin = table.binned_value(*feature_index, row_idx);
625                            node_index = if bin <= *threshold_bin {
626                                *left_child
627                            } else {
628                                *right_child
629                            };
630                        }
631                    }
632                }
633            }
634            TreeStructure::Oblivious {
635                splits,
636                leaf_class_counts,
637                ..
638            } => {
639                let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
640                    let go_right =
641                        table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
642                    (leaf_index << 1) | usize::from(go_right)
643                });
644
645                normalized_class_probabilities(&leaf_class_counts[leaf_index])
646            }
647        }
648    }
649
650    pub(crate) fn class_labels(&self) -> &[f64] {
651        &self.class_labels
652    }
653
654    pub(crate) fn structure(&self) -> &TreeStructure {
655        &self.structure
656    }
657
658    pub(crate) fn num_features(&self) -> usize {
659        self.num_features
660    }
661
662    pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
663        &self.feature_preprocessing
664    }
665
666    pub(crate) fn training_metadata(&self) -> TrainingMetadata {
667        TrainingMetadata {
668            algorithm: "dt".to_string(),
669            task: "classification".to_string(),
670            tree_type: tree_type_name(match self.algorithm {
671                DecisionTreeAlgorithm::Id3 => crate::TreeType::Id3,
672                DecisionTreeAlgorithm::C45 => crate::TreeType::C45,
673                DecisionTreeAlgorithm::Cart => crate::TreeType::Cart,
674                DecisionTreeAlgorithm::Randomized => crate::TreeType::Randomized,
675                DecisionTreeAlgorithm::Oblivious => crate::TreeType::Oblivious,
676            })
677            .to_string(),
678            criterion: criterion_name(self.criterion).to_string(),
679            canaries: self.training_canaries,
680            compute_oob: false,
681            max_depth: Some(self.options.max_depth),
682            min_samples_split: Some(self.options.min_samples_split),
683            min_samples_leaf: Some(self.options.min_samples_leaf),
684            n_trees: None,
685            max_features: self.options.max_features,
686            seed: None,
687            oob_score: None,
688            class_labels: Some(self.class_labels.clone()),
689            learning_rate: None,
690            bootstrap: None,
691            top_gradient_fraction: None,
692            other_gradient_fraction: None,
693        }
694    }
695
696    pub(crate) fn to_ir_tree(&self) -> TreeDefinition {
697        match &self.structure {
698            TreeStructure::Standard { nodes, root } => {
699                let depths = standard_node_depths(nodes, *root);
700                TreeDefinition::NodeTree {
701                    tree_id: 0,
702                    weight: 1.0,
703                    root_node_id: *root,
704                    nodes: nodes
705                        .iter()
706                        .enumerate()
707                        .map(|(node_id, node)| match node {
708                            TreeNode::Leaf {
709                                class_index,
710                                sample_count,
711                                class_counts,
712                            } => NodeTreeNode::Leaf {
713                                node_id,
714                                depth: depths[node_id],
715                                leaf: self.class_leaf(*class_index),
716                                stats: NodeStats {
717                                    sample_count: *sample_count,
718                                    impurity: None,
719                                    gain: None,
720                                    class_counts: Some(class_counts.clone()),
721                                    variance: None,
722                                },
723                            },
724                            TreeNode::BinarySplit {
725                                feature_index,
726                                threshold_bin,
727                                missing_direction,
728                                left_child,
729                                right_child,
730                                sample_count,
731                                impurity,
732                                gain,
733                                class_counts,
734                            } => NodeTreeNode::BinaryBranch {
735                                node_id,
736                                depth: depths[node_id],
737                                split: binary_split_ir(
738                                    *feature_index,
739                                    *threshold_bin,
740                                    *missing_direction,
741                                    &self.feature_preprocessing,
742                                ),
743                                children: BinaryChildren {
744                                    left: *left_child,
745                                    right: *right_child,
746                                },
747                                stats: NodeStats {
748                                    sample_count: *sample_count,
749                                    impurity: Some(*impurity),
750                                    gain: Some(*gain),
751                                    class_counts: Some(class_counts.clone()),
752                                    variance: None,
753                                },
754                            },
755                            TreeNode::MultiwaySplit {
756                                feature_index,
757                                fallback_class_index,
758                                branches,
759                                missing_child: _,
760                                sample_count,
761                                impurity,
762                                gain,
763                                class_counts,
764                            } => NodeTreeNode::MultiwayBranch {
765                                node_id,
766                                depth: depths[node_id],
767                                split: MultiwaySplit {
768                                    split_type: "binned_value_multiway".to_string(),
769                                    feature_index: *feature_index,
770                                    feature_name: feature_name(*feature_index),
771                                    comparison_dtype: "uint16".to_string(),
772                                },
773                                branches: branches
774                                    .iter()
775                                    .map(|(bin, child)| MultiwayBranch {
776                                        bin: *bin,
777                                        child: *child,
778                                    })
779                                    .collect(),
780                                unmatched_leaf: self.class_leaf(*fallback_class_index),
781                                stats: NodeStats {
782                                    sample_count: *sample_count,
783                                    impurity: Some(*impurity),
784                                    gain: Some(*gain),
785                                    class_counts: Some(class_counts.clone()),
786                                    variance: None,
787                                },
788                            },
789                        })
790                        .collect(),
791                }
792            }
793            TreeStructure::Oblivious {
794                splits,
795                leaf_class_indices,
796                leaf_sample_counts,
797                leaf_class_counts,
798            } => TreeDefinition::ObliviousLevels {
799                tree_id: 0,
800                weight: 1.0,
801                depth: splits.len(),
802                levels: splits
803                    .iter()
804                    .enumerate()
805                    .map(|(level, split)| ObliviousLevel {
806                        level,
807                        split: oblivious_split_ir(
808                            split.feature_index,
809                            split.threshold_bin,
810                            &self.feature_preprocessing,
811                        ),
812                        stats: NodeStats {
813                            sample_count: split.sample_count,
814                            impurity: Some(split.impurity),
815                            gain: Some(split.gain),
816                            class_counts: None,
817                            variance: None,
818                        },
819                    })
820                    .collect(),
821                leaf_indexing: LeafIndexing {
822                    bit_order: "msb_first".to_string(),
823                    index_formula: "sum(bit[level] << (depth - 1 - level))".to_string(),
824                },
825                leaves: leaf_class_indices
826                    .iter()
827                    .enumerate()
828                    .map(|(leaf_index, class_index)| IndexedLeaf {
829                        leaf_index,
830                        leaf: self.class_leaf(*class_index),
831                        stats: NodeStats {
832                            sample_count: leaf_sample_counts[leaf_index],
833                            impurity: None,
834                            gain: None,
835                            class_counts: Some(leaf_class_counts[leaf_index].clone()),
836                            variance: None,
837                        },
838                    })
839                    .collect(),
840            },
841        }
842    }
843
844    fn class_leaf(&self, class_index: usize) -> LeafPayload {
845        LeafPayload::ClassIndex {
846            class_index,
847            class_value: self.class_labels[class_index],
848        }
849    }
850
851    #[allow(clippy::too_many_arguments)]
852    pub(crate) fn from_ir_parts(
853        algorithm: DecisionTreeAlgorithm,
854        criterion: Criterion,
855        class_labels: Vec<f64>,
856        structure: TreeStructure,
857        options: DecisionTreeOptions,
858        num_features: usize,
859        feature_preprocessing: Vec<FeaturePreprocessing>,
860        training_canaries: usize,
861    ) -> Self {
862        Self {
863            algorithm,
864            criterion,
865            class_labels,
866            structure,
867            options,
868            num_features,
869            feature_preprocessing,
870            training_canaries,
871        }
872    }
873}
874
875fn build_binary_node_in_place(
876    context: &BuildContext<'_>,
877    nodes: &mut Vec<TreeNode>,
878    rows: &mut [usize],
879    depth: usize,
880) -> usize {
881    build_binary_node_in_place_with_hist(context, nodes, rows, depth, None)
882}
883
884fn build_binary_node_in_place_with_hist(
885    context: &BuildContext<'_>,
886    nodes: &mut Vec<TreeNode>,
887    rows: &mut [usize],
888    depth: usize,
889    histograms: Option<Vec<ClassificationFeatureHistogram>>,
890) -> usize {
891    let majority_class_index =
892        majority_class(rows, context.class_indices, context.class_labels.len());
893    let current_class_counts =
894        class_counts(rows, context.class_indices, context.class_labels.len());
895
896    if rows.is_empty()
897        || depth >= context.options.max_depth
898        || rows.len() < context.options.min_samples_split
899        || is_pure(rows, context.class_indices)
900    {
901        return push_leaf(
902            nodes,
903            majority_class_index,
904            rows.len(),
905            current_class_counts,
906        );
907    }
908
909    let scoring = SplitScoringContext {
910        table: context.table,
911        class_indices: context.class_indices,
912        num_classes: context.class_labels.len(),
913        criterion: context.criterion,
914        min_samples_leaf: context.options.min_samples_leaf,
915        missing_value_strategies: &context.options.missing_value_strategies,
916    };
917    let histograms = histograms.unwrap_or_else(|| {
918        build_classification_node_histograms(
919            context.table,
920            context.class_indices,
921            rows,
922            context.class_labels.len(),
923        )
924    });
925    let feature_indices = candidate_feature_indices(
926        context.table.binned_feature_count(),
927        context.options.max_features,
928        node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
929    );
930    let best_split = if context.parallelism.enabled() {
931        feature_indices
932            .into_par_iter()
933            .filter_map(|feature_index| {
934                score_binary_split_choice_from_hist(
935                    &scoring,
936                    &histograms[feature_index],
937                    feature_index,
938                    rows,
939                    &current_class_counts,
940                    context.algorithm,
941                )
942            })
943            .max_by(|left, right| left.score.total_cmp(&right.score))
944    } else {
945        feature_indices
946            .into_iter()
947            .filter_map(|feature_index| {
948                score_binary_split_choice_from_hist(
949                    &scoring,
950                    &histograms[feature_index],
951                    feature_index,
952                    rows,
953                    &current_class_counts,
954                    context.algorithm,
955                )
956            })
957            .max_by(|left, right| left.score.total_cmp(&right.score))
958    };
959
960    match best_split {
961        Some(best_split)
962            if context
963                .table
964                .is_canary_binned_feature(best_split.feature_index) =>
965        {
966            push_leaf(
967                nodes,
968                majority_class_index,
969                rows.len(),
970                current_class_counts,
971            )
972        }
973        Some(best_split) if best_split.score > 0.0 => {
974            let impurity =
975                classification_impurity(&current_class_counts, rows.len(), context.criterion);
976            let left_count = partition_rows_for_binary_split(
977                context.table,
978                best_split.feature_index,
979                best_split.threshold_bin,
980                best_split.missing_direction,
981                rows,
982            );
983            let (left_rows, right_rows) = rows.split_at_mut(left_count);
984            let (left_histograms, right_histograms) = if left_rows.len() <= right_rows.len() {
985                let left_histograms = build_classification_node_histograms(
986                    context.table,
987                    context.class_indices,
988                    left_rows,
989                    context.class_labels.len(),
990                );
991                let right_histograms =
992                    subtract_classification_node_histograms(&histograms, &left_histograms);
993                (left_histograms, right_histograms)
994            } else {
995                let right_histograms = build_classification_node_histograms(
996                    context.table,
997                    context.class_indices,
998                    right_rows,
999                    context.class_labels.len(),
1000                );
1001                let left_histograms =
1002                    subtract_classification_node_histograms(&histograms, &right_histograms);
1003                (left_histograms, right_histograms)
1004            };
1005            let left_child = build_binary_node_in_place_with_hist(
1006                context,
1007                nodes,
1008                left_rows,
1009                depth + 1,
1010                Some(left_histograms),
1011            );
1012            let right_child = build_binary_node_in_place_with_hist(
1013                context,
1014                nodes,
1015                right_rows,
1016                depth + 1,
1017                Some(right_histograms),
1018            );
1019
1020            push_node(
1021                nodes,
1022                TreeNode::BinarySplit {
1023                    feature_index: best_split.feature_index,
1024                    threshold_bin: best_split.threshold_bin,
1025                    missing_direction: best_split.missing_direction,
1026                    left_child,
1027                    right_child,
1028                    sample_count: rows.len(),
1029                    impurity,
1030                    gain: best_split.score,
1031                    class_counts: current_class_counts,
1032                },
1033            )
1034        }
1035        _ => push_leaf(
1036            nodes,
1037            majority_class_index,
1038            rows.len(),
1039            current_class_counts,
1040        ),
1041    }
1042}
1043
1044fn build_multiway_node_in_place(
1045    context: &BuildContext<'_>,
1046    nodes: &mut Vec<TreeNode>,
1047    rows: &mut [usize],
1048    depth: usize,
1049) -> usize {
1050    let majority_class_index =
1051        majority_class(rows, context.class_indices, context.class_labels.len());
1052    let current_class_counts =
1053        class_counts(rows, context.class_indices, context.class_labels.len());
1054
1055    if rows.is_empty()
1056        || depth >= context.options.max_depth
1057        || rows.len() < context.options.min_samples_split
1058        || is_pure(rows, context.class_indices)
1059    {
1060        return push_leaf(
1061            nodes,
1062            majority_class_index,
1063            rows.len(),
1064            current_class_counts,
1065        );
1066    }
1067
1068    let metric = match context.algorithm {
1069        DecisionTreeAlgorithm::Id3 => MultiwayMetric::InformationGain,
1070        DecisionTreeAlgorithm::C45 => MultiwayMetric::GainRatio,
1071        _ => unreachable!("multiway builder only supports id3/c45"),
1072    };
1073    let scoring = SplitScoringContext {
1074        table: context.table,
1075        class_indices: context.class_indices,
1076        num_classes: context.class_labels.len(),
1077        criterion: context.criterion,
1078        min_samples_leaf: context.options.min_samples_leaf,
1079        missing_value_strategies: &context.options.missing_value_strategies,
1080    };
1081    let feature_indices = candidate_feature_indices(
1082        context.table.binned_feature_count(),
1083        context.options.max_features,
1084        node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
1085    );
1086    let best_split = if context.parallelism.enabled() {
1087        feature_indices
1088            .into_par_iter()
1089            .filter_map(|feature_index| {
1090                score_multiway_split_choice(&scoring, feature_index, rows, metric)
1091            })
1092            .max_by(|left, right| left.score.total_cmp(&right.score))
1093    } else {
1094        feature_indices
1095            .into_iter()
1096            .filter_map(|feature_index| {
1097                score_multiway_split_choice(&scoring, feature_index, rows, metric)
1098            })
1099            .max_by(|left, right| left.score.total_cmp(&right.score))
1100    };
1101
1102    match best_split {
1103        Some(best_split)
1104            if context
1105                .table
1106                .is_canary_binned_feature(best_split.feature_index) =>
1107        {
1108            push_leaf(
1109                nodes,
1110                majority_class_index,
1111                rows.len(),
1112                current_class_counts,
1113            )
1114        }
1115        Some(best_split) if best_split.score > 0.0 => {
1116            let impurity =
1117                classification_impurity(&current_class_counts, rows.len(), context.criterion);
1118            let branch_ranges = partition_rows_for_multiway_split(
1119                context.table,
1120                best_split.feature_index,
1121                &best_split.branch_bins,
1122                best_split.missing_branch_bin,
1123                rows,
1124            );
1125            let mut branch_nodes = Vec::with_capacity(branch_ranges.len());
1126            let mut missing_child = None;
1127            for (bin, start, end) in branch_ranges {
1128                let child =
1129                    build_multiway_node_in_place(context, nodes, &mut rows[start..end], depth + 1);
1130                if best_split.missing_branch_bin == Some(bin) {
1131                    missing_child = Some(child);
1132                }
1133                branch_nodes.push((bin, child));
1134            }
1135
1136            push_node(
1137                nodes,
1138                TreeNode::MultiwaySplit {
1139                    feature_index: best_split.feature_index,
1140                    fallback_class_index: majority_class_index,
1141                    branches: branch_nodes,
1142                    missing_child,
1143                    sample_count: rows.len(),
1144                    impurity,
1145                    gain: best_split.score,
1146                    class_counts: current_class_counts,
1147                },
1148            )
1149        }
1150        _ => push_leaf(
1151            nodes,
1152            majority_class_index,
1153            rows.len(),
1154            current_class_counts,
1155        ),
1156    }
1157}
1158
1159struct BuildContext<'a> {
1160    table: &'a dyn TableAccess,
1161    class_indices: &'a [usize],
1162    class_labels: &'a [f64],
1163    algorithm: DecisionTreeAlgorithm,
1164    criterion: Criterion,
1165    parallelism: Parallelism,
1166    options: DecisionTreeOptions,
1167}
1168
1169fn encode_class_labels(
1170    train_set: &dyn TableAccess,
1171) -> Result<(Vec<f64>, Vec<usize>), DecisionTreeError> {
1172    let targets: Vec<f64> = (0..train_set.n_rows())
1173        .map(|row_idx| {
1174            let value = train_set.target_value(row_idx);
1175            if value.is_finite() {
1176                Ok(value)
1177            } else {
1178                Err(DecisionTreeError::InvalidTargetValue {
1179                    row: row_idx,
1180                    value,
1181                })
1182            }
1183        })
1184        .collect::<Result<_, _>>()?;
1185
1186    let class_labels = targets
1187        .iter()
1188        .copied()
1189        .fold(Vec::<f64>::new(), |mut labels, value| {
1190            if labels
1191                .binary_search_by(|candidate| candidate.total_cmp(&value))
1192                .is_err()
1193            {
1194                labels.push(value);
1195                labels.sort_by(|left, right| left.total_cmp(right));
1196            }
1197            labels
1198        });
1199
1200    let class_indices = targets
1201        .iter()
1202        .map(|value| {
1203            class_labels
1204                .binary_search_by(|candidate| candidate.total_cmp(value))
1205                .expect("target value must exist in class label vocabulary")
1206        })
1207        .collect();
1208
1209    Ok((class_labels, class_indices))
1210}
1211
1212fn class_counts(rows: &[usize], class_indices: &[usize], num_classes: usize) -> Vec<usize> {
1213    rows.iter()
1214        .fold(vec![0usize; num_classes], |mut counts, row_idx| {
1215            counts[class_indices[*row_idx]] += 1;
1216            counts
1217        })
1218}
1219
1220fn majority_class(rows: &[usize], class_indices: &[usize], num_classes: usize) -> usize {
1221    majority_class_from_counts(&class_counts(rows, class_indices, num_classes))
1222}
1223
1224fn majority_class_from_counts(counts: &[usize]) -> usize {
1225    counts
1226        .iter()
1227        .copied()
1228        .enumerate()
1229        .max_by(|left, right| left.1.cmp(&right.1).then_with(|| right.0.cmp(&left.0)))
1230        .map(|(class_index, _count)| class_index)
1231        .unwrap_or(0)
1232}
1233
1234fn is_pure(rows: &[usize], class_indices: &[usize]) -> bool {
1235    rows.first().is_none_or(|first_row| {
1236        rows.iter()
1237            .all(|row_idx| class_indices[*row_idx] == class_indices[*first_row])
1238    })
1239}
1240
1241fn entropy(counts: &[usize], total: usize) -> f64 {
1242    counts
1243        .iter()
1244        .copied()
1245        .filter(|count| *count > 0)
1246        .map(|count| {
1247            let probability = count as f64 / total as f64;
1248            -probability * probability.log2()
1249        })
1250        .sum()
1251}
1252
1253fn gini(counts: &[usize], total: usize) -> f64 {
1254    1.0 - counts
1255        .iter()
1256        .copied()
1257        .map(|count| {
1258            let probability = count as f64 / total as f64;
1259            probability * probability
1260        })
1261        .sum::<f64>()
1262}
1263
1264fn classification_impurity(counts: &[usize], total: usize, criterion: Criterion) -> f64 {
1265    match criterion {
1266        Criterion::Entropy => entropy(counts, total),
1267        Criterion::Gini => gini(counts, total),
1268        _ => unreachable!("classification impurity only supports gini or entropy"),
1269    }
1270}
1271
1272fn push_leaf(
1273    nodes: &mut Vec<TreeNode>,
1274    class_index: usize,
1275    sample_count: usize,
1276    class_counts: Vec<usize>,
1277) -> usize {
1278    push_node(
1279        nodes,
1280        TreeNode::Leaf {
1281            class_index,
1282            sample_count,
1283            class_counts,
1284        },
1285    )
1286}
1287
1288fn push_node(nodes: &mut Vec<TreeNode>, node: TreeNode) -> usize {
1289    nodes.push(node);
1290    nodes.len() - 1
1291}
1292
1293#[cfg(test)]
1294mod tests;