Skip to main content

fdars_core/
clustering.rs

1//! Clustering algorithms for functional data.
2//!
3//! This module provides k-means and fuzzy c-means clustering algorithms
4//! for functional data.
5
6use crate::helpers::{extract_curves, l2_distance, simpsons_weights, NUMERICAL_EPS};
7use crate::{iter_maybe_parallel, slice_maybe_parallel};
8use rand::prelude::*;
9#[cfg(feature = "parallel")]
10use rayon::iter::ParallelIterator;
11
12/// Result of k-means clustering.
13pub struct KmeansResult {
14    /// Cluster assignments for each observation
15    pub cluster: Vec<usize>,
16    /// Cluster centers (k x m matrix, column-major)
17    pub centers: Vec<f64>,
18    /// Within-cluster sum of squares for each cluster
19    pub withinss: Vec<f64>,
20    /// Total within-cluster sum of squares
21    pub tot_withinss: f64,
22    /// Number of iterations
23    pub iter: usize,
24    /// Whether the algorithm converged
25    pub converged: bool,
26}
27
28/// K-means++ initialization: select initial centers with probability proportional to D^2.
29///
30/// # Arguments
31/// * `curves` - Vector of curve vectors
32/// * `k` - Number of clusters
33/// * `weights` - Integration weights for L2 distance
34/// * `rng` - Random number generator
35///
36/// # Returns
37/// Vector of k initial cluster centers
38fn kmeans_plusplus_init(
39    curves: &[Vec<f64>],
40    k: usize,
41    weights: &[f64],
42    rng: &mut StdRng,
43) -> Vec<Vec<f64>> {
44    let n = curves.len();
45    let mut centers: Vec<Vec<f64>> = Vec::with_capacity(k);
46
47    // First center: random
48    let first_idx = rng.gen_range(0..n);
49    centers.push(curves[first_idx].clone());
50
51    // Remaining centers: probability proportional to D^2
52    for _ in 1..k {
53        let distances: Vec<f64> = curves
54            .iter()
55            .map(|curve| {
56                centers
57                    .iter()
58                    .map(|c| l2_distance(curve, c, weights))
59                    .fold(f64::INFINITY, f64::min)
60            })
61            .collect();
62
63        let dist_sq: Vec<f64> = distances.iter().map(|d| d * d).collect();
64        let total: f64 = dist_sq.iter().sum();
65
66        if total < NUMERICAL_EPS {
67            let idx = rng.gen_range(0..n);
68            centers.push(curves[idx].clone());
69        } else {
70            let r = rng.gen::<f64>() * total;
71            let mut cumsum = 0.0;
72            let mut chosen = 0;
73            for (i, &d) in dist_sq.iter().enumerate() {
74                cumsum += d;
75                if cumsum >= r {
76                    chosen = i;
77                    break;
78                }
79            }
80            centers.push(curves[chosen].clone());
81        }
82    }
83
84    centers
85}
86
87/// Compute fuzzy membership values for a single observation.
88///
89/// # Arguments
90/// * `distances` - Distances from the observation to each cluster center
91/// * `exponent` - Exponent for fuzzy membership (2 / (fuzziness - 1))
92///
93/// # Returns
94/// Vector of membership values (one per cluster)
95fn compute_fuzzy_membership(distances: &[f64], exponent: f64) -> Vec<f64> {
96    let k = distances.len();
97    let mut membership = vec![0.0; k];
98
99    // Check if observation is very close to any center
100    for (c, &dist) in distances.iter().enumerate() {
101        if dist < NUMERICAL_EPS {
102            // Assign full membership to this cluster
103            membership[c] = 1.0;
104            return membership;
105        }
106    }
107
108    // Normal fuzzy membership computation
109    for c in 0..k {
110        let mut sum = 0.0;
111        for c2 in 0..k {
112            if distances[c2] > NUMERICAL_EPS {
113                sum += (distances[c] / distances[c2]).powf(exponent);
114            }
115        }
116        membership[c] = if sum > NUMERICAL_EPS { 1.0 / sum } else { 1.0 };
117    }
118
119    membership
120}
121
122/// Flatten Vec<Vec<f64>> centers into column-major Vec<f64>.
123fn flatten_centers_colmajor(centers: &[Vec<f64>], k: usize, m: usize) -> Vec<f64> {
124    let mut flat = vec![0.0; k * m];
125    for c in 0..k {
126        for j in 0..m {
127            flat[c + j * k] = centers[c][j];
128        }
129    }
130    flat
131}
132
133/// Initialize a random membership matrix (n x k, column-major) with rows summing to 1.
134fn init_random_membership(n: usize, k: usize, rng: &mut StdRng) -> Vec<f64> {
135    let mut membership = vec![0.0; n * k];
136    for i in 0..n {
137        let mut row_sum = 0.0;
138        for c in 0..k {
139            let val = rng.gen::<f64>();
140            membership[i + c * n] = val;
141            row_sum += val;
142        }
143        for c in 0..k {
144            membership[i + c * n] /= row_sum;
145        }
146    }
147    membership
148}
149
150/// Group sample indices by their cluster assignment.
151fn cluster_member_indices(cluster: &[usize], k: usize) -> Vec<Vec<usize>> {
152    let mut indices = vec![Vec::new(); k];
153    for (i, &c) in cluster.iter().enumerate() {
154        indices[c].push(i);
155    }
156    indices
157}
158
159/// Assign each curve to its nearest center, returning cluster indices.
160fn assign_clusters(curves: &[Vec<f64>], centers: &[Vec<f64>], weights: &[f64]) -> Vec<usize> {
161    slice_maybe_parallel!(curves)
162        .map(|curve| {
163            let mut best_cluster = 0;
164            let mut best_dist = f64::INFINITY;
165            for (c, center) in centers.iter().enumerate() {
166                let dist = l2_distance(curve, center, weights);
167                if dist < best_dist {
168                    best_dist = dist;
169                    best_cluster = c;
170                }
171            }
172            best_cluster
173        })
174        .collect()
175}
176
177/// Compute new cluster centers from curve assignments.
178fn update_kmeans_centers(
179    curves: &[Vec<f64>],
180    assignments: &[usize],
181    centers: &[Vec<f64>],
182    k: usize,
183    m: usize,
184) -> Vec<Vec<f64>> {
185    (0..k)
186        .map(|c| {
187            let members: Vec<usize> = assignments
188                .iter()
189                .enumerate()
190                .filter(|(_, &cl)| cl == c)
191                .map(|(i, _)| i)
192                .collect();
193
194            if members.is_empty() {
195                centers[c].clone()
196            } else {
197                let mut center = vec![0.0; m];
198                for &i in &members {
199                    for j in 0..m {
200                        center[j] += curves[i][j];
201                    }
202                }
203                let n_members = members.len() as f64;
204                for j in 0..m {
205                    center[j] /= n_members;
206                }
207                center
208            }
209        })
210        .collect()
211}
212
213/// Compute within-cluster sum of squares for each cluster.
214fn compute_within_ss(
215    curves: &[Vec<f64>],
216    centers: &[Vec<f64>],
217    assignments: &[usize],
218    k: usize,
219    weights: &[f64],
220) -> Vec<f64> {
221    let mut withinss = vec![0.0; k];
222    for (i, curve) in curves.iter().enumerate() {
223        let c = assignments[i];
224        let dist = l2_distance(curve, &centers[c], weights);
225        withinss[c] += dist * dist;
226    }
227    withinss
228}
229
230/// Update fuzzy c-means cluster centers from membership values.
231fn update_fuzzy_centers(
232    curves: &[Vec<f64>],
233    membership: &[f64],
234    n: usize,
235    k: usize,
236    m: usize,
237    fuzziness: f64,
238) -> Vec<Vec<f64>> {
239    let mut centers = vec![vec![0.0; m]; k];
240    for c in 0..k {
241        let mut numerator = vec![0.0; m];
242        let mut denominator = 0.0;
243
244        for (i, curve) in curves.iter().enumerate() {
245            let weight = membership[i + c * n].powf(fuzziness);
246            for j in 0..m {
247                numerator[j] += weight * curve[j];
248            }
249            denominator += weight;
250        }
251
252        if denominator > NUMERICAL_EPS {
253            for j in 0..m {
254                centers[c][j] = numerator[j] / denominator;
255            }
256        }
257    }
258    centers
259}
260
261/// Update fuzzy membership values and compute max change.
262fn update_fuzzy_membership_step(
263    curves: &[Vec<f64>],
264    centers: &[Vec<f64>],
265    old_membership: &[f64],
266    n: usize,
267    k: usize,
268    exponent: f64,
269    weights: &[f64],
270) -> (Vec<f64>, f64) {
271    let mut new_membership = vec![0.0; n * k];
272    let mut max_change = 0.0;
273
274    for (i, curve) in curves.iter().enumerate() {
275        let distances: Vec<f64> = centers
276            .iter()
277            .map(|c| l2_distance(curve, c, weights))
278            .collect();
279
280        let memberships = compute_fuzzy_membership(&distances, exponent);
281
282        for c in 0..k {
283            new_membership[i + c * n] = memberships[c];
284            let change = (memberships[c] - old_membership[i + c * n]).abs();
285            if change > max_change {
286                max_change = change;
287            }
288        }
289    }
290
291    (new_membership, max_change)
292}
293
294/// Compute mean L2 distance from a curve to a set of curve indices.
295fn mean_cluster_distance(
296    curve: &[f64],
297    curves: &[Vec<f64>],
298    indices: &[usize],
299    weights: &[f64],
300) -> f64 {
301    if indices.is_empty() {
302        return 0.0;
303    }
304    let sum: f64 = indices
305        .iter()
306        .map(|&j| l2_distance(curve, &curves[j], weights))
307        .sum();
308    sum / indices.len() as f64
309}
310
311/// Compute cluster centers, global mean, and counts from curves and assignments.
312fn compute_centers_and_global_mean(
313    curves: &[Vec<f64>],
314    assignments: &[usize],
315    k: usize,
316    m: usize,
317) -> (Vec<Vec<f64>>, Vec<f64>, Vec<usize>) {
318    let n = curves.len();
319    let mut global_mean = vec![0.0; m];
320    for curve in curves {
321        for j in 0..m {
322            global_mean[j] += curve[j];
323        }
324    }
325    for j in 0..m {
326        global_mean[j] /= n as f64;
327    }
328
329    let mut centers = vec![vec![0.0; m]; k];
330    let mut counts = vec![0usize; k];
331    for (i, curve) in curves.iter().enumerate() {
332        let c = assignments[i];
333        counts[c] += 1;
334        for j in 0..m {
335            centers[c][j] += curve[j];
336        }
337    }
338    for c in 0..k {
339        if counts[c] > 0 {
340            for j in 0..m {
341                centers[c][j] /= counts[c] as f64;
342            }
343        }
344    }
345
346    (centers, global_mean, counts)
347}
348
349/// Run one k-means iteration: assign clusters, update centers, compute movement.
350fn kmeans_step(
351    curves: &[Vec<f64>],
352    centers: &[Vec<f64>],
353    weights: &[f64],
354    k: usize,
355    m: usize,
356) -> (Vec<usize>, Vec<Vec<f64>>, f64) {
357    let new_cluster = assign_clusters(curves, centers, weights);
358    let new_centers = update_kmeans_centers(curves, &new_cluster, centers, k, m);
359    let max_movement = centers
360        .iter()
361        .zip(new_centers.iter())
362        .map(|(old, new)| l2_distance(old, new, weights))
363        .fold(0.0, f64::max);
364    (new_cluster, new_centers, max_movement)
365}
366
367/// Run the k-means iteration loop until convergence or max iterations.
368fn kmeans_iterate(
369    curves: &[Vec<f64>],
370    mut centers: Vec<Vec<f64>>,
371    weights: &[f64],
372    k: usize,
373    m: usize,
374    max_iter: usize,
375    tol: f64,
376) -> (Vec<usize>, Vec<Vec<f64>>, usize, bool) {
377    let n = curves.len();
378    let mut cluster = vec![0usize; n];
379    let mut converged = false;
380    let mut iter = 0;
381
382    for iteration in 0..max_iter {
383        iter = iteration + 1;
384        let (new_cluster, new_centers, max_movement) = kmeans_step(curves, &centers, weights, k, m);
385
386        if new_cluster == cluster {
387            converged = true;
388            break;
389        }
390        cluster = new_cluster;
391        centers = new_centers;
392
393        if max_movement < tol {
394            converged = true;
395            break;
396        }
397    }
398
399    (cluster, centers, iter, converged)
400}
401
402/// K-means clustering for functional data.
403///
404/// # Arguments
405/// * `data` - Column-major matrix (n x m)
406/// * `n` - Number of observations
407/// * `m` - Number of evaluation points
408/// * `argvals` - Evaluation points
409/// * `k` - Number of clusters
410/// * `max_iter` - Maximum iterations
411/// * `tol` - Convergence tolerance
412/// * `seed` - Random seed
413pub fn kmeans_fd(
414    data: &[f64],
415    n: usize,
416    m: usize,
417    argvals: &[f64],
418    k: usize,
419    max_iter: usize,
420    tol: f64,
421    seed: u64,
422) -> KmeansResult {
423    if n == 0 || m == 0 || k == 0 || k > n || argvals.len() != m {
424        return KmeansResult {
425            cluster: Vec::new(),
426            centers: Vec::new(),
427            withinss: Vec::new(),
428            tot_withinss: 0.0,
429            iter: 0,
430            converged: false,
431        };
432    }
433
434    let weights = simpsons_weights(argvals);
435    let mut rng = StdRng::seed_from_u64(seed);
436
437    // Extract curves using helper
438    let curves = extract_curves(data, n, m);
439
440    // K-means++ initialization using helper
441    let centers = kmeans_plusplus_init(&curves, k, &weights, &mut rng);
442
443    let (cluster, centers, iter, converged) =
444        kmeans_iterate(&curves, centers, &weights, k, m, max_iter, tol);
445
446    let withinss = compute_within_ss(&curves, &centers, &cluster, k, &weights);
447    let tot_withinss: f64 = withinss.iter().sum();
448    let centers_flat = flatten_centers_colmajor(&centers, k, m);
449
450    KmeansResult {
451        cluster,
452        centers: centers_flat,
453        withinss,
454        tot_withinss,
455        iter,
456        converged,
457    }
458}
459
460/// Result of fuzzy c-means clustering.
461pub struct FuzzyCmeansResult {
462    /// Membership matrix (n x k, column-major)
463    pub membership: Vec<f64>,
464    /// Cluster centers (k x m, column-major)
465    pub centers: Vec<f64>,
466    /// Number of iterations
467    pub iter: usize,
468    /// Whether the algorithm converged
469    pub converged: bool,
470}
471
472/// Fuzzy c-means clustering for functional data.
473///
474/// # Arguments
475/// * `data` - Column-major matrix (n x m)
476/// * `n` - Number of observations
477/// * `m` - Number of evaluation points
478/// * `argvals` - Evaluation points
479/// * `k` - Number of clusters
480/// * `fuzziness` - Fuzziness parameter (> 1)
481/// * `max_iter` - Maximum iterations
482/// * `tol` - Convergence tolerance
483/// * `seed` - Random seed
484pub fn fuzzy_cmeans_fd(
485    data: &[f64],
486    n: usize,
487    m: usize,
488    argvals: &[f64],
489    k: usize,
490    fuzziness: f64,
491    max_iter: usize,
492    tol: f64,
493    seed: u64,
494) -> FuzzyCmeansResult {
495    if n == 0 || m == 0 || k == 0 || k > n || argvals.len() != m || fuzziness <= 1.0 {
496        return FuzzyCmeansResult {
497            membership: Vec::new(),
498            centers: Vec::new(),
499            iter: 0,
500            converged: false,
501        };
502    }
503
504    let weights = simpsons_weights(argvals);
505    let mut rng = StdRng::seed_from_u64(seed);
506
507    // Extract curves using helper
508    let curves = extract_curves(data, n, m);
509
510    let mut membership = init_random_membership(n, k, &mut rng);
511
512    let mut centers = vec![vec![0.0; m]; k];
513    let mut converged = false;
514    let mut iter = 0;
515    let exponent = 2.0 / (fuzziness - 1.0);
516
517    for iteration in 0..max_iter {
518        iter = iteration + 1;
519
520        centers = update_fuzzy_centers(&curves, &membership, n, k, m, fuzziness);
521
522        let (new_membership, max_change) =
523            update_fuzzy_membership_step(&curves, &centers, &membership, n, k, exponent, &weights);
524
525        membership = new_membership;
526
527        if max_change < tol {
528            converged = true;
529            break;
530        }
531    }
532
533    let centers_flat = flatten_centers_colmajor(&centers, k, m);
534
535    FuzzyCmeansResult {
536        membership,
537        centers: centers_flat,
538        iter,
539        converged,
540    }
541}
542
543/// Compute silhouette score for clustering result.
544pub fn silhouette_score(
545    data: &[f64],
546    n: usize,
547    m: usize,
548    argvals: &[f64],
549    cluster: &[usize],
550) -> Vec<f64> {
551    if n == 0 || m == 0 || cluster.len() != n || argvals.len() != m {
552        return Vec::new();
553    }
554
555    let weights = simpsons_weights(argvals);
556    let curves = extract_curves(data, n, m);
557
558    let k = cluster.iter().cloned().max().unwrap_or(0) + 1;
559    let members = cluster_member_indices(cluster, k);
560
561    iter_maybe_parallel!(0..n)
562        .map(|i| {
563            let my_cluster = cluster[i];
564
565            let same_indices: Vec<usize> = members[my_cluster]
566                .iter()
567                .copied()
568                .filter(|&j| j != i)
569                .collect();
570            let a_i = mean_cluster_distance(&curves[i], &curves, &same_indices, &weights);
571
572            let mut b_i = f64::INFINITY;
573            for c in 0..k {
574                if c != my_cluster && !members[c].is_empty() {
575                    b_i = b_i.min(mean_cluster_distance(
576                        &curves[i],
577                        &curves,
578                        &members[c],
579                        &weights,
580                    ));
581                }
582            }
583
584            if b_i.is_infinite() {
585                0.0
586            } else {
587                let max_ab = a_i.max(b_i);
588                if max_ab > NUMERICAL_EPS {
589                    (b_i - a_i) / max_ab
590                } else {
591                    0.0
592                }
593            }
594        })
595        .collect()
596}
597
598/// Compute Calinski-Harabasz index for clustering result.
599pub fn calinski_harabasz(
600    data: &[f64],
601    n: usize,
602    m: usize,
603    argvals: &[f64],
604    cluster: &[usize],
605) -> f64 {
606    if n == 0 || m == 0 || cluster.len() != n || argvals.len() != m {
607        return 0.0;
608    }
609
610    let weights = simpsons_weights(argvals);
611    let curves = extract_curves(data, n, m);
612
613    let k = cluster.iter().cloned().max().unwrap_or(0) + 1;
614    if k < 2 {
615        return 0.0;
616    }
617
618    let (centers, global_mean, counts) = compute_centers_and_global_mean(&curves, cluster, k, m);
619
620    let mut bgss = 0.0;
621    for c in 0..k {
622        let dist = l2_distance(&centers[c], &global_mean, &weights);
623        bgss += counts[c] as f64 * dist * dist;
624    }
625
626    let wgss_vec = compute_within_ss(&curves, &centers, cluster, k, &weights);
627    let wgss: f64 = wgss_vec.iter().sum();
628
629    if wgss < NUMERICAL_EPS {
630        return f64::INFINITY;
631    }
632
633    (bgss / (k - 1) as f64) / (wgss / (n - k) as f64)
634}
635
636#[cfg(test)]
637mod tests {
638    use super::*;
639    use std::f64::consts::PI;
640
641    /// Generate a uniform grid of points
642    fn uniform_grid(n: usize) -> Vec<f64> {
643        (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
644    }
645
646    /// Generate two clearly separated clusters of curves
647    fn generate_two_clusters(n_per_cluster: usize, m: usize) -> (Vec<f64>, Vec<f64>) {
648        let t = uniform_grid(m);
649        let mut data = Vec::with_capacity(2 * n_per_cluster * m);
650
651        // Cluster 0: sine waves with low amplitude
652        for i in 0..n_per_cluster {
653            for &ti in &t {
654                data.push((2.0 * PI * ti).sin() + 0.1 * (i as f64 / n_per_cluster as f64));
655            }
656        }
657
658        // Cluster 1: sine waves shifted up by 5
659        for i in 0..n_per_cluster {
660            for &ti in &t {
661                data.push((2.0 * PI * ti).sin() + 5.0 + 0.1 * (i as f64 / n_per_cluster as f64));
662            }
663        }
664
665        // Column-major reordering
666        let n = 2 * n_per_cluster;
667        let mut col_major = vec![0.0; n * m];
668        for i in 0..n {
669            for j in 0..m {
670                col_major[i + j * n] = data[i * m + j];
671            }
672        }
673
674        (col_major, t)
675    }
676
677    // ============== K-means tests ==============
678
679    #[test]
680    fn test_kmeans_fd_basic() {
681        let m = 50;
682        let n_per = 5;
683        let (data, t) = generate_two_clusters(n_per, m);
684        let n = 2 * n_per;
685
686        let result = kmeans_fd(&data, n, m, &t, 2, 100, 1e-6, 42);
687
688        assert_eq!(result.cluster.len(), n);
689        assert!(result.converged);
690        assert!(result.iter > 0 && result.iter <= 100);
691    }
692
693    #[test]
694    fn test_kmeans_fd_finds_clusters() {
695        let m = 50;
696        let n_per = 10;
697        let (data, t) = generate_two_clusters(n_per, m);
698        let n = 2 * n_per;
699
700        let result = kmeans_fd(&data, n, m, &t, 2, 100, 1e-6, 42);
701
702        // First half should be one cluster, second half the other
703        let cluster_0 = result.cluster[0];
704        let cluster_1 = result.cluster[n_per];
705
706        assert_ne!(cluster_0, cluster_1, "Clusters should be different");
707
708        // Check that first half is in same cluster
709        for i in 0..n_per {
710            assert_eq!(result.cluster[i], cluster_0);
711        }
712
713        // Check that second half is in same cluster
714        for i in n_per..n {
715            assert_eq!(result.cluster[i], cluster_1);
716        }
717    }
718
719    #[test]
720    fn test_kmeans_fd_deterministic() {
721        let m = 30;
722        let n_per = 5;
723        let (data, t) = generate_two_clusters(n_per, m);
724        let n = 2 * n_per;
725
726        let result1 = kmeans_fd(&data, n, m, &t, 2, 100, 1e-6, 42);
727        let result2 = kmeans_fd(&data, n, m, &t, 2, 100, 1e-6, 42);
728
729        // Same seed should give same results
730        assert_eq!(result1.cluster, result2.cluster);
731    }
732
733    #[test]
734    fn test_kmeans_fd_withinss() {
735        let m = 30;
736        let n_per = 5;
737        let (data, t) = generate_two_clusters(n_per, m);
738        let n = 2 * n_per;
739
740        let result = kmeans_fd(&data, n, m, &t, 2, 100, 1e-6, 42);
741
742        // Within-cluster sum of squares should be non-negative
743        for &wss in &result.withinss {
744            assert!(wss >= 0.0);
745        }
746
747        // Total should equal sum
748        let sum: f64 = result.withinss.iter().sum();
749        assert!((sum - result.tot_withinss).abs() < 1e-10);
750    }
751
752    #[test]
753    fn test_kmeans_fd_centers_shape() {
754        let m = 30;
755        let n_per = 5;
756        let (data, t) = generate_two_clusters(n_per, m);
757        let n = 2 * n_per;
758        let k = 3;
759
760        let result = kmeans_fd(&data, n, m, &t, k, 100, 1e-6, 42);
761
762        // Centers should be k x m matrix (column-major)
763        assert_eq!(result.centers.len(), k * m);
764    }
765
766    #[test]
767    fn test_kmeans_fd_invalid_input() {
768        let t = uniform_grid(30);
769
770        // Empty data
771        let result = kmeans_fd(&[], 0, 30, &t, 2, 100, 1e-6, 42);
772        assert!(result.cluster.is_empty());
773        assert!(!result.converged);
774
775        // k > n
776        let data = vec![0.0; 5 * 30];
777        let result = kmeans_fd(&data, 5, 30, &t, 10, 100, 1e-6, 42);
778        assert!(result.cluster.is_empty());
779    }
780
781    #[test]
782    fn test_kmeans_fd_single_cluster() {
783        let m = 30;
784        let t = uniform_grid(m);
785        let n = 10;
786        let data = vec![0.0; n * m];
787
788        let result = kmeans_fd(&data, n, m, &t, 1, 100, 1e-6, 42);
789
790        // All should be in cluster 0
791        for &c in &result.cluster {
792            assert_eq!(c, 0);
793        }
794    }
795
796    // ============== Fuzzy C-means tests ==============
797
798    #[test]
799    fn test_fuzzy_cmeans_fd_basic() {
800        let m = 50;
801        let n_per = 5;
802        let (data, t) = generate_two_clusters(n_per, m);
803        let n = 2 * n_per;
804
805        let result = fuzzy_cmeans_fd(&data, n, m, &t, 2, 2.0, 100, 1e-6, 42);
806
807        assert_eq!(result.membership.len(), n * 2);
808        assert!(result.iter > 0);
809    }
810
811    #[test]
812    fn test_fuzzy_cmeans_fd_membership_sums_to_one() {
813        let m = 30;
814        let n_per = 5;
815        let (data, t) = generate_two_clusters(n_per, m);
816        let n = 2 * n_per;
817        let k = 2;
818
819        let result = fuzzy_cmeans_fd(&data, n, m, &t, k, 2.0, 100, 1e-6, 42);
820
821        // Each observation's membership should sum to 1
822        for i in 0..n {
823            let sum: f64 = (0..k).map(|c| result.membership[i + c * n]).sum();
824            assert!(
825                (sum - 1.0).abs() < 1e-6,
826                "Membership should sum to 1, got {}",
827                sum
828            );
829        }
830    }
831
832    #[test]
833    fn test_fuzzy_cmeans_fd_membership_in_range() {
834        let m = 30;
835        let n_per = 5;
836        let (data, t) = generate_two_clusters(n_per, m);
837        let n = 2 * n_per;
838
839        let result = fuzzy_cmeans_fd(&data, n, m, &t, 2, 2.0, 100, 1e-6, 42);
840
841        // All memberships should be in [0, 1]
842        for &mem in &result.membership {
843            assert!((0.0..=1.0 + 1e-10).contains(&mem));
844        }
845    }
846
847    #[test]
848    fn test_fuzzy_cmeans_fd_fuzziness_effect() {
849        let m = 30;
850        let n_per = 5;
851        let (data, t) = generate_two_clusters(n_per, m);
852        let n = 2 * n_per;
853
854        let result_low = fuzzy_cmeans_fd(&data, n, m, &t, 2, 1.5, 100, 1e-6, 42);
855        let result_high = fuzzy_cmeans_fd(&data, n, m, &t, 2, 3.0, 100, 1e-6, 42);
856
857        // Higher fuzziness should give more diffuse memberships
858        // Measure by entropy-like metric
859        let entropy_low: f64 = result_low
860            .membership
861            .iter()
862            .map(|&m| if m > 1e-10 { -m * m.ln() } else { 0.0 })
863            .sum();
864
865        let entropy_high: f64 = result_high
866            .membership
867            .iter()
868            .map(|&m| if m > 1e-10 { -m * m.ln() } else { 0.0 })
869            .sum();
870
871        assert!(
872            entropy_high >= entropy_low - 0.1,
873            "Higher fuzziness should give higher entropy"
874        );
875    }
876
877    #[test]
878    fn test_fuzzy_cmeans_fd_invalid_fuzziness() {
879        let t = uniform_grid(30);
880        let data = vec![0.0; 10 * 30];
881
882        // Fuzziness <= 1 should fail
883        let result = fuzzy_cmeans_fd(&data, 10, 30, &t, 2, 1.0, 100, 1e-6, 42);
884        assert!(result.membership.is_empty());
885
886        let result = fuzzy_cmeans_fd(&data, 10, 30, &t, 2, 0.5, 100, 1e-6, 42);
887        assert!(result.membership.is_empty());
888    }
889
890    #[test]
891    fn test_fuzzy_cmeans_fd_centers_shape() {
892        let m = 30;
893        let t = uniform_grid(m);
894        let n = 10;
895        let k = 3;
896        let data = vec![0.0; n * m];
897
898        let result = fuzzy_cmeans_fd(&data, n, m, &t, k, 2.0, 100, 1e-6, 42);
899
900        assert_eq!(result.centers.len(), k * m);
901    }
902
903    // ============== Silhouette score tests ==============
904
905    #[test]
906    fn test_silhouette_score_well_separated() {
907        let m = 30;
908        let n_per = 10;
909        let (data, t) = generate_two_clusters(n_per, m);
910        let n = 2 * n_per;
911
912        // Perfect clustering: first half in 0, second in 1
913        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
914
915        let scores = silhouette_score(&data, n, m, &t, &cluster);
916
917        assert_eq!(scores.len(), n);
918
919        // Well-separated clusters should have high silhouette scores
920        let mean_score: f64 = scores.iter().sum::<f64>() / n as f64;
921        assert!(
922            mean_score > 0.5,
923            "Well-separated clusters should have high silhouette: {}",
924            mean_score
925        );
926    }
927
928    #[test]
929    fn test_silhouette_score_range() {
930        let m = 30;
931        let n_per = 5;
932        let (data, t) = generate_two_clusters(n_per, m);
933        let n = 2 * n_per;
934
935        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
936
937        let scores = silhouette_score(&data, n, m, &t, &cluster);
938
939        // Silhouette scores should be in [-1, 1]
940        for &s in &scores {
941            assert!((-1.0 - 1e-10..=1.0 + 1e-10).contains(&s));
942        }
943    }
944
945    #[test]
946    fn test_silhouette_score_single_cluster() {
947        let m = 30;
948        let t = uniform_grid(m);
949        let n = 10;
950        let data = vec![0.0; n * m];
951
952        // All in one cluster
953        let cluster = vec![0usize; n];
954
955        let scores = silhouette_score(&data, n, m, &t, &cluster);
956
957        // Single cluster should give zeros
958        for &s in &scores {
959            assert!(s.abs() < 1e-10);
960        }
961    }
962
963    #[test]
964    fn test_silhouette_score_invalid_input() {
965        let t = uniform_grid(30);
966
967        // Empty data
968        let scores = silhouette_score(&[], 0, 30, &t, &[]);
969        assert!(scores.is_empty());
970
971        // Mismatched cluster length
972        let data = vec![0.0; 10 * 30];
973        let cluster = vec![0; 5]; // Wrong length
974        let scores = silhouette_score(&data, 10, 30, &t, &cluster);
975        assert!(scores.is_empty());
976    }
977
978    // ============== Calinski-Harabasz tests ==============
979
980    #[test]
981    fn test_calinski_harabasz_well_separated() {
982        let m = 30;
983        let n_per = 10;
984        let (data, t) = generate_two_clusters(n_per, m);
985        let n = 2 * n_per;
986
987        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
988
989        let ch = calinski_harabasz(&data, n, m, &t, &cluster);
990
991        // Well-separated clusters should have high CH index
992        assert!(
993            ch > 1.0,
994            "Well-separated clusters should have high CH: {}",
995            ch
996        );
997    }
998
999    #[test]
1000    fn test_calinski_harabasz_positive() {
1001        let m = 30;
1002        let n_per = 5;
1003        let (data, t) = generate_two_clusters(n_per, m);
1004        let n = 2 * n_per;
1005
1006        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
1007
1008        let ch = calinski_harabasz(&data, n, m, &t, &cluster);
1009
1010        assert!(ch >= 0.0, "CH index should be non-negative");
1011    }
1012
1013    #[test]
1014    fn test_calinski_harabasz_single_cluster() {
1015        let m = 30;
1016        let t = uniform_grid(m);
1017        let n = 10;
1018        let data = vec![0.0; n * m];
1019
1020        // All in one cluster
1021        let cluster = vec![0usize; n];
1022
1023        let ch = calinski_harabasz(&data, n, m, &t, &cluster);
1024
1025        // Single cluster should give 0
1026        assert!(ch.abs() < 1e-10);
1027    }
1028
1029    #[test]
1030    fn test_calinski_harabasz_invalid_input() {
1031        let t = uniform_grid(30);
1032
1033        // Empty data
1034        let ch = calinski_harabasz(&[], 0, 30, &t, &[]);
1035        assert!(ch.abs() < 1e-10);
1036    }
1037}