use super::dp_alignment_core;
use super::pairwise::elastic_align_pair;
use super::srsf::{reparameterize_curve, srsf_single};
use crate::error::FdarError;
use crate::helpers::{l2_distance, simpsons_weights};
use crate::iter_maybe_parallel;
use crate::matrix::FdMatrix;
#[cfg(feature = "parallel")]
use rayon::iter::ParallelIterator;
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct ClosedAlignmentResult {
pub gamma: Vec<f64>,
pub f_aligned: Vec<f64>,
pub distance: f64,
pub optimal_rotation: usize,
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct ClosedKarcherMeanResult {
pub mean: Vec<f64>,
pub mean_srsf: Vec<f64>,
pub gammas: FdMatrix,
pub aligned_data: FdMatrix,
pub rotations: Vec<usize>,
pub n_iter: usize,
pub converged: bool,
}
fn circular_shift(f: &[f64], k: usize) -> Vec<f64> {
let m = f.len();
if m == 0 || k == 0 {
return f.to_vec();
}
let k = k % m;
(0..m).map(|j| f[(j + k) % m]).collect()
}
fn find_best_rotation(f1: &[f64], f2: &[f64], argvals: &[f64], lambda: f64) -> (usize, f64) {
let m = f1.len();
if m < 2 {
return (0, 0.0);
}
let step_size = (m / 20).max(1);
let mut best_k = 0;
let mut best_dist = f64::INFINITY;
let mut k = 0;
while k < m {
let f2_rot = circular_shift(f2, k);
let q1 = srsf_single(f1, argvals);
let q2 = srsf_single(&f2_rot, argvals);
let gamma = dp_alignment_core(&q1, &q2, argvals, lambda);
let f_aligned = reparameterize_curve(&f2_rot, argvals, &gamma);
let q_aligned = srsf_single(&f_aligned, argvals);
let weights = simpsons_weights(argvals);
let dist = l2_distance(&q1, &q_aligned, &weights);
if dist < best_dist {
best_dist = dist;
best_k = k;
}
k += step_size;
}
let search_start = best_k.saturating_sub(step_size);
let search_end = (best_k + step_size).min(m);
for k in search_start..search_end {
if k % step_size == 0 {
continue; }
let f2_rot = circular_shift(f2, k);
let q1 = srsf_single(f1, argvals);
let q2 = srsf_single(&f2_rot, argvals);
let gamma = dp_alignment_core(&q1, &q2, argvals, lambda);
let f_aligned = reparameterize_curve(&f2_rot, argvals, &gamma);
let q_aligned = srsf_single(&f_aligned, argvals);
let weights = simpsons_weights(argvals);
let dist = l2_distance(&q1, &q_aligned, &weights);
if dist < best_dist {
best_dist = dist;
best_k = k;
}
}
(best_k, best_dist)
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn elastic_align_pair_closed(
f1: &[f64],
f2: &[f64],
argvals: &[f64],
lambda: f64,
) -> Result<ClosedAlignmentResult, FdarError> {
let m = f1.len();
if m != f2.len() || m != argvals.len() {
return Err(FdarError::InvalidDimension {
parameter: "f1/f2/argvals",
expected: format!("equal lengths, f1 has {m}"),
actual: format!("f2 has {}, argvals has {}", f2.len(), argvals.len()),
});
}
if m < 2 {
return Err(FdarError::InvalidDimension {
parameter: "f1",
expected: "length >= 2".to_string(),
actual: format!("length {m}"),
});
}
let (best_k, _) = find_best_rotation(f1, f2, argvals, lambda);
let f2_rotated = circular_shift(f2, best_k);
let result = elastic_align_pair(f1, &f2_rotated, argvals, lambda);
Ok(ClosedAlignmentResult {
gamma: result.gamma,
f_aligned: result.f_aligned,
distance: result.distance,
optimal_rotation: best_k,
})
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn elastic_distance_closed(
f1: &[f64],
f2: &[f64],
argvals: &[f64],
lambda: f64,
) -> Result<f64, FdarError> {
Ok(elastic_align_pair_closed(f1, f2, argvals, lambda)?.distance)
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn karcher_mean_closed(
data: &FdMatrix,
argvals: &[f64],
max_iter: usize,
tol: f64,
lambda: f64,
) -> Result<ClosedKarcherMeanResult, FdarError> {
let (n, m) = data.shape();
if m != argvals.len() {
return Err(FdarError::InvalidDimension {
parameter: "argvals",
expected: format!("length {m}"),
actual: format!("length {}", argvals.len()),
});
}
if m < 2 {
return Err(FdarError::InvalidDimension {
parameter: "data",
expected: "ncols >= 2".to_string(),
actual: format!("ncols = {m}"),
});
}
if n == 0 {
return Err(FdarError::InvalidDimension {
parameter: "data",
expected: "nrows > 0".to_string(),
actual: "nrows = 0".to_string(),
});
}
let mut mu: Vec<f64> = data.row(0);
let mut mu_q = srsf_single(&mu, argvals);
let mut gammas = FdMatrix::zeros(n, m);
let mut rotations = vec![0usize; n];
let mut converged = false;
let mut n_iter = 0;
for iter in 0..max_iter {
n_iter = iter + 1;
let align_results: Vec<(ClosedAlignmentResult, Vec<f64>)> = iter_maybe_parallel!(0..n)
.map(|i| {
let fi = data.row(i);
let res = elastic_align_pair_closed(&mu, &fi, argvals, lambda)
.expect("dimension invariant: all curves have length m");
let q_warped = srsf_single(&res.f_aligned, argvals);
(res, q_warped)
})
.collect();
let mut mu_q_new = vec![0.0; m];
for (i, (res, q_aligned)) in align_results.iter().enumerate() {
for j in 0..m {
gammas[(i, j)] = res.gamma[j];
mu_q_new[j] += q_aligned[j];
}
rotations[i] = res.optimal_rotation;
}
for j in 0..m {
mu_q_new[j] /= n as f64;
}
let diff_norm: f64 = mu_q
.iter()
.zip(mu_q_new.iter())
.map(|(&a, &b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
let old_norm: f64 = mu_q.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
let rel = diff_norm / old_norm;
mu_q = mu_q_new;
if rel < tol {
converged = true;
break;
}
mu = crate::alignment::srsf::srsf_inverse(&mu_q, argvals, mu[0]);
}
let mut aligned_data = FdMatrix::zeros(n, m);
for i in 0..n {
let fi = data.row(i);
let f_rotated = circular_shift(&fi, rotations[i]);
let gamma_i: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
let f_aligned = reparameterize_curve(&f_rotated, argvals, &gamma_i);
for j in 0..m {
aligned_data[(i, j)] = f_aligned[j];
}
}
mu = crate::alignment::srsf::srsf_inverse(&mu_q, argvals, mu[0]);
Ok(ClosedKarcherMeanResult {
mean: mu,
mean_srsf: mu_q,
gammas,
aligned_data,
rotations,
n_iter,
converged,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::uniform_grid;
#[test]
fn closed_align_identity() {
let m = 30;
let argvals = uniform_grid(m);
let f: Vec<f64> = argvals
.iter()
.map(|&t| (2.0 * std::f64::consts::PI * t).sin())
.collect();
let result = elastic_align_pair_closed(&f, &f, &argvals, 0.0).unwrap();
assert!(
result.distance < 0.1,
"identical closed curves should have near-zero distance, got {}",
result.distance
);
assert_eq!(
result.optimal_rotation, 0,
"identical curves should need no rotation"
);
}
#[test]
fn closed_align_shifted() {
let m = 40;
let argvals = uniform_grid(m);
let f1: Vec<f64> = argvals
.iter()
.map(|&t| (2.0 * std::f64::consts::PI * t).sin() + 0.5 * t)
.collect();
let shift = 5;
let f2 = circular_shift(&f1, shift);
let result = elastic_align_pair_closed(&f1, &f2, &argvals, 0.0).unwrap();
assert!(
result.distance < 1.0,
"distance after closed alignment should be small, got {}",
result.distance
);
}
#[test]
fn closed_distance_symmetric() {
let m = 25;
let argvals = uniform_grid(m);
let f1: Vec<f64> = argvals
.iter()
.map(|&t| (2.0 * std::f64::consts::PI * t).sin())
.collect();
let f2: Vec<f64> = argvals
.iter()
.map(|&t| (2.0 * std::f64::consts::PI * t).cos())
.collect();
let d12 = elastic_distance_closed(&f1, &f2, &argvals, 0.0).unwrap();
let d21 = elastic_distance_closed(&f2, &f1, &argvals, 0.0).unwrap();
assert!(
d12 >= 0.0 && d12.is_finite(),
"d12 should be non-negative finite, got {d12}"
);
assert!(
d21 >= 0.0 && d21.is_finite(),
"d21 should be non-negative finite, got {d21}"
);
assert!(
d12.max(d21) < 2.0 * d12.min(d21) + 0.5,
"closed distances should be in comparable range: d12={d12:.4}, d21={d21:.4}"
);
}
#[test]
fn closed_karcher_mean_smoke() {
let n = 5;
let m = 25;
let argvals = uniform_grid(m);
let mut data_flat = vec![0.0; n * m];
for i in 0..n {
let shift = i as f64 * 0.1;
for j in 0..m {
let t = argvals[j];
data_flat[i + j * n] = (2.0 * std::f64::consts::PI * (t + shift)).sin();
}
}
let data = FdMatrix::from_column_major(data_flat, n, m).unwrap();
let result = karcher_mean_closed(&data, &argvals, 10, 1e-3, 0.0).unwrap();
assert_eq!(result.mean.len(), m);
assert_eq!(result.mean_srsf.len(), m);
assert_eq!(result.gammas.shape(), (n, m));
assert_eq!(result.aligned_data.shape(), (n, m));
assert_eq!(result.rotations.len(), n);
assert!(result.n_iter <= 10);
}
}