use num_traits::ToPrimitive;
use crate::{error::Error, util::validate_and_convert, ClassifiedResult, IndexRanges};
pub struct KNSquared {}
fn compute_dp(converted_data: &[f64], k: usize) -> Vec<Vec<usize>> {
let n = converted_data.len();
let mut matrix_d: Vec<Vec<f64>> = vec![vec![f64::INFINITY; k]; n];
let mut matrix_b: Vec<Vec<usize>> = vec![vec![0; k]; n];
let mut mu = converted_data[0];
let mut d_running = 0f64;
matrix_d[0][0] = 0.0;
matrix_b[0][0] = 0;
for i in 1..n {
let xi = converted_data[i];
let count = i + 1;
d_running = d_next(d_running, count, xi, mu);
mu = mu_next(xi, count, mu);
matrix_d[i][0] = d_running;
matrix_b[i][0] = 0;
}
for m in 1..k { for i in m..n { let mut d_xi_2_xj = 0f64;
let mut mu_prev = converted_data[i];
let mut lowest_d = f64::INFINITY;
let mut b = i;
let cost = matrix_d[i - 1][m - 1];
if cost < lowest_d {
lowest_d = cost;
b = i;
}
for j in (m..i).rev() { let count = i - j + 1;
d_xi_2_xj = d_next(d_xi_2_xj, count, converted_data[j], mu_prev);
mu_prev = mu_next(converted_data[j], count, mu_prev);
if d_xi_2_xj >= lowest_d {
break;
}
if matrix_d[j - 1][m - 1] + d_xi_2_xj < lowest_d {
lowest_d = matrix_d[j - 1][m - 1] + d_xi_2_xj;
b = j;
}
}
matrix_d[i][m] = lowest_d;
matrix_b[i][m] = b;
}
}
matrix_b
}
fn backtrack_values<T: Clone>(data: &[T], matrix_b: &[Vec<usize>], k: usize) -> ClassifiedResult<T> {
let n = data.len();
let mut result: ClassifiedResult<T> = Vec::with_capacity(k);
let mut cluster_end = n;
let mut m = k - 1;
loop {
let cluster_start = matrix_b[cluster_end - 1][m];
result.push(data[cluster_start..cluster_end].to_vec());
if m == 0 {
break;
}
cluster_end = cluster_start;
m -= 1;
}
result.reverse();
result
}
fn backtrack_indices(matrix_b: &[Vec<usize>], k: usize, n: usize) -> IndexRanges {
let mut result: IndexRanges = Vec::with_capacity(k);
let mut cluster_end = n;
let mut m = k - 1;
loop {
let cluster_start = matrix_b[cluster_end - 1][m];
result.push((cluster_start, cluster_end));
if m == 0 {
break;
}
cluster_end = cluster_start;
m -= 1;
}
result.reverse();
result
}
impl KNSquared {
pub fn classify<T>(data: Vec<T>, k: usize) -> Result<ClassifiedResult<T>, Error>
where
T: PartialOrd + Clone + ToPrimitive,
{
let converted_data = validate_and_convert(&data, k)?;
let matrix_b = compute_dp(&converted_data, k);
Ok(backtrack_values(&data, &matrix_b, k))
}
pub fn classify_indices<T>(data: &[T], k: usize) -> Result<IndexRanges, Error>
where
T: PartialOrd + Clone + ToPrimitive,
{
let converted_data = validate_and_convert(data, k)?;
let matrix_b = compute_dp(&converted_data, k);
Ok(backtrack_indices(&matrix_b, k, data.len()))
}
pub fn classify_indices_with_sort<T>(mut data: Vec<T>, k: usize) -> Result<IndexRanges, Error>
where
T: PartialOrd + Clone + ToPrimitive,
{
for window in data.windows(2) {
if window[0].partial_cmp(&window[1]).is_none() {
return Err(Error::NaNError);
}
}
data.sort_by(|a, b| a.partial_cmp(b).unwrap());
Self::classify_indices(&data, k)
}
pub fn classify_with_sort<T>(mut data: Vec<T>, k: usize) -> Result<ClassifiedResult<T>, Error>
where
T: PartialOrd + Clone + ToPrimitive,
{
for window in data.windows(2) {
if window[0].partial_cmp(&window[1]).is_none() {
return Err(Error::NaNError);
}
}
data.sort_by(|a, b| a.partial_cmp(b).unwrap());
Self::classify(data, k)
}
}
fn mu_next(xi: f64, count: usize, mu_prev: f64) -> f64 {
(xi + (count - 1) as f64 * mu_prev) / count as f64
}
fn d_next(d_prev: f64, count: usize, xi: f64, mu_prev: f64) -> f64 {
d_prev + ((count - 1) as f64 / count as f64) * (xi - mu_prev).powi(2)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_clustering() {
let data = vec![1.0, 2.0, 3.0, 10.0, 11.0, 12.0];
let result = KNSquared::classify(data, 2).unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0], vec![1.0, 2.0, 3.0]);
assert_eq!(result[1], vec![10.0, 11.0, 12.0]);
}
#[test]
fn test_three_clusters() {
let data = vec![1, 2, 3, 10, 11, 12, 50, 51, 52];
let result = KNSquared::classify(data, 3).unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result[0], vec![1, 2, 3]);
assert_eq!(result[1], vec![10, 11, 12]);
assert_eq!(result[2], vec![50, 51, 52]);
}
#[test]
fn test_k_equals_n() {
let data = vec![5.0, 10.0, 15.0];
let result = KNSquared::classify(data, 3).unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result[0], vec![5.0]);
assert_eq!(result[1], vec![10.0]);
assert_eq!(result[2], vec![15.0]);
}
#[test]
fn test_single_cluster() {
let data = vec![1.0, 2.0, 3.0];
let result = KNSquared::classify(data, 1).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0], vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_zero_clusters_error() {
let data = vec![1.0, 2.0];
let result = KNSquared::classify(data, 0);
assert!(result.is_err());
}
#[test]
fn test_n_less_than_k_error() {
let data = vec![1.0, 2.0];
let result = KNSquared::classify(data, 5);
assert!(result.is_err());
}
#[test]
fn test_classify_with_sort() {
let data = vec![12.0, 1.0, 11.0, 2.0, 10.0, 3.0];
let result = KNSquared::classify_with_sort(data, 2).unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0], vec![1.0, 2.0, 3.0]);
assert_eq!(result[1], vec![10.0, 11.0, 12.0]);
}
#[test]
fn test_duplicates() {
let data = vec![1.0, 1.0, 1.0, 5.0, 5.0, 5.0];
let result = KNSquared::classify(data, 2).unwrap();
assert_eq!(result[0], vec![1.0, 1.0, 1.0]);
assert_eq!(result[1], vec![5.0, 5.0, 5.0]);
}
#[test]
fn test_four_clusters() {
let data = vec![1, 2, 10, 11, 20, 21, 30, 31];
let result = KNSquared::classify(data, 4).unwrap();
assert_eq!(result, vec![vec![1, 2], vec![10, 11], vec![20, 21], vec![30, 31]]);
}
#[test]
fn test_classify_indices() {
let data = [1.0, 2.0, 3.0, 10.0, 11.0, 12.0];
let result = KNSquared::classify_indices(&data, 2).unwrap();
assert_eq!(result, vec![(0, 3), (3, 6)]);
let data = [1, 2, 10, 11, 20, 21, 30, 31];
let result = KNSquared::classify_indices(&data, 4).unwrap();
assert_eq!(result, vec![(0, 2), (2, 4), (4, 6), (6, 8)]);
}
#[test]
fn test_wcss_matches_optimal() {
fn wcss(cluster: &[f64]) -> f64 {
let mean = cluster.iter().sum::<f64>() / cluster.len() as f64;
cluster.iter().map(|x| (x - mean).powi(2)).sum()
}
let cases: Vec<(Vec<f64>, usize, f64)> = vec![
(vec![1.0, 2.0, 3.0, 10.0, 11.0, 12.0], 2, 4.0),
(vec![1.0, 2.0, 3.0, 10.0, 11.0, 12.0, 50.0, 51.0, 52.0], 3, 6.0),
(vec![1.0, 3.0, 5.0, 7.0, 9.0, 50.0, 52.0, 54.0], 2, 48.0),
];
for (data, k, expected_wcss) in cases {
let result = KNSquared::classify(data, k).unwrap();
let total_wcss: f64 = result.iter().map(|c| wcss(c)).sum();
assert!(
(total_wcss - expected_wcss).abs() < 1e-9,
"k={k}: expected WCSS={expected_wcss}, got {total_wcss}"
);
}
}
}