Skip to main content

forestfire_core/tree/
classifier.rs

1//! Classification tree learners.
2//!
3//! The module intentionally supports multiple tree families because they express
4//! different tradeoffs:
5//!
6//! - `id3` / `c45` keep multiway splits for categorical-like binned features.
7//! - `cart` is the standard binary threshold learner.
8//! - `randomized` keeps the CART structure but cheapens split search.
9//! - `oblivious` uses one split per depth, which is attractive for some runtime
10//!   layouts and boosting-style ensembles.
11//!
12//! The hot numeric paths are written around binned histograms and in-place row
13//! partitioning. That is why many helpers operate on row-index buffers instead of
14//! allocating fresh row vectors at every recursive step.
15
16use crate::ir::{
17    BinaryChildren, BinarySplit, IndexedLeaf, LeafIndexing, LeafPayload, MultiwayBranch,
18    MultiwaySplit, NodeStats, NodeTreeNode, ObliviousLevel, ObliviousSplit as IrObliviousSplit,
19    TrainingMetadata, TreeDefinition, criterion_name, feature_name, threshold_upper_bound,
20    tree_type_name,
21};
22use crate::tree::shared::{
23    MissingBranchDirection, candidate_feature_indices, choose_random_threshold, node_seed,
24    partition_rows_for_binary_split, select_best_non_canary_candidate,
25};
26use crate::{
27    CanaryFilter, Criterion, FeaturePreprocessing, MissingValueStrategy, Parallelism,
28    capture_feature_preprocessing,
29};
30use forestfire_data::TableAccess;
31use rayon::prelude::*;
32use std::collections::BTreeMap;
33use std::error::Error;
34use std::fmt::{Display, Formatter};
35
36mod histogram;
37mod ir_support;
38mod oblivious;
39mod partitioning;
40mod split_scoring;
41
42use histogram::{
43    ClassificationFeatureHistogram, build_classification_node_histograms,
44    subtract_classification_node_histograms,
45};
46use ir_support::{
47    binary_split_ir, normalized_class_probabilities, oblivious_split_ir, standard_node_depths,
48};
49use oblivious::train_oblivious_structure;
50use partitioning::partition_rows_for_multiway_split;
51use split_scoring::{
52    MultiwayMetric, SplitScoringContext, score_binary_split_choice_from_hist,
53    score_multiway_split_choice,
54};
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum DecisionTreeAlgorithm {
58    Id3,
59    C45,
60    Cart,
61    Randomized,
62    Oblivious,
63}
64
65/// Shared training controls for classification tree learners.
66///
67/// The defaults are intentionally modest rather than "grow until pure", because
68/// ForestFire wants trees to be a stable building block for ensembles and
69/// interpretable standalone models.
70#[derive(Debug, Clone)]
71pub struct DecisionTreeOptions {
72    pub max_depth: usize,
73    pub min_samples_split: usize,
74    pub min_samples_leaf: usize,
75    pub max_features: Option<usize>,
76    pub random_seed: u64,
77    pub missing_value_strategies: Vec<MissingValueStrategy>,
78    pub canary_filter: CanaryFilter,
79}
80
81impl Default for DecisionTreeOptions {
82    fn default() -> Self {
83        Self {
84            max_depth: 8,
85            min_samples_split: 2,
86            min_samples_leaf: 1,
87            max_features: None,
88            random_seed: 0,
89            missing_value_strategies: Vec::new(),
90            canary_filter: CanaryFilter::default(),
91        }
92    }
93}
94
95#[derive(Debug)]
96pub enum DecisionTreeError {
97    EmptyTarget,
98    InvalidTargetValue { row: usize, value: f64 },
99}
100
101impl Display for DecisionTreeError {
102    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
103        match self {
104            DecisionTreeError::EmptyTarget => write!(f, "Cannot train on an empty target vector."),
105            DecisionTreeError::InvalidTargetValue { row, value } => write!(
106                f,
107                "Classification targets must be finite values. Found {} at row {}.",
108                value, row
109            ),
110        }
111    }
112}
113
114impl Error for DecisionTreeError {}
115
116/// Concrete trained classification tree.
117#[derive(Debug, Clone)]
118pub struct DecisionTreeClassifier {
119    algorithm: DecisionTreeAlgorithm,
120    criterion: Criterion,
121    class_labels: Vec<f64>,
122    structure: TreeStructure,
123    options: DecisionTreeOptions,
124    num_features: usize,
125    feature_preprocessing: Vec<FeaturePreprocessing>,
126    training_canaries: usize,
127}
128
129#[derive(Debug, Clone)]
130pub(crate) enum TreeStructure {
131    Standard {
132        nodes: Vec<TreeNode>,
133        root: usize,
134    },
135    Oblivious {
136        splits: Vec<ObliviousSplit>,
137        leaf_class_indices: Vec<usize>,
138        leaf_sample_counts: Vec<usize>,
139        leaf_class_counts: Vec<Vec<usize>>,
140    },
141}
142
143#[derive(Debug, Clone)]
144pub(crate) struct ObliviousSplit {
145    pub(crate) feature_index: usize,
146    pub(crate) threshold_bin: u16,
147    #[allow(dead_code)]
148    pub(crate) missing_directions: Vec<MissingBranchDirection>,
149    pub(crate) sample_count: usize,
150    pub(crate) impurity: f64,
151    pub(crate) gain: f64,
152}
153
154#[derive(Debug, Clone)]
155pub(crate) enum TreeNode {
156    Leaf {
157        class_index: usize,
158        sample_count: usize,
159        class_counts: Vec<usize>,
160    },
161    MultiwaySplit {
162        feature_index: usize,
163        fallback_class_index: usize,
164        branches: Vec<(u16, usize)>,
165        missing_child: Option<usize>,
166        sample_count: usize,
167        impurity: f64,
168        gain: f64,
169        class_counts: Vec<usize>,
170    },
171    BinarySplit {
172        feature_index: usize,
173        threshold_bin: u16,
174        missing_direction: MissingBranchDirection,
175        left_child: usize,
176        right_child: usize,
177        sample_count: usize,
178        impurity: f64,
179        gain: f64,
180        class_counts: Vec<usize>,
181    },
182}
183
184#[derive(Debug, Clone, Copy)]
185struct BinarySplitChoice {
186    feature_index: usize,
187    score: f64,
188    threshold_bin: u16,
189    missing_direction: MissingBranchDirection,
190}
191
192#[derive(Debug, Clone)]
193struct MultiwaySplitChoice {
194    feature_index: usize,
195    score: f64,
196    branch_bins: Vec<u16>,
197    missing_branch_bin: Option<u16>,
198}
199
200pub fn train_id3(train_set: &dyn TableAccess) -> Result<DecisionTreeClassifier, DecisionTreeError> {
201    train_id3_with_criterion(train_set, Criterion::Entropy)
202}
203
204pub fn train_id3_with_criterion(
205    train_set: &dyn TableAccess,
206    criterion: Criterion,
207) -> Result<DecisionTreeClassifier, DecisionTreeError> {
208    train_id3_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
209}
210
211pub(crate) fn train_id3_with_criterion_and_parallelism(
212    train_set: &dyn TableAccess,
213    criterion: Criterion,
214    parallelism: Parallelism,
215) -> Result<DecisionTreeClassifier, DecisionTreeError> {
216    train_id3_with_criterion_parallelism_and_options(
217        train_set,
218        criterion,
219        parallelism,
220        DecisionTreeOptions::default(),
221    )
222}
223
224pub(crate) fn train_id3_with_criterion_parallelism_and_options(
225    train_set: &dyn TableAccess,
226    criterion: Criterion,
227    parallelism: Parallelism,
228    options: DecisionTreeOptions,
229) -> Result<DecisionTreeClassifier, DecisionTreeError> {
230    train_classifier(
231        train_set,
232        DecisionTreeAlgorithm::Id3,
233        criterion,
234        parallelism,
235        options,
236    )
237}
238
239pub fn train_c45(train_set: &dyn TableAccess) -> Result<DecisionTreeClassifier, DecisionTreeError> {
240    train_c45_with_criterion(train_set, Criterion::Entropy)
241}
242
243pub fn train_c45_with_criterion(
244    train_set: &dyn TableAccess,
245    criterion: Criterion,
246) -> Result<DecisionTreeClassifier, DecisionTreeError> {
247    train_c45_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
248}
249
250pub(crate) fn train_c45_with_criterion_and_parallelism(
251    train_set: &dyn TableAccess,
252    criterion: Criterion,
253    parallelism: Parallelism,
254) -> Result<DecisionTreeClassifier, DecisionTreeError> {
255    train_c45_with_criterion_parallelism_and_options(
256        train_set,
257        criterion,
258        parallelism,
259        DecisionTreeOptions::default(),
260    )
261}
262
263pub(crate) fn train_c45_with_criterion_parallelism_and_options(
264    train_set: &dyn TableAccess,
265    criterion: Criterion,
266    parallelism: Parallelism,
267    options: DecisionTreeOptions,
268) -> Result<DecisionTreeClassifier, DecisionTreeError> {
269    train_classifier(
270        train_set,
271        DecisionTreeAlgorithm::C45,
272        criterion,
273        parallelism,
274        options,
275    )
276}
277
278pub fn train_cart(
279    train_set: &dyn TableAccess,
280) -> Result<DecisionTreeClassifier, DecisionTreeError> {
281    train_cart_with_criterion(train_set, Criterion::Gini)
282}
283
284pub fn train_cart_with_criterion(
285    train_set: &dyn TableAccess,
286    criterion: Criterion,
287) -> Result<DecisionTreeClassifier, DecisionTreeError> {
288    train_cart_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
289}
290
291pub(crate) fn train_cart_with_criterion_and_parallelism(
292    train_set: &dyn TableAccess,
293    criterion: Criterion,
294    parallelism: Parallelism,
295) -> Result<DecisionTreeClassifier, DecisionTreeError> {
296    train_cart_with_criterion_parallelism_and_options(
297        train_set,
298        criterion,
299        parallelism,
300        DecisionTreeOptions::default(),
301    )
302}
303
304pub(crate) fn train_cart_with_criterion_parallelism_and_options(
305    train_set: &dyn TableAccess,
306    criterion: Criterion,
307    parallelism: Parallelism,
308    options: DecisionTreeOptions,
309) -> Result<DecisionTreeClassifier, DecisionTreeError> {
310    train_classifier(
311        train_set,
312        DecisionTreeAlgorithm::Cart,
313        criterion,
314        parallelism,
315        options,
316    )
317}
318
319pub fn train_oblivious(
320    train_set: &dyn TableAccess,
321) -> Result<DecisionTreeClassifier, DecisionTreeError> {
322    train_oblivious_with_criterion(train_set, Criterion::Gini)
323}
324
325pub fn train_oblivious_with_criterion(
326    train_set: &dyn TableAccess,
327    criterion: Criterion,
328) -> Result<DecisionTreeClassifier, DecisionTreeError> {
329    train_oblivious_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
330}
331
332pub(crate) fn train_oblivious_with_criterion_and_parallelism(
333    train_set: &dyn TableAccess,
334    criterion: Criterion,
335    parallelism: Parallelism,
336) -> Result<DecisionTreeClassifier, DecisionTreeError> {
337    train_oblivious_with_criterion_parallelism_and_options(
338        train_set,
339        criterion,
340        parallelism,
341        DecisionTreeOptions::default(),
342    )
343}
344
345pub(crate) fn train_oblivious_with_criterion_parallelism_and_options(
346    train_set: &dyn TableAccess,
347    criterion: Criterion,
348    parallelism: Parallelism,
349    options: DecisionTreeOptions,
350) -> Result<DecisionTreeClassifier, DecisionTreeError> {
351    train_classifier(
352        train_set,
353        DecisionTreeAlgorithm::Oblivious,
354        criterion,
355        parallelism,
356        options,
357    )
358}
359
360pub fn train_randomized(
361    train_set: &dyn TableAccess,
362) -> Result<DecisionTreeClassifier, DecisionTreeError> {
363    train_randomized_with_criterion(train_set, Criterion::Gini)
364}
365
366pub fn train_randomized_with_criterion(
367    train_set: &dyn TableAccess,
368    criterion: Criterion,
369) -> Result<DecisionTreeClassifier, DecisionTreeError> {
370    train_randomized_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
371}
372
373pub(crate) fn train_randomized_with_criterion_and_parallelism(
374    train_set: &dyn TableAccess,
375    criterion: Criterion,
376    parallelism: Parallelism,
377) -> Result<DecisionTreeClassifier, DecisionTreeError> {
378    train_randomized_with_criterion_parallelism_and_options(
379        train_set,
380        criterion,
381        parallelism,
382        DecisionTreeOptions::default(),
383    )
384}
385
386pub(crate) fn train_randomized_with_criterion_parallelism_and_options(
387    train_set: &dyn TableAccess,
388    criterion: Criterion,
389    parallelism: Parallelism,
390    options: DecisionTreeOptions,
391) -> Result<DecisionTreeClassifier, DecisionTreeError> {
392    train_classifier(
393        train_set,
394        DecisionTreeAlgorithm::Randomized,
395        criterion,
396        parallelism,
397        options,
398    )
399}
400
401fn train_classifier(
402    train_set: &dyn TableAccess,
403    algorithm: DecisionTreeAlgorithm,
404    criterion: Criterion,
405    parallelism: Parallelism,
406    options: DecisionTreeOptions,
407) -> Result<DecisionTreeClassifier, DecisionTreeError> {
408    if train_set.n_rows() == 0 {
409        return Err(DecisionTreeError::EmptyTarget);
410    }
411
412    let (class_labels, class_indices) = encode_class_labels(train_set)?;
413    let structure = match algorithm {
414        DecisionTreeAlgorithm::Oblivious => train_oblivious_structure(
415            train_set,
416            &class_indices,
417            &class_labels,
418            criterion,
419            parallelism,
420            options.clone(),
421        ),
422        DecisionTreeAlgorithm::Cart | DecisionTreeAlgorithm::Randomized => {
423            let mut nodes = Vec::new();
424            let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
425            let context = BuildContext {
426                table: train_set,
427                class_indices: &class_indices,
428                class_labels: &class_labels,
429                algorithm,
430                criterion,
431                parallelism,
432                options: options.clone(),
433            };
434            let root = build_binary_node_in_place(&context, &mut nodes, &mut all_rows, 0);
435            TreeStructure::Standard { nodes, root }
436        }
437        DecisionTreeAlgorithm::Id3 | DecisionTreeAlgorithm::C45 => {
438            let mut nodes = Vec::new();
439            let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
440            let context = BuildContext {
441                table: train_set,
442                class_indices: &class_indices,
443                class_labels: &class_labels,
444                algorithm,
445                criterion,
446                parallelism,
447                options: options.clone(),
448            };
449            let root = build_multiway_node_in_place(&context, &mut nodes, &mut all_rows, 0);
450            TreeStructure::Standard { nodes, root }
451        }
452    };
453
454    Ok(DecisionTreeClassifier {
455        algorithm,
456        criterion,
457        class_labels,
458        structure,
459        options,
460        num_features: train_set.n_features(),
461        feature_preprocessing: capture_feature_preprocessing(train_set),
462        training_canaries: train_set.canaries(),
463    })
464}
465
466impl DecisionTreeClassifier {
467    pub fn algorithm(&self) -> DecisionTreeAlgorithm {
468        self.algorithm
469    }
470
471    pub fn criterion(&self) -> Criterion {
472        self.criterion
473    }
474
475    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
476        (0..table.n_rows())
477            .map(|row_idx| self.predict_row(table, row_idx))
478            .collect()
479    }
480
481    pub fn predict_proba_table(&self, table: &dyn TableAccess) -> Vec<Vec<f64>> {
482        (0..table.n_rows())
483            .map(|row_idx| self.predict_proba_row(table, row_idx))
484            .collect()
485    }
486
487    fn predict_row(&self, table: &dyn TableAccess, row_idx: usize) -> f64 {
488        match &self.structure {
489            TreeStructure::Standard { nodes, root } => {
490                let mut node_index = *root;
491                loop {
492                    match &nodes[node_index] {
493                        TreeNode::Leaf { class_index, .. } => {
494                            return self.class_labels[*class_index];
495                        }
496                        TreeNode::MultiwaySplit {
497                            feature_index,
498                            fallback_class_index,
499                            branches,
500                            missing_child,
501                            ..
502                        } => {
503                            if table.is_missing(*feature_index, row_idx) {
504                                if let Some(child_index) = missing_child {
505                                    node_index = *child_index;
506                                } else {
507                                    return self.class_labels[*fallback_class_index];
508                                }
509                                continue;
510                            }
511                            let bin = table.binned_value(*feature_index, row_idx);
512                            if let Some((_, child_index)) =
513                                branches.iter().find(|(branch_bin, _)| *branch_bin == bin)
514                            {
515                                node_index = *child_index;
516                            } else {
517                                return self.class_labels[*fallback_class_index];
518                            }
519                        }
520                        TreeNode::BinarySplit {
521                            feature_index,
522                            threshold_bin,
523                            missing_direction,
524                            left_child,
525                            right_child,
526                            class_counts,
527                            ..
528                        } => {
529                            if table.is_missing(*feature_index, row_idx) {
530                                match missing_direction {
531                                    MissingBranchDirection::Left => {
532                                        node_index = *left_child;
533                                    }
534                                    MissingBranchDirection::Right => {
535                                        node_index = *right_child;
536                                    }
537                                    MissingBranchDirection::Node => {
538                                        return self.class_labels
539                                            [majority_class_from_counts(class_counts)];
540                                    }
541                                }
542                                continue;
543                            }
544                            let bin = table.binned_value(*feature_index, row_idx);
545                            node_index = if bin <= *threshold_bin {
546                                *left_child
547                            } else {
548                                *right_child
549                            };
550                        }
551                    }
552                }
553            }
554            TreeStructure::Oblivious {
555                splits,
556                leaf_class_indices,
557                ..
558            } => {
559                let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
560                    let go_right =
561                        table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
562                    (leaf_index << 1) | usize::from(go_right)
563                });
564
565                self.class_labels[leaf_class_indices[leaf_index]]
566            }
567        }
568    }
569
570    fn predict_proba_row(&self, table: &dyn TableAccess, row_idx: usize) -> Vec<f64> {
571        match &self.structure {
572            TreeStructure::Standard { nodes, root } => {
573                let mut node_index = *root;
574                loop {
575                    match &nodes[node_index] {
576                        TreeNode::Leaf { class_counts, .. } => {
577                            return normalized_class_probabilities(class_counts);
578                        }
579                        TreeNode::MultiwaySplit {
580                            feature_index,
581                            branches,
582                            missing_child,
583                            class_counts,
584                            ..
585                        } => {
586                            if table.is_missing(*feature_index, row_idx) {
587                                if let Some(child_index) = missing_child {
588                                    node_index = *child_index;
589                                } else {
590                                    return normalized_class_probabilities(class_counts);
591                                }
592                                continue;
593                            }
594                            let bin = table.binned_value(*feature_index, row_idx);
595                            if let Some((_, child_index)) =
596                                branches.iter().find(|(branch_bin, _)| *branch_bin == bin)
597                            {
598                                node_index = *child_index;
599                            } else {
600                                return normalized_class_probabilities(class_counts);
601                            }
602                        }
603                        TreeNode::BinarySplit {
604                            feature_index,
605                            threshold_bin,
606                            missing_direction,
607                            left_child,
608                            right_child,
609                            class_counts,
610                            ..
611                        } => {
612                            if table.is_missing(*feature_index, row_idx) {
613                                match missing_direction {
614                                    MissingBranchDirection::Left => {
615                                        node_index = *left_child;
616                                    }
617                                    MissingBranchDirection::Right => {
618                                        node_index = *right_child;
619                                    }
620                                    MissingBranchDirection::Node => {
621                                        return normalized_class_probabilities(class_counts);
622                                    }
623                                }
624                                continue;
625                            }
626                            let bin = table.binned_value(*feature_index, row_idx);
627                            node_index = if bin <= *threshold_bin {
628                                *left_child
629                            } else {
630                                *right_child
631                            };
632                        }
633                    }
634                }
635            }
636            TreeStructure::Oblivious {
637                splits,
638                leaf_class_counts,
639                ..
640            } => {
641                let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
642                    let go_right =
643                        table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
644                    (leaf_index << 1) | usize::from(go_right)
645                });
646
647                normalized_class_probabilities(&leaf_class_counts[leaf_index])
648            }
649        }
650    }
651
652    pub(crate) fn class_labels(&self) -> &[f64] {
653        &self.class_labels
654    }
655
656    pub(crate) fn structure(&self) -> &TreeStructure {
657        &self.structure
658    }
659
660    pub(crate) fn num_features(&self) -> usize {
661        self.num_features
662    }
663
664    pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
665        &self.feature_preprocessing
666    }
667
668    pub(crate) fn training_metadata(&self) -> TrainingMetadata {
669        TrainingMetadata {
670            algorithm: "dt".to_string(),
671            task: "classification".to_string(),
672            tree_type: tree_type_name(match self.algorithm {
673                DecisionTreeAlgorithm::Id3 => crate::TreeType::Id3,
674                DecisionTreeAlgorithm::C45 => crate::TreeType::C45,
675                DecisionTreeAlgorithm::Cart => crate::TreeType::Cart,
676                DecisionTreeAlgorithm::Randomized => crate::TreeType::Randomized,
677                DecisionTreeAlgorithm::Oblivious => crate::TreeType::Oblivious,
678            })
679            .to_string(),
680            criterion: criterion_name(self.criterion).to_string(),
681            canaries: self.training_canaries,
682            compute_oob: false,
683            max_depth: Some(self.options.max_depth),
684            min_samples_split: Some(self.options.min_samples_split),
685            min_samples_leaf: Some(self.options.min_samples_leaf),
686            n_trees: None,
687            max_features: self.options.max_features,
688            seed: None,
689            oob_score: None,
690            class_labels: Some(self.class_labels.clone()),
691            learning_rate: None,
692            bootstrap: None,
693            top_gradient_fraction: None,
694            other_gradient_fraction: None,
695        }
696    }
697
698    pub(crate) fn to_ir_tree(&self) -> TreeDefinition {
699        match &self.structure {
700            TreeStructure::Standard { nodes, root } => {
701                let depths = standard_node_depths(nodes, *root);
702                TreeDefinition::NodeTree {
703                    tree_id: 0,
704                    weight: 1.0,
705                    root_node_id: *root,
706                    nodes: nodes
707                        .iter()
708                        .enumerate()
709                        .map(|(node_id, node)| match node {
710                            TreeNode::Leaf {
711                                class_index,
712                                sample_count,
713                                class_counts,
714                            } => NodeTreeNode::Leaf {
715                                node_id,
716                                depth: depths[node_id],
717                                leaf: self.class_leaf(*class_index),
718                                stats: NodeStats {
719                                    sample_count: *sample_count,
720                                    impurity: None,
721                                    gain: None,
722                                    class_counts: Some(class_counts.clone()),
723                                    variance: None,
724                                },
725                            },
726                            TreeNode::BinarySplit {
727                                feature_index,
728                                threshold_bin,
729                                missing_direction,
730                                left_child,
731                                right_child,
732                                sample_count,
733                                impurity,
734                                gain,
735                                class_counts,
736                            } => NodeTreeNode::BinaryBranch {
737                                node_id,
738                                depth: depths[node_id],
739                                split: binary_split_ir(
740                                    *feature_index,
741                                    *threshold_bin,
742                                    *missing_direction,
743                                    &self.feature_preprocessing,
744                                ),
745                                children: BinaryChildren {
746                                    left: *left_child,
747                                    right: *right_child,
748                                },
749                                stats: NodeStats {
750                                    sample_count: *sample_count,
751                                    impurity: Some(*impurity),
752                                    gain: Some(*gain),
753                                    class_counts: Some(class_counts.clone()),
754                                    variance: None,
755                                },
756                            },
757                            TreeNode::MultiwaySplit {
758                                feature_index,
759                                fallback_class_index,
760                                branches,
761                                missing_child: _,
762                                sample_count,
763                                impurity,
764                                gain,
765                                class_counts,
766                            } => NodeTreeNode::MultiwayBranch {
767                                node_id,
768                                depth: depths[node_id],
769                                split: MultiwaySplit {
770                                    split_type: "binned_value_multiway".to_string(),
771                                    feature_index: *feature_index,
772                                    feature_name: feature_name(*feature_index),
773                                    comparison_dtype: "uint16".to_string(),
774                                },
775                                branches: branches
776                                    .iter()
777                                    .map(|(bin, child)| MultiwayBranch {
778                                        bin: *bin,
779                                        child: *child,
780                                    })
781                                    .collect(),
782                                unmatched_leaf: self.class_leaf(*fallback_class_index),
783                                stats: NodeStats {
784                                    sample_count: *sample_count,
785                                    impurity: Some(*impurity),
786                                    gain: Some(*gain),
787                                    class_counts: Some(class_counts.clone()),
788                                    variance: None,
789                                },
790                            },
791                        })
792                        .collect(),
793                }
794            }
795            TreeStructure::Oblivious {
796                splits,
797                leaf_class_indices,
798                leaf_sample_counts,
799                leaf_class_counts,
800            } => TreeDefinition::ObliviousLevels {
801                tree_id: 0,
802                weight: 1.0,
803                depth: splits.len(),
804                levels: splits
805                    .iter()
806                    .enumerate()
807                    .map(|(level, split)| ObliviousLevel {
808                        level,
809                        split: oblivious_split_ir(
810                            split.feature_index,
811                            split.threshold_bin,
812                            &self.feature_preprocessing,
813                        ),
814                        stats: NodeStats {
815                            sample_count: split.sample_count,
816                            impurity: Some(split.impurity),
817                            gain: Some(split.gain),
818                            class_counts: None,
819                            variance: None,
820                        },
821                    })
822                    .collect(),
823                leaf_indexing: LeafIndexing {
824                    bit_order: "msb_first".to_string(),
825                    index_formula: "sum(bit[level] << (depth - 1 - level))".to_string(),
826                },
827                leaves: leaf_class_indices
828                    .iter()
829                    .enumerate()
830                    .map(|(leaf_index, class_index)| IndexedLeaf {
831                        leaf_index,
832                        leaf: self.class_leaf(*class_index),
833                        stats: NodeStats {
834                            sample_count: leaf_sample_counts[leaf_index],
835                            impurity: None,
836                            gain: None,
837                            class_counts: Some(leaf_class_counts[leaf_index].clone()),
838                            variance: None,
839                        },
840                    })
841                    .collect(),
842            },
843        }
844    }
845
846    fn class_leaf(&self, class_index: usize) -> LeafPayload {
847        LeafPayload::ClassIndex {
848            class_index,
849            class_value: self.class_labels[class_index],
850        }
851    }
852
853    #[allow(clippy::too_many_arguments)]
854    pub(crate) fn from_ir_parts(
855        algorithm: DecisionTreeAlgorithm,
856        criterion: Criterion,
857        class_labels: Vec<f64>,
858        structure: TreeStructure,
859        options: DecisionTreeOptions,
860        num_features: usize,
861        feature_preprocessing: Vec<FeaturePreprocessing>,
862        training_canaries: usize,
863    ) -> Self {
864        Self {
865            algorithm,
866            criterion,
867            class_labels,
868            structure,
869            options,
870            num_features,
871            feature_preprocessing,
872            training_canaries,
873        }
874    }
875}
876
877fn build_binary_node_in_place(
878    context: &BuildContext<'_>,
879    nodes: &mut Vec<TreeNode>,
880    rows: &mut [usize],
881    depth: usize,
882) -> usize {
883    build_binary_node_in_place_with_hist(context, nodes, rows, depth, None)
884}
885
886fn build_binary_node_in_place_with_hist(
887    context: &BuildContext<'_>,
888    nodes: &mut Vec<TreeNode>,
889    rows: &mut [usize],
890    depth: usize,
891    histograms: Option<Vec<ClassificationFeatureHistogram>>,
892) -> usize {
893    let majority_class_index =
894        majority_class(rows, context.class_indices, context.class_labels.len());
895    let current_class_counts =
896        class_counts(rows, context.class_indices, context.class_labels.len());
897
898    if rows.is_empty()
899        || depth >= context.options.max_depth
900        || rows.len() < context.options.min_samples_split
901        || is_pure(rows, context.class_indices)
902    {
903        return push_leaf(
904            nodes,
905            majority_class_index,
906            rows.len(),
907            current_class_counts,
908        );
909    }
910
911    let scoring = SplitScoringContext {
912        table: context.table,
913        class_indices: context.class_indices,
914        num_classes: context.class_labels.len(),
915        criterion: context.criterion,
916        min_samples_leaf: context.options.min_samples_leaf,
917        missing_value_strategies: &context.options.missing_value_strategies,
918    };
919    let histograms = histograms.unwrap_or_else(|| {
920        build_classification_node_histograms(
921            context.table,
922            context.class_indices,
923            rows,
924            context.class_labels.len(),
925        )
926    });
927    let feature_indices = candidate_feature_indices(
928        context.table,
929        context.options.max_features,
930        node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
931    );
932    let split_candidates = if context.parallelism.enabled() {
933        feature_indices
934            .into_par_iter()
935            .filter_map(|feature_index| {
936                score_binary_split_choice_from_hist(
937                    &scoring,
938                    &histograms[feature_index],
939                    feature_index,
940                    rows,
941                    &current_class_counts,
942                    context.algorithm,
943                )
944            })
945            .collect::<Vec<_>>()
946    } else {
947        feature_indices
948            .into_iter()
949            .filter_map(|feature_index| {
950                score_binary_split_choice_from_hist(
951                    &scoring,
952                    &histograms[feature_index],
953                    feature_index,
954                    rows,
955                    &current_class_counts,
956                    context.algorithm,
957                )
958            })
959            .collect::<Vec<_>>()
960    };
961    let best_split = select_best_non_canary_candidate(
962        context.table,
963        split_candidates,
964        context.options.canary_filter,
965        |candidate| candidate.score,
966        |candidate| candidate.feature_index,
967    )
968    .selected;
969
970    match best_split {
971        Some(best_split) if best_split.score > 0.0 => {
972            let impurity =
973                classification_impurity(&current_class_counts, rows.len(), context.criterion);
974            let left_count = partition_rows_for_binary_split(
975                context.table,
976                best_split.feature_index,
977                best_split.threshold_bin,
978                best_split.missing_direction,
979                rows,
980            );
981            let (left_rows, right_rows) = rows.split_at_mut(left_count);
982            let (left_histograms, right_histograms) = if left_rows.len() <= right_rows.len() {
983                let left_histograms = build_classification_node_histograms(
984                    context.table,
985                    context.class_indices,
986                    left_rows,
987                    context.class_labels.len(),
988                );
989                let right_histograms =
990                    subtract_classification_node_histograms(&histograms, &left_histograms);
991                (left_histograms, right_histograms)
992            } else {
993                let right_histograms = build_classification_node_histograms(
994                    context.table,
995                    context.class_indices,
996                    right_rows,
997                    context.class_labels.len(),
998                );
999                let left_histograms =
1000                    subtract_classification_node_histograms(&histograms, &right_histograms);
1001                (left_histograms, right_histograms)
1002            };
1003            let left_child = build_binary_node_in_place_with_hist(
1004                context,
1005                nodes,
1006                left_rows,
1007                depth + 1,
1008                Some(left_histograms),
1009            );
1010            let right_child = build_binary_node_in_place_with_hist(
1011                context,
1012                nodes,
1013                right_rows,
1014                depth + 1,
1015                Some(right_histograms),
1016            );
1017
1018            push_node(
1019                nodes,
1020                TreeNode::BinarySplit {
1021                    feature_index: best_split.feature_index,
1022                    threshold_bin: best_split.threshold_bin,
1023                    missing_direction: best_split.missing_direction,
1024                    left_child,
1025                    right_child,
1026                    sample_count: rows.len(),
1027                    impurity,
1028                    gain: best_split.score,
1029                    class_counts: current_class_counts,
1030                },
1031            )
1032        }
1033        _ => push_leaf(
1034            nodes,
1035            majority_class_index,
1036            rows.len(),
1037            current_class_counts,
1038        ),
1039    }
1040}
1041
1042fn build_multiway_node_in_place(
1043    context: &BuildContext<'_>,
1044    nodes: &mut Vec<TreeNode>,
1045    rows: &mut [usize],
1046    depth: usize,
1047) -> usize {
1048    let majority_class_index =
1049        majority_class(rows, context.class_indices, context.class_labels.len());
1050    let current_class_counts =
1051        class_counts(rows, context.class_indices, context.class_labels.len());
1052
1053    if rows.is_empty()
1054        || depth >= context.options.max_depth
1055        || rows.len() < context.options.min_samples_split
1056        || is_pure(rows, context.class_indices)
1057    {
1058        return push_leaf(
1059            nodes,
1060            majority_class_index,
1061            rows.len(),
1062            current_class_counts,
1063        );
1064    }
1065
1066    let metric = match context.algorithm {
1067        DecisionTreeAlgorithm::Id3 => MultiwayMetric::InformationGain,
1068        DecisionTreeAlgorithm::C45 => MultiwayMetric::GainRatio,
1069        _ => unreachable!("multiway builder only supports id3/c45"),
1070    };
1071    let scoring = SplitScoringContext {
1072        table: context.table,
1073        class_indices: context.class_indices,
1074        num_classes: context.class_labels.len(),
1075        criterion: context.criterion,
1076        min_samples_leaf: context.options.min_samples_leaf,
1077        missing_value_strategies: &context.options.missing_value_strategies,
1078    };
1079    let feature_indices = candidate_feature_indices(
1080        context.table,
1081        context.options.max_features,
1082        node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
1083    );
1084    let split_candidates = if context.parallelism.enabled() {
1085        feature_indices
1086            .into_par_iter()
1087            .filter_map(|feature_index| {
1088                score_multiway_split_choice(&scoring, feature_index, rows, metric)
1089            })
1090            .collect::<Vec<_>>()
1091    } else {
1092        feature_indices
1093            .into_iter()
1094            .filter_map(|feature_index| {
1095                score_multiway_split_choice(&scoring, feature_index, rows, metric)
1096            })
1097            .collect::<Vec<_>>()
1098    };
1099    let best_split = select_best_non_canary_candidate(
1100        context.table,
1101        split_candidates,
1102        context.options.canary_filter,
1103        |candidate| candidate.score,
1104        |candidate| candidate.feature_index,
1105    )
1106    .selected;
1107
1108    match best_split {
1109        Some(best_split) if best_split.score > 0.0 => {
1110            let impurity =
1111                classification_impurity(&current_class_counts, rows.len(), context.criterion);
1112            let branch_ranges = partition_rows_for_multiway_split(
1113                context.table,
1114                best_split.feature_index,
1115                &best_split.branch_bins,
1116                best_split.missing_branch_bin,
1117                rows,
1118            );
1119            let mut branch_nodes = Vec::with_capacity(branch_ranges.len());
1120            let mut missing_child = None;
1121            for (bin, start, end) in branch_ranges {
1122                let child =
1123                    build_multiway_node_in_place(context, nodes, &mut rows[start..end], depth + 1);
1124                if best_split.missing_branch_bin == Some(bin) {
1125                    missing_child = Some(child);
1126                }
1127                branch_nodes.push((bin, child));
1128            }
1129
1130            push_node(
1131                nodes,
1132                TreeNode::MultiwaySplit {
1133                    feature_index: best_split.feature_index,
1134                    fallback_class_index: majority_class_index,
1135                    branches: branch_nodes,
1136                    missing_child,
1137                    sample_count: rows.len(),
1138                    impurity,
1139                    gain: best_split.score,
1140                    class_counts: current_class_counts,
1141                },
1142            )
1143        }
1144        _ => push_leaf(
1145            nodes,
1146            majority_class_index,
1147            rows.len(),
1148            current_class_counts,
1149        ),
1150    }
1151}
1152
1153struct BuildContext<'a> {
1154    table: &'a dyn TableAccess,
1155    class_indices: &'a [usize],
1156    class_labels: &'a [f64],
1157    algorithm: DecisionTreeAlgorithm,
1158    criterion: Criterion,
1159    parallelism: Parallelism,
1160    options: DecisionTreeOptions,
1161}
1162
1163fn encode_class_labels(
1164    train_set: &dyn TableAccess,
1165) -> Result<(Vec<f64>, Vec<usize>), DecisionTreeError> {
1166    let targets: Vec<f64> = (0..train_set.n_rows())
1167        .map(|row_idx| {
1168            let value = train_set.target_value(row_idx);
1169            if value.is_finite() {
1170                Ok(value)
1171            } else {
1172                Err(DecisionTreeError::InvalidTargetValue {
1173                    row: row_idx,
1174                    value,
1175                })
1176            }
1177        })
1178        .collect::<Result<_, _>>()?;
1179
1180    let class_labels = targets
1181        .iter()
1182        .copied()
1183        .fold(Vec::<f64>::new(), |mut labels, value| {
1184            if labels
1185                .binary_search_by(|candidate| candidate.total_cmp(&value))
1186                .is_err()
1187            {
1188                labels.push(value);
1189                labels.sort_by(|left, right| left.total_cmp(right));
1190            }
1191            labels
1192        });
1193
1194    let class_indices = targets
1195        .iter()
1196        .map(|value| {
1197            class_labels
1198                .binary_search_by(|candidate| candidate.total_cmp(value))
1199                .expect("target value must exist in class label vocabulary")
1200        })
1201        .collect();
1202
1203    Ok((class_labels, class_indices))
1204}
1205
1206fn class_counts(rows: &[usize], class_indices: &[usize], num_classes: usize) -> Vec<usize> {
1207    rows.iter()
1208        .fold(vec![0usize; num_classes], |mut counts, row_idx| {
1209            counts[class_indices[*row_idx]] += 1;
1210            counts
1211        })
1212}
1213
1214fn majority_class(rows: &[usize], class_indices: &[usize], num_classes: usize) -> usize {
1215    majority_class_from_counts(&class_counts(rows, class_indices, num_classes))
1216}
1217
1218fn majority_class_from_counts(counts: &[usize]) -> usize {
1219    counts
1220        .iter()
1221        .copied()
1222        .enumerate()
1223        .max_by(|left, right| left.1.cmp(&right.1).then_with(|| right.0.cmp(&left.0)))
1224        .map(|(class_index, _count)| class_index)
1225        .unwrap_or(0)
1226}
1227
1228fn is_pure(rows: &[usize], class_indices: &[usize]) -> bool {
1229    rows.first().is_none_or(|first_row| {
1230        rows.iter()
1231            .all(|row_idx| class_indices[*row_idx] == class_indices[*first_row])
1232    })
1233}
1234
1235fn entropy(counts: &[usize], total: usize) -> f64 {
1236    counts
1237        .iter()
1238        .copied()
1239        .filter(|count| *count > 0)
1240        .map(|count| {
1241            let probability = count as f64 / total as f64;
1242            -probability * probability.log2()
1243        })
1244        .sum()
1245}
1246
1247fn gini(counts: &[usize], total: usize) -> f64 {
1248    1.0 - counts
1249        .iter()
1250        .copied()
1251        .map(|count| {
1252            let probability = count as f64 / total as f64;
1253            probability * probability
1254        })
1255        .sum::<f64>()
1256}
1257
1258fn classification_impurity(counts: &[usize], total: usize, criterion: Criterion) -> f64 {
1259    match criterion {
1260        Criterion::Entropy => entropy(counts, total),
1261        Criterion::Gini => gini(counts, total),
1262        _ => unreachable!("classification impurity only supports gini or entropy"),
1263    }
1264}
1265
1266fn push_leaf(
1267    nodes: &mut Vec<TreeNode>,
1268    class_index: usize,
1269    sample_count: usize,
1270    class_counts: Vec<usize>,
1271) -> usize {
1272    push_node(
1273        nodes,
1274        TreeNode::Leaf {
1275            class_index,
1276            sample_count,
1277            class_counts,
1278        },
1279    )
1280}
1281
1282fn push_node(nodes: &mut Vec<TreeNode>, node: TreeNode) -> usize {
1283    nodes.push(node);
1284    nodes.len() - 1
1285}
1286
1287#[cfg(test)]
1288mod tests;