Skip to main content

fdars_core/alignment/
constrained.rs

1//! Landmark-constrained elastic alignment.
2
3use super::pairwise::elastic_align_pair;
4use super::srsf::{reparameterize_curve, srsf_transform};
5use super::{dp_edge_weight, dp_grid_solve, dp_lambda_penalty};
6use crate::helpers::{l2_distance, linear_interp, simpsons_weights};
7use crate::matrix::FdMatrix;
8use crate::warping::normalize_warp;
9
10/// Result of landmark-constrained elastic alignment.
11#[derive(Debug, Clone, PartialEq)]
12#[non_exhaustive]
13pub struct ConstrainedAlignmentResult {
14    /// Optimal warping function (length m).
15    pub gamma: Vec<f64>,
16    /// Aligned curve f2∘γ (length m).
17    pub f_aligned: Vec<f64>,
18    /// Elastic distance after alignment.
19    pub distance: f64,
20    /// Enforced landmark pairs (snapped to grid): `(target_t, source_t)`.
21    pub enforced_landmarks: Vec<(f64, f64)>,
22}
23
24/// Snap a time value to the nearest grid point index.
25fn snap_to_grid(t_val: f64, argvals: &[f64]) -> usize {
26    let mut best = 0;
27    let mut best_dist = (t_val - argvals[0]).abs();
28    for (i, &a) in argvals.iter().enumerate().skip(1) {
29        let d = (t_val - a).abs();
30        if d < best_dist {
31            best = i;
32            best_dist = d;
33        }
34    }
35    best
36}
37
38/// Run DP on a rectangular sub-grid `[sc..=ec] × [sr..=er]`.
39///
40/// Uses global indices for `dp_edge_weight`. Returns the path segment
41/// as a list of `(tc_idx, tr_idx)` pairs from start to end.
42fn dp_segment(
43    q1: &[f64],
44    q2: &[f64],
45    argvals: &[f64],
46    sc: usize,
47    ec: usize,
48    sr: usize,
49    er: usize,
50    lambda: f64,
51) -> Vec<(usize, usize)> {
52    let nc = ec - sc + 1;
53    let nr = er - sr + 1;
54
55    if nc <= 1 || nr <= 1 {
56        return vec![(sc, sr), (ec, er)];
57    }
58
59    let path = dp_grid_solve(nr, nc, |local_sr, local_sc, local_tr, local_tc| {
60        let gsr = sr + local_sr;
61        let gsc = sc + local_sc;
62        let gtr = sr + local_tr;
63        let gtc = sc + local_tc;
64        dp_edge_weight(q1, q2, argvals, gsc, gtc, gsr, gtr)
65            + dp_lambda_penalty(argvals, gsc, gtc, gsr, gtr, lambda)
66    });
67
68    // Convert local indices to global
69    path.iter().map(|&(lr, lc)| (sc + lc, sr + lr)).collect()
70}
71
72/// Build DP waypoints from landmark pairs: snap to grid, deduplicate, add endpoints.
73fn build_constrained_waypoints(
74    landmark_pairs: &[(f64, f64)],
75    argvals: &[f64],
76    m: usize,
77) -> Vec<(usize, usize)> {
78    let mut waypoints: Vec<(usize, usize)> = Vec::with_capacity(landmark_pairs.len() + 2);
79    waypoints.push((0, 0));
80    for &(tt, st) in landmark_pairs {
81        let tc = snap_to_grid(tt, argvals);
82        let tr = snap_to_grid(st, argvals);
83        if let Some(&(prev_c, prev_r)) = waypoints.last() {
84            if tc > prev_c && tr > prev_r {
85                waypoints.push((tc, tr));
86            }
87        }
88    }
89    let last = m - 1;
90    if let Some(&(prev_c, prev_r)) = waypoints.last() {
91        if prev_c != last || prev_r != last {
92            waypoints.push((last, last));
93        }
94    }
95    waypoints
96}
97
98/// Run DP segments between consecutive waypoints and assemble into a gamma warp.
99fn segmented_dp_gamma(
100    q1n: &[f64],
101    q2n: &[f64],
102    argvals: &[f64],
103    waypoints: &[(usize, usize)],
104    lambda: f64,
105) -> Vec<f64> {
106    let mut full_path_tc: Vec<f64> = Vec::new();
107    let mut full_path_tr: Vec<f64> = Vec::new();
108
109    for seg in 0..(waypoints.len() - 1) {
110        let (sc, sr) = waypoints[seg];
111        let (ec, er) = waypoints[seg + 1];
112        let segment_path = dp_segment(q1n, q2n, argvals, sc, ec, sr, er, lambda);
113        let start = usize::from(seg > 0);
114        for &(tc, tr) in &segment_path[start..] {
115            full_path_tc.push(argvals[tc]);
116            full_path_tr.push(argvals[tr]);
117        }
118    }
119
120    let mut gamma: Vec<f64> = argvals
121        .iter()
122        .map(|&t| linear_interp(&full_path_tc, &full_path_tr, t))
123        .collect();
124    normalize_warp(&mut gamma, argvals);
125    gamma
126}
127
128/// Align f2 to f1 with landmark constraints.
129///
130/// Landmark pairs define waypoints on the DP grid. Between consecutive waypoints,
131/// an independent smaller DP is run. The resulting warp passes through all landmarks.
132///
133/// # Arguments
134/// * `f1` — Target curve (length m)
135/// * `f2` — Curve to align (length m)
136/// * `argvals` — Evaluation points (length m)
137/// * `landmark_pairs` — `(target_t, source_t)` pairs in increasing order
138/// * `lambda` — Penalty weight
139///
140/// # Returns
141/// [`ConstrainedAlignmentResult`] with warp, aligned curve, and enforced landmarks.
142pub fn elastic_align_pair_constrained(
143    f1: &[f64],
144    f2: &[f64],
145    argvals: &[f64],
146    landmark_pairs: &[(f64, f64)],
147    lambda: f64,
148) -> ConstrainedAlignmentResult {
149    let m = f1.len();
150
151    if landmark_pairs.is_empty() {
152        let r = elastic_align_pair(f1, f2, argvals, lambda);
153        return ConstrainedAlignmentResult {
154            gamma: r.gamma,
155            f_aligned: r.f_aligned,
156            distance: r.distance,
157            enforced_landmarks: Vec::new(),
158        };
159    }
160
161    // Compute & normalize SRSFs
162    let f1_mat = FdMatrix::from_slice(f1, 1, m).expect("dimension invariant: data.len() == n * m");
163    let f2_mat = FdMatrix::from_slice(f2, 1, m).expect("dimension invariant: data.len() == n * m");
164    let q1_mat = srsf_transform(&f1_mat, argvals);
165    let q2_mat = srsf_transform(&f2_mat, argvals);
166    let q1: Vec<f64> = q1_mat.row(0);
167    let q2: Vec<f64> = q2_mat.row(0);
168    let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
169    let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
170    let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
171    let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
172
173    let waypoints = build_constrained_waypoints(landmark_pairs, argvals, m);
174    let gamma = segmented_dp_gamma(&q1n, &q2n, argvals, &waypoints, lambda);
175
176    let f_aligned = reparameterize_curve(f2, argvals, &gamma);
177    let f_aligned_mat =
178        FdMatrix::from_slice(&f_aligned, 1, m).expect("dimension invariant: data.len() == n * m");
179    let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
180    let q_aligned: Vec<f64> = q_aligned_mat.row(0);
181    let weights = simpsons_weights(argvals);
182    let distance = l2_distance(&q1, &q_aligned, &weights);
183
184    let enforced: Vec<(f64, f64)> = waypoints[1..waypoints.len() - 1]
185        .iter()
186        .map(|&(tc, tr)| (argvals[tc], argvals[tr]))
187        .collect();
188
189    ConstrainedAlignmentResult {
190        gamma,
191        f_aligned,
192        distance,
193        enforced_landmarks: enforced,
194    }
195}
196
197/// Align f2 to f1 with automatic landmark detection and elastic constraints.
198///
199/// Detects landmarks in both curves, matches them, and uses the matches
200/// as constraints for segmented DP alignment.
201///
202/// # Arguments
203/// * `f1` — Target curve (length m)
204/// * `f2` — Curve to align (length m)
205/// * `argvals` — Evaluation points (length m)
206/// * `kind` — Type of landmarks to detect
207/// * `min_prominence` — Minimum prominence for landmark detection
208/// * `expected_count` — Expected number of landmarks (0 = all detected)
209/// * `lambda` — Penalty weight
210pub fn elastic_align_pair_with_landmarks(
211    f1: &[f64],
212    f2: &[f64],
213    argvals: &[f64],
214    kind: crate::landmark::LandmarkKind,
215    min_prominence: f64,
216    expected_count: usize,
217    lambda: f64,
218) -> ConstrainedAlignmentResult {
219    let lm1 = crate::landmark::detect_landmarks(f1, argvals, kind, min_prominence);
220    let lm2 = crate::landmark::detect_landmarks(f2, argvals, kind, min_prominence);
221
222    // Match landmarks by order (take min count)
223    let n_match = if expected_count > 0 {
224        expected_count.min(lm1.len()).min(lm2.len())
225    } else {
226        lm1.len().min(lm2.len())
227    };
228
229    let pairs: Vec<(f64, f64)> = (0..n_match)
230        .map(|i| (lm1[i].position, lm2[i].position))
231        .collect();
232
233    elastic_align_pair_constrained(f1, f2, argvals, &pairs, lambda)
234}