gbrt_rs/tree/criterion.rs
1//! Split criteria for decision tree construction in gradient boosting.
2//!
3//! This module defines traits and implementations for evaluating the quality of potential
4//! splits when building decision trees. Split criteria determine which feature and threshold
5//! provide the best separation of data to minimize the loss function.
6//!
7//! # Key Concepts
8//!
9//! In gradient boosting, split criteria operate on **gradients** and **hessians** (first and
10//! second derivatives of the loss function) rather than raw target values. This allows the
11//! same tree-building algorithm to work with any differentiable loss function.
12//!
13//! # Available Criteria
14//!
15//! - [`MSECriterion`]: Basic mean squared error criterion with L2 regularization
16//! - [`FriedmanMSECriterion`]: Enhanced MSE criterion with L1/L2 regularization and
17//! minimum gain threshold (used in XGBoost and LightGBM)
18//!
19//! # Regularization
20//!
21//! Both criteria support regularization parameters to prevent overfitting:
22//! - `lambda`: L2 regularization on leaf weights
23//! - `alpha`: L1 regularization (Friedman only)
24//! - `gamma`: Minimum gain required to make a split (Friedman only)
25//! - `min_hessian`: Minimum sum of hessians required for a valid split
26
27use crate::data::FeatureMatrix;
28use thiserror::Error;
29
30/// Errors that can occur during split criterion computation.
31///
32/// These errors cover invalid inputs, insufficient samples, and numerical issues
33/// that prevent proper calculation of split quality.
34#[derive(Error, Debug)]
35pub enum CriterionError {
36 /// Input data is invalid for computation.
37 ///
38 /// This includes mismatched array lengths, non-finite values, or other
39 /// data integrity issues that violate the criterion's preconditions.
40 #[error("Invalid input data: {0}")]
41 InvalidInput(String),
42
43 /// Not enough samples to perform a meaningful split.
44 ///
45 /// This error is returned when a node or split candidate has fewer samples
46 /// than the minimum required for stable statistical estimation.
47 #[error("Insufficient samples: {samples} (min: {min})")]
48 InsufficientSamples { samples: usize, min: usize },
49}
50
51/// Core trait for evaluating split quality in decision trees.
52///
53/// Split criteria determine how well a candidate split separates the data
54/// in terms of reducing the loss function. They operate on gradients and hessians
55/// from the objective function's derivatives.
56///
57/// # Type Safety
58///
59/// Implementations must be thread-safe (`Send + Sync`) as they may be used
60/// in parallel tree construction.
61pub trait SplitCriterion: Send + Sync {
62 /// Computes the improvement (gain) from making a split.
63 ///
64 /// The gain measures how much the loss is reduced by splitting a parent node
65 /// into left and right children. Higher gain indicates a better split.
66 ///
67 /// # Arguments
68 ///
69 /// * `features` - Feature matrix of the training data
70 /// * `gradients` - First derivatives of the loss w.r.t predictions (one per sample)
71 /// * `hessians` - Second derivatives of the loss w.r.t predictions (one per sample)
72 /// * `left_indices` - Indices of samples assigned to the left child node
73 /// * `right_indices` - Indices of samples assigned to the right child node
74 ///
75 /// # Returns
76 ///
77 /// The split gain (improvement score). Should be non-negative.
78 ///
79 /// # Errors
80 ///
81 /// Returns [`CriterionError::InvalidInput`] if gradients and hessians have different lengths
82 /// Returns [`CriterionError::InsufficientSamples`] if either child has too few samples
83 fn compute_gain(
84 &self,
85 features: &FeatureMatrix,
86 gradients: &[f64],
87 hessians: &[f64],
88 left_indices: &[usize],
89 right_indices: &[usize],
90 ) -> Result<f64, CriterionError>;
91
92 /// Computes the optimal prediction value for a leaf node.
93 ///
94 /// The leaf value minimizes the loss for the samples reaching that leaf.
95 /// For MSE-based criteria, this is typically `-sum(gradients) / sum(hessians + lambda)`.
96 ///
97 /// # Arguments
98 ///
99 /// * `gradients` - First derivatives of samples in the leaf
100 /// * `hessians` - Second derivatives of samples in the leaf
101 ///
102 /// # Returns
103 ///
104 /// The optimal leaf prediction value
105 ///
106 /// # Errors
107 ///
108 /// Returns [`CriterionError::InsufficientSamples`] if the leaf is empty
109 fn compute_leaf_value(
110 &self,
111 gradients: &[f64],
112 hessians: &[f64],
113 ) -> Result<f64, CriterionError>;
114
115 /// Returns the name of the criterion.
116 ///
117 /// Used for logging, debugging, and serialization.
118 ///
119 /// # Returns
120 ///
121 /// A static string slice identifying the criterion (e.g., "MSE", "FriedmanMSE")
122 fn name(&self) -> &str;
123}
124
125
126/// Standard Mean Squared Error split criterion with L2 regularization.
127///
128/// This criterion computes split gain based on the reduction in squared error,
129/// treating gradients as residuals and hessians as constant weights (typically 1.0).
130///
131/// # Formula
132///
133/// - Node gain: `(sum_grad)² / (sum_hess + lambda)`
134/// - Leaf value: `-sum_grad / (sum_hess + lambda)`
135///
136/// # Parameters
137///
138/// - `lambda`: L2 regularization term added to the denominator to prevent division by zero
139/// and control leaf value magnitude
140/// - `min_hessian`: Minimum sum of hessians required for a valid split (avoids unstable splits)
141#[derive(Debug, Clone)]
142pub struct MSECriterion {
143 /// L2 regularization parameter (lambda) for leaf weights.
144 ///
145 /// Larger values shrink leaf predictions toward zero, preventing overfitting.
146 pub lambda: f64,
147 /// Minimum sum of hessians required to consider a split valid.
148 ///
149 /// This prevents splits on nodes with too few effective samples.
150 pub min_hessian: f64,
151}
152
153impl Default for MSECriterion {
154 fn default() -> Self {
155 Self {
156 lambda: 1.0,
157 min_hessian: 1e-8,
158 }
159 }
160}
161
162impl MSECriterion {
163 /// Creates a new MSE criterion with the specified L2 regularization.
164 ///
165 /// # Arguments
166 ///
167 /// * `lambda` - L2 regularization parameter (typically 1.0)
168 ///
169 /// # Returns
170 ///
171 /// A new `MSECriterion` instance with default `min_hessian = 1e-8`
172 pub fn new(lambda: f64) -> Self {
173 Self {
174 lambda,
175 min_hessian: 1e-8,
176 }
177 }
178
179 /// Sets the minimum hessian threshold (builder pattern).
180 ///
181 /// # Arguments
182 ///
183 /// * `min_hessian` - Minimum sum of hessians for valid splits
184 ///
185 /// # Returns
186 ///
187 /// Self with the updated threshold
188 pub fn with_min_hessian(mut self, min_hessian: f64) -> Self {
189 self.min_hessian = min_hessian;
190 self
191 }
192
193 /// Computes the gain for a specific node (without children).
194 ///
195 /// This helper method calculates the base gain used for both parent
196 /// and child nodes in split evaluation.
197 ///
198 /// # Arguments
199 ///
200 /// * `sum_grad` - Sum of gradients in the node
201 /// * `sum_hess` - Sum of hessians in the node
202 ///
203 /// # Returns
204 ///
205 /// Node gain score (non-negative)
206 fn compute_node_gain(&self, sum_grad: f64, sum_hess: f64) -> f64 {
207 if sum_hess < self.min_hessian {
208 return 0.0;
209 }
210 (sum_grad * sum_grad) / (sum_hess + self.lambda)
211 }
212}
213
214impl SplitCriterion for MSECriterion {
215 fn compute_gain(
216 &self,
217 _features: &FeatureMatrix,
218 gradients: &[f64],
219 hessians: &[f64],
220 left_indices: &[usize],
221 right_indices: &[usize],
222 ) -> Result<f64, CriterionError> {
223 if gradients.len() != hessians.len() {
224 return Err(CriterionError::InvalidInput(
225 "Gradients and hessians must have the same length".to_string()
226 ));
227 }
228
229 if left_indices.is_empty() || right_indices.is_empty() {
230 return Err(CriterionError::InsufficientSamples {
231 samples: left_indices.len().min(right_indices.len()),
232 min: 1,
233 });
234 }
235
236 // Compute parent statistics
237 let parent_grad: f64 = gradients.iter().sum();
238 let parent_hess: f64 = hessians.iter().sum();
239 let parent_gain = self.compute_node_gain(parent_grad, parent_hess);
240
241 // Compute left child statistics
242 let left_grad: f64 = left_indices.iter().map(|&i| gradients[i]).sum();
243 let left_hess: f64 = left_indices.iter().map(|&i| hessians[i]).sum();
244 let left_gain = self.compute_node_gain(left_grad, left_hess);
245
246 // Compute right child statistics
247 let right_grad: f64 = right_indices.iter().map(|&i| gradients[i]).sum();
248 let right_hess: f64 = right_indices.iter().map(|&i| hessians[i]).sum();
249 let right_gain = self.compute_node_gain(right_grad, right_hess);
250
251 // Gain is the improvement over the parent
252 Ok(left_gain + right_gain - parent_gain)
253 }
254
255 fn compute_leaf_value(
256 &self,
257 gradients: &[f64],
258 hessians: &[f64],
259 ) -> Result<f64, CriterionError> {
260 if gradients.is_empty() {
261 return Err(CriterionError::InsufficientSamples {
262 samples: 0,
263 min: 1,
264 });
265 }
266
267 let sum_grad: f64 = gradients.iter().sum();
268 let sum_hess: f64 = hessians.iter().sum();
269
270 if sum_hess.abs() < self.min_hessian {
271 return Ok(0.0);
272 }
273
274 // For MSE loss, the optimal leaf value is -sum_grad / (sum_hess + lambda)
275 Ok(-sum_grad / (sum_hess + self.lambda))
276 }
277
278 fn name(&self) -> &str {
279 "MSE"
280 }
281}
282
283/// Friedman's MSE criterion with regularization (XGBoost-style).
284///
285/// This enhanced criterion adds L1 regularization, minimum gain threshold,
286/// and proper handling of regularized leaf values as used in modern
287/// gradient boosting frameworks.
288///
289/// # Formula
290///
291/// - Node gain: `(reg_grad)² / (sum_hess + lambda)` where `reg_grad` is L1-regularized
292/// - Leaf value: `-reg_grad / (sum_hess + lambda)`
293/// - Split gain: `left_gain + right_gain - parent_gain - gamma`
294///
295/// # Regularization Parameters
296///
297/// - `lambda`: L2 regularization on leaf weights
298/// - `alpha`: L1 regularization (shrinkage effect)
299/// - `gamma`: Minimum gain required to make a split (pruning)
300/// - `min_hessian`: Minimum hessian sum for numerical stability
301#[derive(Debug, Clone)]
302pub struct FriedmanMSECriterion {
303 /// L2 regularization parameter for leaf weights.
304 ///
305 /// Controls the magnitude of leaf predictions to prevent overfitting.
306 pub lambda: f64,
307 /// L1 regularization parameter for sparse models.
308 ///
309 /// Encourages leaf values to be exactly zero, creating sparse trees.
310 pub alpha: f64,
311 /// Minimum sum of hessians required for a valid split.
312 ///
313 /// Prevents splits on nodes with insufficient statistical support.
314 pub min_hessian: f64,
315 /// Minimum gain required to create a split (gamma parameter).
316 ///
317 /// Acts as pre-pruning: splits with gain < gamma are rejected.
318 pub gamma: f64,
319}
320
321impl Default for FriedmanMSECriterion {
322 fn default() -> Self {
323 Self {
324 lambda: 1.0,
325 alpha: 0.0,
326 min_hessian: 1e-8,
327 gamma: 0.0,
328 }
329 }
330}
331
332impl FriedmanMSECriterion {
333 /// Creates a new Friedman MSE criterion with L2 and minimum gain regularization.
334 ///
335 /// # Arguments
336 ///
337 /// * `lambda` - L2 regularization parameter (typically 1.0-2.0)
338 /// * `gamma` - Minimum gain threshold for splits (0.0 for no pruning)
339 ///
340 /// # Returns
341 ///
342 /// A new `FriedmanMSECriterion` with default `alpha = 0.0` and `min_hessian = 1e-8`
343 pub fn new(lambda: f64, gamma: f64) -> Self {
344 Self {
345 lambda,
346 alpha: 0.0,
347 min_hessian: 1e-8,
348 gamma,
349 }
350 }
351
352 /// Sets the L1 regularization parameter (builder pattern).
353 ///
354 /// # Arguments
355 ///
356 /// * `alpha` - L1 regularization parameter (useful for sparse data)
357 ///
358 /// # Returns
359 ///
360 /// Self with updated alpha
361 pub fn with_alpha(mut self, alpha: f64) -> Self {
362 self.alpha = alpha;
363 self
364 }
365
366 /// Computes the gain for a node with L1 regularization applied.
367 ///
368 /// L1 regularization modifies the gradient by shrinking it toward zero,
369 /// which can produce sparse leaf values.
370 ///
371 /// # Arguments
372 ///
373 /// * `sum_grad` - Sum of gradients in the node
374 /// * `sum_hess` - Sum of hessians in the node
375 ///
376 /// # Returns
377 ///
378 /// Regularized node gain
379 fn compute_node_gain(&self, sum_grad: f64, sum_hess: f64) -> f64 {
380 if sum_hess < self.min_hessian {
381 return 0.0;
382 }
383
384 // Apply L1 regularization (like in XGBoost)
385 let reg_sum_grad = if sum_grad >= 0.0 {
386 (sum_grad - self.alpha).max(0.0)
387 } else {
388 (sum_grad + self.alpha).min(0.0)
389 };
390
391 (reg_sum_grad * reg_sum_grad) / (sum_hess + self.lambda)
392 }
393
394 /// Computes the optimal leaf value with L1/L2 regularization.
395 ///
396 /// # Arguments
397 ///
398 /// * `sum_grad` - Sum of gradients in the leaf
399 /// * `sum_hess` - Sum of hessians in the leaf
400 ///
401 /// # Returns
402 ///
403 /// Regularized leaf prediction value
404 fn compute_leaf_value_with_reg(&self, sum_grad: f64, sum_hess: f64) -> f64 {
405 if sum_hess.abs() < self.min_hessian {
406 return 0.0;
407 }
408
409 // Apply L1 regularization
410 let reg_sum_grad = if sum_grad >= 0.0 {
411 (sum_grad - self.alpha).max(0.0)
412 } else {
413 (sum_grad + self.alpha).min(0.0)
414 };
415
416 -reg_sum_grad / (sum_hess + self.lambda)
417 }
418}
419
420impl SplitCriterion for FriedmanMSECriterion {
421 fn compute_gain(
422 &self,
423 _features: &FeatureMatrix,
424 gradients: &[f64],
425 hessians: &[f64],
426 left_indices: &[usize],
427 right_indices: &[usize],
428 ) -> Result<f64, CriterionError> {
429 if gradients.len() != hessians.len() {
430 return Err(CriterionError::InvalidInput(
431 "Gradients and hessians must have the same length".to_string()
432 ));
433 }
434
435 if left_indices.is_empty() || right_indices.is_empty() {
436 return Err(CriterionError::InsufficientSamples {
437 samples: left_indices.len().min(right_indices.len()),
438 min: 1,
439 });
440 }
441
442 // Compute parent statistics
443 let parent_grad: f64 = gradients.iter().sum();
444 let parent_hess: f64 = hessians.iter().sum();
445 let parent_gain = self.compute_node_gain(parent_grad, parent_hess);
446
447 // Compute left child statistics
448 let left_grad: f64 = left_indices.iter().map(|&i| gradients[i]).sum();
449 let left_hess: f64 = left_indices.iter().map(|&i| hessians[i]).sum();
450 let left_gain = self.compute_node_gain(left_grad, left_hess);
451
452 // Compute right child statistics
453 let right_grad: f64 = right_indices.iter().map(|&i| gradients[i]).sum();
454 let right_hess: f64 = right_indices.iter().map(|&i| hessians[i]).sum();
455 let right_gain = self.compute_node_gain(right_grad, right_hess);
456
457 // Gain is the improvement over the parent, minus gamma (minimum gain required)
458 let gain = left_gain + right_gain - parent_gain - self.gamma;
459 Ok(gain.max(0.0))
460 }
461
462 fn compute_leaf_value(
463 &self,
464 gradients: &[f64],
465 hessians: &[f64],
466 ) -> Result<f64, CriterionError> {
467 if gradients.is_empty() {
468 return Err(CriterionError::InsufficientSamples {
469 samples: 0,
470 min: 1,
471 });
472 }
473
474 let sum_grad: f64 = gradients.iter().sum();
475 let sum_hess: f64 = hessians.iter().sum();
476
477 Ok(self.compute_leaf_value_with_reg(sum_grad, sum_hess))
478 }
479
480 fn name(&self) -> &str {
481 "FriedmanMSE"
482 }
483}
484
485/// Factory function to create split criterion instances.
486///
487/// This convenience function creates criterion objects from string names,
488/// supporting both built-in criteria and custom regularization parameters.
489///
490/// # Arguments
491///
492/// * `name` - Name of the criterion: "friedman_mse" or "mse"
493/// * `lambda` - L2 regularization parameter
494/// * `gamma` - Minimum gain threshold (only used by FriedmanMSECriterion)
495///
496/// # Returns
497///
498/// A boxed trait object implementing [`SplitCriterion`]
499pub fn create_criterion(name: &str, lambda: f64, gamma: f64) -> Box<dyn SplitCriterion> {
500 match name {
501 "friedman_mse" => Box::new(FriedmanMSECriterion::new(lambda, gamma)),
502 "mse" | _ => Box::new(MSECriterion::new(lambda)),
503 }
504}
505