use crate::elastic_fpca::{sphere_karcher_mean, warps_to_normalized_psi};
use crate::error::FdarError;
use crate::matrix::FdMatrix;
use crate::warping::{exp_map_sphere, inner_product_l2, inv_exp_map_sphere, l2_norm_l2};
use super::KarcherMeanResult;
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct FpnsResult {
pub components: FdMatrix,
pub scores: FdMatrix,
pub explained_variance: Vec<f64>,
pub subsphere_means: Vec<Vec<f64>>,
}
fn top_singular_vector(mat: &FdMatrix, n: usize, p: usize) -> Vec<f64> {
let mut u: Vec<f64> = if n > 0 { mat.row(0) } else { vec![1.0; p] };
let norm = u.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-15);
for v in &mut u {
*v /= norm;
}
for _ in 0..200 {
let mut w = vec![0.0; n];
for i in 0..n {
let mut s = 0.0;
for j in 0..p {
s += mat[(i, j)] * u[j];
}
w[i] = s;
}
let mut u_new = vec![0.0; p];
for j in 0..p {
let mut s = 0.0;
for i in 0..n {
s += mat[(i, j)] * w[i];
}
u_new[j] = s;
}
let new_norm = u_new.iter().map(|&v| v * v).sum::<f64>().sqrt();
if new_norm < 1e-15 {
break;
}
for v in &mut u_new {
*v /= new_norm;
}
let dot: f64 = u.iter().zip(u_new.iter()).map(|(&a, &b)| a * b).sum();
u = u_new;
if (1.0 - dot.abs()) < 1e-10 {
break;
}
}
u
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn horiz_fpns(
karcher: &KarcherMeanResult,
argvals: &[f64],
ncomp: usize,
) -> Result<FpnsResult, FdarError> {
let (n, m) = karcher.gammas.shape();
if n < 2 || m < 2 || ncomp < 1 || argvals.len() != m {
return Err(FdarError::InvalidDimension {
parameter: "gammas/argvals/ncomp",
expected: "n >= 2, m >= 2, ncomp >= 1, argvals.len() == m".to_string(),
actual: format!(
"n={n}, m={m}, ncomp={ncomp}, argvals.len()={}",
argvals.len()
),
});
}
let ncomp = ncomp.min(n - 1).min(m);
let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
let psis = warps_to_normalized_psi(&karcher.gammas, argvals);
let mut mu = sphere_karcher_mean(&psis, &time, 50);
let mut current_psis = psis;
let mut components = FdMatrix::zeros(ncomp, m);
let mut scores = FdMatrix::zeros(n, ncomp);
let mut explained_variance = Vec::with_capacity(ncomp);
let mut subsphere_means = Vec::with_capacity(ncomp);
for k in 0..ncomp {
let mut shooting = FdMatrix::zeros(n, m);
for i in 0..n {
let v = inv_exp_map_sphere(&mu, ¤t_psis[i], &time);
for j in 0..m {
shooting[(i, j)] = v[j];
}
}
let e_k = top_singular_vector(&shooting, n, m);
for j in 0..m {
components[(k, j)] = e_k[j];
}
let mut score_vec = vec![0.0; n];
for i in 0..n {
let v_i: Vec<f64> = (0..m).map(|j| shooting[(i, j)]).collect();
let s = inner_product_l2(&v_i, &e_k, &time);
scores[(i, k)] = s;
score_vec[i] = s;
}
let var_k = score_vec.iter().map(|s| s * s).sum::<f64>() / (n - 1) as f64;
explained_variance.push(var_k);
subsphere_means.push(mu.clone());
if k + 1 < ncomp {
let mut new_psis = Vec::with_capacity(n);
for i in 0..n {
let v_i: Vec<f64> = (0..m).map(|j| shooting[(i, j)]).collect();
let s = score_vec[i];
let v_perp: Vec<f64> = v_i
.iter()
.zip(e_k.iter())
.map(|(&v, &e)| v - s * e)
.collect();
let psi_new = exp_map_sphere(&mu, &v_perp, &time);
new_psis.push(psi_new);
}
mu = sphere_karcher_mean(&new_psis, &time, 50);
let mu_norm = l2_norm_l2(&mu, &time);
if mu_norm > 1e-10 {
for v in &mut mu {
*v /= mu_norm;
}
}
current_psis = new_psis;
}
}
Ok(FpnsResult {
components,
scores,
explained_variance,
subsphere_means,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::alignment::karcher_mean;
use crate::matrix::FdMatrix;
use std::f64::consts::PI;
fn generate_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
let mut data = FdMatrix::zeros(n, m);
for i in 0..n {
let shift = 0.1 * (i as f64 - n as f64 / 2.0);
let scale = 1.0 + 0.2 * (i as f64 / n as f64);
for j in 0..m {
data[(i, j)] = scale * (2.0 * PI * (t[j] + shift)).sin();
}
}
(data, t)
}
#[test]
fn fpns_basic_dimensions() {
let (data, t) = generate_test_data(15, 51);
let km = karcher_mean(&data, &t, 10, 1e-4, 0.0);
let ncomp = 3;
let result = horiz_fpns(&km, &t, ncomp).expect("horiz_fpns should succeed");
assert_eq!(result.scores.shape(), (15, ncomp));
assert_eq!(result.components.shape(), (ncomp, 51));
assert_eq!(result.explained_variance.len(), ncomp);
assert_eq!(result.subsphere_means.len(), ncomp);
for mean in &result.subsphere_means {
assert_eq!(mean.len(), 51);
}
}
#[test]
fn fpns_variance_decreasing() {
let (data, t) = generate_test_data(15, 51);
let km = karcher_mean(&data, &t, 10, 1e-4, 0.0);
let result = horiz_fpns(&km, &t, 3).expect("horiz_fpns should succeed");
for ev in &result.explained_variance {
assert!(
*ev >= -1e-10,
"Explained variance should be non-negative: {ev}"
);
}
if result.explained_variance.len() >= 2 {
assert!(
result.explained_variance[0] >= result.explained_variance[1] * 0.5,
"First component should capture substantial variance: {} vs {}",
result.explained_variance[0],
result.explained_variance[1]
);
}
}
#[test]
fn fpns_subsphere_means_on_sphere() {
let (data, t) = generate_test_data(15, 51);
let km = karcher_mean(&data, &t, 10, 1e-4, 0.0);
let result = horiz_fpns(&km, &t, 3).expect("horiz_fpns should succeed");
let time: Vec<f64> = (0..51).map(|i| i as f64 / 50.0).collect();
for (k, mean) in result.subsphere_means.iter().enumerate() {
let norm = l2_norm_l2(mean, &time);
assert!(
(norm - 1.0).abs() < 0.1,
"Subsphere mean {k} should have approximately unit L2 norm, got {norm}"
);
}
}
#[test]
fn fpns_ncomp_one_smoke() {
let (data, t) = generate_test_data(15, 51);
let km = karcher_mean(&data, &t, 10, 1e-4, 0.0);
let result = horiz_fpns(&km, &t, 1).expect("horiz_fpns should succeed with ncomp=1");
assert_eq!(result.scores.shape(), (15, 1));
assert_eq!(result.components.shape(), (1, 51));
assert_eq!(result.explained_variance.len(), 1);
assert_eq!(result.subsphere_means.len(), 1);
assert!(result.explained_variance[0] >= 0.0);
}
}