Skip to main content

fdars_core/alignment/
nd.rs

1//! Multidimensional (R^d) SRSF transforms and elastic alignment.
2
3use super::srsf::reparameterize_curve;
4use super::{
5    dp_alignment_core, dp_edge_weight, dp_grid_solve, dp_lambda_penalty, dp_path_to_gamma,
6};
7use crate::error::FdarError;
8use crate::helpers::{cumulative_trapz, l2_distance, simpsons_weights};
9use crate::iter_maybe_parallel;
10use crate::matrix::{FdCurveSet, FdMatrix};
11#[cfg(feature = "parallel")]
12use rayon::iter::ParallelIterator;
13
14/// Result of aligning multidimensional (R^d) curves.
15#[derive(Debug, Clone, PartialEq)]
16#[non_exhaustive]
17pub struct AlignmentResultNd {
18    /// Optimal warping function (length m), same for all dimensions.
19    pub gamma: Vec<f64>,
20    /// Aligned curve: d vectors, each length m.
21    pub f_aligned: Vec<Vec<f64>>,
22    /// Elastic distance after alignment.
23    pub distance: f64,
24}
25
26/// Scale derivative vector at one point by 1/√‖f'‖, writing into result_dims.
27#[inline]
28fn srsf_scale_point(derivs: &[FdMatrix], result_dims: &mut [FdMatrix], i: usize, j: usize) {
29    let d = derivs.len();
30    let norm_sq: f64 = derivs.iter().map(|dd| dd[(i, j)].powi(2)).sum();
31    let norm = norm_sq.sqrt();
32    if norm < 1e-15 {
33        for k in 0..d {
34            result_dims[k][(i, j)] = 0.0;
35        }
36    } else {
37        let scale = 1.0 / norm.sqrt();
38        for k in 0..d {
39            result_dims[k][(i, j)] = derivs[k][(i, j)] * scale;
40        }
41    }
42}
43
44/// Compute the SRSF transform for multidimensional (R^d) curves.
45///
46/// For f: \[0,1\] → R^d, the SRSF is q(t) = f'(t) / √‖f'(t)‖ where ‖·‖ is the
47/// Euclidean norm in R^d. For d=1 this reduces to `sign(f') · √|f'|`.
48///
49/// # Arguments
50/// * `data` — Set of n curves in R^d, each with m evaluation points
51/// * `argvals` — Evaluation points (length m)
52///
53/// # Returns
54/// `FdCurveSet` of SRSF values with the same shape as input.
55pub fn srsf_transform_nd(data: &FdCurveSet, argvals: &[f64]) -> FdCurveSet {
56    let d = data.ndim();
57    let n = data.ncurves();
58    let m = data.npoints();
59
60    if d == 0 || n == 0 || m == 0 || argvals.len() != m {
61        return FdCurveSet {
62            dims: (0..d).map(|_| FdMatrix::zeros(n, m)).collect(),
63        };
64    }
65
66    let derivs: Vec<FdMatrix> = data
67        .dims
68        .iter()
69        .map(|dim_mat| crate::fdata::deriv_1d(dim_mat, argvals, 1))
70        .collect();
71
72    let mut result_dims: Vec<FdMatrix> = (0..d).map(|_| FdMatrix::zeros(n, m)).collect();
73    for i in 0..n {
74        for j in 0..m {
75            srsf_scale_point(&derivs, &mut result_dims, i, j);
76        }
77    }
78
79    FdCurveSet { dims: result_dims }
80}
81
82/// Reconstruct an R^d curve from its SRSF.
83///
84/// Given d-dimensional SRSF vectors and initial point f0, reconstructs:
85/// `f_k(t) = f0_k + ∫₀ᵗ q_k(s) · ‖q(s)‖ ds` for each dimension k.
86///
87/// # Arguments
88/// * `q` — SRSF: d vectors, each length m
89/// * `argvals` — Evaluation points (length m)
90/// * `f0` — Initial values in R^d (length d)
91///
92/// # Returns
93/// Reconstructed curve: d vectors, each length m.
94pub fn srsf_inverse_nd(q: &[Vec<f64>], argvals: &[f64], f0: &[f64]) -> Vec<Vec<f64>> {
95    let d = q.len();
96    if d == 0 {
97        return Vec::new();
98    }
99    let m = q[0].len();
100    if m == 0 {
101        return vec![Vec::new(); d];
102    }
103
104    // Compute ||q(t)|| at each time point
105    let norms: Vec<f64> = (0..m)
106        .map(|j| {
107            let norm_sq: f64 = q.iter().map(|qk| qk[j].powi(2)).sum();
108            norm_sq.sqrt()
109        })
110        .collect();
111
112    // For each dimension, integrand = q_k(t) * ||q(t)||
113    let mut result = Vec::with_capacity(d);
114    for k in 0..d {
115        let integrand: Vec<f64> = (0..m).map(|j| q[k][j] * norms[j]).collect();
116        let integral = cumulative_trapz(&integrand, argvals);
117        let curve: Vec<f64> = integral.iter().map(|&v| f0[k] + v).collect();
118        result.push(curve);
119    }
120
121    result
122}
123
124/// Core DP alignment for R^d SRSFs.
125///
126/// Same DP grid and coprime neighborhood as `dp_alignment_core`, but edge weight
127/// is the sum of `dp_edge_weight` over d dimensions.
128fn dp_alignment_core_nd(
129    q1: &[Vec<f64>],
130    q2: &[Vec<f64>],
131    argvals: &[f64],
132    lambda: f64,
133) -> Vec<f64> {
134    let d = q1.len();
135    let m = argvals.len();
136    if m < 2 || d == 0 {
137        return argvals.to_vec();
138    }
139
140    // For d=1, delegate to existing implementation for exact backward compat
141    if d == 1 {
142        return dp_alignment_core(&q1[0], &q2[0], argvals, lambda);
143    }
144
145    // Normalize each dimension's SRSF to unit L2 norm
146    let q1n: Vec<Vec<f64>> = q1
147        .iter()
148        .map(|qk| {
149            let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
150            qk.iter().map(|&v| v / norm).collect()
151        })
152        .collect();
153    let q2n: Vec<Vec<f64>> = q2
154        .iter()
155        .map(|qk| {
156            let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
157            qk.iter().map(|&v| v / norm).collect()
158        })
159        .collect();
160
161    let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
162        let w: f64 = (0..d)
163            .map(|k| dp_edge_weight(&q1n[k], &q2n[k], argvals, sc, tc, sr, tr))
164            .sum();
165        w + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
166    });
167
168    dp_path_to_gamma(&path, argvals)
169}
170
171/// Align an R^d curve f2 to f1 using the elastic framework.
172///
173/// Finds the optimal warping γ (shared across all dimensions) such that
174/// f2∘γ is as close as possible to f1 in the elastic metric.
175///
176/// # Arguments
177/// * `f1` — Target curves (d dimensions)
178/// * `f2` — Curves to align (d dimensions)
179/// * `argvals` — Evaluation points (length m)
180/// * `lambda` — Penalty weight (0.0 = no penalty)
181pub fn elastic_align_pair_nd(
182    f1: &FdCurveSet,
183    f2: &FdCurveSet,
184    argvals: &[f64],
185    lambda: f64,
186) -> AlignmentResultNd {
187    let d = f1.ndim();
188    let m = f1.npoints();
189
190    // Compute SRSFs
191    let q1_set = srsf_transform_nd(f1, argvals);
192    let q2_set = srsf_transform_nd(f2, argvals);
193
194    // Extract first curve from each dimension
195    let q1: Vec<Vec<f64>> = q1_set.dims.iter().map(|dm| dm.row(0)).collect();
196    let q2: Vec<Vec<f64>> = q2_set.dims.iter().map(|dm| dm.row(0)).collect();
197
198    // DP alignment using summed cost over dimensions
199    let gamma = dp_alignment_core_nd(&q1, &q2, argvals, lambda);
200
201    // Apply warping to f2 in each dimension
202    let f_aligned: Vec<Vec<f64>> = f2
203        .dims
204        .iter()
205        .map(|dm| {
206            let row = dm.row(0);
207            reparameterize_curve(&row, argvals, &gamma)
208        })
209        .collect();
210
211    // Compute elastic distance: sum of squared L2 distances between aligned SRSFs
212    let f_aligned_set = {
213        let dims: Vec<FdMatrix> = f_aligned
214            .iter()
215            .map(|fa| {
216                FdMatrix::from_slice(fa, 1, m).expect("dimension invariant: data.len() == n * m")
217            })
218            .collect();
219        FdCurveSet { dims }
220    };
221    let q_aligned = srsf_transform_nd(&f_aligned_set, argvals);
222    let weights = simpsons_weights(argvals);
223
224    let mut dist_sq = 0.0;
225    for k in 0..d {
226        let q1k = q1_set.dims[k].row(0);
227        let qak = q_aligned.dims[k].row(0);
228        let d_k = l2_distance(&q1k, &qak, &weights);
229        dist_sq += d_k * d_k;
230    }
231
232    AlignmentResultNd {
233        gamma,
234        f_aligned,
235        distance: dist_sq.sqrt(),
236    }
237}
238
239/// Elastic distance between two R^d curves.
240///
241/// Aligns f2 to f1 and returns the post-alignment SRSF distance.
242pub fn elastic_distance_nd(f1: &FdCurveSet, f2: &FdCurveSet, argvals: &[f64], lambda: f64) -> f64 {
243    elastic_align_pair_nd(f1, f2, argvals, lambda).distance
244}
245
246// ─── Karcher Mean for N-d Curves ─────────────────────────────────────────
247
248/// Result of the Karcher mean computation for multidimensional (R^d) curves.
249#[derive(Debug, Clone, PartialEq)]
250#[non_exhaustive]
251pub struct KarcherMeanResultNd {
252    /// Karcher mean curve: d vectors of length m.
253    pub mean: Vec<Vec<f64>>,
254    /// SRSF of the Karcher mean: d vectors of length m.
255    pub mean_srsf: Vec<Vec<f64>>,
256    /// Final warping functions (n x m).
257    pub gammas: FdMatrix,
258    /// Curves aligned to the mean: d matrices, each n x m.
259    pub aligned_data: Vec<FdMatrix>,
260    /// Number of iterations used.
261    pub n_iter: usize,
262    /// Whether the algorithm converged.
263    pub converged: bool,
264}
265
266/// Result of PCA on aligned multidimensional (R^d) curves.
267#[derive(Debug, Clone, PartialEq)]
268#[non_exhaustive]
269pub struct PcaNdResult {
270    /// PC scores (n x ncomp).
271    pub scores: FdMatrix,
272    /// Principal components per dimension: d matrices, each ncomp x m.
273    pub components: Vec<FdMatrix>,
274    /// Explained variance for each component.
275    pub explained_variance: Vec<f64>,
276    /// Cumulative proportion of variance explained.
277    pub cumulative_variance: Vec<f64>,
278    /// Covariance eigenvalues (same as explained_variance for convenience).
279    pub covariance_eigenvalues: Vec<f64>,
280}
281
282/// Compute SRSF for a single R^d curve (d vectors of length m).
283fn srsf_single_nd(curve: &[Vec<f64>], argvals: &[f64]) -> Vec<Vec<f64>> {
284    let m = argvals.len();
285    let dims: Vec<FdMatrix> = curve
286        .iter()
287        .map(|c| FdMatrix::from_slice(c, 1, m).expect("dimension invariant: data.len() == n * m"))
288        .collect();
289    let cs = FdCurveSet { dims };
290    let q_set = srsf_transform_nd(&cs, argvals);
291    q_set.dims.iter().map(|dm| dm.row(0)).collect()
292}
293
294/// Compute the relative change between two N-d mean SRSFs.
295fn relative_change_nd(old: &[Vec<f64>], new: &[Vec<f64>]) -> f64 {
296    let mut diff_sq = 0.0;
297    let mut old_sq = 0.0;
298    for (qo, qn) in old.iter().zip(new.iter()) {
299        for (&a, &b) in qo.iter().zip(qn.iter()) {
300            diff_sq += (a - b).powi(2);
301            old_sq += a * a;
302        }
303    }
304    diff_sq.sqrt() / old_sq.sqrt().max(1e-10)
305}
306
307/// Select the curve whose SRSF is closest to the pointwise mean SRSF.
308///
309/// Returns the index of the template curve.
310fn select_template_nd(data: &[FdCurveSet], srsfs: &[Vec<Vec<f64>>]) -> usize {
311    let n = data.len();
312    let d = srsfs[0].len();
313    let m = srsfs[0][0].len();
314
315    // Compute pointwise mean SRSF
316    let mut mean_q: Vec<Vec<f64>> = vec![vec![0.0; m]; d];
317    for q in srsfs {
318        for k in 0..d {
319            for j in 0..m {
320                mean_q[k][j] += q[k][j];
321            }
322        }
323    }
324    for k in 0..d {
325        for j in 0..m {
326            mean_q[k][j] /= n as f64;
327        }
328    }
329
330    // Find curve closest to mean
331    let mut min_dist = f64::INFINITY;
332    let mut min_idx = 0;
333    for (i, q) in srsfs.iter().enumerate() {
334        let mut dist_sq = 0.0;
335        for k in 0..d {
336            for j in 0..m {
337                dist_sq += (q[k][j] - mean_q[k][j]).powi(2);
338            }
339        }
340        if dist_sq < min_dist {
341            min_dist = dist_sq;
342            min_idx = i;
343        }
344    }
345    min_idx
346}
347
348/// Compute the Karcher (Frechet) mean for multidimensional (R^d) curves.
349///
350/// Iteratively aligns all N-d curves to the current mean estimate in SRSF space,
351/// computes the pointwise mean of aligned SRSFs per dimension, and reconstructs
352/// the mean curve.
353///
354/// # Arguments
355/// * `data` — Slice of n `FdCurveSet`s, each with d dimensions and m evaluation points
356/// * `argvals` — Evaluation points (length m)
357/// * `max_iter` — Maximum number of iterations
358/// * `tol` — Convergence tolerance (relative SRSF change)
359/// * `lambda` — Roughness penalty weight (0.0 = no penalty)
360///
361/// # Errors
362/// Returns `FdarError::InvalidDimension` if inputs are inconsistent.
363#[must_use = "expensive computation whose result should not be discarded"]
364pub fn karcher_mean_nd(
365    data: &[FdCurveSet],
366    argvals: &[f64],
367    max_iter: usize,
368    tol: f64,
369    lambda: f64,
370) -> Result<KarcherMeanResultNd, FdarError> {
371    let n = data.len();
372    if n < 2 {
373        return Err(FdarError::InvalidDimension {
374            parameter: "data",
375            expected: "at least 2 curves".to_string(),
376            actual: format!("{n}"),
377        });
378    }
379
380    let d = data[0].ndim();
381    let m = data[0].npoints();
382    if d == 0 || m < 2 || argvals.len() != m {
383        return Err(FdarError::InvalidDimension {
384            parameter: "data/argvals",
385            expected: format!("d > 0, m >= 2, argvals.len() == m (m={m})"),
386            actual: format!("d={d}, m={m}, argvals.len()={}", argvals.len()),
387        });
388    }
389
390    // Verify all curves have the same dimensions
391    for (i, cs) in data.iter().enumerate() {
392        if cs.ndim() != d || cs.npoints() != m {
393            return Err(FdarError::InvalidDimension {
394                parameter: "data",
395                expected: format!("all curves d={d}, m={m}"),
396                actual: format!("curve {i}: d={}, m={}", cs.ndim(), cs.npoints()),
397            });
398        }
399    }
400
401    // Extract curves as Vec<Vec<f64>> (d vectors) per observation
402    let curves: Vec<Vec<Vec<f64>>> = (0..n)
403        .map(|i| data[i].dims.iter().map(|dm| dm.row(0)).collect())
404        .collect();
405
406    // Compute SRSFs for all curves
407    let srsfs: Vec<Vec<Vec<f64>>> = curves.iter().map(|c| srsf_single_nd(c, argvals)).collect();
408
409    // Select template (closest to mean SRSF)
410    let template_idx = select_template_nd(data, &srsfs);
411    let mut mu_q = srsfs[template_idx].clone();
412    let mut mu_f = curves[template_idx].clone();
413
414    // Iterative alignment loop
415    let mut converged = false;
416    let mut n_iter = 0;
417    let mut gammas = FdMatrix::zeros(n, m);
418
419    for iter in 0..max_iter {
420        n_iter = iter + 1;
421
422        // Align all curves to current mean (parallel)
423        let align_results: Vec<(Vec<f64>, Vec<Vec<f64>>)> = iter_maybe_parallel!(0..n)
424            .map(|i| {
425                // Build single-curve FdCurveSet for mean and curve i
426                let mean_cs = {
427                    let dims: Vec<FdMatrix> = mu_f
428                        .iter()
429                        .map(|v| {
430                            FdMatrix::from_slice(v, 1, m)
431                                .expect("dimension invariant: data.len() == n * m")
432                        })
433                        .collect();
434                    FdCurveSet { dims }
435                };
436                let curve_cs = {
437                    let dims: Vec<FdMatrix> = curves[i]
438                        .iter()
439                        .map(|v| {
440                            FdMatrix::from_slice(v, 1, m)
441                                .expect("dimension invariant: data.len() == n * m")
442                        })
443                        .collect();
444                    FdCurveSet { dims }
445                };
446
447                let result = elastic_align_pair_nd(&mean_cs, &curve_cs, argvals, lambda);
448                (result.gamma, result.f_aligned)
449            })
450            .collect();
451
452        // Store gammas and compute aligned SRSFs
453        let mut new_mu_q: Vec<Vec<f64>> = vec![vec![0.0; m]; d];
454        for (i, (gamma, f_aligned)) in align_results.iter().enumerate() {
455            for j in 0..m {
456                gammas[(i, j)] = gamma[j];
457            }
458
459            // Compute SRSF of aligned curve
460            let q_aligned = srsf_single_nd(f_aligned, argvals);
461            for k in 0..d {
462                for j in 0..m {
463                    new_mu_q[k][j] += q_aligned[k][j];
464                }
465            }
466        }
467        for k in 0..d {
468            for j in 0..m {
469                new_mu_q[k][j] /= n as f64;
470            }
471        }
472
473        // Check convergence
474        let rel = relative_change_nd(&mu_q, &new_mu_q);
475        mu_q = new_mu_q;
476
477        // Reconstruct mean curve from mean SRSF
478        let f0: Vec<f64> = mu_f.iter().map(|v| v[0]).collect();
479        mu_f = srsf_inverse_nd(&mu_q, argvals, &f0);
480
481        if rel < tol {
482            converged = true;
483            break;
484        }
485    }
486
487    // Post-centering: center the warps via sqrt_mean_inverse
488    let gam_inv = super::sqrt_mean_inverse(&gammas, argvals);
489    for i in 0..n {
490        let gam_i: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
491        let gam_centered = reparameterize_curve(&gam_i, argvals, &gam_inv);
492        for j in 0..m {
493            gammas[(i, j)] = gam_centered[j];
494        }
495    }
496
497    // Recompute aligned data using final centered warps
498    let mut aligned_data: Vec<FdMatrix> = (0..d).map(|_| FdMatrix::zeros(n, m)).collect();
499    for i in 0..n {
500        let gamma_i: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
501        for k in 0..d {
502            let f_aligned = reparameterize_curve(&curves[i][k], argvals, &gamma_i);
503            for j in 0..m {
504                aligned_data[k][(i, j)] = f_aligned[j];
505            }
506        }
507    }
508
509    // Recompute mean from final aligned data
510    let mut mean: Vec<Vec<f64>> = vec![vec![0.0; m]; d];
511    for k in 0..d {
512        for j in 0..m {
513            for i in 0..n {
514                mean[k][j] += aligned_data[k][(i, j)];
515            }
516            mean[k][j] /= n as f64;
517        }
518    }
519
520    // Recompute mean SRSF
521    let mean_srsf = srsf_single_nd(&mean, argvals);
522
523    Ok(KarcherMeanResultNd {
524        mean,
525        mean_srsf,
526        gammas,
527        aligned_data,
528        n_iter,
529        converged,
530    })
531}
532
533/// Compute the cross-dimensional covariance matrix of aligned N-d curves.
534///
535/// Stacks all d dimensions of aligned curves into a single (n x d*m) matrix,
536/// centers columns, and returns X^T X / (n-1) as a (d*m x d*m) covariance matrix.
537///
538/// # Errors
539/// Returns `FdarError::InvalidDimension` if d*m exceeds 10000 (to prevent
540/// excessive memory usage) or if input dimensions are inconsistent.
541#[must_use = "expensive computation whose result should not be discarded"]
542pub fn karcher_covariance_nd(
543    result: &KarcherMeanResultNd,
544    argvals: &[f64],
545) -> Result<FdMatrix, FdarError> {
546    let d = result.aligned_data.len();
547    if d == 0 {
548        return Err(FdarError::InvalidDimension {
549            parameter: "aligned_data",
550            expected: "d > 0".to_string(),
551            actual: "0".to_string(),
552        });
553    }
554    let (n, m) = result.aligned_data[0].shape();
555    if argvals.len() != m {
556        return Err(FdarError::InvalidDimension {
557            parameter: "argvals",
558            expected: format!("{m}"),
559            actual: format!("{}", argvals.len()),
560        });
561    }
562
563    let dm = d * m;
564    if dm > 10_000 {
565        return Err(FdarError::InvalidParameter {
566            parameter: "d*m",
567            message: format!(
568                "d*m = {dm} exceeds limit of 10000; covariance matrix would be too large"
569            ),
570        });
571    }
572
573    if n < 2 {
574        return Err(FdarError::InvalidDimension {
575            parameter: "aligned_data",
576            expected: "n >= 2".to_string(),
577            actual: format!("{n}"),
578        });
579    }
580
581    // Build (n x dm) stacked matrix
582    let mut stacked = FdMatrix::zeros(n, dm);
583    for k in 0..d {
584        for i in 0..n {
585            for j in 0..m {
586                stacked[(i, k * m + j)] = result.aligned_data[k][(i, j)];
587            }
588        }
589    }
590
591    // Center columns
592    let mut col_mean = vec![0.0; dm];
593    for j in 0..dm {
594        for i in 0..n {
595            col_mean[j] += stacked[(i, j)];
596        }
597        col_mean[j] /= n as f64;
598    }
599    for i in 0..n {
600        for j in 0..dm {
601            stacked[(i, j)] -= col_mean[j];
602        }
603    }
604
605    // Compute covariance: X^T X / (n-1)
606    let nf = (n - 1) as f64;
607    let mut cov = FdMatrix::zeros(dm, dm);
608    for p in 0..dm {
609        for q in p..dm {
610            let mut s = 0.0;
611            for i in 0..n {
612                s += stacked[(i, p)] * stacked[(i, q)];
613            }
614            s /= nf;
615            cov[(p, q)] = s;
616            cov[(q, p)] = s;
617        }
618    }
619
620    Ok(cov)
621}
622
623/// Perform PCA on aligned multidimensional (R^d) curves.
624///
625/// Stacks aligned data from all d dimensions into an (n x d*m) matrix, centers,
626/// computes the SVD, and extracts principal components and scores.
627///
628/// # Arguments
629/// * `result` — Pre-computed Karcher mean result for N-d curves
630/// * `argvals` — Evaluation points (length m)
631/// * `ncomp` — Number of principal components to extract
632///
633/// # Errors
634/// Returns `FdarError` if inputs are invalid or SVD fails.
635#[must_use = "expensive computation whose result should not be discarded"]
636pub fn pca_nd(
637    result: &KarcherMeanResultNd,
638    argvals: &[f64],
639    ncomp: usize,
640) -> Result<PcaNdResult, FdarError> {
641    let d = result.aligned_data.len();
642    if d == 0 {
643        return Err(FdarError::InvalidDimension {
644            parameter: "aligned_data",
645            expected: "d > 0".to_string(),
646            actual: "0".to_string(),
647        });
648    }
649    let (n, m) = result.aligned_data[0].shape();
650    if n < 2 || m < 2 || ncomp < 1 || argvals.len() != m {
651        return Err(FdarError::InvalidDimension {
652            parameter: "aligned_data/argvals/ncomp",
653            expected: "n >= 2, m >= 2, ncomp >= 1, argvals.len() == m".to_string(),
654            actual: format!(
655                "n={n}, m={m}, ncomp={ncomp}, argvals.len()={}",
656                argvals.len()
657            ),
658        });
659    }
660    let ncomp = ncomp.min(n - 1);
661    let dm = d * m;
662
663    // Build (n x dm) stacked matrix and center
664    let mut stacked = FdMatrix::zeros(n, dm);
665    for k in 0..d {
666        for i in 0..n {
667            for j in 0..m {
668                stacked[(i, k * m + j)] = result.aligned_data[k][(i, j)];
669            }
670        }
671    }
672
673    // Center columns
674    let mut col_mean = vec![0.0; dm];
675    for j in 0..dm {
676        for i in 0..n {
677            col_mean[j] += stacked[(i, j)];
678        }
679        col_mean[j] /= n as f64;
680    }
681    for i in 0..n {
682        for j in 0..dm {
683            stacked[(i, j)] -= col_mean[j];
684        }
685    }
686
687    // Economy SVD: compute Gram matrix G = X X^T / (n-1), size n x n
688    // (much smaller than dm x dm when dm >> n)
689    let nf = (n - 1) as f64;
690    let mut gram = FdMatrix::zeros(n, n);
691    for i in 0..n {
692        for j in i..n {
693            let mut s = 0.0;
694            for p in 0..dm {
695                s += stacked[(i, p)] * stacked[(j, p)];
696            }
697            s /= nf;
698            gram[(i, j)] = s;
699            gram[(j, i)] = s;
700        }
701    }
702
703    // Eigen-decompose Gram matrix via nalgebra SVD (symmetric, so SVD = eigendecomposition)
704    use nalgebra::SVD;
705    let svd = SVD::new(gram.to_dmatrix(), true, true);
706    let u = svd.u.as_ref().ok_or_else(|| FdarError::ComputationFailed {
707        operation: "SVD",
708        detail: "SVD failed to compute U matrix for Gram matrix".to_string(),
709    })?;
710
711    // Eigenvalues of Gram = singular values of Gram = eigenvalues of X X^T / (n-1)
712    // These are also the eigenvalues of the covariance matrix (for the top n components)
713    let eigenvalues: Vec<f64> = svd.singular_values.iter().take(ncomp).copied().collect();
714
715    // Scores: score_ik = u_ik * sqrt(lambda_k * (n-1))
716    // Since Gram = X X^T / (n-1), and SVD(Gram) = U S U^T,
717    // the scores of the data are: X V = U * sqrt(S * (n-1))
718    let mut scores = FdMatrix::zeros(n, ncomp);
719    for k in 0..ncomp {
720        let scale = (eigenvalues[k] * nf).sqrt();
721        for i in 0..n {
722            scores[(i, k)] = u[(i, k)] * scale;
723        }
724    }
725
726    // Loadings: V_k = X^T U_k / sqrt(lambda_k * (n-1))
727    // Reshape back into d matrices of (ncomp x m)
728    let mut components: Vec<FdMatrix> = (0..d).map(|_| FdMatrix::zeros(ncomp, m)).collect();
729    for k in 0..ncomp {
730        let scale = (eigenvalues[k] * nf).sqrt().max(1e-15);
731        let mut loading = vec![0.0; dm];
732        for p in 0..dm {
733            let mut s = 0.0;
734            for i in 0..n {
735                s += stacked[(i, p)] * u[(i, k)];
736            }
737            loading[p] = s / scale;
738        }
739
740        // Distribute into per-dimension matrices
741        for dim in 0..d {
742            for j in 0..m {
743                components[dim][(k, j)] = loading[dim * m + j];
744            }
745        }
746    }
747
748    // Cumulative variance
749    let total_var: f64 = svd.singular_values.iter().sum();
750    let mut cumulative_variance = Vec::with_capacity(ncomp);
751    let mut running = 0.0;
752    for ev in &eigenvalues {
753        running += ev;
754        cumulative_variance.push(if total_var > 0.0 {
755            running / total_var
756        } else {
757            0.0
758        });
759    }
760
761    // Explained variance = eigenvalues
762    let explained_variance = eigenvalues.clone();
763    let covariance_eigenvalues = eigenvalues;
764
765    Ok(PcaNdResult {
766        scores,
767        components,
768        explained_variance,
769        cumulative_variance,
770        covariance_eigenvalues,
771    })
772}
773
774#[cfg(test)]
775mod tests {
776    use super::*;
777    use std::f64::consts::PI;
778
779    /// Build n identical R^2 curves (circle-like) as FdCurveSets.
780    fn make_identical_curves(n: usize, m: usize) -> (Vec<FdCurveSet>, Vec<f64>) {
781        let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
782        let dim0: Vec<f64> = t.iter().map(|&ti| (2.0 * PI * ti).sin()).collect();
783        let dim1: Vec<f64> = t.iter().map(|&ti| (2.0 * PI * ti).cos()).collect();
784
785        let data: Vec<FdCurveSet> = (0..n)
786            .map(|_| {
787                let m0 = FdMatrix::from_slice(&dim0, 1, m)
788                    .expect("dimension invariant: data.len() == n * m");
789                let m1 = FdMatrix::from_slice(&dim1, 1, m)
790                    .expect("dimension invariant: data.len() == n * m");
791                FdCurveSet { dims: vec![m0, m1] }
792            })
793            .collect();
794        (data, t)
795    }
796
797    /// Build n shifted R^2 sine curves.
798    fn make_shifted_curves(n: usize, m: usize) -> (Vec<FdCurveSet>, Vec<f64>) {
799        let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
800        let data: Vec<FdCurveSet> = (0..n)
801            .map(|i| {
802                let shift = 0.05 * (i as f64 - n as f64 / 2.0);
803                let dim0: Vec<f64> = t
804                    .iter()
805                    .map(|&ti| (2.0 * PI * (ti + shift)).sin())
806                    .collect();
807                let dim1: Vec<f64> = t
808                    .iter()
809                    .map(|&ti| (2.0 * PI * (ti + shift)).cos())
810                    .collect();
811                let m0 = FdMatrix::from_slice(&dim0, 1, m)
812                    .expect("dimension invariant: data.len() == n * m");
813                let m1 = FdMatrix::from_slice(&dim1, 1, m)
814                    .expect("dimension invariant: data.len() == n * m");
815                FdCurveSet { dims: vec![m0, m1] }
816            })
817            .collect();
818        (data, t)
819    }
820
821    #[test]
822    fn karcher_mean_nd_identical_curves() {
823        let (data, t) = make_identical_curves(5, 31);
824        let result = karcher_mean_nd(&data, &t, 10, 1e-4, 0.0).expect("should succeed");
825
826        let d = 2;
827        let m = 31;
828
829        // Mean should be close to the input curves
830        let input_dim0: Vec<f64> = t.iter().map(|&ti| (2.0 * PI * ti).sin()).collect();
831        let input_dim1: Vec<f64> = t.iter().map(|&ti| (2.0 * PI * ti).cos()).collect();
832
833        let max_diff_0: f64 = result.mean[0]
834            .iter()
835            .zip(input_dim0.iter())
836            .map(|(&a, &b)| (a - b).abs())
837            .fold(0.0_f64, f64::max);
838        let max_diff_1: f64 = result.mean[1]
839            .iter()
840            .zip(input_dim1.iter())
841            .map(|(&a, &b)| (a - b).abs())
842            .fold(0.0_f64, f64::max);
843
844        assert!(
845            max_diff_0 < 0.3,
846            "Mean dim 0 should be close to input, max diff = {max_diff_0}"
847        );
848        assert!(
849            max_diff_1 < 0.3,
850            "Mean dim 1 should be close to input, max diff = {max_diff_1}"
851        );
852
853        // Gammas should be near-identity
854        let n = 5;
855        for i in 0..n {
856            for j in 0..m {
857                let diff = (result.gammas[(i, j)] - t[j]).abs();
858                assert!(
859                    diff < 0.15,
860                    "Warp for identical curves should be near identity: gamma[{i},{j}] diff = {diff}"
861                );
862            }
863        }
864
865        // Correct number of dimensions
866        assert_eq!(result.mean.len(), d);
867        assert_eq!(result.mean_srsf.len(), d);
868        assert_eq!(result.aligned_data.len(), d);
869    }
870
871    #[test]
872    fn karcher_mean_nd_output_dimensions() {
873        let (data, t) = make_shifted_curves(8, 25);
874        let result = karcher_mean_nd(&data, &t, 5, 1e-3, 0.0).expect("should succeed");
875
876        let n = 8;
877        let m = 25;
878        let d = 2;
879
880        assert_eq!(result.mean.len(), d);
881        assert_eq!(result.mean_srsf.len(), d);
882        for k in 0..d {
883            assert_eq!(result.mean[k].len(), m);
884            assert_eq!(result.mean_srsf[k].len(), m);
885        }
886        assert_eq!(result.gammas.shape(), (n, m));
887        assert_eq!(result.aligned_data.len(), d);
888        for k in 0..d {
889            assert_eq!(result.aligned_data[k].shape(), (n, m));
890        }
891        assert!(result.n_iter <= 5);
892    }
893
894    #[test]
895    fn karcher_mean_nd_convergence() {
896        let (data, t) = make_shifted_curves(10, 31);
897        let result = karcher_mean_nd(&data, &t, 20, 1e-3, 0.0).expect("should succeed");
898
899        // With well-behaved shifted sine curves, algorithm should converge
900        assert!(
901            result.converged,
902            "Algorithm should converge for shifted sine curves, n_iter={}",
903            result.n_iter
904        );
905    }
906
907    #[test]
908    fn pca_nd_basic_properties() {
909        let (data, t) = make_shifted_curves(10, 31);
910        let km = karcher_mean_nd(&data, &t, 10, 1e-3, 0.0).expect("karcher_mean should succeed");
911        let pca = pca_nd(&km, &t, 3).expect("pca_nd should succeed");
912
913        let n = 10;
914        let ncomp = 3;
915        let m = 31;
916
917        // Scores shape
918        assert_eq!(pca.scores.shape(), (n, ncomp));
919
920        // Components shape: d=2, each ncomp x m
921        assert_eq!(pca.components.len(), 2);
922        for comp in &pca.components {
923            assert_eq!(comp.shape(), (ncomp, m));
924        }
925
926        // Explained variance: non-negative
927        for ev in &pca.explained_variance {
928            assert!(
929                *ev >= -1e-10,
930                "Explained variance should be non-negative: {ev}"
931            );
932        }
933
934        // Explained variance should be approximately decreasing
935        for i in 1..pca.explained_variance.len() {
936            assert!(
937                pca.explained_variance[i] <= pca.explained_variance[i - 1] + 1e-8,
938                "Explained variance should be decreasing: {} > {}",
939                pca.explained_variance[i],
940                pca.explained_variance[i - 1]
941            );
942        }
943
944        // Cumulative variance should be increasing
945        for i in 1..pca.cumulative_variance.len() {
946            assert!(
947                pca.cumulative_variance[i] >= pca.cumulative_variance[i - 1] - 1e-10,
948                "Cumulative variance should be increasing"
949            );
950        }
951    }
952
953    #[test]
954    fn karcher_covariance_nd_symmetric() {
955        let (data, t) = make_shifted_curves(8, 21);
956        let km = karcher_mean_nd(&data, &t, 5, 1e-3, 0.0).expect("karcher_mean should succeed");
957        let cov = karcher_covariance_nd(&km, &t).expect("covariance should succeed");
958
959        let dm = 2 * 21;
960        assert_eq!(cov.shape(), (dm, dm));
961
962        // Verify symmetry
963        for p in 0..dm {
964            for q in p..dm {
965                let diff = (cov[(p, q)] - cov[(q, p)]).abs();
966                assert!(
967                    diff < 1e-12,
968                    "Covariance should be symmetric at ({p},{q}): diff = {diff}"
969                );
970            }
971        }
972    }
973}