1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::ops::{AddAssign, DivAssign, Sub};
4
5use itertools::Itertools;
6use ndarray::{Array1, ArrayBase, ArrayView1, Data, Ix2};
7use num_traits::{float::FloatCore, FromPrimitive};
8use petal_neighbors::distance::{Euclidean, Metric};
9use petal_neighbors::BallTree;
10use serde::{Deserialize, Serialize};
11
12use super::Fit;
13use crate::mst::{condense_mst, mst_linkage, Boruvka};
14use crate::union_find::TreeUnionFind;
15
16#[derive(Debug, Deserialize, Serialize)]
49pub struct HDbscan<A, M> {
50 pub alpha: A,
52
53 pub min_samples: usize,
55 pub min_cluster_size: usize,
56 pub metric: M,
57 pub boruvka: bool,
58}
59
60impl<A> Default for HDbscan<A, Euclidean>
61where
62 A: FloatCore,
63{
64 fn default() -> Self {
65 Self {
66 alpha: A::one(),
67 min_samples: 15,
68 min_cluster_size: 15,
69 metric: Euclidean::default(),
70 boruvka: true,
71 }
72 }
73}
74
75impl<S, A, M>
98 Fit<
99 ArrayBase<S, Ix2>,
100 HashMap<usize, Vec<usize>>,
101 (HashMap<usize, Vec<usize>>, Vec<usize>, Vec<A>),
102 > for HDbscan<A, M>
103where
104 A: AddAssign + DivAssign + FloatCore + FromPrimitive + Sync + Send,
105 S: Data<Elem = A>,
106 M: Metric<A> + Clone + Sync + Send,
107{
108 fn fit(
109 &mut self,
110 input: &ArrayBase<S, Ix2>,
111 partial_labels: Option<&HashMap<usize, Vec<usize>>>,
112 ) -> (HashMap<usize, Vec<usize>>, Vec<usize>, Vec<A>) {
113 if input.is_empty() {
114 return (HashMap::new(), Vec::new(), Vec::new());
115 }
116 let input = input.as_standard_layout();
117 let db = BallTree::new(input.view(), self.metric.clone()).expect("non-empty array");
118
119 let (mut mst, _offset) = if self.boruvka {
120 let boruvka = Boruvka::new(db, self.min_samples);
121 boruvka.min_spanning_tree().into_raw_vec_and_offset()
122 } else {
123 let core_distances = Array1::from_vec(
124 input
125 .rows()
126 .into_iter()
127 .map(|r| {
128 db.query(&r, self.min_samples)
129 .1
130 .last()
131 .copied()
132 .expect("at least one point should be returned")
133 })
134 .collect(),
135 );
136 mst_linkage(
137 input.view(),
138 &self.metric,
139 core_distances.view(),
140 self.alpha,
141 )
142 .into_raw_vec_and_offset()
143 };
144
145 mst.sort_unstable_by(|a, b| a.2.partial_cmp(&(b.2)).expect("invalid distance"));
146 let labeled = label(&mst);
147 let condensed = condense_mst(&labeled, self.min_cluster_size);
148 let outlier_scores = glosh(&condensed, self.min_cluster_size);
149 let (clusters, outliers) =
150 find_clusters(&Array1::from_vec(condensed).view(), partial_labels);
151 (clusters, outliers, outlier_scores)
152 }
153}
154
155fn label<A: FloatCore>(mst: &[(usize, usize, A)]) -> Vec<(usize, usize, A, usize)> {
156 let n = mst.len() + 1;
157 let mut result: Vec<(usize, usize, A, usize)> = Vec::with_capacity(2 * n);
158 let mut next_label = n;
159 let mut label = (0..2 * n).collect::<Vec<_>>(); let mut sizes = [vec![1; n], vec![0; n]].concat(); let mut uf = TreeUnionFind::new(n);
162
163 for (eps, edges) in &mst.iter().chunk_by(|(_, _, eps)| *eps) {
166 let edges = edges.collect::<Vec<_>>();
167
168 let subtree_roots = edges
170 .iter()
171 .flat_map(|(u, v, _)| [uf.find(*u), uf.find(*v)])
172 .unique()
173 .collect::<Vec<_>>();
174
175 for (u, v, _) in edges {
177 uf.union(*u, *v);
178 }
179
180 let mut level: HashMap<usize, usize> = HashMap::new();
182 for child in subtree_roots {
183 let parent = uf.find(child);
184 let parent_label = level.entry(parent).or_insert_with(|| {
185 next_label += 1;
186 next_label - 1
187 });
188 let child_label = label[child];
189 result.push((*parent_label, child_label, eps, sizes[child_label]));
190 sizes[*parent_label] += sizes[child_label];
191 label[child] = *parent_label;
192 }
193 }
194 result
195}
196
197fn get_stability<A: FloatCore + FromPrimitive + AddAssign + Sub>(
198 condensed_tree: &ArrayView1<(usize, usize, A, usize)>,
199) -> HashMap<usize, A> {
200 let mut births: HashMap<_, _> = condensed_tree.iter().fold(HashMap::new(), |mut births, v| {
201 let entry = births.entry(v.1).or_insert(v.2);
202 if *entry > v.2 {
203 *entry = v.2;
204 }
205 births
206 });
207
208 let min_parent = condensed_tree
209 .iter()
210 .min_by_key(|v| v.0)
211 .expect("couldn't find the smallest cluster")
212 .0;
213
214 let entry = births.entry(min_parent).or_insert_with(A::zero);
215 *entry = A::zero();
216
217 condensed_tree.iter().fold(
218 HashMap::new(),
219 |mut stability, (parent, _child, lambda, size)| {
220 let entry = stability.entry(*parent).or_insert_with(A::zero);
221 let birth = births.get(parent).expect("invalid child node.");
222 let Some(size) = A::from_usize(*size) else {
223 panic!("invalid size");
224 };
225 *entry += (*lambda - *birth) * size;
226 stability
227 },
228 )
229}
230
231fn get_bcubed<A: FloatCore + FromPrimitive + AddAssign + Sub>(
232 condensed_tree: &ArrayView1<(usize, usize, A, usize)>,
233 partial_labels: &HashMap<usize, Vec<usize>>,
234) -> HashMap<usize, A> {
235 let num_labelled = partial_labels.values().fold(0, |acc, v| acc + v.len());
236
237 let num_events = condensed_tree
239 .iter()
240 .map(|(parent, _, _, _)| *parent)
241 .min()
242 .map_or(0, |min_parent| min_parent);
243
244 let mut labels: Vec<Option<usize>> = vec![None; num_events];
246 for (label, points) in partial_labels {
247 for point in points {
248 labels[*point] = Some(*label);
249 }
250 }
251
252 let num_clusters = condensed_tree
253 .iter()
254 .map(|(parent, child, _, _)| parent.max(child))
255 .max()
256 .expect("empty condensed_mst");
257
258 let mut label_map: HashMap<usize, HashMap<usize, A>> = HashMap::new();
261 let mut num_labels: Vec<A> = vec![A::zero(); num_clusters + 1];
262 let mut bcubed: Vec<A> = vec![A::zero(); num_clusters + 1];
263 for (parent, child, _, _) in condensed_tree.iter().rev() {
264 if *child < num_events {
265 if let Some(label) = labels[*child] {
267 let entry = label_map.entry(*parent).or_default();
268 let count = entry.entry(label).or_insert(A::zero());
269 *count += A::one();
270 num_labels[*parent] += A::one();
271 }
272 } else {
273 let child_map = label_map.remove(child).unwrap_or_default(); let child_num_labelled = num_labels[*child];
276
277 let parent_map = label_map.entry(*parent).or_default();
278 for (label, count) in child_map {
279 let precision = count / child_num_labelled;
281 let recall = count / A::from(partial_labels[&label].len()).expect("invalid count");
282 let fmeasure =
283 A::from(2).expect("invalid count") * precision * recall / (precision + recall);
284 bcubed[*child] += count * fmeasure / A::from(num_labelled).expect("invalid count");
285
286 let c = parent_map.entry(label).or_insert(A::zero());
288 *c += count;
289 num_labels[*parent] += count;
290 }
291 }
292 }
293
294 condensed_tree
295 .iter()
296 .fold(HashMap::new(), |mut scores, (parent, _child, _, _)| {
297 scores.entry(*parent).or_insert_with(|| bcubed[*parent]);
298 scores
299 })
300}
301
302fn find_clusters<A: FloatCore + FromPrimitive + AddAssign + Sub>(
303 condensed_tree: &ArrayView1<(usize, usize, A, usize)>,
304 partial_labels: Option<&HashMap<usize, Vec<usize>>>,
305) -> (HashMap<usize, Vec<usize>>, Vec<usize>) {
306 let mut stability = get_stability(condensed_tree);
307 let mut bcubed = if let Some(partial_labels) = partial_labels {
308 get_bcubed(condensed_tree, partial_labels)
309 } else {
310 HashMap::new()
311 };
312
313 let mut nodes: Vec<_> = stability.keys().copied().collect();
314 nodes.sort_unstable();
315 nodes.remove(0); let adj: HashMap<usize, Vec<usize>> =
318 condensed_tree
319 .iter()
320 .fold(HashMap::new(), |mut adj, (p, c, _, _)| {
321 adj.entry(*p).or_default().push(*c);
322 adj
323 });
324
325 let num_clusters = condensed_tree
326 .iter()
327 .max_by_key(|v| v.0)
328 .expect("no maximum parent available")
329 .0;
330
331 let mut clusters: Vec<Option<usize>> = vec![None; num_clusters + 1];
333 for node in nodes.iter().rev() {
334 let subtree_stability = adj.get(node).map_or(A::zero(), |children| {
335 children.iter().fold(A::zero(), |acc, c| {
336 acc + *stability.get(c).unwrap_or(&A::zero())
337 })
338 });
339
340 let subtree_bcubed = adj.get(node).map_or(A::zero(), |children| {
341 children.iter().fold(A::zero(), |acc, c| {
342 acc + *bcubed.get(c).unwrap_or(&A::zero())
343 })
344 });
345
346 stability.entry(*node).and_modify(|node_stability| {
347 let node_bcubed = bcubed.entry(*node).or_insert(A::zero());
348 if *node_bcubed > subtree_bcubed
350 || (*node_bcubed == subtree_bcubed && *node_stability >= subtree_stability)
351 {
352 clusters[*node] = Some(*node);
353 }
354 *node_bcubed = node_bcubed.max(subtree_bcubed);
355 *node_stability = node_stability.max(subtree_stability);
356 });
357 }
358
359 for node in nodes {
361 if let Some(cluster) = clusters[node] {
362 let children = adj.get(&node).expect("corrupted adjacency dictionary");
363 for child in children {
364 clusters[*child] = Some(cluster);
365 }
366 }
367 }
368
369 let num_events = condensed_tree
370 .iter()
371 .min_by_key(|v| v.0)
372 .expect("no minimum parent available")
373 .0;
374
375 let mut res_clusters: HashMap<_, Vec<_>> = HashMap::new();
376 let mut outliers = vec![];
377 for (point, cluster) in clusters.iter().enumerate().take(num_events) {
378 if let Some(cluster) = cluster {
379 let c = res_clusters.entry(*cluster).or_default();
380 c.push(point);
381 } else {
382 outliers.push(point);
383 }
384 }
385 (res_clusters, outliers)
386}
387
388fn glosh<A: FloatCore>(
413 condensed_mst: &[(usize, usize, A, usize)],
414 min_cluster_size: usize,
415) -> Vec<A> {
416 let deaths = max_lambdas(condensed_mst, min_cluster_size);
417
418 let num_events = condensed_mst
420 .iter()
421 .map(|(parent, _, _, _)| *parent)
422 .min()
423 .map_or(0, |min_parent| min_parent);
424
425 let mut scores = vec![A::zero(); num_events];
426 for (parent, child, lambda, _) in condensed_mst {
427 if *child >= num_events {
428 continue;
429 }
430 let lambda_max = deaths[*parent];
431 if lambda_max == A::zero() {
432 scores[*child] = A::zero();
433 } else {
434 scores[*child] = (lambda_max - *lambda) / lambda_max;
435 }
436 }
437 scores
438}
439
440fn max_lambdas<A: FloatCore>(
443 condensed_mst: &[(usize, usize, A, usize)],
444 min_cluster_size: usize,
445) -> Vec<A> {
446 let num_clusters = condensed_mst
447 .iter()
448 .map(|(parent, child, _, _)| parent.max(child))
449 .max()
450 .expect("empty condensed_mst");
451
452 let mut parent_sizes: Vec<usize> = vec![0; num_clusters + 1];
455 let mut deaths_arr: Vec<A> = vec![A::zero(); num_clusters + 1];
456 for (parent, child, lambda, child_size) in condensed_mst.iter().rev() {
457 parent_sizes[*parent] += *child_size;
458 if parent_sizes[*parent] >= min_cluster_size {
459 deaths_arr[*parent] = deaths_arr[*parent].max(*lambda);
460 }
461 if *child_size >= min_cluster_size {
462 deaths_arr[*parent] = deaths_arr[*parent].max(deaths_arr[*child]);
463 }
464 }
465 deaths_arr
466}
467
468mod test {
469 #[test]
470 fn hdbscan32() {
471 use ndarray::{array, Array2};
472 use petal_neighbors::distance::Euclidean;
473
474 use crate::Fit;
475
476 let data: Array2<f32> = array![
477 [1.0, 2.0],
478 [1.1, 2.2],
479 [0.9, 1.9],
480 [1.0, 2.1],
481 [-2.0, 3.0],
482 [-2.2, 3.1],
483 ];
484 let mut hdbscan = super::HDbscan {
485 alpha: 1.,
486 min_samples: 2,
487 min_cluster_size: 2,
488 metric: Euclidean::default(),
489 boruvka: false,
490 };
491 let (clusters, outliers, _) = hdbscan.fit(&data, None);
492 assert_eq!(clusters.len(), 2);
493 assert_eq!(
494 outliers.len(),
495 data.nrows() - clusters.values().fold(0, |acc, v| acc + v.len())
496 );
497 }
498
499 #[test]
500 fn hdbscan64() {
501 use ndarray::{array, Array2};
502 use petal_neighbors::distance::Euclidean;
503
504 use crate::Fit;
505
506 let data: Array2<f64> = array![
507 [1.0, 2.0],
508 [1.1, 2.2],
509 [0.9, 1.9],
510 [1.0, 2.1],
511 [-2.0, 3.0],
512 [-2.2, 3.1],
513 ];
514 let mut hdbscan = super::HDbscan {
515 alpha: 1.,
516 min_samples: 2,
517 min_cluster_size: 2,
518 metric: Euclidean::default(),
519 boruvka: false,
520 };
521 let (clusters, outliers, _) = hdbscan.fit(&data, None);
522 assert_eq!(clusters.len(), 2);
523 assert_eq!(
524 outliers.len(),
525 data.nrows() - clusters.values().fold(0, |acc, v| acc + v.len())
526 );
527 }
528
529 #[test]
530 fn outlier_scores() {
531 use ndarray::array;
532 use petal_neighbors::distance::Euclidean;
533
534 use crate::Fit;
535
536 let data = array![
537 [2., 9.],
539 [3., 9.],
540 [2., 8.],
541 [3., 8.],
542 [2., 7.],
543 [3., 7.],
544 [1., 8.],
545 [4., 8.],
546 [7., 9.],
548 [7., 8.],
549 [8., 8.],
550 [8., 7.],
551 [9., 7.],
552 [6., 3.],
554 [5., 2.],
555 [6., 2.],
556 [7., 2.],
557 [6., 1.],
558 [8., 4.], [3., 3.], ];
562 let mut hdbscan = super::HDbscan {
563 alpha: 1.,
564 min_samples: 5,
565 min_cluster_size: 5,
566 metric: Euclidean::default(),
567 boruvka: true,
568 };
569 let (_, _, outlier_scores) = hdbscan.fit(&data, None);
570
571 let expected = 1.0 - 2.0_f64.sqrt() / 3.0_f64;
578 let actual = outlier_scores[18];
579 assert!(
580 (actual - expected).abs() < f64::EPSILON,
581 "Expected: {}, got: {}",
582 expected,
583 actual
584 );
585
586 let expected = 1.0 - 2.0_f64.sqrt() / 13.0_f64.sqrt();
593 let actual = outlier_scores[19];
594 assert!(
595 (actual - expected).abs() < f64::EPSILON,
596 "Expected: {}, got: {}",
597 expected,
598 actual
599 );
600 }
601
602 #[test]
603 fn partial_labels() {
604 use std::collections::HashMap;
605
606 use ndarray::array;
607 use petal_neighbors::distance::Euclidean;
608
609 use crate::Fit;
610
611 let data = array![
612 [1., 9.],
614 [2., 9.],
615 [1., 8.],
616 [2., 8.],
617 [3., 7.],
618 [5., 4.],
620 [6., 4.],
621 [5., 3.],
622 [6., 3.],
623 [8., 3.],
625 [9., 3.],
626 [8., 2.],
627 [9., 2.],
628 [8., 1.],
629 [9., 1.],
630 [7., 8.],
632 ];
633 let mut hdbscan = super::HDbscan {
634 min_samples: 4,
635 min_cluster_size: 4,
636 metric: Euclidean::default(),
637 boruvka: false,
638 ..Default::default()
639 };
640
641 let (clusters, noise, _) = hdbscan.fit(&data, None);
643 assert_eq!(clusters.len(), 2); assert_eq!(noise, [15]); let c1 = clusters.keys().find(|k| clusters[k].contains(&0)).unwrap();
646 assert_eq!(clusters[c1], [0, 1, 2, 3, 4]);
647 let c2 = clusters.keys().find(|k| clusters[k].contains(&5)).unwrap();
648 assert_eq!(clusters[c2], [5, 6, 7, 8, 9, 10, 11, 12, 13, 14]);
649 assert_eq!(noise, [15]);
650
651 let partial_labels: HashMap<usize, Vec<usize>> = HashMap::new();
653 let (answer, noise, _) = hdbscan.fit(&data, Some(&partial_labels));
654 assert_eq!(answer, clusters);
655 assert_eq!(noise, [15]);
656
657 let mut partial_labels: HashMap<usize, Vec<usize>> = HashMap::new();
659 partial_labels.insert(0, vec![0]);
660 partial_labels.insert(1, vec![3, 4]);
661 partial_labels.insert(2, vec![6]);
662 partial_labels.insert(3, vec![11]);
663 let (clusters, noise, _) = hdbscan.fit(&data, Some(&partial_labels));
664 assert_eq!(clusters.len(), 3); assert_eq!(noise, [15]); let c1 = clusters.keys().find(|k| clusters[k].contains(&0)).unwrap();
667 assert_eq!(clusters[c1], [0, 1, 2, 3, 4]);
668 let c2 = clusters.keys().find(|k| clusters[k].contains(&5)).unwrap();
669 assert_eq!(clusters[c2], [5, 6, 7, 8]);
670 let c3 = clusters.keys().find(|k| clusters[k].contains(&9)).unwrap();
671 assert_eq!(clusters[c3], [9, 10, 11, 12, 13, 14]);
672 }
673
674 #[test]
675 fn label() {
676 let mst = vec![
677 (0, 1, 4.),
678 (2, 3, 4.),
679 (4, 5, 4.),
680 (1, 2, 7.), (3, 4, 7.), (5, 6, 8.),
683 ];
684 let labeled_mst = super::label(&mst);
693 assert_eq!(
694 labeled_mst,
695 vec![
696 (7, 0, 4., 1),
697 (7, 1, 4., 1),
698 (8, 2, 4., 1),
699 (8, 3, 4., 1),
700 (9, 4, 4., 1),
701 (9, 5, 4., 1),
702 (10, 7, 7., 2),
703 (10, 8, 7., 2),
704 (10, 9, 7., 2),
705 (11, 10, 8., 6),
706 (11, 6, 8., 1),
707 ]
708 );
709 }
710
711 #[test]
712 fn get_stability() {
713 use std::collections::HashMap;
714
715 use ndarray::arr1;
716
717 let condensed = arr1(&[
718 (7, 6, 1. / 9., 1),
719 (7, 4, 1. / 7., 1),
720 (7, 2, 1. / 7., 1),
721 (7, 1, 1. / 7., 1),
722 (7, 5, 1. / 6., 1),
723 (7, 0, 1. / 6., 1),
724 (7, 3, 1. / 6., 1),
725 ]);
726 let stability_map = super::get_stability(&condensed.view());
727 let mut answer = HashMap::new();
728 answer.insert(7, 1. / 9. + 3. / 7. + 3. / 6.);
729 assert_eq!(stability_map, answer);
730 }
731
732 #[test]
733 fn get_bcubed() {
734 use std::collections::HashMap;
735
736 use ndarray::arr1;
737
738 let condensed = arr1(&[
739 (8, 9, 1. / 10., 4),
740 (8, 10, 1. / 10., 4),
741 (9, 0, 1. / 6., 1),
742 (9, 1, 1. / 7., 1),
743 (9, 2, 1. / 7., 1),
744 (9, 3, 1. / 6., 1),
745 (10, 4, 1. / 7., 1),
746 (10, 5, 1. / 6., 1),
747 (10, 6, 1. / 9., 1),
748 (10, 7, 1. / 9., 1),
749 ]);
750 let mut partial_labels = HashMap::new();
751 partial_labels.insert(0, vec![0, 1, 4]);
752 partial_labels.insert(1, vec![5]);
753 partial_labels.insert(2, vec![7]);
754 let bcubed_map: HashMap<usize, f64> = super::get_bcubed(&condensed.view(), &partial_labels);
755 assert_eq!(bcubed_map.len(), 3);
756 assert_eq!(bcubed_map[&8], 0.0);
757 assert!((bcubed_map[&9] - 8. / 25.).abs() < f64::EPSILON);
758 assert!((bcubed_map[&10] - 4. / 15.).abs() < f64::EPSILON);
759 }
760}