Skip to main content

fdars_core/explain/
importance.rs

1//! Permutation importance, pointwise importance, and conditional permutation importance.
2
3use super::helpers::{
4    clone_scores_matrix, compute_conditioning_bins, compute_score_variance,
5    logistic_accuracy_from_scores, permute_component, project_scores, shuffle_global,
6};
7use crate::error::FdarError;
8use crate::matrix::FdMatrix;
9use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
10use rand::prelude::*;
11
12// ===========================================================================
13// FPC Permutation Importance
14// ===========================================================================
15
16/// Result of FPC permutation importance.
17#[derive(Debug, Clone, PartialEq)]
18pub struct FpcPermutationImportance {
19    /// R^2 (or accuracy) drop per component (length ncomp).
20    pub importance: Vec<f64>,
21    /// Baseline metric (R^2 or accuracy).
22    pub baseline_metric: f64,
23    /// Mean metric after permuting each component.
24    pub permuted_metric: Vec<f64>,
25}
26
27/// Permutation importance for a linear functional regression (metric = R^2).
28///
29/// # Errors
30///
31/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, its column
32/// count does not match `fit.fpca.mean`, or `y.len()` does not match the row
33/// count.
34/// Returns [`FdarError::InvalidParameter`] if `n_perm` is zero.
35/// Returns [`FdarError::ComputationFailed`] if the total sum of squares is zero.
36///
37/// # Examples
38///
39/// ```
40/// use fdars_core::matrix::FdMatrix;
41/// use fdars_core::scalar_on_function::fregre_lm;
42/// use fdars_core::explain::fpc_permutation_importance;
43///
44/// let (n, m) = (20, 30);
45/// let data = FdMatrix::from_column_major(
46///     (0..n * m).map(|k| {
47///         let i = (k % n) as f64;
48///         let j = (k / n) as f64;
49///         ((i + 1.0) * j * 0.2).sin()
50///     }).collect(),
51///     n, m,
52/// ).unwrap();
53/// let y: Vec<f64> = (0..n).map(|i| (i as f64 * 0.5).sin()).collect();
54/// let fit = fregre_lm(&data, &y, None, 3).unwrap();
55/// let imp = fpc_permutation_importance(&fit, &data, &y, 10, 42).unwrap();
56/// assert_eq!(imp.importance.len(), 3);
57/// ```
58#[must_use = "expensive computation whose result should not be discarded"]
59pub fn fpc_permutation_importance(
60    fit: &FregreLmResult,
61    data: &FdMatrix,
62    y: &[f64],
63    n_perm: usize,
64    seed: u64,
65) -> Result<FpcPermutationImportance, FdarError> {
66    let (n, m) = data.shape();
67    if n == 0 {
68        return Err(FdarError::InvalidDimension {
69            parameter: "data",
70            expected: ">0 rows".into(),
71            actual: "0".into(),
72        });
73    }
74    if n != y.len() {
75        return Err(FdarError::InvalidDimension {
76            parameter: "y",
77            expected: format!("{n} (matching data rows)"),
78            actual: format!("{}", y.len()),
79        });
80    }
81    if m != fit.fpca.mean.len() {
82        return Err(FdarError::InvalidDimension {
83            parameter: "data",
84            expected: format!("{} columns", fit.fpca.mean.len()),
85            actual: format!("{m}"),
86        });
87    }
88    if n_perm == 0 {
89        return Err(FdarError::InvalidParameter {
90            parameter: "n_perm",
91            message: "must be > 0".into(),
92        });
93    }
94    let ncomp = fit.ncomp;
95    let scores = project_scores(
96        data,
97        &fit.fpca.mean,
98        &fit.fpca.rotation,
99        ncomp,
100        &fit.fpca.weights,
101    );
102
103    // Baseline R^2 -- compute from same FPC-only prediction used in permuted path
104    // to ensure consistent comparison (gamma terms are constant across permutations)
105    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
106    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
107    if ss_tot == 0.0 {
108        return Err(FdarError::ComputationFailed {
109            operation: "fpc_permutation_importance",
110            detail: "total sum of squares is zero; all response values may be identical — check your data".into(),
111        });
112    }
113    let identity_idx: Vec<usize> = (0..n).collect();
114    let ss_res_base = permuted_ss_res_linear(
115        &scores,
116        &fit.coefficients,
117        y,
118        n,
119        ncomp,
120        ncomp,
121        &identity_idx,
122    );
123    let baseline = 1.0 - ss_res_base / ss_tot;
124
125    let mut rng = StdRng::seed_from_u64(seed);
126    let mut importance = vec![0.0; ncomp];
127    let mut permuted_metric = vec![0.0; ncomp];
128
129    for k in 0..ncomp {
130        let mut sum_r2 = 0.0;
131        for _ in 0..n_perm {
132            let mut idx: Vec<usize> = (0..n).collect();
133            idx.shuffle(&mut rng);
134            let ss_res_perm =
135                permuted_ss_res_linear(&scores, &fit.coefficients, y, n, ncomp, k, &idx);
136            sum_r2 += 1.0 - ss_res_perm / ss_tot;
137        }
138        let mean_perm = sum_r2 / n_perm as f64;
139        permuted_metric[k] = mean_perm;
140        importance[k] = baseline - mean_perm;
141    }
142
143    Ok(FpcPermutationImportance {
144        importance,
145        baseline_metric: baseline,
146        permuted_metric,
147    })
148}
149
150/// Permutation importance for functional logistic regression (metric = accuracy).
151///
152/// # Errors
153///
154/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, its column
155/// count does not match `fit.fpca.mean`, or `y.len()` does not match the row
156/// count.
157/// Returns [`FdarError::InvalidParameter`] if `n_perm` is zero.
158#[must_use = "expensive computation whose result should not be discarded"]
159pub fn fpc_permutation_importance_logistic(
160    fit: &FunctionalLogisticResult,
161    data: &FdMatrix,
162    y: &[f64],
163    n_perm: usize,
164    seed: u64,
165) -> Result<FpcPermutationImportance, FdarError> {
166    let (n, m) = data.shape();
167    if n == 0 {
168        return Err(FdarError::InvalidDimension {
169            parameter: "data",
170            expected: ">0 rows".into(),
171            actual: "0".into(),
172        });
173    }
174    if n != y.len() {
175        return Err(FdarError::InvalidDimension {
176            parameter: "y",
177            expected: format!("{n} (matching data rows)"),
178            actual: format!("{}", y.len()),
179        });
180    }
181    if m != fit.fpca.mean.len() {
182        return Err(FdarError::InvalidDimension {
183            parameter: "data",
184            expected: format!("{} columns", fit.fpca.mean.len()),
185            actual: format!("{m}"),
186        });
187    }
188    if n_perm == 0 {
189        return Err(FdarError::InvalidParameter {
190            parameter: "n_perm",
191            message: "must be > 0".into(),
192        });
193    }
194    let ncomp = fit.ncomp;
195    let scores = project_scores(
196        data,
197        &fit.fpca.mean,
198        &fit.fpca.rotation,
199        ncomp,
200        &fit.fpca.weights,
201    );
202
203    let baseline: f64 = (0..n)
204        .filter(|&i| {
205            let pred = if fit.probabilities[i] >= 0.5 {
206                1.0
207            } else {
208                0.0
209            };
210            (pred - y[i]).abs() < 1e-10
211        })
212        .count() as f64
213        / n as f64;
214
215    let mut rng = StdRng::seed_from_u64(seed);
216    let mut importance = vec![0.0; ncomp];
217    let mut permuted_metric = vec![0.0; ncomp];
218
219    for k in 0..ncomp {
220        let mut sum_acc = 0.0;
221        for _ in 0..n_perm {
222            let mut perm_scores = clone_scores_matrix(&scores, n, ncomp);
223            shuffle_global(&mut perm_scores, &scores, k, n, &mut rng);
224            sum_acc += logistic_accuracy_from_scores(
225                &perm_scores,
226                fit.intercept,
227                &fit.coefficients,
228                y,
229                n,
230                ncomp,
231            );
232        }
233        let mean_acc = sum_acc / n_perm as f64;
234        permuted_metric[k] = mean_acc;
235        importance[k] = baseline - mean_acc;
236    }
237
238    Ok(FpcPermutationImportance {
239        importance,
240        baseline_metric: baseline,
241        permuted_metric,
242    })
243}
244
245/// Compute SS_res with component k shuffled by given index permutation.
246fn permuted_ss_res_linear(
247    scores: &FdMatrix,
248    coefficients: &[f64],
249    y: &[f64],
250    n: usize,
251    ncomp: usize,
252    k: usize,
253    perm_idx: &[usize],
254) -> f64 {
255    (0..n)
256        .map(|i| {
257            let mut yhat = coefficients[0];
258            for c in 0..ncomp {
259                let s = if c == k {
260                    scores[(perm_idx[i], c)]
261                } else {
262                    scores[(i, c)]
263                };
264                yhat += coefficients[1 + c] * s;
265            }
266            (y[i] - yhat).powi(2)
267        })
268        .sum()
269}
270
271// ===========================================================================
272// Pointwise Variable Importance
273// ===========================================================================
274
275/// Result of pointwise variable importance analysis.
276#[derive(Debug, Clone, PartialEq)]
277#[non_exhaustive]
278pub struct PointwiseImportanceResult {
279    /// Importance at each grid point (length m).
280    pub importance: Vec<f64>,
281    /// Normalized importance summing to 1 (length m).
282    pub importance_normalized: Vec<f64>,
283    /// Per-component importance (ncomp x m).
284    pub component_importance: FdMatrix,
285    /// Variance of each FPC score (length ncomp).
286    pub score_variance: Vec<f64>,
287}
288
289/// Pointwise variable importance for a linear functional regression model.
290///
291/// Measures how much X(t_j) contributes to prediction variance via the FPC decomposition.
292///
293/// # Errors
294///
295/// Returns [`FdarError::InvalidParameter`] if `fit.ncomp` is zero.
296/// Returns [`FdarError::InvalidDimension`] if the rotation matrix has zero rows
297/// or the scores matrix has fewer than 2 rows.
298#[must_use = "expensive computation whose result should not be discarded"]
299pub fn pointwise_importance(fit: &FregreLmResult) -> Result<PointwiseImportanceResult, FdarError> {
300    let ncomp = fit.ncomp;
301    let m = fit.fpca.rotation.nrows();
302    let n = fit.fpca.scores.nrows();
303    if ncomp == 0 {
304        return Err(FdarError::InvalidParameter {
305            parameter: "ncomp",
306            message: "must be > 0".into(),
307        });
308    }
309    if m == 0 {
310        return Err(FdarError::InvalidDimension {
311            parameter: "rotation",
312            expected: ">0 rows".into(),
313            actual: "0".into(),
314        });
315    }
316    if n < 2 {
317        return Err(FdarError::InvalidDimension {
318            parameter: "scores",
319            expected: ">=2 rows".into(),
320            actual: format!("{n}"),
321        });
322    }
323
324    let score_variance = compute_score_variance(&fit.fpca.scores, n, ncomp);
325    let (component_importance, importance, importance_normalized) =
326        compute_pointwise_importance_core(
327            &fit.coefficients,
328            &fit.fpca.rotation,
329            &score_variance,
330            ncomp,
331            m,
332        );
333
334    Ok(PointwiseImportanceResult {
335        importance,
336        importance_normalized,
337        component_importance,
338        score_variance,
339    })
340}
341
342/// Pointwise variable importance for a functional logistic regression model.
343///
344/// # Errors
345///
346/// Returns [`FdarError::InvalidParameter`] if `fit.ncomp` is zero.
347/// Returns [`FdarError::InvalidDimension`] if the rotation matrix has zero rows
348/// or the scores matrix has fewer than 2 rows.
349#[must_use = "expensive computation whose result should not be discarded"]
350pub fn pointwise_importance_logistic(
351    fit: &FunctionalLogisticResult,
352) -> Result<PointwiseImportanceResult, FdarError> {
353    let ncomp = fit.ncomp;
354    let m = fit.fpca.rotation.nrows();
355    let n = fit.fpca.scores.nrows();
356    if ncomp == 0 {
357        return Err(FdarError::InvalidParameter {
358            parameter: "ncomp",
359            message: "must be > 0".into(),
360        });
361    }
362    if m == 0 {
363        return Err(FdarError::InvalidDimension {
364            parameter: "rotation",
365            expected: ">0 rows".into(),
366            actual: "0".into(),
367        });
368    }
369    if n < 2 {
370        return Err(FdarError::InvalidDimension {
371            parameter: "scores",
372            expected: ">=2 rows".into(),
373            actual: format!("{n}"),
374        });
375    }
376
377    let score_variance = compute_score_variance(&fit.fpca.scores, n, ncomp);
378    let (component_importance, importance, importance_normalized) =
379        compute_pointwise_importance_core(
380            &fit.coefficients,
381            &fit.fpca.rotation,
382            &score_variance,
383            ncomp,
384            m,
385        );
386
387    Ok(PointwiseImportanceResult {
388        importance,
389        importance_normalized,
390        component_importance,
391        score_variance,
392    })
393}
394
395/// Compute component importance matrix and aggregated importance.
396fn compute_pointwise_importance_core(
397    coefficients: &[f64],
398    rotation: &FdMatrix,
399    score_variance: &[f64],
400    ncomp: usize,
401    m: usize,
402) -> (FdMatrix, Vec<f64>, Vec<f64>) {
403    let mut component_importance = FdMatrix::zeros(ncomp, m);
404    for k in 0..ncomp {
405        let ck = coefficients[1 + k];
406        for j in 0..m {
407            component_importance[(k, j)] = (ck * rotation[(j, k)]).powi(2) * score_variance[k];
408        }
409    }
410
411    let mut importance = vec![0.0; m];
412    for j in 0..m {
413        for k in 0..ncomp {
414            importance[j] += component_importance[(k, j)];
415        }
416    }
417
418    let total: f64 = importance.iter().sum();
419    let importance_normalized = if total > 0.0 {
420        importance.iter().map(|&v| v / total).collect()
421    } else {
422        vec![0.0; m]
423    };
424
425    (component_importance, importance, importance_normalized)
426}
427
428// ===========================================================================
429// Conditional Permutation Importance
430// ===========================================================================
431
432/// Result of conditional permutation importance.
433#[derive(Debug, Clone, PartialEq)]
434#[non_exhaustive]
435pub struct ConditionalPermutationImportanceResult {
436    /// Conditional importance per FPC component, length ncomp.
437    pub importance: Vec<f64>,
438    /// Baseline metric (R^2 or accuracy).
439    pub baseline_metric: f64,
440    /// Mean metric after conditional permutation, length ncomp.
441    pub permuted_metric: Vec<f64>,
442    /// Unconditional (standard) permutation importance for comparison, length ncomp.
443    pub unconditional_importance: Vec<f64>,
444}
445
446/// Conditional permutation importance for a linear functional regression model.
447///
448/// # Errors
449///
450/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, its column
451/// count does not match `fit.fpca.mean`, or `y.len()` does not match the row
452/// count.
453/// Returns [`FdarError::InvalidParameter`] if `n_perm` or `n_bins` is zero.
454/// Returns [`FdarError::ComputationFailed`] if the total sum of squares is zero.
455#[must_use = "expensive computation whose result should not be discarded"]
456pub fn conditional_permutation_importance(
457    fit: &FregreLmResult,
458    data: &FdMatrix,
459    y: &[f64],
460    scalar_covariates: Option<&FdMatrix>,
461    n_bins: usize,
462    n_perm: usize,
463    seed: u64,
464) -> Result<ConditionalPermutationImportanceResult, FdarError> {
465    let (n, m) = data.shape();
466    if n == 0 {
467        return Err(FdarError::InvalidDimension {
468            parameter: "data",
469            expected: ">0 rows".into(),
470            actual: "0".into(),
471        });
472    }
473    if n != y.len() {
474        return Err(FdarError::InvalidDimension {
475            parameter: "y",
476            expected: format!("{n} (matching data rows)"),
477            actual: format!("{}", y.len()),
478        });
479    }
480    if m != fit.fpca.mean.len() {
481        return Err(FdarError::InvalidDimension {
482            parameter: "data",
483            expected: format!("{} columns", fit.fpca.mean.len()),
484            actual: format!("{m}"),
485        });
486    }
487    if n_perm == 0 {
488        return Err(FdarError::InvalidParameter {
489            parameter: "n_perm",
490            message: "must be > 0".into(),
491        });
492    }
493    if n_bins == 0 {
494        return Err(FdarError::InvalidParameter {
495            parameter: "n_bins",
496            message: "must be > 0".into(),
497        });
498    }
499    let _ = scalar_covariates;
500    let ncomp = fit.ncomp;
501    let scores = project_scores(
502        data,
503        &fit.fpca.mean,
504        &fit.fpca.rotation,
505        ncomp,
506        &fit.fpca.weights,
507    );
508
509    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
510    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
511    if ss_tot == 0.0 {
512        return Err(FdarError::ComputationFailed {
513            operation: "conditional_permutation_importance",
514            detail: "total sum of squares is zero; all response values may be identical — check your data".into(),
515        });
516    }
517    let ss_res_base: f64 = fit.residuals.iter().map(|r| r * r).sum();
518    let baseline = 1.0 - ss_res_base / ss_tot;
519
520    let predict_r2 = |score_mat: &FdMatrix| -> f64 {
521        let ss_res: f64 = (0..n)
522            .map(|i| {
523                let mut yhat = fit.coefficients[0];
524                for c in 0..ncomp {
525                    yhat += fit.coefficients[1 + c] * score_mat[(i, c)];
526                }
527                (y[i] - yhat).powi(2)
528            })
529            .sum();
530        1.0 - ss_res / ss_tot
531    };
532
533    let mut rng = StdRng::seed_from_u64(seed);
534    let mut importance = vec![0.0; ncomp];
535    let mut permuted_metric = vec![0.0; ncomp];
536    let mut unconditional_importance = vec![0.0; ncomp];
537
538    for k in 0..ncomp {
539        let bins = compute_conditioning_bins(&scores, ncomp, k, n, n_bins);
540        let (mean_cond, mean_uncond) =
541            permute_component(&scores, &bins, k, n, ncomp, n_perm, &mut rng, &predict_r2);
542        permuted_metric[k] = mean_cond;
543        importance[k] = baseline - mean_cond;
544        unconditional_importance[k] = baseline - mean_uncond;
545    }
546
547    Ok(ConditionalPermutationImportanceResult {
548        importance,
549        baseline_metric: baseline,
550        permuted_metric,
551        unconditional_importance,
552    })
553}
554
555/// Conditional permutation importance for a functional logistic regression model.
556///
557/// # Errors
558///
559/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, its column
560/// count does not match `fit.fpca.mean`, or `y.len()` does not match the row
561/// count.
562/// Returns [`FdarError::InvalidParameter`] if `n_perm` or `n_bins` is zero.
563#[must_use = "expensive computation whose result should not be discarded"]
564pub fn conditional_permutation_importance_logistic(
565    fit: &FunctionalLogisticResult,
566    data: &FdMatrix,
567    y: &[f64],
568    scalar_covariates: Option<&FdMatrix>,
569    n_bins: usize,
570    n_perm: usize,
571    seed: u64,
572) -> Result<ConditionalPermutationImportanceResult, FdarError> {
573    let (n, m) = data.shape();
574    if n == 0 {
575        return Err(FdarError::InvalidDimension {
576            parameter: "data",
577            expected: ">0 rows".into(),
578            actual: "0".into(),
579        });
580    }
581    if n != y.len() {
582        return Err(FdarError::InvalidDimension {
583            parameter: "y",
584            expected: format!("{n} (matching data rows)"),
585            actual: format!("{}", y.len()),
586        });
587    }
588    if m != fit.fpca.mean.len() {
589        return Err(FdarError::InvalidDimension {
590            parameter: "data",
591            expected: format!("{} columns", fit.fpca.mean.len()),
592            actual: format!("{m}"),
593        });
594    }
595    if n_perm == 0 {
596        return Err(FdarError::InvalidParameter {
597            parameter: "n_perm",
598            message: "must be > 0".into(),
599        });
600    }
601    if n_bins == 0 {
602        return Err(FdarError::InvalidParameter {
603            parameter: "n_bins",
604            message: "must be > 0".into(),
605        });
606    }
607    let _ = scalar_covariates;
608    let ncomp = fit.ncomp;
609    let scores = project_scores(
610        data,
611        &fit.fpca.mean,
612        &fit.fpca.rotation,
613        ncomp,
614        &fit.fpca.weights,
615    );
616
617    let baseline: f64 = (0..n)
618        .filter(|&i| {
619            let pred = if fit.probabilities[i] >= 0.5 {
620                1.0
621            } else {
622                0.0
623            };
624            (pred - y[i]).abs() < 1e-10
625        })
626        .count() as f64
627        / n as f64;
628
629    let predict_acc = |score_mat: &FdMatrix| -> f64 {
630        let correct: usize = (0..n)
631            .filter(|&i| {
632                let mut eta = fit.intercept;
633                for c in 0..ncomp {
634                    eta += fit.coefficients[1 + c] * score_mat[(i, c)];
635                }
636                let pred = if sigmoid(eta) >= 0.5 { 1.0 } else { 0.0 };
637                (pred - y[i]).abs() < 1e-10
638            })
639            .count();
640        correct as f64 / n as f64
641    };
642
643    let mut rng = StdRng::seed_from_u64(seed);
644    let mut importance = vec![0.0; ncomp];
645    let mut permuted_metric = vec![0.0; ncomp];
646    let mut unconditional_importance = vec![0.0; ncomp];
647
648    for k in 0..ncomp {
649        let bins = compute_conditioning_bins(&scores, ncomp, k, n, n_bins);
650        let (mean_cond, mean_uncond) =
651            permute_component(&scores, &bins, k, n, ncomp, n_perm, &mut rng, &predict_acc);
652        permuted_metric[k] = mean_cond;
653        importance[k] = baseline - mean_cond;
654        unconditional_importance[k] = baseline - mean_uncond;
655    }
656
657    Ok(ConditionalPermutationImportanceResult {
658        importance,
659        baseline_metric: baseline,
660        permuted_metric,
661        unconditional_importance,
662    })
663}