1use linfa::prelude::*;
13use linfa_clustering::{GaussianMixtureModel, KMeans};
14use linfa_nn::distance::{Distance, L2Dist};
15use linfa_tsne::TSneParams;
16use log::{debug, info};
17use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis};
18use ndarray_rand::RandomExt;
19use rand::distributions::Uniform;
20use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
21use statrs::statistics::Statistics;
22
23use crate::{errors::ClusteringError, Analysis, Feature, NUMBER_FEATURES};
24
25pub struct AnalysisArray(pub(crate) Array2<Feature>);
26
27impl From<Vec<Analysis>> for AnalysisArray {
28 fn from(data: Vec<Analysis>) -> Self {
29 let shape = (data.len(), NUMBER_FEATURES);
30 debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
31
32 Self(
33 Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
34 .expect("Failed to convert to array, shape mismatch"),
35 )
36 }
37}
38
39impl From<Vec<[Feature; NUMBER_FEATURES]>> for AnalysisArray {
40 fn from(data: Vec<[Feature; NUMBER_FEATURES]>) -> Self {
41 let shape = (data.len(), NUMBER_FEATURES);
42 debug_assert_eq!(shape, (data.len(), data[0].len()));
43
44 Self(
45 Array2::from_shape_vec(shape, data.into_iter().flatten().collect())
46 .expect("Failed to convert to array, shape mismatch"),
47 )
48 }
49}
50
51#[derive(Clone, Copy, Debug)]
52#[allow(clippy::module_name_repetitions)]
53pub enum ClusteringMethod {
54 KMeans,
55 GaussianMixtureModel,
56}
57
58impl ClusteringMethod {
59 #[must_use]
61 fn fit(self, k: usize, samples: &Array2<Feature>) -> Array1<usize> {
62 match self {
63 Self::KMeans => {
64 let model = KMeans::params(k)
65 .fit(&Dataset::from(samples.clone()))
67 .unwrap();
68 model.predict(samples)
69 }
70 Self::GaussianMixtureModel => {
71 let model = GaussianMixtureModel::params(k)
72 .init_method(linfa_clustering::GmmInitMethod::KMeans)
73 .n_runs(10)
74 .fit(&Dataset::from(samples.clone()))
75 .unwrap();
76 model.predict(samples)
77 }
78 }
79 }
80}
81
82#[derive(Clone, Copy, Debug)]
83pub enum KOptimal {
84 GapStatistic {
85 b: usize,
87 },
88 DaviesBouldin,
89}
90
91const EMBEDDING_SIZE: usize =
93 {
95 let log2 = usize::ilog2(NUMBER_FEATURES) as usize;
96 if log2 < 2 {
97 2
98 } else {
99 log2
100 }
101 };
102
103#[allow(clippy::module_name_repetitions)]
104pub struct ClusteringHelper<S>
105where
106 S: Sized,
107{
108 state: S,
109}
110
111pub struct EntryPoint;
112pub struct NotInitialized {
113 embeddings: Array2<Feature>,
115 pub k_max: usize,
116 pub optimizer: KOptimal,
117 pub clustering_method: ClusteringMethod,
118}
119pub struct Initialized {
120 embeddings: Array2<Feature>,
122 pub k: usize,
123 pub clustering_method: ClusteringMethod,
124}
125pub struct Finished {
126 labels: Array1<usize>,
129 pub k: usize,
130}
131
132impl ClusteringHelper<EntryPoint> {
134 pub fn new(
140 samples: AnalysisArray,
141 k_max: usize,
142 optimizer: KOptimal,
143 clustering_method: ClusteringMethod,
144 ) -> Result<ClusteringHelper<NotInitialized>, ClusteringError> {
145 debug!("Generating embeddings (size: {EMBEDDING_SIZE}) using t-SNE",);
147
148 if samples.0.nrows() <= 15 {
149 return Err(ClusteringError::SmallLibrary);
150 }
151
152 #[allow(clippy::cast_precision_loss)]
153 let mut embeddings = TSneParams::embedding_size(EMBEDDING_SIZE)
154 .perplexity(f64::max(samples.0.nrows() as f64 / 20., 5.))
155 .approx_threshold(0.5)
156 .transform(samples.0)?;
157
158 debug!("Embeddings shape: {:?}", embeddings.shape());
159
160 debug!("Normalizing embeddings");
162 for i in 0..EMBEDDING_SIZE {
163 let min = embeddings.column(i).min();
164 let max = embeddings.column(i).max();
165 let range = max - min;
166 embeddings
167 .column_mut(i)
168 .mapv_inplace(|v| ((v - min) / range).mul_add(2., -1.));
169 }
170
171 Ok(ClusteringHelper {
172 state: NotInitialized {
173 embeddings,
174 k_max,
175 optimizer,
176 clustering_method,
177 },
178 })
179 }
180}
181
182impl ClusteringHelper<NotInitialized> {
184 pub fn initialize(self) -> Result<ClusteringHelper<Initialized>, ClusteringError> {
190 let k = self.get_optimal_k()?;
191 Ok(ClusteringHelper {
192 state: Initialized {
193 embeddings: self.state.embeddings,
194 k,
195 clustering_method: self.state.clustering_method,
196 },
197 })
198 }
199
200 fn get_optimal_k(&self) -> Result<usize, ClusteringError> {
201 match self.state.optimizer {
202 KOptimal::GapStatistic { b } => self.get_optimal_k_gap_statistic(b),
203 KOptimal::DaviesBouldin => self.get_optimal_k_davies_bouldin(),
204 }
205 }
206
207 fn get_optimal_k_gap_statistic(&self, b: usize) -> Result<usize, ClusteringError> {
224 let reference_data_sets = generate_reference_data_set(self.state.embeddings.view(), b);
226
227 let results = (1..=self.state.k_max)
228 .map(|k| {
230 debug!("Fitting k-means to embeddings with k={k}");
231 let labels = self.state.clustering_method.fit(k, &self.state.embeddings);
232 (k, labels)
233 })
234 .map(|(k, labels)| {
236 let pairwise_distances =
238 calc_pairwise_distances(self.state.embeddings.view(), k, labels.view());
239 let w_k = calc_within_dispersion(labels.view(), k, pairwise_distances.view());
240
241 debug!(
243 "Calculating within intra-cluster variation for reference data sets with k={k}"
244 );
245 let w_kb = reference_data_sets.par_iter().map(|ref_data| {
246 let ref_labels = self.state.clustering_method.fit(k, ref_data);
248 let ref_pairwise_distances =
250 calc_pairwise_distances(ref_data.view(), k, ref_labels.view());
251 calc_within_dispersion(ref_labels.view(), k, ref_pairwise_distances.view())
252 });
253
254 #[allow(clippy::cast_precision_loss)]
256 let gap_k = (1.0 / b as f64)
257 .mul_add(w_kb.clone().map(f64::log2).sum::<f64>().log2(), -w_k.log2());
258
259 #[allow(clippy::cast_precision_loss)]
260 let l = (1.0 / b as f64) * w_kb.clone().map(f64::log2).sum::<f64>();
261 #[allow(clippy::cast_precision_loss)]
262 let standard_deviation = ((1.0 / b as f64)
263 * w_kb.map(|w_kb| (w_kb.log2() - l).powi(2)).sum::<f64>())
264 .sqrt();
265 #[allow(clippy::cast_precision_loss)]
266 let s_k = standard_deviation * (1.0 + 1.0 / b as f64).sqrt();
267
268 (k, gap_k, s_k)
269 });
270
271 let (mut optimal_k, mut gap_k_minus_one) = (None, None);
277
278 for (k, gap_k, s_k) in results {
279 info!("k: {k}, gap_k: {gap_k}, s_k: {s_k}");
280
281 if let Some(gap_k_minus_one) = gap_k_minus_one {
282 if gap_k_minus_one >= gap_k - s_k {
283 info!("Optimal k found: {}", k - 1);
284 optimal_k = Some(k - 1);
285 break;
286 }
287 }
288 gap_k_minus_one = Some(gap_k);
289 }
290
291 optimal_k.ok_or(ClusteringError::OptimalKNotFound(self.state.k_max))
292 }
293
294 fn get_optimal_k_davies_bouldin(&self) -> Result<usize, ClusteringError> {
295 todo!();
296 }
297}
298
299#[must_use]
305pub fn convert_to_array(data: Vec<Analysis>) -> AnalysisArray {
306 let shape = (data.len(), NUMBER_FEATURES);
308 debug_assert_eq!(shape, (data.len(), data[0].inner().len()));
309
310 AnalysisArray(
311 Array2::from_shape_vec(shape, data.into_iter().flat_map(|a| *a.inner()).collect())
312 .expect("Failed to convert to array, shape mismatch"),
313 )
314}
315
316fn generate_reference_data_set(samples: ArrayView2<Feature>, b: usize) -> Vec<Array2<f64>> {
340 let mut reference_data_sets = Vec::with_capacity(b);
341 for _ in 0..b {
342 reference_data_sets.push(generate_ref_single(samples));
343 }
344
345 reference_data_sets
346}
347fn generate_ref_single(samples: ArrayView2<Feature>) -> Array2<f64> {
348 let feature_distributions = samples
349 .axis_iter(Axis(1))
350 .map(|feature| Array::random(feature.dim(), Uniform::new(feature.min(), feature.max())))
351 .collect::<Vec<_>>();
352 let feature_dists_views = feature_distributions
353 .iter()
354 .map(ndarray::ArrayBase::view)
355 .collect::<Vec<_>>();
356 ndarray::stack(Axis(0), &feature_dists_views)
357 .unwrap()
358 .t()
359 .to_owned()
360}
361
362fn calc_within_dispersion(
366 labels: ArrayView1<usize>,
367 k: usize,
368 pairwise_distances: ArrayView1<Feature>,
369) -> Feature {
370 debug_assert_eq!(k, labels.iter().max().unwrap() + 1);
371
372 let counts = labels.iter().fold(vec![0; k], |mut counts, &label| {
374 counts[label] += 1;
375 counts
376 });
377 counts
379 .iter()
380 .zip(pairwise_distances.iter())
381 .map(|(&count, distance)| distance / (2.0 * f64::from(count)))
382 .sum()
383}
384
385fn calc_pairwise_distances(
393 samples: ArrayView2<Feature>,
394 k: usize,
395 labels: ArrayView1<usize>,
396) -> Array1<Feature> {
397 debug_assert_eq!(
398 samples.nrows(),
399 labels.len(),
400 "Samples and labels must have the same length"
401 );
402 debug_assert_eq!(
403 k,
404 labels.iter().max().unwrap() + 1,
405 "Labels must be in the range 0..k"
406 );
407
408 (0..k)
410 .map(|k| {
411 (
412 k,
413 samples
414 .outer_iter()
415 .zip(labels.iter())
416 .filter_map(|(s, &l)| (l == k).then_some(s))
417 .collect::<Vec<_>>(),
418 )
419 })
420 .fold(Array1::zeros(k), |mut distances, (label, cluster)| {
421 distances[label] += cluster
422 .iter()
423 .enumerate()
424 .map(|(i, &a)| {
425 cluster
426 .iter()
427 .skip(i + 1)
428 .map(|&b| L2Dist.distance(a, b))
429 .sum::<Feature>()
430 })
431 .sum::<Feature>();
432 distances
433 })
434}
435
436impl ClusteringHelper<Initialized> {
438 #[must_use]
444 pub fn cluster(self) -> ClusteringHelper<Finished> {
445 let labels = self
446 .state
447 .clustering_method
448 .fit(self.state.k, &self.state.embeddings);
449
450 ClusteringHelper {
451 state: Finished {
452 labels,
453 k: self.state.k,
454 },
455 }
456 }
457}
458
459impl ClusteringHelper<Finished> {
461 #[must_use]
463 pub fn extract_analysis_clusters<T: Clone>(&self, samples: Vec<T>) -> Vec<Vec<T>> {
464 let mut clusters = vec![Vec::new(); self.state.k];
465
466 for (sample, &label) in samples.into_iter().zip(self.state.labels.iter()) {
467 clusters[label].push(sample);
468 }
469
470 clusters
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use ndarray::{arr1, arr2, s};
478 use pretty_assertions::assert_eq;
479
480 #[test]
481 fn test_generate_reference_data_set() {
482 let data = arr2(&[[10.0, -10.0], [20.0, -20.0], [30.0, -30.0]]);
483
484 let ref_data = generate_ref_single(data.view());
485
486 assert!(ref_data
488 .slice(s![.., 0])
489 .iter()
490 .all(|v| *v >= 10.0 && *v <= 30.0));
491
492 assert!(ref_data
494 .slice(s![.., 1])
495 .iter()
496 .all(|v| *v <= -10.0 && *v >= -30.0));
497
498 assert_eq!(ref_data.shape(), data.shape());
500
501 assert_ne!(ref_data, data);
503 }
504
505 #[test]
506 fn test_pairwise_distances() {
507 let samples = arr2(&[[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]);
508 let labels = arr1(&[0, 0, 1, 1]);
509
510 let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
511
512 assert_eq!(pairwise_distances[0], 0.0);
513 assert_eq!(pairwise_distances[1], 0.0);
514
515 let samples = arr2(&[[1.0, 2.0], [1.0, 1.0], [2.0, 2.0], [2.0, 3.0]]);
516
517 let pairwise_distances = calc_pairwise_distances(samples.view(), 2, labels.view());
518
519 assert_eq!(pairwise_distances[0], 1.0);
520 assert_eq!(pairwise_distances[1], 1.0);
521 }
522
523 #[test]
524 fn test_convert_to_vec() {
525 let data = vec![
526 Analysis::new([1.0; NUMBER_FEATURES]),
527 Analysis::new([2.0; NUMBER_FEATURES]),
528 Analysis::new([3.0; NUMBER_FEATURES]),
529 ];
530
531 let array = convert_to_array(data);
532
533 assert_eq!(array.0.shape(), &[3, NUMBER_FEATURES]);
534 assert_eq!(array.0[[0, 0]], 1.0);
535 assert_eq!(array.0[[1, 0]], 2.0);
536 assert_eq!(array.0[[2, 0]], 3.0);
537
538 let mut iter = array.0.axis_iter(Axis(0));
541 assert_eq!(iter.next().unwrap().to_vec(), vec![1.0; NUMBER_FEATURES]);
542 assert_eq!(iter.next().unwrap().to_vec(), vec![2.0; NUMBER_FEATURES]);
543 assert_eq!(iter.next().unwrap().to_vec(), vec![3.0; NUMBER_FEATURES]);
544 for column in array.0.axis_iter(Axis(1)) {
546 assert_eq!(column.to_vec(), vec![1.0, 2.0, 3.0]);
547 }
548 }
549}
550
551