1use linfa::prelude::*;
13use linfa_clustering::{GaussianMixtureModel, KMeans};
14use linfa_nn::distance::{Distance, L2Dist};
15use linfa_reduction::Pca;
16use linfa_tsne::TSneParams;
17use log::{debug, info};
18use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis, Dim};
19use ndarray_rand::RandomExt;
20use rand::distributions::Uniform;
21use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
22use statrs::statistics::Statistics;
23
24use crate::{
25 Analysis, Feature, NUMBER_FEATURES,
26 errors::{ClusteringError, ProjectionError},
27};
28
29pub struct AnalysisArray(pub(crate) Array2<Feature>);
30
31impl From<Vec<Analysis>> for AnalysisArray {
32 #[inline]
33 fn from(data: Vec<Analysis>) -> Self {
34 let shape = (data.len(), NUMBER_FEATURES);
35 debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
36
37 Self(
38 Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
39 .expect("Failed to convert to array, shape mismatch"),
40 )
41 }
42}
43
44impl From<Vec<[Feature; NUMBER_FEATURES]>> for AnalysisArray {
45 #[inline]
46 fn from(data: Vec<[Feature; NUMBER_FEATURES]>) -> Self {
47 let shape = (data.len(), NUMBER_FEATURES);
48 debug_assert_eq!(shape, (data.len(), data[0].len()));
49
50 Self(
51 Array2::from_shape_vec(shape, data.into_iter().flatten().collect())
52 .expect("Failed to convert to array, shape mismatch"),
53 )
54 }
55}
56
57pub type FitDataset = Dataset<Feature, (), Dim<[usize; 1]>>;
58
59#[derive(Clone, Copy, Debug)]
60#[allow(clippy::module_name_repetitions)]
61pub enum ClusteringMethod {
62 KMeans,
63 GaussianMixtureModel,
64}
65
66impl ClusteringMethod {
67 #[must_use]
69 fn fit(self, k: usize, data: &FitDataset) -> Array1<usize> {
70 match self {
71 Self::KMeans => {
72 let model = KMeans::params(k)
73 .fit(data)
75 .unwrap();
76 model.predict(data.records())
77 }
78 Self::GaussianMixtureModel => {
79 let model = GaussianMixtureModel::params(k)
80 .init_method(linfa_clustering::GmmInitMethod::KMeans)
81 .n_runs(10)
82 .fit(data)
83 .unwrap();
84 model.predict(data.records())
85 }
86 }
87 }
88}
89
90#[derive(Clone, Copy, Debug)]
91pub enum KOptimal {
92 GapStatistic {
93 b: u32,
95 },
96 DaviesBouldin,
97}
98
99#[derive(Clone, Copy, Debug, Default)]
100pub enum ProjectionMethod {
102 TSne,
104 Pca,
106 #[default]
107 None,
109}
110
111impl ProjectionMethod {
112 #[inline]
118 pub fn project(self, samples: AnalysisArray) -> Result<Array2<Feature>, ProjectionError> {
119 let result = match self {
120 Self::TSne => {
121 let nrecords = samples.0.nrows();
122 debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using t-SNE");
124 #[allow(clippy::cast_precision_loss)]
125 let mut embeddings = TSneParams::embedding_size(EMBEDDING_SIZE)
126 .perplexity(f64::max(samples.0.nrows() as f64 / 20., 5.))
127 .approx_threshold(0.5)
128 .transform(samples.0)?;
129 debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
130
131 debug!("Normalizing embeddings");
133 normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
134 embeddings
135 }
136 Self::Pca => {
137 let nrecords = samples.0.nrows();
138 debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using PCA");
140 let data = Dataset::from(samples.0);
141 let pca: Pca<f64> = Pca::params(EMBEDDING_SIZE).whiten(true).fit(&data)?;
142 let mut embeddings = pca.predict(&data);
143 debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
144
145 debug!("Normalizing embeddings");
147 normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
148 embeddings
149 }
150 Self::None => {
151 debug!("Using original data as embeddings");
152 samples.0
153 }
154 };
155 debug!("Embeddings shape: {:?}", result.shape());
156 Ok(result)
157 }
158}
159
160fn normalize_embeddings_inplace<const SIZE: usize>(embeddings: &mut Array2<f64>) {
163 for i in 0..SIZE {
164 let min = embeddings.column(i).min();
165 let max = embeddings.column(i).max();
166 let range = max - min;
167 embeddings
168 .column_mut(i)
169 .mapv_inplace(|v| ((v - min) / range).mul_add(2., -1.));
170 }
171}
172
173const EMBEDDING_SIZE: usize = {
176 let log2 = usize::ilog2(NUMBER_FEATURES) as usize;
177 if log2 < 2 { 2 } else { log2 }
178};
179
180#[allow(clippy::module_name_repetitions)]
181pub struct ClusteringHelper<S>
182where
183 S: Sized,
184{
185 state: S,
186}
187
188pub struct EntryPoint;
189pub struct NotInitialized {
190 embeddings: Array2<Feature>,
192 pub k_max: usize,
193 pub optimizer: KOptimal,
194 pub clustering_method: ClusteringMethod,
195}
196pub struct Initialized {
197 embeddings: Array2<Feature>,
199 pub k: usize,
200 pub clustering_method: ClusteringMethod,
201}
202pub struct Finished {
203 labels: Array1<usize>,
206 pub k: usize,
207}
208
209impl ClusteringHelper<EntryPoint> {
211 #[allow(clippy::missing_inline_in_public_items)]
217 pub fn new(
218 samples: AnalysisArray,
219 k_max: usize,
220 optimizer: KOptimal,
221 clustering_method: ClusteringMethod,
222 projection_method: ProjectionMethod,
223 ) -> Result<ClusteringHelper<NotInitialized>, ClusteringError> {
224 if samples.0.nrows() <= 15 {
225 return Err(ClusteringError::SmallLibrary);
226 }
227
228 let embeddings = projection_method.project(samples)?;
230
231 Ok(ClusteringHelper {
232 state: NotInitialized {
233 embeddings,
234 k_max,
235 optimizer,
236 clustering_method,
237 },
238 })
239 }
240}
241
242impl ClusteringHelper<NotInitialized> {
244 #[inline]
250 pub fn initialize(self) -> Result<ClusteringHelper<Initialized>, ClusteringError> {
251 let k = self.get_optimal_k()?;
252 Ok(ClusteringHelper {
253 state: Initialized {
254 embeddings: self.state.embeddings,
255 k,
256 clustering_method: self.state.clustering_method,
257 },
258 })
259 }
260
261 fn get_optimal_k(&self) -> Result<usize, ClusteringError> {
262 match self.state.optimizer {
263 KOptimal::GapStatistic { b } => self.get_optimal_k_gap_statistic(b),
264 KOptimal::DaviesBouldin => self.get_optimal_k_davies_bouldin(),
265 }
266 }
267
268 fn get_optimal_k_gap_statistic(&self, b: u32) -> Result<usize, ClusteringError> {
285 let embedding_dataset = Dataset::from(self.state.embeddings.clone());
286
287 let reference_datasets =
289 generate_reference_datasets(embedding_dataset.records().view(), b as usize);
290
291 let b = f64::from(b);
292
293 let results = (1..=self.state.k_max)
294 .map(|k| {
296 debug!("Fitting k-means to embeddings with k={k}");
297 let labels = self.state.clustering_method.fit(k, &embedding_dataset);
298 (k, labels)
299 })
300 .map(|(k, labels)| {
302 debug!(
304 "Calculating within intra-cluster variation for reference data sets with k={k}"
305 );
306 let w_kb_log: Vec<_> = reference_datasets
307 .par_iter()
308 .map(|ref_data| {
309 let ref_labels = self.state.clustering_method.fit(k, ref_data);
311 let ref_pairwise_distances = calc_pairwise_distances(
313 ref_data.records().view(),
314 k,
315 ref_labels.view(),
316 );
317 calc_within_dispersion(ref_labels.view(), k, ref_pairwise_distances.view())
318 .log2()
319 })
320 .collect();
321
322 let pairwise_distances =
324 calc_pairwise_distances(self.state.embeddings.view(), k, labels.view());
325 let w_k = calc_within_dispersion(labels.view(), k, pairwise_distances.view());
326
327 let w_kb_log_sum: f64 = w_kb_log.iter().sum();
329 let l = b.recip() * w_kb_log_sum;
331 let gap_k = l - w_k.log2();
333 let standard_deviation = (b.recip()
335 * w_kb_log
336 .iter()
337 .map(|w_kb_log| (w_kb_log - l).powi(2))
338 .sum::<f64>())
339 .sqrt();
340 let s_k = standard_deviation * (1.0 + b.recip()).sqrt();
343
344 (k, gap_k, s_k)
345 });
346
347 let (mut optimal_k, mut gap_k_minus_one) = (None, None);
353 for (k, gap_k, s_k) in results {
354 info!("k: {k}, gap_k: {gap_k}, s_k: {s_k}");
355
356 if let Some(gap_k_minus_one) = gap_k_minus_one
357 && gap_k_minus_one >= gap_k - s_k
358 {
359 info!("Optimal k found: {}", k - 1);
360 optimal_k = Some(k - 1);
361 break;
362 }
363
364 gap_k_minus_one = Some(gap_k);
365 }
366
367 optimal_k.ok_or(ClusteringError::OptimalKNotFound(self.state.k_max))
368 }
369
370 fn get_optimal_k_davies_bouldin(&self) -> Result<usize, ClusteringError> {
371 todo!();
372 }
373}
374
375#[must_use]
381#[inline]
382pub fn convert_to_array(data: Vec<Analysis>) -> AnalysisArray {
383 let shape = (data.len(), NUMBER_FEATURES);
385 debug_assert_eq!(NUMBER_FEATURES, data[0].inner().len());
386
387 AnalysisArray(
388 Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
389 .expect("Failed to convert to array, shape mismatch"),
390 )
391}
392
393fn generate_reference_datasets(samples: ArrayView2<Feature>, b: usize) -> Vec<FitDataset> {
417 let mut reference_datasets = Vec::with_capacity(b);
418 for _ in 0..b {
419 reference_datasets.push(Dataset::from(generate_ref_single(samples)));
420 }
421
422 reference_datasets
423}
424fn generate_ref_single(samples: ArrayView2<Feature>) -> Array2<Feature> {
425 let feature_distributions = samples
426 .axis_iter(Axis(1))
427 .map(|feature| Array::random(feature.dim(), Uniform::new(feature.min(), feature.max())))
428 .collect::<Vec<_>>();
429 let feature_dists_views = feature_distributions
430 .iter()
431 .map(ndarray::ArrayBase::view)
432 .collect::<Vec<_>>();
433 ndarray::stack(Axis(0), &feature_dists_views)
434 .unwrap()
435 .t()
436 .to_owned()
437}
438
439fn calc_within_dispersion(
443 labels: ArrayView1<usize>,
444 k: usize,
445 pairwise_distances: ArrayView1<Feature>,
446) -> Feature {
447 debug_assert_eq!(k, labels.iter().max().unwrap() + 1);
448
449 let counts = labels.iter().fold(vec![0u32; k], |mut counts, &label| {
451 counts[label] += 1;
452 counts
453 });
454 counts
456 .iter()
457 .zip(pairwise_distances.iter())
458 .map(|(&count, distance)| (2.0 * f64::from(count)).recip() * distance)
459 .sum()
460}
461
462fn calc_pairwise_distances(
470 samples: ArrayView2<Feature>,
471 k: usize,
472 labels: ArrayView1<usize>,
473) -> Array1<Feature> {
474 debug_assert_eq!(
475 samples.nrows(),
476 labels.len(),
477 "Samples and labels must have the same length"
478 );
479 debug_assert_eq!(
480 k,
481 labels.iter().max().unwrap() + 1,
482 "Labels must be in the range 0..k"
483 );
484
485 let mut distances = Array1::zeros(k);
487 let mut clusters = vec![Vec::new(); k];
488 for (sample, label) in samples.outer_iter().zip(labels.iter()) {
490 clusters[*label].push(sample);
491 }
492 for (k, cluster) in clusters.iter().enumerate() {
494 let mut pairwise_dists = 0.;
495 for i in 0..cluster.len() - 1 {
496 let a = cluster[i];
497 let rest = &cluster[i + 1..];
498 for &b in rest {
499 pairwise_dists += L2Dist.distance(a, b);
500 }
501 }
502 distances[k] += pairwise_dists + pairwise_dists;
503 }
504 distances
505}
506
507impl ClusteringHelper<Initialized> {
509 #[must_use]
515 #[inline]
516 pub fn cluster(self) -> ClusteringHelper<Finished> {
517 let Initialized {
518 clustering_method,
519 embeddings,
520 k,
521 } = self.state;
522
523 let embedding_dataset = Dataset::from(embeddings);
524 let labels = clustering_method.fit(k, &embedding_dataset);
525
526 ClusteringHelper {
527 state: Finished { labels, k },
528 }
529 }
530}
531
532impl ClusteringHelper<Finished> {
534 #[must_use]
536 #[inline]
537 pub fn extract_analysis_clusters<T: Clone>(&self, samples: Vec<T>) -> Vec<Vec<T>> {
538 let mut clusters = vec![Vec::new(); self.state.k];
539
540 for (sample, &label) in samples.into_iter().zip(self.state.labels.iter()) {
541 clusters[label].push(sample);
542 }
543
544 clusters
545 }
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551 use ndarray::{arr1, arr2, s};
552 use ndarray_rand::rand_distr::StandardNormal;
553 use pretty_assertions::assert_eq;
554 use rstest::rstest;
555
556 #[test]
557 fn test_generate_reference_data_set() {
558 let data = arr2(&[[10.0, -10.0], [20.0, -20.0], [30.0, -30.0]]);
559
560 let ref_data = generate_ref_single(data.view());
561
562 assert!(
564 ref_data
565 .slice(s![.., 0])
566 .iter()
567 .all(|v| *v >= 10.0 && *v <= 30.0)
568 );
569
570 assert!(
572 ref_data
573 .slice(s![.., 1])
574 .iter()
575 .all(|v| *v <= -10.0 && *v >= -30.0)
576 );
577
578 assert_eq!(ref_data.shape(), data.shape());
580
581 assert_ne!(ref_data, data);
583 }
584
585 #[test]
586 fn test_pairwise_distances() {
587 let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
588 let labels = arr1(&[0, 0, 1, 1]);
589
590 let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
591
592 assert!(
593 f64::EPSILON > (pairwise_distances[0] - 0.0).abs(),
594 "{} != 0.0",
595 pairwise_distances[0]
596 );
597 assert!(
598 f64::EPSILON > (pairwise_distances[1] - 0.0).abs(),
599 "{} != 0.0",
600 pairwise_distances[1]
601 );
602
603 let samples = arr2(&[[1.0, 2.0], [1.0, 1.0], [2.0, 2.0], [2.0, 3.0]]);
604
605 let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
606
607 assert!(
608 f64::EPSILON > (pairwise_distances[0] - 2.0).abs(),
609 "{} != 2.0",
610 pairwise_distances[0]
611 );
612 assert!(
613 f64::EPSILON > (pairwise_distances[1] - 2.0).abs(),
614 "{} != 2.0",
615 pairwise_distances[1]
616 );
617 }
618
619 #[test]
620 fn test_convert_to_vec() {
621 let data = vec![
622 Analysis::new([1.0; NUMBER_FEATURES]),
623 Analysis::new([2.0; NUMBER_FEATURES]),
624 Analysis::new([3.0; NUMBER_FEATURES]),
625 ];
626
627 let array = convert_to_array(data);
628
629 assert_eq!(array.0.shape(), &[3, NUMBER_FEATURES]);
630 assert!(
631 f64::EPSILON > (array.0[[0, 0]] - 1.0).abs(),
632 "{} != 1.0",
633 array.0[[0, 0]]
634 );
635 assert!(
636 f64::EPSILON > (array.0[[1, 0]] - 2.0).abs(),
637 "{} != 2.0",
638 array.0[[1, 0]]
639 );
640 assert!(
641 f64::EPSILON > (array.0[[2, 0]] - 3.0).abs(),
642 "{} != 3.0",
643 array.0[[2, 0]]
644 );
645
646 let mut iter = array.0.axis_iter(Axis(0));
649 assert_eq!(iter.next().unwrap().to_vec(), vec![1.0; NUMBER_FEATURES]);
650 assert_eq!(iter.next().unwrap().to_vec(), vec![2.0; NUMBER_FEATURES]);
651 assert_eq!(iter.next().unwrap().to_vec(), vec![3.0; NUMBER_FEATURES]);
652 for column in array.0.axis_iter(Axis(1)) {
654 assert_eq!(column.to_vec(), vec![1.0, 2.0, 3.0]);
655 }
656 }
657
658 #[test]
659 fn test_calc_within_dispersion() {
660 let labels = arr1(&[0, 1, 0, 1]);
661 let pairwise_distances = arr1(&[1.0, 2.0]);
662 let result = calc_within_dispersion(labels.view(), 2, pairwise_distances.view());
663
664 assert!(f64::EPSILON > (result - 0.75).abs(), "{result} != 0.75");
666 }
667
668 #[rstest]
669 #[case::project_none(ProjectionMethod::None, NUMBER_FEATURES)]
670 #[case::project_tsne(ProjectionMethod::TSne, EMBEDDING_SIZE)]
671 #[case::project_pca(ProjectionMethod::Pca, EMBEDDING_SIZE)]
672 fn test_project(
673 #[case] projection_method: ProjectionMethod,
674 #[case] expected_embedding_size: usize,
675 ) {
676 let mut samples = Array2::random((100, NUMBER_FEATURES), StandardNormal);
680 normalize_embeddings_inplace::<NUMBER_FEATURES>(&mut samples);
681 let samples = AnalysisArray(samples);
682
683 let result = projection_method.project(samples).unwrap();
684
685 assert_eq!(result.shape(), &[100, expected_embedding_size]);
687
688 for i in 0..expected_embedding_size {
690 let min = result.column(i).min();
691 let max = result.column(i).max();
692 assert!(
693 f64::EPSILON > (min + 1.0).abs(),
694 "Min value of column {i} is not -1.0: {min}",
695 );
696 assert!(
697 f64::EPSILON > (max - 1.0).abs(),
698 "Max value of column {i} is not 1.0: {max}",
699 );
700 }
701 }
702}
703
704