1use linfa::prelude::*;
13use linfa_clustering::{GaussianMixtureModel, KMeans};
14use linfa_nn::distance::{Distance, L2Dist};
15use linfa_tsne::TSneParams;
16use log::{debug, info};
17use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis};
18use ndarray_rand::RandomExt;
19use rand::distributions::Uniform;
20use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
21use statrs::statistics::Statistics;
22
23use crate::{Analysis, Feature, NUMBER_FEATURES, errors::ClusteringError};
24
25pub struct AnalysisArray(pub(crate) Array2<Feature>);
26
27impl From<Vec<Analysis>> for AnalysisArray {
28 #[inline]
29 fn from(data: Vec<Analysis>) -> Self {
30 let shape = (data.len(), NUMBER_FEATURES);
31 debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
32
33 Self(
34 Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
35 .expect("Failed to convert to array, shape mismatch"),
36 )
37 }
38}
39
40impl From<Vec<[Feature; NUMBER_FEATURES]>> for AnalysisArray {
41 #[inline]
42 fn from(data: Vec<[Feature; NUMBER_FEATURES]>) -> Self {
43 let shape = (data.len(), NUMBER_FEATURES);
44 debug_assert_eq!(shape, (data.len(), data[0].len()));
45
46 Self(
47 Array2::from_shape_vec(shape, data.into_iter().flatten().collect())
48 .expect("Failed to convert to array, shape mismatch"),
49 )
50 }
51}
52
53#[derive(Clone, Copy, Debug)]
54#[allow(clippy::module_name_repetitions)]
55pub enum ClusteringMethod {
56 KMeans,
57 GaussianMixtureModel,
58}
59
60impl ClusteringMethod {
61 #[must_use]
63 fn fit(self, k: usize, samples: &Array2<Feature>) -> Array1<usize> {
64 match self {
65 Self::KMeans => {
66 let model = KMeans::params(k)
67 .fit(&Dataset::from(samples.clone()))
69 .unwrap();
70 model.predict(samples)
71 }
72 Self::GaussianMixtureModel => {
73 let model = GaussianMixtureModel::params(k)
74 .init_method(linfa_clustering::GmmInitMethod::KMeans)
75 .n_runs(10)
76 .fit(&Dataset::from(samples.clone()))
77 .unwrap();
78 model.predict(samples)
79 }
80 }
81 }
82}
83
84#[derive(Clone, Copy, Debug)]
85pub enum KOptimal {
86 GapStatistic {
87 b: usize,
89 },
90 DaviesBouldin,
91}
92
93const EMBEDDING_SIZE: usize =
95 {
97 let log2 = usize::ilog2(NUMBER_FEATURES) as usize;
98 if log2 < 2 { 2 } else { log2 }
99 };
100
101#[allow(clippy::module_name_repetitions)]
102pub struct ClusteringHelper<S>
103where
104 S: Sized,
105{
106 state: S,
107}
108
109pub struct EntryPoint;
110pub struct NotInitialized {
111 embeddings: Array2<Feature>,
113 pub k_max: usize,
114 pub optimizer: KOptimal,
115 pub clustering_method: ClusteringMethod,
116}
117pub struct Initialized {
118 embeddings: Array2<Feature>,
120 pub k: usize,
121 pub clustering_method: ClusteringMethod,
122}
123pub struct Finished {
124 labels: Array1<usize>,
127 pub k: usize,
128}
129
130impl ClusteringHelper<EntryPoint> {
132 #[allow(clippy::missing_inline_in_public_items)]
138 pub fn new(
139 samples: AnalysisArray,
140 k_max: usize,
141 optimizer: KOptimal,
142 clustering_method: ClusteringMethod,
143 ) -> Result<ClusteringHelper<NotInitialized>, ClusteringError> {
144 debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using t-SNE",);
146
147 if samples.0.nrows() <= 15 {
148 return Err(ClusteringError::SmallLibrary);
149 }
150
151 #[allow(clippy::cast_precision_loss)]
152 let mut embeddings = TSneParams::embedding_size(EMBEDDING_SIZE)
153 .perplexity(f64::max(samples.0.nrows() as f64 / 20., 5.))
154 .approx_threshold(0.5)
155 .transform(samples.0)?;
156
157 debug!("Embeddings shape: {:?}", embeddings.shape());
158
159 debug!("Normalizing embeddings");
161 for i in 0..EMBEDDING_SIZE {
162 let min = embeddings.column(i).min();
163 let max = embeddings.column(i).max();
164 let range = max - min;
165 embeddings
166 .column_mut(i)
167 .mapv_inplace(|v| ((v - min) / range).mul_add(2., -1.));
168 }
169
170 Ok(ClusteringHelper {
171 state: NotInitialized {
172 embeddings,
173 k_max,
174 optimizer,
175 clustering_method,
176 },
177 })
178 }
179}
180
181impl ClusteringHelper<NotInitialized> {
183 #[inline]
189 pub fn initialize(self) -> Result<ClusteringHelper<Initialized>, ClusteringError> {
190 let k = self.get_optimal_k()?;
191 Ok(ClusteringHelper {
192 state: Initialized {
193 embeddings: self.state.embeddings,
194 k,
195 clustering_method: self.state.clustering_method,
196 },
197 })
198 }
199
200 fn get_optimal_k(&self) -> Result<usize, ClusteringError> {
201 match self.state.optimizer {
202 KOptimal::GapStatistic { b } => self.get_optimal_k_gap_statistic(b),
203 KOptimal::DaviesBouldin => self.get_optimal_k_davies_bouldin(),
204 }
205 }
206
207 fn get_optimal_k_gap_statistic(&self, b: usize) -> Result<usize, ClusteringError> {
224 let reference_data_sets = generate_reference_data_set(self.state.embeddings.view(), b);
226
227 let results = (1..=self.state.k_max)
228 .map(|k| {
230 debug!("Fitting k-means to embeddings with k={k}");
231 let labels = self.state.clustering_method.fit(k, &self.state.embeddings);
232 (k, labels)
233 })
234 .map(|(k, labels)| {
236 let pairwise_distances =
238 calc_pairwise_distances(self.state.embeddings.view(), k, labels.view());
239 let w_k = calc_within_dispersion(labels.view(), k, pairwise_distances.view());
240
241 debug!(
243 "Calculating within intra-cluster variation for reference data sets with k={k}"
244 );
245 let w_kb = reference_data_sets.par_iter().map(|ref_data| {
246 let ref_labels = self.state.clustering_method.fit(k, ref_data);
248 let ref_pairwise_distances =
250 calc_pairwise_distances(ref_data.view(), k, ref_labels.view());
251 calc_within_dispersion(ref_labels.view(), k, ref_pairwise_distances.view())
252 });
253
254 let w_kb_log_sum = w_kb.clone().map(f64::log2).sum::<f64>();
256 #[allow(clippy::cast_precision_loss)]
258 let l = (1.0 / b as f64) * w_kb_log_sum;
259 #[allow(clippy::cast_precision_loss)]
261 let gap_k = l - w_k.log2();
262 #[allow(clippy::cast_precision_loss)]
264 let standard_deviation = ((1.0 / b as f64)
265 * w_kb.map(|w_kb| (w_kb.log2() - l).powi(2)).sum::<f64>())
266 .sqrt();
267 #[allow(clippy::cast_precision_loss)]
270 let s_k = standard_deviation * (1.0 + 1.0 / b as f64).sqrt();
271
272 (k, gap_k, s_k)
273 });
274
275 let (mut optimal_k, mut gap_k_minus_one) = (None, None);
281
282 for (k, gap_k, s_k) in results {
283 info!("k: {k}, gap_k: {gap_k}, s_k: {s_k}");
284
285 if let Some(gap_k_minus_one) = gap_k_minus_one {
286 if gap_k_minus_one >= gap_k - s_k {
287 info!("Optimal k found: {}", k - 1);
288 optimal_k = Some(k - 1);
289 break;
290 }
291 }
292 gap_k_minus_one = Some(gap_k);
293 }
294
295 optimal_k.ok_or(ClusteringError::OptimalKNotFound(self.state.k_max))
296 }
297
298 fn get_optimal_k_davies_bouldin(&self) -> Result<usize, ClusteringError> {
299 todo!();
300 }
301}
302
303#[must_use]
309#[inline]
310pub fn convert_to_array(data: Vec<Analysis>) -> AnalysisArray {
311 let shape = (data.len(), NUMBER_FEATURES);
313 debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
314
315 AnalysisArray(
316 Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
317 .expect("Failed to convert to array, shape mismatch"),
318 )
319}
320
321fn generate_reference_data_set(samples: ArrayView2<Feature>, b: usize) -> Vec<Array2<f64>> {
345 let mut reference_data_sets = Vec::with_capacity(b);
346 for _ in 0..b {
347 reference_data_sets.push(generate_ref_single(samples));
348 }
349
350 reference_data_sets
351}
352fn generate_ref_single(samples: ArrayView2<Feature>) -> Array2<f64> {
353 let feature_distributions = samples
354 .axis_iter(Axis(1))
355 .map(|feature| Array::random(feature.dim(), Uniform::new(feature.min(), feature.max())))
356 .collect::<Vec<_>>();
357 let feature_dists_views = feature_distributions
358 .iter()
359 .map(ndarray::ArrayBase::view)
360 .collect::<Vec<_>>();
361 ndarray::stack(Axis(0), &feature_dists_views)
362 .unwrap()
363 .t()
364 .to_owned()
365}
366
367fn calc_within_dispersion(
371 labels: ArrayView1<usize>,
372 k: usize,
373 pairwise_distances: ArrayView1<Feature>,
374) -> Feature {
375 debug_assert_eq!(k, labels.iter().max().unwrap() + 1);
376
377 let counts = labels.iter().fold(vec![0u32; k], |mut counts, &label| {
379 counts[label] += 1;
380 counts
381 });
382 counts
384 .iter()
385 .zip(pairwise_distances.iter())
386 .map(|(&count, distance)| (1. / (2.0 * f64::from(count))) * distance)
387 .sum()
388}
389
390fn calc_pairwise_distances(
398 samples: ArrayView2<Feature>,
399 k: usize,
400 labels: ArrayView1<usize>,
401) -> Array1<Feature> {
402 debug_assert_eq!(
403 samples.nrows(),
404 labels.len(),
405 "Samples and labels must have the same length"
406 );
407 debug_assert_eq!(
408 k,
409 labels.iter().max().unwrap() + 1,
410 "Labels must be in the range 0..k"
411 );
412
413 (0..k)
415 .map(|k| {
416 (
417 k,
418 samples
419 .outer_iter()
420 .zip(labels.iter())
421 .filter_map(|(s, &l)| (l == k).then_some(s))
422 .collect::<Vec<_>>(),
423 )
424 })
425 .fold(Array1::zeros(k), |mut distances, (label, cluster)| {
426 distances[label] += cluster
427 .iter()
428 .enumerate()
429 .map(|(i, &a)| {
430 cluster
431 .iter()
432 .skip(i + 1)
433 .map(|&b| L2Dist.distance(a, b))
434 .sum::<Feature>()
435 })
436 .sum::<Feature>()
437 * 2.;
438 distances
439 })
440}
441
442impl ClusteringHelper<Initialized> {
444 #[must_use]
450 #[inline]
451 pub fn cluster(self) -> ClusteringHelper<Finished> {
452 let labels = self
453 .state
454 .clustering_method
455 .fit(self.state.k, &self.state.embeddings);
456
457 ClusteringHelper {
458 state: Finished {
459 labels,
460 k: self.state.k,
461 },
462 }
463 }
464}
465
466impl ClusteringHelper<Finished> {
468 #[must_use]
470 #[inline]
471 pub fn extract_analysis_clusters<T: Clone>(&self, samples: Vec<T>) -> Vec<Vec<T>> {
472 let mut clusters = vec![Vec::new(); self.state.k];
473
474 for (sample, &label) in samples.into_iter().zip(self.state.labels.iter()) {
475 clusters[label].push(sample);
476 }
477
478 clusters
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485 use ndarray::{arr1, arr2, s};
486 use pretty_assertions::assert_eq;
487
488 #[test]
489 fn test_generate_reference_data_set() {
490 let data = arr2(&[[10.0, -10.0], [20.0, -20.0], [30.0, -30.0]]);
491
492 let ref_data = generate_ref_single(data.view());
493
494 assert!(
496 ref_data
497 .slice(s![.., 0])
498 .iter()
499 .all(|v| *v >= 10.0 && *v <= 30.0)
500 );
501
502 assert!(
504 ref_data
505 .slice(s![.., 1])
506 .iter()
507 .all(|v| *v <= -10.0 && *v >= -30.0)
508 );
509
510 assert_eq!(ref_data.shape(), data.shape());
512
513 assert_ne!(ref_data, data);
515 }
516
517 #[test]
518 fn test_pairwise_distances() {
519 let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
520 let labels = arr1(&[0, 0, 1, 1]);
521
522 let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
523
524 assert!(
525 f64::EPSILON > (pairwise_distances[0] - 0.0).abs(),
526 "{} != 0.0",
527 pairwise_distances[0]
528 );
529 assert!(
530 f64::EPSILON > (pairwise_distances[1] - 0.0).abs(),
531 "{} != 0.0",
532 pairwise_distances[1]
533 );
534
535 let samples = arr2(&[[1.0, 2.0], [1.0, 1.0], [2.0, 2.0], [2.0, 3.0]]);
536
537 let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
538
539 assert!(
540 f64::EPSILON > (pairwise_distances[0] - 2.0).abs(),
541 "{} != 2.0",
542 pairwise_distances[0]
543 );
544 assert!(
545 f64::EPSILON > (pairwise_distances[1] - 2.0).abs(),
546 "{} != 2.0",
547 pairwise_distances[1]
548 );
549 }
550
551 #[test]
552 fn test_convert_to_vec() {
553 let data = vec![
554 Analysis::new([1.0; NUMBER_FEATURES]),
555 Analysis::new([2.0; NUMBER_FEATURES]),
556 Analysis::new([3.0; NUMBER_FEATURES]),
557 ];
558
559 let array = convert_to_array(data);
560
561 assert_eq!(array.0.shape(), &[3, NUMBER_FEATURES]);
562 assert!(
563 f64::EPSILON > (array.0[[0, 0]] - 1.0).abs(),
564 "{} != 1.0",
565 array.0[[0, 0]]
566 );
567 assert!(
568 f64::EPSILON > (array.0[[1, 0]] - 2.0).abs(),
569 "{} != 2.0",
570 array.0[[1, 0]]
571 );
572 assert!(
573 f64::EPSILON > (array.0[[2, 0]] - 3.0).abs(),
574 "{} != 3.0",
575 array.0[[2, 0]]
576 );
577
578 let mut iter = array.0.axis_iter(Axis(0));
581 assert_eq!(iter.next().unwrap().to_vec(), vec![1.0; NUMBER_FEATURES]);
582 assert_eq!(iter.next().unwrap().to_vec(), vec![2.0; NUMBER_FEATURES]);
583 assert_eq!(iter.next().unwrap().to_vec(), vec![3.0; NUMBER_FEATURES]);
584 for column in array.0.axis_iter(Axis(1)) {
586 assert_eq!(column.to_vec(), vec![1.0, 2.0, 3.0]);
587 }
588 }
589}
590
591