Skip to main content

fdars_core/
elastic_explain.rs

1//! Elastic shape explainability: amplitude vs phase attribution.
2//!
3//! Decomposes elastic PCR predictions into "shape" (amplitude) and "timing" (phase)
4//! contributions, with permutation-based variable importance.
5//!
6//! - [`elastic_pcr_attribution`] — Decompose predictions and compute importance scores
7
8use crate::alignment::srsf_transform;
9use crate::elastic_fpca::{
10    build_augmented_srsfs, center_matrix, shooting_vectors_from_psis, sphere_karcher_mean,
11    warps_to_normalized_psi,
12};
13use crate::elastic_regression::{ElasticPcrResult, PcaMethod};
14use crate::error::FdarError;
15use crate::matrix::FdMatrix;
16use rand::prelude::*;
17
18/// Result of elastic amplitude/phase attribution.
19#[derive(Debug, Clone, PartialEq)]
20#[non_exhaustive]
21pub struct ElasticAttributionResult {
22    /// Per-observation amplitude contribution (length n).
23    pub amplitude_contribution: Vec<f64>,
24    /// Per-observation phase contribution (length n).
25    pub phase_contribution: Vec<f64>,
26    /// R² drop from permuting amplitude scores.
27    pub amplitude_importance: f64,
28    /// R² drop from permuting phase scores.
29    pub phase_importance: f64,
30}
31
32/// Decompose elastic PCR predictions into amplitude and phase contributions.
33///
34/// For joint FPCA, the joint eigenvectors split into vertical (amplitude) and
35/// horizontal (phase) parts. Each observation's score can be decomposed into
36/// amplitude and phase sub-scores based on these parts.
37///
38/// For vertical-only or horizontal-only models, the missing component
39/// contributes zero.
40///
41/// # Arguments
42/// * `result` — A fitted [`ElasticPcrResult`] (must have stored FPCA results)
43/// * `y` — Original scalar responses (length n)
44/// * `ncomp` — Number of components to use for attribution
45/// * `n_perm` — Number of permutation replicates for importance
46/// * `seed` — RNG seed for permutation reproducibility
47///
48/// # Errors
49///
50/// Returns [`FdarError::InvalidDimension`] if `y.len()` does not match the
51/// number of fitted values in `result`.
52/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero or there are
53/// fewer than 2 observations.
54/// Returns [`FdarError::ComputationFailed`] if the joint FPCA result is
55/// missing from a `PcaMethod::Joint` model.
56#[must_use = "expensive computation whose result should not be discarded"]
57pub fn elastic_pcr_attribution(
58    result: &ElasticPcrResult,
59    y: &[f64],
60    ncomp: usize,
61    n_perm: usize,
62    seed: u64,
63) -> Result<ElasticAttributionResult, FdarError> {
64    let n = result.fitted_values.len();
65    if y.len() != n {
66        return Err(FdarError::InvalidDimension {
67            parameter: "y",
68            expected: n.to_string(),
69            actual: y.len().to_string(),
70        });
71    }
72    if ncomp == 0 {
73        return Err(FdarError::InvalidParameter {
74            parameter: "ncomp",
75            message: "ncomp must be >= 1".into(),
76        });
77    }
78    if n < 2 {
79        return Err(FdarError::InvalidParameter {
80            parameter: "n",
81            message: "need at least 2 observations".into(),
82        });
83    }
84    let actual_ncomp = ncomp.min(result.coefficients.len());
85
86    match result.pca_method {
87        PcaMethod::Joint => attribution_joint(result, y, actual_ncomp, n_perm, seed),
88        PcaMethod::Vertical => {
89            // All contribution is amplitude, phase is zero
90            let amp: Vec<f64> = result
91                .fitted_values
92                .iter()
93                .map(|&f| f - result.alpha)
94                .collect();
95            let phase = vec![0.0; n];
96            let amp_imp = permutation_importance_single(
97                y,
98                &result.fitted_values,
99                result.alpha,
100                &result.coefficients,
101                actual_ncomp,
102                n_perm,
103                seed,
104            );
105            Ok(ElasticAttributionResult {
106                amplitude_contribution: amp,
107                phase_contribution: phase,
108                amplitude_importance: amp_imp,
109                phase_importance: 0.0,
110            })
111        }
112        PcaMethod::Horizontal => {
113            // All contribution is phase, amplitude is zero
114            let phase: Vec<f64> = result
115                .fitted_values
116                .iter()
117                .map(|&f| f - result.alpha)
118                .collect();
119            let amp = vec![0.0; n];
120            let phase_imp = permutation_importance_single(
121                y,
122                &result.fitted_values,
123                result.alpha,
124                &result.coefficients,
125                actual_ncomp,
126                n_perm,
127                seed,
128            );
129            Ok(ElasticAttributionResult {
130                amplitude_contribution: amp,
131                phase_contribution: phase,
132                amplitude_importance: 0.0,
133                phase_importance: phase_imp,
134            })
135        }
136    }
137}
138
139/// Joint FPCA attribution: decompose scores into amp and phase parts.
140fn attribution_joint(
141    result: &ElasticPcrResult,
142    y: &[f64],
143    ncomp: usize,
144    n_perm: usize,
145    seed: u64,
146) -> Result<ElasticAttributionResult, FdarError> {
147    let joint = result
148        .joint_fpca
149        .as_ref()
150        .ok_or_else(|| FdarError::ComputationFailed {
151            operation: "elastic_pcr_attribution",
152            detail: "joint_fpca result missing from ElasticPcrResult; ensure elastic_pcr was called with PcaMethod::Combined".into(),
153        })?;
154    let km = &result.karcher;
155    let (n, m) = km.aligned_data.shape();
156    let m_aug = m + 1;
157
158    let qn = match &km.aligned_srsfs {
159        Some(srsfs) => srsfs.clone(),
160        None => {
161            let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
162            srsf_transform(&km.aligned_data, &argvals)
163        }
164    };
165
166    let q_aug = build_augmented_srsfs(&qn, &km.aligned_data, n, m);
167    let (_, mean_q) = center_matrix(&q_aug, n, m_aug);
168
169    // Compute shooting vectors using shared helpers
170    let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
171    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
172    let psis = warps_to_normalized_psi(&km.gammas, &argvals);
173    let mu_psi = sphere_karcher_mean(&psis, &time, 50);
174    let shooting = shooting_vectors_from_psis(&psis, &mu_psi, &time);
175
176    let c = joint.balance_c;
177    let (amp_scores, phase_scores) = decompose_joint_scores(
178        &q_aug,
179        &mean_q,
180        &shooting,
181        &joint.vert_component,
182        &joint.horiz_component,
183        c,
184        n,
185        m_aug,
186        m,
187        ncomp,
188    );
189
190    let (amplitude_contribution, phase_contribution) =
191        compute_contributions(&amp_scores, &phase_scores, &result.coefficients, n, ncomp);
192
193    // Permutation importance
194    let r2_orig = compute_r2(y, &result.fitted_values);
195    let amplitude_importance = permutation_importance(
196        y,
197        result.alpha,
198        &result.coefficients,
199        &amp_scores,
200        &phase_scores,
201        ncomp,
202        n_perm,
203        seed,
204        true,
205    );
206    let phase_importance = permutation_importance(
207        y,
208        result.alpha,
209        &result.coefficients,
210        &amp_scores,
211        &phase_scores,
212        ncomp,
213        n_perm,
214        seed + 1_000_000,
215        false,
216    );
217
218    Ok(ElasticAttributionResult {
219        amplitude_contribution,
220        phase_contribution,
221        amplitude_importance: (r2_orig - amplitude_importance).max(0.0),
222        phase_importance: (r2_orig - phase_importance).max(0.0),
223    })
224}
225
226/// Decompose joint scores into amplitude and phase sub-scores.
227fn decompose_joint_scores(
228    q_aug: &FdMatrix,
229    mean_q: &[f64],
230    shooting: &FdMatrix,
231    vert_component: &FdMatrix,
232    horiz_component: &FdMatrix,
233    c: f64,
234    n: usize,
235    m_aug: usize,
236    m: usize,
237    ncomp: usize,
238) -> (FdMatrix, FdMatrix) {
239    let mut amp_scores = FdMatrix::zeros(n, ncomp);
240    let mut phase_scores = FdMatrix::zeros(n, ncomp);
241    for k in 0..ncomp {
242        for i in 0..n {
243            let mut amp_s = 0.0;
244            for j in 0..m_aug {
245                amp_s += (q_aug[(i, j)] - mean_q[j]) * vert_component[(k, j)];
246            }
247            amp_scores[(i, k)] = amp_s;
248
249            let mut phase_s = 0.0;
250            for j in 0..m {
251                phase_s += c * shooting[(i, j)] * horiz_component[(k, j)];
252            }
253            phase_scores[(i, k)] = phase_s;
254        }
255    }
256    (amp_scores, phase_scores)
257}
258
259/// Compute amplitude and phase contributions from decomposed scores.
260fn compute_contributions(
261    amp_scores: &FdMatrix,
262    phase_scores: &FdMatrix,
263    coefficients: &[f64],
264    n: usize,
265    ncomp: usize,
266) -> (Vec<f64>, Vec<f64>) {
267    let mut amplitude_contribution = vec![0.0; n];
268    let mut phase_contribution = vec![0.0; n];
269    for i in 0..n {
270        for k in 0..ncomp {
271            amplitude_contribution[i] += coefficients[k] * amp_scores[(i, k)];
272            phase_contribution[i] += coefficients[k] * phase_scores[(i, k)];
273        }
274    }
275    (amplitude_contribution, phase_contribution)
276}
277
278/// Compute R² statistic.
279fn compute_r2(y: &[f64], fitted: &[f64]) -> f64 {
280    let n = y.len();
281    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
282    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
283    let ss_res: f64 = y
284        .iter()
285        .zip(fitted)
286        .map(|(&yi, &fi)| (yi - fi).powi(2))
287        .sum();
288    if ss_tot > 0.0 {
289        1.0 - ss_res / ss_tot
290    } else {
291        0.0
292    }
293}
294
295/// Permutation importance: shuffle one component's scores, recompute fitted, return avg R².
296fn permutation_importance(
297    y: &[f64],
298    alpha: f64,
299    coefficients: &[f64],
300    amp_scores: &FdMatrix,
301    phase_scores: &FdMatrix,
302    ncomp: usize,
303    n_perm: usize,
304    seed: u64,
305    permute_amplitude: bool,
306) -> f64 {
307    let n = y.len();
308    if n_perm == 0 {
309        return compute_r2(y, &vec![alpha; n]);
310    }
311
312    let mut total_r2 = 0.0;
313    for p in 0..n_perm {
314        let mut rng = StdRng::seed_from_u64(seed.wrapping_add(p as u64));
315        let mut perm_idx: Vec<usize> = (0..n).collect();
316        perm_idx.shuffle(&mut rng);
317
318        let fitted = fitted_with_permuted_scores(
319            alpha,
320            coefficients,
321            amp_scores,
322            phase_scores,
323            &perm_idx,
324            n,
325            ncomp,
326            permute_amplitude,
327        );
328        total_r2 += compute_r2(y, &fitted);
329    }
330    total_r2 / n_perm as f64
331}
332
333/// Compute fitted values with one component's scores permuted.
334fn fitted_with_permuted_scores(
335    alpha: f64,
336    coefficients: &[f64],
337    amp_scores: &FdMatrix,
338    phase_scores: &FdMatrix,
339    perm_idx: &[usize],
340    n: usize,
341    ncomp: usize,
342    permute_amplitude: bool,
343) -> Vec<f64> {
344    let mut fitted = vec![0.0; n];
345    for i in 0..n {
346        fitted[i] = alpha;
347        for k in 0..ncomp {
348            let amp_i = if permute_amplitude {
349                amp_scores[(perm_idx[i], k)]
350            } else {
351                amp_scores[(i, k)]
352            };
353            let phase_i = if permute_amplitude {
354                phase_scores[(i, k)]
355            } else {
356                phase_scores[(perm_idx[i], k)]
357            };
358            fitted[i] += coefficients[k] * (amp_i + phase_i);
359        }
360    }
361    fitted
362}
363
364/// Permutation importance for single-component models (vert-only or horiz-only).
365fn permutation_importance_single(
366    y: &[f64],
367    fitted_values: &[f64],
368    alpha: f64,
369    _coefficients: &[f64],
370    _ncomp: usize,
371    n_perm: usize,
372    seed: u64,
373) -> f64 {
374    let n = y.len();
375    let r2_orig = compute_r2(y, fitted_values);
376    if n_perm == 0 {
377        return r2_orig;
378    }
379
380    // Extract per-obs contribution = fitted - alpha, then permute
381    let contribs: Vec<f64> = fitted_values.iter().map(|&f| f - alpha).collect();
382    let mut total_r2 = 0.0;
383    for p in 0..n_perm {
384        let mut rng = StdRng::seed_from_u64(seed.wrapping_add(p as u64));
385        let mut perm_idx: Vec<usize> = (0..n).collect();
386        perm_idx.shuffle(&mut rng);
387
388        let fitted_perm: Vec<f64> = (0..n).map(|i| alpha + contribs[perm_idx[i]]).collect();
389        total_r2 += compute_r2(y, &fitted_perm);
390    }
391    let avg_r2 = total_r2 / n_perm as f64;
392    (r2_orig - avg_r2).max(0.0)
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use crate::elastic_regression::{elastic_pcr, PcaMethod};
399    use std::f64::consts::PI;
400
401    fn generate_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>, Vec<f64>) {
402        let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
403        let mut data = FdMatrix::zeros(n, m);
404        let mut y = vec![0.0; n];
405        for i in 0..n {
406            let amp = 1.0 + 0.5 * (i as f64 / n as f64);
407            let shift = 0.1 * (i as f64 - n as f64 / 2.0);
408            for j in 0..m {
409                data[(i, j)] = amp * (2.0 * PI * (t[j] + shift)).sin();
410            }
411            y[i] = amp;
412        }
413        (data, y, t)
414    }
415
416    #[test]
417    fn test_elastic_attribution_joint_decomposition() {
418        let (data, y, t) = generate_test_data(15, 51);
419        let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Joint, 0.0, 5, 1e-3).unwrap();
420        let attr = elastic_pcr_attribution(&result, &y, 3, 10, 42).unwrap();
421
422        assert_eq!(attr.amplitude_contribution.len(), 15);
423        assert_eq!(attr.phase_contribution.len(), 15);
424
425        // Verify: amp + phase ≈ fitted - alpha
426        for i in 0..15 {
427            let sum = attr.amplitude_contribution[i] + attr.phase_contribution[i];
428            let expected = result.fitted_values[i] - result.alpha;
429            assert!(
430                (sum - expected).abs() < 1e-6,
431                "amp + phase should ≈ fitted - alpha at i={}: {} vs {}",
432                i,
433                sum,
434                expected
435            );
436        }
437    }
438
439    #[test]
440    fn test_elastic_attribution_vertical_only() {
441        let (data, y, t) = generate_test_data(15, 51);
442        let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Vertical, 0.0, 5, 1e-3).unwrap();
443        let attr = elastic_pcr_attribution(&result, &y, 3, 10, 42).unwrap();
444
445        // Phase contribution should all be zero
446        for i in 0..15 {
447            assert!(
448                attr.phase_contribution[i].abs() < 1e-12,
449                "phase_contribution should be 0 for vertical-only at i={}",
450                i
451            );
452        }
453        assert!(
454            attr.phase_importance.abs() < 1e-12,
455            "phase_importance should be 0 for vertical-only"
456        );
457    }
458
459    #[test]
460    fn test_elastic_attribution_importance_nonnegative() {
461        let (data, y, t) = generate_test_data(15, 51);
462        let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Joint, 0.0, 5, 1e-3).unwrap();
463        let attr = elastic_pcr_attribution(&result, &y, 3, 20, 42).unwrap();
464
465        assert!(
466            attr.amplitude_importance >= 0.0,
467            "amplitude_importance should be >= 0, got {}",
468            attr.amplitude_importance
469        );
470        assert!(
471            attr.phase_importance >= 0.0,
472            "phase_importance should be >= 0, got {}",
473            attr.phase_importance
474        );
475    }
476}