use super::{Leaf, Node, RegressionLeaf, RegressionNode, RegressionTreeNode, TreeNode};
use std::collections::HashSet;
pub(super) fn flatten_tree_node(
node: &TreeNode,
features: &mut Vec<f32>,
thresholds: &mut Vec<f32>,
classes: &mut Vec<f32>,
samples: &mut Vec<f32>,
left_children: &mut Vec<f32>,
right_children: &mut Vec<f32>,
) -> usize {
let current_idx = features.len();
match node {
TreeNode::Leaf(leaf) => {
features.push(-1.0);
thresholds.push(0.0);
classes.push(leaf.class_label as f32);
samples.push(leaf.n_samples as f32);
left_children.push(-1.0);
right_children.push(-1.0);
}
TreeNode::Node(internal) => {
features.push(internal.feature_idx as f32);
thresholds.push(internal.threshold);
classes.push(0.0); samples.push(0.0); left_children.push(0.0); right_children.push(0.0);
let left_idx = flatten_tree_node(
&internal.left,
features,
thresholds,
classes,
samples,
left_children,
right_children,
);
let right_idx = flatten_tree_node(
&internal.right,
features,
thresholds,
classes,
samples,
left_children,
right_children,
);
left_children[current_idx] = left_idx as f32;
right_children[current_idx] = right_idx as f32;
}
}
current_idx
}
pub(super) fn reconstruct_tree_node(
idx: usize,
features: &[f32],
thresholds: &[f32],
classes: &[f32],
samples: &[f32],
left_children: &[f32],
right_children: &[f32],
) -> TreeNode {
if features[idx] < 0.0 {
TreeNode::Leaf(Leaf {
class_label: classes[idx] as usize,
n_samples: samples[idx] as usize,
})
} else {
let left_idx = left_children[idx] as usize;
let right_idx = right_children[idx] as usize;
TreeNode::Node(Node {
feature_idx: features[idx] as usize,
threshold: thresholds[idx],
left: Box::new(reconstruct_tree_node(
left_idx,
features,
thresholds,
classes,
samples,
left_children,
right_children,
)),
right: Box::new(reconstruct_tree_node(
right_idx,
features,
thresholds,
classes,
samples,
left_children,
right_children,
)),
})
}
}
pub fn gini_impurity(labels: &[usize]) -> f32 {
contract_pre_gini_impurity!();
if labels.is_empty() {
return 0.0;
}
let mut counts = std::collections::BTreeMap::new();
for &label in labels {
*counts.entry(label).or_insert(0) += 1;
}
let n = labels.len() as f32;
let mut gini = 1.0;
for count in counts.values() {
let p = *count as f32 / n;
gini -= p * p;
}
gini
}
pub fn gini_split(left_labels: &[usize], right_labels: &[usize]) -> f32 {
contract_pre_gini_split!();
let n_left = left_labels.len() as f32;
let n_right = right_labels.len() as f32;
let n_total = n_left + n_right;
if n_total == 0.0 {
return 0.0;
}
let weight_left = n_left / n_total;
let weight_right = n_right / n_total;
weight_left * gini_impurity(left_labels) + weight_right * gini_impurity(right_labels)
}
pub(super) fn get_sorted_unique_values(x: &[f32]) -> Vec<f32> {
let mut sorted_indices: Vec<usize> = (0..x.len()).collect();
sorted_indices.sort_by(|&a, &b| {
x[a].partial_cmp(&x[b])
.expect("f32 values should be comparable")
});
let mut unique_values = Vec::new();
let mut prev_val = x[sorted_indices[0]];
unique_values.push(prev_val);
for &idx in sorted_indices.get(1..).unwrap_or(&[]) {
if (x[idx] - prev_val).abs() > 1e-10 {
unique_values.push(x[idx]);
prev_val = x[idx];
}
}
unique_values
}
pub(super) fn split_labels_by_threshold(
x: &[f32],
y: &[usize],
threshold: f32,
) -> Option<(Vec<usize>, Vec<usize>)> {
let mut left_labels = Vec::new();
let mut right_labels = Vec::new();
for (idx, &val) in x.iter().enumerate() {
if val <= threshold {
left_labels.push(y[idx]);
} else {
right_labels.push(y[idx]);
}
}
if left_labels.is_empty() || right_labels.is_empty() {
None
} else {
Some((left_labels, right_labels))
}
}
pub(super) fn calculate_information_gain(
current_impurity: f32,
left_labels: &[usize],
right_labels: &[usize],
) -> f32 {
let split_impurity = gini_split(left_labels, right_labels);
current_impurity - split_impurity
}
pub(super) fn find_best_split_for_feature(x: &[f32], y: &[usize]) -> Option<(f32, f32)> {
if x.len() < 2 {
return None;
}
let unique_values = get_sorted_unique_values(x);
if unique_values.len() < 2 {
return None;
}
let current_impurity = gini_impurity(y);
let mut best_gain = 0.0;
let mut best_threshold = 0.0;
for i in 0..unique_values.len() - 1 {
let threshold = f32::midpoint(unique_values[i], unique_values[i + 1]);
if let Some((left_labels, right_labels)) = split_labels_by_threshold(x, y, threshold) {
let gain = calculate_information_gain(current_impurity, &left_labels, &right_labels);
if gain > best_gain {
best_gain = gain;
best_threshold = threshold;
}
}
}
if best_gain > 0.0 {
Some((best_threshold, best_gain))
} else {
None
}
}
pub(super) fn find_best_split(
x_matrix: &crate::primitives::Matrix<f32>,
y: &[usize],
) -> Option<(usize, f32, f32)> {
let (n_samples, n_features) = x_matrix.shape();
if n_samples < 2 {
return None;
}
let mut best_gain = 0.0;
let mut best_feature = 0;
let mut best_threshold = 0.0;
for feature_idx in 0..n_features {
let mut feature_values = Vec::with_capacity(n_samples);
for row in 0..n_samples {
feature_values.push(x_matrix.get(row, feature_idx));
}
if let Some((threshold, gain)) = find_best_split_for_feature(&feature_values, y) {
if gain > best_gain {
best_gain = gain;
best_feature = feature_idx;
best_threshold = threshold;
}
}
}
if best_gain > 0.0 {
Some((best_feature, best_threshold, best_gain))
} else {
None
}
}
pub(super) fn majority_class(labels: &[usize]) -> usize {
let mut counts = std::collections::BTreeMap::new();
for &label in labels {
*counts.entry(label).or_insert(0usize) += 1;
}
counts
.into_iter()
.max_by_key(|&(_, count)| count)
.expect("at least one label should exist")
.0
}
pub(super) fn split_data_by_indices(
x: &crate::primitives::Matrix<f32>,
y: &[usize],
indices: &[usize],
) -> (crate::primitives::Matrix<f32>, Vec<usize>) {
let n_cols = x.shape().1;
let mut data = Vec::with_capacity(indices.len() * n_cols);
let mut labels = Vec::with_capacity(indices.len());
for &idx in indices {
for col in 0..n_cols {
data.push(x.get(idx, col));
}
labels.push(y[idx]);
}
let matrix = crate::primitives::Matrix::from_vec(indices.len(), n_cols, data)
.expect("matrix creation should succeed with valid indices");
(matrix, labels)
}
pub(super) fn check_stopping_criteria(
y: &[usize],
depth: usize,
max_depth: Option<usize>,
) -> Option<TreeNode> {
let n_samples = y.len();
let unique_labels: HashSet<_> = y.iter().collect();
if unique_labels.len() == 1 {
return Some(TreeNode::Leaf(Leaf {
class_label: y[0],
n_samples,
}));
}
if let Some(max_d) = max_depth {
if depth >= max_d {
return Some(TreeNode::Leaf(Leaf {
class_label: majority_class(y),
n_samples,
}));
}
}
None
}
pub(super) fn split_indices_by_threshold(
x: &crate::primitives::Matrix<f32>,
feature_idx: usize,
threshold: f32,
n_samples: usize,
) -> Option<(Vec<usize>, Vec<usize>)> {
let mut left_indices = Vec::new();
let mut right_indices = Vec::new();
for row in 0..n_samples {
if x.get(row, feature_idx) <= threshold {
left_indices.push(row);
} else {
right_indices.push(row);
}
}
if left_indices.is_empty() || right_indices.is_empty() {
None
} else {
Some((left_indices, right_indices))
}
}
pub(super) fn build_tree(
x: &crate::primitives::Matrix<f32>,
y: &[usize],
depth: usize,
max_depth: Option<usize>,
) -> TreeNode {
let n_samples = y.len();
if let Some(leaf) = check_stopping_criteria(y, depth, max_depth) {
return leaf;
}
let Some((feature_idx, threshold, _gain)) = find_best_split(x, y) else {
return TreeNode::Leaf(Leaf {
class_label: majority_class(y),
n_samples,
});
};
let Some((left_indices, right_indices)) =
split_indices_by_threshold(x, feature_idx, threshold, n_samples)
else {
return TreeNode::Leaf(Leaf {
class_label: majority_class(y),
n_samples,
});
};
let (left_matrix, left_labels) = split_data_by_indices(x, y, &left_indices);
let (right_matrix, right_labels) = split_data_by_indices(x, y, &right_indices);
let left_child = build_tree(&left_matrix, &left_labels, depth + 1, max_depth);
let right_child = build_tree(&right_matrix, &right_labels, depth + 1, max_depth);
TreeNode::Node(Node {
feature_idx,
threshold,
left: Box::new(left_child),
right: Box::new(right_child),
})
}
include!("regression_helpers.rs");
include!("helpers_tests.rs");