1use super::pairwise::elastic_align_pair;
4use super::srsf::{reparameterize_curve, srsf_transform};
5use super::{dp_edge_weight, dp_grid_solve, dp_lambda_penalty};
6use crate::helpers::{l2_distance, linear_interp, simpsons_weights};
7use crate::matrix::FdMatrix;
8use crate::warping::normalize_warp;
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct ConstrainedAlignmentResult {
13 pub gamma: Vec<f64>,
15 pub f_aligned: Vec<f64>,
17 pub distance: f64,
19 pub enforced_landmarks: Vec<(f64, f64)>,
21}
22
23fn snap_to_grid(t_val: f64, argvals: &[f64]) -> usize {
25 let mut best = 0;
26 let mut best_dist = (t_val - argvals[0]).abs();
27 for (i, &a) in argvals.iter().enumerate().skip(1) {
28 let d = (t_val - a).abs();
29 if d < best_dist {
30 best = i;
31 best_dist = d;
32 }
33 }
34 best
35}
36
37fn dp_segment(
42 q1: &[f64],
43 q2: &[f64],
44 argvals: &[f64],
45 sc: usize,
46 ec: usize,
47 sr: usize,
48 er: usize,
49 lambda: f64,
50) -> Vec<(usize, usize)> {
51 let nc = ec - sc + 1;
52 let nr = er - sr + 1;
53
54 if nc <= 1 || nr <= 1 {
55 return vec![(sc, sr), (ec, er)];
56 }
57
58 let path = dp_grid_solve(nr, nc, |local_sr, local_sc, local_tr, local_tc| {
59 let gsr = sr + local_sr;
60 let gsc = sc + local_sc;
61 let gtr = sr + local_tr;
62 let gtc = sc + local_tc;
63 dp_edge_weight(q1, q2, argvals, gsc, gtc, gsr, gtr)
64 + dp_lambda_penalty(argvals, gsc, gtc, gsr, gtr, lambda)
65 });
66
67 path.iter().map(|&(lr, lc)| (sc + lc, sr + lr)).collect()
69}
70
71fn build_constrained_waypoints(
73 landmark_pairs: &[(f64, f64)],
74 argvals: &[f64],
75 m: usize,
76) -> Vec<(usize, usize)> {
77 let mut waypoints: Vec<(usize, usize)> = Vec::with_capacity(landmark_pairs.len() + 2);
78 waypoints.push((0, 0));
79 for &(tt, st) in landmark_pairs {
80 let tc = snap_to_grid(tt, argvals);
81 let tr = snap_to_grid(st, argvals);
82 if let Some(&(prev_c, prev_r)) = waypoints.last() {
83 if tc > prev_c && tr > prev_r {
84 waypoints.push((tc, tr));
85 }
86 }
87 }
88 let last = m - 1;
89 if let Some(&(prev_c, prev_r)) = waypoints.last() {
90 if prev_c != last || prev_r != last {
91 waypoints.push((last, last));
92 }
93 }
94 waypoints
95}
96
97fn segmented_dp_gamma(
99 q1n: &[f64],
100 q2n: &[f64],
101 argvals: &[f64],
102 waypoints: &[(usize, usize)],
103 lambda: f64,
104) -> Vec<f64> {
105 let mut full_path_tc: Vec<f64> = Vec::new();
106 let mut full_path_tr: Vec<f64> = Vec::new();
107
108 for seg in 0..(waypoints.len() - 1) {
109 let (sc, sr) = waypoints[seg];
110 let (ec, er) = waypoints[seg + 1];
111 let segment_path = dp_segment(q1n, q2n, argvals, sc, ec, sr, er, lambda);
112 let start = if seg > 0 { 1 } else { 0 };
113 for &(tc, tr) in &segment_path[start..] {
114 full_path_tc.push(argvals[tc]);
115 full_path_tr.push(argvals[tr]);
116 }
117 }
118
119 let mut gamma: Vec<f64> = argvals
120 .iter()
121 .map(|&t| linear_interp(&full_path_tc, &full_path_tr, t))
122 .collect();
123 normalize_warp(&mut gamma, argvals);
124 gamma
125}
126
127pub fn elastic_align_pair_constrained(
142 f1: &[f64],
143 f2: &[f64],
144 argvals: &[f64],
145 landmark_pairs: &[(f64, f64)],
146 lambda: f64,
147) -> ConstrainedAlignmentResult {
148 let m = f1.len();
149
150 if landmark_pairs.is_empty() {
151 let r = elastic_align_pair(f1, f2, argvals, lambda);
152 return ConstrainedAlignmentResult {
153 gamma: r.gamma,
154 f_aligned: r.f_aligned,
155 distance: r.distance,
156 enforced_landmarks: Vec::new(),
157 };
158 }
159
160 let f1_mat = FdMatrix::from_slice(f1, 1, m).expect("dimension invariant: data.len() == n * m");
162 let f2_mat = FdMatrix::from_slice(f2, 1, m).expect("dimension invariant: data.len() == n * m");
163 let q1_mat = srsf_transform(&f1_mat, argvals);
164 let q2_mat = srsf_transform(&f2_mat, argvals);
165 let q1: Vec<f64> = q1_mat.row(0);
166 let q2: Vec<f64> = q2_mat.row(0);
167 let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
168 let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
169 let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
170 let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
171
172 let waypoints = build_constrained_waypoints(landmark_pairs, argvals, m);
173 let gamma = segmented_dp_gamma(&q1n, &q2n, argvals, &waypoints, lambda);
174
175 let f_aligned = reparameterize_curve(f2, argvals, &gamma);
176 let f_aligned_mat =
177 FdMatrix::from_slice(&f_aligned, 1, m).expect("dimension invariant: data.len() == n * m");
178 let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
179 let q_aligned: Vec<f64> = q_aligned_mat.row(0);
180 let weights = simpsons_weights(argvals);
181 let distance = l2_distance(&q1, &q_aligned, &weights);
182
183 let enforced: Vec<(f64, f64)> = waypoints[1..waypoints.len() - 1]
184 .iter()
185 .map(|&(tc, tr)| (argvals[tc], argvals[tr]))
186 .collect();
187
188 ConstrainedAlignmentResult {
189 gamma,
190 f_aligned,
191 distance,
192 enforced_landmarks: enforced,
193 }
194}
195
196pub fn elastic_align_pair_with_landmarks(
210 f1: &[f64],
211 f2: &[f64],
212 argvals: &[f64],
213 kind: crate::landmark::LandmarkKind,
214 min_prominence: f64,
215 expected_count: usize,
216 lambda: f64,
217) -> ConstrainedAlignmentResult {
218 let lm1 = crate::landmark::detect_landmarks(f1, argvals, kind, min_prominence);
219 let lm2 = crate::landmark::detect_landmarks(f2, argvals, kind, min_prominence);
220
221 let n_match = if expected_count > 0 {
223 expected_count.min(lm1.len()).min(lm2.len())
224 } else {
225 lm1.len().min(lm2.len())
226 };
227
228 let pairs: Vec<(f64, f64)> = (0..n_match)
229 .map(|i| (lm1[i].position, lm2[i].position))
230 .collect();
231
232 elastic_align_pair_constrained(f1, f2, argvals, &pairs, lambda)
233}