use alloc::vec::Vec;
use crate::feature::FeatureType;
use crate::histogram::BinEdges;
use crate::math;
use crate::rng::xorshift64;
use crate::tree::split::{leaf_weight, SplitCandidate, SplitCriterion, XGBoostGain};
use super::leaf::{make_binners, LeafState};
pub(crate) const TAU: f64 = 0.05;
pub(crate) mod private {
use super::*;
use crate::histogram::bins::LeafHistograms;
use crate::tree::builder::TreeConfig;
use alloc::boxed::Box;
pub struct SplitContext<'a> {
pub config: &'a TreeConfig,
pub n_features: Option<usize>,
pub n_feature_mask: &'a [usize],
pub split_criterion: &'a XGBoostGain,
#[allow(dead_code)]
pub rng_state: &'a mut u64,
}
pub fn evaluate_split_candidates(
histograms: &LeafHistograms,
feature_types: Option<&[FeatureType]>,
ctx: &SplitContext,
) -> Vec<(usize, SplitCandidate, Option<Vec<usize>>)> {
let mut candidates: Vec<(usize, SplitCandidate, Option<Vec<usize>>)> = Vec::new();
for &feat_idx in ctx.n_feature_mask {
if feat_idx >= histograms.n_features() {
continue;
}
let hist = &histograms.histograms[feat_idx];
let total_grad = hist.total_gradient();
let total_hess = hist.total_hessian();
let is_categorical = feature_types
.as_ref()
.is_some_and(|ft| feat_idx < ft.len() && ft[feat_idx] == FeatureType::Categorical);
if is_categorical {
let n_bins = hist.grad_sums.len();
if n_bins < 2 {
continue;
}
let mut bin_order: Vec<usize> = (0..n_bins)
.filter(|&i| math::abs(hist.hess_sums[i]) > 1e-15)
.collect();
if bin_order.len() < 2 {
continue;
}
bin_order.sort_by(|&a, &b| {
let ratio_a = hist.grad_sums[a] / hist.hess_sums[a];
let ratio_b = hist.grad_sums[b] / hist.hess_sums[b];
ratio_a
.partial_cmp(&ratio_b)
.unwrap_or(core::cmp::Ordering::Equal)
});
let sorted_grads: Vec<f64> = bin_order.iter().map(|&i| hist.grad_sums[i]).collect();
let sorted_hess: Vec<f64> = bin_order.iter().map(|&i| hist.hess_sums[i]).collect();
if let Some(candidate) = ctx.split_criterion.evaluate(
&sorted_grads,
&sorted_hess,
total_grad,
total_hess,
ctx.config.gamma,
ctx.config.lambda,
) {
candidates.push((feat_idx, candidate, Some(bin_order)));
}
} else {
if let Some(candidate) = ctx.split_criterion.evaluate(
&hist.grad_sums,
&hist.hess_sums,
total_grad,
total_hess,
ctx.config.gamma,
ctx.config.lambda,
) {
candidates.push((feat_idx, candidate, None));
}
}
}
if let Some(ref mc) = ctx.config.monotone_constraints {
candidates.retain(|(feat_idx, candidate, _)| {
if *feat_idx >= mc.len() {
return true; }
let constraint = mc[*feat_idx];
if constraint == 0 {
return true; }
let left_val =
leaf_weight(candidate.left_grad, candidate.left_hess, ctx.config.lambda);
let right_val = leaf_weight(
candidate.right_grad,
candidate.right_hess,
ctx.config.lambda,
);
if constraint > 0 {
left_val <= right_val
} else {
left_val >= right_val
}
});
}
candidates.sort_by(|a, b| {
b.1.gain
.partial_cmp(&a.1.gain)
.unwrap_or(core::cmp::Ordering::Equal)
});
candidates
}
pub fn should_split_hoeffding(
best_gain: f64,
second_best_gain: f64,
sample_count: u64,
ctx: &SplitContext,
) -> bool {
if let Some(cir_factor) = ctx.config.adaptive_depth {
let n = sample_count as f64;
if n > 1.0 {
let effective_n = match ctx.config.leaf_decay_alpha {
Some(alpha) => n.min(1.0 / (1.0 - alpha)),
None => n,
};
let n_feat = ctx.n_features.unwrap_or(1) as f64;
let penalty = cir_factor / (effective_n * n_feat);
if best_gain <= penalty {
return false; }
}
}
let r = ctx.config.hoeffding_r.unwrap_or(1.0);
let r_squared = r * r;
let n = sample_count as f64;
let effective_n = match ctx.config.leaf_decay_alpha {
Some(alpha) => n.min(1.0 / (1.0 - alpha)),
None => n,
};
let ln_inv_delta = math::ln(1.0 / ctx.config.delta);
let epsilon = math::sqrt(r_squared * ln_inv_delta / (2.0 * effective_n));
let gap = best_gain - second_best_gain;
!(gap <= epsilon && epsilon >= TAU)
}
#[allow(dead_code)]
pub fn initialize_child_states(
n_features: usize,
feature_types: Option<&[FeatureType]>,
leaf_model_factory: Option<Box<dyn Fn() -> Box<dyn crate::tree::leaf_model::LeafModel>>>,
parent_hists: Option<LeafHistograms>,
) -> (LeafState, LeafState) {
let edges_per_feature: Vec<BinEdges> = parent_hists
.as_ref()
.map(|h| h.histograms.iter().map(|hist| hist.edges.clone()).collect())
.unwrap_or_default();
let (left_hists, right_hists) = if parent_hists.is_some() {
(
LeafHistograms::new(&edges_per_feature),
LeafHistograms::new(&edges_per_feature),
)
} else {
(LeafHistograms::new(&[]), LeafHistograms::new(&[]))
};
let child_binners_l = make_binners(n_features, feature_types);
let child_binners_r = make_binners(n_features, feature_types);
let left_model = leaf_model_factory.as_ref().map(|f| f());
let right_model = leaf_model_factory.as_ref().map(|f| f());
let left_state = LeafState {
histograms: if parent_hists.is_some() {
Some(left_hists)
} else {
None
},
binners: child_binners_l,
bins_ready: parent_hists.is_some(),
grad_sum: 0.0,
hess_sum: 0.0,
last_reeval_count: 0,
clip_grad_mean: 0.0,
clip_grad_m2: 0.0,
clip_grad_count: 0,
output_mean: 0.0,
output_m2: 0.0,
output_count: 0,
leaf_model: left_model,
};
let right_state = LeafState {
histograms: if parent_hists.is_some() {
Some(right_hists)
} else {
None
},
binners: child_binners_r,
bins_ready: parent_hists.is_some(),
grad_sum: 0.0,
hess_sum: 0.0,
last_reeval_count: 0,
clip_grad_mean: 0.0,
clip_grad_m2: 0.0,
clip_grad_count: 0,
output_mean: 0.0,
output_m2: 0.0,
output_count: 0,
leaf_model: right_model,
};
(left_state, right_state)
}
}
pub(crate) fn generate_feature_mask(
mut feature_mask: Vec<usize>,
mut feature_mask_bits: Vec<u64>,
rng_state: &mut u64,
subsample_rate: f64,
n_features: usize,
) -> (Vec<usize>, Vec<u64>) {
feature_mask.clear();
if subsample_rate >= 1.0 {
feature_mask.extend(0..n_features);
return (feature_mask, feature_mask_bits);
}
let target_count = math::ceil((n_features as f64) * subsample_rate) as usize;
let target_count = target_count.max(1).min(n_features);
let n_words = n_features.div_ceil(64);
feature_mask_bits.clear();
feature_mask_bits.resize(n_words, 0u64);
for i in 0..n_features {
let r = xorshift64(rng_state);
let p = (r as f64) / (u64::MAX as f64);
if p < subsample_rate {
feature_mask.push(i);
feature_mask_bits[i / 64] |= 1u64 << (i % 64);
}
}
if feature_mask.len() < target_count {
for i in 0..n_features {
if feature_mask.len() >= target_count {
break;
}
if feature_mask_bits[i / 64] & (1u64 << (i % 64)) == 0 {
feature_mask.push(i);
feature_mask_bits[i / 64] |= 1u64 << (i % 64);
}
}
}
(feature_mask, feature_mask_bits)
}