Skip to main content

forestfire_core/tree/
classifier.rs

1//! Classification tree learners.
2//!
3//! The module intentionally supports multiple tree families because they express
4//! different tradeoffs:
5//!
6//! - `id3` / `c45` keep multiway splits for categorical-like binned features.
7//! - `cart` is the standard binary threshold learner.
8//! - `randomized` keeps the CART structure but cheapens split search.
9//! - `oblivious` uses one split per depth, which is attractive for some runtime
10//!   layouts and boosting-style ensembles.
11//!
12//! The hot numeric paths are written around binned histograms and in-place row
13//! partitioning. That is why many helpers operate on row-index buffers instead of
14//! allocating fresh row vectors at every recursive step.
15
16use crate::ir::{
17    BinaryChildren, BinarySplit, IndexedLeaf, LeafIndexing, LeafPayload, MultiwayBranch,
18    MultiwaySplit, NodeStats, NodeTreeNode, ObliviousLevel, ObliviousSplit as IrObliviousSplit,
19    TrainingMetadata, TreeDefinition, criterion_name, feature_name, threshold_upper_bound,
20    tree_type_name,
21};
22use crate::tree::shared::{
23    candidate_feature_indices, choose_random_threshold, node_seed, partition_rows_for_binary_split,
24};
25use crate::{Criterion, FeaturePreprocessing, Parallelism, capture_feature_preprocessing};
26use forestfire_data::TableAccess;
27use rayon::prelude::*;
28use std::collections::BTreeMap;
29use std::error::Error;
30use std::fmt::{Display, Formatter};
31
32mod histogram;
33mod ir_support;
34mod oblivious;
35mod partitioning;
36mod split_scoring;
37
38use histogram::{
39    ClassificationFeatureHistogram, build_classification_node_histograms,
40    subtract_classification_node_histograms,
41};
42use ir_support::{
43    binary_split_ir, normalized_class_probabilities, oblivious_split_ir, standard_node_depths,
44};
45use oblivious::train_oblivious_structure;
46use partitioning::partition_rows_for_multiway_split;
47use split_scoring::{
48    MultiwayMetric, SplitScoringContext, score_binary_split_choice_from_hist,
49    score_multiway_split_choice,
50};
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum DecisionTreeAlgorithm {
54    Id3,
55    C45,
56    Cart,
57    Randomized,
58    Oblivious,
59}
60
61/// Shared training controls for classification tree learners.
62///
63/// The defaults are intentionally modest rather than "grow until pure", because
64/// ForestFire wants trees to be a stable building block for ensembles and
65/// interpretable standalone models.
66#[derive(Debug, Clone, Copy)]
67pub struct DecisionTreeOptions {
68    pub max_depth: usize,
69    pub min_samples_split: usize,
70    pub min_samples_leaf: usize,
71    pub max_features: Option<usize>,
72    pub random_seed: u64,
73}
74
75impl Default for DecisionTreeOptions {
76    fn default() -> Self {
77        Self {
78            max_depth: 8,
79            min_samples_split: 2,
80            min_samples_leaf: 1,
81            max_features: None,
82            random_seed: 0,
83        }
84    }
85}
86
87#[derive(Debug)]
88pub enum DecisionTreeError {
89    EmptyTarget,
90    InvalidTargetValue { row: usize, value: f64 },
91}
92
93impl Display for DecisionTreeError {
94    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
95        match self {
96            DecisionTreeError::EmptyTarget => write!(f, "Cannot train on an empty target vector."),
97            DecisionTreeError::InvalidTargetValue { row, value } => write!(
98                f,
99                "Classification targets must be finite values. Found {} at row {}.",
100                value, row
101            ),
102        }
103    }
104}
105
106impl Error for DecisionTreeError {}
107
108/// Concrete trained classification tree.
109#[derive(Debug, Clone)]
110pub struct DecisionTreeClassifier {
111    algorithm: DecisionTreeAlgorithm,
112    criterion: Criterion,
113    class_labels: Vec<f64>,
114    structure: TreeStructure,
115    options: DecisionTreeOptions,
116    num_features: usize,
117    feature_preprocessing: Vec<FeaturePreprocessing>,
118    training_canaries: usize,
119}
120
121#[derive(Debug, Clone)]
122pub(crate) enum TreeStructure {
123    Standard {
124        nodes: Vec<TreeNode>,
125        root: usize,
126    },
127    Oblivious {
128        splits: Vec<ObliviousSplit>,
129        leaf_class_indices: Vec<usize>,
130        leaf_sample_counts: Vec<usize>,
131        leaf_class_counts: Vec<Vec<usize>>,
132    },
133}
134
135#[derive(Debug, Clone, Copy)]
136pub(crate) struct ObliviousSplit {
137    pub(crate) feature_index: usize,
138    pub(crate) threshold_bin: u16,
139    pub(crate) sample_count: usize,
140    pub(crate) impurity: f64,
141    pub(crate) gain: f64,
142}
143
144#[derive(Debug, Clone)]
145pub(crate) enum TreeNode {
146    Leaf {
147        class_index: usize,
148        sample_count: usize,
149        class_counts: Vec<usize>,
150    },
151    MultiwaySplit {
152        feature_index: usize,
153        fallback_class_index: usize,
154        branches: Vec<(u16, usize)>,
155        sample_count: usize,
156        impurity: f64,
157        gain: f64,
158        class_counts: Vec<usize>,
159    },
160    BinarySplit {
161        feature_index: usize,
162        threshold_bin: u16,
163        left_child: usize,
164        right_child: usize,
165        sample_count: usize,
166        impurity: f64,
167        gain: f64,
168        class_counts: Vec<usize>,
169    },
170}
171
172#[derive(Debug, Clone, Copy)]
173struct BinarySplitChoice {
174    feature_index: usize,
175    score: f64,
176    threshold_bin: u16,
177}
178
179#[derive(Debug, Clone)]
180struct MultiwaySplitChoice {
181    feature_index: usize,
182    score: f64,
183    branch_bins: Vec<u16>,
184}
185
186pub fn train_id3(train_set: &dyn TableAccess) -> Result<DecisionTreeClassifier, DecisionTreeError> {
187    train_id3_with_criterion(train_set, Criterion::Entropy)
188}
189
190pub fn train_id3_with_criterion(
191    train_set: &dyn TableAccess,
192    criterion: Criterion,
193) -> Result<DecisionTreeClassifier, DecisionTreeError> {
194    train_id3_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
195}
196
197pub(crate) fn train_id3_with_criterion_and_parallelism(
198    train_set: &dyn TableAccess,
199    criterion: Criterion,
200    parallelism: Parallelism,
201) -> Result<DecisionTreeClassifier, DecisionTreeError> {
202    train_id3_with_criterion_parallelism_and_options(
203        train_set,
204        criterion,
205        parallelism,
206        DecisionTreeOptions::default(),
207    )
208}
209
210pub(crate) fn train_id3_with_criterion_parallelism_and_options(
211    train_set: &dyn TableAccess,
212    criterion: Criterion,
213    parallelism: Parallelism,
214    options: DecisionTreeOptions,
215) -> Result<DecisionTreeClassifier, DecisionTreeError> {
216    train_classifier(
217        train_set,
218        DecisionTreeAlgorithm::Id3,
219        criterion,
220        parallelism,
221        options,
222    )
223}
224
225pub fn train_c45(train_set: &dyn TableAccess) -> Result<DecisionTreeClassifier, DecisionTreeError> {
226    train_c45_with_criterion(train_set, Criterion::Entropy)
227}
228
229pub fn train_c45_with_criterion(
230    train_set: &dyn TableAccess,
231    criterion: Criterion,
232) -> Result<DecisionTreeClassifier, DecisionTreeError> {
233    train_c45_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
234}
235
236pub(crate) fn train_c45_with_criterion_and_parallelism(
237    train_set: &dyn TableAccess,
238    criterion: Criterion,
239    parallelism: Parallelism,
240) -> Result<DecisionTreeClassifier, DecisionTreeError> {
241    train_c45_with_criterion_parallelism_and_options(
242        train_set,
243        criterion,
244        parallelism,
245        DecisionTreeOptions::default(),
246    )
247}
248
249pub(crate) fn train_c45_with_criterion_parallelism_and_options(
250    train_set: &dyn TableAccess,
251    criterion: Criterion,
252    parallelism: Parallelism,
253    options: DecisionTreeOptions,
254) -> Result<DecisionTreeClassifier, DecisionTreeError> {
255    train_classifier(
256        train_set,
257        DecisionTreeAlgorithm::C45,
258        criterion,
259        parallelism,
260        options,
261    )
262}
263
264pub fn train_cart(
265    train_set: &dyn TableAccess,
266) -> Result<DecisionTreeClassifier, DecisionTreeError> {
267    train_cart_with_criterion(train_set, Criterion::Gini)
268}
269
270pub fn train_cart_with_criterion(
271    train_set: &dyn TableAccess,
272    criterion: Criterion,
273) -> Result<DecisionTreeClassifier, DecisionTreeError> {
274    train_cart_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
275}
276
277pub(crate) fn train_cart_with_criterion_and_parallelism(
278    train_set: &dyn TableAccess,
279    criterion: Criterion,
280    parallelism: Parallelism,
281) -> Result<DecisionTreeClassifier, DecisionTreeError> {
282    train_cart_with_criterion_parallelism_and_options(
283        train_set,
284        criterion,
285        parallelism,
286        DecisionTreeOptions::default(),
287    )
288}
289
290pub(crate) fn train_cart_with_criterion_parallelism_and_options(
291    train_set: &dyn TableAccess,
292    criterion: Criterion,
293    parallelism: Parallelism,
294    options: DecisionTreeOptions,
295) -> Result<DecisionTreeClassifier, DecisionTreeError> {
296    train_classifier(
297        train_set,
298        DecisionTreeAlgorithm::Cart,
299        criterion,
300        parallelism,
301        options,
302    )
303}
304
305pub fn train_oblivious(
306    train_set: &dyn TableAccess,
307) -> Result<DecisionTreeClassifier, DecisionTreeError> {
308    train_oblivious_with_criterion(train_set, Criterion::Gini)
309}
310
311pub fn train_oblivious_with_criterion(
312    train_set: &dyn TableAccess,
313    criterion: Criterion,
314) -> Result<DecisionTreeClassifier, DecisionTreeError> {
315    train_oblivious_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
316}
317
318pub(crate) fn train_oblivious_with_criterion_and_parallelism(
319    train_set: &dyn TableAccess,
320    criterion: Criterion,
321    parallelism: Parallelism,
322) -> Result<DecisionTreeClassifier, DecisionTreeError> {
323    train_oblivious_with_criterion_parallelism_and_options(
324        train_set,
325        criterion,
326        parallelism,
327        DecisionTreeOptions::default(),
328    )
329}
330
331pub(crate) fn train_oblivious_with_criterion_parallelism_and_options(
332    train_set: &dyn TableAccess,
333    criterion: Criterion,
334    parallelism: Parallelism,
335    options: DecisionTreeOptions,
336) -> Result<DecisionTreeClassifier, DecisionTreeError> {
337    train_classifier(
338        train_set,
339        DecisionTreeAlgorithm::Oblivious,
340        criterion,
341        parallelism,
342        options,
343    )
344}
345
346pub fn train_randomized(
347    train_set: &dyn TableAccess,
348) -> Result<DecisionTreeClassifier, DecisionTreeError> {
349    train_randomized_with_criterion(train_set, Criterion::Gini)
350}
351
352pub fn train_randomized_with_criterion(
353    train_set: &dyn TableAccess,
354    criterion: Criterion,
355) -> Result<DecisionTreeClassifier, DecisionTreeError> {
356    train_randomized_with_criterion_and_parallelism(train_set, criterion, Parallelism::sequential())
357}
358
359pub(crate) fn train_randomized_with_criterion_and_parallelism(
360    train_set: &dyn TableAccess,
361    criterion: Criterion,
362    parallelism: Parallelism,
363) -> Result<DecisionTreeClassifier, DecisionTreeError> {
364    train_randomized_with_criterion_parallelism_and_options(
365        train_set,
366        criterion,
367        parallelism,
368        DecisionTreeOptions::default(),
369    )
370}
371
372pub(crate) fn train_randomized_with_criterion_parallelism_and_options(
373    train_set: &dyn TableAccess,
374    criterion: Criterion,
375    parallelism: Parallelism,
376    options: DecisionTreeOptions,
377) -> Result<DecisionTreeClassifier, DecisionTreeError> {
378    train_classifier(
379        train_set,
380        DecisionTreeAlgorithm::Randomized,
381        criterion,
382        parallelism,
383        options,
384    )
385}
386
387fn train_classifier(
388    train_set: &dyn TableAccess,
389    algorithm: DecisionTreeAlgorithm,
390    criterion: Criterion,
391    parallelism: Parallelism,
392    options: DecisionTreeOptions,
393) -> Result<DecisionTreeClassifier, DecisionTreeError> {
394    if train_set.n_rows() == 0 {
395        return Err(DecisionTreeError::EmptyTarget);
396    }
397
398    let (class_labels, class_indices) = encode_class_labels(train_set)?;
399    let structure = match algorithm {
400        DecisionTreeAlgorithm::Oblivious => train_oblivious_structure(
401            train_set,
402            &class_indices,
403            &class_labels,
404            criterion,
405            parallelism,
406            options,
407        ),
408        DecisionTreeAlgorithm::Cart | DecisionTreeAlgorithm::Randomized => {
409            let mut nodes = Vec::new();
410            let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
411            let context = BuildContext {
412                table: train_set,
413                class_indices: &class_indices,
414                class_labels: &class_labels,
415                algorithm,
416                criterion,
417                parallelism,
418                options,
419            };
420            let root = build_binary_node_in_place(&context, &mut nodes, &mut all_rows, 0);
421            TreeStructure::Standard { nodes, root }
422        }
423        DecisionTreeAlgorithm::Id3 | DecisionTreeAlgorithm::C45 => {
424            let mut nodes = Vec::new();
425            let mut all_rows: Vec<usize> = (0..train_set.n_rows()).collect();
426            let context = BuildContext {
427                table: train_set,
428                class_indices: &class_indices,
429                class_labels: &class_labels,
430                algorithm,
431                criterion,
432                parallelism,
433                options,
434            };
435            let root = build_multiway_node_in_place(&context, &mut nodes, &mut all_rows, 0);
436            TreeStructure::Standard { nodes, root }
437        }
438    };
439
440    Ok(DecisionTreeClassifier {
441        algorithm,
442        criterion,
443        class_labels,
444        structure,
445        options,
446        num_features: train_set.n_features(),
447        feature_preprocessing: capture_feature_preprocessing(train_set),
448        training_canaries: train_set.canaries(),
449    })
450}
451
452impl DecisionTreeClassifier {
453    pub fn algorithm(&self) -> DecisionTreeAlgorithm {
454        self.algorithm
455    }
456
457    pub fn criterion(&self) -> Criterion {
458        self.criterion
459    }
460
461    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
462        (0..table.n_rows())
463            .map(|row_idx| self.predict_row(table, row_idx))
464            .collect()
465    }
466
467    pub fn predict_proba_table(&self, table: &dyn TableAccess) -> Vec<Vec<f64>> {
468        (0..table.n_rows())
469            .map(|row_idx| self.predict_proba_row(table, row_idx))
470            .collect()
471    }
472
473    fn predict_row(&self, table: &dyn TableAccess, row_idx: usize) -> f64 {
474        match &self.structure {
475            TreeStructure::Standard { nodes, root } => {
476                let mut node_index = *root;
477                loop {
478                    match &nodes[node_index] {
479                        TreeNode::Leaf { class_index, .. } => {
480                            return self.class_labels[*class_index];
481                        }
482                        TreeNode::MultiwaySplit {
483                            feature_index,
484                            fallback_class_index,
485                            branches,
486                            ..
487                        } => {
488                            let bin = table.binned_value(*feature_index, row_idx);
489                            if let Some((_, child_index)) =
490                                branches.iter().find(|(branch_bin, _)| *branch_bin == bin)
491                            {
492                                node_index = *child_index;
493                            } else {
494                                return self.class_labels[*fallback_class_index];
495                            }
496                        }
497                        TreeNode::BinarySplit {
498                            feature_index,
499                            threshold_bin,
500                            left_child,
501                            right_child,
502                            ..
503                        } => {
504                            let bin = table.binned_value(*feature_index, row_idx);
505                            node_index = if bin <= *threshold_bin {
506                                *left_child
507                            } else {
508                                *right_child
509                            };
510                        }
511                    }
512                }
513            }
514            TreeStructure::Oblivious {
515                splits,
516                leaf_class_indices,
517                ..
518            } => {
519                let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
520                    let go_right =
521                        table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
522                    (leaf_index << 1) | usize::from(go_right)
523                });
524
525                self.class_labels[leaf_class_indices[leaf_index]]
526            }
527        }
528    }
529
530    fn predict_proba_row(&self, table: &dyn TableAccess, row_idx: usize) -> Vec<f64> {
531        match &self.structure {
532            TreeStructure::Standard { nodes, root } => {
533                let mut node_index = *root;
534                loop {
535                    match &nodes[node_index] {
536                        TreeNode::Leaf { class_counts, .. } => {
537                            return normalized_class_probabilities(class_counts);
538                        }
539                        TreeNode::MultiwaySplit {
540                            feature_index,
541                            branches,
542                            class_counts,
543                            ..
544                        } => {
545                            let bin = table.binned_value(*feature_index, row_idx);
546                            if let Some((_, child_index)) =
547                                branches.iter().find(|(branch_bin, _)| *branch_bin == bin)
548                            {
549                                node_index = *child_index;
550                            } else {
551                                return normalized_class_probabilities(class_counts);
552                            }
553                        }
554                        TreeNode::BinarySplit {
555                            feature_index,
556                            threshold_bin,
557                            left_child,
558                            right_child,
559                            ..
560                        } => {
561                            let bin = table.binned_value(*feature_index, row_idx);
562                            node_index = if bin <= *threshold_bin {
563                                *left_child
564                            } else {
565                                *right_child
566                            };
567                        }
568                    }
569                }
570            }
571            TreeStructure::Oblivious {
572                splits,
573                leaf_class_counts,
574                ..
575            } => {
576                let leaf_index = splits.iter().fold(0usize, |leaf_index, split| {
577                    let go_right =
578                        table.binned_value(split.feature_index, row_idx) > split.threshold_bin;
579                    (leaf_index << 1) | usize::from(go_right)
580                });
581
582                normalized_class_probabilities(&leaf_class_counts[leaf_index])
583            }
584        }
585    }
586
587    pub(crate) fn class_labels(&self) -> &[f64] {
588        &self.class_labels
589    }
590
591    pub(crate) fn structure(&self) -> &TreeStructure {
592        &self.structure
593    }
594
595    pub(crate) fn num_features(&self) -> usize {
596        self.num_features
597    }
598
599    pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
600        &self.feature_preprocessing
601    }
602
603    pub(crate) fn training_metadata(&self) -> TrainingMetadata {
604        TrainingMetadata {
605            algorithm: "dt".to_string(),
606            task: "classification".to_string(),
607            tree_type: tree_type_name(match self.algorithm {
608                DecisionTreeAlgorithm::Id3 => crate::TreeType::Id3,
609                DecisionTreeAlgorithm::C45 => crate::TreeType::C45,
610                DecisionTreeAlgorithm::Cart => crate::TreeType::Cart,
611                DecisionTreeAlgorithm::Randomized => crate::TreeType::Randomized,
612                DecisionTreeAlgorithm::Oblivious => crate::TreeType::Oblivious,
613            })
614            .to_string(),
615            criterion: criterion_name(self.criterion).to_string(),
616            canaries: self.training_canaries,
617            compute_oob: false,
618            max_depth: Some(self.options.max_depth),
619            min_samples_split: Some(self.options.min_samples_split),
620            min_samples_leaf: Some(self.options.min_samples_leaf),
621            n_trees: None,
622            max_features: self.options.max_features,
623            seed: None,
624            oob_score: None,
625            class_labels: Some(self.class_labels.clone()),
626            learning_rate: None,
627            bootstrap: None,
628            top_gradient_fraction: None,
629            other_gradient_fraction: None,
630        }
631    }
632
633    pub(crate) fn to_ir_tree(&self) -> TreeDefinition {
634        match &self.structure {
635            TreeStructure::Standard { nodes, root } => {
636                let depths = standard_node_depths(nodes, *root);
637                TreeDefinition::NodeTree {
638                    tree_id: 0,
639                    weight: 1.0,
640                    root_node_id: *root,
641                    nodes: nodes
642                        .iter()
643                        .enumerate()
644                        .map(|(node_id, node)| match node {
645                            TreeNode::Leaf {
646                                class_index,
647                                sample_count,
648                                class_counts,
649                            } => NodeTreeNode::Leaf {
650                                node_id,
651                                depth: depths[node_id],
652                                leaf: self.class_leaf(*class_index),
653                                stats: NodeStats {
654                                    sample_count: *sample_count,
655                                    impurity: None,
656                                    gain: None,
657                                    class_counts: Some(class_counts.clone()),
658                                    variance: None,
659                                },
660                            },
661                            TreeNode::BinarySplit {
662                                feature_index,
663                                threshold_bin,
664                                left_child,
665                                right_child,
666                                sample_count,
667                                impurity,
668                                gain,
669                                class_counts,
670                            } => NodeTreeNode::BinaryBranch {
671                                node_id,
672                                depth: depths[node_id],
673                                split: binary_split_ir(
674                                    *feature_index,
675                                    *threshold_bin,
676                                    &self.feature_preprocessing,
677                                ),
678                                children: BinaryChildren {
679                                    left: *left_child,
680                                    right: *right_child,
681                                },
682                                stats: NodeStats {
683                                    sample_count: *sample_count,
684                                    impurity: Some(*impurity),
685                                    gain: Some(*gain),
686                                    class_counts: Some(class_counts.clone()),
687                                    variance: None,
688                                },
689                            },
690                            TreeNode::MultiwaySplit {
691                                feature_index,
692                                fallback_class_index,
693                                branches,
694                                sample_count,
695                                impurity,
696                                gain,
697                                class_counts,
698                            } => NodeTreeNode::MultiwayBranch {
699                                node_id,
700                                depth: depths[node_id],
701                                split: MultiwaySplit {
702                                    split_type: "binned_value_multiway".to_string(),
703                                    feature_index: *feature_index,
704                                    feature_name: feature_name(*feature_index),
705                                    comparison_dtype: "uint16".to_string(),
706                                },
707                                branches: branches
708                                    .iter()
709                                    .map(|(bin, child)| MultiwayBranch {
710                                        bin: *bin,
711                                        child: *child,
712                                    })
713                                    .collect(),
714                                unmatched_leaf: self.class_leaf(*fallback_class_index),
715                                stats: NodeStats {
716                                    sample_count: *sample_count,
717                                    impurity: Some(*impurity),
718                                    gain: Some(*gain),
719                                    class_counts: Some(class_counts.clone()),
720                                    variance: None,
721                                },
722                            },
723                        })
724                        .collect(),
725                }
726            }
727            TreeStructure::Oblivious {
728                splits,
729                leaf_class_indices,
730                leaf_sample_counts,
731                leaf_class_counts,
732            } => TreeDefinition::ObliviousLevels {
733                tree_id: 0,
734                weight: 1.0,
735                depth: splits.len(),
736                levels: splits
737                    .iter()
738                    .enumerate()
739                    .map(|(level, split)| ObliviousLevel {
740                        level,
741                        split: oblivious_split_ir(
742                            split.feature_index,
743                            split.threshold_bin,
744                            &self.feature_preprocessing,
745                        ),
746                        stats: NodeStats {
747                            sample_count: split.sample_count,
748                            impurity: Some(split.impurity),
749                            gain: Some(split.gain),
750                            class_counts: None,
751                            variance: None,
752                        },
753                    })
754                    .collect(),
755                leaf_indexing: LeafIndexing {
756                    bit_order: "msb_first".to_string(),
757                    index_formula: "sum(bit[level] << (depth - 1 - level))".to_string(),
758                },
759                leaves: leaf_class_indices
760                    .iter()
761                    .enumerate()
762                    .map(|(leaf_index, class_index)| IndexedLeaf {
763                        leaf_index,
764                        leaf: self.class_leaf(*class_index),
765                        stats: NodeStats {
766                            sample_count: leaf_sample_counts[leaf_index],
767                            impurity: None,
768                            gain: None,
769                            class_counts: Some(leaf_class_counts[leaf_index].clone()),
770                            variance: None,
771                        },
772                    })
773                    .collect(),
774            },
775        }
776    }
777
778    fn class_leaf(&self, class_index: usize) -> LeafPayload {
779        LeafPayload::ClassIndex {
780            class_index,
781            class_value: self.class_labels[class_index],
782        }
783    }
784
785    #[allow(clippy::too_many_arguments)]
786    pub(crate) fn from_ir_parts(
787        algorithm: DecisionTreeAlgorithm,
788        criterion: Criterion,
789        class_labels: Vec<f64>,
790        structure: TreeStructure,
791        options: DecisionTreeOptions,
792        num_features: usize,
793        feature_preprocessing: Vec<FeaturePreprocessing>,
794        training_canaries: usize,
795    ) -> Self {
796        Self {
797            algorithm,
798            criterion,
799            class_labels,
800            structure,
801            options,
802            num_features,
803            feature_preprocessing,
804            training_canaries,
805        }
806    }
807}
808
809fn build_binary_node_in_place(
810    context: &BuildContext<'_>,
811    nodes: &mut Vec<TreeNode>,
812    rows: &mut [usize],
813    depth: usize,
814) -> usize {
815    build_binary_node_in_place_with_hist(context, nodes, rows, depth, None)
816}
817
818fn build_binary_node_in_place_with_hist(
819    context: &BuildContext<'_>,
820    nodes: &mut Vec<TreeNode>,
821    rows: &mut [usize],
822    depth: usize,
823    histograms: Option<Vec<ClassificationFeatureHistogram>>,
824) -> usize {
825    let majority_class_index =
826        majority_class(rows, context.class_indices, context.class_labels.len());
827    let current_class_counts =
828        class_counts(rows, context.class_indices, context.class_labels.len());
829
830    if rows.is_empty()
831        || depth >= context.options.max_depth
832        || rows.len() < context.options.min_samples_split
833        || is_pure(rows, context.class_indices)
834    {
835        return push_leaf(
836            nodes,
837            majority_class_index,
838            rows.len(),
839            current_class_counts,
840        );
841    }
842
843    let scoring = SplitScoringContext {
844        table: context.table,
845        class_indices: context.class_indices,
846        num_classes: context.class_labels.len(),
847        criterion: context.criterion,
848        min_samples_leaf: context.options.min_samples_leaf,
849    };
850    let histograms = histograms.unwrap_or_else(|| {
851        build_classification_node_histograms(
852            context.table,
853            context.class_indices,
854            rows,
855            context.class_labels.len(),
856        )
857    });
858    let feature_indices = candidate_feature_indices(
859        context.table.binned_feature_count(),
860        context.options.max_features,
861        node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
862    );
863    let best_split = if context.parallelism.enabled() {
864        feature_indices
865            .into_par_iter()
866            .filter_map(|feature_index| {
867                score_binary_split_choice_from_hist(
868                    &scoring,
869                    &histograms[feature_index],
870                    feature_index,
871                    rows,
872                    &current_class_counts,
873                    context.algorithm,
874                )
875            })
876            .max_by(|left, right| left.score.total_cmp(&right.score))
877    } else {
878        feature_indices
879            .into_iter()
880            .filter_map(|feature_index| {
881                score_binary_split_choice_from_hist(
882                    &scoring,
883                    &histograms[feature_index],
884                    feature_index,
885                    rows,
886                    &current_class_counts,
887                    context.algorithm,
888                )
889            })
890            .max_by(|left, right| left.score.total_cmp(&right.score))
891    };
892
893    match best_split {
894        Some(best_split)
895            if context
896                .table
897                .is_canary_binned_feature(best_split.feature_index) =>
898        {
899            push_leaf(
900                nodes,
901                majority_class_index,
902                rows.len(),
903                current_class_counts,
904            )
905        }
906        Some(best_split) if best_split.score > 0.0 => {
907            let impurity =
908                classification_impurity(&current_class_counts, rows.len(), context.criterion);
909            let left_count = partition_rows_for_binary_split(
910                context.table,
911                best_split.feature_index,
912                best_split.threshold_bin,
913                rows,
914            );
915            let (left_rows, right_rows) = rows.split_at_mut(left_count);
916            let (left_histograms, right_histograms) = if left_rows.len() <= right_rows.len() {
917                let left_histograms = build_classification_node_histograms(
918                    context.table,
919                    context.class_indices,
920                    left_rows,
921                    context.class_labels.len(),
922                );
923                let right_histograms =
924                    subtract_classification_node_histograms(&histograms, &left_histograms);
925                (left_histograms, right_histograms)
926            } else {
927                let right_histograms = build_classification_node_histograms(
928                    context.table,
929                    context.class_indices,
930                    right_rows,
931                    context.class_labels.len(),
932                );
933                let left_histograms =
934                    subtract_classification_node_histograms(&histograms, &right_histograms);
935                (left_histograms, right_histograms)
936            };
937            let left_child = build_binary_node_in_place_with_hist(
938                context,
939                nodes,
940                left_rows,
941                depth + 1,
942                Some(left_histograms),
943            );
944            let right_child = build_binary_node_in_place_with_hist(
945                context,
946                nodes,
947                right_rows,
948                depth + 1,
949                Some(right_histograms),
950            );
951
952            push_node(
953                nodes,
954                TreeNode::BinarySplit {
955                    feature_index: best_split.feature_index,
956                    threshold_bin: best_split.threshold_bin,
957                    left_child,
958                    right_child,
959                    sample_count: rows.len(),
960                    impurity,
961                    gain: best_split.score,
962                    class_counts: current_class_counts,
963                },
964            )
965        }
966        _ => push_leaf(
967            nodes,
968            majority_class_index,
969            rows.len(),
970            current_class_counts,
971        ),
972    }
973}
974
975fn build_multiway_node_in_place(
976    context: &BuildContext<'_>,
977    nodes: &mut Vec<TreeNode>,
978    rows: &mut [usize],
979    depth: usize,
980) -> usize {
981    let majority_class_index =
982        majority_class(rows, context.class_indices, context.class_labels.len());
983    let current_class_counts =
984        class_counts(rows, context.class_indices, context.class_labels.len());
985
986    if rows.is_empty()
987        || depth >= context.options.max_depth
988        || rows.len() < context.options.min_samples_split
989        || is_pure(rows, context.class_indices)
990    {
991        return push_leaf(
992            nodes,
993            majority_class_index,
994            rows.len(),
995            current_class_counts,
996        );
997    }
998
999    let metric = match context.algorithm {
1000        DecisionTreeAlgorithm::Id3 => MultiwayMetric::InformationGain,
1001        DecisionTreeAlgorithm::C45 => MultiwayMetric::GainRatio,
1002        _ => unreachable!("multiway builder only supports id3/c45"),
1003    };
1004    let scoring = SplitScoringContext {
1005        table: context.table,
1006        class_indices: context.class_indices,
1007        num_classes: context.class_labels.len(),
1008        criterion: context.criterion,
1009        min_samples_leaf: context.options.min_samples_leaf,
1010    };
1011    let feature_indices = candidate_feature_indices(
1012        context.table.binned_feature_count(),
1013        context.options.max_features,
1014        node_seed(context.options.random_seed, depth, rows, 0xC1A5_5EEDu64),
1015    );
1016    let best_split = if context.parallelism.enabled() {
1017        feature_indices
1018            .into_par_iter()
1019            .filter_map(|feature_index| {
1020                score_multiway_split_choice(&scoring, feature_index, rows, metric)
1021            })
1022            .max_by(|left, right| left.score.total_cmp(&right.score))
1023    } else {
1024        feature_indices
1025            .into_iter()
1026            .filter_map(|feature_index| {
1027                score_multiway_split_choice(&scoring, feature_index, rows, metric)
1028            })
1029            .max_by(|left, right| left.score.total_cmp(&right.score))
1030    };
1031
1032    match best_split {
1033        Some(best_split)
1034            if context
1035                .table
1036                .is_canary_binned_feature(best_split.feature_index) =>
1037        {
1038            push_leaf(
1039                nodes,
1040                majority_class_index,
1041                rows.len(),
1042                current_class_counts,
1043            )
1044        }
1045        Some(best_split) if best_split.score > 0.0 => {
1046            let impurity =
1047                classification_impurity(&current_class_counts, rows.len(), context.criterion);
1048            let branch_ranges = partition_rows_for_multiway_split(
1049                context.table,
1050                best_split.feature_index,
1051                &best_split.branch_bins,
1052                rows,
1053            );
1054            let mut branch_nodes = Vec::with_capacity(branch_ranges.len());
1055            for (bin, start, end) in branch_ranges {
1056                let child =
1057                    build_multiway_node_in_place(context, nodes, &mut rows[start..end], depth + 1);
1058                branch_nodes.push((bin, child));
1059            }
1060
1061            push_node(
1062                nodes,
1063                TreeNode::MultiwaySplit {
1064                    feature_index: best_split.feature_index,
1065                    fallback_class_index: majority_class_index,
1066                    branches: branch_nodes,
1067                    sample_count: rows.len(),
1068                    impurity,
1069                    gain: best_split.score,
1070                    class_counts: current_class_counts,
1071                },
1072            )
1073        }
1074        _ => push_leaf(
1075            nodes,
1076            majority_class_index,
1077            rows.len(),
1078            current_class_counts,
1079        ),
1080    }
1081}
1082
1083struct BuildContext<'a> {
1084    table: &'a dyn TableAccess,
1085    class_indices: &'a [usize],
1086    class_labels: &'a [f64],
1087    algorithm: DecisionTreeAlgorithm,
1088    criterion: Criterion,
1089    parallelism: Parallelism,
1090    options: DecisionTreeOptions,
1091}
1092
1093fn encode_class_labels(
1094    train_set: &dyn TableAccess,
1095) -> Result<(Vec<f64>, Vec<usize>), DecisionTreeError> {
1096    let targets: Vec<f64> = (0..train_set.n_rows())
1097        .map(|row_idx| {
1098            let value = train_set.target_value(row_idx);
1099            if value.is_finite() {
1100                Ok(value)
1101            } else {
1102                Err(DecisionTreeError::InvalidTargetValue {
1103                    row: row_idx,
1104                    value,
1105                })
1106            }
1107        })
1108        .collect::<Result<_, _>>()?;
1109
1110    let class_labels = targets
1111        .iter()
1112        .copied()
1113        .fold(Vec::<f64>::new(), |mut labels, value| {
1114            if labels
1115                .binary_search_by(|candidate| candidate.total_cmp(&value))
1116                .is_err()
1117            {
1118                labels.push(value);
1119                labels.sort_by(|left, right| left.total_cmp(right));
1120            }
1121            labels
1122        });
1123
1124    let class_indices = targets
1125        .iter()
1126        .map(|value| {
1127            class_labels
1128                .binary_search_by(|candidate| candidate.total_cmp(value))
1129                .expect("target value must exist in class label vocabulary")
1130        })
1131        .collect();
1132
1133    Ok((class_labels, class_indices))
1134}
1135
1136fn class_counts(rows: &[usize], class_indices: &[usize], num_classes: usize) -> Vec<usize> {
1137    rows.iter()
1138        .fold(vec![0usize; num_classes], |mut counts, row_idx| {
1139            counts[class_indices[*row_idx]] += 1;
1140            counts
1141        })
1142}
1143
1144fn majority_class(rows: &[usize], class_indices: &[usize], num_classes: usize) -> usize {
1145    majority_class_from_counts(&class_counts(rows, class_indices, num_classes))
1146}
1147
1148fn majority_class_from_counts(counts: &[usize]) -> usize {
1149    counts
1150        .iter()
1151        .copied()
1152        .enumerate()
1153        .max_by(|left, right| left.1.cmp(&right.1).then_with(|| right.0.cmp(&left.0)))
1154        .map(|(class_index, _count)| class_index)
1155        .unwrap_or(0)
1156}
1157
1158fn is_pure(rows: &[usize], class_indices: &[usize]) -> bool {
1159    rows.first().is_none_or(|first_row| {
1160        rows.iter()
1161            .all(|row_idx| class_indices[*row_idx] == class_indices[*first_row])
1162    })
1163}
1164
1165fn entropy(counts: &[usize], total: usize) -> f64 {
1166    counts
1167        .iter()
1168        .copied()
1169        .filter(|count| *count > 0)
1170        .map(|count| {
1171            let probability = count as f64 / total as f64;
1172            -probability * probability.log2()
1173        })
1174        .sum()
1175}
1176
1177fn gini(counts: &[usize], total: usize) -> f64 {
1178    1.0 - counts
1179        .iter()
1180        .copied()
1181        .map(|count| {
1182            let probability = count as f64 / total as f64;
1183            probability * probability
1184        })
1185        .sum::<f64>()
1186}
1187
1188fn classification_impurity(counts: &[usize], total: usize, criterion: Criterion) -> f64 {
1189    match criterion {
1190        Criterion::Entropy => entropy(counts, total),
1191        Criterion::Gini => gini(counts, total),
1192        _ => unreachable!("classification impurity only supports gini or entropy"),
1193    }
1194}
1195
1196fn push_leaf(
1197    nodes: &mut Vec<TreeNode>,
1198    class_index: usize,
1199    sample_count: usize,
1200    class_counts: Vec<usize>,
1201) -> usize {
1202    push_node(
1203        nodes,
1204        TreeNode::Leaf {
1205            class_index,
1206            sample_count,
1207            class_counts,
1208        },
1209    )
1210}
1211
1212fn push_node(nodes: &mut Vec<TreeNode>, node: TreeNode) -> usize {
1213    nodes.push(node);
1214    nodes.len() - 1
1215}
1216
1217#[cfg(test)]
1218mod tests;