1use super::dp_alignment_core;
4use super::nd::{elastic_align_pair_nd, srsf_transform_nd};
5use super::srsf::{reparameterize_curve, srsf_inverse, srsf_single};
6use crate::error::FdarError;
7use crate::helpers::{l2_distance, simpsons_weights};
8use crate::matrix::{FdCurveSet, FdMatrix};
9use crate::warping::{exp_map_sphere, gam_to_psi, inv_exp_map_sphere, normalize_warp, psi_to_gam};
10
11#[derive(Debug, Clone, PartialEq)]
13#[non_exhaustive]
14pub struct GeodesicPath {
15 pub curves: FdMatrix,
17 pub warps: FdMatrix,
19 pub distances: Vec<f64>,
21 pub parameter_values: Vec<f64>,
23}
24
25#[derive(Debug, Clone, PartialEq)]
27#[non_exhaustive]
28pub struct GeodesicPathNd {
29 pub curves: Vec<FdMatrix>,
31 pub warps: FdMatrix,
33 pub distances: Vec<f64>,
35 pub parameter_values: Vec<f64>,
37}
38
39#[must_use = "expensive computation whose result should not be discarded"]
57pub fn curve_geodesic(
58 f1: &[f64],
59 f2: &[f64],
60 argvals: &[f64],
61 n_points: usize,
62 lambda: f64,
63) -> Result<GeodesicPath, FdarError> {
64 let m = f1.len();
65
66 if m < 2 {
68 return Err(FdarError::InvalidDimension {
69 parameter: "f1",
70 expected: "length >= 2".to_string(),
71 actual: format!("length {m}"),
72 });
73 }
74 if f2.len() != m {
75 return Err(FdarError::InvalidDimension {
76 parameter: "f2",
77 expected: format!("length {m}"),
78 actual: format!("length {}", f2.len()),
79 });
80 }
81 if argvals.len() != m {
82 return Err(FdarError::InvalidDimension {
83 parameter: "argvals",
84 expected: format!("length {m}"),
85 actual: format!("length {}", argvals.len()),
86 });
87 }
88 if n_points < 2 {
89 return Err(FdarError::InvalidParameter {
90 parameter: "n_points",
91 message: format!("must be >= 2, got {n_points}"),
92 });
93 }
94
95 let q1 = srsf_single(f1, argvals);
97 let q2 = srsf_single(f2, argvals);
98 let gamma = dp_alignment_core(&q1, &q2, argvals, lambda);
99 let f2_aligned = reparameterize_curve(f2, argvals, &gamma);
100 let q2a = srsf_single(&f2_aligned, argvals);
101
102 let t0 = argvals[0];
104 let domain = argvals[m - 1] - t0;
105 let time_01: Vec<f64> = (0..m).map(|j| (j as f64) / (m - 1) as f64).collect();
106 let binsize = 1.0 / (m - 1) as f64;
107
108 let gamma_01: Vec<f64> = gamma.iter().map(|&g| (g - t0) / domain).collect();
109 let psi = gam_to_psi(&gamma_01, binsize);
110 let psi_id = gam_to_psi(&time_01, binsize);
111 let v = inv_exp_map_sphere(&psi_id, &psi, &time_01);
112
113 let weights = simpsons_weights(argvals);
115
116 let parameter_values: Vec<f64> = (0..n_points)
118 .map(|k| k as f64 / (n_points - 1) as f64)
119 .collect();
120
121 let mut curves = FdMatrix::zeros(n_points, m);
122 let mut warps = FdMatrix::zeros(n_points, m);
123 let mut distances = Vec::with_capacity(n_points);
124
125 for (k, &t_k) in parameter_values.iter().enumerate() {
126 let scaled_v: Vec<f64> = v.iter().map(|&vi| t_k * vi).collect();
128 let psi_k = exp_map_sphere(&psi_id, &scaled_v, &time_01);
129 let mut gamma_k_01 = psi_to_gam(&psi_k, &time_01);
130 for j in 0..m {
132 gamma_k_01[j] = t0 + gamma_k_01[j] * domain;
133 }
134 normalize_warp(&mut gamma_k_01, argvals);
135
136 let q_k: Vec<f64> = (0..m).map(|j| (1.0 - t_k) * q1[j] + t_k * q2a[j]).collect();
138
139 let f0_k = (1.0 - t_k) * f1[0] + t_k * f2_aligned[0];
141 let f_k = srsf_inverse(&q_k, argvals, f0_k);
142
143 let dist = l2_distance(&q1, &q_k, &weights);
145
146 for j in 0..m {
147 curves[(k, j)] = f_k[j];
148 warps[(k, j)] = gamma_k_01[j];
149 }
150 distances.push(dist);
151 }
152
153 Ok(GeodesicPath {
154 curves,
155 warps,
156 distances,
157 parameter_values,
158 })
159}
160
161#[must_use = "expensive computation whose result should not be discarded"]
177pub fn curve_geodesic_nd(
178 f1: &FdCurveSet,
179 f2: &FdCurveSet,
180 argvals: &[f64],
181 n_points: usize,
182 lambda: f64,
183) -> Result<GeodesicPathNd, FdarError> {
184 let d = f1.ndim();
185 let m = f1.npoints();
186
187 if d == 0 {
189 return Err(FdarError::InvalidDimension {
190 parameter: "f1",
191 expected: "ndim >= 1".to_string(),
192 actual: "ndim 0".to_string(),
193 });
194 }
195 if f2.ndim() != d {
196 return Err(FdarError::InvalidDimension {
197 parameter: "f2",
198 expected: format!("ndim {d}"),
199 actual: format!("ndim {}", f2.ndim()),
200 });
201 }
202 if f2.npoints() != m {
203 return Err(FdarError::InvalidDimension {
204 parameter: "f2",
205 expected: format!("{m} points"),
206 actual: format!("{} points", f2.npoints()),
207 });
208 }
209 if m < 2 {
210 return Err(FdarError::InvalidDimension {
211 parameter: "f1",
212 expected: "npoints >= 2".to_string(),
213 actual: format!("npoints {m}"),
214 });
215 }
216 if argvals.len() != m {
217 return Err(FdarError::InvalidDimension {
218 parameter: "argvals",
219 expected: format!("length {m}"),
220 actual: format!("length {}", argvals.len()),
221 });
222 }
223 if n_points < 2 {
224 return Err(FdarError::InvalidParameter {
225 parameter: "n_points",
226 message: format!("must be >= 2, got {n_points}"),
227 });
228 }
229
230 let result = elastic_align_pair_nd(f1, f2, argvals, lambda);
232 let gamma = &result.gamma;
233
234 let q1_set = srsf_transform_nd(f1, argvals);
236 let f2_aligned_set = {
237 let dims: Vec<FdMatrix> = result
238 .f_aligned
239 .iter()
240 .map(|fa| FdMatrix::from_slice(fa, 1, m).expect("dimension invariant"))
241 .collect();
242 FdCurveSet { dims }
243 };
244 let q2a_set = srsf_transform_nd(&f2_aligned_set, argvals);
245
246 let q1: Vec<Vec<f64>> = q1_set.dims.iter().map(|dm| dm.row(0)).collect();
247 let q2a: Vec<Vec<f64>> = q2a_set.dims.iter().map(|dm| dm.row(0)).collect();
248
249 let t0 = argvals[0];
251 let domain = argvals[m - 1] - t0;
252 let time_01: Vec<f64> = (0..m).map(|j| (j as f64) / (m - 1) as f64).collect();
253 let binsize = 1.0 / (m - 1) as f64;
254
255 let gamma_01: Vec<f64> = gamma.iter().map(|&g| (g - t0) / domain).collect();
256 let psi = gam_to_psi(&gamma_01, binsize);
257 let psi_id = gam_to_psi(&time_01, binsize);
258 let v = inv_exp_map_sphere(&psi_id, &psi, &time_01);
259
260 let weights = simpsons_weights(argvals);
262
263 let parameter_values: Vec<f64> = (0..n_points)
265 .map(|k| k as f64 / (n_points - 1) as f64)
266 .collect();
267
268 let mut dim_curves: Vec<FdMatrix> = (0..d).map(|_| FdMatrix::zeros(n_points, m)).collect();
269 let mut warps_mat = FdMatrix::zeros(n_points, m);
270 let mut distances = Vec::with_capacity(n_points);
271
272 for (k, &t_k) in parameter_values.iter().enumerate() {
273 let scaled_v: Vec<f64> = v.iter().map(|&vi| t_k * vi).collect();
275 let psi_k = exp_map_sphere(&psi_id, &scaled_v, &time_01);
276 let mut gamma_k_01 = psi_to_gam(&psi_k, &time_01);
277 for j in 0..m {
278 gamma_k_01[j] = t0 + gamma_k_01[j] * domain;
279 }
280 normalize_warp(&mut gamma_k_01, argvals);
281
282 for j in 0..m {
283 warps_mat[(k, j)] = gamma_k_01[j];
284 }
285
286 let mut dist_sq = 0.0;
288 for dd in 0..d {
289 let q_k: Vec<f64> = (0..m)
290 .map(|j| (1.0 - t_k) * q1[dd][j] + t_k * q2a[dd][j])
291 .collect();
292
293 let f0_k = (1.0 - t_k) * f1.dims[dd][(0, 0)] + t_k * result.f_aligned[dd][0];
294 let f_k = srsf_inverse(&q_k, argvals, f0_k);
295
296 let d_k = l2_distance(&q1[dd], &q_k, &weights);
297 dist_sq += d_k * d_k;
298
299 for j in 0..m {
300 dim_curves[dd][(k, j)] = f_k[j];
301 }
302 }
303 distances.push(dist_sq.sqrt());
304 }
305
306 Ok(GeodesicPathNd {
307 curves: dim_curves,
308 warps: warps_mat,
309 distances,
310 parameter_values,
311 })
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use crate::test_helpers::uniform_grid;
318
319 #[test]
320 fn geodesic_endpoints_match() {
321 let m = 51;
322 let t = uniform_grid(m);
323 let f1: Vec<f64> = t
324 .iter()
325 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
326 .collect();
327 let f2: Vec<f64> = t
328 .iter()
329 .map(|&ti| (2.0 * std::f64::consts::PI * ti).cos())
330 .collect();
331
332 let path = curve_geodesic(&f1, &f2, &t, 5, 0.0).unwrap();
333
334 let first_curve = path.curves.row(0);
336 let max_diff_start: f64 = first_curve
337 .iter()
338 .zip(f1.iter())
339 .map(|(&a, &b)| (a - b).abs())
340 .fold(0.0_f64, f64::max);
341 assert!(
342 max_diff_start < 0.5,
343 "At t=0 curve should approximate f1, max diff = {max_diff_start}"
344 );
345
346 let last_curve = path.curves.row(path.parameter_values.len() - 1);
349 assert_eq!(last_curve.len(), m);
350 }
351
352 #[test]
353 fn geodesic_distances_nonneg() {
354 let m = 41;
355 let t = uniform_grid(m);
356 let f1: Vec<f64> = t
357 .iter()
358 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
359 .collect();
360 let f2: Vec<f64> = t
361 .iter()
362 .map(|&ti| 0.5 * (4.0 * std::f64::consts::PI * ti).sin())
363 .collect();
364
365 let path = curve_geodesic(&f1, &f2, &t, 6, 0.0).unwrap();
366 for (k, &d) in path.distances.iter().enumerate() {
367 assert!(d >= 0.0, "Distance at k={k} should be >= 0, got {d}");
368 }
369 }
370
371 #[test]
372 fn geodesic_identical_curves() {
373 let m = 41;
374 let t = uniform_grid(m);
375 let f: Vec<f64> = t
376 .iter()
377 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
378 .collect();
379
380 let path = curve_geodesic(&f, &f, &t, 4, 0.0).unwrap();
381
382 for k in 0..path.parameter_values.len() {
384 let curve_k = path.curves.row(k);
385 let max_diff: f64 = curve_k
386 .iter()
387 .zip(f.iter())
388 .map(|(&a, &b)| (a - b).abs())
389 .fold(0.0_f64, f64::max);
390 assert!(
391 max_diff < 0.5,
392 "Identical curve geodesic: curve at k={k} deviates by {max_diff}"
393 );
394 }
395
396 for (k, &d) in path.distances.iter().enumerate() {
398 assert!(
399 d < 1.0,
400 "Identical curve geodesic: distance at k={k} = {d}, expected near 0"
401 );
402 }
403 }
404
405 #[test]
406 fn geodesic_nd_dimensions() {
407 let m = 31;
408 let t = uniform_grid(m);
409 let d = 2;
410 let n_points = 4;
411
412 let f1x: Vec<f64> = t
413 .iter()
414 .map(|&ti| (2.0 * std::f64::consts::PI * ti).sin())
415 .collect();
416 let f1y: Vec<f64> = t
417 .iter()
418 .map(|&ti| (2.0 * std::f64::consts::PI * ti).cos())
419 .collect();
420 let f2x: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
421 let f2y: Vec<f64> = t.to_vec();
422
423 let f1 = FdCurveSet::from_dims(vec![
424 FdMatrix::from_slice(&f1x, 1, m).unwrap(),
425 FdMatrix::from_slice(&f1y, 1, m).unwrap(),
426 ])
427 .unwrap();
428 let f2 = FdCurveSet::from_dims(vec![
429 FdMatrix::from_slice(&f2x, 1, m).unwrap(),
430 FdMatrix::from_slice(&f2y, 1, m).unwrap(),
431 ])
432 .unwrap();
433
434 let path = curve_geodesic_nd(&f1, &f2, &t, n_points, 0.0).unwrap();
435
436 assert_eq!(path.curves.len(), d, "Should have d dimension matrices");
437 for (dd, dim_mat) in path.curves.iter().enumerate() {
438 assert_eq!(
439 dim_mat.shape(),
440 (n_points, m),
441 "Dimension {dd} matrix shape mismatch"
442 );
443 }
444 assert_eq!(path.warps.shape(), (n_points, m));
445 assert_eq!(path.distances.len(), n_points);
446 assert_eq!(path.parameter_values.len(), n_points);
447 }
448}