mecomp_analysis/
clustering.rs

1//! this module contains helpers that wrap the a k-means crate to perform clustering on the data
2//! without having to choose an exact number of clusters.
3//!
4//! Instead, you provide the minimum and maximum number of clusters you want to try, and we'll
5//! use one of a range of methods to determine the optimal number of clusters.
6//!
7//! # References:
8//!
9//! - The gap statistic [R. Tibshirani, G. Walther, and T. Hastie (Standford University, 2001)](https://hastie.su.domains/Papers/gap.pdf)
10//! - The Davies-Bouldin index [wikipedia](https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index)
11
12use linfa::prelude::*;
13use linfa_clustering::{GaussianMixtureModel, KMeans};
14use linfa_nn::distance::{Distance, L2Dist};
15use linfa_tsne::TSneParams;
16use log::{debug, info};
17use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis};
18use ndarray_rand::RandomExt;
19use rand::distributions::Uniform;
20use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
21use statrs::statistics::Statistics;
22
23use crate::{errors::ClusteringError, Analysis, Feature, NUMBER_FEATURES};
24
25pub struct AnalysisArray(pub(crate) Array2<Feature>);
26
27impl From<Vec<Analysis>> for AnalysisArray {
28    fn from(data: Vec<Analysis>) -> Self {
29        let shape = (data.len(), NUMBER_FEATURES);
30        debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
31
32        Self(
33            Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
34                .expect("Failed to convert to array, shape mismatch"),
35        )
36    }
37}
38
39impl From<Vec<[Feature; NUMBER_FEATURES]>> for AnalysisArray {
40    fn from(data: Vec<[Feature; NUMBER_FEATURES]>) -> Self {
41        let shape = (data.len(), NUMBER_FEATURES);
42        debug_assert_eq!(shape, (data.len(), data[0].len()));
43
44        Self(
45            Array2::from_shape_vec(shape, data.into_iter().flatten().collect())
46                .expect("Failed to convert to array, shape mismatch"),
47        )
48    }
49}
50
51#[derive(Clone, Copy, Debug)]
52#[allow(clippy::module_name_repetitions)]
53pub enum ClusteringMethod {
54    KMeans,
55    GaussianMixtureModel,
56}
57
58impl ClusteringMethod {
59    /// Fit the clustering method to the samples and get the labels
60    #[must_use]
61    fn fit(self, k: usize, samples: &Array2<Feature>) -> Array1<usize> {
62        match self {
63            Self::KMeans => {
64                let model = KMeans::params(k)
65                    // .max_n_iterations(MAX_ITERATIONS)
66                    .fit(&Dataset::from(samples.clone()))
67                    .unwrap();
68                model.predict(samples)
69            }
70            Self::GaussianMixtureModel => {
71                let model = GaussianMixtureModel::params(k)
72                    .init_method(linfa_clustering::GmmInitMethod::KMeans)
73                    .n_runs(10)
74                    .fit(&Dataset::from(samples.clone()))
75                    .unwrap();
76                model.predict(samples)
77            }
78        }
79    }
80}
81
82#[derive(Clone, Copy, Debug)]
83pub enum KOptimal {
84    GapStatistic {
85        /// The number of reference datasets to generate
86        b: usize,
87    },
88    DaviesBouldin,
89}
90
91// log the number of features
92const EMBEDDING_SIZE: usize =
93    //  2;
94    {
95        let log2 = usize::ilog2(NUMBER_FEATURES) as usize;
96        if log2 < 2 {
97            2
98        } else {
99            log2
100        }
101    };
102
103#[allow(clippy::module_name_repetitions)]
104pub struct ClusteringHelper<S>
105where
106    S: Sized,
107{
108    state: S,
109}
110
111pub struct EntryPoint;
112pub struct NotInitialized {
113    /// The embeddings of our input, as a Nx`EMBEDDING_SIZE` array
114    embeddings: Array2<Feature>,
115    pub k_max: usize,
116    pub optimizer: KOptimal,
117    pub clustering_method: ClusteringMethod,
118}
119pub struct Initialized {
120    /// The embeddings of our input, as a Nx`EMBEDDING_SIZE` array
121    embeddings: Array2<Feature>,
122    pub k: usize,
123    pub clustering_method: ClusteringMethod,
124}
125pub struct Finished {
126    /// The labelings of the samples, as a Nx1 array.
127    /// Each element is the cluster that the corresponding sample belongs to.
128    labels: Array1<usize>,
129    pub k: usize,
130}
131
132/// Functions available for all states
133impl ClusteringHelper<EntryPoint> {
134    /// Create a new `KMeansHelper` object
135    ///
136    /// # Errors
137    ///
138    /// Will return an error if there was an error projecting the data into a lower-dimensional space
139    pub fn new(
140        samples: AnalysisArray,
141        k_max: usize,
142        optimizer: KOptimal,
143        clustering_method: ClusteringMethod,
144    ) -> Result<ClusteringHelper<NotInitialized>, ClusteringError> {
145        // first use the t-SNE algorithm to project the data into a lower-dimensional space
146        debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using t-SNE",);
147
148        if samples.0.nrows() <= 15 {
149            return Err(ClusteringError::SmallLibrary);
150        }
151
152        #[allow(clippy::cast_precision_loss)]
153        let mut embeddings = TSneParams::embedding_size(EMBEDDING_SIZE)
154            .perplexity(f64::max(samples.0.nrows() as f64 / 20., 5.))
155            .approx_threshold(0.5)
156            .transform(samples.0)?;
157
158        debug!("Embeddings shape: {:?}", embeddings.shape());
159
160        // normalize the embeddings so each dimension is between -1 and 1
161        debug!("Normalizing embeddings");
162        for i in 0..EMBEDDING_SIZE {
163            let min = embeddings.column(i).min();
164            let max = embeddings.column(i).max();
165            let range = max - min;
166            embeddings
167                .column_mut(i)
168                .mapv_inplace(|v| ((v - min) / range).mul_add(2., -1.));
169        }
170
171        Ok(ClusteringHelper {
172            state: NotInitialized {
173                embeddings,
174                k_max,
175                optimizer,
176                clustering_method,
177            },
178        })
179    }
180}
181
182/// Functions available for `NotInitialized` state
183impl ClusteringHelper<NotInitialized> {
184    /// Initialize the `KMeansHelper` object
185    ///
186    /// # Errors
187    ///
188    /// Will return an error if there was an error calculating the optimal number of clusters
189    pub fn initialize(self) -> Result<ClusteringHelper<Initialized>, ClusteringError> {
190        let k = self.get_optimal_k()?;
191        Ok(ClusteringHelper {
192            state: Initialized {
193                embeddings: self.state.embeddings,
194                k,
195                clustering_method: self.state.clustering_method,
196            },
197        })
198    }
199
200    fn get_optimal_k(&self) -> Result<usize, ClusteringError> {
201        match self.state.optimizer {
202            KOptimal::GapStatistic { b } => self.get_optimal_k_gap_statistic(b),
203            KOptimal::DaviesBouldin => self.get_optimal_k_davies_bouldin(),
204        }
205    }
206
207    /// Get the optimal number of clusters using the gap statistic
208    ///
209    /// # References:
210    ///
211    /// - [R. Tibshirani, G. Walther, and T. Hastie (Standford University, 2001)](https://hastie.su.domains/Papers/gap.pdf)
212    ///
213    /// # Algorithm:
214    ///
215    /// 1. Cluster the observed data, varying the number of clusters from k = 1, …, kmax, and compute the corresponding total within intra-cluster variation Wk.
216    /// 2. Generate B reference data sets with a random uniform distribution. Cluster each of these reference data sets with varying number of clusters k = 1, …, kmax,
217    ///    and compute the corresponding total within intra-cluster variation `W_{kb}`.
218    /// 3. Compute the estimated gap statistic as the deviation of the observed `W_k` value from its expected value `W_kb` under the null hypothesis:
219    ///    `Gap(k)=(1/B) \sum_{b=1}^{B} \log(W^*_{kb})−\log(W_k)`.
220    ///    Compute also the standard deviation of the statistics.
221    /// 4. Choose the number of clusters as the smallest value of k such that the gap statistic is within one standard deviation of the gap at k+1:
222    ///    `Gap(k)≥Gap(k + 1)−s_{k + 1}`.
223    fn get_optimal_k_gap_statistic(&self, b: usize) -> Result<usize, ClusteringError> {
224        // our reference data sets
225        let reference_data_sets = generate_reference_data_set(self.state.embeddings.view(), b);
226
227        let results = (1..=self.state.k_max)
228            // for each k, cluster the data into k clusters
229            .map(|k| {
230                debug!("Fitting k-means to embeddings with k={k}");
231                let labels = self.state.clustering_method.fit(k, &self.state.embeddings);
232                (k, labels)
233            })
234            // for each k, calculate the gap statistic, and the standard deviation of the statistics
235            .map(|(k, labels)| {
236                // first, we calculate the within intra-cluster variation for the observed data
237                let pairwise_distances =
238                    calc_pairwise_distances(self.state.embeddings.view(), k, labels.view());
239                let w_k = calc_within_dispersion(labels.view(), k, pairwise_distances.view());
240
241                // then, we calculate the within intra-cluster variation for the reference data sets
242                debug!(
243                    "Calculating within intra-cluster variation for reference data sets with k={k}"
244                );
245                let w_kb = reference_data_sets.par_iter().map(|ref_data| {
246                    // cluster the reference data into k clusters
247                    let ref_labels = self.state.clustering_method.fit(k, ref_data);
248                    // calculate the within intra-cluster variation for the reference data
249                    let ref_pairwise_distances =
250                        calc_pairwise_distances(ref_data.view(), k, ref_labels.view());
251                    calc_within_dispersion(ref_labels.view(), k, ref_pairwise_distances.view())
252                });
253
254                // finally, we calculate the gap statistic
255                #[allow(clippy::cast_precision_loss)]
256                let gap_k = (1.0 / b as f64)
257                    .mul_add(w_kb.clone().map(f64::log2).sum::<f64>().log2(), -w_k.log2());
258
259                #[allow(clippy::cast_precision_loss)]
260                let l = (1.0 / b as f64) * w_kb.clone().map(f64::log2).sum::<f64>();
261                #[allow(clippy::cast_precision_loss)]
262                let standard_deviation = ((1.0 / b as f64)
263                    * w_kb.map(|w_kb| (w_kb.log2() - l).powi(2)).sum::<f64>())
264                .sqrt();
265                #[allow(clippy::cast_precision_loss)]
266                let s_k = standard_deviation * (1.0 + 1.0 / b as f64).sqrt();
267
268                (k, gap_k, s_k)
269            });
270
271        // // plot the gap_k (whisker with s_k) w.r.t. k
272        // #[cfg(feature = "plot_gap")]
273        // plot_gap_statistic(results.clone().collect::<Vec<_>>());
274
275        // finally, we go over the iterator to find the optimal k
276        let (mut optimal_k, mut gap_k_minus_one) = (None, None);
277
278        for (k, gap_k, s_k) in results {
279            info!("k: {k}, gap_k: {gap_k}, s_k: {s_k}");
280
281            if let Some(gap_k_minus_one) = gap_k_minus_one {
282                if gap_k_minus_one >= gap_k - s_k {
283                    info!("Optimal k found: {}", k - 1);
284                    optimal_k = Some(k - 1);
285                    break;
286                }
287            }
288            gap_k_minus_one = Some(gap_k);
289        }
290
291        optimal_k.ok_or(ClusteringError::OptimalKNotFound(self.state.k_max))
292    }
293
294    fn get_optimal_k_davies_bouldin(&self) -> Result<usize, ClusteringError> {
295        todo!();
296    }
297}
298
299/// Convert a vector of Analyses into a 2D array
300///
301/// # Panics
302///
303/// Will panic if the shape of the data does not match the number of features, should never happen
304#[must_use]
305pub fn convert_to_array(data: Vec<Analysis>) -> AnalysisArray {
306    // Convert vector to Array
307    let shape = (data.len(), NUMBER_FEATURES);
308    debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
309
310    AnalysisArray(
311        Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
312            .expect("Failed to convert to array, shape mismatch"),
313    )
314}
315
316/// Generate B reference data sets with a random uniform distribution
317///
318/// (excerpt from reference paper)
319/// """
320/// We consider two choices for the reference distribution:
321///
322/// 1. generate each reference feature uniformly over the range of the observed values for that feature.
323/// 2. generate the reference features from a uniform distribution over a box aligned with the
324///    principle components of the data.
325///    In detail, if X is our n by p data matrix, we assume that the columns have mean 0 and compute
326///    the singular value decomposition X = UDV^T. We transform via X' = XV and then draw uniform features Z'
327///    over the ranges of the columns of X', as in method (1) above.
328///    Finally, we back-transform via Z=Z'V^T to give reference data Z.
329///
330/// Method (1) has the advantage of simplicity. Method (2) takes into account the shape of the
331/// data distribution and makes the procedure rotationally invariant, as long as the
332/// clustering method itself is invariant
333/// """
334///
335/// For now, we will use method (1) as it is simpler to implement
336/// and we know that our data is already normalized and that
337/// the ordering of features is important, meaning that we can't
338/// rotate the data anyway.
339fn generate_reference_data_set(samples: ArrayView2<Feature>, b: usize) -> Vec<Array2<f64>> {
340    let mut reference_data_sets = Vec::with_capacity(b);
341    for _ in 0..b {
342        reference_data_sets.push(generate_ref_single(samples));
343    }
344
345    reference_data_sets
346}
347fn generate_ref_single(samples: ArrayView2<Feature>) -> Array2<f64> {
348    let feature_distributions = samples
349        .axis_iter(Axis(1))
350        .map(|feature| Array::random(feature.dim(), Uniform::new(feature.min(), feature.max())))
351        .collect::<Vec<_>>();
352    let feature_dists_views = feature_distributions
353        .iter()
354        .map(ndarray::ArrayBase::view)
355        .collect::<Vec<_>>();
356    ndarray::stack(Axis(0), &feature_dists_views)
357        .unwrap()
358        .t()
359        .to_owned()
360}
361
362/// Calculate `W_k`, the within intra-cluster variation for the given clustering
363///
364/// `W_k = \sum_{r=1}^{k} \frac{D_r}{2*n_r}`
365fn calc_within_dispersion(
366    labels: ArrayView1<usize>,
367    k: usize,
368    pairwise_distances: ArrayView1<Feature>,
369) -> Feature {
370    debug_assert_eq!(k, labels.iter().max().unwrap() + 1);
371
372    // we first need to convert our list of labels into a list of the number of samples in each cluster
373    let counts = labels.iter().fold(vec![0; k], |mut counts, &label| {
374        counts[label] += 1;
375        counts
376    });
377    // then, we calculate the within intra-cluster variation
378    counts
379        .iter()
380        .zip(pairwise_distances.iter())
381        .map(|(&count, distance)| distance / (2.0 * f64::from(count)))
382        .sum()
383}
384
385/// Calculate the `D_r` array, the sum of the pairwise distances in cluster r, for all clusters in the given clustering
386///
387/// # Arguments
388///
389/// - `samples`: The samples in the dataset
390/// - `k`: The number of clusters
391/// - `labels`: The cluster labels for each sample
392fn calc_pairwise_distances(
393    samples: ArrayView2<Feature>,
394    k: usize,
395    labels: ArrayView1<usize>,
396) -> Array1<Feature> {
397    debug_assert_eq!(
398        samples.nrows(),
399        labels.len(),
400        "Samples and labels must have the same length"
401    );
402    debug_assert_eq!(
403        k,
404        labels.iter().max().unwrap() + 1,
405        "Labels must be in the range 0..k"
406    );
407
408    // for each cluster, calculate the sum of the pairwise distances between samples in that cluster
409    (0..k)
410        .map(|k| {
411            (
412                k,
413                samples
414                    .outer_iter()
415                    .zip(labels.iter())
416                    .filter_map(|(s, &l)| (l == k).then_some(s))
417                    .collect::<Vec<_>>(),
418            )
419        })
420        .fold(Array1::zeros(k), |mut distances, (label, cluster)| {
421            distances[label] += cluster
422                .iter()
423                .enumerate()
424                .map(|(i, &a)| {
425                    cluster
426                        .iter()
427                        .skip(i + 1)
428                        .map(|&b| L2Dist.distance(a, b))
429                        .sum::<Feature>()
430                })
431                .sum::<Feature>();
432            distances
433        })
434}
435
436/// Functions available for Initialized state
437impl ClusteringHelper<Initialized> {
438    /// Cluster the data into k clusters
439    ///
440    /// # Errors
441    ///
442    /// Will return an error if the clustering fails
443    #[must_use]
444    pub fn cluster(self) -> ClusteringHelper<Finished> {
445        let labels = self
446            .state
447            .clustering_method
448            .fit(self.state.k, &self.state.embeddings);
449
450        ClusteringHelper {
451            state: Finished {
452                labels,
453                k: self.state.k,
454            },
455        }
456    }
457}
458
459/// Functions available for Finished state
460impl ClusteringHelper<Finished> {
461    /// use the labels to reorganize the provided samples into clusters
462    #[must_use]
463    pub fn extract_analysis_clusters<T: Clone>(&self, samples: Vec<T>) -> Vec<Vec<T>> {
464        let mut clusters = vec![Vec::new(); self.state.k];
465
466        for (sample, &label) in samples.into_iter().zip(self.state.labels.iter()) {
467            clusters[label].push(sample);
468        }
469
470        clusters
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477    use ndarray::{arr1, arr2, s};
478    use pretty_assertions::assert_eq;
479
480    #[test]
481    fn test_generate_reference_data_set() {
482        let data = arr2(&[[10.0, -10.0], [20.0, -20.0], [30.0, -30.0]]);
483
484        let ref_data = generate_ref_single(data.view());
485
486        // First column all vals between 10.0 and 30.0
487        assert!(ref_data
488            .slice(s![.., 0])
489            .iter()
490            .all(|v| *v >= 10.0 && *v <= 30.0));
491
492        // Second column all vals between -10.0 and -30.0
493        assert!(ref_data
494            .slice(s![.., 1])
495            .iter()
496            .all(|v| *v <= -10.0 && *v >= -30.0));
497
498        // check that the shape is correct
499        assert_eq!(ref_data.shape(), data.shape());
500
501        // check that the data is not the same as the original data
502        assert_ne!(ref_data, data);
503    }
504
505    #[test]
506    fn test_pairwise_distances() {
507        let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
508        let labels = arr1(&[0, 0, 1, 1]);
509
510        let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
511
512        assert_eq!(pairwise_distances[0], 0.0);
513        assert_eq!(pairwise_distances[1], 0.0);
514
515        let samples = arr2(&[[1.0, 2.0], [1.0, 1.0], [2.0, 2.0], [2.0, 3.0]]);
516
517        let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
518
519        assert_eq!(pairwise_distances[0], 1.0);
520        assert_eq!(pairwise_distances[1], 1.0);
521    }
522
523    #[test]
524    fn test_convert_to_vec() {
525        let data = vec![
526            Analysis::new([1.0; NUMBER_FEATURES]),
527            Analysis::new([2.0; NUMBER_FEATURES]),
528            Analysis::new([3.0; NUMBER_FEATURES]),
529        ];
530
531        let array = convert_to_array(data);
532
533        assert_eq!(array.0.shape(), &[3, NUMBER_FEATURES]);
534        assert_eq!(array.0[[0, 0]], 1.0);
535        assert_eq!(array.0[[1, 0]], 2.0);
536        assert_eq!(array.0[[2, 0]], 3.0);
537
538        // check that axis iteration works how we expect
539        // axis 0
540        let mut iter = array.0.axis_iter(Axis(0));
541        assert_eq!(iter.next().unwrap().to_vec(), vec![1.0; NUMBER_FEATURES]);
542        assert_eq!(iter.next().unwrap().to_vec(), vec![2.0; NUMBER_FEATURES]);
543        assert_eq!(iter.next().unwrap().to_vec(), vec![3.0; NUMBER_FEATURES]);
544        // axis 1
545        for column in array.0.axis_iter(Axis(1)) {
546            assert_eq!(column.to_vec(), vec![1.0, 2.0, 3.0]);
547        }
548    }
549}
550
551// #[cfg(feature = "plot_gap")]
552// fn plot_gap_statistic(data: Vec<(usize, f64, f64)>) {
553//     use plotters::prelude::*;
554
555//     // Assuming data is a Vec<(usize, f64, f64)> of (k, gap_k, s_k)
556//     let root_area = BitMapBackend::new("gap_statistic_plot.png", (640, 480)).into_drawing_area();
557//     root_area.fill(&WHITE).unwrap();
558
559//     let max_gap_k = data
560//         .iter()
561//         .map(|(_, gap_k, _)| *gap_k)
562//         .fold(f64::MIN, f64::max);
563//     let min_gap_k = data
564//         .iter()
565//         .map(|(_, gap_k, _)| *gap_k)
566//         .fold(f64::MAX, f64::min);
567//     let max_k = data.iter().map(|(k, _, _)| *k).max().unwrap_or(0);
568
569//     let mut chart = ChartBuilder::on(&root_area)
570//         .caption("Gap Statistic Plot", ("sans-serif", 30))
571//         .margin(5)
572//         .x_label_area_size(30)
573//         .y_label_area_size(30)
574//         .build_cartesian_2d(0..max_k, min_gap_k..max_gap_k)
575//         .unwrap();
576
577//     chart.configure_mesh().draw().unwrap();
578
579//     for (k, gap_k, s_k) in data {
580//         chart
581//             .draw_series(PointSeries::of_element(
582//                 vec![(k, gap_k)],
583//                 5,
584//                 &RED,
585//                 &|coord, size, style| {
586//                     EmptyElement::at(coord) + Circle::new((0, 0), size, style.filled())
587//                 },
588//             ))
589//             .unwrap();
590
591//         // Drawing error bars
592//         chart
593//             .draw_series(LineSeries::new(
594//                 vec![(k, gap_k - s_k), (k, gap_k + s_k)],
595//                 &BLACK,
596//             ))
597//             .unwrap();
598//     }
599
600//     root_area.present().unwrap();
601// }