Skip to main content

fdars_core/alignment/
multires.rs

1//! Multi-resolution elastic alignment: coarse DP + fine gradient refinement.
2//!
3//! Standard DP alignment has O(m²) complexity. Multi-resolution alignment
4//! runs DP on a coarsened grid first, then refines the warp using gradient
5//! descent on the original resolution, giving faster alignment for long curves.
6
7use super::pairwise::elastic_align_pair;
8use super::srsf::{reparameterize_curve, srsf_single};
9use super::{dp_alignment_core, AlignmentResult};
10use crate::error::FdarError;
11use crate::helpers::{l2_distance, linear_interp, simpsons_weights};
12use crate::warping::normalize_warp;
13
14// ─── Types ──────────────────────────────────────────────────────────────────
15
16/// Configuration for multi-resolution alignment.
17#[derive(Debug, Clone, PartialEq)]
18pub struct MultiresConfig {
19    /// Coarsening factor: the coarse grid has `m / coarsen_factor` points.
20    /// Must be >= 2. Default 4.
21    pub coarsen_factor: usize,
22    /// Number of gradient refinement steps on the fine grid.
23    /// Default 10.
24    pub n_refine_steps: usize,
25    /// Gradient descent step size for refinement.
26    /// Default 0.01.
27    pub step_size: f64,
28    /// Roughness penalty for elastic alignment (0.0 = no penalty).
29    pub lambda: f64,
30}
31
32impl Default for MultiresConfig {
33    fn default() -> Self {
34        Self {
35            coarsen_factor: 4,
36            n_refine_steps: 10,
37            step_size: 0.01,
38            lambda: 0.0,
39        }
40    }
41}
42
43// ─── Public API ─────────────────────────────────────────────────────────────
44
45/// Align curve `f2` to `f1` using multi-resolution elastic alignment.
46///
47/// 1. **Coarse stage**: Subsample both SRSFs to a coarser grid, run DP,
48///    interpolate the resulting warp back to full resolution.
49/// 2. **Fine stage**: Starting from the coarse warp, run gradient descent
50///    steps to locally refine the warp on the full-resolution grid.
51///
52/// For short curves (m < 2 * coarsen_factor), falls back to standard DP.
53///
54/// # Arguments
55/// * `f1` — Target curve (length m)
56/// * `f2` — Curve to align (length m)
57/// * `argvals` — Evaluation points (length m)
58/// * `config` — Multi-resolution configuration
59///
60/// # Errors
61/// Returns [`FdarError::InvalidDimension`] if lengths do not match or m < 2.
62/// Returns [`FdarError::InvalidParameter`] if `coarsen_factor < 2`.
63#[must_use = "expensive computation whose result should not be discarded"]
64pub fn elastic_align_pair_multires(
65    f1: &[f64],
66    f2: &[f64],
67    argvals: &[f64],
68    config: &MultiresConfig,
69) -> Result<AlignmentResult, FdarError> {
70    let m = f1.len();
71
72    if m != f2.len() || m != argvals.len() {
73        return Err(FdarError::InvalidDimension {
74            parameter: "f1/f2/argvals",
75            expected: format!("equal lengths, f1 has {m}"),
76            actual: format!("f2 has {}, argvals has {}", f2.len(), argvals.len()),
77        });
78    }
79    if m < 2 {
80        return Err(FdarError::InvalidDimension {
81            parameter: "f1",
82            expected: "length >= 2".to_string(),
83            actual: format!("length {m}"),
84        });
85    }
86    if config.coarsen_factor < 2 {
87        return Err(FdarError::InvalidParameter {
88            parameter: "coarsen_factor",
89            message: format!("must be >= 2, got {}", config.coarsen_factor),
90        });
91    }
92
93    // For short curves, fall back to standard alignment
94    if m < 2 * config.coarsen_factor {
95        let result = elastic_align_pair(f1, f2, argvals, config.lambda);
96        return Ok(result);
97    }
98
99    let q1 = srsf_single(f1, argvals);
100    let q2 = srsf_single(f2, argvals);
101
102    // ── Stage 1: Coarse DP ──
103    let m_coarse = (m / config.coarsen_factor).max(4);
104    let coarse_argvals = subsample_grid(argvals, m_coarse);
105    let coarse_q1 = subsample_values(&q1, argvals, &coarse_argvals);
106    let coarse_q2 = subsample_values(&q2, argvals, &coarse_argvals);
107
108    let coarse_gamma = dp_alignment_core(&coarse_q1, &coarse_q2, &coarse_argvals, config.lambda);
109
110    // Interpolate coarse warp to fine grid
111    let mut gamma: Vec<f64> = argvals
112        .iter()
113        .map(|&t| linear_interp(&coarse_argvals, &coarse_gamma, t))
114        .collect();
115    normalize_warp(&mut gamma, argvals);
116
117    // ── Stage 2: Gradient refinement ──
118    for _ in 0..config.n_refine_steps {
119        // Compute current cost and gradient
120        let f2_warped = reparameterize_curve(f2, argvals, &gamma);
121        let q2_warped = srsf_single(&f2_warped, argvals);
122
123        // Approximate gradient: dJ/dγ_j ≈ -2(q1_j - q2_warped_j) * dq2/dγ_j
124        // We use a finite-difference approximation for simplicity
125        let h = 1.0 / (m as f64 * 10.0);
126        let weights = simpsons_weights(argvals);
127        let _current_dist = l2_distance(&q1, &q2_warped, &weights);
128
129        let mut improved = false;
130        for j in 1..m - 1 {
131            // Perturb gamma[j] and measure cost change
132            let orig = gamma[j];
133
134            gamma[j] = orig + h;
135            // Ensure monotonicity
136            if gamma[j] <= gamma[j - 1] || gamma[j] >= gamma[j + 1] {
137                gamma[j] = orig;
138                continue;
139            }
140
141            let f2_pert = reparameterize_curve(f2, argvals, &gamma);
142            let q2_pert = srsf_single(&f2_pert, argvals);
143            let dist_plus = l2_distance(&q1, &q2_pert, &weights);
144
145            gamma[j] = orig - h;
146            if gamma[j] <= gamma[j - 1] || gamma[j] >= gamma[j + 1] {
147                gamma[j] = orig;
148                continue;
149            }
150
151            let f2_pert2 = reparameterize_curve(f2, argvals, &gamma);
152            let q2_pert2 = srsf_single(&f2_pert2, argvals);
153            let dist_minus = l2_distance(&q1, &q2_pert2, &weights);
154
155            // Central difference gradient
156            let grad = (dist_plus - dist_minus) / (2.0 * h);
157
158            // Gradient step
159            let new_val = orig - config.step_size * grad;
160            // Clamp to maintain monotonicity
161            let lo = gamma[j - 1] + 1e-12;
162            let hi = gamma[j + 1] - 1e-12;
163            gamma[j] = new_val.clamp(lo, hi);
164
165            if (gamma[j] - orig).abs() > 1e-15 {
166                improved = true;
167            }
168        }
169
170        if !improved {
171            break;
172        }
173
174        normalize_warp(&mut gamma, argvals);
175    }
176
177    // ── Final alignment ──
178    let f_aligned = reparameterize_curve(f2, argvals, &gamma);
179    let q_aligned = srsf_single(&f_aligned, argvals);
180    let weights = simpsons_weights(argvals);
181    let distance = l2_distance(&q1, &q_aligned, &weights);
182
183    Ok(AlignmentResult {
184        gamma,
185        f_aligned,
186        distance,
187    })
188}
189
190// ─── Helpers ────────────────────────────────────────────────────────────────
191
192/// Create a uniform subsample of a grid.
193fn subsample_grid(argvals: &[f64], m_coarse: usize) -> Vec<f64> {
194    let m = argvals.len();
195    if m_coarse >= m {
196        return argvals.to_vec();
197    }
198    (0..m_coarse)
199        .map(|i| {
200            let idx_f = i as f64 * (m - 1) as f64 / (m_coarse - 1) as f64;
201            let lo = idx_f.floor() as usize;
202            let hi = idx_f.ceil().min((m - 1) as f64) as usize;
203            let frac = idx_f - lo as f64;
204            argvals[lo] * (1.0 - frac) + argvals[hi] * frac
205        })
206        .collect()
207}
208
209/// Interpolate values from the fine grid to a coarser grid.
210fn subsample_values(values: &[f64], fine_grid: &[f64], coarse_grid: &[f64]) -> Vec<f64> {
211    coarse_grid
212        .iter()
213        .map(|&t| linear_interp(fine_grid, values, t))
214        .collect()
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use crate::test_helpers::uniform_grid;
221
222    #[test]
223    fn multires_identity() {
224        let m = 50;
225        let t = uniform_grid(m);
226        let f: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
227
228        let config = MultiresConfig::default();
229        let result = elastic_align_pair_multires(&f, &f, &t, &config).unwrap();
230
231        assert!(
232            result.distance < 0.5,
233            "identical curves should have near-zero distance, got {}",
234            result.distance
235        );
236    }
237
238    #[test]
239    fn multires_phase_shifted() {
240        let m = 60;
241        let t = uniform_grid(m);
242        let f1: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
243        let f2: Vec<f64> = t.iter().map(|&x| ((x + 0.1) * 6.0).sin()).collect();
244
245        let config = MultiresConfig::default();
246        let result = elastic_align_pair_multires(&f1, &f2, &t, &config).unwrap();
247
248        // Should produce a reasonable alignment
249        let standard = elastic_align_pair(&f1, &f2, &t, 0.0);
250        // Multi-res may be slightly worse but should not be dramatically worse
251        assert!(
252            result.distance < standard.distance * 2.0 + 0.5,
253            "multi-res distance ({}) should be comparable to standard ({})",
254            result.distance,
255            standard.distance,
256        );
257    }
258
259    #[test]
260    fn multires_falls_back_short_curves() {
261        let m = 6;
262        let t = uniform_grid(m);
263        let f1: Vec<f64> = t.iter().map(|&x| x * x).collect();
264        let f2: Vec<f64> = t.iter().map(|&x| x * x + 0.1).collect();
265
266        let config = MultiresConfig {
267            coarsen_factor: 4,
268            ..Default::default()
269        };
270        let result = elastic_align_pair_multires(&f1, &f2, &t, &config).unwrap();
271        assert_eq!(result.gamma.len(), m);
272        assert_eq!(result.f_aligned.len(), m);
273    }
274
275    #[test]
276    fn multires_rejects_bad_coarsen_factor() {
277        let t = uniform_grid(20);
278        let f: Vec<f64> = t.to_vec();
279        let config = MultiresConfig {
280            coarsen_factor: 1,
281            ..Default::default()
282        };
283        assert!(elastic_align_pair_multires(&f, &f, &t, &config).is_err());
284    }
285
286    #[test]
287    fn multires_config_default() {
288        let config = MultiresConfig::default();
289        assert_eq!(config.coarsen_factor, 4);
290        assert_eq!(config.n_refine_steps, 10);
291        assert!((config.step_size - 0.01).abs() < f64::EPSILON);
292        assert!((config.lambda - 0.0).abs() < f64::EPSILON);
293    }
294}