1use linfa::prelude::*;
13use linfa_clustering::{GaussianMixtureModel, KMeans};
14use linfa_nn::distance::{Distance, L2Dist};
15use linfa_reduction::Pca;
16use linfa_tsne::TSneParams;
17use log::{debug, info};
18use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis, Dim};
19use ndarray_rand::RandomExt;
20use rand::distributions::Uniform;
21use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
22use statrs::statistics::Statistics;
23
24use crate::{
25 Analysis, Feature, NUMBER_FEATURES,
26 errors::{ClusteringError, ProjectionError},
27};
28
29pub struct AnalysisArray(pub(crate) Array2<Feature>);
30
31impl From<Vec<Analysis>> for AnalysisArray {
32 #[inline]
33 fn from(data: Vec<Analysis>) -> Self {
34 let shape = (data.len(), NUMBER_FEATURES);
35 debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
36
37 Self(
38 Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
39 .expect("Failed to convert to array, shape mismatch"),
40 )
41 }
42}
43
44impl From<Vec<[Feature; NUMBER_FEATURES]>> for AnalysisArray {
45 #[inline]
46 fn from(data: Vec<[Feature; NUMBER_FEATURES]>) -> Self {
47 let shape = (data.len(), NUMBER_FEATURES);
48 debug_assert_eq!(shape, (data.len(), data[0].len()));
49
50 Self(
51 Array2::from_shape_vec(shape, data.into_iter().flatten().collect())
52 .expect("Failed to convert to array, shape mismatch"),
53 )
54 }
55}
56
57pub type FitDataset = Dataset<Feature, (), Dim<[usize; 1]>>;
58
59#[derive(Clone, Copy, Debug)]
60#[allow(clippy::module_name_repetitions)]
61pub enum ClusteringMethod {
62 KMeans,
63 GaussianMixtureModel,
64}
65
66impl ClusteringMethod {
67 #[must_use]
69 fn fit(self, k: usize, data: &FitDataset) -> Array1<usize> {
70 match self {
71 Self::KMeans => {
72 let model = KMeans::params(k)
73 .fit(data)
75 .unwrap();
76 model.predict(data.records())
77 }
78 Self::GaussianMixtureModel => {
79 let model = GaussianMixtureModel::params(k)
80 .init_method(linfa_clustering::GmmInitMethod::KMeans)
81 .n_runs(10)
82 .fit(data)
83 .unwrap();
84 model.predict(data.records())
85 }
86 }
87 }
88}
89
90#[derive(Clone, Copy, Debug)]
91pub enum KOptimal {
92 GapStatistic {
93 b: u32,
95 },
96 DaviesBouldin,
97}
98
99#[derive(Clone, Copy, Debug, Default)]
100pub enum ProjectionMethod {
102 TSne,
104 Pca,
106 #[default]
107 None,
109}
110
111impl ProjectionMethod {
112 #[inline]
118 pub fn project(self, samples: AnalysisArray) -> Result<Array2<Feature>, ProjectionError> {
119 let result = match self {
120 Self::TSne => {
121 let nrecords = samples.0.nrows();
122 debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using t-SNE");
124 #[allow(clippy::cast_precision_loss)]
125 let mut embeddings = TSneParams::embedding_size(EMBEDDING_SIZE)
126 .perplexity(f64::max(samples.0.nrows() as f64 / 20., 5.))
127 .approx_threshold(0.5)
128 .transform(samples.0)?;
129 debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
130
131 debug!("Normalizing embeddings");
133 normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
134 embeddings
135 }
136 Self::Pca => {
137 let nrecords = samples.0.nrows();
138 debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using PCA");
140 let data = Dataset::from(samples.0);
141 let pca: Pca<f64> = Pca::params(EMBEDDING_SIZE).whiten(true).fit(&data)?;
142 let mut embeddings = pca.predict(&data);
143 debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
144
145 debug!("Normalizing embeddings");
147 normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
148 embeddings
149 }
150 Self::None => {
151 debug!("Using original data as embeddings");
152 samples.0
153 }
154 };
155 debug!("Embeddings shape: {:?}", result.shape());
156 Ok(result)
157 }
158}
159
160fn normalize_embeddings_inplace<const SIZE: usize>(embeddings: &mut Array2<f64>) {
163 for i in 0..SIZE {
164 let min = embeddings.column(i).min();
165 let max = embeddings.column(i).max();
166 let range = max - min;
167 embeddings
168 .column_mut(i)
169 .mapv_inplace(|v| ((v - min) / range).mul_add(2., -1.));
170 }
171}
172
173const EMBEDDING_SIZE: usize = {
176 let log2 = usize::ilog2(NUMBER_FEATURES) as usize;
177 if log2 < 2 { 2 } else { log2 }
178};
179
180#[allow(clippy::module_name_repetitions)]
181pub struct ClusteringHelper<S>
182where
183 S: Sized,
184{
185 state: S,
186}
187
188pub struct EntryPoint;
189pub struct NotInitialized {
190 embeddings: Array2<Feature>,
192 pub k_max: usize,
193 pub optimizer: KOptimal,
194 pub clustering_method: ClusteringMethod,
195}
196pub struct Initialized {
197 embeddings: Array2<Feature>,
199 pub k: usize,
200 pub clustering_method: ClusteringMethod,
201}
202pub struct Finished {
203 labels: Array1<usize>,
206 pub k: usize,
207}
208
209impl ClusteringHelper<EntryPoint> {
211 #[allow(clippy::missing_inline_in_public_items)]
217 pub fn new(
218 samples: AnalysisArray,
219 k_max: usize,
220 optimizer: KOptimal,
221 clustering_method: ClusteringMethod,
222 projection_method: ProjectionMethod,
223 ) -> Result<ClusteringHelper<NotInitialized>, ClusteringError> {
224 if samples.0.nrows() <= 15 {
225 return Err(ClusteringError::SmallLibrary);
226 }
227
228 let embeddings = projection_method.project(samples)?;
230
231 Ok(ClusteringHelper {
232 state: NotInitialized {
233 embeddings,
234 k_max,
235 optimizer,
236 clustering_method,
237 },
238 })
239 }
240}
241
242impl ClusteringHelper<NotInitialized> {
244 #[inline]
250 pub fn initialize(self) -> Result<ClusteringHelper<Initialized>, ClusteringError> {
251 let k = self.get_optimal_k()?;
252 Ok(ClusteringHelper {
253 state: Initialized {
254 embeddings: self.state.embeddings,
255 k,
256 clustering_method: self.state.clustering_method,
257 },
258 })
259 }
260
261 fn get_optimal_k(&self) -> Result<usize, ClusteringError> {
262 match self.state.optimizer {
263 KOptimal::GapStatistic { b } => self.get_optimal_k_gap_statistic(b),
264 KOptimal::DaviesBouldin => self.get_optimal_k_davies_bouldin(),
265 }
266 }
267
268 fn get_optimal_k_gap_statistic(&self, b: u32) -> Result<usize, ClusteringError> {
285 let embedding_dataset = Dataset::from(self.state.embeddings.clone());
286
287 let reference_datasets =
289 generate_reference_datasets(embedding_dataset.records().view(), b as usize);
290
291 let b = f64::from(b);
292
293 let results = (1..=self.state.k_max)
294 .map(|k| {
296 debug!("Fitting k-means to embeddings with k={k}");
297 let labels = self.state.clustering_method.fit(k, &embedding_dataset);
298 (k, labels)
299 })
300 .map(|(k, labels)| {
302 debug!(
304 "Calculating within intra-cluster variation for reference data sets with k={k}"
305 );
306 let w_kb_log: Vec<_> = reference_datasets
307 .par_iter()
308 .map(|ref_data| {
309 let ref_labels = self.state.clustering_method.fit(k, ref_data);
311 let ref_pairwise_distances = calc_pairwise_distances(
313 ref_data.records().view(),
314 k,
315 ref_labels.view(),
316 );
317 calc_within_dispersion(ref_labels.view(), k, ref_pairwise_distances.view())
318 .log2()
319 })
320 .collect();
321 let w_kb_log = Array::from_vec(w_kb_log);
322
323 let pairwise_distances =
325 calc_pairwise_distances(self.state.embeddings.view(), k, labels.view());
326 let w_k = calc_within_dispersion(labels.view(), k, pairwise_distances.view());
327
328 let w_kb_log_sum: f64 = w_kb_log.sum();
330 let l = b.recip() * w_kb_log_sum;
332 let gap_k = l - w_k.log2();
334 let standard_deviation = (b.recip() * (w_kb_log - l).pow2().sum()).sqrt();
336 let s_k = standard_deviation * (1.0 + b.recip()).sqrt();
339
340 (k, gap_k, s_k)
341 });
342
343 let (mut optimal_k, mut gap_k_minus_one) = (None, None);
349 for (k, gap_k, s_k) in results {
350 info!("k: {k}, gap_k: {gap_k}, s_k: {s_k}");
351
352 if let Some(gap_k_minus_one) = gap_k_minus_one
353 && gap_k_minus_one >= gap_k - s_k
354 {
355 info!("Optimal k found: {}", k - 1);
356 optimal_k = Some(k - 1);
357 break;
358 }
359
360 gap_k_minus_one = Some(gap_k);
361 }
362
363 optimal_k.ok_or(ClusteringError::OptimalKNotFound(self.state.k_max))
364 }
365
366 fn get_optimal_k_davies_bouldin(&self) -> Result<usize, ClusteringError> {
367 todo!();
368 }
369}
370
371#[must_use]
377#[inline]
378pub fn convert_to_array(data: Vec<Analysis>) -> AnalysisArray {
379 let shape = (data.len(), NUMBER_FEATURES);
381 debug_assert_eq!(NUMBER_FEATURES, data[0].inner().len());
382
383 AnalysisArray(
384 Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
385 .expect("Failed to convert to array, shape mismatch"),
386 )
387}
388
389fn generate_reference_datasets(samples: ArrayView2<'_, Feature>, b: usize) -> Vec<FitDataset> {
413 (0..b)
414 .into_par_iter()
415 .map(|_| Dataset::from(generate_ref_single(samples.view())))
416 .collect()
417}
418fn generate_ref_single(samples: ArrayView2<'_, Feature>) -> Array2<Feature> {
419 let feature_distributions = samples
420 .axis_iter(Axis(1))
421 .map(|feature| Array::random(feature.dim(), Uniform::new(feature.min(), feature.max())))
422 .collect::<Vec<_>>();
423 let feature_dists_views = feature_distributions
424 .iter()
425 .map(ndarray::ArrayBase::view)
426 .collect::<Vec<_>>();
427 ndarray::stack(Axis(0), &feature_dists_views)
428 .unwrap()
429 .t()
430 .to_owned()
431}
432
433fn calc_within_dispersion(
437 labels: ArrayView1<'_, usize>,
438 k: usize,
439 pairwise_distances: ArrayView1<'_, Feature>,
440) -> Feature {
441 debug_assert_eq!(k, labels.iter().max().unwrap() + 1);
442
443 let counts = labels.iter().fold(vec![0u32; k], |mut counts, &label| {
445 counts[label] += 1;
446 counts
447 });
448 counts
450 .iter()
451 .zip(pairwise_distances.iter())
452 .map(|(&count, distance)| (2.0 * f64::from(count)).recip() * distance)
453 .sum()
454}
455
456fn calc_pairwise_distances(
464 samples: ArrayView2<'_, Feature>,
465 k: usize,
466 labels: ArrayView1<'_, usize>,
467) -> Array1<Feature> {
468 debug_assert_eq!(
469 samples.nrows(),
470 labels.len(),
471 "Samples and labels must have the same length"
472 );
473 debug_assert_eq!(
474 k,
475 labels.iter().max().unwrap() + 1,
476 "Labels must be in the range 0..k"
477 );
478
479 let mut distances = Array1::zeros(k);
481 let mut clusters = vec![Vec::new(); k];
482 for (sample, label) in samples.outer_iter().zip(labels.iter()) {
484 clusters[*label].push(sample);
485 }
486 for (k, cluster) in clusters.iter().enumerate() {
488 let mut pairwise_dists = 0.;
489 for i in 0..cluster.len() - 1 {
490 let a = cluster[i];
491 let rest = &cluster[i + 1..];
492 for &b in rest {
493 pairwise_dists += L2Dist.distance(a, b);
494 }
495 }
496 distances[k] += pairwise_dists + pairwise_dists;
497 }
498 distances
499}
500
501impl ClusteringHelper<Initialized> {
503 #[must_use]
509 #[inline]
510 pub fn cluster(self) -> ClusteringHelper<Finished> {
511 let Initialized {
512 clustering_method,
513 embeddings,
514 k,
515 } = self.state;
516
517 let embedding_dataset = Dataset::from(embeddings);
518 let labels = clustering_method.fit(k, &embedding_dataset);
519
520 ClusteringHelper {
521 state: Finished { labels, k },
522 }
523 }
524}
525
526impl ClusteringHelper<Finished> {
528 #[must_use]
530 #[inline]
531 pub fn extract_analysis_clusters<T: Clone>(&self, samples: Vec<T>) -> Vec<Vec<T>> {
532 let mut clusters = vec![Vec::new(); self.state.k];
533
534 for (sample, &label) in samples.into_iter().zip(self.state.labels.iter()) {
535 clusters[label].push(sample);
536 }
537
538 clusters
539 }
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545 use ndarray::{arr1, arr2, s};
546 use ndarray_rand::rand_distr::StandardNormal;
547 use pretty_assertions::assert_eq;
548 use rstest::rstest;
549
550 #[test]
551 fn test_generate_reference_data_set() {
552 let data = arr2(&[[10.0, -10.0], [20.0, -20.0], [30.0, -30.0]]);
553
554 let ref_data = generate_ref_single(data.view());
555
556 assert!(
558 ref_data
559 .slice(s![.., 0])
560 .iter()
561 .all(|v| *v >= 10.0 && *v <= 30.0)
562 );
563
564 assert!(
566 ref_data
567 .slice(s![.., 1])
568 .iter()
569 .all(|v| *v <= -10.0 && *v >= -30.0)
570 );
571
572 assert_eq!(ref_data.shape(), data.shape());
574
575 assert_ne!(ref_data, data);
577 }
578
579 #[test]
580 fn test_pairwise_distances() {
581 let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
582 let labels = arr1(&[0, 0, 1, 1]);
583
584 let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
585
586 assert!(
587 f64::EPSILON > (pairwise_distances[0] - 0.0).abs(),
588 "{} != 0.0",
589 pairwise_distances[0]
590 );
591 assert!(
592 f64::EPSILON > (pairwise_distances[1] - 0.0).abs(),
593 "{} != 0.0",
594 pairwise_distances[1]
595 );
596
597 let samples = arr2(&[[1.0, 2.0], [1.0, 1.0], [2.0, 2.0], [2.0, 3.0]]);
598
599 let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
600
601 assert!(
602 f64::EPSILON > (pairwise_distances[0] - 2.0).abs(),
603 "{} != 2.0",
604 pairwise_distances[0]
605 );
606 assert!(
607 f64::EPSILON > (pairwise_distances[1] - 2.0).abs(),
608 "{} != 2.0",
609 pairwise_distances[1]
610 );
611 }
612
613 #[test]
614 fn test_convert_to_vec() {
615 let data = vec![
616 Analysis::new([1.0; NUMBER_FEATURES]),
617 Analysis::new([2.0; NUMBER_FEATURES]),
618 Analysis::new([3.0; NUMBER_FEATURES]),
619 ];
620
621 let array = convert_to_array(data);
622
623 assert_eq!(array.0.shape(), &[3, NUMBER_FEATURES]);
624 assert!(
625 f64::EPSILON > (array.0[[0, 0]] - 1.0).abs(),
626 "{} != 1.0",
627 array.0[[0, 0]]
628 );
629 assert!(
630 f64::EPSILON > (array.0[[1, 0]] - 2.0).abs(),
631 "{} != 2.0",
632 array.0[[1, 0]]
633 );
634 assert!(
635 f64::EPSILON > (array.0[[2, 0]] - 3.0).abs(),
636 "{} != 3.0",
637 array.0[[2, 0]]
638 );
639
640 let mut iter = array.0.axis_iter(Axis(0));
643 assert_eq!(iter.next().unwrap().to_vec(), vec![1.0; NUMBER_FEATURES]);
644 assert_eq!(iter.next().unwrap().to_vec(), vec![2.0; NUMBER_FEATURES]);
645 assert_eq!(iter.next().unwrap().to_vec(), vec![3.0; NUMBER_FEATURES]);
646 for column in array.0.axis_iter(Axis(1)) {
648 assert_eq!(column.to_vec(), vec![1.0, 2.0, 3.0]);
649 }
650 }
651
652 #[test]
653 fn test_calc_within_dispersion() {
654 let labels = arr1(&[0, 1, 0, 1]);
655 let pairwise_distances = arr1(&[1.0, 2.0]);
656 let result = calc_within_dispersion(labels.view(), 2, pairwise_distances.view());
657
658 assert!(f64::EPSILON > (result - 0.75).abs(), "{result} != 0.75");
660 }
661
662 #[rstest]
663 #[case::project_none(ProjectionMethod::None, NUMBER_FEATURES)]
664 #[case::project_tsne(ProjectionMethod::TSne, EMBEDDING_SIZE)]
665 #[case::project_pca(ProjectionMethod::Pca, EMBEDDING_SIZE)]
666 fn test_project(
667 #[case] projection_method: ProjectionMethod,
668 #[case] expected_embedding_size: usize,
669 ) {
670 let mut samples = Array2::random((100, NUMBER_FEATURES), StandardNormal);
674 normalize_embeddings_inplace::<NUMBER_FEATURES>(&mut samples);
675 let samples = AnalysisArray(samples);
676
677 let result = projection_method.project(samples).unwrap();
678
679 assert_eq!(result.shape(), &[100, expected_embedding_size]);
681
682 for i in 0..expected_embedding_size {
684 let min = result.column(i).min();
685 let max = result.column(i).max();
686 assert!(
687 f64::EPSILON > (min + 1.0).abs(),
688 "Min value of column {i} is not -1.0: {min}",
689 );
690 assert!(
691 f64::EPSILON > (max - 1.0).abs(),
692 "Max value of column {i} is not 1.0: {max}",
693 );
694 }
695 }
696}
697
698