use super::karcher::karcher_mean;
use super::srsf::{srsf_inverse, srsf_transform};
use super::KarcherMeanResult;
use crate::iter_maybe_parallel;
use crate::matrix::FdMatrix;
use crate::smoothing::nadaraya_watson;
use crate::warping::{exp_map_sphere, inv_exp_map_sphere, l2_norm_l2};
#[cfg(feature = "parallel")]
use rayon::iter::ParallelIterator;
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct TsrvfResult {
pub tangent_vectors: FdMatrix,
pub mean: Vec<f64>,
pub mean_srsf: Vec<f64>,
pub mean_srsf_norm: f64,
pub srsf_norms: Vec<f64>,
pub initial_values: Vec<f64>,
pub gammas: FdMatrix,
pub converged: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[non_exhaustive]
pub enum TransportMethod {
#[default]
LogMap,
SchildsLadder,
PoleLadder,
}
fn smooth_aligned_srsfs(srsf: &FdMatrix, m: usize) -> FdMatrix {
let n = srsf.nrows();
let time: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
let bandwidth = 2.0 / (m - 1) as f64;
let mut smoothed = FdMatrix::zeros(n, m);
for i in 0..n {
let qi = srsf.row(i);
let qi_smooth = nadaraya_watson(&time, &qi, &time, bandwidth, "gaussian")
.expect("smoothing valid SRSF data should not fail");
for j in 0..m {
smoothed[(i, j)] = qi_smooth[j];
}
}
smoothed
}
pub(super) fn parallel_transport_schilds(
v: &[f64],
from: &[f64],
to: &[f64],
time: &[f64],
) -> Vec<f64> {
let v_norm = l2_norm_l2(v, time);
if v_norm < 1e-10 {
return vec![0.0; v.len()];
}
let endpoint = exp_map_sphere(from, v, time);
let log_to_ep = inv_exp_map_sphere(to, &endpoint, time);
let half_log: Vec<f64> = log_to_ep.iter().map(|&x| 0.5 * x).collect();
let midpoint = exp_map_sphere(to, &half_log, time);
let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
log_to_mid.iter().map(|&x| 2.0 * x).collect()
}
pub(super) fn parallel_transport_pole(
v: &[f64],
from: &[f64],
to: &[f64],
time: &[f64],
) -> Vec<f64> {
let v_norm = l2_norm_l2(v, time);
if v_norm < 1e-10 {
return vec![0.0; v.len()];
}
let neg_v: Vec<f64> = v.iter().map(|&x| -x).collect();
let pole = exp_map_sphere(from, &neg_v, time);
let log_to_pole = inv_exp_map_sphere(to, &pole, time);
let half_log: Vec<f64> = log_to_pole.iter().map(|&x| 0.5 * x).collect();
let midpoint = exp_map_sphere(to, &half_log, time);
let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
log_to_mid.iter().map(|&x| -2.0 * x).collect()
}
pub fn tsrvf_transform(
data: &FdMatrix,
argvals: &[f64],
max_iter: usize,
tol: f64,
lambda: f64,
) -> TsrvfResult {
let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
tsrvf_from_alignment(&karcher, argvals)
}
pub fn tsrvf_from_alignment(karcher: &KarcherMeanResult, argvals: &[f64]) -> TsrvfResult {
let (n, m) = karcher.aligned_data.shape();
let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
let bandwidth = 2.0 / (m - 1) as f64;
let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian")
.expect("smoothing valid mean SRSF should not fail");
let mean_norm = l2_norm_l2(&mean_srsf_smooth, &time);
let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
} else {
vec![0.0; m]
};
let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
.map(|i| {
let qi = aligned_srsf.row(i);
l2_norm_l2(&qi, &time)
})
.collect();
let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
.map(|i| {
let qi = aligned_srsf.row(i);
let qi_norm = srsf_norms[i];
if qi_norm < 1e-10 || mean_norm < 1e-10 {
return vec![0.0; m];
}
let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
inv_exp_map_sphere(&mu_unit, &qi_unit, &time)
})
.collect();
let mut tangent_vectors = FdMatrix::zeros(n, m);
for i in 0..n {
for j in 0..m {
tangent_vectors[(i, j)] = tangent_data[i][j];
}
}
let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
TsrvfResult {
tangent_vectors,
mean: karcher.mean.clone(),
mean_srsf: mean_srsf_smooth,
mean_srsf_norm: mean_norm,
srsf_norms,
initial_values,
gammas: karcher.gammas.clone(),
converged: karcher.converged,
}
}
pub fn tsrvf_inverse(tsrvf: &TsrvfResult, argvals: &[f64]) -> FdMatrix {
let (n, m) = tsrvf.tangent_vectors.shape();
let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
let mu_unit: Vec<f64> = if tsrvf.mean_srsf_norm > 1e-10 {
tsrvf
.mean_srsf
.iter()
.map(|&q| q / tsrvf.mean_srsf_norm)
.collect()
} else {
vec![0.0; m]
};
let curves: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
.map(|i| {
let vi = tsrvf.tangent_vectors.row(i);
let qi_unit = exp_map_sphere(&mu_unit, &vi, &time);
let qi: Vec<f64> = qi_unit.iter().map(|&q| q * tsrvf.srsf_norms[i]).collect();
srsf_inverse(&qi, argvals, tsrvf.initial_values[i])
})
.collect();
let mut result = FdMatrix::zeros(n, m);
for i in 0..n {
for j in 0..m {
result[(i, j)] = curves[i][j];
}
}
result
}
pub fn tsrvf_transform_with_method(
data: &FdMatrix,
argvals: &[f64],
max_iter: usize,
tol: f64,
lambda: f64,
method: TransportMethod,
) -> TsrvfResult {
let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
tsrvf_from_alignment_with_method(&karcher, argvals, method)
}
pub fn tsrvf_from_alignment_with_method(
karcher: &KarcherMeanResult,
argvals: &[f64],
method: TransportMethod,
) -> TsrvfResult {
if method == TransportMethod::LogMap {
return tsrvf_from_alignment(karcher, argvals);
}
let (n, m) = karcher.aligned_data.shape();
let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
let bandwidth = 2.0 / (m - 1) as f64;
let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian")
.expect("smoothing valid mean SRSF should not fail");
let mean_norm = l2_norm_l2(&mean_srsf_smooth, &time);
let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
} else {
vec![0.0; m]
};
let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
.map(|i| {
let qi = aligned_srsf.row(i);
l2_norm_l2(&qi, &time)
})
.collect();
let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
.map(|i| {
let qi = aligned_srsf.row(i);
let qi_norm = srsf_norms[i];
if qi_norm < 1e-10 || mean_norm < 1e-10 {
return vec![0.0; m];
}
let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
let v_at_qi = inv_exp_map_sphere(&qi_unit, &mu_unit, &time);
let neg_v: Vec<f64> = v_at_qi.iter().map(|&x| -x).collect();
match method {
TransportMethod::SchildsLadder => {
parallel_transport_schilds(&neg_v, &qi_unit, &mu_unit, &time)
}
TransportMethod::PoleLadder => {
parallel_transport_pole(&neg_v, &qi_unit, &mu_unit, &time)
}
TransportMethod::LogMap => unreachable!(),
}
})
.collect();
let mut tangent_vectors = FdMatrix::zeros(n, m);
for i in 0..n {
for j in 0..m {
tangent_vectors[(i, j)] = tangent_data[i][j];
}
}
let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
TsrvfResult {
tangent_vectors,
mean: karcher.mean.clone(),
mean_srsf: mean_srsf_smooth,
mean_srsf_norm: mean_norm,
srsf_norms,
initial_values,
gammas: karcher.gammas.clone(),
converged: karcher.converged,
}
}