Skip to main content

fdars_core/
clustering.rs

1//! Clustering algorithms for functional data.
2//!
3//! This module provides k-means and fuzzy c-means clustering algorithms
4//! for functional data.
5
6use crate::error::FdarError;
7use crate::helpers::{l2_distance, simpsons_weights, NUMERICAL_EPS};
8use crate::matrix::FdMatrix;
9use crate::{iter_maybe_parallel, maybe_par_chunks_mut_enumerate, slice_maybe_parallel};
10use rand::prelude::*;
11#[cfg(feature = "parallel")]
12use rayon::iter::ParallelIterator;
13
14/// Result of k-means clustering.
15#[derive(Debug, Clone, PartialEq)]
16#[non_exhaustive]
17pub struct KmeansResult {
18    /// Cluster assignments for each observation
19    pub cluster: Vec<usize>,
20    /// Cluster centers (k x m matrix)
21    pub centers: FdMatrix,
22    /// Within-cluster sum of squares for each cluster
23    pub withinss: Vec<f64>,
24    /// Total within-cluster sum of squares
25    pub tot_withinss: f64,
26    /// Number of iterations
27    pub iter: usize,
28    /// Whether the algorithm converged
29    pub converged: bool,
30}
31
32impl KmeansResult {
33    /// Assign new observations to the nearest cluster center.
34    ///
35    /// For each row in `data`, computes the weighted L2 distance to each
36    /// cluster center using Simpson's integration weights derived from
37    /// `argvals`, and assigns the observation to the cluster with the
38    /// minimum distance.
39    ///
40    /// # Arguments
41    /// * `data` - Matrix (n_new x m) of new observations
42    /// * `argvals` - Evaluation points (length m)
43    ///
44    /// # Errors
45    ///
46    /// Returns [`FdarError::InvalidDimension`] if the number of columns in
47    /// `data` does not match the number of columns in the cluster centers, or
48    /// if `argvals.len()` does not match the number of columns.
49    ///
50    /// # Examples
51    ///
52    /// ```
53    /// use fdars_core::matrix::FdMatrix;
54    /// use fdars_core::clustering::kmeans_fd;
55    ///
56    /// let argvals: Vec<f64> = (0..20).map(|i| i as f64 / 19.0).collect();
57    /// let data = FdMatrix::from_column_major(
58    ///     (0..200).map(|i| (i as f64 * 0.1).sin()).collect(),
59    ///     10, 20,
60    /// ).unwrap();
61    /// let result = kmeans_fd(&data, &argvals, 2, 100, 1e-6, 42).unwrap();
62    ///
63    /// // Predict cluster assignments for new data
64    /// let new_data = FdMatrix::from_column_major(
65    ///     (0..60).map(|i| (i as f64 * 0.1).sin()).collect(),
66    ///     3, 20,
67    /// ).unwrap();
68    /// let assignments = result.predict(&new_data, &argvals).unwrap();
69    /// assert_eq!(assignments.len(), 3);
70    /// assert!(assignments.iter().all(|&c| c < 2));
71    /// ```
72    pub fn predict(&self, data: &FdMatrix, argvals: &[f64]) -> Result<Vec<usize>, FdarError> {
73        let (n, m) = data.shape();
74        let m_centers = self.centers.ncols();
75        let k = self.centers.nrows();
76        if m != m_centers {
77            return Err(FdarError::InvalidDimension {
78                parameter: "data",
79                expected: format!("{m_centers} columns"),
80                actual: format!("{m} columns"),
81            });
82        }
83        if argvals.len() != m {
84            return Err(FdarError::InvalidDimension {
85                parameter: "argvals",
86                expected: format!("{m}"),
87                actual: format!("{}", argvals.len()),
88            });
89        }
90
91        let weights = simpsons_weights(argvals);
92        let curves = data.to_row_major();
93        let centers = self.centers.to_row_major();
94        Ok(assign_clusters(&curves, n, m, &centers, k, &weights))
95    }
96}
97
98/// K-means++ initialization: select initial centers with probability proportional to D^2.
99///
100/// Uses incremental distance tracking: maintains a `min_dist_sq` vector and only
101/// computes distances to the newest center on each iteration, avoiding redundant
102/// distance computations to all existing centers.
103///
104/// # Arguments
105/// * `curves` - Flat row-major buffer of curves (n curves, each m values)
106/// * `n` - Number of curves
107/// * `m` - Number of evaluation points per curve
108/// * `k` - Number of clusters
109/// * `weights` - Integration weights for L2 distance
110/// * `rng` - Random number generator
111///
112/// # Returns
113/// Flat buffer of k initial cluster centers (k * m values)
114/// Select an index with probability proportional to the given weights.
115fn weighted_random_select(dist_sq: &[f64], rng: &mut StdRng) -> usize {
116    let total: f64 = dist_sq.iter().sum();
117    if total < NUMERICAL_EPS {
118        return rng.gen_range(0..dist_sq.len());
119    }
120    let r = rng.gen::<f64>() * total;
121    let mut cumsum = 0.0;
122    for (i, &d) in dist_sq.iter().enumerate() {
123        cumsum += d;
124        if cumsum >= r {
125            return i;
126        }
127    }
128    dist_sq.len() - 1
129}
130
131fn kmeans_plusplus_init(
132    curves: &[f64],
133    n: usize,
134    m: usize,
135    k: usize,
136    weights: &[f64],
137    rng: &mut StdRng,
138) -> Vec<f64> {
139    let mut centers = vec![0.0; k * m];
140
141    // First center: random
142    let first_idx = rng.gen_range(0..n);
143    centers[..m].copy_from_slice(&curves[first_idx * m..(first_idx + 1) * m]);
144
145    // Initialize min_dist_sq with squared distances to first center
146    let center0 = &centers[..m];
147    let mut min_dist_sq: Vec<f64> = (0..n)
148        .map(|i| {
149            let d = l2_distance(&curves[i * m..(i + 1) * m], center0, weights);
150            d * d
151        })
152        .collect();
153
154    // Remaining centers: probability proportional to D^2
155    for c_idx in 1..k {
156        let chosen = weighted_random_select(&min_dist_sq, rng);
157        centers[c_idx * m..(c_idx + 1) * m].copy_from_slice(&curves[chosen * m..(chosen + 1) * m]);
158
159        // Update min_dist_sq: only compute distance to the newest center
160        let new_center = &centers[c_idx * m..(c_idx + 1) * m];
161        maybe_par_chunks_mut_enumerate!(min_dist_sq, 1, |(i, chunk): (usize, &mut [f64])| {
162            let d_sq = l2_distance(&curves[i * m..(i + 1) * m], new_center, weights).powi(2);
163            if d_sq < chunk[0] {
164                chunk[0] = d_sq;
165            }
166        });
167    }
168
169    centers
170}
171
172/// Compute fuzzy membership values for a single observation, writing into `out`.
173///
174/// # Arguments
175/// * `distances` - Distances from the observation to each cluster center
176/// * `k` - Number of clusters
177/// * `exponent` - Exponent for fuzzy membership (2 / (fuzziness - 1))
178/// * `out` - Output slice (length k) to write membership values into
179fn compute_fuzzy_membership_into(distances: &[f64], k: usize, exponent: f64, out: &mut [f64]) {
180    out[..k].fill(0.0);
181
182    // Check if observation is very close to any center
183    for (c, &dist) in distances[..k].iter().enumerate() {
184        if dist < NUMERICAL_EPS {
185            // Assign full membership to this cluster
186            out[c] = 1.0;
187            return;
188        }
189    }
190
191    // Normal fuzzy membership computation
192    for c in 0..k {
193        let mut sum = 0.0;
194        for c2 in 0..k {
195            if distances[c2] > NUMERICAL_EPS {
196                sum += (distances[c] / distances[c2]).powf(exponent);
197            }
198        }
199        out[c] = if sum > NUMERICAL_EPS { 1.0 / sum } else { 1.0 };
200    }
201}
202
203/// Build an FdMatrix (k x m) from flat row-major centers buffer.
204fn centers_to_matrix(centers: &[f64], k: usize, m: usize) -> FdMatrix {
205    let mut flat = vec![0.0; k * m];
206    for c in 0..k {
207        for j in 0..m {
208            flat[c + j * k] = centers[c * m + j];
209        }
210    }
211    FdMatrix::from_column_major(flat, k, m).expect("dimension invariant: data.len() == n * m")
212}
213
214/// Initialize a random membership matrix (n x k) with rows summing to 1.
215fn init_random_membership(n: usize, k: usize, rng: &mut StdRng) -> FdMatrix {
216    let mut membership = FdMatrix::zeros(n, k);
217    for i in 0..n {
218        let mut row_sum = 0.0;
219        for c in 0..k {
220            let val = rng.gen::<f64>();
221            membership[(i, c)] = val;
222            row_sum += val;
223        }
224        for c in 0..k {
225            membership[(i, c)] /= row_sum;
226        }
227    }
228    membership
229}
230
231/// Group sample indices by their cluster assignment.
232fn cluster_member_indices(cluster: &[usize], k: usize) -> Vec<Vec<usize>> {
233    let mut indices = vec![Vec::new(); k];
234    for (i, &c) in cluster.iter().enumerate() {
235        indices[c].push(i);
236    }
237    indices
238}
239
240/// Assign each curve to its nearest center, returning cluster indices.
241fn assign_clusters(
242    curves: &[f64],
243    n: usize,
244    m: usize,
245    centers: &[f64],
246    k: usize,
247    weights: &[f64],
248) -> Vec<usize> {
249    // Build a slice of curve slices for parallel iteration
250    let curve_indices: Vec<usize> = (0..n).collect();
251    slice_maybe_parallel!(curve_indices)
252        .map(|&i| {
253            let curve = &curves[i * m..(i + 1) * m];
254            let mut best_cluster = 0;
255            let mut best_dist = f64::INFINITY;
256            for c in 0..k {
257                let center = &centers[c * m..(c + 1) * m];
258                let dist = l2_distance(curve, center, weights);
259                if dist < best_dist {
260                    best_dist = dist;
261                    best_cluster = c;
262                }
263            }
264            best_cluster
265        })
266        .collect()
267}
268
269/// Compute new cluster centers from curve assignments.
270fn update_kmeans_centers(
271    curves: &[f64],
272    n: usize,
273    m: usize,
274    assignments: &[usize],
275    old_centers: &[f64],
276    k: usize,
277) -> Vec<f64> {
278    let mut centers = vec![0.0; k * m];
279    let mut counts = vec![0usize; k];
280
281    for i in 0..n {
282        let c = assignments[i];
283        counts[c] += 1;
284        let curve = &curves[i * m..(i + 1) * m];
285        let center = &mut centers[c * m..(c + 1) * m];
286        for j in 0..m {
287            center[j] += curve[j];
288        }
289    }
290
291    for c in 0..k {
292        if counts[c] > 0 {
293            let center = &mut centers[c * m..(c + 1) * m];
294            let n_members = counts[c] as f64;
295            for j in 0..m {
296                center[j] /= n_members;
297            }
298        } else {
299            // Keep old center for empty clusters
300            centers[c * m..(c + 1) * m].copy_from_slice(&old_centers[c * m..(c + 1) * m]);
301        }
302    }
303
304    centers
305}
306
307/// Compute within-cluster sum of squares for each cluster.
308fn compute_within_ss(
309    curves: &[f64],
310    n: usize,
311    m: usize,
312    centers: &[f64],
313    assignments: &[usize],
314    k: usize,
315    weights: &[f64],
316) -> Vec<f64> {
317    let mut withinss = vec![0.0; k];
318    for i in 0..n {
319        let c = assignments[i];
320        let dist = l2_distance(
321            &curves[i * m..(i + 1) * m],
322            &centers[c * m..(c + 1) * m],
323            weights,
324        );
325        withinss[c] += dist * dist;
326    }
327    withinss
328}
329
330/// Update fuzzy c-means cluster centers from membership values.
331fn update_fuzzy_centers(
332    curves: &[f64],
333    n: usize,
334    m: usize,
335    membership: &FdMatrix,
336    k: usize,
337    fuzziness: f64,
338) -> Vec<f64> {
339    let mut centers = vec![0.0; k * m];
340    for c in 0..k {
341        let mut denominator = 0.0;
342        let center = &mut centers[c * m..(c + 1) * m];
343
344        for i in 0..n {
345            let weight = membership[(i, c)].powf(fuzziness);
346            let curve = &curves[i * m..(i + 1) * m];
347            for j in 0..m {
348                center[j] += weight * curve[j];
349            }
350            denominator += weight;
351        }
352
353        if denominator > NUMERICAL_EPS {
354            for j in 0..m {
355                center[j] /= denominator;
356            }
357        }
358    }
359    centers
360}
361
362/// Update fuzzy membership values and compute max change.
363fn update_fuzzy_membership_step(
364    curves: &[f64],
365    n: usize,
366    m: usize,
367    centers: &[f64],
368    k: usize,
369    old_membership: &FdMatrix,
370    exponent: f64,
371    weights: &[f64],
372) -> (FdMatrix, f64) {
373    let mut new_membership = FdMatrix::zeros(n, k);
374    let mut max_change = 0.0;
375    let mut distances = vec![0.0; k];
376    let mut memberships = vec![0.0; k];
377
378    for i in 0..n {
379        let curve = &curves[i * m..(i + 1) * m];
380        for c in 0..k {
381            distances[c] = l2_distance(curve, &centers[c * m..(c + 1) * m], weights);
382        }
383
384        compute_fuzzy_membership_into(&distances, k, exponent, &mut memberships);
385
386        for c in 0..k {
387            new_membership[(i, c)] = memberships[c];
388            let change = (memberships[c] - old_membership[(i, c)]).abs();
389            if change > max_change {
390                max_change = change;
391            }
392        }
393    }
394
395    (new_membership, max_change)
396}
397
398/// Compute mean L2 distance from a curve to a set of curve indices.
399fn mean_cluster_distance(
400    curve: &[f64],
401    curves: &[f64],
402    m: usize,
403    indices: &[usize],
404    weights: &[f64],
405) -> f64 {
406    if indices.is_empty() {
407        return 0.0;
408    }
409    let sum: f64 = indices
410        .iter()
411        .map(|&j| l2_distance(curve, &curves[j * m..(j + 1) * m], weights))
412        .sum();
413    sum / indices.len() as f64
414}
415
416/// Compute cluster centers, global mean, and counts from curves and assignments.
417fn compute_centers_and_global_mean(
418    curves: &[f64],
419    n: usize,
420    m: usize,
421    assignments: &[usize],
422    k: usize,
423) -> (Vec<f64>, Vec<f64>, Vec<usize>) {
424    let mut global_mean = vec![0.0; m];
425    for i in 0..n {
426        let curve = &curves[i * m..(i + 1) * m];
427        for j in 0..m {
428            global_mean[j] += curve[j];
429        }
430    }
431    for j in 0..m {
432        global_mean[j] /= n as f64;
433    }
434
435    let mut centers = vec![0.0; k * m];
436    let mut counts = vec![0usize; k];
437    for i in 0..n {
438        let c = assignments[i];
439        counts[c] += 1;
440        let curve = &curves[i * m..(i + 1) * m];
441        let center = &mut centers[c * m..(c + 1) * m];
442        for j in 0..m {
443            center[j] += curve[j];
444        }
445    }
446    for c in 0..k {
447        if counts[c] > 0 {
448            let center = &mut centers[c * m..(c + 1) * m];
449            for j in 0..m {
450                center[j] /= counts[c] as f64;
451            }
452        }
453    }
454
455    (centers, global_mean, counts)
456}
457
458/// Run one k-means iteration: assign clusters, update centers, compute movement.
459fn kmeans_step(
460    curves: &[f64],
461    n: usize,
462    m: usize,
463    centers: &[f64],
464    k: usize,
465    weights: &[f64],
466) -> (Vec<usize>, Vec<f64>, f64) {
467    let new_cluster = assign_clusters(curves, n, m, centers, k, weights);
468    let new_centers = update_kmeans_centers(curves, n, m, &new_cluster, centers, k);
469    let max_movement = (0..k)
470        .map(|c| {
471            l2_distance(
472                &centers[c * m..(c + 1) * m],
473                &new_centers[c * m..(c + 1) * m],
474                weights,
475            )
476        })
477        .fold(0.0, f64::max);
478    (new_cluster, new_centers, max_movement)
479}
480
481/// Run the k-means iteration loop until convergence or max iterations.
482fn kmeans_iterate(
483    curves: &[f64],
484    n: usize,
485    m: usize,
486    mut centers: Vec<f64>,
487    k: usize,
488    weights: &[f64],
489    max_iter: usize,
490    tol: f64,
491) -> (Vec<usize>, Vec<f64>, usize, bool) {
492    let mut cluster = vec![0usize; n];
493    let mut converged = false;
494    let mut iter = 0;
495
496    for iteration in 0..max_iter {
497        iter = iteration + 1;
498        let (new_cluster, new_centers, max_movement) =
499            kmeans_step(curves, n, m, &centers, k, weights);
500
501        if new_cluster == cluster {
502            converged = true;
503            break;
504        }
505        cluster = new_cluster;
506        centers = new_centers;
507
508        if max_movement < tol {
509            converged = true;
510            break;
511        }
512    }
513
514    (cluster, centers, iter, converged)
515}
516
517/// K-means clustering for functional data.
518///
519/// # Arguments
520/// * `data` - Functional data matrix (n x m)
521/// * `argvals` - Evaluation points
522/// * `k` - Number of clusters
523/// * `max_iter` - Maximum iterations
524/// * `tol` - Convergence tolerance
525/// * `seed` - Random seed
526///
527/// # Examples
528///
529/// ```
530/// use fdars_core::matrix::FdMatrix;
531/// use fdars_core::clustering::kmeans_fd;
532///
533/// // 10 curves at 20 evaluation points
534/// let argvals: Vec<f64> = (0..20).map(|i| i as f64 / 19.0).collect();
535/// let data = FdMatrix::from_column_major(
536///     (0..200).map(|i| (i as f64 * 0.1).sin()).collect(),
537///     10, 20,
538/// ).unwrap();
539/// let result = kmeans_fd(&data, &argvals, 2, 100, 1e-6, 42).unwrap();
540/// assert_eq!(result.cluster.len(), 10);
541/// assert_eq!(result.centers.nrows(), 2);
542/// assert!(result.converged);
543/// ```
544#[must_use = "expensive computation whose result should not be discarded"]
545pub fn kmeans_fd(
546    data: &FdMatrix,
547    argvals: &[f64],
548    k: usize,
549    max_iter: usize,
550    tol: f64,
551    seed: u64,
552) -> Result<KmeansResult, FdarError> {
553    let n = data.nrows();
554    let m = data.ncols();
555
556    if n == 0 || m == 0 {
557        return Err(FdarError::InvalidDimension {
558            parameter: "data",
559            expected: "non-empty matrix".into(),
560            actual: format!("{n}x{m}"),
561        });
562    }
563    if k == 0 {
564        return Err(FdarError::InvalidParameter {
565            parameter: "k",
566            message: "number of clusters must be > 0".into(),
567        });
568    }
569    if k > n {
570        return Err(FdarError::InvalidParameter {
571            parameter: "k",
572            message: format!("k={k} exceeds number of observations n={n}"),
573        });
574    }
575    if argvals.len() != m {
576        return Err(FdarError::InvalidDimension {
577            parameter: "argvals",
578            expected: format!("{m}"),
579            actual: format!("{}", argvals.len()),
580        });
581    }
582
583    let weights = simpsons_weights(argvals);
584    let mut rng = StdRng::seed_from_u64(seed);
585
586    // Extract curves as flat row-major buffer
587    let curves = data.to_row_major();
588
589    // K-means++ initialization
590    let centers = kmeans_plusplus_init(&curves, n, m, k, &weights, &mut rng);
591
592    let (cluster, centers, iter, converged) =
593        kmeans_iterate(&curves, n, m, centers, k, &weights, max_iter, tol);
594
595    let withinss = compute_within_ss(&curves, n, m, &centers, &cluster, k, &weights);
596    let tot_withinss: f64 = withinss.iter().sum();
597    let centers_mat = centers_to_matrix(&centers, k, m);
598
599    Ok(KmeansResult {
600        cluster,
601        centers: centers_mat,
602        withinss,
603        tot_withinss,
604        iter,
605        converged,
606    })
607}
608
609/// Result of fuzzy c-means clustering.
610#[derive(Debug, Clone, PartialEq)]
611#[non_exhaustive]
612pub struct FuzzyCmeansResult {
613    /// Membership matrix (n x k)
614    pub membership: FdMatrix,
615    /// Cluster centers (k x m)
616    pub centers: FdMatrix,
617    /// Fuzziness parameter used during fitting
618    pub fuzziness: f64,
619    /// Number of iterations
620    pub iter: usize,
621    /// Whether the algorithm converged
622    pub converged: bool,
623}
624
625impl FuzzyCmeansResult {
626    /// Compute fuzzy membership values for new observations.
627    ///
628    /// For each new observation, computes the weighted L2 distance to each
629    /// cluster center and derives fuzzy membership values using the same
630    /// fuzziness parameter used during fitting.
631    ///
632    /// Each row of the returned matrix sums to 1.0, with values in \[0, 1\]
633    /// indicating the degree of membership in each cluster.
634    ///
635    /// # Arguments
636    /// * `data` - Matrix (n_new x m) of new observations
637    /// * `argvals` - Evaluation points (length m)
638    ///
639    /// # Errors
640    ///
641    /// Returns [`FdarError::InvalidDimension`] if the number of columns in
642    /// `data` does not match the number of columns in the cluster centers, or
643    /// if `argvals.len()` does not match the number of columns.
644    ///
645    /// # Examples
646    ///
647    /// ```
648    /// use fdars_core::matrix::FdMatrix;
649    /// use fdars_core::clustering::fuzzy_cmeans_fd;
650    ///
651    /// let argvals: Vec<f64> = (0..20).map(|i| i as f64 / 19.0).collect();
652    /// let data = FdMatrix::from_column_major(
653    ///     (0..200).map(|i| (i as f64 * 0.1).sin()).collect(),
654    ///     10, 20,
655    /// ).unwrap();
656    /// let result = fuzzy_cmeans_fd(&data, &argvals, 2, 2.0, 100, 1e-6, 42).unwrap();
657    ///
658    /// // Predict membership for new data
659    /// let new_data = FdMatrix::from_column_major(
660    ///     (0..60).map(|i| (i as f64 * 0.1).sin()).collect(),
661    ///     3, 20,
662    /// ).unwrap();
663    /// let membership = result.predict(&new_data, &argvals).unwrap();
664    /// assert_eq!(membership.shape(), (3, 2));
665    /// // Each row should sum to 1
666    /// for i in 0..3 {
667    ///     let sum: f64 = (0..2).map(|c| membership[(i, c)]).sum();
668    ///     assert!((sum - 1.0).abs() < 1e-6);
669    /// }
670    /// ```
671    pub fn predict(&self, data: &FdMatrix, argvals: &[f64]) -> Result<FdMatrix, FdarError> {
672        let (n, m) = data.shape();
673        let m_centers = self.centers.ncols();
674        let k = self.centers.nrows();
675        if m != m_centers {
676            return Err(FdarError::InvalidDimension {
677                parameter: "data",
678                expected: format!("{m_centers} columns"),
679                actual: format!("{m} columns"),
680            });
681        }
682        if argvals.len() != m {
683            return Err(FdarError::InvalidDimension {
684                parameter: "argvals",
685                expected: format!("{m}"),
686                actual: format!("{}", argvals.len()),
687            });
688        }
689
690        let weights = simpsons_weights(argvals);
691        let curves = data.to_row_major();
692        let centers = self.centers.to_row_major();
693        let exponent = 2.0 / (self.fuzziness - 1.0);
694
695        let mut membership = FdMatrix::zeros(n, k);
696        let mut distances = vec![0.0; k];
697        let mut memberships = vec![0.0; k];
698
699        for i in 0..n {
700            let curve = &curves[i * m..(i + 1) * m];
701            for c in 0..k {
702                distances[c] = l2_distance(curve, &centers[c * m..(c + 1) * m], &weights);
703            }
704
705            compute_fuzzy_membership_into(&distances, k, exponent, &mut memberships);
706
707            for c in 0..k {
708                membership[(i, c)] = memberships[c];
709            }
710        }
711
712        Ok(membership)
713    }
714}
715
716/// Fuzzy c-means clustering for functional data.
717///
718/// # Arguments
719/// * `data` - Functional data matrix (n x m)
720/// * `argvals` - Evaluation points
721/// * `k` - Number of clusters
722/// * `fuzziness` - Fuzziness parameter (> 1)
723/// * `max_iter` - Maximum iterations
724/// * `tol` - Convergence tolerance
725/// * `seed` - Random seed
726#[must_use = "expensive computation whose result should not be discarded"]
727pub fn fuzzy_cmeans_fd(
728    data: &FdMatrix,
729    argvals: &[f64],
730    k: usize,
731    fuzziness: f64,
732    max_iter: usize,
733    tol: f64,
734    seed: u64,
735) -> Result<FuzzyCmeansResult, FdarError> {
736    let n = data.nrows();
737    let m = data.ncols();
738
739    if n == 0 || m == 0 {
740        return Err(FdarError::InvalidDimension {
741            parameter: "data",
742            expected: "non-empty matrix".into(),
743            actual: format!("{n}x{m}"),
744        });
745    }
746    if k == 0 {
747        return Err(FdarError::InvalidParameter {
748            parameter: "k",
749            message: "number of clusters must be > 0".into(),
750        });
751    }
752    if k > n {
753        return Err(FdarError::InvalidParameter {
754            parameter: "k",
755            message: format!("k={k} exceeds number of observations n={n}"),
756        });
757    }
758    if argvals.len() != m {
759        return Err(FdarError::InvalidDimension {
760            parameter: "argvals",
761            expected: format!("{m}"),
762            actual: format!("{}", argvals.len()),
763        });
764    }
765    if fuzziness <= 1.0 {
766        return Err(FdarError::InvalidParameter {
767            parameter: "fuzziness",
768            message: format!("fuzziness must be > 1.0, got {fuzziness}"),
769        });
770    }
771
772    let weights = simpsons_weights(argvals);
773    let mut rng = StdRng::seed_from_u64(seed);
774
775    // Extract curves as flat row-major buffer
776    let curves = data.to_row_major();
777
778    let mut membership = init_random_membership(n, k, &mut rng);
779
780    let mut centers = vec![0.0; k * m];
781    let mut converged = false;
782    let mut iter = 0;
783    let exponent = 2.0 / (fuzziness - 1.0);
784
785    for iteration in 0..max_iter {
786        iter = iteration + 1;
787
788        centers = update_fuzzy_centers(&curves, n, m, &membership, k, fuzziness);
789
790        let (new_membership, max_change) = update_fuzzy_membership_step(
791            &curves,
792            n,
793            m,
794            &centers,
795            k,
796            &membership,
797            exponent,
798            &weights,
799        );
800
801        membership = new_membership;
802
803        if max_change < tol {
804            converged = true;
805            break;
806        }
807    }
808
809    let centers_mat = centers_to_matrix(&centers, k, m);
810
811    Ok(FuzzyCmeansResult {
812        membership,
813        centers: centers_mat,
814        fuzziness,
815        iter,
816        converged,
817    })
818}
819
820/// Compute silhouette score for clustering result.
821///
822/// # Examples
823///
824/// ```
825/// use fdars_core::matrix::FdMatrix;
826/// use fdars_core::clustering::silhouette_score;
827///
828/// let argvals: Vec<f64> = (0..10).map(|i| i as f64 / 9.0).collect();
829/// let data = FdMatrix::from_column_major(
830///     (0..60).map(|i| (i as f64 * 0.1).sin()).collect(),
831///     6, 10,
832/// ).unwrap();
833/// let cluster = vec![0, 0, 0, 1, 1, 1];
834/// let scores = silhouette_score(&data, &argvals, &cluster);
835/// assert_eq!(scores.len(), 6);
836/// // Silhouette scores are in [-1, 1]
837/// assert!(scores.iter().all(|&s| s >= -1.0 - 1e-10 && s <= 1.0 + 1e-10));
838/// ```
839#[must_use = "expensive computation whose result should not be discarded"]
840pub fn silhouette_score(data: &FdMatrix, argvals: &[f64], cluster: &[usize]) -> Vec<f64> {
841    let n = data.nrows();
842    let m = data.ncols();
843
844    if n == 0 || m == 0 || cluster.len() != n || argvals.len() != m {
845        return Vec::new();
846    }
847
848    let weights = simpsons_weights(argvals);
849    let curves = data.to_row_major();
850
851    let k = cluster.iter().copied().max().unwrap_or(0) + 1;
852    let members = cluster_member_indices(cluster, k);
853
854    iter_maybe_parallel!(0..n)
855        .map(|i| {
856            let my_cluster = cluster[i];
857            let curve_i = &curves[i * m..(i + 1) * m];
858
859            let same_indices: Vec<usize> = members[my_cluster]
860                .iter()
861                .copied()
862                .filter(|&j| j != i)
863                .collect();
864            let a_i = mean_cluster_distance(curve_i, &curves, m, &same_indices, &weights);
865
866            let mut b_i = f64::INFINITY;
867            for c in 0..k {
868                if c != my_cluster && !members[c].is_empty() {
869                    b_i = b_i.min(mean_cluster_distance(
870                        curve_i,
871                        &curves,
872                        m,
873                        &members[c],
874                        &weights,
875                    ));
876                }
877            }
878
879            if b_i.is_infinite() {
880                0.0
881            } else {
882                let max_ab = a_i.max(b_i);
883                if max_ab > NUMERICAL_EPS {
884                    (b_i - a_i) / max_ab
885                } else {
886                    0.0
887                }
888            }
889        })
890        .collect()
891}
892
893/// Silhouette score from a precomputed distance matrix.
894///
895/// Works with any distance matrix (elastic, DTW, Lp, or custom).
896#[must_use = "expensive computation whose result should not be discarded"]
897pub fn silhouette_score_from_distances(dist_mat: &FdMatrix, cluster: &[usize]) -> Vec<f64> {
898    let n = dist_mat.nrows();
899    if n == 0 || dist_mat.ncols() != n || cluster.len() != n {
900        return Vec::new();
901    }
902
903    let k = cluster.iter().copied().max().unwrap_or(0) + 1;
904    let members = cluster_member_indices(cluster, k);
905
906    (0..n)
907        .map(|i| {
908            let my_cluster = cluster[i];
909
910            // a(i) = mean distance to same-cluster members
911            let same: Vec<usize> = members[my_cluster]
912                .iter()
913                .copied()
914                .filter(|&j| j != i)
915                .collect();
916            let a_i = if same.is_empty() {
917                0.0
918            } else {
919                same.iter().map(|&j| dist_mat[(i, j)]).sum::<f64>() / same.len() as f64
920            };
921
922            // b(i) = min over other clusters of mean distance
923            let mut b_i = f64::INFINITY;
924            for c in 0..k {
925                if c != my_cluster && !members[c].is_empty() {
926                    let mean_d = members[c].iter().map(|&j| dist_mat[(i, j)]).sum::<f64>()
927                        / members[c].len() as f64;
928                    b_i = b_i.min(mean_d);
929                }
930            }
931
932            if b_i.is_infinite() {
933                0.0
934            } else {
935                let max_ab = a_i.max(b_i);
936                if max_ab > 1e-15 {
937                    (b_i - a_i) / max_ab
938                } else {
939                    0.0
940                }
941            }
942        })
943        .collect()
944}
945
946/// Compute Calinski-Harabasz index for clustering result.
947#[must_use = "expensive computation whose result should not be discarded"]
948pub fn calinski_harabasz(data: &FdMatrix, argvals: &[f64], cluster: &[usize]) -> f64 {
949    let n = data.nrows();
950    let m = data.ncols();
951
952    if n == 0 || m == 0 || cluster.len() != n || argvals.len() != m {
953        return 0.0;
954    }
955
956    let weights = simpsons_weights(argvals);
957    let curves = data.to_row_major();
958
959    let k = cluster.iter().copied().max().unwrap_or(0) + 1;
960    if k < 2 {
961        return 0.0;
962    }
963
964    let (centers, global_mean, counts) = compute_centers_and_global_mean(&curves, n, m, cluster, k);
965
966    let mut bgss = 0.0;
967    for c in 0..k {
968        let dist = l2_distance(&centers[c * m..(c + 1) * m], &global_mean, &weights);
969        bgss += counts[c] as f64 * dist * dist;
970    }
971
972    let wgss_vec = compute_within_ss(&curves, n, m, &centers, cluster, k, &weights);
973    let wgss: f64 = wgss_vec.iter().sum();
974
975    if wgss < NUMERICAL_EPS {
976        return f64::INFINITY;
977    }
978
979    (bgss / (k - 1) as f64) / (wgss / (n - k) as f64)
980}
981
982/// Calinski-Harabasz index from a precomputed distance matrix.
983///
984/// Uses the distance-based formulation: CH = [B/(k-1)] / [W/(n-k)]
985/// where B = total between-cluster distance, W = total within-cluster distance.
986#[must_use = "expensive computation whose result should not be discarded"]
987pub fn calinski_harabasz_from_distances(dist_mat: &FdMatrix, cluster: &[usize]) -> f64 {
988    let n = dist_mat.nrows();
989    if n == 0 || dist_mat.ncols() != n || cluster.len() != n {
990        return 0.0;
991    }
992
993    let k = cluster.iter().copied().max().unwrap_or(0) + 1;
994    if k < 2 || n <= k {
995        return 0.0;
996    }
997
998    // Total dispersion: sum of all pairwise squared distances
999    let total_disp: f64 = (0..n)
1000        .flat_map(|i| ((i + 1)..n).map(move |j| dist_mat[(i, j)].powi(2)))
1001        .sum::<f64>();
1002
1003    // Within-cluster dispersion
1004    let members = cluster_member_indices(cluster, k);
1005    let mut within = 0.0;
1006    for c in 0..k {
1007        let nc = members[c].len();
1008        if nc < 2 {
1009            continue;
1010        }
1011        for ii in 0..nc {
1012            for jj in (ii + 1)..nc {
1013                within += dist_mat[(members[c][ii], members[c][jj])].powi(2);
1014            }
1015        }
1016    }
1017
1018    let between = total_disp - within;
1019    // Normalize: account for cluster sizes
1020    let w_norm = within / (n - k) as f64;
1021    let b_norm = between / (k - 1) as f64;
1022
1023    if w_norm > 1e-15 {
1024        b_norm / w_norm
1025    } else {
1026        0.0
1027    }
1028}
1029
1030#[cfg(test)]
1031mod tests {
1032    use super::*;
1033    use crate::test_helpers::uniform_grid;
1034    use std::f64::consts::PI;
1035
1036    /// Generate two clearly separated clusters of curves as an FdMatrix
1037    fn generate_two_clusters(n_per_cluster: usize, m: usize) -> (FdMatrix, Vec<f64>) {
1038        let t = uniform_grid(m);
1039        let n = 2 * n_per_cluster;
1040        let mut col_major = vec![0.0; n * m];
1041
1042        // Cluster 0: sine waves with low amplitude
1043        for i in 0..n_per_cluster {
1044            for (j, &ti) in t.iter().enumerate() {
1045                col_major[i + j * n] =
1046                    (2.0 * PI * ti).sin() + 0.1 * (i as f64 / n_per_cluster as f64);
1047            }
1048        }
1049
1050        // Cluster 1: sine waves shifted up by 5
1051        for i in 0..n_per_cluster {
1052            for (j, &ti) in t.iter().enumerate() {
1053                col_major[(i + n_per_cluster) + j * n] =
1054                    (2.0 * PI * ti).sin() + 5.0 + 0.1 * (i as f64 / n_per_cluster as f64);
1055            }
1056        }
1057
1058        (FdMatrix::from_column_major(col_major, n, m).unwrap(), t)
1059    }
1060
1061    // ============== K-means tests ==============
1062
1063    #[test]
1064    fn test_kmeans_fd_basic() {
1065        let m = 50;
1066        let n_per = 5;
1067        let (data, t) = generate_two_clusters(n_per, m);
1068        let n = 2 * n_per;
1069
1070        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42).unwrap();
1071
1072        assert_eq!(result.cluster.len(), n);
1073        assert!(result.converged);
1074        assert!(result.iter > 0 && result.iter <= 100);
1075    }
1076
1077    #[test]
1078    fn test_kmeans_fd_finds_clusters() {
1079        let m = 50;
1080        let n_per = 10;
1081        let (data, t) = generate_two_clusters(n_per, m);
1082        let n = 2 * n_per;
1083
1084        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42).unwrap();
1085
1086        // First half should be one cluster, second half the other
1087        let cluster_0 = result.cluster[0];
1088        let cluster_1 = result.cluster[n_per];
1089
1090        assert_ne!(cluster_0, cluster_1, "Clusters should be different");
1091
1092        // Check that first half is in same cluster
1093        for i in 0..n_per {
1094            assert_eq!(result.cluster[i], cluster_0);
1095        }
1096
1097        // Check that second half is in same cluster
1098        for i in n_per..n {
1099            assert_eq!(result.cluster[i], cluster_1);
1100        }
1101    }
1102
1103    #[test]
1104    fn test_kmeans_fd_deterministic() {
1105        let m = 30;
1106        let n_per = 5;
1107        let (data, t) = generate_two_clusters(n_per, m);
1108
1109        let result1 = kmeans_fd(&data, &t, 2, 100, 1e-6, 42).unwrap();
1110        let result2 = kmeans_fd(&data, &t, 2, 100, 1e-6, 42).unwrap();
1111
1112        // Same seed should give same results
1113        assert_eq!(result1.cluster, result2.cluster);
1114    }
1115
1116    #[test]
1117    fn test_kmeans_fd_withinss() {
1118        let m = 30;
1119        let n_per = 5;
1120        let (data, t) = generate_two_clusters(n_per, m);
1121
1122        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42).unwrap();
1123
1124        // Within-cluster sum of squares should be non-negative
1125        for &wss in &result.withinss {
1126            assert!(wss >= 0.0);
1127        }
1128
1129        // Total should equal sum
1130        let sum: f64 = result.withinss.iter().sum();
1131        assert!((sum - result.tot_withinss).abs() < 1e-10);
1132    }
1133
1134    #[test]
1135    fn test_kmeans_fd_centers_shape() {
1136        let m = 30;
1137        let n_per = 5;
1138        let (data, t) = generate_two_clusters(n_per, m);
1139        let k = 3;
1140
1141        let result = kmeans_fd(&data, &t, k, 100, 1e-6, 42).unwrap();
1142
1143        // Centers should be k x m matrix
1144        assert_eq!(result.centers.nrows(), k);
1145        assert_eq!(result.centers.ncols(), m);
1146    }
1147
1148    #[test]
1149    fn test_kmeans_fd_invalid_input() {
1150        let t = uniform_grid(30);
1151
1152        // Empty data
1153        let data = FdMatrix::zeros(0, 0);
1154        assert!(kmeans_fd(&data, &t, 2, 100, 1e-6, 42).is_err());
1155
1156        // k > n
1157        let data = FdMatrix::zeros(5, 30);
1158        assert!(kmeans_fd(&data, &t, 10, 100, 1e-6, 42).is_err());
1159    }
1160
1161    #[test]
1162    fn test_kmeans_fd_single_cluster() {
1163        let m = 30;
1164        let t = uniform_grid(m);
1165        let n = 10;
1166        let data = FdMatrix::zeros(n, m);
1167
1168        let result = kmeans_fd(&data, &t, 1, 100, 1e-6, 42).unwrap();
1169
1170        // All should be in cluster 0
1171        for &c in &result.cluster {
1172            assert_eq!(c, 0);
1173        }
1174    }
1175
1176    // ============== Fuzzy C-means tests ==============
1177
1178    #[test]
1179    fn test_fuzzy_cmeans_fd_basic() {
1180        let m = 50;
1181        let n_per = 5;
1182        let (data, t) = generate_two_clusters(n_per, m);
1183        let n = 2 * n_per;
1184
1185        let result = fuzzy_cmeans_fd(&data, &t, 2, 2.0, 100, 1e-6, 42).unwrap();
1186
1187        assert_eq!(result.membership.nrows(), n);
1188        assert_eq!(result.membership.ncols(), 2);
1189        assert!(result.iter > 0);
1190    }
1191
1192    #[test]
1193    fn test_fuzzy_cmeans_fd_membership_sums_to_one() {
1194        let m = 30;
1195        let n_per = 5;
1196        let (data, t) = generate_two_clusters(n_per, m);
1197        let n = 2 * n_per;
1198        let k = 2;
1199
1200        let result = fuzzy_cmeans_fd(&data, &t, k, 2.0, 100, 1e-6, 42).unwrap();
1201
1202        // Each observation's membership should sum to 1
1203        for i in 0..n {
1204            let sum: f64 = (0..k).map(|c| result.membership[(i, c)]).sum();
1205            assert!(
1206                (sum - 1.0).abs() < 1e-6,
1207                "Membership should sum to 1, got {}",
1208                sum
1209            );
1210        }
1211    }
1212
1213    #[test]
1214    fn test_fuzzy_cmeans_fd_membership_in_range() {
1215        let m = 30;
1216        let n_per = 5;
1217        let (data, t) = generate_two_clusters(n_per, m);
1218
1219        let result = fuzzy_cmeans_fd(&data, &t, 2, 2.0, 100, 1e-6, 42).unwrap();
1220
1221        // All memberships should be in [0, 1]
1222        for &mem in result.membership.as_slice() {
1223            assert!((0.0..=1.0 + 1e-10).contains(&mem));
1224        }
1225    }
1226
1227    #[test]
1228    fn test_fuzzy_cmeans_fd_fuzziness_effect() {
1229        let m = 30;
1230        let n_per = 5;
1231        let (data, t) = generate_two_clusters(n_per, m);
1232
1233        let result_low = fuzzy_cmeans_fd(&data, &t, 2, 1.5, 100, 1e-6, 42).unwrap();
1234        let result_high = fuzzy_cmeans_fd(&data, &t, 2, 3.0, 100, 1e-6, 42).unwrap();
1235
1236        // Higher fuzziness should give more diffuse memberships
1237        // Measure by entropy-like metric
1238        let entropy_low: f64 = result_low
1239            .membership
1240            .as_slice()
1241            .iter()
1242            .map(|&m| if m > 1e-10 { -m * m.ln() } else { 0.0 })
1243            .sum();
1244
1245        let entropy_high: f64 = result_high
1246            .membership
1247            .as_slice()
1248            .iter()
1249            .map(|&m| if m > 1e-10 { -m * m.ln() } else { 0.0 })
1250            .sum();
1251
1252        assert!(
1253            entropy_high >= entropy_low - 0.1,
1254            "Higher fuzziness should give higher entropy"
1255        );
1256    }
1257
1258    #[test]
1259    fn test_fuzzy_cmeans_fd_invalid_fuzziness() {
1260        let t = uniform_grid(30);
1261        let data = FdMatrix::zeros(10, 30);
1262
1263        // Fuzziness <= 1 should fail
1264        assert!(fuzzy_cmeans_fd(&data, &t, 2, 1.0, 100, 1e-6, 42).is_err());
1265        assert!(fuzzy_cmeans_fd(&data, &t, 2, 0.5, 100, 1e-6, 42).is_err());
1266    }
1267
1268    #[test]
1269    fn test_fuzzy_cmeans_fd_centers_shape() {
1270        let m = 30;
1271        let t = uniform_grid(m);
1272        let n = 10;
1273        let k = 3;
1274        let data = FdMatrix::zeros(n, m);
1275
1276        let result = fuzzy_cmeans_fd(&data, &t, k, 2.0, 100, 1e-6, 42).unwrap();
1277
1278        assert_eq!(result.centers.nrows(), k);
1279        assert_eq!(result.centers.ncols(), m);
1280    }
1281
1282    // ============== Silhouette score tests ==============
1283
1284    #[test]
1285    fn test_silhouette_score_well_separated() {
1286        let m = 30;
1287        let n_per = 10;
1288        let (data, t) = generate_two_clusters(n_per, m);
1289        let n = 2 * n_per;
1290
1291        // Perfect clustering: first half in 0, second in 1
1292        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
1293
1294        let scores = silhouette_score(&data, &t, &cluster);
1295
1296        assert_eq!(scores.len(), n);
1297
1298        // Well-separated clusters should have high silhouette scores
1299        let mean_score: f64 = scores.iter().sum::<f64>() / n as f64;
1300        assert!(
1301            mean_score > 0.5,
1302            "Well-separated clusters should have high silhouette: {}",
1303            mean_score
1304        );
1305    }
1306
1307    #[test]
1308    fn test_silhouette_score_range() {
1309        let m = 30;
1310        let n_per = 5;
1311        let (data, t) = generate_two_clusters(n_per, m);
1312        let n = 2 * n_per;
1313
1314        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
1315
1316        let scores = silhouette_score(&data, &t, &cluster);
1317
1318        // Silhouette scores should be in [-1, 1]
1319        for &s in &scores {
1320            assert!((-1.0 - 1e-10..=1.0 + 1e-10).contains(&s));
1321        }
1322    }
1323
1324    #[test]
1325    fn test_silhouette_score_single_cluster() {
1326        let m = 30;
1327        let t = uniform_grid(m);
1328        let n = 10;
1329        let data = FdMatrix::zeros(n, m);
1330
1331        // All in one cluster
1332        let cluster = vec![0usize; n];
1333
1334        let scores = silhouette_score(&data, &t, &cluster);
1335
1336        // Single cluster should give zeros
1337        for &s in &scores {
1338            assert!(s.abs() < 1e-10);
1339        }
1340    }
1341
1342    #[test]
1343    fn test_silhouette_score_invalid_input() {
1344        let t = uniform_grid(30);
1345
1346        // Empty data
1347        let data = FdMatrix::zeros(0, 0);
1348        let scores = silhouette_score(&data, &t, &[]);
1349        assert!(scores.is_empty());
1350
1351        // Mismatched cluster length
1352        let data = FdMatrix::zeros(10, 30);
1353        let cluster = vec![0; 5]; // Wrong length
1354        let scores = silhouette_score(&data, &t, &cluster);
1355        assert!(scores.is_empty());
1356    }
1357
1358    // ============== Calinski-Harabasz tests ==============
1359
1360    #[test]
1361    fn test_calinski_harabasz_well_separated() {
1362        let m = 30;
1363        let n_per = 10;
1364        let (data, t) = generate_two_clusters(n_per, m);
1365        let n = 2 * n_per;
1366
1367        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
1368
1369        let ch = calinski_harabasz(&data, &t, &cluster);
1370
1371        // Well-separated clusters should have high CH index
1372        assert!(
1373            ch > 1.0,
1374            "Well-separated clusters should have high CH: {}",
1375            ch
1376        );
1377    }
1378
1379    #[test]
1380    fn test_calinski_harabasz_positive() {
1381        let m = 30;
1382        let n_per = 5;
1383        let (data, t) = generate_two_clusters(n_per, m);
1384        let n = 2 * n_per;
1385
1386        let cluster: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
1387
1388        let ch = calinski_harabasz(&data, &t, &cluster);
1389
1390        assert!(ch >= 0.0, "CH index should be non-negative");
1391    }
1392
1393    #[test]
1394    fn test_calinski_harabasz_single_cluster() {
1395        let m = 30;
1396        let t = uniform_grid(m);
1397        let n = 10;
1398        let data = FdMatrix::zeros(n, m);
1399
1400        // All in one cluster
1401        let cluster = vec![0usize; n];
1402
1403        let ch = calinski_harabasz(&data, &t, &cluster);
1404
1405        // Single cluster should give 0
1406        assert!(ch.abs() < 1e-10);
1407    }
1408
1409    #[test]
1410    fn test_calinski_harabasz_invalid_input() {
1411        let t = uniform_grid(30);
1412
1413        // Empty data
1414        let data = FdMatrix::zeros(0, 0);
1415        let ch = calinski_harabasz(&data, &t, &[]);
1416        assert!(ch.abs() < 1e-10);
1417    }
1418
1419    #[test]
1420    fn test_identical_curves_kmeans() {
1421        let m = 30;
1422        let t = uniform_grid(m);
1423        let n = 10;
1424        // All curves identical
1425        let data_vec: Vec<f64> = (0..n * m)
1426            .map(|idx| (2.0 * PI * (idx % m) as f64 / (m - 1) as f64).sin())
1427            .collect();
1428        let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
1429        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42).unwrap();
1430        // Should not panic with identical data
1431        assert_eq!(result.cluster.len(), n);
1432    }
1433
1434    #[test]
1435    fn test_k_equals_n() {
1436        let m = 30;
1437        let t = uniform_grid(m);
1438        let n = 5;
1439        let (data, _) = generate_two_clusters(n, m);
1440        let result = kmeans_fd(&data, &t, 2 * n, 100, 1e-6, 42).unwrap();
1441        // k == n: each curve is its own cluster
1442        assert_eq!(result.cluster.len(), 2 * n);
1443    }
1444
1445    #[test]
1446    fn test_n2_kmeans() {
1447        let m = 30;
1448        let t = uniform_grid(m);
1449        let mut data = FdMatrix::zeros(2, m);
1450        for j in 0..m {
1451            data[(0, j)] = 0.0;
1452            data[(1, j)] = 10.0;
1453        }
1454        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42).unwrap();
1455        assert_eq!(result.cluster.len(), 2);
1456        assert_ne!(result.cluster[0], result.cluster[1]);
1457    }
1458
1459    // ============== KmeansResult::predict tests ==============
1460
1461    #[test]
1462    fn test_kmeans_predict_shape() {
1463        let m = 30;
1464        let n_per = 5;
1465        let (data, t) = generate_two_clusters(n_per, m);
1466
1467        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42).unwrap();
1468
1469        let new_data = FdMatrix::zeros(3, m);
1470        let assignments = result.predict(&new_data, &t).unwrap();
1471        assert_eq!(assignments.len(), 3);
1472    }
1473
1474    #[test]
1475    fn test_kmeans_predict_reproduces_training() {
1476        let m = 30;
1477        let n_per = 10;
1478        let (data, t) = generate_two_clusters(n_per, m);
1479
1480        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42).unwrap();
1481
1482        // Predicting on training data should reproduce original assignments
1483        let predicted = result.predict(&data, &t).unwrap();
1484        assert_eq!(predicted, result.cluster);
1485    }
1486
1487    #[test]
1488    fn test_kmeans_predict_correct_cluster() {
1489        let m = 30;
1490        let n_per = 10;
1491        let (data, t) = generate_two_clusters(n_per, m);
1492
1493        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42).unwrap();
1494
1495        // Create a new curve clearly in cluster 0 (low amplitude sine)
1496        let mut new_data = FdMatrix::zeros(1, m);
1497        for j in 0..m {
1498            new_data[(0, j)] = (2.0 * PI * t[j]).sin();
1499        }
1500        let pred = result.predict(&new_data, &t).unwrap();
1501        assert_eq!(pred[0], result.cluster[0]); // same cluster as first training group
1502
1503        // Create a new curve clearly in cluster 1 (shifted up by 5)
1504        let mut new_data2 = FdMatrix::zeros(1, m);
1505        for j in 0..m {
1506            new_data2[(0, j)] = (2.0 * PI * t[j]).sin() + 5.0;
1507        }
1508        let pred2 = result.predict(&new_data2, &t).unwrap();
1509        assert_eq!(pred2[0], result.cluster[n_per]); // same cluster as second group
1510    }
1511
1512    #[test]
1513    fn test_kmeans_predict_dimension_mismatch() {
1514        let m = 30;
1515        let n_per = 5;
1516        let (data, t) = generate_two_clusters(n_per, m);
1517        let result = kmeans_fd(&data, &t, 2, 100, 1e-6, 42).unwrap();
1518
1519        // Wrong number of columns
1520        let wrong_data = FdMatrix::zeros(3, 20);
1521        assert!(result.predict(&wrong_data, &t).is_err());
1522
1523        // Wrong argvals length
1524        let new_data = FdMatrix::zeros(3, m);
1525        let wrong_t: Vec<f64> = (0..20).map(|i| i as f64 / 19.0).collect();
1526        assert!(result.predict(&new_data, &wrong_t).is_err());
1527    }
1528
1529    // ============== FuzzyCmeansResult::predict tests ==============
1530
1531    #[test]
1532    fn test_fuzzy_predict_shape() {
1533        let m = 30;
1534        let n_per = 5;
1535        let k = 2;
1536        let (data, t) = generate_two_clusters(n_per, m);
1537
1538        let result = fuzzy_cmeans_fd(&data, &t, k, 2.0, 100, 1e-6, 42).unwrap();
1539
1540        let new_data = FdMatrix::zeros(3, m);
1541        let membership = result.predict(&new_data, &t).unwrap();
1542        assert_eq!(membership.shape(), (3, k));
1543    }
1544
1545    #[test]
1546    fn test_fuzzy_predict_membership_sums_to_one() {
1547        let m = 30;
1548        let n_per = 5;
1549        let k = 2;
1550        let (data, t) = generate_two_clusters(n_per, m);
1551
1552        let result = fuzzy_cmeans_fd(&data, &t, k, 2.0, 100, 1e-6, 42).unwrap();
1553
1554        let new_data = FdMatrix::zeros(4, m);
1555        let membership = result.predict(&new_data, &t).unwrap();
1556
1557        for i in 0..4 {
1558            let sum: f64 = (0..k).map(|c| membership[(i, c)]).sum();
1559            assert!(
1560                (sum - 1.0).abs() < 1e-6,
1561                "Row {} membership should sum to 1, got {}",
1562                i,
1563                sum
1564            );
1565        }
1566    }
1567
1568    #[test]
1569    fn test_fuzzy_predict_membership_in_range() {
1570        let m = 30;
1571        let n_per = 5;
1572        let (data, t) = generate_two_clusters(n_per, m);
1573
1574        let result = fuzzy_cmeans_fd(&data, &t, 2, 2.0, 100, 1e-6, 42).unwrap();
1575
1576        let new_data = FdMatrix::zeros(4, m);
1577        let membership = result.predict(&new_data, &t).unwrap();
1578
1579        for &mem in membership.as_slice() {
1580            assert!((0.0..=1.0 + 1e-10).contains(&mem));
1581        }
1582    }
1583
1584    #[test]
1585    fn test_fuzzy_predict_reproduces_training() {
1586        let m = 30;
1587        let n_per = 5;
1588        let n = 2 * n_per;
1589        let k = 2;
1590        let (data, t) = generate_two_clusters(n_per, m);
1591
1592        let result = fuzzy_cmeans_fd(&data, &t, k, 2.0, 100, 1e-6, 42).unwrap();
1593
1594        // Predicting on training data should reproduce similar memberships
1595        let predicted = result.predict(&data, &t).unwrap();
1596        for i in 0..n {
1597            for c in 0..k {
1598                assert!(
1599                    (predicted[(i, c)] - result.membership[(i, c)]).abs() < 1e-4,
1600                    "Membership mismatch at ({}, {}): {} vs {}",
1601                    i,
1602                    c,
1603                    predicted[(i, c)],
1604                    result.membership[(i, c)]
1605                );
1606            }
1607        }
1608    }
1609
1610    #[test]
1611    fn test_fuzzy_predict_dimension_mismatch() {
1612        let m = 30;
1613        let n_per = 5;
1614        let (data, t) = generate_two_clusters(n_per, m);
1615        let result = fuzzy_cmeans_fd(&data, &t, 2, 2.0, 100, 1e-6, 42).unwrap();
1616
1617        // Wrong number of columns
1618        let wrong_data = FdMatrix::zeros(3, 20);
1619        assert!(result.predict(&wrong_data, &t).is_err());
1620
1621        // Wrong argvals length
1622        let new_data = FdMatrix::zeros(3, m);
1623        let wrong_t: Vec<f64> = (0..20).map(|i| i as f64 / 19.0).collect();
1624        assert!(result.predict(&new_data, &wrong_t).is_err());
1625    }
1626
1627    #[test]
1628    fn test_fuzzy_predict_fuzziness_stored() {
1629        let m = 30;
1630        let n_per = 5;
1631        let (data, t) = generate_two_clusters(n_per, m);
1632
1633        let result = fuzzy_cmeans_fd(&data, &t, 2, 2.5, 100, 1e-6, 42).unwrap();
1634        assert!((result.fuzziness - 2.5).abs() < 1e-10);
1635    }
1636}