1use crate::helpers::{extract_curves, l2_distance, simpsons_weights, NUMERICAL_EPS};
7use crate::{iter_maybe_parallel, slice_maybe_parallel};
8use rand::prelude::*;
9#[cfg(feature = "parallel")]
10use rayon::iter::ParallelIterator;
11
12pub struct KmeansResult {
14 pub cluster: Vec<usize>,
16 pub centers: Vec<f64>,
18 pub withinss: Vec<f64>,
20 pub tot_withinss: f64,
22 pub iter: usize,
24 pub converged: bool,
26}
27
28fn kmeans_plusplus_init(
39 curves: &[Vec<f64>],
40 k: usize,
41 weights: &[f64],
42 rng: &mut StdRng,
43) -> Vec<Vec<f64>> {
44 let n = curves.len();
45 let mut centers: Vec<Vec<f64>> = Vec::with_capacity(k);
46
47 let first_idx = rng.gen_range(0..n);
49 centers.push(curves[first_idx].clone());
50
51 for _ in 1..k {
53 let distances: Vec<f64> = curves
54 .iter()
55 .map(|curve| {
56 centers
57 .iter()
58 .map(|c| l2_distance(curve, c, weights))
59 .fold(f64::INFINITY, f64::min)
60 })
61 .collect();
62
63 let dist_sq: Vec<f64> = distances.iter().map(|d| d * d).collect();
64 let total: f64 = dist_sq.iter().sum();
65
66 if total < NUMERICAL_EPS {
67 let idx = rng.gen_range(0..n);
68 centers.push(curves[idx].clone());
69 } else {
70 let r = rng.gen::<f64>() * total;
71 let mut cumsum = 0.0;
72 let mut chosen = 0;
73 for (i, &d) in dist_sq.iter().enumerate() {
74 cumsum += d;
75 if cumsum >= r {
76 chosen = i;
77 break;
78 }
79 }
80 centers.push(curves[chosen].clone());
81 }
82 }
83
84 centers
85}
86
87fn compute_fuzzy_membership(distances: &[f64], exponent: f64) -> Vec<f64> {
96 let k = distances.len();
97 let mut membership = vec![0.0; k];
98
99 for (c, &dist) in distances.iter().enumerate() {
101 if dist < NUMERICAL_EPS {
102 membership[c] = 1.0;
104 return membership;
105 }
106 }
107
108 for c in 0..k {
110 let mut sum = 0.0;
111 for c2 in 0..k {
112 if distances[c2] > NUMERICAL_EPS {
113 sum += (distances[c] / distances[c2]).powf(exponent);
114 }
115 }
116 membership[c] = if sum > NUMERICAL_EPS { 1.0 / sum } else { 1.0 };
117 }
118
119 membership
120}
121
122fn flatten_centers_colmajor(centers: &[Vec<f64>], k: usize, m: usize) -> Vec<f64> {
124 let mut flat = vec![0.0; k * m];
125 for c in 0..k {
126 for j in 0..m {
127 flat[c + j * k] = centers[c][j];
128 }
129 }
130 flat
131}
132
133fn init_random_membership(n: usize, k: usize, rng: &mut StdRng) -> Vec<f64> {
135 let mut membership = vec![0.0; n * k];
136 for i in 0..n {
137 let mut row_sum = 0.0;
138 for c in 0..k {
139 let val = rng.gen::<f64>();
140 membership[i + c * n] = val;
141 row_sum += val;
142 }
143 for c in 0..k {
144 membership[i + c * n] /= row_sum;
145 }
146 }
147 membership
148}
149
150fn cluster_member_indices(cluster: &[usize], k: usize) -> Vec<Vec<usize>> {
152 let mut indices = vec![Vec::new(); k];
153 for (i, &c) in cluster.iter().enumerate() {
154 indices[c].push(i);
155 }
156 indices
157}
158
159fn assign_clusters(curves: &[Vec<f64>], centers: &[Vec<f64>], weights: &[f64]) -> Vec<usize> {
161 slice_maybe_parallel!(curves)
162 .map(|curve| {
163 let mut best_cluster = 0;
164 let mut best_dist = f64::INFINITY;
165 for (c, center) in centers.iter().enumerate() {
166 let dist = l2_distance(curve, center, weights);
167 if dist < best_dist {
168 best_dist = dist;
169 best_cluster = c;
170 }
171 }
172 best_cluster
173 })
174 .collect()
175}
176
177fn update_kmeans_centers(
179 curves: &[Vec<f64>],
180 assignments: &[usize],
181 centers: &[Vec<f64>],
182 k: usize,
183 m: usize,
184) -> Vec<Vec<f64>> {
185 (0..k)
186 .map(|c| {
187 let members: Vec<usize> = assignments
188 .iter()
189 .enumerate()
190 .filter(|(_, &cl)| cl == c)
191 .map(|(i, _)| i)
192 .collect();
193
194 if members.is_empty() {
195 centers[c].clone()
196 } else {
197 let mut center = vec![0.0; m];
198 for &i in &members {
199 for j in 0..m {
200 center[j] += curves[i][j];
201 }
202 }
203 let n_members = members.len() as f64;
204 for j in 0..m {
205 center[j] /= n_members;
206 }
207 center
208 }
209 })
210 .collect()
211}
212
213fn compute_within_ss(
215 curves: &[Vec<f64>],
216 centers: &[Vec<f64>],
217 assignments: &[usize],
218 k: usize,
219 weights: &[f64],
220) -> Vec<f64> {
221 let mut withinss = vec![0.0; k];
222 for (i, curve) in curves.iter().enumerate() {
223 let c = assignments[i];
224 let dist = l2_distance(curve, ¢ers[c], weights);
225 withinss[c] += dist * dist;
226 }
227 withinss
228}
229
230fn update_fuzzy_centers(
232 curves: &[Vec<f64>],
233 membership: &[f64],
234 n: usize,
235 k: usize,
236 m: usize,
237 fuzziness: f64,
238) -> Vec<Vec<f64>> {
239 let mut centers = vec![vec![0.0; m]; k];
240 for c in 0..k {
241 let mut numerator = vec![0.0; m];
242 let mut denominator = 0.0;
243
244 for (i, curve) in curves.iter().enumerate() {
245 let weight = membership[i + c * n].powf(fuzziness);
246 for j in 0..m {
247 numerator[j] += weight * curve[j];
248 }
249 denominator += weight;
250 }
251
252 if denominator > NUMERICAL_EPS {
253 for j in 0..m {
254 centers[c][j] = numerator[j] / denominator;
255 }
256 }
257 }
258 centers
259}
260
261fn update_fuzzy_membership_step(
263 curves: &[Vec<f64>],
264 centers: &[Vec<f64>],
265 old_membership: &[f64],
266 n: usize,
267 k: usize,
268 exponent: f64,
269 weights: &[f64],
270) -> (Vec<f64>, f64) {
271 let mut new_membership = vec![0.0; n * k];
272 let mut max_change = 0.0;
273
274 for (i, curve) in curves.iter().enumerate() {
275 let distances: Vec<f64> = centers
276 .iter()
277 .map(|c| l2_distance(curve, c, weights))
278 .collect();
279
280 let memberships = compute_fuzzy_membership(&distances, exponent);
281
282 for c in 0..k {
283 new_membership[i + c * n] = memberships[c];
284 let change = (memberships[c] - old_membership[i + c * n]).abs();
285 if change > max_change {
286 max_change = change;
287 }
288 }
289 }
290
291 (new_membership, max_change)
292}
293
294fn mean_cluster_distance(
296 curve: &[f64],
297 curves: &[Vec<f64>],
298 indices: &[usize],
299 weights: &[f64],
300) -> f64 {
301 if indices.is_empty() {
302 return 0.0;
303 }
304 let sum: f64 = indices
305 .iter()
306 .map(|&j| l2_distance(curve, &curves[j], weights))
307 .sum();
308 sum / indices.len() as f64
309}
310
311fn compute_centers_and_global_mean(
313 curves: &[Vec<f64>],
314 assignments: &[usize],
315 k: usize,
316 m: usize,
317) -> (Vec<Vec<f64>>, Vec<f64>, Vec<usize>) {
318 let n = curves.len();
319 let mut global_mean = vec![0.0; m];
320 for curve in curves {
321 for j in 0..m {
322 global_mean[j] += curve[j];
323 }
324 }
325 for j in 0..m {
326 global_mean[j] /= n as f64;
327 }
328
329 let mut centers = vec![vec![0.0; m]; k];
330 let mut counts = vec![0usize; k];
331 for (i, curve) in curves.iter().enumerate() {
332 let c = assignments[i];
333 counts[c] += 1;
334 for j in 0..m {
335 centers[c][j] += curve[j];
336 }
337 }
338 for c in 0..k {
339 if counts[c] > 0 {
340 for j in 0..m {
341 centers[c][j] /= counts[c] as f64;
342 }
343 }
344 }
345
346 (centers, global_mean, counts)
347}
348
349fn kmeans_step(
351 curves: &[Vec<f64>],
352 centers: &[Vec<f64>],
353 weights: &[f64],
354 k: usize,
355 m: usize,
356) -> (Vec<usize>, Vec<Vec<f64>>, f64) {
357 let new_cluster = assign_clusters(curves, centers, weights);
358 let new_centers = update_kmeans_centers(curves, &new_cluster, centers, k, m);
359 let max_movement = centers
360 .iter()
361 .zip(new_centers.iter())
362 .map(|(old, new)| l2_distance(old, new, weights))
363 .fold(0.0, f64::max);
364 (new_cluster, new_centers, max_movement)
365}
366
367fn kmeans_iterate(
369 curves: &[Vec<f64>],
370 mut centers: Vec<Vec<f64>>,
371 weights: &[f64],
372 k: usize,
373 m: usize,
374 max_iter: usize,
375 tol: f64,
376) -> (Vec<usize>, Vec<Vec<f64>>, usize, bool) {
377 let n = curves.len();
378 let mut cluster = vec![0usize; n];
379 let mut converged = false;
380 let mut iter = 0;
381
382 for iteration in 0..max_iter {
383 iter = iteration + 1;
384 let (new_cluster, new_centers, max_movement) = kmeans_step(curves, ¢ers, weights, k, m);
385
386 if new_cluster == cluster {
387 converged = true;
388 break;
389 }
390 cluster = new_cluster;
391 centers = new_centers;
392
393 if max_movement < tol {
394 converged = true;
395 break;
396 }
397 }
398
399 (cluster, centers, iter, converged)
400}
401
402pub fn kmeans_fd(
414 data: &[f64],
415 n: usize,
416 m: usize,
417 argvals: &[f64],
418 k: usize,
419 max_iter: usize,
420 tol: f64,
421 seed: u64,
422) -> KmeansResult {
423 if n == 0 || m == 0 || k == 0 || k > n || argvals.len() != m {
424 return KmeansResult {
425 cluster: Vec::new(),
426 centers: Vec::new(),
427 withinss: Vec::new(),
428 tot_withinss: 0.0,
429 iter: 0,
430 converged: false,
431 };
432 }
433
434 let weights = simpsons_weights(argvals);
435 let mut rng = StdRng::seed_from_u64(seed);
436
437 let curves = extract_curves(data, n, m);
439
440 let centers = kmeans_plusplus_init(&curves, k, &weights, &mut rng);
442
443 let (cluster, centers, iter, converged) =
444 kmeans_iterate(&curves, centers, &weights, k, m, max_iter, tol);
445
446 let withinss = compute_within_ss(&curves, ¢ers, &cluster, k, &weights);
447 let tot_withinss: f64 = withinss.iter().sum();
448 let centers_flat = flatten_centers_colmajor(¢ers, k, m);
449
450 KmeansResult {
451 cluster,
452 centers: centers_flat,
453 withinss,
454 tot_withinss,
455 iter,
456 converged,
457 }
458}
459
460pub struct FuzzyCmeansResult {
462 pub membership: Vec<f64>,
464 pub centers: Vec<f64>,
466 pub iter: usize,
468 pub converged: bool,
470}
471
472pub fn fuzzy_cmeans_fd(
485 data: &[f64],
486 n: usize,
487 m: usize,
488 argvals: &[f64],
489 k: usize,
490 fuzziness: f64,
491 max_iter: usize,
492 tol: f64,
493 seed: u64,
494) -> FuzzyCmeansResult {
495 if n == 0 || m == 0 || k == 0 || k > n || argvals.len() != m || fuzziness <= 1.0 {
496 return FuzzyCmeansResult {
497 membership: Vec::new(),
498 centers: Vec::new(),
499 iter: 0,
500 converged: false,
501 };
502 }
503
504 let weights = simpsons_weights(argvals);
505 let mut rng = StdRng::seed_from_u64(seed);
506
507 let curves = extract_curves(data, n, m);
509
510 let mut membership = init_random_membership(n, k, &mut rng);
511
512 let mut centers = vec![vec![0.0; m]; k];
513 let mut converged = false;
514 let mut iter = 0;
515 let exponent = 2.0 / (fuzziness - 1.0);
516
517 for iteration in 0..max_iter {
518 iter = iteration + 1;
519
520 centers = update_fuzzy_centers(&curves, &membership, n, k, m, fuzziness);
521
522 let (new_membership, max_change) =
523 update_fuzzy_membership_step(&curves, ¢ers, &membership, n, k, exponent, &weights);
524
525 membership = new_membership;
526
527 if max_change < tol {
528 converged = true;
529 break;
530 }
531 }
532
533 let centers_flat = flatten_centers_colmajor(¢ers, k, m);
534
535 FuzzyCmeansResult {
536 membership,
537 centers: centers_flat,
538 iter,
539 converged,
540 }
541}
542
543pub fn silhouette_score(
545 data: &[f64],
546 n: usize,
547 m: usize,
548 argvals: &[f64],
549 cluster: &[usize],
550) -> Vec<f64> {
551 if n == 0 || m == 0 || cluster.len() != n || argvals.len() != m {
552 return Vec::new();
553 }
554
555 let weights = simpsons_weights(argvals);
556 let curves = extract_curves(data, n, m);
557
558 let k = cluster.iter().cloned().max().unwrap_or(0) + 1;
559 let members = cluster_member_indices(cluster, k);
560
561 iter_maybe_parallel!(0..n)
562 .map(|i| {
563 let my_cluster = cluster[i];
564
565 let same_indices: Vec<usize> = members[my_cluster]
566 .iter()
567 .copied()
568 .filter(|&j| j != i)
569 .collect();
570 let a_i = mean_cluster_distance(&curves[i], &curves, &same_indices, &weights);
571
572 let mut b_i = f64::INFINITY;
573 for c in 0..k {
574 if c != my_cluster && !members[c].is_empty() {
575 b_i = b_i.min(mean_cluster_distance(
576 &curves[i],
577 &curves,
578 &members[c],
579 &weights,
580 ));
581 }
582 }
583
584 if b_i.is_infinite() {
585 0.0
586 } else {
587 let max_ab = a_i.max(b_i);
588 if max_ab > NUMERICAL_EPS {
589 (b_i - a_i) / max_ab
590 } else {
591 0.0
592 }
593 }
594 })
595 .collect()
596}
597
598pub fn calinski_harabasz(
600 data: &[f64],
601 n: usize,
602 m: usize,
603 argvals: &[f64],
604 cluster: &[usize],
605) -> f64 {
606 if n == 0 || m == 0 || cluster.len() != n || argvals.len() != m {
607 return 0.0;
608 }
609
610 let weights = simpsons_weights(argvals);
611 let curves = extract_curves(data, n, m);
612
613 let k = cluster.iter().cloned().max().unwrap_or(0) + 1;
614 if k < 2 {
615 return 0.0;
616 }
617
618 let (centers, global_mean, counts) = compute_centers_and_global_mean(&curves, cluster, k, m);
619
620 let mut bgss = 0.0;
621 for c in 0..k {
622 let dist = l2_distance(¢ers[c], &global_mean, &weights);
623 bgss += counts[c] as f64 * dist * dist;
624 }
625
626 let wgss_vec = compute_within_ss(&curves, ¢ers, cluster, k, &weights);
627 let wgss: f64 = wgss_vec.iter().sum();
628
629 if wgss < NUMERICAL_EPS {
630 return f64::INFINITY;
631 }
632
633 (bgss / (k - 1) as f64) / (wgss / (n - k) as f64)
634}
635
636#[cfg(test)]
637mod tests {
638 use super::*;
639 use std::f64::consts::PI;
640
641 fn uniform_grid(n: usize) -> Vec<f64> {
643 (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
644 }
645
646 fn generate_two_clusters(n_per_cluster: usize, m: usize) -> (Vec<f64>, Vec<f64>) {
648 let t = uniform_grid(m);
649 let mut data = Vec::with_capacity(2 * n_per_cluster * m);
650
651 for i in 0..n_per_cluster {
653 for &ti in &t {
654 data.push((2.0 * PI * ti).sin() + 0.1 * (i as f64 / n_per_cluster as f64));
655 }
656 }
657
658 for i in 0..n_per_cluster {
660 for &ti in &t {
661 data.push((2.0 * PI * ti).sin() + 5.0 + 0.1 * (i as f64 / n_per_cluster as f64));
662 }
663 }
664
665 let n = 2 * n_per_cluster;
667 let mut col_major = vec![0.0; n * m];
668 for i in 0..n {
669 for j in 0..m {
670 col_major[i + j * n] = data[i * m + j];
671 }
672 }
673
674 (col_major, t)
675 }
676
677 #[test]
680 fn test_kmeans_fd_basic() {
681 let m = 50;
682 let n_per = 5;
683 let (data, t) = generate_two_clusters(n_per, m);
684 let n = 2 * n_per;
685
686 let result = kmeans_fd(&data, n, m, &t, 2, 100, 1e-6, 42);
687
688 assert_eq!(result.cluster.len(), n);
689 assert!(result.converged);
690 assert!(result.iter > 0 && result.iter <= 100);
691 }
692
693 #[test]
694 fn test_kmeans_fd_finds_clusters() {
695 let m = 50;
696 let n_per = 10;
697 let (data, t) = generate_two_clusters(n_per, m);
698 let n = 2 * n_per;
699
700 let result = kmeans_fd(&data, n, m, &t, 2, 100, 1e-6, 42);
701
702 let cluster_0 = result.cluster[0];
704 let cluster_1 = result.cluster[n_per];
705
706 assert_ne!(cluster_0, cluster_1, "Clusters should be different");
707
708 for i in 0..n_per {
710 assert_eq!(result.cluster[i], cluster_0);
711 }
712
713 for i in n_per..n {
715 assert_eq!(result.cluster[i], cluster_1);
716 }
717 }
718
719 #[test]
720 fn test_kmeans_fd_deterministic() {
721 let m = 30;
722 let n_per = 5;
723 let (data, t) = generate_two_clusters(n_per, m);
724 let n = 2 * n_per;
725
726 let result1 = kmeans_fd(&data, n, m, &t, 2, 100, 1e-6, 42);
727 let result2 = kmeans_fd(&data, n, m, &t, 2, 100, 1e-6, 42);
728
729 assert_eq!(result1.cluster, result2.cluster);
731 }
732
733 #[test]
734 fn test_kmeans_fd_withinss() {
735 let m = 30;
736 let n_per = 5;
737 let (data, t) = generate_two_clusters(n_per, m);
738 let n = 2 * n_per;
739
740 let result = kmeans_fd(&data, n, m, &t, 2, 100, 1e-6, 42);
741
742 for &wss in &result.withinss {
744 assert!(wss >= 0.0);
745 }
746
747 let sum: f64 = result.withinss.iter().sum();
749 assert!((sum - result.tot_withinss).abs() < 1e-10);
750 }
751
752 #[test]
753 fn test_kmeans_fd_centers_shape() {
754 let m = 30;
755 let n_per = 5;
756 let (data, t) = generate_two_clusters(n_per, m);
757 let n = 2 * n_per;
758 let k = 3;
759
760 let result = kmeans_fd(&data, n, m, &t, k, 100, 1e-6, 42);
761
762 assert_eq!(result.centers.len(), k * m);
764 }
765
766 #[test]
767 fn test_kmeans_fd_invalid_input() {
768 let t = uniform_grid(30);
769
770 let result = kmeans_fd(&[], 0, 30, &t, 2, 100, 1e-6, 42);
772 assert!(result.cluster.is_empty());
773 assert!(!result.converged);
774
775 let data = vec![0.0; 5 * 30];
777 let result = kmeans_fd(&data, 5, 30, &t, 10, 100, 1e-6, 42);
778 assert!(result.cluster.is_empty());
779 }
780
781 #[test]
782 fn test_kmeans_fd_single_cluster() {
783 let m = 30;
784 let t = uniform_grid(m);
785 let n = 10;
786 let data = vec![0.0; n * m];
787
788 let result = kmeans_fd(&data, n, m, &t, 1, 100, 1e-6, 42);
789
790 for &c in &result.cluster {
792 assert_eq!(c, 0);
793 }
794 }
795
796 #[test]
799 fn test_fuzzy_cmeans_fd_basic() {
800 let m = 50;
801 let n_per = 5;
802 let (data, t) = generate_two_clusters(n_per, m);
803 let n = 2 * n_per;
804
805 let result = fuzzy_cmeans_fd(&data, n, m, &t, 2, 2.0, 100, 1e-6, 42);
806
807 assert_eq!(result.membership.len(), n * 2);
808 assert!(result.iter > 0);
809 }
810
811 #[test]
812 fn test_fuzzy_cmeans_fd_membership_sums_to_one() {
813 let m = 30;
814 let n_per = 5;
815 let (data, t) = generate_two_clusters(n_per, m);
816 let n = 2 * n_per;
817 let k = 2;
818
819 let result = fuzzy_cmeans_fd(&data, n, m, &t, k, 2.0, 100, 1e-6, 42);
820
821 for i in 0..n {
823 let sum: f64 = (0..k).map(|c| result.membership[i + c * n]).sum();
824 assert!(
825 (sum - 1.0).abs() < 1e-6,
826 "Membership should sum to 1, got {}",
827 sum
828 );
829 }
830 }
831
832 #[test]
833 fn test_fuzzy_cmeans_fd_membership_in_range() {
834 let m = 30;
835 let n_per = 5;
836 let (data, t) = generate_two_clusters(n_per, m);
837 let n = 2 * n_per;
838
839 let result = fuzzy_cmeans_fd(&data, n, m, &t, 2, 2.0, 100, 1e-6, 42);
840
841 for &mem in &result.membership {
843 assert!((0.0..=1.0 + 1e-10).contains(&mem));
844 }
845 }
846
847 #[test]
848 fn test_fuzzy_cmeans_fd_fuzziness_effect() {
849 let m = 30;
850 let n_per = 5;
851 let (data, t) = generate_two_clusters(n_per, m);
852 let n = 2 * n_per;
853
854 let result_low = fuzzy_cmeans_fd(&data, n, m, &t, 2, 1.5, 100, 1e-6, 42);
855 let result_high = fuzzy_cmeans_fd(&data, n, m, &t, 2, 3.0, 100, 1e-6, 42);
856
857 let entropy_low: f64 = result_low
860 .membership
861 .iter()
862 .map(|&m| if m > 1e-10 { -m * m.ln() } else { 0.0 })
863 .sum();
864
865 let entropy_high: f64 = result_high
866 .membership
867 .iter()
868 .map(|&m| if m > 1e-10 { -m * m.ln() } else { 0.0 })
869 .sum();
870
871 assert!(
872 entropy_high >= entropy_low - 0.1,
873 "Higher fuzziness should give higher entropy"
874 );
875 }
876
877 #[test]
878 fn test_fuzzy_cmeans_fd_invalid_fuzziness() {
879 let t = uniform_grid(30);
880 let data = vec![0.0; 10 * 30];
881
882 let result = fuzzy_cmeans_fd(&data, 10, 30, &t, 2, 1.0, 100, 1e-6, 42);
884 assert!(result.membership.is_empty());
885
886 let result = fuzzy_cmeans_fd(&data, 10, 30, &t, 2, 0.5, 100, 1e-6, 42);
887 assert!(result.membership.is_empty());
888 }
889
890 #[test]
891 fn test_fuzzy_cmeans_fd_centers_shape() {
892 let m = 30;
893 let t = uniform_grid(m);
894 let n = 10;
895 let k = 3;
896 let data = vec![0.0; n * m];
897
898 let result = fuzzy_cmeans_fd(&data, n, m, &t, k, 2.0, 100, 1e-6, 42);
899
900 assert_eq!(result.centers.len(), k * m);
901 }
902
903 #[test]
906 fn test_silhouette_score_well_separated() {
907 let m = 30;
908 let n_per = 10;
909 let (data, t) = generate_two_clusters(n_per, m);
910 let n = 2 * n_per;
911
912 let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
914
915 let scores = silhouette_score(&data, n, m, &t, &cluster);
916
917 assert_eq!(scores.len(), n);
918
919 let mean_score: f64 = scores.iter().sum::<f64>() / n as f64;
921 assert!(
922 mean_score > 0.5,
923 "Well-separated clusters should have high silhouette: {}",
924 mean_score
925 );
926 }
927
928 #[test]
929 fn test_silhouette_score_range() {
930 let m = 30;
931 let n_per = 5;
932 let (data, t) = generate_two_clusters(n_per, m);
933 let n = 2 * n_per;
934
935 let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
936
937 let scores = silhouette_score(&data, n, m, &t, &cluster);
938
939 for &s in &scores {
941 assert!((-1.0 - 1e-10..=1.0 + 1e-10).contains(&s));
942 }
943 }
944
945 #[test]
946 fn test_silhouette_score_single_cluster() {
947 let m = 30;
948 let t = uniform_grid(m);
949 let n = 10;
950 let data = vec![0.0; n * m];
951
952 let cluster = vec![0usize; n];
954
955 let scores = silhouette_score(&data, n, m, &t, &cluster);
956
957 for &s in &scores {
959 assert!(s.abs() < 1e-10);
960 }
961 }
962
963 #[test]
964 fn test_silhouette_score_invalid_input() {
965 let t = uniform_grid(30);
966
967 let scores = silhouette_score(&[], 0, 30, &t, &[]);
969 assert!(scores.is_empty());
970
971 let data = vec![0.0; 10 * 30];
973 let cluster = vec![0; 5]; let scores = silhouette_score(&data, 10, 30, &t, &cluster);
975 assert!(scores.is_empty());
976 }
977
978 #[test]
981 fn test_calinski_harabasz_well_separated() {
982 let m = 30;
983 let n_per = 10;
984 let (data, t) = generate_two_clusters(n_per, m);
985 let n = 2 * n_per;
986
987 let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
988
989 let ch = calinski_harabasz(&data, n, m, &t, &cluster);
990
991 assert!(
993 ch > 1.0,
994 "Well-separated clusters should have high CH: {}",
995 ch
996 );
997 }
998
999 #[test]
1000 fn test_calinski_harabasz_positive() {
1001 let m = 30;
1002 let n_per = 5;
1003 let (data, t) = generate_two_clusters(n_per, m);
1004 let n = 2 * n_per;
1005
1006 let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
1007
1008 let ch = calinski_harabasz(&data, n, m, &t, &cluster);
1009
1010 assert!(ch >= 0.0, "CH index should be non-negative");
1011 }
1012
1013 #[test]
1014 fn test_calinski_harabasz_single_cluster() {
1015 let m = 30;
1016 let t = uniform_grid(m);
1017 let n = 10;
1018 let data = vec![0.0; n * m];
1019
1020 let cluster = vec![0usize; n];
1022
1023 let ch = calinski_harabasz(&data, n, m, &t, &cluster);
1024
1025 assert!(ch.abs() < 1e-10);
1027 }
1028
1029 #[test]
1030 fn test_calinski_harabasz_invalid_input() {
1031 let t = uniform_grid(30);
1032
1033 let ch = calinski_harabasz(&[], 0, 30, &t, &[]);
1035 assert!(ch.abs() < 1e-10);
1036 }
1037}