Skip to main content

fdars_core/
clustering.rs

1//! Clustering algorithms for functional data.
2//!
3//! This module provides k-means and fuzzy c-means clustering algorithms
4//! for functional data.
5
6use crate::helpers::{l2_distance, simpsons_weights, NUMERICAL_EPS};
7use crate::matrix::FdMatrix;
8use crate::{iter_maybe_parallel, slice_maybe_parallel};
9use rand::prelude::*;
10#[cfg(feature = "parallel")]
11use rayon::iter::ParallelIterator;
12
13/// Result of k-means clustering.
14pub struct KmeansResult {
15    /// Cluster assignments for each observation
16    pub cluster: Vec<usize>,
17    /// Cluster centers (k x m matrix)
18    pub centers: FdMatrix,
19    /// Within-cluster sum of squares for each cluster
20    pub withinss: Vec<f64>,
21    /// Total within-cluster sum of squares
22    pub tot_withinss: f64,
23    /// Number of iterations
24    pub iter: usize,
25    /// Whether the algorithm converged
26    pub converged: bool,
27}
28
29/// K-means++ initialization: select initial centers with probability proportional to D^2.
30///
31/// Uses incremental distance tracking: maintains a `min_dist_sq` vector and only
32/// computes distances to the newest center on each iteration, avoiding redundant
33/// distance computations to all existing centers.
34///
35/// # Arguments
36/// * `curves` - Flat row-major buffer of curves (n curves, each m values)
37/// * `n` - Number of curves
38/// * `m` - Number of evaluation points per curve
39/// * `k` - Number of clusters
40/// * `weights` - Integration weights for L2 distance
41/// * `rng` - Random number generator
42///
43/// # Returns
44/// Flat buffer of k initial cluster centers (k * m values)
45/// Select an index with probability proportional to the given weights.
46fn weighted_random_select(dist_sq: &[f64], rng: &mut StdRng) -> usize {
47    let total: f64 = dist_sq.iter().sum();
48    if total < NUMERICAL_EPS {
49        return rng.gen_range(0..dist_sq.len());
50    }
51    let r = rng.gen::<f64>() * total;
52    let mut cumsum = 0.0;
53    for (i, &d) in dist_sq.iter().enumerate() {
54        cumsum += d;
55        if cumsum >= r {
56            return i;
57        }
58    }
59    dist_sq.len() - 1
60}
61
62fn kmeans_plusplus_init(
63    curves: &[f64],
64    n: usize,
65    m: usize,
66    k: usize,
67    weights: &[f64],
68    rng: &mut StdRng,
69) -> Vec<f64> {
70    let mut centers = vec![0.0; k * m];
71
72    // First center: random
73    let first_idx = rng.gen_range(0..n);
74    centers[..m].copy_from_slice(&curves[first_idx * m..(first_idx + 1) * m]);
75
76    // Initialize min_dist_sq with squared distances to first center
77    let center0 = &centers[..m];
78    let mut min_dist_sq: Vec<f64> = (0..n)
79        .map(|i| {
80            let d = l2_distance(&curves[i * m..(i + 1) * m], center0, weights);
81            d * d
82        })
83        .collect();
84
85    // Remaining centers: probability proportional to D^2
86    for c_idx in 1..k {
87        let chosen = weighted_random_select(&min_dist_sq, rng);
88        centers[c_idx * m..(c_idx + 1) * m].copy_from_slice(&curves[chosen * m..(chosen + 1) * m]);
89
90        // Update min_dist_sq: only compute distance to the newest center
91        let new_center = &centers[c_idx * m..(c_idx + 1) * m];
92        for i in 0..n {
93            let d_sq = l2_distance(&curves[i * m..(i + 1) * m], new_center, weights).powi(2);
94            if d_sq < min_dist_sq[i] {
95                min_dist_sq[i] = d_sq;
96            }
97        }
98    }
99
100    centers
101}
102
103/// Compute fuzzy membership values for a single observation, writing into `out`.
104///
105/// # Arguments
106/// * `distances` - Distances from the observation to each cluster center
107/// * `k` - Number of clusters
108/// * `exponent` - Exponent for fuzzy membership (2 / (fuzziness - 1))
109/// * `out` - Output slice (length k) to write membership values into
110fn compute_fuzzy_membership_into(distances: &[f64], k: usize, exponent: f64, out: &mut [f64]) {
111    out[..k].fill(0.0);
112
113    // Check if observation is very close to any center
114    for (c, &dist) in distances[..k].iter().enumerate() {
115        if dist < NUMERICAL_EPS {
116            // Assign full membership to this cluster
117            out[c] = 1.0;
118            return;
119        }
120    }
121
122    // Normal fuzzy membership computation
123    for c in 0..k {
124        let mut sum = 0.0;
125        for c2 in 0..k {
126            if distances[c2] > NUMERICAL_EPS {
127                sum += (distances[c] / distances[c2]).powf(exponent);
128            }
129        }
130        out[c] = if sum > NUMERICAL_EPS { 1.0 / sum } else { 1.0 };
131    }
132}
133
134/// Build an FdMatrix (k x m) from flat row-major centers buffer.
135fn centers_to_matrix(centers: &[f64], k: usize, m: usize) -> FdMatrix {
136    let mut flat = vec![0.0; k * m];
137    for c in 0..k {
138        for j in 0..m {
139            flat[c + j * k] = centers[c * m + j];
140        }
141    }
142    FdMatrix::from_column_major(flat, k, m).unwrap()
143}
144
145/// Initialize a random membership matrix (n x k) with rows summing to 1.
146fn init_random_membership(n: usize, k: usize, rng: &mut StdRng) -> FdMatrix {
147    let mut membership = FdMatrix::zeros(n, k);
148    for i in 0..n {
149        let mut row_sum = 0.0;
150        for c in 0..k {
151            let val = rng.gen::<f64>();
152            membership[(i, c)] = val;
153            row_sum += val;
154        }
155        for c in 0..k {
156            membership[(i, c)] /= row_sum;
157        }
158    }
159    membership
160}
161
162/// Group sample indices by their cluster assignment.
163fn cluster_member_indices(cluster: &[usize], k: usize) -> Vec<Vec<usize>> {
164    let mut indices = vec![Vec::new(); k];
165    for (i, &c) in cluster.iter().enumerate() {
166        indices[c].push(i);
167    }
168    indices
169}
170
171/// Assign each curve to its nearest center, returning cluster indices.
172fn assign_clusters(
173    curves: &[f64],
174    n: usize,
175    m: usize,
176    centers: &[f64],
177    k: usize,
178    weights: &[f64],
179) -> Vec<usize> {
180    // Build a slice of curve slices for parallel iteration
181    let curve_indices: Vec<usize> = (0..n).collect();
182    slice_maybe_parallel!(curve_indices)
183        .map(|&i| {
184            let curve = &curves[i * m..(i + 1) * m];
185            let mut best_cluster = 0;
186            let mut best_dist = f64::INFINITY;
187            for c in 0..k {
188                let center = &centers[c * m..(c + 1) * m];
189                let dist = l2_distance(curve, center, weights);
190                if dist < best_dist {
191                    best_dist = dist;
192                    best_cluster = c;
193                }
194            }
195            best_cluster
196        })
197        .collect()
198}
199
200/// Compute new cluster centers from curve assignments.
201fn update_kmeans_centers(
202    curves: &[f64],
203    n: usize,
204    m: usize,
205    assignments: &[usize],
206    old_centers: &[f64],
207    k: usize,
208) -> Vec<f64> {
209    let mut centers = vec![0.0; k * m];
210    let mut counts = vec![0usize; k];
211
212    for i in 0..n {
213        let c = assignments[i];
214        counts[c] += 1;
215        let curve = &curves[i * m..(i + 1) * m];
216        let center = &mut centers[c * m..(c + 1) * m];
217        for j in 0..m {
218            center[j] += curve[j];
219        }
220    }
221
222    for c in 0..k {
223        if counts[c] > 0 {
224            let center = &mut centers[c * m..(c + 1) * m];
225            let n_members = counts[c] as f64;
226            for j in 0..m {
227                center[j] /= n_members;
228            }
229        } else {
230            // Keep old center for empty clusters
231            centers[c * m..(c + 1) * m].copy_from_slice(&old_centers[c * m..(c + 1) * m]);
232        }
233    }
234
235    centers
236}
237
238/// Compute within-cluster sum of squares for each cluster.
239fn compute_within_ss(
240    curves: &[f64],
241    n: usize,
242    m: usize,
243    centers: &[f64],
244    assignments: &[usize],
245    k: usize,
246    weights: &[f64],
247) -> Vec<f64> {
248    let mut withinss = vec![0.0; k];
249    for i in 0..n {
250        let c = assignments[i];
251        let dist = l2_distance(
252            &curves[i * m..(i + 1) * m],
253            &centers[c * m..(c + 1) * m],
254            weights,
255        );
256        withinss[c] += dist * dist;
257    }
258    withinss
259}
260
261/// Update fuzzy c-means cluster centers from membership values.
262fn update_fuzzy_centers(
263    curves: &[f64],
264    n: usize,
265    m: usize,
266    membership: &FdMatrix,
267    k: usize,
268    fuzziness: f64,
269) -> Vec<f64> {
270    let mut centers = vec![0.0; k * m];
271    for c in 0..k {
272        let mut denominator = 0.0;
273        let center = &mut centers[c * m..(c + 1) * m];
274
275        for i in 0..n {
276            let weight = membership[(i, c)].powf(fuzziness);
277            let curve = &curves[i * m..(i + 1) * m];
278            for j in 0..m {
279                center[j] += weight * curve[j];
280            }
281            denominator += weight;
282        }
283
284        if denominator > NUMERICAL_EPS {
285            for j in 0..m {
286                center[j] /= denominator;
287            }
288        }
289    }
290    centers
291}
292
293/// Update fuzzy membership values and compute max change.
294fn update_fuzzy_membership_step(
295    curves: &[f64],
296    n: usize,
297    m: usize,
298    centers: &[f64],
299    k: usize,
300    old_membership: &FdMatrix,
301    exponent: f64,
302    weights: &[f64],
303) -> (FdMatrix, f64) {
304    let mut new_membership = FdMatrix::zeros(n, k);
305    let mut max_change = 0.0;
306    let mut distances = vec![0.0; k];
307    let mut memberships = vec![0.0; k];
308
309    for i in 0..n {
310        let curve = &curves[i * m..(i + 1) * m];
311        for c in 0..k {
312            distances[c] = l2_distance(curve, &centers[c * m..(c + 1) * m], weights);
313        }
314
315        compute_fuzzy_membership_into(&distances, k, exponent, &mut memberships);
316
317        for c in 0..k {
318            new_membership[(i, c)] = memberships[c];
319            let change = (memberships[c] - old_membership[(i, c)]).abs();
320            if change > max_change {
321                max_change = change;
322            }
323        }
324    }
325
326    (new_membership, max_change)
327}
328
329/// Compute mean L2 distance from a curve to a set of curve indices.
330fn mean_cluster_distance(
331    curve: &[f64],
332    curves: &[f64],
333    m: usize,
334    indices: &[usize],
335    weights: &[f64],
336) -> f64 {
337    if indices.is_empty() {
338        return 0.0;
339    }
340    let sum: f64 = indices
341        .iter()
342        .map(|&j| l2_distance(curve, &curves[j * m..(j + 1) * m], weights))
343        .sum();
344    sum / indices.len() as f64
345}
346
347/// Compute cluster centers, global mean, and counts from curves and assignments.
348fn compute_centers_and_global_mean(
349    curves: &[f64],
350    n: usize,
351    m: usize,
352    assignments: &[usize],
353    k: usize,
354) -> (Vec<f64>, Vec<f64>, Vec<usize>) {
355    let mut global_mean = vec![0.0; m];
356    for i in 0..n {
357        let curve = &curves[i * m..(i + 1) * m];
358        for j in 0..m {
359            global_mean[j] += curve[j];
360        }
361    }
362    for j in 0..m {
363        global_mean[j] /= n as f64;
364    }
365
366    let mut centers = vec![0.0; k * m];
367    let mut counts = vec![0usize; k];
368    for i in 0..n {
369        let c = assignments[i];
370        counts[c] += 1;
371        let curve = &curves[i * m..(i + 1) * m];
372        let center = &mut centers[c * m..(c + 1) * m];
373        for j in 0..m {
374            center[j] += curve[j];
375        }
376    }
377    for c in 0..k {
378        if counts[c] > 0 {
379            let center = &mut centers[c * m..(c + 1) * m];
380            for j in 0..m {
381                center[j] /= counts[c] as f64;
382            }
383        }
384    }
385
386    (centers, global_mean, counts)
387}
388
389/// Run one k-means iteration: assign clusters, update centers, compute movement.
390fn kmeans_step(
391    curves: &[f64],
392    n: usize,
393    m: usize,
394    centers: &[f64],
395    k: usize,
396    weights: &[f64],
397) -> (Vec<usize>, Vec<f64>, f64) {
398    let new_cluster = assign_clusters(curves, n, m, centers, k, weights);
399    let new_centers = update_kmeans_centers(curves, n, m, &new_cluster, centers, k);
400    let max_movement = (0..k)
401        .map(|c| {
402            l2_distance(
403                &centers[c * m..(c + 1) * m],
404                &new_centers[c * m..(c + 1) * m],
405                weights,
406            )
407        })
408        .fold(0.0, f64::max);
409    (new_cluster, new_centers, max_movement)
410}
411
412/// Run the k-means iteration loop until convergence or max iterations.
413fn kmeans_iterate(
414    curves: &[f64],
415    n: usize,
416    m: usize,
417    mut centers: Vec<f64>,
418    k: usize,
419    weights: &[f64],
420    max_iter: usize,
421    tol: f64,
422) -> (Vec<usize>, Vec<f64>, usize, bool) {
423    let mut cluster = vec![0usize; n];
424    let mut converged = false;
425    let mut iter = 0;
426
427    for iteration in 0..max_iter {
428        iter = iteration + 1;
429        let (new_cluster, new_centers, max_movement) =
430            kmeans_step(curves, n, m, &centers, k, weights);
431
432        if new_cluster == cluster {
433            converged = true;
434            break;
435        }
436        cluster = new_cluster;
437        centers = new_centers;
438
439        if max_movement < tol {
440            converged = true;
441            break;
442        }
443    }
444
445    (cluster, centers, iter, converged)
446}
447
448/// K-means clustering for functional data.
449///
450/// # Arguments
451/// * `data` - Functional data matrix (n x m)
452/// * `argvals` - Evaluation points
453/// * `k` - Number of clusters
454/// * `max_iter` - Maximum iterations
455/// * `tol` - Convergence tolerance
456/// * `seed` - Random seed
457pub fn kmeans_fd(
458    data: &FdMatrix,
459    argvals: &[f64],
460    k: usize,
461    max_iter: usize,
462    tol: f64,
463    seed: u64,
464) -> KmeansResult {
465    let n = data.nrows();
466    let m = data.ncols();
467
468    if n == 0 || m == 0 || k == 0 || k > n || argvals.len() != m {
469        return KmeansResult {
470            cluster: Vec::new(),
471            centers: FdMatrix::zeros(0, 0),
472            withinss: Vec::new(),
473            tot_withinss: 0.0,
474            iter: 0,
475            converged: false,
476        };
477    }
478
479    let weights = simpsons_weights(argvals);
480    let mut rng = StdRng::seed_from_u64(seed);
481
482    // Extract curves as flat row-major buffer
483    let curves = data.to_row_major();
484
485    // K-means++ initialization
486    let centers = kmeans_plusplus_init(&curves, n, m, k, &weights, &mut rng);
487
488    let (cluster, centers, iter, converged) =
489        kmeans_iterate(&curves, n, m, centers, k, &weights, max_iter, tol);
490
491    let withinss = compute_within_ss(&curves, n, m, &centers, &cluster, k, &weights);
492    let tot_withinss: f64 = withinss.iter().sum();
493    let centers_mat = centers_to_matrix(&centers, k, m);
494
495    KmeansResult {
496        cluster,
497        centers: centers_mat,
498        withinss,
499        tot_withinss,
500        iter,
501        converged,
502    }
503}
504
505/// Result of fuzzy c-means clustering.
506pub struct FuzzyCmeansResult {
507    /// Membership matrix (n x k)
508    pub membership: FdMatrix,
509    /// Cluster centers (k x m)
510    pub centers: FdMatrix,
511    /// Number of iterations
512    pub iter: usize,
513    /// Whether the algorithm converged
514    pub converged: bool,
515}
516
517/// Fuzzy c-means clustering for functional data.
518///
519/// # Arguments
520/// * `data` - Functional data matrix (n x m)
521/// * `argvals` - Evaluation points
522/// * `k` - Number of clusters
523/// * `fuzziness` - Fuzziness parameter (> 1)
524/// * `max_iter` - Maximum iterations
525/// * `tol` - Convergence tolerance
526/// * `seed` - Random seed
527pub fn fuzzy_cmeans_fd(
528    data: &FdMatrix,
529    argvals: &[f64],
530    k: usize,
531    fuzziness: f64,
532    max_iter: usize,
533    tol: f64,
534    seed: u64,
535) -> FuzzyCmeansResult {
536    let n = data.nrows();
537    let m = data.ncols();
538
539    if n == 0 || m == 0 || k == 0 || k > n || argvals.len() != m || fuzziness <= 1.0 {
540        return FuzzyCmeansResult {
541            membership: FdMatrix::zeros(0, 0),
542            centers: FdMatrix::zeros(0, 0),
543            iter: 0,
544            converged: false,
545        };
546    }
547
548    let weights = simpsons_weights(argvals);
549    let mut rng = StdRng::seed_from_u64(seed);
550
551    // Extract curves as flat row-major buffer
552    let curves = data.to_row_major();
553
554    let mut membership = init_random_membership(n, k, &mut rng);
555
556    let mut centers = vec![0.0; k * m];
557    let mut converged = false;
558    let mut iter = 0;
559    let exponent = 2.0 / (fuzziness - 1.0);
560
561    for iteration in 0..max_iter {
562        iter = iteration + 1;
563
564        centers = update_fuzzy_centers(&curves, n, m, &membership, k, fuzziness);
565
566        let (new_membership, max_change) = update_fuzzy_membership_step(
567            &curves,
568            n,
569            m,
570            &centers,
571            k,
572            &membership,
573            exponent,
574            &weights,
575        );
576
577        membership = new_membership;
578
579        if max_change < tol {
580            converged = true;
581            break;
582        }
583    }
584
585    let centers_mat = centers_to_matrix(&centers, k, m);
586
587    FuzzyCmeansResult {
588        membership,
589        centers: centers_mat,
590        iter,
591        converged,
592    }
593}
594
595/// Compute silhouette score for clustering result.
596pub fn silhouette_score(data: &FdMatrix, argvals: &[f64], cluster: &[usize]) -> Vec<f64> {
597    let n = data.nrows();
598    let m = data.ncols();
599
600    if n == 0 || m == 0 || cluster.len() != n || argvals.len() != m {
601        return Vec::new();
602    }
603
604    let weights = simpsons_weights(argvals);
605    let curves = data.to_row_major();
606
607    let k = cluster.iter().cloned().max().unwrap_or(0) + 1;
608    let members = cluster_member_indices(cluster, k);
609
610    iter_maybe_parallel!(0..n)
611        .map(|i| {
612            let my_cluster = cluster[i];
613            let curve_i = &curves[i * m..(i + 1) * m];
614
615            let same_indices: Vec<usize> = members[my_cluster]
616                .iter()
617                .copied()
618                .filter(|&j| j != i)
619                .collect();
620            let a_i = mean_cluster_distance(curve_i, &curves, m, &same_indices, &weights);
621
622            let mut b_i = f64::INFINITY;
623            for c in 0..k {
624                if c != my_cluster && !members[c].is_empty() {
625                    b_i = b_i.min(mean_cluster_distance(
626                        curve_i,
627                        &curves,
628                        m,
629                        &members[c],
630                        &weights,
631                    ));
632                }
633            }
634
635            if b_i.is_infinite() {
636                0.0
637            } else {
638                let max_ab = a_i.max(b_i);
639                if max_ab > NUMERICAL_EPS {
640                    (b_i - a_i) / max_ab
641                } else {
642                    0.0
643                }
644            }
645        })
646        .collect()
647}
648
649/// Compute Calinski-Harabasz index for clustering result.
650pub fn calinski_harabasz(data: &FdMatrix, argvals: &[f64], cluster: &[usize]) -> f64 {
651    let n = data.nrows();
652    let m = data.ncols();
653
654    if n == 0 || m == 0 || cluster.len() != n || argvals.len() != m {
655        return 0.0;
656    }
657
658    let weights = simpsons_weights(argvals);
659    let curves = data.to_row_major();
660
661    let k = cluster.iter().cloned().max().unwrap_or(0) + 1;
662    if k < 2 {
663        return 0.0;
664    }
665
666    let (centers, global_mean, counts) = compute_centers_and_global_mean(&curves, n, m, cluster, k);
667
668    let mut bgss = 0.0;
669    for c in 0..k {
670        let dist = l2_distance(&centers[c * m..(c + 1) * m], &global_mean, &weights);
671        bgss += counts[c] as f64 * dist * dist;
672    }
673
674    let wgss_vec = compute_within_ss(&curves, n, m, &centers, cluster, k, &weights);
675    let wgss: f64 = wgss_vec.iter().sum();
676
677    if wgss < NUMERICAL_EPS {
678        return f64::INFINITY;
679    }
680
681    (bgss / (k - 1) as f64) / (wgss / (n - k) as f64)
682}
683
684#[cfg(test)]
685mod tests {
686    use super::*;
687    use std::f64::consts::PI;
688
689    /// Generate a uniform grid of points
690    fn uniform_grid(n: usize) -> Vec<f64> {
691        (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
692    }
693
694    /// Generate two clearly separated clusters of curves as an FdMatrix
695    fn generate_two_clusters(n_per_cluster: usize, m: usize) -> (FdMatrix, Vec<f64>) {
696        let t = uniform_grid(m);
697        let n = 2 * n_per_cluster;
698        let mut col_major = vec![0.0; n * m];
699
700        // Cluster 0: sine waves with low amplitude
701        for i in 0..n_per_cluster {
702            for (j, &ti) in t.iter().enumerate() {
703                col_major[i + j * n] =
704                    (2.0 * PI * ti).sin() + 0.1 * (i as f64 / n_per_cluster as f64);
705            }
706        }
707
708        // Cluster 1: sine waves shifted up by 5
709        for i in 0..n_per_cluster {
710            for (j, &ti) in t.iter().enumerate() {
711                col_major[(i + n_per_cluster) + j * n] =
712                    (2.0 * PI * ti).sin() + 5.0 + 0.1 * (i as f64 / n_per_cluster as f64);
713            }
714        }
715
716        (FdMatrix::from_column_major(col_major, n, m).unwrap(), t)
717    }
718
719    // ============== K-means tests ==============
720
721    #[test]
722    fn test_kmeans_fd_basic() {
723        let m = 50;
724        let n_per = 5;
725        let (data, t) = generate_two_clusters(n_per, m);
726        let n = 2 * n_per;
727
728        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42);
729
730        assert_eq!(result.cluster.len(), n);
731        assert!(result.converged);
732        assert!(result.iter > 0 && result.iter <= 100);
733    }
734
735    #[test]
736    fn test_kmeans_fd_finds_clusters() {
737        let m = 50;
738        let n_per = 10;
739        let (data, t) = generate_two_clusters(n_per, m);
740        let n = 2 * n_per;
741
742        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42);
743
744        // First half should be one cluster, second half the other
745        let cluster_0 = result.cluster[0];
746        let cluster_1 = result.cluster[n_per];
747
748        assert_ne!(cluster_0, cluster_1, "Clusters should be different");
749
750        // Check that first half is in same cluster
751        for i in 0..n_per {
752            assert_eq!(result.cluster[i], cluster_0);
753        }
754
755        // Check that second half is in same cluster
756        for i in n_per..n {
757            assert_eq!(result.cluster[i], cluster_1);
758        }
759    }
760
761    #[test]
762    fn test_kmeans_fd_deterministic() {
763        let m = 30;
764        let n_per = 5;
765        let (data, t) = generate_two_clusters(n_per, m);
766
767        let result1 = kmeans_fd(&data, &t, 2, 100, 1e-6, 42);
768        let result2 = kmeans_fd(&data, &t, 2, 100, 1e-6, 42);
769
770        // Same seed should give same results
771        assert_eq!(result1.cluster, result2.cluster);
772    }
773
774    #[test]
775    fn test_kmeans_fd_withinss() {
776        let m = 30;
777        let n_per = 5;
778        let (data, t) = generate_two_clusters(n_per, m);
779
780        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42);
781
782        // Within-cluster sum of squares should be non-negative
783        for &wss in &result.withinss {
784            assert!(wss >= 0.0);
785        }
786
787        // Total should equal sum
788        let sum: f64 = result.withinss.iter().sum();
789        assert!((sum - result.tot_withinss).abs() < 1e-10);
790    }
791
792    #[test]
793    fn test_kmeans_fd_centers_shape() {
794        let m = 30;
795        let n_per = 5;
796        let (data, t) = generate_two_clusters(n_per, m);
797        let k = 3;
798
799        let result = kmeans_fd(&data, &t, k, 100, 1e-6, 42);
800
801        // Centers should be k x m matrix
802        assert_eq!(result.centers.nrows(), k);
803        assert_eq!(result.centers.ncols(), m);
804    }
805
806    #[test]
807    fn test_kmeans_fd_invalid_input() {
808        let t = uniform_grid(30);
809
810        // Empty data
811        let data = FdMatrix::zeros(0, 0);
812        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42);
813        assert!(result.cluster.is_empty());
814        assert!(!result.converged);
815
816        // k > n
817        let data = FdMatrix::zeros(5, 30);
818        let result = kmeans_fd(&data, &t, 10, 100, 1e-6, 42);
819        assert!(result.cluster.is_empty());
820    }
821
822    #[test]
823    fn test_kmeans_fd_single_cluster() {
824        let m = 30;
825        let t = uniform_grid(m);
826        let n = 10;
827        let data = FdMatrix::zeros(n, m);
828
829        let result = kmeans_fd(&data, &t, 1, 100, 1e-6, 42);
830
831        // All should be in cluster 0
832        for &c in &result.cluster {
833            assert_eq!(c, 0);
834        }
835    }
836
837    // ============== Fuzzy C-means tests ==============
838
839    #[test]
840    fn test_fuzzy_cmeans_fd_basic() {
841        let m = 50;
842        let n_per = 5;
843        let (data, t) = generate_two_clusters(n_per, m);
844        let n = 2 * n_per;
845
846        let result = fuzzy_cmeans_fd(&data, &t, 2, 2.0, 100, 1e-6, 42);
847
848        assert_eq!(result.membership.nrows(), n);
849        assert_eq!(result.membership.ncols(), 2);
850        assert!(result.iter > 0);
851    }
852
853    #[test]
854    fn test_fuzzy_cmeans_fd_membership_sums_to_one() {
855        let m = 30;
856        let n_per = 5;
857        let (data, t) = generate_two_clusters(n_per, m);
858        let n = 2 * n_per;
859        let k = 2;
860
861        let result = fuzzy_cmeans_fd(&data, &t, k, 2.0, 100, 1e-6, 42);
862
863        // Each observation's membership should sum to 1
864        for i in 0..n {
865            let sum: f64 = (0..k).map(|c| result.membership[(i, c)]).sum();
866            assert!(
867                (sum - 1.0).abs() < 1e-6,
868                "Membership should sum to 1, got {}",
869                sum
870            );
871        }
872    }
873
874    #[test]
875    fn test_fuzzy_cmeans_fd_membership_in_range() {
876        let m = 30;
877        let n_per = 5;
878        let (data, t) = generate_two_clusters(n_per, m);
879
880        let result = fuzzy_cmeans_fd(&data, &t, 2, 2.0, 100, 1e-6, 42);
881
882        // All memberships should be in [0, 1]
883        for &mem in result.membership.as_slice() {
884            assert!((0.0..=1.0 + 1e-10).contains(&mem));
885        }
886    }
887
888    #[test]
889    fn test_fuzzy_cmeans_fd_fuzziness_effect() {
890        let m = 30;
891        let n_per = 5;
892        let (data, t) = generate_two_clusters(n_per, m);
893
894        let result_low = fuzzy_cmeans_fd(&data, &t, 2, 1.5, 100, 1e-6, 42);
895        let result_high = fuzzy_cmeans_fd(&data, &t, 2, 3.0, 100, 1e-6, 42);
896
897        // Higher fuzziness should give more diffuse memberships
898        // Measure by entropy-like metric
899        let entropy_low: f64 = result_low
900            .membership
901            .as_slice()
902            .iter()
903            .map(|&m| if m > 1e-10 { -m * m.ln() } else { 0.0 })
904            .sum();
905
906        let entropy_high: f64 = result_high
907            .membership
908            .as_slice()
909            .iter()
910            .map(|&m| if m > 1e-10 { -m * m.ln() } else { 0.0 })
911            .sum();
912
913        assert!(
914            entropy_high >= entropy_low - 0.1,
915            "Higher fuzziness should give higher entropy"
916        );
917    }
918
919    #[test]
920    fn test_fuzzy_cmeans_fd_invalid_fuzziness() {
921        let t = uniform_grid(30);
922        let data = FdMatrix::zeros(10, 30);
923
924        // Fuzziness <= 1 should fail
925        let result = fuzzy_cmeans_fd(&data, &t, 2, 1.0, 100, 1e-6, 42);
926        assert!(result.membership.is_empty());
927
928        let result = fuzzy_cmeans_fd(&data, &t, 2, 0.5, 100, 1e-6, 42);
929        assert!(result.membership.is_empty());
930    }
931
932    #[test]
933    fn test_fuzzy_cmeans_fd_centers_shape() {
934        let m = 30;
935        let t = uniform_grid(m);
936        let n = 10;
937        let k = 3;
938        let data = FdMatrix::zeros(n, m);
939
940        let result = fuzzy_cmeans_fd(&data, &t, k, 2.0, 100, 1e-6, 42);
941
942        assert_eq!(result.centers.nrows(), k);
943        assert_eq!(result.centers.ncols(), m);
944    }
945
946    // ============== Silhouette score tests ==============
947
948    #[test]
949    fn test_silhouette_score_well_separated() {
950        let m = 30;
951        let n_per = 10;
952        let (data, t) = generate_two_clusters(n_per, m);
953        let n = 2 * n_per;
954
955        // Perfect clustering: first half in 0, second in 1
956        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
957
958        let scores = silhouette_score(&data, &t, &cluster);
959
960        assert_eq!(scores.len(), n);
961
962        // Well-separated clusters should have high silhouette scores
963        let mean_score: f64 = scores.iter().sum::<f64>() / n as f64;
964        assert!(
965            mean_score > 0.5,
966            "Well-separated clusters should have high silhouette: {}",
967            mean_score
968        );
969    }
970
971    #[test]
972    fn test_silhouette_score_range() {
973        let m = 30;
974        let n_per = 5;
975        let (data, t) = generate_two_clusters(n_per, m);
976        let n = 2 * n_per;
977
978        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
979
980        let scores = silhouette_score(&data, &t, &cluster);
981
982        // Silhouette scores should be in [-1, 1]
983        for &s in &scores {
984            assert!((-1.0 - 1e-10..=1.0 + 1e-10).contains(&s));
985        }
986    }
987
988    #[test]
989    fn test_silhouette_score_single_cluster() {
990        let m = 30;
991        let t = uniform_grid(m);
992        let n = 10;
993        let data = FdMatrix::zeros(n, m);
994
995        // All in one cluster
996        let cluster = vec![0usize; n];
997
998        let scores = silhouette_score(&data, &t, &cluster);
999
1000        // Single cluster should give zeros
1001        for &s in &scores {
1002            assert!(s.abs() < 1e-10);
1003        }
1004    }
1005
1006    #[test]
1007    fn test_silhouette_score_invalid_input() {
1008        let t = uniform_grid(30);
1009
1010        // Empty data
1011        let data = FdMatrix::zeros(0, 0);
1012        let scores = silhouette_score(&data, &t, &[]);
1013        assert!(scores.is_empty());
1014
1015        // Mismatched cluster length
1016        let data = FdMatrix::zeros(10, 30);
1017        let cluster = vec![0; 5]; // Wrong length
1018        let scores = silhouette_score(&data, &t, &cluster);
1019        assert!(scores.is_empty());
1020    }
1021
1022    // ============== Calinski-Harabasz tests ==============
1023
1024    #[test]
1025    fn test_calinski_harabasz_well_separated() {
1026        let m = 30;
1027        let n_per = 10;
1028        let (data, t) = generate_two_clusters(n_per, m);
1029        let n = 2 * n_per;
1030
1031        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
1032
1033        let ch = calinski_harabasz(&data, &t, &cluster);
1034
1035        // Well-separated clusters should have high CH index
1036        assert!(
1037            ch > 1.0,
1038            "Well-separated clusters should have high CH: {}",
1039            ch
1040        );
1041    }
1042
1043    #[test]
1044    fn test_calinski_harabasz_positive() {
1045        let m = 30;
1046        let n_per = 5;
1047        let (data, t) = generate_two_clusters(n_per, m);
1048        let n = 2 * n_per;
1049
1050        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
1051
1052        let ch = calinski_harabasz(&data, &t, &cluster);
1053
1054        assert!(ch >= 0.0, "CH index should be non-negative");
1055    }
1056
1057    #[test]
1058    fn test_calinski_harabasz_single_cluster() {
1059        let m = 30;
1060        let t = uniform_grid(m);
1061        let n = 10;
1062        let data = FdMatrix::zeros(n, m);
1063
1064        // All in one cluster
1065        let cluster = vec![0usize; n];
1066
1067        let ch = calinski_harabasz(&data, &t, &cluster);
1068
1069        // Single cluster should give 0
1070        assert!(ch.abs() < 1e-10);
1071    }
1072
1073    #[test]
1074    fn test_calinski_harabasz_invalid_input() {
1075        let t = uniform_grid(30);
1076
1077        // Empty data
1078        let data = FdMatrix::zeros(0, 0);
1079        let ch = calinski_harabasz(&data, &t, &[]);
1080        assert!(ch.abs() < 1e-10);
1081    }
1082
1083    #[test]
1084    fn test_identical_curves_kmeans() {
1085        let m = 30;
1086        let t = uniform_grid(m);
1087        let n = 10;
1088        // All curves identical
1089        let data_vec: Vec<f64> = (0..n * m)
1090            .map(|idx| (2.0 * PI * (idx % m) as f64 / (m - 1) as f64).sin())
1091            .collect();
1092        let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
1093        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42);
1094        // Should not panic with identical data
1095        assert_eq!(result.cluster.len(), n);
1096    }
1097
1098    #[test]
1099    fn test_k_equals_n() {
1100        let m = 30;
1101        let t = uniform_grid(m);
1102        let n = 5;
1103        let (data, _) = generate_two_clusters(n, m);
1104        let result = kmeans_fd(&data, &t, 2 * n, 100, 1e-6, 42);
1105        // k == n: each curve is its own cluster
1106        assert_eq!(result.cluster.len(), 2 * n);
1107    }
1108
1109    #[test]
1110    fn test_n2_kmeans() {
1111        let m = 30;
1112        let t = uniform_grid(m);
1113        let mut data = FdMatrix::zeros(2, m);
1114        for j in 0..m {
1115            data[(0, j)] = 0.0;
1116            data[(1, j)] = 10.0;
1117        }
1118        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42);
1119        assert_eq!(result.cluster.len(), 2);
1120        assert_ne!(result.cluster[0], result.cluster[1]);
1121    }
1122}