1#[cfg(test)]
8use alloc::vec::Vec;
9
10#[derive(Debug, Clone, Copy)]
12pub struct SplitCandidate {
13 pub bin_idx: usize,
15 pub gain: f64,
17 pub left_grad: f64,
19 pub left_hess: f64,
21 pub right_grad: f64,
23 pub right_hess: f64,
25}
26
27pub trait SplitCriterion: Send + Sync + 'static {
29 fn evaluate(
41 &self,
42 grad_sums: &[f64],
43 hess_sums: &[f64],
44 total_grad: f64,
45 total_hess: f64,
46 gamma: f64,
47 lambda: f64,
48 ) -> Option<SplitCandidate>;
49}
50
51#[derive(Debug, Clone, Copy)]
60pub struct XGBoostGain {
61 pub min_child_weight: f64,
63}
64
65impl Default for XGBoostGain {
66 fn default() -> Self {
67 Self {
68 min_child_weight: 1.0,
69 }
70 }
71}
72
73impl XGBoostGain {
74 pub fn new(min_child_weight: f64) -> Self {
76 Self { min_child_weight }
77 }
78}
79
80impl SplitCriterion for XGBoostGain {
81 fn evaluate(
82 &self,
83 grad_sums: &[f64],
84 hess_sums: &[f64],
85 total_grad: f64,
86 total_hess: f64,
87 gamma: f64,
88 lambda: f64,
89 ) -> Option<SplitCandidate> {
90 let n_bins = grad_sums.len();
91 debug_assert_eq!(
92 n_bins,
93 hess_sums.len(),
94 "grad_sums and hess_sums must have the same length"
95 );
96
97 if n_bins < 2 {
99 return None;
100 }
101
102 let parent_score = total_grad * total_grad / (total_hess + lambda);
104
105 let mut best_gain = f64::NEG_INFINITY;
106 let mut best_bin = 0usize;
107 let mut best_left_grad = 0.0;
108 let mut best_left_hess = 0.0;
109 let mut best_right_grad = 0.0;
110 let mut best_right_hess = 0.0;
111
112 let mut left_grad = 0.0;
114 let mut left_hess = 0.0;
115
116 for i in 0..n_bins - 1 {
120 left_grad += grad_sums[i];
121 left_hess += hess_sums[i];
122
123 let right_grad = total_grad - left_grad;
124 let right_hess = total_hess - left_hess;
125
126 if left_hess < self.min_child_weight || right_hess < self.min_child_weight {
128 continue;
129 }
130
131 let left_score = left_grad * left_grad / (left_hess + lambda);
132 let right_score = right_grad * right_grad / (right_hess + lambda);
133 let gain = 0.5 * (left_score + right_score - parent_score) - gamma;
134
135 if gain > best_gain {
136 best_gain = gain;
137 best_bin = i;
138 best_left_grad = left_grad;
139 best_left_hess = left_hess;
140 best_right_grad = right_grad;
141 best_right_hess = right_hess;
142 }
143 }
144
145 if best_gain > 0.0 {
147 Some(SplitCandidate {
148 bin_idx: best_bin,
149 gain: best_gain,
150 left_grad: best_left_grad,
151 left_hess: best_left_hess,
152 right_grad: best_right_grad,
153 right_hess: best_right_hess,
154 })
155 } else {
156 None
157 }
158 }
159}
160
161#[inline]
166pub fn leaf_weight(grad_sum: f64, hess_sum: f64, lambda: f64) -> f64 {
167 -grad_sum / (hess_sum + lambda)
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 const EPSILON: f64 = 1e-10;
175
176 #[test]
179 fn perfect_split() {
180 let criterion = XGBoostGain::new(0.0);
181
182 let grad_sums = [
184 -5.0, -5.0, 5.0, 5.0, ];
187 let hess_sums = [
188 2.0, 2.0, 2.0, 2.0, ];
191 let total_grad: f64 = grad_sums.iter().sum(); let total_hess: f64 = hess_sums.iter().sum(); let lambda = 1.0;
195 let gamma = 0.0;
196
197 let result = criterion
198 .evaluate(
199 &grad_sums, &hess_sums, total_grad, total_hess, gamma, lambda,
200 )
201 .expect("should find a valid split");
202
203 assert_eq!(result.bin_idx, 1);
205
206 assert!((result.left_grad - (-10.0)).abs() < EPSILON);
208 assert!((result.left_hess - 4.0).abs() < EPSILON);
209 assert!((result.right_grad - 10.0).abs() < EPSILON);
210 assert!((result.right_hess - 4.0).abs() < EPSILON);
211
212 assert!((result.gain - 20.0).abs() < EPSILON);
218 assert!(result.gain > 0.0);
219 }
220
221 #[test]
224 fn no_valid_split_single_bin() {
225 let criterion = XGBoostGain::new(0.0);
226
227 let grad_sums = [5.0];
229 let hess_sums = [3.0];
230
231 let result = criterion.evaluate(&grad_sums, &hess_sums, 5.0, 3.0, 0.0, 1.0);
232 assert!(result.is_none());
233 }
234
235 #[test]
242 fn no_valid_split_all_data_one_side() {
243 let criterion = XGBoostGain::new(1.0);
244
245 let grad_sums = [5.0, 0.0, 0.0];
246 let hess_sums = [3.0, 0.0, 0.0];
247
248 let result = criterion.evaluate(&grad_sums, &hess_sums, 5.0, 3.0, 0.0, 1.0);
249 assert!(result.is_none());
251 }
252
253 #[test]
256 fn min_child_weight_enforcement() {
257 let grad_sums = [10.0, 10.0];
259 let hess_sums = [0.5, 5.0]; let total_grad = 20.0;
262 let total_hess = 5.5;
263
264 let strict = XGBoostGain::new(1.0);
266 let result = strict.evaluate(&grad_sums, &hess_sums, total_grad, total_hess, 0.0, 1.0);
267 assert!(
268 result.is_none(),
269 "split should be rejected: left hess 0.5 < min_child_weight 1.0"
270 );
271
272 let lenient = XGBoostGain::new(0.1);
274 let result = lenient.evaluate(&grad_sums, &hess_sums, total_grad, total_hess, 0.0, 1.0);
275 assert!(
276 result.is_some(),
277 "split should be accepted with lower min_child_weight"
278 );
279 }
280
281 #[test]
283 fn leaf_weight_computation() {
284 assert!((leaf_weight(10.0, 5.0, 1.0) - (-10.0 / 6.0)).abs() < EPSILON);
286
287 assert!((leaf_weight(0.0, 5.0, 1.0) - 0.0).abs() < EPSILON);
289
290 assert!((leaf_weight(-3.0, 2.0, 0.5) - (3.0 / 2.5)).abs() < EPSILON);
292
293 assert!((leaf_weight(4.0, 2.0, 0.0) - (-2.0)).abs() < EPSILON);
295 }
296
297 #[test]
300 fn gain_symmetry_under_gradient_sign_flip() {
301 let criterion = XGBoostGain::new(0.0);
302 let lambda = 1.0;
303 let gamma = 0.0;
304
305 let grad_sums = [-3.0, -2.0, 2.0, 3.0];
306 let hess_sums = [1.0, 1.0, 1.0, 1.0];
307 let total_grad: f64 = grad_sums.iter().sum(); let total_hess: f64 = hess_sums.iter().sum(); let result_pos = criterion
311 .evaluate(
312 &grad_sums, &hess_sums, total_grad, total_hess, gamma, lambda,
313 )
314 .expect("should find split");
315
316 let grad_sums_neg: Vec<f64> = grad_sums.iter().map(|g| -g).collect();
318 let total_grad_neg: f64 = grad_sums_neg.iter().sum(); let result_neg = criterion
321 .evaluate(
322 &grad_sums_neg,
323 &hess_sums,
324 total_grad_neg,
325 total_hess,
326 gamma,
327 lambda,
328 )
329 .expect("should find split with negated gradients");
330
331 assert!(
333 (result_pos.gain - result_neg.gain).abs() < EPSILON,
334 "gain should be invariant under gradient sign flip: {} vs {}",
335 result_pos.gain,
336 result_neg.gain
337 );
338
339 assert_eq!(result_pos.bin_idx, result_neg.bin_idx);
341 }
342
343 #[test]
346 fn gamma_threshold_rejects_weak_split() {
347 let criterion = XGBoostGain::new(0.0);
348 let lambda = 1.0;
349
350 let grad_sums = [-1.0, 1.0];
351 let hess_sums = [5.0, 5.0];
352 let total_grad = 0.0;
353 let total_hess = 10.0;
354
355 let result =
357 criterion.evaluate(&grad_sums, &hess_sums, total_grad, total_hess, 0.0, lambda);
358 assert!(result.is_some(), "should find split with gamma=0");
359 let gain_no_gamma = result.unwrap().gain;
360
361 let result = criterion.evaluate(
363 &grad_sums,
364 &hess_sums,
365 total_grad,
366 total_hess,
367 gain_no_gamma + 1.0,
368 lambda,
369 );
370 assert!(
371 result.is_none(),
372 "split should be rejected when gamma exceeds raw gain"
373 );
374 }
375
376 #[test]
379 fn lambda_reduces_gain() {
380 let criterion = XGBoostGain::new(0.0);
381 let gamma = 0.0;
382
383 let grad_sums = [-5.0, 5.0];
384 let hess_sums = [2.0, 2.0];
385 let total_grad = 0.0;
386 let total_hess = 4.0;
387
388 let result_low = criterion
389 .evaluate(&grad_sums, &hess_sums, total_grad, total_hess, gamma, 0.1)
390 .expect("should find split with low lambda");
391
392 let result_high = criterion
393 .evaluate(&grad_sums, &hess_sums, total_grad, total_hess, gamma, 100.0)
394 .expect("should find split with high lambda");
395
396 assert!(
397 result_low.gain > result_high.gain,
398 "higher lambda should reduce gain: {} vs {}",
399 result_low.gain,
400 result_high.gain
401 );
402 }
403
404 #[test]
406 fn empty_histogram() {
407 let criterion = XGBoostGain::new(0.0);
408 let result = criterion.evaluate(&[], &[], 0.0, 0.0, 0.0, 1.0);
409 assert!(result.is_none());
410 }
411
412 #[test]
414 fn selects_best_among_multiple_candidates() {
415 let criterion = XGBoostGain::new(0.0);
416 let lambda = 1.0;
417 let gamma = 0.0;
418
419 let grad_sums = [-1.0, -1.0, -8.0, 5.0, 5.0];
423 let hess_sums = [1.0, 1.0, 1.0, 1.0, 1.0];
424 let total_grad: f64 = grad_sums.iter().sum();
425 let total_hess: f64 = hess_sums.iter().sum();
426
427 let result = criterion
428 .evaluate(
429 &grad_sums, &hess_sums, total_grad, total_hess, gamma, lambda,
430 )
431 .expect("should find a valid split");
432
433 assert_eq!(result.bin_idx, 2, "best split should be at bin 2");
434 }
435}