gbrt_rs/tree/
splitter.rs

1//! Split finding algorithms for decision tree construction.
2//!
3//! This module provides the core logic for identifying optimal feature splits
4//! in gradient boosting trees. It defines the [`Splitter`] trait for different
5//! split-finding strategies and implements [`BestSplitter`] as the default
6//! exhaustive search algorithm.
7//!
8//! # Split Finding Strategies
9//!
10//! The module supports two main approaches:
11//!
12//! - **Exact splits**: Evaluate every possible threshold between sorted feature values.
13//!   Most accurate but O(n) per feature. Suitable for small to medium datasets.
14//!
15//! - **Approximate splits**: Use histogram binning to discretize continuous features.
16//!   O(bins) per feature, much faster for large datasets with negligible accuracy loss.
17//!
18//! # Performance Optimizations
19//!
20//! [`BestSplitter`] implements several key optimizations:
21//! - Precomputes cumulative sums for O(n) gain calculation
22//! - Merges small histogram bins to respect `min_samples_leaf`
23//! - Early termination for degenerate cases (constant features, insufficient samples)
24//! - Feature subsampling via `feature_indices` parameter
25//!
26//! # Split Candidate Evaluation
27//!
28//! Splits are evaluated using the gain formula from the objective function's
29//! second-order Taylor approximation. The gain measures improvement over the parent node:
30//!
31//! ```text
32//! Gain = LeftGain + RightGain - ParentGain
33//! ```
34//!
35//! where `NodeGain = (sum_gradients)² / (sum_hessians + lambda)`
36//!
37//! [`Splitter`]: trait.Splitter.html
38//! [`BestSplitter`]: struct.BestSplitter.html
39
40use crate::data::FeatureMatrix;
41use thiserror::Error;
42
43/// Errors that can occur during split finding.
44///
45/// These errors cover invalid inputs, insufficient samples, and data integrity
46/// issues that prevent finding valid splits.
47#[derive(Error, Debug)]
48pub enum SplitterError {
49    /// Requested feature index is out of bounds.
50    #[error("Invalid feature index: {index} (max: {max})")]
51    InvalidFeatureIndex { index: usize, max: usize },
52    
53    /// Too few samples to perform a split meeting minimum leaf size requirements.
54    #[error("Insufficient samples for splitting: {samples} (min: {min})")]
55    InsufficientSamples { samples: usize, min: usize },
56    
57    /// A feature had no valid split points (e.g., all values identical).
58    #[error("No valid splits found for feature {feature}")]
59    NoValidSplits { feature: usize },
60    
61    /// Underlying data access error from feature matrix.
62    #[error("Data error: {0}")]
63    DataError(#[from] crate::data::DataError),
64   
65    /// Sample array is empty after filtering.
66    #[error("Empty samples array")]
67    EmptySamples,
68}
69
70/// Represents a candidate split with precomputed statistics.
71///
72/// This struct stores all information needed to evaluate and apply a split,
73/// including the feature, threshold, gain, and sample indices for each child.
74#[derive(Debug, Clone)]
75pub struct SplitCandidate {
76    /// Index of the feature to split on.
77    pub feature_index: usize,
78    /// Threshold value for the split (samples <= go left, > go right).
79    pub split_value: f64,
80    /// Improvement in loss from this split.
81    pub gain: f64,
82    /// Indices of samples assigned to the left child.
83    pub left_indices: Vec<usize>,
84    /// Indices of samples assigned to the right child.
85    pub right_indices: Vec<usize>,
86    /// Sum of gradients in the left child.
87    pub left_grad_sum: f64,
88    /// Sum of gradients in the right child.
89    pub right_grad_sum: f64,
90    /// Sum of hessians in the left child.
91    pub left_hess_sum: f64,
92    /// Sum of hessians in the right child.
93    pub right_hess_sum: f64,
94}
95
96impl SplitCandidate {
97    /// Creates a new split candidate with all required statistics.
98    ///
99    /// # Arguments
100    ///
101    /// * `feature_index` - Feature to split on
102    /// * `split_value` - Threshold value
103    /// * `gain` - Split improvement score
104    /// * `left_indices` - Samples in left child
105    /// * `right_indices` - Samples in right child
106    /// * `left_grad_sum` - Gradient sum for left child
107    /// * `right_grad_sum` - Gradient sum for right child
108    /// * `left_hess_sum` - Hessian sum for left child
109    /// * `right_hess_sum` - Hessian sum for right child
110    pub fn new(
111        feature_index: usize,
112        split_value: f64,
113        gain: f64,
114        left_indices: Vec<usize>,
115        right_indices: Vec<usize>,
116        left_grad_sum: f64,
117        right_grad_sum: f64,
118        left_hess_sum: f64,
119        right_hess_sum: f64,
120    ) -> Self {
121        Self {
122            feature_index,
123            split_value,
124            gain,
125            left_indices,
126            right_indices,
127            left_grad_sum,
128            right_grad_sum,
129            left_hess_sum,
130            right_hess_sum,
131        }
132    }
133   
134    /// Validates that this split meets minimum leaf size requirements.
135    ///
136    /// # Arguments
137    ///
138    /// * `min_samples_leaf` - Minimum samples required in each child
139    ///
140    /// # Returns
141    ///
142    /// `true` if both children have at least `min_samples_leaf` samples
143    pub fn is_valid(&self, min_samples_leaf: usize) -> bool {
144        self.left_indices.len() >= min_samples_leaf && 
145        self.right_indices.len() >= min_samples_leaf
146    }
147}
148
149/// Trait for pluggable split-finding algorithms.
150///
151/// Implementations define different strategies for finding optimal feature splits,
152/// enabling experimentation with exact search, approximate methods, or random splits.
153pub trait Splitter: Send + Sync {
154    /// Finds the best split across the specified features.
155    ///
156    /// This method searches for the optimal (feature, threshold) pair that maximizes
157    /// the split gain while respecting `min_samples_leaf` and other constraints.
158    ///
159    /// # Arguments
160    ///
161    /// * `features` - Training feature matrix
162    /// * `gradients` - First derivatives from the objective function
163    /// * `hessians` - Second derivatives from the objective function
164    /// * `feature_indices` - Subset of features to consider (enables feature sampling)
165    /// * `min_samples_leaf` - Minimum samples in each child
166    /// * `lambda` - L2 regularization parameter
167    ///
168    /// # Returns
169    ///
170    /// `Ok(Some(candidate))` if a valid split is found, `Ok(None)` if no valid split exists,
171    /// or a [`SplitterError`] if computation fails.
172    fn find_best_split(
173        &self,
174        features: &FeatureMatrix,
175        gradients: &[f64],
176        hessians: &[f64],
177        feature_indices: &[usize],
178        min_samples_leaf: usize,
179        lambda: f64,
180    ) -> Result<Option<SplitCandidate>, SplitterError>;
181}
182
183/// Default splitter that exhaustively searches for the optimal split.
184///
185/// `BestSplitter` evaluates all valid split points for each feature, computing
186/// the exact gain using cumulative sums. For large datasets, it can use histogram
187/// approximation to trade a small amount of accuracy for significant speedup.
188///
189/// # Configuration
190///
191/// - `n_bins`: Number of histogram bins for approximate splitting (default: 256)
192/// - `use_exact_splits`: Force exact evaluation even for large features
193///
194/// # Algorithm
195///
196/// For each feature:
197/// 1. Sort samples by feature value (O(n log n))
198/// 2. Precompute cumulative gradient/hessian sums (O(n))
199/// 3. Evaluate all splits meeting `min_samples_leaf` (O(n))
200/// 4. Track the split with maximum gain
201///
202/// The overall complexity is O(k × n log n) where k is the number of features considered.
203#[derive(Debug, Clone)]
204pub struct BestSplitter {
205    /// Number of bins for histogram approximation (None = exact).
206    pub n_bins: Option<usize>,
207    /// Whether to disable approximation and use exact splits.
208    pub use_exact_splits: bool,
209}
210
211impl Default for BestSplitter {
212    fn default() -> Self {
213        Self {
214            n_bins: Some(256),
215            use_exact_splits: false,
216        }
217    }
218}
219
220impl BestSplitter {
221    /// Creates a new `BestSplitter` with default settings (256 bins, approximate).
222    pub fn new() -> Self {
223        Self::default()
224    }
225   
226    /// Sets the number of histogram bins for approximate splitting.
227    ///
228    /// More bins = more accurate but slower. Fewer bins = faster but coarser approximations.
229    ///
230    /// # Arguments
231    ///
232    /// * `n_bins` - Number of histogram bins (must be ≥ 2)
233    ///
234    /// # Returns
235    ///
236    /// A new `BestSplitter` instance
237    pub fn with_n_bins(n_bins: usize) -> Self {
238        Self {
239            n_bins: Some(n_bins),
240            use_exact_splits: false,
241        }
242    }
243   
244    /// Enables or disables exact split evaluation.
245    ///
246    /// When `true`, all splits are evaluated exactly regardless of dataset size.
247    /// When `false`, histogram approximation is used for large features.
248    ///
249    /// # Arguments
250    ///
251    /// * `use_exact` - Whether to force exact evaluation
252    ///
253    /// # Returns
254    ///
255    /// Self with updated setting (builder pattern)
256    pub fn use_exact_splits(mut self, use_exact: bool) -> Self {
257        self.use_exact_splits = use_exact;
258        self
259    }
260   
261    /// Finds the best split for a single feature.
262    ///
263    /// This internal method delegates to either exact or approximate search
264    /// based on the splitter's configuration.
265    ///
266    /// # Arguments
267    ///
268    /// * `features` - Feature matrix
269    /// * `gradients` - Gradient values
270    /// * `hessians` - Hessian values
271    /// * `feature_index` - Feature to evaluate
272    /// * `min_samples_leaf` - Minimum leaf size
273    /// * `lambda` - L2 regularization
274    ///
275    /// # Returns
276    ///
277    /// The best split candidate for this feature, or `None` if no valid split exists
278    fn find_best_split_for_feature(
279        &self,
280        features: &FeatureMatrix,
281        gradients: &[f64],
282        hessians: &[f64],
283        feature_index: usize,
284        min_samples_leaf: usize,
285        lambda: f64,
286    ) -> Result<Option<SplitCandidate>, SplitterError> {
287        let n_samples = features.n_samples();
288        
289        if n_samples < min_samples_leaf * 2 {
290            return Ok(None);
291        }
292        
293        // Get feature values with error handling
294        let mut samples: Vec<(f64, f64, f64, usize)> = (0..n_samples)
295            .filter_map(|i| {
296                match features.get(i, feature_index) {
297                    Ok(feature_val) => Some((feature_val, gradients[i], hessians[i], i)),
298                    Err(_) => None,
299                }
300            })
301            .collect();
302        
303        if samples.is_empty() {
304            return Err(SplitterError::EmptySamples);
305        }
306        
307        // Sort by feature value with safe comparison
308        samples.sort_by(|a, b| {
309            a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
310        });
311        
312        // Remove duplicate feature values at boundaries to avoid degenerate splits
313        samples.dedup_by(|a, b| (a.0 - b.0).abs() < 1e-10);
314        
315        if samples.len() < min_samples_leaf * 2 {
316            return Ok(None);
317        }
318        
319        if self.use_exact_splits {
320            self.find_best_exact_split(&samples, feature_index, min_samples_leaf, lambda)
321        } else {
322            self.find_best_approximate_split(&samples, feature_index, min_samples_leaf, lambda)
323        }
324    }
325    
326    /// Exhaustively evaluates all possible split points.
327    ///
328    /// # Arguments
329    ///
330    /// * `samples` - Sorted samples: (feature_value, gradient, hessian, index)
331    /// * `feature_index` - Feature being evaluated
332    /// * `min_samples_leaf` - Minimum samples per child
333    /// * `lambda` - L2 regularization
334    ///
335    /// # Returns
336    ///
337    /// Best split candidate found via exhaustive search
338    fn find_best_exact_split(
339        &self,
340        samples: &[(f64, f64, f64, usize)],
341        feature_index: usize,
342        min_samples_leaf: usize,
343        lambda: f64,
344    ) -> Result<Option<SplitCandidate>, SplitterError> {
345        let n_samples = samples.len();
346        let mut best_gain = -f64::INFINITY;
347        let mut best_split: Option<SplitCandidate> = None;
348        
349        // Precompute cumulative sums
350        let mut grad_prefix = Vec::with_capacity(n_samples + 1);
351        let mut hess_prefix = Vec::with_capacity(n_samples + 1);
352        
353        grad_prefix.push(0.0);
354        hess_prefix.push(0.0);
355        
356        for i in 0..n_samples {
357            grad_prefix.push(grad_prefix[i] + samples[i].1);
358            hess_prefix.push(hess_prefix[i] + samples[i].2);
359        }
360        
361        let total_grad = grad_prefix[n_samples];
362        let total_hess = hess_prefix[n_samples];
363        
364        if total_hess + lambda <= 0.0 {
365            return Ok(None);
366        }
367        
368        let parent_gain = self.compute_gain(total_grad, total_hess, lambda);
369        
370        // Try all possible splits with bounds checking
371        let max_i = n_samples.saturating_sub(min_samples_leaf);
372        for i in min_samples_leaf..=max_i {
373            // Ensure we don't go out of bounds
374            if i >= n_samples {
375                break;
376            }
377            
378            // Skip if feature value is too close to next (degenerate)
379            if i < n_samples - 1 {
380                let diff = (samples[i].0 - samples[i + 1].0).abs();
381                if diff < 1e-10 {
382                    continue;
383                }
384            }
385            
386            let left_grad = grad_prefix[i];
387            let left_hess = hess_prefix[i];
388            let right_grad = total_grad - left_grad;
389            let right_hess = total_hess - left_hess;
390            
391            // Avoid division by zero
392            if left_hess + lambda <= 0.0 || right_hess + lambda <= 0.0 {
393                continue;
394            }
395            
396            let left_gain = self.compute_gain(left_grad, left_hess, lambda);
397            let right_gain = self.compute_gain(right_grad, right_hess, lambda);
398            let gain = left_gain + right_gain - parent_gain;
399            
400            // Only consider positive gains
401            if gain > best_gain && gain > 1e-10 {
402                best_gain = gain;
403                
404                let split_value = if i < n_samples - 1 {
405                    (samples[i].0 + samples[i + 1].0) / 2.0
406                } else {
407                    samples[i].0
408                };
409                
410                let left_indices: Vec<usize> = samples[0..i].iter().map(|&(_, _, _, idx)| idx).collect();
411                let right_indices: Vec<usize> = samples[i..].iter().map(|&(_, _, _, idx)| idx).collect();
412                
413                // Validate split sizes
414                if left_indices.len() >= min_samples_leaf && right_indices.len() >= min_samples_leaf {
415                    best_split = Some(SplitCandidate::new(
416                        feature_index,
417                        split_value,
418                        gain,
419                        left_indices,
420                        right_indices,
421                        left_grad,
422                        right_grad,
423                        left_hess,
424                        right_hess,
425                    ));
426                }
427            }
428        }
429        
430        Ok(best_split)
431    }
432    
433    /// Uses histogram binning for faster approximate split finding.
434    ///
435    /// # Arguments
436    ///
437    /// * `samples` - Sorted samples: (feature_value, gradient, hessian, index)
438    /// * `feature_index` - Feature being evaluated
439    /// * `min_samples_leaf` - Minimum samples per child
440    /// * `lambda` - L2 regularization
441    ///
442    /// # Returns
443    ///
444    /// Best split candidate found via histogram approximation
445    fn find_best_approximate_split(
446        &self,
447        samples: &[(f64, f64, f64, usize)],
448        feature_index: usize,
449        min_samples_leaf: usize,
450        lambda: f64,
451    ) -> Result<Option<SplitCandidate>, SplitterError> {
452        let n_samples = samples.len();
453        let n_bins = self.n_bins.unwrap_or(256).min(n_samples);
454
455        if n_bins < 2 {
456            return self.find_best_exact_split(samples, feature_index, min_samples_leaf, lambda);
457        }
458
459        // Create bins
460        let min_val = samples[0].0;
461        let max_val = samples[n_samples - 1].0;
462        if (max_val - min_val).abs() < 1e-10 {
463            return Ok(None);
464        }
465
466        let bin_size = (max_val - min_val) / (n_bins as f64);
467        let mut bins: Vec<(f64, f64, f64, f64, Vec<usize>)> = Vec::with_capacity(n_bins);
468
469        for _ in 0..n_bins {
470            bins.push((f64::INFINITY, f64::NEG_INFINITY, 0.0, 0.0, Vec::new()));
471        }
472
473        // Accumulate into bins
474        for &(feature_val, grad, hess, idx) in samples {
475            let bin_idx = (((feature_val - min_val) / bin_size) as usize).min(n_bins - 1);
476            let bin = &mut bins[bin_idx];
477            bin.0 = bin.0.min(feature_val);
478            bin.1 = bin.1.max(feature_val);
479            bin.2 += grad;
480            bin.3 += hess;
481            bin.4.push(idx);
482        }
483
484        // Filter empty bins
485        let non_empty_bins: Vec<(f64, f64, f64, f64, Vec<usize>)> = bins
486            .into_iter()
487            .filter(|bin| !bin.4.is_empty())
488            .collect();
489
490        if non_empty_bins.len() < 2 {
491            return Ok(None);
492        }
493
494        // Merge small bins
495        let mut merged_bins = Vec::new();
496        let mut current_bin = (f64::INFINITY, f64::NEG_INFINITY, 0.0, 0.0, Vec::new());
497
498        for bin in non_empty_bins {
499            if current_bin.4.len() + bin.4.len() < min_samples_leaf && !current_bin.4.is_empty() {
500                current_bin.0 = current_bin.0.min(bin.0);
501                current_bin.1 = current_bin.1.max(bin.1);
502                current_bin.2 += bin.2;
503                current_bin.3 += bin.3;
504                current_bin.4.extend(bin.4);
505            } else {
506                if !current_bin.4.is_empty() {
507                    merged_bins.push(current_bin);
508                }
509                current_bin = bin;
510            }
511        }
512
513        if !current_bin.4.is_empty() {
514            merged_bins.push(current_bin);
515        }
516
517        if merged_bins.len() < 2 {
518            return Ok(None);
519        }
520
521        // Find best split among bins
522        let mut best_gain = -f64::INFINITY;
523        let mut best_split: Option<SplitCandidate> = None;
524
525        let total_grad: f64 = merged_bins.iter().map(|bin| bin.2).sum();
526        let total_hess: f64 = merged_bins.iter().map(|bin| bin.3).sum();
527        
528        if total_hess + lambda <= 0.0 {
529            return Ok(None);
530        }
531        
532        let parent_gain = self.compute_gain(total_grad, total_hess, lambda);
533
534        let mut left_grad = 0.0;
535        let mut left_hess = 0.0;
536        let mut left_indices = Vec::new();
537
538        for i in 0..(merged_bins.len() - 1) {
539            left_grad += merged_bins[i].2;
540            left_hess += merged_bins[i].3;
541            left_indices.extend(&merged_bins[i].4);
542
543            if left_indices.len() < min_samples_leaf {
544                continue;
545            }
546
547            let right_grad = total_grad - left_grad;
548            let right_hess = total_hess - left_hess;
549            let right_indices: Vec<usize> = merged_bins[i + 1..]
550                .iter()
551                .flat_map(|bin| bin.4.iter().cloned())
552                .collect();
553
554            if right_indices.len() < min_samples_leaf {
555                continue;
556            }
557
558            if left_hess + lambda <= 0.0 || right_hess + lambda <= 0.0 {
559                continue;
560            }
561            
562            let left_gain = self.compute_gain(left_grad, left_hess, lambda);
563            let right_gain = self.compute_gain(right_grad, right_hess, lambda);
564            let gain = left_gain + right_gain - parent_gain;
565
566            if gain > best_gain && gain > 1e-10 {
567                best_gain = gain;
568
569                let split_value = (merged_bins[i].1 + merged_bins[i + 1].0) / 2.0;
570
571                best_split = Some(SplitCandidate::new(
572                    feature_index,
573                    split_value,
574                    gain,
575                    left_indices.clone(),
576                    right_indices,
577                    left_grad,
578                    right_grad,
579                    left_hess,
580                    right_hess,
581                ));
582            }
583        }
584
585        Ok(best_split)
586    }
587    
588    /// Computes the gain for a node given its gradient/hessian sums.
589    ///
590    /// # Formula
591    ///
592    /// `gain = (sum_grad)² / (sum_hess + lambda)`
593    ///
594    /// # Arguments
595    ///
596    /// * `sum_grad` - Sum of gradients in the node
597    /// * `sum_hess` - Sum of hessians in the node
598    /// * `lambda` - L2 regularization parameter
599    ///
600    /// # Returns
601    ///
602    /// Node gain, or `-∞` if the denominator would be non-positive
603    fn compute_gain(&self, sum_grad: f64, sum_hess: f64, lambda: f64) -> f64 {
604        if sum_hess + lambda <= 0.0 {
605            -f64::INFINITY
606        } else {
607            (sum_grad * sum_grad) / (sum_hess + lambda)
608        }
609    }
610}
611
612impl Splitter for BestSplitter {
613    /// Finds the globally optimal split across all specified features.
614    ///
615    /// This method iterates through `feature_indices`, finds the best split for each,
616    /// and returns the candidate with maximum gain. It respects all regularization
617    /// and sampling constraints.
618    ///
619    /// # Arguments
620    ///
621    /// * `features` - Feature matrix
622    /// * `gradients` - Gradient values
623    /// * `hessians` - Hessian values
624    /// * `feature_indices` - Features to evaluate
625    /// * `min_samples_leaf` - Minimum samples per child
626    /// * `lambda` - L2 regularization
627    ///
628    /// # Returns
629    ///
630    /// Best split across all features, or `None` if no valid split exists
631    fn find_best_split(
632        &self,
633        features: &FeatureMatrix,
634        gradients: &[f64],
635        hessians: &[f64],
636        feature_indices: &[usize],
637        min_samples_leaf: usize,
638        lambda: f64,
639    ) -> Result<Option<SplitCandidate>, SplitterError> {
640        if features.n_samples() != gradients.len() || gradients.len() != hessians.len() {
641            return Err(SplitterError::InsufficientSamples {
642                samples: gradients.len(),
643                min: features.n_samples(),
644            });
645        }
646        
647        if features.n_samples() < min_samples_leaf * 2 {
648            return Ok(None);
649        }
650        
651        let mut best_candidate: Option<SplitCandidate> = None;
652        
653        // Add progress tracking for large datasets
654        if features.n_features() > 100 && features.n_samples() > 10000 {
655            // Potentially log progress here
656        }
657        
658        for (idx, &feature_idx) in feature_indices.iter().enumerate() {
659            // Log progress every 10 features for large datasets
660            if idx > 0 && idx % 10 == 0 && features.n_features() > 50 {
661                // Progress log
662            }
663            
664            if feature_idx >= features.n_features() {
665                return Err(SplitterError::InvalidFeatureIndex {
666                    index: feature_idx,
667                    max: features.n_features() - 1,
668                });
669            }
670            
671            if let Some(candidate) = self.find_best_split_for_feature(
672                features,
673                gradients,
674                hessians,
675                feature_idx,
676                min_samples_leaf,
677                lambda,
678            )? {
679                if best_candidate.as_ref().map(|c| c.gain < candidate.gain).unwrap_or(true) {
680                    best_candidate = Some(candidate);
681                }
682            }
683        }
684        
685        Ok(best_candidate)
686    }
687}
688