Skip to main content

fdars_core/alignment/
constrained.rs

1//! Landmark-constrained elastic alignment.
2
3use super::pairwise::elastic_align_pair;
4use super::srsf::{reparameterize_curve, srsf_transform};
5use super::{dp_edge_weight, dp_grid_solve, dp_lambda_penalty};
6use crate::helpers::{l2_distance, linear_interp, simpsons_weights};
7use crate::matrix::FdMatrix;
8use crate::warping::normalize_warp;
9
10/// Result of landmark-constrained elastic alignment.
11#[derive(Debug, Clone, PartialEq)]
12pub struct ConstrainedAlignmentResult {
13    /// Optimal warping function (length m).
14    pub gamma: Vec<f64>,
15    /// Aligned curve f2∘γ (length m).
16    pub f_aligned: Vec<f64>,
17    /// Elastic distance after alignment.
18    pub distance: f64,
19    /// Enforced landmark pairs (snapped to grid): `(target_t, source_t)`.
20    pub enforced_landmarks: Vec<(f64, f64)>,
21}
22
23/// Snap a time value to the nearest grid point index.
24fn snap_to_grid(t_val: f64, argvals: &[f64]) -> usize {
25    let mut best = 0;
26    let mut best_dist = (t_val - argvals[0]).abs();
27    for (i, &a) in argvals.iter().enumerate().skip(1) {
28        let d = (t_val - a).abs();
29        if d < best_dist {
30            best = i;
31            best_dist = d;
32        }
33    }
34    best
35}
36
37/// Run DP on a rectangular sub-grid `[sc..=ec] × [sr..=er]`.
38///
39/// Uses global indices for `dp_edge_weight`. Returns the path segment
40/// as a list of `(tc_idx, tr_idx)` pairs from start to end.
41fn dp_segment(
42    q1: &[f64],
43    q2: &[f64],
44    argvals: &[f64],
45    sc: usize,
46    ec: usize,
47    sr: usize,
48    er: usize,
49    lambda: f64,
50) -> Vec<(usize, usize)> {
51    let nc = ec - sc + 1;
52    let nr = er - sr + 1;
53
54    if nc <= 1 || nr <= 1 {
55        return vec![(sc, sr), (ec, er)];
56    }
57
58    let path = dp_grid_solve(nr, nc, |local_sr, local_sc, local_tr, local_tc| {
59        let gsr = sr + local_sr;
60        let gsc = sc + local_sc;
61        let gtr = sr + local_tr;
62        let gtc = sc + local_tc;
63        dp_edge_weight(q1, q2, argvals, gsc, gtc, gsr, gtr)
64            + dp_lambda_penalty(argvals, gsc, gtc, gsr, gtr, lambda)
65    });
66
67    // Convert local indices to global
68    path.iter().map(|&(lr, lc)| (sc + lc, sr + lr)).collect()
69}
70
71/// Build DP waypoints from landmark pairs: snap to grid, deduplicate, add endpoints.
72fn build_constrained_waypoints(
73    landmark_pairs: &[(f64, f64)],
74    argvals: &[f64],
75    m: usize,
76) -> Vec<(usize, usize)> {
77    let mut waypoints: Vec<(usize, usize)> = Vec::with_capacity(landmark_pairs.len() + 2);
78    waypoints.push((0, 0));
79    for &(tt, st) in landmark_pairs {
80        let tc = snap_to_grid(tt, argvals);
81        let tr = snap_to_grid(st, argvals);
82        if let Some(&(prev_c, prev_r)) = waypoints.last() {
83            if tc > prev_c && tr > prev_r {
84                waypoints.push((tc, tr));
85            }
86        }
87    }
88    let last = m - 1;
89    if let Some(&(prev_c, prev_r)) = waypoints.last() {
90        if prev_c != last || prev_r != last {
91            waypoints.push((last, last));
92        }
93    }
94    waypoints
95}
96
97/// Run DP segments between consecutive waypoints and assemble into a gamma warp.
98fn segmented_dp_gamma(
99    q1n: &[f64],
100    q2n: &[f64],
101    argvals: &[f64],
102    waypoints: &[(usize, usize)],
103    lambda: f64,
104) -> Vec<f64> {
105    let mut full_path_tc: Vec<f64> = Vec::new();
106    let mut full_path_tr: Vec<f64> = Vec::new();
107
108    for seg in 0..(waypoints.len() - 1) {
109        let (sc, sr) = waypoints[seg];
110        let (ec, er) = waypoints[seg + 1];
111        let segment_path = dp_segment(q1n, q2n, argvals, sc, ec, sr, er, lambda);
112        let start = if seg > 0 { 1 } else { 0 };
113        for &(tc, tr) in &segment_path[start..] {
114            full_path_tc.push(argvals[tc]);
115            full_path_tr.push(argvals[tr]);
116        }
117    }
118
119    let mut gamma: Vec<f64> = argvals
120        .iter()
121        .map(|&t| linear_interp(&full_path_tc, &full_path_tr, t))
122        .collect();
123    normalize_warp(&mut gamma, argvals);
124    gamma
125}
126
127/// Align f2 to f1 with landmark constraints.
128///
129/// Landmark pairs define waypoints on the DP grid. Between consecutive waypoints,
130/// an independent smaller DP is run. The resulting warp passes through all landmarks.
131///
132/// # Arguments
133/// * `f1` — Target curve (length m)
134/// * `f2` — Curve to align (length m)
135/// * `argvals` — Evaluation points (length m)
136/// * `landmark_pairs` — `(target_t, source_t)` pairs in increasing order
137/// * `lambda` — Penalty weight
138///
139/// # Returns
140/// [`ConstrainedAlignmentResult`] with warp, aligned curve, and enforced landmarks.
141pub fn elastic_align_pair_constrained(
142    f1: &[f64],
143    f2: &[f64],
144    argvals: &[f64],
145    landmark_pairs: &[(f64, f64)],
146    lambda: f64,
147) -> ConstrainedAlignmentResult {
148    let m = f1.len();
149
150    if landmark_pairs.is_empty() {
151        let r = elastic_align_pair(f1, f2, argvals, lambda);
152        return ConstrainedAlignmentResult {
153            gamma: r.gamma,
154            f_aligned: r.f_aligned,
155            distance: r.distance,
156            enforced_landmarks: Vec::new(),
157        };
158    }
159
160    // Compute & normalize SRSFs
161    let f1_mat = FdMatrix::from_slice(f1, 1, m).expect("dimension invariant: data.len() == n * m");
162    let f2_mat = FdMatrix::from_slice(f2, 1, m).expect("dimension invariant: data.len() == n * m");
163    let q1_mat = srsf_transform(&f1_mat, argvals);
164    let q2_mat = srsf_transform(&f2_mat, argvals);
165    let q1: Vec<f64> = q1_mat.row(0);
166    let q2: Vec<f64> = q2_mat.row(0);
167    let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
168    let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
169    let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
170    let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
171
172    let waypoints = build_constrained_waypoints(landmark_pairs, argvals, m);
173    let gamma = segmented_dp_gamma(&q1n, &q2n, argvals, &waypoints, lambda);
174
175    let f_aligned = reparameterize_curve(f2, argvals, &gamma);
176    let f_aligned_mat =
177        FdMatrix::from_slice(&f_aligned, 1, m).expect("dimension invariant: data.len() == n * m");
178    let q_aligned_mat = srsf_transform(&f_aligned_mat, argvals);
179    let q_aligned: Vec<f64> = q_aligned_mat.row(0);
180    let weights = simpsons_weights(argvals);
181    let distance = l2_distance(&q1, &q_aligned, &weights);
182
183    let enforced: Vec<(f64, f64)> = waypoints[1..waypoints.len() - 1]
184        .iter()
185        .map(|&(tc, tr)| (argvals[tc], argvals[tr]))
186        .collect();
187
188    ConstrainedAlignmentResult {
189        gamma,
190        f_aligned,
191        distance,
192        enforced_landmarks: enforced,
193    }
194}
195
196/// Align f2 to f1 with automatic landmark detection and elastic constraints.
197///
198/// Detects landmarks in both curves, matches them, and uses the matches
199/// as constraints for segmented DP alignment.
200///
201/// # Arguments
202/// * `f1` — Target curve (length m)
203/// * `f2` — Curve to align (length m)
204/// * `argvals` — Evaluation points (length m)
205/// * `kind` — Type of landmarks to detect
206/// * `min_prominence` — Minimum prominence for landmark detection
207/// * `expected_count` — Expected number of landmarks (0 = all detected)
208/// * `lambda` — Penalty weight
209pub fn elastic_align_pair_with_landmarks(
210    f1: &[f64],
211    f2: &[f64],
212    argvals: &[f64],
213    kind: crate::landmark::LandmarkKind,
214    min_prominence: f64,
215    expected_count: usize,
216    lambda: f64,
217) -> ConstrainedAlignmentResult {
218    let lm1 = crate::landmark::detect_landmarks(f1, argvals, kind, min_prominence);
219    let lm2 = crate::landmark::detect_landmarks(f2, argvals, kind, min_prominence);
220
221    // Match landmarks by order (take min count)
222    let n_match = if expected_count > 0 {
223        expected_count.min(lm1.len()).min(lm2.len())
224    } else {
225        lm1.len().min(lm2.len())
226    };
227
228    let pairs: Vec<(f64, f64)> = (0..n_match)
229        .map(|i| (lm1[i].position, lm2[i].position))
230        .collect();
231
232    elastic_align_pair_constrained(f1, f2, argvals, &pairs, lambda)
233}