use alloc::boxed::Box;
use alloc::vec::Vec;
use super::{BinEdges, BinningStrategy};
pub const MAX_CATEGORIES: usize = 64;
#[derive(Debug, Clone)]
pub struct CategoricalBinning {
categories: Vec<i64>,
}
impl CategoricalBinning {
pub fn new() -> Self {
Self {
categories: Vec::new(),
}
}
#[inline]
pub fn n_categories(&self) -> usize {
self.categories.len()
}
pub fn categories(&self) -> &[i64] {
&self.categories
}
pub fn category_index(&self, value: f64) -> Option<usize> {
let v = value as i64;
self.categories.binary_search(&v).ok()
}
}
impl Default for CategoricalBinning {
fn default() -> Self {
Self::new()
}
}
impl BinningStrategy for CategoricalBinning {
fn observe(&mut self, value: f64) {
let v = value as i64;
match self.categories.binary_search(&v) {
Ok(_) => {} Err(pos) => {
if self.categories.len() < MAX_CATEGORIES {
self.categories.insert(pos, v);
}
}
}
}
fn compute_edges(&self, _n_bins: usize) -> BinEdges {
if self.categories.len() <= 1 {
return BinEdges { edges: Vec::new() };
}
let edges: Vec<f64> = self
.categories
.windows(2)
.map(|w| (w[0] as f64 + w[1] as f64) / 2.0)
.collect();
BinEdges { edges }
}
fn reset(&mut self) {
self.categories.clear();
}
fn clone_fresh(&self) -> Box<dyn BinningStrategy> {
Box::new(CategoricalBinning::new())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_binner() {
let b = CategoricalBinning::new();
assert_eq!(b.n_categories(), 0);
let edges = b.compute_edges(10);
assert!(edges.edges.is_empty());
assert_eq!(edges.n_bins(), 1);
}
#[test]
fn single_category() {
let mut b = CategoricalBinning::new();
b.observe(3.0);
b.observe(3.0); assert_eq!(b.n_categories(), 1);
let edges = b.compute_edges(10);
assert!(edges.edges.is_empty());
assert_eq!(edges.n_bins(), 1);
}
#[test]
fn two_categories() {
let mut b = CategoricalBinning::new();
b.observe(1.0);
b.observe(5.0);
assert_eq!(b.n_categories(), 2);
let edges = b.compute_edges(10);
assert_eq!(edges.edges.len(), 1);
assert!((edges.edges[0] - 3.0).abs() < 1e-10);
assert_eq!(edges.n_bins(), 2);
}
#[test]
fn multiple_categories_sorted() {
let mut b = CategoricalBinning::new();
b.observe(5.0);
b.observe(1.0);
b.observe(3.0);
b.observe(7.0);
assert_eq!(b.categories(), &[1, 3, 5, 7]);
let edges = b.compute_edges(10);
assert_eq!(edges.edges.len(), 3);
assert!((edges.edges[0] - 2.0).abs() < 1e-10); assert!((edges.edges[1] - 4.0).abs() < 1e-10); assert!((edges.edges[2] - 6.0).abs() < 1e-10); }
#[test]
fn find_bin_routes_correctly() {
let mut b = CategoricalBinning::new();
for i in 0..5 {
b.observe(i as f64 * 2.0); }
let edges = b.compute_edges(10);
assert_eq!(edges.find_bin(0.0), 0); assert_eq!(edges.find_bin(2.0), 1); assert_eq!(edges.find_bin(4.0), 2); assert_eq!(edges.find_bin(6.0), 3); assert_eq!(edges.find_bin(8.0), 4); }
#[test]
fn category_index_lookup() {
let mut b = CategoricalBinning::new();
b.observe(10.0);
b.observe(20.0);
b.observe(30.0);
assert_eq!(b.category_index(10.0), Some(0));
assert_eq!(b.category_index(20.0), Some(1));
assert_eq!(b.category_index(30.0), Some(2));
assert_eq!(b.category_index(15.0), None);
}
#[test]
fn max_categories_enforced() {
let mut b = CategoricalBinning::new();
for i in 0..100 {
b.observe(i as f64);
}
assert_eq!(b.n_categories(), MAX_CATEGORIES);
}
#[test]
fn reset_clears() {
let mut b = CategoricalBinning::new();
b.observe(1.0);
b.observe(2.0);
b.observe(3.0);
b.reset();
assert_eq!(b.n_categories(), 0);
}
#[test]
fn n_bins_ignored() {
let mut b = CategoricalBinning::new();
b.observe(1.0);
b.observe(2.0);
b.observe(3.0);
let edges_2 = b.compute_edges(2);
let edges_100 = b.compute_edges(100);
assert_eq!(edges_2.edges, edges_100.edges);
}
}