1use crate::error::FdarError;
30use crate::matrix::FdMatrix;
31use rand::rngs::StdRng;
32use rand::{Rng, SeedableRng};
33
34#[derive(Debug, Clone, PartialEq)]
38pub struct KMedoidsConfig {
39 pub k: usize,
41 pub max_iter: usize,
43 pub seed: u64,
45}
46
47impl Default for KMedoidsConfig {
48 fn default() -> Self {
49 Self {
50 k: 2,
51 max_iter: 100,
52 seed: 42,
53 }
54 }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
59#[non_exhaustive]
60pub enum Linkage {
61 #[default]
63 Single,
64 Complete,
66 Average,
68}
69
70#[derive(Debug, Clone, PartialEq)]
72#[non_exhaustive]
73pub struct KMedoidsResult {
74 pub labels: Vec<usize>,
76 pub medoid_indices: Vec<usize>,
78 pub within_distances: Vec<f64>,
80 pub total_within_distance: f64,
82 pub n_iter: usize,
84 pub converged: bool,
86}
87
88#[derive(Debug, Clone, PartialEq)]
90#[non_exhaustive]
91pub struct Dendrogram {
92 pub merges: Vec<(usize, usize, f64)>,
95 pub n: usize,
97}
98
99fn kmeans_pp_init(dist_mat: &FdMatrix, k: usize, rng: &mut StdRng) -> Vec<usize> {
103 let n = dist_mat.nrows();
104 let mut centers = Vec::with_capacity(k);
105
106 centers.push(rng.gen_range(0..n));
107
108 let mut min_dist_sq: Vec<f64> = (0..n)
109 .map(|i| {
110 let d = dist_mat[(i, centers[0])];
111 d * d
112 })
113 .collect();
114
115 for _ in 1..k {
116 let total: f64 = min_dist_sq.iter().sum();
117 if total <= 0.0 {
118 for i in 0..n {
119 if !centers.contains(&i) {
120 centers.push(i);
121 break;
122 }
123 }
124 } else {
125 let threshold = rng.gen::<f64>() * total;
126 let mut cum = 0.0;
127 let mut chosen = n - 1;
128 for i in 0..n {
129 cum += min_dist_sq[i];
130 if cum >= threshold {
131 chosen = i;
132 break;
133 }
134 }
135 centers.push(chosen);
136 }
137
138 let new_center = *centers.last().unwrap();
139 for i in 0..n {
140 let d = dist_mat[(i, new_center)];
141 let d2 = d * d;
142 if d2 < min_dist_sq[i] {
143 min_dist_sq[i] = d2;
144 }
145 }
146 }
147
148 centers
149}
150
151#[must_use = "expensive computation whose result should not be discarded"]
169pub fn kmedoids_from_distances(
170 dist_mat: &FdMatrix,
171 config: &KMedoidsConfig,
172) -> Result<KMedoidsResult, FdarError> {
173 let n = dist_mat.nrows();
174 if dist_mat.ncols() != n {
175 return Err(FdarError::InvalidDimension {
176 parameter: "dist_mat",
177 expected: format!("{n} x {n} (square)"),
178 actual: format!("{} x {}", n, dist_mat.ncols()),
179 });
180 }
181 if config.k < 1 {
182 return Err(FdarError::InvalidParameter {
183 parameter: "k",
184 message: "k must be >= 1".to_string(),
185 });
186 }
187 if config.k > n {
188 return Err(FdarError::InvalidParameter {
189 parameter: "k",
190 message: format!("k ({}) must be <= n ({})", config.k, n),
191 });
192 }
193
194 let k = config.k;
195 let mut rng = StdRng::seed_from_u64(config.seed);
196 let mut medoids = kmeans_pp_init(dist_mat, k, &mut rng);
197
198 let mut labels = assign_to_medoids(dist_mat, &medoids, n);
200
201 let mut converged = false;
202 let mut n_iter = 0;
203
204 for iter in 0..config.max_iter {
205 n_iter = iter + 1;
206
207 for c in 0..k {
209 let members: Vec<usize> = (0..n).filter(|&i| labels[i] == c).collect();
210 if members.is_empty() {
211 continue;
212 }
213 let mut best_cost = f64::INFINITY;
214 let mut best_m = medoids[c];
215 for &candidate in &members {
216 let cost: f64 = members.iter().map(|&j| dist_mat[(candidate, j)]).sum();
217 if cost < best_cost {
218 best_cost = cost;
219 best_m = candidate;
220 }
221 }
222 medoids[c] = best_m;
223 }
224
225 let new_labels = assign_to_medoids(dist_mat, &medoids, n);
227 if new_labels == labels {
228 converged = true;
229 labels = new_labels;
230 break;
231 }
232 labels = new_labels;
233 }
234
235 let mut within_distances = vec![0.0; k];
237 for i in 0..n {
238 within_distances[labels[i]] += dist_mat[(i, medoids[labels[i]])];
239 }
240 let total_within_distance: f64 = within_distances.iter().sum();
241
242 Ok(KMedoidsResult {
243 labels,
244 medoid_indices: medoids,
245 within_distances,
246 total_within_distance,
247 n_iter,
248 converged,
249 })
250}
251
252fn assign_to_medoids(dist_mat: &FdMatrix, medoids: &[usize], n: usize) -> Vec<usize> {
253 (0..n)
254 .map(|i| {
255 let mut best_d = f64::INFINITY;
256 let mut best_c = 0;
257 for (c, &med) in medoids.iter().enumerate() {
258 let d = dist_mat[(i, med)];
259 if d < best_d {
260 best_d = d;
261 best_c = c;
262 }
263 }
264 best_c
265 })
266 .collect()
267}
268
269#[must_use = "expensive computation whose result should not be discarded"]
283pub fn hierarchical_from_distances(
284 dist_mat: &FdMatrix,
285 linkage: Linkage,
286) -> Result<Dendrogram, FdarError> {
287 let n = dist_mat.nrows();
288 if dist_mat.ncols() != n {
289 return Err(FdarError::InvalidDimension {
290 parameter: "dist_mat",
291 expected: format!("{n} x {n} (square)"),
292 actual: format!("{} x {}", n, dist_mat.ncols()),
293 });
294 }
295 if n < 2 {
296 return Err(FdarError::InvalidDimension {
297 parameter: "dist_mat",
298 expected: "at least 2 rows".to_string(),
299 actual: format!("{n} rows"),
300 });
301 }
302
303 let mut active = vec![true; n];
304 let mut cluster_sizes = vec![1usize; n];
305 let mut cluster_dist = FdMatrix::zeros(n, n);
306 for i in 0..n {
307 for j in 0..n {
308 cluster_dist[(i, j)] = dist_mat[(i, j)];
309 }
310 }
311
312 let mut merges: Vec<(usize, usize, f64)> = Vec::with_capacity(n - 1);
313
314 for _ in 0..(n - 1) {
315 let mut min_d = f64::INFINITY;
316 let mut min_i = 0;
317 let mut min_j = 1;
318 for i in 0..n {
319 if !active[i] {
320 continue;
321 }
322 for j in (i + 1)..n {
323 if !active[j] {
324 continue;
325 }
326 if cluster_dist[(i, j)] < min_d {
327 min_d = cluster_dist[(i, j)];
328 min_i = i;
329 min_j = j;
330 }
331 }
332 }
333
334 merges.push((min_i, min_j, min_d));
335
336 let size_i = cluster_sizes[min_i];
337 let size_j = cluster_sizes[min_j];
338 for k in 0..n {
339 if !active[k] || k == min_i || k == min_j {
340 continue;
341 }
342 let d_ik = cluster_dist[(min_i.min(k), min_i.max(k))];
343 let d_jk = cluster_dist[(min_j.min(k), min_j.max(k))];
344 let new_d = match linkage {
345 Linkage::Single => d_ik.min(d_jk),
346 Linkage::Complete => d_ik.max(d_jk),
347 Linkage::Average => {
348 (d_ik * size_i as f64 + d_jk * size_j as f64) / (size_i + size_j) as f64
349 }
350 };
351 let (lo, hi) = (min_i.min(k), min_i.max(k));
352 cluster_dist[(lo, hi)] = new_d;
353 cluster_dist[(hi, lo)] = new_d;
354 }
355
356 cluster_sizes[min_i] = size_i + size_j;
357 active[min_j] = false;
358 }
359
360 Ok(Dendrogram { merges, n })
361}
362
363pub fn cut_dendrogram(dendrogram: &Dendrogram, k: usize) -> Result<Vec<usize>, FdarError> {
377 let n = dendrogram.n;
378
379 if k < 1 {
380 return Err(FdarError::InvalidParameter {
381 parameter: "k",
382 message: "k must be >= 1".to_string(),
383 });
384 }
385 if k > n {
386 return Err(FdarError::InvalidParameter {
387 parameter: "k",
388 message: format!("k ({k}) must be <= n ({n})"),
389 });
390 }
391
392 let mut cluster_of: Vec<usize> = (0..n).collect();
393 let merges_to_apply = n - k;
394
395 for &(ci, cj, _) in dendrogram.merges.iter().take(merges_to_apply) {
396 let target = cluster_of[ci];
397 let source = cluster_of[cj];
398 for label in cluster_of.iter_mut() {
399 if *label == source {
400 *label = target;
401 }
402 }
403 }
404
405 let mut unique: Vec<usize> = cluster_of.clone();
407 unique.sort_unstable();
408 unique.dedup();
409 let labels = cluster_of
410 .iter()
411 .map(|&l| unique.iter().position(|&u| u == l).unwrap())
412 .collect();
413
414 Ok(labels)
415}
416
417#[cfg(test)]
420mod tests {
421 use super::*;
422 use crate::alignment::elastic_self_distance_matrix;
423 use crate::simulation::{sim_fundata, EFunType, EValType};
424 use crate::test_helpers::uniform_grid;
425
426 fn make_dist_mat(n: usize, m: usize) -> FdMatrix {
427 let t = uniform_grid(m);
428 let data = sim_fundata(n, &t, 3, EFunType::Fourier, EValType::Exponential, Some(42));
429 elastic_self_distance_matrix(&data, &t, 0.0)
430 }
431
432 #[test]
433 fn kmedoids_smoke() {
434 let dist = make_dist_mat(8, 20);
435 let config = KMedoidsConfig {
436 k: 2,
437 max_iter: 10,
438 ..Default::default()
439 };
440 let result = kmedoids_from_distances(&dist, &config).unwrap();
441 assert_eq!(result.labels.len(), 8);
442 assert_eq!(result.medoid_indices.len(), 2);
443 assert_eq!(result.within_distances.len(), 2);
444 assert!(result.total_within_distance >= 0.0);
445 assert!(result.n_iter >= 1);
446 }
447
448 #[test]
449 fn kmedoids_single_cluster() {
450 let dist = make_dist_mat(5, 20);
451 let config = KMedoidsConfig {
452 k: 1,
453 max_iter: 10,
454 ..Default::default()
455 };
456 let result = kmedoids_from_distances(&dist, &config).unwrap();
457 assert!(result.labels.iter().all(|&l| l == 0));
458 assert_eq!(result.medoid_indices.len(), 1);
459 }
460
461 #[test]
462 fn kmedoids_k_too_large() {
463 let dist = make_dist_mat(3, 20);
464 let config = KMedoidsConfig {
465 k: 5,
466 ..Default::default()
467 };
468 assert!(kmedoids_from_distances(&dist, &config).is_err());
469 }
470
471 #[test]
472 fn kmedoids_k_zero() {
473 let dist = make_dist_mat(5, 20);
474 let config = KMedoidsConfig {
475 k: 0,
476 ..Default::default()
477 };
478 assert!(kmedoids_from_distances(&dist, &config).is_err());
479 }
480
481 #[test]
482 fn hierarchical_single_smoke() {
483 let dist = make_dist_mat(5, 20);
484 let dendro = hierarchical_from_distances(&dist, Linkage::Single).unwrap();
485 assert_eq!(dendro.merges.len(), 4);
486 for w in dendro.merges.windows(2) {
487 assert!(
488 w[1].2 >= w[0].2 - 1e-10,
489 "single linkage should be non-decreasing"
490 );
491 }
492 }
493
494 #[test]
495 fn hierarchical_complete_smoke() {
496 let dist = make_dist_mat(5, 20);
497 let dendro = hierarchical_from_distances(&dist, Linkage::Complete).unwrap();
498 assert_eq!(dendro.merges.len(), 4);
499 }
500
501 #[test]
502 fn hierarchical_average_smoke() {
503 let dist = make_dist_mat(5, 20);
504 let dendro = hierarchical_from_distances(&dist, Linkage::Average).unwrap();
505 assert_eq!(dendro.merges.len(), 4);
506 }
507
508 #[test]
509 fn hierarchical_too_few() {
510 let dist = FdMatrix::zeros(1, 1);
511 assert!(hierarchical_from_distances(&dist, Linkage::Single).is_err());
512 }
513
514 #[test]
515 fn cut_dendrogram_all_singletons() {
516 let dist = make_dist_mat(5, 20);
517 let dendro = hierarchical_from_distances(&dist, Linkage::Single).unwrap();
518 let labels = cut_dendrogram(&dendro, 5).unwrap();
519 let mut sorted = labels.clone();
520 sorted.sort_unstable();
521 assert_eq!(sorted, vec![0, 1, 2, 3, 4]);
522 }
523
524 #[test]
525 fn cut_dendrogram_one_cluster() {
526 let dist = make_dist_mat(5, 20);
527 let dendro = hierarchical_from_distances(&dist, Linkage::Single).unwrap();
528 let labels = cut_dendrogram(&dendro, 1).unwrap();
529 assert!(labels.iter().all(|&l| l == 0));
530 }
531
532 #[test]
533 fn cut_dendrogram_k_too_large() {
534 let dist = make_dist_mat(5, 20);
535 let dendro = hierarchical_from_distances(&dist, Linkage::Single).unwrap();
536 assert!(cut_dendrogram(&dendro, 10).is_err());
537 }
538
539 #[test]
540 fn cut_dendrogram_two_clusters() {
541 let dist = make_dist_mat(6, 20);
542 let dendro = hierarchical_from_distances(&dist, Linkage::Single).unwrap();
543 let labels = cut_dendrogram(&dendro, 2).unwrap();
544 assert_eq!(labels.len(), 6);
545 let unique: std::collections::HashSet<usize> = labels.iter().copied().collect();
546 assert_eq!(unique.len(), 2);
547 }
548
549 #[test]
550 fn default_config_values() {
551 let cfg = KMedoidsConfig::default();
552 assert_eq!(cfg.k, 2);
553 assert_eq!(cfg.max_iter, 100);
554 assert_eq!(cfg.seed, 42);
555 }
556
557 #[test]
558 fn default_linkage() {
559 assert_eq!(Linkage::default(), Linkage::Single);
560 }
561
562 #[test]
563 fn non_square_dist_mat_error() {
564 let dist = FdMatrix::zeros(3, 4);
565 assert!(hierarchical_from_distances(&dist, Linkage::Single).is_err());
566 let config = KMedoidsConfig::default();
567 assert!(kmedoids_from_distances(&dist, &config).is_err());
568 }
569}