use rayon::prelude::*;
use std::fmt;
use std::ops::Range;
use std::cmp::Ordering;
use crate::weak_learner::common::{
type_and_struct::*,
};
use crate::sample::{
Feature,
feature_struct::{
DenseFeature,
SparseFeature,
},
};
const EPS: f64 = 0.001;
const NUM_TOLERANCE: f64 = 1e-9;
#[derive(Debug)]
pub struct Bin(pub Range<f64>);
impl Bin {
#[inline(always)]
pub fn new(range: Range<f64>) -> Self {
Self(range)
}
#[inline(always)]
pub fn contains(&self, item: &f64) -> bool {
self.0.contains(item)
}
}
pub struct Bins(Vec<Bin>);
impl Bins {
pub fn len(&self) -> usize {
self.0.len()
}
#[inline(always)]
pub fn cut(feature: &Feature, n_bin: usize) -> Self
{
let mut bins = match feature {
Feature::Dense(feat) => Self::cut_dense(feat, n_bin),
Feature::Sparse(feat) => Self::cut_sparse(feat, n_bin),
};
bins.0.first_mut().unwrap().0.start = f64::MIN;
bins.0.last_mut().unwrap().0.end = f64::MAX;
bins
}
fn cut_dense(feature: &DenseFeature, n_bin: usize) -> Self
{
let mut min = f64::MAX;
let mut max = f64::MIN;
feature.sample[..]
.iter()
.copied()
.for_each(|val| {
min = min.min(val);
max = max.max(val);
});
if min == max {
min = min - EPS;
max = max + EPS;
}
let intercept = (max - min) / n_bin as f64;
let mut bins = Vec::with_capacity(n_bin);
let mut left = min;
while left < max {
let right = left + intercept;
bins.push(Bin::new(left..right));
if (right - max).abs() < NUM_TOLERANCE { break; }
left = right;
}
assert_eq!(bins.len(), n_bin);
Self(bins)
}
fn cut_sparse(feature: &SparseFeature, n_bin: usize) -> Self
{
let mut min = f64::MAX;
let mut max = f64::MIN;
feature.sample[..]
.into_iter()
.copied()
.for_each(|(_, val)| {
min = min.min(val);
max = max.max(val);
});
if min > 0.0 && feature.has_zero() {
min = 0.0;
}
if max < 0.0 && feature.has_zero() {
max = 0.0;
}
if min == max {
min = min - EPS;
max = max + EPS;
}
let intercept = (max - min) / n_bin as f64;
let mut bins = Vec::with_capacity(n_bin);
let mut left = min;
while left < max {
let right = left + intercept;
bins.push(Bin::new(left..right));
if (right - max).abs() < NUM_TOLERANCE { break; }
left = right;
}
assert_eq!(bins.len(), n_bin);
Self(bins)
}
pub(crate) fn pack(
&self,
indices: &[usize],
feat: &Feature,
y: &[f64],
dist: &[f64]
) -> Vec<(Bin, LabelToWeight)>
{
let n_bins = self.0.len();
let mut packed = vec![LabelToWeight::new(); n_bins];
for &i in indices {
let xi = feat[i];
let yi = y[i] as i32;
let di = dist[i];
let pos = self.0.binary_search_by(|range| {
if range.contains(&xi) {
return Ordering::Equal;
}
range.0.start.partial_cmp(&xi).unwrap()
})
.unwrap();
let weight = packed[pos].entry(yi).or_insert(0.0);
*weight += di;
}
self.remove_zero_weight_pack_and_normalize(packed)
}
fn remove_zero_weight_pack_and_normalize(
&self,
pack: Vec<LabelToWeight>
) -> Vec<(Bin, LabelToWeight)>
{
let total_weight = pack.par_iter()
.map(|mp| mp.values().sum::<f64>())
.sum::<f64>();
assert!(total_weight > 0.0);
let mut iter = self.0.iter().zip(pack);
let (prev_bin, mut prev_weight) = iter.next().unwrap();
let mut prev_bin = Bin::new(prev_bin.0.clone());
let mut iter = iter.filter(|(_, mp)| !mp.is_empty());
if prev_weight.is_empty() {
let (next_bin, next_weight) = iter.next().unwrap();
let start = prev_bin.0.start;
let end = next_bin.0.end;
prev_bin = Bin::new(start..end);
prev_weight = next_weight;
}
prev_weight.par_iter_mut()
.for_each(|(_, v)| { *v /= total_weight; });
let mut bin_and_weight = Vec::new();
for (next_bin, next_weight) in iter {
let start = prev_bin.0.start;
let end = (prev_bin.0.end + next_bin.0.start) / 2.0;
let bin = Bin::new(start..end);
bin_and_weight.push((bin, prev_weight));
prev_bin = Bin::new(next_bin.0.clone());
prev_bin.0.start = end;
prev_weight = next_weight;
prev_weight.par_iter_mut()
.for_each(|(_, v)| { *v /= total_weight; });
}
bin_and_weight.push((prev_bin, prev_weight));
bin_and_weight
}
}
const PRINT_BIN_SIZE: usize = 3;
impl fmt::Display for Bins {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let bins = &self.0;
let n_bins = bins.len();
if n_bins > PRINT_BIN_SIZE {
let head = bins[..2].iter()
.map(|bin| format!("{bin}"))
.collect::<Vec<_>>()
.join(", ");
let tail = bins.last()
.map(|bin| format!("{bin}"))
.unwrap();
write!(f, "{head}, ... , {tail}")
} else {
let line = bins.iter()
.map(|bin| format!("{}", bin))
.collect::<Vec<_>>()
.join(", ");
write!(f, "{line}")
}
}
}
impl fmt::Display for Bin {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let start = if self.0.start == f64::MIN {
String::from("-Inf")
} else {
let start = self.0.start;
let sgn = if start > 0.0 {
'+'
} else if start < 0.0 {
'-'
} else {
' '
};
let start = start.abs();
format!("{sgn}{start: >.2}")
};
let end = if self.0.end == f64::MAX {
String::from("+Inf")
} else {
let end = self.0.end;
let sgn = if end > 0.0 {
'+'
} else if end < 0.0 {
'-'
} else {
' '
};
let end = end.abs();
format!("{sgn}{end: >.2}")
};
write!(f, "[{start}, {end})")
}
}