use crate::dataset::{BinnedDataset, SparseColumn, DEFAULT_BIN};
use crate::histogram::NodeHistograms;
use crate::loss::LossFunction;
use rayon::prelude::*;
const BLOCK_SIZE: usize = 2048;
pub struct FusedResult {
pub histograms: NodeHistograms,
pub total_gradient: f32,
pub total_hessian: f32,
}
#[derive(Debug, Clone, Copy)]
pub struct FusedHistogramBuilder;
impl Default for FusedHistogramBuilder {
fn default() -> Self {
Self::new()
}
}
impl FusedHistogramBuilder {
pub fn new() -> Self {
Self
}
#[allow(clippy::too_many_arguments)]
pub fn build_root(
&self,
dataset: &BinnedDataset,
row_indices: &[usize],
targets: &[f32],
predictions: &[f32],
loss_fn: &dyn LossFunction,
gradients: &mut [f32],
hessians: &mut [f32],
) -> FusedResult {
let num_rows = row_indices.len();
if num_rows < BLOCK_SIZE {
return self.build_single_block(
dataset,
row_indices,
targets,
predictions,
loss_fn,
gradients,
hessians,
);
}
if Self::is_contiguous(row_indices) {
self.build_blocked_contiguous(
dataset,
num_rows,
targets,
predictions,
loss_fn,
gradients,
hessians,
)
} else {
self.build_blocked_indexed(
dataset,
row_indices,
targets,
predictions,
loss_fn,
gradients,
hessians,
)
}
}
#[allow(clippy::too_many_arguments)]
fn build_blocked_contiguous(
&self,
dataset: &BinnedDataset,
num_rows: usize,
targets: &[f32],
predictions: &[f32],
loss_fn: &dyn LossFunction,
gradients: &mut [f32],
hessians: &mut [f32],
) -> FusedResult {
let num_features = dataset.num_features();
let partial_results: Vec<(NodeHistograms, f32, f32)> = (0..num_rows)
.into_par_iter()
.step_by(BLOCK_SIZE)
.map(|block_start| {
let block_end = (block_start + BLOCK_SIZE).min(num_rows);
let block_len = block_end - block_start;
let mut local_hists = NodeHistograms::new(num_features);
let mut grad_cache = [0.0f32; BLOCK_SIZE];
let mut hess_cache = [0.0f32; BLOCK_SIZE];
for i in 0..block_len {
let row = block_start + i;
let (g, h) = loss_fn.gradient_hessian(targets[row], predictions[row]);
grad_cache[i] = g;
hess_cache[i] = h;
}
unsafe {
let grad_ptr = gradients.as_ptr() as *mut f32;
let hess_ptr = hessians.as_ptr() as *mut f32;
std::ptr::copy_nonoverlapping(
grad_cache.as_ptr(),
grad_ptr.add(block_start),
block_len,
);
std::ptr::copy_nonoverlapping(
hess_cache.as_ptr(),
hess_ptr.add(block_start),
block_len,
);
}
let block_grad: f32 = grad_cache[..block_len].iter().sum();
let block_hess: f32 = hess_cache[..block_len].iter().sum();
for feature_idx in 0..num_features {
if let Some(sparse_col) = dataset.sparse_column(feature_idx) {
Self::build_sparse_histogram_block(
local_hists.get_mut(feature_idx),
sparse_col,
block_start,
block_len,
&grad_cache,
&hess_cache,
block_grad,
block_hess,
);
} else {
let feature_column = dataset.feature_column(feature_idx);
let hist = local_hists.get_mut(feature_idx);
let bins = hist.bins_mut();
let chunks = block_len / 8;
let remainder = block_len % 8;
unsafe {
for i in 0..chunks {
let base = i * 8;
let row_base = block_start + base;
let bin0 = *feature_column.get_unchecked(row_base) as usize;
let bin1 = *feature_column.get_unchecked(row_base + 1) as usize;
let bin2 = *feature_column.get_unchecked(row_base + 2) as usize;
let bin3 = *feature_column.get_unchecked(row_base + 3) as usize;
let bin4 = *feature_column.get_unchecked(row_base + 4) as usize;
let bin5 = *feature_column.get_unchecked(row_base + 5) as usize;
let bin6 = *feature_column.get_unchecked(row_base + 6) as usize;
let bin7 = *feature_column.get_unchecked(row_base + 7) as usize;
let grad0 = *grad_cache.get_unchecked(base);
let grad1 = *grad_cache.get_unchecked(base + 1);
let grad2 = *grad_cache.get_unchecked(base + 2);
let grad3 = *grad_cache.get_unchecked(base + 3);
let grad4 = *grad_cache.get_unchecked(base + 4);
let grad5 = *grad_cache.get_unchecked(base + 5);
let grad6 = *grad_cache.get_unchecked(base + 6);
let grad7 = *grad_cache.get_unchecked(base + 7);
let hess0 = *hess_cache.get_unchecked(base);
let hess1 = *hess_cache.get_unchecked(base + 1);
let hess2 = *hess_cache.get_unchecked(base + 2);
let hess3 = *hess_cache.get_unchecked(base + 3);
let hess4 = *hess_cache.get_unchecked(base + 4);
let hess5 = *hess_cache.get_unchecked(base + 5);
let hess6 = *hess_cache.get_unchecked(base + 6);
let hess7 = *hess_cache.get_unchecked(base + 7);
bins.get_unchecked_mut(bin0).accumulate(grad0, hess0);
bins.get_unchecked_mut(bin1).accumulate(grad1, hess1);
bins.get_unchecked_mut(bin2).accumulate(grad2, hess2);
bins.get_unchecked_mut(bin3).accumulate(grad3, hess3);
bins.get_unchecked_mut(bin4).accumulate(grad4, hess4);
bins.get_unchecked_mut(bin5).accumulate(grad5, hess5);
bins.get_unchecked_mut(bin6).accumulate(grad6, hess6);
bins.get_unchecked_mut(bin7).accumulate(grad7, hess7);
}
let rem_base = chunks * 8;
for i in 0..remainder {
let bin = *feature_column.get_unchecked(block_start + rem_base + i)
as usize;
let grad = *grad_cache.get_unchecked(rem_base + i);
let hess = *hess_cache.get_unchecked(rem_base + i);
bins.get_unchecked_mut(bin).accumulate(grad, hess);
}
}
}
}
(local_hists, block_grad, block_hess)
})
.collect();
self.reduce_results(partial_results, num_features)
}
#[allow(clippy::too_many_arguments)]
fn build_blocked_indexed(
&self,
dataset: &BinnedDataset,
row_indices: &[usize],
targets: &[f32],
predictions: &[f32],
loss_fn: &dyn LossFunction,
gradients: &mut [f32],
hessians: &mut [f32],
) -> FusedResult {
let num_features = dataset.num_features();
let partial_results: Vec<(NodeHistograms, f32, f32)> = row_indices
.par_chunks(BLOCK_SIZE)
.map(|chunk| {
let block_len = chunk.len();
let mut local_hists = NodeHistograms::new(num_features);
let mut grad_cache = [0.0f32; BLOCK_SIZE];
let mut hess_cache = [0.0f32; BLOCK_SIZE];
for (i, &row_idx) in chunk.iter().enumerate() {
let (g, h) = loss_fn.gradient_hessian(targets[row_idx], predictions[row_idx]);
grad_cache[i] = g;
hess_cache[i] = h;
unsafe {
let grad_ptr = gradients.as_ptr() as *mut f32;
let hess_ptr = hessians.as_ptr() as *mut f32;
*grad_ptr.add(row_idx) = g;
*hess_ptr.add(row_idx) = h;
}
}
let block_grad: f32 = grad_cache[..block_len].iter().sum();
let block_hess: f32 = hess_cache[..block_len].iter().sum();
for feature_idx in 0..num_features {
if let Some(sparse_col) = dataset.sparse_column(feature_idx) {
Self::build_sparse_histogram_indexed(
local_hists.get_mut(feature_idx),
sparse_col,
chunk,
&grad_cache,
&hess_cache,
block_len,
block_grad,
block_hess,
);
} else {
let feature_column = dataset.feature_column(feature_idx);
let hist = local_hists.get_mut(feature_idx);
let bins = hist.bins_mut();
let chunks_count = block_len / 8;
let remainder = block_len % 8;
unsafe {
for i in 0..chunks_count {
let base = i * 8;
let idx0 = *chunk.get_unchecked(base);
let idx1 = *chunk.get_unchecked(base + 1);
let idx2 = *chunk.get_unchecked(base + 2);
let idx3 = *chunk.get_unchecked(base + 3);
let idx4 = *chunk.get_unchecked(base + 4);
let idx5 = *chunk.get_unchecked(base + 5);
let idx6 = *chunk.get_unchecked(base + 6);
let idx7 = *chunk.get_unchecked(base + 7);
let bin0 = *feature_column.get_unchecked(idx0) as usize;
let bin1 = *feature_column.get_unchecked(idx1) as usize;
let bin2 = *feature_column.get_unchecked(idx2) as usize;
let bin3 = *feature_column.get_unchecked(idx3) as usize;
let bin4 = *feature_column.get_unchecked(idx4) as usize;
let bin5 = *feature_column.get_unchecked(idx5) as usize;
let bin6 = *feature_column.get_unchecked(idx6) as usize;
let bin7 = *feature_column.get_unchecked(idx7) as usize;
let grad0 = *grad_cache.get_unchecked(base);
let grad1 = *grad_cache.get_unchecked(base + 1);
let grad2 = *grad_cache.get_unchecked(base + 2);
let grad3 = *grad_cache.get_unchecked(base + 3);
let grad4 = *grad_cache.get_unchecked(base + 4);
let grad5 = *grad_cache.get_unchecked(base + 5);
let grad6 = *grad_cache.get_unchecked(base + 6);
let grad7 = *grad_cache.get_unchecked(base + 7);
let hess0 = *hess_cache.get_unchecked(base);
let hess1 = *hess_cache.get_unchecked(base + 1);
let hess2 = *hess_cache.get_unchecked(base + 2);
let hess3 = *hess_cache.get_unchecked(base + 3);
let hess4 = *hess_cache.get_unchecked(base + 4);
let hess5 = *hess_cache.get_unchecked(base + 5);
let hess6 = *hess_cache.get_unchecked(base + 6);
let hess7 = *hess_cache.get_unchecked(base + 7);
bins.get_unchecked_mut(bin0).accumulate(grad0, hess0);
bins.get_unchecked_mut(bin1).accumulate(grad1, hess1);
bins.get_unchecked_mut(bin2).accumulate(grad2, hess2);
bins.get_unchecked_mut(bin3).accumulate(grad3, hess3);
bins.get_unchecked_mut(bin4).accumulate(grad4, hess4);
bins.get_unchecked_mut(bin5).accumulate(grad5, hess5);
bins.get_unchecked_mut(bin6).accumulate(grad6, hess6);
bins.get_unchecked_mut(bin7).accumulate(grad7, hess7);
}
let rem_base = chunks_count * 8;
for i in 0..remainder {
let idx = *chunk.get_unchecked(rem_base + i);
let bin = *feature_column.get_unchecked(idx) as usize;
let grad = *grad_cache.get_unchecked(rem_base + i);
let hess = *hess_cache.get_unchecked(rem_base + i);
bins.get_unchecked_mut(bin).accumulate(grad, hess);
}
}
}
}
(local_hists, block_grad, block_hess)
})
.collect();
self.reduce_results(partial_results, num_features)
}
#[allow(clippy::too_many_arguments)]
fn build_single_block(
&self,
dataset: &BinnedDataset,
row_indices: &[usize],
targets: &[f32],
predictions: &[f32],
loss_fn: &dyn LossFunction,
gradients: &mut [f32],
hessians: &mut [f32],
) -> FusedResult {
let num_features = dataset.num_features();
let num_rows = row_indices.len();
let mut histograms = NodeHistograms::new(num_features);
let mut total_gradient = 0.0f32;
let mut total_hessian = 0.0f32;
let mut grad_cache = vec![0.0f32; num_rows];
let mut hess_cache = vec![0.0f32; num_rows];
for (i, &row_idx) in row_indices.iter().enumerate() {
let (g, h) = loss_fn.gradient_hessian(targets[row_idx], predictions[row_idx]);
grad_cache[i] = g;
hess_cache[i] = h;
gradients[row_idx] = g;
hessians[row_idx] = h;
total_gradient += g;
total_hessian += h;
}
for feature_idx in 0..num_features {
let feature_column = dataset.feature_column(feature_idx);
let hist = histograms.get_mut(feature_idx);
for (i, &row_idx) in row_indices.iter().enumerate() {
let bin = feature_column[row_idx];
hist.accumulate(bin, grad_cache[i], hess_cache[i]);
}
}
FusedResult {
histograms,
total_gradient,
total_hessian,
}
}
fn reduce_results(
&self,
partials: Vec<(NodeHistograms, f32, f32)>,
num_features: usize,
) -> FusedResult {
if partials.is_empty() {
return FusedResult {
histograms: NodeHistograms::new(num_features),
total_gradient: 0.0,
total_hessian: 0.0,
};
}
if partials.len() == 1 {
let (hists, grad, hess) = partials.into_iter().next().unwrap();
return FusedResult {
histograms: hists,
total_gradient: grad,
total_hessian: hess,
};
}
let mut result_hists = NodeHistograms::new(num_features);
let mut total_grad = 0.0f32;
let mut total_hess = 0.0f32;
for (partial_hists, grad, hess) in partials {
result_hists.merge(&partial_hists);
total_grad += grad;
total_hess += hess;
}
FusedResult {
histograms: result_hists,
total_gradient: total_grad,
total_hessian: total_hess,
}
}
#[inline]
fn is_contiguous(row_indices: &[usize]) -> bool {
if row_indices.is_empty() {
return true;
}
row_indices[0] == 0 && row_indices.last() == Some(&(row_indices.len() - 1))
}
#[allow(clippy::too_many_arguments)]
fn build_sparse_histogram_block(
hist: &mut crate::histogram::Histogram,
sparse_col: &SparseColumn,
block_start: usize,
block_len: usize,
grad_cache: &[f32; BLOCK_SIZE],
hess_cache: &[f32; BLOCK_SIZE],
block_total_grad: f32,
block_total_hess: f32,
) {
let block_end = block_start + block_len;
let bins = hist.bins_mut();
let mut non_default_grad = 0.0f32;
let mut non_default_hess = 0.0f32;
let mut non_default_count = 0u32;
let start_pos = sparse_col
.indices
.partition_point(|&idx| (idx as usize) < block_start);
for i in start_pos..sparse_col.indices.len() {
let row_idx = sparse_col.indices[i] as usize;
if row_idx >= block_end {
break;
}
let bin = sparse_col.values[i];
let cache_idx = row_idx - block_start;
unsafe {
let grad = *grad_cache.get_unchecked(cache_idx);
let hess = *hess_cache.get_unchecked(cache_idx);
bins.get_unchecked_mut(bin as usize).accumulate(grad, hess);
non_default_grad += grad;
non_default_hess += hess;
non_default_count += 1;
}
}
let default_grad = block_total_grad - non_default_grad;
let default_hess = block_total_hess - non_default_hess;
let default_count = block_len as u32 - non_default_count;
if default_count > 0 {
unsafe {
bins.get_unchecked_mut(DEFAULT_BIN as usize)
.accumulate_with_count(default_grad, default_hess, default_count);
}
}
}
#[allow(clippy::too_many_arguments)]
fn build_sparse_histogram_indexed(
hist: &mut crate::histogram::Histogram,
sparse_col: &SparseColumn,
chunk: &[usize],
grad_cache: &[f32; BLOCK_SIZE],
hess_cache: &[f32; BLOCK_SIZE],
block_len: usize,
block_total_grad: f32,
block_total_hess: f32,
) {
let bins = hist.bins_mut();
let mut non_default_grad = 0.0f32;
let mut non_default_hess = 0.0f32;
let mut non_default_count = 0u32;
let chunk_min = chunk.first().copied().unwrap_or(0);
let chunk_max = chunk.last().copied().unwrap_or(0);
let sparse_start = sparse_col
.indices
.partition_point(|&idx| (idx as usize) < chunk_min);
let sparse_end = sparse_col
.indices
.partition_point(|&idx| (idx as usize) <= chunk_max);
let sparse_range = sparse_end - sparse_start;
if sparse_range <= chunk.len() / 2 {
for i in sparse_start..sparse_end {
let row_idx = sparse_col.indices[i] as usize;
let bin = sparse_col.values[i];
if let Ok(cache_idx) = chunk.binary_search(&row_idx) {
unsafe {
let grad = *grad_cache.get_unchecked(cache_idx);
let hess = *hess_cache.get_unchecked(cache_idx);
bins.get_unchecked_mut(bin as usize).accumulate(grad, hess);
non_default_grad += grad;
non_default_hess += hess;
non_default_count += 1;
}
}
}
} else {
for (cache_idx, &row_idx) in chunk.iter().enumerate() {
if let Ok(sparse_pos) = sparse_col.indices.binary_search(&(row_idx as u32)) {
let bin = sparse_col.values[sparse_pos];
unsafe {
let grad = *grad_cache.get_unchecked(cache_idx);
let hess = *hess_cache.get_unchecked(cache_idx);
bins.get_unchecked_mut(bin as usize).accumulate(grad, hess);
non_default_grad += grad;
non_default_hess += hess;
non_default_count += 1;
}
}
}
}
let default_grad = block_total_grad - non_default_grad;
let default_hess = block_total_hess - non_default_hess;
let default_count = block_len as u32 - non_default_count;
if default_count > 0 {
unsafe {
bins.get_unchecked_mut(DEFAULT_BIN as usize)
.accumulate_with_count(default_grad, default_hess, default_count);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::{FeatureInfo, FeatureType};
use crate::loss::MseLoss;
fn create_test_dataset(num_rows: usize, num_features: usize) -> BinnedDataset {
let mut features = Vec::with_capacity(num_rows * num_features);
for f in 0..num_features {
for r in 0..num_rows {
features.push(((r * (f + 1) * 17) % 256) as u8);
}
}
let targets: Vec<f32> = (0..num_rows).map(|i| (i as f32 * 0.01).sin()).collect();
let feature_info = (0..num_features)
.map(|i| FeatureInfo {
name: format!("f{}", i),
feature_type: FeatureType::Numeric,
num_bins: 255,
bin_boundaries: vec![],
})
.collect();
BinnedDataset::new(num_rows, features, targets, feature_info)
}
#[test]
fn test_fused_basic() {
let num_rows = 1000;
let num_features = 10;
let dataset = create_test_dataset(num_rows, num_features);
let targets = dataset.targets().to_vec();
let predictions = vec![0.0f32; num_rows];
let mut gradients = vec![0.0f32; num_rows];
let mut hessians = vec![0.0f32; num_rows];
let row_indices: Vec<usize> = (0..num_rows).collect();
let loss_fn = MseLoss::new();
let builder = FusedHistogramBuilder::new();
let result = builder.build_root(
&dataset,
&row_indices,
&targets,
&predictions,
&loss_fn,
&mut gradients,
&mut hessians,
);
assert_eq!(result.histograms.num_features(), num_features);
let total_count: u32 = result
.histograms
.get(0)
.bins()
.iter()
.map(|b| b.count)
.sum();
assert_eq!(total_count, num_rows as u32);
assert!(gradients.iter().any(|&g| g != 0.0));
assert!(hessians.iter().all(|&h| h == 1.0)); }
#[test]
fn test_fused_matches_separate() {
use crate::histogram::HistogramBuilder;
let num_rows = 5000;
let num_features = 20;
let dataset = create_test_dataset(num_rows, num_features);
let targets = dataset.targets().to_vec();
let predictions = vec![0.5f32; num_rows];
let row_indices: Vec<usize> = (0..num_rows).collect();
let loss_fn = MseLoss::new();
let mut fused_grads = vec![0.0f32; num_rows];
let mut fused_hess = vec![0.0f32; num_rows];
let fused_builder = FusedHistogramBuilder::new();
let fused_result = fused_builder.build_root(
&dataset,
&row_indices,
&targets,
&predictions,
&loss_fn,
&mut fused_grads,
&mut fused_hess,
);
let mut sep_grads = vec![0.0f32; num_rows];
let mut sep_hess = vec![0.0f32; num_rows];
for &idx in &row_indices {
let (g, h) = loss_fn.gradient_hessian(targets[idx], predictions[idx]);
sep_grads[idx] = g;
sep_hess[idx] = h;
}
let hist_builder = HistogramBuilder::new();
let sep_hists = hist_builder.build(&dataset, &row_indices, &sep_grads, &sep_hess);
for i in 0..num_rows {
assert!(
(fused_grads[i] - sep_grads[i]).abs() < 1e-6,
"Gradient mismatch at row {}",
i
);
assert!(
(fused_hess[i] - sep_hess[i]).abs() < 1e-6,
"Hessian mismatch at row {}",
i
);
}
for f in 0..num_features {
for bin in 0..=255u8 {
let fused_entry = fused_result.histograms.get(f).get(bin);
let sep_entry = sep_hists.get(f).get(bin);
assert!(
(fused_entry.sum_gradients - sep_entry.sum_gradients).abs() < 1e-4,
"Gradient sum mismatch at feature {} bin {}",
f,
bin
);
assert!(
(fused_entry.sum_hessians - sep_entry.sum_hessians).abs() < 1e-4,
"Hessian sum mismatch at feature {} bin {}",
f,
bin
);
assert_eq!(
fused_entry.count, sep_entry.count,
"Count mismatch at feature {} bin {}",
f, bin
);
}
}
}
#[test]
fn test_fused_indexed_rows() {
let num_rows = 1000;
let num_features = 5;
let dataset = create_test_dataset(num_rows, num_features);
let targets = dataset.targets().to_vec();
let predictions = vec![0.0f32; num_rows];
let mut gradients = vec![0.0f32; num_rows];
let mut hessians = vec![0.0f32; num_rows];
let row_indices: Vec<usize> = (0..num_rows).filter(|i| i % 2 == 0).collect();
let loss_fn = MseLoss::new();
let builder = FusedHistogramBuilder::new();
let result = builder.build_root(
&dataset,
&row_indices,
&targets,
&predictions,
&loss_fn,
&mut gradients,
&mut hessians,
);
let total_count: u32 = result
.histograms
.get(0)
.bins()
.iter()
.map(|b| b.count)
.sum();
assert_eq!(total_count, row_indices.len() as u32);
}
}