1use crate::alignment::srsf_transform;
9use crate::elastic_fpca::{
10 build_augmented_srsfs, center_matrix, shooting_vectors_from_psis, sphere_karcher_mean,
11 warps_to_normalized_psi,
12};
13use crate::elastic_regression::{ElasticPcrResult, PcaMethod};
14use crate::matrix::FdMatrix;
15use rand::prelude::*;
16
17pub struct ElasticAttributionResult {
19 pub amplitude_contribution: Vec<f64>,
21 pub phase_contribution: Vec<f64>,
23 pub amplitude_importance: f64,
25 pub phase_importance: f64,
27}
28
29pub fn elastic_pcr_attribution(
45 result: &ElasticPcrResult,
46 y: &[f64],
47 ncomp: usize,
48 n_perm: usize,
49 seed: u64,
50) -> Option<ElasticAttributionResult> {
51 let n = result.fitted_values.len();
52 if y.len() != n || ncomp == 0 || n < 2 {
53 return None;
54 }
55 let actual_ncomp = ncomp.min(result.coefficients.len());
56
57 match result.pca_method {
58 PcaMethod::Joint => attribution_joint(result, y, actual_ncomp, n_perm, seed),
59 PcaMethod::Vertical => {
60 let amp: Vec<f64> = result
62 .fitted_values
63 .iter()
64 .map(|&f| f - result.alpha)
65 .collect();
66 let phase = vec![0.0; n];
67 let amp_imp = permutation_importance_single(
68 y,
69 &result.fitted_values,
70 result.alpha,
71 &result.coefficients,
72 actual_ncomp,
73 n_perm,
74 seed,
75 );
76 Some(ElasticAttributionResult {
77 amplitude_contribution: amp,
78 phase_contribution: phase,
79 amplitude_importance: amp_imp,
80 phase_importance: 0.0,
81 })
82 }
83 PcaMethod::Horizontal => {
84 let phase: Vec<f64> = result
86 .fitted_values
87 .iter()
88 .map(|&f| f - result.alpha)
89 .collect();
90 let amp = vec![0.0; n];
91 let phase_imp = permutation_importance_single(
92 y,
93 &result.fitted_values,
94 result.alpha,
95 &result.coefficients,
96 actual_ncomp,
97 n_perm,
98 seed,
99 );
100 Some(ElasticAttributionResult {
101 amplitude_contribution: amp,
102 phase_contribution: phase,
103 amplitude_importance: 0.0,
104 phase_importance: phase_imp,
105 })
106 }
107 }
108}
109
110fn attribution_joint(
112 result: &ElasticPcrResult,
113 y: &[f64],
114 ncomp: usize,
115 n_perm: usize,
116 seed: u64,
117) -> Option<ElasticAttributionResult> {
118 let joint = result.joint_fpca.as_ref()?;
119 let km = &result.karcher;
120 let (n, m) = km.aligned_data.shape();
121 let m_aug = m + 1;
122
123 let qn = match &km.aligned_srsfs {
124 Some(srsfs) => srsfs.clone(),
125 None => {
126 let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
127 srsf_transform(&km.aligned_data, &argvals)
128 }
129 };
130
131 let q_aug = build_augmented_srsfs(&qn, &km.aligned_data, n, m);
132 let (_, mean_q) = center_matrix(&q_aug, n, m_aug);
133
134 let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
136 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
137 let psis = warps_to_normalized_psi(&km.gammas, &argvals);
138 let mu_psi = sphere_karcher_mean(&psis, &time, 50);
139 let shooting = shooting_vectors_from_psis(&psis, &mu_psi, &time);
140
141 let c = joint.balance_c;
142 let (amp_scores, phase_scores) = decompose_joint_scores(
143 &q_aug,
144 &mean_q,
145 &shooting,
146 &joint.vert_component,
147 &joint.horiz_component,
148 c,
149 n,
150 m_aug,
151 m,
152 ncomp,
153 );
154
155 let (amplitude_contribution, phase_contribution) =
156 compute_contributions(&_scores, &phase_scores, &result.coefficients, n, ncomp);
157
158 let r2_orig = compute_r2(y, &result.fitted_values);
160 let amplitude_importance = permutation_importance(
161 y,
162 result.alpha,
163 &result.coefficients,
164 &_scores,
165 &phase_scores,
166 ncomp,
167 n_perm,
168 seed,
169 true,
170 );
171 let phase_importance = permutation_importance(
172 y,
173 result.alpha,
174 &result.coefficients,
175 &_scores,
176 &phase_scores,
177 ncomp,
178 n_perm,
179 seed + 1_000_000,
180 false,
181 );
182
183 Some(ElasticAttributionResult {
184 amplitude_contribution,
185 phase_contribution,
186 amplitude_importance: (r2_orig - amplitude_importance).max(0.0),
187 phase_importance: (r2_orig - phase_importance).max(0.0),
188 })
189}
190
191fn decompose_joint_scores(
193 q_aug: &FdMatrix,
194 mean_q: &[f64],
195 shooting: &FdMatrix,
196 vert_component: &FdMatrix,
197 horiz_component: &FdMatrix,
198 c: f64,
199 n: usize,
200 m_aug: usize,
201 m: usize,
202 ncomp: usize,
203) -> (FdMatrix, FdMatrix) {
204 let mut amp_scores = FdMatrix::zeros(n, ncomp);
205 let mut phase_scores = FdMatrix::zeros(n, ncomp);
206 for k in 0..ncomp {
207 for i in 0..n {
208 let mut amp_s = 0.0;
209 for j in 0..m_aug {
210 amp_s += (q_aug[(i, j)] - mean_q[j]) * vert_component[(k, j)];
211 }
212 amp_scores[(i, k)] = amp_s;
213
214 let mut phase_s = 0.0;
215 for j in 0..m {
216 phase_s += c * shooting[(i, j)] * horiz_component[(k, j)];
217 }
218 phase_scores[(i, k)] = phase_s;
219 }
220 }
221 (amp_scores, phase_scores)
222}
223
224fn compute_contributions(
226 amp_scores: &FdMatrix,
227 phase_scores: &FdMatrix,
228 coefficients: &[f64],
229 n: usize,
230 ncomp: usize,
231) -> (Vec<f64>, Vec<f64>) {
232 let mut amplitude_contribution = vec![0.0; n];
233 let mut phase_contribution = vec![0.0; n];
234 for i in 0..n {
235 for k in 0..ncomp {
236 amplitude_contribution[i] += coefficients[k] * amp_scores[(i, k)];
237 phase_contribution[i] += coefficients[k] * phase_scores[(i, k)];
238 }
239 }
240 (amplitude_contribution, phase_contribution)
241}
242
243fn compute_r2(y: &[f64], fitted: &[f64]) -> f64 {
245 let n = y.len();
246 let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
247 let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
248 let ss_res: f64 = y
249 .iter()
250 .zip(fitted)
251 .map(|(&yi, &fi)| (yi - fi).powi(2))
252 .sum();
253 if ss_tot > 0.0 {
254 1.0 - ss_res / ss_tot
255 } else {
256 0.0
257 }
258}
259
260fn permutation_importance(
262 y: &[f64],
263 alpha: f64,
264 coefficients: &[f64],
265 amp_scores: &FdMatrix,
266 phase_scores: &FdMatrix,
267 ncomp: usize,
268 n_perm: usize,
269 seed: u64,
270 permute_amplitude: bool,
271) -> f64 {
272 let n = y.len();
273 if n_perm == 0 {
274 return compute_r2(y, &vec![alpha; n]);
275 }
276
277 let mut total_r2 = 0.0;
278 for p in 0..n_perm {
279 let mut rng = StdRng::seed_from_u64(seed.wrapping_add(p as u64));
280 let mut perm_idx: Vec<usize> = (0..n).collect();
281 perm_idx.shuffle(&mut rng);
282
283 let fitted = fitted_with_permuted_scores(
284 alpha,
285 coefficients,
286 amp_scores,
287 phase_scores,
288 &perm_idx,
289 n,
290 ncomp,
291 permute_amplitude,
292 );
293 total_r2 += compute_r2(y, &fitted);
294 }
295 total_r2 / n_perm as f64
296}
297
298fn fitted_with_permuted_scores(
300 alpha: f64,
301 coefficients: &[f64],
302 amp_scores: &FdMatrix,
303 phase_scores: &FdMatrix,
304 perm_idx: &[usize],
305 n: usize,
306 ncomp: usize,
307 permute_amplitude: bool,
308) -> Vec<f64> {
309 let mut fitted = vec![0.0; n];
310 for i in 0..n {
311 fitted[i] = alpha;
312 for k in 0..ncomp {
313 let amp_i = if permute_amplitude {
314 amp_scores[(perm_idx[i], k)]
315 } else {
316 amp_scores[(i, k)]
317 };
318 let phase_i = if !permute_amplitude {
319 phase_scores[(perm_idx[i], k)]
320 } else {
321 phase_scores[(i, k)]
322 };
323 fitted[i] += coefficients[k] * (amp_i + phase_i);
324 }
325 }
326 fitted
327}
328
329fn permutation_importance_single(
331 y: &[f64],
332 fitted_values: &[f64],
333 alpha: f64,
334 _coefficients: &[f64],
335 _ncomp: usize,
336 n_perm: usize,
337 seed: u64,
338) -> f64 {
339 let n = y.len();
340 let r2_orig = compute_r2(y, fitted_values);
341 if n_perm == 0 {
342 return r2_orig;
343 }
344
345 let contribs: Vec<f64> = fitted_values.iter().map(|&f| f - alpha).collect();
347 let mut total_r2 = 0.0;
348 for p in 0..n_perm {
349 let mut rng = StdRng::seed_from_u64(seed.wrapping_add(p as u64));
350 let mut perm_idx: Vec<usize> = (0..n).collect();
351 perm_idx.shuffle(&mut rng);
352
353 let fitted_perm: Vec<f64> = (0..n).map(|i| alpha + contribs[perm_idx[i]]).collect();
354 total_r2 += compute_r2(y, &fitted_perm);
355 }
356 let avg_r2 = total_r2 / n_perm as f64;
357 (r2_orig - avg_r2).max(0.0)
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363 use crate::elastic_regression::{elastic_pcr, PcaMethod};
364 use std::f64::consts::PI;
365
366 fn generate_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>, Vec<f64>) {
367 let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
368 let mut data = FdMatrix::zeros(n, m);
369 let mut y = vec![0.0; n];
370 for i in 0..n {
371 let amp = 1.0 + 0.5 * (i as f64 / n as f64);
372 let shift = 0.1 * (i as f64 - n as f64 / 2.0);
373 for j in 0..m {
374 data[(i, j)] = amp * (2.0 * PI * (t[j] + shift)).sin();
375 }
376 y[i] = amp;
377 }
378 (data, y, t)
379 }
380
381 #[test]
382 fn test_elastic_attribution_joint_decomposition() {
383 let (data, y, t) = generate_test_data(15, 51);
384 let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Joint, 0.0, 5, 1e-3).unwrap();
385 let attr = elastic_pcr_attribution(&result, &y, 3, 10, 42).unwrap();
386
387 assert_eq!(attr.amplitude_contribution.len(), 15);
388 assert_eq!(attr.phase_contribution.len(), 15);
389
390 for i in 0..15 {
392 let sum = attr.amplitude_contribution[i] + attr.phase_contribution[i];
393 let expected = result.fitted_values[i] - result.alpha;
394 assert!(
395 (sum - expected).abs() < 1e-6,
396 "amp + phase should ≈ fitted - alpha at i={}: {} vs {}",
397 i,
398 sum,
399 expected
400 );
401 }
402 }
403
404 #[test]
405 fn test_elastic_attribution_vertical_only() {
406 let (data, y, t) = generate_test_data(15, 51);
407 let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Vertical, 0.0, 5, 1e-3).unwrap();
408 let attr = elastic_pcr_attribution(&result, &y, 3, 10, 42).unwrap();
409
410 for i in 0..15 {
412 assert!(
413 attr.phase_contribution[i].abs() < 1e-12,
414 "phase_contribution should be 0 for vertical-only at i={}",
415 i
416 );
417 }
418 assert!(
419 attr.phase_importance.abs() < 1e-12,
420 "phase_importance should be 0 for vertical-only"
421 );
422 }
423
424 #[test]
425 fn test_elastic_attribution_importance_nonnegative() {
426 let (data, y, t) = generate_test_data(15, 51);
427 let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Joint, 0.0, 5, 1e-3).unwrap();
428 let attr = elastic_pcr_attribution(&result, &y, 3, 20, 42).unwrap();
429
430 assert!(
431 attr.amplitude_importance >= 0.0,
432 "amplitude_importance should be >= 0, got {}",
433 attr.amplitude_importance
434 );
435 assert!(
436 attr.phase_importance >= 0.0,
437 "phase_importance should be >= 0, got {}",
438 attr.phase_importance
439 );
440 }
441}