1use linfa::prelude::*;
13use linfa_clustering::{GaussianMixtureModel, GmmError, KMeans};
14use linfa_nn::distance::{Distance, L2Dist};
15use linfa_reduction::Pca;
16use linfa_tsne::TSneParams;
17use log::{debug, info};
18use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis, Dim};
19use ndarray_rand::RandomExt;
20use ndarray_stats::QuantileExt;
21use rand::distributions::Uniform;
22use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
23
24use crate::{
25 DIM_EMBEDDING, Feature, NUMBER_FEATURES,
26 errors::{ClusteringError, ProjectionError},
27};
28
29pub type FitDataset = Dataset<Feature, (), Dim<[usize; 1]>>;
30
31pub type ClusteringResult<T> = Result<T, ClusteringError>;
32
33#[derive(Clone, Copy, Debug)]
34#[allow(clippy::module_name_repetitions)]
35pub enum ClusteringMethod {
36 KMeans,
37 GaussianMixtureModel,
38}
39
40impl ClusteringMethod {
41 fn fit(self, k: usize, data: &FitDataset) -> ClusteringResult<Array1<usize>> {
43 match self {
44 Self::KMeans => {
45 let model = KMeans::params(k)
46 .fit(data)?;
48 Ok(model.predict(data.records()))
49 }
50 Self::GaussianMixtureModel => {
51 let model = GaussianMixtureModel::params(k)
52 .init_method(linfa_clustering::GmmInitMethod::KMeans)
53 .reg_covariance(5e-3)
54 .n_runs(10)
55 .fit(data)
56 .inspect_err(|e| debug!("GMM fitting failed with k={k}: {e:?}, if this continues try using KMeans instead"))?;
57 Ok(model.predict(data.records()))
58 }
59 }
60 }
61}
62
63#[derive(Clone, Copy, Debug)]
64pub enum KOptimal {
65 GapStatistic {
66 b: u32,
68 },
69 DaviesBouldin,
70}
71
72#[derive(Clone, Copy, Debug, Default)]
73pub enum ProjectionMethod {
75 TSne,
77 Pca,
79 #[default]
80 None,
82}
83
84impl ProjectionMethod {
85 #[inline]
91 pub fn project(self, samples: Array2<Feature>) -> Result<Array2<Feature>, ProjectionError> {
92 let nrecords = samples.nrows();
93 let ncols = samples.ncols();
94 let result = match self {
95 Self::TSne => {
96 let intermediate_dim = ncols.midpoint(EMBEDDING_SIZE).midpoint(EMBEDDING_SIZE);
99 debug!(
100 "Preprocessing data with PCA to speed up t-SNE ({ncols} -> {intermediate_dim})"
101 );
102 let data = Dataset::from(samples.mapv(f64::from));
103 let pca: Pca<f64> = Pca::params(intermediate_dim).fit(&data)?;
104 #[allow(clippy::cast_possible_truncation)]
105 let pca_samples = pca.predict(&data).mapv(|f| f as Feature);
106
107 debug!(
109 "Generating embeddings (size: {intermediate_dim} -> {EMBEDDING_SIZE}) using t-SNE"
110 );
111 #[allow(clippy::cast_precision_loss)]
112 let mut embeddings = TSneParams::embedding_size(EMBEDDING_SIZE)
113 .perplexity(f32::max(nrecords as f32 / 20., 5.))
114 .approx_threshold(0.5)
115 .max_iter(1000)
116 .transform(pca_samples)?;
117 debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
118
119 debug!("Normalizing embeddings");
121 normalize_embeddings_inplace(&mut embeddings);
122 embeddings
123 }
124 Self::Pca => {
125 debug!("Generating embeddings (size: {ncols} -> {EMBEDDING_SIZE}) using PCA");
127 let data = Dataset::from(samples.mapv(f64::from));
129 let pca: Pca<f64> = Pca::params(EMBEDDING_SIZE).whiten(true).fit(&data)?;
130 #[allow(clippy::cast_possible_truncation)]
131 let mut embeddings = pca.predict(&data).mapv(|f| f as Feature);
132 debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
133
134 debug!("Normalizing embeddings");
136 normalize_embeddings_inplace(&mut embeddings);
137 embeddings
138 }
139 Self::None => {
140 debug!("Using original data as embeddings");
141 samples
142 }
143 };
144 debug!("Embeddings shape: {:?}", result.shape());
145 Ok(result)
146 }
147}
148
149fn normalize_embeddings_inplace(embeddings: &mut Array2<Feature>) {
152 for i in 0..embeddings.ncols() {
153 let min = *embeddings.column(i).min_skipnan();
154 let max = *embeddings.column(i).max_skipnan();
155 let range = max - min;
156 embeddings
157 .column_mut(i)
158 .mapv_inplace(|v| ((v - min) / range).mul_add(2., -1.));
159 }
160}
161
162const EMBEDDING_SIZE: usize = {
167 let log2 = usize::ilog2(if DIM_EMBEDDING < NUMBER_FEATURES {
168 NUMBER_FEATURES
169 } else {
170 DIM_EMBEDDING
171 }) as usize;
172 if log2 < 2 { 2 } else { log2 }
173};
174
175#[allow(clippy::module_name_repetitions)]
176pub struct ClusteringHelper<S>
177where
178 S: Sized,
179{
180 state: S,
181}
182
183pub struct EntryPoint;
184pub struct NotInitialized {
185 embeddings: Array2<Feature>,
187 pub k_max: usize,
188 pub optimizer: KOptimal,
189 pub clustering_method: ClusteringMethod,
190}
191pub struct Initialized {
192 embeddings: Array2<Feature>,
194 pub k: usize,
195 pub clustering_method: ClusteringMethod,
196}
197pub struct Finished {
198 labels: Array1<usize>,
201 pub k: usize,
202}
203
204impl ClusteringHelper<EntryPoint> {
206 #[allow(clippy::missing_inline_in_public_items)]
212 pub fn new(
213 samples: Array2<Feature>,
214 k_max: usize,
215 optimizer: KOptimal,
216 clustering_method: ClusteringMethod,
217 projection_method: ProjectionMethod,
218 ) -> Result<ClusteringHelper<NotInitialized>, ClusteringError> {
219 if samples.nrows() <= 15 {
220 return Err(ClusteringError::SmallLibrary);
221 }
222
223 let embeddings = projection_method.project(samples)?;
225
226 Ok(ClusteringHelper {
227 state: NotInitialized {
228 embeddings,
229 k_max,
230 optimizer,
231 clustering_method,
232 },
233 })
234 }
235}
236
237impl ClusteringHelper<NotInitialized> {
239 #[inline]
245 pub fn initialize(self) -> Result<ClusteringHelper<Initialized>, ClusteringError> {
246 let k = self.get_optimal_k()?;
247 Ok(ClusteringHelper {
248 state: Initialized {
249 embeddings: self.state.embeddings,
250 k,
251 clustering_method: self.state.clustering_method,
252 },
253 })
254 }
255
256 fn get_optimal_k(&self) -> Result<usize, ClusteringError> {
257 match self.state.optimizer {
258 KOptimal::GapStatistic { b } => self.get_optimal_k_gap_statistic(b),
259 KOptimal::DaviesBouldin => self.get_optimal_k_davies_bouldin(),
260 }
261 }
262
263 fn get_optimal_k_gap_statistic(&self, b: u32) -> Result<usize, ClusteringError> {
280 let embedding_dataset = Dataset::from(self.state.embeddings.clone());
281
282 let reference_datasets =
284 generate_reference_datasets(embedding_dataset.records().view(), b as usize);
285
286 #[allow(clippy::cast_precision_loss)]
287 let b = b as Feature;
288
289 let (mut optimal_k, mut gap_k_minus_one) = (None, None);
291
292 for k in 1..=self.state.k_max {
293 info!("Fitting k-means to embeddings with k={k}");
295 let labels = self.state.clustering_method.fit(k, &embedding_dataset)?;
296
297 debug!("Calculating within intra-cluster variation for reference data sets with k={k}");
300 let w_kb_log = reference_datasets
301 .par_iter()
302 .map(|ref_data| {
303 let ref_labels = self.state.clustering_method.fit(k, ref_data)?;
305 let ref_dispersion =
307 calc_centroid_dispersion(ref_data.records().view(), k, ref_labels.view());
308 let dispersion = calc_within_dispersion(ref_dispersion.view()).log2();
309 Ok(dispersion)
310 })
311 .collect::<ClusteringResult<Vec<_>>>();
312 let w_kb_log = match w_kb_log {
313 Ok(w_kb_log) => Array::from_vec(w_kb_log),
314 Err(ClusteringError::Gmm(GmmError::EmptyCluster(e))) => {
315 log::warn!("Library is not large enough to cluster with k={k}: {e}");
316 break;
317 }
318 Err(e) => return Err(e),
319 };
320 let centroid_dispersion =
322 calc_centroid_dispersion(self.state.embeddings.view(), k, labels.view());
323 let w_k = calc_within_dispersion(centroid_dispersion.view());
324
325 let w_kb_log_sum: Feature = w_kb_log.sum();
327 let l = b.recip() * w_kb_log_sum;
329 let gap_k = l - w_k.log2();
331 let standard_deviation = (b.recip() * (w_kb_log - l).pow2().sum()).sqrt();
333 let s_k = standard_deviation * (1.0 + b.recip()).sqrt();
336
337 info!("k: {k}, gap_k: {gap_k}, s_k: {s_k}");
339 if let Some(gap_k_minus_one) = gap_k_minus_one
340 && gap_k_minus_one >= gap_k - s_k
341 {
342 info!("Optimal k found: {}", k - 1);
343 optimal_k = Some(k - 1);
344 break;
345 }
346
347 gap_k_minus_one = Some(gap_k);
348 }
349
350 optimal_k.ok_or(ClusteringError::OptimalKNotFound(self.state.k_max))
351 }
352
353 fn get_optimal_k_davies_bouldin(&self) -> Result<usize, ClusteringError> {
354 todo!();
355 }
356}
357
358const MAX_REF_DATASET_SAMPLES: usize = 2000;
359
360fn generate_reference_datasets(samples: ArrayView2<'_, Feature>, b: usize) -> Vec<FitDataset> {
384 if samples.nrows() < MAX_REF_DATASET_SAMPLES {
385 return (0..b)
387 .into_par_iter()
388 .map(|_| Dataset::from(generate_ref_single(samples.view())))
389 .collect();
390 }
391
392 let mut rng = rand::thread_rng();
394 let indices = rand::seq::index::sample(&mut rng, samples.nrows(), MAX_REF_DATASET_SAMPLES);
395 let samples = samples.select(Axis(0), &indices.into_vec());
396
397 (0..b)
398 .into_par_iter()
399 .map(|_| Dataset::from(generate_ref_single(samples.view())))
400 .collect()
401}
402fn generate_ref_single(samples: ArrayView2<'_, Feature>) -> Array2<Feature> {
403 let feature_distributions = samples
404 .axis_iter(Axis(1))
405 .map(|feature| {
406 let min = *feature.min_skipnan();
407 let max = *feature.max_skipnan();
408 if min >= max {
409 return Array::from_elem(feature.dim(), min);
411 }
412 Array::random(feature.dim(), Uniform::new(min, max))
413 })
414 .collect::<Vec<_>>();
415 let feature_dists_views = feature_distributions
416 .iter()
417 .map(ndarray::ArrayBase::view)
418 .collect::<Vec<_>>();
419 ndarray::stack(Axis(0), &feature_dists_views)
420 .unwrap()
421 .t()
422 .to_owned()
423}
424
425fn calc_within_dispersion(pairwise_distances: ArrayView1<'_, Feature>) -> Feature {
433 pairwise_distances.sum()
434}
435
436#[allow(unused)]
444fn calc_pairwise_distances(
445 samples: ArrayView2<'_, Feature>,
446 k: usize,
447 labels: ArrayView1<'_, usize>,
448) -> Array1<Feature> {
449 debug_assert_eq!(
450 samples.nrows(),
451 labels.len(),
452 "Samples and labels must have the same length"
453 );
454 debug_assert_eq!(
455 k,
456 labels.iter().max().unwrap() + 1,
457 "Labels must be in the range 0..k"
458 );
459
460 let mut distances = Array1::zeros(k);
462 let mut clusters = vec![Vec::new(); k];
463 for (sample, &label) in samples.outer_iter().zip(labels.iter()) {
465 clusters[label].push(sample);
466 }
467 #[allow(clippy::cast_precision_loss)]
469 for (k, cluster) in clusters.iter().enumerate() {
470 let mut pairwise_dists = 0.;
471 for i in 0..cluster.len() - 1 {
472 let a = cluster[i];
473 let rest = &cluster[i + 1..];
474 for &b in rest {
475 pairwise_dists += L2Dist.distance(a, b);
476 }
477 }
478 distances[k] += pairwise_dists / cluster.len() as Feature; }
480 distances
481}
482
483fn calc_centroid_dispersion(
501 samples: ArrayView2<'_, Feature>,
502 k: usize,
503 labels: ArrayView1<'_, usize>,
504) -> Array1<Feature> {
505 debug_assert_eq!(
506 samples.nrows(),
507 labels.len(),
508 "Samples and labels must have the same length"
509 );
510 debug_assert_eq!(
511 k,
512 labels.iter().max().unwrap() + 1,
513 "Labels must be in the range 0..k"
514 );
515
516 let mut dispersions = Array1::zeros(k);
517 let mut centroids = Array2::zeros((k, samples.ncols()));
518 let mut counts = vec![0usize; k];
519
520 for (sample, &label) in samples.outer_iter().zip(labels.iter()) {
522 centroids.row_mut(label).scaled_add(1.0, &sample);
523 counts[label] += 1;
524 }
525
526 for (i, &count) in counts.iter().enumerate() {
528 if count > 0 {
529 #[allow(clippy::cast_precision_loss)]
530 centroids.row_mut(i).mapv_inplace(|v| v / count as Feature);
531 }
532 }
533
534 for (sample, &label) in samples.outer_iter().zip(labels.iter()) {
536 let dist = L2Dist.distance(sample, centroids.row(label));
537 dispersions[label] += dist; }
539
540 dispersions
541}
542
543impl ClusteringHelper<Initialized> {
545 #[inline]
551 pub fn cluster(self) -> ClusteringResult<ClusteringHelper<Finished>> {
552 let Initialized {
553 clustering_method,
554 embeddings,
555 k,
556 } = self.state;
557
558 let embedding_dataset = Dataset::from(embeddings);
559 let labels = clustering_method.fit(k, &embedding_dataset)?;
560
561 Ok(ClusteringHelper {
562 state: Finished { labels, k },
563 })
564 }
565}
566
567impl ClusteringHelper<Finished> {
569 #[must_use]
571 #[inline]
572 pub fn extract_analysis_clusters<T: Clone>(&self, samples: Vec<T>) -> Vec<Vec<T>> {
573 let mut clusters = vec![Vec::new(); self.state.k];
574
575 for (sample, &label) in samples.into_iter().zip(self.state.labels.iter()) {
576 clusters[label].push(sample);
577 }
578
579 clusters
580 }
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586 use ndarray::{arr1, arr2, s};
587 use ndarray_rand::rand_distr::StandardNormal;
588 use pretty_assertions::assert_eq;
589 use rstest::rstest;
590
591 #[test]
592 fn test_generate_reference_data_set() {
593 let data = arr2(&[[10.0, -10.0], [20.0, -20.0], [30.0, -30.0]]);
594
595 let ref_data = generate_ref_single(data.view());
596
597 assert!(
599 ref_data
600 .slice(s![.., 0])
601 .iter()
602 .all(|v| *v >= 10.0 && *v <= 30.0)
603 );
604
605 assert!(
607 ref_data
608 .slice(s![.., 1])
609 .iter()
610 .all(|v| *v <= -10.0 && *v >= -30.0)
611 );
612
613 assert_eq!(ref_data.shape(), data.shape());
615
616 assert_ne!(ref_data, data);
618 }
619
620 #[test]
621 fn test_pairwise_distances() {
622 let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
623 let labels = arr1(&[0, 0, 1, 1]);
624
625 let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
626
627 assert!(
628 f32::EPSILON > (pairwise_distances[0] - 0.0).abs(),
629 "{} != 0.0",
630 pairwise_distances[0]
631 );
632 assert!(
633 f32::EPSILON > (pairwise_distances[1] - 0.0).abs(),
634 "{} != 0.0",
635 pairwise_distances[1]
636 );
637
638 let samples = arr2(&[[1.0, 2.0], [1.0, 1.0], [2.0, 2.0], [2.0, 3.0]]);
639
640 let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
641
642 assert!(
643 f32::EPSILON > (pairwise_distances[0] - 2.0 / 4.0).abs(),
644 "{} != 2.0",
645 pairwise_distances[0]
646 );
647 assert!(
648 f32::EPSILON > (pairwise_distances[1] - 2.0 / 4.0).abs(),
649 "{} != 2.0",
650 pairwise_distances[1]
651 );
652 }
653
654 #[test]
655 fn test_calc_within_dispersion() {
656 let pairwise_distances = arr1(&[1.0 / 4.0, 2.0 / 4.0]); let result = calc_within_dispersion(pairwise_distances.view());
659
660 assert!(f32::EPSILON > (result - 0.75).abs(), "{result} != 0.75");
662 }
663
664 #[rstest]
665 #[case::project_none(ProjectionMethod::None, NUMBER_FEATURES)]
666 #[case::project_tsne(ProjectionMethod::TSne, EMBEDDING_SIZE)]
667 #[case::project_pca(ProjectionMethod::Pca, EMBEDDING_SIZE)]
668 fn test_project(
669 #[case] projection_method: ProjectionMethod,
670 #[case] expected_embedding_size: usize,
671 ) {
672 let mut samples = Array2::random((100, NUMBER_FEATURES), StandardNormal);
676 normalize_embeddings_inplace(&mut samples);
677
678 let result = projection_method.project(samples).unwrap();
679
680 assert_eq!(result.shape(), &[100, expected_embedding_size]);
682
683 for i in 0..expected_embedding_size {
685 let min = result.column(i).min().copied().unwrap_or_default();
686 let max = result.column(i).max().copied().unwrap_or_default();
687 assert!(
688 f32::EPSILON > (min + 1.0).abs(),
689 "Min value of column {i} is not -1.0: {min}",
690 );
691 assert!(
692 f32::EPSILON > (max - 1.0).abs(),
693 "Max value of column {i} is not 1.0: {max}",
694 );
695 }
696 }
697
698 #[test]
699 fn test_centroid_dispersion_identical_points() {
700 let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
702 let labels = arr1(&[0, 0, 1, 1]);
703
704 let dispersion = calc_centroid_dispersion(samples.view(), 2, labels.view());
705
706 assert!(
708 f32::EPSILON > dispersion[0].abs(),
709 "Cluster 0 dispersion should be ~0, got {}",
710 dispersion[0]
711 );
712 assert!(
713 f32::EPSILON > dispersion[1].abs(),
714 "Cluster 1 dispersion should be ~0, got {}",
715 dispersion[1]
716 );
717 }
718
719 #[test]
720 fn test_centroid_dispersion_known_values() {
721 let samples = arr2(&[
726 [0.0, 0.0],
727 [2.0, 0.0],
728 [0.0, 2.0],
729 [2.0, 2.0],
730 [10.0, 10.0], ]);
732 let labels = arr1(&[0, 0, 0, 0, 1]);
733
734 let dispersion = calc_centroid_dispersion(samples.view(), 2, labels.view());
735
736 let expected = 4.0 * 2.0_f32.sqrt();
739 assert!(
740 (dispersion[0] - expected).abs() < 0.001,
741 "Cluster 0 dispersion should be ~{}, got {}",
742 expected,
743 dispersion[0]
744 );
745
746 assert!(
748 dispersion[1].abs() < f32::EPSILON,
749 "Cluster 1 dispersion should be ~0, got {}",
750 dispersion[1]
751 );
752 }
753
754 #[test]
755 fn test_centroid_dispersion_matches_shape() {
756 let samples = Array2::random((100, 10), StandardNormal);
758 let labels = Array1::from_vec((0..100).map(|i| i % 5).collect());
759
760 let dispersion = calc_centroid_dispersion(samples.view(), 5, labels.view());
761
762 assert_eq!(dispersion.len(), 5, "Should have dispersion for 5 clusters");
763 assert!(
764 dispersion.iter().all(|&d| d >= 0.0),
765 "All dispersions should be non-negative"
766 );
767 }
768
769 #[test]
770 fn test_centroid_vs_pairwise_relative_ordering() {
771 let mut samples = Array2::zeros((60, 2));
776 let mut labels = Array1::zeros(60);
777
778 for i in 0..20 {
780 samples[[i, 0]] = (i as f32 * 0.1) - 1.0;
781 samples[[i, 1]] = (i as f32 * 0.1) - 1.0;
782 labels[i] = 0;
783 }
784
785 for i in 20..40 {
787 samples[[i, 0]] = ((i - 20) as f32 * 0.3) + 4.0;
788 samples[[i, 1]] = ((i - 20) as f32 * 0.3) + 4.0;
789 labels[i] = 1;
790 }
791
792 for i in 40..60 {
794 samples[[i, 0]] = ((i - 40) as f32 * 0.5) + 8.0;
795 samples[[i, 1]] = ((i - 40) as f32 * 0.5) + 8.0;
796 labels[i] = 2;
797 }
798
799 let pairwise = calc_pairwise_distances(samples.view(), 3, labels.view());
800 let centroid = calc_centroid_dispersion(samples.view(), 3, labels.view());
801
802 assert!(
805 pairwise[0] < pairwise[1] && pairwise[1] < pairwise[2],
806 "Pairwise ordering incorrect: {:?}",
807 pairwise
808 );
809 assert!(
810 centroid[0] < centroid[1] && centroid[1] < centroid[2],
811 "Centroid ordering incorrect: {:?}",
812 centroid
813 );
814
815 let pairwise_spread = pairwise.max_skipnan() / pairwise.min_skipnan();
826 let centroid_spread = centroid.max_skipnan() / centroid.min_skipnan();
827
828 assert!(
829 pairwise_spread > 2.0 && centroid_spread > 2.0,
830 "Both methods should show significant dispersion variation: pairwise={}, centroid={}",
831 pairwise_spread,
832 centroid_spread
833 );
834 }
835
836 #[test]
837 fn test_centroid_dispersion_with_empty_cluster() {
838 let samples = arr2(&[[1.0, 1.0], [2.0, 2.0]]);
840 let labels = arr1(&[0, 2]); let dispersion = calc_centroid_dispersion(samples.view(), 3, labels.view());
843
844 assert!(dispersion[0].abs() < f32::EPSILON);
846 assert!(dispersion[2].abs() < f32::EPSILON);
847 assert!(dispersion[1].abs() < f32::EPSILON);
849 }
850}
851
852