Skip to main content

fdars_core/
elastic_explain.rs

1//! Elastic shape explainability: amplitude vs phase attribution.
2//!
3//! Decomposes elastic PCR predictions into "shape" (amplitude) and "timing" (phase)
4//! contributions, with permutation-based variable importance.
5//!
6//! - [`elastic_pcr_attribution`] — Decompose predictions and compute importance scores
7
8use crate::alignment::srsf_transform;
9use crate::elastic_fpca::{
10    build_augmented_srsfs, center_matrix, shooting_vectors_from_psis, sphere_karcher_mean,
11    warps_to_normalized_psi,
12};
13use crate::elastic_regression::{ElasticPcrResult, PcaMethod};
14use crate::matrix::FdMatrix;
15use rand::prelude::*;
16
17/// Result of elastic amplitude/phase attribution.
18pub struct ElasticAttributionResult {
19    /// Per-observation amplitude contribution (length n).
20    pub amplitude_contribution: Vec<f64>,
21    /// Per-observation phase contribution (length n).
22    pub phase_contribution: Vec<f64>,
23    /// R² drop from permuting amplitude scores.
24    pub amplitude_importance: f64,
25    /// R² drop from permuting phase scores.
26    pub phase_importance: f64,
27}
28
29/// Decompose elastic PCR predictions into amplitude and phase contributions.
30///
31/// For joint FPCA, the joint eigenvectors split into vertical (amplitude) and
32/// horizontal (phase) parts. Each observation's score can be decomposed into
33/// amplitude and phase sub-scores based on these parts.
34///
35/// For vertical-only or horizontal-only models, the missing component
36/// contributes zero.
37///
38/// # Arguments
39/// * `result` — A fitted [`ElasticPcrResult`] (must have stored FPCA results)
40/// * `y` — Original scalar responses (length n)
41/// * `ncomp` — Number of components to use for attribution
42/// * `n_perm` — Number of permutation replicates for importance
43/// * `seed` — RNG seed for permutation reproducibility
44pub fn elastic_pcr_attribution(
45    result: &ElasticPcrResult,
46    y: &[f64],
47    ncomp: usize,
48    n_perm: usize,
49    seed: u64,
50) -> Option<ElasticAttributionResult> {
51    let n = result.fitted_values.len();
52    if y.len() != n || ncomp == 0 || n < 2 {
53        return None;
54    }
55    let actual_ncomp = ncomp.min(result.coefficients.len());
56
57    match result.pca_method {
58        PcaMethod::Joint => attribution_joint(result, y, actual_ncomp, n_perm, seed),
59        PcaMethod::Vertical => {
60            // All contribution is amplitude, phase is zero
61            let amp: Vec<f64> = result
62                .fitted_values
63                .iter()
64                .map(|&f| f - result.alpha)
65                .collect();
66            let phase = vec![0.0; n];
67            let amp_imp = permutation_importance_single(
68                y,
69                &result.fitted_values,
70                result.alpha,
71                &result.coefficients,
72                actual_ncomp,
73                n_perm,
74                seed,
75            );
76            Some(ElasticAttributionResult {
77                amplitude_contribution: amp,
78                phase_contribution: phase,
79                amplitude_importance: amp_imp,
80                phase_importance: 0.0,
81            })
82        }
83        PcaMethod::Horizontal => {
84            // All contribution is phase, amplitude is zero
85            let phase: Vec<f64> = result
86                .fitted_values
87                .iter()
88                .map(|&f| f - result.alpha)
89                .collect();
90            let amp = vec![0.0; n];
91            let phase_imp = permutation_importance_single(
92                y,
93                &result.fitted_values,
94                result.alpha,
95                &result.coefficients,
96                actual_ncomp,
97                n_perm,
98                seed,
99            );
100            Some(ElasticAttributionResult {
101                amplitude_contribution: amp,
102                phase_contribution: phase,
103                amplitude_importance: 0.0,
104                phase_importance: phase_imp,
105            })
106        }
107    }
108}
109
110/// Joint FPCA attribution: decompose scores into amp and phase parts.
111fn attribution_joint(
112    result: &ElasticPcrResult,
113    y: &[f64],
114    ncomp: usize,
115    n_perm: usize,
116    seed: u64,
117) -> Option<ElasticAttributionResult> {
118    let joint = result.joint_fpca.as_ref()?;
119    let km = &result.karcher;
120    let (n, m) = km.aligned_data.shape();
121    let m_aug = m + 1;
122
123    let qn = match &km.aligned_srsfs {
124        Some(srsfs) => srsfs.clone(),
125        None => {
126            let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
127            srsf_transform(&km.aligned_data, &argvals)
128        }
129    };
130
131    let q_aug = build_augmented_srsfs(&qn, &km.aligned_data, n, m);
132    let (_, mean_q) = center_matrix(&q_aug, n, m_aug);
133
134    // Compute shooting vectors using shared helpers
135    let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
136    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
137    let psis = warps_to_normalized_psi(&km.gammas, &argvals);
138    let mu_psi = sphere_karcher_mean(&psis, &time, 50);
139    let shooting = shooting_vectors_from_psis(&psis, &mu_psi, &time);
140
141    let c = joint.balance_c;
142    let (amp_scores, phase_scores) = decompose_joint_scores(
143        &q_aug,
144        &mean_q,
145        &shooting,
146        &joint.vert_component,
147        &joint.horiz_component,
148        c,
149        n,
150        m_aug,
151        m,
152        ncomp,
153    );
154
155    let (amplitude_contribution, phase_contribution) =
156        compute_contributions(&amp_scores, &phase_scores, &result.coefficients, n, ncomp);
157
158    // Permutation importance
159    let r2_orig = compute_r2(y, &result.fitted_values);
160    let amplitude_importance = permutation_importance(
161        y,
162        result.alpha,
163        &result.coefficients,
164        &amp_scores,
165        &phase_scores,
166        ncomp,
167        n_perm,
168        seed,
169        true,
170    );
171    let phase_importance = permutation_importance(
172        y,
173        result.alpha,
174        &result.coefficients,
175        &amp_scores,
176        &phase_scores,
177        ncomp,
178        n_perm,
179        seed + 1_000_000,
180        false,
181    );
182
183    Some(ElasticAttributionResult {
184        amplitude_contribution,
185        phase_contribution,
186        amplitude_importance: (r2_orig - amplitude_importance).max(0.0),
187        phase_importance: (r2_orig - phase_importance).max(0.0),
188    })
189}
190
191/// Decompose joint scores into amplitude and phase sub-scores.
192fn decompose_joint_scores(
193    q_aug: &FdMatrix,
194    mean_q: &[f64],
195    shooting: &FdMatrix,
196    vert_component: &FdMatrix,
197    horiz_component: &FdMatrix,
198    c: f64,
199    n: usize,
200    m_aug: usize,
201    m: usize,
202    ncomp: usize,
203) -> (FdMatrix, FdMatrix) {
204    let mut amp_scores = FdMatrix::zeros(n, ncomp);
205    let mut phase_scores = FdMatrix::zeros(n, ncomp);
206    for k in 0..ncomp {
207        for i in 0..n {
208            let mut amp_s = 0.0;
209            for j in 0..m_aug {
210                amp_s += (q_aug[(i, j)] - mean_q[j]) * vert_component[(k, j)];
211            }
212            amp_scores[(i, k)] = amp_s;
213
214            let mut phase_s = 0.0;
215            for j in 0..m {
216                phase_s += c * shooting[(i, j)] * horiz_component[(k, j)];
217            }
218            phase_scores[(i, k)] = phase_s;
219        }
220    }
221    (amp_scores, phase_scores)
222}
223
224/// Compute amplitude and phase contributions from decomposed scores.
225fn compute_contributions(
226    amp_scores: &FdMatrix,
227    phase_scores: &FdMatrix,
228    coefficients: &[f64],
229    n: usize,
230    ncomp: usize,
231) -> (Vec<f64>, Vec<f64>) {
232    let mut amplitude_contribution = vec![0.0; n];
233    let mut phase_contribution = vec![0.0; n];
234    for i in 0..n {
235        for k in 0..ncomp {
236            amplitude_contribution[i] += coefficients[k] * amp_scores[(i, k)];
237            phase_contribution[i] += coefficients[k] * phase_scores[(i, k)];
238        }
239    }
240    (amplitude_contribution, phase_contribution)
241}
242
243/// Compute R² statistic.
244fn compute_r2(y: &[f64], fitted: &[f64]) -> f64 {
245    let n = y.len();
246    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
247    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
248    let ss_res: f64 = y
249        .iter()
250        .zip(fitted)
251        .map(|(&yi, &fi)| (yi - fi).powi(2))
252        .sum();
253    if ss_tot > 0.0 {
254        1.0 - ss_res / ss_tot
255    } else {
256        0.0
257    }
258}
259
260/// Permutation importance: shuffle one component's scores, recompute fitted, return avg R².
261fn permutation_importance(
262    y: &[f64],
263    alpha: f64,
264    coefficients: &[f64],
265    amp_scores: &FdMatrix,
266    phase_scores: &FdMatrix,
267    ncomp: usize,
268    n_perm: usize,
269    seed: u64,
270    permute_amplitude: bool,
271) -> f64 {
272    let n = y.len();
273    if n_perm == 0 {
274        return compute_r2(y, &vec![alpha; n]);
275    }
276
277    let mut total_r2 = 0.0;
278    for p in 0..n_perm {
279        let mut rng = StdRng::seed_from_u64(seed.wrapping_add(p as u64));
280        let mut perm_idx: Vec<usize> = (0..n).collect();
281        perm_idx.shuffle(&mut rng);
282
283        let fitted = fitted_with_permuted_scores(
284            alpha,
285            coefficients,
286            amp_scores,
287            phase_scores,
288            &perm_idx,
289            n,
290            ncomp,
291            permute_amplitude,
292        );
293        total_r2 += compute_r2(y, &fitted);
294    }
295    total_r2 / n_perm as f64
296}
297
298/// Compute fitted values with one component's scores permuted.
299fn fitted_with_permuted_scores(
300    alpha: f64,
301    coefficients: &[f64],
302    amp_scores: &FdMatrix,
303    phase_scores: &FdMatrix,
304    perm_idx: &[usize],
305    n: usize,
306    ncomp: usize,
307    permute_amplitude: bool,
308) -> Vec<f64> {
309    let mut fitted = vec![0.0; n];
310    for i in 0..n {
311        fitted[i] = alpha;
312        for k in 0..ncomp {
313            let amp_i = if permute_amplitude {
314                amp_scores[(perm_idx[i], k)]
315            } else {
316                amp_scores[(i, k)]
317            };
318            let phase_i = if !permute_amplitude {
319                phase_scores[(perm_idx[i], k)]
320            } else {
321                phase_scores[(i, k)]
322            };
323            fitted[i] += coefficients[k] * (amp_i + phase_i);
324        }
325    }
326    fitted
327}
328
329/// Permutation importance for single-component models (vert-only or horiz-only).
330fn permutation_importance_single(
331    y: &[f64],
332    fitted_values: &[f64],
333    alpha: f64,
334    _coefficients: &[f64],
335    _ncomp: usize,
336    n_perm: usize,
337    seed: u64,
338) -> f64 {
339    let n = y.len();
340    let r2_orig = compute_r2(y, fitted_values);
341    if n_perm == 0 {
342        return r2_orig;
343    }
344
345    // Extract per-obs contribution = fitted - alpha, then permute
346    let contribs: Vec<f64> = fitted_values.iter().map(|&f| f - alpha).collect();
347    let mut total_r2 = 0.0;
348    for p in 0..n_perm {
349        let mut rng = StdRng::seed_from_u64(seed.wrapping_add(p as u64));
350        let mut perm_idx: Vec<usize> = (0..n).collect();
351        perm_idx.shuffle(&mut rng);
352
353        let fitted_perm: Vec<f64> = (0..n).map(|i| alpha + contribs[perm_idx[i]]).collect();
354        total_r2 += compute_r2(y, &fitted_perm);
355    }
356    let avg_r2 = total_r2 / n_perm as f64;
357    (r2_orig - avg_r2).max(0.0)
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use crate::elastic_regression::{elastic_pcr, PcaMethod};
364    use std::f64::consts::PI;
365
366    fn generate_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>, Vec<f64>) {
367        let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
368        let mut data = FdMatrix::zeros(n, m);
369        let mut y = vec![0.0; n];
370        for i in 0..n {
371            let amp = 1.0 + 0.5 * (i as f64 / n as f64);
372            let shift = 0.1 * (i as f64 - n as f64 / 2.0);
373            for j in 0..m {
374                data[(i, j)] = amp * (2.0 * PI * (t[j] + shift)).sin();
375            }
376            y[i] = amp;
377        }
378        (data, y, t)
379    }
380
381    #[test]
382    fn test_elastic_attribution_joint_decomposition() {
383        let (data, y, t) = generate_test_data(15, 51);
384        let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Joint, 0.0, 5, 1e-3).unwrap();
385        let attr = elastic_pcr_attribution(&result, &y, 3, 10, 42).unwrap();
386
387        assert_eq!(attr.amplitude_contribution.len(), 15);
388        assert_eq!(attr.phase_contribution.len(), 15);
389
390        // Verify: amp + phase ≈ fitted - alpha
391        for i in 0..15 {
392            let sum = attr.amplitude_contribution[i] + attr.phase_contribution[i];
393            let expected = result.fitted_values[i] - result.alpha;
394            assert!(
395                (sum - expected).abs() < 1e-6,
396                "amp + phase should ≈ fitted - alpha at i={}: {} vs {}",
397                i,
398                sum,
399                expected
400            );
401        }
402    }
403
404    #[test]
405    fn test_elastic_attribution_vertical_only() {
406        let (data, y, t) = generate_test_data(15, 51);
407        let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Vertical, 0.0, 5, 1e-3).unwrap();
408        let attr = elastic_pcr_attribution(&result, &y, 3, 10, 42).unwrap();
409
410        // Phase contribution should all be zero
411        for i in 0..15 {
412            assert!(
413                attr.phase_contribution[i].abs() < 1e-12,
414                "phase_contribution should be 0 for vertical-only at i={}",
415                i
416            );
417        }
418        assert!(
419            attr.phase_importance.abs() < 1e-12,
420            "phase_importance should be 0 for vertical-only"
421        );
422    }
423
424    #[test]
425    fn test_elastic_attribution_importance_nonnegative() {
426        let (data, y, t) = generate_test_data(15, 51);
427        let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Joint, 0.0, 5, 1e-3).unwrap();
428        let attr = elastic_pcr_attribution(&result, &y, 3, 20, 42).unwrap();
429
430        assert!(
431            attr.amplitude_importance >= 0.0,
432            "amplitude_importance should be >= 0, got {}",
433            attr.amplitude_importance
434        );
435        assert!(
436            attr.phase_importance >= 0.0,
437            "phase_importance should be >= 0, got {}",
438            attr.phase_importance
439        );
440    }
441}