1use crate::alignment::srsf_transform;
9use crate::elastic_fpca::{
10 build_augmented_srsfs, center_matrix, shooting_vectors_from_psis, sphere_karcher_mean,
11 warps_to_normalized_psi,
12};
13use crate::elastic_regression::{ElasticPcrResult, PcaMethod};
14use crate::error::FdarError;
15use crate::matrix::FdMatrix;
16use rand::prelude::*;
17
18#[derive(Debug, Clone, PartialEq)]
20#[non_exhaustive]
21pub struct ElasticAttributionResult {
22 pub amplitude_contribution: Vec<f64>,
24 pub phase_contribution: Vec<f64>,
26 pub amplitude_importance: f64,
28 pub phase_importance: f64,
30}
31
32#[must_use = "expensive computation whose result should not be discarded"]
57pub fn elastic_pcr_attribution(
58 result: &ElasticPcrResult,
59 y: &[f64],
60 ncomp: usize,
61 n_perm: usize,
62 seed: u64,
63) -> Result<ElasticAttributionResult, FdarError> {
64 let n = result.fitted_values.len();
65 if y.len() != n {
66 return Err(FdarError::InvalidDimension {
67 parameter: "y",
68 expected: n.to_string(),
69 actual: y.len().to_string(),
70 });
71 }
72 if ncomp == 0 {
73 return Err(FdarError::InvalidParameter {
74 parameter: "ncomp",
75 message: "ncomp must be >= 1".into(),
76 });
77 }
78 if n < 2 {
79 return Err(FdarError::InvalidParameter {
80 parameter: "n",
81 message: "need at least 2 observations".into(),
82 });
83 }
84 let actual_ncomp = ncomp.min(result.coefficients.len());
85
86 match result.pca_method {
87 PcaMethod::Joint => attribution_joint(result, y, actual_ncomp, n_perm, seed),
88 PcaMethod::Vertical => {
89 let amp: Vec<f64> = result
91 .fitted_values
92 .iter()
93 .map(|&f| f - result.alpha)
94 .collect();
95 let phase = vec![0.0; n];
96 let amp_imp = permutation_importance_single(
97 y,
98 &result.fitted_values,
99 result.alpha,
100 &result.coefficients,
101 actual_ncomp,
102 n_perm,
103 seed,
104 );
105 Ok(ElasticAttributionResult {
106 amplitude_contribution: amp,
107 phase_contribution: phase,
108 amplitude_importance: amp_imp,
109 phase_importance: 0.0,
110 })
111 }
112 PcaMethod::Horizontal => {
113 let phase: Vec<f64> = result
115 .fitted_values
116 .iter()
117 .map(|&f| f - result.alpha)
118 .collect();
119 let amp = vec![0.0; n];
120 let phase_imp = permutation_importance_single(
121 y,
122 &result.fitted_values,
123 result.alpha,
124 &result.coefficients,
125 actual_ncomp,
126 n_perm,
127 seed,
128 );
129 Ok(ElasticAttributionResult {
130 amplitude_contribution: amp,
131 phase_contribution: phase,
132 amplitude_importance: 0.0,
133 phase_importance: phase_imp,
134 })
135 }
136 }
137}
138
139fn attribution_joint(
141 result: &ElasticPcrResult,
142 y: &[f64],
143 ncomp: usize,
144 n_perm: usize,
145 seed: u64,
146) -> Result<ElasticAttributionResult, FdarError> {
147 let joint = result
148 .joint_fpca
149 .as_ref()
150 .ok_or_else(|| FdarError::ComputationFailed {
151 operation: "elastic_pcr_attribution",
152 detail: "joint_fpca result missing from ElasticPcrResult; ensure elastic_pcr was called with PcaMethod::Combined".into(),
153 })?;
154 let km = &result.karcher;
155 let (n, m) = km.aligned_data.shape();
156 let m_aug = m + 1;
157
158 let qn = match &km.aligned_srsfs {
159 Some(srsfs) => srsfs.clone(),
160 None => {
161 let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
162 srsf_transform(&km.aligned_data, &argvals)
163 }
164 };
165
166 let q_aug = build_augmented_srsfs(&qn, &km.aligned_data, n, m);
167 let (_, mean_q) = center_matrix(&q_aug, n, m_aug);
168
169 let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
171 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
172 let psis = warps_to_normalized_psi(&km.gammas, &argvals);
173 let mu_psi = sphere_karcher_mean(&psis, &time, 50);
174 let shooting = shooting_vectors_from_psis(&psis, &mu_psi, &time);
175
176 let c = joint.balance_c;
177 let (amp_scores, phase_scores) = decompose_joint_scores(
178 &q_aug,
179 &mean_q,
180 &shooting,
181 &joint.vert_component,
182 &joint.horiz_component,
183 c,
184 n,
185 m_aug,
186 m,
187 ncomp,
188 );
189
190 let (amplitude_contribution, phase_contribution) =
191 compute_contributions(&_scores, &phase_scores, &result.coefficients, n, ncomp);
192
193 let r2_orig = compute_r2(y, &result.fitted_values);
195 let amplitude_importance = permutation_importance(
196 y,
197 result.alpha,
198 &result.coefficients,
199 &_scores,
200 &phase_scores,
201 ncomp,
202 n_perm,
203 seed,
204 true,
205 );
206 let phase_importance = permutation_importance(
207 y,
208 result.alpha,
209 &result.coefficients,
210 &_scores,
211 &phase_scores,
212 ncomp,
213 n_perm,
214 seed + 1_000_000,
215 false,
216 );
217
218 Ok(ElasticAttributionResult {
219 amplitude_contribution,
220 phase_contribution,
221 amplitude_importance: (r2_orig - amplitude_importance).max(0.0),
222 phase_importance: (r2_orig - phase_importance).max(0.0),
223 })
224}
225
226fn decompose_joint_scores(
228 q_aug: &FdMatrix,
229 mean_q: &[f64],
230 shooting: &FdMatrix,
231 vert_component: &FdMatrix,
232 horiz_component: &FdMatrix,
233 c: f64,
234 n: usize,
235 m_aug: usize,
236 m: usize,
237 ncomp: usize,
238) -> (FdMatrix, FdMatrix) {
239 let mut amp_scores = FdMatrix::zeros(n, ncomp);
240 let mut phase_scores = FdMatrix::zeros(n, ncomp);
241 for k in 0..ncomp {
242 for i in 0..n {
243 let mut amp_s = 0.0;
244 for j in 0..m_aug {
245 amp_s += (q_aug[(i, j)] - mean_q[j]) * vert_component[(k, j)];
246 }
247 amp_scores[(i, k)] = amp_s;
248
249 let mut phase_s = 0.0;
250 for j in 0..m {
251 phase_s += c * shooting[(i, j)] * horiz_component[(k, j)];
252 }
253 phase_scores[(i, k)] = phase_s;
254 }
255 }
256 (amp_scores, phase_scores)
257}
258
259fn compute_contributions(
261 amp_scores: &FdMatrix,
262 phase_scores: &FdMatrix,
263 coefficients: &[f64],
264 n: usize,
265 ncomp: usize,
266) -> (Vec<f64>, Vec<f64>) {
267 let mut amplitude_contribution = vec![0.0; n];
268 let mut phase_contribution = vec![0.0; n];
269 for i in 0..n {
270 for k in 0..ncomp {
271 amplitude_contribution[i] += coefficients[k] * amp_scores[(i, k)];
272 phase_contribution[i] += coefficients[k] * phase_scores[(i, k)];
273 }
274 }
275 (amplitude_contribution, phase_contribution)
276}
277
278fn compute_r2(y: &[f64], fitted: &[f64]) -> f64 {
280 let n = y.len();
281 let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
282 let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
283 let ss_res: f64 = y
284 .iter()
285 .zip(fitted)
286 .map(|(&yi, &fi)| (yi - fi).powi(2))
287 .sum();
288 if ss_tot > 0.0 {
289 1.0 - ss_res / ss_tot
290 } else {
291 0.0
292 }
293}
294
295fn permutation_importance(
297 y: &[f64],
298 alpha: f64,
299 coefficients: &[f64],
300 amp_scores: &FdMatrix,
301 phase_scores: &FdMatrix,
302 ncomp: usize,
303 n_perm: usize,
304 seed: u64,
305 permute_amplitude: bool,
306) -> f64 {
307 let n = y.len();
308 if n_perm == 0 {
309 return compute_r2(y, &vec![alpha; n]);
310 }
311
312 let mut total_r2 = 0.0;
313 for p in 0..n_perm {
314 let mut rng = StdRng::seed_from_u64(seed.wrapping_add(p as u64));
315 let mut perm_idx: Vec<usize> = (0..n).collect();
316 perm_idx.shuffle(&mut rng);
317
318 let fitted = fitted_with_permuted_scores(
319 alpha,
320 coefficients,
321 amp_scores,
322 phase_scores,
323 &perm_idx,
324 n,
325 ncomp,
326 permute_amplitude,
327 );
328 total_r2 += compute_r2(y, &fitted);
329 }
330 total_r2 / n_perm as f64
331}
332
333fn fitted_with_permuted_scores(
335 alpha: f64,
336 coefficients: &[f64],
337 amp_scores: &FdMatrix,
338 phase_scores: &FdMatrix,
339 perm_idx: &[usize],
340 n: usize,
341 ncomp: usize,
342 permute_amplitude: bool,
343) -> Vec<f64> {
344 let mut fitted = vec![0.0; n];
345 for i in 0..n {
346 fitted[i] = alpha;
347 for k in 0..ncomp {
348 let amp_i = if permute_amplitude {
349 amp_scores[(perm_idx[i], k)]
350 } else {
351 amp_scores[(i, k)]
352 };
353 let phase_i = if permute_amplitude {
354 phase_scores[(i, k)]
355 } else {
356 phase_scores[(perm_idx[i], k)]
357 };
358 fitted[i] += coefficients[k] * (amp_i + phase_i);
359 }
360 }
361 fitted
362}
363
364fn permutation_importance_single(
366 y: &[f64],
367 fitted_values: &[f64],
368 alpha: f64,
369 _coefficients: &[f64],
370 _ncomp: usize,
371 n_perm: usize,
372 seed: u64,
373) -> f64 {
374 let n = y.len();
375 let r2_orig = compute_r2(y, fitted_values);
376 if n_perm == 0 {
377 return r2_orig;
378 }
379
380 let contribs: Vec<f64> = fitted_values.iter().map(|&f| f - alpha).collect();
382 let mut total_r2 = 0.0;
383 for p in 0..n_perm {
384 let mut rng = StdRng::seed_from_u64(seed.wrapping_add(p as u64));
385 let mut perm_idx: Vec<usize> = (0..n).collect();
386 perm_idx.shuffle(&mut rng);
387
388 let fitted_perm: Vec<f64> = (0..n).map(|i| alpha + contribs[perm_idx[i]]).collect();
389 total_r2 += compute_r2(y, &fitted_perm);
390 }
391 let avg_r2 = total_r2 / n_perm as f64;
392 (r2_orig - avg_r2).max(0.0)
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use crate::elastic_regression::{elastic_pcr, PcaMethod};
399 use std::f64::consts::PI;
400
401 fn generate_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>, Vec<f64>) {
402 let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
403 let mut data = FdMatrix::zeros(n, m);
404 let mut y = vec![0.0; n];
405 for i in 0..n {
406 let amp = 1.0 + 0.5 * (i as f64 / n as f64);
407 let shift = 0.1 * (i as f64 - n as f64 / 2.0);
408 for j in 0..m {
409 data[(i, j)] = amp * (2.0 * PI * (t[j] + shift)).sin();
410 }
411 y[i] = amp;
412 }
413 (data, y, t)
414 }
415
416 #[test]
417 fn test_elastic_attribution_joint_decomposition() {
418 let (data, y, t) = generate_test_data(15, 51);
419 let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Joint, 0.0, 5, 1e-3).unwrap();
420 let attr = elastic_pcr_attribution(&result, &y, 3, 10, 42).unwrap();
421
422 assert_eq!(attr.amplitude_contribution.len(), 15);
423 assert_eq!(attr.phase_contribution.len(), 15);
424
425 for i in 0..15 {
427 let sum = attr.amplitude_contribution[i] + attr.phase_contribution[i];
428 let expected = result.fitted_values[i] - result.alpha;
429 assert!(
430 (sum - expected).abs() < 1e-6,
431 "amp + phase should ≈ fitted - alpha at i={}: {} vs {}",
432 i,
433 sum,
434 expected
435 );
436 }
437 }
438
439 #[test]
440 fn test_elastic_attribution_vertical_only() {
441 let (data, y, t) = generate_test_data(15, 51);
442 let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Vertical, 0.0, 5, 1e-3).unwrap();
443 let attr = elastic_pcr_attribution(&result, &y, 3, 10, 42).unwrap();
444
445 for i in 0..15 {
447 assert!(
448 attr.phase_contribution[i].abs() < 1e-12,
449 "phase_contribution should be 0 for vertical-only at i={}",
450 i
451 );
452 }
453 assert!(
454 attr.phase_importance.abs() < 1e-12,
455 "phase_importance should be 0 for vertical-only"
456 );
457 }
458
459 #[test]
460 fn test_elastic_attribution_importance_nonnegative() {
461 let (data, y, t) = generate_test_data(15, 51);
462 let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Joint, 0.0, 5, 1e-3).unwrap();
463 let attr = elastic_pcr_attribution(&result, &y, 3, 20, 42).unwrap();
464
465 assert!(
466 attr.amplitude_importance >= 0.0,
467 "amplitude_importance should be >= 0, got {}",
468 attr.amplitude_importance
469 );
470 assert!(
471 attr.phase_importance >= 0.0,
472 "phase_importance should be >= 0, got {}",
473 attr.phase_importance
474 );
475 }
476}