mod error;
pub use error::{KMeansError, KMeansInitError};
use crate::algorithm::Vector2DOps;
use crate::line_measures::Distance;
use crate::line_measures::metric_spaces::Euclidean;
use crate::{Centroid, GeoFloat, MultiPoint, Point};
use rand::RngExt;
use rand::SeedableRng;
use rand::distr::weighted::WeightedIndex;
use rand::prelude::Distribution;
#[derive(Debug, Clone)]
pub struct KMeansParams<T: GeoFloat> {
k: usize,
seed: Option<u64>,
max_iter: usize,
tolerance: T,
max_radius: Option<T>,
}
impl<T: GeoFloat> KMeansParams<T> {
#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
pub fn new(k: usize) -> Self {
let mut s = Self::new_with_seed(k, 42);
s.seed = None;
s
}
pub fn new_with_seed(k: usize, seed: u64) -> Self {
Self {
k,
seed: Some(seed),
max_iter: 300, tolerance: T::from(0.0001).expect("tolerance must be representable in float type"),
max_radius: None,
}
}
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn tolerance(mut self, tolerance: T) -> Self {
self.tolerance = tolerance;
self
}
pub fn max_radius(mut self, max_radius: T) -> Self {
self.max_radius = Some(max_radius);
self
}
}
pub trait KMeans<T>
where
T: GeoFloat,
{
#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
fn kmeans(&self, k: usize) -> Result<Vec<usize>, KMeansError<T>> {
self.kmeans_with_params(KMeansParams::new(k))
}
fn kmeans_with_seed(&self, k: usize, seed: u64) -> Result<Vec<usize>, KMeansError<T>> {
self.kmeans_with_params(KMeansParams::new_with_seed(k, seed))
}
fn kmeans_with_params(&self, params: KMeansParams<T>) -> Result<Vec<usize>, KMeansError<T>>;
}
#[derive(Debug, Clone)]
struct KMeansConfig<T: GeoFloat> {
params: KMeansParams<T>,
max_split_depth: usize,
}
#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
impl<T: GeoFloat> Default for KMeansConfig<T> {
fn default() -> Self {
Self {
params: KMeansParams::new(0),
max_split_depth: 10,
}
}
}
impl<T: GeoFloat> From<KMeansParams<T>> for KMeansConfig<T> {
fn from(params: KMeansParams<T>) -> Self {
Self {
params,
max_split_depth: 10,
}
}
}
fn kmeans_impl<T>(
points: &[Point<T>],
params: KMeansParams<T>,
) -> Result<Vec<usize>, KMeansError<T>>
where
T: GeoFloat,
{
let n = points.len();
let k = params.k;
if n == 0 {
return Ok(Vec::new());
}
if k == 0 || k > n {
return Err(KMeansError::InvalidK { k, n });
}
if k == 1 {
return Ok(vec![0; n]);
}
if k == n {
return Ok((0..n).collect());
}
let config = KMeansConfig::from(params);
kmeans_impl_with_config(points, config)
}
fn kmeans_impl_with_config<T>(
points: &[Point<T>],
config: KMeansConfig<T>,
) -> Result<Vec<usize>, KMeansError<T>>
where
T: GeoFloat,
{
let n = points.len();
let k = config.params.k;
let point_sq_norms: Vec<T> = points.iter().map(|p| p.0.magnitude_squared()).collect();
let mut centroids = kmeans_plusplus_init(points, k, config.params.seed)?;
let mut centroid_sq_norms: Vec<T> = centroids.iter().map(|c| c.0.magnitude_squared()).collect();
let mut assignments = vec![0; n];
let mut upper_bounds = vec![T::infinity(); n];
let mut lower_bounds = vec![T::zero(); n];
for (i, (point, assignment)) in points.iter().zip(assignments.iter_mut()).enumerate() {
let (nearest_idx, _, _) = find_nearest_and_second_nearest(
*point,
¢roids,
point_sq_norms[i],
¢roid_sq_norms,
);
*assignment = nearest_idx;
}
centroids = update_centroids(points, &mut assignments, ¢roids, k, 0)?;
centroid_sq_norms = centroids.iter().map(|c| c.0.magnitude_squared()).collect();
for (i, ((point, &assignment), (upper, lower))) in points
.iter()
.zip(assignments.iter())
.zip(upper_bounds.iter_mut().zip(lower_bounds.iter_mut()))
.enumerate()
{
let (_, nearest_sq_dist, second_nearest_sq_dist) = find_nearest_and_second_nearest(
*point,
¢roids,
point_sq_norms[i],
¢roid_sq_norms,
);
let sq_dist_to_assigned = squared_distance_using_norms(
*point,
centroids[assignment],
point_sq_norms[i],
centroid_sq_norms[assignment],
);
*upper = sq_dist_to_assigned.sqrt();
*lower = second_nearest_sq_dist.sqrt();
if nearest_sq_dist < sq_dist_to_assigned {
*lower = nearest_sq_dist.sqrt().min(*lower);
}
}
let mut final_max_delta = T::zero();
let mut final_changed_count = 0;
let mut converged = false;
for iter in 0..config.params.max_iter {
let mut changed_count = 0;
let centroid_distances = compute_centroid_distances(¢roids);
for (i, point) in points.iter().enumerate() {
let assigned_idx = assignments[i];
if upper_bounds[i] <= lower_bounds[i] {
continue;
}
if upper_bounds[i] <= centroid_distances[assigned_idx] / T::from(2.0).unwrap() {
continue;
}
let sq_dist_to_assigned = squared_distance_using_norms(
*point,
centroids[assigned_idx],
point_sq_norms[i],
centroid_sq_norms[assigned_idx],
);
upper_bounds[i] = sq_dist_to_assigned.sqrt();
if upper_bounds[i] <= lower_bounds[i] {
continue;
}
let (new_nearest_idx, new_nearest_sq_dist, new_second_nearest_sq_dist) =
find_nearest_and_second_nearest(
*point,
¢roids,
point_sq_norms[i],
¢roid_sq_norms,
);
if new_nearest_idx != assigned_idx {
assignments[i] = new_nearest_idx;
changed_count += 1;
}
upper_bounds[i] = new_nearest_sq_dist.sqrt();
lower_bounds[i] = new_second_nearest_sq_dist.sqrt();
}
if changed_count == 0 {
converged = true;
break;
}
let new_centroids = update_centroids(points, &mut assignments, ¢roids, k, iter)?;
centroid_sq_norms = new_centroids
.iter()
.map(|c| c.0.magnitude_squared())
.collect();
let deltas: Vec<T> = centroids
.iter()
.zip(new_centroids.iter())
.map(|(old, new)| Euclidean.distance(*old, *new))
.collect();
let max_delta = deltas.iter().fold(T::zero(), |a, &b| a.max(b));
final_max_delta = max_delta;
final_changed_count = changed_count;
if max_delta < config.params.tolerance {
converged = true;
break;
}
for ((upper, lower), &assigned_idx) in upper_bounds
.iter_mut()
.zip(lower_bounds.iter_mut())
.zip(assignments.iter())
{
*upper = *upper + deltas[assigned_idx];
*lower = (*lower - max_delta).max(T::zero());
}
centroids = new_centroids;
}
if !converged {
return Err(KMeansError::MaxIterationsReached {
assignments,
iterations: config.params.max_iter,
max_centroid_shift: final_max_delta,
tolerance: config.params.tolerance,
changed_assignments: final_changed_count,
});
}
if let Some(max_radius) = config.params.max_radius {
apply_max_radius_constraint(
points,
&mut assignments,
max_radius,
config.max_split_depth,
config.params.seed,
)?;
}
Ok(assignments)
}
fn kmeans_plusplus_init<T>(
points: &[Point<T>],
k: usize,
seed: Option<u64>,
) -> Result<Vec<Point<T>>, KMeansError<T>>
where
T: GeoFloat,
{
let n = points.len();
let mut rng = rand_pcg::Pcg32::seed_from_u64(seed.unwrap_or(0));
let mut centroids = Vec::with_capacity(k);
let first_idx = rng.random_range(0..n);
centroids.push(points[first_idx]);
for _ in 1..k {
let distances_f64: Vec<f64> = points
.iter()
.map(|point| {
let mut min_sq_dist = T::infinity();
for centroid in centroids.iter() {
let dist = Euclidean.distance(*point, *centroid);
if dist.is_nan() {
return Err(KMeansError::InitializationFailed(
KMeansInitError::NaNCoordinate,
));
}
if dist.is_infinite() {
return Err(KMeansError::InitializationFailed(
KMeansInitError::InfiniteCoordinate,
));
}
let sq_dist = dist * dist;
min_sq_dist = min_sq_dist.min(sq_dist);
}
Ok(min_sq_dist
.to_f64()
.expect("Valid distance should convert to f64"))
})
.collect::<Result<Vec<f64>, KMeansError<T>>>()?;
let dist = WeightedIndex::new(&distances_f64).map_err(|e| {
let all_zero = distances_f64.iter().all(|&d| d == 0.0);
if all_zero {
KMeansError::InitializationFailed(KMeansInitError::DegenerateData)
} else {
KMeansError::InitializationFailed(KMeansInitError::WeightedDistributionFailed {
error: e,
})
}
})?;
let next_idx = dist.sample(&mut rng);
centroids.push(points[next_idx]);
}
Ok(centroids)
}
fn compute_centroid_distances<T>(centroids: &[Point<T>]) -> Vec<T>
where
T: GeoFloat,
{
let k = centroids.len();
let mut distances = vec![T::infinity(); k];
(0..k)
.flat_map(|i| ((i + 1)..k).map(move |j| (i, j)))
.for_each(|(i, j)| {
let dist = Euclidean.distance(centroids[i], centroids[j]);
distances[i] = distances[i].min(dist);
distances[j] = distances[j].min(dist);
});
distances
}
#[inline]
fn squared_distance_using_norms<T>(
point: Point<T>,
centroid: Point<T>,
point_sq_norm: T,
centroid_sq_norm: T,
) -> T
where
T: GeoFloat,
{
let dot_prod = point.0.dot_product(centroid.0);
point_sq_norm + centroid_sq_norm - (T::from(2.0).unwrap() * dot_prod)
}
fn find_nearest_and_second_nearest<T>(
point: Point<T>,
centroids: &[Point<T>],
point_sq_norm: T,
centroid_sq_norms: &[T],
) -> (usize, T, T)
where
T: GeoFloat,
{
let mut nearest_idx = 0;
let mut nearest_sq_dist = T::infinity();
let mut second_nearest_sq_dist = T::infinity();
for (idx, (centroid, ¢roid_sq_norm)) in
centroids.iter().zip(centroid_sq_norms.iter()).enumerate()
{
let sq_dist =
squared_distance_using_norms(point, *centroid, point_sq_norm, centroid_sq_norm);
if sq_dist < nearest_sq_dist {
second_nearest_sq_dist = nearest_sq_dist;
nearest_sq_dist = sq_dist;
nearest_idx = idx;
} else if sq_dist < second_nearest_sq_dist {
second_nearest_sq_dist = sq_dist;
}
}
(nearest_idx, nearest_sq_dist, second_nearest_sq_dist)
}
fn find_farthest_point<T>(
points: &[Point<T>],
assignments: &[usize],
centroids: &[Point<T>],
) -> Option<(usize, T)>
where
T: GeoFloat,
{
let mut farthest_idx = None;
let mut farthest_dist = T::zero();
for (i, (point, &cluster_id)) in points.iter().zip(assignments.iter()).enumerate() {
if cluster_id < centroids.len() {
let dist = Euclidean.distance(*point, centroids[cluster_id]);
if dist > farthest_dist {
farthest_dist = dist;
farthest_idx = Some(i);
}
}
}
farthest_idx.map(|idx| (idx, farthest_dist))
}
fn update_centroids<T>(
points: &[Point<T>],
assignments: &mut [usize],
centroids: &[Point<T>],
k: usize,
iteration: usize,
) -> Result<Vec<Point<T>>, KMeansError<T>>
where
T: GeoFloat,
{
let mut sums: Vec<(T, T)> = vec![(T::zero(), T::zero()); k];
let mut counts: Vec<usize> = vec![0; k];
for (point, &cluster_id) in points.iter().zip(assignments.iter()) {
if cluster_id < k {
sums[cluster_id].0 = sums[cluster_id].0 + point.x();
sums[cluster_id].1 = sums[cluster_id].1 + point.y();
counts[cluster_id] += 1;
}
}
for cluster_id in 0..k {
if counts[cluster_id] == 0 {
if let Some((farthest_idx, _)) = find_farthest_point(points, assignments, centroids) {
let old_cluster = assignments[farthest_idx];
let point = points[farthest_idx];
sums[old_cluster].0 = sums[old_cluster].0 - point.x();
sums[old_cluster].1 = sums[old_cluster].1 - point.y();
counts[old_cluster] -= 1;
sums[cluster_id].0 = sums[cluster_id].0 + point.x();
sums[cluster_id].1 = sums[cluster_id].1 + point.y();
counts[cluster_id] += 1;
assignments[farthest_idx] = cluster_id;
} else {
return Err(KMeansError::EmptyCluster {
iteration,
cluster_id,
});
}
}
}
let new_centroids = sums
.iter()
.zip(counts.iter())
.map(|(&(sum_x, sum_y), &count)| {
let count_t = T::from(count).expect("Cluster count must be representable as float");
Point::new(sum_x / count_t, sum_y / count_t)
})
.collect();
Ok(new_centroids)
}
fn apply_max_radius_constraint<T>(
points: &[Point<T>],
assignments: &mut [usize],
max_radius: T,
remaining_depth: usize,
seed: Option<u64>,
) -> Result<(), KMeansError<T>>
where
T: GeoFloat,
{
if remaining_depth == 0 {
return Ok(());
}
let mut next_cluster_id = assignments.iter().max().map(|&a| a + 1).unwrap_or(0);
let max_cluster = assignments.iter().max().copied().unwrap_or(0);
let mut cluster_ids = Vec::with_capacity(max_cluster + 1);
let mut seen = vec![false; max_cluster + 1];
for &assignment in &*assignments {
if !seen[assignment] {
seen[assignment] = true;
cluster_ids.push(assignment);
}
}
for cluster_id in cluster_ids {
let cluster_points: Vec<(usize, Point<T>)> = points
.iter()
.enumerate()
.filter(|(idx, _)| assignments[*idx] == cluster_id)
.map(|(idx, &point)| (idx, point))
.collect();
if cluster_points.is_empty() {
continue;
}
let cluster_point_vec: Vec<Point<T>> = cluster_points.iter().map(|(_, p)| *p).collect();
let multipoint = MultiPoint::new(cluster_point_vec);
let centroid = multipoint
.centroid()
.expect("MultiPoint cannot be empty after filtering non-empty cluster");
let max_dist = multipoint
.iter()
.map(|p| Euclidean.distance(¢roid, p))
.fold(T::zero(), T::max);
if max_dist > max_radius {
let params = if let Some(s) = seed {
KMeansParams::new_with_seed(2, s)
} else {
#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
{
KMeansParams::new(2)
}
#[cfg(all(target_family = "wasm", target_os = "unknown"))]
{
unreachable!(
"Geo only supports KMeansParams with an explicit seed on WASM Targets"
)
}
};
let sub_assignments = kmeans_impl(&multipoint.0, params)?;
for ((original_idx, _), &sub_assignment) in
cluster_points.iter().zip(sub_assignments.iter())
{
assignments[*original_idx] = if sub_assignment == 0 {
cluster_id
} else {
next_cluster_id
};
}
next_cluster_id += 1;
}
}
Ok(())
}
impl<T> KMeans<T> for MultiPoint<T>
where
T: GeoFloat,
{
fn kmeans_with_params(&self, params: KMeansParams<T>) -> Result<Vec<usize>, KMeansError<T>> {
kmeans_impl(&self.0, params)
}
}
impl<T> KMeans<T> for &MultiPoint<T>
where
T: GeoFloat,
{
fn kmeans_with_params(&self, params: KMeansParams<T>) -> Result<Vec<usize>, KMeansError<T>> {
kmeans_impl(&self.0, params)
}
}
impl<T> KMeans<T> for [Point<T>]
where
T: GeoFloat,
{
fn kmeans_with_params(&self, params: KMeansParams<T>) -> Result<Vec<usize>, KMeansError<T>> {
kmeans_impl(self, params)
}
}
impl<T> KMeans<T> for &[Point<T>]
where
T: GeoFloat,
{
fn kmeans_with_params(&self, params: KMeansParams<T>) -> Result<Vec<usize>, KMeansError<T>> {
kmeans_impl(self, params)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::point;
#[test]
fn test_kmeans_empty() {
let points: Vec<Point<f64>> = vec![];
let labels = points.kmeans(2).unwrap();
assert_eq!(labels.len(), 0);
}
#[test]
fn test_kmeans_single_point() {
let points = [point!(x: 0.0, y: 0.0)];
let labels = points.kmeans(1).unwrap();
assert_eq!(labels, vec![0]);
}
#[test]
fn test_kmeans_k_zero() {
let points = [point!(x: 0.0, y: 0.0), point!(x: 1.0, y: 1.0)];
let result = points.kmeans(0);
assert!(matches!(result, Err(KMeansError::InvalidK { k: 0, n: 2 })));
}
#[test]
fn test_kmeans_k_too_large() {
let points = [point!(x: 0.0, y: 0.0), point!(x: 1.0, y: 1.0)];
let result = points.kmeans(10);
assert!(matches!(result, Err(KMeansError::InvalidK { k: 10, n: 2 })));
}
#[test]
fn test_kmeans_k_equals_n() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 1.0, y: 1.0),
point!(x: 2.0, y: 2.0),
];
let labels = points.kmeans(3).unwrap();
assert_eq!(labels, vec![0, 1, 2]);
}
#[test]
fn test_kmeans_single_cluster() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 1.0, y: 0.0),
point!(x: 0.0, y: 1.0),
point!(x: 1.0, y: 1.0),
];
let labels = points.kmeans(1).unwrap();
assert!(labels.iter().all(|&label| label == 0));
}
#[test]
fn test_kmeans_two_clusters() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 1.0, y: 0.0),
point!(x: 0.0, y: 1.0),
point!(x: 10.0, y: 10.0),
point!(x: 11.0, y: 10.0),
point!(x: 10.0, y: 11.0),
];
let labels = points.kmeans(2).unwrap();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[1], labels[2]);
assert_eq!(labels[3], labels[4]);
assert_eq!(labels[4], labels[5]);
assert_ne!(labels[0], labels[3]);
}
#[test]
fn test_kmeans_multipoint() {
let points = MultiPoint::new(vec![
point!(x: 0.0, y: 0.0),
point!(x: 1.0, y: 0.0),
point!(x: 0.0, y: 1.0),
point!(x: 10.0, y: 10.0),
point!(x: 11.0, y: 10.0),
point!(x: 10.0, y: 11.0),
]);
let labels = points.kmeans(2).unwrap();
let cluster_0_count = labels.iter().filter(|&&l| l == 0).count();
let cluster_1_count = labels.iter().filter(|&&l| l == 1).count();
assert_eq!(cluster_0_count + cluster_1_count, 6);
}
#[test]
fn test_kmeans_three_clusters() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 1.0, y: 0.0),
point!(x: 10.0, y: 10.0),
point!(x: 11.0, y: 10.0),
point!(x: 20.0, y: 20.0),
point!(x: 21.0, y: 20.0),
];
let params = KMeansParams::new(3).seed(42);
let labels = points.kmeans_with_params(params).unwrap();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[2], labels[3]);
assert_eq!(labels[4], labels[5]);
assert_ne!(labels[0], labels[2]);
assert_ne!(labels[2], labels[4]);
assert_ne!(labels[0], labels[4]);
}
#[test]
fn test_kmeans_identical_points() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 0.0, y: 0.0),
point!(x: 0.0, y: 0.0),
];
let labels = points.kmeans(1).unwrap();
assert!(labels.iter().all(|&label| label == 0));
}
#[test]
fn test_kmeans_linear_cluster() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 1.0, y: 0.0),
point!(x: 2.0, y: 0.0),
point!(x: 3.0, y: 0.0),
point!(x: 4.0, y: 0.0),
];
let labels = points.kmeans(2).unwrap();
let mut unique_clusters: Vec<_> = labels.to_vec();
unique_clusters.sort_unstable();
unique_clusters.dedup();
assert_eq!(unique_clusters.len(), 2);
}
#[test]
fn test_kmeans_degenerate_all_same_location() {
let points = [
point!(x: 5.0, y: 5.0),
point!(x: 5.0, y: 5.0),
point!(x: 5.0, y: 5.0),
point!(x: 5.0, y: 5.0),
];
let result = points.kmeans(2);
assert!(matches!(
result,
Err(KMeansError::InitializationFailed(
KMeansInitError::DegenerateData
))
));
}
#[test]
fn test_kmeans_nan_coordinates() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: f64::NAN, y: 1.0),
point!(x: 2.0, y: 2.0),
];
let result = points.kmeans(2);
assert!(matches!(
result,
Err(KMeansError::InitializationFailed(
KMeansInitError::NaNCoordinate
))
));
}
#[test]
fn test_kmeans_infinite_coordinates() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: f64::INFINITY, y: 1.0),
point!(x: 2.0, y: 2.0),
];
let result = points.kmeans(2);
assert!(matches!(
result,
Err(KMeansError::InitializationFailed(
KMeansInitError::InfiniteCoordinate
))
));
}
#[test]
fn test_kmeans_reproducibility_with_seed() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 1.0, y: 0.0),
point!(x: 0.0, y: 1.0),
point!(x: 10.0, y: 10.0),
point!(x: 11.0, y: 10.0),
point!(x: 10.0, y: 11.0),
];
let params = KMeansParams::new(2).seed(42);
let result1 = points.kmeans_with_params(params.clone()).unwrap();
let result2 = points.kmeans_with_params(params).unwrap();
assert_eq!(result1, result2);
}
#[test]
fn test_kmeans_reproducibility_with_seed_and_max_radius() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 100.0, y: 0.0),
point!(x: 0.0, y: 100.0),
point!(x: 100.0, y: 100.0),
point!(x: 50.0, y: 50.0),
];
let params = KMeansParams::new(1).seed(42).max_radius(40.0);
let result1 = points.kmeans_with_params(params.clone()).unwrap();
let result2 = points.kmeans_with_params(params).unwrap();
assert_eq!(result1, result2);
}
#[test]
fn test_kmeans_builder_pattern() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 1.0, y: 0.0),
point!(x: 0.0, y: 1.0),
point!(x: 10.0, y: 10.0),
point!(x: 11.0, y: 10.0),
point!(x: 10.0, y: 11.0),
];
let params = KMeansParams::new(2).seed(42).max_iter(100).tolerance(0.001);
let labels = points.kmeans_with_params(params).unwrap();
let unique_labels: std::collections::HashSet<_> = labels.iter().copied().collect();
assert_eq!(unique_labels.len(), 2);
}
#[test]
fn test_kmeans_converges_with_sufficient_iterations() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 1.0, y: 0.0),
point!(x: 0.0, y: 1.0),
point!(x: 10.0, y: 10.0),
point!(x: 11.0, y: 10.0),
point!(x: 10.0, y: 11.0),
];
let params = KMeansParams::new(2).seed(42).max_iter(100);
let result = points.kmeans_with_params(params);
assert!(result.is_ok());
}
#[test]
fn test_kmeans_max_radius() {
use std::collections::HashSet;
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 1.0, y: 0.0),
point!(x: 0.5, y: 0.5),
point!(x: 0.0, y: 1.0),
point!(x: 1.0, y: 1.0),
point!(x: 20.0, y: 0.0),
point!(x: 30.0, y: 0.0),
point!(x: 40.0, y: 0.0),
point!(x: 50.0, y: 0.0),
point!(x: 60.0, y: 0.0),
point!(x: 70.0, y: 0.0),
point!(x: 80.0, y: 0.0),
];
let params = KMeansParams::new(2).seed(42).max_radius(15.0);
let labels = points.kmeans_with_params(params).unwrap();
let unique_labels: HashSet<_> = labels.iter().copied().collect();
assert!(
unique_labels.len() > 2,
"Expected more than 2 clusters due to max_radius splitting, got {}",
unique_labels.len()
);
for &cluster_id in &unique_labels {
let cluster_points: Vec<Point<f64>> = points
.iter()
.zip(labels.iter())
.filter(|&(_, &label)| label == cluster_id)
.map(|(&p, _)| p)
.collect();
if cluster_points.is_empty() {
continue;
}
let multipoint = MultiPoint::new(cluster_points);
let centroid = multipoint.centroid().unwrap();
let max_dist = multipoint
.iter()
.map(|p| Euclidean.distance(¢roid, p))
.fold(0.0, f64::max);
let epsilon = 2.0;
assert!(
max_dist <= 15.0 + epsilon,
"Cluster {} has radius {:.2} significantly exceeding max_radius 15.0",
cluster_id,
max_dist
);
}
let tight_cluster_labels: HashSet<_> = labels[0..5].iter().copied().collect();
assert_eq!(
tight_cluster_labels.len(),
1,
"Tight cluster should remain as one cluster, but has {} different labels",
tight_cluster_labels.len()
);
let elongated_cluster_labels: HashSet<_> = labels[5..].iter().copied().collect();
assert!(
elongated_cluster_labels.len() >= 2,
"Elongated cluster should be split into at least 2 clusters, got {}",
elongated_cluster_labels.len()
);
}
}