1use super::karcher::karcher_mean;
9use super::pairwise::{elastic_distance, elastic_self_distance_matrix};
10use super::KarcherMeanResult;
11use crate::cv::subset_rows;
12use crate::error::FdarError;
13use crate::matrix::FdMatrix;
14use rand::rngs::StdRng;
15use rand::{Rng, SeedableRng};
16
17#[derive(Debug, Clone, PartialEq)]
21pub struct ElasticClusterConfig {
22 pub k: usize,
24 pub lambda: f64,
26 pub max_iter: usize,
28 pub tol: f64,
30 pub karcher_max_iter: usize,
32 pub karcher_tol: f64,
34 pub seed: u64,
36}
37
38impl Default for ElasticClusterConfig {
39 fn default() -> Self {
40 Self {
41 k: 2,
42 lambda: 0.0,
43 max_iter: 20,
44 tol: 1e-4,
45 karcher_max_iter: 15,
46 karcher_tol: 1e-3,
47 seed: 42,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
54#[non_exhaustive]
55pub enum ElasticClusterMethod {
56 #[default]
58 KMeans,
59 HierarchicalSingle,
61 HierarchicalComplete,
63 HierarchicalAverage,
65}
66
67#[derive(Debug, Clone, PartialEq)]
69#[non_exhaustive]
70pub struct ElasticClusterResult {
71 pub labels: Vec<usize>,
73 pub centers: Vec<KarcherMeanResult>,
75 pub within_distances: Vec<f64>,
77 pub total_within_distance: f64,
79 pub n_iter: usize,
81 pub converged: bool,
83}
84
85#[derive(Debug, Clone, PartialEq)]
87#[non_exhaustive]
88pub struct ElasticDendrogram {
89 pub merges: Vec<(usize, usize, f64)>,
92 pub distance_matrix: FdMatrix,
94}
95
96fn kmeans_pp_init(dist_mat: &FdMatrix, k: usize, rng: &mut StdRng) -> Vec<usize> {
100 let n = dist_mat.nrows();
101 let mut centers = Vec::with_capacity(k);
102
103 centers.push(rng.gen_range(0..n));
105
106 let mut min_dist_sq: Vec<f64> = (0..n)
108 .map(|i| {
109 let d = dist_mat[(i, centers[0])];
110 d * d
111 })
112 .collect();
113
114 for _ in 1..k {
115 let total: f64 = min_dist_sq.iter().sum();
116 if total <= 0.0 {
117 for i in 0..n {
119 if !centers.contains(&i) {
120 centers.push(i);
121 break;
122 }
123 }
124 } else {
125 let threshold = rng.gen::<f64>() * total;
126 let mut cum = 0.0;
127 let mut chosen = n - 1;
128 for i in 0..n {
129 cum += min_dist_sq[i];
130 if cum >= threshold {
131 chosen = i;
132 break;
133 }
134 }
135 centers.push(chosen);
136 }
137
138 let new_center = *centers.last().unwrap();
140 for i in 0..n {
141 let d = dist_mat[(i, new_center)];
142 let d2 = d * d;
143 if d2 < min_dist_sq[i] {
144 min_dist_sq[i] = d2;
145 }
146 }
147 }
148
149 centers
150}
151
152fn reassign_empty_cluster(labels: &[usize], dist_mat: &FdMatrix) -> usize {
157 let n = labels.len();
158
159 let max_label = labels.iter().copied().max().unwrap_or(0);
161 let mut counts = vec![0usize; max_label + 1];
162 for &l in labels {
163 counts[l] += 1;
164 }
165 let largest_cluster = counts
166 .iter()
167 .enumerate()
168 .max_by_key(|&(_, &cnt)| cnt)
169 .map(|(c, _)| c)
170 .unwrap_or(0);
171
172 let members: Vec<usize> = (0..n).filter(|&i| labels[i] == largest_cluster).collect();
174 let mut max_avg_dist = -1.0_f64;
175 let mut farthest = members[0];
176 for &i in &members {
177 let avg_d: f64 =
178 members.iter().map(|&j| dist_mat[(i, j)]).sum::<f64>() / members.len() as f64;
179 if avg_d > max_avg_dist {
180 max_avg_dist = avg_d;
181 farthest = i;
182 }
183 }
184 farthest
185}
186
187#[must_use = "expensive computation whose result should not be discarded"]
210pub fn elastic_kmeans(
211 data: &FdMatrix,
212 argvals: &[f64],
213 config: &ElasticClusterConfig,
214) -> Result<ElasticClusterResult, FdarError> {
215 let (n, m) = data.shape();
216
217 if config.k < 1 {
218 return Err(FdarError::InvalidParameter {
219 parameter: "k",
220 message: "k must be >= 1".to_string(),
221 });
222 }
223 if config.k > n {
224 return Err(FdarError::InvalidParameter {
225 parameter: "k",
226 message: format!("k ({}) must be <= n ({})", config.k, n),
227 });
228 }
229 if argvals.len() != m {
230 return Err(FdarError::InvalidDimension {
231 parameter: "argvals",
232 expected: format!("{m}"),
233 actual: format!("{}", argvals.len()),
234 });
235 }
236
237 let k = config.k;
238
239 let dist_mat = elastic_self_distance_matrix(data, argvals, config.lambda);
241
242 let mut rng = StdRng::seed_from_u64(config.seed);
244 let center_indices = kmeans_pp_init(&dist_mat, k, &mut rng);
245
246 let mut labels = vec![0usize; n];
248 for i in 0..n {
249 let mut best_d = f64::INFINITY;
250 for (c, &ci) in center_indices.iter().enumerate() {
251 let d = dist_mat[(i, ci)];
252 if d < best_d {
253 best_d = d;
254 labels[i] = c;
255 }
256 }
257 }
258
259 let mut converged = false;
261 let mut n_iter = 0;
262 let mut centers: Vec<KarcherMeanResult> = Vec::with_capacity(k);
263
264 for iter in 0..config.max_iter {
265 n_iter = iter + 1;
266
267 centers = compute_cluster_centers(data, argvals, &labels, k, &dist_mat, config);
269
270 let new_labels: Vec<usize> = (0..n)
272 .map(|i| {
273 let fi = data.row(i);
274 let mut best_d = f64::INFINITY;
275 let mut best_c = 0;
276 for (c, center) in centers.iter().enumerate() {
277 let d = elastic_distance(&fi, ¢er.mean, argvals, config.lambda);
278 if d < best_d {
279 best_d = d;
280 best_c = c;
281 }
282 }
283 best_c
284 })
285 .collect();
286
287 if new_labels == labels {
289 converged = true;
290 labels = new_labels;
291 break;
292 }
293
294 labels = new_labels;
295 }
296
297 if !converged {
299 centers = compute_cluster_centers(data, argvals, &labels, k, &dist_mat, config);
300 }
301
302 let mut within_distances = vec![0.0; k];
304 for i in 0..n {
305 let fi = data.row(i);
306 let c = labels[i];
307 let d = elastic_distance(&fi, ¢ers[c].mean, argvals, config.lambda);
308 within_distances[c] += d;
309 }
310 let total_within_distance: f64 = within_distances.iter().sum();
311
312 Ok(ElasticClusterResult {
313 labels,
314 centers,
315 within_distances,
316 total_within_distance,
317 n_iter,
318 converged,
319 })
320}
321
322fn compute_cluster_centers(
324 data: &FdMatrix,
325 argvals: &[f64],
326 labels: &[usize],
327 k: usize,
328 dist_mat: &FdMatrix,
329 config: &ElasticClusterConfig,
330) -> Vec<KarcherMeanResult> {
331 let n = data.nrows();
332 let mut centers = Vec::with_capacity(k);
333 for c in 0..k {
334 let members: Vec<usize> = (0..n).filter(|&i| labels[i] == c).collect();
335 if members.is_empty() {
336 let singleton_idx = reassign_empty_cluster(labels, dist_mat);
338 let sub = subset_rows(data, &[singleton_idx]);
339 centers.push(karcher_mean(
340 &sub,
341 argvals,
342 1,
343 config.karcher_tol,
344 config.lambda,
345 ));
346 } else {
347 let sub = subset_rows(data, &members);
348 centers.push(karcher_mean(
349 &sub,
350 argvals,
351 config.karcher_max_iter,
352 config.karcher_tol,
353 config.lambda,
354 ));
355 }
356 }
357 centers
358}
359
360#[must_use = "expensive computation whose result should not be discarded"]
378pub fn elastic_hierarchical(
379 data: &FdMatrix,
380 argvals: &[f64],
381 method: ElasticClusterMethod,
382 lambda: f64,
383) -> Result<ElasticDendrogram, FdarError> {
384 let (n, m) = data.shape();
385
386 if argvals.len() != m {
387 return Err(FdarError::InvalidDimension {
388 parameter: "argvals",
389 expected: format!("{m}"),
390 actual: format!("{}", argvals.len()),
391 });
392 }
393 if n < 2 {
394 return Err(FdarError::InvalidDimension {
395 parameter: "data",
396 expected: "at least 2 rows".to_string(),
397 actual: format!("{n} rows"),
398 });
399 }
400
401 let dist_mat = elastic_self_distance_matrix(data, argvals, lambda);
403
404 let mut active = vec![true; n];
406 let mut cluster_sizes = vec![1usize; n];
407 let mut cluster_dist = FdMatrix::zeros(n, n);
408 for i in 0..n {
409 for j in 0..n {
410 cluster_dist[(i, j)] = dist_mat[(i, j)];
411 }
412 }
413
414 let mut merges: Vec<(usize, usize, f64)> = Vec::with_capacity(n - 1);
415
416 for _ in 0..(n - 1) {
418 let mut min_d = f64::INFINITY;
420 let mut min_i = 0;
421 let mut min_j = 1;
422 for i in 0..n {
423 if !active[i] {
424 continue;
425 }
426 for j in (i + 1)..n {
427 if !active[j] {
428 continue;
429 }
430 if cluster_dist[(i, j)] < min_d {
431 min_d = cluster_dist[(i, j)];
432 min_i = i;
433 min_j = j;
434 }
435 }
436 }
437
438 merges.push((min_i, min_j, min_d));
439
440 let size_i = cluster_sizes[min_i];
442 let size_j = cluster_sizes[min_j];
443 for k in 0..n {
444 if !active[k] || k == min_i || k == min_j {
445 continue;
446 }
447 let d_ik = cluster_dist[(min_i.min(k), min_i.max(k))];
448 let d_jk = cluster_dist[(min_j.min(k), min_j.max(k))];
449 let new_d = match method {
450 ElasticClusterMethod::HierarchicalSingle | ElasticClusterMethod::KMeans => {
451 d_ik.min(d_jk)
452 }
453 ElasticClusterMethod::HierarchicalComplete => d_ik.max(d_jk),
454 ElasticClusterMethod::HierarchicalAverage => {
455 (d_ik * size_i as f64 + d_jk * size_j as f64) / (size_i + size_j) as f64
456 }
457 };
458 let (lo, hi) = (min_i.min(k), min_i.max(k));
459 cluster_dist[(lo, hi)] = new_d;
460 cluster_dist[(hi, lo)] = new_d;
461 }
462
463 cluster_sizes[min_i] = size_i + size_j;
464 active[min_j] = false;
465 }
466
467 Ok(ElasticDendrogram {
468 merges,
469 distance_matrix: dist_mat,
470 })
471}
472
473pub fn cut_dendrogram(dendrogram: &ElasticDendrogram, k: usize) -> Result<Vec<usize>, FdarError> {
487 let n = dendrogram.distance_matrix.nrows();
488
489 if k < 1 {
490 return Err(FdarError::InvalidParameter {
491 parameter: "k",
492 message: "k must be >= 1".to_string(),
493 });
494 }
495 if k > n {
496 return Err(FdarError::InvalidParameter {
497 parameter: "k",
498 message: format!("k ({k}) must be <= n ({n})"),
499 });
500 }
501
502 let mut cluster_of: Vec<usize> = (0..n).collect();
504 let merges_to_apply = n - k;
505
506 for &(ci, cj, _) in dendrogram.merges.iter().take(merges_to_apply) {
507 let target = cluster_of[ci];
509 let source = cluster_of[cj];
510 for label in cluster_of.iter_mut() {
511 if *label == source {
512 *label = target;
513 }
514 }
515 }
516
517 let mut unique: Vec<usize> = cluster_of.clone();
519 unique.sort_unstable();
520 unique.dedup();
521
522 let labels: Vec<usize> = cluster_of
523 .iter()
524 .map(|&c| unique.iter().position(|&u| u == c).unwrap())
525 .collect();
526
527 Ok(labels)
528}
529
530#[cfg(test)]
533mod tests {
534 use super::*;
535 use crate::simulation::{sim_fundata, EFunType, EValType};
536 use crate::test_helpers::uniform_grid;
537
538 fn make_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
539 let t = uniform_grid(m);
540 let data = sim_fundata(n, &t, 3, EFunType::Fourier, EValType::Exponential, Some(42));
541 (data, t)
542 }
543
544 #[test]
545 fn kmeans_smoke() {
546 let (data, t) = make_data(8, 20);
547 let config = ElasticClusterConfig {
548 k: 2,
549 max_iter: 3,
550 karcher_max_iter: 3,
551 ..Default::default()
552 };
553 let result = elastic_kmeans(&data, &t, &config).unwrap();
554 assert_eq!(result.labels.len(), 8);
555 assert_eq!(result.centers.len(), 2);
556 assert_eq!(result.within_distances.len(), 2);
557 assert!(result.total_within_distance >= 0.0);
558 assert!(result.n_iter >= 1);
559 }
560
561 #[test]
562 fn kmeans_single_cluster() {
563 let (data, t) = make_data(5, 20);
564 let config = ElasticClusterConfig {
565 k: 1,
566 max_iter: 3,
567 karcher_max_iter: 3,
568 ..Default::default()
569 };
570 let result = elastic_kmeans(&data, &t, &config).unwrap();
571 assert!(result.labels.iter().all(|&l| l == 0));
572 assert_eq!(result.centers.len(), 1);
573 }
574
575 #[test]
576 fn kmeans_k_too_large() {
577 let (data, t) = make_data(3, 20);
578 let config = ElasticClusterConfig {
579 k: 5,
580 ..Default::default()
581 };
582 assert!(elastic_kmeans(&data, &t, &config).is_err());
583 }
584
585 #[test]
586 fn kmeans_k_zero() {
587 let (data, t) = make_data(5, 20);
588 let config = ElasticClusterConfig {
589 k: 0,
590 ..Default::default()
591 };
592 assert!(elastic_kmeans(&data, &t, &config).is_err());
593 }
594
595 #[test]
596 fn hierarchical_single_smoke() {
597 let (data, t) = make_data(5, 20);
598 let dendro =
599 elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).unwrap();
600 assert_eq!(dendro.merges.len(), 4);
601 for w in dendro.merges.windows(2) {
603 assert!(
604 w[1].2 >= w[0].2 - 1e-10,
605 "single linkage should be non-decreasing"
606 );
607 }
608 }
609
610 #[test]
611 fn hierarchical_complete_smoke() {
612 let (data, t) = make_data(5, 20);
613 let dendro =
614 elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalComplete, 0.0)
615 .unwrap();
616 assert_eq!(dendro.merges.len(), 4);
617 }
618
619 #[test]
620 fn hierarchical_average_smoke() {
621 let (data, t) = make_data(5, 20);
622 let dendro =
623 elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalAverage, 0.0)
624 .unwrap();
625 assert_eq!(dendro.merges.len(), 4);
626 }
627
628 #[test]
629 fn hierarchical_too_few_curves() {
630 let t = uniform_grid(20);
631 let curve: Vec<f64> = t.iter().map(|&x| x.sin()).collect();
632 let data = FdMatrix::from_slice(&curve, 1, 20).unwrap();
633 assert!(
634 elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).is_err()
635 );
636 }
637
638 #[test]
639 fn cut_dendrogram_all_singletons() {
640 let (data, t) = make_data(5, 20);
641 let dendro =
642 elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).unwrap();
643 let labels = cut_dendrogram(&dendro, 5).unwrap();
644 let mut sorted = labels.clone();
646 sorted.sort_unstable();
647 assert_eq!(sorted, vec![0, 1, 2, 3, 4]);
648 }
649
650 #[test]
651 fn cut_dendrogram_one_cluster() {
652 let (data, t) = make_data(5, 20);
653 let dendro =
654 elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).unwrap();
655 let labels = cut_dendrogram(&dendro, 1).unwrap();
656 assert!(labels.iter().all(|&l| l == 0));
657 }
658
659 #[test]
660 fn cut_dendrogram_k_too_large() {
661 let (data, t) = make_data(5, 20);
662 let dendro =
663 elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).unwrap();
664 assert!(cut_dendrogram(&dendro, 10).is_err());
665 }
666
667 #[test]
668 fn cut_dendrogram_two_clusters() {
669 let (data, t) = make_data(6, 20);
670 let dendro =
671 elastic_hierarchical(&data, &t, ElasticClusterMethod::HierarchicalSingle, 0.0).unwrap();
672 let labels = cut_dendrogram(&dendro, 2).unwrap();
673 assert_eq!(labels.len(), 6);
674 let unique: std::collections::HashSet<usize> = labels.iter().copied().collect();
675 assert_eq!(unique.len(), 2);
676 }
677
678 #[test]
679 fn default_config_values() {
680 let cfg = ElasticClusterConfig::default();
681 assert_eq!(cfg.k, 2);
682 assert!((cfg.lambda - 0.0).abs() < f64::EPSILON);
683 assert_eq!(cfg.max_iter, 20);
684 assert!((cfg.tol - 1e-4).abs() < f64::EPSILON);
685 assert_eq!(cfg.karcher_max_iter, 15);
686 assert!((cfg.karcher_tol - 1e-3).abs() < f64::EPSILON);
687 assert_eq!(cfg.seed, 42);
688 }
689
690 #[test]
691 fn default_method() {
692 assert_eq!(
693 ElasticClusterMethod::default(),
694 ElasticClusterMethod::KMeans
695 );
696 }
697}