use crate::dataset::Bin;
#[derive(Debug, Clone, Copy, Default)]
#[repr(C)]
pub struct HistBin {
pub grad: f64,
pub hess: f64,
pub count: u32,
_pad: u32,
}
#[derive(Debug, Clone)]
pub struct FeatureHistogram {
pub bins: Vec<HistBin>,
}
impl FeatureHistogram {
pub fn zeros(num_bins: usize) -> Self {
Self {
bins: vec![HistBin::default(); num_bins],
}
}
pub fn num_bins(&self) -> usize {
self.bins.len()
}
pub fn clear(&mut self) {
for b in self.bins.iter_mut() {
*b = HistBin::default();
}
}
pub fn build<B: Bin>(&mut self, column: &[B], indices: &[u32], gradhess: &[[f32; 2]]) {
self.clear();
let bins = self.bins.as_mut_ptr();
unsafe {
for &i in indices {
let row = i as usize;
let bin = (*column.get_unchecked(row)).as_usize();
let gh = *gradhess.get_unchecked(row);
let b = bins.add(bin);
(*b).grad += gh[0] as f64;
(*b).hess += gh[1] as f64;
(*b).count += 1;
}
}
}
pub fn build_full<B: Bin>(&mut self, column: &[B], gradhess: &[[f32; 2]]) {
self.clear();
debug_assert_eq!(column.len(), gradhess.len());
let n = column.len();
let bins = self.bins.as_mut_ptr();
unsafe {
for row in 0..n {
let bin = (*column.get_unchecked(row)).as_usize();
let gh = *gradhess.get_unchecked(row);
let b = bins.add(bin);
(*b).grad += gh[0] as f64;
(*b).hess += gh[1] as f64;
(*b).count += 1;
}
}
}
pub fn subtract_into(
parent: &FeatureHistogram,
sibling: &FeatureHistogram,
out: &mut FeatureHistogram,
) {
let n = parent.num_bins();
debug_assert_eq!(n, sibling.num_bins());
debug_assert_eq!(n, out.num_bins());
let p = parent.bins.as_ptr();
let s = sibling.bins.as_ptr();
let o = out.bins.as_mut_ptr();
unsafe {
for i in 0..n {
let pi = &*p.add(i);
let si = &*s.add(i);
*o.add(i) = HistBin {
grad: pi.grad - si.grad,
hess: pi.hess - si.hess,
count: pi.count - si.count,
_pad: 0,
};
}
}
}
}
pub fn build_histograms_batched<B: Bin>(
columns: &[&[B]],
indices: &[u32],
gradhess: &[[f32; 2]],
histograms: &mut [FeatureHistogram],
) {
debug_assert_eq!(columns.len(), histograms.len());
for h in histograms.iter_mut() {
h.clear();
}
let n_feat = columns.len();
let bin_ptrs: Vec<*mut HistBin> = histograms.iter_mut().map(|h| h.bins.as_mut_ptr()).collect();
let col_ptrs: Vec<*const B> = columns.iter().map(|c| c.as_ptr()).collect();
unsafe {
for &i in indices {
let row = i as usize;
let gh = *gradhess.get_unchecked(row);
let g = gh[0] as f64;
let h = gh[1] as f64;
for fi in 0..n_feat {
let bin = (*col_ptrs.get_unchecked(fi).add(row)).as_usize();
let b = (*bin_ptrs.get_unchecked(fi)).add(bin);
(*b).grad += g;
(*b).hess += h;
(*b).count += 1;
}
}
}
}
pub fn build_histograms_batched_full<B: Bin>(
columns: &[&[B]],
gradhess: &[[f32; 2]],
histograms: &mut [FeatureHistogram],
) {
debug_assert_eq!(columns.len(), histograms.len());
for h in histograms.iter_mut() {
h.clear();
}
let n = gradhess.len();
let n_feat = columns.len();
let bin_ptrs: Vec<*mut HistBin> = histograms.iter_mut().map(|h| h.bins.as_mut_ptr()).collect();
let col_ptrs: Vec<*const B> = columns.iter().map(|c| c.as_ptr()).collect();
unsafe {
for row in 0..n {
let gh = *gradhess.get_unchecked(row);
let g = gh[0] as f64;
let h = gh[1] as f64;
for fi in 0..n_feat {
let bin = (*col_ptrs.get_unchecked(fi).add(row)).as_usize();
let b = (*bin_ptrs.get_unchecked(fi)).add(bin);
(*b).grad += g;
(*b).hess += h;
(*b).count += 1;
}
}
}
}