#[derive(Clone, Debug)]
pub(crate) struct UnionFind {
pub(crate) parent: Vec<usize>,
pub(crate) size: Vec<usize>,
}
impl UnionFind {
pub(crate) fn new(n: usize) -> Self {
Self {
parent: (0..n).collect(),
size: vec![1; n],
}
}
pub(crate) fn find(&mut self, x: usize) -> usize {
if self.parent[x] != x {
let root = self.find(self.parent[x]);
self.parent[x] = root;
}
self.parent[x]
}
pub(crate) fn union(&mut self, a: usize, b: usize) -> usize {
let ra = self.find(a);
let rb = self.find(b);
self.union_roots(ra, rb)
}
pub(crate) fn union_roots(&mut self, ra: usize, rb: usize) -> usize {
if ra == rb {
return ra;
}
let (mut big, mut small) = (ra, rb);
if self.size[big] < self.size[small] {
std::mem::swap(&mut big, &mut small);
}
self.parent[small] = big;
self.size[big] += self.size[small];
big
}
}
use super::distance::DistanceMetric;
use super::flat::DataRef;
use crate::error::{Error, Result};
use rand::prelude::*;
pub(crate) fn validate_finite(data: &(impl DataRef + ?Sized)) -> Result<()> {
for i in 0..data.n() {
let point = data.row(i);
for (j, &val) in point.iter().enumerate() {
if !val.is_finite() {
return Err(Error::Other(format!(
"data[{i}][{j}] is not finite (NaN or infinity)"
)));
}
}
}
Ok(())
}
pub(crate) fn mean_variance(data: &(impl DataRef + ?Sized)) -> f64 {
let n_usize = data.n();
let n = n_usize as f64;
if n < 1.0 {
return 1.0;
}
let d = data.d();
if d == 0 {
return 1.0;
}
let mut total_var = 0.0f64;
for j in 0..d {
let mean = (0..n_usize).map(|i| data.row(i)[j] as f64).sum::<f64>() / n;
let var = (0..n_usize)
.map(|i| {
let diff = data.row(i)[j] as f64 - mean;
diff * diff
})
.sum::<f64>()
/ n;
total_var += var;
}
let mv = total_var / d as f64;
if mv < f64::EPSILON {
1.0
} else {
mv
}
}
fn update_min_dists_for_centroid<D: DistanceMetric>(
data: &(impl DataRef + ?Sized),
centroid: &[f32],
min_dists: &mut [f32],
metric: &D,
) {
#[cfg(feature = "parallel")]
if data.n() >= 20_000 {
use rayon::prelude::*;
min_dists.par_iter_mut().enumerate().for_each(|(i, md)| {
let d = metric.distance(data.row(i), centroid).max(0.0);
if d < *md {
*md = d;
}
});
return;
}
for (i, md) in min_dists.iter_mut().enumerate() {
let d = metric.distance(data.row(i), centroid).max(0.0);
if d < *md {
*md = d;
}
}
}
pub(crate) fn kmeanspp_init<D: DistanceMetric>(
data: &(impl DataRef + ?Sized),
k: usize,
metric: &D,
alpha: f32,
rng: &mut StdRng,
) -> Vec<Vec<f32>> {
let n = data.n();
let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(k);
let mut min_dists = vec![f32::MAX; n];
let first = rng.random_range(0..n);
centroids.push(data.row(first).to_vec());
update_min_dists_for_centroid(data, ¢roids[0], &mut min_dists, metric);
let exp = alpha / 2.0;
let identity_exp = (exp - 1.0).abs() < f32::EPSILON;
for _ in 1..k {
let total: f32 = if identity_exp {
min_dists.iter().sum()
} else {
min_dists.iter().map(|&d| d.powf(exp)).sum()
};
if total == 0.0 || !total.is_finite() {
let idx = rng.random_range(0..n);
centroids.push(data.row(idx).to_vec());
let new_c = centroids.last().unwrap();
update_min_dists_for_centroid(data, new_c, &mut min_dists, metric);
continue;
}
let threshold = rng.random::<f32>() * total;
let mut cumsum = 0.0f32;
let mut selected = 0;
for (j, &d) in min_dists.iter().enumerate() {
let w = if identity_exp { d } else { d.powf(exp) };
cumsum += w;
if cumsum >= threshold {
selected = j;
break;
}
}
centroids.push(data.row(selected).to_vec());
let new_c = centroids.last().unwrap();
update_min_dists_for_centroid(data, new_c, &mut min_dists, metric);
}
centroids
}
pub(crate) fn assign_nearest<D: DistanceMetric>(
point: &[f32],
centroids: &[Vec<f32>],
metric: &D,
) -> usize {
let mut best_cluster = 0;
let mut best_dist = f32::MAX;
for (k, centroid) in centroids.iter().enumerate() {
let dist = metric.distance(point, centroid);
if dist < best_dist {
best_dist = dist;
best_cluster = k;
}
}
best_cluster
}
#[cfg(not(feature = "parallel"))]
pub(crate) fn geometric_assign<D: DistanceMetric>(
data: &(impl DataRef + ?Sized),
centroids: &[Vec<f32>],
labels: &mut [usize],
centroid_shifts: &[f32],
metric: &D,
first_iter: bool,
) {
let n = data.n();
let k = centroids.len();
if first_iter || k <= 1 {
#[allow(clippy::needless_range_loop)] for i in 0..n {
let mut best = f32::MAX;
let mut best_k = 0;
for (j, c) in centroids.iter().enumerate() {
let d = metric.distance(data.row(i), c);
if d < best {
best = d;
best_k = j;
}
}
labels[i] = best_k;
}
return;
}
let mut half_inter = vec![vec![0.0f32; k]; k];
for j1 in 0..k {
for j2 in (j1 + 1)..k {
let d = metric.distance(¢roids[j1], ¢roids[j2]) * 0.5;
half_inter[j1][j2] = d;
half_inter[j2][j1] = d;
}
}
let use_projection = metric.supports_expanded_form();
let centroid_sq_norms: Vec<f32> = if use_projection {
centroids
.iter()
.map(|c| c.iter().map(|&v| v * v).sum())
.collect()
} else {
Vec::new()
};
#[allow(clippy::needless_range_loop)] for i in 0..n {
let a = labels[i];
let a_shift = centroid_shifts[a];
let mut can_skip = true;
for j in 0..k {
if j == a {
continue;
}
if half_inter[a][j] <= a_shift + centroid_shifts[j] {
can_skip = false;
break;
}
}
if can_skip {
continue;
}
let point = data.row(i);
let dist_a = metric.distance(point, ¢roids[a]);
let mut best = dist_a;
let mut best_k = a;
for j in 0..k {
if j == a {
continue;
}
if half_inter[a][j] > a_shift + centroid_shifts[j] {
continue;
}
if use_projection {
let ca = ¢roids[a];
let cj = ¢roids[j];
let mut dot_x_dir = 0.0f32;
for (idx, &xv) in point.iter().enumerate() {
dot_x_dir += xv * (cj[idx] - ca[idx]);
}
let bias = (centroid_sq_norms[j] - centroid_sq_norms[a]) * 0.5;
if dot_x_dir <= bias {
continue;
}
}
let d = metric.distance(point, ¢roids[j]);
if d < best {
best = d;
best_k = j;
}
}
labels[i] = best_k;
}
}
#[cfg(feature = "parallel")]
#[allow(clippy::too_many_arguments)]
pub(crate) fn hamerly_assign_parallel<D: DistanceMetric>(
data: &(impl DataRef + ?Sized),
centroids: &[Vec<f32>],
labels: &mut [usize],
upper: &mut [f32],
lower: &mut [f32],
centroid_shifts: &[f32],
metric: &D,
first_iter: bool,
flat_buf: &mut Vec<f32>,
) {
use rayon::prelude::*;
let n = data.n();
let k = centroids.len();
let dim = if k > 0 { centroids[0].len() } else { 0 };
flat_buf.clear();
flat_buf.extend(centroids.iter().flat_map(|c| c.iter().copied()));
let flat_centroids = &*flat_buf;
if first_iter || k <= 1 {
let results: Vec<(usize, f32, f32)> = (0..n)
.into_par_iter()
.map(|i| {
let mut best = f32::MAX;
let mut second = f32::MAX;
let mut best_k = 0;
for j in 0..k {
let c = &flat_centroids[j * dim..(j + 1) * dim];
let d = metric.distance(data.row(i), c);
if d < best {
second = best;
best = d;
best_k = j;
} else if d < second {
second = d;
}
}
(best_k, best, second)
})
.collect();
for (i, (lbl, u, l)) in results.into_iter().enumerate() {
labels[i] = lbl;
upper[i] = u;
lower[i] = l;
}
return;
}
let max_shift = centroid_shifts.iter().copied().fold(0.0f32, f32::max);
let mut max_shift_idx = 0;
for (j, &s) in centroid_shifts.iter().enumerate() {
if s >= max_shift {
max_shift_idx = j;
}
}
let mut second_max_shift = 0.0f32;
for (j, &s) in centroid_shifts.iter().enumerate() {
if j != max_shift_idx && s > second_max_shift {
second_max_shift = s;
}
}
let updates: Vec<(usize, f32, f32)> = (0..n)
.into_par_iter()
.map(|i| {
let old_label = labels[i];
let mut u = upper[i] + centroid_shifts[old_label];
let relevant_max = if old_label == max_shift_idx {
second_max_shift
} else {
max_shift
};
let l = lower[i] - relevant_max;
if u <= l {
return (old_label, u, l);
}
let old_c = &flat_centroids[old_label * dim..(old_label + 1) * dim];
u = metric.distance(data.row(i), old_c);
if u <= l {
return (old_label, u, l);
}
let mut best = u;
let mut second = f32::MAX;
let mut best_k = old_label;
for j in 0..k {
if j == old_label {
if best < second {
second = best;
}
continue;
}
let c = &flat_centroids[j * dim..(j + 1) * dim];
let dist = metric.distance(data.row(i), c);
if dist < best {
second = best;
best = dist;
best_k = j;
} else if dist < second {
second = dist;
}
}
(best_k, best, second)
})
.collect();
for (i, (lbl, u, l)) in updates.into_iter().enumerate() {
labels[i] = lbl;
upper[i] = u;
lower[i] = l;
}
}
#[cfg(not(feature = "parallel"))]
#[allow(clippy::too_many_arguments)]
pub(crate) fn hamerly_assign<D: DistanceMetric>(
data: &(impl DataRef + ?Sized),
centroids: &[Vec<f32>],
labels: &mut [usize],
upper: &mut [f32],
lower: &mut [f32],
centroid_shifts: &[f32],
metric: &D,
first_iter: bool,
flat_buf: &mut Vec<f32>,
) -> usize {
let n = data.n();
let k = centroids.len();
let dim = if k > 0 { centroids[0].len() } else { 0 };
let mut recomputed = 0;
let use_flat = k >= 16;
flat_buf.clear();
if use_flat {
flat_buf.extend(centroids.iter().flat_map(|c| c.iter().copied()));
}
let flat_centroids = &*flat_buf;
let batch_dist = |point: &[f32], centroid_idx: usize| -> f32 {
if use_flat {
let c = &flat_centroids[centroid_idx * dim..(centroid_idx + 1) * dim];
metric.distance(point, c)
} else {
metric.distance(point, ¢roids[centroid_idx])
}
};
if first_iter || k <= 1 {
for i in 0..n {
let mut best = f32::MAX;
let mut second = f32::MAX;
let mut best_k = 0;
for j in 0..k {
let dist = batch_dist(data.row(i), j);
if dist < best {
second = best;
best = dist;
best_k = j;
} else if dist < second {
second = dist;
}
}
labels[i] = best_k;
upper[i] = best;
lower[i] = second;
}
return n;
}
let max_shift = centroid_shifts.iter().copied().fold(0.0f32, f32::max);
let mut second_max_shift = 0.0f32;
let mut max_shift_idx = 0;
for (j, &s) in centroid_shifts.iter().enumerate() {
if s >= max_shift {
max_shift_idx = j;
}
}
for (j, &s) in centroid_shifts.iter().enumerate() {
if j != max_shift_idx && s > second_max_shift {
second_max_shift = s;
}
}
for i in 0..n {
upper[i] += centroid_shifts[labels[i]];
let relevant_max = if labels[i] == max_shift_idx {
second_max_shift
} else {
max_shift
};
lower[i] -= relevant_max;
if upper[i] <= lower[i] {
continue;
}
upper[i] = batch_dist(data.row(i), labels[i]);
if upper[i] <= lower[i] {
continue;
}
recomputed += 1;
let mut best = upper[i];
let mut second = f32::MAX;
let mut best_k = labels[i];
for j in 0..k {
if j == labels[i] {
if best < second {
second = best;
}
continue;
}
let dist = batch_dist(data.row(i), j);
if dist < best {
second = best;
best = dist;
best_k = j;
} else if dist < second {
second = dist;
}
}
labels[i] = best_k;
upper[i] = best;
lower[i] = second;
}
recomputed
}
pub(crate) fn squared_norms(data: &(impl DataRef + ?Sized)) -> Vec<f32> {
(0..data.n())
.map(|i| data.row(i).iter().map(|&x| x * x).sum())
.collect()
}
#[cfg(not(feature = "parallel"))]
pub(crate) fn assign_expanded(
data: &(impl DataRef + ?Sized),
centroids: &[Vec<f32>],
data_norms: &[f32],
centroid_norms: &[f32],
) -> (Vec<usize>, Vec<f32>, Vec<f32>) {
let n = data.n();
let k = centroids.len();
let dim = if k > 0 { centroids[0].len() } else { 0 };
let mut labels = vec![0usize; n];
let mut upper = vec![f32::MAX; n];
let mut lower = vec![f32::MAX; n];
let flat_c: Vec<f32> = centroids.iter().flat_map(|c| c.iter().copied()).collect();
for i in 0..n {
let xn = data_norms[i];
let mut best_dist = f32::MAX;
let mut second_dist = f32::MAX;
let mut best_k = 0;
let point = data.row(i);
for j in 0..k {
let cn = centroid_norms[j];
let c_slice = &flat_c[j * dim..(j + 1) * dim];
#[cfg(feature = "simd")]
let dot = if dim >= 16 {
innr::dot(point, c_slice)
} else {
point.iter().zip(c_slice).map(|(&a, &b)| a * b).sum()
};
#[cfg(not(feature = "simd"))]
let dot: f32 = point.iter().zip(c_slice).map(|(&a, &b)| a * b).sum();
let dist = (xn + cn - 2.0 * dot).max(0.0);
if dist < best_dist {
second_dist = best_dist;
best_dist = dist;
best_k = j;
} else if dist < second_dist {
second_dist = dist;
}
}
labels[i] = best_k;
upper[i] = best_dist;
lower[i] = second_dist;
}
(labels, upper, lower)
}
#[cfg(feature = "parallel")]
pub(crate) fn assign_expanded_parallel(
data: &(impl DataRef + ?Sized),
centroids: &[Vec<f32>],
data_norms: &[f32],
centroid_norms: &[f32],
) -> (Vec<usize>, Vec<f32>, Vec<f32>) {
use rayon::prelude::*;
let k = centroids.len();
let dim = if k > 0 { centroids[0].len() } else { 0 };
let flat_c: Vec<f32> = centroids.iter().flat_map(|c| c.iter().copied()).collect();
let results: Vec<(usize, f32, f32)> = (0..data.n())
.into_par_iter()
.map(|i| {
let point = data.row(i);
let xn = data_norms[i];
let mut best_dist = f32::MAX;
let mut second_dist = f32::MAX;
let mut best_k = 0;
for j in 0..k {
let cn = centroid_norms[j];
let c_slice = &flat_c[j * dim..(j + 1) * dim];
#[cfg(feature = "simd")]
let dot = if dim >= 16 {
innr::dot(point, c_slice)
} else {
point.iter().zip(c_slice).map(|(&a, &b)| a * b).sum()
};
#[cfg(not(feature = "simd"))]
let dot: f32 = point.iter().zip(c_slice).map(|(&a, &b)| a * b).sum();
let dist = (xn + cn - 2.0 * dot).max(0.0);
if dist < best_dist {
second_dist = best_dist;
best_dist = dist;
best_k = j;
} else if dist < second_dist {
second_dist = dist;
}
}
(best_k, best_dist, second_dist)
})
.collect();
let n = data.n();
let mut labels = vec![0usize; n];
let mut upper = vec![0.0f32; n];
let mut lower = vec![0.0f32; n];
for (i, (lbl, u, l)) in results.into_iter().enumerate() {
labels[i] = lbl;
upper[i] = u;
lower[i] = l;
}
(labels, upper, lower)
}
pub(crate) fn pairwise_distance_matrix<D: DistanceMetric>(
data: &(impl DataRef + ?Sized),
metric: &D,
) -> Vec<f32> {
let n = data.n();
let mut dists = vec![0.0f32; n * n];
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
let rows: Vec<Vec<(usize, f32)>> = (0..n)
.into_par_iter()
.map(|i| {
((i + 1)..n)
.map(|j| (j, metric.distance(data.row(i), data.row(j))))
.collect()
})
.collect();
for (i, row) in rows.into_iter().enumerate() {
for (j, d) in row {
dists[i * n + j] = d;
dists[j * n + i] = d;
}
}
}
#[cfg(not(feature = "parallel"))]
for i in 0..n {
for j in (i + 1)..n {
let d = metric.distance(data.row(i), data.row(j));
dists[i * n + j] = d;
dists[j * n + i] = d;
}
}
dists
}
pub(crate) fn prim_mst(
n: usize,
dist_fn: impl Fn(usize, usize) -> f32 + Sync,
) -> Vec<(usize, usize, f32)> {
if n <= 1 {
return Vec::new();
}
let mut in_tree = vec![false; n];
let mut best = vec![f32::INFINITY; n];
let mut parent = vec![usize::MAX; n];
best[0] = 0.0;
let mut next_u = 0usize;
for _ in 0..n {
let u = next_u;
if best[u] == f32::INFINITY && u != 0 {
break;
}
in_tree[u] = true;
#[cfg(feature = "parallel")]
if n >= 5000 {
use rayon::prelude::*;
let chunk_size = (n / rayon::current_num_threads().max(1)).max(256);
#[allow(clippy::type_complexity)]
let results: Vec<(usize, f32, Vec<(usize, f32, usize)>)> = (0..n)
.collect::<Vec<_>>()
.par_chunks(chunk_size)
.map(|chunk| {
let mut local_best_v = usize::MAX;
let mut local_best_val = f32::INFINITY;
let mut updates = Vec::new();
for &v in chunk {
if in_tree[v] {
continue;
}
let d = dist_fn(u, v);
if d < best[v] {
updates.push((v, d, u));
}
let current = if d < best[v] { d } else { best[v] };
if current < local_best_val {
local_best_val = current;
local_best_v = v;
}
}
(local_best_v, local_best_val, updates)
})
.collect();
for (_, _, updates) in &results {
for &(v, d, p) in updates {
if d < best[v] {
best[v] = d;
parent[v] = p;
}
}
}
next_u = usize::MAX;
let mut next_best = f32::INFINITY;
for &(v, _val, _) in &results {
if v != usize::MAX && best[v] < next_best {
next_best = best[v];
next_u = v;
}
}
for v in 0..n {
if !in_tree[v] && best[v] < next_best {
next_best = best[v];
next_u = v;
}
}
if next_u == usize::MAX {
break;
}
continue;
}
let mut next_best = f32::INFINITY;
next_u = usize::MAX;
for v in 0..n {
if in_tree[v] {
continue;
}
let d = dist_fn(u, v);
if d < best[v] {
best[v] = d;
parent[v] = u;
}
if best[v] < next_best {
next_best = best[v];
next_u = v;
}
}
if next_u == usize::MAX {
break;
}
}
let mut edges: Vec<(usize, usize, f32)> = Vec::with_capacity(n - 1);
for v in 1..n {
let u = parent[v];
if u != usize::MAX {
edges.push((u, v, best[v]));
}
}
edges
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cluster::distance::{Euclidean, SquaredEuclidean};
#[test]
fn validate_finite_accepts_good_data() {
let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
assert!(validate_finite(&data).is_ok());
}
#[test]
fn validate_finite_rejects_nan() {
let data = vec![vec![1.0, f32::NAN]];
assert!(validate_finite(&data).is_err());
}
#[test]
fn validate_finite_rejects_inf() {
let data = vec![vec![f32::INFINITY, 0.0]];
assert!(validate_finite(&data).is_err());
}
#[test]
fn validate_finite_empty_is_ok() {
let data: Vec<Vec<f32>> = vec![];
assert!(validate_finite(&data).is_ok());
}
#[test]
fn mean_variance_constant_data() {
let data = vec![vec![5.0, 5.0]; 10];
let mv = mean_variance(&data);
assert!(
(mv - 1.0).abs() < 1e-6,
"constant data should give clamped variance 1.0, got {mv}"
);
}
#[test]
fn mean_variance_spread_data() {
let data = vec![vec![0.0], vec![10.0]];
let mv = mean_variance(&data);
assert!(mv > 0.0, "spread data should have positive variance");
}
#[test]
fn squared_norms_basic() {
let data = vec![vec![3.0, 4.0]];
let norms = squared_norms(&data);
assert!((norms[0] - 25.0).abs() < 1e-6);
}
#[test]
fn assign_nearest_finds_closest() {
let point = &[0.0, 0.0];
let centroids = vec![vec![10.0, 10.0], vec![0.1, 0.1], vec![5.0, 5.0]];
let idx = assign_nearest(point, ¢roids, &SquaredEuclidean);
assert_eq!(idx, 1);
}
#[test]
fn prim_mst_triangle() {
let points = [vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
let mst = prim_mst(3, |i, j| Euclidean.distance(&points[i], &points[j]));
assert_eq!(mst.len(), 2, "MST of 3 nodes has 2 edges");
let total_weight: f32 = mst.iter().map(|(_, _, w)| w).sum();
assert!(
(total_weight - 2.0).abs() < 1e-5,
"MST weight should be 2.0, got {total_weight}"
);
}
#[test]
fn prim_mst_single_node() {
let mst = prim_mst(1, |_, _| 0.0);
assert!(mst.is_empty());
}
#[test]
fn prim_mst_two_nodes() {
let mst = prim_mst(2, |_, _| 5.0);
assert_eq!(mst.len(), 1);
assert!((mst[0].2 - 5.0).abs() < 1e-6);
}
#[test]
fn pairwise_distance_matrix_symmetric() {
let data = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
let dists = pairwise_distance_matrix(&data, &Euclidean);
let n = 3;
for i in 0..n {
for j in 0..n {
assert!(
(dists[i * n + j] - dists[j * n + i]).abs() < 1e-6,
"dist matrix not symmetric at ({i},{j})"
);
}
assert!(
dists[i * n + i].abs() < 1e-6,
"diagonal should be 0 at ({i},{i})"
);
}
}
#[test]
fn union_find_basic() {
let mut uf = UnionFind::new(5);
assert_ne!(uf.find(0), uf.find(1));
uf.union(0, 1);
assert_eq!(uf.find(0), uf.find(1));
uf.union(2, 3);
uf.union(0, 2);
assert_eq!(uf.find(0), uf.find(3));
}
#[test]
fn kmeanspp_init_correct_count() {
let data = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![10.0, 10.0],
vec![11.0, 10.0],
];
let mut rng = StdRng::seed_from_u64(42);
let centroids = kmeanspp_init(&data, 3, &SquaredEuclidean, 2.0, &mut rng);
assert_eq!(centroids.len(), 3);
for c in ¢roids {
assert!(data.contains(c), "centroid {:?} not in dataset", c);
}
}
#[test]
fn kmeanspp_init_distinct() {
let data = vec![vec![0.0, 0.0], vec![100.0, 0.0], vec![0.0, 100.0]];
let mut rng = StdRng::seed_from_u64(42);
let centroids = kmeanspp_init(&data, 3, &SquaredEuclidean, 2.0, &mut rng);
let unique: std::collections::HashSet<Vec<u32>> = centroids
.iter()
.map(|c| c.iter().map(|&x| x.to_bits()).collect())
.collect();
assert_eq!(
unique.len(),
3,
"k-means++ should pick 3 distinct centroids"
);
}
}