use alloc::boxed::Box;
use alloc::vec;
use alloc::vec::Vec;
use crate::histogram::{BinEdges, BinningStrategy};
const DEFAULT_MAX_RESERVOIR: usize = 10_000;
const DEFAULT_MAX_ITERS: usize = 50;
const EPSILON: f64 = 1e-10;
use crate::rng::xorshift64;
#[derive(Debug, Clone)]
pub struct KMeansBinning {
reservoir: Vec<f64>,
max_reservoir: usize,
max_iters: usize,
count: u64,
rng_state: u64,
}
impl KMeansBinning {
pub fn new() -> Self {
Self::with_params(DEFAULT_MAX_RESERVOIR, DEFAULT_MAX_ITERS)
}
pub fn with_params(max_reservoir: usize, max_iters: usize) -> Self {
assert!(max_reservoir > 0, "max_reservoir must be > 0");
assert!(max_iters > 0, "max_iters must be > 0");
Self {
reservoir: Vec::with_capacity(max_reservoir.min(1024)),
max_reservoir,
max_iters,
count: 0,
rng_state: 0xDEAD_BEEF_CAFE_BABE,
}
}
#[inline]
pub fn reservoir_len(&self) -> usize {
self.reservoir.len()
}
#[inline]
pub fn count(&self) -> u64 {
self.count
}
}
impl Default for KMeansBinning {
fn default() -> Self {
Self::new()
}
}
impl BinningStrategy for KMeansBinning {
fn observe(&mut self, value: f64) {
self.count += 1;
if self.reservoir.len() < self.max_reservoir {
self.reservoir.push(value);
} else {
let r = xorshift64(&mut self.rng_state) % self.count;
if (r as usize) < self.max_reservoir {
let idx = r as usize;
self.reservoir[idx] = value;
}
}
}
fn compute_edges(&self, n_bins: usize) -> BinEdges {
if n_bins == 0 {
return BinEdges { edges: Vec::new() };
}
if self.reservoir.is_empty() {
return BinEdges { edges: Vec::new() };
}
let mut sorted: Vec<f64> = self.reservoir.clone();
sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
sorted.dedup();
let n_unique = sorted.len();
if n_unique <= n_bins {
if n_unique <= 1 {
return BinEdges { edges: Vec::new() };
}
let edges: Vec<f64> = sorted.windows(2).map(|w| (w[0] + w[1]) / 2.0).collect();
return BinEdges { edges };
}
let k = n_bins;
let data = &self.reservoir; let n = data.len();
let mut sorted_full: Vec<f64> = data.to_vec();
sorted_full.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
let mut centers: Vec<f64> = (0..k)
.map(|i| {
let q = (i as f64 + 0.5) / k as f64;
let idx_f = q * (n - 1) as f64;
let lo = libm::floor(idx_f) as usize;
let hi = (lo + 1).min(n - 1);
let frac = idx_f - lo as f64;
sorted_full[lo] * (1.0 - frac) + sorted_full[hi] * frac
})
.collect();
let mut sums = vec![0.0_f64; k];
let mut counts = vec![0_usize; k];
for _ in 0..self.max_iters {
for j in 0..k {
sums[j] = 0.0;
counts[j] = 0;
}
for &x in data.iter() {
let mut best_j = 0;
let mut best_dist = (x - centers[0]).abs();
for (j, ¢er) in centers.iter().enumerate().skip(1) {
let d = (x - center).abs();
if d < best_dist {
best_dist = d;
best_j = j;
}
}
sums[best_j] += x;
counts[best_j] += 1;
}
let mut max_move: f64 = 0.0;
for j in 0..k {
if counts[j] > 0 {
let new_center = sums[j] / counts[j] as f64;
let move_dist = (new_center - centers[j]).abs();
if move_dist > max_move {
max_move = move_dist;
}
centers[j] = new_center;
}
}
if max_move <= EPSILON {
break;
}
}
centers.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
centers.dedup_by(|a, b| (*a - *b).abs() <= EPSILON);
if centers.len() <= 1 {
return BinEdges { edges: Vec::new() };
}
let mut edges: Vec<f64> = centers.windows(2).map(|w| (w[0] + w[1]) / 2.0).collect();
edges.dedup_by(|a, b| (*a - *b).abs() <= EPSILON);
BinEdges { edges }
}
fn reset(&mut self) {
self.reservoir.clear();
self.count = 0;
self.rng_state = 0xDEAD_BEEF_CAFE_BABE;
}
fn clone_fresh(&self) -> Box<dyn BinningStrategy> {
Box::new(KMeansBinning::with_params(
self.max_reservoir,
self.max_iters,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_sorted(edges: &[f64]) {
for i in 1..edges.len() {
assert!(
edges[i] > edges[i - 1],
"edges not strictly sorted: edges[{}]={} <= edges[{}]={}",
i,
edges[i],
i - 1,
edges[i - 1],
);
}
}
#[test]
fn normal_ish_data_edges_sorted() {
let mut binner = KMeansBinning::new();
let mut lcg: u64 = 12345;
for _ in 0..5000 {
let mut sum = 0.0_f64;
for _ in 0..4 {
lcg = lcg.wrapping_mul(6364136223846793005).wrapping_add(1);
let u = (lcg >> 33) as f64 / (1u64 << 31) as f64; sum += u;
}
binner.observe(sum - 2.0);
}
let edges = binner.compute_edges(8);
assert!(!edges.edges.is_empty());
assert!(edges.edges.len() <= 7, "at most n_bins-1 edges");
assert_sorted(&edges.edges);
}
#[test]
fn skewed_data_denser_bins_at_low_end() {
let mut binner = KMeansBinning::new();
for i in 0..1000u64 {
binner.observe((i * i) as f64);
}
let edges = binner.compute_edges(8);
assert_sorted(&edges.edges);
let max_val = 999.0 * 999.0;
let equal_width_first = max_val / 8.0;
assert!(
edges.edges[0] < equal_width_first * 0.6,
"first k-means edge {} should be well below equal-width edge {} (60% threshold = {})",
edges.edges[0],
equal_width_first,
equal_width_first * 0.6,
);
}
#[test]
fn uniform_data_roughly_equal_width() {
let mut binner = KMeansBinning::new();
for i in 0..2000 {
binner.observe(i as f64);
}
let edges = binner.compute_edges(5);
assert_sorted(&edges.edges);
assert_eq!(edges.edges.len(), 4, "5 bins -> 4 edges");
let expected = [400.0, 800.0, 1200.0, 1600.0];
for (got, want) in edges.edges.iter().zip(expected.iter()) {
assert!(
(got - want).abs() < 200.0,
"edge {} too far from expected {} for uniform data",
got,
want,
);
}
let gaps: Vec<f64> = core::iter::once(edges.edges[0])
.chain(edges.edges.windows(2).map(|w| w[1] - w[0]))
.collect();
let min_gap = gaps.iter().cloned().fold(f64::MAX, f64::min);
let max_gap = gaps.iter().cloned().fold(f64::MIN, f64::max);
assert!(
max_gap / min_gap < 2.0,
"gaps should be roughly even for uniform data, got min={} max={}",
min_gap,
max_gap,
);
}
#[test]
fn reservoir_capped_at_max() {
let cap = 500;
let mut binner = KMeansBinning::with_params(cap, 20);
for i in 0..100_000u64 {
binner.observe(i as f64);
}
assert_eq!(binner.count(), 100_000);
assert!(
binner.reservoir_len() <= cap,
"reservoir {} exceeds cap {}",
binner.reservoir_len(),
cap,
);
assert_eq!(binner.reservoir_len(), cap);
}
#[test]
fn fewer_samples_than_bins() {
let mut binner = KMeansBinning::new();
binner.observe(1.0);
binner.observe(3.0);
binner.observe(5.0);
let edges = binner.compute_edges(10);
assert!(
edges.edges.len() <= 2,
"expected at most 2 edges, got {}",
edges.edges.len(),
);
assert_sorted(&edges.edges);
if edges.edges.len() == 2 {
assert!((edges.edges[0] - 2.0).abs() < 1e-10);
assert!((edges.edges[1] - 4.0).abs() < 1e-10);
}
}
#[test]
fn single_value_repeated() {
let mut binner = KMeansBinning::new();
for _ in 0..1000 {
binner.observe(42.0);
}
let edges = binner.compute_edges(8);
assert!(
edges.edges.is_empty(),
"expected no edges for constant data, got {:?}",
edges.edges,
);
}
#[test]
fn convergence_within_iters() {
let mut binner_fast = KMeansBinning::with_params(5000, 3);
let mut binner_full = KMeansBinning::with_params(5000, 200);
let mut lcg: u64 = 99999;
let data: Vec<f64> = (0..3000)
.map(|_| {
lcg = lcg.wrapping_mul(6364136223846793005).wrapping_add(1);
(lcg >> 33) as f64 / (1u64 << 31) as f64
})
.collect();
for &x in &data {
binner_fast.observe(x);
binner_full.observe(x);
}
let edges_fast = binner_fast.compute_edges(6);
let edges_full = binner_full.compute_edges(6);
assert_sorted(&edges_fast.edges);
assert_sorted(&edges_full.edges);
assert_eq!(edges_fast.edges.len(), edges_full.edges.len());
for (a, b) in edges_fast.edges.iter().zip(edges_full.edges.iter()) {
let range = 1.0; let tolerance = range * 0.1; assert!(
(a - b).abs() < tolerance,
"3-iter edge {} too far from converged edge {} (tol {})",
a,
b,
tolerance,
);
}
}
#[test]
fn empty_binner_returns_no_edges() {
let binner = KMeansBinning::new();
let edges = binner.compute_edges(8);
assert!(edges.edges.is_empty());
}
#[test]
fn zero_bins_returns_empty() {
let mut binner = KMeansBinning::new();
binner.observe(1.0);
binner.observe(2.0);
let edges = binner.compute_edges(0);
assert!(edges.edges.is_empty());
}
#[test]
fn one_bin_returns_no_edges() {
let mut binner = KMeansBinning::new();
for i in 0..100 {
binner.observe(i as f64);
}
let edges = binner.compute_edges(1);
assert!(edges.edges.is_empty());
}
#[test]
fn two_values_two_bins() {
let mut binner = KMeansBinning::new();
binner.observe(0.0);
binner.observe(10.0);
let edges = binner.compute_edges(2);
assert_eq!(edges.edges.len(), 1);
assert!((edges.edges[0] - 5.0).abs() < 1e-10);
}
#[test]
fn reset_clears_all_state() {
let mut binner = KMeansBinning::with_params(100, 10);
for i in 0..200 {
binner.observe(i as f64);
}
assert_eq!(binner.count(), 200);
assert_eq!(binner.reservoir_len(), 100);
binner.reset();
assert_eq!(binner.count(), 0);
assert_eq!(binner.reservoir_len(), 0);
let edges = binner.compute_edges(4);
assert!(edges.edges.is_empty());
}
#[test]
fn clone_fresh_preserves_params() {
let mut binner = KMeansBinning::with_params(256, 25);
for i in 0..500 {
binner.observe(i as f64);
}
let fresh = binner.clone_fresh();
let edges = fresh.compute_edges(4);
assert!(edges.edges.is_empty());
}
#[test]
fn bimodal_data_finds_two_clusters() {
let mut binner = KMeansBinning::new();
let mut lcg: u64 = 77777;
for _ in 0..2500 {
lcg = lcg.wrapping_mul(6364136223846793005).wrapping_add(1);
let noise = ((lcg >> 33) as f64 / (1u64 << 31) as f64 - 0.5) * 2.0;
binner.observe(10.0 + noise);
}
for _ in 0..2500 {
lcg = lcg.wrapping_mul(6364136223846793005).wrapping_add(1);
let noise = ((lcg >> 33) as f64 / (1u64 << 31) as f64 - 0.5) * 2.0;
binner.observe(90.0 + noise);
}
let edges = binner.compute_edges(4);
assert_sorted(&edges.edges);
let has_gap_edge = edges.edges.iter().any(|&e| e > 20.0 && e < 80.0);
assert!(
has_gap_edge,
"expected at least one edge in the gap between clusters, edges={:?}",
edges.edges,
);
}
#[test]
fn find_bin_integration() {
let mut binner = KMeansBinning::new();
for i in 0..1000 {
binner.observe(i as f64);
}
let edges = binner.compute_edges(5);
assert_sorted(&edges.edges);
let n_bins = edges.n_bins();
for i in 0..1000 {
let bin = edges.find_bin(i as f64);
assert!(
bin < n_bins,
"find_bin({}) = {} but n_bins = {}",
i,
bin,
n_bins,
);
}
}
}