use anofox_ml_core::Float;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
pub enum TreeNode<F: Float> {
Split {
feature_index: usize,
threshold: F,
left: Box<TreeNode<F>>,
right: Box<TreeNode<F>>,
n_samples: usize,
impurity: F,
},
Leaf {
value: F,
n_samples: usize,
class_counts: Option<Vec<(F, usize)>>,
},
}
impl<F: Float> TreeNode<F> {
#[inline]
pub fn predict_one(&self, features: &[F]) -> F {
match self {
TreeNode::Leaf { value, .. } => *value,
TreeNode::Split {
feature_index,
threshold,
left,
right,
..
} => {
if features[*feature_index] <= *threshold {
left.predict_one(features)
} else {
right.predict_one(features)
}
}
}
}
pub fn feature_importances(&self, n_features: usize, total_samples: usize) -> Vec<F> {
let mut importances = vec![F::zero(); n_features];
self.accumulate_importances(&mut importances, total_samples);
importances
}
fn accumulate_importances(&self, importances: &mut [F], total_samples: usize) {
if let TreeNode::Split {
feature_index,
left,
right,
n_samples,
impurity,
..
} = self
{
let left_samples = node_samples(left);
let right_samples = node_samples(right);
let left_impurity = node_impurity(left);
let right_impurity = node_impurity(right);
let n = num_traits::FromPrimitive::from_usize(*n_samples).unwrap_or(F::one());
let nl = num_traits::FromPrimitive::from_usize(left_samples).unwrap_or(F::zero());
let nr = num_traits::FromPrimitive::from_usize(right_samples).unwrap_or(F::zero());
let total = num_traits::FromPrimitive::from_usize(total_samples).unwrap_or(F::one());
let decrease =
(n / total) * (*impurity - (nl / n) * left_impurity - (nr / n) * right_impurity);
importances[*feature_index] += decrease;
left.accumulate_importances(importances, total_samples);
right.accumulate_importances(importances, total_samples);
}
}
}
fn node_samples<F: Float>(node: &TreeNode<F>) -> usize {
match node {
TreeNode::Leaf { n_samples, .. } => *n_samples,
TreeNode::Split { n_samples, .. } => *n_samples,
}
}
fn node_impurity<F: Float>(node: &TreeNode<F>) -> F {
match node {
TreeNode::Leaf { .. } => F::zero(),
TreeNode::Split { impurity, .. } => *impurity,
}
}