use super::pairwise::elastic_align_pair;
use super::srsf::{reparameterize_curve, srsf_single};
use super::{dp_alignment_core, AlignmentResult};
use crate::error::FdarError;
use crate::helpers::{l2_distance, linear_interp, simpsons_weights};
use crate::warping::normalize_warp;
#[derive(Debug, Clone, PartialEq)]
pub struct MultiresConfig {
pub coarsen_factor: usize,
pub n_refine_steps: usize,
pub step_size: f64,
pub lambda: f64,
}
impl Default for MultiresConfig {
fn default() -> Self {
Self {
coarsen_factor: 4,
n_refine_steps: 10,
step_size: 0.01,
lambda: 0.0,
}
}
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn elastic_align_pair_multires(
f1: &[f64],
f2: &[f64],
argvals: &[f64],
config: &MultiresConfig,
) -> Result<AlignmentResult, FdarError> {
let m = f1.len();
if m != f2.len() || m != argvals.len() {
return Err(FdarError::InvalidDimension {
parameter: "f1/f2/argvals",
expected: format!("equal lengths, f1 has {m}"),
actual: format!("f2 has {}, argvals has {}", f2.len(), argvals.len()),
});
}
if m < 2 {
return Err(FdarError::InvalidDimension {
parameter: "f1",
expected: "length >= 2".to_string(),
actual: format!("length {m}"),
});
}
if config.coarsen_factor < 2 {
return Err(FdarError::InvalidParameter {
parameter: "coarsen_factor",
message: format!("must be >= 2, got {}", config.coarsen_factor),
});
}
if m < 2 * config.coarsen_factor {
let result = elastic_align_pair(f1, f2, argvals, config.lambda);
return Ok(result);
}
let q1 = srsf_single(f1, argvals);
let q2 = srsf_single(f2, argvals);
let m_coarse = (m / config.coarsen_factor).max(4);
let coarse_argvals = subsample_grid(argvals, m_coarse);
let coarse_q1 = subsample_values(&q1, argvals, &coarse_argvals);
let coarse_q2 = subsample_values(&q2, argvals, &coarse_argvals);
let coarse_gamma = dp_alignment_core(&coarse_q1, &coarse_q2, &coarse_argvals, config.lambda);
let mut gamma: Vec<f64> = argvals
.iter()
.map(|&t| linear_interp(&coarse_argvals, &coarse_gamma, t))
.collect();
normalize_warp(&mut gamma, argvals);
for _ in 0..config.n_refine_steps {
let f2_warped = reparameterize_curve(f2, argvals, &gamma);
let q2_warped = srsf_single(&f2_warped, argvals);
let h = 1.0 / (m as f64 * 10.0);
let weights = simpsons_weights(argvals);
let _current_dist = l2_distance(&q1, &q2_warped, &weights);
let mut improved = false;
for j in 1..m - 1 {
let orig = gamma[j];
gamma[j] = orig + h;
if gamma[j] <= gamma[j - 1] || gamma[j] >= gamma[j + 1] {
gamma[j] = orig;
continue;
}
let f2_pert = reparameterize_curve(f2, argvals, &gamma);
let q2_pert = srsf_single(&f2_pert, argvals);
let dist_plus = l2_distance(&q1, &q2_pert, &weights);
gamma[j] = orig - h;
if gamma[j] <= gamma[j - 1] || gamma[j] >= gamma[j + 1] {
gamma[j] = orig;
continue;
}
let f2_pert2 = reparameterize_curve(f2, argvals, &gamma);
let q2_pert2 = srsf_single(&f2_pert2, argvals);
let dist_minus = l2_distance(&q1, &q2_pert2, &weights);
let grad = (dist_plus - dist_minus) / (2.0 * h);
let new_val = orig - config.step_size * grad;
let lo = gamma[j - 1] + 1e-12;
let hi = gamma[j + 1] - 1e-12;
gamma[j] = new_val.clamp(lo, hi);
if (gamma[j] - orig).abs() > 1e-15 {
improved = true;
}
}
if !improved {
break;
}
normalize_warp(&mut gamma, argvals);
}
let f_aligned = reparameterize_curve(f2, argvals, &gamma);
let q_aligned = srsf_single(&f_aligned, argvals);
let weights = simpsons_weights(argvals);
let distance = l2_distance(&q1, &q_aligned, &weights);
Ok(AlignmentResult {
gamma,
f_aligned,
distance,
})
}
fn subsample_grid(argvals: &[f64], m_coarse: usize) -> Vec<f64> {
let m = argvals.len();
if m_coarse >= m {
return argvals.to_vec();
}
(0..m_coarse)
.map(|i| {
let idx_f = i as f64 * (m - 1) as f64 / (m_coarse - 1) as f64;
let lo = idx_f.floor() as usize;
let hi = idx_f.ceil().min((m - 1) as f64) as usize;
let frac = idx_f - lo as f64;
argvals[lo] * (1.0 - frac) + argvals[hi] * frac
})
.collect()
}
fn subsample_values(values: &[f64], fine_grid: &[f64], coarse_grid: &[f64]) -> Vec<f64> {
coarse_grid
.iter()
.map(|&t| linear_interp(fine_grid, values, t))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::uniform_grid;
#[test]
fn multires_identity() {
let m = 50;
let t = uniform_grid(m);
let f: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
let config = MultiresConfig::default();
let result = elastic_align_pair_multires(&f, &f, &t, &config).unwrap();
assert!(
result.distance < 0.5,
"identical curves should have near-zero distance, got {}",
result.distance
);
}
#[test]
fn multires_phase_shifted() {
let m = 60;
let t = uniform_grid(m);
let f1: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
let f2: Vec<f64> = t.iter().map(|&x| ((x + 0.1) * 6.0).sin()).collect();
let config = MultiresConfig::default();
let result = elastic_align_pair_multires(&f1, &f2, &t, &config).unwrap();
let standard = elastic_align_pair(&f1, &f2, &t, 0.0);
assert!(
result.distance < standard.distance * 2.0 + 0.5,
"multi-res distance ({}) should be comparable to standard ({})",
result.distance,
standard.distance,
);
}
#[test]
fn multires_falls_back_short_curves() {
let m = 6;
let t = uniform_grid(m);
let f1: Vec<f64> = t.iter().map(|&x| x * x).collect();
let f2: Vec<f64> = t.iter().map(|&x| x * x + 0.1).collect();
let config = MultiresConfig {
coarsen_factor: 4,
..Default::default()
};
let result = elastic_align_pair_multires(&f1, &f2, &t, &config).unwrap();
assert_eq!(result.gamma.len(), m);
assert_eq!(result.f_aligned.len(), m);
}
#[test]
fn multires_rejects_bad_coarsen_factor() {
let t = uniform_grid(20);
let f: Vec<f64> = t.to_vec();
let config = MultiresConfig {
coarsen_factor: 1,
..Default::default()
};
assert!(elastic_align_pair_multires(&f, &f, &t, &config).is_err());
}
#[test]
fn multires_config_default() {
let config = MultiresConfig::default();
assert_eq!(config.coarsen_factor, 4);
assert_eq!(config.n_refine_steps, 10);
assert!((config.step_size - 0.01).abs() < f64::EPSILON);
assert!((config.lambda - 0.0).abs() < f64::EPSILON);
}
}