gbrt_rs/tree/splitter.rs
1//! Split finding algorithms for decision tree construction.
2//!
3//! This module provides the core logic for identifying optimal feature splits
4//! in gradient boosting trees. It defines the [`Splitter`] trait for different
5//! split-finding strategies and implements [`BestSplitter`] as the default
6//! exhaustive search algorithm.
7//!
8//! # Split Finding Strategies
9//!
10//! The module supports two main approaches:
11//!
12//! - **Exact splits**: Evaluate every possible threshold between sorted feature values.
13//! Most accurate but O(n) per feature. Suitable for small to medium datasets.
14//!
15//! - **Approximate splits**: Use histogram binning to discretize continuous features.
16//! O(bins) per feature, much faster for large datasets with negligible accuracy loss.
17//!
18//! # Performance Optimizations
19//!
20//! [`BestSplitter`] implements several key optimizations:
21//! - Precomputes cumulative sums for O(n) gain calculation
22//! - Merges small histogram bins to respect `min_samples_leaf`
23//! - Early termination for degenerate cases (constant features, insufficient samples)
24//! - Feature subsampling via `feature_indices` parameter
25//!
26//! # Split Candidate Evaluation
27//!
28//! Splits are evaluated using the gain formula from the objective function's
29//! second-order Taylor approximation. The gain measures improvement over the parent node:
30//!
31//! ```text
32//! Gain = LeftGain + RightGain - ParentGain
33//! ```
34//!
35//! where `NodeGain = (sum_gradients)² / (sum_hessians + lambda)`
36//!
37//! [`Splitter`]: trait.Splitter.html
38//! [`BestSplitter`]: struct.BestSplitter.html
39
40use crate::data::FeatureMatrix;
41use thiserror::Error;
42
43/// Errors that can occur during split finding.
44///
45/// These errors cover invalid inputs, insufficient samples, and data integrity
46/// issues that prevent finding valid splits.
47#[derive(Error, Debug)]
48pub enum SplitterError {
49 /// Requested feature index is out of bounds.
50 #[error("Invalid feature index: {index} (max: {max})")]
51 InvalidFeatureIndex { index: usize, max: usize },
52
53 /// Too few samples to perform a split meeting minimum leaf size requirements.
54 #[error("Insufficient samples for splitting: {samples} (min: {min})")]
55 InsufficientSamples { samples: usize, min: usize },
56
57 /// A feature had no valid split points (e.g., all values identical).
58 #[error("No valid splits found for feature {feature}")]
59 NoValidSplits { feature: usize },
60
61 /// Underlying data access error from feature matrix.
62 #[error("Data error: {0}")]
63 DataError(#[from] crate::data::DataError),
64
65 /// Sample array is empty after filtering.
66 #[error("Empty samples array")]
67 EmptySamples,
68}
69
70/// Represents a candidate split with precomputed statistics.
71///
72/// This struct stores all information needed to evaluate and apply a split,
73/// including the feature, threshold, gain, and sample indices for each child.
74#[derive(Debug, Clone)]
75pub struct SplitCandidate {
76 /// Index of the feature to split on.
77 pub feature_index: usize,
78 /// Threshold value for the split (samples <= go left, > go right).
79 pub split_value: f64,
80 /// Improvement in loss from this split.
81 pub gain: f64,
82 /// Indices of samples assigned to the left child.
83 pub left_indices: Vec<usize>,
84 /// Indices of samples assigned to the right child.
85 pub right_indices: Vec<usize>,
86 /// Sum of gradients in the left child.
87 pub left_grad_sum: f64,
88 /// Sum of gradients in the right child.
89 pub right_grad_sum: f64,
90 /// Sum of hessians in the left child.
91 pub left_hess_sum: f64,
92 /// Sum of hessians in the right child.
93 pub right_hess_sum: f64,
94}
95
96impl SplitCandidate {
97 /// Creates a new split candidate with all required statistics.
98 ///
99 /// # Arguments
100 ///
101 /// * `feature_index` - Feature to split on
102 /// * `split_value` - Threshold value
103 /// * `gain` - Split improvement score
104 /// * `left_indices` - Samples in left child
105 /// * `right_indices` - Samples in right child
106 /// * `left_grad_sum` - Gradient sum for left child
107 /// * `right_grad_sum` - Gradient sum for right child
108 /// * `left_hess_sum` - Hessian sum for left child
109 /// * `right_hess_sum` - Hessian sum for right child
110 pub fn new(
111 feature_index: usize,
112 split_value: f64,
113 gain: f64,
114 left_indices: Vec<usize>,
115 right_indices: Vec<usize>,
116 left_grad_sum: f64,
117 right_grad_sum: f64,
118 left_hess_sum: f64,
119 right_hess_sum: f64,
120 ) -> Self {
121 Self {
122 feature_index,
123 split_value,
124 gain,
125 left_indices,
126 right_indices,
127 left_grad_sum,
128 right_grad_sum,
129 left_hess_sum,
130 right_hess_sum,
131 }
132 }
133
134 /// Validates that this split meets minimum leaf size requirements.
135 ///
136 /// # Arguments
137 ///
138 /// * `min_samples_leaf` - Minimum samples required in each child
139 ///
140 /// # Returns
141 ///
142 /// `true` if both children have at least `min_samples_leaf` samples
143 pub fn is_valid(&self, min_samples_leaf: usize) -> bool {
144 self.left_indices.len() >= min_samples_leaf &&
145 self.right_indices.len() >= min_samples_leaf
146 }
147}
148
149/// Trait for pluggable split-finding algorithms.
150///
151/// Implementations define different strategies for finding optimal feature splits,
152/// enabling experimentation with exact search, approximate methods, or random splits.
153pub trait Splitter: Send + Sync {
154 /// Finds the best split across the specified features.
155 ///
156 /// This method searches for the optimal (feature, threshold) pair that maximizes
157 /// the split gain while respecting `min_samples_leaf` and other constraints.
158 ///
159 /// # Arguments
160 ///
161 /// * `features` - Training feature matrix
162 /// * `gradients` - First derivatives from the objective function
163 /// * `hessians` - Second derivatives from the objective function
164 /// * `feature_indices` - Subset of features to consider (enables feature sampling)
165 /// * `min_samples_leaf` - Minimum samples in each child
166 /// * `lambda` - L2 regularization parameter
167 ///
168 /// # Returns
169 ///
170 /// `Ok(Some(candidate))` if a valid split is found, `Ok(None)` if no valid split exists,
171 /// or a [`SplitterError`] if computation fails.
172 fn find_best_split(
173 &self,
174 features: &FeatureMatrix,
175 gradients: &[f64],
176 hessians: &[f64],
177 feature_indices: &[usize],
178 min_samples_leaf: usize,
179 lambda: f64,
180 ) -> Result<Option<SplitCandidate>, SplitterError>;
181}
182
183/// Default splitter that exhaustively searches for the optimal split.
184///
185/// `BestSplitter` evaluates all valid split points for each feature, computing
186/// the exact gain using cumulative sums. For large datasets, it can use histogram
187/// approximation to trade a small amount of accuracy for significant speedup.
188///
189/// # Configuration
190///
191/// - `n_bins`: Number of histogram bins for approximate splitting (default: 256)
192/// - `use_exact_splits`: Force exact evaluation even for large features
193///
194/// # Algorithm
195///
196/// For each feature:
197/// 1. Sort samples by feature value (O(n log n))
198/// 2. Precompute cumulative gradient/hessian sums (O(n))
199/// 3. Evaluate all splits meeting `min_samples_leaf` (O(n))
200/// 4. Track the split with maximum gain
201///
202/// The overall complexity is O(k × n log n) where k is the number of features considered.
203#[derive(Debug, Clone)]
204pub struct BestSplitter {
205 /// Number of bins for histogram approximation (None = exact).
206 pub n_bins: Option<usize>,
207 /// Whether to disable approximation and use exact splits.
208 pub use_exact_splits: bool,
209}
210
211impl Default for BestSplitter {
212 fn default() -> Self {
213 Self {
214 n_bins: Some(256),
215 use_exact_splits: false,
216 }
217 }
218}
219
220impl BestSplitter {
221 /// Creates a new `BestSplitter` with default settings (256 bins, approximate).
222 pub fn new() -> Self {
223 Self::default()
224 }
225
226 /// Sets the number of histogram bins for approximate splitting.
227 ///
228 /// More bins = more accurate but slower. Fewer bins = faster but coarser approximations.
229 ///
230 /// # Arguments
231 ///
232 /// * `n_bins` - Number of histogram bins (must be ≥ 2)
233 ///
234 /// # Returns
235 ///
236 /// A new `BestSplitter` instance
237 pub fn with_n_bins(n_bins: usize) -> Self {
238 Self {
239 n_bins: Some(n_bins),
240 use_exact_splits: false,
241 }
242 }
243
244 /// Enables or disables exact split evaluation.
245 ///
246 /// When `true`, all splits are evaluated exactly regardless of dataset size.
247 /// When `false`, histogram approximation is used for large features.
248 ///
249 /// # Arguments
250 ///
251 /// * `use_exact` - Whether to force exact evaluation
252 ///
253 /// # Returns
254 ///
255 /// Self with updated setting (builder pattern)
256 pub fn use_exact_splits(mut self, use_exact: bool) -> Self {
257 self.use_exact_splits = use_exact;
258 self
259 }
260
261 /// Finds the best split for a single feature.
262 ///
263 /// This internal method delegates to either exact or approximate search
264 /// based on the splitter's configuration.
265 ///
266 /// # Arguments
267 ///
268 /// * `features` - Feature matrix
269 /// * `gradients` - Gradient values
270 /// * `hessians` - Hessian values
271 /// * `feature_index` - Feature to evaluate
272 /// * `min_samples_leaf` - Minimum leaf size
273 /// * `lambda` - L2 regularization
274 ///
275 /// # Returns
276 ///
277 /// The best split candidate for this feature, or `None` if no valid split exists
278 fn find_best_split_for_feature(
279 &self,
280 features: &FeatureMatrix,
281 gradients: &[f64],
282 hessians: &[f64],
283 feature_index: usize,
284 min_samples_leaf: usize,
285 lambda: f64,
286 ) -> Result<Option<SplitCandidate>, SplitterError> {
287 let n_samples = features.n_samples();
288
289 if n_samples < min_samples_leaf * 2 {
290 return Ok(None);
291 }
292
293 // Get feature values with error handling
294 let mut samples: Vec<(f64, f64, f64, usize)> = (0..n_samples)
295 .filter_map(|i| {
296 match features.get(i, feature_index) {
297 Ok(feature_val) => Some((feature_val, gradients[i], hessians[i], i)),
298 Err(_) => None,
299 }
300 })
301 .collect();
302
303 if samples.is_empty() {
304 return Err(SplitterError::EmptySamples);
305 }
306
307 // Sort by feature value with safe comparison
308 samples.sort_by(|a, b| {
309 a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
310 });
311
312 // Remove duplicate feature values at boundaries to avoid degenerate splits
313 samples.dedup_by(|a, b| (a.0 - b.0).abs() < 1e-10);
314
315 if samples.len() < min_samples_leaf * 2 {
316 return Ok(None);
317 }
318
319 if self.use_exact_splits {
320 self.find_best_exact_split(&samples, feature_index, min_samples_leaf, lambda)
321 } else {
322 self.find_best_approximate_split(&samples, feature_index, min_samples_leaf, lambda)
323 }
324 }
325
326 /// Exhaustively evaluates all possible split points.
327 ///
328 /// # Arguments
329 ///
330 /// * `samples` - Sorted samples: (feature_value, gradient, hessian, index)
331 /// * `feature_index` - Feature being evaluated
332 /// * `min_samples_leaf` - Minimum samples per child
333 /// * `lambda` - L2 regularization
334 ///
335 /// # Returns
336 ///
337 /// Best split candidate found via exhaustive search
338 fn find_best_exact_split(
339 &self,
340 samples: &[(f64, f64, f64, usize)],
341 feature_index: usize,
342 min_samples_leaf: usize,
343 lambda: f64,
344 ) -> Result<Option<SplitCandidate>, SplitterError> {
345 let n_samples = samples.len();
346 let mut best_gain = -f64::INFINITY;
347 let mut best_split: Option<SplitCandidate> = None;
348
349 // Precompute cumulative sums
350 let mut grad_prefix = Vec::with_capacity(n_samples + 1);
351 let mut hess_prefix = Vec::with_capacity(n_samples + 1);
352
353 grad_prefix.push(0.0);
354 hess_prefix.push(0.0);
355
356 for i in 0..n_samples {
357 grad_prefix.push(grad_prefix[i] + samples[i].1);
358 hess_prefix.push(hess_prefix[i] + samples[i].2);
359 }
360
361 let total_grad = grad_prefix[n_samples];
362 let total_hess = hess_prefix[n_samples];
363
364 if total_hess + lambda <= 0.0 {
365 return Ok(None);
366 }
367
368 let parent_gain = self.compute_gain(total_grad, total_hess, lambda);
369
370 // Try all possible splits with bounds checking
371 let max_i = n_samples.saturating_sub(min_samples_leaf);
372 for i in min_samples_leaf..=max_i {
373 // Ensure we don't go out of bounds
374 if i >= n_samples {
375 break;
376 }
377
378 // Skip if feature value is too close to next (degenerate)
379 if i < n_samples - 1 {
380 let diff = (samples[i].0 - samples[i + 1].0).abs();
381 if diff < 1e-10 {
382 continue;
383 }
384 }
385
386 let left_grad = grad_prefix[i];
387 let left_hess = hess_prefix[i];
388 let right_grad = total_grad - left_grad;
389 let right_hess = total_hess - left_hess;
390
391 // Avoid division by zero
392 if left_hess + lambda <= 0.0 || right_hess + lambda <= 0.0 {
393 continue;
394 }
395
396 let left_gain = self.compute_gain(left_grad, left_hess, lambda);
397 let right_gain = self.compute_gain(right_grad, right_hess, lambda);
398 let gain = left_gain + right_gain - parent_gain;
399
400 // Only consider positive gains
401 if gain > best_gain && gain > 1e-10 {
402 best_gain = gain;
403
404 let split_value = if i < n_samples - 1 {
405 (samples[i].0 + samples[i + 1].0) / 2.0
406 } else {
407 samples[i].0
408 };
409
410 let left_indices: Vec<usize> = samples[0..i].iter().map(|&(_, _, _, idx)| idx).collect();
411 let right_indices: Vec<usize> = samples[i..].iter().map(|&(_, _, _, idx)| idx).collect();
412
413 // Validate split sizes
414 if left_indices.len() >= min_samples_leaf && right_indices.len() >= min_samples_leaf {
415 best_split = Some(SplitCandidate::new(
416 feature_index,
417 split_value,
418 gain,
419 left_indices,
420 right_indices,
421 left_grad,
422 right_grad,
423 left_hess,
424 right_hess,
425 ));
426 }
427 }
428 }
429
430 Ok(best_split)
431 }
432
433 /// Uses histogram binning for faster approximate split finding.
434 ///
435 /// # Arguments
436 ///
437 /// * `samples` - Sorted samples: (feature_value, gradient, hessian, index)
438 /// * `feature_index` - Feature being evaluated
439 /// * `min_samples_leaf` - Minimum samples per child
440 /// * `lambda` - L2 regularization
441 ///
442 /// # Returns
443 ///
444 /// Best split candidate found via histogram approximation
445 fn find_best_approximate_split(
446 &self,
447 samples: &[(f64, f64, f64, usize)],
448 feature_index: usize,
449 min_samples_leaf: usize,
450 lambda: f64,
451 ) -> Result<Option<SplitCandidate>, SplitterError> {
452 let n_samples = samples.len();
453 let n_bins = self.n_bins.unwrap_or(256).min(n_samples);
454
455 if n_bins < 2 {
456 return self.find_best_exact_split(samples, feature_index, min_samples_leaf, lambda);
457 }
458
459 // Create bins
460 let min_val = samples[0].0;
461 let max_val = samples[n_samples - 1].0;
462 if (max_val - min_val).abs() < 1e-10 {
463 return Ok(None);
464 }
465
466 let bin_size = (max_val - min_val) / (n_bins as f64);
467 let mut bins: Vec<(f64, f64, f64, f64, Vec<usize>)> = Vec::with_capacity(n_bins);
468
469 for _ in 0..n_bins {
470 bins.push((f64::INFINITY, f64::NEG_INFINITY, 0.0, 0.0, Vec::new()));
471 }
472
473 // Accumulate into bins
474 for &(feature_val, grad, hess, idx) in samples {
475 let bin_idx = (((feature_val - min_val) / bin_size) as usize).min(n_bins - 1);
476 let bin = &mut bins[bin_idx];
477 bin.0 = bin.0.min(feature_val);
478 bin.1 = bin.1.max(feature_val);
479 bin.2 += grad;
480 bin.3 += hess;
481 bin.4.push(idx);
482 }
483
484 // Filter empty bins
485 let non_empty_bins: Vec<(f64, f64, f64, f64, Vec<usize>)> = bins
486 .into_iter()
487 .filter(|bin| !bin.4.is_empty())
488 .collect();
489
490 if non_empty_bins.len() < 2 {
491 return Ok(None);
492 }
493
494 // Merge small bins
495 let mut merged_bins = Vec::new();
496 let mut current_bin = (f64::INFINITY, f64::NEG_INFINITY, 0.0, 0.0, Vec::new());
497
498 for bin in non_empty_bins {
499 if current_bin.4.len() + bin.4.len() < min_samples_leaf && !current_bin.4.is_empty() {
500 current_bin.0 = current_bin.0.min(bin.0);
501 current_bin.1 = current_bin.1.max(bin.1);
502 current_bin.2 += bin.2;
503 current_bin.3 += bin.3;
504 current_bin.4.extend(bin.4);
505 } else {
506 if !current_bin.4.is_empty() {
507 merged_bins.push(current_bin);
508 }
509 current_bin = bin;
510 }
511 }
512
513 if !current_bin.4.is_empty() {
514 merged_bins.push(current_bin);
515 }
516
517 if merged_bins.len() < 2 {
518 return Ok(None);
519 }
520
521 // Find best split among bins
522 let mut best_gain = -f64::INFINITY;
523 let mut best_split: Option<SplitCandidate> = None;
524
525 let total_grad: f64 = merged_bins.iter().map(|bin| bin.2).sum();
526 let total_hess: f64 = merged_bins.iter().map(|bin| bin.3).sum();
527
528 if total_hess + lambda <= 0.0 {
529 return Ok(None);
530 }
531
532 let parent_gain = self.compute_gain(total_grad, total_hess, lambda);
533
534 let mut left_grad = 0.0;
535 let mut left_hess = 0.0;
536 let mut left_indices = Vec::new();
537
538 for i in 0..(merged_bins.len() - 1) {
539 left_grad += merged_bins[i].2;
540 left_hess += merged_bins[i].3;
541 left_indices.extend(&merged_bins[i].4);
542
543 if left_indices.len() < min_samples_leaf {
544 continue;
545 }
546
547 let right_grad = total_grad - left_grad;
548 let right_hess = total_hess - left_hess;
549 let right_indices: Vec<usize> = merged_bins[i + 1..]
550 .iter()
551 .flat_map(|bin| bin.4.iter().cloned())
552 .collect();
553
554 if right_indices.len() < min_samples_leaf {
555 continue;
556 }
557
558 if left_hess + lambda <= 0.0 || right_hess + lambda <= 0.0 {
559 continue;
560 }
561
562 let left_gain = self.compute_gain(left_grad, left_hess, lambda);
563 let right_gain = self.compute_gain(right_grad, right_hess, lambda);
564 let gain = left_gain + right_gain - parent_gain;
565
566 if gain > best_gain && gain > 1e-10 {
567 best_gain = gain;
568
569 let split_value = (merged_bins[i].1 + merged_bins[i + 1].0) / 2.0;
570
571 best_split = Some(SplitCandidate::new(
572 feature_index,
573 split_value,
574 gain,
575 left_indices.clone(),
576 right_indices,
577 left_grad,
578 right_grad,
579 left_hess,
580 right_hess,
581 ));
582 }
583 }
584
585 Ok(best_split)
586 }
587
588 /// Computes the gain for a node given its gradient/hessian sums.
589 ///
590 /// # Formula
591 ///
592 /// `gain = (sum_grad)² / (sum_hess + lambda)`
593 ///
594 /// # Arguments
595 ///
596 /// * `sum_grad` - Sum of gradients in the node
597 /// * `sum_hess` - Sum of hessians in the node
598 /// * `lambda` - L2 regularization parameter
599 ///
600 /// # Returns
601 ///
602 /// Node gain, or `-∞` if the denominator would be non-positive
603 fn compute_gain(&self, sum_grad: f64, sum_hess: f64, lambda: f64) -> f64 {
604 if sum_hess + lambda <= 0.0 {
605 -f64::INFINITY
606 } else {
607 (sum_grad * sum_grad) / (sum_hess + lambda)
608 }
609 }
610}
611
612impl Splitter for BestSplitter {
613 /// Finds the globally optimal split across all specified features.
614 ///
615 /// This method iterates through `feature_indices`, finds the best split for each,
616 /// and returns the candidate with maximum gain. It respects all regularization
617 /// and sampling constraints.
618 ///
619 /// # Arguments
620 ///
621 /// * `features` - Feature matrix
622 /// * `gradients` - Gradient values
623 /// * `hessians` - Hessian values
624 /// * `feature_indices` - Features to evaluate
625 /// * `min_samples_leaf` - Minimum samples per child
626 /// * `lambda` - L2 regularization
627 ///
628 /// # Returns
629 ///
630 /// Best split across all features, or `None` if no valid split exists
631 fn find_best_split(
632 &self,
633 features: &FeatureMatrix,
634 gradients: &[f64],
635 hessians: &[f64],
636 feature_indices: &[usize],
637 min_samples_leaf: usize,
638 lambda: f64,
639 ) -> Result<Option<SplitCandidate>, SplitterError> {
640 if features.n_samples() != gradients.len() || gradients.len() != hessians.len() {
641 return Err(SplitterError::InsufficientSamples {
642 samples: gradients.len(),
643 min: features.n_samples(),
644 });
645 }
646
647 if features.n_samples() < min_samples_leaf * 2 {
648 return Ok(None);
649 }
650
651 let mut best_candidate: Option<SplitCandidate> = None;
652
653 // Add progress tracking for large datasets
654 if features.n_features() > 100 && features.n_samples() > 10000 {
655 // Potentially log progress here
656 }
657
658 for (idx, &feature_idx) in feature_indices.iter().enumerate() {
659 // Log progress every 10 features for large datasets
660 if idx > 0 && idx % 10 == 0 && features.n_features() > 50 {
661 // Progress log
662 }
663
664 if feature_idx >= features.n_features() {
665 return Err(SplitterError::InvalidFeatureIndex {
666 index: feature_idx,
667 max: features.n_features() - 1,
668 });
669 }
670
671 if let Some(candidate) = self.find_best_split_for_feature(
672 features,
673 gradients,
674 hessians,
675 feature_idx,
676 min_samples_leaf,
677 lambda,
678 )? {
679 if best_candidate.as_ref().map(|c| c.gain < candidate.gain).unwrap_or(true) {
680 best_candidate = Some(candidate);
681 }
682 }
683 }
684
685 Ok(best_candidate)
686 }
687}
688