use num_traits::ToPrimitive;
use crate::{error::Error, util::validate_and_convert, ClassifiedResult, IndexRanges};
pub struct KNLogN {}
struct PrefixSums {
prefix_sum: Vec<f64>,
prefix_sum_sq: Vec<f64>,
}
impl PrefixSums {
fn new(data: &[f64]) -> Self {
let n = data.len();
let mut prefix_sum = Vec::with_capacity(n + 1);
let mut prefix_sum_sq = Vec::with_capacity(n + 1);
prefix_sum.push(0.0);
prefix_sum_sq.push(0.0);
let mut s = 0.0;
let mut s_sq = 0.0;
for &x in data {
s += x;
s_sq += x * x;
prefix_sum.push(s);
prefix_sum_sq.push(s_sq);
}
Self { prefix_sum, prefix_sum_sq }
}
#[inline]
fn wcss(&self, start: usize, end: usize) -> f64 {
debug_assert!(start <= end);
let count = (end - start + 1) as f64;
let sum = self.prefix_sum[end + 1] - self.prefix_sum[start];
let sum_sq = self.prefix_sum_sq[end + 1] - self.prefix_sum_sq[start];
(sum_sq - sum * sum / count).max(0.0)
}
}
#[allow(clippy::too_many_arguments)]
fn fill_row_dc(
prefix: &PrefixSums,
best_cost_prev: &[f64],
best_cost_curr: &mut [f64],
last_split_curr: &mut [usize],
row_lo: usize,
row_hi: usize,
search_lo: usize,
search_hi: usize,
) {
if row_lo > row_hi {
return;
}
let mid_row = row_lo + (row_hi - row_lo) / 2;
let upper = search_hi.min(mid_row);
let mut best_cost = f64::INFINITY;
let mut best_split = search_lo;
for candidate_split in search_lo..=upper {
let left_cost = if candidate_split == 0 {
0.0
} else {
best_cost_prev[candidate_split - 1]
};
let right_cost = prefix.wcss(candidate_split, mid_row);
let total = left_cost + right_cost;
if total < best_cost {
best_cost = total;
best_split = candidate_split;
}
}
best_cost_curr[mid_row] = best_cost;
last_split_curr[mid_row] = best_split;
if mid_row > row_lo {
fill_row_dc(
prefix,
best_cost_prev,
best_cost_curr,
last_split_curr,
row_lo,
mid_row - 1,
search_lo,
best_split,
);
}
fill_row_dc(
prefix,
best_cost_prev,
best_cost_curr,
last_split_curr,
mid_row + 1,
row_hi,
best_split,
search_hi,
);
}
#[cfg(not(feature = "low-memory"))]
fn compute_dp(data: &[f64], k: usize) -> Vec<Vec<usize>> {
let n = data.len();
let prefix = PrefixSums::new(data);
let mut matrix_b: Vec<Vec<usize>> = vec![vec![0; k]; n];
let mut best_cost_prev: Vec<f64> = vec![f64::INFINITY; n];
let mut best_cost_curr: Vec<f64> = vec![f64::INFINITY; n];
for i in 0..n {
best_cost_prev[i] = prefix.wcss(0, i);
matrix_b[i][0] = 0;
}
for m in 1..k {
for v in best_cost_curr.iter_mut() {
*v = f64::INFINITY;
}
let mut last_split_curr: Vec<usize> = vec![0; n];
fill_row_dc(
&prefix,
&best_cost_prev,
&mut best_cost_curr,
&mut last_split_curr,
m, n - 1, m, n - 1, );
for i in m..n {
matrix_b[i][m] = last_split_curr[i];
}
std::mem::swap(&mut best_cost_prev, &mut best_cost_curr);
}
matrix_b
}
#[cfg(feature = "low-memory")]
fn solve_last_split(prefix: &PrefixSums, n: usize, upto: usize, k: usize) -> usize {
if k == 1 {
return 0;
}
let effective_n = upto + 1;
let mut best_cost_prev: Vec<f64> = vec![f64::INFINITY; effective_n];
let mut best_cost_curr: Vec<f64> = vec![f64::INFINITY; effective_n];
for (i, cost) in best_cost_prev.iter_mut().enumerate().take(effective_n) {
*cost = prefix.wcss(0, i);
}
let mut last_split_curr: Vec<usize> = vec![0; effective_n];
for m in 1..k {
for v in best_cost_curr.iter_mut() {
*v = f64::INFINITY;
}
for v in last_split_curr.iter_mut() {
*v = 0;
}
fill_row_dc(
prefix,
&best_cost_prev,
&mut best_cost_curr,
&mut last_split_curr,
m,
effective_n - 1,
m,
effective_n - 1,
);
std::mem::swap(&mut best_cost_prev, &mut best_cost_curr);
}
let _ = n; last_split_curr[upto]
}
#[cfg(feature = "low-memory")]
fn compute_ranges(data: &[f64], k: usize) -> IndexRanges {
let n = data.len();
let prefix = PrefixSums::new(data);
let mut ranges: IndexRanges = Vec::with_capacity(k);
let mut cluster_end = n;
let mut m = k - 1;
loop {
let start = solve_last_split(&prefix, n, cluster_end - 1, m + 1);
ranges.push((start, cluster_end));
if m == 0 {
break;
}
cluster_end = start;
m -= 1;
}
ranges.reverse();
ranges
}
#[cfg(not(feature = "low-memory"))]
fn backtrack_values<T: Clone>(data: &[T], matrix_b: &[Vec<usize>], k: usize) -> ClassifiedResult<T> {
let n = data.len();
let mut result: ClassifiedResult<T> = Vec::with_capacity(k);
let mut cluster_end = n;
let mut m = k - 1;
loop {
let cluster_start = matrix_b[cluster_end - 1][m];
result.push(data[cluster_start..cluster_end].to_vec());
if m == 0 {
break;
}
cluster_end = cluster_start;
m -= 1;
}
result.reverse();
result
}
#[cfg(not(feature = "low-memory"))]
fn backtrack_indices(matrix_b: &[Vec<usize>], k: usize, n: usize) -> IndexRanges {
let mut result: IndexRanges = Vec::with_capacity(k);
let mut cluster_end = n;
let mut m = k - 1;
loop {
let cluster_start = matrix_b[cluster_end - 1][m];
result.push((cluster_start, cluster_end));
if m == 0 {
break;
}
cluster_end = cluster_start;
m -= 1;
}
result.reverse();
result
}
impl KNLogN {
#[cfg(not(feature = "low-memory"))]
pub fn classify<T>(data: Vec<T>, k: usize) -> Result<ClassifiedResult<T>, Error>
where
T: PartialOrd + Clone + ToPrimitive,
{
let converted_data = validate_and_convert(&data, k)?;
let matrix_b = compute_dp(&converted_data, k);
Ok(backtrack_values(&data, &matrix_b, k))
}
#[cfg(feature = "low-memory")]
pub fn classify<T>(data: Vec<T>, k: usize) -> Result<ClassifiedResult<T>, Error>
where
T: PartialOrd + Clone + ToPrimitive,
{
let converted_data = validate_and_convert(&data, k)?;
let ranges = compute_ranges(&converted_data, k);
let result = ranges
.into_iter()
.map(|(start, end)| data[start..end].to_vec())
.collect();
Ok(result)
}
#[cfg(not(feature = "low-memory"))]
pub fn classify_indices<T>(data: &[T], k: usize) -> Result<IndexRanges, Error>
where
T: PartialOrd + Clone + ToPrimitive,
{
let converted_data = validate_and_convert(data, k)?;
let matrix_b = compute_dp(&converted_data, k);
Ok(backtrack_indices(&matrix_b, k, data.len()))
}
#[cfg(feature = "low-memory")]
pub fn classify_indices<T>(data: &[T], k: usize) -> Result<IndexRanges, Error>
where
T: PartialOrd + Clone + ToPrimitive,
{
let converted_data = validate_and_convert(data, k)?;
Ok(compute_ranges(&converted_data, k))
}
pub fn classify_indices_with_sort<T>(mut data: Vec<T>, k: usize) -> Result<IndexRanges, Error>
where
T: PartialOrd + Clone + ToPrimitive,
{
for window in data.windows(2) {
if window[0].partial_cmp(&window[1]).is_none() {
return Err(Error::NaNError);
}
}
data.sort_by(|a, b| a.partial_cmp(b).unwrap());
Self::classify_indices(&data, k)
}
pub fn classify_with_sort<T>(mut data: Vec<T>, k: usize) -> Result<ClassifiedResult<T>, Error>
where
T: PartialOrd + Clone + ToPrimitive,
{
for window in data.windows(2) {
if window[0].partial_cmp(&window[1]).is_none() {
return Err(Error::NaNError);
}
}
data.sort_by(|a, b| a.partial_cmp(b).unwrap());
Self::classify(data, k)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_clustering() {
let data = vec![1.0, 2.0, 3.0, 10.0, 11.0, 12.0];
let result = KNLogN::classify(data, 2).unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0], vec![1.0, 2.0, 3.0]);
assert_eq!(result[1], vec![10.0, 11.0, 12.0]);
}
#[test]
fn test_three_clusters() {
let data = vec![1, 2, 3, 10, 11, 12, 50, 51, 52];
let result = KNLogN::classify(data, 3).unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result[0], vec![1, 2, 3]);
assert_eq!(result[1], vec![10, 11, 12]);
assert_eq!(result[2], vec![50, 51, 52]);
}
#[test]
fn test_k_equals_n() {
let data = vec![5.0, 10.0, 15.0];
let result = KNLogN::classify(data, 3).unwrap();
assert_eq!(result, vec![vec![5.0], vec![10.0], vec![15.0]]);
}
#[test]
fn test_single_cluster() {
let data = vec![1.0, 2.0, 3.0];
let result = KNLogN::classify(data, 1).unwrap();
assert_eq!(result, vec![vec![1.0, 2.0, 3.0]]);
}
#[test]
fn test_duplicates() {
let data = vec![1.0, 1.0, 1.0, 5.0, 5.0, 5.0];
let result = KNLogN::classify(data, 2).unwrap();
assert_eq!(result[0], vec![1.0, 1.0, 1.0]);
assert_eq!(result[1], vec![5.0, 5.0, 5.0]);
}
#[test]
fn test_four_clusters() {
let data = vec![1, 2, 10, 11, 20, 21, 30, 31];
let result = KNLogN::classify(data, 4).unwrap();
assert_eq!(result, vec![vec![1, 2], vec![10, 11], vec![20, 21], vec![30, 31]]);
}
#[test]
fn test_wcss_matches_optimal() {
fn wcss(cluster: &[f64]) -> f64 {
let mean = cluster.iter().sum::<f64>() / cluster.len() as f64;
cluster.iter().map(|x| (x - mean).powi(2)).sum()
}
let cases: Vec<(Vec<f64>, usize, f64)> = vec![
(vec![1.0, 2.0, 3.0, 10.0, 11.0, 12.0], 2, 4.0),
(vec![1.0, 2.0, 3.0, 10.0, 11.0, 12.0, 50.0, 51.0, 52.0], 3, 6.0),
(vec![1.0, 3.0, 5.0, 7.0, 9.0, 50.0, 52.0, 54.0], 2, 48.0),
];
for (data, k, expected_wcss) in cases {
let result = KNLogN::classify(data, k).unwrap();
let total_wcss: f64 = result.iter().map(|c| wcss(c)).sum();
assert!(
(total_wcss - expected_wcss).abs() < 1e-9,
"k={k}: expected WCSS={expected_wcss}, got {total_wcss}"
);
}
}
#[test]
fn test_agrees_with_kn_squared() {
use crate::k_n2::KNSquared;
fn total_wcss(clusters: &[Vec<f64>]) -> f64 {
clusters.iter().map(|c| {
let mean = c.iter().sum::<f64>() / c.len() as f64;
c.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
}).sum()
}
let cases: Vec<(Vec<f64>, usize)> = vec![
(vec![1.0, 2.0, 3.0, 10.0, 11.0, 12.0], 2),
(vec![1.0, 2.0, 3.0, 10.0, 11.0, 12.0, 50.0, 51.0, 52.0], 3),
(vec![1.0, 3.0, 5.0, 7.0, 9.0, 50.0, 52.0, 54.0], 2),
(vec![1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0], 3),
(vec![0.1, 0.2, 0.3, 0.4, 0.5, 10.0, 10.1, 10.2, 100.0, 100.1], 4),
];
for (data, k) in cases {
let log_result = KNLogN::classify(data.clone(), k).unwrap();
let sq_result = KNSquared::classify(data.clone(), k).unwrap();
let log_wcss = total_wcss(&log_result);
let sq_wcss = total_wcss(&sq_result);
assert!(
(log_wcss - sq_wcss).abs() < 1e-9,
"WCSS mismatch on data={data:?}, k={k}: KNLogN={log_wcss}, KNSquared={sq_wcss}"
);
}
}
}