use std::collections::HashMap;
use anofox_ml_core::Float;
use ndarray::{Array1, Array2};
#[inline]
fn float_key<F: Float>(v: F) -> u64 {
v.to_f64().unwrap().to_bits()
}
#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum SplitCriterion {
Gini,
Entropy,
Mse,
}
#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum SplitStrategy {
Best,
Random,
}
#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum MaxFeatures {
Sqrt,
Log2,
Fixed(usize),
Fraction(f64),
}
impl MaxFeatures {
pub fn resolve(&self, n_features: usize) -> usize {
match self {
MaxFeatures::Sqrt => (n_features as f64).sqrt().floor().max(1.0) as usize,
MaxFeatures::Log2 => (n_features as f64).log2().floor().max(1.0) as usize,
MaxFeatures::Fixed(k) => (*k).min(n_features).max(1),
MaxFeatures::Fraction(f) => (*f * n_features as f64).floor().max(1.0) as usize,
}
}
}
pub fn select_feature_subset(n_features: usize, k: usize, seed: u64) -> Vec<usize> {
if k >= n_features {
return (0..n_features).collect();
}
let mut indices: Vec<usize> = (0..n_features).collect();
let mut state = seed.wrapping_add(0x9E3779B97F4A7C15);
for i in 0..k {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
let j = i + (state as usize) % (n_features - i);
indices.swap(i, j);
}
indices.truncate(k);
indices.sort_unstable();
indices
}
#[derive(Debug, Clone)]
pub struct BestSplit<F: Float> {
pub feature_index: usize,
pub threshold: F,
pub left_indices: Vec<usize>,
pub right_indices: Vec<usize>,
pub improvement: F,
}
pub fn find_best_split<F: Float>(
x: &Array2<F>,
y: &Array1<F>,
indices: &[usize],
criterion: SplitCriterion,
min_samples_leaf: usize,
) -> Option<BestSplit<F>> {
let all_features: Vec<usize> = (0..x.ncols()).collect();
find_best_split_with_features(x, y, indices, criterion, min_samples_leaf, &all_features)
}
pub fn find_best_split_with_features<F: Float>(
x: &Array2<F>,
y: &Array1<F>,
indices: &[usize],
criterion: SplitCriterion,
min_samples_leaf: usize,
feature_indices: &[usize],
) -> Option<BestSplit<F>> {
let n = indices.len();
if n < 2 * min_samples_leaf {
return None;
}
let parent_impurity = compute_impurity(y, indices, criterion);
match criterion {
SplitCriterion::Gini | SplitCriterion::Entropy => find_best_split_classification(
x,
y,
indices,
criterion,
min_samples_leaf,
feature_indices,
n,
parent_impurity,
),
SplitCriterion::Mse => find_best_split_regression(
x,
y,
indices,
min_samples_leaf,
feature_indices,
n,
parent_impurity,
),
}
}
#[inline]
fn sort_feature_pairs<F: Float>(
x: &Array2<F>,
indices: &[usize],
feature: usize,
sorted_pairs: &mut Vec<(F, usize)>,
) {
sorted_pairs.clear();
sorted_pairs.extend(indices.iter().map(|&i| (x[[i, feature]], i)));
sorted_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
}
struct CandidateSplit<F: Float> {
feature: usize,
threshold: F,
improvement: F,
}
#[inline]
fn try_update_best_split<F: Float>(
improvement: F,
best_improvement: &mut F,
best: &mut Option<CandidateSplit<F>>,
feature: usize,
threshold: F,
) {
if improvement > *best_improvement {
*best_improvement = improvement;
*best = Some(CandidateSplit {
feature,
threshold,
improvement,
});
}
}
trait SplitAccumulator<F: Float> {
fn reset(&mut self, y: &Array1<F>, indices: &[usize]);
fn move_to_left(&mut self, y: &Array1<F>, idx: usize);
fn weighted_impurity(&self, n: usize) -> F;
fn n_left(&self) -> usize;
fn n_right(&self) -> usize;
}
struct ClassificationAccumulator<F: Float> {
left_counts: Vec<usize>,
right_counts: Vec<usize>,
n_left: usize,
n_right: usize,
criterion: SplitCriterion,
class_map: HashMap<u64, usize>,
_marker: std::marker::PhantomData<F>,
}
impl<F: Float> ClassificationAccumulator<F> {
fn new(y: &Array1<F>, indices: &[usize]) -> Self {
let class_map = build_class_map(y, indices);
let n_classes = class_map.len();
let mut total_counts = vec![0usize; n_classes];
for &i in indices {
let cls = class_map[&float_key(y[i])];
total_counts[cls] += 1;
}
Self {
left_counts: vec![0usize; n_classes],
right_counts: total_counts,
n_left: 0,
n_right: indices.len(),
criterion: SplitCriterion::Gini,
class_map,
_marker: std::marker::PhantomData,
}
}
}
impl<F: Float> SplitAccumulator<F> for ClassificationAccumulator<F> {
fn reset(&mut self, y: &Array1<F>, indices: &[usize]) {
self.left_counts.fill(0);
self.right_counts.fill(0);
for &i in indices {
let cls = self.class_map[&float_key(y[i])];
self.right_counts[cls] += 1;
}
self.n_left = 0;
self.n_right = indices.len();
}
fn move_to_left(&mut self, y: &Array1<F>, idx: usize) {
let cls = self.class_map[&float_key(y[idx])];
self.left_counts[cls] += 1;
self.right_counts[cls] -= 1;
self.n_left += 1;
self.n_right -= 1;
}
fn weighted_impurity(&self, n: usize) -> F {
let n_f = F::from_usize(n).unwrap();
let nl = F::from_usize(self.n_left).unwrap();
let nr = F::from_usize(self.n_right).unwrap();
let left_imp = impurity_from_counts(&self.left_counts, self.n_left, self.criterion);
let right_imp = impurity_from_counts(&self.right_counts, self.n_right, self.criterion);
(nl / n_f) * left_imp + (nr / n_f) * right_imp
}
fn n_left(&self) -> usize {
self.n_left
}
fn n_right(&self) -> usize {
self.n_right
}
}
impl<F: Float> ClassificationAccumulator<F> {
fn with_criterion(mut self, criterion: SplitCriterion) -> Self {
self.criterion = criterion;
self
}
}
struct RegressionAccumulator<F: Float> {
left_sum: F,
left_sum_sq: F,
right_sum: F,
right_sum_sq: F,
n_left: usize,
n_right: usize,
}
impl<F: Float> RegressionAccumulator<F> {
fn new(y: &Array1<F>, indices: &[usize]) -> Self {
let mut total_sum = F::zero();
let mut total_sum_sq = F::zero();
for &i in indices {
let v = y[i];
total_sum += v;
total_sum_sq += v * v;
}
Self {
left_sum: F::zero(),
left_sum_sq: F::zero(),
right_sum: total_sum,
right_sum_sq: total_sum_sq,
n_left: 0,
n_right: indices.len(),
}
}
}
impl<F: Float> SplitAccumulator<F> for RegressionAccumulator<F> {
fn reset(&mut self, y: &Array1<F>, indices: &[usize]) {
self.left_sum = F::zero();
self.left_sum_sq = F::zero();
self.right_sum = F::zero();
self.right_sum_sq = F::zero();
for &i in indices {
let v = y[i];
self.right_sum += v;
self.right_sum_sq += v * v;
}
self.n_left = 0;
self.n_right = indices.len();
}
fn move_to_left(&mut self, y: &Array1<F>, idx: usize) {
let v = y[idx];
self.left_sum += v;
self.left_sum_sq += v * v;
self.right_sum -= v;
self.right_sum_sq -= v * v;
self.n_left += 1;
self.n_right -= 1;
}
fn weighted_impurity(&self, n: usize) -> F {
let n_f = F::from_usize(n).unwrap();
let nl = F::from_usize(self.n_left).unwrap();
let nr = F::from_usize(self.n_right).unwrap();
let left_mse = self.left_sum_sq / nl - (self.left_sum / nl) * (self.left_sum / nl);
let right_mse = self.right_sum_sq / nr - (self.right_sum / nr) * (self.right_sum / nr);
(nl / n_f) * left_mse + (nr / n_f) * right_mse
}
fn n_left(&self) -> usize {
self.n_left
}
fn n_right(&self) -> usize {
self.n_right
}
}
#[inline]
fn evaluate_candidate_split<F: Float, A: SplitAccumulator<F>>(
acc: &A,
n: usize,
min_samples_leaf: usize,
cur_val: F,
next_val: F,
parent_impurity: F,
) -> Option<(F, F)> {
if (next_val - cur_val).abs() < F::from_f64(1e-15).unwrap() {
return None;
}
if acc.n_left() < min_samples_leaf || acc.n_right() < min_samples_leaf {
return None;
}
let threshold = (cur_val + next_val) / (F::one() + F::one());
let improvement = parent_impurity - acc.weighted_impurity(n);
Some((threshold, improvement))
}
#[allow(clippy::too_many_arguments)]
fn find_best_split_inner<F, A>(
x: &Array2<F>,
y: &Array1<F>,
indices: &[usize],
min_samples_leaf: usize,
feature_indices: &[usize],
n: usize,
parent_impurity: F,
mut acc: A,
) -> Option<BestSplit<F>>
where
F: Float,
A: SplitAccumulator<F>,
{
let mut best: Option<CandidateSplit<F>> = None;
let mut best_improvement = F::neg_infinity();
let mut sorted_pairs: Vec<(F, usize)> = Vec::with_capacity(n);
for &feature in feature_indices {
sort_feature_pairs(x, indices, feature, &mut sorted_pairs);
acc.reset(y, indices);
for pos in 0..n - 1 {
let (cur_val, cur_idx) = sorted_pairs[pos];
acc.move_to_left(y, cur_idx);
let next_val = sorted_pairs[pos + 1].0;
if let Some((threshold, improvement)) = evaluate_candidate_split(
&acc,
n,
min_samples_leaf,
cur_val,
next_val,
parent_impurity,
) {
try_update_best_split(
improvement,
&mut best_improvement,
&mut best,
feature,
threshold,
);
}
}
}
best.map(|candidate| {
let mut left_indices = Vec::with_capacity(n);
let mut right_indices = Vec::with_capacity(n);
for &i in indices {
if x[[i, candidate.feature]] <= candidate.threshold {
left_indices.push(i);
} else {
right_indices.push(i);
}
}
BestSplit {
feature_index: candidate.feature,
threshold: candidate.threshold,
left_indices,
right_indices,
improvement: candidate.improvement,
}
})
}
#[allow(clippy::too_many_arguments)]
fn find_best_split_classification<F: Float>(
x: &Array2<F>,
y: &Array1<F>,
indices: &[usize],
criterion: SplitCriterion,
min_samples_leaf: usize,
feature_indices: &[usize],
n: usize,
parent_impurity: F,
) -> Option<BestSplit<F>> {
let acc = ClassificationAccumulator::<F>::new(y, indices).with_criterion(criterion);
find_best_split_inner(
x,
y,
indices,
min_samples_leaf,
feature_indices,
n,
parent_impurity,
acc,
)
}
fn find_best_split_regression<F: Float>(
x: &Array2<F>,
y: &Array1<F>,
indices: &[usize],
min_samples_leaf: usize,
feature_indices: &[usize],
n: usize,
parent_impurity: F,
) -> Option<BestSplit<F>> {
let acc = RegressionAccumulator::<F>::new(y, indices);
find_best_split_inner(
x,
y,
indices,
min_samples_leaf,
feature_indices,
n,
parent_impurity,
acc,
)
}
fn build_class_map<F: Float>(y: &Array1<F>, indices: &[usize]) -> HashMap<u64, usize> {
let mut map = HashMap::new();
let mut next_idx = 0;
for &i in indices {
let bits = float_key(y[i]);
if let std::collections::hash_map::Entry::Vacant(e) = map.entry(bits) {
e.insert(next_idx);
next_idx += 1;
}
}
map
}
#[inline]
fn impurity_from_counts<F: Float>(counts: &[usize], total: usize, criterion: SplitCriterion) -> F {
let n = F::from_usize(total).unwrap();
match criterion {
SplitCriterion::Gini => {
let sum_sq: F = counts
.iter()
.filter(|&&c| c > 0)
.map(|&c| {
let p = F::from_usize(c).unwrap() / n;
p * p
})
.fold(F::zero(), |a, b| a + b);
F::one() - sum_sq
}
SplitCriterion::Entropy => {
let sum: F = counts
.iter()
.filter(|&&c| c > 0)
.map(|&c| {
let p = F::from_usize(c).unwrap() / n;
p * p.ln()
})
.fold(F::zero(), |a, b| a + b);
-sum
}
SplitCriterion::Mse => unreachable!("MSE does not use class counts"),
}
}
#[inline]
pub fn compute_impurity<F: Float>(
y: &Array1<F>,
indices: &[usize],
criterion: SplitCriterion,
) -> F {
match criterion {
SplitCriterion::Gini => gini(y, indices),
SplitCriterion::Entropy => entropy(y, indices),
SplitCriterion::Mse => mse_impurity(y, indices),
}
}
#[inline]
fn gini<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
let n = F::from_usize(indices.len()).unwrap();
let class_counts = count_classes(y, indices);
let sum_sq: F = class_counts
.iter()
.map(|&(_, count)| {
let p = F::from_usize(count).unwrap() / n;
p * p
})
.fold(F::zero(), |a, b| a + b);
F::one() - sum_sq
}
#[inline]
fn entropy<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
let n = F::from_usize(indices.len()).unwrap();
let class_counts = count_classes(y, indices);
let sum: F = class_counts
.iter()
.map(|&(_, count)| {
let p = F::from_usize(count).unwrap() / n;
if p > F::zero() {
p * p.ln()
} else {
F::zero()
}
})
.fold(F::zero(), |a, b| a + b);
-sum
}
#[inline]
fn mse_impurity<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
let n = F::from_usize(indices.len()).unwrap();
let mean: F = indices.iter().map(|&i| y[i]).fold(F::zero(), |a, b| a + b) / n;
indices
.iter()
.map(|&i| (y[i] - mean) * (y[i] - mean))
.fold(F::zero(), |a, b| a + b)
/ n
}
pub fn count_classes<F: Float>(y: &Array1<F>, indices: &[usize]) -> Vec<(F, usize)> {
let mut map: HashMap<u64, (F, usize)> = HashMap::new();
for &i in indices {
let val = y[i];
let bits = float_key(val);
map.entry(bits).and_modify(|e| e.1 += 1).or_insert((val, 1));
}
let mut counts: Vec<(F, usize)> = map.into_values().collect();
counts.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
counts
}
#[inline]
pub fn leaf_value<F: Float>(y: &Array1<F>, indices: &[usize], criterion: SplitCriterion) -> F {
match criterion {
SplitCriterion::Mse => {
let n = F::from_usize(indices.len()).unwrap();
indices.iter().map(|&i| y[i]).fold(F::zero(), |a, b| a + b) / n
}
SplitCriterion::Gini | SplitCriterion::Entropy => {
let counts = count_classes(y, indices);
counts
.into_iter()
.max_by_key(|&(_, count)| count)
.unwrap()
.0
}
}
}
pub fn find_random_split<F: Float>(
x: &Array2<F>,
y: &Array1<F>,
indices: &[usize],
criterion: SplitCriterion,
min_samples_leaf: usize,
seed: u64,
) -> Option<BestSplit<F>> {
let n_features = x.ncols();
let n = indices.len();
if n < 2 * min_samples_leaf {
return None;
}
let parent_impurity = compute_impurity(y, indices, criterion);
let mut best: Option<CandidateSplit<F>> = None;
let mut best_improvement = F::neg_infinity();
let mut rng_state = seed.wrapping_add(0x9E3779B97F4A7C15);
for feature in 0..n_features {
let mut min_val = x[[indices[0], feature]];
let mut max_val = min_val;
for &i in &indices[1..] {
let v = x[[i, feature]];
if v < min_val {
min_val = v;
}
if v > max_val {
max_val = v;
}
}
if (max_val - min_val).abs() < F::from_f64(1e-15).unwrap() {
continue;
}
rng_state ^= rng_state << 13;
rng_state ^= rng_state >> 7;
rng_state ^= rng_state << 17;
let t = F::from_f64((rng_state as f64) / (u64::MAX as f64)).unwrap();
let threshold = min_val + t * (max_val - min_val);
let mut n_left = 0usize;
let mut n_right = 0usize;
for &i in indices {
if x[[i, feature]] <= threshold {
n_left += 1;
} else {
n_right += 1;
}
}
if n_left < min_samples_leaf || n_right < min_samples_leaf {
continue;
}
let left_indices: Vec<usize> = indices
.iter()
.copied()
.filter(|&i| x[[i, feature]] <= threshold)
.collect();
let right_indices: Vec<usize> = indices
.iter()
.copied()
.filter(|&i| x[[i, feature]] > threshold)
.collect();
let left_imp = compute_impurity(y, &left_indices, criterion);
let right_imp = compute_impurity(y, &right_indices, criterion);
let n_f = F::from_usize(n).unwrap();
let nl_f = F::from_usize(n_left).unwrap();
let nr_f = F::from_usize(n_right).unwrap();
let weighted = (nl_f / n_f) * left_imp + (nr_f / n_f) * right_imp;
let improvement = parent_impurity - weighted;
try_update_best_split(
improvement,
&mut best_improvement,
&mut best,
feature,
threshold,
);
}
best.map(|candidate| {
let mut left_indices = Vec::with_capacity(n);
let mut right_indices = Vec::with_capacity(n);
for &i in indices {
if x[[i, candidate.feature]] <= candidate.threshold {
left_indices.push(i);
} else {
right_indices.push(i);
}
}
BestSplit {
feature_index: candidate.feature,
threshold: candidate.threshold,
left_indices,
right_indices,
improvement: candidate.improvement,
}
})
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum ClassWeight {
Balanced,
Manual(Vec<(f64, f64)>),
}
pub fn compute_sample_weights_from_class_weight<F: Float>(
y: &Array1<F>,
class_weight: &ClassWeight,
) -> Array1<F> {
let n_samples = y.len();
match class_weight {
ClassWeight::Balanced => {
let counts = count_classes(y, &(0..n_samples).collect::<Vec<_>>());
let n_classes = counts.len();
let n_f = F::from_usize(n_samples).unwrap();
let nc_f = F::from_usize(n_classes).unwrap();
let mut weights = Array1::<F>::ones(n_samples);
for i in 0..n_samples {
for &(class_val, count) in &counts {
if (y[i] - class_val).abs() < F::from_f64(1e-9).unwrap() {
weights[i] = n_f / (nc_f * F::from_usize(count).unwrap());
break;
}
}
}
weights
}
ClassWeight::Manual(mapping) => {
let mut weights = Array1::<F>::ones(n_samples);
for i in 0..n_samples {
let yi = y[i].to_f64().unwrap();
for &(class_val, w) in mapping {
if (yi - class_val).abs() < 1e-9 {
weights[i] = F::from_f64(w).unwrap();
break;
}
}
}
weights
}
}
}
pub fn compute_weighted_impurity<F: Float>(
y: &Array1<F>,
indices: &[usize],
weights: &Array1<F>,
criterion: SplitCriterion,
) -> F {
let total_weight: F = indices
.iter()
.map(|&i| weights[i])
.fold(F::zero(), |a, b| a + b);
if total_weight <= F::zero() {
return F::zero();
}
match criterion {
SplitCriterion::Gini => {
let mut class_weights: HashMap<u64, F> = HashMap::new();
for &i in indices {
let key = float_key(y[i]);
*class_weights.entry(key).or_insert(F::zero()) += weights[i];
}
let sum_sq: F = class_weights
.values()
.map(|&w| {
let p = w / total_weight;
p * p
})
.fold(F::zero(), |a, b| a + b);
F::one() - sum_sq
}
SplitCriterion::Entropy => {
let mut class_weights: HashMap<u64, F> = HashMap::new();
for &i in indices {
let key = float_key(y[i]);
*class_weights.entry(key).or_insert(F::zero()) += weights[i];
}
let sum: F = class_weights
.values()
.filter(|&&w| w > F::zero())
.map(|&w| {
let p = w / total_weight;
p * p.ln()
})
.fold(F::zero(), |a, b| a + b);
-sum
}
SplitCriterion::Mse => {
let w_mean: F = indices
.iter()
.map(|&i| weights[i] * y[i])
.fold(F::zero(), |a, b| a + b)
/ total_weight;
indices
.iter()
.map(|&i| weights[i] * (y[i] - w_mean) * (y[i] - w_mean))
.fold(F::zero(), |a, b| a + b)
/ total_weight
}
}
}
pub fn weighted_leaf_value<F: Float>(
y: &Array1<F>,
indices: &[usize],
weights: &Array1<F>,
criterion: SplitCriterion,
) -> F {
match criterion {
SplitCriterion::Mse => {
let total_weight: F = indices
.iter()
.map(|&i| weights[i])
.fold(F::zero(), |a, b| a + b);
if total_weight <= F::zero() {
return F::zero();
}
indices
.iter()
.map(|&i| weights[i] * y[i])
.fold(F::zero(), |a, b| a + b)
/ total_weight
}
SplitCriterion::Gini | SplitCriterion::Entropy => {
let mut class_weights: HashMap<u64, (F, F)> = HashMap::new();
for &i in indices {
let key = float_key(y[i]);
class_weights
.entry(key)
.and_modify(|e| e.1 += weights[i])
.or_insert((y[i], weights[i]));
}
class_weights
.into_values()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap()
.0
}
}
}
pub fn weighted_count_classes<F: Float>(
y: &Array1<F>,
indices: &[usize],
weights: &Array1<F>,
) -> Vec<(F, F)> {
let mut map: HashMap<u64, (F, F)> = HashMap::new();
for &i in indices {
let val = y[i];
let bits = float_key(val);
map.entry(bits)
.and_modify(|e| e.1 += weights[i])
.or_insert((val, weights[i]));
}
let mut counts: Vec<(F, F)> = map.into_values().collect();
counts.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
counts
}
pub fn find_best_split_weighted<F: Float>(
x: &Array2<F>,
y: &Array1<F>,
indices: &[usize],
weights: &Array1<F>,
criterion: SplitCriterion,
min_samples_leaf: usize,
feature_indices: &[usize],
) -> Option<BestSplit<F>> {
let n = indices.len();
if n < 2 * min_samples_leaf {
return None;
}
let parent_impurity = compute_weighted_impurity(y, indices, weights, criterion);
let mut best: Option<CandidateSplit<F>> = None;
let mut best_improvement = F::neg_infinity();
let mut sorted_pairs: Vec<(F, usize)> = Vec::with_capacity(n);
let total_weight: F = indices
.iter()
.map(|&i| weights[i])
.fold(F::zero(), |a, b| a + b);
for &feature in feature_indices {
sort_feature_pairs(x, indices, feature, &mut sorted_pairs);
let mut left_weight = F::zero();
let mut right_weight = total_weight;
let mut left_class_weights: HashMap<u64, F> = HashMap::new();
let mut right_class_weights: HashMap<u64, F> = HashMap::new();
for &i in indices {
let key = float_key(y[i]);
*right_class_weights.entry(key).or_insert(F::zero()) += weights[i];
}
for pos in 0..n - 1 {
let (cur_val, cur_idx) = sorted_pairs[pos];
let w = weights[cur_idx];
let key = float_key(y[cur_idx]);
left_weight += w;
right_weight -= w;
*left_class_weights.entry(key).or_insert(F::zero()) += w;
*right_class_weights.entry(key).or_insert(F::zero()) -= w;
let next_val = sorted_pairs[pos + 1].0;
if (next_val - cur_val).abs() < F::from_f64(1e-15).unwrap() {
continue;
}
let n_left = pos + 1;
let n_right = n - n_left;
if n_left < min_samples_leaf || n_right < min_samples_leaf {
continue;
}
let left_imp = match criterion {
SplitCriterion::Gini => {
let sum_sq: F = left_class_weights
.values()
.filter(|&&w| w > F::zero())
.map(|&w| {
let p = w / left_weight;
p * p
})
.fold(F::zero(), |a, b| a + b);
F::one() - sum_sq
}
SplitCriterion::Entropy => {
let sum: F = left_class_weights
.values()
.filter(|&&w| w > F::zero())
.map(|&w| {
let p = w / left_weight;
p * p.ln()
})
.fold(F::zero(), |a, b| a + b);
-sum
}
SplitCriterion::Mse => {
let left_indices: Vec<usize> =
sorted_pairs[..=pos].iter().map(|&(_, i)| i).collect();
compute_weighted_impurity(y, &left_indices, weights, criterion)
}
};
let right_imp = match criterion {
SplitCriterion::Gini => {
let sum_sq: F = right_class_weights
.values()
.filter(|&&w| w > F::zero())
.map(|&w| {
let p = w / right_weight;
p * p
})
.fold(F::zero(), |a, b| a + b);
F::one() - sum_sq
}
SplitCriterion::Entropy => {
let sum: F = right_class_weights
.values()
.filter(|&&w| w > F::zero())
.map(|&w| {
let p = w / right_weight;
p * p.ln()
})
.fold(F::zero(), |a, b| a + b);
-sum
}
SplitCriterion::Mse => {
let right_indices: Vec<usize> =
sorted_pairs[pos + 1..].iter().map(|&(_, i)| i).collect();
compute_weighted_impurity(y, &right_indices, weights, criterion)
}
};
let weighted_imp =
(left_weight / total_weight) * left_imp + (right_weight / total_weight) * right_imp;
let improvement = parent_impurity - weighted_imp;
let threshold = (cur_val + next_val) / (F::one() + F::one());
try_update_best_split(
improvement,
&mut best_improvement,
&mut best,
feature,
threshold,
);
}
}
best.map(|candidate| {
let mut left_indices = Vec::with_capacity(n);
let mut right_indices = Vec::with_capacity(n);
for &i in indices {
if x[[i, candidate.feature]] <= candidate.threshold {
left_indices.push(i);
} else {
right_indices.push(i);
}
}
BestSplit {
feature_index: candidate.feature,
threshold: candidate.threshold,
left_indices,
right_indices,
improvement: candidate.improvement,
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
#[test]
fn test_gini_pure() {
let y = array![1.0, 1.0, 1.0];
let indices = vec![0, 1, 2];
assert_abs_diff_eq!(gini(&y, &indices), 0.0, epsilon = 1e-10);
}
#[test]
fn test_gini_balanced() {
let y = array![0.0, 1.0];
let indices = vec![0, 1];
assert_abs_diff_eq!(gini(&y, &indices), 0.5, epsilon = 1e-10);
}
#[test]
fn test_mse_pure() {
let y = array![5.0, 5.0, 5.0];
let indices = vec![0, 1, 2];
assert_abs_diff_eq!(mse_impurity(&y, &indices), 0.0, epsilon = 1e-10);
}
#[test]
fn test_find_best_split() {
let x = array![[1.0], [2.0], [3.0], [4.0]];
let y = array![0.0, 0.0, 1.0, 1.0];
let indices = vec![0, 1, 2, 3];
let split = find_best_split(&x, &y, &indices, SplitCriterion::Gini, 1).unwrap();
assert!(split.threshold > 2.0 && split.threshold < 3.0);
}
#[test]
fn test_find_best_split_regression() {
let x = array![[1.0], [2.0], [3.0], [4.0]];
let y = array![1.0, 1.5, 10.0, 10.5];
let indices = vec![0, 1, 2, 3];
let split = find_best_split(&x, &y, &indices, SplitCriterion::Mse, 1).unwrap();
assert!(split.threshold > 2.0 && split.threshold < 3.0);
assert_eq!(split.left_indices.len(), 2);
assert_eq!(split.right_indices.len(), 2);
}
#[test]
fn test_count_classes_uses_exact_bits() {
let y = array![0.0, 1.0, 0.0, 2.0, 1.0];
let indices = vec![0, 1, 2, 3, 4];
let counts = count_classes(&y, &indices);
assert_eq!(counts.len(), 3);
assert_eq!(counts[0].1, 2); assert_eq!(counts[1].1, 2); assert_eq!(counts[2].1, 1); }
#[test]
fn test_find_best_split_entropy() {
let x = array![[1.0], [2.0], [3.0], [4.0]];
let y = array![0.0, 0.0, 1.0, 1.0];
let indices = vec![0, 1, 2, 3];
let split = find_best_split(&x, &y, &indices, SplitCriterion::Entropy, 1).unwrap();
assert!(split.threshold > 2.0 && split.threshold < 3.0);
}
#[test]
fn test_min_samples_leaf_respected() {
let x = array![[1.0], [2.0], [3.0], [4.0]];
let y = array![0.0, 0.0, 1.0, 1.0];
let indices = vec![0, 1, 2, 3];
let split = find_best_split(&x, &y, &indices, SplitCriterion::Gini, 3);
assert!(split.is_none());
}
}