gbrt_rs/tree/
criterion.rs

1//! Split criteria for decision tree construction in gradient boosting.
2//!
3//! This module defines traits and implementations for evaluating the quality of potential
4//! splits when building decision trees. Split criteria determine which feature and threshold
5//! provide the best separation of data to minimize the loss function.
6//!
7//! # Key Concepts
8//!
9//! In gradient boosting, split criteria operate on **gradients** and **hessians** (first and
10//! second derivatives of the loss function) rather than raw target values. This allows the
11//! same tree-building algorithm to work with any differentiable loss function.
12//!
13//! # Available Criteria
14//!
15//! - [`MSECriterion`]: Basic mean squared error criterion with L2 regularization
16//! - [`FriedmanMSECriterion`]: Enhanced MSE criterion with L1/L2 regularization and
17//!   minimum gain threshold (used in XGBoost and LightGBM)
18//!
19//! # Regularization
20//!
21//! Both criteria support regularization parameters to prevent overfitting:
22//! - `lambda`: L2 regularization on leaf weights
23//! - `alpha`: L1 regularization (Friedman only)
24//! - `gamma`: Minimum gain required to make a split (Friedman only)
25//! - `min_hessian`: Minimum sum of hessians required for a valid split
26
27use crate::data::FeatureMatrix;
28use thiserror::Error;
29
30/// Errors that can occur during split criterion computation.
31///
32/// These errors cover invalid inputs, insufficient samples, and numerical issues
33/// that prevent proper calculation of split quality.
34#[derive(Error, Debug)]
35pub enum CriterionError {
36    /// Input data is invalid for computation.
37    ///
38    /// This includes mismatched array lengths, non-finite values, or other
39    /// data integrity issues that violate the criterion's preconditions.
40    #[error("Invalid input data: {0}")]
41    InvalidInput(String),
42   
43    /// Not enough samples to perform a meaningful split.
44    ///
45    /// This error is returned when a node or split candidate has fewer samples
46    /// than the minimum required for stable statistical estimation.
47    #[error("Insufficient samples: {samples} (min: {min})")]
48    InsufficientSamples { samples: usize, min: usize },
49}
50
51/// Core trait for evaluating split quality in decision trees.
52///
53/// Split criteria determine how well a candidate split separates the data
54/// in terms of reducing the loss function. They operate on gradients and hessians
55/// from the objective function's derivatives.
56///
57/// # Type Safety
58///
59/// Implementations must be thread-safe (`Send + Sync`) as they may be used
60/// in parallel tree construction.
61pub trait SplitCriterion: Send + Sync {
62    /// Computes the improvement (gain) from making a split.
63    ///
64    /// The gain measures how much the loss is reduced by splitting a parent node
65    /// into left and right children. Higher gain indicates a better split.
66    ///
67    /// # Arguments
68    ///
69    /// * `features` - Feature matrix of the training data
70    /// * `gradients` - First derivatives of the loss w.r.t predictions (one per sample)
71    /// * `hessians` - Second derivatives of the loss w.r.t predictions (one per sample)
72    /// * `left_indices` - Indices of samples assigned to the left child node
73    /// * `right_indices` - Indices of samples assigned to the right child node
74    ///
75    /// # Returns
76    ///
77    /// The split gain (improvement score). Should be non-negative.
78    ///
79    /// # Errors
80    ///
81    /// Returns [`CriterionError::InvalidInput`] if gradients and hessians have different lengths
82    /// Returns [`CriterionError::InsufficientSamples`] if either child has too few samples
83    fn compute_gain(
84        &self,
85        features: &FeatureMatrix,
86        gradients: &[f64],
87        hessians: &[f64],
88        left_indices: &[usize],
89        right_indices: &[usize],
90    ) -> Result<f64, CriterionError>;
91    
92    /// Computes the optimal prediction value for a leaf node.
93    ///
94    /// The leaf value minimizes the loss for the samples reaching that leaf.
95    /// For MSE-based criteria, this is typically `-sum(gradients) / sum(hessians + lambda)`.
96    ///
97    /// # Arguments
98    ///
99    /// * `gradients` - First derivatives of samples in the leaf
100    /// * `hessians` - Second derivatives of samples in the leaf
101    ///
102    /// # Returns
103    ///
104    /// The optimal leaf prediction value
105    ///
106    /// # Errors
107    ///
108    /// Returns [`CriterionError::InsufficientSamples`] if the leaf is empty
109    fn compute_leaf_value(
110        &self,
111        gradients: &[f64],
112        hessians: &[f64],
113    ) -> Result<f64, CriterionError>;
114    
115    /// Returns the name of the criterion.
116    ///
117    /// Used for logging, debugging, and serialization.
118    ///
119    /// # Returns
120    ///
121    /// A static string slice identifying the criterion (e.g., "MSE", "FriedmanMSE") 
122    fn name(&self) -> &str;
123}
124
125
126/// Standard Mean Squared Error split criterion with L2 regularization.
127///
128/// This criterion computes split gain based on the reduction in squared error,
129/// treating gradients as residuals and hessians as constant weights (typically 1.0).
130///
131/// # Formula
132///
133/// - Node gain: `(sum_grad)² / (sum_hess + lambda)`
134/// - Leaf value: `-sum_grad / (sum_hess + lambda)`
135///
136/// # Parameters
137///
138/// - `lambda`: L2 regularization term added to the denominator to prevent division by zero
139///   and control leaf value magnitude
140/// - `min_hessian`: Minimum sum of hessians required for a valid split (avoids unstable splits)
141#[derive(Debug, Clone)]
142pub struct MSECriterion {
143    /// L2 regularization parameter (lambda) for leaf weights.
144    ///
145    /// Larger values shrink leaf predictions toward zero, preventing overfitting.
146    pub lambda: f64,
147    /// Minimum sum of hessians required to consider a split valid.
148    ///
149    /// This prevents splits on nodes with too few effective samples.
150    pub min_hessian: f64,
151}
152
153impl Default for MSECriterion {
154    fn default() -> Self {
155        Self {
156            lambda: 1.0,
157            min_hessian: 1e-8,
158        }
159    }
160}
161
162impl MSECriterion {
163    /// Creates a new MSE criterion with the specified L2 regularization.
164    ///
165    /// # Arguments
166    ///
167    /// * `lambda` - L2 regularization parameter (typically 1.0)
168    ///
169    /// # Returns
170    ///
171    /// A new `MSECriterion` instance with default `min_hessian = 1e-8`
172    pub fn new(lambda: f64) -> Self {
173        Self {
174            lambda,
175            min_hessian: 1e-8,
176        }
177    }
178    
179    /// Sets the minimum hessian threshold (builder pattern).
180    ///
181    /// # Arguments
182    ///
183    /// * `min_hessian` - Minimum sum of hessians for valid splits
184    ///
185    /// # Returns
186    ///
187    /// Self with the updated threshold
188    pub fn with_min_hessian(mut self, min_hessian: f64) -> Self {
189        self.min_hessian = min_hessian;
190        self
191    }
192    
193    /// Computes the gain for a specific node (without children).
194    ///
195    /// This helper method calculates the base gain used for both parent
196    /// and child nodes in split evaluation.
197    ///
198    /// # Arguments
199    ///
200    /// * `sum_grad` - Sum of gradients in the node
201    /// * `sum_hess` - Sum of hessians in the node
202    ///
203    /// # Returns
204    ///
205    /// Node gain score (non-negative) 
206    fn compute_node_gain(&self, sum_grad: f64, sum_hess: f64) -> f64 {
207        if sum_hess < self.min_hessian {
208            return 0.0;
209        }
210        (sum_grad * sum_grad) / (sum_hess + self.lambda)
211    }
212}
213
214impl SplitCriterion for MSECriterion {
215    fn compute_gain(
216        &self,
217        _features: &FeatureMatrix,
218        gradients: &[f64],
219        hessians: &[f64],
220        left_indices: &[usize],
221        right_indices: &[usize],
222    ) -> Result<f64, CriterionError> {
223        if gradients.len() != hessians.len() {
224            return Err(CriterionError::InvalidInput(
225                "Gradients and hessians must have the same length".to_string()
226            ));
227        }
228        
229        if left_indices.is_empty() || right_indices.is_empty() {
230            return Err(CriterionError::InsufficientSamples {
231                samples: left_indices.len().min(right_indices.len()),
232                min: 1,
233            });
234        }
235        
236        // Compute parent statistics
237        let parent_grad: f64 = gradients.iter().sum();
238        let parent_hess: f64 = hessians.iter().sum();
239        let parent_gain = self.compute_node_gain(parent_grad, parent_hess);
240        
241        // Compute left child statistics
242        let left_grad: f64 = left_indices.iter().map(|&i| gradients[i]).sum();
243        let left_hess: f64 = left_indices.iter().map(|&i| hessians[i]).sum();
244        let left_gain = self.compute_node_gain(left_grad, left_hess);
245        
246        // Compute right child statistics
247        let right_grad: f64 = right_indices.iter().map(|&i| gradients[i]).sum();
248        let right_hess: f64 = right_indices.iter().map(|&i| hessians[i]).sum();
249        let right_gain = self.compute_node_gain(right_grad, right_hess);
250        
251        // Gain is the improvement over the parent
252        Ok(left_gain + right_gain - parent_gain)
253    }
254    
255    fn compute_leaf_value(
256        &self,
257        gradients: &[f64],
258        hessians: &[f64],
259    ) -> Result<f64, CriterionError> {
260        if gradients.is_empty() {
261            return Err(CriterionError::InsufficientSamples {
262                samples: 0,
263                min: 1,
264            });
265        }
266        
267        let sum_grad: f64 = gradients.iter().sum();
268        let sum_hess: f64 = hessians.iter().sum();
269        
270        if sum_hess.abs() < self.min_hessian {
271            return Ok(0.0);
272        }
273        
274        // For MSE loss, the optimal leaf value is -sum_grad / (sum_hess + lambda)
275        Ok(-sum_grad / (sum_hess + self.lambda))
276    }
277    
278    fn name(&self) -> &str {
279        "MSE"
280    }
281}
282
283/// Friedman's MSE criterion with regularization (XGBoost-style).
284///
285/// This enhanced criterion adds L1 regularization, minimum gain threshold,
286/// and proper handling of regularized leaf values as used in modern
287/// gradient boosting frameworks.
288///
289/// # Formula
290///
291/// - Node gain: `(reg_grad)² / (sum_hess + lambda)` where `reg_grad` is L1-regularized
292/// - Leaf value: `-reg_grad / (sum_hess + lambda)`
293/// - Split gain: `left_gain + right_gain - parent_gain - gamma`
294///
295/// # Regularization Parameters
296///
297/// - `lambda`: L2 regularization on leaf weights
298/// - `alpha`: L1 regularization (shrinkage effect)
299/// - `gamma`: Minimum gain required to make a split (pruning)
300/// - `min_hessian`: Minimum hessian sum for numerical stability
301#[derive(Debug, Clone)]
302pub struct FriedmanMSECriterion {
303    /// L2 regularization parameter for leaf weights.
304    ///
305    /// Controls the magnitude of leaf predictions to prevent overfitting.
306    pub lambda: f64,
307    /// L1 regularization parameter for sparse models.
308    ///
309    /// Encourages leaf values to be exactly zero, creating sparse trees.
310    pub alpha: f64,
311    /// Minimum sum of hessians required for a valid split.
312    ///
313    /// Prevents splits on nodes with insufficient statistical support.
314    pub min_hessian: f64,
315    /// Minimum gain required to create a split (gamma parameter).
316    ///
317    /// Acts as pre-pruning: splits with gain < gamma are rejected.
318    pub gamma: f64,
319}
320
321impl Default for FriedmanMSECriterion {
322    fn default() -> Self {
323        Self {
324            lambda: 1.0,
325            alpha: 0.0,
326            min_hessian: 1e-8,
327            gamma: 0.0,
328        }
329    }
330}
331
332impl FriedmanMSECriterion {
333    /// Creates a new Friedman MSE criterion with L2 and minimum gain regularization.
334    ///
335    /// # Arguments
336    ///
337    /// * `lambda` - L2 regularization parameter (typically 1.0-2.0)
338    /// * `gamma` - Minimum gain threshold for splits (0.0 for no pruning)
339    ///
340    /// # Returns
341    ///
342    /// A new `FriedmanMSECriterion` with default `alpha = 0.0` and `min_hessian = 1e-8`
343    pub fn new(lambda: f64, gamma: f64) -> Self {
344        Self {
345            lambda,
346            alpha: 0.0,
347            min_hessian: 1e-8,
348            gamma,
349        }
350    }
351   
352    /// Sets the L1 regularization parameter (builder pattern).
353    ///
354    /// # Arguments
355    ///
356    /// * `alpha` - L1 regularization parameter (useful for sparse data)
357    ///
358    /// # Returns
359    ///
360    /// Self with updated alpha
361    pub fn with_alpha(mut self, alpha: f64) -> Self {
362        self.alpha = alpha;
363        self
364    }
365    
366    /// Computes the gain for a node with L1 regularization applied.
367    ///
368    /// L1 regularization modifies the gradient by shrinking it toward zero,
369    /// which can produce sparse leaf values.
370    ///
371    /// # Arguments
372    ///
373    /// * `sum_grad` - Sum of gradients in the node
374    /// * `sum_hess` - Sum of hessians in the node
375    ///
376    /// # Returns
377    ///
378    /// Regularized node gain 
379    fn compute_node_gain(&self, sum_grad: f64, sum_hess: f64) -> f64 {
380        if sum_hess < self.min_hessian {
381            return 0.0;
382        }
383        
384        // Apply L1 regularization (like in XGBoost)
385        let reg_sum_grad = if sum_grad >= 0.0 {
386            (sum_grad - self.alpha).max(0.0)
387        } else {
388            (sum_grad + self.alpha).min(0.0)
389        };
390        
391        (reg_sum_grad * reg_sum_grad) / (sum_hess + self.lambda)
392    }
393    
394    /// Computes the optimal leaf value with L1/L2 regularization.
395    ///
396    /// # Arguments
397    ///
398    /// * `sum_grad` - Sum of gradients in the leaf
399    /// * `sum_hess` - Sum of hessians in the leaf
400    ///
401    /// # Returns
402    ///
403    /// Regularized leaf prediction value    
404    fn compute_leaf_value_with_reg(&self, sum_grad: f64, sum_hess: f64) -> f64 {
405        if sum_hess.abs() < self.min_hessian {
406            return 0.0;
407        }
408        
409        // Apply L1 regularization
410        let reg_sum_grad = if sum_grad >= 0.0 {
411            (sum_grad - self.alpha).max(0.0)
412        } else {
413            (sum_grad + self.alpha).min(0.0)
414        };
415        
416        -reg_sum_grad / (sum_hess + self.lambda)
417    }
418}
419
420impl SplitCriterion for FriedmanMSECriterion {
421    fn compute_gain(
422        &self,
423        _features: &FeatureMatrix,
424        gradients: &[f64],
425        hessians: &[f64],
426        left_indices: &[usize],
427        right_indices: &[usize],
428    ) -> Result<f64, CriterionError> {
429        if gradients.len() != hessians.len() {
430            return Err(CriterionError::InvalidInput(
431                "Gradients and hessians must have the same length".to_string()
432            ));
433        }
434        
435        if left_indices.is_empty() || right_indices.is_empty() {
436            return Err(CriterionError::InsufficientSamples {
437                samples: left_indices.len().min(right_indices.len()),
438                min: 1,
439            });
440        }
441        
442        // Compute parent statistics
443        let parent_grad: f64 = gradients.iter().sum();
444        let parent_hess: f64 = hessians.iter().sum();
445        let parent_gain = self.compute_node_gain(parent_grad, parent_hess);
446        
447        // Compute left child statistics
448        let left_grad: f64 = left_indices.iter().map(|&i| gradients[i]).sum();
449        let left_hess: f64 = left_indices.iter().map(|&i| hessians[i]).sum();
450        let left_gain = self.compute_node_gain(left_grad, left_hess);
451        
452        // Compute right child statistics
453        let right_grad: f64 = right_indices.iter().map(|&i| gradients[i]).sum();
454        let right_hess: f64 = right_indices.iter().map(|&i| hessians[i]).sum();
455        let right_gain = self.compute_node_gain(right_grad, right_hess);
456        
457        // Gain is the improvement over the parent, minus gamma (minimum gain required)
458        let gain = left_gain + right_gain - parent_gain - self.gamma;
459        Ok(gain.max(0.0))
460    }
461    
462    fn compute_leaf_value(
463        &self,
464        gradients: &[f64],
465        hessians: &[f64],
466    ) -> Result<f64, CriterionError> {
467        if gradients.is_empty() {
468            return Err(CriterionError::InsufficientSamples {
469                samples: 0,
470                min: 1,
471            });
472        }
473        
474        let sum_grad: f64 = gradients.iter().sum();
475        let sum_hess: f64 = hessians.iter().sum();
476        
477        Ok(self.compute_leaf_value_with_reg(sum_grad, sum_hess))
478    }
479    
480    fn name(&self) -> &str {
481        "FriedmanMSE"
482    }
483}
484
485/// Factory function to create split criterion instances.
486///
487/// This convenience function creates criterion objects from string names,
488/// supporting both built-in criteria and custom regularization parameters.
489///
490/// # Arguments
491///
492/// * `name` - Name of the criterion: "friedman_mse" or "mse"
493/// * `lambda` - L2 regularization parameter
494/// * `gamma` - Minimum gain threshold (only used by FriedmanMSECriterion)
495///
496/// # Returns
497///
498/// A boxed trait object implementing [`SplitCriterion`]
499pub fn create_criterion(name: &str, lambda: f64, gamma: f64) -> Box<dyn SplitCriterion> {
500    match name {
501        "friedman_mse" => Box::new(FriedmanMSECriterion::new(lambda, gamma)),
502        "mse" | _ => Box::new(MSECriterion::new(lambda)),
503    }
504}
505