1use super::pairwise::elastic_align_pair;
8use super::srsf::{reparameterize_curve, srsf_single};
9use super::{dp_alignment_core, AlignmentResult};
10use crate::error::FdarError;
11use crate::helpers::{l2_distance, linear_interp, simpsons_weights};
12use crate::warping::normalize_warp;
13
14#[derive(Debug, Clone, PartialEq)]
18pub struct MultiresConfig {
19 pub coarsen_factor: usize,
22 pub n_refine_steps: usize,
25 pub step_size: f64,
28 pub lambda: f64,
30}
31
32impl Default for MultiresConfig {
33 fn default() -> Self {
34 Self {
35 coarsen_factor: 4,
36 n_refine_steps: 10,
37 step_size: 0.01,
38 lambda: 0.0,
39 }
40 }
41}
42
43#[must_use = "expensive computation whose result should not be discarded"]
64pub fn elastic_align_pair_multires(
65 f1: &[f64],
66 f2: &[f64],
67 argvals: &[f64],
68 config: &MultiresConfig,
69) -> Result<AlignmentResult, FdarError> {
70 let m = f1.len();
71
72 if m != f2.len() || m != argvals.len() {
73 return Err(FdarError::InvalidDimension {
74 parameter: "f1/f2/argvals",
75 expected: format!("equal lengths, f1 has {m}"),
76 actual: format!("f2 has {}, argvals has {}", f2.len(), argvals.len()),
77 });
78 }
79 if m < 2 {
80 return Err(FdarError::InvalidDimension {
81 parameter: "f1",
82 expected: "length >= 2".to_string(),
83 actual: format!("length {m}"),
84 });
85 }
86 if config.coarsen_factor < 2 {
87 return Err(FdarError::InvalidParameter {
88 parameter: "coarsen_factor",
89 message: format!("must be >= 2, got {}", config.coarsen_factor),
90 });
91 }
92
93 if m < 2 * config.coarsen_factor {
95 let result = elastic_align_pair(f1, f2, argvals, config.lambda);
96 return Ok(result);
97 }
98
99 let q1 = srsf_single(f1, argvals);
100 let q2 = srsf_single(f2, argvals);
101
102 let m_coarse = (m / config.coarsen_factor).max(4);
104 let coarse_argvals = subsample_grid(argvals, m_coarse);
105 let coarse_q1 = subsample_values(&q1, argvals, &coarse_argvals);
106 let coarse_q2 = subsample_values(&q2, argvals, &coarse_argvals);
107
108 let coarse_gamma = dp_alignment_core(&coarse_q1, &coarse_q2, &coarse_argvals, config.lambda);
109
110 let mut gamma: Vec<f64> = argvals
112 .iter()
113 .map(|&t| linear_interp(&coarse_argvals, &coarse_gamma, t))
114 .collect();
115 normalize_warp(&mut gamma, argvals);
116
117 for _ in 0..config.n_refine_steps {
119 let f2_warped = reparameterize_curve(f2, argvals, &gamma);
121 let q2_warped = srsf_single(&f2_warped, argvals);
122
123 let h = 1.0 / (m as f64 * 10.0);
126 let weights = simpsons_weights(argvals);
127 let _current_dist = l2_distance(&q1, &q2_warped, &weights);
128
129 let mut improved = false;
130 for j in 1..m - 1 {
131 let orig = gamma[j];
133
134 gamma[j] = orig + h;
135 if gamma[j] <= gamma[j - 1] || gamma[j] >= gamma[j + 1] {
137 gamma[j] = orig;
138 continue;
139 }
140
141 let f2_pert = reparameterize_curve(f2, argvals, &gamma);
142 let q2_pert = srsf_single(&f2_pert, argvals);
143 let dist_plus = l2_distance(&q1, &q2_pert, &weights);
144
145 gamma[j] = orig - h;
146 if gamma[j] <= gamma[j - 1] || gamma[j] >= gamma[j + 1] {
147 gamma[j] = orig;
148 continue;
149 }
150
151 let f2_pert2 = reparameterize_curve(f2, argvals, &gamma);
152 let q2_pert2 = srsf_single(&f2_pert2, argvals);
153 let dist_minus = l2_distance(&q1, &q2_pert2, &weights);
154
155 let grad = (dist_plus - dist_minus) / (2.0 * h);
157
158 let new_val = orig - config.step_size * grad;
160 let lo = gamma[j - 1] + 1e-12;
162 let hi = gamma[j + 1] - 1e-12;
163 gamma[j] = new_val.clamp(lo, hi);
164
165 if (gamma[j] - orig).abs() > 1e-15 {
166 improved = true;
167 }
168 }
169
170 if !improved {
171 break;
172 }
173
174 normalize_warp(&mut gamma, argvals);
175 }
176
177 let f_aligned = reparameterize_curve(f2, argvals, &gamma);
179 let q_aligned = srsf_single(&f_aligned, argvals);
180 let weights = simpsons_weights(argvals);
181 let distance = l2_distance(&q1, &q_aligned, &weights);
182
183 Ok(AlignmentResult {
184 gamma,
185 f_aligned,
186 distance,
187 })
188}
189
190fn subsample_grid(argvals: &[f64], m_coarse: usize) -> Vec<f64> {
194 let m = argvals.len();
195 if m_coarse >= m {
196 return argvals.to_vec();
197 }
198 (0..m_coarse)
199 .map(|i| {
200 let idx_f = i as f64 * (m - 1) as f64 / (m_coarse - 1) as f64;
201 let lo = idx_f.floor() as usize;
202 let hi = idx_f.ceil().min((m - 1) as f64) as usize;
203 let frac = idx_f - lo as f64;
204 argvals[lo] * (1.0 - frac) + argvals[hi] * frac
205 })
206 .collect()
207}
208
209fn subsample_values(values: &[f64], fine_grid: &[f64], coarse_grid: &[f64]) -> Vec<f64> {
211 coarse_grid
212 .iter()
213 .map(|&t| linear_interp(fine_grid, values, t))
214 .collect()
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use crate::test_helpers::uniform_grid;
221
222 #[test]
223 fn multires_identity() {
224 let m = 50;
225 let t = uniform_grid(m);
226 let f: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
227
228 let config = MultiresConfig::default();
229 let result = elastic_align_pair_multires(&f, &f, &t, &config).unwrap();
230
231 assert!(
232 result.distance < 0.5,
233 "identical curves should have near-zero distance, got {}",
234 result.distance
235 );
236 }
237
238 #[test]
239 fn multires_phase_shifted() {
240 let m = 60;
241 let t = uniform_grid(m);
242 let f1: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
243 let f2: Vec<f64> = t.iter().map(|&x| ((x + 0.1) * 6.0).sin()).collect();
244
245 let config = MultiresConfig::default();
246 let result = elastic_align_pair_multires(&f1, &f2, &t, &config).unwrap();
247
248 let standard = elastic_align_pair(&f1, &f2, &t, 0.0);
250 assert!(
252 result.distance < standard.distance * 2.0 + 0.5,
253 "multi-res distance ({}) should be comparable to standard ({})",
254 result.distance,
255 standard.distance,
256 );
257 }
258
259 #[test]
260 fn multires_falls_back_short_curves() {
261 let m = 6;
262 let t = uniform_grid(m);
263 let f1: Vec<f64> = t.iter().map(|&x| x * x).collect();
264 let f2: Vec<f64> = t.iter().map(|&x| x * x + 0.1).collect();
265
266 let config = MultiresConfig {
267 coarsen_factor: 4,
268 ..Default::default()
269 };
270 let result = elastic_align_pair_multires(&f1, &f2, &t, &config).unwrap();
271 assert_eq!(result.gamma.len(), m);
272 assert_eq!(result.f_aligned.len(), m);
273 }
274
275 #[test]
276 fn multires_rejects_bad_coarsen_factor() {
277 let t = uniform_grid(20);
278 let f: Vec<f64> = t.to_vec();
279 let config = MultiresConfig {
280 coarsen_factor: 1,
281 ..Default::default()
282 };
283 assert!(elastic_align_pair_multires(&f, &f, &t, &config).is_err());
284 }
285
286 #[test]
287 fn multires_config_default() {
288 let config = MultiresConfig::default();
289 assert_eq!(config.coarsen_factor, 4);
290 assert_eq!(config.n_refine_steps, 10);
291 assert!((config.step_size - 0.01).abs() < f64::EPSILON);
292 assert!((config.lambda - 0.0).abs() < f64::EPSILON);
293 }
294}