use serde::{Deserialize, Serialize};
use super::MISSING_BIN;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BinMapper {
pub(crate) upper_bounds: Vec<f64>,
}
impl BinMapper {
pub fn fit(values: &[f64], max_bin: usize, min_data_in_bin: usize) -> Self {
debug_assert!(max_bin >= 2);
let mut finite: Vec<f64> = values.iter().copied().filter(|v| v.is_finite()).collect();
if finite.is_empty() {
return Self {
upper_bounds: vec![],
};
}
finite.sort_by(|a, b| a.partial_cmp(b).unwrap());
let mut distinct_vals: Vec<f64> = Vec::new();
let mut counts: Vec<usize> = Vec::new();
for &v in &finite {
match distinct_vals.last() {
Some(&last) if last == v => *counts.last_mut().unwrap() += 1,
_ => {
distinct_vals.push(v);
counts.push(1);
}
}
}
let total = finite.len();
let target_bins = max_bin.saturating_sub(1).max(1); let mut upper_bounds: Vec<f64> = Vec::new();
if distinct_vals.len() <= target_bins {
for i in 0..distinct_vals.len().saturating_sub(1) {
let mid = (distinct_vals[i] + distinct_vals[i + 1]) / 2.0;
upper_bounds.push(mid);
}
} else {
let target_size = total.div_ceil(target_bins).max(min_data_in_bin).max(1);
let mut cur_cnt: usize = 0;
for i in 0..distinct_vals.len() - 1 {
cur_cnt += counts[i];
if cur_cnt >= target_size && upper_bounds.len() + 1 < target_bins {
let mid = (distinct_vals[i] + distinct_vals[i + 1]) / 2.0;
upper_bounds.push(mid);
cur_cnt = 0;
}
}
}
Self { upper_bounds }
}
pub fn num_real_bins(&self) -> usize {
self.upper_bounds.len() + 1
}
pub fn num_bins(&self) -> usize {
self.num_real_bins() + 1
}
pub fn upper_bounds(&self) -> &[f64] {
&self.upper_bounds
}
#[inline]
pub fn value_to_bin(&self, v: f64) -> u16 {
if !v.is_finite() {
return MISSING_BIN;
}
let idx = self.upper_bounds.partition_point(|&ub| ub < v);
(idx as u16) + 1
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn few_distinct_values_each_get_own_bin() {
let vals = vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0];
let bm = BinMapper::fit(&vals, 16, 1);
assert_eq!(bm.num_real_bins(), 3);
assert_eq!(bm.value_to_bin(1.0), 1);
assert_eq!(bm.value_to_bin(2.0), 2);
assert_eq!(bm.value_to_bin(3.0), 3);
}
#[test]
fn nan_maps_to_missing() {
let vals = vec![1.0, 2.0, 3.0];
let bm = BinMapper::fit(&vals, 16, 1);
assert_eq!(bm.value_to_bin(f64::NAN), MISSING_BIN);
}
#[test]
fn many_values_capped_at_max_bin() {
let vals: Vec<f64> = (0..1000).map(|i| i as f64).collect();
let bm = BinMapper::fit(&vals, 16, 1);
assert!(bm.num_real_bins() <= 15);
}
}