Skip to main content

fdars_core/alignment/
diagnostics.rs

1//! Registration failure detection and alignment diagnostics.
2
3use super::quality::{warp_complexity, warp_smoothness};
4use super::{AlignmentResult, KarcherMeanResult};
5use crate::error::FdarError;
6use crate::helpers::simpsons_weights;
7use crate::matrix::FdMatrix;
8
9// ─── Types ───────────────────────────────────────────────────────────────────
10
11/// Diagnostic information for a single curve's alignment.
12#[derive(Debug, Clone, PartialEq)]
13#[non_exhaustive]
14pub struct AlignmentDiagnostic {
15    /// Index of the curve in the original dataset (or 0 for pairwise).
16    pub curve_index: usize,
17    /// Geodesic distance from the warp to the identity.
18    pub warp_complexity: f64,
19    /// Bending energy of the warp.
20    pub warp_smoothness: f64,
21    /// True if the residual is barely reduced (possible under-alignment).
22    pub is_under_aligned: bool,
23    /// True if warp complexity exceeds the threshold (possible over-alignment).
24    pub is_over_aligned: bool,
25    /// True if the warp contains a non-monotone segment.
26    pub has_non_monotone: bool,
27    /// Post-alignment L2 residual (weighted).
28    pub residual: f64,
29    /// Ratio of post-alignment residual to pre-alignment distance.
30    pub distance_ratio: f64,
31    /// True if any issue was detected.
32    pub flagged: bool,
33    /// Human-readable issue descriptions.
34    pub issues: Vec<String>,
35}
36
37/// Configuration for alignment diagnostics.
38#[derive(Debug, Clone, PartialEq)]
39pub struct DiagnosticConfig {
40    /// Warp complexity above which the curve is flagged as over-aligned.
41    pub over_alignment_threshold: f64,
42    /// Distance ratio below which the curve is flagged as under-aligned
43    /// (i.e. the alignment barely improved the fit).
44    pub under_alignment_threshold: f64,
45    /// Maximum bending energy before the warp is considered too irregular.
46    pub max_bending_energy: f64,
47    /// Minimum improvement ratio (residual / pre-distance) to avoid flagging.
48    pub min_improvement_ratio: f64,
49}
50
51impl Default for DiagnosticConfig {
52    fn default() -> Self {
53        Self {
54            over_alignment_threshold: 1.0,
55            under_alignment_threshold: 1e-6,
56            max_bending_energy: 100.0,
57            min_improvement_ratio: 0.5,
58        }
59    }
60}
61
62/// Summary of diagnostics across all curves.
63#[derive(Debug, Clone, PartialEq)]
64#[non_exhaustive]
65pub struct AlignmentDiagnosticSummary {
66    /// Per-curve diagnostics.
67    pub diagnostics: Vec<AlignmentDiagnostic>,
68    /// Indices of flagged curves.
69    pub flagged_indices: Vec<usize>,
70    /// Number of flagged curves.
71    pub n_flagged: usize,
72    /// Overall health score in [0, 1]: fraction of curves that are *not* flagged.
73    pub health_score: f64,
74}
75
76// ─── Helpers ─────────────────────────────────────────────────────────────────
77
78/// Weighted L2 distance between two slices using pre-computed Simpson weights.
79fn weighted_l2(a: &[f64], b: &[f64], weights: &[f64]) -> f64 {
80    let mut sum = 0.0;
81    for i in 0..a.len() {
82        let d = a[i] - b[i];
83        sum += d * d * weights[i];
84    }
85    sum.sqrt()
86}
87
88/// Check monotonicity of a warp: returns true if any gamma[j+1] < gamma[j].
89fn is_non_monotone(gamma: &[f64]) -> bool {
90    gamma.windows(2).any(|w| w[1] < w[0])
91}
92
93/// Build a diagnostic for one curve given its warp, pre-distance, and residual.
94fn build_diagnostic(
95    curve_index: usize,
96    gamma: &[f64],
97    argvals: &[f64],
98    pre_distance: f64,
99    residual: f64,
100    config: &DiagnosticConfig,
101) -> AlignmentDiagnostic {
102    let wc = warp_complexity(gamma, argvals);
103    let ws = warp_smoothness(gamma, argvals);
104    let non_mono = is_non_monotone(gamma);
105
106    let distance_ratio = if pre_distance > 1e-15 {
107        residual / pre_distance
108    } else {
109        0.0
110    };
111
112    let is_over = wc > config.over_alignment_threshold;
113    let is_under = distance_ratio > config.min_improvement_ratio
114        && pre_distance > config.under_alignment_threshold;
115
116    let mut issues = Vec::new();
117    if is_over {
118        issues.push(format!(
119            "warp complexity {wc:.4} exceeds threshold {}",
120            config.over_alignment_threshold
121        ));
122    }
123    if is_under {
124        issues.push(format!(
125            "distance ratio {distance_ratio:.4} exceeds improvement threshold {}",
126            config.min_improvement_ratio
127        ));
128    }
129    if non_mono {
130        issues.push("warp contains non-monotone segments".to_string());
131    }
132    if ws > config.max_bending_energy {
133        issues.push(format!(
134            "bending energy {ws:.2} exceeds threshold {}",
135            config.max_bending_energy
136        ));
137    }
138
139    let flagged = !issues.is_empty();
140
141    AlignmentDiagnostic {
142        curve_index,
143        warp_complexity: wc,
144        warp_smoothness: ws,
145        is_under_aligned: is_under,
146        is_over_aligned: is_over,
147        has_non_monotone: non_mono,
148        residual,
149        distance_ratio,
150        flagged,
151        issues,
152    }
153}
154
155// ─── Public API ──────────────────────────────────────────────────────────────
156
157/// Diagnose alignment quality for every curve after a Karcher mean computation.
158///
159/// For each curve the function computes warp complexity, smoothness, pre- and
160/// post-alignment residuals, and checks for non-monotone warps and insufficient
161/// improvement. Curves with any issue are flagged.
162///
163/// # Arguments
164/// * `data`    — Original (unaligned) functional data (n x m).
165/// * `karcher` — Result of [`super::karcher::karcher_mean`].
166/// * `argvals` — Evaluation grid (length m).
167/// * `config`  — Diagnostic thresholds.
168///
169/// # Errors
170/// Returns `FdarError::InvalidDimension` on shape mismatches.
171pub fn diagnose_alignment(
172    data: &FdMatrix,
173    karcher: &KarcherMeanResult,
174    argvals: &[f64],
175    config: &DiagnosticConfig,
176) -> Result<AlignmentDiagnosticSummary, FdarError> {
177    let (n, m) = data.shape();
178
179    if argvals.len() != m {
180        return Err(FdarError::InvalidDimension {
181            parameter: "argvals",
182            expected: format!("{m}"),
183            actual: format!("{}", argvals.len()),
184        });
185    }
186    if karcher.gammas.nrows() != n || karcher.gammas.ncols() != m {
187        return Err(FdarError::InvalidDimension {
188            parameter: "karcher.gammas",
189            expected: format!("{n} x {m}"),
190            actual: format!("{} x {}", karcher.gammas.nrows(), karcher.gammas.ncols()),
191        });
192    }
193    if karcher.aligned_data.nrows() != n || karcher.aligned_data.ncols() != m {
194        return Err(FdarError::InvalidDimension {
195            parameter: "karcher.aligned_data",
196            expected: format!("{n} x {m}"),
197            actual: format!(
198                "{} x {}",
199                karcher.aligned_data.nrows(),
200                karcher.aligned_data.ncols()
201            ),
202        });
203    }
204    if karcher.mean.len() != m {
205        return Err(FdarError::InvalidDimension {
206            parameter: "karcher.mean",
207            expected: format!("{m}"),
208            actual: format!("{}", karcher.mean.len()),
209        });
210    }
211
212    let weights = simpsons_weights(argvals);
213
214    let mut diagnostics = Vec::with_capacity(n);
215    let mut flagged_indices = Vec::new();
216
217    for i in 0..n {
218        let gamma_i: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
219
220        // Pre-alignment distance: ||f_i - mean||
221        let fi = data.row(i);
222        let pre_distance = weighted_l2(&fi, &karcher.mean, &weights);
223
224        // Post-alignment residual: ||f_i_aligned - mean||
225        let fi_aligned = karcher.aligned_data.row(i);
226        let residual = weighted_l2(&fi_aligned, &karcher.mean, &weights);
227
228        let diag = build_diagnostic(i, &gamma_i, argvals, pre_distance, residual, config);
229        if diag.flagged {
230            flagged_indices.push(i);
231        }
232        diagnostics.push(diag);
233    }
234
235    let n_flagged = flagged_indices.len();
236    let health_score = if n > 0 {
237        1.0 - n_flagged as f64 / n as f64
238    } else {
239        1.0
240    };
241
242    Ok(AlignmentDiagnosticSummary {
243        diagnostics,
244        flagged_indices,
245        n_flagged,
246        health_score,
247    })
248}
249
250/// Diagnose a single pairwise alignment.
251///
252/// Examines the warp produced by [`super::pairwise::elastic_align_pair`] and
253/// checks for over-alignment, under-alignment, non-monotonicity, and excessive
254/// bending energy.
255pub fn diagnose_pairwise(
256    f1: &[f64],
257    f2: &[f64],
258    result: &AlignmentResult,
259    argvals: &[f64],
260    config: &DiagnosticConfig,
261) -> AlignmentDiagnostic {
262    let weights = simpsons_weights(argvals);
263
264    // Pre-alignment L2 distance
265    let pre_distance = weighted_l2(f1, f2, &weights);
266
267    // Post-alignment residual: ||f1 - f2_aligned||
268    let residual = weighted_l2(f1, &result.f_aligned, &weights);
269
270    build_diagnostic(0, &result.gamma, argvals, pre_distance, residual, config)
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::alignment::karcher_mean;
277    use crate::alignment::pairwise::elastic_align_pair;
278    use crate::simulation::{sim_fundata, EFunType, EValType};
279    use crate::test_helpers::uniform_grid;
280
281    fn make_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
282        let t = uniform_grid(m);
283        let data = sim_fundata(n, &t, 3, EFunType::Fourier, EValType::Exponential, Some(99));
284        (data, t)
285    }
286
287    #[test]
288    fn diagnose_alignment_smoke() {
289        let (data, t) = make_data(8, 30);
290        let km = karcher_mean(&data, &t, 5, 1e-2, 0.0);
291        let config = DiagnosticConfig::default();
292        let summary = diagnose_alignment(&data, &km, &t, &config).unwrap();
293        assert_eq!(summary.diagnostics.len(), 8);
294        assert!(summary.health_score >= 0.0 && summary.health_score <= 1.0);
295        assert_eq!(summary.n_flagged, summary.flagged_indices.len());
296    }
297
298    #[test]
299    fn diagnose_alignment_identical_returns_low_complexity() {
300        // When data is identical, warp complexity should be small even though
301        // post-centering numerics may not yield exactly the identity warp.
302        let t = uniform_grid(30);
303        let curve: Vec<f64> = t.iter().map(|&x| x.sin()).collect();
304        let mut vals = Vec::with_capacity(5 * 30);
305        for _ in 0..5 {
306            vals.extend_from_slice(&curve);
307        }
308        let data = FdMatrix::from_column_major(vals, 5, 30).unwrap();
309        let km = karcher_mean(&data, &t, 5, 1e-3, 0.0);
310        let config = DiagnosticConfig::default();
311        let summary = diagnose_alignment(&data, &km, &t, &config).unwrap();
312        assert_eq!(summary.diagnostics.len(), 5);
313        // All warp complexities should be small (near identity)
314        for d in &summary.diagnostics {
315            assert!(
316                d.warp_complexity < 0.5,
317                "curve {} warp_complexity {} should be small for identical data",
318                d.curve_index,
319                d.warp_complexity,
320            );
321        }
322    }
323
324    #[test]
325    fn diagnose_alignment_rejects_shape_mismatch() {
326        let (data, t) = make_data(6, 30);
327        let km = karcher_mean(&data, &t, 3, 1e-2, 0.0);
328        let bad_t = uniform_grid(20);
329        let config = DiagnosticConfig::default();
330        assert!(diagnose_alignment(&data, &km, &bad_t, &config).is_err());
331    }
332
333    #[test]
334    fn diagnose_pairwise_smoke() {
335        let t = uniform_grid(30);
336        let f1: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
337        let f2: Vec<f64> = t.iter().map(|&x| ((x + 0.15) * 6.0).sin()).collect();
338        let alignment = elastic_align_pair(&f1, &f2, &t, 0.0);
339        let config = DiagnosticConfig::default();
340        let diag = diagnose_pairwise(&f1, &f2, &alignment, &t, &config);
341        assert!(diag.warp_complexity >= 0.0);
342        assert!(diag.residual >= 0.0);
343    }
344
345    #[test]
346    fn diagnose_pairwise_identical() {
347        let t = uniform_grid(30);
348        let f: Vec<f64> = t.iter().map(|&x| x.sin()).collect();
349        let alignment = elastic_align_pair(&f, &f, &t, 0.0);
350        let config = DiagnosticConfig::default();
351        let diag = diagnose_pairwise(&f, &f, &alignment, &t, &config);
352        assert!(
353            diag.residual < 1e-3,
354            "identical curves should have near-zero residual"
355        );
356        assert!(!diag.has_non_monotone);
357    }
358}