1use super::pairwise::elastic_align_pair;
4use super::srsf::{reparameterize_curve, srsf_transform};
5use super::{dp_edge_weight, dp_grid_solve, dp_lambda_penalty};
6use crate::helpers::{l2_distance, linear_interp, simpsons_weights};
7use crate::matrix::FdMatrix;
8use crate::warping::normalize_warp;
9
10#[derive(Debug, Clone, PartialEq)]
12#[non_exhaustive]
13pub struct ConstrainedAlignmentResult {
14 pub gamma: Vec<f64>,
16 pub f_aligned: Vec<f64>,
18 pub distance: f64,
20 pub enforced_landmarks: Vec<(f64, f64)>,
22}
23
24fn snap_to_grid(t_val: f64, argvals: &[f64]) -> usize {
26 let mut best = 0;
27 let mut best_dist = (t_val - argvals[0]).abs();
28 for (i, &a) in argvals.iter().enumerate().skip(1) {
29 let d = (t_val - a).abs();
30 if d < best_dist {
31 best = i;
32 best_dist = d;
33 }
34 }
35 best
36}
37
38fn dp_segment(
43 q1: &[f64],
44 q2: &[f64],
45 argvals: &[f64],
46 sc: usize,
47 ec: usize,
48 sr: usize,
49 er: usize,
50 lambda: f64,
51) -> Vec<(usize, usize)> {
52 let nc = ec - sc + 1;
53 let nr = er - sr + 1;
54
55 if nc <= 1 || nr <= 1 {
56 return vec![(sc, sr), (ec, er)];
57 }
58
59 let path = dp_grid_solve(nr, nc, |local_sr, local_sc, local_tr, local_tc| {
60 let gsr = sr + local_sr;
61 let gsc = sc + local_sc;
62 let gtr = sr + local_tr;
63 let gtc = sc + local_tc;
64 dp_edge_weight(q1, q2, argvals, gsc, gtc, gsr, gtr)
65 + dp_lambda_penalty(argvals, gsc, gtc, gsr, gtr, lambda)
66 });
67
68 path.iter().map(|&(lr, lc)| (sc + lc, sr + lr)).collect()
70}
71
72fn build_constrained_waypoints(
74 landmark_pairs: &[(f64, f64)],
75 argvals: &[f64],
76 m: usize,
77) -> Vec<(usize, usize)> {
78 let mut waypoints: Vec<(usize, usize)> = Vec::with_capacity(landmark_pairs.len() + 2);
79 waypoints.push((0, 0));
80 for &(tt, st) in landmark_pairs {
81 let tc = snap_to_grid(tt, argvals);
82 let tr = snap_to_grid(st, argvals);
83 if let Some(&(prev_c, prev_r)) = waypoints.last() {
84 if tc > prev_c && tr > prev_r {
85 waypoints.push((tc, tr));
86 }
87 }
88 }
89 let last = m - 1;
90 if let Some(&(prev_c, prev_r)) = waypoints.last() {
91 if prev_c != last || prev_r != last {
92 waypoints.push((last, last));
93 }
94 }
95 waypoints
96}
97
98fn segmented_dp_gamma(
100 q1n: &[f64],
101 q2n: &[f64],
102 argvals: &[f64],
103 waypoints: &[(usize, usize)],
104 lambda: f64,
105) -> Vec<f64> {
106 let mut full_path_tc: Vec<f64> = Vec::new();
107 let mut full_path_tr: Vec<f64> = Vec::new();
108
109 for seg in 0..(waypoints.len() - 1) {
110 let (sc, sr) = waypoints[seg];
111 let (ec, er) = waypoints[seg + 1];
112 let segment_path = dp_segment(q1n, q2n, argvals, sc, ec, sr, er, lambda);
113 let start = usize::from(seg > 0);
114 for &(tc, tr) in &segment_path[start..] {
115 full_path_tc.push(argvals[tc]);
116 full_path_tr.push(argvals[tr]);
117 }
118 }
119
120 let mut gamma: Vec<f64> = argvals
121 .iter()
122 .map(|&t| linear_interp(&full_path_tc, &full_path_tr, t))
123 .collect();
124 normalize_warp(&mut gamma, argvals);
125 gamma
126}
127
128pub fn elastic_align_pair_constrained(
143 f1: &[f64],
144 f2: &[f64],
145 argvals: &[f64],
146 landmark_pairs: &[(f64, f64)],
147 lambda: f64,
148) -> ConstrainedAlignmentResult {
149 let m = f1.len();
150
151 if landmark_pairs.is_empty() {
152 let r = elastic_align_pair(f1, f2, argvals, lambda);
153 return ConstrainedAlignmentResult {
154 gamma: r.gamma,
155 f_aligned: r.f_aligned,
156 distance: r.distance,
157 enforced_landmarks: Vec::new(),
158 };
159 }
160
161 let f1_mat = FdMatrix::from_slice(f1, 1, m).expect("dimension invariant: data.len() == n * m");
163 let f2_mat = FdMatrix::from_slice(f2, 1, m).expect("dimension invariant: data.len() == n * m");
164 let q1_mat = srsf_transform(&f1_mat, argvals);
165 let q2_mat = srsf_transform(&f2_mat, argvals);
166 let q1: Vec<f64> = q1_mat.row(0);
167 let q2: Vec<f64> = q2_mat.row(0);
168 let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
169 let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
170 let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
171 let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
172
173 let waypoints = build_constrained_waypoints(landmark_pairs, argvals, m);
174 let gamma = segmented_dp_gamma(&q1n, &q2n, argvals, &waypoints, lambda);
175
176 let f_aligned = reparameterize_curve(f2, argvals, &gamma);
177 let f_aligned_mat =
178 FdMatrix::from_slice(&f_aligned, 1, m).expect("dimension invariant: data.len() == n * m");
179 let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
180 let q_aligned: Vec<f64> = q_aligned_mat.row(0);
181 let weights = simpsons_weights(argvals);
182 let distance = l2_distance(&q1, &q_aligned, &weights);
183
184 let enforced: Vec<(f64, f64)> = waypoints[1..waypoints.len() - 1]
185 .iter()
186 .map(|&(tc, tr)| (argvals[tc], argvals[tr]))
187 .collect();
188
189 ConstrainedAlignmentResult {
190 gamma,
191 f_aligned,
192 distance,
193 enforced_landmarks: enforced,
194 }
195}
196
197pub fn elastic_align_pair_with_landmarks(
211 f1: &[f64],
212 f2: &[f64],
213 argvals: &[f64],
214 kind: crate::landmark::LandmarkKind,
215 min_prominence: f64,
216 expected_count: usize,
217 lambda: f64,
218) -> ConstrainedAlignmentResult {
219 let lm1 = crate::landmark::detect_landmarks(f1, argvals, kind, min_prominence);
220 let lm2 = crate::landmark::detect_landmarks(f2, argvals, kind, min_prominence);
221
222 let n_match = if expected_count > 0 {
224 expected_count.min(lm1.len()).min(lm2.len())
225 } else {
226 lm1.len().min(lm2.len())
227 };
228
229 let pairs: Vec<(f64, f64)> = (0..n_match)
230 .map(|i| (lm1[i].position, lm2[i].position))
231 .collect();
232
233 elastic_align_pair_constrained(f1, f2, argvals, &pairs, lambda)
234}