#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum TreeNode {
Leaf {
prediction: f64,
n_samples: usize,
class_counts: Vec<usize>,
impurity: f64,
},
Split {
feature_idx: usize,
threshold: f64,
left: Box<TreeNode>,
right: Box<TreeNode>,
n_samples: usize,
impurity: f64,
class_counts: Vec<usize>,
prediction: f64,
},
}
impl TreeNode {
pub fn predict(&self, sample: &[f64]) -> f64 {
match self {
TreeNode::Leaf { prediction, .. } => *prediction,
TreeNode::Split {
feature_idx,
threshold,
left,
right,
..
} => {
if sample[*feature_idx] <= *threshold {
left.predict(sample)
} else {
right.predict(sample)
}
}
}
}
pub fn predict_proba(&self, sample: &[f64], n_classes: usize) -> Vec<f64> {
match self {
TreeNode::Leaf {
class_counts,
n_samples,
..
} => {
let mut proba = vec![0.0; n_classes];
let total = *n_samples as f64;
for (i, &count) in class_counts.iter().enumerate() {
if i < n_classes {
proba[i] = count as f64 / total;
}
}
proba
}
TreeNode::Split {
feature_idx,
threshold,
left,
right,
..
} => {
if sample[*feature_idx] <= *threshold {
left.predict_proba(sample, n_classes)
} else {
right.predict_proba(sample, n_classes)
}
}
}
}
pub fn depth(&self) -> usize {
match self {
TreeNode::Leaf { .. } => 1,
TreeNode::Split { left, right, .. } => 1 + left.depth().max(right.depth()),
}
}
pub fn n_leaves(&self) -> usize {
match self {
TreeNode::Leaf { .. } => 1,
TreeNode::Split { left, right, .. } => left.n_leaves() + right.n_leaves(),
}
}
pub fn n_samples(&self) -> usize {
match self {
TreeNode::Leaf { n_samples, .. } | TreeNode::Split { n_samples, .. } => *n_samples,
}
}
pub fn total_leaf_impurity(&self) -> f64 {
match self {
TreeNode::Leaf {
impurity,
n_samples,
..
} => *impurity * (*n_samples as f64),
TreeNode::Split { left, right, .. } => {
left.total_leaf_impurity() + right.total_leaf_impurity()
}
}
}
pub fn prune_ccp(self, ccp_alpha: f64) -> TreeNode {
match self {
TreeNode::Leaf { .. } => self,
TreeNode::Split {
feature_idx,
threshold,
left,
right,
n_samples,
impurity,
class_counts,
prediction,
} => {
let pruned_left = left.prune_ccp(ccp_alpha);
let pruned_right = right.prune_ccp(ccp_alpha);
let subtree = TreeNode::Split {
feature_idx,
threshold,
left: Box::new(pruned_left),
right: Box::new(pruned_right),
n_samples,
impurity,
class_counts: class_counts.clone(),
prediction,
};
let n_leaves = subtree.n_leaves();
if n_leaves <= 1 {
return subtree;
}
let r_node = impurity * (n_samples as f64);
let r_subtree = subtree.total_leaf_impurity();
let effective_alpha = (r_node - r_subtree) / (n_leaves as f64 - 1.0);
if effective_alpha <= ccp_alpha {
TreeNode::Leaf {
prediction,
n_samples,
class_counts,
impurity,
}
} else {
subtree
}
}
}
}
pub fn cost_complexity_pruning_path(&self) -> (Vec<f64>, Vec<f64>) {
let mut alphas = vec![0.0];
let mut impurities = vec![self.total_leaf_impurity()];
let mut current = self.clone();
loop {
let min_alpha = Self::min_effective_alpha(¤t);
match min_alpha {
None => break, Some(alpha) => {
current = current.prune_ccp(alpha);
alphas.push(alpha);
impurities.push(current.total_leaf_impurity());
}
}
}
(alphas, impurities)
}
fn min_effective_alpha(node: &TreeNode) -> Option<f64> {
match node {
TreeNode::Leaf { .. } => None,
TreeNode::Split {
left,
right,
n_samples,
impurity,
..
} => {
let n_leaves = node.n_leaves();
let r_node = impurity * (*n_samples as f64);
let r_subtree = node.total_leaf_impurity();
let my_alpha = if n_leaves > 1 {
Some((r_node - r_subtree) / (n_leaves as f64 - 1.0))
} else {
None
};
let left_alpha = Self::min_effective_alpha(left);
let right_alpha = Self::min_effective_alpha(right);
[my_alpha, left_alpha, right_alpha]
.iter()
.filter_map(|a| *a)
.reduce(f64::min)
}
}
}
}