#![expect(
clippy::missing_errors_doc,
reason = "The Error-Enum is sparse and documented."
)]
use core::hash::{self, Hash};
use core::{cmp, ops};
use core::{f64, fmt, iter};
use ndarray::Array1;
use pathfinding::{num_traits::Zero, prelude::dijkstra};
use rustc_hash::{FxHashMap, FxHashSet};
use smallvec::SmallVec;
use std::collections::BinaryHeap;
#[cfg(not(target_pointer_width = "16"))]
pub type Storage = u32;
#[expect(
clippy::as_conversions,
reason = "`Storage::BITS` will always fit into a `usize`."
)]
pub const MAX_POINT_COUNT: usize = Storage::BITS as usize;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
pub struct Cluster(Storage);
impl Cluster {
const fn new() -> Self {
Self(0)
}
const fn singleton(point_ix: usize) -> Self {
Self(1 << point_ix)
}
fn insert(&mut self, point_ix: usize) {
let point = 1 << point_ix;
debug_assert!(
(point & self.0) == 0,
"Throughout the entire implementation, we should never to add the same point twice."
);
self.0 |= point;
}
fn remove(&mut self, point_ix: usize) {
let point = 1 << point_ix;
debug_assert!(
(point & self.0) != 0,
"Throughout the entire implementation, we should never remove a non-existing point."
);
self.0 &= !point;
}
#[must_use]
#[inline]
pub const fn contains(self, point_ix: usize) -> bool {
(self.0 & (1 << point_ix)) != 0
}
#[must_use]
#[inline]
pub const fn len(self) -> Storage {
self.0.count_ones()
}
#[must_use]
#[inline]
pub const fn is_empty(self) -> bool {
self.0 == 0
}
#[inline]
#[must_use]
pub const fn iter(self) -> ClusterIter {
ClusterIter(self.0)
}
fn union_with(&mut self, other: Self) {
debug_assert!(
self.0 & other.0 == 0,
"Troughout the entire implementation, we should never be merging intersecting clusters."
);
self.0 |= other.0;
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct ClusterIter(Storage);
impl Iterator for ClusterIter {
type Item = usize;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.0 == 0 {
None
} else {
#[expect(
clippy::as_conversions,
reason = "I assume `usize` is at least `Storage`."
)]
let ix = self.0.trailing_zeros() as usize;
self.0 &= self.0 - 1;
Some(ix)
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
#[expect(
clippy::as_conversions,
reason = "I assume `usize` is at least `Storage`."
)]
let count = self.0.count_ones() as usize;
(count, Some(count))
}
}
impl IntoIterator for Cluster {
type Item = usize;
type IntoIter = ClusterIter;
#[inline]
fn into_iter(self) -> Self::IntoIter {
ClusterIter(self.0)
}
}
impl fmt::Display for Cluster {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
#[expect(
clippy::as_conversions,
reason = "I assume `usize` is at least `Storage`."
)]
let mut result = String::with_capacity(Storage::BITS as usize);
let mut bits = self.0;
for _ in 0..Storage::BITS {
if (bits & 1) == 1 {
result.push('#');
} else {
result.push('.');
}
bits >>= 1;
}
write!(f, "{result}")
}
}
pub type Clustering = FxHashSet<Cluster>;
type Distances = Vec<Vec<f64>>;
pub type Point = Array1<f64>;
pub type WeightedPoint = (f64, Array1<f64>);
#[derive(Clone, Debug)]
struct ClusteringNodeMergeMultiple {
clusters: SmallVec<[Cluster; 6]>,
cost: f64,
}
impl PartialEq for ClusteringNodeMergeMultiple {
fn eq(&self, other: &Self) -> bool {
self.clusters == other.clusters
}
}
impl Eq for ClusteringNodeMergeMultiple {}
impl Hash for ClusteringNodeMergeMultiple {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
self.clusters.hash(state);
}
}
impl ClusteringNodeMergeMultiple {
#[must_use]
#[inline]
fn get_all_merges<C: Cost + ?Sized>(&self, data: &mut C) -> Vec<Self> {
debug_assert!(
self.clusters.is_sorted(),
"The clusters should always be sorted, to prevent duplicates."
);
#[expect(
clippy::integer_division,
reason = "At least one of the factors is always even."
)]
let mut nodes = Vec::with_capacity(self.clusters.len() * (self.clusters.len() - 1) / 2);
for i in 0..(self.clusters.len() - 1) {
let (cluster_i, clusters_minus_i) = {
let mut clusters_minus_i = self.clusters.clone();
let cluster_i = clusters_minus_i.remove(i);
(cluster_i, clusters_minus_i)
};
let cost_minus_i = self.cost - data.cost(cluster_i);
nodes.extend((i..clusters_minus_i.len()).map(|j| {
let mut new_clusters = clusters_minus_i.clone();
let cluster_j = unsafe { new_clusters.get_unchecked_mut(j) };
let mut new_cost = cost_minus_i - data.cost(*cluster_j);
cluster_j.union_with(cluster_i);
new_cost += data.cost(*cluster_j);
debug_assert!(new_clusters.len() == self.clusters.len() - 1, "We should have merged two clusters, which should have reduced the number of clusters by exactly one.");
debug_assert!(new_clusters.is_sorted(), "The clusters should always be sorted, to prevent duplicates.");
debug_assert!({
(0..data.num_points()).all(|point_ix| new_clusters.iter().filter(|cluster| cluster.contains(point_ix)).count()==1)
},"The clusters should always cover every point exactly once.");
Self {
clusters: new_clusters,
cost: new_cost,
}
}));
}
nodes
}
fn optimise_locally<C: Cost + ?Sized>(&mut self, data: &mut C) {
let mut already_visited: FxHashSet<(Cluster, usize, usize)> = FxHashSet::default();
let mut found_improvement = || {
#[expect(
clippy::indexing_slicing,
reason = "These are safe, we just use indices to avoid borrow-issues."
)]
for source_cluster_ix in 0..self.clusters.len() {
let source_cluster = self.clusters[source_cluster_ix];
for point_ix in source_cluster {
let mut updated_source_cluster = source_cluster;
updated_source_cluster.remove(point_ix);
let source_costdelta =
data.cost(updated_source_cluster) - data.cost(source_cluster);
for target_cluster_ix in
(0..self.clusters.len()).filter(|ix| *ix != source_cluster_ix)
{
if !already_visited.insert((
source_cluster,
source_cluster_ix,
target_cluster_ix,
)) {
continue;
}
let target_cluster = self.clusters[target_cluster_ix];
let mut updated_target_cluster = target_cluster;
updated_target_cluster.insert(point_ix);
let costdelta = source_costdelta + data.cost(updated_target_cluster)
- data.cost(target_cluster);
if costdelta < 0.0 {
if updated_source_cluster.cmp(&updated_target_cluster)
== source_cluster_ix.cmp(&target_cluster_ix)
{
self.clusters[source_cluster_ix] = updated_source_cluster;
self.clusters[target_cluster_ix] = updated_target_cluster;
} else {
self.clusters[source_cluster_ix] = updated_target_cluster;
self.clusters[target_cluster_ix] = updated_source_cluster;
}
self.cost += costdelta;
return true;
}
}
}
}
false
};
while found_improvement() {}
self.clusters.sort();
debug_assert!(
{
(0..data.num_points()).all(|point_ix| {
self.clusters
.iter()
.filter(|cluster| cluster.contains(point_ix))
.count()
== 1
})
},
"The clusters should always cover every point exactly once."
);
}
#[inline]
fn new_singletons(num_points: usize) -> Self {
let mut clusters = SmallVec::default();
for i in 0..num_points {
clusters.push(Cluster::singleton(i));
}
debug_assert!(
clusters.is_sorted(),
"The clusters should always be sorted, to prevent duplicates."
);
Self {
clusters,
cost: 0.0,
}
}
#[inline]
fn into_clustering(self) -> Clustering {
self.clusters.into_iter().collect()
}
}
#[derive(Clone, Debug)]
struct ClusteringNodeMergeSingle {
clusters: SmallVec<[Cluster; 6]>,
cost: f64,
next_to_add: usize,
}
impl PartialEq for ClusteringNodeMergeSingle {
fn eq(&self, other: &Self) -> bool {
self.clusters == other.clusters
}
}
impl Eq for ClusteringNodeMergeSingle {}
impl Hash for ClusteringNodeMergeSingle {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
self.clusters.hash(state);
}
}
impl Ord for ClusteringNodeMergeSingle {
fn cmp(&self, other: &Self) -> cmp::Ordering {
other
.cost
.total_cmp(&self.cost)
.then_with(|| self.clusters.cmp(&other.clusters))
}
}
impl PartialOrd for ClusteringNodeMergeSingle {
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
Some(self.cmp(other))
}
}
impl ClusteringNodeMergeSingle {
#[inline]
fn get_next_nodes<'a, C: Cost + ?Sized>(
&'a self,
data: &'a mut C,
k: usize,
) -> impl Iterator<Item = Self> + use<'a, C> {
(0..self.clusters.len())
.map(|cluster_ix| {
let mut new_clustering_node = self.clone();
let cluster_to_edit =
unsafe { new_clustering_node.clusters.get_unchecked_mut(cluster_ix) };
new_clustering_node.cost -= data.cost(*cluster_to_edit);
cluster_to_edit.insert(new_clustering_node.next_to_add);
new_clustering_node.cost += data.cost(*cluster_to_edit);
new_clustering_node.next_to_add += 1;
new_clustering_node
})
.chain((self.clusters.len() < k).then(|| {
let mut clustering_node = self.clone();
clustering_node
.clusters
.push(Cluster::singleton(clustering_node.next_to_add));
clustering_node.next_to_add += 1;
clustering_node
}))
}
fn empty() -> Self {
Self {
clusters: SmallVec::default(),
cost: 0.0,
next_to_add: 0,
}
}
}
#[derive(Debug, PartialEq, Clone, Copy)]
struct MaxRatio(f64);
impl MaxRatio {
#[inline]
fn new(clustering_cost: f64, opt_cost: f64) -> Self {
debug_assert!(
clustering_cost.is_finite(),
"hierarchy_cost {clustering_cost} should be finite."
);
debug_assert!(
opt_cost.is_finite(),
"opt_cost {opt_cost} should be finite."
);
debug_assert!(
opt_cost >= 0.0,
"opt_cost {opt_cost} should be non-negative."
);
debug_assert!(
clustering_cost >= 0.0,
"hierarchy_cost {clustering_cost} should be non-negative"
);
debug_assert!(
clustering_cost >= opt_cost - 1e-9,
"hierarchy_cost {clustering_cost} should be at least opt_cost {opt_cost}"
);
Self(if opt_cost.is_zero() {
if clustering_cost.is_zero() {
1.0
} else {
f64::INFINITY
}
} else {
clustering_cost / opt_cost
})
}
}
impl Eq for MaxRatio {} impl Ord for MaxRatio {
fn cmp(&self, other: &Self) -> cmp::Ordering {
self.0.total_cmp(&other.0)
}
}
impl PartialOrd for MaxRatio {
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
Some(self.cmp(other))
}
}
impl ops::Add for MaxRatio {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self(self.0.max(rhs.0))
}
}
impl Zero for MaxRatio {
fn zero() -> Self {
Self(1.0)
}
#[expect(clippy::float_cmp, reason = "This should be exact.")]
fn is_zero(&self) -> bool {
self.0 == 1.0
}
}
type Costs = FxHashMap<Cluster, f64>;
pub trait Cost {
fn cost(&mut self, cluster: Cluster) -> f64;
#[inline]
fn total_cost(&mut self, clustering: &Clustering) -> f64 {
clustering.iter().map(|cluster| self.cost(*cluster)).sum()
}
#[inline]
fn approximate_clusterings(&mut self) -> Vec<(f64, Clustering)> {
let num_points = self.num_points();
let mut clustering = ClusteringNodeMergeMultiple::new_singletons(num_points);
let mut solution: Vec<(f64, Clustering)> =
vec![(0.0, clustering.clone().into_clustering())];
while clustering.clusters.len() > 1 {
let mut best_merge = clustering
.get_all_merges(self)
.into_iter()
.min_by(|a, b| a.cost.total_cmp(&b.cost))
.expect("There should always be a possible merge");
best_merge.optimise_locally(self);
solution.push((best_merge.cost, best_merge.clone().into_clustering()));
clustering = best_merge;
}
solution.push((0.0, Clustering::default()));
solution.reverse();
solution
}
fn num_points(&self) -> usize;
#[inline]
fn optimal_clusterings(&mut self) -> Vec<(f64, Clustering)> {
let num_points = self.num_points();
let mut results = Vec::with_capacity(num_points);
for (k, (approximate_cost, approximate_clustering)) in
self.approximate_clusterings().into_iter().enumerate()
{
results.push((|| {
debug_assert_eq!(
approximate_clustering.len(),
k,
"The approximate clustering on level {k} should have exactly {k} clusters."
);
let mut min_cost = approximate_cost;
let mut to_see: BinaryHeap<ClusteringNodeMergeSingle> = BinaryHeap::new();
to_see.push(ClusteringNodeMergeSingle::empty());
while let Some(clustering_node) = to_see.pop() {
if clustering_node.clusters.len() == k
&& clustering_node.next_to_add == num_points
{
return (
clustering_node.cost,
clustering_node.clusters.into_iter().collect(),
);
}
if clustering_node.next_to_add < num_points {
for new_clustering_node in clustering_node.get_next_nodes(self, k) {
if new_clustering_node.cost < min_cost {
if new_clustering_node.clusters.len() == k
&& new_clustering_node.next_to_add == num_points
{
min_cost = new_clustering_node.cost;
}
to_see.push(new_clustering_node);
}
}
}
}
(approximate_cost, approximate_clustering)
})());
}
results
}
#[must_use]
#[inline]
fn price_of_hierarchy(&mut self) -> (f64, Vec<Clustering>) {
let num_points = self.num_points();
let opt_for_fixed_k: Vec<f64> = self
.optimal_clusterings()
.into_iter()
.map(|(cost, _)| cost)
.collect();
let (price_of_greedy, greedy_hierarchy) = self.price_of_greedy();
let mut min_hierarchy_price = MaxRatio(price_of_greedy);
let initial_clustering = ClusteringNodeMergeMultiple::new_singletons(num_points);
dijkstra(
&initial_clustering,
|clustering| {
let opt_cost =
*unsafe { opt_for_fixed_k.get_unchecked(clustering.clusters.len()-1) };
clustering
.get_all_merges(self)
.into_iter()
.filter_map(move |new_clustering| {
let ratio = MaxRatio::new(new_clustering.cost, opt_cost);
(ratio < min_hierarchy_price).then(|| {
if new_clustering.clusters.len() == 1 {
min_hierarchy_price = ratio;
}
(new_clustering, ratio)
})
})
},
|clustering| clustering.clusters.len() == 1,
)
.map_or_else(
|| (price_of_greedy, greedy_hierarchy),
|(path, cost)| {
(
cost.0,
iter::once(Clustering::default())
.chain(
path.into_iter()
.rev()
.map(ClusteringNodeMergeMultiple::into_clustering),
)
.collect(),
)
},
)
}
#[must_use]
#[inline]
fn greedy_hierarchy(&mut self) -> Vec<(f64, Clustering)> {
let num_points = self.num_points();
let mut clustering = ClusteringNodeMergeMultiple::new_singletons(num_points);
let mut solution: Vec<(f64, Clustering)> =
vec![(0.0, clustering.clone().into_clustering())];
while clustering.clusters.len() > 1 {
let best_merge = clustering
.get_all_merges(self)
.into_iter()
.min_by(|a, b| a.cost.total_cmp(&b.cost))
.expect("There should always be a possible merge");
solution.push((best_merge.cost, best_merge.clone().into_clustering()));
clustering = best_merge;
}
solution.push((0.0, Clustering::default()));
solution.reverse();
solution
}
#[must_use]
#[inline]
fn price_of_greedy(&mut self) -> (f64, Vec<Clustering>) {
let mut max_ratio = MaxRatio::zero();
let greedy_hierarchy = self.greedy_hierarchy();
let opt_for_fixed_k: Vec<f64> = self
.optimal_clusterings()
.into_iter()
.map(|(cost, _)| cost)
.collect();
for (cost, clustering) in greedy_hierarchy.iter().skip(1) {
let opt_cost = opt_for_fixed_k
.get(clustering.len())
.expect("opt_for_fixed_k should have an entry for this number of clusters.");
let ratio = MaxRatio::new(*cost, *opt_cost);
max_ratio = max_ratio + ratio;
}
let hierarchy = greedy_hierarchy.into_iter().map(|x| x.1).collect();
(max_ratio.0, hierarchy)
}
}
#[derive(Clone, Debug)]
pub struct KMedian {
distances: Distances,
costs: Costs,
}
impl KMedian {
#[inline]
pub fn l2_squared(points: &[Point]) -> Result<Self, Error> {
let verified_points = verify_points(points)?;
Ok(Self {
distances: distances_from_points_with_element_norm(verified_points, |x| x.powi(2)),
costs: Costs::default(),
})
}
#[inline]
pub fn l2(points: &[Point]) -> Result<Self, Error> {
let verified_points = verify_points(points)?;
Ok(Self {
distances: distances_from_points_with_element_norm(verified_points, |x| x.powi(2))
.iter()
.map(|vec| vec.iter().map(|x| x.sqrt()).collect())
.collect(),
costs: Costs::default(),
})
}
#[inline]
pub fn l1(points: &[Point]) -> Result<Self, Error> {
let verified_points = verify_points(points)?;
Ok(Self {
distances: distances_from_points_with_element_norm(verified_points, f64::abs),
costs: Costs::default(),
})
}
#[inline]
pub fn weighted_l2_squared(weighted_points: &[WeightedPoint]) -> Result<Self, Error> {
let verified_weighted_points = verify_weighted_points(weighted_points)?;
Ok(Self {
distances: distances_from_weighted_points_with_element_norm(
verified_weighted_points,
|x| x.powi(2),
),
costs: Costs::default(),
})
}
#[inline]
pub fn weighted_l2(weighted_points: &[WeightedPoint]) -> Result<Self, Error> {
let verified_weighted_points = verify_weighted_points(weighted_points)?;
Ok(Self {
distances: distances_from_weighted_points_with_element_norm(
verified_weighted_points,
|x| x.powi(2),
)
.iter()
.map(|vec| vec.iter().map(|x| x.sqrt()).collect())
.collect(),
costs: Costs::default(),
})
}
#[inline]
pub fn weighted_l1(weighted_points: &[WeightedPoint]) -> Result<Self, Error> {
let verified_weighted_points = verify_weighted_points(weighted_points)?;
Ok(Self {
distances: distances_from_weighted_points_with_element_norm(
verified_weighted_points,
f64::abs,
),
costs: Costs::default(),
})
}
}
impl Cost for KMedian {
#[inline]
fn num_points(&self) -> usize {
self.distances.len()
}
#[inline]
fn cost(&mut self, cluster: Cluster) -> f64 {
*self.costs.entry(cluster).or_insert_with(|| {
cluster
.iter()
.map(|center_candidate_ix| {
let center_candidate_row =
unsafe { self.distances.get_unchecked(center_candidate_ix) };
cluster
.iter()
.map(|ix| *unsafe { center_candidate_row.get_unchecked(ix) })
.sum()
})
.min_by(f64::total_cmp)
.unwrap_or(0.0)
})
}
}
fn distances_from_points_with_distance_function<T>(
points: &[T],
distance_function: impl Fn(&T, &T) -> f64,
) -> Distances {
points
.iter()
.map(|p| points.iter().map(|q| distance_function(p, q)).collect())
.collect()
}
fn distances_from_points_with_element_norm(
points: &[Point],
elementnorm: impl Fn(f64) -> f64,
) -> Distances {
distances_from_points_with_distance_function(points, |p, q| {
(p - q).map(|x| elementnorm(*x)).sum()
})
}
fn distances_from_weighted_points_with_element_norm(
points: &[WeightedPoint],
elementnorm: impl Fn(f64) -> f64,
) -> Distances {
distances_from_points_with_distance_function(points, |p, q| {
q.0 * (&p.1 - &q.1).map(|x| elementnorm(*x)).sum()
})
}
#[derive(Debug, PartialEq, Eq)]
#[expect(
clippy::exhaustive_enums,
reason = "Extending this enum should be a breaking change."
)]
pub enum Error {
EmptyPoints,
TooManyPoints(usize),
ShapeMismatch(usize, usize),
BadWeight(usize),
}
impl fmt::Display for Error {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let msg = match *self {
Self::EmptyPoints => "no points supplied".to_owned(),
Self::TooManyPoints(pointcount) => {
format!("can cluster at most {MAX_POINT_COUNT} points, but got {pointcount}")
}
Self::ShapeMismatch(ix1, ix2) => {
format!("points {ix1} and {ix2} have different dimensions",)
}
Self::BadWeight(ix) => {
format!("point {ix} doesn't have a finite and positive weight",)
}
};
f.write_str(&msg)
}
}
#[expect(
clippy::absolute_paths,
reason = "Not worth bringing into scope for one use."
)]
impl core::error::Error for Error {}
fn verify_points(points: &[Point]) -> Result<&[Point], Error> {
let point_count = points.len();
if point_count > MAX_POINT_COUNT {
return Err(Error::TooManyPoints(point_count));
}
let first_point = points.first().ok_or(Error::EmptyPoints)?;
let first_dim = first_point.raw_dim();
if let Some(ix) = points.iter().position(|p| p.raw_dim() != first_dim) {
return Err(Error::ShapeMismatch(0, ix));
}
Ok(points)
}
fn verify_weighted_points(weighted_points: &[WeightedPoint]) -> Result<&[WeightedPoint], Error> {
let point_count = weighted_points.len();
if point_count > MAX_POINT_COUNT {
return Err(Error::TooManyPoints(point_count));
}
let first_point = weighted_points.first().ok_or(Error::EmptyPoints)?;
let first_dim = first_point.1.raw_dim();
if let Some(ix) = weighted_points
.iter()
.position(|p| p.1.raw_dim() != first_dim)
{
return Err(Error::ShapeMismatch(0, ix));
}
if let Some(ix) = weighted_points
.iter()
.position(|p| !p.0.is_finite() || p.0 <= 0.0)
{
return Err(Error::BadWeight(ix));
}
Ok(weighted_points)
}
#[derive(Clone, Debug)]
pub struct KMeans {
points: Vec<Point>,
costs: Costs,
}
impl Cost for KMeans {
#[inline]
fn num_points(&self) -> usize {
self.points.len()
}
#[inline]
fn cost(&mut self, cluster: Cluster) -> f64 {
*self.costs.entry(cluster).or_insert_with(|| {
let first_point_dimensions =
unsafe { self.points.first().unwrap_unchecked() }.raw_dim();
let mut center = Array1::zeros(first_point_dimensions);
cluster
.iter()
.for_each(|i| center += unsafe { self.points.get_unchecked(i) });
center /= f64::from(cluster.len());
cluster
.iter()
.map(|i| {
let p = unsafe { self.points.get_unchecked(i) };
(p - ¢er).map(|x| x.powi(2)).sum()
})
.sum()
})
}
#[inline]
fn approximate_clusterings(&mut self) -> Vec<(f64, Clustering)> {
use clustering::kmeans;
let mut results = Vec::with_capacity(self.num_points() + 1);
results.push((0.0, Clustering::default()));
let max_iter = 1000;
let samples: Vec<Vec<f64>> = self
.points
.iter()
.map(|x| x.into_iter().copied().collect())
.collect();
results.extend((1..=self.num_points()).map(|k| {
let kmeans_clustering = kmeans(k, &samples, max_iter);
let mut clusters = vec![Cluster::new(); k];
for (point_ix, cluster_ix) in kmeans_clustering.membership.iter().enumerate() {
clusters
.get_mut(*cluster_ix)
.expect("Cluster index out of range")
.insert(point_ix);
}
let clustering: Clustering = clusters.into_iter().collect();
(self.total_cost(&clustering), clustering)
}));
results
}
}
impl KMeans {
#[inline]
pub fn new(points: &[Point]) -> Result<Self, Error> {
let verified_points = verify_points(points)?;
Ok(Self {
points: verified_points.to_vec(),
costs: Costs::default(),
})
}
}
#[derive(Clone, Debug)]
pub struct WeightedKMeans {
weighted_points: Vec<WeightedPoint>,
costs: Costs,
}
impl Cost for WeightedKMeans {
#[inline]
fn num_points(&self) -> usize {
self.weighted_points.len()
}
#[inline]
fn cost(&mut self, cluster: Cluster) -> f64 {
*self.costs.entry(cluster).or_insert_with(|| {
let mut total_weight = 0.0;
let first_point_dimensions =
unsafe { self.weighted_points.first().unwrap_unchecked() }.1.raw_dim();
let mut center: Array1<f64> = Array1::zeros(first_point_dimensions);
cluster.iter().for_each(|i| {
let weighted_point = unsafe { self.weighted_points.get_unchecked(i) };
total_weight += weighted_point.0;
center += &(&weighted_point.1 * weighted_point.0);
});
center /= total_weight;
cluster
.iter()
.map(|i| {
let weighted_point = unsafe { self.weighted_points.get_unchecked(i) };
weighted_point.0 * (&weighted_point.1 - ¢er).map(|x| x.powi(2)).sum()
})
.sum()
})
}
}
impl WeightedKMeans {
#[inline]
pub fn new(weighted_points: &[WeightedPoint]) -> Result<Self, Error> {
let verified_weighted_points = verify_weighted_points(weighted_points)?;
Ok(Self {
weighted_points: verified_weighted_points.to_vec(),
costs: Costs::default(),
})
}
}
#[inline]
pub fn cluster_from_iterator<I: IntoIterator<Item = usize>>(it: I) -> Cluster {
let mut cluster = Cluster::new();
for i in it {
cluster.insert(i);
}
cluster
}
#[cfg(test)]
mod tests {
use super::*;
use core::f64::consts::SQRT_2;
use itertools::Itertools as _;
use ndarray::array;
use smallvec::smallvec;
use std::panic::catch_unwind;
#[test]
#[should_panic(
expected = "Throughout the entire implementation, we should never to add the same point twice."
)]
fn cluster_double_insert() {
let mut cluster = Cluster::singleton(7);
cluster.insert(7);
}
#[test]
#[should_panic(
expected = "Troughout the entire implementation, we should never be merging intersecting clusters."
)]
fn cluster_intersecting_merge() {
let mut cluster7 = Cluster::singleton(7);
let mut cluster9 = Cluster::singleton(7);
cluster7.insert(8);
cluster9.insert(8);
cluster7.union_with(cluster9);
}
#[test]
fn cluster() {
for i in 0..8 {
let cluster = Cluster::singleton(i);
assert!(!cluster.is_empty());
assert_eq!(cluster.len(), 1);
assert_eq!(cluster.iter().collect_vec(), vec![i]);
for j in 0..8 {
assert_eq!(cluster.contains(j), j == i);
let cluster2 = {
let mut cluster2 = cluster;
if i != j {
cluster2.insert(j);
}
assert!(!cluster2.is_empty());
cluster2
};
assert!(!cluster2.is_empty());
assert_eq!(cluster2.len(), if i == j { 1 } else { 2 });
assert_eq!(
cluster2.iter().collect_vec(),
match i.cmp(&j) {
cmp::Ordering::Less => vec![i, j],
cmp::Ordering::Equal => vec![i],
cmp::Ordering::Greater => vec![j, i],
}
);
}
}
let mut cluster_div_3 = Cluster::new();
let mut cluster_div_5 = Cluster::new();
assert!(cluster_div_3.is_empty());
assert!(cluster_div_5.is_empty());
for i in 1..=14 {
if i % 3 == 0 {
cluster_div_3.insert(i);
assert!(!cluster_div_3.is_empty());
}
if i % 5 == 0 {
cluster_div_5.insert(i);
assert!(!cluster_div_5.is_empty());
}
}
assert_eq!(cluster_div_3.iter().collect_vec(), vec![3, 6, 9, 12]);
assert_eq!(cluster_div_5.iter().collect_vec(), vec![5, 10]);
let merged = {
let mut merged = cluster_div_3;
merged.union_with(cluster_div_5);
merged
};
assert_eq!(merged.iter().collect_vec(), vec![3, 5, 6, 9, 10, 12]);
assert_eq!(merged.to_string(), "...#.##..##.#...................");
}
#[expect(clippy::float_cmp, reason = "This should be exact.")]
#[expect(
clippy::assertions_on_result_states,
reason = "We'd like to catch the errors."
)]
#[test]
fn max_ratio() {
assert_eq!(MaxRatio::new(3.0, 1.5).0, 2.0);
assert_eq!(MaxRatio::new(SQRT_2, SQRT_2).0, 1.0);
assert_eq!(MaxRatio::new(SQRT_2, 0.0).0, f64::INFINITY);
assert_eq!(MaxRatio::new(SQRT_2, -0.0).0, f64::INFINITY);
assert_eq!(MaxRatio::new(0.0, 0.0).0, 1.0);
assert_eq!(MaxRatio::new(-0.0, 0.0).0, 1.0);
assert_eq!(MaxRatio::new(0.0, -0.0).0, 1.0);
assert_eq!(MaxRatio::new(-0.0, -0.0).0, 1.0);
assert!(catch_unwind(|| MaxRatio::new(1.0 - 1e-3, 1.0)).is_err());
assert!(catch_unwind(|| MaxRatio::new(1.0 - 1e-12, 1.0)).is_ok());
assert!(catch_unwind(|| MaxRatio::new(0.0 - 1e-12, 0.0)).is_err());
assert!(catch_unwind(|| MaxRatio::new(f64::INFINITY, 1.0)).is_err());
assert!(catch_unwind(|| MaxRatio::new(f64::NAN, 1.0)).is_err());
assert!(catch_unwind(|| MaxRatio::new(f64::NEG_INFINITY, 1.0)).is_err());
assert!(catch_unwind(|| MaxRatio::new(1.0, f64::INFINITY)).is_err());
assert!(catch_unwind(|| MaxRatio::new(1.0, f64::NAN)).is_err());
assert!(catch_unwind(|| MaxRatio::new(1.0, f64::NEG_INFINITY)).is_err());
assert!(catch_unwind(|| MaxRatio::new(1.0, 0.0)).is_ok());
assert!(catch_unwind(|| MaxRatio::new(1.0, -1e-12)).is_err());
}
macro_rules! clusterings {
( $( [ $( [ $( $num:expr ),* ] ),* ] ),* $(,)? ) => {
[
$(
vec![
$(
cluster_from_iterator([$( $num ),*]),
)*
],
)*
]
}
}
#[test]
fn node_merge_multiple() {
fn clusters_are_correct(
expected_clusterings: &[Vec<Cluster>],
nodes: &[ClusteringNodeMergeMultiple],
) {
let actual = nodes.iter().map(|x| x.clusters.to_vec()).collect_vec();
assert_eq!(
expected_clusterings, actual,
"Clustering should match expected clustering. Maybe the order of returned Clusters has changed?"
);
}
let mut kmedian =
KMedian::l2_squared(&[array![0.0], array![1.0], array![2.0], array![3.0]])
.expect("Creating kmedian should not fail.");
let mut update_nodes = |nodes: &mut Vec<ClusteringNodeMergeMultiple>| {
*nodes = nodes
.iter()
.flat_map(|n| n.get_all_merges(&mut kmedian))
.collect();
};
let mut nodes = vec![ClusteringNodeMergeMultiple::new_singletons(4)];
let expected_init_clusters = smallvec![
Cluster::singleton(0),
Cluster::singleton(1),
Cluster::singleton(2),
Cluster::singleton(3)
];
assert_eq!(
nodes,
vec![ClusteringNodeMergeMultiple {
clusters: expected_init_clusters,
cost: f64::NAN,
}],
"Testing nodes for equality should only depend on clusters, not on their cost."
);
clusters_are_correct(&clusterings![[[0], [1], [2], [3]]], &nodes);
update_nodes(&mut nodes);
clusters_are_correct(
&clusterings![
[[0, 1], [2], [3]],
[[1], [0, 2], [3]],
[[1], [2], [0, 3]],
[[0], [1, 2], [3]],
[[0], [2], [1, 3]],
[[0], [1], [2, 3]],
],
&nodes,
);
update_nodes(&mut nodes);
clusters_are_correct(
&clusterings![
[[0, 1, 2], [3]],
[[2], [0, 1, 3]],
[[0, 1], [2, 3]],
[[1, 0, 2], [3]],
[[0, 2], [1, 3]],
[[1], [0, 2, 3]],
[[1, 2], [0, 3]],
[[2], [1, 0, 3]],
[[1], [2, 0, 3]],
[[0, 1, 2], [3]],
[[1, 2], [0, 3]],
[[0], [1, 2, 3]],
[[0, 2], [1, 3]],
[[2], [0, 1, 3]],
[[0], [2, 1, 3]],
[[0, 1], [2, 3]],
[[1], [0, 2, 3]],
[[0], [1, 2, 3]],
],
&nodes,
);
update_nodes(&mut nodes);
clusters_are_correct(&vec![vec![Cluster(15)]; 18], &nodes);
}
#[test]
#[should_panic(expected = "The clusters should always be sorted, to prevent duplicates.")]
fn unsorted_node_merge_multiple() {
let unsorted = ClusteringNodeMergeMultiple {
clusters: smallvec![Cluster(1), Cluster(0)],
cost: 0.0,
};
let mut small_kmedian =
KMedian::l1(&[array![0.0], array![1.0]]).expect("Creating kmedian should not fail.");
let _: Vec<_> = unsorted
.get_all_merges(&mut small_kmedian) .into_iter()
.collect_vec();
}
#[test]
fn node_merge_single() {
fn clusters_are_correct(
expected_clusterings: &[Vec<Cluster>],
nodes: &[ClusteringNodeMergeSingle],
) {
let actual = nodes.iter().map(|x| x.clusters.to_vec()).collect_vec();
assert_eq!(
expected_clusterings, actual,
"Clustering should match expected clustering. Maybe the order of returned Clusters has changed?"
);
}
let mut kmedian =
KMedian::l2_squared(&[array![0.0], array![1.0], array![2.0], array![3.0]])
.expect("Creating kmedian should not fail.");
let mut update_nodes = |nodes: &mut Vec<ClusteringNodeMergeSingle>| {
*nodes = nodes
.iter()
.flat_map(|n| n.get_next_nodes(&mut kmedian, 3).collect_vec())
.collect();
};
let mut nodes = vec![ClusteringNodeMergeSingle::empty()];
clusters_are_correct(&clusterings![[]], &nodes);
update_nodes(&mut nodes);
clusters_are_correct(&clusterings![[[0]]], &nodes);
update_nodes(&mut nodes);
clusters_are_correct(&clusterings![[[0, 1]], [[0], [1]]], &nodes);
update_nodes(&mut nodes);
clusters_are_correct(
&clusterings![
[[0, 1, 2]],
[[0, 1], [2]],
[[0, 2], [1]],
[[0], [1, 2]],
[[0], [1], [2]],
],
&nodes,
);
update_nodes(&mut nodes);
clusters_are_correct(
&clusterings![
[[0, 1, 2, 3]],
[[0, 1, 2], [3]],
[[0, 1, 3], [2]],
[[0, 1], [2, 3]],
[[0, 1], [2], [3]],
[[0, 2, 3], [1]],
[[0, 2], [1, 3]],
[[0, 2], [1], [3]],
[[0, 3], [1, 2]],
[[0], [1, 2, 3]],
[[0], [1, 2], [3]],
[[0, 3], [1], [2]],
[[0], [1, 3], [2]],
[[0], [1], [2, 3]],
],
&nodes,
);
}
#[test]
fn infinite_loop_optimise_locally() {
let (weight_a, point_a) = (0.588_906_661, array![-0.487_778_761_130_834]);
let (weight_b, point_b) = (0.434_371_596, array![-0.438_191_407_837_575]);
let points = [
(weight_a, -point_a.clone()),
(weight_b, -point_b.clone()),
(1.0, array![0.0]),
(weight_a, point_a),
(weight_b, point_b),
];
let mut kmedian = KMedian::weighted_l1(&points).expect("Creating kmedian should not fail.");
let mut clustering = ClusteringNodeMergeMultiple {
clusters: SmallVec::from_iter([
cluster_from_iterator([0, 1, 2]),
cluster_from_iterator([3, 4]),
]),
cost: 0.488_933_068_284_744_25,
};
clustering.optimise_locally(&mut kmedian);
}
#[test]
fn infinite_loop_optimise_locally_1() {
let points = vec![
(1.870_423_609_633_216e24, array![1000.0, -1000.0, 1000.0]),
(3.817_589_201_683_946e23, array![1000.0, 1000.0, -1000.0]),
(2.074_998_884_450_784_5e21, array![1000.0, 1000.0, 1000.0]),
(
1.0,
array![
-400.240_609_956_200_4,
616.506_453_035_030_1,
-79.475_319_067_602_64
],
),
(1.0, array![-1000.0, 415.010_128_673_398_5, 1000.0]),
];
let mut kmedian = KMedian::weighted_l1(&points).expect("Creating kmedian should not fail.");
let mut clustering = ClusteringNodeMergeMultiple {
clusters: SmallVec::from_iter([
cluster_from_iterator([0, 2, 4]),
cluster_from_iterator([1, 3]),
]),
cost: 4.149_997_768_901_569e24,
};
clustering.optimise_locally(&mut kmedian);
}
}