Skip to main content

fdars_core/alignment/
partial_match.rs

1//! Elastic partial matching: find the best-aligned subcurve of a longer curve.
2//!
3//! Standard elastic alignment requires both curves to span the full domain.
4//! Partial matching relaxes this: given a template curve and a longer curve,
5//! it finds the contiguous subdomain of the longer curve that best matches
6//! the template in the elastic metric.
7
8use super::srsf::srsf_single;
9use super::{dp_edge_weight, dp_lambda_penalty, dp_path_to_gamma};
10use crate::error::FdarError;
11use crate::helpers::{l2_distance, simpsons_weights};
12
13// ─── Types ──────────────────────────────────────────────────────────────────
14
15/// Configuration for elastic partial matching.
16#[derive(Debug, Clone, PartialEq)]
17pub struct PartialMatchConfig {
18    /// Roughness penalty for elastic alignment (0.0 = no penalty).
19    pub lambda: f64,
20    /// Minimum fraction of the target curve that the match must span.
21    /// Must be in (0, 1]. Default 0.5.
22    pub min_span: f64,
23}
24
25impl Default for PartialMatchConfig {
26    fn default() -> Self {
27        Self {
28            lambda: 0.0,
29            min_span: 0.5,
30        }
31    }
32}
33
34/// Result of elastic partial matching.
35#[derive(Debug, Clone, PartialEq)]
36#[non_exhaustive]
37pub struct PartialMatchResult {
38    /// Start index in the target curve where the best match begins.
39    pub start_index: usize,
40    /// End index (inclusive) in the target curve where the best match ends.
41    pub end_index: usize,
42    /// Warping function mapping template domain to the matched subdomain.
43    pub gamma: Vec<f64>,
44    /// Elastic distance of the best partial match.
45    pub distance: f64,
46    /// Fraction of the target domain spanned by the match.
47    pub domain_fraction: f64,
48}
49
50// ─── Public API ─────────────────────────────────────────────────────────────
51
52/// Find the best elastic partial match of `template` within `target`.
53///
54/// Slides a variable-length window over the target curve and performs
55/// elastic alignment of the template to each window, returning the
56/// window position and warp with minimum elastic distance.
57///
58/// # Arguments
59/// * `template` — Short template curve (length m_t)
60/// * `target` — Longer target curve (length m_f)
61/// * `argvals_template` — Evaluation points for the template (length m_t)
62/// * `argvals_target` — Evaluation points for the target (length m_f)
63/// * `config` — Partial matching configuration
64///
65/// # Errors
66/// Returns [`FdarError::InvalidDimension`] if lengths are inconsistent.
67/// Returns [`FdarError::InvalidParameter`] if `min_span` is not in (0, 1].
68#[must_use = "expensive computation whose result should not be discarded"]
69pub fn elastic_partial_match(
70    template: &[f64],
71    target: &[f64],
72    argvals_template: &[f64],
73    argvals_target: &[f64],
74    config: &PartialMatchConfig,
75) -> Result<PartialMatchResult, FdarError> {
76    let m_t = template.len();
77    let m_f = target.len();
78
79    if m_t != argvals_template.len() {
80        return Err(FdarError::InvalidDimension {
81            parameter: "argvals_template",
82            expected: format!("{m_t}"),
83            actual: format!("{}", argvals_template.len()),
84        });
85    }
86    if m_f != argvals_target.len() {
87        return Err(FdarError::InvalidDimension {
88            parameter: "argvals_target",
89            expected: format!("{m_f}"),
90            actual: format!("{}", argvals_target.len()),
91        });
92    }
93    if m_t < 2 || m_f < 2 {
94        return Err(FdarError::InvalidDimension {
95            parameter: "template/target",
96            expected: "length >= 2".to_string(),
97            actual: format!("template={m_t}, target={m_f}"),
98        });
99    }
100    if config.min_span <= 0.0 || config.min_span > 1.0 {
101        return Err(FdarError::InvalidParameter {
102            parameter: "min_span",
103            message: format!("must be in (0, 1], got {}", config.min_span),
104        });
105    }
106
107    let q_template = srsf_single(template, argvals_template);
108
109    // Minimum window size (in grid points) based on min_span
110    let min_window = ((m_f as f64 * config.min_span).ceil() as usize).max(2);
111
112    let mut best_start = 0;
113    let mut best_end = m_f - 1;
114    let mut best_dist = f64::INFINITY;
115    let mut best_gamma = argvals_template.to_vec();
116
117    // Iterate over window sizes from min_window to m_f
118    // Use a coarse grid of window sizes for efficiency
119    let n_sizes = 5.min(m_f - min_window + 1);
120    let sizes: Vec<usize> = if n_sizes <= 1 {
121        vec![m_f]
122    } else {
123        (0..n_sizes)
124            .map(|i| min_window + i * (m_f - min_window) / (n_sizes - 1))
125            .collect()
126    };
127
128    for &win_size in &sizes {
129        let step = (win_size / 10).max(1);
130        let mut start = 0;
131        while start + win_size <= m_f {
132            let end = start + win_size - 1;
133
134            // Extract sub-argvals and sub-curve
135            let sub_argvals: Vec<f64> = (0..m_t)
136                .map(|i| {
137                    argvals_target[start]
138                        + (argvals_target[end] - argvals_target[start]) * i as f64
139                            / (m_t - 1) as f64
140                })
141                .collect();
142
143            // Interpolate target onto sub_argvals
144            let sub_target: Vec<f64> = sub_argvals
145                .iter()
146                .map(|&t| interp_target(target, argvals_target, t))
147                .collect();
148
149            let q_sub = srsf_single(&sub_target, argvals_template);
150
151            // DP alignment on the shared template grid
152            let gamma = dp_align_partial(&q_template, &q_sub, argvals_template, config.lambda);
153
154            // Compute distance
155            let sub_aligned: Vec<f64> = argvals_template
156                .iter()
157                .map(|&t| {
158                    interp_target(
159                        &sub_target,
160                        argvals_template,
161                        interp_target(&gamma, argvals_template, t),
162                    )
163                })
164                .collect();
165            let q_aligned = srsf_single(&sub_aligned, argvals_template);
166            let weights = simpsons_weights(argvals_template);
167            let dist = l2_distance(&q_template, &q_aligned, &weights);
168
169            if dist < best_dist {
170                best_dist = dist;
171                best_start = start;
172                best_end = end;
173                best_gamma = gamma;
174            }
175
176            start += step;
177        }
178    }
179
180    let total_domain = argvals_target[m_f - 1] - argvals_target[0];
181    let match_domain = argvals_target[best_end] - argvals_target[best_start];
182    let domain_fraction = if total_domain > 0.0 {
183        match_domain / total_domain
184    } else {
185        1.0
186    };
187
188    Ok(PartialMatchResult {
189        start_index: best_start,
190        end_index: best_end,
191        gamma: best_gamma,
192        distance: best_dist,
193        domain_fraction,
194    })
195}
196
197// ─── Helpers ────────────────────────────────────────────────────────────────
198
199/// Linear interpolation of a curve at point `t`.
200fn interp_target(values: &[f64], grid: &[f64], t: f64) -> f64 {
201    let n = grid.len();
202    if n == 0 {
203        return 0.0;
204    }
205    if t <= grid[0] {
206        return values[0];
207    }
208    if t >= grid[n - 1] {
209        return values[n - 1];
210    }
211    // Binary search for the interval
212    let mut lo = 0;
213    let mut hi = n - 1;
214    while hi - lo > 1 {
215        let mid = (lo + hi) / 2;
216        if grid[mid] <= t {
217            lo = mid;
218        } else {
219            hi = mid;
220        }
221    }
222    let frac = (t - grid[lo]) / (grid[hi] - grid[lo]);
223    values[lo] * (1.0 - frac) + values[hi] * frac
224}
225
226/// DP alignment for partial matching (same grid for both SRSFs).
227fn dp_align_partial(q1: &[f64], q2: &[f64], argvals: &[f64], lambda: f64) -> Vec<f64> {
228    let m = argvals.len();
229    if m < 2 {
230        return argvals.to_vec();
231    }
232
233    let norm1 = q1.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
234    let norm2 = q2.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
235    let q1n: Vec<f64> = q1.iter().map(|&v| v / norm1).collect();
236    let q2n: Vec<f64> = q2.iter().map(|&v| v / norm2).collect();
237
238    let path = super::dp_grid_solve(m, m, |sr, sc, tr, tc| {
239        dp_edge_weight(&q1n, &q2n, argvals, sc, tc, sr, tr)
240            + dp_lambda_penalty(argvals, sc, tc, sr, tr, lambda)
241    });
242
243    dp_path_to_gamma(&path, argvals)
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use crate::test_helpers::uniform_grid;
250
251    #[test]
252    fn partial_match_identity() {
253        let m = 30;
254        let t = uniform_grid(m);
255        let f: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
256
257        let config = PartialMatchConfig {
258            min_span: 0.5,
259            ..Default::default()
260        };
261        let result = elastic_partial_match(&f, &f, &t, &t, &config).unwrap();
262
263        assert!(
264            result.distance < 0.5,
265            "matching a curve to itself should give small distance, got {}",
266            result.distance
267        );
268    }
269
270    #[test]
271    fn partial_match_subcurve() {
272        let m = 40;
273        let t = uniform_grid(m);
274        let target: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
275
276        // Template is roughly the middle portion
277        let m_t = 20;
278        let t_template = uniform_grid(m_t);
279        let template: Vec<f64> = t_template
280            .iter()
281            .map(|&x| ((x * 0.5 + 0.25) * 6.0).sin())
282            .collect();
283
284        let config = PartialMatchConfig {
285            min_span: 0.3,
286            ..Default::default()
287        };
288        let result = elastic_partial_match(&template, &target, &t_template, &t, &config).unwrap();
289
290        assert!(result.start_index < result.end_index);
291        assert!(result.domain_fraction >= 0.3);
292        assert!(result.gamma.len() == m_t);
293    }
294
295    #[test]
296    fn partial_match_rejects_bad_min_span() {
297        let t = uniform_grid(10);
298        let f: Vec<f64> = t.iter().map(|&x| x * x).collect();
299        let config = PartialMatchConfig {
300            min_span: 0.0,
301            ..Default::default()
302        };
303        assert!(elastic_partial_match(&f, &f, &t, &t, &config).is_err());
304    }
305
306    #[test]
307    fn partial_match_config_default() {
308        let config = PartialMatchConfig::default();
309        assert!((config.lambda - 0.0).abs() < f64::EPSILON);
310        assert!((config.min_span - 0.5).abs() < f64::EPSILON);
311    }
312}