Skip to main content

anofox_ml_trees/
node.rs

1use anofox_ml_core::Float;
2
3/// A node in a decision tree.
4#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
5#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
6pub enum TreeNode<F: Float> {
7    /// Internal node: split on a feature at a threshold.
8    Split {
9        feature_index: usize,
10        threshold: F,
11        left: Box<TreeNode<F>>,
12        right: Box<TreeNode<F>>,
13        /// Number of training samples that reached this node.
14        n_samples: usize,
15        /// Impurity at this node before splitting.
16        impurity: F,
17    },
18    /// Leaf node: predict a value.
19    Leaf {
20        value: F,
21        /// Number of training samples in this leaf.
22        n_samples: usize,
23        /// For classifiers: class distribution counts.
24        class_counts: Option<Vec<(F, usize)>>,
25    },
26}
27
28impl<F: Float> TreeNode<F> {
29    /// Predict a single sample by traversing the tree.
30    #[inline]
31    pub fn predict_one(&self, features: &[F]) -> F {
32        match self {
33            TreeNode::Leaf { value, .. } => *value,
34            TreeNode::Split {
35                feature_index,
36                threshold,
37                left,
38                right,
39                ..
40            } => {
41                if features[*feature_index] <= *threshold {
42                    left.predict_one(features)
43                } else {
44                    right.predict_one(features)
45                }
46            }
47        }
48    }
49
50    /// Compute feature importances by accumulating weighted impurity decreases.
51    pub fn feature_importances(&self, n_features: usize, total_samples: usize) -> Vec<F> {
52        let mut importances = vec![F::zero(); n_features];
53        self.accumulate_importances(&mut importances, total_samples);
54        importances
55    }
56
57    fn accumulate_importances(&self, importances: &mut [F], total_samples: usize) {
58        if let TreeNode::Split {
59            feature_index,
60            left,
61            right,
62            n_samples,
63            impurity,
64            ..
65        } = self
66        {
67            let left_samples = node_samples(left);
68            let right_samples = node_samples(right);
69            let left_impurity = node_impurity(left);
70            let right_impurity = node_impurity(right);
71
72            let n = num_traits::FromPrimitive::from_usize(*n_samples).unwrap_or(F::one());
73            let nl = num_traits::FromPrimitive::from_usize(left_samples).unwrap_or(F::zero());
74            let nr = num_traits::FromPrimitive::from_usize(right_samples).unwrap_or(F::zero());
75            let total = num_traits::FromPrimitive::from_usize(total_samples).unwrap_or(F::one());
76
77            // Weighted impurity decrease
78            let decrease =
79                (n / total) * (*impurity - (nl / n) * left_impurity - (nr / n) * right_impurity);
80
81            importances[*feature_index] += decrease;
82
83            left.accumulate_importances(importances, total_samples);
84            right.accumulate_importances(importances, total_samples);
85        }
86    }
87}
88
89fn node_samples<F: Float>(node: &TreeNode<F>) -> usize {
90    match node {
91        TreeNode::Leaf { n_samples, .. } => *n_samples,
92        TreeNode::Split { n_samples, .. } => *n_samples,
93    }
94}
95
96fn node_impurity<F: Float>(node: &TreeNode<F>) -> F {
97    match node {
98        TreeNode::Leaf { .. } => F::zero(),
99        TreeNode::Split { impurity, .. } => *impurity,
100    }
101}