use std::{cell::UnsafeCell, cmp::Ordering};
use crate::data::FloatData;
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, Serialize, Clone, Copy)]
pub struct Bin {
pub num: u16,
pub cut_value: f64,
pub g_folded: [f32; 5],
pub h_folded: [f32; 5],
pub counts: [u32; 5],
}
impl Bin {
pub fn empty_const_hess(num: u16, cut_value: f64) -> Self {
Bin {
num,
cut_value,
g_folded: [f32::ZERO; 5],
h_folded: [f32::ZERO; 5],
counts: [0; 5],
}
}
pub fn empty(num: u16, cut_value: f64) -> Self {
Bin {
num,
cut_value,
g_folded: [f32::ZERO; 5],
h_folded: [f32::ZERO; 5],
counts: [0; 5],
}
}
pub unsafe fn from_parent_child(root_bin: *mut Bin, child_bin: *mut Bin, update_bin: *mut Bin) {
unsafe {
let rb = root_bin.as_ref().unwrap_unchecked();
let cb = child_bin.as_ref().unwrap_unchecked();
let ub = update_bin.as_mut().unwrap_unchecked();
for j in 0..5 {
*ub.g_folded.get_unchecked_mut(j) = rb.g_folded.get_unchecked(j) - cb.g_folded.get_unchecked(j);
*ub.h_folded.get_unchecked_mut(j) = rb.h_folded.get_unchecked(j) - cb.h_folded.get_unchecked(j);
*ub.counts.get_unchecked_mut(j) = rb.counts.get_unchecked(j) - cb.counts.get_unchecked(j);
}
}
}
pub unsafe fn from_parent_two_children(
root_bin: *mut Bin,
first_bin: *mut Bin,
second_bin: *mut Bin,
update_bin: *mut Bin,
) {
unsafe {
let rb = root_bin.as_ref().unwrap_unchecked();
let fb = first_bin.as_ref().unwrap_unchecked();
let sb = second_bin.as_ref().unwrap_unchecked();
let ub = update_bin.as_mut().unwrap_unchecked();
for j in 0..5 {
*ub.g_folded.get_unchecked_mut(j) =
rb.g_folded.get_unchecked(j) - fb.g_folded.get_unchecked(j) - sb.g_folded.get_unchecked(j);
*ub.h_folded.get_unchecked_mut(j) =
rb.h_folded.get_unchecked(j) - fb.h_folded.get_unchecked(j) - sb.h_folded.get_unchecked(j);
*ub.counts.get_unchecked_mut(j) =
rb.counts.get_unchecked(j) - fb.counts.get_unchecked(j) - sb.counts.get_unchecked(j);
}
}
}
}
pub fn sort_cat_bins_by_num(histogram: &mut [&UnsafeCell<Bin>]) {
unsafe {
histogram.sort_unstable_by_key(|bin| bin.get().as_ref().unwrap().num);
}
}
pub fn sort_cat_bins_by_stat(histogram: &mut [&UnsafeCell<Bin>], is_const_hess: bool) {
unsafe {
if is_const_hess {
histogram.sort_unstable_by(|bin1, bin2| {
let b1 = bin1.get().as_ref().unwrap();
let b2 = bin2.get().as_ref().unwrap();
if b1.num == 0 {
return Ordering::Less;
} else if b2.num == 0 {
return Ordering::Greater;
}
let div1: f32 = b1.g_folded.iter().sum::<f32>() / b1.counts.iter().sum::<u32>() as f32;
let div2: f32 = b2.g_folded.iter().sum::<f32>() / b2.counts.iter().sum::<u32>() as f32;
div2.total_cmp(&div1)
});
} else {
histogram.sort_unstable_by(|bin1, bin2| {
let b1 = bin1.get().as_ref().unwrap();
let b2 = bin2.get().as_ref().unwrap();
if b1.num == 0 {
return Ordering::Less;
} else if b2.num == 0 {
return Ordering::Greater;
}
let div1: f32 = b1.g_folded.iter().sum::<f32>() / b1.h_folded.iter().sum::<f32>();
let div2: f32 = b2.g_folded.iter().sum::<f32>() / b2.h_folded.iter().sum::<f32>();
div2.total_cmp(&div1)
});
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bin() {
let mut root_bin = Bin::empty_const_hess(0, 0.0);
root_bin.counts = [10, 10, 10, 10, 10];
let mut child_bin = Bin::empty_const_hess(1, 0.0);
child_bin.counts = [9, 8, 7, 6, 5];
let mut update_bin = Bin::empty_const_hess(2, 0.0);
unsafe {
Bin::from_parent_child(
&mut root_bin as *mut Bin,
&mut child_bin as *mut Bin,
&mut update_bin as *mut Bin,
)
};
assert!(update_bin.counts == [1, 2, 3, 4, 5]);
}
#[test]
fn test_from_parent_two_children() {
let mut root = Bin::empty(0, 0.0);
root.g_folded = [10.0, 20.0, 30.0, 40.0, 50.0];
root.h_folded = [5.0, 5.0, 5.0, 5.0, 5.0];
root.counts = [100, 100, 100, 100, 100];
let mut c1 = Bin::empty(1, 0.0);
c1.g_folded = [3.0, 6.0, 9.0, 12.0, 15.0];
c1.h_folded = [1.0, 1.0, 1.0, 1.0, 1.0];
c1.counts = [30, 30, 30, 30, 30];
let mut c2 = Bin::empty(2, 0.0);
c2.g_folded = [2.0, 4.0, 6.0, 8.0, 10.0];
c2.h_folded = [1.0, 2.0, 1.0, 2.0, 1.0];
c2.counts = [20, 20, 20, 20, 20];
let mut update = Bin::empty(3, 0.0);
unsafe {
Bin::from_parent_two_children(
&mut root as *mut Bin,
&mut c1 as *mut Bin,
&mut c2 as *mut Bin,
&mut update as *mut Bin,
);
}
assert!((update.g_folded[0] - 5.0).abs() < 1e-6);
assert!((update.h_folded[1] - 2.0).abs() < 1e-6);
assert_eq!(update.counts[0], 50);
}
#[test]
fn test_sort_cat_bins_by_num() {
let b1 = Bin::empty_const_hess(3, 0.0);
let b2 = Bin::empty_const_hess(1, 0.0);
let b3 = Bin::empty_const_hess(2, 0.0);
let c1 = UnsafeCell::new(b1);
let c2 = UnsafeCell::new(b2);
let c3 = UnsafeCell::new(b3);
let mut hist: Vec<&UnsafeCell<Bin>> = vec![&c1, &c2, &c3];
sort_cat_bins_by_num(&mut hist);
unsafe {
assert_eq!(hist[0].get().as_ref().unwrap().num, 1);
assert_eq!(hist[1].get().as_ref().unwrap().num, 2);
assert_eq!(hist[2].get().as_ref().unwrap().num, 3);
}
}
#[test]
fn test_sort_cat_bins_by_stat_const_hess() {
let b0 = Bin::empty_const_hess(0, 0.0); let mut b1 = Bin::empty_const_hess(1, 0.0);
b1.g_folded = [1.0; 5];
b1.counts = [10; 5]; let mut b2 = Bin::empty_const_hess(2, 0.0);
b2.g_folded = [5.0; 5];
b2.counts = [10; 5]; let c0 = UnsafeCell::new(b0);
let c1 = UnsafeCell::new(b1);
let c2 = UnsafeCell::new(b2);
let mut hist: Vec<&UnsafeCell<Bin>> = vec![&c2, &c0, &c1];
sort_cat_bins_by_stat(&mut hist, true);
unsafe {
assert_eq!(hist[0].get().as_ref().unwrap().num, 0); }
}
#[test]
fn test_sort_cat_bins_by_stat_non_const() {
let b0 = Bin::empty(0, 0.0); let mut b1 = Bin::empty(1, 0.0);
b1.g_folded = [1.0; 5];
b1.h_folded = [10.0; 5]; let mut b2 = Bin::empty(2, 0.0);
b2.g_folded = [5.0; 5];
b2.h_folded = [10.0; 5]; let c0 = UnsafeCell::new(b0);
let c1 = UnsafeCell::new(b1);
let c2 = UnsafeCell::new(b2);
let mut hist: Vec<&UnsafeCell<Bin>> = vec![&c2, &c0, &c1];
sort_cat_bins_by_stat(&mut hist, false);
unsafe {
assert_eq!(hist[0].get().as_ref().unwrap().num, 0); }
}
}