1use linfa::prelude::*;
13use linfa_clustering::{GaussianMixtureModel, KMeans};
14use linfa_nn::distance::{Distance, L2Dist};
15use linfa_reduction::Pca;
16use linfa_tsne::TSneParams;
17use log::{debug, info};
18use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis, Dim};
19use ndarray_rand::RandomExt;
20use rand::distributions::Uniform;
21use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
22use statrs::statistics::Statistics;
23
24use crate::{
25 Analysis, Feature, NUMBER_FEATURES,
26 errors::{ClusteringError, ProjectionError},
27};
28
29pub struct AnalysisArray(pub(crate) Array2<Feature>);
30
31impl From<Vec<Analysis>> for AnalysisArray {
32 #[inline]
33 fn from(data: Vec<Analysis>) -> Self {
34 let shape = (data.len(), NUMBER_FEATURES);
35 debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
36
37 Self(
38 Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
39 .expect("Failed to convert to array, shape mismatch"),
40 )
41 }
42}
43
44impl From<Vec<[Feature; NUMBER_FEATURES]>> for AnalysisArray {
45 #[inline]
46 fn from(data: Vec<[Feature; NUMBER_FEATURES]>) -> Self {
47 let shape = (data.len(), NUMBER_FEATURES);
48 debug_assert_eq!(shape, (data.len(), data[0].len()));
49
50 Self(
51 Array2::from_shape_vec(shape, data.into_iter().flatten().collect())
52 .expect("Failed to convert to array, shape mismatch"),
53 )
54 }
55}
56
57pub type FitDataset = Dataset<Feature, (), Dim<[usize; 1]>>;
58
59#[derive(Clone, Copy, Debug)]
60#[allow(clippy::module_name_repetitions)]
61pub enum ClusteringMethod {
62 KMeans,
63 GaussianMixtureModel,
64}
65
66impl ClusteringMethod {
67 #[must_use]
69 fn fit(self, k: usize, data: &FitDataset) -> Array1<usize> {
70 match self {
71 Self::KMeans => {
72 let model = KMeans::params(k)
73 .fit(data)
75 .unwrap();
76 model.predict(data.records())
77 }
78 Self::GaussianMixtureModel => {
79 let model = GaussianMixtureModel::params(k)
80 .init_method(linfa_clustering::GmmInitMethod::KMeans)
81 .n_runs(10)
82 .fit(data)
83 .unwrap();
84 model.predict(data.records())
85 }
86 }
87 }
88}
89
90#[derive(Clone, Copy, Debug)]
91pub enum KOptimal {
92 GapStatistic {
93 b: u32,
95 },
96 DaviesBouldin,
97}
98
99#[derive(Clone, Copy, Debug, Default)]
100pub enum ProjectionMethod {
102 TSne,
104 Pca,
106 #[default]
107 None,
109}
110
111impl ProjectionMethod {
112 #[inline]
118 pub fn project(self, samples: AnalysisArray) -> Result<Array2<Feature>, ProjectionError> {
119 let result = match self {
120 Self::TSne => {
121 let nrecords = samples.0.nrows();
122 debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using t-SNE");
124 #[allow(clippy::cast_precision_loss)]
125 let mut embeddings = TSneParams::embedding_size(EMBEDDING_SIZE)
126 .perplexity(f64::max(samples.0.nrows() as f64 / 20., 5.))
127 .approx_threshold(0.5)
128 .transform(samples.0)?;
129 debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
130
131 debug!("Normalizing embeddings");
133 normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
134 embeddings
135 }
136 Self::Pca => {
137 let nrecords = samples.0.nrows();
138 debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using PCA");
140 let data = Dataset::from(samples.0);
141 let pca: Pca<f64> = Pca::params(EMBEDDING_SIZE).whiten(true).fit(&data)?;
142 let mut embeddings = pca.predict(&data);
143 debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
144
145 debug!("Normalizing embeddings");
147 normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
148 embeddings
149 }
150 Self::None => {
151 debug!("Using original data as embeddings");
152 samples.0
153 }
154 };
155 debug!("Embeddings shape: {:?}", result.shape());
156 Ok(result)
157 }
158}
159
160fn normalize_embeddings_inplace<const SIZE: usize>(embeddings: &mut Array2<f64>) {
163 for i in 0..SIZE {
164 let min = embeddings.column(i).min();
165 let max = embeddings.column(i).max();
166 let range = max - min;
167 embeddings
168 .column_mut(i)
169 .mapv_inplace(|v| ((v - min) / range).mul_add(2., -1.));
170 }
171}
172
173const EMBEDDING_SIZE: usize = {
176 let log2 = usize::ilog2(NUMBER_FEATURES) as usize;
177 if log2 < 2 { 2 } else { log2 }
178};
179
180#[allow(clippy::module_name_repetitions)]
181pub struct ClusteringHelper<S>
182where
183 S: Sized,
184{
185 state: S,
186}
187
188pub struct EntryPoint;
189pub struct NotInitialized {
190 embeddings: Array2<Feature>,
192 pub k_max: usize,
193 pub optimizer: KOptimal,
194 pub clustering_method: ClusteringMethod,
195}
196pub struct Initialized {
197 embeddings: Array2<Feature>,
199 pub k: usize,
200 pub clustering_method: ClusteringMethod,
201}
202pub struct Finished {
203 labels: Array1<usize>,
206 pub k: usize,
207}
208
209impl ClusteringHelper<EntryPoint> {
211 #[allow(clippy::missing_inline_in_public_items)]
217 pub fn new(
218 samples: AnalysisArray,
219 k_max: usize,
220 optimizer: KOptimal,
221 clustering_method: ClusteringMethod,
222 projection_method: ProjectionMethod,
223 ) -> Result<ClusteringHelper<NotInitialized>, ClusteringError> {
224 if samples.0.nrows() <= 15 {
225 return Err(ClusteringError::SmallLibrary);
226 }
227
228 let embeddings = projection_method.project(samples)?;
230
231 Ok(ClusteringHelper {
232 state: NotInitialized {
233 embeddings,
234 k_max,
235 optimizer,
236 clustering_method,
237 },
238 })
239 }
240}
241
242impl ClusteringHelper<NotInitialized> {
244 #[inline]
250 pub fn initialize(self) -> Result<ClusteringHelper<Initialized>, ClusteringError> {
251 let k = self.get_optimal_k()?;
252 Ok(ClusteringHelper {
253 state: Initialized {
254 embeddings: self.state.embeddings,
255 k,
256 clustering_method: self.state.clustering_method,
257 },
258 })
259 }
260
261 fn get_optimal_k(&self) -> Result<usize, ClusteringError> {
262 match self.state.optimizer {
263 KOptimal::GapStatistic { b } => self.get_optimal_k_gap_statistic(b),
264 KOptimal::DaviesBouldin => self.get_optimal_k_davies_bouldin(),
265 }
266 }
267
268 fn get_optimal_k_gap_statistic(&self, b: u32) -> Result<usize, ClusteringError> {
285 let embedding_dataset = Dataset::from(self.state.embeddings.clone());
286
287 let reference_datasets =
289 generate_reference_datasets(embedding_dataset.records().view(), b as usize);
290
291 let b = f64::from(b);
292
293 let results = (1..=self.state.k_max)
294 .map(|k| {
296 debug!("Fitting k-means to embeddings with k={k}");
297 let labels = self.state.clustering_method.fit(k, &embedding_dataset);
298 (k, labels)
299 })
300 .map(|(k, labels)| {
302 debug!(
304 "Calculating within intra-cluster variation for reference data sets with k={k}"
305 );
306 let w_kb_log: Vec<_> = reference_datasets
307 .par_iter()
308 .map(|ref_data| {
309 let ref_labels = self.state.clustering_method.fit(k, ref_data);
311 let ref_pairwise_distances = calc_pairwise_distances(
313 ref_data.records().view(),
314 k,
315 ref_labels.view(),
316 );
317 calc_within_dispersion(ref_labels.view(), k, ref_pairwise_distances.view())
318 .log2()
319 })
320 .collect();
321
322 let pairwise_distances =
324 calc_pairwise_distances(self.state.embeddings.view(), k, labels.view());
325 let w_k = calc_within_dispersion(labels.view(), k, pairwise_distances.view());
326
327 let w_kb_log_sum: f64 = w_kb_log.iter().sum();
329 let l = b.recip() * w_kb_log_sum;
331 let gap_k = l - w_k.log2();
333 let standard_deviation = (b.recip()
335 * w_kb_log
336 .iter()
337 .map(|w_kb_log| (w_kb_log - l).powi(2))
338 .sum::<f64>())
339 .sqrt();
340 let s_k = standard_deviation * (1.0 + b.recip()).sqrt();
343
344 (k, gap_k, s_k)
345 });
346
347 let (mut optimal_k, mut gap_k_minus_one) = (None, None);
353 for (k, gap_k, s_k) in results {
354 info!("k: {k}, gap_k: {gap_k}, s_k: {s_k}");
355
356 if let Some(gap_k_minus_one) = gap_k_minus_one
357 && gap_k_minus_one >= gap_k - s_k
358 {
359 info!("Optimal k found: {}", k - 1);
360 optimal_k = Some(k - 1);
361 break;
362 }
363
364 gap_k_minus_one = Some(gap_k);
365 }
366
367 optimal_k.ok_or(ClusteringError::OptimalKNotFound(self.state.k_max))
368 }
369
370 fn get_optimal_k_davies_bouldin(&self) -> Result<usize, ClusteringError> {
371 todo!();
372 }
373}
374
375#[must_use]
381#[inline]
382pub fn convert_to_array(data: Vec<Analysis>) -> AnalysisArray {
383 let shape = (data.len(), NUMBER_FEATURES);
385 debug_assert_eq!(NUMBER_FEATURES, data[0].inner().len());
386
387 AnalysisArray(
388 Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
389 .expect("Failed to convert to array, shape mismatch"),
390 )
391}
392
393fn generate_reference_datasets(samples: ArrayView2<Feature>, b: usize) -> Vec<FitDataset> {
417 (0..b)
418 .into_par_iter()
419 .map(|_| Dataset::from(generate_ref_single(samples.view())))
420 .collect()
421}
422fn generate_ref_single(samples: ArrayView2<Feature>) -> Array2<Feature> {
423 let feature_distributions = samples
424 .axis_iter(Axis(1))
425 .map(|feature| Array::random(feature.dim(), Uniform::new(feature.min(), feature.max())))
426 .collect::<Vec<_>>();
427 let feature_dists_views = feature_distributions
428 .iter()
429 .map(ndarray::ArrayBase::view)
430 .collect::<Vec<_>>();
431 ndarray::stack(Axis(0), &feature_dists_views)
432 .unwrap()
433 .t()
434 .to_owned()
435}
436
437fn calc_within_dispersion(
441 labels: ArrayView1<usize>,
442 k: usize,
443 pairwise_distances: ArrayView1<Feature>,
444) -> Feature {
445 debug_assert_eq!(k, labels.iter().max().unwrap() + 1);
446
447 let counts = labels.iter().fold(vec![0u32; k], |mut counts, &label| {
449 counts[label] += 1;
450 counts
451 });
452 counts
454 .iter()
455 .zip(pairwise_distances.iter())
456 .map(|(&count, distance)| (2.0 * f64::from(count)).recip() * distance)
457 .sum()
458}
459
460fn calc_pairwise_distances(
468 samples: ArrayView2<Feature>,
469 k: usize,
470 labels: ArrayView1<usize>,
471) -> Array1<Feature> {
472 debug_assert_eq!(
473 samples.nrows(),
474 labels.len(),
475 "Samples and labels must have the same length"
476 );
477 debug_assert_eq!(
478 k,
479 labels.iter().max().unwrap() + 1,
480 "Labels must be in the range 0..k"
481 );
482
483 let mut distances = Array1::zeros(k);
485 let mut clusters = vec![Vec::new(); k];
486 for (sample, label) in samples.outer_iter().zip(labels.iter()) {
488 clusters[*label].push(sample);
489 }
490 for (k, cluster) in clusters.iter().enumerate() {
492 let mut pairwise_dists = 0.;
493 for i in 0..cluster.len() - 1 {
494 let a = cluster[i];
495 let rest = &cluster[i + 1..];
496 for &b in rest {
497 pairwise_dists += L2Dist.distance(a, b);
498 }
499 }
500 distances[k] += pairwise_dists + pairwise_dists;
501 }
502 distances
503}
504
505impl ClusteringHelper<Initialized> {
507 #[must_use]
513 #[inline]
514 pub fn cluster(self) -> ClusteringHelper<Finished> {
515 let Initialized {
516 clustering_method,
517 embeddings,
518 k,
519 } = self.state;
520
521 let embedding_dataset = Dataset::from(embeddings);
522 let labels = clustering_method.fit(k, &embedding_dataset);
523
524 ClusteringHelper {
525 state: Finished { labels, k },
526 }
527 }
528}
529
530impl ClusteringHelper<Finished> {
532 #[must_use]
534 #[inline]
535 pub fn extract_analysis_clusters<T: Clone>(&self, samples: Vec<T>) -> Vec<Vec<T>> {
536 let mut clusters = vec![Vec::new(); self.state.k];
537
538 for (sample, &label) in samples.into_iter().zip(self.state.labels.iter()) {
539 clusters[label].push(sample);
540 }
541
542 clusters
543 }
544}
545
546#[cfg(test)]
547mod tests {
548 use super::*;
549 use ndarray::{arr1, arr2, s};
550 use ndarray_rand::rand_distr::StandardNormal;
551 use pretty_assertions::assert_eq;
552 use rstest::rstest;
553
554 #[test]
555 fn test_generate_reference_data_set() {
556 let data = arr2(&[[10.0, -10.0], [20.0, -20.0], [30.0, -30.0]]);
557
558 let ref_data = generate_ref_single(data.view());
559
560 assert!(
562 ref_data
563 .slice(s![.., 0])
564 .iter()
565 .all(|v| *v >= 10.0 && *v <= 30.0)
566 );
567
568 assert!(
570 ref_data
571 .slice(s![.., 1])
572 .iter()
573 .all(|v| *v <= -10.0 && *v >= -30.0)
574 );
575
576 assert_eq!(ref_data.shape(), data.shape());
578
579 assert_ne!(ref_data, data);
581 }
582
583 #[test]
584 fn test_pairwise_distances() {
585 let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
586 let labels = arr1(&[0, 0, 1, 1]);
587
588 let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
589
590 assert!(
591 f64::EPSILON > (pairwise_distances[0] - 0.0).abs(),
592 "{} != 0.0",
593 pairwise_distances[0]
594 );
595 assert!(
596 f64::EPSILON > (pairwise_distances[1] - 0.0).abs(),
597 "{} != 0.0",
598 pairwise_distances[1]
599 );
600
601 let samples = arr2(&[[1.0, 2.0], [1.0, 1.0], [2.0, 2.0], [2.0, 3.0]]);
602
603 let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
604
605 assert!(
606 f64::EPSILON > (pairwise_distances[0] - 2.0).abs(),
607 "{} != 2.0",
608 pairwise_distances[0]
609 );
610 assert!(
611 f64::EPSILON > (pairwise_distances[1] - 2.0).abs(),
612 "{} != 2.0",
613 pairwise_distances[1]
614 );
615 }
616
617 #[test]
618 fn test_convert_to_vec() {
619 let data = vec![
620 Analysis::new([1.0; NUMBER_FEATURES]),
621 Analysis::new([2.0; NUMBER_FEATURES]),
622 Analysis::new([3.0; NUMBER_FEATURES]),
623 ];
624
625 let array = convert_to_array(data);
626
627 assert_eq!(array.0.shape(), &[3, NUMBER_FEATURES]);
628 assert!(
629 f64::EPSILON > (array.0[[0, 0]] - 1.0).abs(),
630 "{} != 1.0",
631 array.0[[0, 0]]
632 );
633 assert!(
634 f64::EPSILON > (array.0[[1, 0]] - 2.0).abs(),
635 "{} != 2.0",
636 array.0[[1, 0]]
637 );
638 assert!(
639 f64::EPSILON > (array.0[[2, 0]] - 3.0).abs(),
640 "{} != 3.0",
641 array.0[[2, 0]]
642 );
643
644 let mut iter = array.0.axis_iter(Axis(0));
647 assert_eq!(iter.next().unwrap().to_vec(), vec![1.0; NUMBER_FEATURES]);
648 assert_eq!(iter.next().unwrap().to_vec(), vec![2.0; NUMBER_FEATURES]);
649 assert_eq!(iter.next().unwrap().to_vec(), vec![3.0; NUMBER_FEATURES]);
650 for column in array.0.axis_iter(Axis(1)) {
652 assert_eq!(column.to_vec(), vec![1.0, 2.0, 3.0]);
653 }
654 }
655
656 #[test]
657 fn test_calc_within_dispersion() {
658 let labels = arr1(&[0, 1, 0, 1]);
659 let pairwise_distances = arr1(&[1.0, 2.0]);
660 let result = calc_within_dispersion(labels.view(), 2, pairwise_distances.view());
661
662 assert!(f64::EPSILON > (result - 0.75).abs(), "{result} != 0.75");
664 }
665
666 #[rstest]
667 #[case::project_none(ProjectionMethod::None, NUMBER_FEATURES)]
668 #[case::project_tsne(ProjectionMethod::TSne, EMBEDDING_SIZE)]
669 #[case::project_pca(ProjectionMethod::Pca, EMBEDDING_SIZE)]
670 fn test_project(
671 #[case] projection_method: ProjectionMethod,
672 #[case] expected_embedding_size: usize,
673 ) {
674 let mut samples = Array2::random((100, NUMBER_FEATURES), StandardNormal);
678 normalize_embeddings_inplace::<NUMBER_FEATURES>(&mut samples);
679 let samples = AnalysisArray(samples);
680
681 let result = projection_method.project(samples).unwrap();
682
683 assert_eq!(result.shape(), &[100, expected_embedding_size]);
685
686 for i in 0..expected_embedding_size {
688 let min = result.column(i).min();
689 let max = result.column(i).max();
690 assert!(
691 f64::EPSILON > (min + 1.0).abs(),
692 "Min value of column {i} is not -1.0: {min}",
693 );
694 assert!(
695 f64::EPSILON > (max - 1.0).abs(),
696 "Max value of column {i} is not 1.0: {max}",
697 );
698 }
699 }
700}
701
702