use super::quality::{warp_complexity, warp_smoothness};
use super::{AlignmentResult, KarcherMeanResult};
use crate::error::FdarError;
use crate::helpers::simpsons_weights;
use crate::matrix::FdMatrix;
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct AlignmentDiagnostic {
pub curve_index: usize,
pub warp_complexity: f64,
pub warp_smoothness: f64,
pub is_under_aligned: bool,
pub is_over_aligned: bool,
pub has_non_monotone: bool,
pub residual: f64,
pub distance_ratio: f64,
pub flagged: bool,
pub issues: Vec<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct DiagnosticConfig {
pub over_alignment_threshold: f64,
pub under_alignment_threshold: f64,
pub max_bending_energy: f64,
pub min_improvement_ratio: f64,
}
impl Default for DiagnosticConfig {
fn default() -> Self {
Self {
over_alignment_threshold: 1.0,
under_alignment_threshold: 1e-6,
max_bending_energy: 100.0,
min_improvement_ratio: 0.5,
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct AlignmentDiagnosticSummary {
pub diagnostics: Vec<AlignmentDiagnostic>,
pub flagged_indices: Vec<usize>,
pub n_flagged: usize,
pub health_score: f64,
}
fn weighted_l2(a: &[f64], b: &[f64], weights: &[f64]) -> f64 {
let mut sum = 0.0;
for i in 0..a.len() {
let d = a[i] - b[i];
sum += d * d * weights[i];
}
sum.sqrt()
}
fn is_non_monotone(gamma: &[f64]) -> bool {
gamma.windows(2).any(|w| w[1] < w[0])
}
fn build_diagnostic(
curve_index: usize,
gamma: &[f64],
argvals: &[f64],
pre_distance: f64,
residual: f64,
config: &DiagnosticConfig,
) -> AlignmentDiagnostic {
let wc = warp_complexity(gamma, argvals);
let ws = warp_smoothness(gamma, argvals);
let non_mono = is_non_monotone(gamma);
let distance_ratio = if pre_distance > 1e-15 {
residual / pre_distance
} else {
0.0
};
let is_over = wc > config.over_alignment_threshold;
let is_under = distance_ratio > config.min_improvement_ratio
&& pre_distance > config.under_alignment_threshold;
let mut issues = Vec::new();
if is_over {
issues.push(format!(
"warp complexity {wc:.4} exceeds threshold {}",
config.over_alignment_threshold
));
}
if is_under {
issues.push(format!(
"distance ratio {distance_ratio:.4} exceeds improvement threshold {}",
config.min_improvement_ratio
));
}
if non_mono {
issues.push("warp contains non-monotone segments".to_string());
}
if ws > config.max_bending_energy {
issues.push(format!(
"bending energy {ws:.2} exceeds threshold {}",
config.max_bending_energy
));
}
let flagged = !issues.is_empty();
AlignmentDiagnostic {
curve_index,
warp_complexity: wc,
warp_smoothness: ws,
is_under_aligned: is_under,
is_over_aligned: is_over,
has_non_monotone: non_mono,
residual,
distance_ratio,
flagged,
issues,
}
}
pub fn diagnose_alignment(
data: &FdMatrix,
karcher: &KarcherMeanResult,
argvals: &[f64],
config: &DiagnosticConfig,
) -> Result<AlignmentDiagnosticSummary, FdarError> {
let (n, m) = data.shape();
if argvals.len() != m {
return Err(FdarError::InvalidDimension {
parameter: "argvals",
expected: format!("{m}"),
actual: format!("{}", argvals.len()),
});
}
if karcher.gammas.nrows() != n || karcher.gammas.ncols() != m {
return Err(FdarError::InvalidDimension {
parameter: "karcher.gammas",
expected: format!("{n} x {m}"),
actual: format!("{} x {}", karcher.gammas.nrows(), karcher.gammas.ncols()),
});
}
if karcher.aligned_data.nrows() != n || karcher.aligned_data.ncols() != m {
return Err(FdarError::InvalidDimension {
parameter: "karcher.aligned_data",
expected: format!("{n} x {m}"),
actual: format!(
"{} x {}",
karcher.aligned_data.nrows(),
karcher.aligned_data.ncols()
),
});
}
if karcher.mean.len() != m {
return Err(FdarError::InvalidDimension {
parameter: "karcher.mean",
expected: format!("{m}"),
actual: format!("{}", karcher.mean.len()),
});
}
let weights = simpsons_weights(argvals);
let mut diagnostics = Vec::with_capacity(n);
let mut flagged_indices = Vec::new();
for i in 0..n {
let gamma_i: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
let fi = data.row(i);
let pre_distance = weighted_l2(&fi, &karcher.mean, &weights);
let fi_aligned = karcher.aligned_data.row(i);
let residual = weighted_l2(&fi_aligned, &karcher.mean, &weights);
let diag = build_diagnostic(i, &gamma_i, argvals, pre_distance, residual, config);
if diag.flagged {
flagged_indices.push(i);
}
diagnostics.push(diag);
}
let n_flagged = flagged_indices.len();
let health_score = if n > 0 {
1.0 - n_flagged as f64 / n as f64
} else {
1.0
};
Ok(AlignmentDiagnosticSummary {
diagnostics,
flagged_indices,
n_flagged,
health_score,
})
}
pub fn diagnose_pairwise(
f1: &[f64],
f2: &[f64],
result: &AlignmentResult,
argvals: &[f64],
config: &DiagnosticConfig,
) -> AlignmentDiagnostic {
let weights = simpsons_weights(argvals);
let pre_distance = weighted_l2(f1, f2, &weights);
let residual = weighted_l2(f1, &result.f_aligned, &weights);
build_diagnostic(0, &result.gamma, argvals, pre_distance, residual, config)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::alignment::karcher_mean;
use crate::alignment::pairwise::elastic_align_pair;
use crate::simulation::{sim_fundata, EFunType, EValType};
use crate::test_helpers::uniform_grid;
fn make_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
let t = uniform_grid(m);
let data = sim_fundata(n, &t, 3, EFunType::Fourier, EValType::Exponential, Some(99));
(data, t)
}
#[test]
fn diagnose_alignment_smoke() {
let (data, t) = make_data(8, 30);
let km = karcher_mean(&data, &t, 5, 1e-2, 0.0);
let config = DiagnosticConfig::default();
let summary = diagnose_alignment(&data, &km, &t, &config).unwrap();
assert_eq!(summary.diagnostics.len(), 8);
assert!(summary.health_score >= 0.0 && summary.health_score <= 1.0);
assert_eq!(summary.n_flagged, summary.flagged_indices.len());
}
#[test]
fn diagnose_alignment_identical_returns_low_complexity() {
let t = uniform_grid(30);
let curve: Vec<f64> = t.iter().map(|&x| x.sin()).collect();
let mut vals = Vec::with_capacity(5 * 30);
for _ in 0..5 {
vals.extend_from_slice(&curve);
}
let data = FdMatrix::from_column_major(vals, 5, 30).unwrap();
let km = karcher_mean(&data, &t, 5, 1e-3, 0.0);
let config = DiagnosticConfig::default();
let summary = diagnose_alignment(&data, &km, &t, &config).unwrap();
assert_eq!(summary.diagnostics.len(), 5);
for d in &summary.diagnostics {
assert!(
d.warp_complexity < 0.5,
"curve {} warp_complexity {} should be small for identical data",
d.curve_index,
d.warp_complexity,
);
}
}
#[test]
fn diagnose_alignment_rejects_shape_mismatch() {
let (data, t) = make_data(6, 30);
let km = karcher_mean(&data, &t, 3, 1e-2, 0.0);
let bad_t = uniform_grid(20);
let config = DiagnosticConfig::default();
assert!(diagnose_alignment(&data, &km, &bad_t, &config).is_err());
}
#[test]
fn diagnose_pairwise_smoke() {
let t = uniform_grid(30);
let f1: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
let f2: Vec<f64> = t.iter().map(|&x| ((x + 0.15) * 6.0).sin()).collect();
let alignment = elastic_align_pair(&f1, &f2, &t, 0.0);
let config = DiagnosticConfig::default();
let diag = diagnose_pairwise(&f1, &f2, &alignment, &t, &config);
assert!(diag.warp_complexity >= 0.0);
assert!(diag.residual >= 0.0);
}
#[test]
fn diagnose_pairwise_identical() {
let t = uniform_grid(30);
let f: Vec<f64> = t.iter().map(|&x| x.sin()).collect();
let alignment = elastic_align_pair(&f, &f, &t, 0.0);
let config = DiagnosticConfig::default();
let diag = diagnose_pairwise(&f, &f, &alignment, &t, &config);
assert!(
diag.residual < 1e-3,
"identical curves should have near-zero residual"
);
assert!(!diag.has_non_monotone);
}
}