#[cfg(test)]
use alloc::vec::Vec;
#[derive(Debug, Clone, Copy)]
pub struct SplitCandidate {
pub bin_idx: usize,
pub gain: f64,
pub left_grad: f64,
pub left_hess: f64,
pub right_grad: f64,
pub right_hess: f64,
}
pub trait SplitCriterion: Send + Sync + 'static {
fn evaluate(
&self,
grad_sums: &[f64],
hess_sums: &[f64],
total_grad: f64,
total_hess: f64,
gamma: f64,
lambda: f64,
) -> Option<SplitCandidate>;
}
#[derive(Debug, Clone, Copy)]
pub struct XGBoostGain {
pub min_child_weight: f64,
}
impl Default for XGBoostGain {
fn default() -> Self {
Self {
min_child_weight: 1.0,
}
}
}
impl XGBoostGain {
pub fn new(min_child_weight: f64) -> Self {
Self { min_child_weight }
}
}
impl SplitCriterion for XGBoostGain {
fn evaluate(
&self,
grad_sums: &[f64],
hess_sums: &[f64],
total_grad: f64,
total_hess: f64,
gamma: f64,
lambda: f64,
) -> Option<SplitCandidate> {
let n_bins = grad_sums.len();
debug_assert_eq!(
n_bins,
hess_sums.len(),
"grad_sums and hess_sums must have the same length"
);
if n_bins < 2 {
return None;
}
let parent_score = total_grad * total_grad / (total_hess + lambda);
let mut best_gain = f64::NEG_INFINITY;
let mut best_bin = 0usize;
let mut best_left_grad = 0.0;
let mut best_left_hess = 0.0;
let mut best_right_grad = 0.0;
let mut best_right_hess = 0.0;
let mut left_grad = 0.0;
let mut left_hess = 0.0;
for i in 0..n_bins - 1 {
left_grad += grad_sums[i];
left_hess += hess_sums[i];
let right_grad = total_grad - left_grad;
let right_hess = total_hess - left_hess;
if left_hess < self.min_child_weight || right_hess < self.min_child_weight {
continue;
}
let left_score = left_grad * left_grad / (left_hess + lambda);
let right_score = right_grad * right_grad / (right_hess + lambda);
let gain = 0.5 * (left_score + right_score - parent_score) - gamma;
if gain > best_gain {
best_gain = gain;
best_bin = i;
best_left_grad = left_grad;
best_left_hess = left_hess;
best_right_grad = right_grad;
best_right_hess = right_hess;
}
}
if best_gain > 0.0 {
Some(SplitCandidate {
bin_idx: best_bin,
gain: best_gain,
left_grad: best_left_grad,
left_hess: best_left_hess,
right_grad: best_right_grad,
right_hess: best_right_hess,
})
} else {
None
}
}
}
#[inline]
pub fn leaf_weight(grad_sum: f64, hess_sum: f64, lambda: f64) -> f64 {
-grad_sum / (hess_sum + lambda)
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f64 = 1e-10;
#[test]
fn perfect_split() {
let criterion = XGBoostGain::new(0.0);
let grad_sums = [
-5.0, -5.0, 5.0, 5.0, ];
let hess_sums = [
2.0, 2.0, 2.0, 2.0, ];
let total_grad: f64 = grad_sums.iter().sum(); let total_hess: f64 = hess_sums.iter().sum();
let lambda = 1.0;
let gamma = 0.0;
let result = criterion
.evaluate(
&grad_sums, &hess_sums, total_grad, total_hess, gamma, lambda,
)
.expect("should find a valid split");
assert_eq!(result.bin_idx, 1);
assert!((result.left_grad - (-10.0)).abs() < EPSILON);
assert!((result.left_hess - 4.0).abs() < EPSILON);
assert!((result.right_grad - 10.0).abs() < EPSILON);
assert!((result.right_hess - 4.0).abs() < EPSILON);
assert!((result.gain - 20.0).abs() < EPSILON);
assert!(result.gain > 0.0);
}
#[test]
fn no_valid_split_single_bin() {
let criterion = XGBoostGain::new(0.0);
let grad_sums = [5.0];
let hess_sums = [3.0];
let result = criterion.evaluate(&grad_sums, &hess_sums, 5.0, 3.0, 0.0, 1.0);
assert!(result.is_none());
}
#[test]
fn no_valid_split_all_data_one_side() {
let criterion = XGBoostGain::new(1.0);
let grad_sums = [5.0, 0.0, 0.0];
let hess_sums = [3.0, 0.0, 0.0];
let result = criterion.evaluate(&grad_sums, &hess_sums, 5.0, 3.0, 0.0, 1.0);
assert!(result.is_none());
}
#[test]
fn min_child_weight_enforcement() {
let grad_sums = [10.0, 10.0];
let hess_sums = [0.5, 5.0];
let total_grad = 20.0;
let total_hess = 5.5;
let strict = XGBoostGain::new(1.0);
let result = strict.evaluate(&grad_sums, &hess_sums, total_grad, total_hess, 0.0, 1.0);
assert!(
result.is_none(),
"split should be rejected: left hess 0.5 < min_child_weight 1.0"
);
let lenient = XGBoostGain::new(0.1);
let result = lenient.evaluate(&grad_sums, &hess_sums, total_grad, total_hess, 0.0, 1.0);
assert!(
result.is_some(),
"split should be accepted with lower min_child_weight"
);
}
#[test]
fn leaf_weight_computation() {
assert!((leaf_weight(10.0, 5.0, 1.0) - (-10.0 / 6.0)).abs() < EPSILON);
assert!((leaf_weight(0.0, 5.0, 1.0) - 0.0).abs() < EPSILON);
assert!((leaf_weight(-3.0, 2.0, 0.5) - (3.0 / 2.5)).abs() < EPSILON);
assert!((leaf_weight(4.0, 2.0, 0.0) - (-2.0)).abs() < EPSILON);
}
#[test]
fn gain_symmetry_under_gradient_sign_flip() {
let criterion = XGBoostGain::new(0.0);
let lambda = 1.0;
let gamma = 0.0;
let grad_sums = [-3.0, -2.0, 2.0, 3.0];
let hess_sums = [1.0, 1.0, 1.0, 1.0];
let total_grad: f64 = grad_sums.iter().sum(); let total_hess: f64 = hess_sums.iter().sum();
let result_pos = criterion
.evaluate(
&grad_sums, &hess_sums, total_grad, total_hess, gamma, lambda,
)
.expect("should find split");
let grad_sums_neg: Vec<f64> = grad_sums.iter().map(|g| -g).collect();
let total_grad_neg: f64 = grad_sums_neg.iter().sum();
let result_neg = criterion
.evaluate(
&grad_sums_neg,
&hess_sums,
total_grad_neg,
total_hess,
gamma,
lambda,
)
.expect("should find split with negated gradients");
assert!(
(result_pos.gain - result_neg.gain).abs() < EPSILON,
"gain should be invariant under gradient sign flip: {} vs {}",
result_pos.gain,
result_neg.gain
);
assert_eq!(result_pos.bin_idx, result_neg.bin_idx);
}
#[test]
fn gamma_threshold_rejects_weak_split() {
let criterion = XGBoostGain::new(0.0);
let lambda = 1.0;
let grad_sums = [-1.0, 1.0];
let hess_sums = [5.0, 5.0];
let total_grad = 0.0;
let total_hess = 10.0;
let result =
criterion.evaluate(&grad_sums, &hess_sums, total_grad, total_hess, 0.0, lambda);
assert!(result.is_some(), "should find split with gamma=0");
let gain_no_gamma = result.unwrap().gain;
let result = criterion.evaluate(
&grad_sums,
&hess_sums,
total_grad,
total_hess,
gain_no_gamma + 1.0,
lambda,
);
assert!(
result.is_none(),
"split should be rejected when gamma exceeds raw gain"
);
}
#[test]
fn lambda_reduces_gain() {
let criterion = XGBoostGain::new(0.0);
let gamma = 0.0;
let grad_sums = [-5.0, 5.0];
let hess_sums = [2.0, 2.0];
let total_grad = 0.0;
let total_hess = 4.0;
let result_low = criterion
.evaluate(&grad_sums, &hess_sums, total_grad, total_hess, gamma, 0.1)
.expect("should find split with low lambda");
let result_high = criterion
.evaluate(&grad_sums, &hess_sums, total_grad, total_hess, gamma, 100.0)
.expect("should find split with high lambda");
assert!(
result_low.gain > result_high.gain,
"higher lambda should reduce gain: {} vs {}",
result_low.gain,
result_high.gain
);
}
#[test]
fn empty_histogram() {
let criterion = XGBoostGain::new(0.0);
let result = criterion.evaluate(&[], &[], 0.0, 0.0, 0.0, 1.0);
assert!(result.is_none());
}
#[test]
fn selects_best_among_multiple_candidates() {
let criterion = XGBoostGain::new(0.0);
let lambda = 1.0;
let gamma = 0.0;
let grad_sums = [-1.0, -1.0, -8.0, 5.0, 5.0];
let hess_sums = [1.0, 1.0, 1.0, 1.0, 1.0];
let total_grad: f64 = grad_sums.iter().sum();
let total_hess: f64 = hess_sums.iter().sum();
let result = criterion
.evaluate(
&grad_sums, &hess_sums, total_grad, total_hess, gamma, lambda,
)
.expect("should find a valid split");
assert_eq!(result.bin_idx, 2, "best split should be at bin 2");
}
}