mecomp_analysis/
clustering.rs

1//! this module contains helpers that wrap the a k-means crate to perform clustering on the data
2//! without having to choose an exact number of clusters.
3//!
4//! Instead, you provide the minimum and maximum number of clusters you want to try, and we'll
5//! use one of a range of methods to determine the optimal number of clusters.
6//!
7//! # References:
8//!
9//! - The gap statistic [R. Tibshirani, G. Walther, and T. Hastie (Standford University, 2001)](https://hastie.su.domains/Papers/gap.pdf)
10//! - The Davies-Bouldin index [wikipedia](https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index)
11
12use linfa::prelude::*;
13use linfa_clustering::{GaussianMixtureModel, KMeans};
14use linfa_nn::distance::{Distance, L2Dist};
15use linfa_tsne::TSneParams;
16use log::{debug, info};
17use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis};
18use ndarray_rand::RandomExt;
19use rand::distributions::Uniform;
20use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
21use statrs::statistics::Statistics;
22
23use crate::{Analysis, Feature, NUMBER_FEATURES, errors::ClusteringError};
24
25pub struct AnalysisArray(pub(crate) Array2<Feature>);
26
27impl From<Vec<Analysis>> for AnalysisArray {
28    #[inline]
29    fn from(data: Vec<Analysis>) -> Self {
30        let shape = (data.len(), NUMBER_FEATURES);
31        debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
32
33        Self(
34            Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
35                .expect("Failed to convert to array, shape mismatch"),
36        )
37    }
38}
39
40impl From<Vec<[Feature; NUMBER_FEATURES]>> for AnalysisArray {
41    #[inline]
42    fn from(data: Vec<[Feature; NUMBER_FEATURES]>) -> Self {
43        let shape = (data.len(), NUMBER_FEATURES);
44        debug_assert_eq!(shape, (data.len(), data[0].len()));
45
46        Self(
47            Array2::from_shape_vec(shape, data.into_iter().flatten().collect())
48                .expect("Failed to convert to array, shape mismatch"),
49        )
50    }
51}
52
53#[derive(Clone, Copy, Debug)]
54#[allow(clippy::module_name_repetitions)]
55pub enum ClusteringMethod {
56    KMeans,
57    GaussianMixtureModel,
58}
59
60impl ClusteringMethod {
61    /// Fit the clustering method to the samples and get the labels
62    #[must_use]
63    fn fit(self, k: usize, samples: &Array2<Feature>) -> Array1<usize> {
64        match self {
65            Self::KMeans => {
66                let model = KMeans::params(k)
67                    // .max_n_iterations(MAX_ITERATIONS)
68                    .fit(&Dataset::from(samples.clone()))
69                    .unwrap();
70                model.predict(samples)
71            }
72            Self::GaussianMixtureModel => {
73                let model = GaussianMixtureModel::params(k)
74                    .init_method(linfa_clustering::GmmInitMethod::KMeans)
75                    .n_runs(10)
76                    .fit(&Dataset::from(samples.clone()))
77                    .unwrap();
78                model.predict(samples)
79            }
80        }
81    }
82}
83
84#[derive(Clone, Copy, Debug)]
85pub enum KOptimal {
86    GapStatistic {
87        /// The number of reference datasets to generate
88        b: usize,
89    },
90    DaviesBouldin,
91}
92
93// log the number of features
94const EMBEDDING_SIZE: usize =
95    //  2;
96    {
97        let log2 = usize::ilog2(NUMBER_FEATURES) as usize;
98        if log2 < 2 { 2 } else { log2 }
99    };
100
101#[allow(clippy::module_name_repetitions)]
102pub struct ClusteringHelper<S>
103where
104    S: Sized,
105{
106    state: S,
107}
108
109pub struct EntryPoint;
110pub struct NotInitialized {
111    /// The embeddings of our input, as a Nx`EMBEDDING_SIZE` array
112    embeddings: Array2<Feature>,
113    pub k_max: usize,
114    pub optimizer: KOptimal,
115    pub clustering_method: ClusteringMethod,
116}
117pub struct Initialized {
118    /// The embeddings of our input, as a Nx`EMBEDDING_SIZE` array
119    embeddings: Array2<Feature>,
120    pub k: usize,
121    pub clustering_method: ClusteringMethod,
122}
123pub struct Finished {
124    /// The labelings of the samples, as a Nx1 array.
125    /// Each element is the cluster that the corresponding sample belongs to.
126    labels: Array1<usize>,
127    pub k: usize,
128}
129
130/// Functions available for all states
131impl ClusteringHelper<EntryPoint> {
132    /// Create a new `KMeansHelper` object
133    ///
134    /// # Errors
135    ///
136    /// Will return an error if there was an error projecting the data into a lower-dimensional space
137    #[allow(clippy::missing_inline_in_public_items)]
138    pub fn new(
139        samples: AnalysisArray,
140        k_max: usize,
141        optimizer: KOptimal,
142        clustering_method: ClusteringMethod,
143    ) -> Result<ClusteringHelper<NotInitialized>, ClusteringError> {
144        // first use the t-SNE algorithm to project the data into a lower-dimensional space
145        debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using t-SNE",);
146
147        if samples.0.nrows() <= 15 {
148            return Err(ClusteringError::SmallLibrary);
149        }
150
151        #[allow(clippy::cast_precision_loss)]
152        let mut embeddings = TSneParams::embedding_size(EMBEDDING_SIZE)
153            .perplexity(f64::max(samples.0.nrows() as f64 / 20., 5.))
154            .approx_threshold(0.5)
155            .transform(samples.0)?;
156
157        debug!("Embeddings shape: {:?}", embeddings.shape());
158
159        // normalize the embeddings so each dimension is between -1 and 1
160        debug!("Normalizing embeddings");
161        for i in 0..EMBEDDING_SIZE {
162            let min = embeddings.column(i).min();
163            let max = embeddings.column(i).max();
164            let range = max - min;
165            embeddings
166                .column_mut(i)
167                .mapv_inplace(|v| ((v - min) / range).mul_add(2., -1.));
168        }
169
170        Ok(ClusteringHelper {
171            state: NotInitialized {
172                embeddings,
173                k_max,
174                optimizer,
175                clustering_method,
176            },
177        })
178    }
179}
180
181/// Functions available for `NotInitialized` state
182impl ClusteringHelper<NotInitialized> {
183    /// Initialize the `KMeansHelper` object
184    ///
185    /// # Errors
186    ///
187    /// Will return an error if there was an error calculating the optimal number of clusters
188    #[inline]
189    pub fn initialize(self) -> Result<ClusteringHelper<Initialized>, ClusteringError> {
190        let k = self.get_optimal_k()?;
191        Ok(ClusteringHelper {
192            state: Initialized {
193                embeddings: self.state.embeddings,
194                k,
195                clustering_method: self.state.clustering_method,
196            },
197        })
198    }
199
200    fn get_optimal_k(&self) -> Result<usize, ClusteringError> {
201        match self.state.optimizer {
202            KOptimal::GapStatistic { b } => self.get_optimal_k_gap_statistic(b),
203            KOptimal::DaviesBouldin => self.get_optimal_k_davies_bouldin(),
204        }
205    }
206
207    /// Get the optimal number of clusters using the gap statistic
208    ///
209    /// # References:
210    ///
211    /// - [R. Tibshirani, G. Walther, and T. Hastie (Standford University, 2001)](https://hastie.su.domains/Papers/gap.pdf)
212    ///
213    /// # Algorithm:
214    ///
215    /// 1. Cluster the observed data, varying the number of clusters from k = 1, …, kmax, and compute the corresponding total within intra-cluster variation Wk.
216    /// 2. Generate B reference data sets with a random uniform distribution. Cluster each of these reference data sets with varying number of clusters k = 1, …, kmax,
217    ///    and compute the corresponding total within intra-cluster variation `W_{kb}`.
218    /// 3. Compute the estimated gap statistic as the deviation of the observed `W_k` value from its expected value `W_kb` under the null hypothesis:
219    ///    `Gap(k)=(1/B) \sum_{b=1}^{B} \log(W^*_{kb}) − \log(W_k)`.
220    ///    Compute also the standard deviation of the statistics.
221    /// 4. Choose the number of clusters as the smallest value of k such that the gap statistic is within one standard deviation of the gap at k+1:
222    ///    `Gap(k)≥Gap(k + 1)−s_{k + 1}`.
223    fn get_optimal_k_gap_statistic(&self, b: usize) -> Result<usize, ClusteringError> {
224        // our reference data sets
225        let reference_data_sets = generate_reference_data_set(self.state.embeddings.view(), b);
226
227        let results = (1..=self.state.k_max)
228            // for each k, cluster the data into k clusters
229            .map(|k| {
230                debug!("Fitting k-means to embeddings with k={k}");
231                let labels = self.state.clustering_method.fit(k, &self.state.embeddings);
232                (k, labels)
233            })
234            // for each k, calculate the gap statistic, and the standard deviation of the statistics
235            .map(|(k, labels)| {
236                // first, we calculate the within intra-cluster variation for the observed data
237                let pairwise_distances =
238                    calc_pairwise_distances(self.state.embeddings.view(), k, labels.view());
239                let w_k = calc_within_dispersion(labels.view(), k, pairwise_distances.view());
240
241                // then, we calculate the within intra-cluster variation for the reference data sets
242                debug!(
243                    "Calculating within intra-cluster variation for reference data sets with k={k}"
244                );
245                let w_kb = reference_data_sets.par_iter().map(|ref_data| {
246                    // cluster the reference data into k clusters
247                    let ref_labels = self.state.clustering_method.fit(k, ref_data);
248                    // calculate the within intra-cluster variation for the reference data
249                    let ref_pairwise_distances =
250                        calc_pairwise_distances(ref_data.view(), k, ref_labels.view());
251                    calc_within_dispersion(ref_labels.view(), k, ref_pairwise_distances.view())
252                });
253
254                // finally, we calculate the gap statistic
255                let w_kb_log_sum = w_kb.clone().map(f64::log2).sum::<f64>();
256                // original formula: l = (1 / B) * sum_b(log(W_kb))
257                #[allow(clippy::cast_precision_loss)]
258                let l = (1.0 / b as f64) * w_kb_log_sum;
259                // original formula: gap_k = (1 / B) * sum_b(log(W_kb)) - log(W_k)
260                #[allow(clippy::cast_precision_loss)]
261                let gap_k = l - w_k.log2();
262                // original formula: sd_k = [(1 / B) * sum_b((log(W_kb) - l)^2)]^0.5
263                #[allow(clippy::cast_precision_loss)]
264                let standard_deviation = ((1.0 / b as f64)
265                    * w_kb.map(|w_kb| (w_kb.log2() - l).powi(2)).sum::<f64>())
266                .sqrt();
267                // original formula: s_k = sd_k * (1 + 1 / B)^0.5
268                // calculate differently to minimize rounding errors
269                #[allow(clippy::cast_precision_loss)]
270                let s_k = standard_deviation * (1.0 + 1.0 / b as f64).sqrt();
271
272                (k, gap_k, s_k)
273            });
274
275        // // plot the gap_k (whisker with s_k) w.r.t. k
276        // #[cfg(feature = "plot_gap")]
277        // plot_gap_statistic(results.clone().collect::<Vec<_>>());
278
279        // finally, we go over the iterator to find the optimal k
280        let (mut optimal_k, mut gap_k_minus_one) = (None, None);
281
282        for (k, gap_k, s_k) in results {
283            info!("k: {k}, gap_k: {gap_k}, s_k: {s_k}");
284
285            if let Some(gap_k_minus_one) = gap_k_minus_one {
286                if gap_k_minus_one >= gap_k - s_k {
287                    info!("Optimal k found: {}", k - 1);
288                    optimal_k = Some(k - 1);
289                    break;
290                }
291            }
292            gap_k_minus_one = Some(gap_k);
293        }
294
295        optimal_k.ok_or(ClusteringError::OptimalKNotFound(self.state.k_max))
296    }
297
298    fn get_optimal_k_davies_bouldin(&self) -> Result<usize, ClusteringError> {
299        todo!();
300    }
301}
302
303/// Convert a vector of Analyses into a 2D array
304///
305/// # Panics
306///
307/// Will panic if the shape of the data does not match the number of features, should never happen
308#[must_use]
309#[inline]
310pub fn convert_to_array(data: Vec<Analysis>) -> AnalysisArray {
311    // Convert vector to Array
312    let shape = (data.len(), NUMBER_FEATURES);
313    debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
314
315    AnalysisArray(
316        Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
317            .expect("Failed to convert to array, shape mismatch"),
318    )
319}
320
321/// Generate B reference data sets with a random uniform distribution
322///
323/// (excerpt from reference paper)
324/// """
325/// We consider two choices for the reference distribution:
326///
327/// 1. generate each reference feature uniformly over the range of the observed values for that feature.
328/// 2. generate the reference features from a uniform distribution over a box aligned with the
329///    principle components of the data.
330///    In detail, if X is our n by p data matrix, we assume that the columns have mean 0 and compute
331///    the singular value decomposition X = UDV^T. We transform via X' = XV and then draw uniform features Z'
332///    over the ranges of the columns of X', as in method (1) above.
333///    Finally, we back-transform via Z=Z'V^T to give reference data Z.
334///
335/// Method (1) has the advantage of simplicity. Method (2) takes into account the shape of the
336/// data distribution and makes the procedure rotationally invariant, as long as the
337/// clustering method itself is invariant
338/// """
339///
340/// For now, we will use method (1) as it is simpler to implement
341/// and we know that our data is already normalized and that
342/// the ordering of features is important, meaning that we can't
343/// rotate the data anyway.
344fn generate_reference_data_set(samples: ArrayView2<Feature>, b: usize) -> Vec<Array2<f64>> {
345    let mut reference_data_sets = Vec::with_capacity(b);
346    for _ in 0..b {
347        reference_data_sets.push(generate_ref_single(samples));
348    }
349
350    reference_data_sets
351}
352fn generate_ref_single(samples: ArrayView2<Feature>) -> Array2<f64> {
353    let feature_distributions = samples
354        .axis_iter(Axis(1))
355        .map(|feature| Array::random(feature.dim(), Uniform::new(feature.min(), feature.max())))
356        .collect::<Vec<_>>();
357    let feature_dists_views = feature_distributions
358        .iter()
359        .map(ndarray::ArrayBase::view)
360        .collect::<Vec<_>>();
361    ndarray::stack(Axis(0), &feature_dists_views)
362        .unwrap()
363        .t()
364        .to_owned()
365}
366
367/// Calculate `W_k`, the within intra-cluster variation for the given clustering
368///
369/// `W_k = \sum_{r=1}^{k} \frac{D_r}{2*n_r}`
370fn calc_within_dispersion(
371    labels: ArrayView1<usize>,
372    k: usize,
373    pairwise_distances: ArrayView1<Feature>,
374) -> Feature {
375    debug_assert_eq!(k, labels.iter().max().unwrap() + 1);
376
377    // we first need to convert our list of labels into a list of the number of samples in each cluster
378    let counts = labels.iter().fold(vec![0u32; k], |mut counts, &label| {
379        counts[label] += 1;
380        counts
381    });
382    // then, we calculate the within intra-cluster variation
383    counts
384        .iter()
385        .zip(pairwise_distances.iter())
386        .map(|(&count, distance)| (1. / (2.0 * f64::from(count))) * distance)
387        .sum()
388}
389
390/// Calculate the `D_r` array, the sum of the pairwise distances in cluster r, for all clusters in the given clustering
391///
392/// # Arguments
393///
394/// - `samples`: The samples in the dataset
395/// - `k`: The number of clusters
396/// - `labels`: The cluster labels for each sample
397fn calc_pairwise_distances(
398    samples: ArrayView2<Feature>,
399    k: usize,
400    labels: ArrayView1<usize>,
401) -> Array1<Feature> {
402    debug_assert_eq!(
403        samples.nrows(),
404        labels.len(),
405        "Samples and labels must have the same length"
406    );
407    debug_assert_eq!(
408        k,
409        labels.iter().max().unwrap() + 1,
410        "Labels must be in the range 0..k"
411    );
412
413    // for each cluster, calculate the sum of the pairwise distances between samples in that cluster
414    (0..k)
415        .map(|k| {
416            (
417                k,
418                samples
419                    .outer_iter()
420                    .zip(labels.iter())
421                    .filter_map(|(s, &l)| (l == k).then_some(s))
422                    .collect::<Vec<_>>(),
423            )
424        })
425        .fold(Array1::zeros(k), |mut distances, (label, cluster)| {
426            distances[label] += cluster
427                .iter()
428                .enumerate()
429                .map(|(i, &a)| {
430                    cluster
431                        .iter()
432                        .skip(i + 1)
433                        .map(|&b| L2Dist.distance(a, b))
434                        .sum::<Feature>()
435                })
436                .sum::<Feature>()
437                * 2.;
438            distances
439        })
440}
441
442/// Functions available for Initialized state
443impl ClusteringHelper<Initialized> {
444    /// Cluster the data into k clusters
445    ///
446    /// # Errors
447    ///
448    /// Will return an error if the clustering fails
449    #[must_use]
450    #[inline]
451    pub fn cluster(self) -> ClusteringHelper<Finished> {
452        let labels = self
453            .state
454            .clustering_method
455            .fit(self.state.k, &self.state.embeddings);
456
457        ClusteringHelper {
458            state: Finished {
459                labels,
460                k: self.state.k,
461            },
462        }
463    }
464}
465
466/// Functions available for Finished state
467impl ClusteringHelper<Finished> {
468    /// use the labels to reorganize the provided samples into clusters
469    #[must_use]
470    #[inline]
471    pub fn extract_analysis_clusters<T: Clone>(&self, samples: Vec<T>) -> Vec<Vec<T>> {
472        let mut clusters = vec![Vec::new(); self.state.k];
473
474        for (sample, &label) in samples.into_iter().zip(self.state.labels.iter()) {
475            clusters[label].push(sample);
476        }
477
478        clusters
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485    use ndarray::{arr1, arr2, s};
486    use pretty_assertions::assert_eq;
487
488    #[test]
489    fn test_generate_reference_data_set() {
490        let data = arr2(&[[10.0, -10.0], [20.0, -20.0], [30.0, -30.0]]);
491
492        let ref_data = generate_ref_single(data.view());
493
494        // First column all vals between 10.0 and 30.0
495        assert!(
496            ref_data
497                .slice(s![.., 0])
498                .iter()
499                .all(|v| *v >= 10.0 && *v <= 30.0)
500        );
501
502        // Second column all vals between -10.0 and -30.0
503        assert!(
504            ref_data
505                .slice(s![.., 1])
506                .iter()
507                .all(|v| *v <= -10.0 && *v >= -30.0)
508        );
509
510        // check that the shape is correct
511        assert_eq!(ref_data.shape(), data.shape());
512
513        // check that the data is not the same as the original data
514        assert_ne!(ref_data, data);
515    }
516
517    #[test]
518    fn test_pairwise_distances() {
519        let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
520        let labels = arr1(&[0, 0, 1, 1]);
521
522        let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
523
524        assert!(
525            f64::EPSILON > (pairwise_distances[0] - 0.0).abs(),
526            "{} != 0.0",
527            pairwise_distances[0]
528        );
529        assert!(
530            f64::EPSILON > (pairwise_distances[1] - 0.0).abs(),
531            "{} != 0.0",
532            pairwise_distances[1]
533        );
534
535        let samples = arr2(&[[1.0, 2.0], [1.0, 1.0], [2.0, 2.0], [2.0, 3.0]]);
536
537        let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
538
539        assert!(
540            f64::EPSILON > (pairwise_distances[0] - 2.0).abs(),
541            "{} != 2.0",
542            pairwise_distances[0]
543        );
544        assert!(
545            f64::EPSILON > (pairwise_distances[1] - 2.0).abs(),
546            "{} != 2.0",
547            pairwise_distances[1]
548        );
549    }
550
551    #[test]
552    fn test_convert_to_vec() {
553        let data = vec![
554            Analysis::new([1.0; NUMBER_FEATURES]),
555            Analysis::new([2.0; NUMBER_FEATURES]),
556            Analysis::new([3.0; NUMBER_FEATURES]),
557        ];
558
559        let array = convert_to_array(data);
560
561        assert_eq!(array.0.shape(), &[3, NUMBER_FEATURES]);
562        assert!(
563            f64::EPSILON > (array.0[[0, 0]] - 1.0).abs(),
564            "{} != 1.0",
565            array.0[[0, 0]]
566        );
567        assert!(
568            f64::EPSILON > (array.0[[1, 0]] - 2.0).abs(),
569            "{} != 2.0",
570            array.0[[1, 0]]
571        );
572        assert!(
573            f64::EPSILON > (array.0[[2, 0]] - 3.0).abs(),
574            "{} != 3.0",
575            array.0[[2, 0]]
576        );
577
578        // check that axis iteration works how we expect
579        // axis 0
580        let mut iter = array.0.axis_iter(Axis(0));
581        assert_eq!(iter.next().unwrap().to_vec(), vec![1.0; NUMBER_FEATURES]);
582        assert_eq!(iter.next().unwrap().to_vec(), vec![2.0; NUMBER_FEATURES]);
583        assert_eq!(iter.next().unwrap().to_vec(), vec![3.0; NUMBER_FEATURES]);
584        // axis 1
585        for column in array.0.axis_iter(Axis(1)) {
586            assert_eq!(column.to_vec(), vec![1.0, 2.0, 3.0]);
587        }
588    }
589}
590
591// #[cfg(feature = "plot_gap")]
592// fn plot_gap_statistic(data: Vec<(usize, f64, f64)>) {
593//     use plotters::prelude::*;
594
595//     // Assuming data is a Vec<(usize, f64, f64)> of (k, gap_k, s_k)
596//     let root_area = BitMapBackend::new("gap_statistic_plot.png", (640, 480)).into_drawing_area();
597//     root_area.fill(&WHITE).unwrap();
598
599//     let max_gap_k = data
600//         .iter()
601//         .map(|(_, gap_k, _)| *gap_k)
602//         .fold(f64::MIN, f64::max);
603//     let min_gap_k = data
604//         .iter()
605//         .map(|(_, gap_k, _)| *gap_k)
606//         .fold(f64::MAX, f64::min);
607//     let max_k = data.iter().map(|(k, _, _)| *k).max().unwrap_or(0);
608
609//     let mut chart = ChartBuilder::on(&root_area)
610//         .caption("Gap Statistic Plot", ("sans-serif", 30))
611//         .margin(5)
612//         .x_label_area_size(30)
613//         .y_label_area_size(30)
614//         .build_cartesian_2d(0..max_k, min_gap_k..max_gap_k)
615//         .unwrap();
616
617//     chart.configure_mesh().draw().unwrap();
618
619//     for (k, gap_k, s_k) in data {
620//         chart
621//             .draw_series(PointSeries::of_element(
622//                 vec![(k, gap_k)],
623//                 5,
624//                 &RED,
625//                 &|coord, size, style| {
626//                     EmptyElement::at(coord) + Circle::new((0, 0), size, style.filled())
627//                 },
628//             ))
629//             .unwrap();
630
631//         // Drawing error bars
632//         chart
633//             .draw_series(LineSeries::new(
634//                 vec![(k, gap_k - s_k), (k, gap_k + s_k)],
635//                 &BLACK,
636//             ))
637//             .unwrap();
638//     }
639
640//     root_area.present().unwrap();
641// }