Skip to main content

fdars_core/alignment/
clustering.rs

1//! Distance-based clustering: k-means (k-medoids) and hierarchical.
2//!
3//! These algorithms work with **any** precomputed distance matrix — elastic
4//! (Fisher-Rao), DTW, Lp, amplitude-only, phase-only, or user-defined.
5//!
6//! # Examples
7//!
8//! ```
9//! use fdars_core::alignment::{
10//!     elastic_self_distance_matrix, hierarchical_from_distances,
11//!     kmedoids_from_distances, cut_dendrogram, Linkage, KMedoidsConfig,
12//! };
13//! use fdars_core::matrix::FdMatrix;
14//!
15//! // Compute any distance matrix
16//! let t: Vec<f64> = (0..20).map(|i| i as f64 / 19.0).collect();
17//! let data = FdMatrix::zeros(5, 20);
18//! let dist = elastic_self_distance_matrix(&data, &t, 0.0);
19//!
20//! // Hierarchical clustering — works with any distance matrix
21//! let dendro = hierarchical_from_distances(&dist, Linkage::Complete).unwrap();
22//! let labels = cut_dendrogram(&dendro, 2).unwrap();
23//!
24//! // K-medoids — works with any distance matrix
25//! let config = KMedoidsConfig { k: 2, ..Default::default() };
26//! let result = kmedoids_from_distances(&dist, &config).unwrap();
27//! ```
28
29use crate::error::FdarError;
30use crate::matrix::FdMatrix;
31use rand::rngs::StdRng;
32use rand::{Rng, SeedableRng};
33
34// ─── Types ──────────────────────────────────────────────────────────────────
35
36/// Configuration for k-medoids clustering.
37#[derive(Debug, Clone, PartialEq)]
38pub struct KMedoidsConfig {
39    /// Number of clusters.
40    pub k: usize,
41    /// Maximum number of iterations.
42    pub max_iter: usize,
43    /// Random seed for k-means++ initialization.
44    pub seed: u64,
45}
46
47impl Default for KMedoidsConfig {
48    fn default() -> Self {
49        Self {
50            k: 2,
51            max_iter: 100,
52            seed: 42,
53        }
54    }
55}
56
57/// Linkage method for hierarchical clustering.
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
59#[non_exhaustive]
60pub enum Linkage {
61    /// Minimum distance between clusters.
62    #[default]
63    Single,
64    /// Maximum distance between clusters.
65    Complete,
66    /// Weighted average distance (UPGMA).
67    Average,
68}
69
70/// Result of k-medoids clustering.
71#[derive(Debug, Clone, PartialEq)]
72#[non_exhaustive]
73pub struct KMedoidsResult {
74    /// Cluster label for each observation (0-indexed, length n).
75    pub labels: Vec<usize>,
76    /// Medoid index for each cluster (length k).
77    pub medoid_indices: Vec<usize>,
78    /// Within-cluster sum of distances for each cluster.
79    pub within_distances: Vec<f64>,
80    /// Total within-cluster distance.
81    pub total_within_distance: f64,
82    /// Number of iterations performed.
83    pub n_iter: usize,
84    /// Whether the algorithm converged (labels stabilized).
85    pub converged: bool,
86}
87
88/// Result of hierarchical clustering (dendrogram).
89#[derive(Debug, Clone, PartialEq)]
90#[non_exhaustive]
91pub struct Dendrogram {
92    /// Merge history: each entry `(i, j, distance)` records merging cluster
93    /// indices i and j at the given distance.
94    pub merges: Vec<(usize, usize, f64)>,
95    /// Number of observations.
96    pub n: usize,
97}
98
99// ─── K-Means++ Initialization ───────────────────────────────────────────────
100
101/// Select k initial center indices using k-means++ on a precomputed distance matrix.
102fn kmeans_pp_init(dist_mat: &FdMatrix, k: usize, rng: &mut StdRng) -> Vec<usize> {
103    let n = dist_mat.nrows();
104    let mut centers = Vec::with_capacity(k);
105
106    centers.push(rng.gen_range(0..n));
107
108    let mut min_dist_sq: Vec<f64> = (0..n)
109        .map(|i| {
110            let d = dist_mat[(i, centers[0])];
111            d * d
112        })
113        .collect();
114
115    for _ in 1..k {
116        let total: f64 = min_dist_sq.iter().sum();
117        if total <= 0.0 {
118            for i in 0..n {
119                if !centers.contains(&i) {
120                    centers.push(i);
121                    break;
122                }
123            }
124        } else {
125            let threshold = rng.gen::<f64>() * total;
126            let mut cum = 0.0;
127            let mut chosen = n - 1;
128            for i in 0..n {
129                cum += min_dist_sq[i];
130                if cum >= threshold {
131                    chosen = i;
132                    break;
133                }
134            }
135            centers.push(chosen);
136        }
137
138        let new_center = *centers.last().unwrap();
139        for i in 0..n {
140            let d = dist_mat[(i, new_center)];
141            let d2 = d * d;
142            if d2 < min_dist_sq[i] {
143                min_dist_sq[i] = d2;
144            }
145        }
146    }
147
148    centers
149}
150
151// ─── K-Medoids ─────────────────────────────────────────────────────────────
152
153/// K-medoids (PAM-style) clustering from a precomputed distance matrix.
154///
155/// Uses k-means++ initialization, then alternates between assigning each
156/// observation to its nearest medoid and selecting the medoid that minimizes
157/// within-cluster distances.
158///
159/// Works with **any** distance matrix — elastic, DTW, Lp, or user-defined.
160///
161/// # Arguments
162/// * `dist_mat` — Symmetric n x n distance matrix.
163/// * `config`   — Clustering configuration.
164///
165/// # Errors
166/// Returns [`FdarError::InvalidParameter`] if `k < 1` or `k > n`.
167/// Returns [`FdarError::InvalidDimension`] if `dist_mat` is not square.
168#[must_use = "expensive computation whose result should not be discarded"]
169pub fn kmedoids_from_distances(
170    dist_mat: &FdMatrix,
171    config: &KMedoidsConfig,
172) -> Result<KMedoidsResult, FdarError> {
173    let n = dist_mat.nrows();
174    if dist_mat.ncols() != n {
175        return Err(FdarError::InvalidDimension {
176            parameter: "dist_mat",
177            expected: format!("{n} x {n} (square)"),
178            actual: format!("{} x {}", n, dist_mat.ncols()),
179        });
180    }
181    if config.k < 1 {
182        return Err(FdarError::InvalidParameter {
183            parameter: "k",
184            message: "k must be >= 1".to_string(),
185        });
186    }
187    if config.k > n {
188        return Err(FdarError::InvalidParameter {
189            parameter: "k",
190            message: format!("k ({}) must be <= n ({})", config.k, n),
191        });
192    }
193
194    let k = config.k;
195    let mut rng = StdRng::seed_from_u64(config.seed);
196    let mut medoids = kmeans_pp_init(dist_mat, k, &mut rng);
197
198    // Assign each point to nearest medoid.
199    let mut labels = assign_to_medoids(dist_mat, &medoids, n);
200
201    let mut converged = false;
202    let mut n_iter = 0;
203
204    for iter in 0..config.max_iter {
205        n_iter = iter + 1;
206
207        // Update medoids: for each cluster, pick the member minimizing total distance.
208        for c in 0..k {
209            let members: Vec<usize> = (0..n).filter(|&i| labels[i] == c).collect();
210            if members.is_empty() {
211                continue;
212            }
213            let mut best_cost = f64::INFINITY;
214            let mut best_m = medoids[c];
215            for &candidate in &members {
216                let cost: f64 = members.iter().map(|&j| dist_mat[(candidate, j)]).sum();
217                if cost < best_cost {
218                    best_cost = cost;
219                    best_m = candidate;
220                }
221            }
222            medoids[c] = best_m;
223        }
224
225        // Reassign.
226        let new_labels = assign_to_medoids(dist_mat, &medoids, n);
227        if new_labels == labels {
228            converged = true;
229            labels = new_labels;
230            break;
231        }
232        labels = new_labels;
233    }
234
235    // Compute within-cluster distances.
236    let mut within_distances = vec![0.0; k];
237    for i in 0..n {
238        within_distances[labels[i]] += dist_mat[(i, medoids[labels[i]])];
239    }
240    let total_within_distance: f64 = within_distances.iter().sum();
241
242    Ok(KMedoidsResult {
243        labels,
244        medoid_indices: medoids,
245        within_distances,
246        total_within_distance,
247        n_iter,
248        converged,
249    })
250}
251
252fn assign_to_medoids(dist_mat: &FdMatrix, medoids: &[usize], n: usize) -> Vec<usize> {
253    (0..n)
254        .map(|i| {
255            let mut best_d = f64::INFINITY;
256            let mut best_c = 0;
257            for (c, &med) in medoids.iter().enumerate() {
258                let d = dist_mat[(i, med)];
259                if d < best_d {
260                    best_d = d;
261                    best_c = c;
262                }
263            }
264            best_c
265        })
266        .collect()
267}
268
269// ─── Hierarchical Clustering ───────────────────────────────────────────────
270
271/// Hierarchical agglomerative clustering from a precomputed distance matrix.
272///
273/// Builds a [`Dendrogram`] by iteratively merging the closest pair of clusters.
274/// Works with **any** distance matrix — elastic, DTW, Lp, or user-defined.
275///
276/// # Arguments
277/// * `dist_mat` — Symmetric n x n distance matrix.
278/// * `linkage`  — Linkage criterion.
279///
280/// # Errors
281/// Returns [`FdarError::InvalidDimension`] if `dist_mat` is not square or `n < 2`.
282#[must_use = "expensive computation whose result should not be discarded"]
283pub fn hierarchical_from_distances(
284    dist_mat: &FdMatrix,
285    linkage: Linkage,
286) -> Result<Dendrogram, FdarError> {
287    let n = dist_mat.nrows();
288    if dist_mat.ncols() != n {
289        return Err(FdarError::InvalidDimension {
290            parameter: "dist_mat",
291            expected: format!("{n} x {n} (square)"),
292            actual: format!("{} x {}", n, dist_mat.ncols()),
293        });
294    }
295    if n < 2 {
296        return Err(FdarError::InvalidDimension {
297            parameter: "dist_mat",
298            expected: "at least 2 rows".to_string(),
299            actual: format!("{n} rows"),
300        });
301    }
302
303    let mut active = vec![true; n];
304    let mut cluster_sizes = vec![1usize; n];
305    let mut cluster_dist = FdMatrix::zeros(n, n);
306    for i in 0..n {
307        for j in 0..n {
308            cluster_dist[(i, j)] = dist_mat[(i, j)];
309        }
310    }
311
312    let mut merges: Vec<(usize, usize, f64)> = Vec::with_capacity(n - 1);
313
314    for _ in 0..(n - 1) {
315        let mut min_d = f64::INFINITY;
316        let mut min_i = 0;
317        let mut min_j = 1;
318        for i in 0..n {
319            if !active[i] {
320                continue;
321            }
322            for j in (i + 1)..n {
323                if !active[j] {
324                    continue;
325                }
326                if cluster_dist[(i, j)] < min_d {
327                    min_d = cluster_dist[(i, j)];
328                    min_i = i;
329                    min_j = j;
330                }
331            }
332        }
333
334        merges.push((min_i, min_j, min_d));
335
336        let size_i = cluster_sizes[min_i];
337        let size_j = cluster_sizes[min_j];
338        for k in 0..n {
339            if !active[k] || k == min_i || k == min_j {
340                continue;
341            }
342            let d_ik = cluster_dist[(min_i.min(k), min_i.max(k))];
343            let d_jk = cluster_dist[(min_j.min(k), min_j.max(k))];
344            let new_d = match linkage {
345                Linkage::Single => d_ik.min(d_jk),
346                Linkage::Complete => d_ik.max(d_jk),
347                Linkage::Average => {
348                    (d_ik * size_i as f64 + d_jk * size_j as f64) / (size_i + size_j) as f64
349                }
350            };
351            let (lo, hi) = (min_i.min(k), min_i.max(k));
352            cluster_dist[(lo, hi)] = new_d;
353            cluster_dist[(hi, lo)] = new_d;
354        }
355
356        cluster_sizes[min_i] = size_i + size_j;
357        active[min_j] = false;
358    }
359
360    Ok(Dendrogram { merges, n })
361}
362
363// ─── Cut Dendrogram ─────────────────────────────────────────────────────────
364
365/// Cut a dendrogram to produce k clusters.
366///
367/// Replays the merge history, stopping after `n - k` merges, and returns
368/// cluster labels for each original observation.
369///
370/// # Arguments
371/// * `dendrogram` — Result of [`hierarchical_from_distances`].
372/// * `k`          — Number of clusters desired.
373///
374/// # Errors
375/// Returns [`FdarError::InvalidParameter`] if `k < 1` or `k > n`.
376pub fn cut_dendrogram(dendrogram: &Dendrogram, k: usize) -> Result<Vec<usize>, FdarError> {
377    let n = dendrogram.n;
378
379    if k < 1 {
380        return Err(FdarError::InvalidParameter {
381            parameter: "k",
382            message: "k must be >= 1".to_string(),
383        });
384    }
385    if k > n {
386        return Err(FdarError::InvalidParameter {
387            parameter: "k",
388            message: format!("k ({k}) must be <= n ({n})"),
389        });
390    }
391
392    let mut cluster_of: Vec<usize> = (0..n).collect();
393    let merges_to_apply = n - k;
394
395    for &(ci, cj, _) in dendrogram.merges.iter().take(merges_to_apply) {
396        let target = cluster_of[ci];
397        let source = cluster_of[cj];
398        for label in cluster_of.iter_mut() {
399            if *label == source {
400                *label = target;
401            }
402        }
403    }
404
405    // Compress labels to 0..k-1.
406    let mut unique: Vec<usize> = cluster_of.clone();
407    unique.sort_unstable();
408    unique.dedup();
409    let labels = cluster_of
410        .iter()
411        .map(|&l| unique.iter().position(|&u| u == l).unwrap())
412        .collect();
413
414    Ok(labels)
415}
416
417// ─── Tests ──────────────────────────────────────────────────────────────────
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use crate::alignment::elastic_self_distance_matrix;
423    use crate::simulation::{sim_fundata, EFunType, EValType};
424    use crate::test_helpers::uniform_grid;
425
426    fn make_dist_mat(n: usize, m: usize) -> FdMatrix {
427        let t = uniform_grid(m);
428        let data = sim_fundata(n, &t, 3, EFunType::Fourier, EValType::Exponential, Some(42));
429        elastic_self_distance_matrix(&data, &t, 0.0)
430    }
431
432    #[test]
433    fn kmedoids_smoke() {
434        let dist = make_dist_mat(8, 20);
435        let config = KMedoidsConfig {
436            k: 2,
437            max_iter: 10,
438            ..Default::default()
439        };
440        let result = kmedoids_from_distances(&dist, &config).unwrap();
441        assert_eq!(result.labels.len(), 8);
442        assert_eq!(result.medoid_indices.len(), 2);
443        assert_eq!(result.within_distances.len(), 2);
444        assert!(result.total_within_distance >= 0.0);
445        assert!(result.n_iter >= 1);
446    }
447
448    #[test]
449    fn kmedoids_single_cluster() {
450        let dist = make_dist_mat(5, 20);
451        let config = KMedoidsConfig {
452            k: 1,
453            max_iter: 10,
454            ..Default::default()
455        };
456        let result = kmedoids_from_distances(&dist, &config).unwrap();
457        assert!(result.labels.iter().all(|&l| l == 0));
458        assert_eq!(result.medoid_indices.len(), 1);
459    }
460
461    #[test]
462    fn kmedoids_k_too_large() {
463        let dist = make_dist_mat(3, 20);
464        let config = KMedoidsConfig {
465            k: 5,
466            ..Default::default()
467        };
468        assert!(kmedoids_from_distances(&dist, &config).is_err());
469    }
470
471    #[test]
472    fn kmedoids_k_zero() {
473        let dist = make_dist_mat(5, 20);
474        let config = KMedoidsConfig {
475            k: 0,
476            ..Default::default()
477        };
478        assert!(kmedoids_from_distances(&dist, &config).is_err());
479    }
480
481    #[test]
482    fn hierarchical_single_smoke() {
483        let dist = make_dist_mat(5, 20);
484        let dendro = hierarchical_from_distances(&dist, Linkage::Single).unwrap();
485        assert_eq!(dendro.merges.len(), 4);
486        for w in dendro.merges.windows(2) {
487            assert!(
488                w[1].2 >= w[0].2 - 1e-10,
489                "single linkage should be non-decreasing"
490            );
491        }
492    }
493
494    #[test]
495    fn hierarchical_complete_smoke() {
496        let dist = make_dist_mat(5, 20);
497        let dendro = hierarchical_from_distances(&dist, Linkage::Complete).unwrap();
498        assert_eq!(dendro.merges.len(), 4);
499    }
500
501    #[test]
502    fn hierarchical_average_smoke() {
503        let dist = make_dist_mat(5, 20);
504        let dendro = hierarchical_from_distances(&dist, Linkage::Average).unwrap();
505        assert_eq!(dendro.merges.len(), 4);
506    }
507
508    #[test]
509    fn hierarchical_too_few() {
510        let dist = FdMatrix::zeros(1, 1);
511        assert!(hierarchical_from_distances(&dist, Linkage::Single).is_err());
512    }
513
514    #[test]
515    fn cut_dendrogram_all_singletons() {
516        let dist = make_dist_mat(5, 20);
517        let dendro = hierarchical_from_distances(&dist, Linkage::Single).unwrap();
518        let labels = cut_dendrogram(&dendro, 5).unwrap();
519        let mut sorted = labels.clone();
520        sorted.sort_unstable();
521        assert_eq!(sorted, vec![0, 1, 2, 3, 4]);
522    }
523
524    #[test]
525    fn cut_dendrogram_one_cluster() {
526        let dist = make_dist_mat(5, 20);
527        let dendro = hierarchical_from_distances(&dist, Linkage::Single).unwrap();
528        let labels = cut_dendrogram(&dendro, 1).unwrap();
529        assert!(labels.iter().all(|&l| l == 0));
530    }
531
532    #[test]
533    fn cut_dendrogram_k_too_large() {
534        let dist = make_dist_mat(5, 20);
535        let dendro = hierarchical_from_distances(&dist, Linkage::Single).unwrap();
536        assert!(cut_dendrogram(&dendro, 10).is_err());
537    }
538
539    #[test]
540    fn cut_dendrogram_two_clusters() {
541        let dist = make_dist_mat(6, 20);
542        let dendro = hierarchical_from_distances(&dist, Linkage::Single).unwrap();
543        let labels = cut_dendrogram(&dendro, 2).unwrap();
544        assert_eq!(labels.len(), 6);
545        let unique: std::collections::HashSet<usize> = labels.iter().copied().collect();
546        assert_eq!(unique.len(), 2);
547    }
548
549    #[test]
550    fn default_config_values() {
551        let cfg = KMedoidsConfig::default();
552        assert_eq!(cfg.k, 2);
553        assert_eq!(cfg.max_iter, 100);
554        assert_eq!(cfg.seed, 42);
555    }
556
557    #[test]
558    fn default_linkage() {
559        assert_eq!(Linkage::default(), Linkage::Single);
560    }
561
562    #[test]
563    fn non_square_dist_mat_error() {
564        let dist = FdMatrix::zeros(3, 4);
565        assert!(hierarchical_from_distances(&dist, Linkage::Single).is_err());
566        let config = KMedoidsConfig::default();
567        assert!(kmedoids_from_distances(&dist, &config).is_err());
568    }
569}