1use linfa::prelude::*;
13use linfa_clustering::{GaussianMixtureModel, KMeans};
14use linfa_nn::distance::{Distance, L2Dist};
15use linfa_reduction::Pca;
16use linfa_tsne::TSneParams;
17use log::{debug, info};
18use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis};
19use ndarray_rand::RandomExt;
20use rand::distributions::Uniform;
21use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
22use statrs::statistics::Statistics;
23
24use crate::{
25 Analysis, Feature, NUMBER_FEATURES,
26 errors::{ClusteringError, ProjectionError},
27};
28
29pub struct AnalysisArray(pub(crate) Array2<Feature>);
30
31impl From<Vec<Analysis>> for AnalysisArray {
32 #[inline]
33 fn from(data: Vec<Analysis>) -> Self {
34 let shape = (data.len(), NUMBER_FEATURES);
35 debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
36
37 Self(
38 Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
39 .expect("Failed to convert to array, shape mismatch"),
40 )
41 }
42}
43
44impl From<Vec<[Feature; NUMBER_FEATURES]>> for AnalysisArray {
45 #[inline]
46 fn from(data: Vec<[Feature; NUMBER_FEATURES]>) -> Self {
47 let shape = (data.len(), NUMBER_FEATURES);
48 debug_assert_eq!(shape, (data.len(), data[0].len()));
49
50 Self(
51 Array2::from_shape_vec(shape, data.into_iter().flatten().collect())
52 .expect("Failed to convert to array, shape mismatch"),
53 )
54 }
55}
56
57#[derive(Clone, Copy, Debug)]
58#[allow(clippy::module_name_repetitions)]
59pub enum ClusteringMethod {
60 KMeans,
61 GaussianMixtureModel,
62}
63
64impl ClusteringMethod {
65 #[must_use]
67 fn fit(self, k: usize, samples: &Array2<Feature>) -> Array1<usize> {
68 match self {
69 Self::KMeans => {
70 let model = KMeans::params(k)
71 .fit(&Dataset::from(samples.clone()))
73 .unwrap();
74 model.predict(samples)
75 }
76 Self::GaussianMixtureModel => {
77 let model = GaussianMixtureModel::params(k)
78 .init_method(linfa_clustering::GmmInitMethod::KMeans)
79 .n_runs(10)
80 .fit(&Dataset::from(samples.clone()))
81 .unwrap();
82 model.predict(samples)
83 }
84 }
85 }
86}
87
88#[derive(Clone, Copy, Debug)]
89pub enum KOptimal {
90 GapStatistic {
91 b: usize,
93 },
94 DaviesBouldin,
95}
96
97#[derive(Clone, Copy, Debug, Default)]
98pub enum ProjectionMethod {
100 TSne,
102 Pca,
104 #[default]
105 None,
107}
108
109impl ProjectionMethod {
110 #[inline]
116 pub fn project(self, samples: AnalysisArray) -> Result<Array2<Feature>, ProjectionError> {
117 let result = match self {
118 Self::TSne => {
119 let nrecords = samples.0.nrows();
120 debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using t-SNE");
122 #[allow(clippy::cast_precision_loss)]
123 let mut embeddings = TSneParams::embedding_size(EMBEDDING_SIZE)
124 .perplexity(f64::max(samples.0.nrows() as f64 / 20., 5.))
125 .approx_threshold(0.5)
126 .transform(samples.0)?;
127 debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
128
129 debug!("Normalizing embeddings");
131 normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
132 embeddings
133 }
134 Self::Pca => {
135 let nrecords = samples.0.nrows();
136 debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using PCA");
138 let data = Dataset::from(samples.0);
139 let pca: Pca<f64> = Pca::params(EMBEDDING_SIZE).whiten(true).fit(&data)?;
140 let mut embeddings = pca.predict(&data);
141 debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
142
143 debug!("Normalizing embeddings");
145 normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
146 embeddings
147 }
148 Self::None => {
149 debug!("Using original data as embeddings");
150 samples.0
151 }
152 };
153 debug!("Embeddings shape: {:?}", result.shape());
154 Ok(result)
155 }
156}
157
158fn normalize_embeddings_inplace<const SIZE: usize>(embeddings: &mut Array2<f64>) {
161 for i in 0..SIZE {
162 let min = embeddings.column(i).min();
163 let max = embeddings.column(i).max();
164 let range = max - min;
165 embeddings
166 .column_mut(i)
167 .mapv_inplace(|v| ((v - min) / range).mul_add(2., -1.));
168 }
169}
170
171const EMBEDDING_SIZE: usize = {
174 let log2 = usize::ilog2(NUMBER_FEATURES) as usize;
175 if log2 < 2 { 2 } else { log2 }
176};
177
178#[allow(clippy::module_name_repetitions)]
179pub struct ClusteringHelper<S>
180where
181 S: Sized,
182{
183 state: S,
184}
185
186pub struct EntryPoint;
187pub struct NotInitialized {
188 embeddings: Array2<Feature>,
190 pub k_max: usize,
191 pub optimizer: KOptimal,
192 pub clustering_method: ClusteringMethod,
193}
194pub struct Initialized {
195 embeddings: Array2<Feature>,
197 pub k: usize,
198 pub clustering_method: ClusteringMethod,
199}
200pub struct Finished {
201 labels: Array1<usize>,
204 pub k: usize,
205}
206
207impl ClusteringHelper<EntryPoint> {
209 #[allow(clippy::missing_inline_in_public_items)]
215 pub fn new(
216 samples: AnalysisArray,
217 k_max: usize,
218 optimizer: KOptimal,
219 clustering_method: ClusteringMethod,
220 projection_method: ProjectionMethod,
221 ) -> Result<ClusteringHelper<NotInitialized>, ClusteringError> {
222 if samples.0.nrows() <= 15 {
223 return Err(ClusteringError::SmallLibrary);
224 }
225
226 let embeddings = projection_method.project(samples)?;
228
229 Ok(ClusteringHelper {
230 state: NotInitialized {
231 embeddings,
232 k_max,
233 optimizer,
234 clustering_method,
235 },
236 })
237 }
238}
239
240impl ClusteringHelper<NotInitialized> {
242 #[inline]
248 pub fn initialize(self) -> Result<ClusteringHelper<Initialized>, ClusteringError> {
249 let k = self.get_optimal_k()?;
250 Ok(ClusteringHelper {
251 state: Initialized {
252 embeddings: self.state.embeddings,
253 k,
254 clustering_method: self.state.clustering_method,
255 },
256 })
257 }
258
259 fn get_optimal_k(&self) -> Result<usize, ClusteringError> {
260 match self.state.optimizer {
261 KOptimal::GapStatistic { b } => self.get_optimal_k_gap_statistic(b),
262 KOptimal::DaviesBouldin => self.get_optimal_k_davies_bouldin(),
263 }
264 }
265
266 fn get_optimal_k_gap_statistic(&self, b: usize) -> Result<usize, ClusteringError> {
283 let reference_data_sets = generate_reference_data_set(self.state.embeddings.view(), b);
285
286 let results = (1..=self.state.k_max)
287 .map(|k| {
289 debug!("Fitting k-means to embeddings with k={k}");
290 let labels = self.state.clustering_method.fit(k, &self.state.embeddings);
291 (k, labels)
292 })
293 .map(|(k, labels)| {
295 let pairwise_distances =
297 calc_pairwise_distances(self.state.embeddings.view(), k, labels.view());
298 let w_k = calc_within_dispersion(labels.view(), k, pairwise_distances.view());
299
300 debug!(
302 "Calculating within intra-cluster variation for reference data sets with k={k}"
303 );
304 let w_kb = reference_data_sets.par_iter().map(|ref_data| {
305 let ref_labels = self.state.clustering_method.fit(k, ref_data);
307 let ref_pairwise_distances =
309 calc_pairwise_distances(ref_data.view(), k, ref_labels.view());
310 calc_within_dispersion(ref_labels.view(), k, ref_pairwise_distances.view())
311 });
312
313 let w_kb_log_sum = w_kb.clone().map(f64::log2).sum::<f64>();
315 #[allow(clippy::cast_precision_loss)]
317 let l = (1.0 / b as f64) * w_kb_log_sum;
318 #[allow(clippy::cast_precision_loss)]
320 let gap_k = l - w_k.log2();
321 #[allow(clippy::cast_precision_loss)]
323 let standard_deviation = ((1.0 / b as f64)
324 * w_kb.map(|w_kb| (w_kb.log2() - l).powi(2)).sum::<f64>())
325 .sqrt();
326 #[allow(clippy::cast_precision_loss)]
329 let s_k = standard_deviation * (1.0 + 1.0 / b as f64).sqrt();
330
331 (k, gap_k, s_k)
332 });
333
334 let (mut optimal_k, mut gap_k_minus_one) = (None, None);
340
341 for (k, gap_k, s_k) in results {
342 info!("k: {k}, gap_k: {gap_k}, s_k: {s_k}");
343
344 if let Some(gap_k_minus_one) = gap_k_minus_one {
345 if gap_k_minus_one >= gap_k - s_k {
346 info!("Optimal k found: {}", k - 1);
347 optimal_k = Some(k - 1);
348 break;
349 }
350 }
351 gap_k_minus_one = Some(gap_k);
352 }
353
354 optimal_k.ok_or(ClusteringError::OptimalKNotFound(self.state.k_max))
355 }
356
357 fn get_optimal_k_davies_bouldin(&self) -> Result<usize, ClusteringError> {
358 todo!();
359 }
360}
361
362#[must_use]
368#[inline]
369pub fn convert_to_array(data: Vec<Analysis>) -> AnalysisArray {
370 let shape = (data.len(), NUMBER_FEATURES);
372 debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
373
374 AnalysisArray(
375 Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
376 .expect("Failed to convert to array, shape mismatch"),
377 )
378}
379
380fn generate_reference_data_set(samples: ArrayView2<Feature>, b: usize) -> Vec<Array2<f64>> {
404 let mut reference_data_sets = Vec::with_capacity(b);
405 for _ in 0..b {
406 reference_data_sets.push(generate_ref_single(samples));
407 }
408
409 reference_data_sets
410}
411fn generate_ref_single(samples: ArrayView2<Feature>) -> Array2<f64> {
412 let feature_distributions = samples
413 .axis_iter(Axis(1))
414 .map(|feature| Array::random(feature.dim(), Uniform::new(feature.min(), feature.max())))
415 .collect::<Vec<_>>();
416 let feature_dists_views = feature_distributions
417 .iter()
418 .map(ndarray::ArrayBase::view)
419 .collect::<Vec<_>>();
420 ndarray::stack(Axis(0), &feature_dists_views)
421 .unwrap()
422 .t()
423 .to_owned()
424}
425
426fn calc_within_dispersion(
430 labels: ArrayView1<usize>,
431 k: usize,
432 pairwise_distances: ArrayView1<Feature>,
433) -> Feature {
434 debug_assert_eq!(k, labels.iter().max().unwrap() + 1);
435
436 let counts = labels.iter().fold(vec![0u32; k], |mut counts, &label| {
438 counts[label] += 1;
439 counts
440 });
441 counts
443 .iter()
444 .zip(pairwise_distances.iter())
445 .map(|(&count, distance)| (1. / (2.0 * f64::from(count))) * distance)
446 .sum()
447}
448
449fn calc_pairwise_distances(
457 samples: ArrayView2<Feature>,
458 k: usize,
459 labels: ArrayView1<usize>,
460) -> Array1<Feature> {
461 debug_assert_eq!(
462 samples.nrows(),
463 labels.len(),
464 "Samples and labels must have the same length"
465 );
466 debug_assert_eq!(
467 k,
468 labels.iter().max().unwrap() + 1,
469 "Labels must be in the range 0..k"
470 );
471
472 let mut distances = Array1::zeros(k);
474 let mut clusters = vec![Vec::new(); k];
475 for (sample, label) in samples.outer_iter().zip(labels.iter()) {
477 clusters[*label].push(sample);
478 }
479 for (k, cluster) in clusters.iter().enumerate() {
481 let mut pairwise_dists = 0.;
482 for i in 0..cluster.len() - 1 {
483 let a = cluster[i];
484 let rest = &cluster[i + 1..];
485 for &b in rest {
486 pairwise_dists += L2Dist.distance(a, b);
487 }
488 }
489 distances[k] += pairwise_dists + pairwise_dists;
490 }
491 distances
492}
493
494impl ClusteringHelper<Initialized> {
496 #[must_use]
502 #[inline]
503 pub fn cluster(self) -> ClusteringHelper<Finished> {
504 let labels = self
505 .state
506 .clustering_method
507 .fit(self.state.k, &self.state.embeddings);
508
509 ClusteringHelper {
510 state: Finished {
511 labels,
512 k: self.state.k,
513 },
514 }
515 }
516}
517
518impl ClusteringHelper<Finished> {
520 #[must_use]
522 #[inline]
523 pub fn extract_analysis_clusters<T: Clone>(&self, samples: Vec<T>) -> Vec<Vec<T>> {
524 let mut clusters = vec![Vec::new(); self.state.k];
525
526 for (sample, &label) in samples.into_iter().zip(self.state.labels.iter()) {
527 clusters[label].push(sample);
528 }
529
530 clusters
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537 use ndarray::{arr1, arr2, s};
538 use ndarray_rand::rand_distr::StandardNormal;
539 use pretty_assertions::assert_eq;
540 use rstest::rstest;
541
542 #[test]
543 fn test_generate_reference_data_set() {
544 let data = arr2(&[[10.0, -10.0], [20.0, -20.0], [30.0, -30.0]]);
545
546 let ref_data = generate_ref_single(data.view());
547
548 assert!(
550 ref_data
551 .slice(s![.., 0])
552 .iter()
553 .all(|v| *v >= 10.0 && *v <= 30.0)
554 );
555
556 assert!(
558 ref_data
559 .slice(s![.., 1])
560 .iter()
561 .all(|v| *v <= -10.0 && *v >= -30.0)
562 );
563
564 assert_eq!(ref_data.shape(), data.shape());
566
567 assert_ne!(ref_data, data);
569 }
570
571 #[test]
572 fn test_pairwise_distances() {
573 let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
574 let labels = arr1(&[0, 0, 1, 1]);
575
576 let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
577
578 assert!(
579 f64::EPSILON > (pairwise_distances[0] - 0.0).abs(),
580 "{} != 0.0",
581 pairwise_distances[0]
582 );
583 assert!(
584 f64::EPSILON > (pairwise_distances[1] - 0.0).abs(),
585 "{} != 0.0",
586 pairwise_distances[1]
587 );
588
589 let samples = arr2(&[[1.0, 2.0], [1.0, 1.0], [2.0, 2.0], [2.0, 3.0]]);
590
591 let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
592
593 assert!(
594 f64::EPSILON > (pairwise_distances[0] - 2.0).abs(),
595 "{} != 2.0",
596 pairwise_distances[0]
597 );
598 assert!(
599 f64::EPSILON > (pairwise_distances[1] - 2.0).abs(),
600 "{} != 2.0",
601 pairwise_distances[1]
602 );
603 }
604
605 #[test]
606 fn test_convert_to_vec() {
607 let data = vec![
608 Analysis::new([1.0; NUMBER_FEATURES]),
609 Analysis::new([2.0; NUMBER_FEATURES]),
610 Analysis::new([3.0; NUMBER_FEATURES]),
611 ];
612
613 let array = convert_to_array(data);
614
615 assert_eq!(array.0.shape(), &[3, NUMBER_FEATURES]);
616 assert!(
617 f64::EPSILON > (array.0[[0, 0]] - 1.0).abs(),
618 "{} != 1.0",
619 array.0[[0, 0]]
620 );
621 assert!(
622 f64::EPSILON > (array.0[[1, 0]] - 2.0).abs(),
623 "{} != 2.0",
624 array.0[[1, 0]]
625 );
626 assert!(
627 f64::EPSILON > (array.0[[2, 0]] - 3.0).abs(),
628 "{} != 3.0",
629 array.0[[2, 0]]
630 );
631
632 let mut iter = array.0.axis_iter(Axis(0));
635 assert_eq!(iter.next().unwrap().to_vec(), vec![1.0; NUMBER_FEATURES]);
636 assert_eq!(iter.next().unwrap().to_vec(), vec![2.0; NUMBER_FEATURES]);
637 assert_eq!(iter.next().unwrap().to_vec(), vec![3.0; NUMBER_FEATURES]);
638 for column in array.0.axis_iter(Axis(1)) {
640 assert_eq!(column.to_vec(), vec![1.0, 2.0, 3.0]);
641 }
642 }
643
644 #[test]
645 fn test_calc_within_dispersion() {
646 let labels = arr1(&[0, 1, 0, 1]);
647 let pairwise_distances = arr1(&[1.0, 2.0]);
648 let result = calc_within_dispersion(labels.view(), 2, pairwise_distances.view());
649
650 assert!(f64::EPSILON > (result - 0.75).abs(), "{} != 0.75", result);
652 }
653
654 #[rstest]
655 #[case::project_none(ProjectionMethod::None, NUMBER_FEATURES)]
656 #[case::project_tsne(ProjectionMethod::TSne, EMBEDDING_SIZE)]
657 #[case::project_pca(ProjectionMethod::Pca, EMBEDDING_SIZE)]
658 fn test_project(
659 #[case] projection_method: ProjectionMethod,
660 #[case] expected_embedding_size: usize,
661 ) {
662 let mut samples = Array2::random((100, NUMBER_FEATURES), StandardNormal);
666 normalize_embeddings_inplace::<NUMBER_FEATURES>(&mut samples);
667 let samples = AnalysisArray(samples);
668
669 let result = projection_method.project(samples).unwrap();
670
671 assert_eq!(result.shape(), &[100, expected_embedding_size]);
673
674 for i in 0..expected_embedding_size {
676 let min = result.column(i).min();
677 let max = result.column(i).max();
678 assert!(
679 f64::EPSILON > (min + 1.0).abs(),
680 "Min value of column {i} is not -1.0: {min}",
681 );
682 assert!(
683 f64::EPSILON > (max - 1.0).abs(),
684 "Max value of column {i} is not 1.0: {max}",
685 );
686 }
687 }
688}
689
690