Skip to main content

fdars_core/
clustering.rs

1//! Clustering algorithms for functional data.
2//!
3//! This module provides k-means and fuzzy c-means clustering algorithms
4//! for functional data.
5
6use crate::helpers::{l2_distance, simpsons_weights, NUMERICAL_EPS};
7use crate::matrix::FdMatrix;
8use crate::{iter_maybe_parallel, slice_maybe_parallel};
9use rand::prelude::*;
10#[cfg(feature = "parallel")]
11use rayon::iter::ParallelIterator;
12
13/// Result of k-means clustering.
14pub struct KmeansResult {
15    /// Cluster assignments for each observation
16    pub cluster: Vec<usize>,
17    /// Cluster centers (k x m matrix)
18    pub centers: FdMatrix,
19    /// Within-cluster sum of squares for each cluster
20    pub withinss: Vec<f64>,
21    /// Total within-cluster sum of squares
22    pub tot_withinss: f64,
23    /// Number of iterations
24    pub iter: usize,
25    /// Whether the algorithm converged
26    pub converged: bool,
27}
28
29/// K-means++ initialization: select initial centers with probability proportional to D^2.
30///
31/// # Arguments
32/// * `curves` - Vector of curve vectors
33/// * `k` - Number of clusters
34/// * `weights` - Integration weights for L2 distance
35/// * `rng` - Random number generator
36///
37/// # Returns
38/// Vector of k initial cluster centers
39fn kmeans_plusplus_init(
40    curves: &[Vec<f64>],
41    k: usize,
42    weights: &[f64],
43    rng: &mut StdRng,
44) -> Vec<Vec<f64>> {
45    let n = curves.len();
46    let mut centers: Vec<Vec<f64>> = Vec::with_capacity(k);
47
48    // First center: random
49    let first_idx = rng.gen_range(0..n);
50    centers.push(curves[first_idx].clone());
51
52    // Remaining centers: probability proportional to D^2
53    for _ in 1..k {
54        let distances: Vec<f64> = curves
55            .iter()
56            .map(|curve| {
57                centers
58                    .iter()
59                    .map(|c| l2_distance(curve, c, weights))
60                    .fold(f64::INFINITY, f64::min)
61            })
62            .collect();
63
64        let dist_sq: Vec<f64> = distances.iter().map(|d| d * d).collect();
65        let total: f64 = dist_sq.iter().sum();
66
67        if total < NUMERICAL_EPS {
68            let idx = rng.gen_range(0..n);
69            centers.push(curves[idx].clone());
70        } else {
71            let r = rng.gen::<f64>() * total;
72            let mut cumsum = 0.0;
73            let mut chosen = 0;
74            for (i, &d) in dist_sq.iter().enumerate() {
75                cumsum += d;
76                if cumsum >= r {
77                    chosen = i;
78                    break;
79                }
80            }
81            centers.push(curves[chosen].clone());
82        }
83    }
84
85    centers
86}
87
88/// Compute fuzzy membership values for a single observation.
89///
90/// # Arguments
91/// * `distances` - Distances from the observation to each cluster center
92/// * `exponent` - Exponent for fuzzy membership (2 / (fuzziness - 1))
93///
94/// # Returns
95/// Vector of membership values (one per cluster)
96fn compute_fuzzy_membership(distances: &[f64], exponent: f64) -> Vec<f64> {
97    let k = distances.len();
98    let mut membership = vec![0.0; k];
99
100    // Check if observation is very close to any center
101    for (c, &dist) in distances.iter().enumerate() {
102        if dist < NUMERICAL_EPS {
103            // Assign full membership to this cluster
104            membership[c] = 1.0;
105            return membership;
106        }
107    }
108
109    // Normal fuzzy membership computation
110    for c in 0..k {
111        let mut sum = 0.0;
112        for c2 in 0..k {
113            if distances[c2] > NUMERICAL_EPS {
114                sum += (distances[c] / distances[c2]).powf(exponent);
115            }
116        }
117        membership[c] = if sum > NUMERICAL_EPS { 1.0 / sum } else { 1.0 };
118    }
119
120    membership
121}
122
123/// Build an FdMatrix (k x m) from Vec<Vec<f64>> centers.
124fn centers_to_matrix(centers: &[Vec<f64>], k: usize, m: usize) -> FdMatrix {
125    let mut flat = vec![0.0; k * m];
126    for c in 0..k {
127        for j in 0..m {
128            flat[c + j * k] = centers[c][j];
129        }
130    }
131    FdMatrix::from_column_major(flat, k, m).unwrap()
132}
133
134/// Initialize a random membership matrix (n x k) with rows summing to 1.
135fn init_random_membership(n: usize, k: usize, rng: &mut StdRng) -> FdMatrix {
136    let mut membership = FdMatrix::zeros(n, k);
137    for i in 0..n {
138        let mut row_sum = 0.0;
139        for c in 0..k {
140            let val = rng.gen::<f64>();
141            membership[(i, c)] = val;
142            row_sum += val;
143        }
144        for c in 0..k {
145            membership[(i, c)] /= row_sum;
146        }
147    }
148    membership
149}
150
151/// Group sample indices by their cluster assignment.
152fn cluster_member_indices(cluster: &[usize], k: usize) -> Vec<Vec<usize>> {
153    let mut indices = vec![Vec::new(); k];
154    for (i, &c) in cluster.iter().enumerate() {
155        indices[c].push(i);
156    }
157    indices
158}
159
160/// Assign each curve to its nearest center, returning cluster indices.
161fn assign_clusters(curves: &[Vec<f64>], centers: &[Vec<f64>], weights: &[f64]) -> Vec<usize> {
162    slice_maybe_parallel!(curves)
163        .map(|curve| {
164            let mut best_cluster = 0;
165            let mut best_dist = f64::INFINITY;
166            for (c, center) in centers.iter().enumerate() {
167                let dist = l2_distance(curve, center, weights);
168                if dist < best_dist {
169                    best_dist = dist;
170                    best_cluster = c;
171                }
172            }
173            best_cluster
174        })
175        .collect()
176}
177
178/// Compute new cluster centers from curve assignments.
179fn update_kmeans_centers(
180    curves: &[Vec<f64>],
181    assignments: &[usize],
182    centers: &[Vec<f64>],
183    k: usize,
184    m: usize,
185) -> Vec<Vec<f64>> {
186    (0..k)
187        .map(|c| {
188            let members: Vec<usize> = assignments
189                .iter()
190                .enumerate()
191                .filter(|(_, &cl)| cl == c)
192                .map(|(i, _)| i)
193                .collect();
194
195            if members.is_empty() {
196                centers[c].clone()
197            } else {
198                let mut center = vec![0.0; m];
199                for &i in &members {
200                    for j in 0..m {
201                        center[j] += curves[i][j];
202                    }
203                }
204                let n_members = members.len() as f64;
205                for j in 0..m {
206                    center[j] /= n_members;
207                }
208                center
209            }
210        })
211        .collect()
212}
213
214/// Compute within-cluster sum of squares for each cluster.
215fn compute_within_ss(
216    curves: &[Vec<f64>],
217    centers: &[Vec<f64>],
218    assignments: &[usize],
219    k: usize,
220    weights: &[f64],
221) -> Vec<f64> {
222    let mut withinss = vec![0.0; k];
223    for (i, curve) in curves.iter().enumerate() {
224        let c = assignments[i];
225        let dist = l2_distance(curve, &centers[c], weights);
226        withinss[c] += dist * dist;
227    }
228    withinss
229}
230
231/// Update fuzzy c-means cluster centers from membership values.
232fn update_fuzzy_centers(
233    curves: &[Vec<f64>],
234    membership: &FdMatrix,
235    k: usize,
236    m: usize,
237    fuzziness: f64,
238) -> Vec<Vec<f64>> {
239    let mut centers = vec![vec![0.0; m]; k];
240    for c in 0..k {
241        let mut numerator = vec![0.0; m];
242        let mut denominator = 0.0;
243
244        for (i, curve) in curves.iter().enumerate() {
245            let weight = membership[(i, c)].powf(fuzziness);
246            for j in 0..m {
247                numerator[j] += weight * curve[j];
248            }
249            denominator += weight;
250        }
251
252        if denominator > NUMERICAL_EPS {
253            for j in 0..m {
254                centers[c][j] = numerator[j] / denominator;
255            }
256        }
257    }
258    centers
259}
260
261/// Update fuzzy membership values and compute max change.
262fn update_fuzzy_membership_step(
263    curves: &[Vec<f64>],
264    centers: &[Vec<f64>],
265    old_membership: &FdMatrix,
266    k: usize,
267    exponent: f64,
268    weights: &[f64],
269) -> (FdMatrix, f64) {
270    let n = curves.len();
271    let mut new_membership = FdMatrix::zeros(n, k);
272    let mut max_change = 0.0;
273
274    for (i, curve) in curves.iter().enumerate() {
275        let distances: Vec<f64> = centers
276            .iter()
277            .map(|c| l2_distance(curve, c, weights))
278            .collect();
279
280        let memberships = compute_fuzzy_membership(&distances, exponent);
281
282        for c in 0..k {
283            new_membership[(i, c)] = memberships[c];
284            let change = (memberships[c] - old_membership[(i, c)]).abs();
285            if change > max_change {
286                max_change = change;
287            }
288        }
289    }
290
291    (new_membership, max_change)
292}
293
294/// Compute mean L2 distance from a curve to a set of curve indices.
295fn mean_cluster_distance(
296    curve: &[f64],
297    curves: &[Vec<f64>],
298    indices: &[usize],
299    weights: &[f64],
300) -> f64 {
301    if indices.is_empty() {
302        return 0.0;
303    }
304    let sum: f64 = indices
305        .iter()
306        .map(|&j| l2_distance(curve, &curves[j], weights))
307        .sum();
308    sum / indices.len() as f64
309}
310
311/// Compute cluster centers, global mean, and counts from curves and assignments.
312fn compute_centers_and_global_mean(
313    curves: &[Vec<f64>],
314    assignments: &[usize],
315    k: usize,
316    m: usize,
317) -> (Vec<Vec<f64>>, Vec<f64>, Vec<usize>) {
318    let n = curves.len();
319    let mut global_mean = vec![0.0; m];
320    for curve in curves {
321        for j in 0..m {
322            global_mean[j] += curve[j];
323        }
324    }
325    for j in 0..m {
326        global_mean[j] /= n as f64;
327    }
328
329    let mut centers = vec![vec![0.0; m]; k];
330    let mut counts = vec![0usize; k];
331    for (i, curve) in curves.iter().enumerate() {
332        let c = assignments[i];
333        counts[c] += 1;
334        for j in 0..m {
335            centers[c][j] += curve[j];
336        }
337    }
338    for c in 0..k {
339        if counts[c] > 0 {
340            for j in 0..m {
341                centers[c][j] /= counts[c] as f64;
342            }
343        }
344    }
345
346    (centers, global_mean, counts)
347}
348
349/// Run one k-means iteration: assign clusters, update centers, compute movement.
350fn kmeans_step(
351    curves: &[Vec<f64>],
352    centers: &[Vec<f64>],
353    weights: &[f64],
354    k: usize,
355    m: usize,
356) -> (Vec<usize>, Vec<Vec<f64>>, f64) {
357    let new_cluster = assign_clusters(curves, centers, weights);
358    let new_centers = update_kmeans_centers(curves, &new_cluster, centers, k, m);
359    let max_movement = centers
360        .iter()
361        .zip(new_centers.iter())
362        .map(|(old, new)| l2_distance(old, new, weights))
363        .fold(0.0, f64::max);
364    (new_cluster, new_centers, max_movement)
365}
366
367/// Run the k-means iteration loop until convergence or max iterations.
368fn kmeans_iterate(
369    curves: &[Vec<f64>],
370    mut centers: Vec<Vec<f64>>,
371    weights: &[f64],
372    k: usize,
373    m: usize,
374    max_iter: usize,
375    tol: f64,
376) -> (Vec<usize>, Vec<Vec<f64>>, usize, bool) {
377    let n = curves.len();
378    let mut cluster = vec![0usize; n];
379    let mut converged = false;
380    let mut iter = 0;
381
382    for iteration in 0..max_iter {
383        iter = iteration + 1;
384        let (new_cluster, new_centers, max_movement) = kmeans_step(curves, &centers, weights, k, m);
385
386        if new_cluster == cluster {
387            converged = true;
388            break;
389        }
390        cluster = new_cluster;
391        centers = new_centers;
392
393        if max_movement < tol {
394            converged = true;
395            break;
396        }
397    }
398
399    (cluster, centers, iter, converged)
400}
401
402/// K-means clustering for functional data.
403///
404/// # Arguments
405/// * `data` - Functional data matrix (n x m)
406/// * `argvals` - Evaluation points
407/// * `k` - Number of clusters
408/// * `max_iter` - Maximum iterations
409/// * `tol` - Convergence tolerance
410/// * `seed` - Random seed
411pub fn kmeans_fd(
412    data: &FdMatrix,
413    argvals: &[f64],
414    k: usize,
415    max_iter: usize,
416    tol: f64,
417    seed: u64,
418) -> KmeansResult {
419    let n = data.nrows();
420    let m = data.ncols();
421
422    if n == 0 || m == 0 || k == 0 || k > n || argvals.len() != m {
423        return KmeansResult {
424            cluster: Vec::new(),
425            centers: FdMatrix::zeros(0, 0),
426            withinss: Vec::new(),
427            tot_withinss: 0.0,
428            iter: 0,
429            converged: false,
430        };
431    }
432
433    let weights = simpsons_weights(argvals);
434    let mut rng = StdRng::seed_from_u64(seed);
435
436    // Extract curves
437    let curves = data.rows();
438
439    // K-means++ initialization using helper
440    let centers = kmeans_plusplus_init(&curves, k, &weights, &mut rng);
441
442    let (cluster, centers, iter, converged) =
443        kmeans_iterate(&curves, centers, &weights, k, m, max_iter, tol);
444
445    let withinss = compute_within_ss(&curves, &centers, &cluster, k, &weights);
446    let tot_withinss: f64 = withinss.iter().sum();
447    let centers_mat = centers_to_matrix(&centers, k, m);
448
449    KmeansResult {
450        cluster,
451        centers: centers_mat,
452        withinss,
453        tot_withinss,
454        iter,
455        converged,
456    }
457}
458
459/// Result of fuzzy c-means clustering.
460pub struct FuzzyCmeansResult {
461    /// Membership matrix (n x k)
462    pub membership: FdMatrix,
463    /// Cluster centers (k x m)
464    pub centers: FdMatrix,
465    /// Number of iterations
466    pub iter: usize,
467    /// Whether the algorithm converged
468    pub converged: bool,
469}
470
471/// Fuzzy c-means clustering for functional data.
472///
473/// # Arguments
474/// * `data` - Functional data matrix (n x m)
475/// * `argvals` - Evaluation points
476/// * `k` - Number of clusters
477/// * `fuzziness` - Fuzziness parameter (> 1)
478/// * `max_iter` - Maximum iterations
479/// * `tol` - Convergence tolerance
480/// * `seed` - Random seed
481pub fn fuzzy_cmeans_fd(
482    data: &FdMatrix,
483    argvals: &[f64],
484    k: usize,
485    fuzziness: f64,
486    max_iter: usize,
487    tol: f64,
488    seed: u64,
489) -> FuzzyCmeansResult {
490    let n = data.nrows();
491    let m = data.ncols();
492
493    if n == 0 || m == 0 || k == 0 || k > n || argvals.len() != m || fuzziness <= 1.0 {
494        return FuzzyCmeansResult {
495            membership: FdMatrix::zeros(0, 0),
496            centers: FdMatrix::zeros(0, 0),
497            iter: 0,
498            converged: false,
499        };
500    }
501
502    let weights = simpsons_weights(argvals);
503    let mut rng = StdRng::seed_from_u64(seed);
504
505    // Extract curves
506    let curves = data.rows();
507
508    let mut membership = init_random_membership(n, k, &mut rng);
509
510    let mut centers = vec![vec![0.0; m]; k];
511    let mut converged = false;
512    let mut iter = 0;
513    let exponent = 2.0 / (fuzziness - 1.0);
514
515    for iteration in 0..max_iter {
516        iter = iteration + 1;
517
518        centers = update_fuzzy_centers(&curves, &membership, k, m, fuzziness);
519
520        let (new_membership, max_change) =
521            update_fuzzy_membership_step(&curves, &centers, &membership, k, exponent, &weights);
522
523        membership = new_membership;
524
525        if max_change < tol {
526            converged = true;
527            break;
528        }
529    }
530
531    let centers_mat = centers_to_matrix(&centers, k, m);
532
533    FuzzyCmeansResult {
534        membership,
535        centers: centers_mat,
536        iter,
537        converged,
538    }
539}
540
541/// Compute silhouette score for clustering result.
542pub fn silhouette_score(data: &FdMatrix, argvals: &[f64], cluster: &[usize]) -> Vec<f64> {
543    let n = data.nrows();
544    let m = data.ncols();
545
546    if n == 0 || m == 0 || cluster.len() != n || argvals.len() != m {
547        return Vec::new();
548    }
549
550    let weights = simpsons_weights(argvals);
551    let curves = data.rows();
552
553    let k = cluster.iter().cloned().max().unwrap_or(0) + 1;
554    let members = cluster_member_indices(cluster, k);
555
556    iter_maybe_parallel!(0..n)
557        .map(|i| {
558            let my_cluster = cluster[i];
559
560            let same_indices: Vec<usize> = members[my_cluster]
561                .iter()
562                .copied()
563                .filter(|&j| j != i)
564                .collect();
565            let a_i = mean_cluster_distance(&curves[i], &curves, &same_indices, &weights);
566
567            let mut b_i = f64::INFINITY;
568            for c in 0..k {
569                if c != my_cluster && !members[c].is_empty() {
570                    b_i = b_i.min(mean_cluster_distance(
571                        &curves[i],
572                        &curves,
573                        &members[c],
574                        &weights,
575                    ));
576                }
577            }
578
579            if b_i.is_infinite() {
580                0.0
581            } else {
582                let max_ab = a_i.max(b_i);
583                if max_ab > NUMERICAL_EPS {
584                    (b_i - a_i) / max_ab
585                } else {
586                    0.0
587                }
588            }
589        })
590        .collect()
591}
592
593/// Compute Calinski-Harabasz index for clustering result.
594pub fn calinski_harabasz(data: &FdMatrix, argvals: &[f64], cluster: &[usize]) -> f64 {
595    let n = data.nrows();
596    let m = data.ncols();
597
598    if n == 0 || m == 0 || cluster.len() != n || argvals.len() != m {
599        return 0.0;
600    }
601
602    let weights = simpsons_weights(argvals);
603    let curves = data.rows();
604
605    let k = cluster.iter().cloned().max().unwrap_or(0) + 1;
606    if k < 2 {
607        return 0.0;
608    }
609
610    let (centers, global_mean, counts) = compute_centers_and_global_mean(&curves, cluster, k, m);
611
612    let mut bgss = 0.0;
613    for c in 0..k {
614        let dist = l2_distance(&centers[c], &global_mean, &weights);
615        bgss += counts[c] as f64 * dist * dist;
616    }
617
618    let wgss_vec = compute_within_ss(&curves, &centers, cluster, k, &weights);
619    let wgss: f64 = wgss_vec.iter().sum();
620
621    if wgss < NUMERICAL_EPS {
622        return f64::INFINITY;
623    }
624
625    (bgss / (k - 1) as f64) / (wgss / (n - k) as f64)
626}
627
628#[cfg(test)]
629mod tests {
630    use super::*;
631    use std::f64::consts::PI;
632
633    /// Generate a uniform grid of points
634    fn uniform_grid(n: usize) -> Vec<f64> {
635        (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
636    }
637
638    /// Generate two clearly separated clusters of curves as an FdMatrix
639    fn generate_two_clusters(n_per_cluster: usize, m: usize) -> (FdMatrix, Vec<f64>) {
640        let t = uniform_grid(m);
641        let n = 2 * n_per_cluster;
642        let mut col_major = vec![0.0; n * m];
643
644        // Cluster 0: sine waves with low amplitude
645        for i in 0..n_per_cluster {
646            for (j, &ti) in t.iter().enumerate() {
647                col_major[i + j * n] =
648                    (2.0 * PI * ti).sin() + 0.1 * (i as f64 / n_per_cluster as f64);
649            }
650        }
651
652        // Cluster 1: sine waves shifted up by 5
653        for i in 0..n_per_cluster {
654            for (j, &ti) in t.iter().enumerate() {
655                col_major[(i + n_per_cluster) + j * n] =
656                    (2.0 * PI * ti).sin() + 5.0 + 0.1 * (i as f64 / n_per_cluster as f64);
657            }
658        }
659
660        (FdMatrix::from_column_major(col_major, n, m).unwrap(), t)
661    }
662
663    // ============== K-means tests ==============
664
665    #[test]
666    fn test_kmeans_fd_basic() {
667        let m = 50;
668        let n_per = 5;
669        let (data, t) = generate_two_clusters(n_per, m);
670        let n = 2 * n_per;
671
672        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42);
673
674        assert_eq!(result.cluster.len(), n);
675        assert!(result.converged);
676        assert!(result.iter > 0 && result.iter <= 100);
677    }
678
679    #[test]
680    fn test_kmeans_fd_finds_clusters() {
681        let m = 50;
682        let n_per = 10;
683        let (data, t) = generate_two_clusters(n_per, m);
684        let n = 2 * n_per;
685
686        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42);
687
688        // First half should be one cluster, second half the other
689        let cluster_0 = result.cluster[0];
690        let cluster_1 = result.cluster[n_per];
691
692        assert_ne!(cluster_0, cluster_1, "Clusters should be different");
693
694        // Check that first half is in same cluster
695        for i in 0..n_per {
696            assert_eq!(result.cluster[i], cluster_0);
697        }
698
699        // Check that second half is in same cluster
700        for i in n_per..n {
701            assert_eq!(result.cluster[i], cluster_1);
702        }
703    }
704
705    #[test]
706    fn test_kmeans_fd_deterministic() {
707        let m = 30;
708        let n_per = 5;
709        let (data, t) = generate_two_clusters(n_per, m);
710
711        let result1 = kmeans_fd(&data, &t, 2, 100, 1e-6, 42);
712        let result2 = kmeans_fd(&data, &t, 2, 100, 1e-6, 42);
713
714        // Same seed should give same results
715        assert_eq!(result1.cluster, result2.cluster);
716    }
717
718    #[test]
719    fn test_kmeans_fd_withinss() {
720        let m = 30;
721        let n_per = 5;
722        let (data, t) = generate_two_clusters(n_per, m);
723
724        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42);
725
726        // Within-cluster sum of squares should be non-negative
727        for &wss in &result.withinss {
728            assert!(wss >= 0.0);
729        }
730
731        // Total should equal sum
732        let sum: f64 = result.withinss.iter().sum();
733        assert!((sum - result.tot_withinss).abs() < 1e-10);
734    }
735
736    #[test]
737    fn test_kmeans_fd_centers_shape() {
738        let m = 30;
739        let n_per = 5;
740        let (data, t) = generate_two_clusters(n_per, m);
741        let k = 3;
742
743        let result = kmeans_fd(&data, &t, k, 100, 1e-6, 42);
744
745        // Centers should be k x m matrix
746        assert_eq!(result.centers.nrows(), k);
747        assert_eq!(result.centers.ncols(), m);
748    }
749
750    #[test]
751    fn test_kmeans_fd_invalid_input() {
752        let t = uniform_grid(30);
753
754        // Empty data
755        let data = FdMatrix::zeros(0, 0);
756        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42);
757        assert!(result.cluster.is_empty());
758        assert!(!result.converged);
759
760        // k > n
761        let data = FdMatrix::zeros(5, 30);
762        let result = kmeans_fd(&data, &t, 10, 100, 1e-6, 42);
763        assert!(result.cluster.is_empty());
764    }
765
766    #[test]
767    fn test_kmeans_fd_single_cluster() {
768        let m = 30;
769        let t = uniform_grid(m);
770        let n = 10;
771        let data = FdMatrix::zeros(n, m);
772
773        let result = kmeans_fd(&data, &t, 1, 100, 1e-6, 42);
774
775        // All should be in cluster 0
776        for &c in &result.cluster {
777            assert_eq!(c, 0);
778        }
779    }
780
781    // ============== Fuzzy C-means tests ==============
782
783    #[test]
784    fn test_fuzzy_cmeans_fd_basic() {
785        let m = 50;
786        let n_per = 5;
787        let (data, t) = generate_two_clusters(n_per, m);
788        let n = 2 * n_per;
789
790        let result = fuzzy_cmeans_fd(&data, &t, 2, 2.0, 100, 1e-6, 42);
791
792        assert_eq!(result.membership.nrows(), n);
793        assert_eq!(result.membership.ncols(), 2);
794        assert!(result.iter > 0);
795    }
796
797    #[test]
798    fn test_fuzzy_cmeans_fd_membership_sums_to_one() {
799        let m = 30;
800        let n_per = 5;
801        let (data, t) = generate_two_clusters(n_per, m);
802        let n = 2 * n_per;
803        let k = 2;
804
805        let result = fuzzy_cmeans_fd(&data, &t, k, 2.0, 100, 1e-6, 42);
806
807        // Each observation's membership should sum to 1
808        for i in 0..n {
809            let sum: f64 = (0..k).map(|c| result.membership[(i, c)]).sum();
810            assert!(
811                (sum - 1.0).abs() < 1e-6,
812                "Membership should sum to 1, got {}",
813                sum
814            );
815        }
816    }
817
818    #[test]
819    fn test_fuzzy_cmeans_fd_membership_in_range() {
820        let m = 30;
821        let n_per = 5;
822        let (data, t) = generate_two_clusters(n_per, m);
823
824        let result = fuzzy_cmeans_fd(&data, &t, 2, 2.0, 100, 1e-6, 42);
825
826        // All memberships should be in [0, 1]
827        for &mem in result.membership.as_slice() {
828            assert!((0.0..=1.0 + 1e-10).contains(&mem));
829        }
830    }
831
832    #[test]
833    fn test_fuzzy_cmeans_fd_fuzziness_effect() {
834        let m = 30;
835        let n_per = 5;
836        let (data, t) = generate_two_clusters(n_per, m);
837
838        let result_low = fuzzy_cmeans_fd(&data, &t, 2, 1.5, 100, 1e-6, 42);
839        let result_high = fuzzy_cmeans_fd(&data, &t, 2, 3.0, 100, 1e-6, 42);
840
841        // Higher fuzziness should give more diffuse memberships
842        // Measure by entropy-like metric
843        let entropy_low: f64 = result_low
844            .membership
845            .as_slice()
846            .iter()
847            .map(|&m| if m > 1e-10 { -m * m.ln() } else { 0.0 })
848            .sum();
849
850        let entropy_high: f64 = result_high
851            .membership
852            .as_slice()
853            .iter()
854            .map(|&m| if m > 1e-10 { -m * m.ln() } else { 0.0 })
855            .sum();
856
857        assert!(
858            entropy_high >= entropy_low - 0.1,
859            "Higher fuzziness should give higher entropy"
860        );
861    }
862
863    #[test]
864    fn test_fuzzy_cmeans_fd_invalid_fuzziness() {
865        let t = uniform_grid(30);
866        let data = FdMatrix::zeros(10, 30);
867
868        // Fuzziness <= 1 should fail
869        let result = fuzzy_cmeans_fd(&data, &t, 2, 1.0, 100, 1e-6, 42);
870        assert!(result.membership.is_empty());
871
872        let result = fuzzy_cmeans_fd(&data, &t, 2, 0.5, 100, 1e-6, 42);
873        assert!(result.membership.is_empty());
874    }
875
876    #[test]
877    fn test_fuzzy_cmeans_fd_centers_shape() {
878        let m = 30;
879        let t = uniform_grid(m);
880        let n = 10;
881        let k = 3;
882        let data = FdMatrix::zeros(n, m);
883
884        let result = fuzzy_cmeans_fd(&data, &t, k, 2.0, 100, 1e-6, 42);
885
886        assert_eq!(result.centers.nrows(), k);
887        assert_eq!(result.centers.ncols(), m);
888    }
889
890    // ============== Silhouette score tests ==============
891
892    #[test]
893    fn test_silhouette_score_well_separated() {
894        let m = 30;
895        let n_per = 10;
896        let (data, t) = generate_two_clusters(n_per, m);
897        let n = 2 * n_per;
898
899        // Perfect clustering: first half in 0, second in 1
900        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
901
902        let scores = silhouette_score(&data, &t, &cluster);
903
904        assert_eq!(scores.len(), n);
905
906        // Well-separated clusters should have high silhouette scores
907        let mean_score: f64 = scores.iter().sum::<f64>() / n as f64;
908        assert!(
909            mean_score > 0.5,
910            "Well-separated clusters should have high silhouette: {}",
911            mean_score
912        );
913    }
914
915    #[test]
916    fn test_silhouette_score_range() {
917        let m = 30;
918        let n_per = 5;
919        let (data, t) = generate_two_clusters(n_per, m);
920        let n = 2 * n_per;
921
922        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
923
924        let scores = silhouette_score(&data, &t, &cluster);
925
926        // Silhouette scores should be in [-1, 1]
927        for &s in &scores {
928            assert!((-1.0 - 1e-10..=1.0 + 1e-10).contains(&s));
929        }
930    }
931
932    #[test]
933    fn test_silhouette_score_single_cluster() {
934        let m = 30;
935        let t = uniform_grid(m);
936        let n = 10;
937        let data = FdMatrix::zeros(n, m);
938
939        // All in one cluster
940        let cluster = vec![0usize; n];
941
942        let scores = silhouette_score(&data, &t, &cluster);
943
944        // Single cluster should give zeros
945        for &s in &scores {
946            assert!(s.abs() < 1e-10);
947        }
948    }
949
950    #[test]
951    fn test_silhouette_score_invalid_input() {
952        let t = uniform_grid(30);
953
954        // Empty data
955        let data = FdMatrix::zeros(0, 0);
956        let scores = silhouette_score(&data, &t, &[]);
957        assert!(scores.is_empty());
958
959        // Mismatched cluster length
960        let data = FdMatrix::zeros(10, 30);
961        let cluster = vec![0; 5]; // Wrong length
962        let scores = silhouette_score(&data, &t, &cluster);
963        assert!(scores.is_empty());
964    }
965
966    // ============== Calinski-Harabasz tests ==============
967
968    #[test]
969    fn test_calinski_harabasz_well_separated() {
970        let m = 30;
971        let n_per = 10;
972        let (data, t) = generate_two_clusters(n_per, m);
973        let n = 2 * n_per;
974
975        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
976
977        let ch = calinski_harabasz(&data, &t, &cluster);
978
979        // Well-separated clusters should have high CH index
980        assert!(
981            ch > 1.0,
982            "Well-separated clusters should have high CH: {}",
983            ch
984        );
985    }
986
987    #[test]
988    fn test_calinski_harabasz_positive() {
989        let m = 30;
990        let n_per = 5;
991        let (data, t) = generate_two_clusters(n_per, m);
992        let n = 2 * n_per;
993
994        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
995
996        let ch = calinski_harabasz(&data, &t, &cluster);
997
998        assert!(ch >= 0.0, "CH index should be non-negative");
999    }
1000
1001    #[test]
1002    fn test_calinski_harabasz_single_cluster() {
1003        let m = 30;
1004        let t = uniform_grid(m);
1005        let n = 10;
1006        let data = FdMatrix::zeros(n, m);
1007
1008        // All in one cluster
1009        let cluster = vec![0usize; n];
1010
1011        let ch = calinski_harabasz(&data, &t, &cluster);
1012
1013        // Single cluster should give 0
1014        assert!(ch.abs() < 1e-10);
1015    }
1016
1017    #[test]
1018    fn test_calinski_harabasz_invalid_input() {
1019        let t = uniform_grid(30);
1020
1021        // Empty data
1022        let data = FdMatrix::zeros(0, 0);
1023        let ch = calinski_harabasz(&data, &t, &[]);
1024        assert!(ch.abs() < 1e-10);
1025    }
1026}