petal_clustering/
hdbscan.rs

1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::ops::{AddAssign, DivAssign, Sub};
4
5use itertools::Itertools;
6use ndarray::{Array1, ArrayBase, ArrayView1, Data, Ix2};
7use num_traits::{float::FloatCore, FromPrimitive};
8use petal_neighbors::distance::{Euclidean, Metric};
9use petal_neighbors::BallTree;
10use serde::{Deserialize, Serialize};
11
12use super::Fit;
13use crate::mst::{condense_mst, mst_linkage, Boruvka};
14use crate::union_find::TreeUnionFind;
15
16/// HDBSCAN (hierarchical density-based spatial clustering of applications with noise)
17/// clustering algorithm.
18///
19/// # Examples
20///
21/// ```
22/// use ndarray::array;
23/// use petal_neighbors::distance::Euclidean;
24/// use petal_clustering::{HDbscan, Fit};
25///
26/// let points = array![
27///             [1.0, 2.0],
28///             [1.1, 2.2],
29///             [0.9, 1.9],
30///             [1.0, 2.1],
31///             [-2.0, 3.0],
32///             [-2.2, 3.1],
33///         ];
34/// let mut hdbscan = HDbscan {
35///    alpha: 1.,
36///    min_samples: 2,
37///    min_cluster_size: 2,
38///    metric: Euclidean::default(),
39///    boruvka: false,
40/// };
41/// let (clusters, outliers, _outlier_scores) = hdbscan.fit(&points, None);
42/// assert_eq!(clusters.len(), 2);   // two clusters found
43///
44/// assert_eq!(
45///     outliers.len(),
46///     points.nrows() - clusters.values().fold(0, |acc, v| acc + v.len()));
47/// ```
48#[derive(Debug, Deserialize, Serialize)]
49pub struct HDbscan<A, M> {
50    /// The radius of a neighborhood.
51    pub alpha: A,
52
53    /// The minimum number of points required to form a dense region.
54    pub min_samples: usize,
55    pub min_cluster_size: usize,
56    pub metric: M,
57    pub boruvka: bool,
58}
59
60impl<A> Default for HDbscan<A, Euclidean>
61where
62    A: FloatCore,
63{
64    fn default() -> Self {
65        Self {
66            alpha: A::one(),
67            min_samples: 15,
68            min_cluster_size: 15,
69            metric: Euclidean::default(),
70            boruvka: true,
71        }
72    }
73}
74
75/// Fits the HDBSCAN clustering algorithm to the given input data.
76///
77/// # Parameters
78/// - `input`: A 2D array representing the dataset to cluster. Each row corresponds to a data point.
79/// - `partial_labels`: An optional parameter for prelabelled data.
80///
81/// # Returns
82/// A tuple containing:
83/// - `HashMap<usize, Vec<usize>>`: A mapping of cluster IDs to the indices of points in each cluster.
84/// - `Vec<usize>`: A vector of indices representing the noise points that do not belong to any cluster.
85/// - `Vec<A>`: A vector of outlier scores for each data point.
86///
87/// # Notes
88/// - The outlier scores are computed using the GLOSH algorithm.
89/// - If `partial_labels` is provided, the algorithm will perform semi-supervised clustering using BC (`BCubed`) algorithm,
90///   otherwise, it will perform unsupervised clustering using Excess of Mass (`EoM`) algorithm.
91///
92/// # References
93/// - Campello, Ricardo JGB, et al. "Hierarchical density estimates for data clustering, visualization, and outlier detection."
94///   ACM Transactions on Knowledge Discovery from Data (TKDD) 10.1 (2015): 1-51.
95/// - Castro Gertrudes, Jadson, et al. "A unified view of density-based methods for semi-supervised clustering and classification."
96///   Data mining and knowledge discovery 33.6 (2019): 1894-1952.
97impl<S, A, M>
98    Fit<
99        ArrayBase<S, Ix2>,
100        HashMap<usize, Vec<usize>>,
101        (HashMap<usize, Vec<usize>>, Vec<usize>, Vec<A>),
102    > for HDbscan<A, M>
103where
104    A: AddAssign + DivAssign + FloatCore + FromPrimitive + Sync + Send,
105    S: Data<Elem = A>,
106    M: Metric<A> + Clone + Sync + Send,
107{
108    fn fit(
109        &mut self,
110        input: &ArrayBase<S, Ix2>,
111        partial_labels: Option<&HashMap<usize, Vec<usize>>>,
112    ) -> (HashMap<usize, Vec<usize>>, Vec<usize>, Vec<A>) {
113        if input.is_empty() {
114            return (HashMap::new(), Vec::new(), Vec::new());
115        }
116        let input = input.as_standard_layout();
117        let db = BallTree::new(input.view(), self.metric.clone()).expect("non-empty array");
118
119        let (mut mst, _offset) = if self.boruvka {
120            let boruvka = Boruvka::new(db, self.min_samples);
121            boruvka.min_spanning_tree().into_raw_vec_and_offset()
122        } else {
123            let core_distances = Array1::from_vec(
124                input
125                    .rows()
126                    .into_iter()
127                    .map(|r| {
128                        db.query(&r, self.min_samples)
129                            .1
130                            .last()
131                            .copied()
132                            .expect("at least one point should be returned")
133                    })
134                    .collect(),
135            );
136            mst_linkage(
137                input.view(),
138                &self.metric,
139                core_distances.view(),
140                self.alpha,
141            )
142            .into_raw_vec_and_offset()
143        };
144
145        mst.sort_unstable_by(|a, b| a.2.partial_cmp(&(b.2)).expect("invalid distance"));
146        let labeled = label(&mst);
147        let condensed = condense_mst(&labeled, self.min_cluster_size);
148        let outlier_scores = glosh(&condensed, self.min_cluster_size);
149        let (clusters, outliers) =
150            find_clusters(&Array1::from_vec(condensed).view(), partial_labels);
151        (clusters, outliers, outlier_scores)
152    }
153}
154
155fn label<A: FloatCore>(mst: &[(usize, usize, A)]) -> Vec<(usize, usize, A, usize)> {
156    let n = mst.len() + 1;
157    let mut result: Vec<(usize, usize, A, usize)> = Vec::with_capacity(2 * n);
158    let mut next_label = n;
159    let mut label = (0..2 * n).collect::<Vec<_>>(); // labels of subtrees
160    let mut sizes = [vec![1; n], vec![0; n]].concat(); // sizes of subtrees
161    let mut uf = TreeUnionFind::new(n);
162
163    // HDBSCAN merges subtrees in the order of eps (distance)
164    // where ties in eps should be merged at the same time:
165    for (eps, edges) in &mst.iter().chunk_by(|(_, _, eps)| *eps) {
166        let edges = edges.collect::<Vec<_>>();
167
168        // Collect unique subtree roots (children)
169        let subtree_roots = edges
170            .iter()
171            .flat_map(|(u, v, _)| [uf.find(*u), uf.find(*v)])
172            .unique()
173            .collect::<Vec<_>>();
174
175        // Merge the subtrees
176        for (u, v, _) in edges {
177            uf.union(*u, *v);
178        }
179
180        // Assign parent-child labels
181        let mut level: HashMap<usize, usize> = HashMap::new();
182        for child in subtree_roots {
183            let parent = uf.find(child);
184            let parent_label = level.entry(parent).or_insert_with(|| {
185                next_label += 1;
186                next_label - 1
187            });
188            let child_label = label[child];
189            result.push((*parent_label, child_label, eps, sizes[child_label]));
190            sizes[*parent_label] += sizes[child_label];
191            label[child] = *parent_label;
192        }
193    }
194    result
195}
196
197fn get_stability<A: FloatCore + FromPrimitive + AddAssign + Sub>(
198    condensed_tree: &ArrayView1<(usize, usize, A, usize)>,
199) -> HashMap<usize, A> {
200    let mut births: HashMap<_, _> = condensed_tree.iter().fold(HashMap::new(), |mut births, v| {
201        let entry = births.entry(v.1).or_insert(v.2);
202        if *entry > v.2 {
203            *entry = v.2;
204        }
205        births
206    });
207
208    let min_parent = condensed_tree
209        .iter()
210        .min_by_key(|v| v.0)
211        .expect("couldn't find the smallest cluster")
212        .0;
213
214    let entry = births.entry(min_parent).or_insert_with(A::zero);
215    *entry = A::zero();
216
217    condensed_tree.iter().fold(
218        HashMap::new(),
219        |mut stability, (parent, _child, lambda, size)| {
220            let entry = stability.entry(*parent).or_insert_with(A::zero);
221            let birth = births.get(parent).expect("invalid child node.");
222            let Some(size) = A::from_usize(*size) else {
223                panic!("invalid size");
224            };
225            *entry += (*lambda - *birth) * size;
226            stability
227        },
228    )
229}
230
231fn get_bcubed<A: FloatCore + FromPrimitive + AddAssign + Sub>(
232    condensed_tree: &ArrayView1<(usize, usize, A, usize)>,
233    partial_labels: &HashMap<usize, Vec<usize>>,
234) -> HashMap<usize, A> {
235    let num_labelled = partial_labels.values().fold(0, |acc, v| acc + v.len());
236
237    // min_parent gives the number of events in the hierarchy
238    let num_events = condensed_tree
239        .iter()
240        .map(|(parent, _, _, _)| *parent)
241        .min()
242        .map_or(0, |min_parent| min_parent);
243
244    // initialize the labels with the partial labels (if any)
245    let mut labels: Vec<Option<usize>> = vec![None; num_events];
246    for (label, points) in partial_labels {
247        for point in points {
248            labels[*point] = Some(*label);
249        }
250    }
251
252    let num_clusters = condensed_tree
253        .iter()
254        .map(|(parent, child, _, _)| parent.max(child))
255        .max()
256        .expect("empty condensed_mst");
257
258    // bottom-up traverse the hierarchy to keep track of the counts of the labelled points
259    // (same with the reverse order iteration on the condensed_mst)
260    let mut label_map: HashMap<usize, HashMap<usize, A>> = HashMap::new();
261    let mut num_labels: Vec<A> = vec![A::zero(); num_clusters + 1];
262    let mut bcubed: Vec<A> = vec![A::zero(); num_clusters + 1];
263    for (parent, child, _, _) in condensed_tree.iter().rev() {
264        if *child < num_events {
265            // point is labelled
266            if let Some(label) = labels[*child] {
267                let entry = label_map.entry(*parent).or_default();
268                let count = entry.entry(label).or_insert(A::zero());
269                *count += A::one();
270                num_labels[*parent] += A::one();
271            }
272        } else {
273            // extend with child cluster count map
274            let child_map = label_map.remove(child).unwrap_or_default(); // remove to save space
275            let child_num_labelled = num_labels[*child];
276
277            let parent_map = label_map.entry(*parent).or_default();
278            for (label, count) in child_map {
279                // compute bcubed of the child cluster
280                let precision = count / child_num_labelled;
281                let recall = count / A::from(partial_labels[&label].len()).expect("invalid count");
282                let fmeasure =
283                    A::from(2).expect("invalid count") * precision * recall / (precision + recall);
284                bcubed[*child] += count * fmeasure / A::from(num_labelled).expect("invalid count");
285
286                // update the parent cluster label count map
287                let c = parent_map.entry(label).or_insert(A::zero());
288                *c += count;
289                num_labels[*parent] += count;
290            }
291        }
292    }
293
294    condensed_tree
295        .iter()
296        .fold(HashMap::new(), |mut scores, (parent, _child, _, _)| {
297            scores.entry(*parent).or_insert_with(|| bcubed[*parent]);
298            scores
299        })
300}
301
302fn find_clusters<A: FloatCore + FromPrimitive + AddAssign + Sub>(
303    condensed_tree: &ArrayView1<(usize, usize, A, usize)>,
304    partial_labels: Option<&HashMap<usize, Vec<usize>>>,
305) -> (HashMap<usize, Vec<usize>>, Vec<usize>) {
306    let mut stability = get_stability(condensed_tree);
307    let mut bcubed = if let Some(partial_labels) = partial_labels {
308        get_bcubed(condensed_tree, partial_labels)
309    } else {
310        HashMap::new()
311    };
312
313    let mut nodes: Vec<_> = stability.keys().copied().collect();
314    nodes.sort_unstable();
315    nodes.remove(0); // remove the root node
316
317    let adj: HashMap<usize, Vec<usize>> =
318        condensed_tree
319            .iter()
320            .fold(HashMap::new(), |mut adj, (p, c, _, _)| {
321                adj.entry(*p).or_default().push(*c);
322                adj
323            });
324
325    let num_clusters = condensed_tree
326        .iter()
327        .max_by_key(|v| v.0)
328        .expect("no maximum parent available")
329        .0;
330
331    // bottom-up traverse the nodes to select the most top-level clusters
332    let mut clusters: Vec<Option<usize>> = vec![None; num_clusters + 1];
333    for node in nodes.iter().rev() {
334        let subtree_stability = adj.get(node).map_or(A::zero(), |children| {
335            children.iter().fold(A::zero(), |acc, c| {
336                acc + *stability.get(c).unwrap_or(&A::zero())
337            })
338        });
339
340        let subtree_bcubed = adj.get(node).map_or(A::zero(), |children| {
341            children.iter().fold(A::zero(), |acc, c| {
342                acc + *bcubed.get(c).unwrap_or(&A::zero())
343            })
344        });
345
346        stability.entry(*node).and_modify(|node_stability| {
347            let node_bcubed = bcubed.entry(*node).or_insert(A::zero());
348            // ties are broken by stability
349            if *node_bcubed > subtree_bcubed
350                || (*node_bcubed == subtree_bcubed && *node_stability >= subtree_stability)
351            {
352                clusters[*node] = Some(*node);
353            }
354            *node_bcubed = node_bcubed.max(subtree_bcubed);
355            *node_stability = node_stability.max(subtree_stability);
356        });
357    }
358
359    // now tow-down pass to assign the clusters
360    for node in nodes {
361        if let Some(cluster) = clusters[node] {
362            let children = adj.get(&node).expect("corrupted adjacency dictionary");
363            for child in children {
364                clusters[*child] = Some(cluster);
365            }
366        }
367    }
368
369    let num_events = condensed_tree
370        .iter()
371        .min_by_key(|v| v.0)
372        .expect("no minimum parent available")
373        .0;
374
375    let mut res_clusters: HashMap<_, Vec<_>> = HashMap::new();
376    let mut outliers = vec![];
377    for (point, cluster) in clusters.iter().enumerate().take(num_events) {
378        if let Some(cluster) = cluster {
379            let c = res_clusters.entry(*cluster).or_default();
380            c.push(point);
381        } else {
382            outliers.push(point);
383        }
384    }
385    (res_clusters, outliers)
386}
387
388// GLOSH: Global-Local Outlier Score from Hierarchies
389// Reference: https://dl.acm.org/doi/10.1145/2733381
390//
391// Given the following hierarchy (min_cluster_size = 3),
392//               Root
393//              /    \
394//             A     ...
395// eps_x ->   / \
396//           x   A
397//              / \
398//             y   A
399//                /|\   <- eps_A: A is still a cluster w.r.t. min_cluster_size at this level
400//               a b c
401//
402// To compute the outlier score of point x, we need:
403//    - eps_x: eps that x joins to cluster A (A is the first cluster that x joins to)
404//    - eps_A: lowest eps that A or any of A's child clusters survives w.r.t. min_cluster_size.
405// Then, the outlier score of x is defined as:
406//    score(x) = 1 - eps_A / eps_x
407//
408// Since we are working with density lambda values (where lambda = 1/eps):
409//    lambda_x = 1 / eps_x
410//    lambda_A = 1 / eps_A
411//    score(x) = 1 - lambda_x / lambda_A
412fn glosh<A: FloatCore>(
413    condensed_mst: &[(usize, usize, A, usize)],
414    min_cluster_size: usize,
415) -> Vec<A> {
416    let deaths = max_lambdas(condensed_mst, min_cluster_size);
417
418    // min_parent gives the number of events in the hierarchy
419    let num_events = condensed_mst
420        .iter()
421        .map(|(parent, _, _, _)| *parent)
422        .min()
423        .map_or(0, |min_parent| min_parent);
424
425    let mut scores = vec![A::zero(); num_events];
426    for (parent, child, lambda, _) in condensed_mst {
427        if *child >= num_events {
428            continue;
429        }
430        let lambda_max = deaths[*parent];
431        if lambda_max == A::zero() {
432            scores[*child] = A::zero();
433        } else {
434            scores[*child] = (lambda_max - *lambda) / lambda_max;
435        }
436    }
437    scores
438}
439
440// Return the maximum lambda value (min eps) for each cluster C such that
441// the cluster or any of its child clusters has at least min_cluster_size points.
442fn max_lambdas<A: FloatCore>(
443    condensed_mst: &[(usize, usize, A, usize)],
444    min_cluster_size: usize,
445) -> Vec<A> {
446    let num_clusters = condensed_mst
447        .iter()
448        .map(|(parent, child, _, _)| parent.max(child))
449        .max()
450        .expect("empty condensed_mst");
451
452    // bottom-up traverse the hierarchy to keep track of the maximum lambda values
453    // (same with the reverse order iteration on the condensed_mst)
454    let mut parent_sizes: Vec<usize> = vec![0; num_clusters + 1];
455    let mut deaths_arr: Vec<A> = vec![A::zero(); num_clusters + 1];
456    for (parent, child, lambda, child_size) in condensed_mst.iter().rev() {
457        parent_sizes[*parent] += *child_size;
458        if parent_sizes[*parent] >= min_cluster_size {
459            deaths_arr[*parent] = deaths_arr[*parent].max(*lambda);
460        }
461        if *child_size >= min_cluster_size {
462            deaths_arr[*parent] = deaths_arr[*parent].max(deaths_arr[*child]);
463        }
464    }
465    deaths_arr
466}
467
468mod test {
469    #[test]
470    fn hdbscan32() {
471        use ndarray::{array, Array2};
472        use petal_neighbors::distance::Euclidean;
473
474        use crate::Fit;
475
476        let data: Array2<f32> = array![
477            [1.0, 2.0],
478            [1.1, 2.2],
479            [0.9, 1.9],
480            [1.0, 2.1],
481            [-2.0, 3.0],
482            [-2.2, 3.1],
483        ];
484        let mut hdbscan = super::HDbscan {
485            alpha: 1.,
486            min_samples: 2,
487            min_cluster_size: 2,
488            metric: Euclidean::default(),
489            boruvka: false,
490        };
491        let (clusters, outliers, _) = hdbscan.fit(&data, None);
492        assert_eq!(clusters.len(), 2);
493        assert_eq!(
494            outliers.len(),
495            data.nrows() - clusters.values().fold(0, |acc, v| acc + v.len())
496        );
497    }
498
499    #[test]
500    fn hdbscan64() {
501        use ndarray::{array, Array2};
502        use petal_neighbors::distance::Euclidean;
503
504        use crate::Fit;
505
506        let data: Array2<f64> = array![
507            [1.0, 2.0],
508            [1.1, 2.2],
509            [0.9, 1.9],
510            [1.0, 2.1],
511            [-2.0, 3.0],
512            [-2.2, 3.1],
513        ];
514        let mut hdbscan = super::HDbscan {
515            alpha: 1.,
516            min_samples: 2,
517            min_cluster_size: 2,
518            metric: Euclidean::default(),
519            boruvka: false,
520        };
521        let (clusters, outliers, _) = hdbscan.fit(&data, None);
522        assert_eq!(clusters.len(), 2);
523        assert_eq!(
524            outliers.len(),
525            data.nrows() - clusters.values().fold(0, |acc, v| acc + v.len())
526        );
527    }
528
529    #[test]
530    fn outlier_scores() {
531        use ndarray::array;
532        use petal_neighbors::distance::Euclidean;
533
534        use crate::Fit;
535
536        let data = array![
537            // Cluster A (formed at eps = √2)
538            [2., 9.],
539            [3., 9.],
540            [2., 8.],
541            [3., 8.],
542            [2., 7.],
543            [3., 7.],
544            [1., 8.],
545            [4., 8.],
546            // Cluster B (formed at eps = √8)
547            [7., 9.],
548            [7., 8.],
549            [8., 8.],
550            [8., 7.],
551            [9., 7.],
552            // Cluster C (formed at eps = 2)
553            [6., 3.],
554            [5., 2.],
555            [6., 2.],
556            [7., 2.],
557            [6., 1.],
558            // Outliers:
559            [8., 4.], // outlier1 (joins the root cluster at eps = 3.0)
560            [3., 3.], // outlier2 (joins the root cluster at eps = √13)
561        ];
562        let mut hdbscan = super::HDbscan {
563            alpha: 1.,
564            min_samples: 5,
565            min_cluster_size: 5,
566            metric: Euclidean::default(),
567            boruvka: true,
568        };
569        let (_, _, outlier_scores) = hdbscan.fit(&data, None);
570
571        // Outlier1 joins the root cluster at:
572        //      eps_outlier1 = 3.0
573        // The lowest eps that the root or any of its child clusters survive w.r.t. min_cluster_size = 5 is:
574        //      eps_Root = √2
575        // Then the outlier score of outlier1 is:
576        //      glosh(outlier1) =  1 - √2 / 3.0 = 0.53
577        let expected = 1.0 - 2.0_f64.sqrt() / 3.0_f64;
578        let actual = outlier_scores[18];
579        assert!(
580            (actual - expected).abs() < f64::EPSILON,
581            "Expected: {}, got: {}",
582            expected,
583            actual
584        );
585
586        // Outlier2 joins the root cluster at:
587        //      eps_outlier2 = √13
588        // The lowest eps that the root or any of its child clusters survive w.r.t. min_cluster_size = 5 is:
589        //      eps_root = √2
590        // Then the outlier score of outlier2 is:
591        //      glosh(outlier2) =  1 - √2 / √13 = 0.61
592        let expected = 1.0 - 2.0_f64.sqrt() / 13.0_f64.sqrt();
593        let actual = outlier_scores[19];
594        assert!(
595            (actual - expected).abs() < f64::EPSILON,
596            "Expected: {}, got: {}",
597            expected,
598            actual
599        );
600    }
601
602    #[test]
603    fn partial_labels() {
604        use std::collections::HashMap;
605
606        use ndarray::array;
607        use petal_neighbors::distance::Euclidean;
608
609        use crate::Fit;
610
611        let data = array![
612            // Group 1 (formed at eps = √2)
613            [1., 9.],
614            [2., 9.],
615            [1., 8.],
616            [2., 8.],
617            [3., 7.],
618            // Group 2 (formed at eps = √2)
619            [5., 4.],
620            [6., 4.],
621            [5., 3.],
622            [6., 3.],
623            // Group 3 (formed at eps = √2)
624            [8., 3.],
625            [9., 3.],
626            [8., 2.],
627            [9., 2.],
628            [8., 1.],
629            [9., 1.],
630            // noise (joins the root cluster at eps = √37)
631            [7., 8.],
632        ];
633        let mut hdbscan = super::HDbscan {
634            min_samples: 4,
635            min_cluster_size: 4,
636            metric: Euclidean::default(),
637            boruvka: false,
638            ..Default::default()
639        };
640
641        // Unsupervised clusters
642        let (clusters, noise, _) = hdbscan.fit(&data, None);
643        assert_eq!(clusters.len(), 2); // 2 clusters found
644        assert_eq!(noise, [15]); // 1 outlier found
645        let c1 = clusters.keys().find(|k| clusters[k].contains(&0)).unwrap();
646        assert_eq!(clusters[c1], [0, 1, 2, 3, 4]);
647        let c2 = clusters.keys().find(|k| clusters[k].contains(&5)).unwrap();
648        assert_eq!(clusters[c2], [5, 6, 7, 8, 9, 10, 11, 12, 13, 14]);
649        assert_eq!(noise, [15]);
650
651        // Empty partial labels (should return the same result as unsupervised clustering)
652        let partial_labels: HashMap<usize, Vec<usize>> = HashMap::new();
653        let (answer, noise, _) = hdbscan.fit(&data, Some(&partial_labels));
654        assert_eq!(answer, clusters);
655        assert_eq!(noise, [15]);
656
657        // Semi-supervised clustering
658        let mut partial_labels: HashMap<usize, Vec<usize>> = HashMap::new();
659        partial_labels.insert(0, vec![0]);
660        partial_labels.insert(1, vec![3, 4]);
661        partial_labels.insert(2, vec![6]);
662        partial_labels.insert(3, vec![11]);
663        let (clusters, noise, _) = hdbscan.fit(&data, Some(&partial_labels));
664        assert_eq!(clusters.len(), 3); // 3 clusters found
665        assert_eq!(noise, [15]); // 1 outlier found
666        let c1 = clusters.keys().find(|k| clusters[k].contains(&0)).unwrap();
667        assert_eq!(clusters[c1], [0, 1, 2, 3, 4]);
668        let c2 = clusters.keys().find(|k| clusters[k].contains(&5)).unwrap();
669        assert_eq!(clusters[c2], [5, 6, 7, 8]);
670        let c3 = clusters.keys().find(|k| clusters[k].contains(&9)).unwrap();
671        assert_eq!(clusters[c3], [9, 10, 11, 12, 13, 14]);
672    }
673
674    #[test]
675    fn label() {
676        let mst = vec![
677            (0, 1, 4.),
678            (2, 3, 4.),
679            (4, 5, 4.),
680            (1, 2, 7.), // <-- this (having eps = 7.0)
681            (3, 4, 7.), // <-- and this (also with eps = 7.0) should have the same parent label
682            (5, 6, 8.),
683        ];
684        // Resulting labels should be:
685        //            11
686        //           /  \        <-- eps = 8.0
687        //          10   6
688        //         / | \         <-- eps = 7.0
689        //        7  8  9
690        //       /|  |\  |\      <-- eps = 4.0
691        //      0 1  2 3 4 5
692        let labeled_mst = super::label(&mst);
693        assert_eq!(
694            labeled_mst,
695            vec![
696                (7, 0, 4., 1),
697                (7, 1, 4., 1),
698                (8, 2, 4., 1),
699                (8, 3, 4., 1),
700                (9, 4, 4., 1),
701                (9, 5, 4., 1),
702                (10, 7, 7., 2),
703                (10, 8, 7., 2),
704                (10, 9, 7., 2),
705                (11, 10, 8., 6),
706                (11, 6, 8., 1),
707            ]
708        );
709    }
710
711    #[test]
712    fn get_stability() {
713        use std::collections::HashMap;
714
715        use ndarray::arr1;
716
717        let condensed = arr1(&[
718            (7, 6, 1. / 9., 1),
719            (7, 4, 1. / 7., 1),
720            (7, 2, 1. / 7., 1),
721            (7, 1, 1. / 7., 1),
722            (7, 5, 1. / 6., 1),
723            (7, 0, 1. / 6., 1),
724            (7, 3, 1. / 6., 1),
725        ]);
726        let stability_map = super::get_stability(&condensed.view());
727        let mut answer = HashMap::new();
728        answer.insert(7, 1. / 9. + 3. / 7. + 3. / 6.);
729        assert_eq!(stability_map, answer);
730    }
731
732    #[test]
733    fn get_bcubed() {
734        use std::collections::HashMap;
735
736        use ndarray::arr1;
737
738        let condensed = arr1(&[
739            (8, 9, 1. / 10., 4),
740            (8, 10, 1. / 10., 4),
741            (9, 0, 1. / 6., 1),
742            (9, 1, 1. / 7., 1),
743            (9, 2, 1. / 7., 1),
744            (9, 3, 1. / 6., 1),
745            (10, 4, 1. / 7., 1),
746            (10, 5, 1. / 6., 1),
747            (10, 6, 1. / 9., 1),
748            (10, 7, 1. / 9., 1),
749        ]);
750        let mut partial_labels = HashMap::new();
751        partial_labels.insert(0, vec![0, 1, 4]);
752        partial_labels.insert(1, vec![5]);
753        partial_labels.insert(2, vec![7]);
754        let bcubed_map: HashMap<usize, f64> = super::get_bcubed(&condensed.view(), &partial_labels);
755        assert_eq!(bcubed_map.len(), 3);
756        assert_eq!(bcubed_map[&8], 0.0);
757        assert!((bcubed_map[&9] - 8. / 25.).abs() < f64::EPSILON);
758        assert!((bcubed_map[&10] - 4. / 15.).abs() < f64::EPSILON);
759    }
760}