Skip to main content

fdars_core/alignment/
nd.rs

1//! Multidimensional (R^d) SRSF transforms and elastic alignment.
2
3use super::srsf::reparameterize_curve;
4use super::{
5    dp_alignment_core, dp_edge_weight, dp_grid_solve, dp_lambda_penalty, dp_path_to_gamma,
6};
7use crate::helpers::{cumulative_trapz, l2_distance, simpsons_weights};
8use crate::matrix::{FdCurveSet, FdMatrix};
9
10/// Result of aligning multidimensional (R^d) curves.
11#[derive(Debug, Clone, PartialEq)]
12#[non_exhaustive]
13pub struct AlignmentResultNd {
14    /// Optimal warping function (length m), same for all dimensions.
15    pub gamma: Vec<f64>,
16    /// Aligned curve: d vectors, each length m.
17    pub f_aligned: Vec<Vec<f64>>,
18    /// Elastic distance after alignment.
19    pub distance: f64,
20}
21
22/// Scale derivative vector at one point by 1/√‖f'‖, writing into result_dims.
23#[inline]
24fn srsf_scale_point(derivs: &[FdMatrix], result_dims: &mut [FdMatrix], i: usize, j: usize) {
25    let d = derivs.len();
26    let norm_sq: f64 = derivs.iter().map(|dd| dd[(i, j)].powi(2)).sum();
27    let norm = norm_sq.sqrt();
28    if norm < 1e-15 {
29        for k in 0..d {
30            result_dims[k][(i, j)] = 0.0;
31        }
32    } else {
33        let scale = 1.0 / norm.sqrt();
34        for k in 0..d {
35            result_dims[k][(i, j)] = derivs[k][(i, j)] * scale;
36        }
37    }
38}
39
40/// Compute the SRSF transform for multidimensional (R^d) curves.
41///
42/// For f: \[0,1\] → R^d, the SRSF is q(t) = f'(t) / √‖f'(t)‖ where ‖·‖ is the
43/// Euclidean norm in R^d. For d=1 this reduces to `sign(f') · √|f'|`.
44///
45/// # Arguments
46/// * `data` — Set of n curves in R^d, each with m evaluation points
47/// * `argvals` — Evaluation points (length m)
48///
49/// # Returns
50/// `FdCurveSet` of SRSF values with the same shape as input.
51pub fn srsf_transform_nd(data: &FdCurveSet, argvals: &[f64]) -> FdCurveSet {
52    let d = data.ndim();
53    let n = data.ncurves();
54    let m = data.npoints();
55
56    if d == 0 || n == 0 || m == 0 || argvals.len() != m {
57        return FdCurveSet {
58            dims: (0..d).map(|_| FdMatrix::zeros(n, m)).collect(),
59        };
60    }
61
62    let derivs: Vec<FdMatrix> = data
63        .dims
64        .iter()
65        .map(|dim_mat| crate::fdata::deriv_1d(dim_mat, argvals, 1))
66        .collect();
67
68    let mut result_dims: Vec<FdMatrix> = (0..d).map(|_| FdMatrix::zeros(n, m)).collect();
69    for i in 0..n {
70        for j in 0..m {
71            srsf_scale_point(&derivs, &mut result_dims, i, j);
72        }
73    }
74
75    FdCurveSet { dims: result_dims }
76}
77
78/// Reconstruct an R^d curve from its SRSF.
79///
80/// Given d-dimensional SRSF vectors and initial point f0, reconstructs:
81/// `f_k(t) = f0_k + ∫₀ᵗ q_k(s) · ‖q(s)‖ ds` for each dimension k.
82///
83/// # Arguments
84/// * `q` — SRSF: d vectors, each length m
85/// * `argvals` — Evaluation points (length m)
86/// * `f0` — Initial values in R^d (length d)
87///
88/// # Returns
89/// Reconstructed curve: d vectors, each length m.
90pub fn srsf_inverse_nd(q: &[Vec<f64>], argvals: &[f64], f0: &[f64]) -> Vec<Vec<f64>> {
91    let d = q.len();
92    if d == 0 {
93        return Vec::new();
94    }
95    let m = q[0].len();
96    if m == 0 {
97        return vec![Vec::new(); d];
98    }
99
100    // Compute ||q(t)|| at each time point
101    let norms: Vec<f64> = (0..m)
102        .map(|j| {
103            let norm_sq: f64 = q.iter().map(|qk| qk[j].powi(2)).sum();
104            norm_sq.sqrt()
105        })
106        .collect();
107
108    // For each dimension, integrand = q_k(t) * ||q(t)||
109    let mut result = Vec::with_capacity(d);
110    for k in 0..d {
111        let integrand: Vec<f64> = (0..m).map(|j| q[k][j] * norms[j]).collect();
112        let integral = cumulative_trapz(&integrand, argvals);
113        let curve: Vec<f64> = integral.iter().map(|&v| f0[k] + v).collect();
114        result.push(curve);
115    }
116
117    result
118}
119
120/// Core DP alignment for R^d SRSFs.
121///
122/// Same DP grid and coprime neighborhood as `dp_alignment_core`, but edge weight
123/// is the sum of `dp_edge_weight` over d dimensions.
124fn dp_alignment_core_nd(
125    q1: &[Vec<f64>],
126    q2: &[Vec<f64>],
127    argvals: &[f64],
128    lambda: f64,
129) -> Vec<f64> {
130    let d = q1.len();
131    let m = argvals.len();
132    if m < 2 || d == 0 {
133        return argvals.to_vec();
134    }
135
136    // For d=1, delegate to existing implementation for exact backward compat
137    if d == 1 {
138        return dp_alignment_core(&q1[0], &q2[0], argvals, lambda);
139    }
140
141    // Normalize each dimension's SRSF to unit L2 norm
142    let q1n: Vec<Vec<f64>> = q1
143        .iter()
144        .map(|qk| {
145            let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
146            qk.iter().map(|&v| v / norm).collect()
147        })
148        .collect();
149    let q2n: Vec<Vec<f64>> = q2
150        .iter()
151        .map(|qk| {
152            let norm = qk.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
153            qk.iter().map(|&v| v / norm).collect()
154        })
155        .collect();
156
157    let path = dp_grid_solve(m, m, |sr, sc, tr, tc| {
158        let w: f64 = (0..d)
159            .map(|k| dp_edge_weight(&q1n[k], &q2n[k], argvals, sc, tc, sr, tr))
160            .sum();
161        w + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
162    });
163
164    dp_path_to_gamma(&path, argvals)
165}
166
167/// Align an R^d curve f2 to f1 using the elastic framework.
168///
169/// Finds the optimal warping γ (shared across all dimensions) such that
170/// f2∘γ is as close as possible to f1 in the elastic metric.
171///
172/// # Arguments
173/// * `f1` — Target curves (d dimensions)
174/// * `f2` — Curves to align (d dimensions)
175/// * `argvals` — Evaluation points (length m)
176/// * `lambda` — Penalty weight (0.0 = no penalty)
177pub fn elastic_align_pair_nd(
178    f1: &FdCurveSet,
179    f2: &FdCurveSet,
180    argvals: &[f64],
181    lambda: f64,
182) -> AlignmentResultNd {
183    let d = f1.ndim();
184    let m = f1.npoints();
185
186    // Compute SRSFs
187    let q1_set = srsf_transform_nd(f1, argvals);
188    let q2_set = srsf_transform_nd(f2, argvals);
189
190    // Extract first curve from each dimension
191    let q1: Vec<Vec<f64>> = q1_set.dims.iter().map(|dm| dm.row(0)).collect();
192    let q2: Vec<Vec<f64>> = q2_set.dims.iter().map(|dm| dm.row(0)).collect();
193
194    // DP alignment using summed cost over dimensions
195    let gamma = dp_alignment_core_nd(&q1, &q2, argvals, lambda);
196
197    // Apply warping to f2 in each dimension
198    let f_aligned: Vec<Vec<f64>> = f2
199        .dims
200        .iter()
201        .map(|dm| {
202            let row = dm.row(0);
203            reparameterize_curve(&row, argvals, &gamma)
204        })
205        .collect();
206
207    // Compute elastic distance: sum of squared L2 distances between aligned SRSFs
208    let f_aligned_set = {
209        let dims: Vec<FdMatrix> = f_aligned
210            .iter()
211            .map(|fa| {
212                FdMatrix::from_slice(fa, 1, m).expect("dimension invariant: data.len() == n * m")
213            })
214            .collect();
215        FdCurveSet { dims }
216    };
217    let q_aligned = srsf_transform_nd(&f_aligned_set, argvals);
218    let weights = simpsons_weights(argvals);
219
220    let mut dist_sq = 0.0;
221    for k in 0..d {
222        let q1k = q1_set.dims[k].row(0);
223        let qak = q_aligned.dims[k].row(0);
224        let d_k = l2_distance(&q1k, &qak, &weights);
225        dist_sq += d_k * d_k;
226    }
227
228    AlignmentResultNd {
229        gamma,
230        f_aligned,
231        distance: dist_sq.sqrt(),
232    }
233}
234
235/// Elastic distance between two R^d curves.
236///
237/// Aligns f2 to f1 and returns the post-alignment SRSF distance.
238pub fn elastic_distance_nd(f1: &FdCurveSet, f2: &FdCurveSet, argvals: &[f64], lambda: f64) -> f64 {
239    elastic_align_pair_nd(f1, f2, argvals, lambda).distance
240}