Skip to main content

fdars_core/alignment/
quality.rs

1//! Alignment quality metrics: warp complexity, smoothness, variance decomposition,
2//! and pairwise consistency.
3
4use super::pairwise::elastic_align_pair;
5use super::srsf::compose_warps;
6use super::KarcherMeanResult;
7use crate::helpers::{gradient_uniform, l2_distance, simpsons_weights};
8use crate::matrix::FdMatrix;
9
10/// Comprehensive alignment quality assessment.
11#[derive(Debug, Clone, PartialEq)]
12pub struct AlignmentQuality {
13    /// Per-curve geodesic distance from warp to identity.
14    pub warp_complexity: Vec<f64>,
15    /// Mean warp complexity.
16    pub mean_warp_complexity: f64,
17    /// Per-curve bending energy ∫(γ'')² dt.
18    pub warp_smoothness: Vec<f64>,
19    /// Mean warp smoothness (bending energy).
20    pub mean_warp_smoothness: f64,
21    /// Total variance: (1/n) Σ ∫(f_i - mean_orig)² dt.
22    pub total_variance: f64,
23    /// Amplitude variance: (1/n) Σ ∫(f_i^aligned - mean_aligned)² dt.
24    pub amplitude_variance: f64,
25    /// Phase variance: total - amplitude (clamped ≥ 0).
26    pub phase_variance: f64,
27    /// Phase-to-total variance ratio.
28    pub phase_amplitude_ratio: f64,
29    /// Pointwise ratio: aligned_var / orig_var per time point.
30    pub pointwise_variance_ratio: Vec<f64>,
31    /// Mean variance reduction.
32    pub mean_variance_reduction: f64,
33}
34
35/// Compute warp complexity: geodesic distance from a warp to the identity.
36///
37/// This is `arccos(⟨ψ, ψ_id⟩)` on the Hilbert sphere.
38pub fn warp_complexity(gamma: &[f64], argvals: &[f64]) -> f64 {
39    crate::warping::phase_distance(gamma, argvals)
40}
41
42/// Compute warp smoothness (bending energy): ∫(γ'')² dt.
43pub fn warp_smoothness(gamma: &[f64], argvals: &[f64]) -> f64 {
44    let m = gamma.len();
45    if m < 3 {
46        return 0.0;
47    }
48
49    let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
50    let gam_prime = gradient_uniform(gamma, h);
51    let gam_pprime = gradient_uniform(&gam_prime, h);
52
53    let integrand: Vec<f64> = gam_pprime.iter().map(|&g| g * g).collect();
54    crate::helpers::trapz(&integrand, argvals)
55}
56
57/// Compute comprehensive alignment quality metrics.
58///
59/// # Arguments
60/// * `data` — Original functional data (n × m)
61/// * `karcher` — Pre-computed Karcher mean result
62/// * `argvals` — Evaluation points (length m)
63pub fn alignment_quality(
64    data: &FdMatrix,
65    karcher: &KarcherMeanResult,
66    argvals: &[f64],
67) -> AlignmentQuality {
68    let (n, m) = data.shape();
69    let weights = simpsons_weights(argvals);
70
71    // Per-curve warp complexity and smoothness
72    let wc: Vec<f64> = (0..n)
73        .map(|i| {
74            let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
75            warp_complexity(&gamma, argvals)
76        })
77        .collect();
78    let ws: Vec<f64> = (0..n)
79        .map(|i| {
80            let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
81            warp_smoothness(&gamma, argvals)
82        })
83        .collect();
84
85    let mean_wc = wc.iter().sum::<f64>() / n as f64;
86    let mean_ws = ws.iter().sum::<f64>() / n as f64;
87
88    // Compute original mean
89    let orig_mean = crate::fdata::mean_1d(data);
90
91    // Total variance
92    let total_var: f64 = (0..n)
93        .map(|i| {
94            let fi = data.row(i);
95            let d = l2_distance(&fi, &orig_mean, &weights);
96            d * d
97        })
98        .sum::<f64>()
99        / n as f64;
100
101    // Aligned mean
102    let aligned_mean = crate::fdata::mean_1d(&karcher.aligned_data);
103
104    // Amplitude variance
105    let amp_var: f64 = (0..n)
106        .map(|i| {
107            let fi = karcher.aligned_data.row(i);
108            let d = l2_distance(&fi, &aligned_mean, &weights);
109            d * d
110        })
111        .sum::<f64>()
112        / n as f64;
113
114    let phase_var = (total_var - amp_var).max(0.0);
115    let ratio = if total_var > 1e-10 {
116        phase_var / total_var
117    } else {
118        0.0
119    };
120
121    // Pointwise variance ratio
122    let mut pw_ratio = vec![0.0; m];
123    for j in 0..m {
124        let col_orig = data.column(j);
125        let mean_orig_j = col_orig.iter().sum::<f64>() / n as f64;
126        let var_orig: f64 = col_orig
127            .iter()
128            .map(|&v| (v - mean_orig_j).powi(2))
129            .sum::<f64>()
130            / n as f64;
131
132        let col_aligned = karcher.aligned_data.column(j);
133        let mean_aligned_j = col_aligned.iter().sum::<f64>() / n as f64;
134        let var_aligned: f64 = col_aligned
135            .iter()
136            .map(|&v| (v - mean_aligned_j).powi(2))
137            .sum::<f64>()
138            / n as f64;
139
140        pw_ratio[j] = if var_orig > 1e-15 {
141            var_aligned / var_orig
142        } else {
143            1.0
144        };
145    }
146
147    let mean_vr = pw_ratio.iter().sum::<f64>() / m as f64;
148
149    AlignmentQuality {
150        warp_complexity: wc,
151        mean_warp_complexity: mean_wc,
152        warp_smoothness: ws,
153        mean_warp_smoothness: mean_ws,
154        total_variance: total_var,
155        amplitude_variance: amp_var,
156        phase_variance: phase_var,
157        phase_amplitude_ratio: ratio,
158        pointwise_variance_ratio: pw_ratio,
159        mean_variance_reduction: mean_vr,
160    }
161}
162
163/// Generate triplet indices (i,j,k) with i<j<k, capped at `max_triplets` (0 = all).
164fn triplet_indices(n: usize, max_triplets: usize) -> Vec<(usize, usize, usize)> {
165    let total = n * (n - 1) * (n - 2) / 6;
166    let cap = if max_triplets > 0 {
167        max_triplets.min(total)
168    } else {
169        total
170    };
171    (0..n)
172        .flat_map(|i| ((i + 1)..n).flat_map(move |j| ((j + 1)..n).map(move |k| (i, j, k))))
173        .take(cap)
174        .collect()
175}
176
177/// Compute the warp deviation for one triplet: ‖γ_ij∘γ_jk − γ_ik‖_L2.
178fn triplet_warp_deviation(
179    data: &FdMatrix,
180    argvals: &[f64],
181    weights: &[f64],
182    i: usize,
183    j: usize,
184    k: usize,
185    lambda: f64,
186) -> f64 {
187    let fi = data.row(i);
188    let fj = data.row(j);
189    let fk = data.row(k);
190    let rij = elastic_align_pair(&fi, &fj, argvals, lambda);
191    let rjk = elastic_align_pair(&fj, &fk, argvals, lambda);
192    let rik = elastic_align_pair(&fi, &fk, argvals, lambda);
193    let composed = compose_warps(&rij.gamma, &rjk.gamma, argvals);
194    l2_distance(&composed, &rik.gamma, weights)
195}
196
197/// Measure pairwise alignment consistency via triplet checks.
198///
199/// For triplets (i,j,k), checks `γ_ij ∘ γ_jk ≈ γ_ik` by measuring the L2
200/// deviation of the composed warp from the direct warp.
201///
202/// # Arguments
203/// * `data` — Functional data (n × m)
204/// * `argvals` — Evaluation points (length m)
205/// * `lambda` — Penalty weight
206/// * `max_triplets` — Maximum number of triplets to check (0 = all)
207pub fn pairwise_consistency(
208    data: &FdMatrix,
209    argvals: &[f64],
210    lambda: f64,
211    max_triplets: usize,
212) -> f64 {
213    let n = data.nrows();
214    if n < 3 {
215        return 0.0;
216    }
217
218    let weights = simpsons_weights(argvals);
219    let triplets = triplet_indices(n, max_triplets);
220    if triplets.is_empty() {
221        return 0.0;
222    }
223
224    let total_dev: f64 = triplets
225        .iter()
226        .map(|&(i, j, k)| triplet_warp_deviation(data, argvals, &weights, i, j, k, lambda))
227        .sum();
228    total_dev / triplets.len() as f64
229}