Skip to main content

mecomp_analysis/
clustering.rs

1//! this module contains helpers that wrap the a k-means crate to perform clustering on the data
2//! without having to choose an exact number of clusters.
3//!
4//! Instead, you provide the minimum and maximum number of clusters you want to try, and we'll
5//! use one of a range of methods to determine the optimal number of clusters.
6//!
7//! # References:
8//!
9//! - The gap statistic [R. Tibshirani, G. Walther, and T. Hastie (Standford University, 2001)](https://hastie.su.domains/Papers/gap.pdf)
10//! - The Davies-Bouldin index [wikipedia](https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index)
11
12use linfa::prelude::*;
13use linfa_clustering::{GaussianMixtureModel, GmmError, KMeans};
14use linfa_nn::distance::{Distance, L2Dist};
15use linfa_reduction::Pca;
16use linfa_tsne::TSneParams;
17use log::{debug, info};
18use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis, Dim};
19use ndarray_rand::RandomExt;
20use ndarray_stats::QuantileExt;
21use rand::distributions::Uniform;
22use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
23
24use crate::{
25    DIM_EMBEDDING, Feature, NUMBER_FEATURES,
26    errors::{ClusteringError, ProjectionError},
27};
28
29pub type FitDataset = Dataset<Feature, (), Dim<[usize; 1]>>;
30
31pub type ClusteringResult<T> = Result<T, ClusteringError>;
32
33#[derive(Clone, Copy, Debug)]
34#[allow(clippy::module_name_repetitions)]
35pub enum ClusteringMethod {
36    KMeans,
37    GaussianMixtureModel,
38}
39
40impl ClusteringMethod {
41    /// Fit the clustering method to the dataset and get the Labels
42    fn fit(self, k: usize, data: &FitDataset) -> ClusteringResult<Array1<usize>> {
43        match self {
44            Self::KMeans => {
45                let model = KMeans::params(k)
46                    // .max_n_iterations(MAX_ITERATIONS)
47                    .fit(data)?;
48                Ok(model.predict(data.records()))
49            }
50            Self::GaussianMixtureModel => {
51                let model = GaussianMixtureModel::params(k)
52                    .init_method(linfa_clustering::GmmInitMethod::KMeans)
53                    .reg_covariance(5e-3)
54                    .n_runs(10)
55                    .fit(data)
56                    .inspect_err(|e| debug!("GMM fitting failed with k={k}: {e:?}, if this continues try using KMeans instead"))?;
57                Ok(model.predict(data.records()))
58            }
59        }
60    }
61}
62
63#[derive(Clone, Copy, Debug)]
64pub enum KOptimal {
65    GapStatistic {
66        /// The number of reference datasets to generate
67        b: u32,
68    },
69    DaviesBouldin,
70}
71
72#[derive(Clone, Copy, Debug, Default)]
73/// Should the data be projected into a lower-dimensional space before clustering, if so how?
74pub enum ProjectionMethod {
75    /// Use t-SNE to project the data into a lower-dimensional space
76    TSne,
77    /// Use PCA to project the data into a lower-dimensional space
78    Pca,
79    #[default]
80    /// Don't project the data
81    None,
82}
83
84impl ProjectionMethod {
85    /// Project the data into a lower-dimensional space
86    ///
87    /// # Errors
88    ///
89    /// Will return an error if there was an error projecting the data into a lower-dimensional space
90    #[inline]
91    pub fn project(self, samples: Array2<Feature>) -> Result<Array2<Feature>, ProjectionError> {
92        let nrecords = samples.nrows();
93        let ncols = samples.ncols();
94        let result = match self {
95            Self::TSne => {
96                // first, preprocess the data with PCA into a intermediate dimensionality
97                // to speed up t-SNE
98                let intermediate_dim = ncols.midpoint(EMBEDDING_SIZE).midpoint(EMBEDDING_SIZE);
99                debug!(
100                    "Preprocessing data with PCA to speed up t-SNE ({ncols} -> {intermediate_dim})"
101                );
102                let data = Dataset::from(samples.mapv(f64::from));
103                let pca: Pca<f64> = Pca::params(intermediate_dim).fit(&data)?;
104                #[allow(clippy::cast_possible_truncation)]
105                let pca_samples = pca.predict(&data).mapv(|f| f as Feature);
106
107                // then use the t-SNE algorithm to project the data into a lower-dimensional space
108                debug!(
109                    "Generating embeddings (size: {intermediate_dim} -> {EMBEDDING_SIZE}) using t-SNE"
110                );
111                #[allow(clippy::cast_precision_loss)]
112                let mut embeddings = TSneParams::embedding_size(EMBEDDING_SIZE)
113                    .perplexity(f32::max(nrecords as f32 / 20., 5.))
114                    .approx_threshold(0.5)
115                    .max_iter(1000)
116                    .transform(pca_samples)?;
117                debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
118
119                // normalize the embeddings so each dimension is between -1 and 1
120                debug!("Normalizing embeddings");
121                normalize_embeddings_inplace(&mut embeddings);
122                embeddings
123            }
124            Self::Pca => {
125                // use the PCA algorithm to project the data into a lower-dimensional space
126                debug!("Generating embeddings (size: {ncols} -> {EMBEDDING_SIZE}) using PCA");
127                // linfa_reduction::pca::PCA only works for f64, see: https://github.com/rust-ml/linfa/issues/232
128                let data = Dataset::from(samples.mapv(f64::from));
129                let pca: Pca<f64> = Pca::params(EMBEDDING_SIZE).whiten(true).fit(&data)?;
130                #[allow(clippy::cast_possible_truncation)]
131                let mut embeddings = pca.predict(&data).mapv(|f| f as Feature);
132                debug_assert_eq!(embeddings.shape(), &[nrecords, EMBEDDING_SIZE]);
133
134                // normalize the embeddings so each dimension is between -1 and 1
135                debug!("Normalizing embeddings");
136                normalize_embeddings_inplace(&mut embeddings);
137                embeddings
138            }
139            Self::None => {
140                debug!("Using original data as embeddings");
141                samples
142            }
143        };
144        debug!("Embeddings shape: {:?}", result.shape());
145        Ok(result)
146    }
147}
148
149// Normalize the embeddings to between 0.0 and 1.0, in-place.
150// Pass the embedding size as an argument to enable more compiler optimizations
151fn normalize_embeddings_inplace(embeddings: &mut Array2<Feature>) {
152    for i in 0..embeddings.ncols() {
153        let min = *embeddings.column(i).min_skipnan();
154        let max = *embeddings.column(i).max_skipnan();
155        let range = max - min;
156        embeddings
157            .column_mut(i)
158            .mapv_inplace(|v| ((v - min) / range).mul_add(2., -1.));
159    }
160}
161
162// log the number of features
163/// Baseline dimensionality that the T-SNE and PCA projection methods aim to project the data into.
164/// T-SNE does a two stage projection, first into `(NUMBER_FEATURES + EMBEDDING_SIZE) / 2` dimensions using PCA, then into `EMBEDDING_SIZE` dimensions using T-SNE.
165/// PCA directly projects the data into `EMBEDDING_SIZE` dimensions.
166const EMBEDDING_SIZE: usize = {
167    let log2 = usize::ilog2(if DIM_EMBEDDING < NUMBER_FEATURES {
168        NUMBER_FEATURES
169    } else {
170        DIM_EMBEDDING
171    }) as usize;
172    if log2 < 2 { 2 } else { log2 }
173};
174
175#[allow(clippy::module_name_repetitions)]
176pub struct ClusteringHelper<S>
177where
178    S: Sized,
179{
180    state: S,
181}
182
183pub struct EntryPoint;
184pub struct NotInitialized {
185    /// The embeddings of our input, as a Nx`EMBEDDING_SIZE` array
186    embeddings: Array2<Feature>,
187    pub k_max: usize,
188    pub optimizer: KOptimal,
189    pub clustering_method: ClusteringMethod,
190}
191pub struct Initialized {
192    /// The embeddings of our input, as a Nx`EMBEDDING_SIZE` array
193    embeddings: Array2<Feature>,
194    pub k: usize,
195    pub clustering_method: ClusteringMethod,
196}
197pub struct Finished {
198    /// The labelings of the samples, as a Nx1 array.
199    /// Each element is the cluster that the corresponding sample belongs to.
200    labels: Array1<usize>,
201    pub k: usize,
202}
203
204/// Functions available for all states
205impl ClusteringHelper<EntryPoint> {
206    /// Create a new `KMeansHelper` object
207    ///
208    /// # Errors
209    ///
210    /// Will return an error if there was an error projecting the data into a lower-dimensional space
211    #[allow(clippy::missing_inline_in_public_items)]
212    pub fn new(
213        samples: Array2<Feature>,
214        k_max: usize,
215        optimizer: KOptimal,
216        clustering_method: ClusteringMethod,
217        projection_method: ProjectionMethod,
218    ) -> Result<ClusteringHelper<NotInitialized>, ClusteringError> {
219        if samples.nrows() <= 15 {
220            return Err(ClusteringError::SmallLibrary);
221        }
222
223        // project the data into a lower-dimensional space
224        let embeddings = projection_method.project(samples)?;
225
226        Ok(ClusteringHelper {
227            state: NotInitialized {
228                embeddings,
229                k_max,
230                optimizer,
231                clustering_method,
232            },
233        })
234    }
235}
236
237/// Functions available for `NotInitialized` state
238impl ClusteringHelper<NotInitialized> {
239    /// Initialize the `KMeansHelper` object
240    ///
241    /// # Errors
242    ///
243    /// Will return an error if there was an error calculating the optimal number of clusters
244    #[inline]
245    pub fn initialize(self) -> Result<ClusteringHelper<Initialized>, ClusteringError> {
246        let k = self.get_optimal_k()?;
247        Ok(ClusteringHelper {
248            state: Initialized {
249                embeddings: self.state.embeddings,
250                k,
251                clustering_method: self.state.clustering_method,
252            },
253        })
254    }
255
256    fn get_optimal_k(&self) -> Result<usize, ClusteringError> {
257        match self.state.optimizer {
258            KOptimal::GapStatistic { b } => self.get_optimal_k_gap_statistic(b),
259            KOptimal::DaviesBouldin => self.get_optimal_k_davies_bouldin(),
260        }
261    }
262
263    /// Get the optimal number of clusters using the gap statistic
264    ///
265    /// # References:
266    ///
267    /// - [R. Tibshirani, G. Walther, and T. Hastie (Standford University, 2001)](https://hastie.su.domains/Papers/gap.pdf)
268    ///
269    /// # Algorithm:
270    ///
271    /// 1. Cluster the observed data, varying the number of clusters from k = 1, …, kmax, and compute the corresponding total within intra-cluster variation Wk.
272    /// 2. Generate B reference data sets with a random uniform distribution. Cluster each of these reference data sets with varying number of clusters k = 1, …, kmax,
273    ///    and compute the corresponding total within intra-cluster variation `W_{kb}`.
274    /// 3. Compute the estimated gap statistic as the deviation of the observed `W_k` value from its expected value `W_kb` under the null hypothesis:
275    ///    `Gap(k)=(1/B) \sum_{b=1}^{B} \log(W^*_{kb}) − \log(W_k)`.
276    ///    Compute also the standard deviation of the statistics.
277    /// 4. Choose the number of clusters as the smallest value of k such that the gap statistic is within one standard deviation of the gap at k+1:
278    ///    `Gap(k)≥Gap(k + 1)−s_{k + 1}`.
279    fn get_optimal_k_gap_statistic(&self, b: u32) -> Result<usize, ClusteringError> {
280        let embedding_dataset = Dataset::from(self.state.embeddings.clone());
281
282        // our reference data sets
283        let reference_datasets =
284            generate_reference_datasets(embedding_dataset.records().view(), b as usize);
285
286        #[allow(clippy::cast_precision_loss)]
287        let b = b as Feature;
288
289        // track the best k until we get an optimal one
290        let (mut optimal_k, mut gap_k_minus_one) = (None, None);
291
292        for k in 1..=self.state.k_max {
293            // for each k, cluster the data into k clusters
294            info!("Fitting k-means to embeddings with k={k}");
295            let labels = self.state.clustering_method.fit(k, &embedding_dataset)?;
296
297            // for each k, calculate the gap statistic, and the standard deviation of the statistics
298            // 1. calculate the within intra-cluster variation for the reference data sets
299            debug!("Calculating within intra-cluster variation for reference data sets with k={k}");
300            let w_kb_log = reference_datasets
301                .par_iter()
302                .map(|ref_data| {
303                    // cluster the reference data into k clusters
304                    let ref_labels = self.state.clustering_method.fit(k, ref_data)?;
305                    // calculate the within intra-cluster variation for the reference data
306                    let ref_dispersion =
307                        calc_centroid_dispersion(ref_data.records().view(), k, ref_labels.view());
308                    let dispersion = calc_within_dispersion(ref_dispersion.view()).log2();
309                    Ok(dispersion)
310                })
311                .collect::<ClusteringResult<Vec<_>>>();
312            let w_kb_log = match w_kb_log {
313                Ok(w_kb_log) => Array::from_vec(w_kb_log),
314                Err(ClusteringError::Gmm(GmmError::EmptyCluster(e))) => {
315                    log::warn!("Library is not large enough to cluster with k={k}: {e}");
316                    break;
317                }
318                Err(e) => return Err(e),
319            };
320            // 2. calculate the within intra-cluster variation for the observed data
321            let centroid_dispersion =
322                calc_centroid_dispersion(self.state.embeddings.view(), k, labels.view());
323            let w_k = calc_within_dispersion(centroid_dispersion.view());
324
325            // 3. finally, calculate the gap statistic
326            let w_kb_log_sum: Feature = w_kb_log.sum();
327            // original formula: l = (1 / B) * sum_b(log(W_kb))
328            let l = b.recip() * w_kb_log_sum;
329            // original formula: gap_k = (1 / B) * sum_b(log(W_kb)) - log(W_k)
330            let gap_k = l - w_k.log2();
331            // original formula: sd_k = [(1 / B) * sum_b((log(W_kb) - l)^2)]^0.5
332            let standard_deviation = (b.recip() * (w_kb_log - l).pow2().sum()).sqrt();
333            // original formula: s_k = sd_k * (1 + 1 / B)^0.5
334            // calculate differently to minimize rounding errors
335            let s_k = standard_deviation * (1.0 + b.recip()).sqrt();
336
337            // finally, update the optimal k if needed
338            info!("k: {k}, gap_k: {gap_k}, s_k: {s_k}");
339            if let Some(gap_k_minus_one) = gap_k_minus_one
340                && gap_k_minus_one >= gap_k - s_k
341            {
342                info!("Optimal k found: {}", k - 1);
343                optimal_k = Some(k - 1);
344                break;
345            }
346
347            gap_k_minus_one = Some(gap_k);
348        }
349
350        optimal_k.ok_or(ClusteringError::OptimalKNotFound(self.state.k_max))
351    }
352
353    fn get_optimal_k_davies_bouldin(&self) -> Result<usize, ClusteringError> {
354        todo!();
355    }
356}
357
358const MAX_REF_DATASET_SAMPLES: usize = 2000;
359
360/// Generate B reference data sets with a random uniform distribution
361///
362/// (excerpt from reference paper)
363/// """
364/// We consider two choices for the reference distribution:
365///
366/// 1. generate each reference feature uniformly over the range of the observed values for that feature.
367/// 2. generate the reference features from a uniform distribution over a box aligned with the
368///    principle components of the data.
369///    In detail, if X is our n by p data matrix, we assume that the columns have mean 0 and compute
370///    the singular value decomposition X = UDV^T. We transform via X' = XV and then draw uniform features Z'
371///    over the ranges of the columns of X', as in method (1) above.
372///    Finally, we back-transform via Z=Z'V^T to give reference data Z.
373///
374/// Method (1) has the advantage of simplicity. Method (2) takes into account the shape of the
375/// data distribution and makes the procedure rotationally invariant, as long as the
376/// clustering method itself is invariant
377/// """
378///
379/// For now, we will use method (1) as it is simpler to implement
380/// and we know that our data is already normalized and that
381/// the ordering of features is important, meaning that we can't
382/// rotate the data anyway.
383fn generate_reference_datasets(samples: ArrayView2<'_, Feature>, b: usize) -> Vec<FitDataset> {
384    if samples.nrows() < MAX_REF_DATASET_SAMPLES {
385        // for small datasets, we use all samples
386        return (0..b)
387            .into_par_iter()
388            .map(|_| Dataset::from(generate_ref_single(samples.view())))
389            .collect();
390    }
391
392    // for large datasets, we randomly sample 2000 samples to speed up the reference data generation
393    let mut rng = rand::thread_rng();
394    let indices = rand::seq::index::sample(&mut rng, samples.nrows(), MAX_REF_DATASET_SAMPLES);
395    let samples = samples.select(Axis(0), &indices.into_vec());
396
397    (0..b)
398        .into_par_iter()
399        .map(|_| Dataset::from(generate_ref_single(samples.view())))
400        .collect()
401}
402fn generate_ref_single(samples: ArrayView2<'_, Feature>) -> Array2<Feature> {
403    let feature_distributions = samples
404        .axis_iter(Axis(1))
405        .map(|feature| {
406            let min = *feature.min_skipnan();
407            let max = *feature.max_skipnan();
408            if min >= max {
409                // if all values are the same, we just create a data set with that same value
410                return Array::from_elem(feature.dim(), min);
411            }
412            Array::random(feature.dim(), Uniform::new(min, max))
413        })
414        .collect::<Vec<_>>();
415    let feature_dists_views = feature_distributions
416        .iter()
417        .map(ndarray::ArrayBase::view)
418        .collect::<Vec<_>>();
419    ndarray::stack(Axis(0), &feature_dists_views)
420        .unwrap()
421        .t()
422        .to_owned()
423}
424
425/// Calculate `W_k`, the within intra-cluster variation for the given clustering
426///
427/// `W_k = \sum_{r=1}^{k} \frac{D_r}{2*n_r}`
428///
429/// # Arguments
430///
431/// - `pairwise_distances`: The `D_r / (2*n_r)` array, the sum of the pairwise distances in cluster r, for all clusters in the given clustering
432fn calc_within_dispersion(pairwise_distances: ArrayView1<'_, Feature>) -> Feature {
433    pairwise_distances.sum()
434}
435
436/// Calculate the `D_r / (2*n_r)` array, the sum of the pairwise distances in cluster r, for all clusters in the given clustering
437///
438/// # Arguments
439///
440/// - `samples`: The samples in the dataset
441/// - `k`: The number of clusters
442/// - `labels`: The cluster labels for each sample
443#[allow(unused)]
444fn calc_pairwise_distances(
445    samples: ArrayView2<'_, Feature>,
446    k: usize,
447    labels: ArrayView1<'_, usize>,
448) -> Array1<Feature> {
449    debug_assert_eq!(
450        samples.nrows(),
451        labels.len(),
452        "Samples and labels must have the same length"
453    );
454    debug_assert_eq!(
455        k,
456        labels.iter().max().unwrap() + 1,
457        "Labels must be in the range 0..k"
458    );
459
460    // for each cluster, calculate the sum of the pairwise distances between samples in that cluster
461    let mut distances = Array1::zeros(k);
462    let mut clusters = vec![Vec::new(); k];
463    // build clusters
464    for (sample, &label) in samples.outer_iter().zip(labels.iter()) {
465        clusters[label].push(sample);
466    }
467    // calculate pairwise dist. within each cluster
468    #[allow(clippy::cast_precision_loss)]
469    for (k, cluster) in clusters.iter().enumerate() {
470        let mut pairwise_dists = 0.;
471        for i in 0..cluster.len() - 1 {
472            let a = cluster[i];
473            let rest = &cluster[i + 1..];
474            for &b in rest {
475                pairwise_dists += L2Dist.distance(a, b);
476            }
477        }
478        distances[k] += pairwise_dists / cluster.len() as Feature; // (pairwise dists + pairwise_dists) / (2*n_r) = pairwise_dists / n_r
479    }
480    distances
481}
482
483/// Calculate within-cluster dispersion using cluster centroids, normalized by number of samples in each cluster.
484///
485/// This is an O(n) approximation of the sum of pairwise distances within each cluster, divided by half number of samples in each cluster.
486/// Effectively, this calculates `D_r / (2 * n_r)` for each cluster `r`.
487///
488/// You can then use it to approximate `W_k` (within intra-cluster variation, defined in the paper as `W_k = \sum_{r=1}^{k} \frac{D_r}{2*n_r}`)
489/// by simply summing the returned dispersions .
490///
491/// # Arguments
492///
493/// - `samples`: The samples in the dataset
494/// - `k`: The number of clusters
495/// - `labels`: The cluster labels for each sample
496///
497/// # Returns
498///
499/// Array of dispersions (`D_r / (2 * n_r)`) for each cluster
500fn calc_centroid_dispersion(
501    samples: ArrayView2<'_, Feature>,
502    k: usize,
503    labels: ArrayView1<'_, usize>,
504) -> Array1<Feature> {
505    debug_assert_eq!(
506        samples.nrows(),
507        labels.len(),
508        "Samples and labels must have the same length"
509    );
510    debug_assert_eq!(
511        k,
512        labels.iter().max().unwrap() + 1,
513        "Labels must be in the range 0..k"
514    );
515
516    let mut dispersions = Array1::zeros(k);
517    let mut centroids = Array2::zeros((k, samples.ncols()));
518    let mut counts = vec![0usize; k];
519
520    // Calculate centroids
521    for (sample, &label) in samples.outer_iter().zip(labels.iter()) {
522        centroids.row_mut(label).scaled_add(1.0, &sample);
523        counts[label] += 1;
524    }
525
526    // Normalize centroids by count
527    for (i, &count) in counts.iter().enumerate() {
528        if count > 0 {
529            #[allow(clippy::cast_precision_loss)]
530            centroids.row_mut(i).mapv_inplace(|v| v / count as Feature);
531        }
532    }
533
534    // Calculate sum of squared distances to centroid
535    for (sample, &label) in samples.outer_iter().zip(labels.iter()) {
536        let dist = L2Dist.distance(sample, centroids.row(label));
537        dispersions[label] += dist; // this should technically be dist^2, but dist is closer to pairwise distances
538    }
539
540    dispersions
541}
542
543/// Functions available for Initialized state
544impl ClusteringHelper<Initialized> {
545    /// Cluster the data into k clusters
546    ///
547    /// # Errors
548    ///
549    /// Will return an error if the clustering fails
550    #[inline]
551    pub fn cluster(self) -> ClusteringResult<ClusteringHelper<Finished>> {
552        let Initialized {
553            clustering_method,
554            embeddings,
555            k,
556        } = self.state;
557
558        let embedding_dataset = Dataset::from(embeddings);
559        let labels = clustering_method.fit(k, &embedding_dataset)?;
560
561        Ok(ClusteringHelper {
562            state: Finished { labels, k },
563        })
564    }
565}
566
567/// Functions available for Finished state
568impl ClusteringHelper<Finished> {
569    /// use the labels to reorganize the provided samples into clusters
570    #[must_use]
571    #[inline]
572    pub fn extract_analysis_clusters<T: Clone>(&self, samples: Vec<T>) -> Vec<Vec<T>> {
573        let mut clusters = vec![Vec::new(); self.state.k];
574
575        for (sample, &label) in samples.into_iter().zip(self.state.labels.iter()) {
576            clusters[label].push(sample);
577        }
578
579        clusters
580    }
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586    use ndarray::{arr1, arr2, s};
587    use ndarray_rand::rand_distr::StandardNormal;
588    use pretty_assertions::assert_eq;
589    use rstest::rstest;
590
591    #[test]
592    fn test_generate_reference_data_set() {
593        let data = arr2(&[[10.0, -10.0], [20.0, -20.0], [30.0, -30.0]]);
594
595        let ref_data = generate_ref_single(data.view());
596
597        // First column all vals between 10.0 and 30.0
598        assert!(
599            ref_data
600                .slice(s![.., 0])
601                .iter()
602                .all(|v| *v >= 10.0 && *v <= 30.0)
603        );
604
605        // Second column all vals between -10.0 and -30.0
606        assert!(
607            ref_data
608                .slice(s![.., 1])
609                .iter()
610                .all(|v| *v <= -10.0 && *v >= -30.0)
611        );
612
613        // check that the shape is correct
614        assert_eq!(ref_data.shape(), data.shape());
615
616        // check that the data is not the same as the original data
617        assert_ne!(ref_data, data);
618    }
619
620    #[test]
621    fn test_pairwise_distances() {
622        let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
623        let labels = arr1(&[0, 0, 1, 1]);
624
625        let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
626
627        assert!(
628            f32::EPSILON > (pairwise_distances[0] - 0.0).abs(),
629            "{} != 0.0",
630            pairwise_distances[0]
631        );
632        assert!(
633            f32::EPSILON > (pairwise_distances[1] - 0.0).abs(),
634            "{} != 0.0",
635            pairwise_distances[1]
636        );
637
638        let samples = arr2(&[[1.0, 2.0], [1.0, 1.0], [2.0, 2.0], [2.0, 3.0]]);
639
640        let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
641
642        assert!(
643            f32::EPSILON > (pairwise_distances[0] - 2.0 / 4.0).abs(),
644            "{} != 2.0",
645            pairwise_distances[0]
646        );
647        assert!(
648            f32::EPSILON > (pairwise_distances[1] - 2.0 / 4.0).abs(),
649            "{} != 2.0",
650            pairwise_distances[1]
651        );
652    }
653
654    #[test]
655    fn test_calc_within_dispersion() {
656        // let labels = arr1(&[0, 1, 0, 1]);
657        let pairwise_distances = arr1(&[1.0 / 4.0, 2.0 / 4.0]); // D_1 / (2*n_1) = 1.0 / 4, D_2 / (2*n_2) = 2.0 / 4
658        let result = calc_within_dispersion(pairwise_distances.view());
659
660        // `W_k = \sum_{r=1}^{k} \frac{D_r}{2*n_r}` = 1/4 * 1.0 + 1/4 * 2.0 = 0.25 + 0.5 = 0.75
661        assert!(f32::EPSILON > (result - 0.75).abs(), "{result} != 0.75");
662    }
663
664    #[rstest]
665    #[case::project_none(ProjectionMethod::None, NUMBER_FEATURES)]
666    #[case::project_tsne(ProjectionMethod::TSne, EMBEDDING_SIZE)]
667    #[case::project_pca(ProjectionMethod::Pca, EMBEDDING_SIZE)]
668    fn test_project(
669        #[case] projection_method: ProjectionMethod,
670        #[case] expected_embedding_size: usize,
671    ) {
672        // generate 100 random samples, we use a normal distribution because with a uniform distribution
673        // the data has no real "principle components" and PCA will not work as expected since almost all the eigenvalues
674        // with fall below the cutoff
675        let mut samples = Array2::random((100, NUMBER_FEATURES), StandardNormal);
676        normalize_embeddings_inplace(&mut samples);
677
678        let result = projection_method.project(samples).unwrap();
679
680        // ensure embeddings are the correct shape
681        assert_eq!(result.shape(), &[100, expected_embedding_size]);
682
683        // ensure the data is normalized
684        for i in 0..expected_embedding_size {
685            let min = result.column(i).min().copied().unwrap_or_default();
686            let max = result.column(i).max().copied().unwrap_or_default();
687            assert!(
688                f32::EPSILON > (min + 1.0).abs(),
689                "Min value of column {i} is not -1.0: {min}",
690            );
691            assert!(
692                f32::EPSILON > (max - 1.0).abs(),
693                "Max value of column {i} is not 1.0: {max}",
694            );
695        }
696    }
697
698    #[test]
699    fn test_centroid_dispersion_identical_points() {
700        // Test case: All points in a cluster are identical (dispersion should be 0)
701        let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
702        let labels = arr1(&[0, 0, 1, 1]);
703
704        let dispersion = calc_centroid_dispersion(samples.view(), 2, labels.view());
705
706        // Both clusters have identical points, so dispersion should be near 0
707        assert!(
708            f32::EPSILON > dispersion[0].abs(),
709            "Cluster 0 dispersion should be ~0, got {}",
710            dispersion[0]
711        );
712        assert!(
713            f32::EPSILON > dispersion[1].abs(),
714            "Cluster 1 dispersion should be ~0, got {}",
715            dispersion[1]
716        );
717    }
718
719    #[test]
720    fn test_centroid_dispersion_known_values() {
721        // Test case: Points at known distances from centroid
722        // Cluster 0: points at [0,0], [2,0], [0,2], [2,2] - centroid at [1,1]
723        // Each point is sqrt(2) from centroid, squared = 2
724        // Total dispersion = 4 * 2 = 8
725        let samples = arr2(&[
726            [0.0, 0.0],
727            [2.0, 0.0],
728            [0.0, 2.0],
729            [2.0, 2.0],
730            [10.0, 10.0], // Cluster 1: single point
731        ]);
732        let labels = arr1(&[0, 0, 0, 0, 1]);
733
734        let dispersion = calc_centroid_dispersion(samples.view(), 2, labels.view());
735
736        // Cluster 0: 4 points, each sqrt(2) from [1,1],
737        // Dispersion = 4 * sqrt(2)
738        let expected = 4.0 * 2.0_f32.sqrt();
739        assert!(
740            (dispersion[0] - expected).abs() < 0.001,
741            "Cluster 0 dispersion should be ~{}, got {}",
742            expected,
743            dispersion[0]
744        );
745
746        // Cluster 1: single point, dispersion = 0
747        assert!(
748            dispersion[1].abs() < f32::EPSILON,
749            "Cluster 1 dispersion should be ~0, got {}",
750            dispersion[1]
751        );
752    }
753
754    #[test]
755    fn test_centroid_dispersion_matches_shape() {
756        // Test that dispersion output has correct shape
757        let samples = Array2::random((100, 10), StandardNormal);
758        let labels = Array1::from_vec((0..100).map(|i| i % 5).collect());
759
760        let dispersion = calc_centroid_dispersion(samples.view(), 5, labels.view());
761
762        assert_eq!(dispersion.len(), 5, "Should have dispersion for 5 clusters");
763        assert!(
764            dispersion.iter().all(|&d| d >= 0.0),
765            "All dispersions should be non-negative"
766        );
767    }
768
769    #[test]
770    fn test_centroid_vs_pairwise_relative_ordering() {
771        // Test that centroid-based and pairwise methods produce similar relative orderings
772        // This is crucial for gap statistic, which compares relative values
773
774        // Create 3 clusters with varying tightness
775        let mut samples = Array2::zeros((60, 2));
776        let mut labels = Array1::zeros(60);
777
778        // Tight cluster 0: points clustered around [0, 0]
779        for i in 0..20 {
780            samples[[i, 0]] = (i as f32 * 0.1) - 1.0;
781            samples[[i, 1]] = (i as f32 * 0.1) - 1.0;
782            labels[i] = 0;
783        }
784
785        // Medium cluster 1: points clustered around [5, 5]
786        for i in 20..40 {
787            samples[[i, 0]] = ((i - 20) as f32 * 0.3) + 4.0;
788            samples[[i, 1]] = ((i - 20) as f32 * 0.3) + 4.0;
789            labels[i] = 1;
790        }
791
792        // Loose cluster 2: points spread around [10, 10]
793        for i in 40..60 {
794            samples[[i, 0]] = ((i - 40) as f32 * 0.5) + 8.0;
795            samples[[i, 1]] = ((i - 40) as f32 * 0.5) + 8.0;
796            labels[i] = 2;
797        }
798
799        let pairwise = calc_pairwise_distances(samples.view(), 3, labels.view());
800        let centroid = calc_centroid_dispersion(samples.view(), 3, labels.view());
801
802        // Both methods should agree on relative ordering: tight < medium < loose
803        // Cluster 0 (tight) should have smallest dispersion
804        assert!(
805            pairwise[0] < pairwise[1] && pairwise[1] < pairwise[2],
806            "Pairwise ordering incorrect: {:?}",
807            pairwise
808        );
809        assert!(
810            centroid[0] < centroid[1] && centroid[1] < centroid[2],
811            "Centroid ordering incorrect: {:?}",
812            centroid
813        );
814
815        // The key requirement for gap statistic is that relative ordering is preserved
816        // The absolute values will differ significantly:
817        // - Pairwise: sums all n(n-1) pairs of distances
818        // - Centroid: sums n squared distances to centroid, scaled by 2
819        //
820        // Because centroid uses squared distances, the spread ratios will differ by roughly
821        // the square of the distance ratio. This is expected and doesn't affect gap statistic
822        // since we only care about relative ordering (already verified above).
823        //
824        // Both should show significant spread (not all clusters equally dispersed)
825        let pairwise_spread = pairwise.max_skipnan() / pairwise.min_skipnan();
826        let centroid_spread = centroid.max_skipnan() / centroid.min_skipnan();
827
828        assert!(
829            pairwise_spread > 2.0 && centroid_spread > 2.0,
830            "Both methods should show significant dispersion variation: pairwise={}, centroid={}",
831            pairwise_spread,
832            centroid_spread
833        );
834    }
835
836    #[test]
837    fn test_centroid_dispersion_with_empty_cluster() {
838        // Edge case: what if a cluster is empty? (shouldn't happen but good to test)
839        let samples = arr2(&[[1.0, 1.0], [2.0, 2.0]]);
840        let labels = arr1(&[0, 2]); // Skip cluster 1
841
842        let dispersion = calc_centroid_dispersion(samples.view(), 3, labels.view());
843
844        // Clusters 0 and 2 should have dispersion 0 (single points)
845        assert!(dispersion[0].abs() < f32::EPSILON);
846        assert!(dispersion[2].abs() < f32::EPSILON);
847        // Cluster 1 should have dispersion 0 (empty)
848        assert!(dispersion[1].abs() < f32::EPSILON);
849    }
850}
851
852// #[cfg(feature = "plot_gap")]
853// fn plot_gap_statistic(data: Vec<(usize, f64, f64)>) {
854//     use plotters::prelude::*;
855
856//     // Assuming data is a Vec<(usize, f64, f64)> of (k, gap_k, s_k)
857//     let root_area = BitMapBackend::new("gap_statistic_plot.png", (640, 480)).into_drawing_area();
858//     root_area.fill(&WHITE).unwrap();
859
860//     let max_gap_k = data
861//         .iter()
862//         .map(|(_, gap_k, _)| *gap_k)
863//         .fold(f64::MIN, f64::max);
864//     let min_gap_k = data
865//         .iter()
866//         .map(|(_, gap_k, _)| *gap_k)
867//         .fold(f64::MAX, f64::min);
868//     let max_k = data.iter().map(|(k, _, _)| *k).max().unwrap_or(0);
869
870//     let mut chart = ChartBuilder::on(&root_area)
871//         .caption("Gap Statistic Plot", ("sans-serif", 30))
872//         .margin(5)
873//         .x_label_area_size(30)
874//         .y_label_area_size(30)
875//         .build_cartesian_2d(0..max_k, min_gap_k..max_gap_k)
876//         .unwrap();
877
878//     chart.configure_mesh().draw().unwrap();
879
880//     for (k, gap_k, s_k) in data {
881//         chart
882//             .draw_series(PointSeries::of_element(
883//                 vec![(k, gap_k)],
884//                 5,
885//                 &RED,
886//                 &|coord, size, style| {
887//                     EmptyElement::at(coord) + Circle::new((0, 0), size, style.filled())
888//                 },
889//             ))
890//             .unwrap();
891
892//         // Drawing error bars
893//         chart
894//             .draw_series(LineSeries::new(
895//                 vec![(k, gap_k - s_k), (k, gap_k + s_k)],
896//                 &BLACK,
897//             ))
898//             .unwrap();
899//     }
900
901//     root_area.present().unwrap();
902// }