mecomp_analysis/
clustering.rs

1//! this module contains helpers that wrap the a k-means crate to perform clustering on the data
2//! without having to choose an exact number of clusters.
3//!
4//! Instead, you provide the minimum and maximum number of clusters you want to try, and we'll
5//! use one of a range of methods to determine the optimal number of clusters.
6//!
7//! # References:
8//!
9//! - The gap statistic [R. Tibshirani, G. Walther, and T. Hastie (Standford University, 2001)](https://hastie.su.domains/Papers/gap.pdf)
10//! - The Davies-Bouldin index [wikipedia](https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index)
11
12use linfa::prelude::*;
13use linfa_clustering::{GaussianMixtureModel, KMeans};
14use linfa_nn::distance::{Distance, L2Dist};
15use linfa_tsne::TSneParams;
16use log::{debug, info};
17use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis};
18use ndarray_rand::RandomExt;
19use rand::distributions::Uniform;
20use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
21use statrs::statistics::Statistics;
22
23use crate::{errors::ClusteringError, Analysis, Feature, NUMBER_FEATURES};
24
25pub struct AnalysisArray(pub(crate) Array2<Feature>);
26
27impl From<Vec<Analysis>> for AnalysisArray {
28    #[inline]
29    fn from(data: Vec<Analysis>) -> Self {
30        let shape = (data.len(), NUMBER_FEATURES);
31        debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
32
33        Self(
34            Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
35                .expect("Failed to convert to array, shape mismatch"),
36        )
37    }
38}
39
40impl From<Vec<[Feature; NUMBER_FEATURES]>> for AnalysisArray {
41    #[inline]
42    fn from(data: Vec<[Feature; NUMBER_FEATURES]>) -> Self {
43        let shape = (data.len(), NUMBER_FEATURES);
44        debug_assert_eq!(shape, (data.len(), data[0].len()));
45
46        Self(
47            Array2::from_shape_vec(shape, data.into_iter().flatten().collect())
48                .expect("Failed to convert to array, shape mismatch"),
49        )
50    }
51}
52
53#[derive(Clone, Copy, Debug)]
54#[allow(clippy::module_name_repetitions)]
55pub enum ClusteringMethod {
56    KMeans,
57    GaussianMixtureModel,
58}
59
60impl ClusteringMethod {
61    /// Fit the clustering method to the samples and get the labels
62    #[must_use]
63    fn fit(self, k: usize, samples: &Array2<Feature>) -> Array1<usize> {
64        match self {
65            Self::KMeans => {
66                let model = KMeans::params(k)
67                    // .max_n_iterations(MAX_ITERATIONS)
68                    .fit(&Dataset::from(samples.clone()))
69                    .unwrap();
70                model.predict(samples)
71            }
72            Self::GaussianMixtureModel => {
73                let model = GaussianMixtureModel::params(k)
74                    .init_method(linfa_clustering::GmmInitMethod::KMeans)
75                    .n_runs(10)
76                    .fit(&Dataset::from(samples.clone()))
77                    .unwrap();
78                model.predict(samples)
79            }
80        }
81    }
82}
83
84#[derive(Clone, Copy, Debug)]
85pub enum KOptimal {
86    GapStatistic {
87        /// The number of reference datasets to generate
88        b: usize,
89    },
90    DaviesBouldin,
91}
92
93// log the number of features
94const EMBEDDING_SIZE: usize =
95    //  2;
96    {
97        let log2 = usize::ilog2(NUMBER_FEATURES) as usize;
98        if log2 < 2 {
99            2
100        } else {
101            log2
102        }
103    };
104
105#[allow(clippy::module_name_repetitions)]
106pub struct ClusteringHelper<S>
107where
108    S: Sized,
109{
110    state: S,
111}
112
113pub struct EntryPoint;
114pub struct NotInitialized {
115    /// The embeddings of our input, as a Nx`EMBEDDING_SIZE` array
116    embeddings: Array2<Feature>,
117    pub k_max: usize,
118    pub optimizer: KOptimal,
119    pub clustering_method: ClusteringMethod,
120}
121pub struct Initialized {
122    /// The embeddings of our input, as a Nx`EMBEDDING_SIZE` array
123    embeddings: Array2<Feature>,
124    pub k: usize,
125    pub clustering_method: ClusteringMethod,
126}
127pub struct Finished {
128    /// The labelings of the samples, as a Nx1 array.
129    /// Each element is the cluster that the corresponding sample belongs to.
130    labels: Array1<usize>,
131    pub k: usize,
132}
133
134/// Functions available for all states
135impl ClusteringHelper<EntryPoint> {
136    /// Create a new `KMeansHelper` object
137    ///
138    /// # Errors
139    ///
140    /// Will return an error if there was an error projecting the data into a lower-dimensional space
141    #[allow(clippy::missing_inline_in_public_items)]
142    pub fn new(
143        samples: AnalysisArray,
144        k_max: usize,
145        optimizer: KOptimal,
146        clustering_method: ClusteringMethod,
147    ) -> Result<ClusteringHelper<NotInitialized>, ClusteringError> {
148        // first use the t-SNE algorithm to project the data into a lower-dimensional space
149        debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using t-SNE",);
150
151        if samples.0.nrows() <= 15 {
152            return Err(ClusteringError::SmallLibrary);
153        }
154
155        #[allow(clippy::cast_precision_loss)]
156        let mut embeddings = TSneParams::embedding_size(EMBEDDING_SIZE)
157            .perplexity(f64::max(samples.0.nrows() as f64 / 20., 5.))
158            .approx_threshold(0.5)
159            .transform(samples.0)?;
160
161        debug!("Embeddings shape: {:?}", embeddings.shape());
162
163        // normalize the embeddings so each dimension is between -1 and 1
164        debug!("Normalizing embeddings");
165        for i in 0..EMBEDDING_SIZE {
166            let min = embeddings.column(i).min();
167            let max = embeddings.column(i).max();
168            let range = max - min;
169            embeddings
170                .column_mut(i)
171                .mapv_inplace(|v| ((v - min) / range).mul_add(2., -1.));
172        }
173
174        Ok(ClusteringHelper {
175            state: NotInitialized {
176                embeddings,
177                k_max,
178                optimizer,
179                clustering_method,
180            },
181        })
182    }
183}
184
185/// Functions available for `NotInitialized` state
186impl ClusteringHelper<NotInitialized> {
187    /// Initialize the `KMeansHelper` object
188    ///
189    /// # Errors
190    ///
191    /// Will return an error if there was an error calculating the optimal number of clusters
192    #[inline]
193    pub fn initialize(self) -> Result<ClusteringHelper<Initialized>, ClusteringError> {
194        let k = self.get_optimal_k()?;
195        Ok(ClusteringHelper {
196            state: Initialized {
197                embeddings: self.state.embeddings,
198                k,
199                clustering_method: self.state.clustering_method,
200            },
201        })
202    }
203
204    fn get_optimal_k(&self) -> Result<usize, ClusteringError> {
205        match self.state.optimizer {
206            KOptimal::GapStatistic { b } => self.get_optimal_k_gap_statistic(b),
207            KOptimal::DaviesBouldin => self.get_optimal_k_davies_bouldin(),
208        }
209    }
210
211    /// Get the optimal number of clusters using the gap statistic
212    ///
213    /// # References:
214    ///
215    /// - [R. Tibshirani, G. Walther, and T. Hastie (Standford University, 2001)](https://hastie.su.domains/Papers/gap.pdf)
216    ///
217    /// # Algorithm:
218    ///
219    /// 1. Cluster the observed data, varying the number of clusters from k = 1, …, kmax, and compute the corresponding total within intra-cluster variation Wk.
220    /// 2. Generate B reference data sets with a random uniform distribution. Cluster each of these reference data sets with varying number of clusters k = 1, …, kmax,
221    ///    and compute the corresponding total within intra-cluster variation `W_{kb}`.
222    /// 3. Compute the estimated gap statistic as the deviation of the observed `W_k` value from its expected value `W_kb` under the null hypothesis:
223    ///    `Gap(k)=(1/B) \sum_{b=1}^{B} \log(W^*_{kb}) − \log(W_k)`.
224    ///    Compute also the standard deviation of the statistics.
225    /// 4. Choose the number of clusters as the smallest value of k such that the gap statistic is within one standard deviation of the gap at k+1:
226    ///    `Gap(k)≥Gap(k + 1)−s_{k + 1}`.
227    fn get_optimal_k_gap_statistic(&self, b: usize) -> Result<usize, ClusteringError> {
228        // our reference data sets
229        let reference_data_sets = generate_reference_data_set(self.state.embeddings.view(), b);
230
231        let results = (1..=self.state.k_max)
232            // for each k, cluster the data into k clusters
233            .map(|k| {
234                debug!("Fitting k-means to embeddings with k={k}");
235                let labels = self.state.clustering_method.fit(k, &self.state.embeddings);
236                (k, labels)
237            })
238            // for each k, calculate the gap statistic, and the standard deviation of the statistics
239            .map(|(k, labels)| {
240                // first, we calculate the within intra-cluster variation for the observed data
241                let pairwise_distances =
242                    calc_pairwise_distances(self.state.embeddings.view(), k, labels.view());
243                let w_k = calc_within_dispersion(labels.view(), k, pairwise_distances.view());
244
245                // then, we calculate the within intra-cluster variation for the reference data sets
246                debug!(
247                    "Calculating within intra-cluster variation for reference data sets with k={k}"
248                );
249                let w_kb = reference_data_sets.par_iter().map(|ref_data| {
250                    // cluster the reference data into k clusters
251                    let ref_labels = self.state.clustering_method.fit(k, ref_data);
252                    // calculate the within intra-cluster variation for the reference data
253                    let ref_pairwise_distances =
254                        calc_pairwise_distances(ref_data.view(), k, ref_labels.view());
255                    calc_within_dispersion(ref_labels.view(), k, ref_pairwise_distances.view())
256                });
257
258                // finally, we calculate the gap statistic
259                let w_kb_log_sum = w_kb.clone().map(f64::log2).sum::<f64>();
260                // original formula: l = (1 / B) * sum_b(log(W_kb))
261                #[allow(clippy::cast_precision_loss)]
262                let l = (1.0 / b as f64) * w_kb_log_sum;
263                // original formula: gap_k = (1 / B) * sum_b(log(W_kb)) - log(W_k)
264                #[allow(clippy::cast_precision_loss)]
265                let gap_k = l - w_k.log2();
266                // original formula: sd_k = [(1 / B) * sum_b((log(W_kb) - l)^2)]^0.5
267                #[allow(clippy::cast_precision_loss)]
268                let standard_deviation = ((1.0 / b as f64)
269                    * w_kb.map(|w_kb| (w_kb.log2() - l).powi(2)).sum::<f64>())
270                .sqrt();
271                // original formula: s_k = sd_k * (1 + 1 / B)^0.5
272                // calculate differently to minimize rounding errors
273                #[allow(clippy::cast_precision_loss)]
274                let s_k = standard_deviation * (1.0 + 1.0 / b as f64).sqrt();
275
276                (k, gap_k, s_k)
277            });
278
279        // // plot the gap_k (whisker with s_k) w.r.t. k
280        // #[cfg(feature = "plot_gap")]
281        // plot_gap_statistic(results.clone().collect::<Vec<_>>());
282
283        // finally, we go over the iterator to find the optimal k
284        let (mut optimal_k, mut gap_k_minus_one) = (None, None);
285
286        for (k, gap_k, s_k) in results {
287            info!("k: {k}, gap_k: {gap_k}, s_k: {s_k}");
288
289            if let Some(gap_k_minus_one) = gap_k_minus_one {
290                if gap_k_minus_one >= gap_k - s_k {
291                    info!("Optimal k found: {}", k - 1);
292                    optimal_k = Some(k - 1);
293                    break;
294                }
295            }
296            gap_k_minus_one = Some(gap_k);
297        }
298
299        optimal_k.ok_or(ClusteringError::OptimalKNotFound(self.state.k_max))
300    }
301
302    fn get_optimal_k_davies_bouldin(&self) -> Result<usize, ClusteringError> {
303        todo!();
304    }
305}
306
307/// Convert a vector of Analyses into a 2D array
308///
309/// # Panics
310///
311/// Will panic if the shape of the data does not match the number of features, should never happen
312#[must_use]
313#[inline]
314pub fn convert_to_array(data: Vec<Analysis>) -> AnalysisArray {
315    // Convert vector to Array
316    let shape = (data.len(), NUMBER_FEATURES);
317    debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
318
319    AnalysisArray(
320        Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
321            .expect("Failed to convert to array, shape mismatch"),
322    )
323}
324
325/// Generate B reference data sets with a random uniform distribution
326///
327/// (excerpt from reference paper)
328/// """
329/// We consider two choices for the reference distribution:
330///
331/// 1. generate each reference feature uniformly over the range of the observed values for that feature.
332/// 2. generate the reference features from a uniform distribution over a box aligned with the
333///    principle components of the data.
334///    In detail, if X is our n by p data matrix, we assume that the columns have mean 0 and compute
335///    the singular value decomposition X = UDV^T. We transform via X' = XV and then draw uniform features Z'
336///    over the ranges of the columns of X', as in method (1) above.
337///    Finally, we back-transform via Z=Z'V^T to give reference data Z.
338///
339/// Method (1) has the advantage of simplicity. Method (2) takes into account the shape of the
340/// data distribution and makes the procedure rotationally invariant, as long as the
341/// clustering method itself is invariant
342/// """
343///
344/// For now, we will use method (1) as it is simpler to implement
345/// and we know that our data is already normalized and that
346/// the ordering of features is important, meaning that we can't
347/// rotate the data anyway.
348fn generate_reference_data_set(samples: ArrayView2<Feature>, b: usize) -> Vec<Array2<f64>> {
349    let mut reference_data_sets = Vec::with_capacity(b);
350    for _ in 0..b {
351        reference_data_sets.push(generate_ref_single(samples));
352    }
353
354    reference_data_sets
355}
356fn generate_ref_single(samples: ArrayView2<Feature>) -> Array2<f64> {
357    let feature_distributions = samples
358        .axis_iter(Axis(1))
359        .map(|feature| Array::random(feature.dim(), Uniform::new(feature.min(), feature.max())))
360        .collect::<Vec<_>>();
361    let feature_dists_views = feature_distributions
362        .iter()
363        .map(ndarray::ArrayBase::view)
364        .collect::<Vec<_>>();
365    ndarray::stack(Axis(0), &feature_dists_views)
366        .unwrap()
367        .t()
368        .to_owned()
369}
370
371/// Calculate `W_k`, the within intra-cluster variation for the given clustering
372///
373/// `W_k = \sum_{r=1}^{k} \frac{D_r}{2*n_r}`
374fn calc_within_dispersion(
375    labels: ArrayView1<usize>,
376    k: usize,
377    pairwise_distances: ArrayView1<Feature>,
378) -> Feature {
379    debug_assert_eq!(k, labels.iter().max().unwrap() + 1);
380
381    // we first need to convert our list of labels into a list of the number of samples in each cluster
382    let counts = labels.iter().fold(vec![0u32; k], |mut counts, &label| {
383        counts[label] += 1;
384        counts
385    });
386    // then, we calculate the within intra-cluster variation
387    counts
388        .iter()
389        .zip(pairwise_distances.iter())
390        .map(|(&count, distance)| (1. / (2.0 * f64::from(count))) * distance)
391        .sum()
392}
393
394/// Calculate the `D_r` array, the sum of the pairwise distances in cluster r, for all clusters in the given clustering
395///
396/// # Arguments
397///
398/// - `samples`: The samples in the dataset
399/// - `k`: The number of clusters
400/// - `labels`: The cluster labels for each sample
401fn calc_pairwise_distances(
402    samples: ArrayView2<Feature>,
403    k: usize,
404    labels: ArrayView1<usize>,
405) -> Array1<Feature> {
406    debug_assert_eq!(
407        samples.nrows(),
408        labels.len(),
409        "Samples and labels must have the same length"
410    );
411    debug_assert_eq!(
412        k,
413        labels.iter().max().unwrap() + 1,
414        "Labels must be in the range 0..k"
415    );
416
417    // for each cluster, calculate the sum of the pairwise distances between samples in that cluster
418    (0..k)
419        .map(|k| {
420            (
421                k,
422                samples
423                    .outer_iter()
424                    .zip(labels.iter())
425                    .filter_map(|(s, &l)| (l == k).then_some(s))
426                    .collect::<Vec<_>>(),
427            )
428        })
429        .fold(Array1::zeros(k), |mut distances, (label, cluster)| {
430            distances[label] += cluster
431                .iter()
432                .enumerate()
433                .map(|(i, &a)| {
434                    cluster
435                        .iter()
436                        .skip(i + 1)
437                        .map(|&b| L2Dist.distance(a, b))
438                        .sum::<Feature>()
439                })
440                .sum::<Feature>()
441                * 2.;
442            distances
443        })
444}
445
446/// Functions available for Initialized state
447impl ClusteringHelper<Initialized> {
448    /// Cluster the data into k clusters
449    ///
450    /// # Errors
451    ///
452    /// Will return an error if the clustering fails
453    #[must_use]
454    #[inline]
455    pub fn cluster(self) -> ClusteringHelper<Finished> {
456        let labels = self
457            .state
458            .clustering_method
459            .fit(self.state.k, &self.state.embeddings);
460
461        ClusteringHelper {
462            state: Finished {
463                labels,
464                k: self.state.k,
465            },
466        }
467    }
468}
469
470/// Functions available for Finished state
471impl ClusteringHelper<Finished> {
472    /// use the labels to reorganize the provided samples into clusters
473    #[must_use]
474    #[inline]
475    pub fn extract_analysis_clusters<T: Clone>(&self, samples: Vec<T>) -> Vec<Vec<T>> {
476        let mut clusters = vec![Vec::new(); self.state.k];
477
478        for (sample, &label) in samples.into_iter().zip(self.state.labels.iter()) {
479            clusters[label].push(sample);
480        }
481
482        clusters
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489    use ndarray::{arr1, arr2, s};
490    use pretty_assertions::assert_eq;
491
492    #[test]
493    fn test_generate_reference_data_set() {
494        let data = arr2(&[[10.0, -10.0], [20.0, -20.0], [30.0, -30.0]]);
495
496        let ref_data = generate_ref_single(data.view());
497
498        // First column all vals between 10.0 and 30.0
499        assert!(ref_data
500            .slice(s![.., 0])
501            .iter()
502            .all(|v| *v >= 10.0 && *v <= 30.0));
503
504        // Second column all vals between -10.0 and -30.0
505        assert!(ref_data
506            .slice(s![.., 1])
507            .iter()
508            .all(|v| *v <= -10.0 && *v >= -30.0));
509
510        // check that the shape is correct
511        assert_eq!(ref_data.shape(), data.shape());
512
513        // check that the data is not the same as the original data
514        assert_ne!(ref_data, data);
515    }
516
517    #[test]
518    fn test_pairwise_distances() {
519        let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
520        let labels = arr1(&[0, 0, 1, 1]);
521
522        let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
523
524        assert!(
525            f64::EPSILON > (pairwise_distances[0] - 0.0).abs(),
526            "{} != 0.0",
527            pairwise_distances[0]
528        );
529        assert!(
530            f64::EPSILON > (pairwise_distances[1] - 0.0).abs(),
531            "{} != 0.0",
532            pairwise_distances[1]
533        );
534
535        let samples = arr2(&[[1.0, 2.0], [1.0, 1.0], [2.0, 2.0], [2.0, 3.0]]);
536
537        let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
538
539        assert!(
540            f64::EPSILON > (pairwise_distances[0] - 2.0).abs(),
541            "{} != 2.0",
542            pairwise_distances[0]
543        );
544        assert!(
545            f64::EPSILON > (pairwise_distances[1] - 2.0).abs(),
546            "{} != 2.0",
547            pairwise_distances[1]
548        );
549    }
550
551    #[test]
552    fn test_convert_to_vec() {
553        let data = vec![
554            Analysis::new([1.0; NUMBER_FEATURES]),
555            Analysis::new([2.0; NUMBER_FEATURES]),
556            Analysis::new([3.0; NUMBER_FEATURES]),
557        ];
558
559        let array = convert_to_array(data);
560
561        assert_eq!(array.0.shape(), &[3, NUMBER_FEATURES]);
562        assert!(
563            f64::EPSILON > (array.0[[0, 0]] - 1.0).abs(),
564            "{} != 1.0",
565            array.0[[0, 0]]
566        );
567        assert!(
568            f64::EPSILON > (array.0[[1, 0]] - 2.0).abs(),
569            "{} != 2.0",
570            array.0[[1, 0]]
571        );
572        assert!(
573            f64::EPSILON > (array.0[[2, 0]] - 3.0).abs(),
574            "{} != 3.0",
575            array.0[[2, 0]]
576        );
577
578        // check that axis iteration works how we expect
579        // axis 0
580        let mut iter = array.0.axis_iter(Axis(0));
581        assert_eq!(iter.next().unwrap().to_vec(), vec![1.0; NUMBER_FEATURES]);
582        assert_eq!(iter.next().unwrap().to_vec(), vec![2.0; NUMBER_FEATURES]);
583        assert_eq!(iter.next().unwrap().to_vec(), vec![3.0; NUMBER_FEATURES]);
584        // axis 1
585        for column in array.0.axis_iter(Axis(1)) {
586            assert_eq!(column.to_vec(), vec![1.0, 2.0, 3.0]);
587        }
588    }
589}
590
591// #[cfg(feature = "plot_gap")]
592// fn plot_gap_statistic(data: Vec<(usize, f64, f64)>) {
593//     use plotters::prelude::*;
594
595//     // Assuming data is a Vec<(usize, f64, f64)> of (k, gap_k, s_k)
596//     let root_area = BitMapBackend::new("gap_statistic_plot.png", (640, 480)).into_drawing_area();
597//     root_area.fill(&WHITE).unwrap();
598
599//     let max_gap_k = data
600//         .iter()
601//         .map(|(_, gap_k, _)| *gap_k)
602//         .fold(f64::MIN, f64::max);
603//     let min_gap_k = data
604//         .iter()
605//         .map(|(_, gap_k, _)| *gap_k)
606//         .fold(f64::MAX, f64::min);
607//     let max_k = data.iter().map(|(k, _, _)| *k).max().unwrap_or(0);
608
609//     let mut chart = ChartBuilder::on(&root_area)
610//         .caption("Gap Statistic Plot", ("sans-serif", 30))
611//         .margin(5)
612//         .x_label_area_size(30)
613//         .y_label_area_size(30)
614//         .build_cartesian_2d(0..max_k, min_gap_k..max_gap_k)
615//         .unwrap();
616
617//     chart.configure_mesh().draw().unwrap();
618
619//     for (k, gap_k, s_k) in data {
620//         chart
621//             .draw_series(PointSeries::of_element(
622//                 vec![(k, gap_k)],
623//                 5,
624//                 &RED,
625//                 &|coord, size, style| {
626//                     EmptyElement::at(coord) + Circle::new((0, 0), size, style.filled())
627//                 },
628//             ))
629//             .unwrap();
630
631//         // Drawing error bars
632//         chart
633//             .draw_series(LineSeries::new(
634//                 vec![(k, gap_k - s_k), (k, gap_k + s_k)],
635//                 &BLACK,
636//             ))
637//             .unwrap();
638//     }
639
640//     root_area.present().unwrap();
641// }