use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Node {
Split {
feature: usize,
threshold: f32,
left: Box<Node>,
right: Box<Node>,
},
Leaf {
prob_a: f32,
prob_b: f32,
n_samples: usize,
},
}
impl Node {
pub fn predict_proba(&self, sample: &[f32]) -> f32 {
match self {
Node::Leaf { prob_a, .. } => *prob_a,
Node::Split { feature, threshold, left, right } => {
if sample[*feature] <= *threshold {
left.predict_proba(sample)
} else {
right.predict_proba(sample)
}
}
}
}
}
pub fn gini(n_a: usize, n_b: usize) -> f32 {
let n = (n_a + n_b) as f32;
if n < f32::EPSILON {
return 0.0;
}
let pa = n_a as f32 / n;
let pb = n_b as f32 / n;
1.0 - (pa * pa + pb * pb)
}
pub fn split_gini(
left_a: usize, left_b: usize,
right_a: usize, right_b: usize,
) -> f32 {
let n_left = (left_a + left_b) as f32;
let n_right = (right_a + right_b) as f32;
let n_total = n_left + n_right;
if n_total < f32::EPSILON {
return 0.0;
}
(n_left / n_total) * gini(left_a, left_b)
+ (n_right / n_total) * gini(right_a, right_b)
}
#[derive(Debug, Clone)]
pub struct TreeParams {
pub max_depth: Option<usize>,
pub min_samples_split: usize,
pub min_samples_leaf: usize,
pub max_features: Option<usize>,
pub seed: u64,
}
impl Default for TreeParams {
fn default() -> Self {
Self {
max_depth: Some(20),
min_samples_split: 2,
min_samples_leaf: 1,
max_features: None,
seed: 42,
}
}
}
pub fn build_tree(
features: &[Vec<f32>],
labels: &[u8],
params: &TreeParams,
) -> Node {
let indices: Vec<usize> = (0..features.len()).collect();
let n_features = features.first().map(|r| r.len()).unwrap_or(0);
let max_features = params.max_features.unwrap_or_else(|| {
let sq = (n_features as f64).sqrt().ceil() as usize;
sq.max(1)
});
build_node(features, labels, &indices, 0, params, max_features, params.seed)
}
fn build_node(
features: &[Vec<f32>],
labels: &[u8],
indices: &[usize],
depth: usize,
params: &TreeParams,
max_features: usize,
seed: u64,
) -> Node {
let n = indices.len();
let n_a = indices.iter().filter(|&&i| labels[i] == 0).count();
let n_b = n - n_a;
let at_max_depth = params.max_depth.is_some_and(|d| depth >= d);
let too_small = n < params.min_samples_split;
let pure = n_a == 0 || n_b == 0;
if pure || at_max_depth || too_small {
return make_leaf(n_a, n_b);
}
let n_features = features[0].len();
let feature_indices = random_feature_subset(n_features, max_features, seed ^ (depth as u64 * 2654435761));
let best = best_split(features, labels, indices, &feature_indices, params.min_samples_leaf);
match best {
None => make_leaf(n_a, n_b),
Some((feat, thresh)) => {
let (left_idx, right_idx): (Vec<usize>, Vec<usize>) = indices
.iter()
.partition(|&&i| features[i][feat] <= thresh);
if left_idx.is_empty() || right_idx.is_empty() {
return make_leaf(n_a, n_b);
}
let left = build_node(features, labels, &left_idx, depth + 1, params, max_features, seed.wrapping_add(1));
let right = build_node(features, labels, &right_idx, depth + 1, params, max_features, seed.wrapping_add(2));
Node::Split {
feature: feat,
threshold: thresh,
left: Box::new(left),
right: Box::new(right),
}
}
}
}
fn make_leaf(n_a: usize, n_b: usize) -> Node {
let n = (n_a + n_b) as f32;
let prob_a = if n < f32::EPSILON { 0.5 } else { n_a as f32 / n };
Node::Leaf { prob_a, prob_b: 1.0 - prob_a, n_samples: n_a + n_b }
}
fn best_split(
features: &[Vec<f32>],
labels: &[u8],
indices: &[usize],
feature_indices: &[usize],
min_samples_leaf: usize,
) -> Option<(usize, f32)> {
let parent_n_a = indices.iter().filter(|&&i| labels[i] == 0).count();
let parent_n_b = indices.len() - parent_n_a;
let parent_gini = gini(parent_n_a, parent_n_b);
let mut best_gain = 0.0_f32;
let mut best: Option<(usize, f32)> = None;
for &feat in feature_indices {
let mut vals: Vec<(f32, u8)> = indices.iter()
.map(|&i| (features[i][feat], labels[i]))
.collect();
vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let n = vals.len();
let mut left_a = 0usize;
let mut left_b = 0usize;
let mut right_a = parent_n_a;
let mut right_b = parent_n_b;
for split_pos in 0..n - 1 {
if vals[split_pos].1 == 0 { left_a += 1; right_a -= 1; }
else { left_b += 1; right_b -= 1; }
if (vals[split_pos].0 - vals[split_pos + 1].0).abs() < f32::EPSILON {
continue;
}
let left_n = left_a + left_b;
let right_n = right_a + right_b;
if left_n < min_samples_leaf || right_n < min_samples_leaf {
continue;
}
let sg = split_gini(left_a, left_b, right_a, right_b);
let gain = parent_gini - sg;
if gain > best_gain {
best_gain = gain;
let threshold = (vals[split_pos].0 + vals[split_pos + 1].0) / 2.0;
best = Some((feat, threshold));
}
}
}
best
}
fn random_feature_subset(n_features: usize, k: usize, seed: u64) -> Vec<usize> {
if k >= n_features {
return (0..n_features).collect();
}
let mut rng = seed.wrapping_add(1);
let mut pool: Vec<usize> = (0..n_features).collect();
for i in 0..k {
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
let j = i + (rng as usize % (n_features - i));
pool.swap(i, j);
}
pool[..k].to_vec()
}
#[cfg(test)]
mod tests {
use super::*;
fn features(rows: &[[f32; 2]]) -> Vec<Vec<f32>> {
rows.iter().map(|r| r.to_vec()).collect()
}
#[test]
fn gini_pure_class_a_is_zero() {
assert!((gini(10, 0)).abs() < 1e-6);
}
#[test]
fn gini_uniform_is_half() {
assert!((gini(5, 5) - 0.5).abs() < 1e-6);
}
#[test]
fn gini_zero_samples_is_zero() {
assert_eq!(gini(0, 0), 0.0);
}
#[test]
fn build_tree_pure_class_a_is_leaf() {
let f = features(&[[0.1, 0.2], [0.3, 0.4]]);
let l = vec![0u8, 0u8];
let node = build_tree(&f, &l, &TreeParams::default());
assert!(matches!(node, Node::Leaf { prob_a, .. } if (prob_a - 1.0).abs() < 1e-6));
}
#[test]
fn build_tree_pure_class_b_is_leaf() {
let f = features(&[[0.1, 0.2], [0.3, 0.4]]);
let l = vec![1u8, 1u8];
let node = build_tree(&f, &l, &TreeParams::default());
assert!(matches!(node, Node::Leaf { prob_a, .. } if prob_a.abs() < 1e-6));
}
#[test]
fn build_tree_linearly_separable() {
let f = features(&[
[0.1, 0.0], [0.2, 0.0], [0.3, 0.0],
[0.7, 0.0], [0.8, 0.0], [0.9, 0.0],
]);
let l = vec![0, 0, 0, 1, 1, 1];
let tree = build_tree(&f, &l, &TreeParams::default());
let p_a = tree.predict_proba(&[0.1, 0.0]);
assert!(p_a > 0.5, "expected model_a prob > 0.5 for left side, got {}", p_a);
let p_a2 = tree.predict_proba(&[0.9, 0.0]);
assert!(p_a2 < 0.5, "expected model_a prob < 0.5 for right side, got {}", p_a2);
}
#[test]
fn predict_proba_returns_value_in_zero_one() {
let f = features(&[[0.1, 0.5], [0.9, 0.2], [0.4, 0.8], [0.6, 0.3]]);
let l = vec![0, 1, 0, 1];
let tree = build_tree(&f, &l, &TreeParams::default());
for row in &f {
let p = tree.predict_proba(row);
assert!((0.0..=1.0).contains(&p));
}
}
#[test]
fn random_feature_subset_length() {
let subset = random_feature_subset(10, 3, 12345);
assert_eq!(subset.len(), 3);
}
#[test]
fn random_feature_subset_no_duplicates() {
let subset = random_feature_subset(10, 5, 99);
let unique: std::collections::HashSet<_> = subset.iter().collect();
assert_eq!(unique.len(), 5);
}
#[test]
fn random_feature_subset_k_ge_n_returns_all() {
let subset = random_feature_subset(4, 10, 1);
assert_eq!(subset.len(), 4);
}
}