1#![expect(
58 clippy::missing_errors_doc,
59 reason = "The Error-Enum is sparse and documented."
60)]
61
62use core::hash::{self, Hash};
63use core::{cmp, ops};
64use core::{f64, fmt, iter};
65use ndarray::Array1;
66use pathfinding::{num_traits::Zero, prelude::dijkstra};
67use rustc_hash::{FxHashMap, FxHashSet};
68use smallvec::SmallVec;
69use std::collections::BinaryHeap;
70
71#[cfg(not(target_pointer_width = "16"))]
77pub type Storage = u32;
78
79#[expect(
81 clippy::as_conversions,
82 reason = "`Storage::BITS` will always fit into a `usize`."
83)]
84pub const MAX_POINT_COUNT: usize = Storage::BITS as usize;
85
86#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
88pub struct Cluster(Storage);
89impl Cluster {
90 const fn new() -> Self {
92 Self(0)
93 }
94
95 const fn singleton(point_ix: usize) -> Self {
97 Self(1 << point_ix)
98 }
99
100 fn insert(&mut self, point_ix: usize) {
102 let point = 1 << point_ix;
103 debug_assert!(
104 (point & self.0) == 0,
105 "Throughout the entire implementation, we should never to add the same point twice."
106 );
107 self.0 |= point;
108 }
109
110 fn remove(&mut self, point_ix: usize) {
112 let point = 1 << point_ix;
113 debug_assert!(
114 (point & self.0) != 0,
115 "Throughout the entire implementation, we should never remove a non-existing point."
116 );
117 self.0 &= !point;
118 }
119
120 #[must_use]
122 #[inline]
123 pub const fn contains(self, point_ix: usize) -> bool {
124 (self.0 & (1 << point_ix)) != 0
125 }
126
127 #[must_use]
129 #[inline]
130 pub const fn len(self) -> Storage {
131 self.0.count_ones()
132 }
133
134 #[must_use]
136 #[inline]
137 pub const fn is_empty(self) -> bool {
138 self.0 == 0
139 }
140
141 #[inline]
143 #[must_use]
144 pub const fn iter(self) -> ClusterIter {
145 ClusterIter(self.0)
146 }
147
148 fn union_with(&mut self, other: Self) {
150 debug_assert!(
151 self.0 & other.0 == 0,
152 "Troughout the entire implementation, we should never be merging intersecting clusters."
153 );
154 self.0 |= other.0;
155 }
156}
157
158#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
160pub struct ClusterIter(Storage);
161impl Iterator for ClusterIter {
162 type Item = usize;
163
164 #[inline]
165 fn next(&mut self) -> Option<Self::Item> {
166 if self.0 == 0 {
167 None
168 } else {
169 #[expect(
170 clippy::as_conversions,
171 reason = "I assume `usize` is at least `Storage`."
172 )]
173 let ix = self.0.trailing_zeros() as usize;
174 self.0 &= self.0 - 1;
175 Some(ix)
176 }
177 }
178
179 #[inline]
180 fn size_hint(&self) -> (usize, Option<usize>) {
181 #[expect(
182 clippy::as_conversions,
183 reason = "I assume `usize` is at least `Storage`."
184 )]
185 let count = self.0.count_ones() as usize;
186 (count, Some(count))
187 }
188}
189
190impl IntoIterator for Cluster {
191 type Item = usize;
192 type IntoIter = ClusterIter;
193
194 #[inline]
195 fn into_iter(self) -> Self::IntoIter {
196 ClusterIter(self.0)
197 }
198}
199
200impl fmt::Display for Cluster {
201 #[inline]
202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203 #[expect(
204 clippy::as_conversions,
205 reason = "I assume `usize` is at least `Storage`."
206 )]
207 let mut result = String::with_capacity(Storage::BITS as usize);
208 let mut bits = self.0;
209 for _ in 0..Storage::BITS {
210 if (bits & 1) == 1 {
211 result.push('#');
212 } else {
213 result.push('.');
214 }
215 bits >>= 1;
216 }
217 write!(f, "{result}")
218 }
219}
220
221pub type Clustering = FxHashSet<Cluster>;
223
224type Distances = Vec<Vec<f64>>;
228
229pub type Point = Array1<f64>;
231pub type WeightedPoint = (f64, Array1<f64>);
235
236#[derive(Clone, Debug)]
237struct ClusteringNodeMergeMultiple {
239 clusters: SmallVec<[Cluster; 6]>,
249 cost: f64,
256}
257impl PartialEq for ClusteringNodeMergeMultiple {
259 fn eq(&self, other: &Self) -> bool {
260 self.clusters == other.clusters
261 }
262}
263impl Eq for ClusteringNodeMergeMultiple {}
264impl Hash for ClusteringNodeMergeMultiple {
265 fn hash<H: hash::Hasher>(&self, state: &mut H) {
266 self.clusters.hash(state);
267 }
268}
269impl ClusteringNodeMergeMultiple {
270 #[must_use]
277 #[inline]
278 fn get_all_merges<C: Cost + ?Sized>(&self, data: &mut C) -> Vec<Self> {
279 debug_assert!(
280 self.clusters.is_sorted(),
281 "The clusters should always be sorted, to prevent duplicates."
282 );
283
284 #[expect(
285 clippy::integer_division,
286 reason = "At least one of the factors is always even."
287 )]
288 let mut nodes = Vec::with_capacity(self.clusters.len() * (self.clusters.len() - 1) / 2);
289 for i in 0..(self.clusters.len() - 1) {
290 let (cluster_i, clusters_minus_i) = {
292 let mut clusters_minus_i = self.clusters.clone();
293 let cluster_i = clusters_minus_i.remove(i);
295 (cluster_i, clusters_minus_i)
296 };
297 let cost_minus_i = self.cost - data.cost(cluster_i);
298 nodes.extend((i..clusters_minus_i.len()).map(|j| {
300 let mut new_clusters = clusters_minus_i.clone();
301 let cluster_j = unsafe { new_clusters.get_unchecked_mut(j) };
305 let mut new_cost = cost_minus_i - data.cost(*cluster_j);
306 cluster_j.union_with(cluster_i);
307 new_cost += data.cost(*cluster_j);
308
309 debug_assert!(new_clusters.len() == self.clusters.len() - 1, "We should have merged two clusters, which should have reduced the number of clusters by exactly one.");
310 debug_assert!(new_clusters.is_sorted(), "The clusters should always be sorted, to prevent duplicates.");
311 debug_assert!({
312 (0..data.num_points()).all(|point_ix| new_clusters.iter().filter(|cluster| cluster.contains(point_ix)).count()==1)
313 },"The clusters should always cover every point exactly once.");
314 Self {
315 clusters: new_clusters,
316 cost: new_cost,
317 }
318 }));
319 }
320 nodes
321 }
322
323 fn optimise_locally<C: Cost + ?Sized>(&mut self, data: &mut C) {
328 let mut already_visited: FxHashSet<(Cluster, usize, usize)> = FxHashSet::default();
332 let mut found_improvement = || {
333 #[expect(
334 clippy::indexing_slicing,
335 reason = "These are safe, we just use indices to avoid borrow-issues."
336 )]
337 for source_cluster_ix in 0..self.clusters.len() {
338 let source_cluster = self.clusters[source_cluster_ix];
339 for point_ix in source_cluster {
340 let mut updated_source_cluster = source_cluster;
341 updated_source_cluster.remove(point_ix);
342 let source_costdelta =
343 data.cost(updated_source_cluster) - data.cost(source_cluster);
344
345 for target_cluster_ix in
346 (0..self.clusters.len()).filter(|ix| *ix != source_cluster_ix)
347 {
348 if !already_visited.insert((
349 source_cluster,
350 source_cluster_ix,
351 target_cluster_ix,
352 )) {
353 continue;
354 }
355 let target_cluster = self.clusters[target_cluster_ix];
356
357 let mut updated_target_cluster = target_cluster;
358 updated_target_cluster.insert(point_ix);
359 let costdelta = source_costdelta + data.cost(updated_target_cluster)
360 - data.cost(target_cluster);
361 if costdelta < 0.0 {
362 if updated_source_cluster.cmp(&updated_target_cluster)
364 == source_cluster_ix.cmp(&target_cluster_ix)
365 {
366 self.clusters[source_cluster_ix] = updated_source_cluster;
367 self.clusters[target_cluster_ix] = updated_target_cluster;
368 } else {
369 self.clusters[source_cluster_ix] = updated_target_cluster;
370 self.clusters[target_cluster_ix] = updated_source_cluster;
371 }
372 self.cost += costdelta;
373 return true;
374 }
375 }
376 }
377 }
378 false
379 };
380
381 while found_improvement() {}
382
383 self.clusters.sort();
384
385 debug_assert!(
386 {
387 (0..data.num_points()).all(|point_ix| {
388 self.clusters
389 .iter()
390 .filter(|cluster| cluster.contains(point_ix))
391 .count()
392 == 1
393 })
394 },
395 "The clusters should always cover every point exactly once."
396 );
397 }
398
399 #[inline]
401 fn new_singletons(num_points: usize) -> Self {
402 let mut clusters = SmallVec::default();
403 for i in 0..num_points {
404 clusters.push(Cluster::singleton(i));
405 }
406 debug_assert!(
407 clusters.is_sorted(),
408 "The clusters should always be sorted, to prevent duplicates."
409 );
410 Self {
411 clusters,
412 cost: 0.0,
413 }
414 }
415
416 #[inline]
418 fn into_clustering(self) -> Clustering {
419 self.clusters.into_iter().collect()
420 }
421}
422
423#[derive(Clone, Debug)]
424struct ClusteringNodeMergeSingle {
429 clusters: SmallVec<[Cluster; 6]>,
437 cost: f64,
441 next_to_add: usize,
452}
453impl PartialEq for ClusteringNodeMergeSingle {
454 fn eq(&self, other: &Self) -> bool {
455 self.clusters == other.clusters
456 }
457}
458impl Eq for ClusteringNodeMergeSingle {}
459impl Hash for ClusteringNodeMergeSingle {
460 fn hash<H: hash::Hasher>(&self, state: &mut H) {
461 self.clusters.hash(state);
462 }
463}
464impl Ord for ClusteringNodeMergeSingle {
465 fn cmp(&self, other: &Self) -> cmp::Ordering {
467 other
468 .cost
469 .total_cmp(&self.cost)
470 .then_with(|| self.clusters.cmp(&other.clusters))
471 }
472}
473impl PartialOrd for ClusteringNodeMergeSingle {
474 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
475 Some(self.cmp(other))
476 }
477}
478impl ClusteringNodeMergeSingle {
479 #[inline]
484 fn get_next_nodes<'a, C: Cost + ?Sized>(
485 &'a self,
486 data: &'a mut C,
487 k: usize,
488 ) -> impl Iterator<Item = Self> + use<'a, C> {
489 (0..self.clusters.len())
490 .map(|cluster_ix| {
491 let mut new_clustering_node = self.clone();
492 let cluster_to_edit =
496 unsafe { new_clustering_node.clusters.get_unchecked_mut(cluster_ix) };
497 new_clustering_node.cost -= data.cost(*cluster_to_edit);
498 cluster_to_edit.insert(new_clustering_node.next_to_add);
499 new_clustering_node.cost += data.cost(*cluster_to_edit);
500 new_clustering_node.next_to_add += 1;
501 new_clustering_node
502 })
503 .chain((self.clusters.len() < k).then(|| {
504 let mut clustering_node = self.clone();
505 clustering_node
506 .clusters
507 .push(Cluster::singleton(clustering_node.next_to_add));
508 clustering_node.next_to_add += 1;
509 clustering_node
510 }))
511 }
512
513 fn empty() -> Self {
517 Self {
518 clusters: SmallVec::default(),
519 cost: 0.0,
520 next_to_add: 0,
521 }
522 }
523}
524
525#[derive(Debug, PartialEq, Clone, Copy)]
526struct MaxRatio(f64);
533impl MaxRatio {
534 #[inline]
539 fn new(clustering_cost: f64, opt_cost: f64) -> Self {
540 debug_assert!(
541 clustering_cost.is_finite(),
542 "hierarchy_cost {clustering_cost} should be finite."
543 );
544 debug_assert!(
545 opt_cost.is_finite(),
546 "opt_cost {opt_cost} should be finite."
547 );
548 debug_assert!(
549 opt_cost >= 0.0,
550 "opt_cost {opt_cost} should be non-negative."
551 );
552 debug_assert!(
553 clustering_cost >= 0.0,
554 "hierarchy_cost {clustering_cost} should be non-negative"
555 );
556 debug_assert!(
557 clustering_cost >= opt_cost - 1e-9,
558 "hierarchy_cost {clustering_cost} should be at least opt_cost {opt_cost}"
559 );
560 Self(if opt_cost.is_zero() {
561 if clustering_cost.is_zero() {
562 1.0
563 } else {
564 f64::INFINITY
565 }
566 } else {
567 clustering_cost / opt_cost
568 })
569 }
570}
571impl Eq for MaxRatio {} impl Ord for MaxRatio {
573 fn cmp(&self, other: &Self) -> cmp::Ordering {
574 self.0.total_cmp(&other.0)
575 }
576}
577impl PartialOrd for MaxRatio {
578 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
579 Some(self.cmp(other))
580 }
581}
582impl ops::Add for MaxRatio {
583 type Output = Self;
584 fn add(self, rhs: Self) -> Self {
585 Self(self.0.max(rhs.0))
586 }
587}
588impl Zero for MaxRatio {
589 fn zero() -> Self {
590 Self(1.0)
591 }
592 #[expect(clippy::float_cmp, reason = "This should be exact.")]
593 fn is_zero(&self) -> bool {
594 self.0 == 1.0
595 }
596}
597
598type Costs = FxHashMap<Cluster, f64>;
600
601pub trait Cost {
605 fn cost(&mut self, cluster: Cluster) -> f64;
611
612 #[inline]
614 fn total_cost(&mut self, clustering: &Clustering) -> f64 {
615 clustering.iter().map(|cluster| self.cost(*cluster)).sum()
616 }
617
618 #[inline]
626 fn approximate_clusterings(&mut self) -> Vec<(f64, Clustering)> {
627 let num_points = self.num_points();
630
631 let mut clustering = ClusteringNodeMergeMultiple::new_singletons(num_points);
632 let mut solution: Vec<(f64, Clustering)> =
633 vec![(0.0, clustering.clone().into_clustering())];
634
635 while clustering.clusters.len() > 1 {
636 let mut best_merge = clustering
637 .get_all_merges(self)
638 .into_iter()
639 .min_by(|a, b| a.cost.total_cmp(&b.cost))
640 .expect("There should always be a possible merge");
641 best_merge.optimise_locally(self);
642
643 solution.push((best_merge.cost, best_merge.clone().into_clustering()));
644 clustering = best_merge;
645 }
646
647 solution.push((0.0, Clustering::default()));
648 solution.reverse();
649 solution
650 }
651
652 fn num_points(&self) -> usize;
656
657 #[inline]
663 fn optimal_clusterings(&mut self) -> Vec<(f64, Clustering)> {
664 let num_points = self.num_points();
665 let mut results = Vec::with_capacity(num_points);
666
667 for (k, (approximate_cost, approximate_clustering)) in
670 self.approximate_clusterings().into_iter().enumerate()
671 {
672 results.push((|| {
673 debug_assert_eq!(
674 approximate_clustering.len(),
675 k,
676 "The approximate clustering on level {k} should have exactly {k} clusters."
677 );
678 let mut min_cost = approximate_cost;
679
680 let mut to_see: BinaryHeap<ClusteringNodeMergeSingle> = BinaryHeap::new();
681 to_see.push(ClusteringNodeMergeSingle::empty());
682
683 while let Some(clustering_node) = to_see.pop() {
684 if clustering_node.clusters.len() == k
685 && clustering_node.next_to_add == num_points
686 {
687 return (
688 clustering_node.cost,
689 clustering_node.clusters.into_iter().collect(),
690 );
691 }
692 if clustering_node.next_to_add < num_points {
693 for new_clustering_node in clustering_node.get_next_nodes(self, k) {
694 if new_clustering_node.cost < min_cost {
695 if new_clustering_node.clusters.len() == k
696 && new_clustering_node.next_to_add == num_points
697 {
698 min_cost = new_clustering_node.cost;
699 }
700 to_see.push(new_clustering_node);
701 }
702 }
703 }
704 }
705 (approximate_cost, approximate_clustering)
708 })());
709 }
710 results
711 }
712
713 #[must_use]
726 #[inline]
727 fn price_of_hierarchy(&mut self) -> (f64, Vec<Clustering>) {
728 let num_points = self.num_points();
729 let opt_for_fixed_k: Vec<f64> = self
730 .optimal_clusterings()
731 .into_iter()
732 .map(|(cost, _)| cost)
733 .collect();
734
735 let (price_of_greedy, greedy_hierarchy) = self.price_of_greedy();
736 let mut min_hierarchy_price = MaxRatio(price_of_greedy);
737 let initial_clustering = ClusteringNodeMergeMultiple::new_singletons(num_points);
738 dijkstra(
746 &initial_clustering,
747 |clustering| {
748 let opt_cost =
749 *unsafe { opt_for_fixed_k.get_unchecked(clustering.clusters.len()-1) };
754 clustering
755 .get_all_merges(self)
756 .into_iter()
757 .filter_map(move |new_clustering| {
758 let ratio = MaxRatio::new(new_clustering.cost, opt_cost);
759 (ratio < min_hierarchy_price).then(|| {
760 if new_clustering.clusters.len() == 1 {
761 min_hierarchy_price = ratio;
762 }
763 (new_clustering, ratio)
764 })
765 })
766 },
767 |clustering| clustering.clusters.len() == 1,
768 )
769 .map_or_else(
770 || (price_of_greedy, greedy_hierarchy),
771 |(path, cost)| {
772 (
773 cost.0,
774 iter::once(Clustering::default())
775 .chain(
776 path.into_iter()
777 .rev()
778 .map(ClusteringNodeMergeMultiple::into_clustering),
779 )
780 .collect(),
781 )
782 },
783 )
784 }
785
786 #[must_use]
787 #[inline]
788 fn greedy_hierarchy(&mut self) -> Vec<(f64, Clustering)> {
800 let num_points = self.num_points();
801
802 let mut clustering = ClusteringNodeMergeMultiple::new_singletons(num_points);
803 let mut solution: Vec<(f64, Clustering)> =
804 vec![(0.0, clustering.clone().into_clustering())];
805
806 while clustering.clusters.len() > 1 {
807 let best_merge = clustering
808 .get_all_merges(self)
809 .into_iter()
810 .min_by(|a, b| a.cost.total_cmp(&b.cost))
811 .expect("There should always be a possible merge");
812 solution.push((best_merge.cost, best_merge.clone().into_clustering()));
813 clustering = best_merge;
814 }
815
816 solution.push((0.0, Clustering::default()));
817 solution.reverse();
818 solution
819 }
820
821 #[must_use]
826 #[inline]
827 fn price_of_greedy(&mut self) -> (f64, Vec<Clustering>) {
828 let mut max_ratio = MaxRatio::zero();
829 let greedy_hierarchy = self.greedy_hierarchy();
830 let opt_for_fixed_k: Vec<f64> = self
834 .optimal_clusterings()
835 .into_iter()
836 .map(|(cost, _)| cost)
837 .collect();
838
839 for (cost, clustering) in greedy_hierarchy.iter().skip(1) {
841 let opt_cost = opt_for_fixed_k
842 .get(clustering.len())
843 .expect("opt_for_fixed_k should have an entry for this number of clusters.");
844 let ratio = MaxRatio::new(*cost, *opt_cost);
845 max_ratio = max_ratio + ratio;
846 }
847
848 let hierarchy = greedy_hierarchy.into_iter().map(|x| x.1).collect();
849 (max_ratio.0, hierarchy)
850 }
851}
852
853#[derive(Clone, Debug)]
859pub struct KMedian {
860 distances: Distances,
862 costs: Costs,
864}
865impl KMedian {
866 #[inline]
877 pub fn l2_squared(points: &[Point]) -> Result<Self, Error> {
878 let verified_points = verify_points(points)?;
879 Ok(Self {
880 distances: distances_from_points_with_element_norm(verified_points, |x| x.powi(2)),
881 costs: Costs::default(),
882 })
883 }
884
885 #[inline]
898 pub fn l2(points: &[Point]) -> Result<Self, Error> {
899 let verified_points = verify_points(points)?;
900 Ok(Self {
901 distances: distances_from_points_with_element_norm(verified_points, |x| x.powi(2))
902 .iter()
903 .map(|vec| vec.iter().map(|x| x.sqrt()).collect())
904 .collect(),
905 costs: Costs::default(),
906 })
907 }
908
909 #[inline]
920 pub fn l1(points: &[Point]) -> Result<Self, Error> {
921 let verified_points = verify_points(points)?;
922 Ok(Self {
923 distances: distances_from_points_with_element_norm(verified_points, f64::abs),
924 costs: Costs::default(),
925 })
926 }
927
928 #[inline]
947 pub fn weighted_l2_squared(weighted_points: &[WeightedPoint]) -> Result<Self, Error> {
948 let verified_weighted_points = verify_weighted_points(weighted_points)?;
949 Ok(Self {
950 distances: distances_from_weighted_points_with_element_norm(
951 verified_weighted_points,
952 |x| x.powi(2),
953 ),
954 costs: Costs::default(),
955 })
956 }
957
958 #[inline]
977 pub fn weighted_l2(weighted_points: &[WeightedPoint]) -> Result<Self, Error> {
978 let verified_weighted_points = verify_weighted_points(weighted_points)?;
979 Ok(Self {
980 distances: distances_from_weighted_points_with_element_norm(
981 verified_weighted_points,
982 |x| x.powi(2),
983 )
984 .iter()
985 .map(|vec| vec.iter().map(|x| x.sqrt()).collect())
986 .collect(),
987 costs: Costs::default(),
988 })
989 }
990
991 #[inline]
1009 pub fn weighted_l1(weighted_points: &[WeightedPoint]) -> Result<Self, Error> {
1010 let verified_weighted_points = verify_weighted_points(weighted_points)?;
1011 Ok(Self {
1012 distances: distances_from_weighted_points_with_element_norm(
1013 verified_weighted_points,
1014 f64::abs,
1015 ),
1016 costs: Costs::default(),
1017 })
1018 }
1019}
1020impl Cost for KMedian {
1021 #[inline]
1024 fn num_points(&self) -> usize {
1025 self.distances.len()
1026 }
1027 #[inline]
1028 fn cost(&mut self, cluster: Cluster) -> f64 {
1029 *self.costs.entry(cluster).or_insert_with(|| {
1030 cluster
1031 .iter()
1032 .map(|center_candidate_ix| {
1033 let center_candidate_row =
1034 unsafe { self.distances.get_unchecked(center_candidate_ix) };
1039 cluster
1040 .iter()
1041 .map(|ix| *unsafe { center_candidate_row.get_unchecked(ix) })
1045 .sum()
1046 })
1047 .min_by(f64::total_cmp)
1048 .unwrap_or(0.0)
1049 })
1050 }
1051}
1052
1053fn distances_from_points_with_distance_function<T>(
1057 points: &[T],
1058 distance_function: impl Fn(&T, &T) -> f64,
1059) -> Distances {
1060 points
1061 .iter()
1062 .map(|p| points.iter().map(|q| distance_function(p, q)).collect())
1063 .collect()
1064}
1065
1066fn distances_from_points_with_element_norm(
1071 points: &[Point],
1072 elementnorm: impl Fn(f64) -> f64,
1073) -> Distances {
1074 distances_from_points_with_distance_function(points, |p, q| {
1075 (p - q).map(|x| elementnorm(*x)).sum()
1076 })
1077}
1078
1079fn distances_from_weighted_points_with_element_norm(
1085 points: &[WeightedPoint],
1086 elementnorm: impl Fn(f64) -> f64,
1087) -> Distances {
1088 distances_from_points_with_distance_function(points, |p, q| {
1089 q.0 * (&p.1 - &q.1).map(|x| elementnorm(*x)).sum()
1090 })
1091}
1092
1093#[derive(Debug, PartialEq, Eq)]
1095#[expect(
1096 clippy::exhaustive_enums,
1097 reason = "Extending this enum should be a breaking change."
1098)]
1099pub enum Error {
1100 EmptyPoints,
1102 TooManyPoints(usize),
1104 ShapeMismatch(usize, usize),
1106 BadWeight(usize),
1112}
1113
1114impl fmt::Display for Error {
1115 #[inline]
1116 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1117 let msg = match *self {
1118 Self::EmptyPoints => "no points supplied".to_owned(),
1119 Self::TooManyPoints(pointcount) => {
1120 format!("can cluster at most {MAX_POINT_COUNT} points, but got {pointcount}")
1121 }
1122 Self::ShapeMismatch(ix1, ix2) => {
1123 format!("points {ix1} and {ix2} have different dimensions",)
1124 }
1125 Self::BadWeight(ix) => {
1126 format!("point {ix} doesn't have a finite and positive weight",)
1127 }
1128 };
1129 f.write_str(&msg)
1130 }
1131}
1132
1133#[expect(
1134 clippy::absolute_paths,
1135 reason = "Not worth bringing into scope for one use."
1136)]
1137impl core::error::Error for Error {}
1138
1139fn verify_points(points: &[Point]) -> Result<&[Point], Error> {
1141 let point_count = points.len();
1142 if point_count > MAX_POINT_COUNT {
1143 return Err(Error::TooManyPoints(point_count));
1144 }
1145
1146 let first_point = points.first().ok_or(Error::EmptyPoints)?;
1147 let first_dim = first_point.raw_dim();
1148
1149 if let Some(ix) = points.iter().position(|p| p.raw_dim() != first_dim) {
1150 return Err(Error::ShapeMismatch(0, ix));
1151 }
1152
1153 Ok(points)
1154}
1155
1156fn verify_weighted_points(weighted_points: &[WeightedPoint]) -> Result<&[WeightedPoint], Error> {
1158 let point_count = weighted_points.len();
1159 if point_count > MAX_POINT_COUNT {
1160 return Err(Error::TooManyPoints(point_count));
1161 }
1162
1163 let first_point = weighted_points.first().ok_or(Error::EmptyPoints)?;
1164 let first_dim = first_point.1.raw_dim();
1165
1166 if let Some(ix) = weighted_points
1167 .iter()
1168 .position(|p| p.1.raw_dim() != first_dim)
1169 {
1170 return Err(Error::ShapeMismatch(0, ix));
1171 }
1172
1173 if let Some(ix) = weighted_points
1174 .iter()
1175 .position(|p| !p.0.is_finite() || p.0 <= 0.0)
1176 {
1177 return Err(Error::BadWeight(ix));
1178 }
1179
1180 Ok(weighted_points)
1181}
1182
1183#[derive(Clone, Debug)]
1195pub struct KMeans {
1196 points: Vec<Point>,
1198 costs: Costs,
1200}
1201impl Cost for KMeans {
1202 #[inline]
1203 fn num_points(&self) -> usize {
1204 self.points.len()
1205 }
1206 #[inline]
1207 fn cost(&mut self, cluster: Cluster) -> f64 {
1208 *self.costs.entry(cluster).or_insert_with(|| {
1209 let first_point_dimensions =
1210 unsafe { self.points.first().unwrap_unchecked() }.raw_dim();
1213 let mut center = Array1::zeros(first_point_dimensions);
1214
1215 cluster
1217 .iter()
1218 .for_each(|i| center += unsafe { self.points.get_unchecked(i) });
1221
1222 center /= f64::from(cluster.len());
1224 cluster
1225 .iter()
1226 .map(|i| {
1227 let p = unsafe { self.points.get_unchecked(i) };
1230 (p - ¢er).map(|x| x.powi(2)).sum()
1231 })
1232 .sum()
1233 })
1234 }
1235 #[inline]
1236 fn approximate_clusterings(&mut self) -> Vec<(f64, Clustering)> {
1237 use clustering::kmeans;
1238 let mut results = Vec::with_capacity(self.num_points() + 1);
1239 results.push((0.0, Clustering::default()));
1240 let max_iter = 1000;
1241 let samples: Vec<Vec<f64>> = self
1242 .points
1243 .iter()
1244 .map(|x| x.into_iter().copied().collect())
1245 .collect();
1246 results.extend((1..=self.num_points()).map(|k| {
1247 let kmeans_clustering = kmeans(k, &samples, max_iter);
1248 let mut clusters = vec![Cluster::new(); k];
1249 for (point_ix, cluster_ix) in kmeans_clustering.membership.iter().enumerate() {
1250 clusters
1251 .get_mut(*cluster_ix)
1252 .expect("Cluster index out of range")
1253 .insert(point_ix);
1254 }
1255 let clustering: Clustering = clusters.into_iter().collect();
1256 (self.total_cost(&clustering), clustering)
1257 }));
1258 results
1259 }
1260}
1261impl KMeans {
1262 #[inline]
1273 pub fn new(points: &[Point]) -> Result<Self, Error> {
1274 let verified_points = verify_points(points)?;
1275 Ok(Self {
1276 points: verified_points.to_vec(),
1277 costs: Costs::default(),
1278 })
1279 }
1280}
1281
1282#[derive(Clone, Debug)]
1296pub struct WeightedKMeans {
1297 weighted_points: Vec<WeightedPoint>,
1299 costs: Costs,
1301}
1302impl Cost for WeightedKMeans {
1303 #[inline]
1304 fn num_points(&self) -> usize {
1305 self.weighted_points.len()
1306 }
1307 #[inline]
1308 fn cost(&mut self, cluster: Cluster) -> f64 {
1309 *self.costs.entry(cluster).or_insert_with(|| {
1310 let mut total_weight = 0.0;
1311 let first_point_dimensions =
1312 unsafe { self.weighted_points.first().unwrap_unchecked() }.1.raw_dim();
1315 let mut center: Array1<f64> = Array1::zeros(first_point_dimensions);
1316
1317 cluster.iter().for_each(|i| {
1320 let weighted_point = unsafe { self.weighted_points.get_unchecked(i) };
1323 total_weight += weighted_point.0;
1324 center += &(&weighted_point.1 * weighted_point.0);
1325 });
1326
1327 center /= total_weight;
1330
1331 cluster
1332 .iter()
1333 .map(|i| {
1334 let weighted_point = unsafe { self.weighted_points.get_unchecked(i) };
1337 weighted_point.0 * (&weighted_point.1 - ¢er).map(|x| x.powi(2)).sum()
1338 })
1339 .sum()
1340 })
1341 }
1342}
1343impl WeightedKMeans {
1344 #[inline]
1363 pub fn new(weighted_points: &[WeightedPoint]) -> Result<Self, Error> {
1364 let verified_weighted_points = verify_weighted_points(weighted_points)?;
1365 Ok(Self {
1366 weighted_points: verified_weighted_points.to_vec(),
1367 costs: Costs::default(),
1368 })
1369 }
1370}
1371
1372#[inline]
1398pub fn cluster_from_iterator<I: IntoIterator<Item = usize>>(it: I) -> Cluster {
1399 let mut cluster = Cluster::new();
1400 for i in it {
1401 cluster.insert(i);
1402 }
1403 cluster
1404}
1405
1406#[cfg(test)]
1407mod tests {
1408 use super::*;
1409 use core::f64::consts::SQRT_2;
1410 use itertools::Itertools as _;
1411 use ndarray::array;
1412 use smallvec::smallvec;
1413 use std::panic::catch_unwind;
1414
1415 #[test]
1416 #[should_panic(
1417 expected = "Throughout the entire implementation, we should never to add the same point twice."
1418 )]
1419 fn cluster_double_insert() {
1420 let mut cluster = Cluster::singleton(7);
1421 cluster.insert(7);
1422 }
1423
1424 #[test]
1425 #[should_panic(
1426 expected = "Troughout the entire implementation, we should never be merging intersecting clusters."
1427 )]
1428 fn cluster_intersecting_merge() {
1429 let mut cluster7 = Cluster::singleton(7);
1430 let mut cluster9 = Cluster::singleton(7);
1431 cluster7.insert(8);
1432 cluster9.insert(8);
1433 cluster7.union_with(cluster9);
1434 }
1435
1436 #[test]
1437 fn cluster() {
1438 for i in 0..8 {
1439 let cluster = Cluster::singleton(i);
1440 assert!(!cluster.is_empty());
1441 assert_eq!(cluster.len(), 1);
1442 assert_eq!(cluster.iter().collect_vec(), vec![i]);
1443 for j in 0..8 {
1444 assert_eq!(cluster.contains(j), j == i);
1445 let cluster2 = {
1446 let mut cluster2 = cluster;
1447 if i != j {
1448 cluster2.insert(j);
1449 }
1450 assert!(!cluster2.is_empty());
1451 cluster2
1452 };
1453 assert!(!cluster2.is_empty());
1454 assert_eq!(cluster2.len(), if i == j { 1 } else { 2 });
1455 assert_eq!(
1456 cluster2.iter().collect_vec(),
1457 match i.cmp(&j) {
1458 cmp::Ordering::Less => vec![i, j],
1459 cmp::Ordering::Equal => vec![i],
1460 cmp::Ordering::Greater => vec![j, i],
1461 }
1462 );
1463 }
1464 }
1465 let mut cluster_div_3 = Cluster::new();
1466 let mut cluster_div_5 = Cluster::new();
1467 assert!(cluster_div_3.is_empty());
1468 assert!(cluster_div_5.is_empty());
1469 for i in 1..=14 {
1471 if i % 3 == 0 {
1472 cluster_div_3.insert(i);
1473 assert!(!cluster_div_3.is_empty());
1474 }
1475 if i % 5 == 0 {
1476 cluster_div_5.insert(i);
1477 assert!(!cluster_div_5.is_empty());
1478 }
1479 }
1480 assert_eq!(cluster_div_3.iter().collect_vec(), vec![3, 6, 9, 12]);
1481 assert_eq!(cluster_div_5.iter().collect_vec(), vec![5, 10]);
1482 let merged = {
1483 let mut merged = cluster_div_3;
1484 merged.union_with(cluster_div_5);
1485 merged
1486 };
1487 assert_eq!(merged.iter().collect_vec(), vec![3, 5, 6, 9, 10, 12]);
1488
1489 assert_eq!(merged.to_string(), "...#.##..##.#...................");
1490 }
1491
1492 #[expect(clippy::float_cmp, reason = "This should be exact.")]
1493 #[expect(
1494 clippy::assertions_on_result_states,
1495 reason = "We'd like to catch the errors."
1496 )]
1497 #[test]
1498 fn max_ratio() {
1499 assert_eq!(MaxRatio::new(3.0, 1.5).0, 2.0);
1500 assert_eq!(MaxRatio::new(SQRT_2, SQRT_2).0, 1.0);
1501 assert_eq!(MaxRatio::new(SQRT_2, 0.0).0, f64::INFINITY);
1502 assert_eq!(MaxRatio::new(SQRT_2, -0.0).0, f64::INFINITY);
1503 assert_eq!(MaxRatio::new(0.0, 0.0).0, 1.0);
1504 assert_eq!(MaxRatio::new(-0.0, 0.0).0, 1.0);
1505 assert_eq!(MaxRatio::new(0.0, -0.0).0, 1.0);
1506 assert_eq!(MaxRatio::new(-0.0, -0.0).0, 1.0);
1507 assert!(catch_unwind(|| MaxRatio::new(1.0 - 1e-3, 1.0)).is_err());
1508 assert!(catch_unwind(|| MaxRatio::new(1.0 - 1e-12, 1.0)).is_ok());
1509 assert!(catch_unwind(|| MaxRatio::new(0.0 - 1e-12, 0.0)).is_err());
1510 assert!(catch_unwind(|| MaxRatio::new(f64::INFINITY, 1.0)).is_err());
1511 assert!(catch_unwind(|| MaxRatio::new(f64::NAN, 1.0)).is_err());
1512 assert!(catch_unwind(|| MaxRatio::new(f64::NEG_INFINITY, 1.0)).is_err());
1513 assert!(catch_unwind(|| MaxRatio::new(1.0, f64::INFINITY)).is_err());
1514 assert!(catch_unwind(|| MaxRatio::new(1.0, f64::NAN)).is_err());
1515 assert!(catch_unwind(|| MaxRatio::new(1.0, f64::NEG_INFINITY)).is_err());
1516 assert!(catch_unwind(|| MaxRatio::new(1.0, 0.0)).is_ok());
1517 assert!(catch_unwind(|| MaxRatio::new(1.0, -1e-12)).is_err());
1518 }
1519
1520 macro_rules! clusterings {
1521 ( $( [ $( [ $( $num:expr ),* ] ),* ] ),* $(,)? ) => {
1522 [
1523 $(
1524 vec![
1525 $(
1526 cluster_from_iterator([$( $num ),*]),
1527 )*
1528 ],
1529 )*
1530 ]
1531 }
1532 }
1533
1534 #[test]
1535 fn node_merge_multiple() {
1536 fn clusters_are_correct(
1537 expected_clusterings: &[Vec<Cluster>],
1538 nodes: &[ClusteringNodeMergeMultiple],
1539 ) {
1540 let actual = nodes.iter().map(|x| x.clusters.to_vec()).collect_vec();
1541 assert_eq!(
1542 expected_clusterings, actual,
1543 "Clustering should match expected clustering. Maybe the order of returned Clusters has changed?"
1544 );
1545 }
1546 let mut kmedian =
1547 KMedian::l2_squared(&[array![0.0], array![1.0], array![2.0], array![3.0]])
1548 .expect("Creating kmedian should not fail.");
1549 let mut update_nodes = |nodes: &mut Vec<ClusteringNodeMergeMultiple>| {
1550 *nodes = nodes
1551 .iter()
1552 .flat_map(|n| n.get_all_merges(&mut kmedian))
1553 .collect();
1554 };
1555 let mut nodes = vec![ClusteringNodeMergeMultiple::new_singletons(4)];
1556 let expected_init_clusters = smallvec![
1557 Cluster::singleton(0),
1558 Cluster::singleton(1),
1559 Cluster::singleton(2),
1560 Cluster::singleton(3)
1561 ];
1562 assert_eq!(
1563 nodes,
1564 vec![ClusteringNodeMergeMultiple {
1565 clusters: expected_init_clusters,
1566 cost: f64::NAN,
1567 }],
1568 "Testing nodes for equality should only depend on clusters, not on their cost."
1569 );
1570 clusters_are_correct(&clusterings![[[0], [1], [2], [3]]], &nodes);
1571
1572 update_nodes(&mut nodes);
1573 clusters_are_correct(
1574 &clusterings![
1575 [[0, 1], [2], [3]],
1576 [[1], [0, 2], [3]],
1577 [[1], [2], [0, 3]],
1578 [[0], [1, 2], [3]],
1579 [[0], [2], [1, 3]],
1580 [[0], [1], [2, 3]],
1581 ],
1582 &nodes,
1583 );
1584
1585 update_nodes(&mut nodes);
1586 clusters_are_correct(
1587 &clusterings![
1588 [[0, 1, 2], [3]],
1589 [[2], [0, 1, 3]],
1590 [[0, 1], [2, 3]],
1591 [[1, 0, 2], [3]],
1592 [[0, 2], [1, 3]],
1593 [[1], [0, 2, 3]],
1594 [[1, 2], [0, 3]],
1595 [[2], [1, 0, 3]],
1596 [[1], [2, 0, 3]],
1597 [[0, 1, 2], [3]],
1598 [[1, 2], [0, 3]],
1599 [[0], [1, 2, 3]],
1600 [[0, 2], [1, 3]],
1601 [[2], [0, 1, 3]],
1602 [[0], [2, 1, 3]],
1603 [[0, 1], [2, 3]],
1604 [[1], [0, 2, 3]],
1605 [[0], [1, 2, 3]],
1606 ],
1607 &nodes,
1608 );
1609
1610 update_nodes(&mut nodes);
1611 clusters_are_correct(&vec![vec![Cluster(15)]; 18], &nodes);
1612 }
1613
1614 #[test]
1615 #[should_panic(expected = "The clusters should always be sorted, to prevent duplicates.")]
1616 fn unsorted_node_merge_multiple() {
1617 let unsorted = ClusteringNodeMergeMultiple {
1618 clusters: smallvec![Cluster(1), Cluster(0)],
1619 cost: 0.0,
1620 };
1621 let mut small_kmedian =
1622 KMedian::l1(&[array![0.0], array![1.0]]).expect("Creating kmedian should not fail.");
1623 let _: Vec<_> = unsorted
1624 .get_all_merges(&mut small_kmedian) .into_iter()
1626 .collect_vec();
1627 }
1628
1629 #[test]
1630 fn node_merge_single() {
1631 fn clusters_are_correct(
1632 expected_clusterings: &[Vec<Cluster>],
1633 nodes: &[ClusteringNodeMergeSingle],
1634 ) {
1635 let actual = nodes.iter().map(|x| x.clusters.to_vec()).collect_vec();
1636 assert_eq!(
1637 expected_clusterings, actual,
1638 "Clustering should match expected clustering. Maybe the order of returned Clusters has changed?"
1639 );
1640 }
1641 let mut kmedian =
1642 KMedian::l2_squared(&[array![0.0], array![1.0], array![2.0], array![3.0]])
1643 .expect("Creating kmedian should not fail.");
1644 let mut update_nodes = |nodes: &mut Vec<ClusteringNodeMergeSingle>| {
1645 *nodes = nodes
1646 .iter()
1647 .flat_map(|n| n.get_next_nodes(&mut kmedian, 3).collect_vec())
1648 .collect();
1649 };
1650 let mut nodes = vec![ClusteringNodeMergeSingle::empty()];
1651 clusters_are_correct(&clusterings![[]], &nodes);
1652
1653 update_nodes(&mut nodes);
1654 clusters_are_correct(&clusterings![[[0]]], &nodes);
1655
1656 update_nodes(&mut nodes);
1657 clusters_are_correct(&clusterings![[[0, 1]], [[0], [1]]], &nodes);
1658
1659 update_nodes(&mut nodes);
1660 clusters_are_correct(
1661 &clusterings![
1662 [[0, 1, 2]],
1663 [[0, 1], [2]],
1664 [[0, 2], [1]],
1665 [[0], [1, 2]],
1666 [[0], [1], [2]],
1667 ],
1668 &nodes,
1669 );
1670
1671 update_nodes(&mut nodes);
1672 clusters_are_correct(
1673 &clusterings![
1674 [[0, 1, 2, 3]],
1675 [[0, 1, 2], [3]],
1676 [[0, 1, 3], [2]],
1677 [[0, 1], [2, 3]],
1678 [[0, 1], [2], [3]],
1679 [[0, 2, 3], [1]],
1680 [[0, 2], [1, 3]],
1681 [[0, 2], [1], [3]],
1682 [[0, 3], [1, 2]],
1683 [[0], [1, 2, 3]],
1684 [[0], [1, 2], [3]],
1685 [[0, 3], [1], [2]],
1686 [[0], [1, 3], [2]],
1687 [[0], [1], [2, 3]],
1688 ],
1690 &nodes,
1691 );
1692 }
1693
1694 #[test]
1695 fn infinite_loop_optimise_locally() {
1696 let (weight_a, point_a) = (0.588_906_661, array![-0.487_778_761_130_834]);
1701 let (weight_b, point_b) = (0.434_371_596, array![-0.438_191_407_837_575]);
1702 let points = [
1703 (weight_a, -point_a.clone()),
1704 (weight_b, -point_b.clone()),
1705 (1.0, array![0.0]),
1706 (weight_a, point_a),
1707 (weight_b, point_b),
1708 ];
1709 let mut kmedian = KMedian::weighted_l1(&points).expect("Creating kmedian should not fail.");
1710
1711 let mut clustering = ClusteringNodeMergeMultiple {
1712 clusters: SmallVec::from_iter([
1713 cluster_from_iterator([0, 1, 2]),
1714 cluster_from_iterator([3, 4]),
1715 ]),
1716 cost: 0.488_933_068_284_744_25,
1717 };
1718
1719 clustering.optimise_locally(&mut kmedian);
1721 }
1722
1723 #[test]
1724 fn infinite_loop_optimise_locally_1() {
1725 let points = vec![
1728 (1.870_423_609_633_216e24, array![1000.0, -1000.0, 1000.0]),
1729 (3.817_589_201_683_946e23, array![1000.0, 1000.0, -1000.0]),
1730 (2.074_998_884_450_784_5e21, array![1000.0, 1000.0, 1000.0]),
1731 (
1732 1.0,
1733 array![
1734 -400.240_609_956_200_4,
1735 616.506_453_035_030_1,
1736 -79.475_319_067_602_64
1737 ],
1738 ),
1739 (1.0, array![-1000.0, 415.010_128_673_398_5, 1000.0]),
1740 ];
1741 let mut kmedian = KMedian::weighted_l1(&points).expect("Creating kmedian should not fail.");
1742
1743 let mut clustering = ClusteringNodeMergeMultiple {
1744 clusters: SmallVec::from_iter([
1745 cluster_from_iterator([0, 2, 4]),
1746 cluster_from_iterator([1, 3]),
1747 ]),
1748 cost: 4.149_997_768_901_569e24,
1749 };
1750
1751 clustering.optimise_locally(&mut kmedian);
1753 }
1754}