use crate::dataset::BinEntry;
use rkyv::{Archive, Deserialize, Serialize};
pub const NUM_BINS: usize = 256;
#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
pub struct Histogram {
bins: [BinEntry; NUM_BINS],
}
impl Default for Histogram {
fn default() -> Self {
Self::new()
}
}
impl Histogram {
pub fn new() -> Self {
Self {
bins: [BinEntry::default(); NUM_BINS],
}
}
pub fn from_raw_arrays(grads: &[f32; 256], hess: &[f32; 256], counts: &[u32; 256]) -> Self {
let mut bins = [BinEntry::default(); NUM_BINS];
for i in 0..NUM_BINS {
bins[i] = BinEntry {
sum_gradients: grads[i],
sum_hessians: hess[i],
count: counts[i],
};
}
Self { bins }
}
pub fn clear(&mut self) {
for bin in &mut self.bins {
*bin = BinEntry::default();
}
}
#[inline]
pub fn get(&self, bin: u8) -> &BinEntry {
&self.bins[bin as usize]
}
#[inline]
pub fn get_mut(&mut self, bin: u8) -> &mut BinEntry {
&mut self.bins[bin as usize]
}
#[inline]
pub fn accumulate(&mut self, bin: u8, gradient: f32, hessian: f32) {
unsafe {
self.bins
.get_unchecked_mut(bin as usize)
.accumulate(gradient, hessian);
}
}
#[inline]
pub fn accumulate_batch(&mut self, bins: &[u8], gradients: &[f32], hessians: &[f32]) {
debug_assert_eq!(bins.len(), gradients.len());
debug_assert_eq!(bins.len(), hessians.len());
let len = bins.len();
let chunks = len / 8;
let remainder = len % 8;
unsafe {
for i in 0..chunks {
let base = i * 8;
let bin0 = *bins.get_unchecked(base) as usize;
let bin1 = *bins.get_unchecked(base + 1) as usize;
let bin2 = *bins.get_unchecked(base + 2) as usize;
let bin3 = *bins.get_unchecked(base + 3) as usize;
let bin4 = *bins.get_unchecked(base + 4) as usize;
let bin5 = *bins.get_unchecked(base + 5) as usize;
let bin6 = *bins.get_unchecked(base + 6) as usize;
let bin7 = *bins.get_unchecked(base + 7) as usize;
let grad0 = *gradients.get_unchecked(base);
let grad1 = *gradients.get_unchecked(base + 1);
let grad2 = *gradients.get_unchecked(base + 2);
let grad3 = *gradients.get_unchecked(base + 3);
let grad4 = *gradients.get_unchecked(base + 4);
let grad5 = *gradients.get_unchecked(base + 5);
let grad6 = *gradients.get_unchecked(base + 6);
let grad7 = *gradients.get_unchecked(base + 7);
let hess0 = *hessians.get_unchecked(base);
let hess1 = *hessians.get_unchecked(base + 1);
let hess2 = *hessians.get_unchecked(base + 2);
let hess3 = *hessians.get_unchecked(base + 3);
let hess4 = *hessians.get_unchecked(base + 4);
let hess5 = *hessians.get_unchecked(base + 5);
let hess6 = *hessians.get_unchecked(base + 6);
let hess7 = *hessians.get_unchecked(base + 7);
self.bins.get_unchecked_mut(bin0).accumulate(grad0, hess0);
self.bins.get_unchecked_mut(bin1).accumulate(grad1, hess1);
self.bins.get_unchecked_mut(bin2).accumulate(grad2, hess2);
self.bins.get_unchecked_mut(bin3).accumulate(grad3, hess3);
self.bins.get_unchecked_mut(bin4).accumulate(grad4, hess4);
self.bins.get_unchecked_mut(bin5).accumulate(grad5, hess5);
self.bins.get_unchecked_mut(bin6).accumulate(grad6, hess6);
self.bins.get_unchecked_mut(bin7).accumulate(grad7, hess7);
}
let base = chunks * 8;
for i in 0..remainder {
let bin = *bins.get_unchecked(base + i) as usize;
let grad = *gradients.get_unchecked(base + i);
let hess = *hessians.get_unchecked(base + i);
self.bins.get_unchecked_mut(bin).accumulate(grad, hess);
}
}
}
#[inline]
pub fn merge(&mut self, other: &Histogram) {
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
self.merge_simd(other);
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
{
self.merge_scalar(other);
}
}
#[inline]
fn merge_scalar(&mut self, other: &Histogram) {
for (self_bin, other_bin) in self.bins.iter_mut().zip(other.bins.iter()) {
self_bin.merge(other_bin);
}
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
#[inline]
fn merge_simd(&mut self, other: &Histogram) {
use std::arch::x86_64::*;
self.merge_scalar(other);
}
#[inline]
pub fn subtract(&mut self, other: &Histogram) {
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
self.subtract_simd(other);
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
{
self.subtract_scalar(other);
}
}
#[inline]
fn subtract_scalar(&mut self, other: &Histogram) {
for (self_bin, other_bin) in self.bins.iter_mut().zip(other.bins.iter()) {
self_bin.subtract(other_bin);
}
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
#[inline]
fn subtract_simd(&mut self, other: &Histogram) {
self.subtract_scalar(other);
}
pub fn from_subtraction(parent: &Histogram, child: &Histogram) -> Self {
let mut result = parent.clone();
result.subtract(child);
result
}
pub fn total_gradient(&self) -> f32 {
self.bins.iter().map(|b| b.sum_gradients).sum()
}
pub fn total_hessian(&self) -> f32 {
self.bins.iter().map(|b| b.sum_hessians).sum()
}
pub fn total_count(&self) -> u32 {
self.bins.iter().map(|b| b.count).sum()
}
pub fn totals(&self) -> (f32, f32, u32) {
let mut g = 0.0f32;
let mut h = 0.0f32;
let mut n = 0u32;
for bin in &self.bins {
g += bin.sum_gradients;
h += bin.sum_hessians;
n += bin.count;
}
(g, h, n)
}
pub fn sum_gradients(&self) -> [f32; NUM_BINS] {
let mut result = [0.0f32; NUM_BINS];
for (i, bin) in self.bins.iter().enumerate() {
result[i] = bin.sum_gradients;
}
result
}
pub fn sum_hessians(&self) -> [f32; NUM_BINS] {
let mut result = [0.0f32; NUM_BINS];
for (i, bin) in self.bins.iter().enumerate() {
result[i] = bin.sum_hessians;
}
result
}
pub fn counts(&self) -> [u32; NUM_BINS] {
let mut result = [0u32; NUM_BINS];
for (i, bin) in self.bins.iter().enumerate() {
result[i] = bin.count;
}
result
}
pub fn iter(&self) -> impl Iterator<Item = (u8, &BinEntry)> {
self.bins.iter().enumerate().map(|(i, b)| (i as u8, b))
}
pub fn bins(&self) -> &[BinEntry; NUM_BINS] {
&self.bins
}
#[inline]
pub fn bins_mut(&mut self) -> &mut [BinEntry; NUM_BINS] {
&mut self.bins
}
}
#[derive(Debug, Clone)]
pub struct NodeHistograms {
pub(crate) histograms: Vec<Histogram>,
}
impl NodeHistograms {
pub fn new(num_features: usize) -> Self {
Self {
histograms: vec![Histogram::new(); num_features],
}
}
#[inline]
pub fn get(&self, feature_idx: usize) -> &Histogram {
&self.histograms[feature_idx]
}
#[inline]
pub fn get_mut(&mut self, feature_idx: usize) -> &mut Histogram {
&mut self.histograms[feature_idx]
}
pub fn num_features(&self) -> usize {
self.histograms.len()
}
pub fn clear(&mut self) {
for hist in &mut self.histograms {
hist.clear();
}
}
pub fn merge(&mut self, other: &NodeHistograms) {
for (self_hist, other_hist) in self.histograms.iter_mut().zip(other.histograms.iter()) {
self_hist.merge(other_hist);
}
}
pub fn subtract(&mut self, other: &NodeHistograms) {
for (self_hist, other_hist) in self.histograms.iter_mut().zip(other.histograms.iter()) {
self_hist.subtract(other_hist);
}
}
pub fn from_subtraction(parent: &NodeHistograms, child: &NodeHistograms) -> Self {
Self {
histograms: parent
.histograms
.iter()
.zip(child.histograms.iter())
.map(|(p, c)| Histogram::from_subtraction(p, c))
.collect(),
}
}
pub fn iter(&self) -> impl Iterator<Item = (usize, &Histogram)> {
self.histograms.iter().enumerate()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_histogram_accumulate() {
let mut hist = Histogram::new();
hist.accumulate(0, 1.0, 2.0);
hist.accumulate(0, 0.5, 1.0);
hist.accumulate(255, 3.0, 4.0);
assert_eq!(hist.get(0).sum_gradients, 1.5);
assert_eq!(hist.get(0).sum_hessians, 3.0);
assert_eq!(hist.get(0).count, 2);
assert_eq!(hist.get(255).count, 1);
}
#[test]
fn test_histogram_subtraction_trick() {
let mut parent = Histogram::new();
let mut child = Histogram::new();
parent.accumulate(0, 10.0, 20.0);
parent.accumulate(1, 5.0, 10.0);
child.accumulate(0, 3.0, 6.0);
child.accumulate(1, 2.0, 4.0);
let sibling = Histogram::from_subtraction(&parent, &child);
assert_eq!(sibling.get(0).sum_gradients, 7.0);
assert_eq!(sibling.get(0).sum_hessians, 14.0);
assert_eq!(sibling.get(1).sum_gradients, 3.0);
}
#[test]
fn test_node_histograms() {
let mut hists = NodeHistograms::new(3);
hists.get_mut(0).accumulate(5, 1.0, 2.0);
hists.get_mut(1).accumulate(10, 3.0, 4.0);
hists.get_mut(2).accumulate(15, 5.0, 6.0);
assert_eq!(hists.num_features(), 3);
assert_eq!(hists.get(0).get(5).sum_gradients, 1.0);
assert_eq!(hists.get(1).get(10).sum_gradients, 3.0);
}
}