mecomp_analysis/
clustering.rs

1//! this module contains helpers that wrap the a k-means crate to perform clustering on the data
2//! without having to choose an exact number of clusters.
3//!
4//! Instead, you provide the minimum and maximum number of clusters you want to try, and we'll
5//! use one of a range of methods to determine the optimal number of clusters.
6//!
7//! # References:
8//!
9//! - The gap statistic [R. Tibshirani, G. Walther, and T. Hastie (Standford University, 2001)](https://hastie.su.domains/Papers/gap.pdf)
10//! - The Davies-Bouldin index [wikipedia](https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index)
11
12use linfa::prelude::*;
13use linfa_clustering::{GaussianMixtureModel, KMeans};
14use linfa_nn::distance::{Distance, L2Dist};
15use linfa_reduction::Pca;
16use linfa_tsne::TSneParams;
17use log::{debug, info};
18use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis};
19use ndarray_rand::RandomExt;
20use rand::distributions::Uniform;
21use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
22use statrs::statistics::Statistics;
23
24use crate::{
25    Analysis, Feature, NUMBER_FEATURES,
26    errors::{ClusteringError, ProjectionError},
27};
28
29pub struct AnalysisArray(pub(crate) Array2<Feature>);
30
31impl From<Vec<Analysis>> for AnalysisArray {
32    #[inline]
33    fn from(data: Vec<Analysis>) -> Self {
34        let shape = (data.len(), NUMBER_FEATURES);
35        debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
36
37        Self(
38            Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
39                .expect("Failed to convert to array, shape mismatch"),
40        )
41    }
42}
43
44impl From<Vec<[Feature; NUMBER_FEATURES]>> for AnalysisArray {
45    #[inline]
46    fn from(data: Vec<[Feature; NUMBER_FEATURES]>) -> Self {
47        let shape = (data.len(), NUMBER_FEATURES);
48        debug_assert_eq!(shape, (data.len(), data[0].len()));
49
50        Self(
51            Array2::from_shape_vec(shape, data.into_iter().flatten().collect())
52                .expect("Failed to convert to array, shape mismatch"),
53        )
54    }
55}
56
57#[derive(Clone, Copy, Debug)]
58#[allow(clippy::module_name_repetitions)]
59pub enum ClusteringMethod {
60    KMeans,
61    GaussianMixtureModel,
62}
63
64impl ClusteringMethod {
65    /// Fit the clustering method to the samples and get the labels
66    #[must_use]
67    fn fit(self, k: usize, samples: &Array2<Feature>) -> Array1<usize> {
68        match self {
69            Self::KMeans => {
70                let model = KMeans::params(k)
71                    // .max_n_iterations(MAX_ITERATIONS)
72                    .fit(&Dataset::from(samples.clone()))
73                    .unwrap();
74                model.predict(samples)
75            }
76            Self::GaussianMixtureModel => {
77                let model = GaussianMixtureModel::params(k)
78                    .init_method(linfa_clustering::GmmInitMethod::KMeans)
79                    .n_runs(10)
80                    .fit(&Dataset::from(samples.clone()))
81                    .unwrap();
82                model.predict(samples)
83            }
84        }
85    }
86}
87
88#[derive(Clone, Copy, Debug)]
89pub enum KOptimal {
90    GapStatistic {
91        /// The number of reference datasets to generate
92        b: usize,
93    },
94    DaviesBouldin,
95}
96
97#[derive(Clone, Copy, Debug, Default)]
98/// Should the data be projected into a lower-dimensional space before clustering, if so how?
99pub enum ProjectionMethod {
100    /// Use t-SNE to project the data into a lower-dimensional space
101    TSne,
102    /// Use PCA to project the data into a lower-dimensional space
103    Pca,
104    #[default]
105    /// Don't project the data
106    None,
107}
108
109impl ProjectionMethod {
110    /// Project the data into a lower-dimensional space
111    ///
112    /// # Errors
113    ///
114    /// Will return an error if there was an error projecting the data into a lower-dimensional space
115    #[inline]
116    pub fn project(self, samples: AnalysisArray) -> Result<Array2<Feature>, ProjectionError> {
117        let result = match self {
118            Self::TSne => {
119                let nrecords = samples.0.nrows();
120                // first use the t-SNE algorithm to project the data into a lower-dimensional space
121                debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using t-SNE");
122                #[allow(clippy::cast_precision_loss)]
123                let mut embeddings = TSneParams::embedding_size(EMBEDDING_SIZE)
124                    .perplexity(f64::max(samples.0.nrows() as f64 / 20., 5.))
125                    .approx_threshold(0.5)
126                    .transform(samples.0)?;
127                debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
128
129                // normalize the embeddings so each dimension is between -1 and 1
130                debug!("Normalizing embeddings");
131                normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
132                embeddings
133            }
134            Self::Pca => {
135                let nrecords = samples.0.nrows();
136                // use the PCA algorithm to project the data into a lower-dimensional space
137                debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using PCA");
138                let data = Dataset::from(samples.0);
139                let pca: Pca<f64> = Pca::params(EMBEDDING_SIZE).whiten(true).fit(&data)?;
140                let mut embeddings = pca.predict(&data);
141                debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
142
143                // normalize the embeddings so each dimension is between -1 and 1
144                debug!("Normalizing embeddings");
145                normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
146                embeddings
147            }
148            Self::None => {
149                debug!("Using original data as embeddings");
150                samples.0
151            }
152        };
153        debug!("Embeddings shape: {:?}", result.shape());
154        Ok(result)
155    }
156}
157
158// Normalize the embeddings to between 0.0 and 1.0, in-place.
159// Pass the embedding size as an argument to enable more compiler optimizations
160fn normalize_embeddings_inplace<const SIZE: usize>(embeddings: &mut Array2<f64>) {
161    for i in 0..SIZE {
162        let min = embeddings.column(i).min();
163        let max = embeddings.column(i).max();
164        let range = max - min;
165        embeddings
166            .column_mut(i)
167            .mapv_inplace(|v| ((v - min) / range).mul_add(2., -1.));
168    }
169}
170
171// log the number of features
172/// Dimensionality that the T-SNE and PCA projection methods aim to project the data into.
173const EMBEDDING_SIZE: usize = {
174    let log2 = usize::ilog2(NUMBER_FEATURES) as usize;
175    if log2 < 2 { 2 } else { log2 }
176};
177
178#[allow(clippy::module_name_repetitions)]
179pub struct ClusteringHelper<S>
180where
181    S: Sized,
182{
183    state: S,
184}
185
186pub struct EntryPoint;
187pub struct NotInitialized {
188    /// The embeddings of our input, as a Nx`EMBEDDING_SIZE` array
189    embeddings: Array2<Feature>,
190    pub k_max: usize,
191    pub optimizer: KOptimal,
192    pub clustering_method: ClusteringMethod,
193}
194pub struct Initialized {
195    /// The embeddings of our input, as a Nx`EMBEDDING_SIZE` array
196    embeddings: Array2<Feature>,
197    pub k: usize,
198    pub clustering_method: ClusteringMethod,
199}
200pub struct Finished {
201    /// The labelings of the samples, as a Nx1 array.
202    /// Each element is the cluster that the corresponding sample belongs to.
203    labels: Array1<usize>,
204    pub k: usize,
205}
206
207/// Functions available for all states
208impl ClusteringHelper<EntryPoint> {
209    /// Create a new `KMeansHelper` object
210    ///
211    /// # Errors
212    ///
213    /// Will return an error if there was an error projecting the data into a lower-dimensional space
214    #[allow(clippy::missing_inline_in_public_items)]
215    pub fn new(
216        samples: AnalysisArray,
217        k_max: usize,
218        optimizer: KOptimal,
219        clustering_method: ClusteringMethod,
220        projection_method: ProjectionMethod,
221    ) -> Result<ClusteringHelper<NotInitialized>, ClusteringError> {
222        if samples.0.nrows() <= 15 {
223            return Err(ClusteringError::SmallLibrary);
224        }
225
226        // project the data into a lower-dimensional space
227        let embeddings = projection_method.project(samples)?;
228
229        Ok(ClusteringHelper {
230            state: NotInitialized {
231                embeddings,
232                k_max,
233                optimizer,
234                clustering_method,
235            },
236        })
237    }
238}
239
240/// Functions available for `NotInitialized` state
241impl ClusteringHelper<NotInitialized> {
242    /// Initialize the `KMeansHelper` object
243    ///
244    /// # Errors
245    ///
246    /// Will return an error if there was an error calculating the optimal number of clusters
247    #[inline]
248    pub fn initialize(self) -> Result<ClusteringHelper<Initialized>, ClusteringError> {
249        let k = self.get_optimal_k()?;
250        Ok(ClusteringHelper {
251            state: Initialized {
252                embeddings: self.state.embeddings,
253                k,
254                clustering_method: self.state.clustering_method,
255            },
256        })
257    }
258
259    fn get_optimal_k(&self) -> Result<usize, ClusteringError> {
260        match self.state.optimizer {
261            KOptimal::GapStatistic { b } => self.get_optimal_k_gap_statistic(b),
262            KOptimal::DaviesBouldin => self.get_optimal_k_davies_bouldin(),
263        }
264    }
265
266    /// Get the optimal number of clusters using the gap statistic
267    ///
268    /// # References:
269    ///
270    /// - [R. Tibshirani, G. Walther, and T. Hastie (Standford University, 2001)](https://hastie.su.domains/Papers/gap.pdf)
271    ///
272    /// # Algorithm:
273    ///
274    /// 1. Cluster the observed data, varying the number of clusters from k = 1, …, kmax, and compute the corresponding total within intra-cluster variation Wk.
275    /// 2. Generate B reference data sets with a random uniform distribution. Cluster each of these reference data sets with varying number of clusters k = 1, …, kmax,
276    ///    and compute the corresponding total within intra-cluster variation `W_{kb}`.
277    /// 3. Compute the estimated gap statistic as the deviation of the observed `W_k` value from its expected value `W_kb` under the null hypothesis:
278    ///    `Gap(k)=(1/B) \sum_{b=1}^{B} \log(W^*_{kb}) − \log(W_k)`.
279    ///    Compute also the standard deviation of the statistics.
280    /// 4. Choose the number of clusters as the smallest value of k such that the gap statistic is within one standard deviation of the gap at k+1:
281    ///    `Gap(k)≥Gap(k + 1)−s_{k + 1}`.
282    fn get_optimal_k_gap_statistic(&self, b: usize) -> Result<usize, ClusteringError> {
283        // our reference data sets
284        let reference_data_sets = generate_reference_data_set(self.state.embeddings.view(), b);
285
286        let results = (1..=self.state.k_max)
287            // for each k, cluster the data into k clusters
288            .map(|k| {
289                debug!("Fitting k-means to embeddings with k={k}");
290                let labels = self.state.clustering_method.fit(k, &self.state.embeddings);
291                (k, labels)
292            })
293            // for each k, calculate the gap statistic, and the standard deviation of the statistics
294            .map(|(k, labels)| {
295                // first, we calculate the within intra-cluster variation for the observed data
296                let pairwise_distances =
297                    calc_pairwise_distances(self.state.embeddings.view(), k, labels.view());
298                let w_k = calc_within_dispersion(labels.view(), k, pairwise_distances.view());
299
300                // then, we calculate the within intra-cluster variation for the reference data sets
301                debug!(
302                    "Calculating within intra-cluster variation for reference data sets with k={k}"
303                );
304                let w_kb = reference_data_sets.par_iter().map(|ref_data| {
305                    // cluster the reference data into k clusters
306                    let ref_labels = self.state.clustering_method.fit(k, ref_data);
307                    // calculate the within intra-cluster variation for the reference data
308                    let ref_pairwise_distances =
309                        calc_pairwise_distances(ref_data.view(), k, ref_labels.view());
310                    calc_within_dispersion(ref_labels.view(), k, ref_pairwise_distances.view())
311                });
312
313                // finally, we calculate the gap statistic
314                let w_kb_log_sum = w_kb.clone().map(f64::log2).sum::<f64>();
315                // original formula: l = (1 / B) * sum_b(log(W_kb))
316                #[allow(clippy::cast_precision_loss)]
317                let l = (1.0 / b as f64) * w_kb_log_sum;
318                // original formula: gap_k = (1 / B) * sum_b(log(W_kb)) - log(W_k)
319                #[allow(clippy::cast_precision_loss)]
320                let gap_k = l - w_k.log2();
321                // original formula: sd_k = [(1 / B) * sum_b((log(W_kb) - l)^2)]^0.5
322                #[allow(clippy::cast_precision_loss)]
323                let standard_deviation = ((1.0 / b as f64)
324                    * w_kb.map(|w_kb| (w_kb.log2() - l).powi(2)).sum::<f64>())
325                .sqrt();
326                // original formula: s_k = sd_k * (1 + 1 / B)^0.5
327                // calculate differently to minimize rounding errors
328                #[allow(clippy::cast_precision_loss)]
329                let s_k = standard_deviation * (1.0 + 1.0 / b as f64).sqrt();
330
331                (k, gap_k, s_k)
332            });
333
334        // // plot the gap_k (whisker with s_k) w.r.t. k
335        // #[cfg(feature = "plot_gap")]
336        // plot_gap_statistic(results.clone().collect::<Vec<_>>());
337
338        // finally, we go over the iterator to find the optimal k
339        let (mut optimal_k, mut gap_k_minus_one) = (None, None);
340
341        for (k, gap_k, s_k) in results {
342            info!("k: {k}, gap_k: {gap_k}, s_k: {s_k}");
343
344            if let Some(gap_k_minus_one) = gap_k_minus_one {
345                if gap_k_minus_one >= gap_k - s_k {
346                    info!("Optimal k found: {}", k - 1);
347                    optimal_k = Some(k - 1);
348                    break;
349                }
350            }
351            gap_k_minus_one = Some(gap_k);
352        }
353
354        optimal_k.ok_or(ClusteringError::OptimalKNotFound(self.state.k_max))
355    }
356
357    fn get_optimal_k_davies_bouldin(&self) -> Result<usize, ClusteringError> {
358        todo!();
359    }
360}
361
362/// Convert a vector of Analyses into a 2D array
363///
364/// # Panics
365///
366/// Will panic if the shape of the data does not match the number of features, should never happen
367#[must_use]
368#[inline]
369pub fn convert_to_array(data: Vec<Analysis>) -> AnalysisArray {
370    // Convert vector to Array
371    let shape = (data.len(), NUMBER_FEATURES);
372    debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
373
374    AnalysisArray(
375        Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
376            .expect("Failed to convert to array, shape mismatch"),
377    )
378}
379
380/// Generate B reference data sets with a random uniform distribution
381///
382/// (excerpt from reference paper)
383/// """
384/// We consider two choices for the reference distribution:
385///
386/// 1. generate each reference feature uniformly over the range of the observed values for that feature.
387/// 2. generate the reference features from a uniform distribution over a box aligned with the
388///    principle components of the data.
389///    In detail, if X is our n by p data matrix, we assume that the columns have mean 0 and compute
390///    the singular value decomposition X = UDV^T. We transform via X' = XV and then draw uniform features Z'
391///    over the ranges of the columns of X', as in method (1) above.
392///    Finally, we back-transform via Z=Z'V^T to give reference data Z.
393///
394/// Method (1) has the advantage of simplicity. Method (2) takes into account the shape of the
395/// data distribution and makes the procedure rotationally invariant, as long as the
396/// clustering method itself is invariant
397/// """
398///
399/// For now, we will use method (1) as it is simpler to implement
400/// and we know that our data is already normalized and that
401/// the ordering of features is important, meaning that we can't
402/// rotate the data anyway.
403fn generate_reference_data_set(samples: ArrayView2<Feature>, b: usize) -> Vec<Array2<f64>> {
404    let mut reference_data_sets = Vec::with_capacity(b);
405    for _ in 0..b {
406        reference_data_sets.push(generate_ref_single(samples));
407    }
408
409    reference_data_sets
410}
411fn generate_ref_single(samples: ArrayView2<Feature>) -> Array2<f64> {
412    let feature_distributions = samples
413        .axis_iter(Axis(1))
414        .map(|feature| Array::random(feature.dim(), Uniform::new(feature.min(), feature.max())))
415        .collect::<Vec<_>>();
416    let feature_dists_views = feature_distributions
417        .iter()
418        .map(ndarray::ArrayBase::view)
419        .collect::<Vec<_>>();
420    ndarray::stack(Axis(0), &feature_dists_views)
421        .unwrap()
422        .t()
423        .to_owned()
424}
425
426/// Calculate `W_k`, the within intra-cluster variation for the given clustering
427///
428/// `W_k = \sum_{r=1}^{k} \frac{D_r}{2*n_r}`
429fn calc_within_dispersion(
430    labels: ArrayView1<usize>,
431    k: usize,
432    pairwise_distances: ArrayView1<Feature>,
433) -> Feature {
434    debug_assert_eq!(k, labels.iter().max().unwrap() + 1);
435
436    // we first need to convert our list of labels into a list of the number of samples in each cluster
437    let counts = labels.iter().fold(vec![0u32; k], |mut counts, &label| {
438        counts[label] += 1;
439        counts
440    });
441    // then, we calculate the within intra-cluster variation
442    counts
443        .iter()
444        .zip(pairwise_distances.iter())
445        .map(|(&count, distance)| (1. / (2.0 * f64::from(count))) * distance)
446        .sum()
447}
448
449/// Calculate the `D_r` array, the sum of the pairwise distances in cluster r, for all clusters in the given clustering
450///
451/// # Arguments
452///
453/// - `samples`: The samples in the dataset
454/// - `k`: The number of clusters
455/// - `labels`: The cluster labels for each sample
456fn calc_pairwise_distances(
457    samples: ArrayView2<Feature>,
458    k: usize,
459    labels: ArrayView1<usize>,
460) -> Array1<Feature> {
461    debug_assert_eq!(
462        samples.nrows(),
463        labels.len(),
464        "Samples and labels must have the same length"
465    );
466    debug_assert_eq!(
467        k,
468        labels.iter().max().unwrap() + 1,
469        "Labels must be in the range 0..k"
470    );
471
472    // for each cluster, calculate the sum of the pairwise distances between samples in that cluster
473    let mut distances = Array1::zeros(k);
474    let mut clusters = vec![Vec::new(); k];
475    // build clusters
476    for (sample, label) in samples.outer_iter().zip(labels.iter()) {
477        clusters[*label].push(sample);
478    }
479    // calculate pairwise dist. within each cluster
480    for (k, cluster) in clusters.iter().enumerate() {
481        let mut pairwise_dists = 0.;
482        for i in 0..cluster.len() - 1 {
483            let a = cluster[i];
484            let rest = &cluster[i + 1..];
485            for &b in rest {
486                pairwise_dists += L2Dist.distance(a, b);
487            }
488        }
489        distances[k] += pairwise_dists + pairwise_dists;
490    }
491    distances
492}
493
494/// Functions available for Initialized state
495impl ClusteringHelper<Initialized> {
496    /// Cluster the data into k clusters
497    ///
498    /// # Errors
499    ///
500    /// Will return an error if the clustering fails
501    #[must_use]
502    #[inline]
503    pub fn cluster(self) -> ClusteringHelper<Finished> {
504        let labels = self
505            .state
506            .clustering_method
507            .fit(self.state.k, &self.state.embeddings);
508
509        ClusteringHelper {
510            state: Finished {
511                labels,
512                k: self.state.k,
513            },
514        }
515    }
516}
517
518/// Functions available for Finished state
519impl ClusteringHelper<Finished> {
520    /// use the labels to reorganize the provided samples into clusters
521    #[must_use]
522    #[inline]
523    pub fn extract_analysis_clusters<T: Clone>(&self, samples: Vec<T>) -> Vec<Vec<T>> {
524        let mut clusters = vec![Vec::new(); self.state.k];
525
526        for (sample, &label) in samples.into_iter().zip(self.state.labels.iter()) {
527            clusters[label].push(sample);
528        }
529
530        clusters
531    }
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537    use ndarray::{arr1, arr2, s};
538    use ndarray_rand::rand_distr::StandardNormal;
539    use pretty_assertions::assert_eq;
540    use rstest::rstest;
541
542    #[test]
543    fn test_generate_reference_data_set() {
544        let data = arr2(&[[10.0, -10.0], [20.0, -20.0], [30.0, -30.0]]);
545
546        let ref_data = generate_ref_single(data.view());
547
548        // First column all vals between 10.0 and 30.0
549        assert!(
550            ref_data
551                .slice(s![.., 0])
552                .iter()
553                .all(|v| *v >= 10.0 && *v <= 30.0)
554        );
555
556        // Second column all vals between -10.0 and -30.0
557        assert!(
558            ref_data
559                .slice(s![.., 1])
560                .iter()
561                .all(|v| *v <= -10.0 && *v >= -30.0)
562        );
563
564        // check that the shape is correct
565        assert_eq!(ref_data.shape(), data.shape());
566
567        // check that the data is not the same as the original data
568        assert_ne!(ref_data, data);
569    }
570
571    #[test]
572    fn test_pairwise_distances() {
573        let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
574        let labels = arr1(&[0, 0, 1, 1]);
575
576        let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
577
578        assert!(
579            f64::EPSILON > (pairwise_distances[0] - 0.0).abs(),
580            "{} != 0.0",
581            pairwise_distances[0]
582        );
583        assert!(
584            f64::EPSILON > (pairwise_distances[1] - 0.0).abs(),
585            "{} != 0.0",
586            pairwise_distances[1]
587        );
588
589        let samples = arr2(&[[1.0, 2.0], [1.0, 1.0], [2.0, 2.0], [2.0, 3.0]]);
590
591        let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
592
593        assert!(
594            f64::EPSILON > (pairwise_distances[0] - 2.0).abs(),
595            "{} != 2.0",
596            pairwise_distances[0]
597        );
598        assert!(
599            f64::EPSILON > (pairwise_distances[1] - 2.0).abs(),
600            "{} != 2.0",
601            pairwise_distances[1]
602        );
603    }
604
605    #[test]
606    fn test_convert_to_vec() {
607        let data = vec![
608            Analysis::new([1.0; NUMBER_FEATURES]),
609            Analysis::new([2.0; NUMBER_FEATURES]),
610            Analysis::new([3.0; NUMBER_FEATURES]),
611        ];
612
613        let array = convert_to_array(data);
614
615        assert_eq!(array.0.shape(), &[3, NUMBER_FEATURES]);
616        assert!(
617            f64::EPSILON > (array.0[[0, 0]] - 1.0).abs(),
618            "{} != 1.0",
619            array.0[[0, 0]]
620        );
621        assert!(
622            f64::EPSILON > (array.0[[1, 0]] - 2.0).abs(),
623            "{} != 2.0",
624            array.0[[1, 0]]
625        );
626        assert!(
627            f64::EPSILON > (array.0[[2, 0]] - 3.0).abs(),
628            "{} != 3.0",
629            array.0[[2, 0]]
630        );
631
632        // check that axis iteration works how we expect
633        // axis 0
634        let mut iter = array.0.axis_iter(Axis(0));
635        assert_eq!(iter.next().unwrap().to_vec(), vec![1.0; NUMBER_FEATURES]);
636        assert_eq!(iter.next().unwrap().to_vec(), vec![2.0; NUMBER_FEATURES]);
637        assert_eq!(iter.next().unwrap().to_vec(), vec![3.0; NUMBER_FEATURES]);
638        // axis 1
639        for column in array.0.axis_iter(Axis(1)) {
640            assert_eq!(column.to_vec(), vec![1.0, 2.0, 3.0]);
641        }
642    }
643
644    #[test]
645    fn test_calc_within_dispersion() {
646        let labels = arr1(&[0, 1, 0, 1]);
647        let pairwise_distances = arr1(&[1.0, 2.0]);
648        let result = calc_within_dispersion(labels.view(), 2, pairwise_distances.view());
649
650        // `W_k = \sum_{r=1}^{k} \frac{D_r}{2*n_r}` = 1/4 * 1.0 + 1/4 * 2.0 = 0.25 + 0.5 = 0.75
651        assert!(f64::EPSILON > (result - 0.75).abs(), "{} != 0.75", result);
652    }
653
654    #[rstest]
655    #[case::project_none(ProjectionMethod::None, NUMBER_FEATURES)]
656    #[case::project_tsne(ProjectionMethod::TSne, EMBEDDING_SIZE)]
657    #[case::project_pca(ProjectionMethod::Pca, EMBEDDING_SIZE)]
658    fn test_project(
659        #[case] projection_method: ProjectionMethod,
660        #[case] expected_embedding_size: usize,
661    ) {
662        // generate 100 random samples, we use a normal distribution because with a uniform distribution
663        // the data has no real "principle components" and PCA will not work as expected since almost all the eigenvalues
664        // with fall below the cutoff
665        let mut samples = Array2::random((100, NUMBER_FEATURES), StandardNormal);
666        normalize_embeddings_inplace::<NUMBER_FEATURES>(&mut samples);
667        let samples = AnalysisArray(samples);
668
669        let result = projection_method.project(samples).unwrap();
670
671        // ensure embeddings are the correct shape
672        assert_eq!(result.shape(), &[100, expected_embedding_size]);
673
674        // ensure the data is normalized
675        for i in 0..expected_embedding_size {
676            let min = result.column(i).min();
677            let max = result.column(i).max();
678            assert!(
679                f64::EPSILON > (min + 1.0).abs(),
680                "Min value of column {i} is not -1.0: {min}",
681            );
682            assert!(
683                f64::EPSILON > (max - 1.0).abs(),
684                "Max value of column {i} is not 1.0: {max}",
685            );
686        }
687    }
688}
689
690// #[cfg(feature = "plot_gap")]
691// fn plot_gap_statistic(data: Vec<(usize, f64, f64)>) {
692//     use plotters::prelude::*;
693
694//     // Assuming data is a Vec<(usize, f64, f64)> of (k, gap_k, s_k)
695//     let root_area = BitMapBackend::new("gap_statistic_plot.png", (640, 480)).into_drawing_area();
696//     root_area.fill(&WHITE).unwrap();
697
698//     let max_gap_k = data
699//         .iter()
700//         .map(|(_, gap_k, _)| *gap_k)
701//         .fold(f64::MIN, f64::max);
702//     let min_gap_k = data
703//         .iter()
704//         .map(|(_, gap_k, _)| *gap_k)
705//         .fold(f64::MAX, f64::min);
706//     let max_k = data.iter().map(|(k, _, _)| *k).max().unwrap_or(0);
707
708//     let mut chart = ChartBuilder::on(&root_area)
709//         .caption("Gap Statistic Plot", ("sans-serif", 30))
710//         .margin(5)
711//         .x_label_area_size(30)
712//         .y_label_area_size(30)
713//         .build_cartesian_2d(0..max_k, min_gap_k..max_gap_k)
714//         .unwrap();
715
716//     chart.configure_mesh().draw().unwrap();
717
718//     for (k, gap_k, s_k) in data {
719//         chart
720//             .draw_series(PointSeries::of_element(
721//                 vec![(k, gap_k)],
722//                 5,
723//                 &RED,
724//                 &|coord, size, style| {
725//                     EmptyElement::at(coord) + Circle::new((0, 0), size, style.filled())
726//                 },
727//             ))
728//             .unwrap();
729
730//         // Drawing error bars
731//         chart
732//             .draw_series(LineSeries::new(
733//                 vec![(k, gap_k - s_k), (k, gap_k + s_k)],
734//                 &BLACK,
735//             ))
736//             .unwrap();
737//     }
738
739//     root_area.present().unwrap();
740// }