Skip to main content

irithyll_core/tree/
split.rs

1//! Split criterion traits and XGBoost gain implementation.
2//!
3//! This module defines how candidate splits are evaluated from histogram
4//! gradient/hessian sums. The primary implementation is [`XGBoostGain`],
5//! which uses the exact gain formula from Chen & Guestrin (2016).
6
7#[cfg(test)]
8use alloc::vec::Vec;
9
10/// Result of evaluating potential splits across histogram bins.
11#[derive(Debug, Clone, Copy)]
12pub struct SplitCandidate {
13    /// Which bin index to split at (samples with bin <= index go left).
14    pub bin_idx: usize,
15    /// The gain from this split.
16    pub gain: f64,
17    /// Gradient sum for the left child.
18    pub left_grad: f64,
19    /// Hessian sum for the left child.
20    pub left_hess: f64,
21    /// Gradient sum for the right child.
22    pub right_grad: f64,
23    /// Hessian sum for the right child.
24    pub right_hess: f64,
25}
26
27/// Evaluates split quality from histogram gradient/hessian sums.
28pub trait SplitCriterion: Send + Sync + 'static {
29    /// Evaluate all possible split points across histogram bins.
30    /// Returns the best split candidate, or `None` if no valid split exists.
31    ///
32    /// # Arguments
33    ///
34    /// * `grad_sums` - gradient sums per bin
35    /// * `hess_sums` - hessian sums per bin
36    /// * `total_grad` - total gradient across all bins
37    /// * `total_hess` - total hessian across all bins
38    /// * `gamma` - minimum split gain threshold
39    /// * `lambda` - L2 regularization on leaf weights
40    fn evaluate(
41        &self,
42        grad_sums: &[f64],
43        hess_sums: &[f64],
44        total_grad: f64,
45        total_hess: f64,
46        gamma: f64,
47        lambda: f64,
48    ) -> Option<SplitCandidate>;
49}
50
51/// XGBoost-style split gain criterion.
52///
53/// Gain = 0.5 * (G_L^2/(H_L+lambda) + G_R^2/(H_R+lambda) - G^2/(H+lambda)) - gamma
54///
55/// where G_L, H_L are left child gradient/hessian sums,
56/// G_R, H_R are right child sums, G, H are total sums.
57///
58/// Reference: Chen & Guestrin, "XGBoost: A Scalable Tree Boosting System", KDD 2016.
59#[derive(Debug, Clone, Copy)]
60pub struct XGBoostGain {
61    /// Minimum hessian sum for a child to be valid.
62    pub min_child_weight: f64,
63}
64
65impl Default for XGBoostGain {
66    fn default() -> Self {
67        Self {
68            min_child_weight: 1.0,
69        }
70    }
71}
72
73impl XGBoostGain {
74    /// Create a new `XGBoostGain` with the given minimum child weight.
75    pub fn new(min_child_weight: f64) -> Self {
76        Self { min_child_weight }
77    }
78}
79
80impl SplitCriterion for XGBoostGain {
81    fn evaluate(
82        &self,
83        grad_sums: &[f64],
84        hess_sums: &[f64],
85        total_grad: f64,
86        total_hess: f64,
87        gamma: f64,
88        lambda: f64,
89    ) -> Option<SplitCandidate> {
90        let n_bins = grad_sums.len();
91        debug_assert_eq!(
92            n_bins,
93            hess_sums.len(),
94            "grad_sums and hess_sums must have the same length"
95        );
96
97        // Need at least 2 bins to form a left and right child.
98        if n_bins < 2 {
99            return None;
100        }
101
102        // Precompute the parent term, which is constant across all candidate splits.
103        let parent_score = total_grad * total_grad / (total_hess + lambda);
104
105        let mut best_gain = f64::NEG_INFINITY;
106        let mut best_bin = 0usize;
107        let mut best_left_grad = 0.0;
108        let mut best_left_hess = 0.0;
109        let mut best_right_grad = 0.0;
110        let mut best_right_hess = 0.0;
111
112        // Running prefix sums for the left child.
113        let mut left_grad = 0.0;
114        let mut left_hess = 0.0;
115
116        // Iterate through bins 0..n_bins-1. The split point at bin_idx=i means
117        // bins [0..=i] go left and bins [i+1..n_bins-1] go right.
118        // We stop at n_bins-2 so the right child always has at least 1 bin.
119        for i in 0..n_bins - 1 {
120            left_grad += grad_sums[i];
121            left_hess += hess_sums[i];
122
123            let right_grad = total_grad - left_grad;
124            let right_hess = total_hess - left_hess;
125
126            // Enforce minimum child weight on both children.
127            if left_hess < self.min_child_weight || right_hess < self.min_child_weight {
128                continue;
129            }
130
131            let left_score = left_grad * left_grad / (left_hess + lambda);
132            let right_score = right_grad * right_grad / (right_hess + lambda);
133            let gain = 0.5 * (left_score + right_score - parent_score) - gamma;
134
135            if gain > best_gain {
136                best_gain = gain;
137                best_bin = i;
138                best_left_grad = left_grad;
139                best_left_hess = left_hess;
140                best_right_grad = right_grad;
141                best_right_hess = right_hess;
142            }
143        }
144
145        // Only return a split if the best gain is strictly positive.
146        if best_gain > 0.0 {
147            Some(SplitCandidate {
148                bin_idx: best_bin,
149                gain: best_gain,
150                left_grad: best_left_grad,
151                left_hess: best_left_hess,
152                right_grad: best_right_grad,
153                right_hess: best_right_hess,
154            })
155        } else {
156            None
157        }
158    }
159}
160
161/// Compute the optimal leaf weight: -G / (H + lambda).
162///
163/// This is the closed-form solution for the leaf value that minimizes
164/// the regularized second-order Taylor expansion of the loss.
165#[inline]
166pub fn leaf_weight(grad_sum: f64, hess_sum: f64, lambda: f64) -> f64 {
167    -grad_sum / (hess_sum + lambda)
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    const EPSILON: f64 = 1e-10;
175
176    /// Two clearly separable groups: bins [0,1] have negative gradients,
177    /// bins [2,3] have positive gradients. The best split should be at bin 1.
178    #[test]
179    fn perfect_split() {
180        let criterion = XGBoostGain::new(0.0);
181
182        // Left group: negative gradient, right group: positive gradient.
183        let grad_sums = [
184            -5.0, -5.0, // left: total grad = -10
185            5.0, 5.0, // right: total grad = +10
186        ];
187        let hess_sums = [
188            2.0, 2.0, // left: total hess = 4
189            2.0, 2.0, // right: total hess = 4
190        ];
191        let total_grad: f64 = grad_sums.iter().sum(); // 0.0
192        let total_hess: f64 = hess_sums.iter().sum(); // 8.0
193
194        let lambda = 1.0;
195        let gamma = 0.0;
196
197        let result = criterion
198            .evaluate(
199                &grad_sums, &hess_sums, total_grad, total_hess, gamma, lambda,
200            )
201            .expect("should find a valid split");
202
203        // Best split at bin index 1 (bins 0,1 go left; bins 2,3 go right).
204        assert_eq!(result.bin_idx, 1);
205
206        // Verify gradient/hessian sums.
207        assert!((result.left_grad - (-10.0)).abs() < EPSILON);
208        assert!((result.left_hess - 4.0).abs() < EPSILON);
209        assert!((result.right_grad - 10.0).abs() < EPSILON);
210        assert!((result.right_hess - 4.0).abs() < EPSILON);
211
212        // Verify gain calculation manually:
213        // parent_score = 0^2 / (8+1) = 0
214        // left_score = (-10)^2 / (4+1) = 100/5 = 20
215        // right_score = (10)^2 / (4+1) = 100/5 = 20
216        // gain = 0.5 * (20 + 20 - 0) - 0 = 20.0
217        assert!((result.gain - 20.0).abs() < EPSILON);
218        assert!(result.gain > 0.0);
219    }
220
221    /// All data concentrated in a single bin. No valid split possible because
222    /// the other side would always be empty.
223    #[test]
224    fn no_valid_split_single_bin() {
225        let criterion = XGBoostGain::new(0.0);
226
227        // Only one bin: impossible to split.
228        let grad_sums = [5.0];
229        let hess_sums = [3.0];
230
231        let result = criterion.evaluate(&grad_sums, &hess_sums, 5.0, 3.0, 0.0, 1.0);
232        assert!(result.is_none());
233    }
234
235    /// All data in the first bin, rest empty. Even with multiple bins, the
236    /// right child would have zero hessian so no split meets min_child_weight=0
237    /// unless there is actual data on both sides. Here we use min_child_weight=0
238    /// but right hessian is 0, so only the right side violates the check
239    /// when min_child_weight > 0. With min_child_weight=0, gain should still
240    /// be zero or negative since one side has no information.
241    #[test]
242    fn no_valid_split_all_data_one_side() {
243        let criterion = XGBoostGain::new(1.0);
244
245        let grad_sums = [5.0, 0.0, 0.0];
246        let hess_sums = [3.0, 0.0, 0.0];
247
248        let result = criterion.evaluate(&grad_sums, &hess_sums, 5.0, 3.0, 0.0, 1.0);
249        // Right hessian = 0 at every candidate, which is below min_child_weight=1.0.
250        assert!(result.is_none());
251    }
252
253    /// min_child_weight enforcement: a split that would otherwise be valid
254    /// gets rejected because one child has insufficient hessian.
255    #[test]
256    fn min_child_weight_enforcement() {
257        // Two bins, each with reasonable gradient signal but one has low hessian.
258        let grad_sums = [10.0, 10.0];
259        let hess_sums = [0.5, 5.0]; // left hess = 0.5, right hess = 5.0
260
261        let total_grad = 20.0;
262        let total_hess = 5.5;
263
264        // With min_child_weight=1.0, left_hess=0.5 is too small.
265        let strict = XGBoostGain::new(1.0);
266        let result = strict.evaluate(&grad_sums, &hess_sums, total_grad, total_hess, 0.0, 1.0);
267        assert!(
268            result.is_none(),
269            "split should be rejected: left hess 0.5 < min_child_weight 1.0"
270        );
271
272        // With min_child_weight=0.1, the split should be accepted.
273        let lenient = XGBoostGain::new(0.1);
274        let result = lenient.evaluate(&grad_sums, &hess_sums, total_grad, total_hess, 0.0, 1.0);
275        assert!(
276            result.is_some(),
277            "split should be accepted with lower min_child_weight"
278        );
279    }
280
281    /// Verify the leaf weight formula: w = -G / (H + lambda).
282    #[test]
283    fn leaf_weight_computation() {
284        // Basic case.
285        assert!((leaf_weight(10.0, 5.0, 1.0) - (-10.0 / 6.0)).abs() < EPSILON);
286
287        // Zero gradient -> zero weight.
288        assert!((leaf_weight(0.0, 5.0, 1.0) - 0.0).abs() < EPSILON);
289
290        // Negative gradient -> positive weight.
291        assert!((leaf_weight(-3.0, 2.0, 0.5) - (3.0 / 2.5)).abs() < EPSILON);
292
293        // Lambda=0 edge case.
294        assert!((leaf_weight(4.0, 2.0, 0.0) - (-2.0)).abs() < EPSILON);
295    }
296
297    /// Symmetry test: if we negate all gradients on a symmetric histogram,
298    /// the gain should remain the same (gain depends on G^2 terms).
299    #[test]
300    fn gain_symmetry_under_gradient_sign_flip() {
301        let criterion = XGBoostGain::new(0.0);
302        let lambda = 1.0;
303        let gamma = 0.0;
304
305        let grad_sums = [-3.0, -2.0, 2.0, 3.0];
306        let hess_sums = [1.0, 1.0, 1.0, 1.0];
307        let total_grad: f64 = grad_sums.iter().sum(); // 0.0
308        let total_hess: f64 = hess_sums.iter().sum(); // 4.0
309
310        let result_pos = criterion
311            .evaluate(
312                &grad_sums, &hess_sums, total_grad, total_hess, gamma, lambda,
313            )
314            .expect("should find split");
315
316        // Negate all gradients.
317        let grad_sums_neg: Vec<f64> = grad_sums.iter().map(|g| -g).collect();
318        let total_grad_neg: f64 = grad_sums_neg.iter().sum(); // still 0.0
319
320        let result_neg = criterion
321            .evaluate(
322                &grad_sums_neg,
323                &hess_sums,
324                total_grad_neg,
325                total_hess,
326                gamma,
327                lambda,
328            )
329            .expect("should find split with negated gradients");
330
331        // Gain should be identical.
332        assert!(
333            (result_pos.gain - result_neg.gain).abs() < EPSILON,
334            "gain should be invariant under gradient sign flip: {} vs {}",
335            result_pos.gain,
336            result_neg.gain
337        );
338
339        // Both should pick the same bin.
340        assert_eq!(result_pos.bin_idx, result_neg.bin_idx);
341    }
342
343    /// Gamma threshold: a split that is marginally positive without gamma
344    /// should be rejected when gamma is large enough.
345    #[test]
346    fn gamma_threshold_rejects_weak_split() {
347        let criterion = XGBoostGain::new(0.0);
348        let lambda = 1.0;
349
350        let grad_sums = [-1.0, 1.0];
351        let hess_sums = [5.0, 5.0];
352        let total_grad = 0.0;
353        let total_hess = 10.0;
354
355        // With gamma=0, should find a split.
356        let result =
357            criterion.evaluate(&grad_sums, &hess_sums, total_grad, total_hess, 0.0, lambda);
358        assert!(result.is_some(), "should find split with gamma=0");
359        let gain_no_gamma = result.unwrap().gain;
360
361        // With gamma larger than the raw gain, should reject.
362        let result = criterion.evaluate(
363            &grad_sums,
364            &hess_sums,
365            total_grad,
366            total_hess,
367            gain_no_gamma + 1.0,
368            lambda,
369        );
370        assert!(
371            result.is_none(),
372            "split should be rejected when gamma exceeds raw gain"
373        );
374    }
375
376    /// Lambda regularization: increasing lambda should reduce gain by
377    /// penalizing the score terms.
378    #[test]
379    fn lambda_reduces_gain() {
380        let criterion = XGBoostGain::new(0.0);
381        let gamma = 0.0;
382
383        let grad_sums = [-5.0, 5.0];
384        let hess_sums = [2.0, 2.0];
385        let total_grad = 0.0;
386        let total_hess = 4.0;
387
388        let result_low = criterion
389            .evaluate(&grad_sums, &hess_sums, total_grad, total_hess, gamma, 0.1)
390            .expect("should find split with low lambda");
391
392        let result_high = criterion
393            .evaluate(&grad_sums, &hess_sums, total_grad, total_hess, gamma, 100.0)
394            .expect("should find split with high lambda");
395
396        assert!(
397            result_low.gain > result_high.gain,
398            "higher lambda should reduce gain: {} vs {}",
399            result_low.gain,
400            result_high.gain
401        );
402    }
403
404    /// Empty histogram (no bins) should return None.
405    #[test]
406    fn empty_histogram() {
407        let criterion = XGBoostGain::new(0.0);
408        let result = criterion.evaluate(&[], &[], 0.0, 0.0, 0.0, 1.0);
409        assert!(result.is_none());
410    }
411
412    /// Best split is chosen among multiple valid candidates.
413    #[test]
414    fn selects_best_among_multiple_candidates() {
415        let criterion = XGBoostGain::new(0.0);
416        let lambda = 1.0;
417        let gamma = 0.0;
418
419        // Design: gradients are arranged so the best split is at bin 2.
420        // Bins: [-1, -1, -8, 5, 5] (total_grad = 0)
421        // Best split at bin 2: left_grad=-10, right_grad=10.
422        let grad_sums = [-1.0, -1.0, -8.0, 5.0, 5.0];
423        let hess_sums = [1.0, 1.0, 1.0, 1.0, 1.0];
424        let total_grad: f64 = grad_sums.iter().sum();
425        let total_hess: f64 = hess_sums.iter().sum();
426
427        let result = criterion
428            .evaluate(
429                &grad_sums, &hess_sums, total_grad, total_hess, gamma, lambda,
430            )
431            .expect("should find a valid split");
432
433        assert_eq!(result.bin_idx, 2, "best split should be at bin 2");
434    }
435}