1use anofox_ml_core::Float;
2
3#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
5#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
6pub enum TreeNode<F: Float> {
7 Split {
9 feature_index: usize,
10 threshold: F,
11 left: Box<TreeNode<F>>,
12 right: Box<TreeNode<F>>,
13 n_samples: usize,
15 impurity: F,
17 },
18 Leaf {
20 value: F,
21 n_samples: usize,
23 class_counts: Option<Vec<(F, usize)>>,
25 },
26}
27
28impl<F: Float> TreeNode<F> {
29 #[inline]
31 pub fn predict_one(&self, features: &[F]) -> F {
32 match self {
33 TreeNode::Leaf { value, .. } => *value,
34 TreeNode::Split {
35 feature_index,
36 threshold,
37 left,
38 right,
39 ..
40 } => {
41 if features[*feature_index] <= *threshold {
42 left.predict_one(features)
43 } else {
44 right.predict_one(features)
45 }
46 }
47 }
48 }
49
50 pub fn feature_importances(&self, n_features: usize, total_samples: usize) -> Vec<F> {
52 let mut importances = vec![F::zero(); n_features];
53 self.accumulate_importances(&mut importances, total_samples);
54 importances
55 }
56
57 fn accumulate_importances(&self, importances: &mut [F], total_samples: usize) {
58 if let TreeNode::Split {
59 feature_index,
60 left,
61 right,
62 n_samples,
63 impurity,
64 ..
65 } = self
66 {
67 let left_samples = node_samples(left);
68 let right_samples = node_samples(right);
69 let left_impurity = node_impurity(left);
70 let right_impurity = node_impurity(right);
71
72 let n = num_traits::FromPrimitive::from_usize(*n_samples).unwrap_or(F::one());
73 let nl = num_traits::FromPrimitive::from_usize(left_samples).unwrap_or(F::zero());
74 let nr = num_traits::FromPrimitive::from_usize(right_samples).unwrap_or(F::zero());
75 let total = num_traits::FromPrimitive::from_usize(total_samples).unwrap_or(F::one());
76
77 let decrease =
79 (n / total) * (*impurity - (nl / n) * left_impurity - (nr / n) * right_impurity);
80
81 importances[*feature_index] += decrease;
82
83 left.accumulate_importances(importances, total_samples);
84 right.accumulate_importances(importances, total_samples);
85 }
86 }
87}
88
89fn node_samples<F: Float>(node: &TreeNode<F>) -> usize {
90 match node {
91 TreeNode::Leaf { n_samples, .. } => *n_samples,
92 TreeNode::Split { n_samples, .. } => *n_samples,
93 }
94}
95
96fn node_impurity<F: Float>(node: &TreeNode<F>) -> F {
97 match node {
98 TreeNode::Leaf { .. } => F::zero(),
99 TreeNode::Split { impurity, .. } => *impurity,
100 }
101}