1use linfa::prelude::*;
13use linfa_clustering::{GaussianMixtureModel, KMeans};
14use linfa_nn::distance::{Distance, L2Dist};
15use linfa_reduction::Pca;
16use linfa_tsne::TSneParams;
17use log::{debug, info};
18use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis};
19use ndarray_rand::RandomExt;
20use rand::distributions::Uniform;
21use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
22use statrs::statistics::Statistics;
23
24use crate::{
25 Analysis, Feature, NUMBER_FEATURES,
26 errors::{ClusteringError, ProjectionError},
27};
28
29pub struct AnalysisArray(pub(crate) Array2<Feature>);
30
31impl From<Vec<Analysis>> for AnalysisArray {
32 #[inline]
33 fn from(data: Vec<Analysis>) -> Self {
34 let shape = (data.len(), NUMBER_FEATURES);
35 debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
36
37 Self(
38 Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
39 .expect("Failed to convert to array, shape mismatch"),
40 )
41 }
42}
43
44impl From<Vec<[Feature; NUMBER_FEATURES]>> for AnalysisArray {
45 #[inline]
46 fn from(data: Vec<[Feature; NUMBER_FEATURES]>) -> Self {
47 let shape = (data.len(), NUMBER_FEATURES);
48 debug_assert_eq!(shape, (data.len(), data[0].len()));
49
50 Self(
51 Array2::from_shape_vec(shape, data.into_iter().flatten().collect())
52 .expect("Failed to convert to array, shape mismatch"),
53 )
54 }
55}
56
57#[derive(Clone, Copy, Debug)]
58#[allow(clippy::module_name_repetitions)]
59pub enum ClusteringMethod {
60 KMeans,
61 GaussianMixtureModel,
62}
63
64impl ClusteringMethod {
65 #[must_use]
67 fn fit(self, k: usize, samples: &Array2<Feature>) -> Array1<usize> {
68 match self {
69 Self::KMeans => {
70 let model = KMeans::params(k)
71 .fit(&Dataset::from(samples.clone()))
73 .unwrap();
74 model.predict(samples)
75 }
76 Self::GaussianMixtureModel => {
77 let model = GaussianMixtureModel::params(k)
78 .init_method(linfa_clustering::GmmInitMethod::KMeans)
79 .n_runs(10)
80 .fit(&Dataset::from(samples.clone()))
81 .unwrap();
82 model.predict(samples)
83 }
84 }
85 }
86}
87
88#[derive(Clone, Copy, Debug)]
89pub enum KOptimal {
90 GapStatistic {
91 b: u32,
93 },
94 DaviesBouldin,
95}
96
97#[derive(Clone, Copy, Debug, Default)]
98pub enum ProjectionMethod {
100 TSne,
102 Pca,
104 #[default]
105 None,
107}
108
109impl ProjectionMethod {
110 #[inline]
116 pub fn project(self, samples: AnalysisArray) -> Result<Array2<Feature>, ProjectionError> {
117 let result = match self {
118 Self::TSne => {
119 let nrecords = samples.0.nrows();
120 debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using t-SNE");
122 #[allow(clippy::cast_precision_loss)]
123 let mut embeddings = TSneParams::embedding_size(EMBEDDING_SIZE)
124 .perplexity(f64::max(samples.0.nrows() as f64 / 20., 5.))
125 .approx_threshold(0.5)
126 .transform(samples.0)?;
127 debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
128
129 debug!("Normalizing embeddings");
131 normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
132 embeddings
133 }
134 Self::Pca => {
135 let nrecords = samples.0.nrows();
136 debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using PCA");
138 let data = Dataset::from(samples.0);
139 let pca: Pca<f64> = Pca::params(EMBEDDING_SIZE).whiten(true).fit(&data)?;
140 let mut embeddings = pca.predict(&data);
141 debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
142
143 debug!("Normalizing embeddings");
145 normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
146 embeddings
147 }
148 Self::None => {
149 debug!("Using original data as embeddings");
150 samples.0
151 }
152 };
153 debug!("Embeddings shape: {:?}", result.shape());
154 Ok(result)
155 }
156}
157
158fn normalize_embeddings_inplace<const SIZE: usize>(embeddings: &mut Array2<f64>) {
161 for i in 0..SIZE {
162 let min = embeddings.column(i).min();
163 let max = embeddings.column(i).max();
164 let range = max - min;
165 embeddings
166 .column_mut(i)
167 .mapv_inplace(|v| ((v - min) / range).mul_add(2., -1.));
168 }
169}
170
171const EMBEDDING_SIZE: usize = {
174 let log2 = usize::ilog2(NUMBER_FEATURES) as usize;
175 if log2 < 2 { 2 } else { log2 }
176};
177
178#[allow(clippy::module_name_repetitions)]
179pub struct ClusteringHelper<S>
180where
181 S: Sized,
182{
183 state: S,
184}
185
186pub struct EntryPoint;
187pub struct NotInitialized {
188 embeddings: Array2<Feature>,
190 pub k_max: usize,
191 pub optimizer: KOptimal,
192 pub clustering_method: ClusteringMethod,
193}
194pub struct Initialized {
195 embeddings: Array2<Feature>,
197 pub k: usize,
198 pub clustering_method: ClusteringMethod,
199}
200pub struct Finished {
201 labels: Array1<usize>,
204 pub k: usize,
205}
206
207impl ClusteringHelper<EntryPoint> {
209 #[allow(clippy::missing_inline_in_public_items)]
215 pub fn new(
216 samples: AnalysisArray,
217 k_max: usize,
218 optimizer: KOptimal,
219 clustering_method: ClusteringMethod,
220 projection_method: ProjectionMethod,
221 ) -> Result<ClusteringHelper<NotInitialized>, ClusteringError> {
222 if samples.0.nrows() <= 15 {
223 return Err(ClusteringError::SmallLibrary);
224 }
225
226 let embeddings = projection_method.project(samples)?;
228
229 Ok(ClusteringHelper {
230 state: NotInitialized {
231 embeddings,
232 k_max,
233 optimizer,
234 clustering_method,
235 },
236 })
237 }
238}
239
240impl ClusteringHelper<NotInitialized> {
242 #[inline]
248 pub fn initialize(self) -> Result<ClusteringHelper<Initialized>, ClusteringError> {
249 let k = self.get_optimal_k()?;
250 Ok(ClusteringHelper {
251 state: Initialized {
252 embeddings: self.state.embeddings,
253 k,
254 clustering_method: self.state.clustering_method,
255 },
256 })
257 }
258
259 fn get_optimal_k(&self) -> Result<usize, ClusteringError> {
260 match self.state.optimizer {
261 KOptimal::GapStatistic { b } => self.get_optimal_k_gap_statistic(b),
262 KOptimal::DaviesBouldin => self.get_optimal_k_davies_bouldin(),
263 }
264 }
265
266 fn get_optimal_k_gap_statistic(&self, b: u32) -> Result<usize, ClusteringError> {
283 let reference_data_sets =
285 generate_reference_data_set(self.state.embeddings.view(), b as usize);
286
287 let b = f64::from(b);
288
289 let results = (1..=self.state.k_max)
290 .map(|k| {
292 debug!("Fitting k-means to embeddings with k={k}");
293 let labels = self.state.clustering_method.fit(k, &self.state.embeddings);
294 (k, labels)
295 })
296 .map(|(k, labels)| {
298 debug!(
300 "Calculating within intra-cluster variation for reference data sets with k={k}"
301 );
302 let w_kb_log: Vec<_> = reference_data_sets
303 .par_iter()
304 .map(|ref_data| {
305 let ref_labels = self.state.clustering_method.fit(k, ref_data);
307 let ref_pairwise_distances =
309 calc_pairwise_distances(ref_data.view(), k, ref_labels.view());
310 calc_within_dispersion(ref_labels.view(), k, ref_pairwise_distances.view())
311 .log2()
312 })
313 .collect();
314
315 let pairwise_distances =
317 calc_pairwise_distances(self.state.embeddings.view(), k, labels.view());
318 let w_k = calc_within_dispersion(labels.view(), k, pairwise_distances.view());
319
320 let w_kb_log_sum: f64 = w_kb_log.iter().sum();
322 let l = b.recip() * w_kb_log_sum;
324 let gap_k = l - w_k.log2();
326 let standard_deviation = (b.recip()
328 * w_kb_log
329 .iter()
330 .map(|w_kb_log| (w_kb_log - l).powi(2))
331 .sum::<f64>())
332 .sqrt();
333 let s_k = standard_deviation * (1.0 + b.recip()).sqrt();
336
337 (k, gap_k, s_k)
338 });
339
340 let (mut optimal_k, mut gap_k_minus_one) = (None, None);
346 for (k, gap_k, s_k) in results {
347 info!("k: {k}, gap_k: {gap_k}, s_k: {s_k}");
348
349 if let Some(gap_k_minus_one) = gap_k_minus_one
350 && gap_k_minus_one >= gap_k - s_k
351 {
352 info!("Optimal k found: {}", k - 1);
353 optimal_k = Some(k - 1);
354 break;
355 }
356
357 gap_k_minus_one = Some(gap_k);
358 }
359
360 optimal_k.ok_or(ClusteringError::OptimalKNotFound(self.state.k_max))
361 }
362
363 fn get_optimal_k_davies_bouldin(&self) -> Result<usize, ClusteringError> {
364 todo!();
365 }
366}
367
368#[must_use]
374#[inline]
375pub fn convert_to_array(data: Vec<Analysis>) -> AnalysisArray {
376 let shape = (data.len(), NUMBER_FEATURES);
378 debug_assert_eq!(NUMBER_FEATURES, data[0].inner().len());
379
380 AnalysisArray(
381 Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
382 .expect("Failed to convert to array, shape mismatch"),
383 )
384}
385
386fn generate_reference_data_set(samples: ArrayView2<Feature>, b: usize) -> Vec<Array2<f64>> {
410 let mut reference_data_sets = Vec::with_capacity(b);
411 for _ in 0..b {
412 reference_data_sets.push(generate_ref_single(samples));
413 }
414
415 reference_data_sets
416}
417fn generate_ref_single(samples: ArrayView2<Feature>) -> Array2<f64> {
418 let feature_distributions = samples
419 .axis_iter(Axis(1))
420 .map(|feature| Array::random(feature.dim(), Uniform::new(feature.min(), feature.max())))
421 .collect::<Vec<_>>();
422 let feature_dists_views = feature_distributions
423 .iter()
424 .map(ndarray::ArrayBase::view)
425 .collect::<Vec<_>>();
426 ndarray::stack(Axis(0), &feature_dists_views)
427 .unwrap()
428 .t()
429 .to_owned()
430}
431
432fn calc_within_dispersion(
436 labels: ArrayView1<usize>,
437 k: usize,
438 pairwise_distances: ArrayView1<Feature>,
439) -> Feature {
440 debug_assert_eq!(k, labels.iter().max().unwrap() + 1);
441
442 let counts = labels.iter().fold(vec![0u32; k], |mut counts, &label| {
444 counts[label] += 1;
445 counts
446 });
447 counts
449 .iter()
450 .zip(pairwise_distances.iter())
451 .map(|(&count, distance)| (2.0 * f64::from(count)).recip() * distance)
452 .sum()
453}
454
455fn calc_pairwise_distances(
463 samples: ArrayView2<Feature>,
464 k: usize,
465 labels: ArrayView1<usize>,
466) -> Array1<Feature> {
467 debug_assert_eq!(
468 samples.nrows(),
469 labels.len(),
470 "Samples and labels must have the same length"
471 );
472 debug_assert_eq!(
473 k,
474 labels.iter().max().unwrap() + 1,
475 "Labels must be in the range 0..k"
476 );
477
478 let mut distances = Array1::zeros(k);
480 let mut clusters = vec![Vec::new(); k];
481 for (sample, label) in samples.outer_iter().zip(labels.iter()) {
483 clusters[*label].push(sample);
484 }
485 for (k, cluster) in clusters.iter().enumerate() {
487 let mut pairwise_dists = 0.;
488 for i in 0..cluster.len() - 1 {
489 let a = cluster[i];
490 let rest = &cluster[i + 1..];
491 for &b in rest {
492 pairwise_dists += L2Dist.distance(a, b);
493 }
494 }
495 distances[k] += pairwise_dists + pairwise_dists;
496 }
497 distances
498}
499
500impl ClusteringHelper<Initialized> {
502 #[must_use]
508 #[inline]
509 pub fn cluster(self) -> ClusteringHelper<Finished> {
510 let labels = self
511 .state
512 .clustering_method
513 .fit(self.state.k, &self.state.embeddings);
514
515 ClusteringHelper {
516 state: Finished {
517 labels,
518 k: self.state.k,
519 },
520 }
521 }
522}
523
524impl ClusteringHelper<Finished> {
526 #[must_use]
528 #[inline]
529 pub fn extract_analysis_clusters<T: Clone>(&self, samples: Vec<T>) -> Vec<Vec<T>> {
530 let mut clusters = vec![Vec::new(); self.state.k];
531
532 for (sample, &label) in samples.into_iter().zip(self.state.labels.iter()) {
533 clusters[label].push(sample);
534 }
535
536 clusters
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543 use ndarray::{arr1, arr2, s};
544 use ndarray_rand::rand_distr::StandardNormal;
545 use pretty_assertions::assert_eq;
546 use rstest::rstest;
547
548 #[test]
549 fn test_generate_reference_data_set() {
550 let data = arr2(&[[10.0, -10.0], [20.0, -20.0], [30.0, -30.0]]);
551
552 let ref_data = generate_ref_single(data.view());
553
554 assert!(
556 ref_data
557 .slice(s![.., 0])
558 .iter()
559 .all(|v| *v >= 10.0 && *v <= 30.0)
560 );
561
562 assert!(
564 ref_data
565 .slice(s![.., 1])
566 .iter()
567 .all(|v| *v <= -10.0 && *v >= -30.0)
568 );
569
570 assert_eq!(ref_data.shape(), data.shape());
572
573 assert_ne!(ref_data, data);
575 }
576
577 #[test]
578 fn test_pairwise_distances() {
579 let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
580 let labels = arr1(&[0, 0, 1, 1]);
581
582 let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
583
584 assert!(
585 f64::EPSILON > (pairwise_distances[0] - 0.0).abs(),
586 "{} != 0.0",
587 pairwise_distances[0]
588 );
589 assert!(
590 f64::EPSILON > (pairwise_distances[1] - 0.0).abs(),
591 "{} != 0.0",
592 pairwise_distances[1]
593 );
594
595 let samples = arr2(&[[1.0, 2.0], [1.0, 1.0], [2.0, 2.0], [2.0, 3.0]]);
596
597 let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
598
599 assert!(
600 f64::EPSILON > (pairwise_distances[0] - 2.0).abs(),
601 "{} != 2.0",
602 pairwise_distances[0]
603 );
604 assert!(
605 f64::EPSILON > (pairwise_distances[1] - 2.0).abs(),
606 "{} != 2.0",
607 pairwise_distances[1]
608 );
609 }
610
611 #[test]
612 fn test_convert_to_vec() {
613 let data = vec![
614 Analysis::new([1.0; NUMBER_FEATURES]),
615 Analysis::new([2.0; NUMBER_FEATURES]),
616 Analysis::new([3.0; NUMBER_FEATURES]),
617 ];
618
619 let array = convert_to_array(data);
620
621 assert_eq!(array.0.shape(), &[3, NUMBER_FEATURES]);
622 assert!(
623 f64::EPSILON > (array.0[[0, 0]] - 1.0).abs(),
624 "{} != 1.0",
625 array.0[[0, 0]]
626 );
627 assert!(
628 f64::EPSILON > (array.0[[1, 0]] - 2.0).abs(),
629 "{} != 2.0",
630 array.0[[1, 0]]
631 );
632 assert!(
633 f64::EPSILON > (array.0[[2, 0]] - 3.0).abs(),
634 "{} != 3.0",
635 array.0[[2, 0]]
636 );
637
638 let mut iter = array.0.axis_iter(Axis(0));
641 assert_eq!(iter.next().unwrap().to_vec(), vec![1.0; NUMBER_FEATURES]);
642 assert_eq!(iter.next().unwrap().to_vec(), vec![2.0; NUMBER_FEATURES]);
643 assert_eq!(iter.next().unwrap().to_vec(), vec![3.0; NUMBER_FEATURES]);
644 for column in array.0.axis_iter(Axis(1)) {
646 assert_eq!(column.to_vec(), vec![1.0, 2.0, 3.0]);
647 }
648 }
649
650 #[test]
651 fn test_calc_within_dispersion() {
652 let labels = arr1(&[0, 1, 0, 1]);
653 let pairwise_distances = arr1(&[1.0, 2.0]);
654 let result = calc_within_dispersion(labels.view(), 2, pairwise_distances.view());
655
656 assert!(f64::EPSILON > (result - 0.75).abs(), "{result} != 0.75");
658 }
659
660 #[rstest]
661 #[case::project_none(ProjectionMethod::None, NUMBER_FEATURES)]
662 #[case::project_tsne(ProjectionMethod::TSne, EMBEDDING_SIZE)]
663 #[case::project_pca(ProjectionMethod::Pca, EMBEDDING_SIZE)]
664 fn test_project(
665 #[case] projection_method: ProjectionMethod,
666 #[case] expected_embedding_size: usize,
667 ) {
668 let mut samples = Array2::random((100, NUMBER_FEATURES), StandardNormal);
672 normalize_embeddings_inplace::<NUMBER_FEATURES>(&mut samples);
673 let samples = AnalysisArray(samples);
674
675 let result = projection_method.project(samples).unwrap();
676
677 assert_eq!(result.shape(), &[100, expected_embedding_size]);
679
680 for i in 0..expected_embedding_size {
682 let min = result.column(i).min();
683 let max = result.column(i).max();
684 assert!(
685 f64::EPSILON > (min + 1.0).abs(),
686 "Min value of column {i} is not -1.0: {min}",
687 );
688 assert!(
689 f64::EPSILON > (max - 1.0).abs(),
690 "Max value of column {i} is not 1.0: {max}",
691 );
692 }
693 }
694}
695
696