1use super::karcher::karcher_mean;
8use super::srsf::{srsf_inverse, srsf_transform};
9use super::KarcherMeanResult;
10use crate::iter_maybe_parallel;
11use crate::matrix::FdMatrix;
12use crate::smoothing::nadaraya_watson;
13use crate::warping::{exp_map_sphere, inv_exp_map_sphere, l2_norm_l2};
14#[cfg(feature = "parallel")]
15use rayon::iter::ParallelIterator;
16
17#[derive(Debug, Clone, PartialEq)]
21#[non_exhaustive]
22pub struct TsrvfResult {
23 pub tangent_vectors: FdMatrix,
25 pub mean: Vec<f64>,
27 pub mean_srsf: Vec<f64>,
29 pub mean_srsf_norm: f64,
31 pub srsf_norms: Vec<f64>,
33 pub initial_values: Vec<f64>,
35 pub gammas: FdMatrix,
37 pub converged: bool,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
43#[non_exhaustive]
44pub enum TransportMethod {
45 #[default]
47 LogMap,
48 SchildsLadder,
50 PoleLadder,
52}
53
54fn smooth_aligned_srsfs(srsf: &FdMatrix, m: usize) -> FdMatrix {
62 let n = srsf.nrows();
63 let time: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
64 let bandwidth = 2.0 / (m - 1) as f64;
65
66 let mut smoothed = FdMatrix::zeros(n, m);
67 for i in 0..n {
68 let qi = srsf.row(i);
69 let qi_smooth = nadaraya_watson(&time, &qi, &time, bandwidth, "gaussian")
71 .expect("smoothing valid SRSF data should not fail");
72 for j in 0..m {
73 smoothed[(i, j)] = qi_smooth[j];
74 }
75 }
76 smoothed
77}
78
79pub(super) fn parallel_transport_schilds(
83 v: &[f64],
84 from: &[f64],
85 to: &[f64],
86 time: &[f64],
87) -> Vec<f64> {
88 let v_norm = l2_norm_l2(v, time);
89 if v_norm < 1e-10 {
90 return vec![0.0; v.len()];
91 }
92
93 let endpoint = exp_map_sphere(from, v, time);
95
96 let log_to_ep = inv_exp_map_sphere(to, &endpoint, time);
98
99 let half_log: Vec<f64> = log_to_ep.iter().map(|&x| 0.5 * x).collect();
101 let midpoint = exp_map_sphere(to, &half_log, time);
102
103 let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
105 log_to_mid.iter().map(|&x| 2.0 * x).collect()
106}
107
108pub(super) fn parallel_transport_pole(
110 v: &[f64],
111 from: &[f64],
112 to: &[f64],
113 time: &[f64],
114) -> Vec<f64> {
115 let v_norm = l2_norm_l2(v, time);
116 if v_norm < 1e-10 {
117 return vec![0.0; v.len()];
118 }
119
120 let neg_v: Vec<f64> = v.iter().map(|&x| -x).collect();
122 let pole = exp_map_sphere(from, &neg_v, time);
123
124 let log_to_pole = inv_exp_map_sphere(to, &pole, time);
126
127 let half_log: Vec<f64> = log_to_pole.iter().map(|&x| 0.5 * x).collect();
129 let midpoint = exp_map_sphere(to, &half_log, time);
130
131 let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
133 log_to_mid.iter().map(|&x| -2.0 * x).collect()
134}
135
136pub fn tsrvf_transform(
149 data: &FdMatrix,
150 argvals: &[f64],
151 max_iter: usize,
152 tol: f64,
153 lambda: f64,
154) -> TsrvfResult {
155 let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
156 tsrvf_from_alignment(&karcher, argvals)
157}
158
159pub fn tsrvf_from_alignment(karcher: &KarcherMeanResult, argvals: &[f64]) -> TsrvfResult {
171 let (n, m) = karcher.aligned_data.shape();
172
173 let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
175
176 let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
188
189 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
193 let bandwidth = 2.0 / (m - 1) as f64;
194 let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian")
195 .expect("smoothing valid mean SRSF should not fail");
196 let mean_norm = l2_norm_l2(&mean_srsf_smooth, &time);
197
198 let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
199 mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
200 } else {
201 vec![0.0; m]
202 };
203
204 let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
206 .map(|i| {
207 let qi = aligned_srsf.row(i);
208 l2_norm_l2(&qi, &time)
209 })
210 .collect();
211
212 let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
213 .map(|i| {
214 let qi = aligned_srsf.row(i);
215 let qi_norm = srsf_norms[i];
216
217 if qi_norm < 1e-10 || mean_norm < 1e-10 {
218 return vec![0.0; m];
219 }
220
221 let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
223
224 inv_exp_map_sphere(&mu_unit, &qi_unit, &time)
226 })
227 .collect();
228
229 let mut tangent_vectors = FdMatrix::zeros(n, m);
231 for i in 0..n {
232 for j in 0..m {
233 tangent_vectors[(i, j)] = tangent_data[i][j];
234 }
235 }
236
237 let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
240
241 TsrvfResult {
242 tangent_vectors,
243 mean: karcher.mean.clone(),
244 mean_srsf: mean_srsf_smooth,
245 mean_srsf_norm: mean_norm,
246 srsf_norms,
247 initial_values,
248 gammas: karcher.gammas.clone(),
249 converged: karcher.converged,
250 }
251}
252
253pub fn tsrvf_inverse(tsrvf: &TsrvfResult, argvals: &[f64]) -> FdMatrix {
265 let (n, m) = tsrvf.tangent_vectors.shape();
266
267 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
268
269 let mu_unit: Vec<f64> = if tsrvf.mean_srsf_norm > 1e-10 {
271 tsrvf
272 .mean_srsf
273 .iter()
274 .map(|&q| q / tsrvf.mean_srsf_norm)
275 .collect()
276 } else {
277 vec![0.0; m]
278 };
279
280 let curves: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
281 .map(|i| {
282 let vi = tsrvf.tangent_vectors.row(i);
283
284 let qi_unit = exp_map_sphere(&mu_unit, &vi, &time);
286
287 let qi: Vec<f64> = qi_unit.iter().map(|&q| q * tsrvf.srsf_norms[i]).collect();
289
290 srsf_inverse(&qi, argvals, tsrvf.initial_values[i])
292 })
293 .collect();
294
295 let mut result = FdMatrix::zeros(n, m);
296 for i in 0..n {
297 for j in 0..m {
298 result[(i, j)] = curves[i][j];
299 }
300 }
301 result
302}
303
304pub fn tsrvf_transform_with_method(
308 data: &FdMatrix,
309 argvals: &[f64],
310 max_iter: usize,
311 tol: f64,
312 lambda: f64,
313 method: TransportMethod,
314) -> TsrvfResult {
315 let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
316 tsrvf_from_alignment_with_method(&karcher, argvals, method)
317}
318
319pub fn tsrvf_from_alignment_with_method(
326 karcher: &KarcherMeanResult,
327 argvals: &[f64],
328 method: TransportMethod,
329) -> TsrvfResult {
330 if method == TransportMethod::LogMap {
331 return tsrvf_from_alignment(karcher, argvals);
332 }
333
334 let (n, m) = karcher.aligned_data.shape();
335 let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
336 let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
337 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
338 let bandwidth = 2.0 / (m - 1) as f64;
339 let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian")
340 .expect("smoothing valid mean SRSF should not fail");
341 let mean_norm = l2_norm_l2(&mean_srsf_smooth, &time);
342
343 let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
344 mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
345 } else {
346 vec![0.0; m]
347 };
348
349 let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
350 .map(|i| {
351 let qi = aligned_srsf.row(i);
352 l2_norm_l2(&qi, &time)
353 })
354 .collect();
355
356 let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
357 .map(|i| {
358 let qi = aligned_srsf.row(i);
359 let qi_norm = srsf_norms[i];
360
361 if qi_norm < 1e-10 || mean_norm < 1e-10 {
362 return vec![0.0; m];
363 }
364
365 let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
366
367 let v_at_qi = inv_exp_map_sphere(&qi_unit, &mu_unit, &time);
369 let neg_v: Vec<f64> = v_at_qi.iter().map(|&x| -x).collect();
370
371 match method {
373 TransportMethod::SchildsLadder => {
374 parallel_transport_schilds(&neg_v, &qi_unit, &mu_unit, &time)
375 }
376 TransportMethod::PoleLadder => {
377 parallel_transport_pole(&neg_v, &qi_unit, &mu_unit, &time)
378 }
379 TransportMethod::LogMap => unreachable!(),
380 }
381 })
382 .collect();
383
384 let mut tangent_vectors = FdMatrix::zeros(n, m);
385 for i in 0..n {
386 for j in 0..m {
387 tangent_vectors[(i, j)] = tangent_data[i][j];
388 }
389 }
390
391 let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
392
393 TsrvfResult {
394 tangent_vectors,
395 mean: karcher.mean.clone(),
396 mean_srsf: mean_srsf_smooth,
397 mean_srsf_norm: mean_norm,
398 srsf_norms,
399 initial_values,
400 gammas: karcher.gammas.clone(),
401 converged: karcher.converged,
402 }
403}