1use crate::alignment::srsf_transform;
9use crate::elastic_fpca::{
10 build_augmented_srsfs, center_matrix, shooting_vectors_from_psis, sphere_karcher_mean,
11 warps_to_normalized_psi,
12};
13use crate::elastic_regression::{ElasticPcrResult, PcaMethod};
14use crate::error::FdarError;
15use crate::matrix::FdMatrix;
16use rand::prelude::*;
17
18#[derive(Debug, Clone, PartialEq)]
20pub struct ElasticAttributionResult {
21 pub amplitude_contribution: Vec<f64>,
23 pub phase_contribution: Vec<f64>,
25 pub amplitude_importance: f64,
27 pub phase_importance: f64,
29}
30
31#[must_use = "expensive computation whose result should not be discarded"]
56pub fn elastic_pcr_attribution(
57 result: &ElasticPcrResult,
58 y: &[f64],
59 ncomp: usize,
60 n_perm: usize,
61 seed: u64,
62) -> Result<ElasticAttributionResult, FdarError> {
63 let n = result.fitted_values.len();
64 if y.len() != n {
65 return Err(FdarError::InvalidDimension {
66 parameter: "y",
67 expected: n.to_string(),
68 actual: y.len().to_string(),
69 });
70 }
71 if ncomp == 0 {
72 return Err(FdarError::InvalidParameter {
73 parameter: "ncomp",
74 message: "ncomp must be >= 1".into(),
75 });
76 }
77 if n < 2 {
78 return Err(FdarError::InvalidParameter {
79 parameter: "n",
80 message: "need at least 2 observations".into(),
81 });
82 }
83 let actual_ncomp = ncomp.min(result.coefficients.len());
84
85 match result.pca_method {
86 PcaMethod::Joint => attribution_joint(result, y, actual_ncomp, n_perm, seed),
87 PcaMethod::Vertical => {
88 let amp: Vec<f64> = result
90 .fitted_values
91 .iter()
92 .map(|&f| f - result.alpha)
93 .collect();
94 let phase = vec![0.0; n];
95 let amp_imp = permutation_importance_single(
96 y,
97 &result.fitted_values,
98 result.alpha,
99 &result.coefficients,
100 actual_ncomp,
101 n_perm,
102 seed,
103 );
104 Ok(ElasticAttributionResult {
105 amplitude_contribution: amp,
106 phase_contribution: phase,
107 amplitude_importance: amp_imp,
108 phase_importance: 0.0,
109 })
110 }
111 PcaMethod::Horizontal => {
112 let phase: Vec<f64> = result
114 .fitted_values
115 .iter()
116 .map(|&f| f - result.alpha)
117 .collect();
118 let amp = vec![0.0; n];
119 let phase_imp = permutation_importance_single(
120 y,
121 &result.fitted_values,
122 result.alpha,
123 &result.coefficients,
124 actual_ncomp,
125 n_perm,
126 seed,
127 );
128 Ok(ElasticAttributionResult {
129 amplitude_contribution: amp,
130 phase_contribution: phase,
131 amplitude_importance: 0.0,
132 phase_importance: phase_imp,
133 })
134 }
135 }
136}
137
138fn attribution_joint(
140 result: &ElasticPcrResult,
141 y: &[f64],
142 ncomp: usize,
143 n_perm: usize,
144 seed: u64,
145) -> Result<ElasticAttributionResult, FdarError> {
146 let joint = result
147 .joint_fpca
148 .as_ref()
149 .ok_or_else(|| FdarError::ComputationFailed {
150 operation: "elastic_pcr_attribution",
151 detail: "joint_fpca result missing from ElasticPcrResult".into(),
152 })?;
153 let km = &result.karcher;
154 let (n, m) = km.aligned_data.shape();
155 let m_aug = m + 1;
156
157 let qn = match &km.aligned_srsfs {
158 Some(srsfs) => srsfs.clone(),
159 None => {
160 let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
161 srsf_transform(&km.aligned_data, &argvals)
162 }
163 };
164
165 let q_aug = build_augmented_srsfs(&qn, &km.aligned_data, n, m);
166 let (_, mean_q) = center_matrix(&q_aug, n, m_aug);
167
168 let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
170 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
171 let psis = warps_to_normalized_psi(&km.gammas, &argvals);
172 let mu_psi = sphere_karcher_mean(&psis, &time, 50);
173 let shooting = shooting_vectors_from_psis(&psis, &mu_psi, &time);
174
175 let c = joint.balance_c;
176 let (amp_scores, phase_scores) = decompose_joint_scores(
177 &q_aug,
178 &mean_q,
179 &shooting,
180 &joint.vert_component,
181 &joint.horiz_component,
182 c,
183 n,
184 m_aug,
185 m,
186 ncomp,
187 );
188
189 let (amplitude_contribution, phase_contribution) =
190 compute_contributions(&_scores, &phase_scores, &result.coefficients, n, ncomp);
191
192 let r2_orig = compute_r2(y, &result.fitted_values);
194 let amplitude_importance = permutation_importance(
195 y,
196 result.alpha,
197 &result.coefficients,
198 &_scores,
199 &phase_scores,
200 ncomp,
201 n_perm,
202 seed,
203 true,
204 );
205 let phase_importance = permutation_importance(
206 y,
207 result.alpha,
208 &result.coefficients,
209 &_scores,
210 &phase_scores,
211 ncomp,
212 n_perm,
213 seed + 1_000_000,
214 false,
215 );
216
217 Ok(ElasticAttributionResult {
218 amplitude_contribution,
219 phase_contribution,
220 amplitude_importance: (r2_orig - amplitude_importance).max(0.0),
221 phase_importance: (r2_orig - phase_importance).max(0.0),
222 })
223}
224
225fn decompose_joint_scores(
227 q_aug: &FdMatrix,
228 mean_q: &[f64],
229 shooting: &FdMatrix,
230 vert_component: &FdMatrix,
231 horiz_component: &FdMatrix,
232 c: f64,
233 n: usize,
234 m_aug: usize,
235 m: usize,
236 ncomp: usize,
237) -> (FdMatrix, FdMatrix) {
238 let mut amp_scores = FdMatrix::zeros(n, ncomp);
239 let mut phase_scores = FdMatrix::zeros(n, ncomp);
240 for k in 0..ncomp {
241 for i in 0..n {
242 let mut amp_s = 0.0;
243 for j in 0..m_aug {
244 amp_s += (q_aug[(i, j)] - mean_q[j]) * vert_component[(k, j)];
245 }
246 amp_scores[(i, k)] = amp_s;
247
248 let mut phase_s = 0.0;
249 for j in 0..m {
250 phase_s += c * shooting[(i, j)] * horiz_component[(k, j)];
251 }
252 phase_scores[(i, k)] = phase_s;
253 }
254 }
255 (amp_scores, phase_scores)
256}
257
258fn compute_contributions(
260 amp_scores: &FdMatrix,
261 phase_scores: &FdMatrix,
262 coefficients: &[f64],
263 n: usize,
264 ncomp: usize,
265) -> (Vec<f64>, Vec<f64>) {
266 let mut amplitude_contribution = vec![0.0; n];
267 let mut phase_contribution = vec![0.0; n];
268 for i in 0..n {
269 for k in 0..ncomp {
270 amplitude_contribution[i] += coefficients[k] * amp_scores[(i, k)];
271 phase_contribution[i] += coefficients[k] * phase_scores[(i, k)];
272 }
273 }
274 (amplitude_contribution, phase_contribution)
275}
276
277fn compute_r2(y: &[f64], fitted: &[f64]) -> f64 {
279 let n = y.len();
280 let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
281 let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
282 let ss_res: f64 = y
283 .iter()
284 .zip(fitted)
285 .map(|(&yi, &fi)| (yi - fi).powi(2))
286 .sum();
287 if ss_tot > 0.0 {
288 1.0 - ss_res / ss_tot
289 } else {
290 0.0
291 }
292}
293
294fn permutation_importance(
296 y: &[f64],
297 alpha: f64,
298 coefficients: &[f64],
299 amp_scores: &FdMatrix,
300 phase_scores: &FdMatrix,
301 ncomp: usize,
302 n_perm: usize,
303 seed: u64,
304 permute_amplitude: bool,
305) -> f64 {
306 let n = y.len();
307 if n_perm == 0 {
308 return compute_r2(y, &vec![alpha; n]);
309 }
310
311 let mut total_r2 = 0.0;
312 for p in 0..n_perm {
313 let mut rng = StdRng::seed_from_u64(seed.wrapping_add(p as u64));
314 let mut perm_idx: Vec<usize> = (0..n).collect();
315 perm_idx.shuffle(&mut rng);
316
317 let fitted = fitted_with_permuted_scores(
318 alpha,
319 coefficients,
320 amp_scores,
321 phase_scores,
322 &perm_idx,
323 n,
324 ncomp,
325 permute_amplitude,
326 );
327 total_r2 += compute_r2(y, &fitted);
328 }
329 total_r2 / n_perm as f64
330}
331
332fn fitted_with_permuted_scores(
334 alpha: f64,
335 coefficients: &[f64],
336 amp_scores: &FdMatrix,
337 phase_scores: &FdMatrix,
338 perm_idx: &[usize],
339 n: usize,
340 ncomp: usize,
341 permute_amplitude: bool,
342) -> Vec<f64> {
343 let mut fitted = vec![0.0; n];
344 for i in 0..n {
345 fitted[i] = alpha;
346 for k in 0..ncomp {
347 let amp_i = if permute_amplitude {
348 amp_scores[(perm_idx[i], k)]
349 } else {
350 amp_scores[(i, k)]
351 };
352 let phase_i = if !permute_amplitude {
353 phase_scores[(perm_idx[i], k)]
354 } else {
355 phase_scores[(i, k)]
356 };
357 fitted[i] += coefficients[k] * (amp_i + phase_i);
358 }
359 }
360 fitted
361}
362
363fn permutation_importance_single(
365 y: &[f64],
366 fitted_values: &[f64],
367 alpha: f64,
368 _coefficients: &[f64],
369 _ncomp: usize,
370 n_perm: usize,
371 seed: u64,
372) -> f64 {
373 let n = y.len();
374 let r2_orig = compute_r2(y, fitted_values);
375 if n_perm == 0 {
376 return r2_orig;
377 }
378
379 let contribs: Vec<f64> = fitted_values.iter().map(|&f| f - alpha).collect();
381 let mut total_r2 = 0.0;
382 for p in 0..n_perm {
383 let mut rng = StdRng::seed_from_u64(seed.wrapping_add(p as u64));
384 let mut perm_idx: Vec<usize> = (0..n).collect();
385 perm_idx.shuffle(&mut rng);
386
387 let fitted_perm: Vec<f64> = (0..n).map(|i| alpha + contribs[perm_idx[i]]).collect();
388 total_r2 += compute_r2(y, &fitted_perm);
389 }
390 let avg_r2 = total_r2 / n_perm as f64;
391 (r2_orig - avg_r2).max(0.0)
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::elastic_regression::{elastic_pcr, PcaMethod};
398 use std::f64::consts::PI;
399
400 fn generate_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>, Vec<f64>) {
401 let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
402 let mut data = FdMatrix::zeros(n, m);
403 let mut y = vec![0.0; n];
404 for i in 0..n {
405 let amp = 1.0 + 0.5 * (i as f64 / n as f64);
406 let shift = 0.1 * (i as f64 - n as f64 / 2.0);
407 for j in 0..m {
408 data[(i, j)] = amp * (2.0 * PI * (t[j] + shift)).sin();
409 }
410 y[i] = amp;
411 }
412 (data, y, t)
413 }
414
415 #[test]
416 fn test_elastic_attribution_joint_decomposition() {
417 let (data, y, t) = generate_test_data(15, 51);
418 let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Joint, 0.0, 5, 1e-3).unwrap();
419 let attr = elastic_pcr_attribution(&result, &y, 3, 10, 42).unwrap();
420
421 assert_eq!(attr.amplitude_contribution.len(), 15);
422 assert_eq!(attr.phase_contribution.len(), 15);
423
424 for i in 0..15 {
426 let sum = attr.amplitude_contribution[i] + attr.phase_contribution[i];
427 let expected = result.fitted_values[i] - result.alpha;
428 assert!(
429 (sum - expected).abs() < 1e-6,
430 "amp + phase should ≈ fitted - alpha at i={}: {} vs {}",
431 i,
432 sum,
433 expected
434 );
435 }
436 }
437
438 #[test]
439 fn test_elastic_attribution_vertical_only() {
440 let (data, y, t) = generate_test_data(15, 51);
441 let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Vertical, 0.0, 5, 1e-3).unwrap();
442 let attr = elastic_pcr_attribution(&result, &y, 3, 10, 42).unwrap();
443
444 for i in 0..15 {
446 assert!(
447 attr.phase_contribution[i].abs() < 1e-12,
448 "phase_contribution should be 0 for vertical-only at i={}",
449 i
450 );
451 }
452 assert!(
453 attr.phase_importance.abs() < 1e-12,
454 "phase_importance should be 0 for vertical-only"
455 );
456 }
457
458 #[test]
459 fn test_elastic_attribution_importance_nonnegative() {
460 let (data, y, t) = generate_test_data(15, 51);
461 let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Joint, 0.0, 5, 1e-3).unwrap();
462 let attr = elastic_pcr_attribution(&result, &y, 3, 20, 42).unwrap();
463
464 assert!(
465 attr.amplitude_importance >= 0.0,
466 "amplitude_importance should be >= 0, got {}",
467 attr.amplitude_importance
468 );
469 assert!(
470 attr.phase_importance >= 0.0,
471 "phase_importance should be >= 0, got {}",
472 attr.phase_importance
473 );
474 }
475}