exact_clustering/
lib.rs

1/*!
2Find optimal [clusterings](https://en.wikipedia.org/wiki/Cluster_analysis) and
3[hierarchical clusterings](https://en.wikipedia.org/wiki/Hierarchical_clustering) on up to 32 points.
4If you only need approximate clusterings, there are excellent
5[other crates](https://www.arewelearningyet.com/clustering/) that run significantly faster.
6
7For weighted and unweighted kmedian-clustering (where centers are chosen from the points themselves), see [`KMedian`].
8For kmeans-clustering (where centers are chosen from the ambient space), see [`KMeans`] and [`WeightedKMeans`].
9If you'd like to solve other clustering-problems, implement the [`Cost`]-trait (and feel free to submit
10a pull-request!), or submit an issue on GitHub.
11
12Among others, the [`Cost`]-trait allows you to calculate:
13- Optimal clusterings using [`Cost::optimal_clusterings`]
14- Optimal hierarchical clusterings using [`Cost::price_of_hierarchy`]
15- Greedy hierarchical clusterings using [`Cost::price_of_greedy`]
16
17# Example
18
19```
20use ndarray::prelude::*;
21use std::collections::BTreeSet;
22use exact_clustering::{Cost as _, KMedian, KMeans};
23
24// Set of 2d-points looking like тае
25let points = vec![
26    array![0.0, 0.0],
27    array![1.0, 0.0],
28    array![0.0, 2.0],
29];
30// Instances are mutable to allow caching cluster-costs
31let mut kmedian = KMedian::l1(&points).unwrap();
32// All optimal clusterings are calculated at once to permit some speedups.
33let (cost, clusters) = &kmedian.optimal_clusterings()[2];
34
35assert_eq!(*cost, 1.0);
36// Each cluster in the returned clustering is a set of point-indices:
37assert_eq!(
38    BTreeSet::from([BTreeSet::from([0, 1]), BTreeSet::from([2])]),
39    clusters
40        .iter()
41        .cloned()
42        .map(BTreeSet::from_iter)
43        .collect(),
44);
45
46let price_of_hierarchy = kmedian.price_of_hierarchy().0;
47assert_eq!(price_of_hierarchy, 1.0);
48
49let price_of_greedy = kmedian.price_of_greedy().0;
50assert_eq!(price_of_greedy, 1.0);
51```
52*/
53
54// TODO: Derive serde, see:
55// https://rust-lang.github.io/api-guidelines/interoperability.html#c-serde
56
57#![expect(
58    clippy::missing_errors_doc,
59    reason = "The Error-Enum is sparse and documented."
60)]
61
62use core::hash::{self, Hash};
63use core::{cmp, ops};
64use core::{f64, fmt, iter};
65use ndarray::Array1;
66use pathfinding::{num_traits::Zero, prelude::dijkstra};
67use rustc_hash::{FxHashMap, FxHashSet};
68use smallvec::SmallVec;
69use std::collections::BinaryHeap;
70
71/// The storage-medium for representing clusters compactly via a bitset.
72///
73/// We'll hopefully never have to calculate clusters on more than 32 points, so 32 bits is enough for now.
74///
75/// If you are not on a 32-bit-or-above-platform, this will cause issues with indices.
76#[cfg(not(target_pointer_width = "16"))]
77pub type Storage = u32;
78
79/// The maximum number of points we can cluster before we overflow [`Storage`].
80#[expect(
81    clippy::as_conversions,
82    reason = "`Storage::BITS` will always fit into a `usize`."
83)]
84pub const MAX_POINT_COUNT: usize = Storage::BITS as usize;
85
86/// A compact representation of a cluster of points, using a bitset.
87#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
88pub struct Cluster(Storage);
89impl Cluster {
90    /// Create a new empty cluster containing no points.
91    const fn new() -> Self {
92        Self(0)
93    }
94
95    /// Create a new cluster containing a single point.
96    const fn singleton(point_ix: usize) -> Self {
97        Self(1 << point_ix)
98    }
99
100    /// Insert a point into the cluster.
101    fn insert(&mut self, point_ix: usize) {
102        let point = 1 << point_ix;
103        debug_assert!(
104            (point & self.0) == 0,
105            "Throughout the entire implementation, we should never to add the same point twice."
106        );
107        self.0 |= point;
108    }
109
110    /// Remove a point from the cluster.
111    fn remove(&mut self, point_ix: usize) {
112        let point = 1 << point_ix;
113        debug_assert!(
114            (point & self.0) != 0,
115            "Throughout the entire implementation, we should never remove a non-existing point."
116        );
117        self.0 &= !point;
118    }
119
120    /// Check whether the set contains a point-index.
121    #[must_use]
122    #[inline]
123    pub const fn contains(self, point_ix: usize) -> bool {
124        (self.0 & (1 << point_ix)) != 0
125    }
126
127    /// Count the number of points in the cluster.
128    #[must_use]
129    #[inline]
130    pub const fn len(self) -> Storage {
131        self.0.count_ones()
132    }
133
134    /// Check whether the cluster doesn't contain any points.
135    #[must_use]
136    #[inline]
137    pub const fn is_empty(self) -> bool {
138        self.0 == 0
139    }
140
141    /// Construct an iterator over the point-indices in the cluster.
142    #[inline]
143    #[must_use]
144    pub const fn iter(self) -> ClusterIter {
145        ClusterIter(self.0)
146    }
147
148    /// Merge this cluster with another one.
149    fn union_with(&mut self, other: Self) {
150        debug_assert!(
151            self.0 & other.0 == 0,
152            "Troughout the entire implementation, we should never be merging intersecting clusters."
153        );
154        self.0 |= other.0;
155    }
156}
157
158/// An iterator over the point-indices in a cluster.
159#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
160pub struct ClusterIter(Storage);
161impl Iterator for ClusterIter {
162    type Item = usize;
163
164    #[inline]
165    fn next(&mut self) -> Option<Self::Item> {
166        if self.0 == 0 {
167            None
168        } else {
169            #[expect(
170                clippy::as_conversions,
171                reason = "I assume `usize` is at least `Storage`."
172            )]
173            let ix = self.0.trailing_zeros() as usize;
174            self.0 &= self.0 - 1;
175            Some(ix)
176        }
177    }
178
179    #[inline]
180    fn size_hint(&self) -> (usize, Option<usize>) {
181        #[expect(
182            clippy::as_conversions,
183            reason = "I assume `usize` is at least `Storage`."
184        )]
185        let count = self.0.count_ones() as usize;
186        (count, Some(count))
187    }
188}
189
190impl IntoIterator for Cluster {
191    type Item = usize;
192    type IntoIter = ClusterIter;
193
194    #[inline]
195    fn into_iter(self) -> Self::IntoIter {
196        ClusterIter(self.0)
197    }
198}
199
200impl fmt::Display for Cluster {
201    #[inline]
202    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203        #[expect(
204            clippy::as_conversions,
205            reason = "I assume `usize` is at least `Storage`."
206        )]
207        let mut result = String::with_capacity(Storage::BITS as usize);
208        let mut bits = self.0;
209        for _ in 0..Storage::BITS {
210            if (bits & 1) == 1 {
211                result.push('#');
212            } else {
213                result.push('.');
214            }
215            bits >>= 1;
216        }
217        write!(f, "{result}")
218    }
219}
220
221/// A partition of a set of points into disjoint clusters.
222pub type Clustering = FxHashSet<Cluster>;
223
224/// A distance-matrix between pairs of points.
225///
226/// The distance between the points with indices `i` and `j` is `Distances[i][j]`.
227type Distances = Vec<Vec<f64>>;
228
229/// A single point.
230pub type Point = Array1<f64>;
231/// A weighted point.
232///
233/// The distance between two weighted points `d((w0, p0), (w1, p1))` is `w1 * d(p0, p1)`.
234pub type WeightedPoint = (f64, Array1<f64>);
235
236#[derive(Clone, Debug)]
237/// A helper-struct for efficiently merging clusters, used for finding optimal hierarchies.
238struct ClusteringNodeMergeMultiple {
239    /// The current clusters.
240    ///
241    /// We use a [`SmallVec`] because this will allocate frequently. The smallvec
242    /// must remain sorted so that two Nodes with the same clusters are recognised as
243    /// equal.
244    ///
245    /// TODO: Once [`generic_const_exprs`](https://github.com/rust-lang/rust/issues/76560) is stable,
246    /// try using an array with dynamic dispatch instead. Consider hiding this behind a feature-gate
247    /// to reduce compile-times.
248    clusters: SmallVec<[Cluster; 6]>,
249    /// The total cost of the clustering.
250    ///
251    /// We keep track of this to efficiently recalculate costs after merging.
252    ///
253    /// TODO: Try not keeping track of it, instead having [`ClusteringNodeMergeMultiple::get_all_merges`]
254    /// return a delta, and using that delta in Dijkstra.
255    cost: f64,
256}
257// Only consider `clusters` in equality-checks, costs should be near-equal anyway.
258impl PartialEq for ClusteringNodeMergeMultiple {
259    fn eq(&self, other: &Self) -> bool {
260        self.clusters == other.clusters
261    }
262}
263impl Eq for ClusteringNodeMergeMultiple {}
264impl Hash for ClusteringNodeMergeMultiple {
265    fn hash<H: hash::Hasher>(&self, state: &mut H) {
266        self.clusters.hash(state);
267    }
268}
269impl ClusteringNodeMergeMultiple {
270    /// Get all possible pairs of merges of the current clusters.
271    ///
272    /// If performance is at a premium, you can avoid the vec-allocation by inlining the loops at the
273    /// callsite, but that increases code-duplication (this method is required in [`Cost::price_of_hierarchy`]
274    /// and [`Cost::price_of_greedy`]) and prevents unit-testing. I wasn't able to return
275    /// a lifetimed iterator because it'd move costs into a closure. The performance-gain was about 5% on benchmarks.
276    #[must_use]
277    #[inline]
278    fn get_all_merges<C: Cost + ?Sized>(&self, data: &mut C) -> Vec<Self> {
279        debug_assert!(
280            self.clusters.is_sorted(),
281            "The clusters should always be sorted, to prevent duplicates."
282        );
283
284        #[expect(
285            clippy::integer_division,
286            reason = "At least one of the factors is always even."
287        )]
288        let mut nodes = Vec::with_capacity(self.clusters.len() * (self.clusters.len() - 1) / 2);
289        for i in 0..(self.clusters.len() - 1) {
290            // Split off cluster_i.
291            let (cluster_i, clusters_minus_i) = {
292                let mut clusters_minus_i = self.clusters.clone();
293                // This must *not* be a swap_remove, to preserve order.
294                let cluster_i = clusters_minus_i.remove(i);
295                (cluster_i, clusters_minus_i)
296            };
297            let cost_minus_i = self.cost - data.cost(cluster_i);
298            // Index `i` is gone now, so the lower bound is still `i`
299            nodes.extend((i..clusters_minus_i.len()).map(|j| {
300                let mut new_clusters = clusters_minus_i.clone();
301                // SAFETY:
302                // `j` is less than `clusters_minus_i.len()`, and `new_clusters` is a clone of
303                // `clusters_minus_i`, so it's a valid index.
304                let cluster_j = unsafe { new_clusters.get_unchecked_mut(j) };
305                let mut new_cost = cost_minus_i - data.cost(*cluster_j);
306                cluster_j.union_with(cluster_i);
307                new_cost += data.cost(*cluster_j);
308
309                debug_assert!(new_clusters.len() == self.clusters.len() - 1, "We should have merged two clusters, which should have reduced the number of clusters by exactly one.");
310                debug_assert!(new_clusters.is_sorted(), "The clusters should always be sorted, to prevent duplicates.");
311                debug_assert!({
312                    (0..data.num_points()).all(|point_ix| new_clusters.iter().filter(|cluster| cluster.contains(point_ix)).count()==1)
313                },"The clusters should always cover every point exactly once.");
314                Self {
315                    clusters: new_clusters,
316                    cost: new_cost,
317                }
318            }));
319        }
320        nodes
321    }
322
323    /// Change `self` to be locally optimal.
324    ///
325    /// For every point, try re-assigning that point to a different cluster and check if it decreases the
326    /// cost, repeating this until re-assigning points no longer decreases the cost.
327    fn optimise_locally<C: Cost + ?Sized>(&mut self, data: &mut C) {
328        // Due to floating-point inaccuracies, we could enter an infinite loop if
329        // we accept an "improvement" as "improves cost by some positive amount", so
330        // we additionally keep track of already-visited states.
331        let mut already_visited: FxHashSet<(Cluster, usize, usize)> = FxHashSet::default();
332        let mut found_improvement = || {
333            #[expect(
334                clippy::indexing_slicing,
335                reason = "These are safe, we just use indices to avoid borrow-issues."
336            )]
337            for source_cluster_ix in 0..self.clusters.len() {
338                let source_cluster = self.clusters[source_cluster_ix];
339                for point_ix in source_cluster {
340                    let mut updated_source_cluster = source_cluster;
341                    updated_source_cluster.remove(point_ix);
342                    let source_costdelta =
343                        data.cost(updated_source_cluster) - data.cost(source_cluster);
344
345                    for target_cluster_ix in
346                        (0..self.clusters.len()).filter(|ix| *ix != source_cluster_ix)
347                    {
348                        if !already_visited.insert((
349                            source_cluster,
350                            source_cluster_ix,
351                            target_cluster_ix,
352                        )) {
353                            continue;
354                        }
355                        let target_cluster = self.clusters[target_cluster_ix];
356
357                        let mut updated_target_cluster = target_cluster;
358                        updated_target_cluster.insert(point_ix);
359                        let costdelta = source_costdelta + data.cost(updated_target_cluster)
360                            - data.cost(target_cluster);
361                        if costdelta < 0.0 {
362                            // Keep the clusters in order:
363                            if updated_source_cluster.cmp(&updated_target_cluster)
364                                == source_cluster_ix.cmp(&target_cluster_ix)
365                            {
366                                self.clusters[source_cluster_ix] = updated_source_cluster;
367                                self.clusters[target_cluster_ix] = updated_target_cluster;
368                            } else {
369                                self.clusters[source_cluster_ix] = updated_target_cluster;
370                                self.clusters[target_cluster_ix] = updated_source_cluster;
371                            }
372                            self.cost += costdelta;
373                            return true;
374                        }
375                    }
376                }
377            }
378            false
379        };
380
381        while found_improvement() {}
382
383        self.clusters.sort();
384
385        debug_assert!(
386            {
387                (0..data.num_points()).all(|point_ix| {
388                    self.clusters
389                        .iter()
390                        .filter(|cluster| cluster.contains(point_ix))
391                        .count()
392                        == 1
393                })
394            },
395            "The clusters should always cover every point exactly once."
396        );
397    }
398
399    /// Create a new node with `num_points` singleton-clusters.
400    #[inline]
401    fn new_singletons(num_points: usize) -> Self {
402        let mut clusters = SmallVec::default();
403        for i in 0..num_points {
404            clusters.push(Cluster::singleton(i));
405        }
406        debug_assert!(
407            clusters.is_sorted(),
408            "The clusters should always be sorted, to prevent duplicates."
409        );
410        Self {
411            clusters,
412            cost: 0.0,
413        }
414    }
415
416    /// Convert `self` to a [`Clustering`].
417    #[inline]
418    fn into_clustering(self) -> Clustering {
419        self.clusters.into_iter().collect()
420    }
421}
422
423#[derive(Clone, Debug)]
424/// A helper-struct for efficiently merging clusters, used for finding optimal clusterings.
425///
426/// Unlike [`ClusteringNodeMergeMultiple`], this only allows for
427/// merging of a cluster with a singleton-cluster.
428struct ClusteringNodeMergeSingle {
429    /// The nonempty clusters already created.
430    ///
431    /// This may include singletons. It will always contain at most `k` clusters.
432    ///
433    /// TODO: Once [`generic_const_exprs`](https://github.com/rust-lang/rust/issues/76560) is stable,
434    /// try using an array with dynamic dispatch instead. Consider hiding this behind a feature-gate
435    /// to reduce compile-times.
436    clusters: SmallVec<[Cluster; 6]>,
437    /// The cost of the clustering represented by [`Self::clusters`].
438    ///
439    /// This should always be nearly-equal to [`Cost::total_cost`] of [`Self::clusters`].
440    cost: f64,
441    /// The next point to add.
442    ///
443    /// This point, and implicitly every point after it, are singleton clusters.
444    /// (TODO: This assumption means singleton clusters must have cost 0.)
445    /// It's more efficient to track them this way, because
446    /// - It means we can store fewer clusters in [`Self::clusters`].
447    /// - It ensures we enumerate the clusterings less redundantly
448    ///
449    /// TODO: Try not storing this, but instead doing best-first-search level-wise?
450    /// We could also use the significantly smaller u8 here.
451    next_to_add: usize,
452}
453impl PartialEq for ClusteringNodeMergeSingle {
454    fn eq(&self, other: &Self) -> bool {
455        self.clusters == other.clusters
456    }
457}
458impl Eq for ClusteringNodeMergeSingle {}
459impl Hash for ClusteringNodeMergeSingle {
460    fn hash<H: hash::Hasher>(&self, state: &mut H) {
461        self.clusters.hash(state);
462    }
463}
464impl Ord for ClusteringNodeMergeSingle {
465    // Order them by highest cost first, because the BinaryHeap is a max-heap.
466    fn cmp(&self, other: &Self) -> cmp::Ordering {
467        other
468            .cost
469            .total_cmp(&self.cost)
470            .then_with(|| self.clusters.cmp(&other.clusters))
471    }
472}
473impl PartialOrd for ClusteringNodeMergeSingle {
474    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
475        Some(self.cmp(other))
476    }
477}
478impl ClusteringNodeMergeSingle {
479    /// Returns all possible next nodes from the current node.
480    ///
481    /// This merges [`Self::next_to_add`] into existing clusters, and creates a new singleton-cluster
482    /// for it if [`Self::clusters`] has fewer than `k` clusters.
483    #[inline]
484    fn get_next_nodes<'a, C: Cost + ?Sized>(
485        &'a self,
486        data: &'a mut C,
487        k: usize,
488    ) -> impl Iterator<Item = Self> + use<'a, C> {
489        (0..self.clusters.len())
490            .map(|cluster_ix| {
491                let mut new_clustering_node = self.clone();
492                // SAFETY:
493                // `cluster_ix` is less than `self.clusters.len()`, and `new_clustering_node.clusters` is
494                // a clone of `self.clusters`, so `cluster_ix` is in bounds.
495                let cluster_to_edit =
496                    unsafe { new_clustering_node.clusters.get_unchecked_mut(cluster_ix) };
497                new_clustering_node.cost -= data.cost(*cluster_to_edit);
498                cluster_to_edit.insert(new_clustering_node.next_to_add);
499                new_clustering_node.cost += data.cost(*cluster_to_edit);
500                new_clustering_node.next_to_add += 1;
501                new_clustering_node
502            })
503            .chain((self.clusters.len() < k).then(|| {
504                let mut clustering_node = self.clone();
505                clustering_node
506                    .clusters
507                    .push(Cluster::singleton(clustering_node.next_to_add));
508                clustering_node.next_to_add += 1;
509                clustering_node
510            }))
511    }
512
513    /// Create a new empty [`ClusteringNodeMergeSingle`].
514    ///
515    /// In this, every node implicitly lives in a singleton-cluster.
516    fn empty() -> Self {
517        Self {
518            clusters: SmallVec::default(),
519            cost: 0.0,
520            next_to_add: 0,
521        }
522    }
523}
524
525#[derive(Debug, PartialEq, Clone, Copy)]
526/// Tracking the hierarchy-cost in Dijkstra.
527///
528/// The cost of a hierarchy-level is the cost of its clustering over the cost of the best-possible-clustering on
529/// that level.
530/// The cost of a hierarchy is the maximum of the costs of all its levels, so addition between two [`MaxRatio`]s
531/// is the maximum of the two.
532struct MaxRatio(f64);
533impl MaxRatio {
534    /// Create a new [`MaxRatio`] from a clustering-cost and an optimal-cost.
535    ///
536    /// The clustering-cost will usually be the total cost of a level in the hierarchy,
537    /// whereas the optimal-cost will be the optimal-cost of the same level.
538    #[inline]
539    fn new(clustering_cost: f64, opt_cost: f64) -> Self {
540        debug_assert!(
541            clustering_cost.is_finite(),
542            "hierarchy_cost {clustering_cost} should be finite."
543        );
544        debug_assert!(
545            opt_cost.is_finite(),
546            "opt_cost {opt_cost} should be finite."
547        );
548        debug_assert!(
549            opt_cost >= 0.0,
550            "opt_cost {opt_cost} should be non-negative."
551        );
552        debug_assert!(
553            clustering_cost >= 0.0,
554            "hierarchy_cost {clustering_cost} should be non-negative"
555        );
556        debug_assert!(
557            clustering_cost >= opt_cost - 1e-9,
558            "hierarchy_cost {clustering_cost} should be at least opt_cost {opt_cost}"
559        );
560        Self(if opt_cost.is_zero() {
561            if clustering_cost.is_zero() {
562                1.0
563            } else {
564                f64::INFINITY
565            }
566        } else {
567            clustering_cost / opt_cost
568        })
569    }
570}
571impl Eq for MaxRatio {} // The max-ratios should always be finite.
572impl Ord for MaxRatio {
573    fn cmp(&self, other: &Self) -> cmp::Ordering {
574        self.0.total_cmp(&other.0)
575    }
576}
577impl PartialOrd for MaxRatio {
578    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
579        Some(self.cmp(other))
580    }
581}
582impl ops::Add for MaxRatio {
583    type Output = Self;
584    fn add(self, rhs: Self) -> Self {
585        Self(self.0.max(rhs.0))
586    }
587}
588impl Zero for MaxRatio {
589    fn zero() -> Self {
590        Self(1.0)
591    }
592    #[expect(clippy::float_cmp, reason = "This should be exact.")]
593    fn is_zero(&self) -> bool {
594        self.0 == 1.0
595    }
596}
597
598/// A map from clusters to costs, used for memoization.
599type Costs = FxHashMap<Cluster, f64>;
600
601/// A trait for cost-functions for a class of clustering-problems.
602///
603/// TODO: Specify contract for trait-implementors.
604pub trait Cost {
605    /// Get the cost of a cluster.
606    ///
607    /// This can be memoized via data in `self`.
608    /// The `cluster` will never contain an index higher than `self.num_points()-1` and will
609    /// never be empty.
610    fn cost(&mut self, cluster: Cluster) -> f64;
611
612    /// Get the total cost of a clustering.
613    #[inline]
614    fn total_cost(&mut self, clustering: &Clustering) -> f64 {
615        clustering.iter().map(|cluster| self.cost(*cluster)).sum()
616    }
617
618    /// Quickly calculate a not-necessarily-optimal clustering.
619    ///
620    /// This can speed up the search for an optimal clustering by pruning the search-tree earlier.
621    ///
622    /// For the returned vector, `vec[k]` must be a tuple containing the approximate clustering
623    /// for level `k`, along with the [`total_cost`](Cost::total_cost) of that clustering.
624    /// Here, `vec[0]` can have an arbitrary score (usually `0.0`) and arbitrary clustering (usually empty).
625    #[inline]
626    fn approximate_clusterings(&mut self) -> Vec<(f64, Clustering)> {
627        // This is similar to `greedy_hierarchy`, but with local search after each merge.
628        // TODO: Can we reduce code-duplication here?
629        let num_points = self.num_points();
630
631        let mut clustering = ClusteringNodeMergeMultiple::new_singletons(num_points);
632        let mut solution: Vec<(f64, Clustering)> =
633            vec![(0.0, clustering.clone().into_clustering())];
634
635        while clustering.clusters.len() > 1 {
636            let mut best_merge = clustering
637                .get_all_merges(self)
638                .into_iter()
639                .min_by(|a, b| a.cost.total_cmp(&b.cost))
640                .expect("There should always be a possible merge");
641            best_merge.optimise_locally(self);
642
643            solution.push((best_merge.cost, best_merge.clone().into_clustering()));
644            clustering = best_merge;
645        }
646
647        solution.push((0.0, Clustering::default()));
648        solution.reverse();
649        solution
650    }
651
652    /// Get the number of points that must be clustered.
653    ///
654    /// This must never exceed [`MAX_POINT_COUNT`].
655    fn num_points(&self) -> usize;
656
657    /// Calculate an optimal `k`-clustering for every `0 тЙд k тЙд self.num_points()`.
658    ///
659    /// For the returned vector, `vec[k]` must be a tuple containing the optimal clustering
660    /// for level `k`, along with the [`total_cost`](Cost::total_cost) of that clustering.
661    /// Here, `vec[0]` can have an arbitrary score (usually `0.0`) and arbitrary clustering (usually empty).
662    #[inline]
663    fn optimal_clusterings(&mut self) -> Vec<(f64, Clustering)> {
664        let num_points = self.num_points();
665        let mut results = Vec::with_capacity(num_points);
666
667        // TODO: Could we instead use some good A* heuristic? Then ordering the points by weight might also
668        // be redundant.
669        for (k, (approximate_cost, approximate_clustering)) in
670            self.approximate_clusterings().into_iter().enumerate()
671        {
672            results.push((|| {
673                debug_assert_eq!(
674                    approximate_clustering.len(),
675                    k,
676                    "The approximate clustering on level {k} should have exactly {k} clusters."
677                );
678                let mut min_cost = approximate_cost;
679
680                let mut to_see: BinaryHeap<ClusteringNodeMergeSingle> = BinaryHeap::new();
681                to_see.push(ClusteringNodeMergeSingle::empty());
682
683                while let Some(clustering_node) = to_see.pop() {
684                    if clustering_node.clusters.len() == k
685                        && clustering_node.next_to_add == num_points
686                    {
687                        return (
688                            clustering_node.cost,
689                            clustering_node.clusters.into_iter().collect(),
690                        );
691                    }
692                    if clustering_node.next_to_add < num_points {
693                        for new_clustering_node in clustering_node.get_next_nodes(self, k) {
694                            if new_clustering_node.cost < min_cost {
695                                if new_clustering_node.clusters.len() == k
696                                    && new_clustering_node.next_to_add == num_points
697                                {
698                                    min_cost = new_clustering_node.cost;
699                                }
700                                to_see.push(new_clustering_node);
701                            }
702                        }
703                    }
704                }
705                // This can only happen due to floating-point-rounding-errors, or
706                // if the approximate_clustering_cost was off.
707                (approximate_cost, approximate_clustering)
708            })());
709        }
710        results
711    }
712
713    /// Calculate the price-of-hierarchy of the clustering-problem, together with an optimal hierarchy.
714    ///
715    /// A hierarchical clustering is a set of nested clusterings, one for each possible value of k.
716    /// The cost-ratio of level `k` in the hierarchy is its [total cost](`Cost::total_cost`) divided by the cost of
717    /// an [optimal `k`-clustering](`Cost::optimal_clusterings`).
718    ///
719    /// The cost-ratio of the hierarchy is the maximum of the cost-ratios across all its levels.
720    /// The price-of-hierarchy is the lowest-possible cost-ratio across all hierarchical clusterings.
721    ///
722    /// For the returned vector, `vec[k]` is the cluster for level `k`, defaulting to the empty clustering
723    /// for `k==0`. Note that the algorithm constructs this hierarchy in reverse, starting with every
724    /// point in a singleton-cluster.
725    #[must_use]
726    #[inline]
727    fn price_of_hierarchy(&mut self) -> (f64, Vec<Clustering>) {
728        let num_points = self.num_points();
729        let opt_for_fixed_k: Vec<f64> = self
730            .optimal_clusterings()
731            .into_iter()
732            .map(|(cost, _)| cost)
733            .collect();
734
735        let (price_of_greedy, greedy_hierarchy) = self.price_of_greedy();
736        let mut min_hierarchy_price = MaxRatio(price_of_greedy);
737        let initial_clustering = ClusteringNodeMergeMultiple::new_singletons(num_points);
738        // TODO: If we ever decide to inline dijkstra, we should also have a workhorse-variable for collecting the
739        // get_all_merges results, unless inlining dijkstra makes allocations entirely obsolete due to inline
740        // iterators.
741        // TODO: If we ever decide to inline dijkstra, benchmark running `retain` on all nodes, discarding those
742        // whose cost is below the new `min_hierarchy_price`.
743        // TODO: Could we instead use some good A* heuristic? Then ordering the points by weight might also
744        // be redundant.
745        dijkstra(
746            &initial_clustering,
747            |clustering| {
748                let opt_cost =
749                    // SAFETY:
750                    // We'll never have more than `num_points` clusters, and `opt_for_fixed_k` can index
751                    // up to `num_points`.
752                    // Dijkstra also terminates after encountering a custering with only 1 cluster.
753                    *unsafe { opt_for_fixed_k.get_unchecked(clustering.clusters.len()-1) };
754                clustering
755                    .get_all_merges(self)
756                    .into_iter()
757                    .filter_map(move |new_clustering| {
758                        let ratio = MaxRatio::new(new_clustering.cost, opt_cost);
759                        (ratio < min_hierarchy_price).then(|| {
760                            if new_clustering.clusters.len() == 1 {
761                                min_hierarchy_price = ratio;
762                            }
763                            (new_clustering, ratio)
764                        })
765                    })
766            },
767            |clustering| clustering.clusters.len() == 1,
768        )
769        .map_or_else(
770            || (price_of_greedy, greedy_hierarchy),
771            |(path, cost)| {
772                (
773                    cost.0,
774                    iter::once(Clustering::default())
775                        .chain(
776                            path.into_iter()
777                                .rev()
778                                .map(ClusteringNodeMergeMultiple::into_clustering),
779                        )
780                        .collect(),
781                )
782            },
783        )
784    }
785
786    #[must_use]
787    #[inline]
788    /// Calculate a greedy hierarchical clustering.
789    ///
790    /// The greedy-hierarchy calculates a hierarchical clustering by starting with
791    /// each point in a singleton-cluster, and then repeatedly merging those clusters whose
792    /// merging yields the smallest increase in cost. Ties are broken arbitrarily.
793    ///
794    /// For the returned vector, `vec[k]` is a tuple containing the greedy clustering
795    /// for level `k`, along with the [`total_cost`](Cost::total_cost) of that clustering.
796    ///
797    /// Here, `vec[0]` has a score of `0.0` and an empty clustering. Note that the clusterings
798    /// are constructed in reverse, as we start with every point in a singleton-cluster.
799    fn greedy_hierarchy(&mut self) -> Vec<(f64, Clustering)> {
800        let num_points = self.num_points();
801
802        let mut clustering = ClusteringNodeMergeMultiple::new_singletons(num_points);
803        let mut solution: Vec<(f64, Clustering)> =
804            vec![(0.0, clustering.clone().into_clustering())];
805
806        while clustering.clusters.len() > 1 {
807            let best_merge = clustering
808                .get_all_merges(self)
809                .into_iter()
810                .min_by(|a, b| a.cost.total_cmp(&b.cost))
811                .expect("There should always be a possible merge");
812            solution.push((best_merge.cost, best_merge.clone().into_clustering()));
813            clustering = best_merge;
814        }
815
816        solution.push((0.0, Clustering::default()));
817        solution.reverse();
818        solution
819    }
820
821    /// Calculate the cost-ratio of a greedy hierarchical clustering.
822    ///
823    /// See [`Cost::price_of_hierarchy`] for information about the cost-ratio of a hierarchical clustering,
824    /// and the returned hierarchy.
825    #[must_use]
826    #[inline]
827    fn price_of_greedy(&mut self) -> (f64, Vec<Clustering>) {
828        let mut max_ratio = MaxRatio::zero();
829        let greedy_hierarchy = self.greedy_hierarchy();
830        // TODO: Calculation of optimal_clusterings can be sped up by feeding the
831        // greedy-hierarchy into it as a starting-point, perhaps? Though currently approximate-clusterings
832        // uses a better local-search algorithm.
833        let opt_for_fixed_k: Vec<f64> = self
834            .optimal_clusterings()
835            .into_iter()
836            .map(|(cost, _)| cost)
837            .collect();
838
839        // Skip the first (empty) level
840        for (cost, clustering) in greedy_hierarchy.iter().skip(1) {
841            let opt_cost = opt_for_fixed_k
842                .get(clustering.len())
843                .expect("opt_for_fixed_k should have an entry for this number of clusters.");
844            let ratio = MaxRatio::new(*cost, *opt_cost);
845            max_ratio = max_ratio + ratio;
846        }
847
848        let hierarchy = greedy_hierarchy.into_iter().map(|x| x.1).collect();
849        (max_ratio.0, hierarchy)
850    }
851}
852
853/// A clustering-problem where each center must be one of the points that are to be clustered.
854///
855/// The cost is supplied using a distance-function between points. The cost of a cluster, given a center,
856/// is the sum of the distances between the center and all points in the cluster. The cost of a cluster
857/// will always be calculated by choosing the center yielding the smallest cost.
858#[derive(Clone, Debug)]
859pub struct KMedian {
860    /// The distances between the points.
861    distances: Distances,
862    /// A cache for storing already-calculated costs of clusters.
863    costs: Costs,
864}
865impl KMedian {
866    /// Create a `k`-median clustering instance using the squared L2-norm.
867    ///
868    /// # Examples
869    ///
870    /// ```
871    /// use ndarray::array;
872    /// use exact_clustering::KMedian;
873    ///
874    /// KMedian::l2_squared(&[array![0.0, 0.0], array![1.0, 2.0]]).unwrap();
875    /// ```
876    #[inline]
877    pub fn l2_squared(points: &[Point]) -> Result<Self, Error> {
878        let verified_points = verify_points(points)?;
879        Ok(Self {
880            distances: distances_from_points_with_element_norm(verified_points, |x| x.powi(2)),
881            costs: Costs::default(),
882        })
883    }
884
885    /// Create a `k`-median clustering instance using the L2-norm.
886    ///
887    /// # Examples
888    ///
889    /// ```
890    /// use ndarray::array;
891    /// use exact_clustering::KMedian;
892    ///
893    /// KMedian::l2(&[array![0.0, 0.0], array![1.0, 2.0]]).unwrap();
894    /// ```
895    ///
896    /// TODO: This is uncovered by tests.
897    #[inline]
898    pub fn l2(points: &[Point]) -> Result<Self, Error> {
899        let verified_points = verify_points(points)?;
900        Ok(Self {
901            distances: distances_from_points_with_element_norm(verified_points, |x| x.powi(2))
902                .iter()
903                .map(|vec| vec.iter().map(|x| x.sqrt()).collect())
904                .collect(),
905            costs: Costs::default(),
906        })
907    }
908
909    /// Create a `k`-median clustering instance using the L1-norm.
910    ///
911    /// # Examples
912    ///
913    /// ```
914    /// use ndarray::array;
915    /// use exact_clustering::KMedian;
916    ///
917    /// KMedian::l1(&[array![0.0, 0.0], array![1.0, 2.0]]).unwrap();
918    /// ```
919    #[inline]
920    pub fn l1(points: &[Point]) -> Result<Self, Error> {
921        let verified_points = verify_points(points)?;
922        Ok(Self {
923            distances: distances_from_points_with_element_norm(verified_points, f64::abs),
924            costs: Costs::default(),
925        })
926    }
927
928    /// Create a `k`-median clustering instance using the squared L2-norm.
929    ///
930    /// Use [`KMedian::l2_squared`] instead if all your points have the same weight.
931    ///
932    /// The distance between a weighted point `(w, p)` and the center `(v, c)` is the
933    /// squared [euclidean-distance](https://en.wikipedia.org/wiki/Euclidean_norm) between `c` and `p`,
934    /// multiplied by `w`.
935    /// For instance, the center of the cluster {`(1, [0,0])`, `(2, [3,0])`} will be `(2, [3,0])`, because the cost
936    /// of choosing that center is `9`, whereas the cost of choosing `(1, [0,0])` as a center is `18`.
937    ///
938    /// # Examples
939    ///
940    /// ```
941    /// use ndarray::array;
942    /// use exact_clustering::KMedian;
943    ///
944    /// KMedian::weighted_l2_squared(&[(1.0, array![0.0, 0.0]), (2.0, array![1.0, 2.0])]).unwrap();
945    /// ```
946    #[inline]
947    pub fn weighted_l2_squared(weighted_points: &[WeightedPoint]) -> Result<Self, Error> {
948        let verified_weighted_points = verify_weighted_points(weighted_points)?;
949        Ok(Self {
950            distances: distances_from_weighted_points_with_element_norm(
951                verified_weighted_points,
952                |x| x.powi(2),
953            ),
954            costs: Costs::default(),
955        })
956    }
957
958    /// Create a `k`-median clustering instance using the L2-norm.
959    ///
960    /// Use [`KMedian::l2`] instead if all your points have the same weight.
961    ///
962    /// The distance between a weighted point `(w, p)` and the center `(v, c)` is the
963    /// [euclidean-distance](https://en.wikipedia.org/wiki/Euclidean_norm) between `c` and `p`,
964    /// multiplied by `w`.
965    ///
966    /// # Examples
967    ///
968    /// ```
969    /// use ndarray::array;
970    /// use exact_clustering::KMedian;
971    ///
972    /// KMedian::weighted_l2(&[(1.0, array![0.0, 0.0]), (2.0, array![1.0, 2.0])]).unwrap();
973    /// ```
974    ///
975    /// TODO: This is uncovered by tests.
976    #[inline]
977    pub fn weighted_l2(weighted_points: &[WeightedPoint]) -> Result<Self, Error> {
978        let verified_weighted_points = verify_weighted_points(weighted_points)?;
979        Ok(Self {
980            distances: distances_from_weighted_points_with_element_norm(
981                verified_weighted_points,
982                |x| x.powi(2),
983            )
984            .iter()
985            .map(|vec| vec.iter().map(|x| x.sqrt()).collect())
986            .collect(),
987            costs: Costs::default(),
988        })
989    }
990
991    /// Create a `k`-median clustering instance using the L1-norm.
992    ///
993    /// Use [`KMedian::l1`] instead if all your points have the same weight.
994    ///
995    /// The distance between a weighted point `(w, p)` and the center `(v, c)` is the
996    /// [taxicab-distance](https://en.wikipedia.org/wiki/Taxicab_geometry) between `c` and `p`, multiplied by `w`.
997    /// For instance, the center of the cluster {`(1, [0,0])`, `(2, [3,0])`} will be `(2, [3,0])`, because the cost
998    /// of choosing that center is `3`, whereas the cost of choosing `(1, [0,0])` as a center is `6`.
999    ///
1000    /// # Examples
1001    ///
1002    /// ```
1003    /// use ndarray::array;
1004    /// use exact_clustering::KMedian;
1005    ///
1006    /// KMedian::weighted_l1(&[(1.0, array![0.0, 0.0]), (2.0, array![1.0, 2.0])]).unwrap();
1007    /// ```
1008    #[inline]
1009    pub fn weighted_l1(weighted_points: &[WeightedPoint]) -> Result<Self, Error> {
1010        let verified_weighted_points = verify_weighted_points(weighted_points)?;
1011        Ok(Self {
1012            distances: distances_from_weighted_points_with_element_norm(
1013                verified_weighted_points,
1014                f64::abs,
1015            ),
1016            costs: Costs::default(),
1017        })
1018    }
1019}
1020impl Cost for KMedian {
1021    // TODO: Could we achieve a faster optimal-clusterings-impl for KMedian
1022    // by not searching for clusterings but for centroids?
1023    #[inline]
1024    fn num_points(&self) -> usize {
1025        self.distances.len()
1026    }
1027    #[inline]
1028    fn cost(&mut self, cluster: Cluster) -> f64 {
1029        *self.costs.entry(cluster).or_insert_with(|| {
1030            cluster
1031                .iter()
1032                .map(|center_candidate_ix| {
1033                    let center_candidate_row =
1034                        // SAFETY:
1035                        // [`Cost::cost`] promises that `cluster` will never contain an index
1036                        // higher than `self.num_points()-1`. Because `self.num_points()` is
1037                        // the length of `self.distances`, this bound is safe.
1038                        unsafe { self.distances.get_unchecked(center_candidate_ix) };
1039                    cluster
1040                        .iter()
1041                        // SAFETY:
1042                        // Similar to the above safety-comment, and noting that `center_candidate_row`
1043                        // has length `self.num_points()`, as well.
1044                        .map(|ix| *unsafe { center_candidate_row.get_unchecked(ix) })
1045                        .sum()
1046                })
1047                .min_by(f64::total_cmp)
1048                .unwrap_or(0.0)
1049        })
1050    }
1051}
1052
1053/// Create [`Distances`] from Points using a distance-function.
1054///
1055/// This function must be non-negative, but need not be symmetric or satisfy the triangle-inequality.
1056fn distances_from_points_with_distance_function<T>(
1057    points: &[T],
1058    distance_function: impl Fn(&T, &T) -> f64,
1059) -> Distances {
1060    points
1061        .iter()
1062        .map(|p| points.iter().map(|q| distance_function(p, q)).collect())
1063        .collect()
1064}
1065
1066/// Create [`Distances`] from Points using a function that will be applied to each coordinate of the difference
1067/// between two points, and then summed up.
1068///
1069/// The function must be non-negative, but need not be symmetric or satisfy the triangle-inequality.
1070fn distances_from_points_with_element_norm(
1071    points: &[Point],
1072    elementnorm: impl Fn(f64) -> f64,
1073) -> Distances {
1074    distances_from_points_with_distance_function(points, |p, q| {
1075        (p - q).map(|x| elementnorm(*x)).sum()
1076    })
1077}
1078
1079/// Create [`Distances`] from weighted points using a function that will be applied to each coordinate of the
1080/// difference between two points, summed up, and then multiplied by the second point's weight.
1081///
1082/// The function must be non-negative, but need not be symmetric or satisfy the triangle-inequality. The weights
1083/// must be non-negative.
1084fn distances_from_weighted_points_with_element_norm(
1085    points: &[WeightedPoint],
1086    elementnorm: impl Fn(f64) -> f64,
1087) -> Distances {
1088    distances_from_points_with_distance_function(points, |p, q| {
1089        q.0 * (&p.1 - &q.1).map(|x| elementnorm(*x)).sum()
1090    })
1091}
1092
1093/// An error-type for creating clustering-problems.
1094#[derive(Debug, PartialEq, Eq)]
1095#[expect(
1096    clippy::exhaustive_enums,
1097    reason = "Extending this enum should be a breaking change."
1098)]
1099pub enum Error {
1100    /// No points were supplied.
1101    EmptyPoints,
1102    /// The number of points in the problem is too large. It must not exceed [`MAX_POINT_COUNT`].
1103    TooManyPoints(usize),
1104    /// Two points (specified by their indices in the points-vec) have different dimensions.
1105    ShapeMismatch(usize, usize),
1106    /// A point's (specified by its index in the points-vec) weight is non-finite or non-positive.
1107    ///
1108    /// Positive infinity is not allowed to avoid degenerate cases for multiple +тИЮ-points in the same cluster:
1109    /// If we have `{(+тИЮ, x), (+тИЮ, y)}` with `x!=y`, then we can reasonably set the cost to +тИЮ.
1110    /// But if `x==y`, should the cost still be +тИЮ, or should it be 0? `+тИЮ * 0.0` is NaN.
1111    BadWeight(usize),
1112}
1113
1114impl fmt::Display for Error {
1115    #[inline]
1116    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1117        let msg = match *self {
1118            Self::EmptyPoints => "no points supplied".to_owned(),
1119            Self::TooManyPoints(pointcount) => {
1120                format!("can cluster at most {MAX_POINT_COUNT} points, but got {pointcount}")
1121            }
1122            Self::ShapeMismatch(ix1, ix2) => {
1123                format!("points {ix1} and {ix2} have different dimensions",)
1124            }
1125            Self::BadWeight(ix) => {
1126                format!("point {ix} doesn't have a finite and positive weight",)
1127            }
1128        };
1129        f.write_str(&msg)
1130    }
1131}
1132
1133#[expect(
1134    clippy::absolute_paths,
1135    reason = "Not worth bringing into scope for one use."
1136)]
1137impl core::error::Error for Error {}
1138
1139/// Check whether a set of points is valid for clustering.
1140fn verify_points(points: &[Point]) -> Result<&[Point], Error> {
1141    let point_count = points.len();
1142    if point_count > MAX_POINT_COUNT {
1143        return Err(Error::TooManyPoints(point_count));
1144    }
1145
1146    let first_point = points.first().ok_or(Error::EmptyPoints)?;
1147    let first_dim = first_point.raw_dim();
1148
1149    if let Some(ix) = points.iter().position(|p| p.raw_dim() != first_dim) {
1150        return Err(Error::ShapeMismatch(0, ix));
1151    }
1152
1153    Ok(points)
1154}
1155
1156/// Check whether a set of weighted points is valid for clustering.
1157fn verify_weighted_points(weighted_points: &[WeightedPoint]) -> Result<&[WeightedPoint], Error> {
1158    let point_count = weighted_points.len();
1159    if point_count > MAX_POINT_COUNT {
1160        return Err(Error::TooManyPoints(point_count));
1161    }
1162
1163    let first_point = weighted_points.first().ok_or(Error::EmptyPoints)?;
1164    let first_dim = first_point.1.raw_dim();
1165
1166    if let Some(ix) = weighted_points
1167        .iter()
1168        .position(|p| p.1.raw_dim() != first_dim)
1169    {
1170        return Err(Error::ShapeMismatch(0, ix));
1171    }
1172
1173    if let Some(ix) = weighted_points
1174        .iter()
1175        .position(|p| !p.0.is_finite() || p.0 <= 0.0)
1176    {
1177        return Err(Error::BadWeight(ix));
1178    }
1179
1180    Ok(weighted_points)
1181}
1182
1183/// A clustering-problem where each center can be any point in the metric space.
1184///
1185/// The metric space is the same space the points live in.
1186///
1187/// The cost of a cluster, given a center, is the sum of the squared Euclidean distances between
1188/// the center and all points in the cluster.
1189/// The center is automatically calculated to minimise the cost, which turns out to simply be the
1190/// average of all point-positions in the cluster.
1191///
1192/// See the [wikipedia-article on k-means-clustering](https://en.wikipedia.org/wiki/K-means_clustering)
1193/// for more information.
1194#[derive(Clone, Debug)]
1195pub struct KMeans {
1196    /// The points to be clustered.
1197    points: Vec<Point>,
1198    /// A cache for storing already-calculated costs of clusters.
1199    costs: Costs,
1200}
1201impl Cost for KMeans {
1202    #[inline]
1203    fn num_points(&self) -> usize {
1204        self.points.len()
1205    }
1206    #[inline]
1207    fn cost(&mut self, cluster: Cluster) -> f64 {
1208        *self.costs.entry(cluster).or_insert_with(|| {
1209            let first_point_dimensions =
1210                // SAFETY:
1211                // [`verify_points`] ensures that we always have at least one point.
1212                unsafe { self.points.first().unwrap_unchecked() }.raw_dim();
1213            let mut center = Array1::zeros(first_point_dimensions);
1214
1215            // For some reason, this is 30% faster than a for-loop.
1216            cluster
1217                .iter()
1218                // SAFETY:
1219                // [`Cost::cost`] promises us that this index is in-bounds.
1220                .for_each(|i| center += unsafe { self.points.get_unchecked(i) });
1221
1222            // We never divide by 0 here.
1223            center /= f64::from(cluster.len());
1224            cluster
1225                .iter()
1226                .map(|i| {
1227                    // SAFETY:
1228                    // [`Cost::cost`] promises us that this index is in-bounds.
1229                    let p = unsafe { self.points.get_unchecked(i) };
1230                    (p - &center).map(|x| x.powi(2)).sum()
1231                })
1232                .sum()
1233        })
1234    }
1235    #[inline]
1236    fn approximate_clusterings(&mut self) -> Vec<(f64, Clustering)> {
1237        use clustering::kmeans;
1238        let mut results = Vec::with_capacity(self.num_points() + 1);
1239        results.push((0.0, Clustering::default()));
1240        let max_iter = 1000;
1241        let samples: Vec<Vec<f64>> = self
1242            .points
1243            .iter()
1244            .map(|x| x.into_iter().copied().collect())
1245            .collect();
1246        results.extend((1..=self.num_points()).map(|k| {
1247            let kmeans_clustering = kmeans(k, &samples, max_iter);
1248            let mut clusters = vec![Cluster::new(); k];
1249            for (point_ix, cluster_ix) in kmeans_clustering.membership.iter().enumerate() {
1250                clusters
1251                    .get_mut(*cluster_ix)
1252                    .expect("Cluster index out of range")
1253                    .insert(point_ix);
1254            }
1255            let clustering: Clustering = clusters.into_iter().collect();
1256            (self.total_cost(&clustering), clustering)
1257        }));
1258        results
1259    }
1260}
1261impl KMeans {
1262    /// Construct a new `k`-means clustering instance from a slice of points.
1263    ///
1264    /// # Examples
1265    ///
1266    /// ```
1267    /// use ndarray::array;
1268    /// use exact_clustering::KMeans;
1269    ///
1270    /// KMeans::new(&[array![0.0, 0.0], array![1.0, 2.0]]).unwrap();
1271    /// ```
1272    #[inline]
1273    pub fn new(points: &[Point]) -> Result<Self, Error> {
1274        let verified_points = verify_points(points)?;
1275        Ok(Self {
1276            points: verified_points.to_vec(),
1277            costs: Costs::default(),
1278        })
1279    }
1280}
1281
1282/// A weighted clustering-problem where each center can be any point in the metric space.
1283///
1284/// Use [`KMeans`] instead if all your points have the same weight.
1285///
1286/// The metric space is the same space the weighted points live in.
1287///
1288/// The cost of a cluster, given a center, is the sum of the squared Euclidean distances between
1289/// the center and each point in the cluster, multiplied by the point's weight.
1290/// The center is automatically calculated to minimise the cost, which turns out to simply be the
1291/// weighted average of all point-positions in the cluster.
1292///
1293/// See the [wikipedia-article on k-means-clustering](https://en.wikipedia.org/wiki/K-means_clustering)
1294/// for more information.
1295#[derive(Clone, Debug)]
1296pub struct WeightedKMeans {
1297    /// The points to be clustered.
1298    weighted_points: Vec<WeightedPoint>,
1299    /// A cache for storing already-calculated costs of clusters.
1300    costs: Costs,
1301}
1302impl Cost for WeightedKMeans {
1303    #[inline]
1304    fn num_points(&self) -> usize {
1305        self.weighted_points.len()
1306    }
1307    #[inline]
1308    fn cost(&mut self, cluster: Cluster) -> f64 {
1309        *self.costs.entry(cluster).or_insert_with(|| {
1310            let mut total_weight = 0.0;
1311            let first_point_dimensions =
1312                // SAFETY:
1313                // [`verify_points`] ensures that we always have at least one point.
1314                unsafe { self.weighted_points.first().unwrap_unchecked() }.1.raw_dim();
1315            let mut center: Array1<f64> = Array1::zeros(first_point_dimensions);
1316
1317            // For some reason, this is 30% faster than a for-loop.
1318            // TODO: If this is hot, benchmark changes in assignments (e.g. assign let weight = weighted_point.0 first)
1319            cluster.iter().for_each(|i| {
1320                // SAFETY:
1321                // [`Cost::cost`] promises us that this index is in-bounds.
1322                let weighted_point = unsafe { self.weighted_points.get_unchecked(i) };
1323                total_weight += weighted_point.0;
1324                center += &(&weighted_point.1 * weighted_point.0);
1325            });
1326
1327            // Because `cluster` is never empty, and weights are always positive,
1328            // we never divide by 0 here.
1329            center /= total_weight;
1330
1331            cluster
1332                .iter()
1333                .map(|i| {
1334                    // SAFETY:
1335                    // [`Cost::cost`] promises us that this index is in-bounds.
1336                    let weighted_point = unsafe { self.weighted_points.get_unchecked(i) };
1337                    weighted_point.0 * (&weighted_point.1 - &center).map(|x| x.powi(2)).sum()
1338                })
1339                .sum()
1340        })
1341    }
1342}
1343impl WeightedKMeans {
1344    /// Construct a new `k`-means clustering instance.
1345    ///
1346    /// The algorithm runs significantly faster if you sort the points by weight first, merge
1347    /// points that have the same positions into one (adding their weights), and break symmetries
1348    /// by applying a small amount of noise.
1349    ///
1350    /// TODO: Do so internally instead of at callsite. This probably requires better return-values
1351    /// in the API. Afterwards, rework the `get_high_kmeans_price_of_greedy_instance` function in
1352    /// integration-tests to not sort the points on its own.
1353    ///
1354    /// # Examples
1355    ///
1356    /// ```
1357    /// use ndarray::array;
1358    /// use exact_clustering::WeightedKMeans;
1359    ///
1360    /// WeightedKMeans::new(&[(1.0, array![0.0, 0.0]), (2.0, array![1.0, 2.0])]).unwrap();
1361    /// ```
1362    #[inline]
1363    pub fn new(weighted_points: &[WeightedPoint]) -> Result<Self, Error> {
1364        let verified_weighted_points = verify_weighted_points(weighted_points)?;
1365        Ok(Self {
1366            weighted_points: verified_weighted_points.to_vec(),
1367            costs: Costs::default(),
1368        })
1369    }
1370}
1371
1372/// Construct a cluster from an iterator of point-indices.
1373///
1374/// TODO: This method only exists due to a malnourished API. An API-improvement should make
1375/// it obsolete.
1376///
1377/// # Examples
1378///
1379/// ```
1380/// use exact_clustering::cluster_from_iterator;
1381///
1382/// let cluster = cluster_from_iterator([0,2]);
1383///
1384/// assert!(cluster.contains(0));
1385/// assert!(!cluster.contains(1));
1386/// assert!(cluster.contains(2));
1387///
1388/// assert_eq!(cluster.len(), 2);
1389///
1390/// assert!(!cluster.is_empty());
1391///
1392/// let mut cluster_iter = cluster.iter();
1393/// assert_eq!(cluster_iter.next(), Some(0));
1394/// assert_eq!(cluster_iter.next(), Some(2));
1395/// assert_eq!(cluster_iter.next(), None);
1396/// ```
1397#[inline]
1398pub fn cluster_from_iterator<I: IntoIterator<Item = usize>>(it: I) -> Cluster {
1399    let mut cluster = Cluster::new();
1400    for i in it {
1401        cluster.insert(i);
1402    }
1403    cluster
1404}
1405
1406#[cfg(test)]
1407mod tests {
1408    use super::*;
1409    use core::f64::consts::SQRT_2;
1410    use itertools::Itertools as _;
1411    use ndarray::array;
1412    use smallvec::smallvec;
1413    use std::panic::catch_unwind;
1414
1415    #[test]
1416    #[should_panic(
1417        expected = "Throughout the entire implementation, we should never to add the same point twice."
1418    )]
1419    fn cluster_double_insert() {
1420        let mut cluster = Cluster::singleton(7);
1421        cluster.insert(7);
1422    }
1423
1424    #[test]
1425    #[should_panic(
1426        expected = "Troughout the entire implementation, we should never be merging intersecting clusters."
1427    )]
1428    fn cluster_intersecting_merge() {
1429        let mut cluster7 = Cluster::singleton(7);
1430        let mut cluster9 = Cluster::singleton(7);
1431        cluster7.insert(8);
1432        cluster9.insert(8);
1433        cluster7.union_with(cluster9);
1434    }
1435
1436    #[test]
1437    fn cluster() {
1438        for i in 0..8 {
1439            let cluster = Cluster::singleton(i);
1440            assert!(!cluster.is_empty());
1441            assert_eq!(cluster.len(), 1);
1442            assert_eq!(cluster.iter().collect_vec(), vec![i]);
1443            for j in 0..8 {
1444                assert_eq!(cluster.contains(j), j == i);
1445                let cluster2 = {
1446                    let mut cluster2 = cluster;
1447                    if i != j {
1448                        cluster2.insert(j);
1449                    }
1450                    assert!(!cluster2.is_empty());
1451                    cluster2
1452                };
1453                assert!(!cluster2.is_empty());
1454                assert_eq!(cluster2.len(), if i == j { 1 } else { 2 });
1455                assert_eq!(
1456                    cluster2.iter().collect_vec(),
1457                    match i.cmp(&j) {
1458                        cmp::Ordering::Less => vec![i, j],
1459                        cmp::Ordering::Equal => vec![i],
1460                        cmp::Ordering::Greater => vec![j, i],
1461                    }
1462                );
1463            }
1464        }
1465        let mut cluster_div_3 = Cluster::new();
1466        let mut cluster_div_5 = Cluster::new();
1467        assert!(cluster_div_3.is_empty());
1468        assert!(cluster_div_5.is_empty());
1469        // Only go up to 14, we don't want any intersections between the two.
1470        for i in 1..=14 {
1471            if i % 3 == 0 {
1472                cluster_div_3.insert(i);
1473                assert!(!cluster_div_3.is_empty());
1474            }
1475            if i % 5 == 0 {
1476                cluster_div_5.insert(i);
1477                assert!(!cluster_div_5.is_empty());
1478            }
1479        }
1480        assert_eq!(cluster_div_3.iter().collect_vec(), vec![3, 6, 9, 12]);
1481        assert_eq!(cluster_div_5.iter().collect_vec(), vec![5, 10]);
1482        let merged = {
1483            let mut merged = cluster_div_3;
1484            merged.union_with(cluster_div_5);
1485            merged
1486        };
1487        assert_eq!(merged.iter().collect_vec(), vec![3, 5, 6, 9, 10, 12]);
1488
1489        assert_eq!(merged.to_string(), "...#.##..##.#...................");
1490    }
1491
1492    #[expect(clippy::float_cmp, reason = "This should be exact.")]
1493    #[expect(
1494        clippy::assertions_on_result_states,
1495        reason = "We'd like to catch the errors."
1496    )]
1497    #[test]
1498    fn max_ratio() {
1499        assert_eq!(MaxRatio::new(3.0, 1.5).0, 2.0);
1500        assert_eq!(MaxRatio::new(SQRT_2, SQRT_2).0, 1.0);
1501        assert_eq!(MaxRatio::new(SQRT_2, 0.0).0, f64::INFINITY);
1502        assert_eq!(MaxRatio::new(SQRT_2, -0.0).0, f64::INFINITY);
1503        assert_eq!(MaxRatio::new(0.0, 0.0).0, 1.0);
1504        assert_eq!(MaxRatio::new(-0.0, 0.0).0, 1.0);
1505        assert_eq!(MaxRatio::new(0.0, -0.0).0, 1.0);
1506        assert_eq!(MaxRatio::new(-0.0, -0.0).0, 1.0);
1507        assert!(catch_unwind(|| MaxRatio::new(1.0 - 1e-3, 1.0)).is_err());
1508        assert!(catch_unwind(|| MaxRatio::new(1.0 - 1e-12, 1.0)).is_ok());
1509        assert!(catch_unwind(|| MaxRatio::new(0.0 - 1e-12, 0.0)).is_err());
1510        assert!(catch_unwind(|| MaxRatio::new(f64::INFINITY, 1.0)).is_err());
1511        assert!(catch_unwind(|| MaxRatio::new(f64::NAN, 1.0)).is_err());
1512        assert!(catch_unwind(|| MaxRatio::new(f64::NEG_INFINITY, 1.0)).is_err());
1513        assert!(catch_unwind(|| MaxRatio::new(1.0, f64::INFINITY)).is_err());
1514        assert!(catch_unwind(|| MaxRatio::new(1.0, f64::NAN)).is_err());
1515        assert!(catch_unwind(|| MaxRatio::new(1.0, f64::NEG_INFINITY)).is_err());
1516        assert!(catch_unwind(|| MaxRatio::new(1.0, 0.0)).is_ok());
1517        assert!(catch_unwind(|| MaxRatio::new(1.0, -1e-12)).is_err());
1518    }
1519
1520    macro_rules! clusterings {
1521        ( $( [ $( [ $( $num:expr ),* ] ),* ] ),* $(,)? ) => {
1522            [
1523                $(
1524                    vec![
1525                        $(
1526                            cluster_from_iterator([$( $num ),*]),
1527                        )*
1528                    ],
1529                )*
1530            ]
1531        }
1532    }
1533
1534    #[test]
1535    fn node_merge_multiple() {
1536        fn clusters_are_correct(
1537            expected_clusterings: &[Vec<Cluster>],
1538            nodes: &[ClusteringNodeMergeMultiple],
1539        ) {
1540            let actual = nodes.iter().map(|x| x.clusters.to_vec()).collect_vec();
1541            assert_eq!(
1542                expected_clusterings, actual,
1543                "Clustering should match expected clustering. Maybe the order of returned Clusters has changed?"
1544            );
1545        }
1546        let mut kmedian =
1547            KMedian::l2_squared(&[array![0.0], array![1.0], array![2.0], array![3.0]])
1548                .expect("Creating kmedian should not fail.");
1549        let mut update_nodes = |nodes: &mut Vec<ClusteringNodeMergeMultiple>| {
1550            *nodes = nodes
1551                .iter()
1552                .flat_map(|n| n.get_all_merges(&mut kmedian))
1553                .collect();
1554        };
1555        let mut nodes = vec![ClusteringNodeMergeMultiple::new_singletons(4)];
1556        let expected_init_clusters = smallvec![
1557            Cluster::singleton(0),
1558            Cluster::singleton(1),
1559            Cluster::singleton(2),
1560            Cluster::singleton(3)
1561        ];
1562        assert_eq!(
1563            nodes,
1564            vec![ClusteringNodeMergeMultiple {
1565                clusters: expected_init_clusters,
1566                cost: f64::NAN,
1567            }],
1568            "Testing nodes for equality should only depend on clusters, not on their cost."
1569        );
1570        clusters_are_correct(&clusterings![[[0], [1], [2], [3]]], &nodes);
1571
1572        update_nodes(&mut nodes);
1573        clusters_are_correct(
1574            &clusterings![
1575                [[0, 1], [2], [3]],
1576                [[1], [0, 2], [3]],
1577                [[1], [2], [0, 3]],
1578                [[0], [1, 2], [3]],
1579                [[0], [2], [1, 3]],
1580                [[0], [1], [2, 3]],
1581            ],
1582            &nodes,
1583        );
1584
1585        update_nodes(&mut nodes);
1586        clusters_are_correct(
1587            &clusterings![
1588                [[0, 1, 2], [3]],
1589                [[2], [0, 1, 3]],
1590                [[0, 1], [2, 3]],
1591                [[1, 0, 2], [3]],
1592                [[0, 2], [1, 3]],
1593                [[1], [0, 2, 3]],
1594                [[1, 2], [0, 3]],
1595                [[2], [1, 0, 3]],
1596                [[1], [2, 0, 3]],
1597                [[0, 1, 2], [3]],
1598                [[1, 2], [0, 3]],
1599                [[0], [1, 2, 3]],
1600                [[0, 2], [1, 3]],
1601                [[2], [0, 1, 3]],
1602                [[0], [2, 1, 3]],
1603                [[0, 1], [2, 3]],
1604                [[1], [0, 2, 3]],
1605                [[0], [1, 2, 3]],
1606            ],
1607            &nodes,
1608        );
1609
1610        update_nodes(&mut nodes);
1611        clusters_are_correct(&vec![vec![Cluster(15)]; 18], &nodes);
1612    }
1613
1614    #[test]
1615    #[should_panic(expected = "The clusters should always be sorted, to prevent duplicates.")]
1616    fn unsorted_node_merge_multiple() {
1617        let unsorted = ClusteringNodeMergeMultiple {
1618            clusters: smallvec![Cluster(1), Cluster(0)],
1619            cost: 0.0,
1620        };
1621        let mut small_kmedian =
1622            KMedian::l1(&[array![0.0], array![1.0]]).expect("Creating kmedian should not fail.");
1623        let _: Vec<_> = unsorted
1624            .get_all_merges(&mut small_kmedian) // This should fail.
1625            .into_iter()
1626            .collect_vec();
1627    }
1628
1629    #[test]
1630    fn node_merge_single() {
1631        fn clusters_are_correct(
1632            expected_clusterings: &[Vec<Cluster>],
1633            nodes: &[ClusteringNodeMergeSingle],
1634        ) {
1635            let actual = nodes.iter().map(|x| x.clusters.to_vec()).collect_vec();
1636            assert_eq!(
1637                expected_clusterings, actual,
1638                "Clustering should match expected clustering. Maybe the order of returned Clusters has changed?"
1639            );
1640        }
1641        let mut kmedian =
1642            KMedian::l2_squared(&[array![0.0], array![1.0], array![2.0], array![3.0]])
1643                .expect("Creating kmedian should not fail.");
1644        let mut update_nodes = |nodes: &mut Vec<ClusteringNodeMergeSingle>| {
1645            *nodes = nodes
1646                .iter()
1647                .flat_map(|n| n.get_next_nodes(&mut kmedian, 3).collect_vec())
1648                .collect();
1649        };
1650        let mut nodes = vec![ClusteringNodeMergeSingle::empty()];
1651        clusters_are_correct(&clusterings![[]], &nodes);
1652
1653        update_nodes(&mut nodes);
1654        clusters_are_correct(&clusterings![[[0]]], &nodes);
1655
1656        update_nodes(&mut nodes);
1657        clusters_are_correct(&clusterings![[[0, 1]], [[0], [1]]], &nodes);
1658
1659        update_nodes(&mut nodes);
1660        clusters_are_correct(
1661            &clusterings![
1662                [[0, 1, 2]],
1663                [[0, 1], [2]],
1664                [[0, 2], [1]],
1665                [[0], [1, 2]],
1666                [[0], [1], [2]],
1667            ],
1668            &nodes,
1669        );
1670
1671        update_nodes(&mut nodes);
1672        clusters_are_correct(
1673            &clusterings![
1674                [[0, 1, 2, 3]],
1675                [[0, 1, 2], [3]],
1676                [[0, 1, 3], [2]],
1677                [[0, 1], [2, 3]],
1678                [[0, 1], [2], [3]],
1679                [[0, 2, 3], [1]],
1680                [[0, 2], [1, 3]],
1681                [[0, 2], [1], [3]],
1682                [[0, 3], [1, 2]],
1683                [[0], [1, 2, 3]],
1684                [[0], [1, 2], [3]],
1685                [[0, 3], [1], [2]],
1686                [[0], [1, 3], [2]],
1687                [[0], [1], [2, 3]],
1688                // Notice that [[0],[1],[2],[3]] is not in this list.
1689            ],
1690            &nodes,
1691        );
1692    }
1693
1694    #[test]
1695    fn infinite_loop_optimise_locally() {
1696        // Due to floating-point inaccuracies, `optimise_locally` could enter an infinite loop if
1697        // one accepts an "improvement" as "improves cost by some positive amount".
1698
1699        // Magic values that came up during random search.
1700        let (weight_a, point_a) = (0.588_906_661, array![-0.487_778_761_130_834]);
1701        let (weight_b, point_b) = (0.434_371_596, array![-0.438_191_407_837_575]);
1702        let points = [
1703            (weight_a, -point_a.clone()),
1704            (weight_b, -point_b.clone()),
1705            (1.0, array![0.0]),
1706            (weight_a, point_a),
1707            (weight_b, point_b),
1708        ];
1709        let mut kmedian = KMedian::weighted_l1(&points).expect("Creating kmedian should not fail.");
1710
1711        let mut clustering = ClusteringNodeMergeMultiple {
1712            clusters: SmallVec::from_iter([
1713                cluster_from_iterator([0, 1, 2]),
1714                cluster_from_iterator([3, 4]),
1715            ]),
1716            cost: 0.488_933_068_284_744_25,
1717        };
1718
1719        // In a careless implementation, this would enter an infinite loop.
1720        clustering.optimise_locally(&mut kmedian);
1721    }
1722
1723    #[test]
1724    fn infinite_loop_optimise_locally_1() {
1725        // Another batch of numbers that came up tails in random search
1726        // Regrettably still looped before v0.4.0.
1727        let points = vec![
1728            (1.870_423_609_633_216e24, array![1000.0, -1000.0, 1000.0]),
1729            (3.817_589_201_683_946e23, array![1000.0, 1000.0, -1000.0]),
1730            (2.074_998_884_450_784_5e21, array![1000.0, 1000.0, 1000.0]),
1731            (
1732                1.0,
1733                array![
1734                    -400.240_609_956_200_4,
1735                    616.506_453_035_030_1,
1736                    -79.475_319_067_602_64
1737                ],
1738            ),
1739            (1.0, array![-1000.0, 415.010_128_673_398_5, 1000.0]),
1740        ];
1741        let mut kmedian = KMedian::weighted_l1(&points).expect("Creating kmedian should not fail.");
1742
1743        let mut clustering = ClusteringNodeMergeMultiple {
1744            clusters: SmallVec::from_iter([
1745                cluster_from_iterator([0, 2, 4]),
1746                cluster_from_iterator([1, 3]),
1747            ]),
1748            cost: 4.149_997_768_901_569e24,
1749        };
1750
1751        // This used to loop infinitely
1752        clustering.optimise_locally(&mut kmedian);
1753    }
1754}