mecomp_analysis/
clustering.rs

1//! this module contains helpers that wrap the a k-means crate to perform clustering on the data
2//! without having to choose an exact number of clusters.
3//!
4//! Instead, you provide the minimum and maximum number of clusters you want to try, and we'll
5//! use one of a range of methods to determine the optimal number of clusters.
6//!
7//! # References:
8//!
9//! - The gap statistic [R. Tibshirani, G. Walther, and T. Hastie (Standford University, 2001)](https://hastie.su.domains/Papers/gap.pdf)
10//! - The Davies-Bouldin index [wikipedia](https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index)
11
12use linfa::prelude::*;
13use linfa_clustering::{GaussianMixtureModel, KMeans};
14use linfa_nn::distance::{Distance, L2Dist};
15use linfa_reduction::Pca;
16use linfa_tsne::TSneParams;
17use log::{debug, info};
18use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis, Dim};
19use ndarray_rand::RandomExt;
20use rand::distributions::Uniform;
21use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
22use statrs::statistics::Statistics;
23
24use crate::{
25    Analysis, Feature, NUMBER_FEATURES,
26    errors::{ClusteringError, ProjectionError},
27};
28
29pub struct AnalysisArray(pub(crate) Array2<Feature>);
30
31impl From<Vec<Analysis>> for AnalysisArray {
32    #[inline]
33    fn from(data: Vec<Analysis>) -> Self {
34        let shape = (data.len(), NUMBER_FEATURES);
35        debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
36
37        Self(
38            Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
39                .expect("Failed to convert to array, shape mismatch"),
40        )
41    }
42}
43
44impl From<Vec<[Feature; NUMBER_FEATURES]>> for AnalysisArray {
45    #[inline]
46    fn from(data: Vec<[Feature; NUMBER_FEATURES]>) -> Self {
47        let shape = (data.len(), NUMBER_FEATURES);
48        debug_assert_eq!(shape, (data.len(), data[0].len()));
49
50        Self(
51            Array2::from_shape_vec(shape, data.into_iter().flatten().collect())
52                .expect("Failed to convert to array, shape mismatch"),
53        )
54    }
55}
56
57pub type FitDataset = Dataset<Feature, (), Dim<[usize; 1]>>;
58
59#[derive(Clone, Copy, Debug)]
60#[allow(clippy::module_name_repetitions)]
61pub enum ClusteringMethod {
62    KMeans,
63    GaussianMixtureModel,
64}
65
66impl ClusteringMethod {
67    /// Fit the clustering method to the dataset and get the Labels
68    #[must_use]
69    fn fit(self, k: usize, data: &FitDataset) -> Array1<usize> {
70        match self {
71            Self::KMeans => {
72                let model = KMeans::params(k)
73                    // .max_n_iterations(MAX_ITERATIONS)
74                    .fit(data)
75                    .unwrap();
76                model.predict(data.records())
77            }
78            Self::GaussianMixtureModel => {
79                let model = GaussianMixtureModel::params(k)
80                    .init_method(linfa_clustering::GmmInitMethod::KMeans)
81                    .n_runs(10)
82                    .fit(data)
83                    .unwrap();
84                model.predict(data.records())
85            }
86        }
87    }
88}
89
90#[derive(Clone, Copy, Debug)]
91pub enum KOptimal {
92    GapStatistic {
93        /// The number of reference datasets to generate
94        b: u32,
95    },
96    DaviesBouldin,
97}
98
99#[derive(Clone, Copy, Debug, Default)]
100/// Should the data be projected into a lower-dimensional space before clustering, if so how?
101pub enum ProjectionMethod {
102    /// Use t-SNE to project the data into a lower-dimensional space
103    TSne,
104    /// Use PCA to project the data into a lower-dimensional space
105    Pca,
106    #[default]
107    /// Don't project the data
108    None,
109}
110
111impl ProjectionMethod {
112    /// Project the data into a lower-dimensional space
113    ///
114    /// # Errors
115    ///
116    /// Will return an error if there was an error projecting the data into a lower-dimensional space
117    #[inline]
118    pub fn project(self, samples: AnalysisArray) -> Result<Array2<Feature>, ProjectionError> {
119        let result = match self {
120            Self::TSne => {
121                let nrecords = samples.0.nrows();
122                // first use the t-SNE algorithm to project the data into a lower-dimensional space
123                debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using t-SNE");
124                #[allow(clippy::cast_precision_loss)]
125                let mut embeddings = TSneParams::embedding_size(EMBEDDING_SIZE)
126                    .perplexity(f64::max(samples.0.nrows() as f64 / 20., 5.))
127                    .approx_threshold(0.5)
128                    .transform(samples.0)?;
129                debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
130
131                // normalize the embeddings so each dimension is between -1 and 1
132                debug!("Normalizing embeddings");
133                normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
134                embeddings
135            }
136            Self::Pca => {
137                let nrecords = samples.0.nrows();
138                // use the PCA algorithm to project the data into a lower-dimensional space
139                debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using PCA");
140                let data = Dataset::from(samples.0);
141                let pca: Pca<f64> = Pca::params(EMBEDDING_SIZE).whiten(true).fit(&data)?;
142                let mut embeddings = pca.predict(&data);
143                debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
144
145                // normalize the embeddings so each dimension is between -1 and 1
146                debug!("Normalizing embeddings");
147                normalize_embeddings_inplace::<EMBEDDING_SIZE>(&mut embeddings);
148                embeddings
149            }
150            Self::None => {
151                debug!("Using original data as embeddings");
152                samples.0
153            }
154        };
155        debug!("Embeddings shape: {:?}", result.shape());
156        Ok(result)
157    }
158}
159
160// Normalize the embeddings to between 0.0 and 1.0, in-place.
161// Pass the embedding size as an argument to enable more compiler optimizations
162fn normalize_embeddings_inplace<const SIZE: usize>(embeddings: &mut Array2<f64>) {
163    for i in 0..SIZE {
164        let min = embeddings.column(i).min();
165        let max = embeddings.column(i).max();
166        let range = max - min;
167        embeddings
168            .column_mut(i)
169            .mapv_inplace(|v| ((v - min) / range).mul_add(2., -1.));
170    }
171}
172
173// log the number of features
174/// Dimensionality that the T-SNE and PCA projection methods aim to project the data into.
175const EMBEDDING_SIZE: usize = {
176    let log2 = usize::ilog2(NUMBER_FEATURES) as usize;
177    if log2 < 2 { 2 } else { log2 }
178};
179
180#[allow(clippy::module_name_repetitions)]
181pub struct ClusteringHelper<S>
182where
183    S: Sized,
184{
185    state: S,
186}
187
188pub struct EntryPoint;
189pub struct NotInitialized {
190    /// The embeddings of our input, as a Nx`EMBEDDING_SIZE` array
191    embeddings: Array2<Feature>,
192    pub k_max: usize,
193    pub optimizer: KOptimal,
194    pub clustering_method: ClusteringMethod,
195}
196pub struct Initialized {
197    /// The embeddings of our input, as a Nx`EMBEDDING_SIZE` array
198    embeddings: Array2<Feature>,
199    pub k: usize,
200    pub clustering_method: ClusteringMethod,
201}
202pub struct Finished {
203    /// The labelings of the samples, as a Nx1 array.
204    /// Each element is the cluster that the corresponding sample belongs to.
205    labels: Array1<usize>,
206    pub k: usize,
207}
208
209/// Functions available for all states
210impl ClusteringHelper<EntryPoint> {
211    /// Create a new `KMeansHelper` object
212    ///
213    /// # Errors
214    ///
215    /// Will return an error if there was an error projecting the data into a lower-dimensional space
216    #[allow(clippy::missing_inline_in_public_items)]
217    pub fn new(
218        samples: AnalysisArray,
219        k_max: usize,
220        optimizer: KOptimal,
221        clustering_method: ClusteringMethod,
222        projection_method: ProjectionMethod,
223    ) -> Result<ClusteringHelper<NotInitialized>, ClusteringError> {
224        if samples.0.nrows() <= 15 {
225            return Err(ClusteringError::SmallLibrary);
226        }
227
228        // project the data into a lower-dimensional space
229        let embeddings = projection_method.project(samples)?;
230
231        Ok(ClusteringHelper {
232            state: NotInitialized {
233                embeddings,
234                k_max,
235                optimizer,
236                clustering_method,
237            },
238        })
239    }
240}
241
242/// Functions available for `NotInitialized` state
243impl ClusteringHelper<NotInitialized> {
244    /// Initialize the `KMeansHelper` object
245    ///
246    /// # Errors
247    ///
248    /// Will return an error if there was an error calculating the optimal number of clusters
249    #[inline]
250    pub fn initialize(self) -> Result<ClusteringHelper<Initialized>, ClusteringError> {
251        let k = self.get_optimal_k()?;
252        Ok(ClusteringHelper {
253            state: Initialized {
254                embeddings: self.state.embeddings,
255                k,
256                clustering_method: self.state.clustering_method,
257            },
258        })
259    }
260
261    fn get_optimal_k(&self) -> Result<usize, ClusteringError> {
262        match self.state.optimizer {
263            KOptimal::GapStatistic { b } => self.get_optimal_k_gap_statistic(b),
264            KOptimal::DaviesBouldin => self.get_optimal_k_davies_bouldin(),
265        }
266    }
267
268    /// Get the optimal number of clusters using the gap statistic
269    ///
270    /// # References:
271    ///
272    /// - [R. Tibshirani, G. Walther, and T. Hastie (Standford University, 2001)](https://hastie.su.domains/Papers/gap.pdf)
273    ///
274    /// # Algorithm:
275    ///
276    /// 1. Cluster the observed data, varying the number of clusters from k = 1, …, kmax, and compute the corresponding total within intra-cluster variation Wk.
277    /// 2. Generate B reference data sets with a random uniform distribution. Cluster each of these reference data sets with varying number of clusters k = 1, …, kmax,
278    ///    and compute the corresponding total within intra-cluster variation `W_{kb}`.
279    /// 3. Compute the estimated gap statistic as the deviation of the observed `W_k` value from its expected value `W_kb` under the null hypothesis:
280    ///    `Gap(k)=(1/B) \sum_{b=1}^{B} \log(W^*_{kb}) − \log(W_k)`.
281    ///    Compute also the standard deviation of the statistics.
282    /// 4. Choose the number of clusters as the smallest value of k such that the gap statistic is within one standard deviation of the gap at k+1:
283    ///    `Gap(k)≥Gap(k + 1)−s_{k + 1}`.
284    fn get_optimal_k_gap_statistic(&self, b: u32) -> Result<usize, ClusteringError> {
285        let embedding_dataset = Dataset::from(self.state.embeddings.clone());
286
287        // our reference data sets
288        let reference_datasets =
289            generate_reference_datasets(embedding_dataset.records().view(), b as usize);
290
291        let b = f64::from(b);
292
293        let results = (1..=self.state.k_max)
294            // for each k, cluster the data into k clusters
295            .map(|k| {
296                debug!("Fitting k-means to embeddings with k={k}");
297                let labels = self.state.clustering_method.fit(k, &embedding_dataset);
298                (k, labels)
299            })
300            // for each k, calculate the gap statistic, and the standard deviation of the statistics
301            .map(|(k, labels)| {
302                // calculate the within intra-cluster variation for the reference data sets
303                debug!(
304                    "Calculating within intra-cluster variation for reference data sets with k={k}"
305                );
306                let w_kb_log: Vec<_> = reference_datasets
307                    .par_iter()
308                    .map(|ref_data| {
309                        // cluster the reference data into k clusters
310                        let ref_labels = self.state.clustering_method.fit(k, ref_data);
311                        // calculate the within intra-cluster variation for the reference data
312                        let ref_pairwise_distances = calc_pairwise_distances(
313                            ref_data.records().view(),
314                            k,
315                            ref_labels.view(),
316                        );
317                        calc_within_dispersion(ref_labels.view(), k, ref_pairwise_distances.view())
318                            .log2()
319                    })
320                    .collect();
321                let w_kb_log = Array::from_vec(w_kb_log);
322
323                // calculate the within intra-cluster variation for the observed data
324                let pairwise_distances =
325                    calc_pairwise_distances(self.state.embeddings.view(), k, labels.view());
326                let w_k = calc_within_dispersion(labels.view(), k, pairwise_distances.view());
327
328                // finally, calculate the gap statistic
329                let w_kb_log_sum: f64 = w_kb_log.sum();
330                // original formula: l = (1 / B) * sum_b(log(W_kb))
331                let l = b.recip() * w_kb_log_sum;
332                // original formula: gap_k = (1 / B) * sum_b(log(W_kb)) - log(W_k)
333                let gap_k = l - w_k.log2();
334                // original formula: sd_k = [(1 / B) * sum_b((log(W_kb) - l)^2)]^0.5
335                let standard_deviation = (b.recip() * (w_kb_log - l).pow2().sum()).sqrt();
336                // original formula: s_k = sd_k * (1 + 1 / B)^0.5
337                // calculate differently to minimize rounding errors
338                let s_k = standard_deviation * (1.0 + b.recip()).sqrt();
339
340                (k, gap_k, s_k)
341            });
342
343        // // plot the gap_k (whisker with s_k) w.r.t. k
344        // #[cfg(feature = "plot_gap")]
345        // plot_gap_statistic(results.clone().collect::<Vec<_>>());
346
347        // finally, consume the results iterator until we find the optimal k
348        let (mut optimal_k, mut gap_k_minus_one) = (None, None);
349        for (k, gap_k, s_k) in results {
350            info!("k: {k}, gap_k: {gap_k}, s_k: {s_k}");
351
352            if let Some(gap_k_minus_one) = gap_k_minus_one
353                && gap_k_minus_one >= gap_k - s_k
354            {
355                info!("Optimal k found: {}", k - 1);
356                optimal_k = Some(k - 1);
357                break;
358            }
359
360            gap_k_minus_one = Some(gap_k);
361        }
362
363        optimal_k.ok_or(ClusteringError::OptimalKNotFound(self.state.k_max))
364    }
365
366    fn get_optimal_k_davies_bouldin(&self) -> Result<usize, ClusteringError> {
367        todo!();
368    }
369}
370
371/// Convert a vector of Analyses into a 2D array
372///
373/// # Panics
374///
375/// Will panic if the shape of the data does not match the number of features, should never happen
376#[must_use]
377#[inline]
378pub fn convert_to_array(data: Vec<Analysis>) -> AnalysisArray {
379    // Convert vector to Array
380    let shape = (data.len(), NUMBER_FEATURES);
381    debug_assert_eq!(NUMBER_FEATURES, data[0].inner().len());
382
383    AnalysisArray(
384        Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
385            .expect("Failed to convert to array, shape mismatch"),
386    )
387}
388
389/// Generate B reference data sets with a random uniform distribution
390///
391/// (excerpt from reference paper)
392/// """
393/// We consider two choices for the reference distribution:
394///
395/// 1. generate each reference feature uniformly over the range of the observed values for that feature.
396/// 2. generate the reference features from a uniform distribution over a box aligned with the
397///    principle components of the data.
398///    In detail, if X is our n by p data matrix, we assume that the columns have mean 0 and compute
399///    the singular value decomposition X = UDV^T. We transform via X' = XV and then draw uniform features Z'
400///    over the ranges of the columns of X', as in method (1) above.
401///    Finally, we back-transform via Z=Z'V^T to give reference data Z.
402///
403/// Method (1) has the advantage of simplicity. Method (2) takes into account the shape of the
404/// data distribution and makes the procedure rotationally invariant, as long as the
405/// clustering method itself is invariant
406/// """
407///
408/// For now, we will use method (1) as it is simpler to implement
409/// and we know that our data is already normalized and that
410/// the ordering of features is important, meaning that we can't
411/// rotate the data anyway.
412fn generate_reference_datasets(samples: ArrayView2<'_, Feature>, b: usize) -> Vec<FitDataset> {
413    (0..b)
414        .into_par_iter()
415        .map(|_| Dataset::from(generate_ref_single(samples.view())))
416        .collect()
417}
418fn generate_ref_single(samples: ArrayView2<'_, Feature>) -> Array2<Feature> {
419    let feature_distributions = samples
420        .axis_iter(Axis(1))
421        .map(|feature| Array::random(feature.dim(), Uniform::new(feature.min(), feature.max())))
422        .collect::<Vec<_>>();
423    let feature_dists_views = feature_distributions
424        .iter()
425        .map(ndarray::ArrayBase::view)
426        .collect::<Vec<_>>();
427    ndarray::stack(Axis(0), &feature_dists_views)
428        .unwrap()
429        .t()
430        .to_owned()
431}
432
433/// Calculate `W_k`, the within intra-cluster variation for the given clustering
434///
435/// `W_k = \sum_{r=1}^{k} \frac{D_r}{2*n_r}`
436fn calc_within_dispersion(
437    labels: ArrayView1<'_, usize>,
438    k: usize,
439    pairwise_distances: ArrayView1<'_, Feature>,
440) -> Feature {
441    debug_assert_eq!(k, labels.iter().max().unwrap() + 1);
442
443    // we first need to convert our list of labels into a list of the number of samples in each cluster
444    let counts = labels.iter().fold(vec![0u32; k], |mut counts, &label| {
445        counts[label] += 1;
446        counts
447    });
448    // then, we calculate the within intra-cluster variation
449    counts
450        .iter()
451        .zip(pairwise_distances.iter())
452        .map(|(&count, distance)| (2.0 * f64::from(count)).recip() * distance)
453        .sum()
454}
455
456/// Calculate the `D_r` array, the sum of the pairwise distances in cluster r, for all clusters in the given clustering
457///
458/// # Arguments
459///
460/// - `samples`: The samples in the dataset
461/// - `k`: The number of clusters
462/// - `labels`: The cluster labels for each sample
463fn calc_pairwise_distances(
464    samples: ArrayView2<'_, Feature>,
465    k: usize,
466    labels: ArrayView1<'_, usize>,
467) -> Array1<Feature> {
468    debug_assert_eq!(
469        samples.nrows(),
470        labels.len(),
471        "Samples and labels must have the same length"
472    );
473    debug_assert_eq!(
474        k,
475        labels.iter().max().unwrap() + 1,
476        "Labels must be in the range 0..k"
477    );
478
479    // for each cluster, calculate the sum of the pairwise distances between samples in that cluster
480    let mut distances = Array1::zeros(k);
481    let mut clusters = vec![Vec::new(); k];
482    // build clusters
483    for (sample, label) in samples.outer_iter().zip(labels.iter()) {
484        clusters[*label].push(sample);
485    }
486    // calculate pairwise dist. within each cluster
487    for (k, cluster) in clusters.iter().enumerate() {
488        let mut pairwise_dists = 0.;
489        for i in 0..cluster.len() - 1 {
490            let a = cluster[i];
491            let rest = &cluster[i + 1..];
492            for &b in rest {
493                pairwise_dists += L2Dist.distance(a, b);
494            }
495        }
496        distances[k] += pairwise_dists + pairwise_dists;
497    }
498    distances
499}
500
501/// Functions available for Initialized state
502impl ClusteringHelper<Initialized> {
503    /// Cluster the data into k clusters
504    ///
505    /// # Errors
506    ///
507    /// Will return an error if the clustering fails
508    #[must_use]
509    #[inline]
510    pub fn cluster(self) -> ClusteringHelper<Finished> {
511        let Initialized {
512            clustering_method,
513            embeddings,
514            k,
515        } = self.state;
516
517        let embedding_dataset = Dataset::from(embeddings);
518        let labels = clustering_method.fit(k, &embedding_dataset);
519
520        ClusteringHelper {
521            state: Finished { labels, k },
522        }
523    }
524}
525
526/// Functions available for Finished state
527impl ClusteringHelper<Finished> {
528    /// use the labels to reorganize the provided samples into clusters
529    #[must_use]
530    #[inline]
531    pub fn extract_analysis_clusters<T: Clone>(&self, samples: Vec<T>) -> Vec<Vec<T>> {
532        let mut clusters = vec![Vec::new(); self.state.k];
533
534        for (sample, &label) in samples.into_iter().zip(self.state.labels.iter()) {
535            clusters[label].push(sample);
536        }
537
538        clusters
539    }
540}
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545    use ndarray::{arr1, arr2, s};
546    use ndarray_rand::rand_distr::StandardNormal;
547    use pretty_assertions::assert_eq;
548    use rstest::rstest;
549
550    #[test]
551    fn test_generate_reference_data_set() {
552        let data = arr2(&[[10.0, -10.0], [20.0, -20.0], [30.0, -30.0]]);
553
554        let ref_data = generate_ref_single(data.view());
555
556        // First column all vals between 10.0 and 30.0
557        assert!(
558            ref_data
559                .slice(s![.., 0])
560                .iter()
561                .all(|v| *v >= 10.0 && *v <= 30.0)
562        );
563
564        // Second column all vals between -10.0 and -30.0
565        assert!(
566            ref_data
567                .slice(s![.., 1])
568                .iter()
569                .all(|v| *v <= -10.0 && *v >= -30.0)
570        );
571
572        // check that the shape is correct
573        assert_eq!(ref_data.shape(), data.shape());
574
575        // check that the data is not the same as the original data
576        assert_ne!(ref_data, data);
577    }
578
579    #[test]
580    fn test_pairwise_distances() {
581        let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
582        let labels = arr1(&[0, 0, 1, 1]);
583
584        let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
585
586        assert!(
587            f64::EPSILON > (pairwise_distances[0] - 0.0).abs(),
588            "{} != 0.0",
589            pairwise_distances[0]
590        );
591        assert!(
592            f64::EPSILON > (pairwise_distances[1] - 0.0).abs(),
593            "{} != 0.0",
594            pairwise_distances[1]
595        );
596
597        let samples = arr2(&[[1.0, 2.0], [1.0, 1.0], [2.0, 2.0], [2.0, 3.0]]);
598
599        let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
600
601        assert!(
602            f64::EPSILON > (pairwise_distances[0] - 2.0).abs(),
603            "{} != 2.0",
604            pairwise_distances[0]
605        );
606        assert!(
607            f64::EPSILON > (pairwise_distances[1] - 2.0).abs(),
608            "{} != 2.0",
609            pairwise_distances[1]
610        );
611    }
612
613    #[test]
614    fn test_convert_to_vec() {
615        let data = vec![
616            Analysis::new([1.0; NUMBER_FEATURES]),
617            Analysis::new([2.0; NUMBER_FEATURES]),
618            Analysis::new([3.0; NUMBER_FEATURES]),
619        ];
620
621        let array = convert_to_array(data);
622
623        assert_eq!(array.0.shape(), &[3, NUMBER_FEATURES]);
624        assert!(
625            f64::EPSILON > (array.0[[0, 0]] - 1.0).abs(),
626            "{} != 1.0",
627            array.0[[0, 0]]
628        );
629        assert!(
630            f64::EPSILON > (array.0[[1, 0]] - 2.0).abs(),
631            "{} != 2.0",
632            array.0[[1, 0]]
633        );
634        assert!(
635            f64::EPSILON > (array.0[[2, 0]] - 3.0).abs(),
636            "{} != 3.0",
637            array.0[[2, 0]]
638        );
639
640        // check that axis iteration works how we expect
641        // axis 0
642        let mut iter = array.0.axis_iter(Axis(0));
643        assert_eq!(iter.next().unwrap().to_vec(), vec![1.0; NUMBER_FEATURES]);
644        assert_eq!(iter.next().unwrap().to_vec(), vec![2.0; NUMBER_FEATURES]);
645        assert_eq!(iter.next().unwrap().to_vec(), vec![3.0; NUMBER_FEATURES]);
646        // axis 1
647        for column in array.0.axis_iter(Axis(1)) {
648            assert_eq!(column.to_vec(), vec![1.0, 2.0, 3.0]);
649        }
650    }
651
652    #[test]
653    fn test_calc_within_dispersion() {
654        let labels = arr1(&[0, 1, 0, 1]);
655        let pairwise_distances = arr1(&[1.0, 2.0]);
656        let result = calc_within_dispersion(labels.view(), 2, pairwise_distances.view());
657
658        // `W_k = \sum_{r=1}^{k} \frac{D_r}{2*n_r}` = 1/4 * 1.0 + 1/4 * 2.0 = 0.25 + 0.5 = 0.75
659        assert!(f64::EPSILON > (result - 0.75).abs(), "{result} != 0.75");
660    }
661
662    #[rstest]
663    #[case::project_none(ProjectionMethod::None, NUMBER_FEATURES)]
664    #[case::project_tsne(ProjectionMethod::TSne, EMBEDDING_SIZE)]
665    #[case::project_pca(ProjectionMethod::Pca, EMBEDDING_SIZE)]
666    fn test_project(
667        #[case] projection_method: ProjectionMethod,
668        #[case] expected_embedding_size: usize,
669    ) {
670        // generate 100 random samples, we use a normal distribution because with a uniform distribution
671        // the data has no real "principle components" and PCA will not work as expected since almost all the eigenvalues
672        // with fall below the cutoff
673        let mut samples = Array2::random((100, NUMBER_FEATURES), StandardNormal);
674        normalize_embeddings_inplace::<NUMBER_FEATURES>(&mut samples);
675        let samples = AnalysisArray(samples);
676
677        let result = projection_method.project(samples).unwrap();
678
679        // ensure embeddings are the correct shape
680        assert_eq!(result.shape(), &[100, expected_embedding_size]);
681
682        // ensure the data is normalized
683        for i in 0..expected_embedding_size {
684            let min = result.column(i).min();
685            let max = result.column(i).max();
686            assert!(
687                f64::EPSILON > (min + 1.0).abs(),
688                "Min value of column {i} is not -1.0: {min}",
689            );
690            assert!(
691                f64::EPSILON > (max - 1.0).abs(),
692                "Max value of column {i} is not 1.0: {max}",
693            );
694        }
695    }
696}
697
698// #[cfg(feature = "plot_gap")]
699// fn plot_gap_statistic(data: Vec<(usize, f64, f64)>) {
700//     use plotters::prelude::*;
701
702//     // Assuming data is a Vec<(usize, f64, f64)> of (k, gap_k, s_k)
703//     let root_area = BitMapBackend::new("gap_statistic_plot.png", (640, 480)).into_drawing_area();
704//     root_area.fill(&WHITE).unwrap();
705
706//     let max_gap_k = data
707//         .iter()
708//         .map(|(_, gap_k, _)| *gap_k)
709//         .fold(f64::MIN, f64::max);
710//     let min_gap_k = data
711//         .iter()
712//         .map(|(_, gap_k, _)| *gap_k)
713//         .fold(f64::MAX, f64::min);
714//     let max_k = data.iter().map(|(k, _, _)| *k).max().unwrap_or(0);
715
716//     let mut chart = ChartBuilder::on(&root_area)
717//         .caption("Gap Statistic Plot", ("sans-serif", 30))
718//         .margin(5)
719//         .x_label_area_size(30)
720//         .y_label_area_size(30)
721//         .build_cartesian_2d(0..max_k, min_gap_k..max_gap_k)
722//         .unwrap();
723
724//     chart.configure_mesh().draw().unwrap();
725
726//     for (k, gap_k, s_k) in data {
727//         chart
728//             .draw_series(PointSeries::of_element(
729//                 vec![(k, gap_k)],
730//                 5,
731//                 &RED,
732//                 &|coord, size, style| {
733//                     EmptyElement::at(coord) + Circle::new((0, 0), size, style.filled())
734//                 },
735//             ))
736//             .unwrap();
737
738//         // Drawing error bars
739//         chart
740//             .draw_series(LineSeries::new(
741//                 vec![(k, gap_k - s_k), (k, gap_k + s_k)],
742//                 &BLACK,
743//             ))
744//             .unwrap();
745//     }
746
747//     root_area.present().unwrap();
748// }