Skip to main content

fdars_core/alignment/
fpns.rs

1//! Horizontal Functional Principal Nested Spheres (FPNS).
2//!
3//! Performs principal nested spheres analysis on warping functions
4//! after elastic alignment. This provides a hierarchical decomposition
5//! of phase variability by sequentially projecting onto great subspheres.
6
7use crate::elastic_fpca::{sphere_karcher_mean, warps_to_normalized_psi};
8use crate::error::FdarError;
9use crate::matrix::FdMatrix;
10use crate::warping::{exp_map_sphere, inner_product_l2, inv_exp_map_sphere, l2_norm_l2};
11
12use super::KarcherMeanResult;
13
14// ─── Types ──────────────────────────────────────────────────────────────────
15
16/// Result of horizontal Functional Principal Nested Spheres (FPNS) analysis.
17#[derive(Debug, Clone, PartialEq)]
18#[non_exhaustive]
19pub struct FpnsResult {
20    /// Principal direction components (ncomp x m_psi).
21    pub components: FdMatrix,
22    /// Scores for each observation and component (n x ncomp).
23    pub scores: FdMatrix,
24    /// Explained variance for each component.
25    pub explained_variance: Vec<f64>,
26    /// Karcher mean on the subsphere at each level (ncomp vectors of length m_psi).
27    pub subsphere_means: Vec<Vec<f64>>,
28}
29
30// ─── Power Iteration ────────────────────────────────────────────────────────
31
32/// Find the top right singular vector of an (n x p) matrix via power iteration.
33///
34/// Computes the top eigenvector of V^T V without forming the full matrix.
35/// Returns the unit vector (length p).
36fn top_singular_vector(mat: &FdMatrix, n: usize, p: usize) -> Vec<f64> {
37    // Initialize with the first row (or a constant if degenerate)
38    let mut u: Vec<f64> = if n > 0 { mat.row(0) } else { vec![1.0; p] };
39
40    // Normalize initial vector
41    let norm = u.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-15);
42    for v in &mut u {
43        *v /= norm;
44    }
45
46    // Power iteration: u <- V^T (V u) / ||...||
47    for _ in 0..200 {
48        // Compute w = V u (length n)
49        let mut w = vec![0.0; n];
50        for i in 0..n {
51            let mut s = 0.0;
52            for j in 0..p {
53                s += mat[(i, j)] * u[j];
54            }
55            w[i] = s;
56        }
57
58        // Compute u_new = V^T w (length p)
59        let mut u_new = vec![0.0; p];
60        for j in 0..p {
61            let mut s = 0.0;
62            for i in 0..n {
63                s += mat[(i, j)] * w[i];
64            }
65            u_new[j] = s;
66        }
67
68        // Normalize
69        let new_norm = u_new.iter().map(|&v| v * v).sum::<f64>().sqrt();
70        if new_norm < 1e-15 {
71            break;
72        }
73        for v in &mut u_new {
74            *v /= new_norm;
75        }
76
77        // Check convergence: |1 - |u . u_new|| < tol
78        let dot: f64 = u.iter().zip(u_new.iter()).map(|(&a, &b)| a * b).sum();
79        u = u_new;
80        if (1.0 - dot.abs()) < 1e-10 {
81            break;
82        }
83    }
84
85    u
86}
87
88// ─── Horizontal FPNS ────────────────────────────────────────────────────────
89
90/// Perform horizontal Functional Principal Nested Spheres (FPNS) analysis.
91///
92/// Decomposes phase variability into nested principal directions on the
93/// Hilbert sphere. Each component captures the dominant mode of variation
94/// on the current subsphere, after which data is projected to a lower-dimensional
95/// subsphere for subsequent components.
96///
97/// # Arguments
98/// * `karcher` — Pre-computed Karcher mean result (with gammas)
99/// * `argvals` — Evaluation points (length m)
100/// * `ncomp` — Number of principal nested sphere components to extract
101///
102/// # Errors
103/// Returns `FdarError` if inputs are invalid.
104#[must_use = "expensive computation whose result should not be discarded"]
105pub fn horiz_fpns(
106    karcher: &KarcherMeanResult,
107    argvals: &[f64],
108    ncomp: usize,
109) -> Result<FpnsResult, FdarError> {
110    let (n, m) = karcher.gammas.shape();
111    if n < 2 || m < 2 || ncomp < 1 || argvals.len() != m {
112        return Err(FdarError::InvalidDimension {
113            parameter: "gammas/argvals/ncomp",
114            expected: "n >= 2, m >= 2, ncomp >= 1, argvals.len() == m".to_string(),
115            actual: format!(
116                "n={n}, m={m}, ncomp={ncomp}, argvals.len()={}",
117                argvals.len()
118            ),
119        });
120    }
121    let ncomp = ncomp.min(n - 1).min(m);
122
123    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
124
125    // Step 1: Convert warps to psi on Hilbert sphere
126    let psis = warps_to_normalized_psi(&karcher.gammas, argvals);
127
128    // Step 2: Compute Karcher mean on sphere
129    let mut mu = sphere_karcher_mean(&psis, &time, 50);
130
131    // Working copy of psis that will be projected at each level
132    let mut current_psis = psis;
133
134    let mut components = FdMatrix::zeros(ncomp, m);
135    let mut scores = FdMatrix::zeros(n, ncomp);
136    let mut explained_variance = Vec::with_capacity(ncomp);
137    let mut subsphere_means = Vec::with_capacity(ncomp);
138
139    for k in 0..ncomp {
140        // Step 3: Compute shooting vectors from current mean
141        let mut shooting = FdMatrix::zeros(n, m);
142        for i in 0..n {
143            let v = inv_exp_map_sphere(&mu, &current_psis[i], &time);
144            for j in 0..m {
145                shooting[(i, j)] = v[j];
146            }
147        }
148
149        // Step 4a: Find principal direction via power iteration on shooting vectors
150        let e_k = top_singular_vector(&shooting, n, m);
151
152        // Store component
153        for j in 0..m {
154            components[(k, j)] = e_k[j];
155        }
156
157        // Step 4d: Compute scores: score_ik = <v_i, e_k>_L2
158        let mut score_vec = vec![0.0; n];
159        for i in 0..n {
160            let v_i: Vec<f64> = (0..m).map(|j| shooting[(i, j)]).collect();
161            let s = inner_product_l2(&v_i, &e_k, &time);
162            scores[(i, k)] = s;
163            score_vec[i] = s;
164        }
165
166        // Step 4e: Explained variance
167        let var_k = score_vec.iter().map(|s| s * s).sum::<f64>() / (n - 1) as f64;
168        explained_variance.push(var_k);
169
170        // Store subsphere mean
171        subsphere_means.push(mu.clone());
172
173        // Step 4f-j: Project onto subsphere (remove component along e_k)
174        // Only needed if there are more components to extract
175        if k + 1 < ncomp {
176            // For each psi_i, compute the perpendicular shooting vector
177            // then map back to sphere for next iteration
178            let mut new_psis = Vec::with_capacity(n);
179            for i in 0..n {
180                let v_i: Vec<f64> = (0..m).map(|j| shooting[(i, j)]).collect();
181                // Remove component along e_k: v_perp = v_i - <v_i, e_k> * e_k
182                let s = score_vec[i];
183                let v_perp: Vec<f64> = v_i
184                    .iter()
185                    .zip(e_k.iter())
186                    .map(|(&v, &e)| v - s * e)
187                    .collect();
188                // Map back to sphere from current mean
189                let psi_new = exp_map_sphere(&mu, &v_perp, &time);
190                new_psis.push(psi_new);
191            }
192
193            // Compute new Karcher mean on the subsphere
194            mu = sphere_karcher_mean(&new_psis, &time, 50);
195            // Normalize mu to unit sphere
196            let mu_norm = l2_norm_l2(&mu, &time);
197            if mu_norm > 1e-10 {
198                for v in &mut mu {
199                    *v /= mu_norm;
200                }
201            }
202
203            current_psis = new_psis;
204        }
205    }
206
207    Ok(FpnsResult {
208        components,
209        scores,
210        explained_variance,
211        subsphere_means,
212    })
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use crate::alignment::karcher_mean;
219    use crate::matrix::FdMatrix;
220    use std::f64::consts::PI;
221
222    fn generate_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
223        let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
224        let mut data = FdMatrix::zeros(n, m);
225        for i in 0..n {
226            let shift = 0.1 * (i as f64 - n as f64 / 2.0);
227            let scale = 1.0 + 0.2 * (i as f64 / n as f64);
228            for j in 0..m {
229                data[(i, j)] = scale * (2.0 * PI * (t[j] + shift)).sin();
230            }
231        }
232        (data, t)
233    }
234
235    #[test]
236    fn fpns_basic_dimensions() {
237        let (data, t) = generate_test_data(15, 51);
238        let km = karcher_mean(&data, &t, 10, 1e-4, 0.0);
239        let ncomp = 3;
240        let result = horiz_fpns(&km, &t, ncomp).expect("horiz_fpns should succeed");
241
242        assert_eq!(result.scores.shape(), (15, ncomp));
243        assert_eq!(result.components.shape(), (ncomp, 51));
244        assert_eq!(result.explained_variance.len(), ncomp);
245        assert_eq!(result.subsphere_means.len(), ncomp);
246        for mean in &result.subsphere_means {
247            assert_eq!(mean.len(), 51);
248        }
249    }
250
251    #[test]
252    fn fpns_variance_decreasing() {
253        let (data, t) = generate_test_data(15, 51);
254        let km = karcher_mean(&data, &t, 10, 1e-4, 0.0);
255        let result = horiz_fpns(&km, &t, 3).expect("horiz_fpns should succeed");
256
257        // Explained variances should be non-negative
258        for ev in &result.explained_variance {
259            assert!(
260                *ev >= -1e-10,
261                "Explained variance should be non-negative: {ev}"
262            );
263        }
264
265        // Variance should be approximately decreasing
266        // (Note: FPNS doesn't strictly guarantee this like PCA, but for well-behaved
267        // data the first component should capture the most variance.)
268        // We use a loose check here.
269        if result.explained_variance.len() >= 2 {
270            assert!(
271                result.explained_variance[0] >= result.explained_variance[1] * 0.5,
272                "First component should capture substantial variance: {} vs {}",
273                result.explained_variance[0],
274                result.explained_variance[1]
275            );
276        }
277    }
278
279    #[test]
280    fn fpns_subsphere_means_on_sphere() {
281        let (data, t) = generate_test_data(15, 51);
282        let km = karcher_mean(&data, &t, 10, 1e-4, 0.0);
283        let result = horiz_fpns(&km, &t, 3).expect("horiz_fpns should succeed");
284
285        let time: Vec<f64> = (0..51).map(|i| i as f64 / 50.0).collect();
286
287        for (k, mean) in result.subsphere_means.iter().enumerate() {
288            let norm = l2_norm_l2(mean, &time);
289            assert!(
290                (norm - 1.0).abs() < 0.1,
291                "Subsphere mean {k} should have approximately unit L2 norm, got {norm}"
292            );
293        }
294    }
295
296    #[test]
297    fn fpns_ncomp_one_smoke() {
298        let (data, t) = generate_test_data(15, 51);
299        let km = karcher_mean(&data, &t, 10, 1e-4, 0.0);
300        let result = horiz_fpns(&km, &t, 1).expect("horiz_fpns should succeed with ncomp=1");
301
302        assert_eq!(result.scores.shape(), (15, 1));
303        assert_eq!(result.components.shape(), (1, 51));
304        assert_eq!(result.explained_variance.len(), 1);
305        assert_eq!(result.subsphere_means.len(), 1);
306        assert!(result.explained_variance[0] >= 0.0);
307    }
308}