use std::fmt;
use std::ops::Range;
use std::cmp::Ordering;
use crate::sample::{
Feature,
feature_struct::{
DenseFeature,
SparseFeature,
},
};
const EPS: f64 = 0.001;
const NUM_TOLERANCE: f64 = 1e-9;
#[derive(Clone,Default)]
pub(crate) struct GradientHessian {
pub(crate) grad: f64,
pub(crate) hess: f64,
}
impl GradientHessian {
pub(super) fn new(grad: f64, hess: f64) -> Self {
Self { grad, hess }
}
}
#[derive(Debug)]
pub struct Bin(pub Range<f64>);
impl Bin {
#[inline(always)]
pub fn new(range: Range<f64>) -> Self {
Self(range)
}
#[inline(always)]
pub fn contains(&self, item: &f64) -> bool {
self.0.contains(item)
}
}
pub struct Bins(Vec<Bin>);
impl Bins {
pub fn len(&self) -> usize {
self.0.len()
}
#[inline(always)]
pub fn cut(feature: &Feature, n_bin: usize) -> Self
{
let mut bins = match feature {
Feature::Dense(feat) => Self::cut_dense(feat, n_bin),
Feature::Sparse(feat) => Self::cut_sparse(feat, n_bin),
};
bins.0.first_mut().unwrap().0.start = f64::MIN;
bins.0.last_mut().unwrap().0.end = f64::MAX;
bins
}
fn cut_dense(feature: &DenseFeature, n_bin: usize) -> Self
{
let mut min = f64::MAX;
let mut max = f64::MIN;
feature.sample[..]
.iter()
.copied()
.for_each(|val| {
min = min.min(val);
max = max.max(val);
});
if min == max {
min = min - EPS;
max = max + EPS;
}
let intercept = (max - min) / n_bin as f64;
let mut bins = Vec::with_capacity(n_bin);
let mut left = min;
while left < max {
let right = left + intercept;
bins.push(Bin::new(left..right));
if (right - max).abs() < NUM_TOLERANCE { break; }
left = right;
}
assert_eq!(bins.len(), n_bin);
Self(bins)
}
fn cut_sparse(feature: &SparseFeature, n_bin: usize) -> Self
{
let mut min = f64::MAX;
let mut max = f64::MIN;
feature.sample[..]
.into_iter()
.copied()
.for_each(|(_, val)| {
min = min.min(val);
max = max.max(val);
});
if min > 0.0 && feature.has_zero() {
min = 0.0;
}
if max < 0.0 && feature.has_zero() {
max = 0.0;
}
if min == max {
min = min - EPS;
max = max + EPS;
}
let intercept = (max - min) / n_bin as f64;
let mut bins = Vec::with_capacity(n_bin);
let mut left = min;
while left < max {
let right = left + intercept;
bins.push(Bin::new(left..right));
if (right - max).abs() < NUM_TOLERANCE { break; }
left = right;
}
assert_eq!(bins.len(), n_bin);
Self(bins)
}
pub(crate) fn pack(
&self,
indices: &[usize],
feat: &Feature,
gh: &[GradientHessian],
) -> Vec<(Bin, GradientHessian)>
{
let n_bins = self.0.len();
let mut packed = vec![GradientHessian::default(); n_bins];
for &i in indices {
let xi = feat[i];
let pos = self.0.binary_search_by(|range| {
if range.contains(&xi) {
return Ordering::Equal;
}
range.0.start.partial_cmp(&xi).unwrap()
})
.unwrap();
packed[pos].grad += gh[i].grad;
packed[pos].hess += gh[i].hess;
}
self.remove_zero_weight_pack_and_normalize(packed)
}
fn remove_zero_weight_pack_and_normalize(
&self,
pack: Vec<GradientHessian>,
) -> Vec<(Bin, GradientHessian)>
{
let mut iter = self.0.iter().zip(pack);
let (prev_bin, mut prev_gh) = iter.next().unwrap();
let mut prev_bin = Bin::new(prev_bin.0.clone());
let mut iter = iter.filter(|(_, gh)| {
gh.grad != 0.0 || gh.hess != 0.0
});
if prev_gh.grad == 0.0 && prev_gh.hess == 0.0 {
let (next_bin, next_gh) = iter.next().unwrap();
let start = prev_bin.0.start;
let end = next_bin.0.end;
prev_bin = Bin::new(start..end);
prev_gh = next_gh;
}
let mut bin_and_gh = Vec::new();
for (next_bin, next_gh) in iter {
let start = prev_bin.0.start;
let end = (prev_bin.0.end + next_bin.0.start) / 2.0;
let bin = Bin::new(start..end);
bin_and_gh.push((bin, prev_gh));
prev_bin = Bin::new(next_bin.0.clone());
prev_bin.0.start = end;
prev_gh = next_gh;
}
bin_and_gh.push((prev_bin, prev_gh));
bin_and_gh
}
}
const PRINT_BIN_SIZE: usize = 3;
impl fmt::Display for Bins {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let bins = &self.0;
let n_bins = bins.len();
if n_bins > PRINT_BIN_SIZE {
let head = bins[..2].iter()
.map(|bin| format!("{bin}"))
.collect::<Vec<_>>()
.join(", ");
let tail = bins.last()
.map(|bin| format!("{bin}"))
.unwrap();
write!(f, "{head}, ... , {tail}")
} else {
let line = bins.iter()
.map(|bin| format!("{}", bin))
.collect::<Vec<_>>()
.join(", ");
write!(f, "{line}")
}
}
}
impl fmt::Display for Bin {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let start = if self.0.start == f64::MIN {
String::from("-Inf")
} else {
let start = self.0.start;
let sgn = if start > 0.0 {
'+'
} else if start < 0.0 {
'-'
} else {
' '
};
let start = start.abs();
format!("{sgn}{start: >.2}")
};
let end = if self.0.end == f64::MAX {
String::from("+Inf")
} else {
let end = self.0.end;
let sgn = if end > 0.0 {
'+'
} else if end < 0.0 {
'-'
} else {
' '
};
let end = end.abs();
format!("{sgn}{end: >.2}")
};
write!(f, "[{start}, {end})")
}
}