Skip to main content

fdars_core/
elastic_explain.rs

1//! Elastic shape explainability: amplitude vs phase attribution.
2//!
3//! Decomposes elastic PCR predictions into "shape" (amplitude) and "timing" (phase)
4//! contributions, with permutation-based variable importance.
5//!
6//! - [`elastic_pcr_attribution`] — Decompose predictions and compute importance scores
7
8use crate::alignment::srsf_transform;
9use crate::elastic_fpca::{
10    build_augmented_srsfs, center_matrix, shooting_vectors_from_psis, sphere_karcher_mean,
11    warps_to_normalized_psi,
12};
13use crate::elastic_regression::{ElasticPcrResult, PcaMethod};
14use crate::error::FdarError;
15use crate::matrix::FdMatrix;
16use rand::prelude::*;
17
18/// Result of elastic amplitude/phase attribution.
19#[derive(Debug, Clone, PartialEq)]
20pub struct ElasticAttributionResult {
21    /// Per-observation amplitude contribution (length n).
22    pub amplitude_contribution: Vec<f64>,
23    /// Per-observation phase contribution (length n).
24    pub phase_contribution: Vec<f64>,
25    /// R² drop from permuting amplitude scores.
26    pub amplitude_importance: f64,
27    /// R² drop from permuting phase scores.
28    pub phase_importance: f64,
29}
30
31/// Decompose elastic PCR predictions into amplitude and phase contributions.
32///
33/// For joint FPCA, the joint eigenvectors split into vertical (amplitude) and
34/// horizontal (phase) parts. Each observation's score can be decomposed into
35/// amplitude and phase sub-scores based on these parts.
36///
37/// For vertical-only or horizontal-only models, the missing component
38/// contributes zero.
39///
40/// # Arguments
41/// * `result` — A fitted [`ElasticPcrResult`] (must have stored FPCA results)
42/// * `y` — Original scalar responses (length n)
43/// * `ncomp` — Number of components to use for attribution
44/// * `n_perm` — Number of permutation replicates for importance
45/// * `seed` — RNG seed for permutation reproducibility
46///
47/// # Errors
48///
49/// Returns [`FdarError::InvalidDimension`] if `y.len()` does not match the
50/// number of fitted values in `result`.
51/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero or there are
52/// fewer than 2 observations.
53/// Returns [`FdarError::ComputationFailed`] if the joint FPCA result is
54/// missing from a `PcaMethod::Joint` model.
55#[must_use = "expensive computation whose result should not be discarded"]
56pub fn elastic_pcr_attribution(
57    result: &ElasticPcrResult,
58    y: &[f64],
59    ncomp: usize,
60    n_perm: usize,
61    seed: u64,
62) -> Result<ElasticAttributionResult, FdarError> {
63    let n = result.fitted_values.len();
64    if y.len() != n {
65        return Err(FdarError::InvalidDimension {
66            parameter: "y",
67            expected: n.to_string(),
68            actual: y.len().to_string(),
69        });
70    }
71    if ncomp == 0 {
72        return Err(FdarError::InvalidParameter {
73            parameter: "ncomp",
74            message: "ncomp must be >= 1".into(),
75        });
76    }
77    if n < 2 {
78        return Err(FdarError::InvalidParameter {
79            parameter: "n",
80            message: "need at least 2 observations".into(),
81        });
82    }
83    let actual_ncomp = ncomp.min(result.coefficients.len());
84
85    match result.pca_method {
86        PcaMethod::Joint => attribution_joint(result, y, actual_ncomp, n_perm, seed),
87        PcaMethod::Vertical => {
88            // All contribution is amplitude, phase is zero
89            let amp: Vec<f64> = result
90                .fitted_values
91                .iter()
92                .map(|&f| f - result.alpha)
93                .collect();
94            let phase = vec![0.0; n];
95            let amp_imp = permutation_importance_single(
96                y,
97                &result.fitted_values,
98                result.alpha,
99                &result.coefficients,
100                actual_ncomp,
101                n_perm,
102                seed,
103            );
104            Ok(ElasticAttributionResult {
105                amplitude_contribution: amp,
106                phase_contribution: phase,
107                amplitude_importance: amp_imp,
108                phase_importance: 0.0,
109            })
110        }
111        PcaMethod::Horizontal => {
112            // All contribution is phase, amplitude is zero
113            let phase: Vec<f64> = result
114                .fitted_values
115                .iter()
116                .map(|&f| f - result.alpha)
117                .collect();
118            let amp = vec![0.0; n];
119            let phase_imp = permutation_importance_single(
120                y,
121                &result.fitted_values,
122                result.alpha,
123                &result.coefficients,
124                actual_ncomp,
125                n_perm,
126                seed,
127            );
128            Ok(ElasticAttributionResult {
129                amplitude_contribution: amp,
130                phase_contribution: phase,
131                amplitude_importance: 0.0,
132                phase_importance: phase_imp,
133            })
134        }
135    }
136}
137
138/// Joint FPCA attribution: decompose scores into amp and phase parts.
139fn attribution_joint(
140    result: &ElasticPcrResult,
141    y: &[f64],
142    ncomp: usize,
143    n_perm: usize,
144    seed: u64,
145) -> Result<ElasticAttributionResult, FdarError> {
146    let joint = result
147        .joint_fpca
148        .as_ref()
149        .ok_or_else(|| FdarError::ComputationFailed {
150            operation: "elastic_pcr_attribution",
151            detail: "joint_fpca result missing from ElasticPcrResult".into(),
152        })?;
153    let km = &result.karcher;
154    let (n, m) = km.aligned_data.shape();
155    let m_aug = m + 1;
156
157    let qn = match &km.aligned_srsfs {
158        Some(srsfs) => srsfs.clone(),
159        None => {
160            let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
161            srsf_transform(&km.aligned_data, &argvals)
162        }
163    };
164
165    let q_aug = build_augmented_srsfs(&qn, &km.aligned_data, n, m);
166    let (_, mean_q) = center_matrix(&q_aug, n, m_aug);
167
168    // Compute shooting vectors using shared helpers
169    let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
170    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
171    let psis = warps_to_normalized_psi(&km.gammas, &argvals);
172    let mu_psi = sphere_karcher_mean(&psis, &time, 50);
173    let shooting = shooting_vectors_from_psis(&psis, &mu_psi, &time);
174
175    let c = joint.balance_c;
176    let (amp_scores, phase_scores) = decompose_joint_scores(
177        &q_aug,
178        &mean_q,
179        &shooting,
180        &joint.vert_component,
181        &joint.horiz_component,
182        c,
183        n,
184        m_aug,
185        m,
186        ncomp,
187    );
188
189    let (amplitude_contribution, phase_contribution) =
190        compute_contributions(&amp_scores, &phase_scores, &result.coefficients, n, ncomp);
191
192    // Permutation importance
193    let r2_orig = compute_r2(y, &result.fitted_values);
194    let amplitude_importance = permutation_importance(
195        y,
196        result.alpha,
197        &result.coefficients,
198        &amp_scores,
199        &phase_scores,
200        ncomp,
201        n_perm,
202        seed,
203        true,
204    );
205    let phase_importance = permutation_importance(
206        y,
207        result.alpha,
208        &result.coefficients,
209        &amp_scores,
210        &phase_scores,
211        ncomp,
212        n_perm,
213        seed + 1_000_000,
214        false,
215    );
216
217    Ok(ElasticAttributionResult {
218        amplitude_contribution,
219        phase_contribution,
220        amplitude_importance: (r2_orig - amplitude_importance).max(0.0),
221        phase_importance: (r2_orig - phase_importance).max(0.0),
222    })
223}
224
225/// Decompose joint scores into amplitude and phase sub-scores.
226fn decompose_joint_scores(
227    q_aug: &FdMatrix,
228    mean_q: &[f64],
229    shooting: &FdMatrix,
230    vert_component: &FdMatrix,
231    horiz_component: &FdMatrix,
232    c: f64,
233    n: usize,
234    m_aug: usize,
235    m: usize,
236    ncomp: usize,
237) -> (FdMatrix, FdMatrix) {
238    let mut amp_scores = FdMatrix::zeros(n, ncomp);
239    let mut phase_scores = FdMatrix::zeros(n, ncomp);
240    for k in 0..ncomp {
241        for i in 0..n {
242            let mut amp_s = 0.0;
243            for j in 0..m_aug {
244                amp_s += (q_aug[(i, j)] - mean_q[j]) * vert_component[(k, j)];
245            }
246            amp_scores[(i, k)] = amp_s;
247
248            let mut phase_s = 0.0;
249            for j in 0..m {
250                phase_s += c * shooting[(i, j)] * horiz_component[(k, j)];
251            }
252            phase_scores[(i, k)] = phase_s;
253        }
254    }
255    (amp_scores, phase_scores)
256}
257
258/// Compute amplitude and phase contributions from decomposed scores.
259fn compute_contributions(
260    amp_scores: &FdMatrix,
261    phase_scores: &FdMatrix,
262    coefficients: &[f64],
263    n: usize,
264    ncomp: usize,
265) -> (Vec<f64>, Vec<f64>) {
266    let mut amplitude_contribution = vec![0.0; n];
267    let mut phase_contribution = vec![0.0; n];
268    for i in 0..n {
269        for k in 0..ncomp {
270            amplitude_contribution[i] += coefficients[k] * amp_scores[(i, k)];
271            phase_contribution[i] += coefficients[k] * phase_scores[(i, k)];
272        }
273    }
274    (amplitude_contribution, phase_contribution)
275}
276
277/// Compute R² statistic.
278fn compute_r2(y: &[f64], fitted: &[f64]) -> f64 {
279    let n = y.len();
280    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
281    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
282    let ss_res: f64 = y
283        .iter()
284        .zip(fitted)
285        .map(|(&yi, &fi)| (yi - fi).powi(2))
286        .sum();
287    if ss_tot > 0.0 {
288        1.0 - ss_res / ss_tot
289    } else {
290        0.0
291    }
292}
293
294/// Permutation importance: shuffle one component's scores, recompute fitted, return avg R².
295fn permutation_importance(
296    y: &[f64],
297    alpha: f64,
298    coefficients: &[f64],
299    amp_scores: &FdMatrix,
300    phase_scores: &FdMatrix,
301    ncomp: usize,
302    n_perm: usize,
303    seed: u64,
304    permute_amplitude: bool,
305) -> f64 {
306    let n = y.len();
307    if n_perm == 0 {
308        return compute_r2(y, &vec![alpha; n]);
309    }
310
311    let mut total_r2 = 0.0;
312    for p in 0..n_perm {
313        let mut rng = StdRng::seed_from_u64(seed.wrapping_add(p as u64));
314        let mut perm_idx: Vec<usize> = (0..n).collect();
315        perm_idx.shuffle(&mut rng);
316
317        let fitted = fitted_with_permuted_scores(
318            alpha,
319            coefficients,
320            amp_scores,
321            phase_scores,
322            &perm_idx,
323            n,
324            ncomp,
325            permute_amplitude,
326        );
327        total_r2 += compute_r2(y, &fitted);
328    }
329    total_r2 / n_perm as f64
330}
331
332/// Compute fitted values with one component's scores permuted.
333fn fitted_with_permuted_scores(
334    alpha: f64,
335    coefficients: &[f64],
336    amp_scores: &FdMatrix,
337    phase_scores: &FdMatrix,
338    perm_idx: &[usize],
339    n: usize,
340    ncomp: usize,
341    permute_amplitude: bool,
342) -> Vec<f64> {
343    let mut fitted = vec![0.0; n];
344    for i in 0..n {
345        fitted[i] = alpha;
346        for k in 0..ncomp {
347            let amp_i = if permute_amplitude {
348                amp_scores[(perm_idx[i], k)]
349            } else {
350                amp_scores[(i, k)]
351            };
352            let phase_i = if !permute_amplitude {
353                phase_scores[(perm_idx[i], k)]
354            } else {
355                phase_scores[(i, k)]
356            };
357            fitted[i] += coefficients[k] * (amp_i + phase_i);
358        }
359    }
360    fitted
361}
362
363/// Permutation importance for single-component models (vert-only or horiz-only).
364fn permutation_importance_single(
365    y: &[f64],
366    fitted_values: &[f64],
367    alpha: f64,
368    _coefficients: &[f64],
369    _ncomp: usize,
370    n_perm: usize,
371    seed: u64,
372) -> f64 {
373    let n = y.len();
374    let r2_orig = compute_r2(y, fitted_values);
375    if n_perm == 0 {
376        return r2_orig;
377    }
378
379    // Extract per-obs contribution = fitted - alpha, then permute
380    let contribs: Vec<f64> = fitted_values.iter().map(|&f| f - alpha).collect();
381    let mut total_r2 = 0.0;
382    for p in 0..n_perm {
383        let mut rng = StdRng::seed_from_u64(seed.wrapping_add(p as u64));
384        let mut perm_idx: Vec<usize> = (0..n).collect();
385        perm_idx.shuffle(&mut rng);
386
387        let fitted_perm: Vec<f64> = (0..n).map(|i| alpha + contribs[perm_idx[i]]).collect();
388        total_r2 += compute_r2(y, &fitted_perm);
389    }
390    let avg_r2 = total_r2 / n_perm as f64;
391    (r2_orig - avg_r2).max(0.0)
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use crate::elastic_regression::{elastic_pcr, PcaMethod};
398    use std::f64::consts::PI;
399
400    fn generate_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>, Vec<f64>) {
401        let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
402        let mut data = FdMatrix::zeros(n, m);
403        let mut y = vec![0.0; n];
404        for i in 0..n {
405            let amp = 1.0 + 0.5 * (i as f64 / n as f64);
406            let shift = 0.1 * (i as f64 - n as f64 / 2.0);
407            for j in 0..m {
408                data[(i, j)] = amp * (2.0 * PI * (t[j] + shift)).sin();
409            }
410            y[i] = amp;
411        }
412        (data, y, t)
413    }
414
415    #[test]
416    fn test_elastic_attribution_joint_decomposition() {
417        let (data, y, t) = generate_test_data(15, 51);
418        let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Joint, 0.0, 5, 1e-3).unwrap();
419        let attr = elastic_pcr_attribution(&result, &y, 3, 10, 42).unwrap();
420
421        assert_eq!(attr.amplitude_contribution.len(), 15);
422        assert_eq!(attr.phase_contribution.len(), 15);
423
424        // Verify: amp + phase ≈ fitted - alpha
425        for i in 0..15 {
426            let sum = attr.amplitude_contribution[i] + attr.phase_contribution[i];
427            let expected = result.fitted_values[i] - result.alpha;
428            assert!(
429                (sum - expected).abs() < 1e-6,
430                "amp + phase should ≈ fitted - alpha at i={}: {} vs {}",
431                i,
432                sum,
433                expected
434            );
435        }
436    }
437
438    #[test]
439    fn test_elastic_attribution_vertical_only() {
440        let (data, y, t) = generate_test_data(15, 51);
441        let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Vertical, 0.0, 5, 1e-3).unwrap();
442        let attr = elastic_pcr_attribution(&result, &y, 3, 10, 42).unwrap();
443
444        // Phase contribution should all be zero
445        for i in 0..15 {
446            assert!(
447                attr.phase_contribution[i].abs() < 1e-12,
448                "phase_contribution should be 0 for vertical-only at i={}",
449                i
450            );
451        }
452        assert!(
453            attr.phase_importance.abs() < 1e-12,
454            "phase_importance should be 0 for vertical-only"
455        );
456    }
457
458    #[test]
459    fn test_elastic_attribution_importance_nonnegative() {
460        let (data, y, t) = generate_test_data(15, 51);
461        let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Joint, 0.0, 5, 1e-3).unwrap();
462        let attr = elastic_pcr_attribution(&result, &y, 3, 20, 42).unwrap();
463
464        assert!(
465            attr.amplitude_importance >= 0.0,
466            "amplitude_importance should be >= 0, got {}",
467            attr.amplitude_importance
468        );
469        assert!(
470            attr.phase_importance >= 0.0,
471            "phase_importance should be >= 0, got {}",
472            attr.phase_importance
473        );
474    }
475}