Skip to main content

fdars_core/alignment/
clustering.rs

1//! Elastic clustering: k-means and hierarchical clustering using
2//! Fisher-Rao elastic distances.
3//!
4//! Standard L2 clustering ignores phase variation. Elastic clustering
5//! uses the Fisher-Rao distance (which factors out reparameterization)
6//! and computes cluster centers as Karcher means.
7
8use super::karcher::karcher_mean;
9use super::pairwise::{elastic_distance, elastic_self_distance_matrix};
10use super::KarcherMeanResult;
11use crate::cv::subset_rows;
12use crate::error::FdarError;
13use crate::matrix::FdMatrix;
14use rand::rngs::StdRng;
15use rand::{Rng, SeedableRng};
16
17// ─── Types ──────────────────────────────────────────────────────────────────
18
19/// Configuration for elastic k-means clustering.
20#[derive(Debug, Clone, PartialEq)]
21pub struct ElasticClusterConfig {
22    /// Number of clusters.
23    pub k: usize,
24    /// Roughness penalty for elastic alignment (0.0 = no penalty).
25    pub lambda: f64,
26    /// Maximum number of k-means iterations.
27    pub max_iter: usize,
28    /// Convergence tolerance (reserved for future distance-based criteria).
29    pub tol: f64,
30    /// Maximum iterations for each Karcher mean computation.
31    pub karcher_max_iter: usize,
32    /// Convergence tolerance for each Karcher mean computation.
33    pub karcher_tol: f64,
34    /// Random seed for initialization.
35    pub seed: u64,
36}
37
38impl Default for ElasticClusterConfig {
39    fn default() -> Self {
40        Self {
41            k: 2,
42            lambda: 0.0,
43            max_iter: 20,
44            tol: 1e-4,
45            karcher_max_iter: 15,
46            karcher_tol: 1e-3,
47            seed: 42,
48        }
49    }
50}
51
52/// Linkage method for elastic clustering.
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
54#[non_exhaustive]
55pub enum ElasticClusterMethod {
56    /// K-means clustering with Karcher mean centers.
57    #[default]
58    KMeans,
59    /// Hierarchical clustering with single (minimum) linkage.
60    HierarchicalSingle,
61    /// Hierarchical clustering with complete (maximum) linkage.
62    HierarchicalComplete,
63    /// Hierarchical clustering with average (UPGMA) linkage.
64    HierarchicalAverage,
65}
66
67/// Result of elastic k-means clustering.
68#[derive(Debug, Clone, PartialEq)]
69#[non_exhaustive]
70pub struct ElasticClusterResult {
71    /// Cluster label for each curve (0-indexed, length n).
72    pub labels: Vec<usize>,
73    /// Karcher mean for each cluster.
74    pub centers: Vec<KarcherMeanResult>,
75    /// Within-cluster sum of elastic distances for each cluster.
76    pub within_distances: Vec<f64>,
77    /// Total within-cluster distance (sum of `within_distances`).
78    pub total_within_distance: f64,
79    /// Number of iterations performed.
80    pub n_iter: usize,
81    /// Whether the algorithm converged (labels stabilized).
82    pub converged: bool,
83}
84
85/// Result of hierarchical elastic clustering (dendrogram).
86#[derive(Debug, Clone, PartialEq)]
87#[non_exhaustive]
88pub struct ElasticDendrogram {
89    /// Merge history: each entry `(i, j, distance)` records merging cluster
90    /// indices i and j at the given distance.
91    pub merges: Vec<(usize, usize, f64)>,
92    /// Full elastic distance matrix used to build the dendrogram.
93    pub distance_matrix: FdMatrix,
94}
95
96// ─── K-Means++ Initialization ───────────────────────────────────────────────
97
98/// Select k initial center indices using k-means++ on a precomputed distance matrix.
99fn kmeans_pp_init(dist_mat: &FdMatrix, k: usize, rng: &mut StdRng) -> Vec<usize> {
100    let n = dist_mat.nrows();
101    let mut centers = Vec::with_capacity(k);
102
103    // Pick the first center uniformly at random.
104    centers.push(rng.gen_range(0..n));
105
106    // min_dist_sq[i] = min distance² from curve i to any chosen center.
107    let mut min_dist_sq: Vec<f64> = (0..n)
108        .map(|i| {
109            let d = dist_mat[(i, centers[0])];
110            d * d
111        })
112        .collect();
113
114    for _ in 1..k {
115        let total: f64 = min_dist_sq.iter().sum();
116        if total <= 0.0 {
117            // All remaining points are at distance 0; pick any unselected.
118            for i in 0..n {
119                if !centers.contains(&i) {
120                    centers.push(i);
121                    break;
122                }
123            }
124        } else {
125            let threshold = rng.gen::<f64>() * total;
126            let mut cum = 0.0;
127            let mut chosen = n - 1;
128            for i in 0..n {
129                cum += min_dist_sq[i];
130                if cum >= threshold {
131                    chosen = i;
132                    break;
133                }
134            }
135            centers.push(chosen);
136        }
137
138        // Update minimum distances with the new center.
139        let new_center = *centers.last().unwrap();
140        for i in 0..n {
141            let d = dist_mat[(i, new_center)];
142            let d2 = d * d;
143            if d2 < min_dist_sq[i] {
144                min_dist_sq[i] = d2;
145            }
146        }
147    }
148
149    centers
150}
151
152// ─── Empty Cluster Handling ─────────────────────────────────────────────────
153
154/// Find the curve farthest from its assigned center (by avg peer distance)
155/// in the largest cluster. Used to reassign when a cluster becomes empty.
156fn reassign_empty_cluster(labels: &[usize], dist_mat: &FdMatrix) -> usize {
157    let n = labels.len();
158
159    // Find the largest cluster.
160    let max_label = labels.iter().copied().max().unwrap_or(0);
161    let mut counts = vec![0usize; max_label + 1];
162    for &l in labels {
163        counts[l] += 1;
164    }
165    let largest_cluster = counts
166        .iter()
167        .enumerate()
168        .max_by_key(|&(_, &cnt)| cnt)
169        .map(|(c, _)| c)
170        .unwrap_or(0);
171
172    // Find the member farthest from its peers in that cluster.
173    let members: Vec<usize> = (0..n).filter(|&i| labels[i] == largest_cluster).collect();
174    let mut max_avg_dist = -1.0_f64;
175    let mut farthest = members[0];
176    for &i in &members {
177        let avg_d: f64 =
178            members.iter().map(|&j| dist_mat[(i, j)]).sum::<f64>() / members.len() as f64;
179        if avg_d > max_avg_dist {
180            max_avg_dist = avg_d;
181            farthest = i;
182        }
183    }
184    farthest
185}
186
187// ─── K-Means ────────────────────────────────────────────────────────────────
188
189/// Elastic k-means clustering using Fisher-Rao distances and Karcher means.
190///
191/// Partitions functional data into `k` clusters in the elastic metric. Cluster
192/// centers are Karcher (Frechet) means, and assignment uses the Fisher-Rao
193/// distance.
194///
195/// # Algorithm
196/// 1. Compute the full elastic distance matrix.
197/// 2. Initialize centers using k-means++.
198/// 3. Iterate: assign curves to nearest center, recompute Karcher means,
199///    recompute distances, and check for convergence (label stability).
200///
201/// # Arguments
202/// * `data`    — Functional data matrix (n x m).
203/// * `argvals` — Evaluation points (length m).
204/// * `config`  — Clustering configuration.
205///
206/// # Errors
207/// Returns [`FdarError::InvalidParameter`] if `k < 1` or `k > n`.
208/// Returns [`FdarError::InvalidDimension`] if `argvals` length does not match `m`.
209#[must_use = "expensive computation whose result should not be discarded"]
210pub fn elastic_kmeans(
211    data: &FdMatrix,
212    argvals: &[f64],
213    config: &ElasticClusterConfig,
214) -> Result<ElasticClusterResult, FdarError> {
215    let (n, m) = data.shape();
216
217    if config.k < 1 {
218        return Err(FdarError::InvalidParameter {
219            parameter: "k",
220            message: "k must be >= 1".to_string(),
221        });
222    }
223    if config.k > n {
224        return Err(FdarError::InvalidParameter {
225            parameter: "k",
226            message: format!("k ({}) must be <= n ({})", config.k, n),
227        });
228    }
229    if argvals.len() != m {
230        return Err(FdarError::InvalidDimension {
231            parameter: "argvals",
232            expected: format!("{m}"),
233            actual: format!("{}", argvals.len()),
234        });
235    }
236
237    let k = config.k;
238
239    // Step 1: Compute full elastic distance matrix.
240    let dist_mat = elastic_self_distance_matrix(data, argvals, config.lambda);
241
242    // Step 2: K-means++ initialization.
243    let mut rng = StdRng::seed_from_u64(config.seed);
244    let center_indices = kmeans_pp_init(&dist_mat, k, &mut rng);
245
246    // Initial assignment: each curve goes to its nearest initial center.
247    let mut labels = vec![0usize; n];
248    for i in 0..n {
249        let mut best_d = f64::INFINITY;
250        for (c, &ci) in center_indices.iter().enumerate() {
251            let d = dist_mat[(i, ci)];
252            if d < best_d {
253                best_d = d;
254                labels[i] = c;
255            }
256        }
257    }
258
259    // Step 3: Iterate.
260    let mut converged = false;
261    let mut n_iter = 0;
262    let mut centers: Vec<KarcherMeanResult> = Vec::with_capacity(k);
263
264    for iter in 0..config.max_iter {
265        n_iter = iter + 1;
266
267        // Compute Karcher mean for each cluster.
268        centers = compute_cluster_centers(data, argvals, &labels, k, &dist_mat, config);
269
270        // Reassign: compute distance from each curve to each center's mean.
271        let new_labels: Vec<usize> = (0..n)
272            .map(|i| {
273                let fi = data.row(i);
274                let mut best_d = f64::INFINITY;
275                let mut best_c = 0;
276                for (c, center) in centers.iter().enumerate() {
277                    let d = elastic_distance(&fi, &center.mean, argvals, config.lambda);
278                    if d < best_d {
279                        best_d = d;
280                        best_c = c;
281                    }
282                }
283                best_c
284            })
285            .collect();
286
287        // Check convergence: labels unchanged.
288        if new_labels == labels {
289            converged = true;
290            labels = new_labels;
291            break;
292        }
293
294        labels = new_labels;
295    }
296
297    // If we exited without converging, recompute final centers.
298    if !converged {
299        centers = compute_cluster_centers(data, argvals, &labels, k, &dist_mat, config);
300    }
301
302    // Compute within-cluster distances.
303    let mut within_distances = vec![0.0; k];
304    for i in 0..n {
305        let fi = data.row(i);
306        let c = labels[i];
307        let d = elastic_distance(&fi, &centers[c].mean, argvals, config.lambda);
308        within_distances[c] += d;
309    }
310    let total_within_distance: f64 = within_distances.iter().sum();
311
312    Ok(ElasticClusterResult {
313        labels,
314        centers,
315        within_distances,
316        total_within_distance,
317        n_iter,
318        converged,
319    })
320}
321
322/// Compute Karcher mean centers for each cluster.
323fn compute_cluster_centers(
324    data: &FdMatrix,
325    argvals: &[f64],
326    labels: &[usize],
327    k: usize,
328    dist_mat: &FdMatrix,
329    config: &ElasticClusterConfig,
330) -> Vec<KarcherMeanResult> {
331    let n = data.nrows();
332    let mut centers = Vec::with_capacity(k);
333    for c in 0..k {
334        let members: Vec<usize> = (0..n).filter(|&i| labels[i] == c).collect();
335        if members.is_empty() {
336            // Empty cluster: steal the farthest point from the largest cluster.
337            let singleton_idx = reassign_empty_cluster(labels, dist_mat);
338            let sub = subset_rows(data, &[singleton_idx]);
339            centers.push(karcher_mean(
340                &sub,
341                argvals,
342                1,
343                config.karcher_tol,
344                config.lambda,
345            ));
346        } else {
347            let sub = subset_rows(data, &members);
348            centers.push(karcher_mean(
349                &sub,
350                argvals,
351                config.karcher_max_iter,
352                config.karcher_tol,
353                config.lambda,
354            ));
355        }
356    }
357    centers
358}
359
360// ─── Hierarchical Clustering ────────────────────────────────────────────────
361
362/// Hierarchical elastic clustering using Fisher-Rao distances.
363///
364/// Builds a dendrogram by agglomerative clustering. Supported linkage methods
365/// are single, complete, and average. Passing [`ElasticClusterMethod::KMeans`]
366/// is treated as single linkage.
367///
368/// # Arguments
369/// * `data`    — Functional data matrix (n x m).
370/// * `argvals` — Evaluation points (length m).
371/// * `method`  — Linkage method.
372/// * `lambda`  — Roughness penalty for elastic alignment.
373///
374/// # Errors
375/// Returns [`FdarError::InvalidDimension`] if `argvals` length does not match `m`
376/// or `n < 2`.
377#[must_use = "expensive computation whose result should not be discarded"]
378pub fn elastic_hierarchical(
379    data: &FdMatrix,
380    argvals: &[f64],
381    method: ElasticClusterMethod,
382    lambda: f64,
383) -> Result<ElasticDendrogram, FdarError> {
384    let (n, m) = data.shape();
385
386    if argvals.len() != m {
387        return Err(FdarError::InvalidDimension {
388            parameter: "argvals",
389            expected: format!("{m}"),
390            actual: format!("{}", argvals.len()),
391        });
392    }
393    if n < 2 {
394        return Err(FdarError::InvalidDimension {
395            parameter: "data",
396            expected: "at least 2 rows".to_string(),
397            actual: format!("{n} rows"),
398        });
399    }
400
401    // Step 1: Compute full distance matrix.
402    let dist_mat = elastic_self_distance_matrix(data, argvals, lambda);
403
404    // Step 2: Initialize — working cluster distance matrix and metadata.
405    let mut active = vec![true; n];
406    let mut cluster_sizes = vec![1usize; n];
407    let mut cluster_dist = FdMatrix::zeros(n, n);
408    for i in 0..n {
409        for j in 0..n {
410            cluster_dist[(i, j)] = dist_mat[(i, j)];
411        }
412    }
413
414    let mut merges: Vec<(usize, usize, f64)> = Vec::with_capacity(n - 1);
415
416    // Step 3: n-1 merge steps.
417    for _ in 0..(n - 1) {
418        // Find the minimum-distance active pair.
419        let mut min_d = f64::INFINITY;
420        let mut min_i = 0;
421        let mut min_j = 1;
422        for i in 0..n {
423            if !active[i] {
424                continue;
425            }
426            for j in (i + 1)..n {
427                if !active[j] {
428                    continue;
429                }
430                if cluster_dist[(i, j)] < min_d {
431                    min_d = cluster_dist[(i, j)];
432                    min_i = i;
433                    min_j = j;
434                }
435            }
436        }
437
438        merges.push((min_i, min_j, min_d));
439
440        // Merge j into i: update distances to all other active clusters.
441        let size_i = cluster_sizes[min_i];
442        let size_j = cluster_sizes[min_j];
443        for k in 0..n {
444            if !active[k] || k == min_i || k == min_j {
445                continue;
446            }
447            let d_ik = cluster_dist[(min_i.min(k), min_i.max(k))];
448            let d_jk = cluster_dist[(min_j.min(k), min_j.max(k))];
449            let new_d = match method {
450                ElasticClusterMethod::HierarchicalSingle | ElasticClusterMethod::KMeans => {
451                    d_ik.min(d_jk)
452                }
453                ElasticClusterMethod::HierarchicalComplete => d_ik.max(d_jk),
454                ElasticClusterMethod::HierarchicalAverage => {
455                    (d_ik * size_i as f64 + d_jk * size_j as f64) / (size_i + size_j) as f64
456                }
457            };
458            let (lo, hi) = (min_i.min(k), min_i.max(k));
459            cluster_dist[(lo, hi)] = new_d;
460            cluster_dist[(hi, lo)] = new_d;
461        }
462
463        cluster_sizes[min_i] = size_i + size_j;
464        active[min_j] = false;
465    }
466
467    Ok(ElasticDendrogram {
468        merges,
469        distance_matrix: dist_mat,
470    })
471}
472
473// ─── Cut Dendrogram ─────────────────────────────────────────────────────────
474
475/// Cut a dendrogram to produce k clusters.
476///
477/// Replays the merge history, stopping after `n - k` merges, and returns
478/// cluster labels for each original observation.
479///
480/// # Arguments
481/// * `dendrogram` — Result of [`elastic_hierarchical`].
482/// * `k`          — Number of clusters desired.
483///
484/// # Errors
485/// Returns [`FdarError::InvalidParameter`] if `k < 1` or `k > n`.
486pub fn cut_dendrogram(dendrogram: &ElasticDendrogram, k: usize) -> Result<Vec<usize>, FdarError> {
487    let n = dendrogram.distance_matrix.nrows();
488
489    if k < 1 {
490        return Err(FdarError::InvalidParameter {
491            parameter: "k",
492            message: "k must be >= 1".to_string(),
493        });
494    }
495    if k > n {
496        return Err(FdarError::InvalidParameter {
497            parameter: "k",
498            message: format!("k ({k}) must be <= n ({n})"),
499        });
500    }
501
502    // Start with n singleton clusters, each labeled by its own index.
503    let mut cluster_of: Vec<usize> = (0..n).collect();
504    let merges_to_apply = n - k;
505
506    for &(ci, cj, _) in dendrogram.merges.iter().take(merges_to_apply) {
507        // Relabel all points in cj's current cluster to ci's current cluster.
508        let target = cluster_of[ci];
509        let source = cluster_of[cj];
510        for label in cluster_of.iter_mut() {
511            if *label == source {
512                *label = target;
513            }
514        }
515    }
516
517    // Compress labels to 0..k-1.
518    let mut unique: Vec<usize> = cluster_of.clone();
519    unique.sort_unstable();
520    unique.dedup();
521
522    let labels: Vec<usize> = cluster_of
523        .iter()
524        .map(|&c| unique.iter().position(|&u| u == c).unwrap())
525        .collect();
526
527    Ok(labels)
528}
529
530// ─── Tests ──────────────────────────────────────────────────────────────────
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535    use crate::simulation::{sim_fundata, EFunType, EValType};
536    use crate::test_helpers::uniform_grid;
537
538    fn make_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
539        let t = uniform_grid(m);
540        let data = sim_fundata(n, &t, 3, EFunType::Fourier, EValType::Exponential, Some(42));
541        (data, t)
542    }
543
544    #[test]
545    fn kmeans_smoke() {
546        let (data, t) = make_data(8, 20);
547        let config = ElasticClusterConfig {
548            k: 2,
549            max_iter: 3,
550            karcher_max_iter: 3,
551            ..Default::default()
552        };
553        let result = elastic_kmeans(&data, &t, &config).unwrap();
554        assert_eq!(result.labels.len(), 8);
555        assert_eq!(result.centers.len(), 2);
556        assert_eq!(result.within_distances.len(), 2);
557        assert!(result.total_within_distance >= 0.0);
558        assert!(result.n_iter >= 1);
559    }
560
561    #[test]
562    fn kmeans_single_cluster() {
563        let (data, t) = make_data(5, 20);
564        let config = ElasticClusterConfig {
565            k: 1,
566            max_iter: 3,
567            karcher_max_iter: 3,
568            ..Default::default()
569        };
570        let result = elastic_kmeans(&data, &t, &config).unwrap();
571        assert!(result.labels.iter().all(|&l| l == 0));
572        assert_eq!(result.centers.len(), 1);
573    }
574
575    #[test]
576    fn kmeans_k_too_large() {
577        let (data, t) = make_data(3, 20);
578        let config = ElasticClusterConfig {
579            k: 5,
580            ..Default::default()
581        };
582        assert!(elastic_kmeans(&data, &t, &config).is_err());
583    }
584
585    #[test]
586    fn kmeans_k_zero() {
587        let (data, t) = make_data(5, 20);
588        let config = ElasticClusterConfig {
589            k: 0,
590            ..Default::default()
591        };
592        assert!(elastic_kmeans(&data, &t, &config).is_err());
593    }
594
595    #[test]
596    fn hierarchical_single_smoke() {
597        let (data, t) = make_data(5, 20);
598        let dendro =
599            elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).unwrap();
600        assert_eq!(dendro.merges.len(), 4);
601        // Single linkage merge distances should be non-decreasing.
602        for w in dendro.merges.windows(2) {
603            assert!(
604                w[1].2 >= w[0].2 - 1e-10,
605                "single linkage should be non-decreasing"
606            );
607        }
608    }
609
610    #[test]
611    fn hierarchical_complete_smoke() {
612        let (data, t) = make_data(5, 20);
613        let dendro =
614            elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalComplete, 0.0)
615                .unwrap();
616        assert_eq!(dendro.merges.len(), 4);
617    }
618
619    #[test]
620    fn hierarchical_average_smoke() {
621        let (data, t) = make_data(5, 20);
622        let dendro =
623            elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalAverage, 0.0)
624                .unwrap();
625        assert_eq!(dendro.merges.len(), 4);
626    }
627
628    #[test]
629    fn hierarchical_too_few_curves() {
630        let t = uniform_grid(20);
631        let curve: Vec<f64> = t.iter().map(|&x| x.sin()).collect();
632        let data = FdMatrix::from_slice(&curve, 1, 20).unwrap();
633        assert!(
634            elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).is_err()
635        );
636    }
637
638    #[test]
639    fn cut_dendrogram_all_singletons() {
640        let (data, t) = make_data(5, 20);
641        let dendro =
642            elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).unwrap();
643        let labels = cut_dendrogram(&dendro, 5).unwrap();
644        // Each point in its own cluster.
645        let mut sorted = labels.clone();
646        sorted.sort_unstable();
647        assert_eq!(sorted, vec![0, 1, 2, 3, 4]);
648    }
649
650    #[test]
651    fn cut_dendrogram_one_cluster() {
652        let (data, t) = make_data(5, 20);
653        let dendro =
654            elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).unwrap();
655        let labels = cut_dendrogram(&dendro, 1).unwrap();
656        assert!(labels.iter().all(|&l| l == 0));
657    }
658
659    #[test]
660    fn cut_dendrogram_k_too_large() {
661        let (data, t) = make_data(5, 20);
662        let dendro =
663            elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).unwrap();
664        assert!(cut_dendrogram(&dendro, 10).is_err());
665    }
666
667    #[test]
668    fn cut_dendrogram_two_clusters() {
669        let (data, t) = make_data(6, 20);
670        let dendro =
671            elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).unwrap();
672        let labels = cut_dendrogram(&dendro, 2).unwrap();
673        assert_eq!(labels.len(), 6);
674        let unique: std::collections::HashSet<usize> = labels.iter().copied().collect();
675        assert_eq!(unique.len(), 2);
676    }
677
678    #[test]
679    fn default_config_values() {
680        let cfg = ElasticClusterConfig::default();
681        assert_eq!(cfg.k, 2);
682        assert!((cfg.lambda - 0.0).abs() < f64::EPSILON);
683        assert_eq!(cfg.max_iter, 20);
684        assert!((cfg.tol - 1e-4).abs() < f64::EPSILON);
685        assert_eq!(cfg.karcher_max_iter, 15);
686        assert!((cfg.karcher_tol - 1e-3).abs() < f64::EPSILON);
687        assert_eq!(cfg.seed, 42);
688    }
689
690    #[test]
691    fn default_method() {
692        assert_eq!(
693            ElasticClusterMethod::default(),
694            ElasticClusterMethod::KMeans
695        );
696    }
697}