use crate::defaults::split as split_defaults;
use crate::histogram::Histogram;
pub use crate::dataset::SparseColumn;
pub trait BinStorage: Sync {
fn get_bin(&self, row: usize, feature: usize) -> u8;
fn num_rows(&self) -> usize;
fn num_features(&self) -> usize;
fn feature_column(&self, feature: usize) -> Option<&[u8]>;
fn sparse_column(&self, feature: usize) -> Option<&SparseColumn>;
fn is_sparse(&self, feature: usize) -> bool {
self.sparse_column(feature).is_some()
}
fn as_row_major(&self) -> Option<&[u8]> {
None
}
fn max_bins(&self) -> u8 {
255 }
fn supports_4bit(&self) -> bool {
self.max_bins() <= 16
}
fn as_row_major_4bit(&self) -> Option<&[u8]> {
None
}
fn bytes_per_row_4bit(&self) -> usize {
self.num_features().div_ceil(2)
}
}
#[derive(Debug, Clone, Copy)]
pub struct SplitConfig {
pub lambda: f32,
pub min_samples_leaf: u32,
pub min_hessian_leaf: f32,
pub min_gain: f32,
pub entropy_weight: f32,
}
impl Default for SplitConfig {
fn default() -> Self {
Self {
lambda: split_defaults::DEFAULT_SPLIT_LAMBDA,
min_samples_leaf: split_defaults::DEFAULT_SPLIT_MIN_SAMPLES_LEAF,
min_hessian_leaf: split_defaults::DEFAULT_SPLIT_MIN_HESSIAN_LEAF,
min_gain: split_defaults::DEFAULT_SPLIT_MIN_GAIN,
entropy_weight: split_defaults::DEFAULT_SPLIT_ENTROPY_WEIGHT,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct SplitCandidate {
pub feature: usize,
pub threshold: u8,
pub gain: f32,
pub left_gradient: f32,
pub left_hessian: f32,
pub left_count: u32,
pub right_gradient: f32,
pub right_hessian: f32,
pub right_count: u32,
}
pub trait HistogramBackend: Send + Sync {
fn name(&self) -> &'static str;
fn is_tensor_tile(&self) -> bool;
fn build_histograms(
&self,
bins: &dyn BinStorage,
grad_hess: &[(f32, f32)],
row_indices: &[usize],
) -> Vec<Histogram>;
fn build_histograms_sibling(
&self,
parent: &[Histogram],
smaller_child: &[Histogram],
) -> Vec<Histogram>;
fn find_best_split(
&self,
histograms: &[Histogram],
config: &SplitConfig,
) -> Option<SplitCandidate>;
fn build_histograms_batched(
&self,
bins: &dyn BinStorage,
grad_hess: &[(f32, f32)],
batches: &[&[usize]],
) -> Vec<Vec<Histogram>> {
batches
.iter()
.map(|row_indices| self.build_histograms(bins, grad_hess, row_indices))
.collect()
}
fn build_era_histograms(
&self,
bins: &dyn BinStorage,
grad_hess: &[(f32, f32)],
row_indices: &[usize],
era_indices: &[u16],
num_eras: usize,
) -> Vec<Vec<Histogram>> {
let num_features = bins.num_features();
let mut result = vec![vec![Histogram::new(); num_features]; num_eras];
for &row in row_indices {
let era = era_indices[row] as usize;
let (g, h) = grad_hess[row];
for (f, hist) in result[era].iter_mut().enumerate() {
let bin = bins.get_bin(row, f);
hist.accumulate(bin, g, h);
}
}
result
}
}