mecomp_analysis/
clustering.rs

1//! this module contains helpers that wrap the a k-means crate to perform clustering on the data
2//! without having to choose an exact number of clusters.
3//!
4//! Instead, you provide the minimum and maximum number of clusters you want to try, and we'll
5//! use one of a range of methods to determine the optimal number of clusters.
6//!
7//! # References:
8//!
9//! - The gap statistic [R. Tibshirani, G. Walther, and T. Hastie (Standford University, 2001)](https://hastie.su.domains/Papers/gap.pdf)
10//! - The Davies-Bouldin index [wikipedia](https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index)
11
12use linfa::prelude::*;
13use linfa_clustering::{GaussianMixtureModel, KMeans};
14use linfa_nn::distance::{Distance, L2Dist};
15use linfa_reduction::Pca;
16use linfa_tsne::TSneParams;
17use log::{debug, info};
18use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis, Dim};
19use ndarray_rand::RandomExt;
20use rand::distributions::Uniform;
21use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
22use statrs::statistics::Statistics;
23
24use crate::{
25    Analysis, Feature, NUMBER_FEATURES,
26    errors::{ClusteringError, ProjectionError},
27};
28
29pub struct AnalysisArray(pub(crate) Array2<Feature>);
30
31impl From<Vec<Analysis>> for AnalysisArray {
32    #[inline]
33    fn from(data: Vec<Analysis>) -> Self {
34        let shape = (data.len(), NUMBER_FEATURES);
35        debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
36
37        Self(
38            Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
39                .expect("Failed to convert to array, shape mismatch"),
40        )
41    }
42}
43
44impl From<Vec<[Feature; NUMBER_FEATURES]>> for AnalysisArray {
45    #[inline]
46    fn from(data: Vec<[Feature; NUMBER_FEATURES]>) -> Self {
47        let shape = (data.len(), NUMBER_FEATURES);
48        debug_assert_eq!(shape, (data.len(), data[0].len()));
49
50        Self(
51            Array2::from_shape_vec(shape, data.into_iter().flatten().collect())
52                .expect("Failed to convert to array, shape mismatch"),
53        )
54    }
55}
56
57pub type FitDataset = Dataset<Feature, (), Dim<[usize; 1]>>;
58
59#[derive(Clone, Copy, Debug)]
60#[allow(clippy::module_name_repetitions)]
61pub enum ClusteringMethod {
62    KMeans,
63    GaussianMixtureModel,
64}
65
66impl ClusteringMethod {
67    /// Fit the clustering method to the dataset and get the Labels
68    #[must_use]
69    fn fit(self, k: usize, data: &FitDataset) -> Array1<usize> {
70        match self {
71            Self::KMeans => {
72                let model = KMeans::params(k)
73                    // .max_n_iterations(MAX_ITERATIONS)
74                    .fit(data)
75                    .unwrap();
76                model.predict(data.records())
77            }
78            Self::GaussianMixtureModel => {
79                let model = GaussianMixtureModel::params(k)
80                    .init_method(linfa_clustering::GmmInitMethod::KMeans)
81                    .n_runs(10)
82                    .fit(data)
83                    .unwrap();
84                model.predict(data.records())
85            }
86        }
87    }
88}
89
90#[derive(Clone, Copy, Debug)]
91pub enum KOptimal {
92    GapStatistic {
93        /// The number of reference datasets to generate
94        b: u32,
95    },
96    DaviesBouldin,
97}
98
99#[derive(Clone, Copy, Debug, Default)]
100/// Should the data be projected into a lower-dimensional space before clustering, if so how?
101pub enum ProjectionMethod {
102    /// Use t-SNE to project the data into a lower-dimensional space
103    TSne,
104    /// Use PCA to project the data into a lower-dimensional space
105    Pca,
106    #[default]
107    /// Don't project the data
108    None,
109}
110
111impl ProjectionMethod {
112    /// Project the data into a lower-dimensional space
113    ///
114    /// # Errors
115    ///
116    /// Will return an error if there was an error projecting the data into a lower-dimensional space
117    #[inline]
118    pub fn project(self, samples: AnalysisArray) -> Result<Array2<Feature>, ProjectionError> {
119        let result = match self {
120            Self::TSne => {
121                let nrecords = samples.0.nrows();
122                // first use the t-SNE algorithm to project the data into a lower-dimensional space
123                debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using t-SNE");
124                #[allow(clippy::cast_precision_loss)]
125                let mut embeddings = TSneParams::embedding_size(EMBEDDING_SIZE)
126                    .perplexity(f64::max(samples.0.nrows() as f64 / 20., 5.))
127                    .approx_threshold(0.5)
128                    .transform(samples.0)?;
129                debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
130
131                // normalize the embeddings so each dimension is between -1 and 1
132                debug!("Normalizing embeddings");
133                normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
134                embeddings
135            }
136            Self::Pca => {
137                let nrecords = samples.0.nrows();
138                // use the PCA algorithm to project the data into a lower-dimensional space
139                debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using PCA");
140                let data = Dataset::from(samples.0);
141                let pca: Pca<f64> = Pca::params(EMBEDDING_SIZE).whiten(true).fit(&data)?;
142                let mut embeddings = pca.predict(&data);
143                debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
144
145                // normalize the embeddings so each dimension is between -1 and 1
146                debug!("Normalizing embeddings");
147                normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
148                embeddings
149            }
150            Self::None => {
151                debug!("Using original data as embeddings");
152                samples.0
153            }
154        };
155        debug!("Embeddings shape: {:?}", result.shape());
156        Ok(result)
157    }
158}
159
160// Normalize the embeddings to between 0.0 and 1.0, in-place.
161// Pass the embedding size as an argument to enable more compiler optimizations
162fn normalize_embeddings_inplace<const SIZE: usize>(embeddings: &mut Array2<f64>) {
163    for i in 0..SIZE {
164        let min = embeddings.column(i).min();
165        let max = embeddings.column(i).max();
166        let range = max - min;
167        embeddings
168            .column_mut(i)
169            .mapv_inplace(|v| ((v - min) / range).mul_add(2., -1.));
170    }
171}
172
173// log the number of features
174/// Dimensionality that the T-SNE and PCA projection methods aim to project the data into.
175const EMBEDDING_SIZE: usize = {
176    let log2 = usize::ilog2(NUMBER_FEATURES) as usize;
177    if log2 < 2 { 2 } else { log2 }
178};
179
180#[allow(clippy::module_name_repetitions)]
181pub struct ClusteringHelper<S>
182where
183    S: Sized,
184{
185    state: S,
186}
187
188pub struct EntryPoint;
189pub struct NotInitialized {
190    /// The embeddings of our input, as a Nx`EMBEDDING_SIZE` array
191    embeddings: Array2<Feature>,
192    pub k_max: usize,
193    pub optimizer: KOptimal,
194    pub clustering_method: ClusteringMethod,
195}
196pub struct Initialized {
197    /// The embeddings of our input, as a Nx`EMBEDDING_SIZE` array
198    embeddings: Array2<Feature>,
199    pub k: usize,
200    pub clustering_method: ClusteringMethod,
201}
202pub struct Finished {
203    /// The labelings of the samples, as a Nx1 array.
204    /// Each element is the cluster that the corresponding sample belongs to.
205    labels: Array1<usize>,
206    pub k: usize,
207}
208
209/// Functions available for all states
210impl ClusteringHelper<EntryPoint> {
211    /// Create a new `KMeansHelper` object
212    ///
213    /// # Errors
214    ///
215    /// Will return an error if there was an error projecting the data into a lower-dimensional space
216    #[allow(clippy::missing_inline_in_public_items)]
217    pub fn new(
218        samples: AnalysisArray,
219        k_max: usize,
220        optimizer: KOptimal,
221        clustering_method: ClusteringMethod,
222        projection_method: ProjectionMethod,
223    ) -> Result<ClusteringHelper<NotInitialized>, ClusteringError> {
224        if samples.0.nrows() <= 15 {
225            return Err(ClusteringError::SmallLibrary);
226        }
227
228        // project the data into a lower-dimensional space
229        let embeddings = projection_method.project(samples)?;
230
231        Ok(ClusteringHelper {
232            state: NotInitialized {
233                embeddings,
234                k_max,
235                optimizer,
236                clustering_method,
237            },
238        })
239    }
240}
241
242/// Functions available for `NotInitialized` state
243impl ClusteringHelper<NotInitialized> {
244    /// Initialize the `KMeansHelper` object
245    ///
246    /// # Errors
247    ///
248    /// Will return an error if there was an error calculating the optimal number of clusters
249    #[inline]
250    pub fn initialize(self) -> Result<ClusteringHelper<Initialized>, ClusteringError> {
251        let k = self.get_optimal_k()?;
252        Ok(ClusteringHelper {
253            state: Initialized {
254                embeddings: self.state.embeddings,
255                k,
256                clustering_method: self.state.clustering_method,
257            },
258        })
259    }
260
261    fn get_optimal_k(&self) -> Result<usize, ClusteringError> {
262        match self.state.optimizer {
263            KOptimal::GapStatistic { b } => self.get_optimal_k_gap_statistic(b),
264            KOptimal::DaviesBouldin => self.get_optimal_k_davies_bouldin(),
265        }
266    }
267
268    /// Get the optimal number of clusters using the gap statistic
269    ///
270    /// # References:
271    ///
272    /// - [R. Tibshirani, G. Walther, and T. Hastie (Standford University, 2001)](https://hastie.su.domains/Papers/gap.pdf)
273    ///
274    /// # Algorithm:
275    ///
276    /// 1. Cluster the observed data, varying the number of clusters from k = 1, …, kmax, and compute the corresponding total within intra-cluster variation Wk.
277    /// 2. Generate B reference data sets with a random uniform distribution. Cluster each of these reference data sets with varying number of clusters k = 1, …, kmax,
278    ///    and compute the corresponding total within intra-cluster variation `W_{kb}`.
279    /// 3. Compute the estimated gap statistic as the deviation of the observed `W_k` value from its expected value `W_kb` under the null hypothesis:
280    ///    `Gap(k)=(1/B) \sum_{b=1}^{B} \log(W^*_{kb}) − \log(W_k)`.
281    ///    Compute also the standard deviation of the statistics.
282    /// 4. Choose the number of clusters as the smallest value of k such that the gap statistic is within one standard deviation of the gap at k+1:
283    ///    `Gap(k)≥Gap(k + 1)−s_{k + 1}`.
284    fn get_optimal_k_gap_statistic(&self, b: u32) -> Result<usize, ClusteringError> {
285        let embedding_dataset = Dataset::from(self.state.embeddings.clone());
286
287        // our reference data sets
288        let reference_datasets =
289            generate_reference_datasets(embedding_dataset.records().view(), b as usize);
290
291        let b = f64::from(b);
292
293        let results = (1..=self.state.k_max)
294            // for each k, cluster the data into k clusters
295            .map(|k| {
296                debug!("Fitting k-means to embeddings with k={k}");
297                let labels = self.state.clustering_method.fit(k, &embedding_dataset);
298                (k, labels)
299            })
300            // for each k, calculate the gap statistic, and the standard deviation of the statistics
301            .map(|(k, labels)| {
302                // calculate the within intra-cluster variation for the reference data sets
303                debug!(
304                    "Calculating within intra-cluster variation for reference data sets with k={k}"
305                );
306                let w_kb_log: Vec<_> = reference_datasets
307                    .par_iter()
308                    .map(|ref_data| {
309                        // cluster the reference data into k clusters
310                        let ref_labels = self.state.clustering_method.fit(k, ref_data);
311                        // calculate the within intra-cluster variation for the reference data
312                        let ref_pairwise_distances = calc_pairwise_distances(
313                            ref_data.records().view(),
314                            k,
315                            ref_labels.view(),
316                        );
317                        calc_within_dispersion(ref_labels.view(), k, ref_pairwise_distances.view())
318                            .log2()
319                    })
320                    .collect();
321
322                // calculate the within intra-cluster variation for the observed data
323                let pairwise_distances =
324                    calc_pairwise_distances(self.state.embeddings.view(), k, labels.view());
325                let w_k = calc_within_dispersion(labels.view(), k, pairwise_distances.view());
326
327                // finally, calculate the gap statistic
328                let w_kb_log_sum: f64 = w_kb_log.iter().sum();
329                // original formula: l = (1 / B) * sum_b(log(W_kb))
330                let l = b.recip() * w_kb_log_sum;
331                // original formula: gap_k = (1 / B) * sum_b(log(W_kb)) - log(W_k)
332                let gap_k = l - w_k.log2();
333                // original formula: sd_k = [(1 / B) * sum_b((log(W_kb) - l)^2)]^0.5
334                let standard_deviation = (b.recip()
335                    * w_kb_log
336                        .iter()
337                        .map(|w_kb_log| (w_kb_log - l).powi(2))
338                        .sum::<f64>())
339                .sqrt();
340                // original formula: s_k = sd_k * (1 + 1 / B)^0.5
341                // calculate differently to minimize rounding errors
342                let s_k = standard_deviation * (1.0 + b.recip()).sqrt();
343
344                (k, gap_k, s_k)
345            });
346
347        // // plot the gap_k (whisker with s_k) w.r.t. k
348        // #[cfg(feature = "plot_gap")]
349        // plot_gap_statistic(results.clone().collect::<Vec<_>>());
350
351        // finally, consume the results iterator until we find the optimal k
352        let (mut optimal_k, mut gap_k_minus_one) = (None, None);
353        for (k, gap_k, s_k) in results {
354            info!("k: {k}, gap_k: {gap_k}, s_k: {s_k}");
355
356            if let Some(gap_k_minus_one) = gap_k_minus_one
357                && gap_k_minus_one >= gap_k - s_k
358            {
359                info!("Optimal k found: {}", k - 1);
360                optimal_k = Some(k - 1);
361                break;
362            }
363
364            gap_k_minus_one = Some(gap_k);
365        }
366
367        optimal_k.ok_or(ClusteringError::OptimalKNotFound(self.state.k_max))
368    }
369
370    fn get_optimal_k_davies_bouldin(&self) -> Result<usize, ClusteringError> {
371        todo!();
372    }
373}
374
375/// Convert a vector of Analyses into a 2D array
376///
377/// # Panics
378///
379/// Will panic if the shape of the data does not match the number of features, should never happen
380#[must_use]
381#[inline]
382pub fn convert_to_array(data: Vec<Analysis>) -> AnalysisArray {
383    // Convert vector to Array
384    let shape = (data.len(), NUMBER_FEATURES);
385    debug_assert_eq!(NUMBER_FEATURES, data[0].inner().len());
386
387    AnalysisArray(
388        Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
389            .expect("Failed to convert to array, shape mismatch"),
390    )
391}
392
393/// Generate B reference data sets with a random uniform distribution
394///
395/// (excerpt from reference paper)
396/// """
397/// We consider two choices for the reference distribution:
398///
399/// 1. generate each reference feature uniformly over the range of the observed values for that feature.
400/// 2. generate the reference features from a uniform distribution over a box aligned with the
401///    principle components of the data.
402///    In detail, if X is our n by p data matrix, we assume that the columns have mean 0 and compute
403///    the singular value decomposition X = UDV^T. We transform via X' = XV and then draw uniform features Z'
404///    over the ranges of the columns of X', as in method (1) above.
405///    Finally, we back-transform via Z=Z'V^T to give reference data Z.
406///
407/// Method (1) has the advantage of simplicity. Method (2) takes into account the shape of the
408/// data distribution and makes the procedure rotationally invariant, as long as the
409/// clustering method itself is invariant
410/// """
411///
412/// For now, we will use method (1) as it is simpler to implement
413/// and we know that our data is already normalized and that
414/// the ordering of features is important, meaning that we can't
415/// rotate the data anyway.
416fn generate_reference_datasets(samples: ArrayView2<Feature>, b: usize) -> Vec<FitDataset> {
417    let mut reference_datasets = Vec::with_capacity(b);
418    for _ in 0..b {
419        reference_datasets.push(Dataset::from(generate_ref_single(samples)));
420    }
421
422    reference_datasets
423}
424fn generate_ref_single(samples: ArrayView2<Feature>) -> Array2<Feature> {
425    let feature_distributions = samples
426        .axis_iter(Axis(1))
427        .map(|feature| Array::random(feature.dim(), Uniform::new(feature.min(), feature.max())))
428        .collect::<Vec<_>>();
429    let feature_dists_views = feature_distributions
430        .iter()
431        .map(ndarray::ArrayBase::view)
432        .collect::<Vec<_>>();
433    ndarray::stack(Axis(0), &feature_dists_views)
434        .unwrap()
435        .t()
436        .to_owned()
437}
438
439/// Calculate `W_k`, the within intra-cluster variation for the given clustering
440///
441/// `W_k = \sum_{r=1}^{k} \frac{D_r}{2*n_r}`
442fn calc_within_dispersion(
443    labels: ArrayView1<usize>,
444    k: usize,
445    pairwise_distances: ArrayView1<Feature>,
446) -> Feature {
447    debug_assert_eq!(k, labels.iter().max().unwrap() + 1);
448
449    // we first need to convert our list of labels into a list of the number of samples in each cluster
450    let counts = labels.iter().fold(vec![0u32; k], |mut counts, &label| {
451        counts[label] += 1;
452        counts
453    });
454    // then, we calculate the within intra-cluster variation
455    counts
456        .iter()
457        .zip(pairwise_distances.iter())
458        .map(|(&count, distance)| (2.0 * f64::from(count)).recip() * distance)
459        .sum()
460}
461
462/// Calculate the `D_r` array, the sum of the pairwise distances in cluster r, for all clusters in the given clustering
463///
464/// # Arguments
465///
466/// - `samples`: The samples in the dataset
467/// - `k`: The number of clusters
468/// - `labels`: The cluster labels for each sample
469fn calc_pairwise_distances(
470    samples: ArrayView2<Feature>,
471    k: usize,
472    labels: ArrayView1<usize>,
473) -> Array1<Feature> {
474    debug_assert_eq!(
475        samples.nrows(),
476        labels.len(),
477        "Samples and labels must have the same length"
478    );
479    debug_assert_eq!(
480        k,
481        labels.iter().max().unwrap() + 1,
482        "Labels must be in the range 0..k"
483    );
484
485    // for each cluster, calculate the sum of the pairwise distances between samples in that cluster
486    let mut distances = Array1::zeros(k);
487    let mut clusters = vec![Vec::new(); k];
488    // build clusters
489    for (sample, label) in samples.outer_iter().zip(labels.iter()) {
490        clusters[*label].push(sample);
491    }
492    // calculate pairwise dist. within each cluster
493    for (k, cluster) in clusters.iter().enumerate() {
494        let mut pairwise_dists = 0.;
495        for i in 0..cluster.len() - 1 {
496            let a = cluster[i];
497            let rest = &cluster[i + 1..];
498            for &b in rest {
499                pairwise_dists += L2Dist.distance(a, b);
500            }
501        }
502        distances[k] += pairwise_dists + pairwise_dists;
503    }
504    distances
505}
506
507/// Functions available for Initialized state
508impl ClusteringHelper<Initialized> {
509    /// Cluster the data into k clusters
510    ///
511    /// # Errors
512    ///
513    /// Will return an error if the clustering fails
514    #[must_use]
515    #[inline]
516    pub fn cluster(self) -> ClusteringHelper<Finished> {
517        let Initialized {
518            clustering_method,
519            embeddings,
520            k,
521        } = self.state;
522
523        let embedding_dataset = Dataset::from(embeddings);
524        let labels = clustering_method.fit(k, &embedding_dataset);
525
526        ClusteringHelper {
527            state: Finished { labels, k },
528        }
529    }
530}
531
532/// Functions available for Finished state
533impl ClusteringHelper<Finished> {
534    /// use the labels to reorganize the provided samples into clusters
535    #[must_use]
536    #[inline]
537    pub fn extract_analysis_clusters<T: Clone>(&self, samples: Vec<T>) -> Vec<Vec<T>> {
538        let mut clusters = vec![Vec::new(); self.state.k];
539
540        for (sample, &label) in samples.into_iter().zip(self.state.labels.iter()) {
541            clusters[label].push(sample);
542        }
543
544        clusters
545    }
546}
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551    use ndarray::{arr1, arr2, s};
552    use ndarray_rand::rand_distr::StandardNormal;
553    use pretty_assertions::assert_eq;
554    use rstest::rstest;
555
556    #[test]
557    fn test_generate_reference_data_set() {
558        let data = arr2(&[[10.0, -10.0], [20.0, -20.0], [30.0, -30.0]]);
559
560        let ref_data = generate_ref_single(data.view());
561
562        // First column all vals between 10.0 and 30.0
563        assert!(
564            ref_data
565                .slice(s![.., 0])
566                .iter()
567                .all(|v| *v >= 10.0 && *v <= 30.0)
568        );
569
570        // Second column all vals between -10.0 and -30.0
571        assert!(
572            ref_data
573                .slice(s![.., 1])
574                .iter()
575                .all(|v| *v <= -10.0 && *v >= -30.0)
576        );
577
578        // check that the shape is correct
579        assert_eq!(ref_data.shape(), data.shape());
580
581        // check that the data is not the same as the original data
582        assert_ne!(ref_data, data);
583    }
584
585    #[test]
586    fn test_pairwise_distances() {
587        let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
588        let labels = arr1(&[0, 0, 1, 1]);
589
590        let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
591
592        assert!(
593            f64::EPSILON > (pairwise_distances[0] - 0.0).abs(),
594            "{} != 0.0",
595            pairwise_distances[0]
596        );
597        assert!(
598            f64::EPSILON > (pairwise_distances[1] - 0.0).abs(),
599            "{} != 0.0",
600            pairwise_distances[1]
601        );
602
603        let samples = arr2(&[[1.0, 2.0], [1.0, 1.0], [2.0, 2.0], [2.0, 3.0]]);
604
605        let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
606
607        assert!(
608            f64::EPSILON > (pairwise_distances[0] - 2.0).abs(),
609            "{} != 2.0",
610            pairwise_distances[0]
611        );
612        assert!(
613            f64::EPSILON > (pairwise_distances[1] - 2.0).abs(),
614            "{} != 2.0",
615            pairwise_distances[1]
616        );
617    }
618
619    #[test]
620    fn test_convert_to_vec() {
621        let data = vec![
622            Analysis::new([1.0; NUMBER_FEATURES]),
623            Analysis::new([2.0; NUMBER_FEATURES]),
624            Analysis::new([3.0; NUMBER_FEATURES]),
625        ];
626
627        let array = convert_to_array(data);
628
629        assert_eq!(array.0.shape(), &[3, NUMBER_FEATURES]);
630        assert!(
631            f64::EPSILON > (array.0[[0, 0]] - 1.0).abs(),
632            "{} != 1.0",
633            array.0[[0, 0]]
634        );
635        assert!(
636            f64::EPSILON > (array.0[[1, 0]] - 2.0).abs(),
637            "{} != 2.0",
638            array.0[[1, 0]]
639        );
640        assert!(
641            f64::EPSILON > (array.0[[2, 0]] - 3.0).abs(),
642            "{} != 3.0",
643            array.0[[2, 0]]
644        );
645
646        // check that axis iteration works how we expect
647        // axis 0
648        let mut iter = array.0.axis_iter(Axis(0));
649        assert_eq!(iter.next().unwrap().to_vec(), vec![1.0; NUMBER_FEATURES]);
650        assert_eq!(iter.next().unwrap().to_vec(), vec![2.0; NUMBER_FEATURES]);
651        assert_eq!(iter.next().unwrap().to_vec(), vec![3.0; NUMBER_FEATURES]);
652        // axis 1
653        for column in array.0.axis_iter(Axis(1)) {
654            assert_eq!(column.to_vec(), vec![1.0, 2.0, 3.0]);
655        }
656    }
657
658    #[test]
659    fn test_calc_within_dispersion() {
660        let labels = arr1(&[0, 1, 0, 1]);
661        let pairwise_distances = arr1(&[1.0, 2.0]);
662        let result = calc_within_dispersion(labels.view(), 2, pairwise_distances.view());
663
664        // `W_k = \sum_{r=1}^{k} \frac{D_r}{2*n_r}` = 1/4 * 1.0 + 1/4 * 2.0 = 0.25 + 0.5 = 0.75
665        assert!(f64::EPSILON > (result - 0.75).abs(), "{result} != 0.75");
666    }
667
668    #[rstest]
669    #[case::project_none(ProjectionMethod::None, NUMBER_FEATURES)]
670    #[case::project_tsne(ProjectionMethod::TSne, EMBEDDING_SIZE)]
671    #[case::project_pca(ProjectionMethod::Pca, EMBEDDING_SIZE)]
672    fn test_project(
673        #[case] projection_method: ProjectionMethod,
674        #[case] expected_embedding_size: usize,
675    ) {
676        // generate 100 random samples, we use a normal distribution because with a uniform distribution
677        // the data has no real "principle components" and PCA will not work as expected since almost all the eigenvalues
678        // with fall below the cutoff
679        let mut samples = Array2::random((100, NUMBER_FEATURES), StandardNormal);
680        normalize_embeddings_inplace::<NUMBER_FEATURES>(&mut samples);
681        let samples = AnalysisArray(samples);
682
683        let result = projection_method.project(samples).unwrap();
684
685        // ensure embeddings are the correct shape
686        assert_eq!(result.shape(), &[100, expected_embedding_size]);
687
688        // ensure the data is normalized
689        for i in 0..expected_embedding_size {
690            let min = result.column(i).min();
691            let max = result.column(i).max();
692            assert!(
693                f64::EPSILON > (min + 1.0).abs(),
694                "Min value of column {i} is not -1.0: {min}",
695            );
696            assert!(
697                f64::EPSILON > (max - 1.0).abs(),
698                "Max value of column {i} is not 1.0: {max}",
699            );
700        }
701    }
702}
703
704// #[cfg(feature = "plot_gap")]
705// fn plot_gap_statistic(data: Vec<(usize, f64, f64)>) {
706//     use plotters::prelude::*;
707
708//     // Assuming data is a Vec<(usize, f64, f64)> of (k, gap_k, s_k)
709//     let root_area = BitMapBackend::new("gap_statistic_plot.png", (640, 480)).into_drawing_area();
710//     root_area.fill(&WHITE).unwrap();
711
712//     let max_gap_k = data
713//         .iter()
714//         .map(|(_, gap_k, _)| *gap_k)
715//         .fold(f64::MIN, f64::max);
716//     let min_gap_k = data
717//         .iter()
718//         .map(|(_, gap_k, _)| *gap_k)
719//         .fold(f64::MAX, f64::min);
720//     let max_k = data.iter().map(|(k, _, _)| *k).max().unwrap_or(0);
721
722//     let mut chart = ChartBuilder::on(&root_area)
723//         .caption("Gap Statistic Plot", ("sans-serif", 30))
724//         .margin(5)
725//         .x_label_area_size(30)
726//         .y_label_area_size(30)
727//         .build_cartesian_2d(0..max_k, min_gap_k..max_gap_k)
728//         .unwrap();
729
730//     chart.configure_mesh().draw().unwrap();
731
732//     for (k, gap_k, s_k) in data {
733//         chart
734//             .draw_series(PointSeries::of_element(
735//                 vec![(k, gap_k)],
736//                 5,
737//                 &RED,
738//                 &|coord, size, style| {
739//                     EmptyElement::at(coord) + Circle::new((0, 0), size, style.filled())
740//                 },
741//             ))
742//             .unwrap();
743
744//         // Drawing error bars
745//         chart
746//             .draw_series(LineSeries::new(
747//                 vec![(k, gap_k - s_k), (k, gap_k + s_k)],
748//                 &BLACK,
749//             ))
750//             .unwrap();
751//     }
752
753//     root_area.present().unwrap();
754// }