Skip to main content

fdars_core/
clustering.rs

1//! Clustering algorithms for functional data.
2//!
3//! This module provides k-means and fuzzy c-means clustering algorithms
4//! for functional data.
5
6use crate::helpers::{extract_curves, l2_distance, simpsons_weights, NUMERICAL_EPS};
7use crate::{iter_maybe_parallel, slice_maybe_parallel};
8use rand::prelude::*;
9#[cfg(feature = "parallel")]
10use rayon::iter::ParallelIterator;
11
12/// Result of k-means clustering.
13pub struct KmeansResult {
14    /// Cluster assignments for each observation
15    pub cluster: Vec<usize>,
16    /// Cluster centers (k x m matrix, column-major)
17    pub centers: Vec<f64>,
18    /// Within-cluster sum of squares for each cluster
19    pub withinss: Vec<f64>,
20    /// Total within-cluster sum of squares
21    pub tot_withinss: f64,
22    /// Number of iterations
23    pub iter: usize,
24    /// Whether the algorithm converged
25    pub converged: bool,
26}
27
28/// K-means++ initialization: select initial centers with probability proportional to D^2.
29///
30/// # Arguments
31/// * `curves` - Vector of curve vectors
32/// * `k` - Number of clusters
33/// * `weights` - Integration weights for L2 distance
34/// * `rng` - Random number generator
35///
36/// # Returns
37/// Vector of k initial cluster centers
38fn kmeans_plusplus_init(
39    curves: &[Vec<f64>],
40    k: usize,
41    weights: &[f64],
42    rng: &mut StdRng,
43) -> Vec<Vec<f64>> {
44    let n = curves.len();
45    let mut centers: Vec<Vec<f64>> = Vec::with_capacity(k);
46
47    // First center: random
48    let first_idx = rng.gen_range(0..n);
49    centers.push(curves[first_idx].clone());
50
51    // Remaining centers: probability proportional to D^2
52    for _ in 1..k {
53        let distances: Vec<f64> = curves
54            .iter()
55            .map(|curve| {
56                centers
57                    .iter()
58                    .map(|c| l2_distance(curve, c, weights))
59                    .fold(f64::INFINITY, f64::min)
60            })
61            .collect();
62
63        let dist_sq: Vec<f64> = distances.iter().map(|d| d * d).collect();
64        let total: f64 = dist_sq.iter().sum();
65
66        if total < NUMERICAL_EPS {
67            let idx = rng.gen_range(0..n);
68            centers.push(curves[idx].clone());
69        } else {
70            let r = rng.gen::<f64>() * total;
71            let mut cumsum = 0.0;
72            let mut chosen = 0;
73            for (i, &d) in dist_sq.iter().enumerate() {
74                cumsum += d;
75                if cumsum >= r {
76                    chosen = i;
77                    break;
78                }
79            }
80            centers.push(curves[chosen].clone());
81        }
82    }
83
84    centers
85}
86
87/// Compute fuzzy membership values for a single observation.
88///
89/// # Arguments
90/// * `distances` - Distances from the observation to each cluster center
91/// * `exponent` - Exponent for fuzzy membership (2 / (fuzziness - 1))
92///
93/// # Returns
94/// Vector of membership values (one per cluster)
95fn compute_fuzzy_membership(distances: &[f64], exponent: f64) -> Vec<f64> {
96    let k = distances.len();
97    let mut membership = vec![0.0; k];
98
99    // Check if observation is very close to any center
100    for (c, &dist) in distances.iter().enumerate() {
101        if dist < NUMERICAL_EPS {
102            // Assign full membership to this cluster
103            membership[c] = 1.0;
104            return membership;
105        }
106    }
107
108    // Normal fuzzy membership computation
109    for c in 0..k {
110        let mut sum = 0.0;
111        for c2 in 0..k {
112            if distances[c2] > NUMERICAL_EPS {
113                sum += (distances[c] / distances[c2]).powf(exponent);
114            }
115        }
116        membership[c] = if sum > NUMERICAL_EPS { 1.0 / sum } else { 1.0 };
117    }
118
119    membership
120}
121
122/// K-means clustering for functional data.
123///
124/// # Arguments
125/// * `data` - Column-major matrix (n x m)
126/// * `n` - Number of observations
127/// * `m` - Number of evaluation points
128/// * `argvals` - Evaluation points
129/// * `k` - Number of clusters
130/// * `max_iter` - Maximum iterations
131/// * `tol` - Convergence tolerance
132/// * `seed` - Random seed
133pub fn kmeans_fd(
134    data: &[f64],
135    n: usize,
136    m: usize,
137    argvals: &[f64],
138    k: usize,
139    max_iter: usize,
140    tol: f64,
141    seed: u64,
142) -> KmeansResult {
143    if n == 0 || m == 0 || k == 0 || k > n || argvals.len() != m {
144        return KmeansResult {
145            cluster: Vec::new(),
146            centers: Vec::new(),
147            withinss: Vec::new(),
148            tot_withinss: 0.0,
149            iter: 0,
150            converged: false,
151        };
152    }
153
154    let weights = simpsons_weights(argvals);
155    let mut rng = StdRng::seed_from_u64(seed);
156
157    // Extract curves using helper
158    let curves = extract_curves(data, n, m);
159
160    // K-means++ initialization using helper
161    let mut centers = kmeans_plusplus_init(&curves, k, &weights, &mut rng);
162
163    let mut cluster = vec![0usize; n];
164    let mut converged = false;
165    let mut iter = 0;
166
167    for iteration in 0..max_iter {
168        iter = iteration + 1;
169
170        // Assignment step
171        let new_cluster: Vec<usize> = slice_maybe_parallel!(curves)
172            .map(|curve| {
173                let mut best_cluster = 0;
174                let mut best_dist = f64::INFINITY;
175                for (c, center) in centers.iter().enumerate() {
176                    let dist = l2_distance(curve, center, &weights);
177                    if dist < best_dist {
178                        best_dist = dist;
179                        best_cluster = c;
180                    }
181                }
182                best_cluster
183            })
184            .collect();
185
186        // Check convergence
187        if new_cluster == cluster {
188            converged = true;
189            break;
190        }
191        cluster = new_cluster;
192
193        // Update step
194        let new_centers: Vec<Vec<f64>> = (0..k)
195            .map(|c| {
196                let members: Vec<usize> = cluster
197                    .iter()
198                    .enumerate()
199                    .filter(|(_, &cl)| cl == c)
200                    .map(|(i, _)| i)
201                    .collect();
202
203                if members.is_empty() {
204                    centers[c].clone()
205                } else {
206                    let mut center = vec![0.0; m];
207                    for &i in &members {
208                        for j in 0..m {
209                            center[j] += curves[i][j];
210                        }
211                    }
212                    let n_members = members.len() as f64;
213                    for j in 0..m {
214                        center[j] /= n_members;
215                    }
216                    center
217                }
218            })
219            .collect();
220
221        // Check convergence by center movement
222        let max_movement: f64 = centers
223            .iter()
224            .zip(new_centers.iter())
225            .map(|(old, new)| l2_distance(old, new, &weights))
226            .fold(0.0, f64::max);
227
228        centers = new_centers;
229
230        if max_movement < tol {
231            converged = true;
232            break;
233        }
234    }
235
236    // Compute within-cluster sum of squares
237    let mut withinss = vec![0.0; k];
238    for (i, curve) in curves.iter().enumerate() {
239        let c = cluster[i];
240        let dist = l2_distance(curve, &centers[c], &weights);
241        withinss[c] += dist * dist;
242    }
243    let tot_withinss: f64 = withinss.iter().sum();
244
245    // Flatten centers (column-major: k x m)
246    let mut centers_flat = vec![0.0; k * m];
247    for c in 0..k {
248        for j in 0..m {
249            centers_flat[c + j * k] = centers[c][j];
250        }
251    }
252
253    KmeansResult {
254        cluster,
255        centers: centers_flat,
256        withinss,
257        tot_withinss,
258        iter,
259        converged,
260    }
261}
262
263/// Result of fuzzy c-means clustering.
264pub struct FuzzyCmeansResult {
265    /// Membership matrix (n x k, column-major)
266    pub membership: Vec<f64>,
267    /// Cluster centers (k x m, column-major)
268    pub centers: Vec<f64>,
269    /// Number of iterations
270    pub iter: usize,
271    /// Whether the algorithm converged
272    pub converged: bool,
273}
274
275/// Fuzzy c-means clustering for functional data.
276///
277/// # Arguments
278/// * `data` - Column-major matrix (n x m)
279/// * `n` - Number of observations
280/// * `m` - Number of evaluation points
281/// * `argvals` - Evaluation points
282/// * `k` - Number of clusters
283/// * `fuzziness` - Fuzziness parameter (> 1)
284/// * `max_iter` - Maximum iterations
285/// * `tol` - Convergence tolerance
286/// * `seed` - Random seed
287pub fn fuzzy_cmeans_fd(
288    data: &[f64],
289    n: usize,
290    m: usize,
291    argvals: &[f64],
292    k: usize,
293    fuzziness: f64,
294    max_iter: usize,
295    tol: f64,
296    seed: u64,
297) -> FuzzyCmeansResult {
298    if n == 0 || m == 0 || k == 0 || k > n || argvals.len() != m || fuzziness <= 1.0 {
299        return FuzzyCmeansResult {
300            membership: Vec::new(),
301            centers: Vec::new(),
302            iter: 0,
303            converged: false,
304        };
305    }
306
307    let weights = simpsons_weights(argvals);
308    let mut rng = StdRng::seed_from_u64(seed);
309
310    // Extract curves using helper
311    let curves = extract_curves(data, n, m);
312
313    // Initialize membership matrix randomly
314    let mut membership = vec![0.0; n * k];
315    for i in 0..n {
316        let mut row_sum = 0.0;
317        for c in 0..k {
318            let val = rng.gen::<f64>();
319            membership[i + c * n] = val;
320            row_sum += val;
321        }
322        for c in 0..k {
323            membership[i + c * n] /= row_sum;
324        }
325    }
326
327    let mut centers = vec![vec![0.0; m]; k];
328    let mut converged = false;
329    let mut iter = 0;
330    let exponent = 2.0 / (fuzziness - 1.0);
331
332    for iteration in 0..max_iter {
333        iter = iteration + 1;
334
335        // Update centers
336        for c in 0..k {
337            let mut numerator = vec![0.0; m];
338            let mut denominator = 0.0;
339
340            for (i, curve) in curves.iter().enumerate() {
341                let weight = membership[i + c * n].powf(fuzziness);
342                for j in 0..m {
343                    numerator[j] += weight * curve[j];
344                }
345                denominator += weight;
346            }
347
348            if denominator > NUMERICAL_EPS {
349                for j in 0..m {
350                    centers[c][j] = numerator[j] / denominator;
351                }
352            }
353        }
354
355        // Update membership using helper
356        let mut new_membership = vec![0.0; n * k];
357        let mut max_change = 0.0;
358
359        for (i, curve) in curves.iter().enumerate() {
360            let distances: Vec<f64> = centers
361                .iter()
362                .map(|c| l2_distance(curve, c, &weights))
363                .collect();
364
365            // Use the helper function for membership computation
366            let memberships = compute_fuzzy_membership(&distances, exponent);
367
368            for c in 0..k {
369                new_membership[i + c * n] = memberships[c];
370                let change = (memberships[c] - membership[i + c * n]).abs();
371                if change > max_change {
372                    max_change = change;
373                }
374            }
375        }
376
377        membership = new_membership;
378
379        if max_change < tol {
380            converged = true;
381            break;
382        }
383    }
384
385    // Flatten centers (column-major: k x m)
386    let mut centers_flat = vec![0.0; k * m];
387    for c in 0..k {
388        for j in 0..m {
389            centers_flat[c + j * k] = centers[c][j];
390        }
391    }
392
393    FuzzyCmeansResult {
394        membership,
395        centers: centers_flat,
396        iter,
397        converged,
398    }
399}
400
401/// Compute silhouette score for clustering result.
402pub fn silhouette_score(
403    data: &[f64],
404    n: usize,
405    m: usize,
406    argvals: &[f64],
407    cluster: &[usize],
408) -> Vec<f64> {
409    if n == 0 || m == 0 || cluster.len() != n || argvals.len() != m {
410        return Vec::new();
411    }
412
413    let weights = simpsons_weights(argvals);
414    let curves = extract_curves(data, n, m);
415
416    let k = cluster.iter().cloned().max().unwrap_or(0) + 1;
417
418    iter_maybe_parallel!(0..n)
419        .map(|i| {
420            let my_cluster = cluster[i];
421
422            // a(i) = average distance to points in same cluster
423            let same_cluster: Vec<usize> = cluster
424                .iter()
425                .enumerate()
426                .filter(|(j, &c)| c == my_cluster && *j != i)
427                .map(|(j, _)| j)
428                .collect();
429
430            let a_i = if same_cluster.is_empty() {
431                0.0
432            } else {
433                let sum: f64 = same_cluster
434                    .iter()
435                    .map(|&j| l2_distance(&curves[i], &curves[j], &weights))
436                    .sum();
437                sum / same_cluster.len() as f64
438            };
439
440            // b(i) = min average distance to points in other clusters
441            let mut b_i = f64::INFINITY;
442            for c in 0..k {
443                if c == my_cluster {
444                    continue;
445                }
446
447                let other_cluster: Vec<usize> = cluster
448                    .iter()
449                    .enumerate()
450                    .filter(|(_, &cl)| cl == c)
451                    .map(|(j, _)| j)
452                    .collect();
453
454                if other_cluster.is_empty() {
455                    continue;
456                }
457
458                let avg_dist: f64 = other_cluster
459                    .iter()
460                    .map(|&j| l2_distance(&curves[i], &curves[j], &weights))
461                    .sum::<f64>()
462                    / other_cluster.len() as f64;
463
464                b_i = b_i.min(avg_dist);
465            }
466
467            if b_i.is_infinite() {
468                0.0
469            } else {
470                let max_ab = a_i.max(b_i);
471                if max_ab > NUMERICAL_EPS {
472                    (b_i - a_i) / max_ab
473                } else {
474                    0.0
475                }
476            }
477        })
478        .collect()
479}
480
481/// Compute Calinski-Harabasz index for clustering result.
482pub fn calinski_harabasz(
483    data: &[f64],
484    n: usize,
485    m: usize,
486    argvals: &[f64],
487    cluster: &[usize],
488) -> f64 {
489    if n == 0 || m == 0 || cluster.len() != n || argvals.len() != m {
490        return 0.0;
491    }
492
493    let weights = simpsons_weights(argvals);
494    let curves = extract_curves(data, n, m);
495
496    let k = cluster.iter().cloned().max().unwrap_or(0) + 1;
497    if k < 2 {
498        return 0.0;
499    }
500
501    // Global mean
502    let mut global_mean = vec![0.0; m];
503    for curve in &curves {
504        for j in 0..m {
505            global_mean[j] += curve[j];
506        }
507    }
508    for j in 0..m {
509        global_mean[j] /= n as f64;
510    }
511
512    // Cluster centers
513    let mut centers = vec![vec![0.0; m]; k];
514    let mut counts = vec![0usize; k];
515    for (i, curve) in curves.iter().enumerate() {
516        let c = cluster[i];
517        counts[c] += 1;
518        for j in 0..m {
519            centers[c][j] += curve[j];
520        }
521    }
522    for c in 0..k {
523        if counts[c] > 0 {
524            for j in 0..m {
525                centers[c][j] /= counts[c] as f64;
526            }
527        }
528    }
529
530    // Between-cluster sum of squares
531    let mut bgss = 0.0;
532    for c in 0..k {
533        let dist = l2_distance(&centers[c], &global_mean, &weights);
534        bgss += counts[c] as f64 * dist * dist;
535    }
536
537    // Within-cluster sum of squares
538    let mut wgss = 0.0;
539    for (i, curve) in curves.iter().enumerate() {
540        let c = cluster[i];
541        let dist = l2_distance(curve, &centers[c], &weights);
542        wgss += dist * dist;
543    }
544
545    if wgss < NUMERICAL_EPS {
546        return f64::INFINITY;
547    }
548
549    (bgss / (k - 1) as f64) / (wgss / (n - k) as f64)
550}
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555    use std::f64::consts::PI;
556
557    /// Generate a uniform grid of points
558    fn uniform_grid(n: usize) -> Vec<f64> {
559        (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
560    }
561
562    /// Generate two clearly separated clusters of curves
563    fn generate_two_clusters(n_per_cluster: usize, m: usize) -> (Vec<f64>, Vec<f64>) {
564        let t = uniform_grid(m);
565        let mut data = Vec::with_capacity(2 * n_per_cluster * m);
566
567        // Cluster 0: sine waves with low amplitude
568        for i in 0..n_per_cluster {
569            for &ti in &t {
570                data.push((2.0 * PI * ti).sin() + 0.1 * (i as f64 / n_per_cluster as f64));
571            }
572        }
573
574        // Cluster 1: sine waves shifted up by 5
575        for i in 0..n_per_cluster {
576            for &ti in &t {
577                data.push((2.0 * PI * ti).sin() + 5.0 + 0.1 * (i as f64 / n_per_cluster as f64));
578            }
579        }
580
581        // Column-major reordering
582        let n = 2 * n_per_cluster;
583        let mut col_major = vec![0.0; n * m];
584        for i in 0..n {
585            for j in 0..m {
586                col_major[i + j * n] = data[i * m + j];
587            }
588        }
589
590        (col_major, t)
591    }
592
593    // ============== K-means tests ==============
594
595    #[test]
596    fn test_kmeans_fd_basic() {
597        let m = 50;
598        let n_per = 5;
599        let (data, t) = generate_two_clusters(n_per, m);
600        let n = 2 * n_per;
601
602        let result = kmeans_fd(&data, n, m, &t, 2, 100, 1e-6, 42);
603
604        assert_eq!(result.cluster.len(), n);
605        assert!(result.converged);
606        assert!(result.iter > 0 && result.iter <= 100);
607    }
608
609    #[test]
610    fn test_kmeans_fd_finds_clusters() {
611        let m = 50;
612        let n_per = 10;
613        let (data, t) = generate_two_clusters(n_per, m);
614        let n = 2 * n_per;
615
616        let result = kmeans_fd(&data, n, m, &t, 2, 100, 1e-6, 42);
617
618        // First half should be one cluster, second half the other
619        let cluster_0 = result.cluster[0];
620        let cluster_1 = result.cluster[n_per];
621
622        assert_ne!(cluster_0, cluster_1, "Clusters should be different");
623
624        // Check that first half is in same cluster
625        for i in 0..n_per {
626            assert_eq!(result.cluster[i], cluster_0);
627        }
628
629        // Check that second half is in same cluster
630        for i in n_per..n {
631            assert_eq!(result.cluster[i], cluster_1);
632        }
633    }
634
635    #[test]
636    fn test_kmeans_fd_deterministic() {
637        let m = 30;
638        let n_per = 5;
639        let (data, t) = generate_two_clusters(n_per, m);
640        let n = 2 * n_per;
641
642        let result1 = kmeans_fd(&data, n, m, &t, 2, 100, 1e-6, 42);
643        let result2 = kmeans_fd(&data, n, m, &t, 2, 100, 1e-6, 42);
644
645        // Same seed should give same results
646        assert_eq!(result1.cluster, result2.cluster);
647    }
648
649    #[test]
650    fn test_kmeans_fd_withinss() {
651        let m = 30;
652        let n_per = 5;
653        let (data, t) = generate_two_clusters(n_per, m);
654        let n = 2 * n_per;
655
656        let result = kmeans_fd(&data, n, m, &t, 2, 100, 1e-6, 42);
657
658        // Within-cluster sum of squares should be non-negative
659        for &wss in &result.withinss {
660            assert!(wss >= 0.0);
661        }
662
663        // Total should equal sum
664        let sum: f64 = result.withinss.iter().sum();
665        assert!((sum - result.tot_withinss).abs() < 1e-10);
666    }
667
668    #[test]
669    fn test_kmeans_fd_centers_shape() {
670        let m = 30;
671        let n_per = 5;
672        let (data, t) = generate_two_clusters(n_per, m);
673        let n = 2 * n_per;
674        let k = 3;
675
676        let result = kmeans_fd(&data, n, m, &t, k, 100, 1e-6, 42);
677
678        // Centers should be k x m matrix (column-major)
679        assert_eq!(result.centers.len(), k * m);
680    }
681
682    #[test]
683    fn test_kmeans_fd_invalid_input() {
684        let t = uniform_grid(30);
685
686        // Empty data
687        let result = kmeans_fd(&[], 0, 30, &t, 2, 100, 1e-6, 42);
688        assert!(result.cluster.is_empty());
689        assert!(!result.converged);
690
691        // k > n
692        let data = vec![0.0; 5 * 30];
693        let result = kmeans_fd(&data, 5, 30, &t, 10, 100, 1e-6, 42);
694        assert!(result.cluster.is_empty());
695    }
696
697    #[test]
698    fn test_kmeans_fd_single_cluster() {
699        let m = 30;
700        let t = uniform_grid(m);
701        let n = 10;
702        let data = vec![0.0; n * m];
703
704        let result = kmeans_fd(&data, n, m, &t, 1, 100, 1e-6, 42);
705
706        // All should be in cluster 0
707        for &c in &result.cluster {
708            assert_eq!(c, 0);
709        }
710    }
711
712    // ============== Fuzzy C-means tests ==============
713
714    #[test]
715    fn test_fuzzy_cmeans_fd_basic() {
716        let m = 50;
717        let n_per = 5;
718        let (data, t) = generate_two_clusters(n_per, m);
719        let n = 2 * n_per;
720
721        let result = fuzzy_cmeans_fd(&data, n, m, &t, 2, 2.0, 100, 1e-6, 42);
722
723        assert_eq!(result.membership.len(), n * 2);
724        assert!(result.iter > 0);
725    }
726
727    #[test]
728    fn test_fuzzy_cmeans_fd_membership_sums_to_one() {
729        let m = 30;
730        let n_per = 5;
731        let (data, t) = generate_two_clusters(n_per, m);
732        let n = 2 * n_per;
733        let k = 2;
734
735        let result = fuzzy_cmeans_fd(&data, n, m, &t, k, 2.0, 100, 1e-6, 42);
736
737        // Each observation's membership should sum to 1
738        for i in 0..n {
739            let sum: f64 = (0..k).map(|c| result.membership[i + c * n]).sum();
740            assert!(
741                (sum - 1.0).abs() < 1e-6,
742                "Membership should sum to 1, got {}",
743                sum
744            );
745        }
746    }
747
748    #[test]
749    fn test_fuzzy_cmeans_fd_membership_in_range() {
750        let m = 30;
751        let n_per = 5;
752        let (data, t) = generate_two_clusters(n_per, m);
753        let n = 2 * n_per;
754
755        let result = fuzzy_cmeans_fd(&data, n, m, &t, 2, 2.0, 100, 1e-6, 42);
756
757        // All memberships should be in [0, 1]
758        for &mem in &result.membership {
759            assert!((0.0..=1.0 + 1e-10).contains(&mem));
760        }
761    }
762
763    #[test]
764    fn test_fuzzy_cmeans_fd_fuzziness_effect() {
765        let m = 30;
766        let n_per = 5;
767        let (data, t) = generate_two_clusters(n_per, m);
768        let n = 2 * n_per;
769
770        let result_low = fuzzy_cmeans_fd(&data, n, m, &t, 2, 1.5, 100, 1e-6, 42);
771        let result_high = fuzzy_cmeans_fd(&data, n, m, &t, 2, 3.0, 100, 1e-6, 42);
772
773        // Higher fuzziness should give more diffuse memberships
774        // Measure by entropy-like metric
775        let entropy_low: f64 = result_low
776            .membership
777            .iter()
778            .map(|&m| if m > 1e-10 { -m * m.ln() } else { 0.0 })
779            .sum();
780
781        let entropy_high: f64 = result_high
782            .membership
783            .iter()
784            .map(|&m| if m > 1e-10 { -m * m.ln() } else { 0.0 })
785            .sum();
786
787        assert!(
788            entropy_high >= entropy_low - 0.1,
789            "Higher fuzziness should give higher entropy"
790        );
791    }
792
793    #[test]
794    fn test_fuzzy_cmeans_fd_invalid_fuzziness() {
795        let t = uniform_grid(30);
796        let data = vec![0.0; 10 * 30];
797
798        // Fuzziness <= 1 should fail
799        let result = fuzzy_cmeans_fd(&data, 10, 30, &t, 2, 1.0, 100, 1e-6, 42);
800        assert!(result.membership.is_empty());
801
802        let result = fuzzy_cmeans_fd(&data, 10, 30, &t, 2, 0.5, 100, 1e-6, 42);
803        assert!(result.membership.is_empty());
804    }
805
806    #[test]
807    fn test_fuzzy_cmeans_fd_centers_shape() {
808        let m = 30;
809        let t = uniform_grid(m);
810        let n = 10;
811        let k = 3;
812        let data = vec![0.0; n * m];
813
814        let result = fuzzy_cmeans_fd(&data, n, m, &t, k, 2.0, 100, 1e-6, 42);
815
816        assert_eq!(result.centers.len(), k * m);
817    }
818
819    // ============== Silhouette score tests ==============
820
821    #[test]
822    fn test_silhouette_score_well_separated() {
823        let m = 30;
824        let n_per = 10;
825        let (data, t) = generate_two_clusters(n_per, m);
826        let n = 2 * n_per;
827
828        // Perfect clustering: first half in 0, second in 1
829        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
830
831        let scores = silhouette_score(&data, n, m, &t, &cluster);
832
833        assert_eq!(scores.len(), n);
834
835        // Well-separated clusters should have high silhouette scores
836        let mean_score: f64 = scores.iter().sum::<f64>() / n as f64;
837        assert!(
838            mean_score > 0.5,
839            "Well-separated clusters should have high silhouette: {}",
840            mean_score
841        );
842    }
843
844    #[test]
845    fn test_silhouette_score_range() {
846        let m = 30;
847        let n_per = 5;
848        let (data, t) = generate_two_clusters(n_per, m);
849        let n = 2 * n_per;
850
851        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
852
853        let scores = silhouette_score(&data, n, m, &t, &cluster);
854
855        // Silhouette scores should be in [-1, 1]
856        for &s in &scores {
857            assert!((-1.0 - 1e-10..=1.0 + 1e-10).contains(&s));
858        }
859    }
860
861    #[test]
862    fn test_silhouette_score_single_cluster() {
863        let m = 30;
864        let t = uniform_grid(m);
865        let n = 10;
866        let data = vec![0.0; n * m];
867
868        // All in one cluster
869        let cluster = vec![0usize; n];
870
871        let scores = silhouette_score(&data, n, m, &t, &cluster);
872
873        // Single cluster should give zeros
874        for &s in &scores {
875            assert!(s.abs() < 1e-10);
876        }
877    }
878
879    #[test]
880    fn test_silhouette_score_invalid_input() {
881        let t = uniform_grid(30);
882
883        // Empty data
884        let scores = silhouette_score(&[], 0, 30, &t, &[]);
885        assert!(scores.is_empty());
886
887        // Mismatched cluster length
888        let data = vec![0.0; 10 * 30];
889        let cluster = vec![0; 5]; // Wrong length
890        let scores = silhouette_score(&data, 10, 30, &t, &cluster);
891        assert!(scores.is_empty());
892    }
893
894    // ============== Calinski-Harabasz tests ==============
895
896    #[test]
897    fn test_calinski_harabasz_well_separated() {
898        let m = 30;
899        let n_per = 10;
900        let (data, t) = generate_two_clusters(n_per, m);
901        let n = 2 * n_per;
902
903        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
904
905        let ch = calinski_harabasz(&data, n, m, &t, &cluster);
906
907        // Well-separated clusters should have high CH index
908        assert!(
909            ch > 1.0,
910            "Well-separated clusters should have high CH: {}",
911            ch
912        );
913    }
914
915    #[test]
916    fn test_calinski_harabasz_positive() {
917        let m = 30;
918        let n_per = 5;
919        let (data, t) = generate_two_clusters(n_per, m);
920        let n = 2 * n_per;
921
922        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
923
924        let ch = calinski_harabasz(&data, n, m, &t, &cluster);
925
926        assert!(ch >= 0.0, "CH index should be non-negative");
927    }
928
929    #[test]
930    fn test_calinski_harabasz_single_cluster() {
931        let m = 30;
932        let t = uniform_grid(m);
933        let n = 10;
934        let data = vec![0.0; n * m];
935
936        // All in one cluster
937        let cluster = vec![0usize; n];
938
939        let ch = calinski_harabasz(&data, n, m, &t, &cluster);
940
941        // Single cluster should give 0
942        assert!(ch.abs() < 1e-10);
943    }
944
945    #[test]
946    fn test_calinski_harabasz_invalid_input() {
947        let t = uniform_grid(30);
948
949        // Empty data
950        let ch = calinski_harabasz(&[], 0, 30, &t, &[]);
951        assert!(ch.abs() < 1e-10);
952    }
953}