1use super::karcher::karcher_mean;
8use super::srsf::{srsf_inverse, srsf_transform};
9use super::KarcherMeanResult;
10use crate::iter_maybe_parallel;
11use crate::matrix::FdMatrix;
12use crate::smoothing::nadaraya_watson;
13use crate::warping::{exp_map_sphere, inv_exp_map_sphere, l2_norm_l2};
14#[cfg(feature = "parallel")]
15use rayon::iter::ParallelIterator;
16
17#[derive(Debug, Clone, PartialEq)]
21pub struct TsrvfResult {
22 pub tangent_vectors: FdMatrix,
24 pub mean: Vec<f64>,
26 pub mean_srsf: Vec<f64>,
28 pub mean_srsf_norm: f64,
30 pub srsf_norms: Vec<f64>,
32 pub initial_values: Vec<f64>,
34 pub gammas: FdMatrix,
36 pub converged: bool,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
42pub enum TransportMethod {
43 #[default]
45 LogMap,
46 SchildsLadder,
48 PoleLadder,
50}
51
52fn smooth_aligned_srsfs(srsf: &FdMatrix, m: usize) -> FdMatrix {
60 let n = srsf.nrows();
61 let time: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
62 let bandwidth = 2.0 / (m - 1) as f64;
63
64 let mut smoothed = FdMatrix::zeros(n, m);
65 for i in 0..n {
66 let qi = srsf.row(i);
67 let qi_smooth = nadaraya_watson(&time, &qi, &time, bandwidth, "gaussian");
68 for j in 0..m {
69 smoothed[(i, j)] = qi_smooth[j];
70 }
71 }
72 smoothed
73}
74
75pub(super) fn parallel_transport_schilds(
79 v: &[f64],
80 from: &[f64],
81 to: &[f64],
82 time: &[f64],
83) -> Vec<f64> {
84 let v_norm = l2_norm_l2(v, time);
85 if v_norm < 1e-10 {
86 return vec![0.0; v.len()];
87 }
88
89 let endpoint = exp_map_sphere(from, v, time);
91
92 let log_to_ep = inv_exp_map_sphere(to, &endpoint, time);
94
95 let half_log: Vec<f64> = log_to_ep.iter().map(|&x| 0.5 * x).collect();
97 let midpoint = exp_map_sphere(to, &half_log, time);
98
99 let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
101 log_to_mid.iter().map(|&x| 2.0 * x).collect()
102}
103
104pub(super) fn parallel_transport_pole(
106 v: &[f64],
107 from: &[f64],
108 to: &[f64],
109 time: &[f64],
110) -> Vec<f64> {
111 let v_norm = l2_norm_l2(v, time);
112 if v_norm < 1e-10 {
113 return vec![0.0; v.len()];
114 }
115
116 let neg_v: Vec<f64> = v.iter().map(|&x| -x).collect();
118 let pole = exp_map_sphere(from, &neg_v, time);
119
120 let log_to_pole = inv_exp_map_sphere(to, &pole, time);
122
123 let half_log: Vec<f64> = log_to_pole.iter().map(|&x| 0.5 * x).collect();
125 let midpoint = exp_map_sphere(to, &half_log, time);
126
127 let log_to_mid = inv_exp_map_sphere(to, &midpoint, time);
129 log_to_mid.iter().map(|&x| -2.0 * x).collect()
130}
131
132pub fn tsrvf_transform(
145 data: &FdMatrix,
146 argvals: &[f64],
147 max_iter: usize,
148 tol: f64,
149 lambda: f64,
150) -> TsrvfResult {
151 let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
152 tsrvf_from_alignment(&karcher, argvals)
153}
154
155pub fn tsrvf_from_alignment(karcher: &KarcherMeanResult, argvals: &[f64]) -> TsrvfResult {
167 let (n, m) = karcher.aligned_data.shape();
168
169 let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
171
172 let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
184
185 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
189 let bandwidth = 2.0 / (m - 1) as f64;
190 let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian");
191 let mean_norm = l2_norm_l2(&mean_srsf_smooth, &time);
192
193 let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
194 mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
195 } else {
196 vec![0.0; m]
197 };
198
199 let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
201 .map(|i| {
202 let qi = aligned_srsf.row(i);
203 l2_norm_l2(&qi, &time)
204 })
205 .collect();
206
207 let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
208 .map(|i| {
209 let qi = aligned_srsf.row(i);
210 let qi_norm = srsf_norms[i];
211
212 if qi_norm < 1e-10 || mean_norm < 1e-10 {
213 return vec![0.0; m];
214 }
215
216 let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
218
219 inv_exp_map_sphere(&mu_unit, &qi_unit, &time)
221 })
222 .collect();
223
224 let mut tangent_vectors = FdMatrix::zeros(n, m);
226 for i in 0..n {
227 for j in 0..m {
228 tangent_vectors[(i, j)] = tangent_data[i][j];
229 }
230 }
231
232 let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
235
236 TsrvfResult {
237 tangent_vectors,
238 mean: karcher.mean.clone(),
239 mean_srsf: mean_srsf_smooth,
240 mean_srsf_norm: mean_norm,
241 srsf_norms,
242 initial_values,
243 gammas: karcher.gammas.clone(),
244 converged: karcher.converged,
245 }
246}
247
248pub fn tsrvf_inverse(tsrvf: &TsrvfResult, argvals: &[f64]) -> FdMatrix {
260 let (n, m) = tsrvf.tangent_vectors.shape();
261
262 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
263
264 let mu_unit: Vec<f64> = if tsrvf.mean_srsf_norm > 1e-10 {
266 tsrvf
267 .mean_srsf
268 .iter()
269 .map(|&q| q / tsrvf.mean_srsf_norm)
270 .collect()
271 } else {
272 vec![0.0; m]
273 };
274
275 let curves: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
276 .map(|i| {
277 let vi = tsrvf.tangent_vectors.row(i);
278
279 let qi_unit = exp_map_sphere(&mu_unit, &vi, &time);
281
282 let qi: Vec<f64> = qi_unit.iter().map(|&q| q * tsrvf.srsf_norms[i]).collect();
284
285 srsf_inverse(&qi, argvals, tsrvf.initial_values[i])
287 })
288 .collect();
289
290 let mut result = FdMatrix::zeros(n, m);
291 for i in 0..n {
292 for j in 0..m {
293 result[(i, j)] = curves[i][j];
294 }
295 }
296 result
297}
298
299pub fn tsrvf_transform_with_method(
303 data: &FdMatrix,
304 argvals: &[f64],
305 max_iter: usize,
306 tol: f64,
307 lambda: f64,
308 method: TransportMethod,
309) -> TsrvfResult {
310 let karcher = karcher_mean(data, argvals, max_iter, tol, lambda);
311 tsrvf_from_alignment_with_method(&karcher, argvals, method)
312}
313
314pub fn tsrvf_from_alignment_with_method(
321 karcher: &KarcherMeanResult,
322 argvals: &[f64],
323 method: TransportMethod,
324) -> TsrvfResult {
325 if method == TransportMethod::LogMap {
326 return tsrvf_from_alignment(karcher, argvals);
327 }
328
329 let (n, m) = karcher.aligned_data.shape();
330 let aligned_srsf = srsf_transform(&karcher.aligned_data, argvals);
331 let aligned_srsf = smooth_aligned_srsfs(&aligned_srsf, m);
332 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
333 let bandwidth = 2.0 / (m - 1) as f64;
334 let mean_srsf_smooth = nadaraya_watson(&time, &karcher.mean_srsf, &time, bandwidth, "gaussian");
335 let mean_norm = l2_norm_l2(&mean_srsf_smooth, &time);
336
337 let mu_unit: Vec<f64> = if mean_norm > 1e-10 {
338 mean_srsf_smooth.iter().map(|&q| q / mean_norm).collect()
339 } else {
340 vec![0.0; m]
341 };
342
343 let srsf_norms: Vec<f64> = iter_maybe_parallel!(0..n)
344 .map(|i| {
345 let qi = aligned_srsf.row(i);
346 l2_norm_l2(&qi, &time)
347 })
348 .collect();
349
350 let tangent_data: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
351 .map(|i| {
352 let qi = aligned_srsf.row(i);
353 let qi_norm = srsf_norms[i];
354
355 if qi_norm < 1e-10 || mean_norm < 1e-10 {
356 return vec![0.0; m];
357 }
358
359 let qi_unit: Vec<f64> = qi.iter().map(|&q| q / qi_norm).collect();
360
361 let v_at_qi = inv_exp_map_sphere(&qi_unit, &mu_unit, &time);
363 let neg_v: Vec<f64> = v_at_qi.iter().map(|&x| -x).collect();
364
365 match method {
367 TransportMethod::SchildsLadder => {
368 parallel_transport_schilds(&neg_v, &qi_unit, &mu_unit, &time)
369 }
370 TransportMethod::PoleLadder => {
371 parallel_transport_pole(&neg_v, &qi_unit, &mu_unit, &time)
372 }
373 TransportMethod::LogMap => unreachable!(),
374 }
375 })
376 .collect();
377
378 let mut tangent_vectors = FdMatrix::zeros(n, m);
379 for i in 0..n {
380 for j in 0..m {
381 tangent_vectors[(i, j)] = tangent_data[i][j];
382 }
383 }
384
385 let initial_values: Vec<f64> = (0..n).map(|i| karcher.aligned_data[(i, 0)]).collect();
386
387 TsrvfResult {
388 tangent_vectors,
389 mean: karcher.mean.clone(),
390 mean_srsf: mean_srsf_smooth,
391 mean_srsf_norm: mean_norm,
392 srsf_norms,
393 initial_values,
394 gammas: karcher.gammas.clone(),
395 converged: karcher.converged,
396 }
397}