use crate::dataset::Dataset;
use crate::error::{Result, ScryLearnError};
use crate::weights::{compute_sample_weights, ClassWeight};
use super::{
compute_impurity, compute_impurity_weighted, majority_class, weighted_majority_class,
BestSplit, FlatTree, SplitCriterion, TreeNode,
};
pub(crate) fn presort_indices(data: &Dataset, indices: &[usize]) -> Vec<Vec<usize>> {
let n_features = data.n_features();
let mut sorted_by_feature = Vec::with_capacity(n_features);
for feat_idx in 0..n_features {
let col = &data.features[feat_idx];
let mut sorted = indices.to_vec();
sorted.sort_unstable_by(|&a, &b| {
col[a]
.partial_cmp(&col[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
sorted_by_feature.push(sorted);
}
sorted_by_feature
}
fn filter_sorted(global_sorted: &[Vec<usize>], membership: &[bool]) -> Vec<Vec<usize>> {
global_sorted
.iter()
.map(|gs| gs.iter().copied().filter(|&idx| membership[idx]).collect())
.collect()
}
fn partition_sorted(
mut sorted_by_feature: Vec<Vec<usize>>,
split_col: &[f64],
threshold: f64,
_left_count: usize,
right_count: usize,
) -> (Vec<Vec<usize>>, Vec<Vec<usize>>) {
let n_feat = sorted_by_feature.len();
let mut right_sorted = Vec::with_capacity(n_feat);
for feat_sorted in &mut sorted_by_feature {
let mut right = Vec::with_capacity(right_count);
let mut write = 0;
for read in 0..feat_sorted.len() {
let idx = feat_sorted[read];
if split_col[idx] <= threshold {
feat_sorted[write] = idx;
write += 1;
} else {
right.push(idx);
}
}
feat_sorted.truncate(write);
right_sorted.push(right);
}
(sorted_by_feature, right_sorted)
}
fn fill_feature_buf(
feature_buf: &mut Vec<usize>,
n_features: usize,
max_features: Option<usize>,
rng: &mut crate::rng::FastRng,
) {
feature_buf.clear();
feature_buf.extend(0..n_features);
if let Some(max_f) = max_features {
let m = max_f.min(n_features);
for i in 0..m {
let j = rng.usize(i..n_features);
feature_buf.swap(i, j);
}
feature_buf.truncate(m);
}
}
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct DecisionTreeClassifier {
max_depth: Option<usize>,
min_samples_split: usize,
min_samples_leaf: usize,
max_features: Option<usize>,
criterion: SplitCriterion,
ccp_alpha: f64,
pub(crate) class_weight: ClassWeight,
pub(crate) sample_weights: Option<Vec<f64>>,
pub(crate) flat_tree: Option<FlatTree>,
n_classes: usize,
n_features: usize,
pub(crate) feature_importances_: Vec<f64>,
#[cfg_attr(feature = "serde", serde(default))]
_schema_version: u32,
}
impl DecisionTreeClassifier {
pub fn new() -> Self {
Self {
max_depth: None,
min_samples_split: 2,
min_samples_leaf: 1,
max_features: None,
criterion: SplitCriterion::Gini,
ccp_alpha: 0.0,
class_weight: ClassWeight::Uniform,
sample_weights: None,
flat_tree: None,
n_classes: 0,
n_features: 0,
feature_importances_: Vec::new(),
_schema_version: crate::version::SCHEMA_VERSION,
}
}
pub fn max_depth(mut self, d: usize) -> Self {
self.max_depth = Some(d);
self
}
pub fn min_samples_split(mut self, n: usize) -> Self {
self.min_samples_split = n;
self
}
pub fn min_samples_leaf(mut self, n: usize) -> Self {
self.min_samples_leaf = n;
self
}
pub fn max_features(mut self, n: usize) -> Self {
self.max_features = Some(n);
self
}
pub fn criterion(mut self, c: SplitCriterion) -> Self {
self.criterion = c;
self
}
pub fn class_weight(mut self, cw: ClassWeight) -> Self {
self.class_weight = cw;
self
}
pub fn ccp_alpha(mut self, alpha: f64) -> Self {
self.ccp_alpha = alpha;
self
}
pub fn fit(&mut self, data: &Dataset) -> Result<()> {
data.validate_finite()?;
let indices: Vec<usize> = (0..data.n_samples()).collect();
self.fit_on_indices(data, &indices)
}
pub(crate) fn fit_on_indices(
&mut self,
data: &Dataset,
sample_indices: &[usize],
) -> Result<()> {
let sorted_by_feature = presort_indices(data, sample_indices);
self.fit_with_sorted(data, sample_indices, sorted_by_feature)
}
pub(crate) fn fit_on_indices_presorted(
&mut self,
data: &Dataset,
sample_indices: &[usize],
global_sorted: &[Vec<usize>],
) -> Result<()> {
let membership_len = global_sorted.first().map_or(0, Vec::len);
let mut membership = vec![false; membership_len];
for &i in sample_indices {
membership[i] = true;
}
let sorted_by_feature = filter_sorted(global_sorted, &membership);
self.fit_with_sorted(data, sample_indices, sorted_by_feature)
}
fn fit_with_sorted(
&mut self,
data: &Dataset,
sample_indices: &[usize],
sorted_by_feature: Vec<Vec<usize>>,
) -> Result<()> {
let n = sample_indices.len();
if n == 0 {
return Err(ScryLearnError::EmptyDataset);
}
self.n_features = data.n_features();
self.n_classes = data.n_classes();
self.feature_importances_ = vec![0.0; self.n_features];
let weights = match &self.class_weight {
ClassWeight::Uniform => None,
cw => Some(compute_sample_weights(&data.target, cw)),
};
self.sample_weights = weights;
let mut feature_buf = Vec::with_capacity(self.n_features);
let mut split_rng = crate::rng::FastRng::new(0);
let tree = if self.sample_weights.is_some() {
self.build_tree_weighted(
data,
sorted_by_feature,
n,
0,
&mut feature_buf,
&mut split_rng,
)
} else {
self.build_tree(
data,
sorted_by_feature,
n,
0,
&mut feature_buf,
&mut split_rng,
)
};
let tree = if self.ccp_alpha > 0.0 {
tree.prune_ccp(self.ccp_alpha)
} else {
tree
};
let flat = FlatTree::from_tree_node(&tree, self.n_classes);
self.flat_tree = Some(flat);
let total: f64 = self.feature_importances_.iter().sum();
if total > 0.0 {
for imp in &mut self.feature_importances_ {
*imp /= total;
}
}
self.sample_weights = None;
Ok(())
}
pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
crate::version::check_schema_version(self._schema_version)?;
let ft = self.flat_tree.as_ref().ok_or(ScryLearnError::NotFitted)?;
Ok(ft.predict(features))
}
pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
let ft = self.flat_tree.as_ref().ok_or(ScryLearnError::NotFitted)?;
let n_classes = self.n_classes;
Ok(features
.iter()
.map(|row| ft.predict_proba_sample(row, n_classes))
.collect())
}
pub fn feature_importances(&self) -> Result<Vec<f64>> {
if self.flat_tree.is_none() {
return Err(ScryLearnError::NotFitted);
}
Ok(self.feature_importances_.clone())
}
pub fn flat_tree(&self) -> Option<&FlatTree> {
self.flat_tree.as_ref()
}
pub fn depth(&self) -> usize {
self.flat_tree.as_ref().map_or(0, FlatTree::depth)
}
pub fn n_leaves(&self) -> usize {
self.flat_tree.as_ref().map_or(0, FlatTree::n_leaves)
}
pub fn n_features(&self) -> usize {
self.n_features
}
pub fn n_classes(&self) -> usize {
self.n_classes
}
pub fn cost_complexity_pruning_path(&self, data: &Dataset) -> Result<(Vec<f64>, Vec<f64>)> {
let mut unpruned = self.clone();
unpruned.ccp_alpha = 0.0;
unpruned.fit(data)?;
let indices: Vec<usize> = (0..data.n_samples()).collect();
let sorted_by_feature = presort_indices(data, &indices);
let n = indices.len();
let mut feature_buf = Vec::with_capacity(unpruned.n_features);
let mut split_rng = crate::rng::FastRng::new(0);
let tree = if unpruned.sample_weights.is_some() {
unpruned.build_tree_weighted(
data,
sorted_by_feature,
n,
0,
&mut feature_buf,
&mut split_rng,
)
} else {
unpruned.build_tree(
data,
sorted_by_feature,
n,
0,
&mut feature_buf,
&mut split_rng,
)
};
Ok(tree.cost_complexity_pruning_path())
}
fn build_tree(
&mut self,
data: &Dataset,
sorted_by_feature: Vec<Vec<usize>>,
n_root_samples: usize,
depth: usize,
feature_buf: &mut Vec<usize>,
split_rng: &mut crate::rng::FastRng,
) -> TreeNode {
let active = &sorted_by_feature[0];
let n_actual = active.len();
let mut class_counts = vec![0usize; self.n_classes];
for &idx in active {
let c = data.target[idx] as usize;
if c < self.n_classes {
class_counts[c] += 1;
}
}
let impurity = compute_impurity(&class_counts, n_actual, self.criterion);
let max_depth_reached = self.max_depth.is_some_and(|d| depth >= d);
let too_few_samples = n_actual < self.min_samples_split;
let is_pure = impurity < 1e-12;
if max_depth_reached || too_few_samples || is_pure {
return TreeNode::Leaf {
prediction: majority_class(&class_counts),
n_samples: n_actual,
class_counts,
impurity,
};
}
let best = self.find_best_split(
data,
&sorted_by_feature,
&class_counts,
n_actual,
feature_buf,
split_rng,
);
let node_prediction = majority_class(&class_counts);
match best {
None => TreeNode::Leaf {
prediction: node_prediction,
n_samples: n_actual,
class_counts,
impurity,
},
Some(split) => {
let col = &data.features[split.feature_idx];
let threshold = split.threshold;
let mut left_count = 0usize;
let mut right_count = 0usize;
for &idx in active {
if col[idx] <= threshold {
left_count += 1;
} else {
right_count += 1;
}
}
if left_count < self.min_samples_leaf || right_count < self.min_samples_leaf {
return TreeNode::Leaf {
prediction: node_prediction,
n_samples: n_actual,
class_counts,
impurity,
};
}
let weighted_impurity_decrease = (n_actual as f64 / n_root_samples as f64)
* (impurity - split.impurity_decrease);
self.feature_importances_[split.feature_idx] += weighted_impurity_decrease.max(0.0);
let (left_sorted, right_sorted) =
partition_sorted(sorted_by_feature, col, threshold, left_count, right_count);
let left = self.build_tree(
data,
left_sorted,
n_root_samples,
depth + 1,
feature_buf,
split_rng,
);
let right = self.build_tree(
data,
right_sorted,
n_root_samples,
depth + 1,
feature_buf,
split_rng,
);
TreeNode::Split {
feature_idx: split.feature_idx,
threshold,
left: Box::new(left),
right: Box::new(right),
n_samples: n_actual,
impurity,
class_counts,
prediction: node_prediction,
}
}
}
}
fn find_best_split(
&self,
data: &Dataset,
sorted_by_feature: &[Vec<usize>],
parent_counts: &[usize],
n_parent: usize,
feature_buf: &mut Vec<usize>,
split_rng: &mut crate::rng::FastRng,
) -> Option<BestSplit> {
let n_features = data.n_features();
let mut best: Option<BestSplit> = None;
fill_feature_buf(feature_buf, n_features, self.max_features, split_rng);
for &feat_idx in feature_buf.iter() {
let col = &data.features[feat_idx];
let sorted = &sorted_by_feature[feat_idx];
let mut left_counts = vec![0usize; self.n_classes];
let mut left_n = 0;
let mut prev_val = f64::NEG_INFINITY;
for &idx in sorted {
let val = col[idx];
if left_n > 0 && (val - prev_val).abs() > 1e-12 {
let right_n = n_parent - left_n;
if left_n >= self.min_samples_leaf && right_n >= self.min_samples_leaf {
let right_counts: Vec<usize> = parent_counts
.iter()
.zip(left_counts.iter())
.map(|(&p, &l)| p - l)
.collect();
let left_imp = compute_impurity(&left_counts, left_n, self.criterion);
let right_imp = compute_impurity(&right_counts, right_n, self.criterion);
let weighted_imp = (left_n as f64 * left_imp + right_n as f64 * right_imp)
/ n_parent as f64;
let threshold = f64::midpoint(prev_val, val);
let is_better = best
.as_ref()
.is_none_or(|b| weighted_imp < b.impurity_decrease);
if is_better {
best = Some(BestSplit {
feature_idx: feat_idx,
threshold,
impurity_decrease: weighted_imp,
});
}
}
}
let class = data.target[idx] as usize;
if class < self.n_classes {
left_counts[class] += 1;
}
left_n += 1;
prev_val = val;
}
}
best
}
fn build_tree_weighted(
&mut self,
data: &Dataset,
sorted_by_feature: Vec<Vec<usize>>,
n_root_samples: usize,
depth: usize,
feature_buf: &mut Vec<usize>,
split_rng: &mut crate::rng::FastRng,
) -> TreeNode {
let weights = self.sample_weights.as_ref().expect("weights must be set");
let active = &sorted_by_feature[0];
let n_actual = active.len();
let mut w_counts = vec![0.0_f64; self.n_classes];
let mut w_total = 0.0_f64;
let mut class_counts = vec![0usize; self.n_classes];
for &idx in active {
let c = data.target[idx] as usize;
let w = weights[idx];
if c < self.n_classes {
w_counts[c] += w;
class_counts[c] += 1;
}
w_total += w;
}
let impurity = compute_impurity_weighted(&w_counts, w_total, self.criterion);
let max_depth_reached = self.max_depth.is_some_and(|d| depth >= d);
let too_few_samples = n_actual < self.min_samples_split;
let is_pure = impurity < 1e-12;
if max_depth_reached || too_few_samples || is_pure {
return TreeNode::Leaf {
prediction: weighted_majority_class(&w_counts),
n_samples: n_actual,
class_counts,
impurity,
};
}
let best = self.find_best_split_weighted(
data,
&sorted_by_feature,
&w_counts,
w_total,
n_actual,
feature_buf,
split_rng,
);
let node_prediction = weighted_majority_class(&w_counts);
match best {
None => TreeNode::Leaf {
prediction: node_prediction,
n_samples: n_actual,
class_counts,
impurity,
},
Some(split) => {
let col = &data.features[split.feature_idx];
let threshold = split.threshold;
let mut left_count = 0usize;
let mut right_count = 0usize;
for &idx in active {
if col[idx] <= threshold {
left_count += 1;
} else {
right_count += 1;
}
}
if left_count < self.min_samples_leaf || right_count < self.min_samples_leaf {
return TreeNode::Leaf {
prediction: node_prediction,
n_samples: n_actual,
class_counts,
impurity,
};
}
let weighted_impurity_decrease = (n_actual as f64 / n_root_samples as f64)
* (impurity - split.impurity_decrease);
self.feature_importances_[split.feature_idx] += weighted_impurity_decrease.max(0.0);
let (left_sorted, right_sorted) =
partition_sorted(sorted_by_feature, col, threshold, left_count, right_count);
let left = self.build_tree_weighted(
data,
left_sorted,
n_root_samples,
depth + 1,
feature_buf,
split_rng,
);
let right = self.build_tree_weighted(
data,
right_sorted,
n_root_samples,
depth + 1,
feature_buf,
split_rng,
);
TreeNode::Split {
feature_idx: split.feature_idx,
threshold,
left: Box::new(left),
right: Box::new(right),
n_samples: n_actual,
impurity,
class_counts,
prediction: node_prediction,
}
}
}
}
fn find_best_split_weighted(
&self,
data: &Dataset,
sorted_by_feature: &[Vec<usize>],
parent_w_counts: &[f64],
w_parent_total: f64,
n_parent: usize,
feature_buf: &mut Vec<usize>,
split_rng: &mut crate::rng::FastRng,
) -> Option<BestSplit> {
let weights = self.sample_weights.as_ref().expect("weights must be set");
let n_features = data.n_features();
let mut best: Option<BestSplit> = None;
fill_feature_buf(feature_buf, n_features, self.max_features, split_rng);
for &feat_idx in feature_buf.iter() {
let col = &data.features[feat_idx];
let sorted = &sorted_by_feature[feat_idx];
let mut left_w_counts = vec![0.0_f64; self.n_classes];
let mut left_w_total = 0.0_f64;
let mut left_n = 0usize;
let mut prev_val = f64::NEG_INFINITY;
for &idx in sorted {
let val = col[idx];
let w = weights[idx];
if left_n > 0 && (val - prev_val).abs() > 1e-12 {
let right_n = n_parent - left_n;
if left_n >= self.min_samples_leaf && right_n >= self.min_samples_leaf {
let right_w_total = w_parent_total - left_w_total;
let right_w_counts: Vec<f64> = parent_w_counts
.iter()
.zip(left_w_counts.iter())
.map(|(&p, &l)| (p - l).max(0.0))
.collect();
let left_imp =
compute_impurity_weighted(&left_w_counts, left_w_total, self.criterion);
let right_imp = compute_impurity_weighted(
&right_w_counts,
right_w_total,
self.criterion,
);
let weighted_imp =
(left_w_total * left_imp + right_w_total * right_imp) / w_parent_total;
let threshold = f64::midpoint(prev_val, val);
let is_better = best
.as_ref()
.is_none_or(|b| weighted_imp < b.impurity_decrease);
if is_better {
best = Some(BestSplit {
feature_idx: feat_idx,
threshold,
impurity_decrease: weighted_imp,
});
}
}
}
let class = data.target[idx] as usize;
if class < self.n_classes {
left_w_counts[class] += w;
}
left_w_total += w;
left_n += 1;
prev_val = val;
}
}
best
}
}
impl Default for DecisionTreeClassifier {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct DecisionTreeRegressor {
max_depth: Option<usize>,
min_samples_split: usize,
min_samples_leaf: usize,
max_features: Option<usize>,
ccp_alpha: f64,
pub(crate) flat_tree: Option<FlatTree>,
n_features: usize,
pub(crate) feature_importances_: Vec<f64>,
#[cfg_attr(feature = "serde", serde(default))]
_schema_version: u32,
}
impl DecisionTreeRegressor {
pub fn new() -> Self {
Self {
max_depth: None,
min_samples_split: 2,
min_samples_leaf: 1,
max_features: None,
ccp_alpha: 0.0,
flat_tree: None,
n_features: 0,
feature_importances_: Vec::new(),
_schema_version: crate::version::SCHEMA_VERSION,
}
}
pub fn max_depth(mut self, d: usize) -> Self {
self.max_depth = Some(d);
self
}
pub fn min_samples_split(mut self, n: usize) -> Self {
self.min_samples_split = n;
self
}
pub fn min_samples_leaf(mut self, n: usize) -> Self {
self.min_samples_leaf = n;
self
}
pub fn max_features(mut self, n: usize) -> Self {
self.max_features = Some(n);
self
}
pub fn ccp_alpha(mut self, alpha: f64) -> Self {
self.ccp_alpha = alpha;
self
}
pub fn fit(&mut self, data: &Dataset) -> Result<()> {
data.validate_finite()?;
let indices: Vec<usize> = (0..data.n_samples()).collect();
self.fit_on_indices(data, &indices)
}
pub(crate) fn fit_on_indices(
&mut self,
data: &Dataset,
sample_indices: &[usize],
) -> Result<()> {
let n = sample_indices.len();
if n == 0 {
return Err(ScryLearnError::EmptyDataset);
}
self.n_features = data.n_features();
self.feature_importances_ = vec![0.0; self.n_features];
let sorted_by_feature = presort_indices(data, sample_indices);
let mut feature_buf = Vec::with_capacity(self.n_features);
let mut split_rng = crate::rng::FastRng::new(0);
let tree = self.build_tree_reg(
data,
sorted_by_feature,
n,
0,
&mut feature_buf,
&mut split_rng,
);
let tree = if self.ccp_alpha > 0.0 {
tree.prune_ccp(self.ccp_alpha)
} else {
tree
};
let flat = FlatTree::from_tree_node(&tree, 0);
self.flat_tree = Some(flat);
let total: f64 = self.feature_importances_.iter().sum();
if total > 0.0 {
for imp in &mut self.feature_importances_ {
*imp /= total;
}
}
Ok(())
}
pub(crate) fn fit_on_indices_presorted(
&mut self,
data: &Dataset,
sample_indices: &[usize],
global_sorted: &[Vec<usize>],
) -> Result<()> {
let n = sample_indices.len();
if n == 0 {
return Err(ScryLearnError::EmptyDataset);
}
self.n_features = data.n_features();
self.feature_importances_ = vec![0.0; self.n_features];
let membership_len = global_sorted.first().map_or(0, Vec::len);
let mut membership = vec![false; membership_len];
for &i in sample_indices {
membership[i] = true;
}
let sorted_by_feature = filter_sorted(global_sorted, &membership);
let mut feature_buf = Vec::with_capacity(self.n_features);
let mut split_rng = crate::rng::FastRng::new(0);
let tree = self.build_tree_reg(
data,
sorted_by_feature,
n,
0,
&mut feature_buf,
&mut split_rng,
);
let tree = if self.ccp_alpha > 0.0 {
tree.prune_ccp(self.ccp_alpha)
} else {
tree
};
let flat = FlatTree::from_tree_node(&tree, 0);
self.flat_tree = Some(flat);
let total: f64 = self.feature_importances_.iter().sum();
if total > 0.0 {
for imp in &mut self.feature_importances_ {
*imp /= total;
}
}
Ok(())
}
pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
crate::version::check_schema_version(self._schema_version)?;
let ft = self.flat_tree.as_ref().ok_or(ScryLearnError::NotFitted)?;
Ok(ft.predict(features))
}
pub fn feature_importances(&self) -> Result<Vec<f64>> {
if self.flat_tree.is_none() {
return Err(ScryLearnError::NotFitted);
}
Ok(self.feature_importances_.clone())
}
pub fn flat_tree(&self) -> Option<&FlatTree> {
self.flat_tree.as_ref()
}
pub fn n_features(&self) -> usize {
self.n_features
}
fn build_tree_reg(
&mut self,
data: &Dataset,
sorted_by_feature: Vec<Vec<usize>>,
n_root_samples: usize,
depth: usize,
feature_buf: &mut Vec<usize>,
split_rng: &mut crate::rng::FastRng,
) -> TreeNode {
let active = &sorted_by_feature[0];
let n_actual = active.len();
if n_actual == 0 {
return TreeNode::Leaf {
prediction: 0.0,
n_samples: 0,
class_counts: Vec::new(),
impurity: 0.0,
};
}
let mut sum = 0.0;
let mut sq_sum = 0.0;
for &idx in active {
let v = data.target[idx];
sum += v;
sq_sum += v * v;
}
let mean = sum / n_actual as f64;
let mse = (sq_sum / n_actual as f64 - mean * mean).max(0.0);
let max_depth_reached = self.max_depth.is_some_and(|d| depth >= d);
let too_few = n_actual < self.min_samples_split;
if max_depth_reached || too_few || mse < 1e-12 {
return TreeNode::Leaf {
prediction: mean,
n_samples: n_actual,
class_counts: Vec::new(),
impurity: mse,
};
}
let best = self.find_best_split_reg(
data,
&sorted_by_feature,
sum,
sq_sum,
n_actual,
feature_buf,
split_rng,
);
match best {
None => TreeNode::Leaf {
prediction: mean,
n_samples: n_actual,
class_counts: Vec::new(),
impurity: mse,
},
Some(split) => {
let col = &data.features[split.feature_idx];
let threshold = split.threshold;
let mut left_count = 0usize;
let mut right_count = 0usize;
for &idx in active {
if col[idx] <= threshold {
left_count += 1;
} else {
right_count += 1;
}
}
if left_count < self.min_samples_leaf || right_count < self.min_samples_leaf {
return TreeNode::Leaf {
prediction: mean,
n_samples: n_actual,
class_counts: Vec::new(),
impurity: mse,
};
}
let decrease =
(n_actual as f64 / n_root_samples as f64) * (mse - split.impurity_decrease);
self.feature_importances_[split.feature_idx] += decrease.max(0.0);
let (left_sorted, right_sorted) =
partition_sorted(sorted_by_feature, col, threshold, left_count, right_count);
let left = self.build_tree_reg(
data,
left_sorted,
n_root_samples,
depth + 1,
feature_buf,
split_rng,
);
let right = self.build_tree_reg(
data,
right_sorted,
n_root_samples,
depth + 1,
feature_buf,
split_rng,
);
TreeNode::Split {
feature_idx: split.feature_idx,
threshold,
left: Box::new(left),
right: Box::new(right),
n_samples: n_actual,
impurity: mse,
class_counts: Vec::new(),
prediction: mean,
}
}
}
}
fn find_best_split_reg(
&self,
data: &Dataset,
sorted_by_feature: &[Vec<usize>],
total_sum: f64,
total_sq: f64,
n_parent: usize,
feature_buf: &mut Vec<usize>,
split_rng: &mut crate::rng::FastRng,
) -> Option<BestSplit> {
let n_features = data.n_features();
let mut best: Option<BestSplit> = None;
fill_feature_buf(feature_buf, n_features, self.max_features, split_rng);
for &feat_idx in feature_buf.iter() {
let col = &data.features[feat_idx];
let sorted = &sorted_by_feature[feat_idx];
let mut left_sum = 0.0;
let mut left_sq_sum = 0.0;
let mut left_n = 0usize;
let mut prev_val = f64::NEG_INFINITY;
for &idx in sorted {
let feat_val = col[idx];
if left_n > 0 && (feat_val - prev_val).abs() > 1e-12 {
let right_n = n_parent - left_n;
if left_n >= self.min_samples_leaf && right_n >= self.min_samples_leaf {
let left_mse = (left_sq_sum / left_n as f64
- (left_sum / left_n as f64).powi(2))
.max(0.0);
let right_sum = total_sum - left_sum;
let right_sq = total_sq - left_sq_sum;
let right_mse = (right_sq / right_n as f64
- (right_sum / right_n as f64).powi(2))
.max(0.0);
let weighted = (left_n as f64 * left_mse + right_n as f64 * right_mse)
/ n_parent as f64;
let threshold = f64::midpoint(prev_val, feat_val);
let is_better =
best.as_ref().is_none_or(|b| weighted < b.impurity_decrease);
if is_better {
best = Some(BestSplit {
feature_idx: feat_idx,
threshold,
impurity_decrease: weighted,
});
}
}
}
let target_val = data.target[idx];
left_sum += target_val;
left_sq_sum += target_val * target_val;
left_n += 1;
prev_val = feat_val;
}
}
best
}
}
impl Default for DecisionTreeRegressor {
fn default() -> Self {
Self::new()
}
}