1use super::quality::{warp_complexity, warp_smoothness};
4use super::{AlignmentResult, KarcherMeanResult};
5use crate::error::FdarError;
6use crate::helpers::simpsons_weights;
7use crate::matrix::FdMatrix;
8
9#[derive(Debug, Clone, PartialEq)]
13#[non_exhaustive]
14pub struct AlignmentDiagnostic {
15 pub curve_index: usize,
17 pub warp_complexity: f64,
19 pub warp_smoothness: f64,
21 pub is_under_aligned: bool,
23 pub is_over_aligned: bool,
25 pub has_non_monotone: bool,
27 pub residual: f64,
29 pub distance_ratio: f64,
31 pub flagged: bool,
33 pub issues: Vec<String>,
35}
36
37#[derive(Debug, Clone, PartialEq)]
39pub struct DiagnosticConfig {
40 pub over_alignment_threshold: f64,
42 pub under_alignment_threshold: f64,
45 pub max_bending_energy: f64,
47 pub min_improvement_ratio: f64,
49}
50
51impl Default for DiagnosticConfig {
52 fn default() -> Self {
53 Self {
54 over_alignment_threshold: 1.0,
55 under_alignment_threshold: 1e-6,
56 max_bending_energy: 100.0,
57 min_improvement_ratio: 0.5,
58 }
59 }
60}
61
62#[derive(Debug, Clone, PartialEq)]
64#[non_exhaustive]
65pub struct AlignmentDiagnosticSummary {
66 pub diagnostics: Vec<AlignmentDiagnostic>,
68 pub flagged_indices: Vec<usize>,
70 pub n_flagged: usize,
72 pub health_score: f64,
74}
75
76fn weighted_l2(a: &[f64], b: &[f64], weights: &[f64]) -> f64 {
80 let mut sum = 0.0;
81 for i in 0..a.len() {
82 let d = a[i] - b[i];
83 sum += d * d * weights[i];
84 }
85 sum.sqrt()
86}
87
88fn is_non_monotone(gamma: &[f64]) -> bool {
90 gamma.windows(2).any(|w| w[1] < w[0])
91}
92
93fn build_diagnostic(
95 curve_index: usize,
96 gamma: &[f64],
97 argvals: &[f64],
98 pre_distance: f64,
99 residual: f64,
100 config: &DiagnosticConfig,
101) -> AlignmentDiagnostic {
102 let wc = warp_complexity(gamma, argvals);
103 let ws = warp_smoothness(gamma, argvals);
104 let non_mono = is_non_monotone(gamma);
105
106 let distance_ratio = if pre_distance > 1e-15 {
107 residual / pre_distance
108 } else {
109 0.0
110 };
111
112 let is_over = wc > config.over_alignment_threshold;
113 let is_under = distance_ratio > config.min_improvement_ratio
114 && pre_distance > config.under_alignment_threshold;
115
116 let mut issues = Vec::new();
117 if is_over {
118 issues.push(format!(
119 "warp complexity {wc:.4} exceeds threshold {}",
120 config.over_alignment_threshold
121 ));
122 }
123 if is_under {
124 issues.push(format!(
125 "distance ratio {distance_ratio:.4} exceeds improvement threshold {}",
126 config.min_improvement_ratio
127 ));
128 }
129 if non_mono {
130 issues.push("warp contains non-monotone segments".to_string());
131 }
132 if ws > config.max_bending_energy {
133 issues.push(format!(
134 "bending energy {ws:.2} exceeds threshold {}",
135 config.max_bending_energy
136 ));
137 }
138
139 let flagged = !issues.is_empty();
140
141 AlignmentDiagnostic {
142 curve_index,
143 warp_complexity: wc,
144 warp_smoothness: ws,
145 is_under_aligned: is_under,
146 is_over_aligned: is_over,
147 has_non_monotone: non_mono,
148 residual,
149 distance_ratio,
150 flagged,
151 issues,
152 }
153}
154
155pub fn diagnose_alignment(
172 data: &FdMatrix,
173 karcher: &KarcherMeanResult,
174 argvals: &[f64],
175 config: &DiagnosticConfig,
176) -> Result<AlignmentDiagnosticSummary, FdarError> {
177 let (n, m) = data.shape();
178
179 if argvals.len() != m {
180 return Err(FdarError::InvalidDimension {
181 parameter: "argvals",
182 expected: format!("{m}"),
183 actual: format!("{}", argvals.len()),
184 });
185 }
186 if karcher.gammas.nrows() != n || karcher.gammas.ncols() != m {
187 return Err(FdarError::InvalidDimension {
188 parameter: "karcher.gammas",
189 expected: format!("{n} x {m}"),
190 actual: format!("{} x {}", karcher.gammas.nrows(), karcher.gammas.ncols()),
191 });
192 }
193 if karcher.aligned_data.nrows() != n || karcher.aligned_data.ncols() != m {
194 return Err(FdarError::InvalidDimension {
195 parameter: "karcher.aligned_data",
196 expected: format!("{n} x {m}"),
197 actual: format!(
198 "{} x {}",
199 karcher.aligned_data.nrows(),
200 karcher.aligned_data.ncols()
201 ),
202 });
203 }
204 if karcher.mean.len() != m {
205 return Err(FdarError::InvalidDimension {
206 parameter: "karcher.mean",
207 expected: format!("{m}"),
208 actual: format!("{}", karcher.mean.len()),
209 });
210 }
211
212 let weights = simpsons_weights(argvals);
213
214 let mut diagnostics = Vec::with_capacity(n);
215 let mut flagged_indices = Vec::new();
216
217 for i in 0..n {
218 let gamma_i: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
219
220 let fi = data.row(i);
222 let pre_distance = weighted_l2(&fi, &karcher.mean, &weights);
223
224 let fi_aligned = karcher.aligned_data.row(i);
226 let residual = weighted_l2(&fi_aligned, &karcher.mean, &weights);
227
228 let diag = build_diagnostic(i, &gamma_i, argvals, pre_distance, residual, config);
229 if diag.flagged {
230 flagged_indices.push(i);
231 }
232 diagnostics.push(diag);
233 }
234
235 let n_flagged = flagged_indices.len();
236 let health_score = if n > 0 {
237 1.0 - n_flagged as f64 / n as f64
238 } else {
239 1.0
240 };
241
242 Ok(AlignmentDiagnosticSummary {
243 diagnostics,
244 flagged_indices,
245 n_flagged,
246 health_score,
247 })
248}
249
250pub fn diagnose_pairwise(
256 f1: &[f64],
257 f2: &[f64],
258 result: &AlignmentResult,
259 argvals: &[f64],
260 config: &DiagnosticConfig,
261) -> AlignmentDiagnostic {
262 let weights = simpsons_weights(argvals);
263
264 let pre_distance = weighted_l2(f1, f2, &weights);
266
267 let residual = weighted_l2(f1, &result.f_aligned, &weights);
269
270 build_diagnostic(0, &result.gamma, argvals, pre_distance, residual, config)
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::alignment::karcher_mean;
277 use crate::alignment::pairwise::elastic_align_pair;
278 use crate::simulation::{sim_fundata, EFunType, EValType};
279 use crate::test_helpers::uniform_grid;
280
281 fn make_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
282 let t = uniform_grid(m);
283 let data = sim_fundata(n, &t, 3, EFunType::Fourier, EValType::Exponential, Some(99));
284 (data, t)
285 }
286
287 #[test]
288 fn diagnose_alignment_smoke() {
289 let (data, t) = make_data(8, 30);
290 let km = karcher_mean(&data, &t, 5, 1e-2, 0.0);
291 let config = DiagnosticConfig::default();
292 let summary = diagnose_alignment(&data, &km, &t, &config).unwrap();
293 assert_eq!(summary.diagnostics.len(), 8);
294 assert!(summary.health_score >= 0.0 && summary.health_score <= 1.0);
295 assert_eq!(summary.n_flagged, summary.flagged_indices.len());
296 }
297
298 #[test]
299 fn diagnose_alignment_identical_returns_low_complexity() {
300 let t = uniform_grid(30);
303 let curve: Vec<f64> = t.iter().map(|&x| x.sin()).collect();
304 let mut vals = Vec::with_capacity(5 * 30);
305 for _ in 0..5 {
306 vals.extend_from_slice(&curve);
307 }
308 let data = FdMatrix::from_column_major(vals, 5, 30).unwrap();
309 let km = karcher_mean(&data, &t, 5, 1e-3, 0.0);
310 let config = DiagnosticConfig::default();
311 let summary = diagnose_alignment(&data, &km, &t, &config).unwrap();
312 assert_eq!(summary.diagnostics.len(), 5);
313 for d in &summary.diagnostics {
315 assert!(
316 d.warp_complexity < 0.5,
317 "curve {} warp_complexity {} should be small for identical data",
318 d.curve_index,
319 d.warp_complexity,
320 );
321 }
322 }
323
324 #[test]
325 fn diagnose_alignment_rejects_shape_mismatch() {
326 let (data, t) = make_data(6, 30);
327 let km = karcher_mean(&data, &t, 3, 1e-2, 0.0);
328 let bad_t = uniform_grid(20);
329 let config = DiagnosticConfig::default();
330 assert!(diagnose_alignment(&data, &km, &bad_t, &config).is_err());
331 }
332
333 #[test]
334 fn diagnose_pairwise_smoke() {
335 let t = uniform_grid(30);
336 let f1: Vec<f64> = t.iter().map(|&x| (x * 6.0).sin()).collect();
337 let f2: Vec<f64> = t.iter().map(|&x| ((x + 0.15) * 6.0).sin()).collect();
338 let alignment = elastic_align_pair(&f1, &f2, &t, 0.0);
339 let config = DiagnosticConfig::default();
340 let diag = diagnose_pairwise(&f1, &f2, &alignment, &t, &config);
341 assert!(diag.warp_complexity >= 0.0);
342 assert!(diag.residual >= 0.0);
343 }
344
345 #[test]
346 fn diagnose_pairwise_identical() {
347 let t = uniform_grid(30);
348 let f: Vec<f64> = t.iter().map(|&x| x.sin()).collect();
349 let alignment = elastic_align_pair(&f, &f, &t, 0.0);
350 let config = DiagnosticConfig::default();
351 let diag = diagnose_pairwise(&f, &f, &alignment, &t, &config);
352 assert!(
353 diag.residual < 1e-3,
354 "identical curves should have near-zero residual"
355 );
356 assert!(!diag.has_non_monotone);
357 }
358}