mecomp_analysis/
clustering.rs

1//! this module contains helpers that wrap the a k-means crate to perform clustering on the data
2//! without having to choose an exact number of clusters.
3//!
4//! Instead, you provide the minimum and maximum number of clusters you want to try, and we'll
5//! use one of a range of methods to determine the optimal number of clusters.
6//!
7//! # References:
8//!
9//! - The gap statistic [R. Tibshirani, G. Walther, and T. Hastie (Standford University, 2001)](https://hastie.su.domains/Papers/gap.pdf)
10//! - The Davies-Bouldin index [wikipedia](https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index)
11
12use linfa::prelude::*;
13use linfa_clustering::{GaussianMixtureModel, KMeans};
14use linfa_nn::distance::{Distance, L2Dist};
15use linfa_reduction::Pca;
16use linfa_tsne::TSneParams;
17use log::{debug, info};
18use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis};
19use ndarray_rand::RandomExt;
20use rand::distributions::Uniform;
21use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
22use statrs::statistics::Statistics;
23
24use crate::{
25    Analysis, Feature, NUMBER_FEATURES,
26    errors::{ClusteringError, ProjectionError},
27};
28
29pub struct AnalysisArray(pub(crate) Array2<Feature>);
30
31impl From<Vec<Analysis>> for AnalysisArray {
32    #[inline]
33    fn from(data: Vec<Analysis>) -> Self {
34        let shape = (data.len(), NUMBER_FEATURES);
35        debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
36
37        Self(
38            Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
39                .expect("Failed to convert to array, shape mismatch"),
40        )
41    }
42}
43
44impl From<Vec<[Feature; NUMBER_FEATURES]>> for AnalysisArray {
45    #[inline]
46    fn from(data: Vec<[Feature; NUMBER_FEATURES]>) -> Self {
47        let shape = (data.len(), NUMBER_FEATURES);
48        debug_assert_eq!(shape, (data.len(), data[0].len()));
49
50        Self(
51            Array2::from_shape_vec(shape, data.into_iter().flatten().collect())
52                .expect("Failed to convert to array, shape mismatch"),
53        )
54    }
55}
56
57#[derive(Clone, Copy, Debug)]
58#[allow(clippy::module_name_repetitions)]
59pub enum ClusteringMethod {
60    KMeans,
61    GaussianMixtureModel,
62}
63
64impl ClusteringMethod {
65    /// Fit the clustering method to the samples and get the labels
66    #[must_use]
67    fn fit(self, k: usize, samples: &Array2<Feature>) -> Array1<usize> {
68        match self {
69            Self::KMeans => {
70                let model = KMeans::params(k)
71                    // .max_n_iterations(MAX_ITERATIONS)
72                    .fit(&Dataset::from(samples.clone()))
73                    .unwrap();
74                model.predict(samples)
75            }
76            Self::GaussianMixtureModel => {
77                let model = GaussianMixtureModel::params(k)
78                    .init_method(linfa_clustering::GmmInitMethod::KMeans)
79                    .n_runs(10)
80                    .fit(&Dataset::from(samples.clone()))
81                    .unwrap();
82                model.predict(samples)
83            }
84        }
85    }
86}
87
88#[derive(Clone, Copy, Debug)]
89pub enum KOptimal {
90    GapStatistic {
91        /// The number of reference datasets to generate
92        b: u32,
93    },
94    DaviesBouldin,
95}
96
97#[derive(Clone, Copy, Debug, Default)]
98/// Should the data be projected into a lower-dimensional space before clustering, if so how?
99pub enum ProjectionMethod {
100    /// Use t-SNE to project the data into a lower-dimensional space
101    TSne,
102    /// Use PCA to project the data into a lower-dimensional space
103    Pca,
104    #[default]
105    /// Don't project the data
106    None,
107}
108
109impl ProjectionMethod {
110    /// Project the data into a lower-dimensional space
111    ///
112    /// # Errors
113    ///
114    /// Will return an error if there was an error projecting the data into a lower-dimensional space
115    #[inline]
116    pub fn project(self, samples: AnalysisArray) -> Result<Array2<Feature>, ProjectionError> {
117        let result = match self {
118            Self::TSne => {
119                let nrecords = samples.0.nrows();
120                // first use the t-SNE algorithm to project the data into a lower-dimensional space
121                debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using t-SNE");
122                #[allow(clippy::cast_precision_loss)]
123                let mut embeddings = TSneParams::embedding_size(EMBEDDING_SIZE)
124                    .perplexity(f64::max(samples.0.nrows() as f64 / 20., 5.))
125                    .approx_threshold(0.5)
126                    .transform(samples.0)?;
127                debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
128
129                // normalize the embeddings so each dimension is between -1 and 1
130                debug!("Normalizing embeddings");
131                normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
132                embeddings
133            }
134            Self::Pca => {
135                let nrecords = samples.0.nrows();
136                // use the PCA algorithm to project the data into a lower-dimensional space
137                debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using PCA");
138                let data = Dataset::from(samples.0);
139                let pca: Pca<f64> = Pca::params(EMBEDDING_SIZE).whiten(true).fit(&data)?;
140                let mut embeddings = pca.predict(&data);
141                debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
142
143                // normalize the embeddings so each dimension is between -1 and 1
144                debug!("Normalizing embeddings");
145                normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
146                embeddings
147            }
148            Self::None => {
149                debug!("Using original data as embeddings");
150                samples.0
151            }
152        };
153        debug!("Embeddings shape: {:?}", result.shape());
154        Ok(result)
155    }
156}
157
158// Normalize the embeddings to between 0.0 and 1.0, in-place.
159// Pass the embedding size as an argument to enable more compiler optimizations
160fn normalize_embeddings_inplace<const SIZE: usize>(embeddings: &mut Array2<f64>) {
161    for i in 0..SIZE {
162        let min = embeddings.column(i).min();
163        let max = embeddings.column(i).max();
164        let range = max - min;
165        embeddings
166            .column_mut(i)
167            .mapv_inplace(|v| ((v - min) / range).mul_add(2., -1.));
168    }
169}
170
171// log the number of features
172/// Dimensionality that the T-SNE and PCA projection methods aim to project the data into.
173const EMBEDDING_SIZE: usize = {
174    let log2 = usize::ilog2(NUMBER_FEATURES) as usize;
175    if log2 < 2 { 2 } else { log2 }
176};
177
178#[allow(clippy::module_name_repetitions)]
179pub struct ClusteringHelper<S>
180where
181    S: Sized,
182{
183    state: S,
184}
185
186pub struct EntryPoint;
187pub struct NotInitialized {
188    /// The embeddings of our input, as a Nx`EMBEDDING_SIZE` array
189    embeddings: Array2<Feature>,
190    pub k_max: usize,
191    pub optimizer: KOptimal,
192    pub clustering_method: ClusteringMethod,
193}
194pub struct Initialized {
195    /// The embeddings of our input, as a Nx`EMBEDDING_SIZE` array
196    embeddings: Array2<Feature>,
197    pub k: usize,
198    pub clustering_method: ClusteringMethod,
199}
200pub struct Finished {
201    /// The labelings of the samples, as a Nx1 array.
202    /// Each element is the cluster that the corresponding sample belongs to.
203    labels: Array1<usize>,
204    pub k: usize,
205}
206
207/// Functions available for all states
208impl ClusteringHelper<EntryPoint> {
209    /// Create a new `KMeansHelper` object
210    ///
211    /// # Errors
212    ///
213    /// Will return an error if there was an error projecting the data into a lower-dimensional space
214    #[allow(clippy::missing_inline_in_public_items)]
215    pub fn new(
216        samples: AnalysisArray,
217        k_max: usize,
218        optimizer: KOptimal,
219        clustering_method: ClusteringMethod,
220        projection_method: ProjectionMethod,
221    ) -> Result<ClusteringHelper<NotInitialized>, ClusteringError> {
222        if samples.0.nrows() <= 15 {
223            return Err(ClusteringError::SmallLibrary);
224        }
225
226        // project the data into a lower-dimensional space
227        let embeddings = projection_method.project(samples)?;
228
229        Ok(ClusteringHelper {
230            state: NotInitialized {
231                embeddings,
232                k_max,
233                optimizer,
234                clustering_method,
235            },
236        })
237    }
238}
239
240/// Functions available for `NotInitialized` state
241impl ClusteringHelper<NotInitialized> {
242    /// Initialize the `KMeansHelper` object
243    ///
244    /// # Errors
245    ///
246    /// Will return an error if there was an error calculating the optimal number of clusters
247    #[inline]
248    pub fn initialize(self) -> Result<ClusteringHelper<Initialized>, ClusteringError> {
249        let k = self.get_optimal_k()?;
250        Ok(ClusteringHelper {
251            state: Initialized {
252                embeddings: self.state.embeddings,
253                k,
254                clustering_method: self.state.clustering_method,
255            },
256        })
257    }
258
259    fn get_optimal_k(&self) -> Result<usize, ClusteringError> {
260        match self.state.optimizer {
261            KOptimal::GapStatistic { b } => self.get_optimal_k_gap_statistic(b),
262            KOptimal::DaviesBouldin => self.get_optimal_k_davies_bouldin(),
263        }
264    }
265
266    /// Get the optimal number of clusters using the gap statistic
267    ///
268    /// # References:
269    ///
270    /// - [R. Tibshirani, G. Walther, and T. Hastie (Standford University, 2001)](https://hastie.su.domains/Papers/gap.pdf)
271    ///
272    /// # Algorithm:
273    ///
274    /// 1. Cluster the observed data, varying the number of clusters from k = 1, …, kmax, and compute the corresponding total within intra-cluster variation Wk.
275    /// 2. Generate B reference data sets with a random uniform distribution. Cluster each of these reference data sets with varying number of clusters k = 1, …, kmax,
276    ///    and compute the corresponding total within intra-cluster variation `W_{kb}`.
277    /// 3. Compute the estimated gap statistic as the deviation of the observed `W_k` value from its expected value `W_kb` under the null hypothesis:
278    ///    `Gap(k)=(1/B) \sum_{b=1}^{B} \log(W^*_{kb}) − \log(W_k)`.
279    ///    Compute also the standard deviation of the statistics.
280    /// 4. Choose the number of clusters as the smallest value of k such that the gap statistic is within one standard deviation of the gap at k+1:
281    ///    `Gap(k)≥Gap(k + 1)−s_{k + 1}`.
282    fn get_optimal_k_gap_statistic(&self, b: u32) -> Result<usize, ClusteringError> {
283        // our reference data sets
284        let reference_data_sets =
285            generate_reference_data_set(self.state.embeddings.view(), b as usize);
286
287        let b = f64::from(b);
288
289        let results = (1..=self.state.k_max)
290            // for each k, cluster the data into k clusters
291            .map(|k| {
292                debug!("Fitting k-means to embeddings with k={k}");
293                let labels = self.state.clustering_method.fit(k, &self.state.embeddings);
294                (k, labels)
295            })
296            // for each k, calculate the gap statistic, and the standard deviation of the statistics
297            .map(|(k, labels)| {
298                // calculate the within intra-cluster variation for the reference data sets
299                debug!(
300                    "Calculating within intra-cluster variation for reference data sets with k={k}"
301                );
302                let w_kb_log: Vec<_> = reference_data_sets
303                    .par_iter()
304                    .map(|ref_data| {
305                        // cluster the reference data into k clusters
306                        let ref_labels = self.state.clustering_method.fit(k, ref_data);
307                        // calculate the within intra-cluster variation for the reference data
308                        let ref_pairwise_distances =
309                            calc_pairwise_distances(ref_data.view(), k, ref_labels.view());
310                        calc_within_dispersion(ref_labels.view(), k, ref_pairwise_distances.view())
311                            .log2()
312                    })
313                    .collect();
314
315                // calculate the within intra-cluster variation for the observed data
316                let pairwise_distances =
317                    calc_pairwise_distances(self.state.embeddings.view(), k, labels.view());
318                let w_k = calc_within_dispersion(labels.view(), k, pairwise_distances.view());
319
320                // finally, calculate the gap statistic
321                let w_kb_log_sum: f64 = w_kb_log.iter().sum();
322                // original formula: l = (1 / B) * sum_b(log(W_kb))
323                let l = b.recip() * w_kb_log_sum;
324                // original formula: gap_k = (1 / B) * sum_b(log(W_kb)) - log(W_k)
325                let gap_k = l - w_k.log2();
326                // original formula: sd_k = [(1 / B) * sum_b((log(W_kb) - l)^2)]^0.5
327                let standard_deviation = (b.recip()
328                    * w_kb_log
329                        .iter()
330                        .map(|w_kb_log| (w_kb_log - l).powi(2))
331                        .sum::<f64>())
332                .sqrt();
333                // original formula: s_k = sd_k * (1 + 1 / B)^0.5
334                // calculate differently to minimize rounding errors
335                let s_k = standard_deviation * (1.0 + b.recip()).sqrt();
336
337                (k, gap_k, s_k)
338            });
339
340        // // plot the gap_k (whisker with s_k) w.r.t. k
341        // #[cfg(feature = "plot_gap")]
342        // plot_gap_statistic(results.clone().collect::<Vec<_>>());
343
344        // finally, consume the results iterator until we find the optimal k
345        let (mut optimal_k, mut gap_k_minus_one) = (None, None);
346        for (k, gap_k, s_k) in results {
347            info!("k: {k}, gap_k: {gap_k}, s_k: {s_k}");
348
349            if let Some(gap_k_minus_one) = gap_k_minus_one
350                && gap_k_minus_one >= gap_k - s_k
351            {
352                info!("Optimal k found: {}", k - 1);
353                optimal_k = Some(k - 1);
354                break;
355            }
356
357            gap_k_minus_one = Some(gap_k);
358        }
359
360        optimal_k.ok_or(ClusteringError::OptimalKNotFound(self.state.k_max))
361    }
362
363    fn get_optimal_k_davies_bouldin(&self) -> Result<usize, ClusteringError> {
364        todo!();
365    }
366}
367
368/// Convert a vector of Analyses into a 2D array
369///
370/// # Panics
371///
372/// Will panic if the shape of the data does not match the number of features, should never happen
373#[must_use]
374#[inline]
375pub fn convert_to_array(data: Vec<Analysis>) -> AnalysisArray {
376    // Convert vector to Array
377    let shape = (data.len(), NUMBER_FEATURES);
378    debug_assert_eq!(NUMBER_FEATURES, data[0].inner().len());
379
380    AnalysisArray(
381        Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
382            .expect("Failed to convert to array, shape mismatch"),
383    )
384}
385
386/// Generate B reference data sets with a random uniform distribution
387///
388/// (excerpt from reference paper)
389/// """
390/// We consider two choices for the reference distribution:
391///
392/// 1. generate each reference feature uniformly over the range of the observed values for that feature.
393/// 2. generate the reference features from a uniform distribution over a box aligned with the
394///    principle components of the data.
395///    In detail, if X is our n by p data matrix, we assume that the columns have mean 0 and compute
396///    the singular value decomposition X = UDV^T. We transform via X' = XV and then draw uniform features Z'
397///    over the ranges of the columns of X', as in method (1) above.
398///    Finally, we back-transform via Z=Z'V^T to give reference data Z.
399///
400/// Method (1) has the advantage of simplicity. Method (2) takes into account the shape of the
401/// data distribution and makes the procedure rotationally invariant, as long as the
402/// clustering method itself is invariant
403/// """
404///
405/// For now, we will use method (1) as it is simpler to implement
406/// and we know that our data is already normalized and that
407/// the ordering of features is important, meaning that we can't
408/// rotate the data anyway.
409fn generate_reference_data_set(samples: ArrayView2<Feature>, b: usize) -> Vec<Array2<f64>> {
410    let mut reference_data_sets = Vec::with_capacity(b);
411    for _ in 0..b {
412        reference_data_sets.push(generate_ref_single(samples));
413    }
414
415    reference_data_sets
416}
417fn generate_ref_single(samples: ArrayView2<Feature>) -> Array2<f64> {
418    let feature_distributions = samples
419        .axis_iter(Axis(1))
420        .map(|feature| Array::random(feature.dim(), Uniform::new(feature.min(), feature.max())))
421        .collect::<Vec<_>>();
422    let feature_dists_views = feature_distributions
423        .iter()
424        .map(ndarray::ArrayBase::view)
425        .collect::<Vec<_>>();
426    ndarray::stack(Axis(0), &feature_dists_views)
427        .unwrap()
428        .t()
429        .to_owned()
430}
431
432/// Calculate `W_k`, the within intra-cluster variation for the given clustering
433///
434/// `W_k = \sum_{r=1}^{k} \frac{D_r}{2*n_r}`
435fn calc_within_dispersion(
436    labels: ArrayView1<usize>,
437    k: usize,
438    pairwise_distances: ArrayView1<Feature>,
439) -> Feature {
440    debug_assert_eq!(k, labels.iter().max().unwrap() + 1);
441
442    // we first need to convert our list of labels into a list of the number of samples in each cluster
443    let counts = labels.iter().fold(vec![0u32; k], |mut counts, &label| {
444        counts[label] += 1;
445        counts
446    });
447    // then, we calculate the within intra-cluster variation
448    counts
449        .iter()
450        .zip(pairwise_distances.iter())
451        .map(|(&count, distance)| (2.0 * f64::from(count)).recip() * distance)
452        .sum()
453}
454
455/// Calculate the `D_r` array, the sum of the pairwise distances in cluster r, for all clusters in the given clustering
456///
457/// # Arguments
458///
459/// - `samples`: The samples in the dataset
460/// - `k`: The number of clusters
461/// - `labels`: The cluster labels for each sample
462fn calc_pairwise_distances(
463    samples: ArrayView2<Feature>,
464    k: usize,
465    labels: ArrayView1<usize>,
466) -> Array1<Feature> {
467    debug_assert_eq!(
468        samples.nrows(),
469        labels.len(),
470        "Samples and labels must have the same length"
471    );
472    debug_assert_eq!(
473        k,
474        labels.iter().max().unwrap() + 1,
475        "Labels must be in the range 0..k"
476    );
477
478    // for each cluster, calculate the sum of the pairwise distances between samples in that cluster
479    let mut distances = Array1::zeros(k);
480    let mut clusters = vec![Vec::new(); k];
481    // build clusters
482    for (sample, label) in samples.outer_iter().zip(labels.iter()) {
483        clusters[*label].push(sample);
484    }
485    // calculate pairwise dist. within each cluster
486    for (k, cluster) in clusters.iter().enumerate() {
487        let mut pairwise_dists = 0.;
488        for i in 0..cluster.len() - 1 {
489            let a = cluster[i];
490            let rest = &cluster[i + 1..];
491            for &b in rest {
492                pairwise_dists += L2Dist.distance(a, b);
493            }
494        }
495        distances[k] += pairwise_dists + pairwise_dists;
496    }
497    distances
498}
499
500/// Functions available for Initialized state
501impl ClusteringHelper<Initialized> {
502    /// Cluster the data into k clusters
503    ///
504    /// # Errors
505    ///
506    /// Will return an error if the clustering fails
507    #[must_use]
508    #[inline]
509    pub fn cluster(self) -> ClusteringHelper<Finished> {
510        let labels = self
511            .state
512            .clustering_method
513            .fit(self.state.k, &self.state.embeddings);
514
515        ClusteringHelper {
516            state: Finished {
517                labels,
518                k: self.state.k,
519            },
520        }
521    }
522}
523
524/// Functions available for Finished state
525impl ClusteringHelper<Finished> {
526    /// use the labels to reorganize the provided samples into clusters
527    #[must_use]
528    #[inline]
529    pub fn extract_analysis_clusters<T: Clone>(&self, samples: Vec<T>) -> Vec<Vec<T>> {
530        let mut clusters = vec![Vec::new(); self.state.k];
531
532        for (sample, &label) in samples.into_iter().zip(self.state.labels.iter()) {
533            clusters[label].push(sample);
534        }
535
536        clusters
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543    use ndarray::{arr1, arr2, s};
544    use ndarray_rand::rand_distr::StandardNormal;
545    use pretty_assertions::assert_eq;
546    use rstest::rstest;
547
548    #[test]
549    fn test_generate_reference_data_set() {
550        let data = arr2(&[[10.0, -10.0], [20.0, -20.0], [30.0, -30.0]]);
551
552        let ref_data = generate_ref_single(data.view());
553
554        // First column all vals between 10.0 and 30.0
555        assert!(
556            ref_data
557                .slice(s![.., 0])
558                .iter()
559                .all(|v| *v >= 10.0 && *v <= 30.0)
560        );
561
562        // Second column all vals between -10.0 and -30.0
563        assert!(
564            ref_data
565                .slice(s![.., 1])
566                .iter()
567                .all(|v| *v <= -10.0 && *v >= -30.0)
568        );
569
570        // check that the shape is correct
571        assert_eq!(ref_data.shape(), data.shape());
572
573        // check that the data is not the same as the original data
574        assert_ne!(ref_data, data);
575    }
576
577    #[test]
578    fn test_pairwise_distances() {
579        let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
580        let labels = arr1(&[0, 0, 1, 1]);
581
582        let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
583
584        assert!(
585            f64::EPSILON > (pairwise_distances[0] - 0.0).abs(),
586            "{} != 0.0",
587            pairwise_distances[0]
588        );
589        assert!(
590            f64::EPSILON > (pairwise_distances[1] - 0.0).abs(),
591            "{} != 0.0",
592            pairwise_distances[1]
593        );
594
595        let samples = arr2(&[[1.0, 2.0], [1.0, 1.0], [2.0, 2.0], [2.0, 3.0]]);
596
597        let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
598
599        assert!(
600            f64::EPSILON > (pairwise_distances[0] - 2.0).abs(),
601            "{} != 2.0",
602            pairwise_distances[0]
603        );
604        assert!(
605            f64::EPSILON > (pairwise_distances[1] - 2.0).abs(),
606            "{} != 2.0",
607            pairwise_distances[1]
608        );
609    }
610
611    #[test]
612    fn test_convert_to_vec() {
613        let data = vec![
614            Analysis::new([1.0; NUMBER_FEATURES]),
615            Analysis::new([2.0; NUMBER_FEATURES]),
616            Analysis::new([3.0; NUMBER_FEATURES]),
617        ];
618
619        let array = convert_to_array(data);
620
621        assert_eq!(array.0.shape(), &[3, NUMBER_FEATURES]);
622        assert!(
623            f64::EPSILON > (array.0[[0, 0]] - 1.0).abs(),
624            "{} != 1.0",
625            array.0[[0, 0]]
626        );
627        assert!(
628            f64::EPSILON > (array.0[[1, 0]] - 2.0).abs(),
629            "{} != 2.0",
630            array.0[[1, 0]]
631        );
632        assert!(
633            f64::EPSILON > (array.0[[2, 0]] - 3.0).abs(),
634            "{} != 3.0",
635            array.0[[2, 0]]
636        );
637
638        // check that axis iteration works how we expect
639        // axis 0
640        let mut iter = array.0.axis_iter(Axis(0));
641        assert_eq!(iter.next().unwrap().to_vec(), vec![1.0; NUMBER_FEATURES]);
642        assert_eq!(iter.next().unwrap().to_vec(), vec![2.0; NUMBER_FEATURES]);
643        assert_eq!(iter.next().unwrap().to_vec(), vec![3.0; NUMBER_FEATURES]);
644        // axis 1
645        for column in array.0.axis_iter(Axis(1)) {
646            assert_eq!(column.to_vec(), vec![1.0, 2.0, 3.0]);
647        }
648    }
649
650    #[test]
651    fn test_calc_within_dispersion() {
652        let labels = arr1(&[0, 1, 0, 1]);
653        let pairwise_distances = arr1(&[1.0, 2.0]);
654        let result = calc_within_dispersion(labels.view(), 2, pairwise_distances.view());
655
656        // `W_k = \sum_{r=1}^{k} \frac{D_r}{2*n_r}` = 1/4 * 1.0 + 1/4 * 2.0 = 0.25 + 0.5 = 0.75
657        assert!(f64::EPSILON > (result - 0.75).abs(), "{result} != 0.75");
658    }
659
660    #[rstest]
661    #[case::project_none(ProjectionMethod::None, NUMBER_FEATURES)]
662    #[case::project_tsne(ProjectionMethod::TSne, EMBEDDING_SIZE)]
663    #[case::project_pca(ProjectionMethod::Pca, EMBEDDING_SIZE)]
664    fn test_project(
665        #[case] projection_method: ProjectionMethod,
666        #[case] expected_embedding_size: usize,
667    ) {
668        // generate 100 random samples, we use a normal distribution because with a uniform distribution
669        // the data has no real "principle components" and PCA will not work as expected since almost all the eigenvalues
670        // with fall below the cutoff
671        let mut samples = Array2::random((100, NUMBER_FEATURES), StandardNormal);
672        normalize_embeddings_inplace::<NUMBER_FEATURES>(&mut samples);
673        let samples = AnalysisArray(samples);
674
675        let result = projection_method.project(samples).unwrap();
676
677        // ensure embeddings are the correct shape
678        assert_eq!(result.shape(), &[100, expected_embedding_size]);
679
680        // ensure the data is normalized
681        for i in 0..expected_embedding_size {
682            let min = result.column(i).min();
683            let max = result.column(i).max();
684            assert!(
685                f64::EPSILON > (min + 1.0).abs(),
686                "Min value of column {i} is not -1.0: {min}",
687            );
688            assert!(
689                f64::EPSILON > (max - 1.0).abs(),
690                "Max value of column {i} is not 1.0: {max}",
691            );
692        }
693    }
694}
695
696// #[cfg(feature = "plot_gap")]
697// fn plot_gap_statistic(data: Vec<(usize, f64, f64)>) {
698//     use plotters::prelude::*;
699
700//     // Assuming data is a Vec<(usize, f64, f64)> of (k, gap_k, s_k)
701//     let root_area = BitMapBackend::new("gap_statistic_plot.png", (640, 480)).into_drawing_area();
702//     root_area.fill(&WHITE).unwrap();
703
704//     let max_gap_k = data
705//         .iter()
706//         .map(|(_, gap_k, _)| *gap_k)
707//         .fold(f64::MIN, f64::max);
708//     let min_gap_k = data
709//         .iter()
710//         .map(|(_, gap_k, _)| *gap_k)
711//         .fold(f64::MAX, f64::min);
712//     let max_k = data.iter().map(|(k, _, _)| *k).max().unwrap_or(0);
713
714//     let mut chart = ChartBuilder::on(&root_area)
715//         .caption("Gap Statistic Plot", ("sans-serif", 30))
716//         .margin(5)
717//         .x_label_area_size(30)
718//         .y_label_area_size(30)
719//         .build_cartesian_2d(0..max_k, min_gap_k..max_gap_k)
720//         .unwrap();
721
722//     chart.configure_mesh().draw().unwrap();
723
724//     for (k, gap_k, s_k) in data {
725//         chart
726//             .draw_series(PointSeries::of_element(
727//                 vec![(k, gap_k)],
728//                 5,
729//                 &RED,
730//                 &|coord, size, style| {
731//                     EmptyElement::at(coord) + Circle::new((0, 0), size, style.filled())
732//                 },
733//             ))
734//             .unwrap();
735
736//         // Drawing error bars
737//         chart
738//             .draw_series(LineSeries::new(
739//                 vec![(k, gap_k - s_k), (k, gap_k + s_k)],
740//                 &BLACK,
741//             ))
742//             .unwrap();
743//     }
744
745//     root_area.present().unwrap();
746// }