1use crate::elastic_fpca::{sphere_karcher_mean, warps_to_normalized_psi};
8use crate::error::FdarError;
9use crate::matrix::FdMatrix;
10use crate::warping::{exp_map_sphere, inner_product_l2, inv_exp_map_sphere, l2_norm_l2};
11
12use super::KarcherMeanResult;
13
14#[derive(Debug, Clone, PartialEq)]
18#[non_exhaustive]
19pub struct FpnsResult {
20 pub components: FdMatrix,
22 pub scores: FdMatrix,
24 pub explained_variance: Vec<f64>,
26 pub subsphere_means: Vec<Vec<f64>>,
28}
29
30fn top_singular_vector(mat: &FdMatrix, n: usize, p: usize) -> Vec<f64> {
37 let mut u: Vec<f64> = if n > 0 { mat.row(0) } else { vec![1.0; p] };
39
40 let norm = u.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-15);
42 for v in &mut u {
43 *v /= norm;
44 }
45
46 for _ in 0..200 {
48 let mut w = vec![0.0; n];
50 for i in 0..n {
51 let mut s = 0.0;
52 for j in 0..p {
53 s += mat[(i, j)] * u[j];
54 }
55 w[i] = s;
56 }
57
58 let mut u_new = vec![0.0; p];
60 for j in 0..p {
61 let mut s = 0.0;
62 for i in 0..n {
63 s += mat[(i, j)] * w[i];
64 }
65 u_new[j] = s;
66 }
67
68 let new_norm = u_new.iter().map(|&v| v * v).sum::<f64>().sqrt();
70 if new_norm < 1e-15 {
71 break;
72 }
73 for v in &mut u_new {
74 *v /= new_norm;
75 }
76
77 let dot: f64 = u.iter().zip(u_new.iter()).map(|(&a, &b)| a * b).sum();
79 u = u_new;
80 if (1.0 - dot.abs()) < 1e-10 {
81 break;
82 }
83 }
84
85 u
86}
87
88#[must_use = "expensive computation whose result should not be discarded"]
105pub fn horiz_fpns(
106 karcher: &KarcherMeanResult,
107 argvals: &[f64],
108 ncomp: usize,
109) -> Result<FpnsResult, FdarError> {
110 let (n, m) = karcher.gammas.shape();
111 if n < 2 || m < 2 || ncomp < 1 || argvals.len() != m {
112 return Err(FdarError::InvalidDimension {
113 parameter: "gammas/argvals/ncomp",
114 expected: "n >= 2, m >= 2, ncomp >= 1, argvals.len() == m".to_string(),
115 actual: format!(
116 "n={n}, m={m}, ncomp={ncomp}, argvals.len()={}",
117 argvals.len()
118 ),
119 });
120 }
121 let ncomp = ncomp.min(n - 1).min(m);
122
123 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
124
125 let psis = warps_to_normalized_psi(&karcher.gammas, argvals);
127
128 let mut mu = sphere_karcher_mean(&psis, &time, 50);
130
131 let mut current_psis = psis;
133
134 let mut components = FdMatrix::zeros(ncomp, m);
135 let mut scores = FdMatrix::zeros(n, ncomp);
136 let mut explained_variance = Vec::with_capacity(ncomp);
137 let mut subsphere_means = Vec::with_capacity(ncomp);
138
139 for k in 0..ncomp {
140 let mut shooting = FdMatrix::zeros(n, m);
142 for i in 0..n {
143 let v = inv_exp_map_sphere(&mu, ¤t_psis[i], &time);
144 for j in 0..m {
145 shooting[(i, j)] = v[j];
146 }
147 }
148
149 let e_k = top_singular_vector(&shooting, n, m);
151
152 for j in 0..m {
154 components[(k, j)] = e_k[j];
155 }
156
157 let mut score_vec = vec![0.0; n];
159 for i in 0..n {
160 let v_i: Vec<f64> = (0..m).map(|j| shooting[(i, j)]).collect();
161 let s = inner_product_l2(&v_i, &e_k, &time);
162 scores[(i, k)] = s;
163 score_vec[i] = s;
164 }
165
166 let var_k = score_vec.iter().map(|s| s * s).sum::<f64>() / (n - 1) as f64;
168 explained_variance.push(var_k);
169
170 subsphere_means.push(mu.clone());
172
173 if k + 1 < ncomp {
176 let mut new_psis = Vec::with_capacity(n);
179 for i in 0..n {
180 let v_i: Vec<f64> = (0..m).map(|j| shooting[(i, j)]).collect();
181 let s = score_vec[i];
183 let v_perp: Vec<f64> = v_i
184 .iter()
185 .zip(e_k.iter())
186 .map(|(&v, &e)| v - s * e)
187 .collect();
188 let psi_new = exp_map_sphere(&mu, &v_perp, &time);
190 new_psis.push(psi_new);
191 }
192
193 mu = sphere_karcher_mean(&new_psis, &time, 50);
195 let mu_norm = l2_norm_l2(&mu, &time);
197 if mu_norm > 1e-10 {
198 for v in &mut mu {
199 *v /= mu_norm;
200 }
201 }
202
203 current_psis = new_psis;
204 }
205 }
206
207 Ok(FpnsResult {
208 components,
209 scores,
210 explained_variance,
211 subsphere_means,
212 })
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use crate::alignment::karcher_mean;
219 use crate::matrix::FdMatrix;
220 use std::f64::consts::PI;
221
222 fn generate_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
223 let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
224 let mut data = FdMatrix::zeros(n, m);
225 for i in 0..n {
226 let shift = 0.1 * (i as f64 - n as f64 / 2.0);
227 let scale = 1.0 + 0.2 * (i as f64 / n as f64);
228 for j in 0..m {
229 data[(i, j)] = scale * (2.0 * PI * (t[j] + shift)).sin();
230 }
231 }
232 (data, t)
233 }
234
235 #[test]
236 fn fpns_basic_dimensions() {
237 let (data, t) = generate_test_data(15, 51);
238 let km = karcher_mean(&data, &t, 10, 1e-4, 0.0);
239 let ncomp = 3;
240 let result = horiz_fpns(&km, &t, ncomp).expect("horiz_fpns should succeed");
241
242 assert_eq!(result.scores.shape(), (15, ncomp));
243 assert_eq!(result.components.shape(), (ncomp, 51));
244 assert_eq!(result.explained_variance.len(), ncomp);
245 assert_eq!(result.subsphere_means.len(), ncomp);
246 for mean in &result.subsphere_means {
247 assert_eq!(mean.len(), 51);
248 }
249 }
250
251 #[test]
252 fn fpns_variance_decreasing() {
253 let (data, t) = generate_test_data(15, 51);
254 let km = karcher_mean(&data, &t, 10, 1e-4, 0.0);
255 let result = horiz_fpns(&km, &t, 3).expect("horiz_fpns should succeed");
256
257 for ev in &result.explained_variance {
259 assert!(
260 *ev >= -1e-10,
261 "Explained variance should be non-negative: {ev}"
262 );
263 }
264
265 if result.explained_variance.len() >= 2 {
270 assert!(
271 result.explained_variance[0] >= result.explained_variance[1] * 0.5,
272 "First component should capture substantial variance: {} vs {}",
273 result.explained_variance[0],
274 result.explained_variance[1]
275 );
276 }
277 }
278
279 #[test]
280 fn fpns_subsphere_means_on_sphere() {
281 let (data, t) = generate_test_data(15, 51);
282 let km = karcher_mean(&data, &t, 10, 1e-4, 0.0);
283 let result = horiz_fpns(&km, &t, 3).expect("horiz_fpns should succeed");
284
285 let time: Vec<f64> = (0..51).map(|i| i as f64 / 50.0).collect();
286
287 for (k, mean) in result.subsphere_means.iter().enumerate() {
288 let norm = l2_norm_l2(mean, &time);
289 assert!(
290 (norm - 1.0).abs() < 0.1,
291 "Subsphere mean {k} should have approximately unit L2 norm, got {norm}"
292 );
293 }
294 }
295
296 #[test]
297 fn fpns_ncomp_one_smoke() {
298 let (data, t) = generate_test_data(15, 51);
299 let km = karcher_mean(&data, &t, 10, 1e-4, 0.0);
300 let result = horiz_fpns(&km, &t, 1).expect("horiz_fpns should succeed with ncomp=1");
301
302 assert_eq!(result.scores.shape(), (15, 1));
303 assert_eq!(result.components.shape(), (1, 51));
304 assert_eq!(result.explained_variance.len(), 1);
305 assert_eq!(result.subsphere_means.len(), 1);
306 assert!(result.explained_variance[0] >= 0.0);
307 }
308}